diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py index b5c1c1321f7ed..5e42820ffe997 100644 --- a/airflow/sensors/external_task.py +++ b/airflow/sensors/external_task.py @@ -25,6 +25,7 @@ import attr from sqlalchemy import func +from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowSkipException, RemovedInAirflow3Warning from airflow.models.baseoperator import BaseOperatorLink from airflow.models.dag import DagModel @@ -33,11 +34,13 @@ from airflow.models.taskinstance import TaskInstance from airflow.operators.empty import EmptyOperator from airflow.sensors.base import BaseSensorOperator +from airflow.triggers.external_task import TaskStateTrigger from airflow.utils.file import correct_maybe_zipped from airflow.utils.helpers import build_airflow_url_with_query from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import tuple_in_condition from airflow.utils.state import State, TaskInstanceState +from airflow.utils.timezone import utcnow if TYPE_CHECKING: from sqlalchemy.orm import Query, Session @@ -126,6 +129,8 @@ class ExternalTaskSensor(BaseSensorOperator): external_task_id is not None) or check if the DAG to wait for exists (when external_task_id is None), and immediately cease waiting if the external task or DAG does not exist (default value: False). + :param poll_interval: polling period in seconds to check for the status + :param deferrable: Run sensor in deferrable mode """ template_fields = ["external_dag_id", "external_task_id", "external_task_ids", "external_task_group_id"] @@ -145,9 +150,12 @@ def __init__( execution_delta: datetime.timedelta | None = None, execution_date_fn: Callable | None = None, check_existence: bool = False, + poll_interval: float = 2.0, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): super().__init__(**kwargs) + self.allowed_states = list(allowed_states) if allowed_states else [TaskInstanceState.SUCCESS.value] self.skipped_states = list(skipped_states) if skipped_states else [] self.failed_states = list(failed_states) if failed_states else [] @@ -211,6 +219,8 @@ def __init__( self.external_task_group_id = external_task_group_id self.check_existence = check_existence self._has_checked_existence = False + self.deferrable = deferrable + self.poll_interval = poll_interval def _get_dttm_filter(self, context): if self.execution_delta: @@ -318,6 +328,39 @@ def poke(self, context: Context, session: Session = NEW_SESSION) -> bool: count_allowed = self.get_count(dttm_filter, session, self.allowed_states) return count_allowed == len(dttm_filter) + def execute(self, context: Context) -> None: + """ + Airflow runs this method on the worker and defers using the triggers + if deferrable is set to True. + """ + if not self.deferrable: + super().execute(context) + else: + self.defer( + trigger=TaskStateTrigger( + dag_id=self.external_dag_id, + task_id=self.external_task_id, + execution_dates=self._get_dttm_filter(context), + states=self.allowed_states, + trigger_start_time=utcnow(), + poll_interval=self.poll_interval, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context, event=None): + """Callback for when the trigger fires - returns immediately.""" + if event["status"] == "success": + self.log.info("External task %s has executed successfully.", self.external_task_id) + return None + elif event["status"] == "timeout": + raise AirflowException("Dag was not started within 1 minute, assuming fail.") + else: + raise AirflowException( + "Error occurred while trying to retrieve task status. Please, check the " + "name of executed task and Dag." + ) + def _check_for_existence(self, session) -> None: dag_to_wait = DagModel.get_current(self.external_dag_id, session) diff --git a/airflow/triggers/external_task.py b/airflow/triggers/external_task.py index e739c7a7cb4e8..f179cba259ca4 100644 --- a/airflow/triggers/external_task.py +++ b/airflow/triggers/external_task.py @@ -17,8 +17,8 @@ from __future__ import annotations import asyncio -import datetime import typing +from datetime import datetime from asgiref.sync import sync_to_async from sqlalchemy import func @@ -27,7 +27,8 @@ from airflow.models import DagRun, TaskInstance from airflow.triggers.base import BaseTrigger, TriggerEvent from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.state import DagRunState +from airflow.utils.state import DagRunState, TaskInstanceState +from airflow.utils.timezone import utcnow class TaskStateTrigger(BaseTrigger): @@ -36,20 +37,26 @@ class TaskStateTrigger(BaseTrigger): :param dag_id: The dag_id that contains the task you want to wait for :param task_id: The task_id that contains the task you want to - wait for. If ``None`` (default value) the sensor waits for the DAG + wait for. :param states: allowed states, default is ``['success']`` - :param execution_dates: + :param execution_dates: task execution time interval :param poll_interval: The time interval in seconds to check the state. The default value is 5 sec. + :param trigger_start_time: time in Datetime format when the trigger was started. Is used + to control the execution of trigger to prevent infinite loop in case if specified name + of the dag does not exist in database. It will wait period of time equals _timeout_sec parameter + from the time, when the trigger was started and if the execution lasts more time than expected, + the trigger will terminate with 'timeout' status. """ def __init__( self, dag_id: str, - task_id: str, - states: list[str], - execution_dates: list[datetime.datetime], - poll_interval: float = 5.0, + execution_dates: list[datetime], + trigger_start_time: datetime, + states: list[str] | None = None, + task_id: str | None = None, + poll_interval: float = 2.0, ): super().__init__() self.dag_id = dag_id @@ -57,6 +64,9 @@ def __init__( self.states = states self.execution_dates = execution_dates self.poll_interval = poll_interval + self.trigger_start_time = trigger_start_time + self.states = states if states else [TaskInstanceState.SUCCESS.value] + self._timeout_sec = 60 def serialize(self) -> tuple[str, dict[str, typing.Any]]: """Serializes TaskStateTrigger arguments and classpath.""" @@ -68,17 +78,52 @@ def serialize(self) -> tuple[str, dict[str, typing.Any]]: "states": self.states, "execution_dates": self.execution_dates, "poll_interval": self.poll_interval, + "trigger_start_time": self.trigger_start_time, }, ) async def run(self) -> typing.AsyncIterator[TriggerEvent]: - """Checks periodically in the database to see if the task exists and has hit one of the states.""" + """ + Checks periodically in the database to see if the dag exists and is in the running state. If found, + wait until the task specified will reach one of the expected states. If dag with specified name was + not in the running state after _timeout_sec seconds after starting execution process of the trigger, + terminate with status 'timeout'. + """ while True: - # mypy confuses typing here - num_tasks = await self.count_tasks() # type: ignore[call-arg] - if num_tasks == len(self.execution_dates): - yield TriggerEvent(True) - await asyncio.sleep(self.poll_interval) + try: + delta = utcnow() - self.trigger_start_time + if delta.total_seconds() < self._timeout_sec: + # mypy confuses typing here + if await self.count_running_dags() == 0: # type: ignore[call-arg] + self.log.info("Waiting for DAG to start execution...") + await asyncio.sleep(self.poll_interval) + else: + yield TriggerEvent({"status": "timeout"}) + return + # mypy confuses typing here + if await self.count_tasks() == len(self.execution_dates): # type: ignore[call-arg] + yield TriggerEvent({"status": "success"}) + return + self.log.info("Task is still running, sleeping for %s seconds...", self.poll_interval) + await asyncio.sleep(self.poll_interval) + except Exception: + yield TriggerEvent({"status": "failed"}) + return + + @sync_to_async + @provide_session + def count_running_dags(self, session: Session): + """Count how many dag instances in running state in the database.""" + dags = ( + session.query(func.count("*")) + .filter( + TaskInstance.dag_id == self.dag_id, + TaskInstance.execution_date.in_(self.execution_dates), + TaskInstance.state.in_(["running", "success"]), + ) + .scalar() + ) + return dags @sync_to_async @provide_session @@ -112,7 +157,7 @@ def __init__( self, dag_id: str, states: list[DagRunState], - execution_dates: list[datetime.datetime], + execution_dates: list[datetime], poll_interval: float = 5.0, ): super().__init__() @@ -134,7 +179,10 @@ def serialize(self) -> tuple[str, dict[str, typing.Any]]: ) async def run(self) -> typing.AsyncIterator[TriggerEvent]: - """Checks periodically in the database to see if the dag run exists and has hit one of the states.""" + """ + Checks periodically in the database to see if the dag run exists, and has + hit one of the states yet, or not. + """ while True: # mypy confuses typing here num_dags = await self.count_dags() # type: ignore[call-arg] diff --git a/docs/apache-airflow/howto/operator/external_task_sensor.rst b/docs/apache-airflow/howto/operator/external_task_sensor.rst index 923f8ec3d1161..f6f53f87e55fa 100644 --- a/docs/apache-airflow/howto/operator/external_task_sensor.rst +++ b/docs/apache-airflow/howto/operator/external_task_sensor.rst @@ -53,6 +53,15 @@ via ``allowed_states`` and ``failed_states`` parameters. :start-after: [START howto_operator_external_task_sensor] :end-before: [END howto_operator_external_task_sensor] +Also for this action you can use sensor in the deferrable mode: + +.. exampleinclude:: /../../tests/system/providers/core/example_external_task_parent_deferrable.py + :language: python + :dedent: 4 + :start-after: [START howto_external_task_async_sensor] + :end-before: [END howto_external_task_async_sensor] + + ExternalTaskSensor with task_group dependency --------------------------------------------- In Addition, we can also use the :class:`~airflow.sensors.external_task.ExternalTaskSensor` to make tasks on a DAG diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index a5259084b13f6..e84b3f69f48e0 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -22,12 +22,13 @@ import tempfile import zipfile from datetime import time, timedelta +from unittest import mock import pytest from airflow import exceptions, settings from airflow.decorators import task as task_deco -from airflow.exceptions import AirflowException, AirflowSensorTimeout +from airflow.exceptions import AirflowException, AirflowSensorTimeout, TaskDeferred from airflow.models import DagBag, DagRun, TaskInstance from airflow.models.dag import DAG from airflow.models.serialized_dag import SerializedDagModel @@ -35,9 +36,14 @@ from airflow.operators.bash import BashOperator from airflow.operators.empty import EmptyOperator from airflow.operators.python import PythonOperator -from airflow.sensors.external_task import ExternalTaskMarker, ExternalTaskSensor, ExternalTaskSensorLink +from airflow.sensors.external_task import ( + ExternalTaskMarker, + ExternalTaskSensor, + ExternalTaskSensorLink, +) from airflow.sensors.time_sensor import TimeSensor from airflow.serialization.serialized_objects import SerializedBaseOperator +from airflow.triggers.external_task import TaskStateTrigger from airflow.utils.hashlib_wrapper import md5 from airflow.utils.session import create_session, provide_session from airflow.utils.state import DagRunState, State, TaskInstanceState @@ -54,6 +60,9 @@ TEST_TASK_ID_ALTERNATE = "time_sensor_check_alternate" TEST_TASK_GROUP_ID = "time_sensor_group_id" DEV_NULL = "/dev/null" +TASK_ID = "external_task_sensor_check" +EXTERNAL_DAG_ID = "child_dag" # DAG the external task sensor is waiting on +EXTERNAL_TASK_ID = "child_task" # Task the external task sensor is waiting on @pytest.fixture(autouse=True) @@ -829,6 +838,75 @@ def test_external_task_group_when_there_is_no_TIs(self): ) +class TestExternalTaskAsyncSensor: + TASK_ID = "external_task_sensor_check" + EXTERNAL_DAG_ID = "child_dag" # DAG the external task sensor is waiting on + EXTERNAL_TASK_ID = "child_task" # Task the external task sensor is waiting on + + def test_defer_and_fire_task_state_trigger(self): + """ + Asserts that a task is deferred and TaskStateTrigger will be fired + when the ExternalTaskAsyncSensor is provided with all required arguments + (i.e. including the external_task_id). + """ + sensor = ExternalTaskSensor( + task_id=TASK_ID, + external_task_id=EXTERNAL_TASK_ID, + external_dag_id=EXTERNAL_DAG_ID, + deferrable=True, + ) + + with pytest.raises(TaskDeferred) as exc: + sensor.execute(context=mock.MagicMock()) + + assert isinstance(exc.value.trigger, TaskStateTrigger), "Trigger is not a TaskStateTrigger" + + def test_defer_and_fire_failed_state_trigger(self): + """Tests that an AirflowException is raised in case of error event""" + sensor = ExternalTaskSensor( + task_id=TASK_ID, + external_task_id=EXTERNAL_TASK_ID, + external_dag_id=EXTERNAL_DAG_ID, + deferrable=True, + ) + + with pytest.raises(AirflowException): + sensor.execute_complete( + context=mock.MagicMock(), event={"status": "error", "message": "test failure message"} + ) + + def test_defer_and_fire_timeout_state_trigger(self): + """Tests that an AirflowException is raised in case of timeout event""" + sensor = ExternalTaskSensor( + task_id=TASK_ID, + external_task_id=EXTERNAL_TASK_ID, + external_dag_id=EXTERNAL_DAG_ID, + deferrable=True, + ) + + with pytest.raises(AirflowException): + sensor.execute_complete( + context=mock.MagicMock(), + event={"status": "timeout", "message": "Dag was not started within 1 minute, assuming fail."}, + ) + + def test_defer_execute_check_correct_logging(self): + """Asserts that logging occurs as expected""" + sensor = ExternalTaskSensor( + task_id=TASK_ID, + external_task_id=EXTERNAL_TASK_ID, + external_dag_id=EXTERNAL_DAG_ID, + deferrable=True, + ) + + with mock.patch.object(sensor.log, "info") as mock_log_info: + sensor.execute_complete( + context=mock.MagicMock(), + event={"status": "success"}, + ) + mock_log_info.assert_called_with("External task %s has executed successfully.", EXTERNAL_TASK_ID) + + def test_external_task_sensor_check_zipped_dag_existence(dag_zip_maker): with dag_zip_maker("test_external_task_sensor_check_existense.py") as dagbag: with create_session() as session: diff --git a/tests/system/providers/core/__init__.py b/tests/system/providers/core/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/system/providers/core/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/system/providers/core/example_external_task_child_deferrable.py b/tests/system/providers/core/example_external_task_child_deferrable.py new file mode 100644 index 0000000000000..f75eb4f23479f --- /dev/null +++ b/tests/system/providers/core/example_external_task_child_deferrable.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +from airflow import DAG +from airflow.operators.bash import BashOperator + +with DAG( + dag_id="child_dag", + start_date=datetime(2022, 1, 1), + schedule="@once", + catchup=False, + tags=["example", "async", "core"], +) as dag: + dummy_task = BashOperator( + task_id="child_task", + bash_command="echo 1; sleep 1; echo 2; sleep 2; echo 3; sleep 3", + ) + + +from tests.system.utils import get_test_run + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/tests/system/providers/core/example_external_task_parent_deferrable.py b/tests/system/providers/core/example_external_task_parent_deferrable.py new file mode 100644 index 0000000000000..7cec2ce13815a --- /dev/null +++ b/tests/system/providers/core/example_external_task_parent_deferrable.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow import DAG +from airflow.operators.dummy import DummyOperator +from airflow.operators.trigger_dagrun import TriggerDagRunOperator +from airflow.sensors.external_task import ExternalTaskSensor +from airflow.utils.timezone import datetime + +with DAG( + dag_id="example_external_task", + start_date=datetime(2022, 1, 1), + schedule="@once", + catchup=False, + tags=["example", "async", "core"], +) as dag: + start = DummyOperator(task_id="start") + + # [START howto_external_task_async_sensor] + external_task_sensor = ExternalTaskSensor( + task_id="parent_task_sensor", + external_task_id="child_task", + external_dag_id="child_dag", + deferrable=True, + ) + # [END howto_external_task_async_sensor] + + trigger_child_task = TriggerDagRunOperator( + task_id="trigger_child_task", + trigger_dag_id="child_dag", + allowed_states=[ + "success", + "failed", + ], + execution_date="{{execution_date}}", + poke_interval=5, + reset_dag_run=True, + wait_for_completion=True, + ) + + end = DummyOperator(task_id="end") + + start >> [trigger_child_task, external_task_sensor] >> end + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "teardown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests.system.utils import get_test_run + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/tests/triggers/test_external_task.py b/tests/triggers/test_external_task.py index e8a4a67ba5d57..a8569a9c558d1 100644 --- a/tests/triggers/test_external_task.py +++ b/tests/triggers/test_external_task.py @@ -26,6 +26,7 @@ from airflow.triggers.external_task import DagStateTrigger, TaskStateTrigger from airflow.utils import timezone from airflow.utils.state import DagRunState, TaskInstanceState +from airflow.utils.timezone import utcnow class TestTaskStateTrigger: @@ -40,6 +41,7 @@ async def test_task_state_trigger(self, session): Asserts that the TaskStateTrigger only goes off on or after a TaskInstance reaches an allowed state (i.e. SUCCESS). """ + trigger_start_time = utcnow() dag = DAG(self.DAG_ID, start_date=timezone.datetime(2022, 1, 1)) dag_run = DagRun( dag_id=dag.dag_id, @@ -61,6 +63,7 @@ async def test_task_state_trigger(self, session): states=self.STATES, execution_dates=[timezone.datetime(2022, 1, 1)], poll_interval=0.2, + trigger_start_time=trigger_start_time, ) task = asyncio.create_task(trigger.run().__anext__()) @@ -83,12 +86,14 @@ def test_serialization(self): Asserts that the TaskStateTrigger correctly serializes its arguments and classpath. """ + trigger_start_time = utcnow() trigger = TaskStateTrigger( dag_id=self.DAG_ID, task_id=self.TASK_ID, states=self.STATES, execution_dates=[timezone.datetime(2022, 1, 1)], poll_interval=5, + trigger_start_time=trigger_start_time, ) classpath, kwargs = trigger.serialize() assert classpath == "airflow.triggers.external_task.TaskStateTrigger" @@ -98,6 +103,7 @@ def test_serialization(self): "states": self.STATES, "execution_dates": [timezone.datetime(2022, 1, 1)], "poll_interval": 5, + "trigger_start_time": trigger_start_time, }