diff --git a/airflow/providers/microsoft/azure/hooks/data_factory.py b/airflow/providers/microsoft/azure/hooks/data_factory.py index 52c1bb1c0eb4f..a05ae87538867 100644 --- a/airflow/providers/microsoft/azure/hooks/data_factory.py +++ b/airflow/providers/microsoft/azure/hooks/data_factory.py @@ -74,14 +74,17 @@ def bind_argument(arg, default_key): if arg not in bound_args.arguments or bound_args.arguments[arg] is None: self = args[0] conn = self.get_connection(self.conn_id) - default_value = conn.extra_dejson.get(default_key) + extras = conn.extra_dejson + default_value = extras.get(default_key) or extras.get( + f"extra__azure_data_factory__{default_key}" + ) if not default_value: raise AirflowException("Could not determine the targeted data factory.") - bound_args.arguments[arg] = conn.extra_dejson[default_key] + bound_args.arguments[arg] = default_value - bind_argument("resource_group_name", "extra__azure_data_factory__resource_group_name") - bind_argument("factory_name", "extra__azure_data_factory__factory_name") + bind_argument("resource_group_name", "resource_group_name") + bind_argument("factory_name", "factory_name") return func(*bound_args.args, **bound_args.kwargs) @@ -113,6 +116,23 @@ class AzureDataFactoryPipelineRunException(AirflowException): """An exception that indicates a pipeline run failed to complete.""" +def get_field(extras: dict, field_name: str, strict: bool = False): + """Get field from extra, first checking short name, then for backcompat we check for prefixed name.""" + backcompat_prefix = "extra__azure_data_factory__" + if field_name.startswith("extra__"): + raise ValueError( + f"Got prefixed name {field_name}; please remove the '{backcompat_prefix}' prefix " + "when using this method." + ) + if field_name in extras: + return extras[field_name] or None + prefixed_name = f"{backcompat_prefix}{field_name}" + if prefixed_name in extras: + return extras[prefixed_name] or None + if strict: + raise KeyError(f"Field {field_name} not found in extras") + + class AzureDataFactoryHook(BaseHook): """ A hook to interact with Azure Data Factory. @@ -133,18 +153,12 @@ def get_connection_form_widgets() -> dict[str, Any]: from wtforms import StringField return { - "extra__azure_data_factory__tenantId": StringField( - lazy_gettext("Tenant ID"), widget=BS3TextFieldWidget() - ), - "extra__azure_data_factory__subscriptionId": StringField( - lazy_gettext("Subscription ID"), widget=BS3TextFieldWidget() - ), - "extra__azure_data_factory__resource_group_name": StringField( + "tenantId": StringField(lazy_gettext("Tenant ID"), widget=BS3TextFieldWidget()), + "subscriptionId": StringField(lazy_gettext("Subscription ID"), widget=BS3TextFieldWidget()), + "resource_group_name": StringField( lazy_gettext("Resource Group Name"), widget=BS3TextFieldWidget() ), - "extra__azure_data_factory__factory_name": StringField( - lazy_gettext("Factory Name"), widget=BS3TextFieldWidget() - ), + "factory_name": StringField(lazy_gettext("Factory Name"), widget=BS3TextFieldWidget()), } @staticmethod @@ -168,10 +182,11 @@ def get_conn(self) -> DataFactoryManagementClient: return self._conn conn = self.get_connection(self.conn_id) - tenant = conn.extra_dejson.get("extra__azure_data_factory__tenantId") + extras = conn.extra_dejson + tenant = get_field(extras, "tenantId") try: - subscription_id = conn.extra_dejson["extra__azure_data_factory__subscriptionId"] + subscription_id = get_field(extras, "subscriptionId", strict=True) except KeyError: raise ValueError("A Subscription ID is required to connect to Azure Data Factory.") diff --git a/airflow/providers/microsoft/azure/operators/data_factory.py b/airflow/providers/microsoft/azure/operators/data_factory.py index a97a9578aab4d..da1c47d2d819e 100644 --- a/airflow/providers/microsoft/azure/operators/data_factory.py +++ b/airflow/providers/microsoft/azure/operators/data_factory.py @@ -24,6 +24,7 @@ AzureDataFactoryHook, AzureDataFactoryPipelineRunException, AzureDataFactoryPipelineRunStatus, + get_field, ) if TYPE_CHECKING: @@ -53,17 +54,16 @@ def get_link( task_id=operator.task_id, execution_date=dttm, ) - - conn = BaseHook.get_connection(operator.azure_data_factory_conn_id) - subscription_id = conn.extra_dejson["extra__azure_data_factory__subscriptionId"] + conn_id = operator.azure_data_factory_conn_id + conn = BaseHook.get_connection(conn_id) + extras = conn.extra_dejson + subscription_id = get_field(extras, "subscriptionId") + if not subscription_id: + raise KeyError(f"Param subscriptionId not found in conn_id '{conn_id}'") # Both Resource Group Name and Factory Name can either be declared in the Azure Data Factory # connection or passed directly to the operator. - resource_group_name = operator.resource_group_name or conn.extra_dejson.get( - "extra__azure_data_factory__resource_group_name" - ) - factory_name = operator.factory_name or conn.extra_dejson.get( - "extra__azure_data_factory__factory_name" - ) + resource_group_name = operator.resource_group_name or get_field(extras, "resource_group_name") + factory_name = operator.factory_name or get_field(extras, "factory_name") url = ( f"https://adf.azure.com/en-us/monitoring/pipelineruns/{run_id}" f"?factory=/subscriptions/{subscription_id}/" diff --git a/docs/apache-airflow-providers-microsoft-azure/connections/adf.rst b/docs/apache-airflow-providers-microsoft-azure/connections/adf.rst index 5400bb20ae9b7..25f231efdeb11 100644 --- a/docs/apache-airflow-providers-microsoft-azure/connections/adf.rst +++ b/docs/apache-airflow-providers-microsoft-azure/connections/adf.rst @@ -58,22 +58,22 @@ Tenant ID Specify the Azure tenant ID used for the initial connection. This is needed for *token credentials* authentication mechanism. It can be left out to fall back on ``DefaultAzureCredential``. - Use the key ``extra__azure_data_factory__tenantId`` to pass in the tenant ID. + Use extra param ``tenantId`` to pass in the tenant ID. Subscription ID Specify the ID of the subscription used for the initial connection. This is needed for all authentication mechanisms. - Use the key ``extra__azure_data_factory__subscriptionId`` to pass in the Azure subscription ID. + Use extra param ``subscriptionId`` to pass in the Azure subscription ID. Factory Name (optional) Specify the Azure Data Factory to interface with. If not specified in the connection, this needs to be passed in directly to hooks, operators, and sensors. - Use the key ``extra__azure_data_factory__factory_name`` to pass in the factory name. + Use extra param ``factory_name`` to pass in the factory name. Resource Group Name (optional) Specify the Azure Resource Group Name under which the desired data factory resides. If not specified in the connection, this needs to be passed in directly to hooks, operators, and sensors. - Use the key ``extra__azure_data_factory__resource_group_name`` to pass in the resource group name. + Use extra param ``resource_group_name`` to pass in the resource group name. When specifying the connection in environment variable you should specify @@ -86,8 +86,8 @@ Examples .. code-block:: bash - export AIRFLOW_CONN_AZURE_DATA_FACTORY_DEFAULT='azure-data-factory://applicationid:serviceprincipalpassword@?extra__azure_data_factory__tenantId=tenant+id&extra__azure_data_factory__subscriptionId=subscription+id&extra__azure_data_factory__resource_group_name=group+name&extra__azure_data_factory__factory_name=factory+name' + export AIRFLOW_CONN_AZURE_DATA_FACTORY_DEFAULT='azure-data-factory://applicationid:serviceprincipalpassword@?tenantId=tenant+id&subscriptionId=subscription+id&resource_group_name=group+name&factory_name=factory+name' .. code-block:: bash - export AIRFLOW_CONN_AZURE_DATA_FACTORY_DEFAULT='azure-data-factory://applicationid:serviceprincipalpassword@?extra__azure_data_factory__tenantId=tenant+id&extra__azure_data_factory__subscriptionId=subscription+id' + export AIRFLOW_CONN_AZURE_DATA_FACTORY_DEFAULT='azure-data-factory://applicationid:serviceprincipalpassword@?tenantId=tenant+id&subscriptionId=subscription+id' diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py index cfc387c8f9950..c24b3edc318f6 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py @@ -17,12 +17,13 @@ from __future__ import annotations import json +import os from unittest.mock import MagicMock, PropertyMock, patch import pytest from azure.identity import ClientSecretCredential, DefaultAzureCredential from azure.mgmt.datafactory.models import FactoryListResponse -from pytest import fixture +from pytest import fixture, param from airflow.exceptions import AirflowException from airflow.models.connection import Connection @@ -56,10 +57,10 @@ def setup_module(): password="clientSecret", extra=json.dumps( { - "extra__azure_data_factory__tenantId": "tenantId", - "extra__azure_data_factory__subscriptionId": "subscriptionId", - "extra__azure_data_factory__resource_group_name": DEFAULT_RESOURCE_GROUP, - "extra__azure_data_factory__factory_name": DEFAULT_FACTORY, + "tenantId": "tenantId", + "subscriptionId": "subscriptionId", + "resource_group_name": DEFAULT_RESOURCE_GROUP, + "factory_name": DEFAULT_FACTORY, } ), ) @@ -68,9 +69,9 @@ def setup_module(): conn_type="azure_data_factory", extra=json.dumps( { - "extra__azure_data_factory__subscriptionId": "subscriptionId", - "extra__azure_data_factory__resource_group_name": DEFAULT_RESOURCE_GROUP, - "extra__azure_data_factory__factory_name": DEFAULT_FACTORY, + "subscriptionId": "subscriptionId", + "resource_group_name": DEFAULT_RESOURCE_GROUP, + "factory_name": DEFAULT_FACTORY, } ), ) @@ -81,9 +82,9 @@ def setup_module(): password="clientSecret", extra=json.dumps( { - "extra__azure_data_factory__tenantId": "tenantId", - "extra__azure_data_factory__resource_group_name": DEFAULT_RESOURCE_GROUP, - "extra__azure_data_factory__factory_name": DEFAULT_FACTORY, + "tenantId": "tenantId", + "resource_group_name": DEFAULT_RESOURCE_GROUP, + "factory_name": DEFAULT_FACTORY, } ), ) @@ -94,9 +95,9 @@ def setup_module(): password="clientSecret", extra=json.dumps( { - "extra__azure_data_factory__subscriptionId": "subscriptionId", - "extra__azure_data_factory__resource_group_name": DEFAULT_RESOURCE_GROUP, - "extra__azure_data_factory__factory_name": DEFAULT_FACTORY, + "subscriptionId": "subscriptionId", + "resource_group_name": DEFAULT_RESOURCE_GROUP, + "factory_name": DEFAULT_FACTORY, } ), ) @@ -149,8 +150,8 @@ def echo(_, resource_group_name=None, factory_name=None): assert provide_targeted_factory(echo)(hook, RESOURCE_GROUP, FACTORY) == (RESOURCE_GROUP, FACTORY) conn.extra_dejson = { - "extra__azure_data_factory__resource_group_name": DEFAULT_RESOURCE_GROUP, - "extra__azure_data_factory__factory_name": DEFAULT_FACTORY, + "resource_group_name": DEFAULT_RESOURCE_GROUP, + "factory_name": DEFAULT_FACTORY, } assert provide_targeted_factory(echo)(hook) == (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY) assert provide_targeted_factory(echo)(hook, RESOURCE_GROUP, None) == (RESOURCE_GROUP, DEFAULT_FACTORY) @@ -653,3 +654,57 @@ def test_connection_failure_missing_tenant_id(): assert status is False assert msg == "A Tenant ID is required when authenticating with Client ID and Secret." + + +@pytest.mark.parametrize( + "uri", + [ + param( + "a://?extra__azure_data_factory__resource_group_name=abc" + "&extra__azure_data_factory__factory_name=abc", + id="prefix", + ), + param("a://?resource_group_name=abc&factory_name=abc", id="no-prefix"), + ], +) +@patch("airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.get_conn") +def test_provide_targeted_factory_backcompat_prefix_works(mock_connect, uri): + with patch.dict(os.environ, {"AIRFLOW_CONN_MY_CONN": uri}): + hook = AzureDataFactoryHook("my_conn") + hook.delete_factory() + mock_connect.return_value.factories.delete.assert_called_with("abc", "abc") + + +@pytest.mark.parametrize( + "uri", + [ + param( + "a://hi:yo@?extra__azure_data_factory__tenantId=ten" + "&extra__azure_data_factory__subscriptionId=sub", + id="prefix", + ), + param("a://hi:yo@?tenantId=ten&subscriptionId=sub", id="no-prefix"), + ], +) +@patch("airflow.providers.microsoft.azure.hooks.data_factory.ClientSecretCredential") +@patch("airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook._create_client") +def test_get_conn_backcompat_prefix_works(mock_create, mock_cred, uri): + with patch.dict(os.environ, {"AIRFLOW_CONN_MY_CONN": uri}): + hook = AzureDataFactoryHook("my_conn") + hook.get_conn() + mock_cred.assert_called_with(client_id="hi", client_secret="yo", tenant_id="ten") + mock_create.assert_called_with(mock_cred.return_value, "sub") + + +@patch("airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.get_conn") +def test_backcompat_prefix_both_prefers_short(mock_connect): + with patch.dict( + os.environ, + { + "AIRFLOW_CONN_MY_CONN": "a://?resource_group_name=non-prefixed" + "&extra__azure_data_factory__resource_group_name=prefixed" + }, + ): + hook = AzureDataFactoryHook("my_conn") + hook.delete_factory(factory_name="n/a") + mock_connect.return_value.factories.delete.assert_called_with("non-prefixed", "n/a") diff --git a/tests/providers/microsoft/azure/operators/test_azure_data_factory.py b/tests/providers/microsoft/azure/operators/test_azure_data_factory.py index 38c8c7bf6fc2c..91b4a7dcb6e22 100644 --- a/tests/providers/microsoft/azure/operators/test_azure_data_factory.py +++ b/tests/providers/microsoft/azure/operators/test_azure_data_factory.py @@ -36,10 +36,10 @@ AZURE_DATA_FACTORY_CONN_ID = "azure_data_factory_test" PIPELINE_NAME = "pipeline1" CONN_EXTRAS = { - "extra__azure_data_factory__subscriptionId": SUBSCRIPTION_ID, - "extra__azure_data_factory__tenantId": "my-tenant-id", - "extra__azure_data_factory__resource_group_name": "my-resource-group-name-from-conn", - "extra__azure_data_factory__factory_name": "my-factory-name-from-conn", + "subscriptionId": SUBSCRIPTION_ID, + "tenantId": "my-tenant-id", + "resource_group_name": "my-resource-group-name-from-conn", + "factory_name": "my-factory-name-from-conn", } PIPELINE_RUN_RESPONSE = {"additional_properties": {}, "run_id": "run_id"} EXPECTED_PIPELINE_RUN_OP_EXTRA_LINK = ( @@ -241,8 +241,8 @@ def test_run_pipeline_operator_link(self, resource_group, factory, create_task_i ) conn = AzureDataFactoryHook.get_connection("azure_data_factory_test") - conn_resource_group_name = conn.extra_dejson["extra__azure_data_factory__resource_group_name"] - conn_factory_name = conn.extra_dejson["extra__azure_data_factory__factory_name"] + conn_resource_group_name = conn.extra_dejson["resource_group_name"] + conn_factory_name = conn.extra_dejson["factory_name"] assert url == ( EXPECTED_PIPELINE_RUN_OP_EXTRA_LINK.format(