diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index 2471c3f2748d9..935dd4af748cf 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -21,8 +21,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union from airflow.configuration import conf -from airflow.models import TaskInstance -from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKeyType +from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance, TaskInstanceKeyType from airflow.stats import Stats from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import State diff --git a/airflow/jobs/base_job.py b/airflow/jobs/base_job.py index 6fe8bdbc3f24d..e52a76683672a 100644 --- a/airflow/jobs/base_job.py +++ b/airflow/jobs/base_job.py @@ -25,11 +25,12 @@ from sqlalchemy.exc import OperationalError from sqlalchemy.orm.session import make_transient -from airflow import models from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.executors.executor_loader import ExecutorLoader from airflow.models.base import ID_LEN, Base +from airflow.models.dagrun import DagRun +from airflow.models.taskinstance import TaskInstance from airflow.stats import Stats from airflow.utils import helpers, timezone from airflow.utils.helpers import convert_camel_to_snake @@ -268,22 +269,20 @@ def reset_state_for_orphaned_tasks(self, filter_by_dag_run=None, session=None): running_tis = self.executor.running resettable_states = [State.SCHEDULED, State.QUEUED] - TI = models.TaskInstance - DR = models.DagRun if filter_by_dag_run is None: resettable_tis = ( session - .query(TI) + .query(TaskInstance) .join( - DR, + DagRun, and_( - TI.dag_id == DR.dag_id, - TI.execution_date == DR.execution_date)) + TaskInstance.dag_id == DagRun.dag_id, + TaskInstance.execution_date == DagRun.execution_date)) .filter( # pylint: disable=comparison-with-callable - DR.state == State.RUNNING, - DR.run_type != DagRunType.BACKFILL_JOB.value, - TI.state.in_(resettable_states))).all() + DagRun.state == State.RUNNING, + DagRun.run_type != DagRunType.BACKFILL_JOB.value, + TaskInstance.state.in_(resettable_states))).all() else: resettable_tis = filter_by_dag_run.get_task_instances(state=resettable_states, session=session) @@ -300,9 +299,9 @@ def query(result, items): if not items: return result - filter_for_tis = TI.filter_for_tis(items) - reset_tis = session.query(TI).filter( - filter_for_tis, TI.state.in_(resettable_states) + filter_for_tis = TaskInstance.filter_for_tis(items) + reset_tis = session.query(TaskInstance).filter( + filter_for_tis, TaskInstance.state.in_(resettable_states) ).with_for_update().all() for ti in reset_tis: diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 4d2a0d79e49e5..8360097e0ec5a 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -131,7 +131,7 @@ def clear_task_instances(tis, TaskInstanceKeyType = Tuple[str, str, datetime, int] -class TaskInstance(Base, LoggingMixin): +class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904 """ Task instances store the state of a task instance. This table is the authority and single source of truth around what tasks have run and the @@ -180,6 +180,7 @@ class TaskInstance(Base, LoggingMixin): ) def __init__(self, task, execution_date: datetime, state: Optional[str] = None): + super().__init__() self.dag_id = task.dag_id self.task_id = task.task_id self.task = task @@ -209,6 +210,8 @@ def __init__(self, task, execution_date: datetime, state: Optional[str] = None): # Is this TaskInstance being currently running within `airflow tasks run --raw`. # Not persisted to the database so only valid for the current process self.raw = False + # can be changed when calling 'run' + self.test_mode = False @reconstructor def init_on_load(self): @@ -249,9 +252,10 @@ def prev_attempted_tries(self): @property def next_try_number(self): + """Setting Next Try Number""" return self._try_number + 1 - def command_as_list( + def command_as_list( # pylint: disable=too-many-arguments self, mark_success=False, ignore_all_deps=False, @@ -297,7 +301,7 @@ def command_as_list( cfg_path=cfg_path) @staticmethod - def generate_command(dag_id: str, + def generate_command(dag_id: str, # pylint: disable=too-many-arguments task_id: str, execution_date: datetime, mark_success: Optional[bool] = False, @@ -383,6 +387,7 @@ def generate_command(dag_id: str, @property def log_filepath(self): + """Filepath for TaskInstance""" iso = self.execution_date.isoformat() log = os.path.expanduser(conf.get('logging', 'BASE_LOG_FOLDER')) return ("{log}/{dag_id}/{task_id}/{iso}.log".format( @@ -390,6 +395,7 @@ def log_filepath(self): @property def log_url(self): + """Log URL for TaskInstance""" iso = quote(self.execution_date.isoformat()) base_url = conf.get('webserver', 'BASE_URL') return base_url + ( @@ -401,6 +407,7 @@ def log_url(self): @property def mark_success_url(self): + """URL to mark TI success""" iso = quote(self.execution_date.isoformat()) base_url = conf.get('webserver', 'BASE_URL') return base_url + ( @@ -418,6 +425,9 @@ def current_state(self, session=None) -> str: Get the very latest state from the database, if a session is passed, we use and looking up the state becomes part of the session, otherwise a new session is used. + + :param session: SQLAlchemy ORM Session + :type session: Session """ ti = session.query(TaskInstance).filter( TaskInstance.dag_id == self.dag_id, @@ -434,6 +444,9 @@ def current_state(self, session=None) -> str: def error(self, session=None): """ Forces the task instance's state to FAILED in the database. + + :param session: SQLAlchemy ORM Session + :type session: Session """ self.log.error("Recording the task instance as FAILED") self.state = State.FAILED @@ -445,10 +458,14 @@ def refresh_from_db(self, session=None, lock_for_update=False) -> None: """ Refreshes the task instance from the database based on the primary key + :param session: SQLAlchemy ORM Session + :type session: Session :param lock_for_update: if True, indicates that the database should lock the TaskInstance (issuing a FOR UPDATE clause) until the session is committed. + :type lock_for_update: bool """ + self.log.debug("Refreshing TaskInstance %s from DB", self) qry = session.query(TaskInstance).filter( TaskInstance.dag_id == self.dag_id, @@ -467,7 +484,7 @@ def refresh_from_db(self, session=None, lock_for_update=False) -> None: self.state = ti.state # Get the raw value of try_number column, don't read through the # accessor here otherwise it will be incremented by one already. - self.try_number = ti._try_number + self.try_number = ti._try_number # pylint: disable=protected-access self.max_tries = ti.max_tries self.hostname = ti.hostname self.unixname = ti.unixname @@ -482,6 +499,8 @@ def refresh_from_db(self, session=None, lock_for_update=False) -> None: else: self.state = None + self.log.debug("Refreshed TaskInstance %s", self) + def refresh_from_task(self, task, pool_override=None): """ Copy common attributes from the given task. @@ -504,13 +523,18 @@ def refresh_from_task(self, task, pool_override=None): def clear_xcom_data(self, session=None): """ Clears all XCom data from the database for the task instance + + :param session: SQLAlchemy ORM Session + :type session: Session """ + self.log.debug("Clearing XCom data") session.query(XCom).filter( XCom.dag_id == self.dag_id, XCom.task_id == self.task_id, XCom.execution_date == self.execution_date ).delete() session.commit() + self.log.debug("XCom data cleared") @property def key(self) -> TaskInstanceKeyType: @@ -521,6 +545,17 @@ def key(self) -> TaskInstanceKeyType: @provide_session def set_state(self, state, session=None, commit=True): + """ + Set TaskInstance state + + :param state: State to set for the TI + :type state: str + :param session: SQLAlchemy ORM Session + :type session: Session + :param commit: Whether or not to commit session + :type commit: bool + """ + self.log.debug("Setting task state for %s to %s", self, state) self.state = state self.start_date = timezone.utcnow() self.end_date = timezone.utcnow() @@ -546,6 +581,9 @@ def are_dependents_done(self, session=None): This is useful when you do not want to start processing the next schedule of a task until the dependents are done. For instance, if the task DROPs and recreates a table. + + :param session: SQLAlchemy ORM Session + :type session: Session """ task = self.task @@ -571,6 +609,7 @@ def get_previous_ti( The task instance for the task that ran before this task instance. :param state: If passed, it only take into account instances of a specific state. + :param session: SQLAlchemy ORM Session """ dag = self.task.dag if dag: @@ -643,6 +682,7 @@ def get_previous_execution_date( The execution date from property previous_ti_success. :param state: If passed, it only take into account instances of a specific state. + :param session: SQLAlchemy ORM Session """ self.log.debug("previous_execution_date was called") prev_ti = self.get_previous_ti(state=state, session=session) @@ -658,6 +698,7 @@ def get_previous_start_date( The start date from property previous_ti_success. :param state: If passed, it only take into account instances of a specific state. + :param session: SQLAlchemy ORM Session """ self.log.debug("previous_start_date was called") prev_ti = self.get_previous_ti(state=state, session=session) @@ -723,6 +764,7 @@ def get_failed_dep_statuses( self, dep_context=None, session=None): + """Get failed Dependencies""" dep_context = dep_context or DepContext() for dep in dep_context.deps | self.task.deps: for dep_status in dep.get_dep_statuses( @@ -756,13 +798,13 @@ def next_retry_datetime(self): # will occurr in the modded_hash calculation. min_backoff = int(math.ceil(delay.total_seconds() * (2 ** (self.try_number - 2)))) # deterministic per task instance - hash = int(hashlib.sha1("{}#{}#{}#{}".format(self.dag_id, - self.task_id, - self.execution_date, - self.try_number) - .encode('utf-8')).hexdigest(), 16) + ti_hash = int(hashlib.sha1("{}#{}#{}#{}".format(self.dag_id, + self.task_id, + self.execution_date, + self.try_number) + .encode('utf-8')).hexdigest(), 16) # between 1 and 1.0 * delay * (2^retry_number) - modded_hash = min_backoff + hash % min_backoff + modded_hash = min_backoff + ti_hash % min_backoff # timedelta has a maximum representable value. The exponentiation # here means this value can be exceeded after a certain number # of tries (around 50 if the initial delay is 1s, even fewer if @@ -786,11 +828,11 @@ def ready_for_retry(self): self.next_retry_datetime() < timezone.utcnow()) @provide_session - def get_dagrun(self, session=None): + def get_dagrun(self, session: Session = None): """ Returns the DagRun for this TaskInstance - :param session: + :param session: SQLAlchemy ORM Session :return: DagRun """ from airflow.models.dagrun import DagRun # Avoid circular import @@ -802,7 +844,7 @@ def get_dagrun(self, session=None): return dr @provide_session - def check_and_change_state_before_execution( + def check_and_change_state_before_execution( # pylint: disable=too-many-arguments self, verbose: bool = True, ignore_all_deps: bool = False, @@ -833,11 +875,16 @@ def check_and_change_state_before_execution( :type mark_success: bool :param test_mode: Doesn't record success or failure in the DB :type test_mode: bool + :param job_id: Job (BackfillJob / LocalTaskJob / SchedulerJob) ID + :type job_id: str :param pool: specifies the pool to use to run the task instance :type pool: str + :param session: SQLAlchemy ORM Session + :type session: Session :return: whether the state was changed to running or not :rtype: bool """ + task = self.task self.refresh_from_task(task, pool_override=pool) self.test_mode = test_mode @@ -849,7 +896,7 @@ def check_and_change_state_before_execution( Stats.incr('previously_succeeded', 1, 1) # TODO: Logging needs cleanup, not clear what is being printed - hr = "\n" + ("-" * 80) # Line break + hr_line_break = "\n" + ("-" * 80) # Line break if not mark_success: # Firstly find non-runnable and non-requeueable tis. @@ -892,22 +939,22 @@ def check_and_change_state_before_execution( session=session, verbose=True): self.state = State.NONE - self.log.warning(hr) + self.log.warning(hr_line_break) self.log.warning( "Rescheduling due to concurrency limits reached " "at task runtime. Attempt %s of " "%s. State set to NONE.", self.try_number, self.max_tries + 1 ) - self.log.warning(hr) + self.log.warning(hr_line_break) self.queued_dttm = timezone.utcnow() session.merge(self) session.commit() return False # print status message - self.log.info(hr) + self.log.info(hr_line_break) self.log.info("Starting attempt %s of %s", self.try_number, self.max_tries + 1) - self.log.info(hr) + self.log.info(hr_line_break) self._try_number += 1 if not test_mode: @@ -957,9 +1004,9 @@ def _run_raw_task( :type test_mode: bool :param pool: specifies the pool to use to run the task instance :type pool: str + :param session: SQLAlchemy ORM Session + :type session: Session """ - from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF - from airflow.sensors.base_sensor_operator import BaseSensorOperator task = self.task self.test_mode = test_mode @@ -974,80 +1021,7 @@ def _run_raw_task( try: if not mark_success: context = self.get_template_context() - - task_copy = task.prepare_for_execution() - - # Sensors in `poke` mode can block execution of DAGs when running - # with single process executor, thus we change the mode to`reschedule` - # to allow parallel task being scheduled and executed - if isinstance(task_copy, BaseSensorOperator) and \ - conf.get('core', 'executor') == "DebugExecutor": - self.log.warning("DebugExecutor changes sensor mode to 'reschedule'.") - task_copy.mode = 'reschedule' - - self.task = task_copy - - def signal_handler(signum, frame): - self.log.error("Received SIGTERM. Terminating subprocesses.") - task_copy.on_kill() - raise AirflowException("Task received SIGTERM signal") - signal.signal(signal.SIGTERM, signal_handler) - - # Don't clear Xcom until the task is certain to execute - self.clear_xcom_data() - - start_time = time.time() - - self.render_templates(context=context) - if STORE_SERIALIZED_DAGS: - RTIF.write(RTIF(ti=self, render_templates=False), session=session) - RTIF.delete_old_records(self.task_id, self.dag_id, session=session) - - # Export context to make it available for operators to use. - airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) - self.log.info("Exporting the following env vars:\n%s", - '\n'.join(["{}={}".format(k, v) - for k, v in airflow_context_vars.items()])) - os.environ.update(airflow_context_vars) - task_copy.pre_execute(context=context) - - try: - if task.on_execute_callback: - task.on_execute_callback(context) - except Exception as e3: - self.log.error("Failed when executing execute callback") - self.log.exception(e3) - - # If a timeout is specified for the task, make it fail - # if it goes beyond - if task_copy.execution_timeout: - try: - with timeout(int( - task_copy.execution_timeout.total_seconds())): - result = task_copy.execute(context=context) - except AirflowTaskTimeout: - task_copy.on_kill() - raise - else: - result = task_copy.execute(context=context) - - # If the task returns a result, push an XCom containing it - if task_copy.do_xcom_push and result is not None: - self.xcom_push(key=XCOM_RETURN_KEY, value=result) - - task_copy.post_execute(context=context, result=result) - - end_time = time.time() - duration = end_time - start_time - Stats.timing( - 'dag.{dag_id}.{task_id}.duration'.format( - dag_id=task_copy.dag_id, - task_id=task_copy.task_id), - duration) - - Stats.incr('operator_successes_{}'.format( - self.task.__class__.__name__), 1, 1) - Stats.incr('ti_successes') + self._prepare_and_execute_task_with_callbacks(context, session, task) self.refresh_from_db(lock_for_update=True) self.state = State.SUCCESS except AirflowSkipException as e: @@ -1089,13 +1063,7 @@ def signal_handler(signum, frame): finally: Stats.incr('ti.finish.{}.{}.{}'.format(task.dag_id, task.task_id, self.state)) - # Success callback - try: - if task.on_success_callback: - task.on_success_callback(context) - except Exception as e3: - self.log.error("Failed when executing success callback") - self.log.exception(e3) + self._run_success_callback(context, task) # Recording SUCCESS self.end_date = timezone.utcnow() @@ -1114,8 +1082,108 @@ def signal_handler(signum, frame): session.merge(self) session.commit() + def _prepare_and_execute_task_with_callbacks(self, context, session, task): + """ + Prepare Task for Execution + """ + from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF + from airflow.sensors.base_sensor_operator import BaseSensorOperator + + task_copy = task.prepare_for_execution() + # Sensors in `poke` mode can block execution of DAGs when running + # with single process executor, thus we change the mode to`reschedule` + # to allow parallel task being scheduled and executed + if ( + isinstance(task_copy, BaseSensorOperator) and + conf.get('core', 'executor') == "DebugExecutor" + ): + self.log.warning("DebugExecutor changes sensor mode to 'reschedule'.") + task_copy.mode = 'reschedule' + self.task = task_copy + + def signal_handler(signum, frame): # pylint: disable=unused-argument + self.log.error("Received SIGTERM. Terminating subprocesses.") + task_copy.on_kill() + raise AirflowException("Task received SIGTERM signal") + + signal.signal(signal.SIGTERM, signal_handler) + + # Don't clear Xcom until the task is certain to execute + self.clear_xcom_data() + start_time = time.time() + + self.render_templates(context=context) + if STORE_SERIALIZED_DAGS: + RTIF.write(RTIF(ti=self, render_templates=False), session=session) + RTIF.delete_old_records(self.task_id, self.dag_id, session=session) + + # Export context to make it available for operators to use. + airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) + self.log.info("Exporting the following env vars:\n%s", + '\n'.join(["{}={}".format(k, v) + for k, v in airflow_context_vars.items()])) + + os.environ.update(airflow_context_vars) + + # Run pre_execute callback + task_copy.pre_execute(context=context) + + # Run on_execute callback + self._run_execute_callback(context, task) + + # Execute the task + result = self._execute_task(context, task_copy) + + # Run post_execute callback + task_copy.post_execute(context=context, result=result) + + end_time = time.time() + duration = end_time - start_time + Stats.timing('dag.{dag_id}.{task_id}.duration'.format(dag_id=task_copy.dag_id, + task_id=task_copy.task_id), + duration) + Stats.incr('operator_successes_{}'.format(self.task.__class__.__name__), 1, 1) + Stats.incr('ti_successes') + + def _run_success_callback(self, context, task): + """Functions that need to be run if Task is successful""" + # Success callback + try: + if task.on_success_callback: + task.on_success_callback(context) + except Exception as exc: # pylint: disable=broad-except + self.log.error("Failed when executing success callback") + self.log.exception(exc) + + def _execute_task(self, context, task_copy): + """Executes Task (optionally with a Timeout) and pushes Xcom results""" + # If a timeout is specified for the task, make it fail + # if it goes beyond + if task_copy.execution_timeout: + try: + with timeout(int(task_copy.execution_timeout.total_seconds())): + result = task_copy.execute(context=context) + except AirflowTaskTimeout: + task_copy.on_kill() + raise + else: + result = task_copy.execute(context=context) + # If the task returns a result, push an XCom containing it + if task_copy.do_xcom_push and result is not None: + self.xcom_push(key=XCOM_RETURN_KEY, value=result) + return result + + def _run_execute_callback(self, context, task): + """Functions that need to be run before a Task is executed""" + try: + if task.on_execute_callback: + task.on_execute_callback(context) + except Exception as exc: # pylint: disable=broad-except + self.log.error("Failed when executing execute callback") + self.log.exception(exc) + @provide_session - def run( + def run( # pylint: disable=too-many-arguments self, verbose: bool = True, ignore_all_deps: bool = False, @@ -1127,6 +1195,7 @@ def run( job_id: Optional[str] = None, pool: Optional[str] = None, session=None) -> None: + """Run TaskInstance""" res = self.check_and_change_state_before_execution( verbose=verbose, ignore_all_deps=ignore_all_deps, @@ -1147,6 +1216,7 @@ def run( session=session) def dry_run(self): + """Only Renders Templates for the TI""" task = self.task task_copy = task.prepare_for_execution() self.task = task_copy @@ -1155,7 +1225,11 @@ def dry_run(self): task_copy.dry_run() @provide_session - def _handle_reschedule(self, actual_start_date, reschedule_exception, test_mode=False, context=None, + def _handle_reschedule(self, + actual_start_date, + reschedule_exception, + test_mode=False, + context=None, # pylint: disable=unused-argument session=None): # Don't record reschedule request in test mode if test_mode: @@ -1182,6 +1256,7 @@ def _handle_reschedule(self, actual_start_date, reschedule_exception, test_mode= @provide_session def handle_failure(self, error, test_mode=None, context=None, force_fail=False, session=None): + """Handle Failure for the TaskInstance""" if test_mode is None: test_mode = self.test_mode if context is None: @@ -1236,17 +1311,17 @@ def handle_failure(self, error, test_mode=None, context=None, force_fail=False, if email_for_state and task.email: try: self.email_alert(error) - except Exception as e2: + except Exception as exec2: # pylint: disable=broad-except self.log.error('Failed to send email to: %s', task.email) - self.log.exception(e2) + self.log.exception(exec2) # Handling callbacks pessimistically if callback: try: callback(context) - except Exception as e3: + except Exception as exec3: # pylint: disable=broad-except self.log.error("Failed at executing callback") - self.log.exception(e3) + self.log.exception(exec3) if not test_mode: session.merge(self) @@ -1263,7 +1338,8 @@ def _safe_date(self, date_attr, fmt): return '' @provide_session - def get_template_context(self, session=None) -> Dict[str, Any]: + def get_template_context(self, session=None) -> Dict[str, Any]: # pylint: disable=too-many-locals + """Return TI Context""" task = self.task from airflow import macros @@ -1352,8 +1428,9 @@ def __repr__(self): @staticmethod def get( item: str, - default_var: Any = Variable._Variable__NO_DEFAULT_SENTINEL, + default_var: Any = Variable._Variable__NO_DEFAULT_SENTINEL, # pylint: disable=W0212 ): + """Get Airflow Variable value""" return Variable.get(item, default_var=default_var) class VariableJsonAccessor: @@ -1378,8 +1455,9 @@ def __repr__(self): @staticmethod def get( item: str, - default_var: Any = Variable._Variable__NO_DEFAULT_SENTINEL, + default_var: Any = Variable._Variable__NO_DEFAULT_SENTINEL, # pylint: disable=W0212 ): + """Get Airflow Variable after deserializing JSON value""" return Variable.get(item, default_var=default_var, deserialize_json=True) return { @@ -1447,7 +1525,9 @@ def get_rendered_template_fields(self): self.render_templates() def overwrite_params_with_dag_run_conf(self, params, dag_run): + """Overwrite Task Params with DagRun.conf""" if dag_run and dag_run.conf: + self.log.debug("Updating task params (%s) with DagRun.conf (%s)", params, dag_run.conf) params.update(dag_run.conf) def render_templates(self, context: Optional[Dict] = None) -> None: @@ -1458,6 +1538,7 @@ def render_templates(self, context: Optional[Dict] = None) -> None: self.task.render_template_fields(context) def email_alert(self, exception): + """Send Email Alert with exception trace""" exception_html = str(exception).replace('\n', '
') jinja_context = self.get_template_context() # This function is called after changing the state @@ -1495,7 +1576,7 @@ def render(key, content): html_content = render('html_content_template', default_html_content) try: send_email(self.task.email, subject, html_content) - except Exception: + except Exception: # pylint: disable=broad-except default_html_content_err = ( 'Try {{try_number}} out of {{max_tries + 1}}
' 'Exception:
Failed attempt to attach error logs
' @@ -1508,10 +1589,12 @@ def render(key, content): send_email(self.task.email, subject, html_content_err) def set_duration(self) -> None: + """Set TI duration""" if self.end_date and self.start_date: self.duration = (self.end_date - self.start_date).total_seconds() else: self.duration = None + self.log.debug("Task Duration set to %s", self.duration) def xcom_push( self, @@ -1545,7 +1628,7 @@ def xcom_push( dag_id=self.dag_id, execution_date=execution_date or self.execution_date) - def xcom_pull( + def xcom_pull( # pylint: disable=inconsistent-return-statements self, task_ids: Optional[Union[str, Iterable[str]]] = None, dag_id: Optional[str] = None, @@ -1605,6 +1688,7 @@ def xcom_pull( @provide_session def get_num_running_task_instances(self, session): + """Return Number of running TIs from the DB""" # .count() is inefficient return session.query(func.count()).filter( TaskInstance.dag_id == self.dag_id, diff --git a/airflow/operators/branch_operator.py b/airflow/operators/branch_operator.py index 247d4cc23375b..a4653416c47ee 100644 --- a/airflow/operators/branch_operator.py +++ b/airflow/operators/branch_operator.py @@ -19,7 +19,8 @@ from typing import Dict, Iterable, Union -from airflow.models import BaseOperator, SkipMixin +from airflow.models import BaseOperator +from airflow.models.skipmixin import SkipMixin class BaseBranchOperator(BaseOperator, SkipMixin): diff --git a/airflow/operators/python.py b/airflow/operators/python.py index 7107e17596827..5bbc715ec1b1a 100644 --- a/airflow/operators/python.py +++ b/airflow/operators/python.py @@ -32,8 +32,9 @@ import dill from airflow.exceptions import AirflowException -from airflow.models import BaseOperator, SkipMixin +from airflow.models import BaseOperator from airflow.models.dag import DAG, DagContext +from airflow.models.skipmixin import SkipMixin from airflow.models.xcom_arg import XComArg from airflow.utils.decorators import apply_defaults from airflow.utils.process_utils import execute_in_subprocess diff --git a/airflow/ti_deps/deps/dagrun_exists_dep.py b/airflow/ti_deps/deps/dagrun_exists_dep.py index 92c0ad8b3004b..b04daa89d5f3c 100644 --- a/airflow/ti_deps/deps/dagrun_exists_dep.py +++ b/airflow/ti_deps/deps/dagrun_exists_dep.py @@ -34,7 +34,7 @@ def _get_dep_statuses(self, ti, session, dep_context): dagrun = ti.get_dagrun(session) if not dagrun: # The import is needed here to avoid a circular dependency - from airflow.models import DagRun + from airflow.models.dagrun import DagRun running_dagruns = DagRun.find( dag_id=dag.dag_id, state=State.RUNNING, diff --git a/airflow/ti_deps/deps/ready_to_reschedule.py b/airflow/ti_deps/deps/ready_to_reschedule.py index 57e7dee5b4a5e..5f3530c2b5fc9 100644 --- a/airflow/ti_deps/deps/ready_to_reschedule.py +++ b/airflow/ti_deps/deps/ready_to_reschedule.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -from airflow.models import TaskReschedule +from airflow.models.taskreschedule import TaskReschedule from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.utils import timezone from airflow.utils.session import provide_session diff --git a/tests/models/test_skipmixin.py b/tests/models/test_skipmixin.py index 06ed6515f5ba7..b8ef213ccdc87 100644 --- a/tests/models/test_skipmixin.py +++ b/tests/models/test_skipmixin.py @@ -23,7 +23,9 @@ import pendulum from airflow import settings -from airflow.models import DAG, SkipMixin, TaskInstance as TI +from airflow.models.dag import DAG +from airflow.models.skipmixin import SkipMixin +from airflow.models.taskinstance import TaskInstance as TI from airflow.operators.dummy_operator import DummyOperator from airflow.utils import timezone from airflow.utils.state import State