From b707e6806b5887d39d397ff177a7b120b87c65c8 Mon Sep 17 00:00:00 2001 From: Josh Fell Date: Tue, 1 Mar 2022 09:53:03 -0500 Subject: [PATCH] Add `test_connection` method to `AzureDataFactoryHook` --- .../microsoft/azure/hooks/data_factory.py | 22 +++++- .../azure/hooks/test_azure_data_factory.py | 68 ++++++++++++++++++- 2 files changed, 88 insertions(+), 2 deletions(-) diff --git a/airflow/providers/microsoft/azure/hooks/data_factory.py b/airflow/providers/microsoft/azure/hooks/data_factory.py index 8e1e97d2a3fbe..3b1a79675f9dc 100644 --- a/airflow/providers/microsoft/azure/hooks/data_factory.py +++ b/airflow/providers/microsoft/azure/hooks/data_factory.py @@ -30,7 +30,7 @@ import inspect import time from functools import wraps -from typing import Any, Callable, Dict, Optional, Set, Union +from typing import Any, Callable, Dict, Optional, Set, Tuple, Union from azure.core.polling import LROPoller from azure.identity import ClientSecretCredential, DefaultAzureCredential @@ -891,3 +891,23 @@ def cancel_trigger( :param config: Extra parameters for the ADF client. """ self.get_conn().trigger_runs.cancel(resource_group_name, factory_name, trigger_name, run_id, **config) + + def test_connection(self) -> Tuple[bool, str]: + """Test a configured Azure Data Factory connection.""" + success = (True, "Successfully connected to Azure Data Factory.") + + try: + # Attempt to list existing factories under the configured subscription and retrieve the first in + # the returned iterator. The Azure Data Factory API does allow for creation of a + # DataFactoryManagementClient with incorrect values but then will fail properly once items are + # retrieved using the client. We need to _actually_ try to retrieve an object to properly test the + # connection. + next(self.get_conn().factories.list()) + return success + except StopIteration: + # If the iterator returned is empty it should still be considered a successful connection since + # it's possible to create a Data Factory via the ``AzureDataFactoryHook`` and none could + # legitimately exist yet. + return success + except Exception as e: + return False, str(e) diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py index 85efa95cfc317..a5730fb9a1d40 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py @@ -18,10 +18,11 @@ import json from typing import Type -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, PropertyMock, patch import pytest from azure.identity import ClientSecretCredential, DefaultAzureCredential +from azure.mgmt.datafactory.models import FactoryListResponse from pytest import fixture from airflow.exceptions import AirflowException @@ -74,9 +75,37 @@ def setup_module(): } ), ) + connection_missing_subscription_id = Connection( + conn_id="azure_data_factory_missing_subscription_id", + conn_type="azure_data_factory", + login="clientId", + password="clientSecret", + extra=json.dumps( + { + "extra__azure_data_factory__tenantId": "tenantId", + "extra__azure_data_factory__resource_group_name": DEFAULT_RESOURCE_GROUP, + "extra__azure_data_factory__factory_name": DEFAULT_FACTORY, + } + ), + ) + connection_missing_tenant_id = Connection( + conn_id="azure_data_factory_missing_tenant_id", + conn_type="azure_data_factory", + login="clientId", + password="clientSecret", + extra=json.dumps( + { + "extra__azure_data_factory__subscriptionId": "subscriptionId", + "extra__azure_data_factory__resource_group_name": DEFAULT_RESOURCE_GROUP, + "extra__azure_data_factory__factory_name": DEFAULT_FACTORY, + } + ), + ) db.merge_conn(connection_client_secret) db.merge_conn(connection_default_credential) + db.merge_conn(connection_missing_subscription_id) + db.merge_conn(connection_missing_tenant_id) @fixture @@ -526,3 +555,40 @@ def test_cancel_trigger(hook: AzureDataFactoryHook, user_args, sdk_args): hook.cancel_trigger(*user_args) hook._conn.trigger_runs.cancel.assert_called_with(*sdk_args) + + +@pytest.mark.parametrize( + argnames="factory_list_result", + argvalues=[iter([FactoryListResponse]), iter([])], + ids=["factory_exists", "factory_does_not_exist"], +) +def test_connection_success(hook, factory_list_result): + hook.get_conn().factories.list.return_value = factory_list_result + status, msg = hook.test_connection() + + assert status is True + assert msg == "Successfully connected to Azure Data Factory." + + +def test_connection_failure(hook): + hook.get_conn().factories.list = PropertyMock(side_effect=Exception("Authentication failed.")) + status, msg = hook.test_connection() + + assert status is False + assert msg == "Authentication failed." + + +def test_connection_failure_missing_subscription_id(): + hook = AzureDataFactoryHook("azure_data_factory_missing_subscription_id") + status, msg = hook.test_connection() + + assert status is False + assert msg == "A Subscription ID is required to connect to Azure Data Factory." + + +def test_connection_failure_missing_tenant_id(): + hook = AzureDataFactoryHook("azure_data_factory_missing_tenant_id") + status, msg = hook.test_connection() + + assert status is False + assert msg == "A Tenant ID is required when authenticating with Client ID and Secret."