Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 78 additions & 15 deletions airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
from contextlib import redirect_stderr, redirect_stdout, suppress
from datetime import timedelta
from multiprocessing.connection import Connection as MultiprocessingConnection
from typing import Iterator, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Iterator, List, Optional, Set, Tuple

from setproctitle import setproctitle
from sqlalchemy import func, or_
from sqlalchemy import exc, func, or_
from sqlalchemy.orm.session import Session

from airflow import models, settings
Expand All @@ -52,6 +52,9 @@
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import State

if TYPE_CHECKING:
from airflow.models.operator import Operator

DR = models.DagRun
TI = models.TaskInstance

Expand Down Expand Up @@ -625,7 +628,7 @@ def execute_callbacks(
self.log.debug("Processing Callback Request: %s", request)
try:
if isinstance(request, TaskCallbackRequest):
self._execute_task_callbacks(dagbag, request)
self._execute_task_callbacks(dagbag, request, session=session)
elif isinstance(request, SlaCallbackRequest):
self.manage_slas(dagbag.get_dag(request.dag_id), session=session)
elif isinstance(request, DagCallbackRequest):
Expand All @@ -637,7 +640,27 @@ def execute_callbacks(
request.full_filepath,
)

session.commit()
session.flush()

def execute_callbacks_without_dag(
self, callback_requests: List[CallbackRequest], session: Session
) -> None:
"""
Execute what callbacks we can as "best effort" when the dag cannot be found/had parse errors.

This is so important so that tasks that failed when there is a parse

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
This is so important so that tasks that failed when there is a parse
This is important so that tasks that failed when there is a parse

error don't get stuck in queued state.
"""
for request in callback_requests:
self.log.debug("Processing Callback Request: %s", request)
if isinstance(request, TaskCallbackRequest):
self._execute_task_callbacks(None, request, session)
else:
self.log.info(
"Not executing %s callback for file %s as there was a dag parse error",
request.__class__.__name__,
request.full_filepath,
)

@provide_session
def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, session: Session):
Expand All @@ -647,26 +670,59 @@ def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, se
dagrun=dag_run, success=not request.is_failure_callback, reason=request.msg, session=session
)

def _execute_task_callbacks(self, dagbag: DagBag, request: TaskCallbackRequest):
def _execute_task_callbacks(
self, dagbag: Optional[DagBag], request: TaskCallbackRequest, session: Session
):
if not request.is_failure_callback:
return

simple_ti = request.simple_task_instance
if simple_ti.dag_id in dagbag.dags:
ti: Optional[TI] = (
session.query(TI)
.filter_by(
dag_id=simple_ti.dag_id,
run_id=simple_ti.run_id,
task_id=simple_ti.task_id,
map_index=simple_ti.map_index,
)
.one_or_none()
)
if not ti:
return

task: Optional["Operator"] = None

if dagbag and simple_ti.dag_id in dagbag.dags:
dag = dagbag.dags[simple_ti.dag_id]
if simple_ti.task_id in dag.task_ids:
task = dag.get_task(simple_ti.task_id)
if request.is_failure_callback:
ti = TI(task, run_id=simple_ti.run_id, map_index=simple_ti.map_index)
# TODO: Use simple_ti to improve performance here in the future
ti.refresh_from_db()
ti.handle_failure(error=request.msg, test_mode=self.UNIT_TEST_MODE)
self.log.info('Executed failure callback for %s in state %s', ti, ti.state)
else:
# We don't have the _real_ dag here (perhaps it had a parse error?) but we still want to run
# `handle_failure` so that the state of the TI gets progressed.
#
# Since handle_failure _really_ wants a task, we do our best effort to give it one
from airflow.models.serialized_dag import SerializedDagModel

try:
model = session.query(SerializedDagModel).get(simple_ti.dag_id)
if model:
task = model.dag.get_task(simple_ti.task_id)
except (exc.NoResultFound, TaskNotFound):
pass
if task:
ti.refresh_from_task(task)

ti.handle_failure(error=request.msg, test_mode=self.UNIT_TEST_MODE, session=session)
self.log.info('Executed failure callback for %s in state %s', ti, ti.state)
session.flush()

