Skip to content
43 changes: 43 additions & 0 deletions airflow/sensors/external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import attr
from sqlalchemy import func

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowSkipException, RemovedInAirflow3Warning
from airflow.models.baseoperator import BaseOperatorLink
from airflow.models.dag import DagModel
Expand All @@ -33,11 +34,13 @@
from airflow.models.taskinstance import TaskInstance
from airflow.operators.empty import EmptyOperator
from airflow.sensors.base import BaseSensorOperator
from airflow.triggers.external_task import TaskStateTrigger
from airflow.utils.file import correct_maybe_zipped
from airflow.utils.helpers import build_airflow_url_with_query
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import tuple_in_condition
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.timezone import utcnow

if TYPE_CHECKING:
from sqlalchemy.orm import Query, Session
Expand Down Expand Up @@ -126,6 +129,8 @@ class ExternalTaskSensor(BaseSensorOperator):
external_task_id is not None) or check if the DAG to wait for exists (when
external_task_id is None), and immediately cease waiting if the external task
or DAG does not exist (default value: False).
:param poll_interval: polling period in seconds to check for the status
:param deferrable: Run sensor in deferrable mode
"""

template_fields = ["external_dag_id", "external_task_id", "external_task_ids", "external_task_group_id"]
Expand All @@ -145,9 +150,12 @@ def __init__(
execution_delta: datetime.timedelta | None = None,
execution_date_fn: Callable | None = None,
check_existence: bool = False,
poll_interval: float = 2.0,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
super().__init__(**kwargs)

self.allowed_states = list(allowed_states) if allowed_states else [TaskInstanceState.SUCCESS.value]
self.skipped_states = list(skipped_states) if skipped_states else []
self.failed_states = list(failed_states) if failed_states else []
Expand Down Expand Up @@ -211,6 +219,8 @@ def __init__(
self.external_task_group_id = external_task_group_id
self.check_existence = check_existence
self._has_checked_existence = False
self.deferrable = deferrable
self.poll_interval = poll_interval

def _get_dttm_filter(self, context):
if self.execution_delta:
Expand Down Expand Up @@ -318,6 +328,39 @@ def poke(self, context: Context, session: Session = NEW_SESSION) -> bool:
count_allowed = self.get_count(dttm_filter, session, self.allowed_states)
return count_allowed == len(dttm_filter)

def execute(self, context: Context) -> None:
"""
Airflow runs this method on the worker and defers using the triggers
if deferrable is set to True.
"""
if not self.deferrable:
super().execute(context)
else:
self.defer(
trigger=TaskStateTrigger(
dag_id=self.external_dag_id,
task_id=self.external_task_id,
execution_dates=self._get_dttm_filter(context),
states=self.allowed_states,
trigger_start_time=utcnow(),
poll_interval=self.poll_interval,
),
method_name="execute_complete",
)

def execute_complete(self, context, event=None):
"""Callback for when the trigger fires - returns immediately."""
if event["status"] == "success":
self.log.info("External task %s has executed successfully.", self.external_task_id)
return None
elif event["status"] == "timeout":
raise AirflowException("Dag was not started within 1 minute, assuming fail.")
else:
raise AirflowException(
"Error occurred while trying to retrieve task status. Please, check the "
"name of executed task and Dag."
)

def _check_for_existence(self, session) -> None:
dag_to_wait = DagModel.get_current(self.external_dag_id, session)

Expand Down
80 changes: 64 additions & 16 deletions airflow/triggers/external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from __future__ import annotations

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added Triggers for external_task, can you please rebase on Main.

For sensors, I would be grateful if you can help the community by maintaining compatibility with astronomer-providers -- There are community members as well as our users using it https://pypistats.org/packages/astronomer-providers

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for great implementation os the trigger, i have updated the sensor to use it.
However, after some testing i have found some moments that could be improved in the trigger:

  1. There is no possibility to process several external tasks at once, using external_task_ids parameter. For that reason ExternalTaskSensor now supports only processing one task per run. Can you please take a look on it and add the possibility?;
  2. If you pass name of the Dag that does not exist at all, the trigger will run infinitely trying to query task from that Dag. I have added one additional parameter to set a deadline for this search to prevent infinitive loop. If it will not find specific Dag for a minute, it will terminate.


import asyncio
import datetime
import typing
from datetime import datetime

from asgiref.sync import sync_to_async
from sqlalchemy import func
Expand All @@ -27,7 +27,8 @@
from airflow.models import DagRun, TaskInstance
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState
from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.timezone import utcnow


class TaskStateTrigger(BaseTrigger):
Expand All @@ -36,27 +37,36 @@ class TaskStateTrigger(BaseTrigger):

:param dag_id: The dag_id that contains the task you want to wait for
:param task_id: The task_id that contains the task you want to
wait for. If ``None`` (default value) the sensor waits for the DAG
wait for.
:param states: allowed states, default is ``['success']``

@OfSixes OfSixes Jul 6, 2023

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While having the default value be ["success"] would make a lot of sense, no default value is currently set for the states parameter.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, for the states parameter, as it is in the description, we have set it to ["success"] by default.
list(allowed_states) if allowed_states else [TaskInstanceState.SUCCESS.value].

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This default is set in the sensor though, in the trigger itself this behavior is not implemented. If someone now wanted to use the trigger seperately, the documentation for it is incorrect.

:param execution_dates:
:param execution_dates: task execution time interval
:param poll_interval: The time interval in seconds to check the state.
The default value is 5 sec.
:param trigger_start_time: time in Datetime format when the trigger was started. Is used
to control the execution of trigger to prevent infinite loop in case if specified name
of the dag does not exist in database. It will wait period of time equals _timeout_sec parameter
from the time, when the trigger was started and if the execution lasts more time than expected,
the trigger will terminate with 'timeout' status.
"""

