Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion airflow/providers/microsoft/azure/hooks/data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import inspect
import time
from functools import wraps
from typing import Any, Callable, Dict, Optional, Set, Union
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union

from azure.core.polling import LROPoller
from azure.identity import ClientSecretCredential, DefaultAzureCredential
Expand Down Expand Up @@ -891,3 +891,23 @@ def cancel_trigger(
:param config: Extra parameters for the ADF client.
"""
self.get_conn().trigger_runs.cancel(resource_group_name, factory_name, trigger_name, run_id, **config)

def test_connection(self) -> Tuple[bool, str]:
"""Test a configured Azure Data Factory connection."""
success = (True, "Successfully connected to Azure Data Factory.")

try:
# Attempt to list existing factories under the configured subscription and retrieve the first in
# the returned iterator. The Azure Data Factory API does allow for creation of a
# DataFactoryManagementClient with incorrect values but then will fail properly once items are
# retrieved using the client. We need to _actually_ try to retrieve an object to properly test the
# connection.
next(self.get_conn().factories.list())
return success
except StopIteration:
# If the iterator returned is empty it should still be considered a successful connection since
# it's possible to create a Data Factory via the ``AzureDataFactoryHook`` and none could
# legitimately exist yet.
return success
except Exception as e:
return False, str(e)
68 changes: 67 additions & 1 deletion tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@

import json
from typing import Type
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, PropertyMock, patch

import pytest
from azure.identity import ClientSecretCredential, DefaultAzureCredential
from azure.mgmt.datafactory.models import FactoryListResponse
from pytest import fixture

from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -74,9 +75,37 @@ def setup_module():
}
),
)
connection_missing_subscription_id = Connection(
conn_id="azure_data_factory_missing_subscription_id",
conn_type="azure_data_factory",
login="clientId",
password="clientSecret",
extra=json.dumps(
{
"extra__azure_data_factory__tenantId": "tenantId",
"extra__azure_data_factory__resource_group_name": DEFAULT_RESOURCE_GROUP,
"extra__azure_data_factory__factory_name": DEFAULT_FACTORY,
}
),
)
connection_missing_tenant_id = Connection(
conn_id="azure_data_factory_missing_tenant_id",
conn_type="azure_data_factory",
login="clientId",
password="clientSecret",
extra=json.dumps(
{
"extra__azure_data_factory__subscriptionId": "subscriptionId",
"extra__azure_data_factory__resource_group_name": DEFAULT_RESOURCE_GROUP,
"extra__azure_data_factory__factory_name": DEFAULT_FACTORY,
}
),
)

db.merge_conn(connection_client_secret)
db.merge_conn(connection_default_credential)
db.merge_conn(connection_missing_subscription_id)
db.merge_conn(connection_missing_tenant_id)


@fixture
Expand Down Expand Up @@ -526,3 +555,40 @@ def test_cancel_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
hook.cancel_trigger(*user_args)

hook._conn.trigger_runs.cancel.assert_called_with(*sdk_args)


@pytest.mark.parametrize(
argnames="factory_list_result",
argvalues=[iter([FactoryListResponse]), iter([])],
ids=["factory_exists", "factory_does_not_exist"],
)
def test_connection_success(hook, factory_list_result):
hook.get_conn().factories.list.return_value = factory_list_result
status, msg = hook.test_connection()

assert status is True
assert msg == "Successfully connected to Azure Data Factory."


def test_connection_failure(hook):
hook.get_conn().factories.list = PropertyMock(side_effect=Exception("Authentication failed."))
status, msg = hook.test_connection()

assert status is False
assert msg == "Authentication failed."


def test_connection_failure_missing_subscription_id():
hook = AzureDataFactoryHook("azure_data_factory_missing_subscription_id")
status, msg = hook.test_connection()

assert status is False
assert msg == "A Subscription ID is required to connect to Azure Data Factory."


def test_connection_failure_missing_tenant_id():
hook = AzureDataFactoryHook("azure_data_factory_missing_tenant_id")
status, msg = hook.test_connection()

assert status is False
assert msg == "A Tenant ID is required when authenticating with Client ID and Secret."