From c732774e776490d4d9f199212efa40cfc34ecaf9 Mon Sep 17 00:00:00 2001 From: Phani Kumar Date: Tue, 28 Feb 2023 11:36:25 +0530 Subject: [PATCH 01/11] Add deferrable data factory sensor --- .../microsoft/azure/hooks/data_factory.py | 128 +++++++++++- .../microsoft/azure/sensors/data_factory.py | 65 ++++++- .../microsoft/azure/triggers/__init__.py | 16 ++ .../microsoft/azure/triggers/data_factory.py | 91 +++++++++ .../azure/hooks/test_azure_data_factory.py | 184 +++++++++++++++++- .../azure/sensors/test_azure_data_factory.py | 48 ++++- 6 files changed, 526 insertions(+), 6 deletions(-) create mode 100644 airflow/providers/microsoft/azure/triggers/__init__.py create mode 100644 airflow/providers/microsoft/azure/triggers/data_factory.py diff --git a/airflow/providers/microsoft/azure/hooks/data_factory.py b/airflow/providers/microsoft/azure/hooks/data_factory.py index a05ae87538867..0dde136de66f0 100644 --- a/airflow/providers/microsoft/azure/hooks/data_factory.py +++ b/airflow/providers/microsoft/azure/hooks/data_factory.py @@ -33,11 +33,17 @@ import inspect import time from functools import wraps -from typing import Any, Callable, Union +from typing import Any, Callable, TypeVar, Union, cast +from asgiref.sync import sync_to_async from azure.core.polling import LROPoller from azure.identity import ClientSecretCredential, DefaultAzureCredential +from azure.identity.aio import ( + ClientSecretCredential as AsyncClientSecretCredential, + DefaultAzureCredential as AsyncDefaultAzureCredential, +) from azure.mgmt.datafactory import DataFactoryManagementClient +from azure.mgmt.datafactory.aio import DataFactoryManagementClient as AsyncDataFactoryManagementClient from azure.mgmt.datafactory.models import ( CreateRunResponse, DataFlow, @@ -54,6 +60,9 @@ from airflow.typing_compat import TypedDict Credentials = Union[ClientSecretCredential, DefaultAzureCredential] +AsyncCredentials = Union[AsyncClientSecretCredential, AsyncDefaultAzureCredential] + +T = TypeVar("T", bound=Any) def provide_targeted_factory(func: Callable) -> Callable: @@ -1039,3 +1048,120 @@ def test_connection(self) -> tuple[bool, str]: return success except Exception as e: return False, str(e) + + +def provide_targeted_factory_async(func: T) -> T: + """ + Provide the targeted factory to the async decorated function in case it isn't specified. + + If ``resource_group_name`` or ``factory_name`` is not provided it defaults to the value specified in + the connection extras. + """ + signature = inspect.signature(func) + + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + bound_args = signature.bind(*args, **kwargs) + + async def bind_argument(arg: Any, default_key: str) -> None: + # Check if arg was not included in the function signature or, if it is, the value is not provided. + if arg not in bound_args.arguments or bound_args.arguments[arg] is None: + self = args[0] + conn = await sync_to_async(self.get_connection)(self.conn_id) + default_value = conn.extra_dejson.get(default_key) + if not default_value: + raise AirflowException("Could not determine the targeted data factory.") + + bound_args.arguments[arg] = conn.extra_dejson[default_key] + + await bind_argument("resource_group_name", "extra__azure_data_factory__resource_group_name") + await bind_argument("factory_name", "extra__azure_data_factory__factory_name") + + return await func(*bound_args.args, **bound_args.kwargs) + + return cast(T, wrapper) + + +class AzureDataFactoryAsyncHook(AzureDataFactoryHook): + """ + An Async Hook that connects to Azure DataFactory to perform pipeline operations + + :param azure_data_factory_conn_id: The :ref:`Azure Data Factory connection id`. + """ + + def __init__(self, azure_data_factory_conn_id: str): + self._async_conn: AsyncDataFactoryManagementClient = None + self.conn_id = azure_data_factory_conn_id + super().__init__(azure_data_factory_conn_id=azure_data_factory_conn_id) + + async def get_async_conn(self) -> AsyncDataFactoryManagementClient: + """Get async connection and connect to azure data factory""" + if self._conn is not None: + return self._conn + + conn = await sync_to_async(self.get_connection)(self.conn_id) + tenant = conn.extra_dejson.get("extra__azure_data_factory__tenantId") + + try: + subscription_id = conn.extra_dejson["extra__azure_data_factory__subscriptionId"] + except KeyError: + raise ValueError("A Subscription ID is required to connect to Azure Data Factory.") + + credential: AsyncCredentials + if conn.login is not None and conn.password is not None: + if not tenant: + raise ValueError("A Tenant ID is required when authenticating with Client ID and Secret.") + + credential = AsyncClientSecretCredential( + client_id=conn.login, client_secret=conn.password, tenant_id=tenant + ) + else: + credential = AsyncDefaultAzureCredential() + + return AsyncDataFactoryManagementClient( + credential=credential, + subscription_id=subscription_id, + ) + + @provide_targeted_factory_async + async def get_pipeline_run( + self, + run_id: str, + resource_group_name: str | None = None, + factory_name: str | None = None, + **config: Any, + ) -> PipelineRun: + """ + Connects to Azure Data Factory asynchronously to get the pipeline run details by run id + + :param run_id: The pipeline run identifier. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + """ + async with await self.get_async_conn() as client: + try: + pipeline_run = await client.pipeline_runs.get(resource_group_name, factory_name, run_id) + return pipeline_run + except Exception as e: + raise AirflowException(e) + + async def get_adf_pipeline_run_status( + self, run_id: str, resource_group_name: str | None = None, factory_name: str | None = None + ) -> str: + """ + Connects to Azure Data Factory asynchronously and gets the pipeline status by run_id + + :param run_id: The pipeline run identifier. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + """ + try: + pipeline_run = await self.get_pipeline_run( + run_id=run_id, + factory_name=factory_name, + resource_group_name=resource_group_name, + ) + status: str = pipeline_run.status + return status + except Exception as e: + raise AirflowException(e) diff --git a/airflow/providers/microsoft/azure/sensors/data_factory.py b/airflow/providers/microsoft/azure/sensors/data_factory.py index 9d550405a4d57..414ff3fb7f1fb 100644 --- a/airflow/providers/microsoft/azure/sensors/data_factory.py +++ b/airflow/providers/microsoft/azure/sensors/data_factory.py @@ -16,13 +16,19 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Sequence +import warnings +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Sequence +from airflow import AirflowException from airflow.providers.microsoft.azure.hooks.data_factory import ( AzureDataFactoryHook, AzureDataFactoryPipelineRunException, AzureDataFactoryPipelineRunStatus, ) +from airflow.providers.microsoft.azure.triggers.data_factory import ( + ADFPipelineRunStatusSensorTrigger, +) from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: @@ -78,3 +84,60 @@ def poke(self, context: Context) -> bool: raise AzureDataFactoryPipelineRunException(f"Pipeline run {self.run_id} has been cancelled.") return pipeline_run_status == AzureDataFactoryPipelineRunStatus.SUCCEEDED + + +class AzureDataFactoryPipelineRunStatusAsyncSensor(AzureDataFactoryPipelineRunStatusSensor): + """ + Checks the status of a pipeline run asynchronously. + + :param azure_data_factory_conn_id: The connection identifier for connecting to Azure Data Factory. + :param run_id: The pipeline run identifier. + :param resource_group_name: The resource group name. + :param factory_name: The data factory name. + :param poll_interval: polling period in seconds to check for the status + """ + + def __init__( + self, + *, + poll_interval: float = 5, + **kwargs: Any, + ): + # TODO: Remove once deprecated + if poll_interval: + self.poke_interval = poll_interval + warnings.warn( + "Argument `poll_interval` is deprecated and will be removed " + "in a future release. Please use `poke_interval` instead.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(**kwargs) + + def execute(self, context: Context) -> None: + """Defers trigger class to poll for state of the job run until + it reaches a failure state or success state + """ + self.defer( + timeout=timedelta(seconds=self.timeout), + trigger=ADFPipelineRunStatusSensorTrigger( + run_id=self.run_id, + azure_data_factory_conn_id=self.azure_data_factory_conn_id, + resource_group_name=self.resource_group_name, + factory_name=self.factory_name, + poke_interval=self.poke_interval, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict[str, str]) -> None: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event: + if event["status"] == "error": + raise AirflowException(event["message"]) + self.log.info(event["message"]) + return None diff --git a/airflow/providers/microsoft/azure/triggers/__init__.py b/airflow/providers/microsoft/azure/triggers/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/microsoft/azure/triggers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/microsoft/azure/triggers/data_factory.py b/airflow/providers/microsoft/azure/triggers/data_factory.py new file mode 100644 index 0000000000000..246696451d135 --- /dev/null +++ b/airflow/providers/microsoft/azure/triggers/data_factory.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from typing import Any, AsyncIterator + +from airflow.providers.microsoft.azure.hooks.data_factory import ( + AzureDataFactoryAsyncHook, + AzureDataFactoryPipelineRunStatus, +) +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class ADFPipelineRunStatusSensorTrigger(BaseTrigger): + """ + ADFPipelineRunStatusSensorTrigger is fired as deferred class with params to run the + task in trigger worker, when ADF Pipeline is running + + :param run_id: The pipeline run identifier. + :param azure_data_factory_conn_id: The connection identifier for connecting to Azure Data Factory. + :param poke_interval: polling period in seconds to check for the status + :param resource_group_name: The resource group name. + :param factory_name: The data factory name. + """ + + def __init__( + self, + run_id: str, + azure_data_factory_conn_id: str, + poke_interval: float, + resource_group_name: str | None = None, + factory_name: str | None = None, + ): + super().__init__() + self.run_id = run_id + self.azure_data_factory_conn_id = azure_data_factory_conn_id + self.resource_group_name = resource_group_name + self.factory_name = factory_name + self.poke_interval = poke_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serializes ADFPipelineRunStatusSensorTrigger arguments and classpath.""" + return ( + "airflow.providers.microsoft.azure.triggers.data_factory.ADFPipelineRunStatusSensorTrigger", + { + "run_id": self.run_id, + "azure_data_factory_conn_id": self.azure_data_factory_conn_id, + "resource_group_name": self.resource_group_name, + "factory_name": self.factory_name, + "poke_interval": self.poke_interval, + }, + ) + + async def run(self) -> AsyncIterator["TriggerEvent"]: + """Make async connection to Azure Data Factory, polls for the pipeline run status""" + hook = AzureDataFactoryAsyncHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id) + try: + while True: + pipeline_status = await hook.get_adf_pipeline_run_status( + run_id=self.run_id, + resource_group_name=self.resource_group_name, + factory_name=self.factory_name, + ) + if pipeline_status == AzureDataFactoryPipelineRunStatus.FAILED: + yield TriggerEvent( + {"status": "error", "message": f"Pipeline run {self.run_id} has Failed."} + ) + elif pipeline_status == AzureDataFactoryPipelineRunStatus.CANCELLED: + msg = f"Pipeline run {self.run_id} has been Cancelled." + yield TriggerEvent({"status": "error", "message": msg}) + elif pipeline_status == AzureDataFactoryPipelineRunStatus.SUCCEEDED: + msg = f"Pipeline run {self.run_id} has been Succeeded." + yield TriggerEvent({"status": "success", "message": msg}) + await asyncio.sleep(self.poke_interval) + except Exception as e: + yield TriggerEvent({"status": "error", "message": str(e)}) diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py index c24b3edc318f6..f8eba31b41370 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py @@ -18,16 +18,19 @@ import json import os +from unittest import mock from unittest.mock import MagicMock, PropertyMock, patch import pytest from azure.identity import ClientSecretCredential, DefaultAzureCredential +from azure.mgmt.datafactory.aio import DataFactoryManagementClient from azure.mgmt.datafactory.models import FactoryListResponse from pytest import fixture, param -from airflow.exceptions import AirflowException +from airflow import AirflowException from airflow.models.connection import Connection from airflow.providers.microsoft.azure.hooks.data_factory import ( + AzureDataFactoryAsyncHook, AzureDataFactoryHook, AzureDataFactoryPipelineRunException, AzureDataFactoryPipelineRunStatus, @@ -36,12 +39,18 @@ from airflow.utils import db DEFAULT_RESOURCE_GROUP = "defaultResourceGroup" -RESOURCE_GROUP = "testResourceGroup" +AZURE_DATA_FACTORY_CONN_ID = "azure_data_factory_default" +RESOURCE_GROUP_NAME = "team_provider_resource_group_test" +DATAFACTORY_NAME = "ADFProvidersTeamDataFactory" +TASK_ID = "test_adf_pipeline_run_status_sensor" +RUN_ID = "7f8c6c72-c093-11ec-a83d-0242ac120007" +DEFAULT_CONNECTION_CLIENT_SECRET = "azure_data_factory_test_client_secret" +MODULE = "astronomer.providers.microsoft.azure" +RESOURCE_GROUP = "testResourceGroup" DEFAULT_FACTORY = "defaultFactory" FACTORY = "testFactory" -DEFAULT_CONNECTION_CLIENT_SECRET = "azure_data_factory_test_client_secret" DEFAULT_CONNECTION_DEFAULT_CREDENTIAL = "azure_data_factory_test_default_credential" MODEL = object() @@ -708,3 +717,172 @@ def test_backcompat_prefix_both_prefers_short(mock_connect): hook = AzureDataFactoryHook("my_conn") hook.delete_factory(factory_name="n/a") mock_connect.return_value.factories.delete.assert_called_with("non-prefixed", "n/a") + + +class TestAzureDataFactoryAsyncHook: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_status", + ["Queued", "InProgress", "Succeeded", "Failed", "Cancelled"], + ) + @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_async_conn") + @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_pipeline_run") + async def test_get_adf_pipeline_run_status(self, mock_get_pipeline_run, mock_conn, mock_status): + """Test get_adf_pipeline_run_status function with mocked status""" + mock_get_pipeline_run.return_value.status = mock_status + hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) + response = await hook.get_adf_pipeline_run_status(RUN_ID, RESOURCE_GROUP_NAME, DATAFACTORY_NAME) + assert response == mock_status + + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_async_conn") + @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_pipeline_run") + async def test_get_adf_pipeline_run_status_exception(self, mock_get_pipeline_run, mock_conn): + """Test get_adf_pipeline_run_status function with exception""" + mock_get_pipeline_run.side_effect = Exception("Test exception") + hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) + with pytest.raises(AirflowException): + await hook.get_adf_pipeline_run_status(RUN_ID, RESOURCE_GROUP_NAME, DATAFACTORY_NAME) + + @pytest.mark.asyncio + @mock.patch("azure.mgmt.datafactory.models._models_py3.PipelineRun") + @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_async_conn") + async def test_get_pipeline_run(self, mock_async_connection, mock_pipeline_run): + """Test get_pipeline_run run function""" + mock_async_connection.return_value.__aenter__.return_value.pipeline_runs.get.return_value = ( + mock_pipeline_run + ) + hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) + response = await hook.get_pipeline_run(RUN_ID, RESOURCE_GROUP_NAME, DATAFACTORY_NAME) + assert response == mock_pipeline_run + + @pytest.mark.asyncio + @mock.patch("azure.mgmt.datafactory.models._models_py3.PipelineRun") + @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_connection") + @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_async_conn") + async def test_get_pipeline_run_without_resource_name( + self, mock_async_connection, mock_get_connection, mock_pipeline_run + ): + """Test get_pipeline_run run function without passing the resource name + to check the decorator function""" + mock_connection = Connection( + extra=json.dumps( + { + "extra__azure_data_factory__resource_group_name": RESOURCE_GROUP_NAME, + "extra__azure_data_factory__factory_name": DATAFACTORY_NAME, + } + ) + ) + mock_get_connection.return_value = mock_connection + mock_async_connection.return_value.__aenter__.return_value.pipeline_runs.get.return_value = ( + mock_pipeline_run + ) + hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) + response = await hook.get_pipeline_run(RUN_ID, None, DATAFACTORY_NAME) + assert response == mock_pipeline_run + + @pytest.mark.asyncio + @mock.patch("azure.mgmt.datafactory.models._models_py3.PipelineRun") + @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_connection") + @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_async_conn") + async def test_get_pipeline_run_exception_without_resource( + self, mock_conn, mock_get_connection, mock_pipeline_run + ): + """ + Test get_pipeline_run function without passing the resource name to check the decorator function and + raise exception + """ + mock_connection = Connection( + extra=json.dumps({"extra__azure_data_factory__factory_name": DATAFACTORY_NAME}) + ) + mock_get_connection.return_value = mock_connection + mock_conn.return_value.__aenter__.return_value.pipeline_runs.get.return_value = mock_pipeline_run + hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) + with pytest.raises(AirflowException): + await hook.get_pipeline_run(RUN_ID, None, DATAFACTORY_NAME) + + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_async_conn") + async def test_get_pipeline_run_exception(self, mock_conn): + """Test get_pipeline_run function with exception""" + mock_conn.return_value.__aenter__.return_value.pipeline_runs.get.side_effect = Exception( + "Test exception" + ) + hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) + with pytest.raises(AirflowException): + await hook.get_pipeline_run(RUN_ID, RESOURCE_GROUP_NAME, DATAFACTORY_NAME) + + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_connection") + async def test_get_async_conn(self, mock_connection): + """""" + mock_conn = Connection( + conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, + conn_type="azure_data_factory", + login="clientId", + password="clientSecret", + extra=json.dumps( + { + "extra__azure_data_factory__tenantId": "tenantId", + "extra__azure_data_factory__subscriptionId": "subscriptionId", + "extra__azure_data_factory__resource_group_name": RESOURCE_GROUP_NAME, + "extra__azure_data_factory__factory_name": DATAFACTORY_NAME, + } + ), + ) + mock_connection.return_value = mock_conn + hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) + response = await hook.get_async_conn() + assert isinstance(response, DataFactoryManagementClient) + + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_connection") + async def test_get_async_conn_without_login_id(self, mock_connection): + """Test get_async_conn function without login id""" + mock_conn = Connection( + conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, + conn_type="azure_data_factory", + extra=json.dumps( + { + "extra__azure_data_factory__tenantId": "tenantId", + "extra__azure_data_factory__subscriptionId": "subscriptionId", + "extra__azure_data_factory__resource_group_name": RESOURCE_GROUP_NAME, + "extra__azure_data_factory__factory_name": DATAFACTORY_NAME, + } + ), + ) + mock_connection.return_value = mock_conn + hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) + response = await hook.get_async_conn() + assert isinstance(response, DataFactoryManagementClient) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_connection_params", + [ + { + "extra__azure_data_factory__tenantId": "tenantId", + "extra__azure_data_factory__resource_group_name": RESOURCE_GROUP_NAME, + "extra__azure_data_factory__factory_name": DATAFACTORY_NAME, + }, + { + "extra__azure_data_factory__subscriptionId": "subscriptionId", + "extra__azure_data_factory__resource_group_name": RESOURCE_GROUP_NAME, + "extra__azure_data_factory__factory_name": DATAFACTORY_NAME, + }, + ], + ) + @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_connection") + async def test_get_async_conn_key_error(self, mock_connection, mock_connection_params): + """Test get_async_conn function with raising key error""" + mock_conn = Connection( + conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, + conn_type="azure_data_factory", + login="clientId", + password="clientSecret", + extra=json.dumps(mock_connection_params), + ) + mock_connection.return_value = mock_conn + hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) + with pytest.raises(ValueError): + await hook.get_async_conn() diff --git a/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py b/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py index f695efacefb24..d4413b07ae3c7 100644 --- a/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py +++ b/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py @@ -16,16 +16,22 @@ # under the License. from __future__ import annotations +from unittest import mock from unittest.mock import patch import pytest +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.microsoft.azure.hooks.data_factory import ( AzureDataFactoryHook, AzureDataFactoryPipelineRunException, AzureDataFactoryPipelineRunStatus, ) -from airflow.providers.microsoft.azure.sensors.data_factory import AzureDataFactoryPipelineRunStatusSensor +from airflow.providers.microsoft.azure.sensors.data_factory import ( + AzureDataFactoryPipelineRunStatusAsyncSensor, + AzureDataFactoryPipelineRunStatusSensor, +) +from airflow.providers.microsoft.azure.triggers.data_factory import ADFPipelineRunStatusSensorTrigger class TestPipelineRunStatusSensor: @@ -74,3 +80,43 @@ def test_poke(self, mock_pipeline_run, pipeline_run_status, expected_status): with pytest.raises(AzureDataFactoryPipelineRunException, match=error_message): self.sensor.poke({}) + + +class TestAzureDataFactoryPipelineRunStatusAsyncSensor: + RUN_ID = "7f8c6c72-c093-11ec-a83d-0242ac120007" + SENSOR = AzureDataFactoryPipelineRunStatusAsyncSensor( + task_id="pipeline_run_sensor_async", + run_id=RUN_ID, + ) + + def test_adf_pipeline_status_sensor_async(self): + """Assert execute method defer for Azure Data factory pipeline run status sensor""" + + with pytest.raises(TaskDeferred) as exc: + self.SENSOR.execute({}) + assert isinstance( + exc.value.trigger, ADFPipelineRunStatusSensorTrigger + ), "Trigger is not a ADFPipelineRunStatusSensorTrigger" + + def test_adf_pipeline_status_sensor_execute_complete_success(self): + """Assert execute_complete log success message when trigger fire with target status""" + + msg = f"Pipeline run {self.RUN_ID} has been succeeded." + with mock.patch.object(self.SENSOR.log, "info") as mock_log_info: + self.SENSOR.execute_complete(context={}, event={"status": "success", "message": msg}) + mock_log_info.assert_called_with(msg) + + def test_adf_pipeline_status_sensor_execute_complete_failure(self): + """Assert execute_complete method fail""" + + with pytest.raises(AirflowException): + self.SENSOR.execute_complete(context={}, event={"status": "error", "message": ""}) + + def test_poll_interval_deprecation_warning(self): + """Test DeprecationWarning for AzureDataFactoryPipelineRunStatusAsyncSensor + by setting param poll_interval""" + # TODO: Remove once deprecated + with pytest.warns(expected_warning=DeprecationWarning): + AzureDataFactoryPipelineRunStatusAsyncSensor( + task_id="pipeline_run_sensor_async", run_id=self.RUN_ID, poll_interval=5.0 + ) From d1cac5c61c35819d1d5ccceeca4f5a4b24b5e93c Mon Sep 17 00:00:00 2001 From: Phani Kumar Date: Tue, 28 Feb 2023 11:47:15 +0530 Subject: [PATCH 02/11] Add deferrable data factory sensor --- .../providers/microsoft/azure/hooks/test_azure_data_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py index f8eba31b41370..42e214820765e 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py @@ -45,7 +45,7 @@ TASK_ID = "test_adf_pipeline_run_status_sensor" RUN_ID = "7f8c6c72-c093-11ec-a83d-0242ac120007" DEFAULT_CONNECTION_CLIENT_SECRET = "azure_data_factory_test_client_secret" -MODULE = "astronomer.providers.microsoft.azure" +MODULE = "airflow.providers.microsoft.azure" RESOURCE_GROUP = "testResourceGroup" DEFAULT_FACTORY = "defaultFactory" From a4c82d7ac1334d66554ddf764556027f5d5637a0 Mon Sep 17 00:00:00 2001 From: Phani Kumar Date: Tue, 28 Feb 2023 13:08:12 +0530 Subject: [PATCH 03/11] Add tests --- .../microsoft/azure/triggers/__init__.py | 16 ++ .../azure/triggers/test_azure_data_factory.py | 138 ++++++++++++++++++ 2 files changed, 154 insertions(+) create mode 100644 tests/providers/microsoft/azure/triggers/__init__.py create mode 100644 tests/providers/microsoft/azure/triggers/test_azure_data_factory.py diff --git a/tests/providers/microsoft/azure/triggers/__init__.py b/tests/providers/microsoft/azure/triggers/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/microsoft/azure/triggers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/microsoft/azure/triggers/test_azure_data_factory.py b/tests/providers/microsoft/azure/triggers/test_azure_data_factory.py new file mode 100644 index 0000000000000..2c76dbe678baa --- /dev/null +++ b/tests/providers/microsoft/azure/triggers/test_azure_data_factory.py @@ -0,0 +1,138 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import sys +import time + +import pytest + +from airflow.providers.microsoft.azure.triggers.data_factory import ( + ADFPipelineRunStatusSensorTrigger, +) +from airflow.triggers.base import TriggerEvent + +if sys.version_info < (3, 8): + # For compatibility with Python 3.7 + from asynctest import mock as async_mock +else: + from unittest import mock as async_mock + +RESOURCE_GROUP_NAME = "team_provider_resource_group_test" +DATAFACTORY_NAME = "ADFProvidersTeamDataFactory" +AZURE_DATA_FACTORY_CONN_ID = "azure_data_factory_default" +RUN_ID = "7f8c6c72-c093-11ec-a83d-0242ac120007" +POKE_INTERVAL = 5 +AZ_PIPELINE_RUN_ID = "123" +AZ_RESOURCE_GROUP_NAME = "test-rg" +AZ_FACTORY_NAME = "test-factory" +AZ_DATA_FACTORY_CONN_ID = "test-conn" +AZ_PIPELINE_END_TIME = time.time() + 60 * 60 * 24 * 7 +MODULE = "airflow.providers.microsoft.azure" + + +class TestADFPipelineRunStatusSensorTrigger: + TRIGGER = ADFPipelineRunStatusSensorTrigger( + run_id=RUN_ID, + azure_data_factory_conn_id=AZURE_DATA_FACTORY_CONN_ID, + resource_group_name=RESOURCE_GROUP_NAME, + factory_name=DATAFACTORY_NAME, + poke_interval=POKE_INTERVAL, + ) + + def test_adf_pipeline_run_status_sensors_trigger_serialization(self): + """ + Asserts that the TaskStateTrigger correctly serializes its arguments + and classpath. + """ + + classpath, kwargs = self.TRIGGER.serialize() + assert classpath == f"{MODULE}.triggers.data_factory.ADFPipelineRunStatusSensorTrigger" + assert kwargs == { + "run_id": RUN_ID, + "azure_data_factory_conn_id": AZURE_DATA_FACTORY_CONN_ID, + "resource_group_name": RESOURCE_GROUP_NAME, + "factory_name": DATAFACTORY_NAME, + "poke_interval": POKE_INTERVAL, + } + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_status", + [ + "Queued", + "InProgress", + ], + ) + @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_adf_pipeline_run_status") + async def test_adf_pipeline_run_status_sensors_trigger_run(self, mock_data_factory, mock_status): + """ + Test if the task is run is in trigger successfully. + """ + mock_data_factory.return_value = mock_status + + task = asyncio.create_task(self.TRIGGER.run().__anext__()) + await asyncio.sleep(0.5) + + # TriggerEvent was not returned + assert task.done() is False + asyncio.get_event_loop().stop() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_status", + ["Succeeded"], + ) + @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_adf_pipeline_run_status") + async def test_adf_pipeline_run_status_sensors_trigger_completed(self, mock_data_factory, mock_status): + """Test if the task pipeline status is in succeeded status.""" + mock_data_factory.return_value = mock_status + + generator = self.TRIGGER.run() + actual = await generator.asend(None) + msg = f"Pipeline run {RUN_ID} has been Succeeded." + assert TriggerEvent({"status": "success", "message": msg}) == actual + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_status, mock_message", + [ + ("Failed", f"Pipeline run {RUN_ID} has Failed."), + ("Cancelled", f"Pipeline run {RUN_ID} has been Cancelled."), + ], + ) + @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_adf_pipeline_run_status") + async def test_adf_pipeline_run_status_sensors_trigger_failure_status( + self, mock_data_factory, mock_status, mock_message + ): + """Test if the task is run is in trigger failure status.""" + mock_data_factory.return_value = mock_status + + generator = self.TRIGGER.run() + actual = await generator.asend(None) + assert TriggerEvent({"status": "error", "message": mock_message}) == actual + + @pytest.mark.asyncio + @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_adf_pipeline_run_status") + async def test_adf_pipeline_run_status_sensors_trigger_exception(self, mock_data_factory): + """Test EMR container sensors with raise exception""" + mock_data_factory.side_effect = Exception("Test exception") + + task = [i async for i in self.TRIGGER.run()] + assert len(task) == 1 + assert TriggerEvent({"status": "error", "message": "Test exception"}) in task From 87f0579b2353909f02a1bbf3ff8e02ceb95f4350 Mon Sep 17 00:00:00 2001 From: Phani Kumar Date: Tue, 28 Feb 2023 17:48:28 +0530 Subject: [PATCH 04/11] Fix tests --- .../microsoft/azure/hooks/data_factory.py | 2 + .../azure/hooks/test_azure_data_factory.py | 152 ++++++++++++------ .../azure/triggers/test_azure_data_factory.py | 43 ++++- 3 files changed, 143 insertions(+), 54 deletions(-) diff --git a/airflow/providers/microsoft/azure/hooks/data_factory.py b/airflow/providers/microsoft/azure/hooks/data_factory.py index 0dde136de66f0..0b4bb9e7954d6 100644 --- a/airflow/providers/microsoft/azure/hooks/data_factory.py +++ b/airflow/providers/microsoft/azure/hooks/data_factory.py @@ -1140,7 +1140,9 @@ async def get_pipeline_run( """ async with await self.get_async_conn() as client: try: + print("I am here1") pipeline_run = await client.pipeline_runs.get(resource_group_name, factory_name, run_id) + print("I am here2") return pipeline_run except Exception as e: raise AirflowException(e) diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py index 42e214820765e..02dda16beca97 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py @@ -18,7 +18,7 @@ import json import os -from unittest import mock +import sys from unittest.mock import MagicMock, PropertyMock, patch import pytest @@ -38,6 +38,12 @@ ) from airflow.utils import db +if sys.version_info < (3, 8): + # For compatibility with Python 3.7 + from asynctest import mock as async_mock +else: + from unittest import mock as async_mock + DEFAULT_RESOURCE_GROUP = "defaultResourceGroup" AZURE_DATA_FACTORY_CONN_ID = "azure_data_factory_default" RESOURCE_GROUP_NAME = "team_provider_resource_group_test" @@ -723,11 +729,11 @@ class TestAzureDataFactoryAsyncHook: @pytest.mark.asyncio @pytest.mark.parametrize( "mock_status", - ["Queued", "InProgress", "Succeeded", "Failed", "Cancelled"], + ["Queued"], ) - @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_async_conn") - @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_pipeline_run") - async def test_get_adf_pipeline_run_status(self, mock_get_pipeline_run, mock_conn, mock_status): + @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn") + @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run") + async def test_get_adf_pipeline_run_status_queued(self, mock_get_pipeline_run, mock_conn, mock_status): """Test get_adf_pipeline_run_status function with mocked status""" mock_get_pipeline_run.return_value.status = mock_status hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) @@ -735,56 +741,77 @@ async def test_get_adf_pipeline_run_status(self, mock_get_pipeline_run, mock_con assert response == mock_status @pytest.mark.asyncio - @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_async_conn") - @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_pipeline_run") - async def test_get_adf_pipeline_run_status_exception(self, mock_get_pipeline_run, mock_conn): - """Test get_adf_pipeline_run_status function with exception""" - mock_get_pipeline_run.side_effect = Exception("Test exception") + @pytest.mark.parametrize( + "mock_status", + ["InProgress"], + ) + @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn") + @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run") + async def test_get_adf_pipeline_run_status_inprogress( + self, mock_get_pipeline_run, mock_conn, mock_status + ): + """Test get_adf_pipeline_run_status function with mocked status""" + mock_get_pipeline_run.return_value.status = mock_status hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) - with pytest.raises(AirflowException): - await hook.get_adf_pipeline_run_status(RUN_ID, RESOURCE_GROUP_NAME, DATAFACTORY_NAME) + response = await hook.get_adf_pipeline_run_status(RUN_ID, RESOURCE_GROUP_NAME, DATAFACTORY_NAME) + assert response == mock_status @pytest.mark.asyncio - @mock.patch("azure.mgmt.datafactory.models._models_py3.PipelineRun") - @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_async_conn") - async def test_get_pipeline_run(self, mock_async_connection, mock_pipeline_run): - """Test get_pipeline_run run function""" - mock_async_connection.return_value.__aenter__.return_value.pipeline_runs.get.return_value = ( - mock_pipeline_run - ) + @pytest.mark.parametrize( + "mock_status", + ["Succeeded"], + ) + @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn") + @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run") + async def test_get_adf_pipeline_run_status_success(self, mock_get_pipeline_run, mock_conn, mock_status): + """Test get_adf_pipeline_run_status function with mocked status""" + mock_get_pipeline_run.return_value.status = mock_status hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) - response = await hook.get_pipeline_run(RUN_ID, RESOURCE_GROUP_NAME, DATAFACTORY_NAME) - assert response == mock_pipeline_run + response = await hook.get_adf_pipeline_run_status(RUN_ID, RESOURCE_GROUP_NAME, DATAFACTORY_NAME) + assert response == mock_status @pytest.mark.asyncio - @mock.patch("azure.mgmt.datafactory.models._models_py3.PipelineRun") - @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_connection") - @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_async_conn") - async def test_get_pipeline_run_without_resource_name( - self, mock_async_connection, mock_get_connection, mock_pipeline_run - ): - """Test get_pipeline_run run function without passing the resource name - to check the decorator function""" - mock_connection = Connection( - extra=json.dumps( - { - "extra__azure_data_factory__resource_group_name": RESOURCE_GROUP_NAME, - "extra__azure_data_factory__factory_name": DATAFACTORY_NAME, - } - ) - ) - mock_get_connection.return_value = mock_connection - mock_async_connection.return_value.__aenter__.return_value.pipeline_runs.get.return_value = ( - mock_pipeline_run - ) + @pytest.mark.parametrize( + "mock_status", + ["Failed"], + ) + @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn") + @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run") + async def test_get_adf_pipeline_run_status_failed(self, mock_get_pipeline_run, mock_conn, mock_status): + """Test get_adf_pipeline_run_status function with mocked status""" + mock_get_pipeline_run.return_value.status = mock_status + hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) + response = await hook.get_adf_pipeline_run_status(RUN_ID, RESOURCE_GROUP_NAME, DATAFACTORY_NAME) + assert response == mock_status + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_status", + ["Cancelled"], + ) + @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn") + @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run") + async def test_get_adf_pipeline_run_status_cancelled(self, mock_get_pipeline_run, mock_conn, mock_status): + """Test get_adf_pipeline_run_status function with mocked status""" + mock_get_pipeline_run.return_value.status = mock_status + hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) + response = await hook.get_adf_pipeline_run_status(RUN_ID, RESOURCE_GROUP_NAME, DATAFACTORY_NAME) + assert response == mock_status + + @pytest.mark.asyncio + @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn") + @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run") + async def test_get_adf_pipeline_run_status_exception(self, mock_get_pipeline_run, mock_conn): + """Test get_adf_pipeline_run_status function with exception""" + mock_get_pipeline_run.side_effect = Exception("Test exception") hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) - response = await hook.get_pipeline_run(RUN_ID, None, DATAFACTORY_NAME) - assert response == mock_pipeline_run + with pytest.raises(AirflowException): + await hook.get_adf_pipeline_run_status(RUN_ID, RESOURCE_GROUP_NAME, DATAFACTORY_NAME) @pytest.mark.asyncio - @mock.patch("azure.mgmt.datafactory.models._models_py3.PipelineRun") - @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_connection") - @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_async_conn") + @async_mock.patch("azure.mgmt.datafactory.models._models_py3.PipelineRun") + @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection") + @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn") async def test_get_pipeline_run_exception_without_resource( self, mock_conn, mock_get_connection, mock_pipeline_run ): @@ -802,7 +829,7 @@ async def test_get_pipeline_run_exception_without_resource( await hook.get_pipeline_run(RUN_ID, None, DATAFACTORY_NAME) @pytest.mark.asyncio - @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_async_conn") + @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn") async def test_get_pipeline_run_exception(self, mock_conn): """Test get_pipeline_run function with exception""" mock_conn.return_value.__aenter__.return_value.pipeline_runs.get.side_effect = Exception( @@ -813,7 +840,7 @@ async def test_get_pipeline_run_exception(self, mock_conn): await hook.get_pipeline_run(RUN_ID, RESOURCE_GROUP_NAME, DATAFACTORY_NAME) @pytest.mark.asyncio - @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_connection") + @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection") async def test_get_async_conn(self, mock_connection): """""" mock_conn = Connection( @@ -836,7 +863,7 @@ async def test_get_async_conn(self, mock_connection): assert isinstance(response, DataFactoryManagementClient) @pytest.mark.asyncio - @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_connection") + @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection") async def test_get_async_conn_without_login_id(self, mock_connection): """Test get_async_conn function without login id""" mock_conn = Connection( @@ -864,7 +891,28 @@ async def test_get_async_conn_without_login_id(self, mock_connection): "extra__azure_data_factory__tenantId": "tenantId", "extra__azure_data_factory__resource_group_name": RESOURCE_GROUP_NAME, "extra__azure_data_factory__factory_name": DATAFACTORY_NAME, - }, + } + ], + ) + @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection") + async def test_get_async_conn_key_error_tenantId(self, mock_connection, mock_connection_params): + """Test get_async_conn function with raising key error""" + mock_conn = Connection( + conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, + conn_type="azure_data_factory", + login="clientId", + password="clientSecret", + extra=json.dumps(mock_connection_params), + ) + mock_connection.return_value = mock_conn + hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) + with pytest.raises(ValueError): + await hook.get_async_conn() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_connection_params", + [ { "extra__azure_data_factory__subscriptionId": "subscriptionId", "extra__azure_data_factory__resource_group_name": RESOURCE_GROUP_NAME, @@ -872,8 +920,8 @@ async def test_get_async_conn_without_login_id(self, mock_connection): }, ], ) - @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryHookAsync.get_connection") - async def test_get_async_conn_key_error(self, mock_connection, mock_connection_params): + @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection") + async def test_get_async_conn_key_error_subscriptionId(self, mock_connection, mock_connection_params): """Test get_async_conn function with raising key error""" mock_conn = Connection( conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, diff --git a/tests/providers/microsoft/azure/triggers/test_azure_data_factory.py b/tests/providers/microsoft/azure/triggers/test_azure_data_factory.py index 2c76dbe678baa..141cca91af628 100644 --- a/tests/providers/microsoft/azure/triggers/test_azure_data_factory.py +++ b/tests/providers/microsoft/azure/triggers/test_azure_data_factory.py @@ -76,11 +76,33 @@ def test_adf_pipeline_run_status_sensors_trigger_serialization(self): "mock_status", [ "Queued", + ], + ) + @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_adf_pipeline_run_status") + async def test_adf_pipeline_run_status_sensors_trigger_run_queued(self, mock_data_factory, mock_status): + """ + Test if the task is run is in trigger successfully. + """ + mock_data_factory.return_value = mock_status + + task = asyncio.create_task(self.TRIGGER.run().__anext__()) + await asyncio.sleep(0.5) + + # TriggerEvent was not returned + assert task.done() is False + asyncio.get_event_loop().stop() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_status", + [ "InProgress", ], ) @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_adf_pipeline_run_status") - async def test_adf_pipeline_run_status_sensors_trigger_run(self, mock_data_factory, mock_status): + async def test_adf_pipeline_run_status_sensors_trigger_run_inprogress( + self, mock_data_factory, mock_status + ): """ Test if the task is run is in trigger successfully. """ @@ -113,11 +135,28 @@ async def test_adf_pipeline_run_status_sensors_trigger_completed(self, mock_data "mock_status, mock_message", [ ("Failed", f"Pipeline run {RUN_ID} has Failed."), + ], + ) + @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_adf_pipeline_run_status") + async def test_adf_pipeline_run_status_sensors_trigger_failed( + self, mock_data_factory, mock_status, mock_message + ): + """Test if the task is run is in trigger failure status.""" + mock_data_factory.return_value = mock_status + + generator = self.TRIGGER.run() + actual = await generator.asend(None) + assert TriggerEvent({"status": "error", "message": mock_message}) == actual + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_status, mock_message", + [ ("Cancelled", f"Pipeline run {RUN_ID} has been Cancelled."), ], ) @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_adf_pipeline_run_status") - async def test_adf_pipeline_run_status_sensors_trigger_failure_status( + async def test_adf_pipeline_run_status_sensors_trigger_cancelled( self, mock_data_factory, mock_status, mock_message ): """Test if the task is run is in trigger failure status.""" From 624ed7c41164d424a36366fe12f741b1cd4e4bd8 Mon Sep 17 00:00:00 2001 From: Phani Kumar Date: Tue, 28 Feb 2023 17:52:14 +0530 Subject: [PATCH 05/11] Fix tests --- airflow/providers/microsoft/azure/hooks/data_factory.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/airflow/providers/microsoft/azure/hooks/data_factory.py b/airflow/providers/microsoft/azure/hooks/data_factory.py index 0b4bb9e7954d6..0dde136de66f0 100644 --- a/airflow/providers/microsoft/azure/hooks/data_factory.py +++ b/airflow/providers/microsoft/azure/hooks/data_factory.py @@ -1140,9 +1140,7 @@ async def get_pipeline_run( """ async with await self.get_async_conn() as client: try: - print("I am here1") pipeline_run = await client.pipeline_runs.get(resource_group_name, factory_name, run_id) - print("I am here2") return pipeline_run except Exception as e: raise AirflowException(e) From 4810f26f3d57d572a4e586ce6f12fcaf9eae9624 Mon Sep 17 00:00:00 2001 From: Phani Kumar Date: Sun, 5 Mar 2023 15:10:03 +0530 Subject: [PATCH 06/11] Apply review suggestions --- .../microsoft/azure/hooks/data_factory.py | 20 +++-- .../microsoft/azure/sensors/data_factory.py | 13 +-- .../azure/hooks/test_azure_data_factory.py | 82 +++++++++++++------ 3 files changed, 71 insertions(+), 44 deletions(-) diff --git a/airflow/providers/microsoft/azure/hooks/data_factory.py b/airflow/providers/microsoft/azure/hooks/data_factory.py index 0dde136de66f0..b1db9d5211716 100644 --- a/airflow/providers/microsoft/azure/hooks/data_factory.py +++ b/airflow/providers/microsoft/azure/hooks/data_factory.py @@ -1068,14 +1068,17 @@ async def bind_argument(arg: Any, default_key: str) -> None: if arg not in bound_args.arguments or bound_args.arguments[arg] is None: self = args[0] conn = await sync_to_async(self.get_connection)(self.conn_id) - default_value = conn.extra_dejson.get(default_key) + extras = conn.extra_dejson + default_value = extras.get(default_key) or extras.get( + f"extra__azure_data_factory__{default_key}" + ) if not default_value: raise AirflowException("Could not determine the targeted data factory.") - bound_args.arguments[arg] = conn.extra_dejson[default_key] + bound_args.arguments[arg] = default_value - await bind_argument("resource_group_name", "extra__azure_data_factory__resource_group_name") - await bind_argument("factory_name", "extra__azure_data_factory__factory_name") + await bind_argument("resource_group_name", "resource_group_name") + await bind_argument("factory_name", "factory_name") return await func(*bound_args.args, **bound_args.kwargs) @@ -1100,10 +1103,11 @@ async def get_async_conn(self) -> AsyncDataFactoryManagementClient: return self._conn conn = await sync_to_async(self.get_connection)(self.conn_id) - tenant = conn.extra_dejson.get("extra__azure_data_factory__tenantId") + extras = conn.extra_dejson + tenant = get_field(extras, "tenantId") try: - subscription_id = conn.extra_dejson["extra__azure_data_factory__subscriptionId"] + subscription_id = get_field(extras, "subscriptionId", strict=True) except KeyError: raise ValueError("A Subscription ID is required to connect to Azure Data Factory.") @@ -1132,7 +1136,7 @@ async def get_pipeline_run( **config: Any, ) -> PipelineRun: """ - Connects to Azure Data Factory asynchronously to get the pipeline run details by run id + Connect to Azure Data Factory asynchronously to get the pipeline run details by run id :param run_id: The pipeline run identifier. :param resource_group_name: The resource group name. @@ -1149,7 +1153,7 @@ async def get_adf_pipeline_run_status( self, run_id: str, resource_group_name: str | None = None, factory_name: str | None = None ) -> str: """ - Connects to Azure Data Factory asynchronously and gets the pipeline status by run_id + Connect to Azure Data Factory asynchronously and get the pipeline status by run_id :param run_id: The pipeline run identifier. :param resource_group_name: The resource group name. diff --git a/airflow/providers/microsoft/azure/sensors/data_factory.py b/airflow/providers/microsoft/azure/sensors/data_factory.py index 414ff3fb7f1fb..ef0bea3e3ba17 100644 --- a/airflow/providers/microsoft/azure/sensors/data_factory.py +++ b/airflow/providers/microsoft/azure/sensors/data_factory.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import warnings from datetime import timedelta from typing import TYPE_CHECKING, Any, Sequence @@ -100,18 +99,10 @@ class AzureDataFactoryPipelineRunStatusAsyncSensor(AzureDataFactoryPipelineRunSt def __init__( self, *, - poll_interval: float = 5, + poke_interval: float = 60, **kwargs: Any, ): - # TODO: Remove once deprecated - if poll_interval: - self.poke_interval = poll_interval - warnings.warn( - "Argument `poll_interval` is deprecated and will be removed " - "in a future release. Please use `poke_interval` instead.", - DeprecationWarning, - stacklevel=2, - ) + self.poke_interval = poke_interval super().__init__(**kwargs) def execute(self, context: Context) -> None: diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py index 02dda16beca97..8613e5a62d1e2 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py @@ -34,6 +34,7 @@ AzureDataFactoryHook, AzureDataFactoryPipelineRunException, AzureDataFactoryPipelineRunStatus, + get_field, provide_targeted_factory, ) from airflow.utils import db @@ -727,72 +728,59 @@ def test_backcompat_prefix_both_prefers_short(mock_connect): class TestAzureDataFactoryAsyncHook: @pytest.mark.asyncio - @pytest.mark.parametrize( - "mock_status", - ["Queued"], - ) @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn") @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run") - async def test_get_adf_pipeline_run_status_queued(self, mock_get_pipeline_run, mock_conn, mock_status): + async def test_get_adf_pipeline_run_status_queued(self, mock_get_pipeline_run, mock_conn): """Test get_adf_pipeline_run_status function with mocked status""" + mock_status = "Queued" mock_get_pipeline_run.return_value.status = mock_status hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) response = await hook.get_adf_pipeline_run_status(RUN_ID, RESOURCE_GROUP_NAME, DATAFACTORY_NAME) assert response == mock_status @pytest.mark.asyncio - @pytest.mark.parametrize( - "mock_status", - ["InProgress"], - ) @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn") @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run") async def test_get_adf_pipeline_run_status_inprogress( - self, mock_get_pipeline_run, mock_conn, mock_status + self, + mock_get_pipeline_run, + mock_conn, ): """Test get_adf_pipeline_run_status function with mocked status""" + mock_status = "InProgress" mock_get_pipeline_run.return_value.status = mock_status hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) response = await hook.get_adf_pipeline_run_status(RUN_ID, RESOURCE_GROUP_NAME, DATAFACTORY_NAME) assert response == mock_status @pytest.mark.asyncio - @pytest.mark.parametrize( - "mock_status", - ["Succeeded"], - ) @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn") @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run") - async def test_get_adf_pipeline_run_status_success(self, mock_get_pipeline_run, mock_conn, mock_status): + async def test_get_adf_pipeline_run_status_success(self, mock_get_pipeline_run, mock_conn): """Test get_adf_pipeline_run_status function with mocked status""" + mock_status = "Succeeded" mock_get_pipeline_run.return_value.status = mock_status hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) response = await hook.get_adf_pipeline_run_status(RUN_ID, RESOURCE_GROUP_NAME, DATAFACTORY_NAME) assert response == mock_status @pytest.mark.asyncio - @pytest.mark.parametrize( - "mock_status", - ["Failed"], - ) @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn") @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run") - async def test_get_adf_pipeline_run_status_failed(self, mock_get_pipeline_run, mock_conn, mock_status): + async def test_get_adf_pipeline_run_status_failed(self, mock_get_pipeline_run, mock_conn): """Test get_adf_pipeline_run_status function with mocked status""" + mock_status = "Failed" mock_get_pipeline_run.return_value.status = mock_status hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) response = await hook.get_adf_pipeline_run_status(RUN_ID, RESOURCE_GROUP_NAME, DATAFACTORY_NAME) assert response == mock_status @pytest.mark.asyncio - @pytest.mark.parametrize( - "mock_status", - ["Cancelled"], - ) @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn") @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run") - async def test_get_adf_pipeline_run_status_cancelled(self, mock_get_pipeline_run, mock_conn, mock_status): + async def test_get_adf_pipeline_run_status_cancelled(self, mock_get_pipeline_run, mock_conn): """Test get_adf_pipeline_run_status function with mocked status""" + mock_status = "Cancelled" mock_get_pipeline_run.return_value.status = mock_status hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) response = await hook.get_adf_pipeline_run_status(RUN_ID, RESOURCE_GROUP_NAME, DATAFACTORY_NAME) @@ -934,3 +922,47 @@ async def test_get_async_conn_key_error_subscriptionId(self, mock_connection, mo hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) with pytest.raises(ValueError): await hook.get_async_conn() + + def test_get_field_prefixed_extras(self): + """Test get_field function for retrieving prefixed extra fields""" + mock_conn = Connection( + conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, + conn_type="azure_data_factory", + extra=json.dumps( + { + "extra__azure_data_factory__tenantId": "tenantId", + "extra__azure_data_factory__subscriptionId": "subscriptionId", + "extra__azure_data_factory__resource_group_name": RESOURCE_GROUP_NAME, + "extra__azure_data_factory__factory_name": DATAFACTORY_NAME, + } + ), + ) + extras = mock_conn.extra_dejson + assert get_field(extras, "tenantId", strict=True) == "tenantId" + assert get_field(extras, "subscriptionId", strict=True) == "subscriptionId" + assert get_field(extras, "resource_group_name", strict=True) == RESOURCE_GROUP_NAME + assert get_field(extras, "factory_name", strict=True) == DATAFACTORY_NAME + with pytest.raises(KeyError): + get_field(extras, "non-existent-field", strict=True) + + def test_get_field_non_prefixed_extras(self): + """Test get_field function for retrieving non-prefixed extra fields""" + mock_conn = Connection( + conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, + conn_type="azure_data_factory", + extra=json.dumps( + { + "tenantId": "tenantId", + "subscriptionId": "subscriptionId", + "resource_group_name": RESOURCE_GROUP_NAME, + "factory_name": DATAFACTORY_NAME, + } + ), + ) + extras = mock_conn.extra_dejson + assert get_field(extras, "tenantId", strict=True) == "tenantId" + assert get_field(extras, "subscriptionId", strict=True) == "subscriptionId" + assert get_field(extras, "resource_group_name", strict=True) == RESOURCE_GROUP_NAME + assert get_field(extras, "factory_name", strict=True) == DATAFACTORY_NAME + with pytest.raises(KeyError): + get_field(extras, "non-existent-field", strict=True) From 887f7d00a506b865d7996cfe61938b72c9abbc1a Mon Sep 17 00:00:00 2001 From: Phani Kumar Date: Sun, 5 Mar 2023 17:34:35 +0530 Subject: [PATCH 07/11] Fix tests --- .../microsoft/azure/sensors/test_azure_data_factory.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py b/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py index d4413b07ae3c7..0424bf41c435c 100644 --- a/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py +++ b/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py @@ -111,12 +111,3 @@ def test_adf_pipeline_status_sensor_execute_complete_failure(self): with pytest.raises(AirflowException): self.SENSOR.execute_complete(context={}, event={"status": "error", "message": ""}) - - def test_poll_interval_deprecation_warning(self): - """Test DeprecationWarning for AzureDataFactoryPipelineRunStatusAsyncSensor - by setting param poll_interval""" - # TODO: Remove once deprecated - with pytest.warns(expected_warning=DeprecationWarning): - AzureDataFactoryPipelineRunStatusAsyncSensor( - task_id="pipeline_run_sensor_async", run_id=self.RUN_ID, poll_interval=5.0 - ) From 7405a3f1f87f5509f0cd40d6b3fabfbeec473258 Mon Sep 17 00:00:00 2001 From: Phani Kumar Date: Thu, 9 Mar 2023 14:46:59 +0530 Subject: [PATCH 08/11] Add docs --- .../operators/adf_run_pipeline.rst | 20 +++++++++++++++++-- .../azure/example_adf_run_pipeline.py | 13 ++++++++++-- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/docs/apache-airflow-providers-microsoft-azure/operators/adf_run_pipeline.rst b/docs/apache-airflow-providers-microsoft-azure/operators/adf_run_pipeline.rst index 7d3b484ba7cbe..6a45d335448d2 100644 --- a/docs/apache-airflow-providers-microsoft-azure/operators/adf_run_pipeline.rst +++ b/docs/apache-airflow-providers-microsoft-azure/operators/adf_run_pipeline.rst @@ -27,7 +27,7 @@ AzureDataFactoryRunPipelineOperator ----------------------------------- Use the :class:`~airflow.providers.microsoft.azure.operators.data_factory.AzureDataFactoryRunPipelineOperator` to execute a pipeline within a data factory. By default, the operator will periodically check on the status of the executed pipeline to terminate with a "Succeeded" status. -This functionality can be disabled for an asynchronous wait -- typically with the :class:`~airflow.providers.microsoft.azure.sensors.data_factory.AzureDataFactoryPipelineRunSensor` -- by setting ``wait_for_termination`` to False. +This functionality can be disabled for an asynchronous wait -- typically with the :class:`~airflow.providers.microsoft.azure.sensors.data_factory.AzureDataFactoryPipelineRunStatusSensor` -- by setting ``wait_for_termination`` to False. Below is an example of using this operator to execute an Azure Data Factory pipeline. @@ -37,7 +37,7 @@ Below is an example of using this operator to execute an Azure Data Factory pipe :start-after: [START howto_operator_adf_run_pipeline] :end-before: [END howto_operator_adf_run_pipeline] -Here is a different example of using this operator to execute a pipeline but coupled with the :class:`~airflow.providers.microsoft.azure.sensors.data_factory.AzureDataFactoryPipelineRunSensor` to perform an asynchronous wait. +Here is a different example of using this operator to execute a pipeline but coupled with the :class:`~airflow.providers.microsoft.azure.sensors.data_factory.AzureDataFactoryPipelineRunStatusSensor` to perform an asynchronous wait. .. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_adf_run_pipeline.py :language: python @@ -45,6 +45,22 @@ Here is a different example of using this operator to execute a pipeline but cou :start-after: [START howto_operator_adf_run_pipeline_async] :end-before: [END howto_operator_adf_run_pipeline_async] +Poll for status of a datafactory pipeline run asynchronously +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Use the :class:`~airflow.providers.microsoft.azure.sensors.data_factory.AzureDataFactoryPipelineRunStatusAsyncSensor` +(deferrable version) to periodically retrieve the +status of a datafactory pipeline run asynchronously. This sensor will free up the worker slots since +polling for job status happens on the Airflow triggerer, leading to efficient utilization +of resources within Airflow. + +.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_adf_run_pipeline.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_adf_run_pipeline_async] + :end-before: [END howto_operator_adf_run_pipeline_async] + + Reference --------- diff --git a/tests/system/providers/microsoft/azure/example_adf_run_pipeline.py b/tests/system/providers/microsoft/azure/example_adf_run_pipeline.py index 0ab03b33f266c..456bb1a034731 100644 --- a/tests/system/providers/microsoft/azure/example_adf_run_pipeline.py +++ b/tests/system/providers/microsoft/azure/example_adf_run_pipeline.py @@ -29,7 +29,10 @@ from airflow.operators.dummy import DummyOperator as EmptyOperator # type: ignore from airflow.providers.microsoft.azure.operators.data_factory import AzureDataFactoryRunPipelineOperator -from airflow.providers.microsoft.azure.sensors.data_factory import AzureDataFactoryPipelineRunStatusSensor +from airflow.providers.microsoft.azure.sensors.data_factory import ( + AzureDataFactoryPipelineRunStatusAsyncSensor, + AzureDataFactoryPipelineRunStatusSensor, +) from airflow.utils.edgemodifier import Label ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") @@ -71,11 +74,17 @@ task_id="pipeline_run_sensor", run_id=cast(str, XComArg(run_pipeline2, key="run_id")), ) + + # Performs polling on the Airflow Triggerer thus freeing up resources on Airflow Worker + pipeline_run_async_sensor = AzureDataFactoryPipelineRunStatusAsyncSensor( + task_id="pipeline_run_async_sensor", + run_id=cast(str, XComArg(run_pipeline2, key="run_id")), + ) # [END howto_operator_adf_run_pipeline_async] begin >> Label("No async wait") >> run_pipeline1 begin >> Label("Do async wait with sensor") >> run_pipeline2 - [run_pipeline1, pipeline_run_sensor] >> end + [run_pipeline1, pipeline_run_sensor, pipeline_run_async_sensor] >> end # Task dependency created via `XComArgs`: # run_pipeline2 >> pipeline_run_sensor From ea282d987c2847bc31de18aadc8927ca5f9c492a Mon Sep 17 00:00:00 2001 From: Phani Kumar Date: Thu, 9 Mar 2023 16:04:25 +0530 Subject: [PATCH 09/11] Add docs --- .../operators/adf_run_pipeline.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/apache-airflow-providers-microsoft-azure/operators/adf_run_pipeline.rst b/docs/apache-airflow-providers-microsoft-azure/operators/adf_run_pipeline.rst index 6a45d335448d2..6606b33c12375 100644 --- a/docs/apache-airflow-providers-microsoft-azure/operators/adf_run_pipeline.rst +++ b/docs/apache-airflow-providers-microsoft-azure/operators/adf_run_pipeline.rst @@ -45,12 +45,12 @@ Here is a different example of using this operator to execute a pipeline but cou :start-after: [START howto_operator_adf_run_pipeline_async] :end-before: [END howto_operator_adf_run_pipeline_async] -Poll for status of a datafactory pipeline run asynchronously -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Poll for status of a data factory pipeline run asynchronously +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Use the :class:`~airflow.providers.microsoft.azure.sensors.data_factory.AzureDataFactoryPipelineRunStatusAsyncSensor` (deferrable version) to periodically retrieve the -status of a datafactory pipeline run asynchronously. This sensor will free up the worker slots since +status of a data factory pipeline run asynchronously. This sensor will free up the worker slots since polling for job status happens on the Airflow triggerer, leading to efficient utilization of resources within Airflow. From d140f7adac8fcfe645b65a499e62c5013377bba5 Mon Sep 17 00:00:00 2001 From: Phani Kumar Date: Fri, 10 Mar 2023 22:08:22 +0530 Subject: [PATCH 10/11] Apply review suggestions --- airflow/providers/microsoft/azure/hooks/data_factory.py | 9 ++++++--- .../providers/microsoft/azure/sensors/data_factory.py | 2 +- .../microsoft/azure/hooks/test_azure_data_factory.py | 8 ++++---- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/airflow/providers/microsoft/azure/hooks/data_factory.py b/airflow/providers/microsoft/azure/hooks/data_factory.py index b1db9d5211716..6518798d61595 100644 --- a/airflow/providers/microsoft/azure/hooks/data_factory.py +++ b/airflow/providers/microsoft/azure/hooks/data_factory.py @@ -1092,15 +1092,17 @@ class AzureDataFactoryAsyncHook(AzureDataFactoryHook): :param azure_data_factory_conn_id: The :ref:`Azure Data Factory connection id`. """ - def __init__(self, azure_data_factory_conn_id: str): + default_conn_name: str = "azure_data_factory_default" + + def __init__(self, azure_data_factory_conn_id: str = default_conn_name): self._async_conn: AsyncDataFactoryManagementClient = None self.conn_id = azure_data_factory_conn_id super().__init__(azure_data_factory_conn_id=azure_data_factory_conn_id) async def get_async_conn(self) -> AsyncDataFactoryManagementClient: """Get async connection and connect to azure data factory""" - if self._conn is not None: - return self._conn + if self._async_conn is not None: + return self._async_conn conn = await sync_to_async(self.get_connection)(self.conn_id) extras = conn.extra_dejson @@ -1141,6 +1143,7 @@ async def get_pipeline_run( :param run_id: The pipeline run identifier. :param resource_group_name: The resource group name. :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. """ async with await self.get_async_conn() as client: try: diff --git a/airflow/providers/microsoft/azure/sensors/data_factory.py b/airflow/providers/microsoft/azure/sensors/data_factory.py index ef0bea3e3ba17..e09f7a0a7d008 100644 --- a/airflow/providers/microsoft/azure/sensors/data_factory.py +++ b/airflow/providers/microsoft/azure/sensors/data_factory.py @@ -93,7 +93,7 @@ class AzureDataFactoryPipelineRunStatusAsyncSensor(AzureDataFactoryPipelineRunSt :param run_id: The pipeline run identifier. :param resource_group_name: The resource group name. :param factory_name: The data factory name. - :param poll_interval: polling period in seconds to check for the status + :param poke_interval: polling period in seconds to check for the status """ def __init__( diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py index 8613e5a62d1e2..4d7466fed1408 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py @@ -883,8 +883,8 @@ async def test_get_async_conn_without_login_id(self, mock_connection): ], ) @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection") - async def test_get_async_conn_key_error_tenantId(self, mock_connection, mock_connection_params): - """Test get_async_conn function with raising key error""" + async def test_get_async_conn_key_error_subscription_id(self, mock_connection, mock_connection_params): + """Test get_async_conn function when subscription_id is missing in the connection""" mock_conn = Connection( conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, conn_type="azure_data_factory", @@ -909,8 +909,8 @@ async def test_get_async_conn_key_error_tenantId(self, mock_connection, mock_con ], ) @async_mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection") - async def test_get_async_conn_key_error_subscriptionId(self, mock_connection, mock_connection_params): - """Test get_async_conn function with raising key error""" + async def test_get_async_conn_key_error_tenant_id(self, mock_connection, mock_connection_params): + """Test get_async_conn function when tenant id is missing in the connection""" mock_conn = Connection( conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, conn_type="azure_data_factory", From 264e90dd62f4f6be0498b7df0fb03594b5936153 Mon Sep 17 00:00:00 2001 From: Phani Kumar Date: Mon, 13 Mar 2023 11:31:05 +0530 Subject: [PATCH 11/11] Set async conn variable --- airflow/providers/microsoft/azure/hooks/data_factory.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/airflow/providers/microsoft/azure/hooks/data_factory.py b/airflow/providers/microsoft/azure/hooks/data_factory.py index 6518798d61595..8b26066866a11 100644 --- a/airflow/providers/microsoft/azure/hooks/data_factory.py +++ b/airflow/providers/microsoft/azure/hooks/data_factory.py @@ -1124,11 +1124,13 @@ async def get_async_conn(self) -> AsyncDataFactoryManagementClient: else: credential = AsyncDefaultAzureCredential() - return AsyncDataFactoryManagementClient( + self._async_conn = AsyncDataFactoryManagementClient( credential=credential, subscription_id=subscription_id, ) + return self._async_conn + @provide_targeted_factory_async async def get_pipeline_run( self,