def __init__(
self,
dag_id: str,
task_id: str,
states: list[str],
execution_dates: list[datetime.datetime],
poll_interval: float = 5.0,
execution_dates: list[datetime],
trigger_start_time: datetime,
states: list[str] | None = None,
task_id: str | None = None,
poll_interval: float = 2.0,
):
super().__init__()
self.dag_id = dag_id
self.task_id = task_id
self.states = states
self.execution_dates = execution_dates
self.poll_interval = poll_interval
self.trigger_start_time = trigger_start_time
self.states = states if states else [TaskInstanceState.SUCCESS.value]
self._timeout_sec = 60

def serialize(self) -> tuple[str, dict[str, typing.Any]]:
"""Serializes TaskStateTrigger arguments and classpath."""
Expand All @@ -68,17 +78,52 @@ def serialize(self) -> tuple[str, dict[str, typing.Any]]:
"states": self.states,
"execution_dates": self.execution_dates,
"poll_interval": self.poll_interval,
"trigger_start_time": self.trigger_start_time,
},
)

async def run(self) -> typing.AsyncIterator[TriggerEvent]:
"""Checks periodically in the database to see if the task exists and has hit one of the states."""
"""
Checks periodically in the database to see if the dag exists and is in the running state. If found,
wait until the task specified will reach one of the expected states. If dag with specified name was
not in the running state after _timeout_sec seconds after starting execution process of the trigger,
terminate with status 'timeout'.
"""
while True:
# mypy confuses typing here
num_tasks = await self.count_tasks() # type: ignore[call-arg]
if num_tasks == len(self.execution_dates):
yield TriggerEvent(True)
await asyncio.sleep(self.poll_interval)
try:
delta = utcnow() - self.trigger_start_time
if delta.total_seconds() < self._timeout_sec:
# mypy confuses typing here
if await self.count_running_dags() == 0: # type: ignore[call-arg]
self.log.info("Waiting for DAG to start execution...")
await asyncio.sleep(self.poll_interval)
else:
yield TriggerEvent({"status": "timeout"})
return

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think we need return here

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is no return statement then trigger logic won't work in python 3.7. Trigger will do it's logic of checking task status infinitely and won't stop even after yielding the event. Also source code of the base class says that it is required to return None after yield statement.
I couldn't find this in the documentation, but I found it by myself while was implementing trigger for python3.7. I think that return statement should be there so the interpreter would exit from the generator loop when specified event appears else we get an endless loop.

# mypy confuses typing here
if await self.count_tasks() == len(self.execution_dates): # type: ignore[call-arg]
yield TriggerEvent({"status": "success"})
return

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

self.log.info("Task is still running, sleeping for %s seconds...", self.poll_interval)
await asyncio.sleep(self.poll_interval)
except Exception:
yield TriggerEvent({"status": "failed"})
return

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above


@sync_to_async
@provide_session
def count_running_dags(self, session: Session):
"""Count how many dag instances in running state in the database."""
dags = (
session.query(func.count("*"))
.filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.execution_date.in_(self.execution_dates),
TaskInstance.state.in_(["running", "success"]),
)
.scalar()
)
return dags

@sync_to_async
@provide_session
Expand Down Expand Up @@ -112,7 +157,7 @@ def __init__(
self,
dag_id: str,
states: list[DagRunState],
execution_dates: list[datetime.datetime],
execution_dates: list[datetime],
poll_interval: float = 5.0,
):
super().__init__()
Expand All @@ -134,7 +179,10 @@ def serialize(self) -> tuple[str, dict[str, typing.Any]]:
)

async def run(self) -> typing.AsyncIterator[TriggerEvent]:
"""Checks periodically in the database to see if the dag run exists and has hit one of the states."""
"""
Checks periodically in the database to see if the dag run exists, and has
hit one of the states yet, or not.
"""
while True:
# mypy confuses typing here
num_dags = await self.count_dags() # type: ignore[call-arg]
Expand Down
9 changes: 9 additions & 0 deletions docs/apache-airflow/howto/operator/external_task_sensor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ via ``allowed_states`` and ``failed_states`` parameters.
:start-after: [START howto_operator_external_task_sensor]
:end-before: [END howto_operator_external_task_sensor]