@provide_session
def process_file(
self,
file_path: str,
callback_requests: List[CallbackRequest],
pickle_dags: bool = False,
session: Session = None,
session: Session = NEW_SESSION,
) -> Tuple[int, int]:
"""
Process a Python file containing Airflow DAGs.
Expand Down Expand Up @@ -702,12 +758,19 @@ def process_file(
else:
self.log.warning("No viable dags retrieved from %s", file_path)
self.update_import_errors(session, dagbag)
if callback_requests:
# If there were callback requests for this file but there was a
# parse error we still need to progress the state of TIs,
# otherwise they might be stuck in queued/running for ever!
self.execute_callbacks_without_dag(callback_requests, session)
return 0, len(dagbag.import_errors)

self.execute_callbacks(dagbag, callback_requests)
self.execute_callbacks(dagbag, callback_requests, session)
session.commit()

# Save individual DAGs in the ORM
dagbag.sync_to_db()
dagbag.sync_to_db(session)
session.commit()

if pickle_dags:
paused_dag_ids = DagModel.get_paused_dag_ids(dag_ids=dagbag.dag_ids)
Expand Down
3 changes: 2 additions & 1 deletion airflow/models/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def __init__(self, event, task_instance=None, owner=None, extra=None, **kwargs):
self.task_id = task_instance.task_id
self.execution_date = task_instance.execution_date
self.map_index = task_instance.map_index
task_owner = task_instance.task.owner
if task_instance.task:
task_owner = task_instance.task.owner

if 'task_id' in kwargs:
self.task_id = kwargs['task_id']
Expand Down
17 changes: 11 additions & 6 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1845,9 +1845,6 @@ def handle_failure(
if test_mode is None:
test_mode = self.test_mode

if context is None:
context = self.get_template_context()

if error:
if isinstance(error, BaseException):
tb = self.get_truncated_error_traceback(error, truncate_to=self._execute_task)
Expand All @@ -1859,7 +1856,7 @@ def handle_failure(

self.end_date = timezone.utcnow()
self.set_duration()
Stats.incr(f'operator_failures_{self.task.task_type}')
Stats.incr(f'operator_failures_{self.operator}')
Stats.incr('ti_failures')
if not test_mode:
session.add(Log(State.FAILED, self))
Expand All @@ -1869,6 +1866,10 @@ def handle_failure(

self.clear_next_method_args()

# In extreme cases (zombie in case of dag with parse error) we might _not_ have a Task.
if context is None and self.task:
context = self.get_template_context(session)

if context is not None:
context['exception'] = error

Expand All @@ -1886,7 +1887,8 @@ def handle_failure(

task: Optional[BaseOperator] = None
try:
task = self.task.unmap((context, session))
if self.task and context:
task = self.task.unmap((context, session))
except Exception:
self.log.error("Unable to unmap task to determine if we need to send an alert email")

Expand All @@ -1911,7 +1913,7 @@ def handle_failure(
except Exception:
self.log.exception('Failed to send email to: %s', task.email)

if callback:
if callback and context:
self._run_finished_callback(callback, context, callback_type)

if not test_mode:
Expand All @@ -1924,6 +1926,9 @@ def is_eligible_to_retry(self):
# If a task is cleared when running, it goes into RESTARTING state and is always
# eligible for retry
return True
if not self.task:
# Couldn't load the task, don't know number of retries, guess:
return self.try_number <= self.max_tries

return self.task.retries and self.try_number <= self.max_tries

Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def cleanup(self):
if not dag_ids:
return
# To isolate problems here with problems from elsewhere on the session object
self.session.flush()
self.session.rollback()

self.session.query(SerializedDagModel).filter(
SerializedDagModel.dag_id.in_(dag_ids)
Expand Down
41 changes: 38 additions & 3 deletions tests/dag_processing/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from airflow.dag_processing.manager import DagFileProcessorAgent
from airflow.dag_processing.processor import DagFileProcessor
from airflow.models import DagBag, DagModel, SlaMiss, TaskInstance, errors
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import SimpleTaskInstance
from airflow.operators.empty import EmptyOperator
from airflow.utils import timezone
Expand Down Expand Up @@ -388,10 +389,44 @@ def test_execute_on_failure_callbacks(self, mock_ti_handle_failure):
full_filepath="A", simple_task_instance=SimpleTaskInstance.from_ti(ti), msg="Message"
)
]
dag_file_processor.execute_callbacks(dagbag, requests)
dag_file_processor.execute_callbacks(dagbag, requests, session)
mock_ti_handle_failure.assert_called_once_with(
error="Message", test_mode=conf.getboolean('core', 'unit_test_mode'), session=session
)

@pytest.mark.parametrize(
["has_serialized_dag"],
[pytest.param(True, id="dag_in_db"), pytest.param(False, id="no_dag_found")],
)
@patch.object(TaskInstance, 'handle_failure')
def test_execute_on_failure_callbacks_without_dag(self, mock_ti_handle_failure, has_serialized_dag):
dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False)
dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
with create_session() as session:
session.query(TaskInstance).delete()
dag = dagbag.get_dag('example_branch_operator')
dagrun = dag.create_dagrun(
state=State.RUNNING,
execution_date=DEFAULT_DATE,
run_type=DagRunType.SCHEDULED,
session=session,
)
task = dag.get_task(task_id='run_this_first')
ti = TaskInstance(task, run_id=dagrun.run_id, state=State.QUEUED)
session.add(ti)

if has_serialized_dag:
assert SerializedDagModel.write_dag(dag, session=session) is True
session.flush()

requests = [
TaskCallbackRequest(
full_filepath="A", simple_task_instance=SimpleTaskInstance.from_ti(ti), msg="Message"
)
]
dag_file_processor.execute_callbacks_without_dag(requests, session)
mock_ti_handle_failure.assert_called_once_with(
error="Message",
test_mode=conf.getboolean('core', 'unit_test_mode'),
error="Message", test_mode=conf.getboolean('core', 'unit_test_mode'), session=session
)

def test_failure_callbacks_should_not_drop_hostname(self):
Expand Down
33 changes: 32 additions & 1 deletion tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2081,7 +2081,7 @@ def test_handle_failure_updates_queued_task_try_number(self, dag_maker):
ti = TI(task=task, run_id=dr.run_id)
ti.state = State.QUEUED
session.merge(ti)
session.commit()
session.flush()
assert ti.state == State.QUEUED
assert ti.try_number == 1
ti.handle_failure("test queued ti", test_mode=True)
Expand All @@ -2091,6 +2091,37 @@ def test_handle_failure_updates_queued_task_try_number(self, dag_maker):
# Check 'ti.try_number' is bumped to 2. This is try_number for next run
assert ti.try_number == 2

@patch.object(Stats, 'incr')
def test_handle_failure_no_task(self, Stats_incr, dag_maker):
"""
When a zombie is detected for a DAG with a parse error, we need to be able to run handle_failure
_without_ ti.task being set
"""
session = settings.Session()
with dag_maker():
task = EmptyOperator(task_id="mytask", retries=1)
dr = dag_maker.create_dagrun()
ti = TI(task=task, run_id=dr.run_id)
ti = session.merge(ti)
ti.task = None
ti.state = State.QUEUED
session.flush()

assert ti.task is None, "Check critical pre-condition"

assert ti.state == State.QUEUED
assert ti.try_number == 1

ti.handle_failure("test queued ti", test_mode=False)
assert ti.state == State.UP_FOR_RETRY
# Assert that 'ti._try_number' is bumped from 0 to 1. This is the last/current try
assert ti._try_number == 1
# Check 'ti.try_number' is bumped to 2. This is try_number for next run
assert ti.try_number == 2

Stats_incr.assert_any_call('ti_failures')
Stats_incr.assert_any_call('operator_failures_EmptyOperator')

def test_does_not_retry_on_airflow_fail_exception(self, dag_maker):
def fail():
raise AirflowFailException("hopeless")
Expand Down