From fff8f3baa5074485e00a671c0436c648781c6840 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 20 Sep 2023 09:29:57 +0800 Subject: [PATCH] fix(providers/microsoft-azure): respect soft_fail argument when exception is raised --- .../microsoft/azure/sensors/data_factory.py | 17 ++++++++++++--- .../providers/microsoft/azure/sensors/wasb.py | 8 ++++++- .../azure/sensors/test_azure_data_factory.py | 18 +++++++++++----- .../microsoft/azure/sensors/test_wasb.py | 21 ++++++++++++------- 4 files changed, 47 insertions(+), 17 deletions(-) diff --git a/airflow/providers/microsoft/azure/sensors/data_factory.py b/airflow/providers/microsoft/azure/sensors/data_factory.py index b3e6947f0e5ba..5cede76ad9449 100644 --- a/airflow/providers/microsoft/azure/sensors/data_factory.py +++ b/airflow/providers/microsoft/azure/sensors/data_factory.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Sequence from airflow.configuration import conf -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.microsoft.azure.hooks.data_factory import ( AzureDataFactoryHook, AzureDataFactoryPipelineRunException, @@ -85,10 +85,18 @@ def poke(self, context: Context) -> bool: ) if pipeline_run_status == AzureDataFactoryPipelineRunStatus.FAILED: - raise AzureDataFactoryPipelineRunException(f"Pipeline run {self.run_id} has failed.") + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"Pipeline run {self.run_id} has failed." + if self.soft_fail: + raise AirflowSkipException(message) + raise AzureDataFactoryPipelineRunException(message) if pipeline_run_status == AzureDataFactoryPipelineRunStatus.CANCELLED: - raise AzureDataFactoryPipelineRunException(f"Pipeline run {self.run_id} has been cancelled.") + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"Pipeline run {self.run_id} has been cancelled." + if self.soft_fail: + raise AirflowSkipException(message) + raise AzureDataFactoryPipelineRunException(message) return pipeline_run_status == AzureDataFactoryPipelineRunStatus.SUCCEEDED @@ -122,6 +130,9 @@ def execute_complete(self, context: Context, event: dict[str, str]) -> None: """ if event: if event["status"] == "error": + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + if self.soft_fail: + raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) self.log.info(event["message"]) return None diff --git a/airflow/providers/microsoft/azure/sensors/wasb.py b/airflow/providers/microsoft/azure/sensors/wasb.py index 5abd88a129ac0..3cd3457527584 100644 --- a/airflow/providers/microsoft/azure/sensors/wasb.py +++ b/airflow/providers/microsoft/azure/sensors/wasb.py @@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Any, Sequence from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException from airflow.providers.microsoft.azure.hooks.wasb import WasbHook from airflow.providers.microsoft.azure.triggers.wasb import WasbBlobSensorTrigger, WasbPrefixSensorTrigger from airflow.sensors.base import BaseSensorOperator @@ -102,6 +102,9 @@ def execute_complete(self, context: Context, event: dict[str, str]) -> None: """ if event: if event["status"] == "error": + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + if self.soft_fail: + raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) self.log.info(event["message"]) else: @@ -203,6 +206,9 @@ def execute_complete(self, context: Context, event: dict[str, str]) -> None: """ if event: if event["status"] == "error": + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + if self.soft_fail: + raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) self.log.info(event["message"]) else: diff --git a/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py b/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py index ca67444b5a2f1..fb489c7ad7e6b 100644 --- a/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py +++ b/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException, TaskDeferred +from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred from airflow.providers.microsoft.azure.hooks.data_factory import ( AzureDataFactoryHook, AzureDataFactoryPipelineRunException, @@ -111,10 +111,14 @@ def test_adf_pipeline_status_sensor_execute_complete_success(self): self.defered_sensor.execute_complete(context={}, event={"status": "success", "message": msg}) mock_log_info.assert_called_with(msg) - def test_adf_pipeline_status_sensor_execute_complete_failure(self): + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) + def test_adf_pipeline_status_sensor_execute_complete_failure(self, soft_fail, expected_exception): """Assert execute_complete method fail""" - with pytest.raises(AirflowException): + self.defered_sensor.soft_fail = soft_fail + with pytest.raises(expected_exception): self.defered_sensor.execute_complete(context={}, event={"status": "error", "message": ""}) @@ -142,8 +146,12 @@ def test_adf_pipeline_status_sensor_execute_complete_success(self): self.SENSOR.execute_complete(context={}, event={"status": "success", "message": msg}) mock_log_info.assert_called_with(msg) - def test_adf_pipeline_status_sensor_execute_complete_failure(self): + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) + def test_adf_pipeline_status_sensor_execute_complete_failure(self, soft_fail, expected_exception): """Assert execute_complete method fail""" - with pytest.raises(AirflowException): + self.SENSOR.soft_fail = soft_fail + with pytest.raises(expected_exception): self.SENSOR.execute_complete(context={}, event={"status": "error", "message": ""}) diff --git a/tests/providers/microsoft/azure/sensors/test_wasb.py b/tests/providers/microsoft/azure/sensors/test_wasb.py index a94fd0c07b38a..96b24f8cc82cf 100644 --- a/tests/providers/microsoft/azure/sensors/test_wasb.py +++ b/tests/providers/microsoft/azure/sensors/test_wasb.py @@ -24,15 +24,12 @@ import pendulum import pytest -from airflow.exceptions import AirflowException, TaskDeferred +from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred from airflow.models import Connection from airflow.models.dag import DAG from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance -from airflow.providers.microsoft.azure.sensors.wasb import ( - WasbBlobSensor, - WasbPrefixSensor, -) +from airflow.providers.microsoft.azure.sensors.wasb import WasbBlobSensor, WasbPrefixSensor from airflow.providers.microsoft.azure.triggers.wasb import WasbBlobSensorTrigger, WasbPrefixSensorTrigger from airflow.utils import timezone from airflow.utils.types import DagRunType @@ -163,10 +160,14 @@ def test_wasb_blob_sensor_execute_complete_success(self, event): self.SENSOR.execute_complete(context={}, event=event) mock_log_info.assert_called_with(event["message"]) - def test_wasb_blob_sensor_execute_complete_failure(self): + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) + def test_wasb_blob_sensor_execute_complete_failure(self, soft_fail, expected_exception): """Assert execute_complete method raises an exception when the triggerer fires an error event.""" - with pytest.raises(AirflowException): + self.SENSOR.soft_fail = soft_fail + with pytest.raises(expected_exception): self.SENSOR.execute_complete(context={}, event={"status": "error", "message": ""}) @@ -288,8 +289,12 @@ def test_wasb_prefix_sensor_execute_complete_success(self, event): self.SENSOR.execute_complete(context={}, event=event) mock_log_info.assert_called_with(event["message"]) - def test_wasb_prefix_sensor_execute_complete_failure(self): + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) + def test_wasb_prefix_sensor_execute_complete_failure(self, soft_fail, expected_exception): """Assert execute_complete method raises an exception when the triggerer fires an error event.""" + self.SENSOR.soft_fail = soft_fail with pytest.raises(AirflowException): self.SENSOR.execute_complete(context={}, event={"status": "error", "message": ""})