diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index 6d8920f008d96..376aa27bb16f0 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -20,6 +20,7 @@ import itertools import logging import os +import re import tempfile import zipfile from datetime import time, timedelta @@ -29,7 +30,7 @@ from airflow import exceptions, settings from airflow.decorators import task as task_deco -from airflow.exceptions import AirflowException, AirflowSensorTimeout, TaskDeferred +from airflow.exceptions import AirflowException, AirflowSensorTimeout, AirflowSkipException, TaskDeferred from airflow.models import DagBag, DagRun, TaskInstance from airflow.models.dag import DAG from airflow.models.serialized_dag import SerializedDagModel @@ -838,6 +839,138 @@ def test_external_task_group_when_there_is_no_TIs(self): ignore_ti_state=True, ) + @pytest.mark.parametrize( + "kwargs, expected_message", + ( + ( + { + "external_task_ids": [TEST_TASK_ID, TEST_TASK_ID_ALTERNATE], + "failed_states": [State.FAILED], + }, + f"Some of the external tasks {re.escape(str([TEST_TASK_ID, TEST_TASK_ID_ALTERNATE]))}" + f" in DAG {TEST_DAG_ID} failed.", + ), + ( + { + "external_task_group_id": [TEST_TASK_ID, TEST_TASK_ID_ALTERNATE], + "failed_states": [State.FAILED], + }, + f"The external task_group '{re.escape(str([TEST_TASK_ID, TEST_TASK_ID_ALTERNATE]))}'" + f" in DAG '{TEST_DAG_ID}' failed.", + ), + ( + {"failed_states": [State.FAILED]}, + f"The external DAG {TEST_DAG_ID} failed.", + ), + ), + ) + @pytest.mark.parametrize( + "soft_fail, expected_exception", + ( + ( + False, + AirflowException, + ), + ( + True, + AirflowSkipException, + ), + ), + ) + @mock.patch("airflow.sensors.external_task.ExternalTaskSensor.get_count") + @mock.patch("airflow.sensors.external_task.ExternalTaskSensor._get_dttm_filter") + def test_fail_poke( + self, _get_dttm_filter, get_count, soft_fail, expected_exception, kwargs, expected_message + ): + _get_dttm_filter.return_value = [] + get_count.return_value = 1 + op = ExternalTaskSensor( + task_id="test_external_task_duplicate_task_ids", + external_dag_id=TEST_DAG_ID, + allowed_states=["success"], + dag=self.dag, + soft_fail=soft_fail, + deferrable=False, + **kwargs, + ) + with pytest.raises(expected_exception, match=expected_message): + op.execute(context={}) + + @pytest.mark.parametrize( + "response_get_current, response_exists, kwargs, expected_message", + ( + (None, None, {}, f"The external DAG {TEST_DAG_ID} does not exist."), + ( + DAG(dag_id="test"), + False, + {}, + f"The external DAG {TEST_DAG_ID} was deleted.", + ), + ( + DAG(dag_id="test"), + True, + {"external_task_ids": [TEST_TASK_ID, TEST_TASK_ID_ALTERNATE]}, + f"The external task {TEST_TASK_ID} in DAG {TEST_DAG_ID} does not exist.", + ), + ( + DAG(dag_id="test"), + True, + {"external_task_group_id": [TEST_TASK_ID, TEST_TASK_ID_ALTERNATE]}, + f"The external task group '{re.escape(str([TEST_TASK_ID, TEST_TASK_ID_ALTERNATE]))}'" + f" in DAG '{TEST_DAG_ID}' does not exist.", + ), + ), + ) + @pytest.mark.parametrize( + "soft_fail, expected_exception", + ( + ( + False, + AirflowException, + ), + ( + True, + AirflowSkipException, + ), + ), + ) + @mock.patch("airflow.sensors.external_task.ExternalTaskSensor._get_dttm_filter") + @mock.patch("airflow.models.dagbag.DagBag.get_dag") + @mock.patch("os.path.exists") + @mock.patch("airflow.models.dag.DagModel.get_current") + def test_fail__check_for_existence( + self, + get_current, + exists, + get_dag, + _get_dttm_filter, + soft_fail, + expected_exception, + response_get_current, + response_exists, + kwargs, + expected_message, + ): + _get_dttm_filter.return_value = [] + get_current.return_value = response_get_current + exists.return_value = response_exists + get_dag_response = mock.MagicMock() + get_dag.return_value = get_dag_response + get_dag_response.has_task.return_value = False + get_dag_response.has_task_group.return_value = False + op = ExternalTaskSensor( + task_id="test_external_task_duplicate_task_ids", + external_dag_id=TEST_DAG_ID, + allowed_states=["success"], + dag=self.dag, + soft_fail=soft_fail, + check_existence=True, + **kwargs, + ) + expected_message = "Skipping due to soft_fail is set to True." if soft_fail else expected_message + with pytest.raises(expected_exception, match=expected_message): + op.execute(context={}) + class TestExternalTaskAsyncSensor: TASK_ID = "external_task_sensor_check"