diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index 90c4372a9f693..1a7fa6f10565a 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -25,10 +25,10 @@ from contextlib import redirect_stderr, redirect_stdout, suppress from datetime import timedelta from multiprocessing.connection import Connection as MultiprocessingConnection -from typing import Iterator, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Iterator, List, Optional, Set, Tuple from setproctitle import setproctitle -from sqlalchemy import func, or_ +from sqlalchemy import exc, func, or_ from sqlalchemy.orm.session import Session from airflow import models, settings @@ -52,6 +52,9 @@ from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import State +if TYPE_CHECKING: + from airflow.models.operator import Operator + DR = models.DagRun TI = models.TaskInstance @@ -625,7 +628,7 @@ def execute_callbacks( self.log.debug("Processing Callback Request: %s", request) try: if isinstance(request, TaskCallbackRequest): - self._execute_task_callbacks(dagbag, request) + self._execute_task_callbacks(dagbag, request, session=session) elif isinstance(request, SlaCallbackRequest): self.manage_slas(dagbag.get_dag(request.dag_id), session=session) elif isinstance(request, DagCallbackRequest): @@ -637,7 +640,27 @@ def execute_callbacks( request.full_filepath, ) - session.commit() + session.flush() + + def execute_callbacks_without_dag( + self, callback_requests: List[CallbackRequest], session: Session + ) -> None: + """ + Execute what callbacks we can as "best effort" when the dag cannot be found/had parse errors. + + This is so important so that tasks that failed when there is a parse + error don't get stuck in queued state. + """ + for request in callback_requests: + self.log.debug("Processing Callback Request: %s", request) + if isinstance(request, TaskCallbackRequest): + self._execute_task_callbacks(None, request, session) + else: + self.log.info( + "Not executing %s callback for file %s as there was a dag parse error", + request.__class__.__name__, + request.full_filepath, + ) @provide_session def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, session: Session): @@ -647,18 +670,51 @@ def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, se dagrun=dag_run, success=not request.is_failure_callback, reason=request.msg, session=session ) - def _execute_task_callbacks(self, dagbag: DagBag, request: TaskCallbackRequest): + def _execute_task_callbacks( + self, dagbag: Optional[DagBag], request: TaskCallbackRequest, session: Session + ): + if not request.is_failure_callback: + return + simple_ti = request.simple_task_instance - if simple_ti.dag_id in dagbag.dags: + ti: Optional[TI] = ( + session.query(TI) + .filter_by( + dag_id=simple_ti.dag_id, + run_id=simple_ti.run_id, + task_id=simple_ti.task_id, + map_index=simple_ti.map_index, + ) + .one_or_none() + ) + if not ti: + return + + task: Optional["Operator"] = None + + if dagbag and simple_ti.dag_id in dagbag.dags: dag = dagbag.dags[simple_ti.dag_id] if simple_ti.task_id in dag.task_ids: task = dag.get_task(simple_ti.task_id) - if request.is_failure_callback: - ti = TI(task, run_id=simple_ti.run_id, map_index=simple_ti.map_index) - # TODO: Use simple_ti to improve performance here in the future - ti.refresh_from_db() - ti.handle_failure(error=request.msg, test_mode=self.UNIT_TEST_MODE) - self.log.info('Executed failure callback for %s in state %s', ti, ti.state) + else: + # We don't have the _real_ dag here (perhaps it had a parse error?) but we still want to run + # `handle_failure` so that the state of the TI gets progressed. + # + # Since handle_failure _really_ wants a task, we do our best effort to give it one + from airflow.models.serialized_dag import SerializedDagModel + + try: + model = session.query(SerializedDagModel).get(simple_ti.dag_id) + if model: + task = model.dag.get_task(simple_ti.task_id) + except (exc.NoResultFound, TaskNotFound): + pass + if task: + ti.refresh_from_task(task) + + ti.handle_failure(error=request.msg, test_mode=self.UNIT_TEST_MODE, session=session) + self.log.info('Executed failure callback for %s in state %s', ti, ti.state) + session.flush() @provide_session def process_file( @@ -666,7 +722,7 @@ def process_file( file_path: str, callback_requests: List[CallbackRequest], pickle_dags: bool = False, - session: Session = None, + session: Session = NEW_SESSION, ) -> Tuple[int, int]: """ Process a Python file containing Airflow DAGs. @@ -702,12 +758,19 @@ def process_file( else: self.log.warning("No viable dags retrieved from %s", file_path) self.update_import_errors(session, dagbag) + if callback_requests: + # If there were callback requests for this file but there was a + # parse error we still need to progress the state of TIs, + # otherwise they might be stuck in queued/running for ever! + self.execute_callbacks_without_dag(callback_requests, session) return 0, len(dagbag.import_errors) - self.execute_callbacks(dagbag, callback_requests) + self.execute_callbacks(dagbag, callback_requests, session) + session.commit() # Save individual DAGs in the ORM - dagbag.sync_to_db() + dagbag.sync_to_db(session) + session.commit() if pickle_dags: paused_dag_ids = DagModel.get_paused_dag_ids(dag_ids=dagbag.dag_ids) diff --git a/airflow/models/log.py b/airflow/models/log.py index d3ba41a071331..3f658c217417d 100644 --- a/airflow/models/log.py +++ b/airflow/models/log.py @@ -55,7 +55,8 @@ def __init__(self, event, task_instance=None, owner=None, extra=None, **kwargs): self.task_id = task_instance.task_id self.execution_date = task_instance.execution_date self.map_index = task_instance.map_index - task_owner = task_instance.task.owner + if task_instance.task: + task_owner = task_instance.task.owner if 'task_id' in kwargs: self.task_id = kwargs['task_id'] diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index f45b385769ac7..656d7754561a2 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1845,9 +1845,6 @@ def handle_failure( if test_mode is None: test_mode = self.test_mode - if context is None: - context = self.get_template_context() - if error: if isinstance(error, BaseException): tb = self.get_truncated_error_traceback(error, truncate_to=self._execute_task) @@ -1859,7 +1856,7 @@ def handle_failure( self.end_date = timezone.utcnow() self.set_duration() - Stats.incr(f'operator_failures_{self.task.task_type}') + Stats.incr(f'operator_failures_{self.operator}') Stats.incr('ti_failures') if not test_mode: session.add(Log(State.FAILED, self)) @@ -1869,6 +1866,10 @@ def handle_failure( self.clear_next_method_args() + # In extreme cases (zombie in case of dag with parse error) we might _not_ have a Task. + if context is None and self.task: + context = self.get_template_context(session) + if context is not None: context['exception'] = error @@ -1886,7 +1887,8 @@ def handle_failure( task: Optional[BaseOperator] = None try: - task = self.task.unmap((context, session)) + if self.task and context: + task = self.task.unmap((context, session)) except Exception: self.log.error("Unable to unmap task to determine if we need to send an alert email") @@ -1911,7 +1913,7 @@ def handle_failure( except Exception: self.log.exception('Failed to send email to: %s', task.email) - if callback: + if callback and context: self._run_finished_callback(callback, context, callback_type) if not test_mode: @@ -1924,6 +1926,9 @@ def is_eligible_to_retry(self): # If a task is cleared when running, it goes into RESTARTING state and is always # eligible for retry return True + if not self.task: + # Couldn't load the task, don't know number of retries, guess: + return self.try_number <= self.max_tries return self.task.retries and self.try_number <= self.max_tries diff --git a/tests/conftest.py b/tests/conftest.py index e6329e1991488..b7814c6b42137 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -621,7 +621,7 @@ def cleanup(self): if not dag_ids: return # To isolate problems here with problems from elsewhere on the session object - self.session.flush() + self.session.rollback() self.session.query(SerializedDagModel).filter( SerializedDagModel.dag_id.in_(dag_ids) diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py index eb122698f4d0f..d007851092f29 100644 --- a/tests/dag_processing/test_processor.py +++ b/tests/dag_processing/test_processor.py @@ -30,6 +30,7 @@ from airflow.dag_processing.manager import DagFileProcessorAgent from airflow.dag_processing.processor import DagFileProcessor from airflow.models import DagBag, DagModel, SlaMiss, TaskInstance, errors +from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import SimpleTaskInstance from airflow.operators.empty import EmptyOperator from airflow.utils import timezone @@ -388,10 +389,44 @@ def test_execute_on_failure_callbacks(self, mock_ti_handle_failure): full_filepath="A", simple_task_instance=SimpleTaskInstance.from_ti(ti), msg="Message" ) ] - dag_file_processor.execute_callbacks(dagbag, requests) + dag_file_processor.execute_callbacks(dagbag, requests, session) + mock_ti_handle_failure.assert_called_once_with( + error="Message", test_mode=conf.getboolean('core', 'unit_test_mode'), session=session + ) + + @pytest.mark.parametrize( + ["has_serialized_dag"], + [pytest.param(True, id="dag_in_db"), pytest.param(False, id="no_dag_found")], + ) + @patch.object(TaskInstance, 'handle_failure') + def test_execute_on_failure_callbacks_without_dag(self, mock_ti_handle_failure, has_serialized_dag): + dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False) + dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + with create_session() as session: + session.query(TaskInstance).delete() + dag = dagbag.get_dag('example_branch_operator') + dagrun = dag.create_dagrun( + state=State.RUNNING, + execution_date=DEFAULT_DATE, + run_type=DagRunType.SCHEDULED, + session=session, + ) + task = dag.get_task(task_id='run_this_first') + ti = TaskInstance(task, run_id=dagrun.run_id, state=State.QUEUED) + session.add(ti) + + if has_serialized_dag: + assert SerializedDagModel.write_dag(dag, session=session) is True + session.flush() + + requests = [ + TaskCallbackRequest( + full_filepath="A", simple_task_instance=SimpleTaskInstance.from_ti(ti), msg="Message" + ) + ] + dag_file_processor.execute_callbacks_without_dag(requests, session) mock_ti_handle_failure.assert_called_once_with( - error="Message", - test_mode=conf.getboolean('core', 'unit_test_mode'), + error="Message", test_mode=conf.getboolean('core', 'unit_test_mode'), session=session ) def test_failure_callbacks_should_not_drop_hostname(self): diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 39aa3f6cb895a..d60ab901cac88 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -2081,7 +2081,7 @@ def test_handle_failure_updates_queued_task_try_number(self, dag_maker): ti = TI(task=task, run_id=dr.run_id) ti.state = State.QUEUED session.merge(ti) - session.commit() + session.flush() assert ti.state == State.QUEUED assert ti.try_number == 1 ti.handle_failure("test queued ti", test_mode=True) @@ -2091,6 +2091,37 @@ def test_handle_failure_updates_queued_task_try_number(self, dag_maker): # Check 'ti.try_number' is bumped to 2. This is try_number for next run assert ti.try_number == 2 + @patch.object(Stats, 'incr') + def test_handle_failure_no_task(self, Stats_incr, dag_maker): + """ + When a zombie is detected for a DAG with a parse error, we need to be able to run handle_failure + _without_ ti.task being set + """ + session = settings.Session() + with dag_maker(): + task = EmptyOperator(task_id="mytask", retries=1) + dr = dag_maker.create_dagrun() + ti = TI(task=task, run_id=dr.run_id) + ti = session.merge(ti) + ti.task = None + ti.state = State.QUEUED + session.flush() + + assert ti.task is None, "Check critical pre-condition" + + assert ti.state == State.QUEUED + assert ti.try_number == 1 + + ti.handle_failure("test queued ti", test_mode=False) + assert ti.state == State.UP_FOR_RETRY + # Assert that 'ti._try_number' is bumped from 0 to 1. This is the last/current try + assert ti._try_number == 1 + # Check 'ti.try_number' is bumped to 2. This is try_number for next run + assert ti.try_number == 2 + + Stats_incr.assert_any_call('ti_failures') + Stats_incr.assert_any_call('operator_failures_EmptyOperator') + def test_does_not_retry_on_airflow_fail_exception(self, dag_maker): def fail(): raise AirflowFailException("hopeless")