Also for this action you can use sensor in the deferrable mode:

.. exampleinclude:: /../../tests/system/providers/core/example_external_task_parent_deferrable.py
:language: python
:dedent: 4
:start-after: [START howto_external_task_async_sensor]
:end-before: [END howto_external_task_async_sensor]


ExternalTaskSensor with task_group dependency
---------------------------------------------
In Addition, we can also use the :class:`~airflow.sensors.external_task.ExternalTaskSensor` to make tasks on a DAG
Expand Down
82 changes: 80 additions & 2 deletions tests/sensors/test_external_task_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,28 @@
import tempfile
import zipfile
from datetime import time, timedelta
from unittest import mock

import pytest

from airflow import exceptions, settings
from airflow.decorators import task as task_deco
from airflow.exceptions import AirflowException, AirflowSensorTimeout
from airflow.exceptions import AirflowException, AirflowSensorTimeout, TaskDeferred
from airflow.models import DagBag, DagRun, TaskInstance
from airflow.models.dag import DAG
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.xcom_arg import XComArg
from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import PythonOperator
from airflow.sensors.external_task import ExternalTaskMarker, ExternalTaskSensor, ExternalTaskSensorLink
from airflow.sensors.external_task import (
ExternalTaskMarker,
ExternalTaskSensor,
ExternalTaskSensorLink,
)
from airflow.sensors.time_sensor import TimeSensor
from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.triggers.external_task import TaskStateTrigger
from airflow.utils.hashlib_wrapper import md5
from airflow.utils.session import create_session, provide_session
from airflow.utils.state import DagRunState, State, TaskInstanceState
Expand All @@ -54,6 +60,9 @@
TEST_TASK_ID_ALTERNATE = "time_sensor_check_alternate"
TEST_TASK_GROUP_ID = "time_sensor_group_id"
DEV_NULL = "/dev/null"
TASK_ID = "external_task_sensor_check"
EXTERNAL_DAG_ID = "child_dag" # DAG the external task sensor is waiting on
EXTERNAL_TASK_ID = "child_task" # Task the external task sensor is waiting on


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -829,6 +838,75 @@ def test_external_task_group_when_there_is_no_TIs(self):
)


class TestExternalTaskAsyncSensor:
TASK_ID = "external_task_sensor_check"
EXTERNAL_DAG_ID = "child_dag" # DAG the external task sensor is waiting on
EXTERNAL_TASK_ID = "child_task" # Task the external task sensor is waiting on

def test_defer_and_fire_task_state_trigger(self):
"""
Asserts that a task is deferred and TaskStateTrigger will be fired
when the ExternalTaskAsyncSensor is provided with all required arguments
(i.e. including the external_task_id).
"""
sensor = ExternalTaskSensor(
task_id=TASK_ID,
external_task_id=EXTERNAL_TASK_ID,
external_dag_id=EXTERNAL_DAG_ID,
deferrable=True,
)

with pytest.raises(TaskDeferred) as exc:
sensor.execute(context=mock.MagicMock())

assert isinstance(exc.value.trigger, TaskStateTrigger), "Trigger is not a TaskStateTrigger"

def test_defer_and_fire_failed_state_trigger(self):
"""Tests that an AirflowException is raised in case of error event"""
sensor = ExternalTaskSensor(
task_id=TASK_ID,
external_task_id=EXTERNAL_TASK_ID,
external_dag_id=EXTERNAL_DAG_ID,
deferrable=True,
)

with pytest.raises(AirflowException):
sensor.execute_complete(
context=mock.MagicMock(), event={"status": "error", "message": "test failure message"}
)

def test_defer_and_fire_timeout_state_trigger(self):
"""Tests that an AirflowException is raised in case of timeout event"""
sensor = ExternalTaskSensor(
task_id=TASK_ID,
external_task_id=EXTERNAL_TASK_ID,
external_dag_id=EXTERNAL_DAG_ID,
deferrable=True,
)

with pytest.raises(AirflowException):
sensor.execute_complete(
context=mock.MagicMock(),
event={"status": "timeout", "message": "Dag was not started within 1 minute, assuming fail."},
)

def test_defer_execute_check_correct_logging(self):
"""Asserts that logging occurs as expected"""
sensor = ExternalTaskSensor(
task_id=TASK_ID,
external_task_id=EXTERNAL_TASK_ID,
external_dag_id=EXTERNAL_DAG_ID,
deferrable=True,
)

with mock.patch.object(sensor.log, "info") as mock_log_info:
sensor.execute_complete(
context=mock.MagicMock(),
event={"status": "success"},
)
mock_log_info.assert_called_with("External task %s has executed successfully.", EXTERNAL_TASK_ID)


def test_external_task_sensor_check_zipped_dag_existence(dag_zip_maker):
with dag_zip_maker("test_external_task_sensor_check_existense.py") as dagbag:
with create_session() as session:
Expand Down
16 changes: 16 additions & 0 deletions tests/system/providers/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
Loading