From 5d48c43ed79b0881f667dd072c7c914893cd8d0f Mon Sep 17 00:00:00 2001 From: Andrey Anshin Date: Fri, 25 Aug 2023 16:34:51 +0400 Subject: [PATCH 1/3] Postpone obtain connection in MS Azure provider --- airflow/providers/microsoft/azure/hooks/adx.py | 15 ++++++++++----- .../providers/microsoft/azure/hooks/batch.py | 17 +++++++++++------ .../microsoft/azure/hooks/container_instance.py | 6 +++++- .../microsoft/azure/hooks/container_registry.py | 10 +++++++--- .../microsoft/azure/hooks/data_lake.py | 15 ++++++++++----- airflow/providers/microsoft/azure/hooks/wasb.py | 7 ++++++- .../microsoft/azure/log/wasb_task_handler.py | 1 - .../microsoft/azure/operators/batch.py | 3 ++- .../microsoft/azure/operators/data_factory.py | 7 ++++++- .../microsoft/azure/operators/synapse.py | 9 ++++++--- .../microsoft/azure/sensors/data_factory.py | 7 ++++++- .../microsoft/azure/hooks/test_azure_batch.py | 7 +++++-- 12 files changed, 74 insertions(+), 30 deletions(-) diff --git a/airflow/providers/microsoft/azure/hooks/adx.py b/airflow/providers/microsoft/azure/hooks/adx.py index f2cfd1a6c7e53..53da21e39605e 100644 --- a/airflow/providers/microsoft/azure/hooks/adx.py +++ b/airflow/providers/microsoft/azure/hooks/adx.py @@ -26,6 +26,7 @@ from __future__ import annotations import warnings +from functools import cached_property from typing import Any from azure.identity import DefaultAzureCredential @@ -76,8 +77,8 @@ class AzureDataExplorerHook(BaseHook): conn_type = "azure_data_explorer" hook_name = "Azure Data Explorer" - @staticmethod - def get_connection_form_widgets() -> dict[str, Any]: + @classmethod + def get_connection_form_widgets(cls) -> dict[str, Any]: """Returns connection widgets to add to connection form.""" from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget from flask_babel import lazy_gettext @@ -94,8 +95,8 @@ def get_connection_form_widgets() -> dict[str, Any]: ), } - @staticmethod - def get_ui_field_behaviour() -> dict[str, Any]: + @classmethod + def get_ui_field_behaviour(cls) -> dict[str, Any]: """Returns custom field behaviour.""" return { "hidden_fields": ["schema", "port", "extra"], @@ -116,7 +117,11 @@ def get_ui_field_behaviour() -> dict[str, Any]: def __init__(self, azure_data_explorer_conn_id: str = default_conn_name) -> None: super().__init__() self.conn_id = azure_data_explorer_conn_id - self.connection = self.get_conn() # todo: make this a property, or just delete + + @cached_property + def connection(self) -> KustoClient: + """Return a KustoClient object (cached).""" + return self.get_conn() def get_conn(self) -> KustoClient: """Return a KustoClient object.""" diff --git a/airflow/providers/microsoft/azure/hooks/batch.py b/airflow/providers/microsoft/azure/hooks/batch.py index deca28216dbb9..108e23389ea2a 100644 --- a/airflow/providers/microsoft/azure/hooks/batch.py +++ b/airflow/providers/microsoft/azure/hooks/batch.py @@ -19,6 +19,7 @@ import time from datetime import timedelta +from functools import cached_property from typing import Any from azure.batch import BatchServiceClient, batch_auth, models as batch_models @@ -52,8 +53,8 @@ def _get_field(self, extras, name): field_name=name, ) - @staticmethod - def get_connection_form_widgets() -> dict[str, Any]: + @classmethod + def get_connection_form_widgets(cls) -> dict[str, Any]: """Returns connection widgets to add to connection form.""" from flask_appbuilder.fieldwidgets import BS3TextFieldWidget from flask_babel import lazy_gettext @@ -63,8 +64,8 @@ def get_connection_form_widgets() -> dict[str, Any]: "account_url": StringField(lazy_gettext("Batch Account URL"), widget=BS3TextFieldWidget()), } - @staticmethod - def get_ui_field_behaviour() -> dict[str, Any]: + @classmethod + def get_ui_field_behaviour(cls) -> dict[str, Any]: """Returns custom field behaviour.""" return { "hidden_fields": ["schema", "port", "host", "extra"], @@ -77,14 +78,18 @@ def get_ui_field_behaviour() -> dict[str, Any]: def __init__(self, azure_batch_conn_id: str = default_conn_name) -> None: super().__init__() self.conn_id = azure_batch_conn_id - self.connection = self.get_conn() def _connection(self) -> Connection: """Get connected to Azure Batch service.""" conn = self.get_connection(self.conn_id) return conn - def get_conn(self): + @cached_property + def connection(self) -> BatchServiceClient: + """Get the Batch client connection (cached).""" + return self.get_conn() + + def get_conn(self) -> BatchServiceClient: """ Get the Batch client connection. diff --git a/airflow/providers/microsoft/azure/hooks/container_instance.py b/airflow/providers/microsoft/azure/hooks/container_instance.py index 9a0d0ec21067f..8fc845bf13b19 100644 --- a/airflow/providers/microsoft/azure/hooks/container_instance.py +++ b/airflow/providers/microsoft/azure/hooks/container_instance.py @@ -18,6 +18,7 @@ from __future__ import annotations import warnings +from functools import cached_property from azure.mgmt.containerinstance import ContainerInstanceManagementClient from azure.mgmt.containerinstance.models import ContainerGroup @@ -47,7 +48,10 @@ class AzureContainerInstanceHook(AzureBaseHook): def __init__(self, azure_conn_id: str = default_conn_name) -> None: super().__init__(sdk_client=ContainerInstanceManagementClient, conn_id=azure_conn_id) - self.connection = self.get_conn() + + @cached_property + def connection(self): + return self.get_conn() def create_or_update(self, resource_group: str, name: str, container_group: ContainerGroup) -> None: """ diff --git a/airflow/providers/microsoft/azure/hooks/container_registry.py b/airflow/providers/microsoft/azure/hooks/container_registry.py index a3298117ccf43..c1217e3a86db7 100644 --- a/airflow/providers/microsoft/azure/hooks/container_registry.py +++ b/airflow/providers/microsoft/azure/hooks/container_registry.py @@ -18,6 +18,7 @@ """Hook for Azure Container Registry.""" from __future__ import annotations +from functools import cached_property from typing import Any from azure.mgmt.containerinstance.models import ImageRegistryCredential @@ -39,8 +40,8 @@ class AzureContainerRegistryHook(BaseHook): conn_type = "azure_container_registry" hook_name = "Azure Container Registry" - @staticmethod - def get_ui_field_behaviour() -> dict[str, Any]: + @classmethod + def get_ui_field_behaviour(cls) -> dict[str, Any]: """Returns custom field behaviour.""" return { "hidden_fields": ["schema", "port", "extra"], @@ -59,7 +60,10 @@ def get_ui_field_behaviour() -> dict[str, Any]: def __init__(self, conn_id: str = "azure_registry") -> None: super().__init__() self.conn_id = conn_id - self.connection = self.get_conn() + + @cached_property + def connection(self) -> ImageRegistryCredential: + return self.get_conn() def get_conn(self) -> ImageRegistryCredential: conn = self.get_connection(self.conn_id) diff --git a/airflow/providers/microsoft/azure/hooks/data_lake.py b/airflow/providers/microsoft/azure/hooks/data_lake.py index 95ef4c6cc2789..3849727e861d3 100644 --- a/airflow/providers/microsoft/azure/hooks/data_lake.py +++ b/airflow/providers/microsoft/azure/hooks/data_lake.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +from functools import cached_property from typing import Any from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError @@ -256,8 +257,8 @@ class AzureDataLakeStorageV2Hook(BaseHook): conn_type = "adls" hook_name = "Azure Date Lake Storage V2" - @staticmethod - def get_connection_form_widgets() -> dict[str, Any]: + @classmethod + def get_connection_form_widgets(cls) -> dict[str, Any]: """Returns connection widgets to add to connection form.""" from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget from flask_babel import lazy_gettext @@ -272,8 +273,8 @@ def get_connection_form_widgets() -> dict[str, Any]: ), } - @staticmethod - def get_ui_field_behaviour() -> dict[str, Any]: + @classmethod + def get_ui_field_behaviour(cls) -> dict[str, Any]: """Returns custom field behaviour.""" return { "hidden_fields": ["schema", "port"], @@ -296,7 +297,11 @@ def __init__(self, adls_conn_id: str, public_read: bool = False) -> None: super().__init__() self.conn_id = adls_conn_id self.public_read = public_read - self.service_client = self.get_conn() + + @cached_property + def service_client(self) -> DataLakeServiceClient: + """Return the DataLakeServiceClient object (cached).""" + return self.get_conn() def get_conn(self) -> DataLakeServiceClient: # type: ignore[override] """Return the DataLakeServiceClient object.""" diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py b/airflow/providers/microsoft/azure/hooks/wasb.py index 95d4764b675cb..55c1aba08632b 100644 --- a/airflow/providers/microsoft/azure/hooks/wasb.py +++ b/airflow/providers/microsoft/azure/hooks/wasb.py @@ -27,6 +27,7 @@ import logging import os +from functools import cached_property from typing import Any, Union from urllib.parse import urlparse @@ -123,7 +124,6 @@ def __init__( super().__init__() self.conn_id = wasb_conn_id self.public_read = public_read - self.blob_service_client: BlobServiceClient = self.get_conn() logger = logging.getLogger("azure.core.pipeline.policies.http_logging_policy") try: @@ -142,6 +142,11 @@ def _get_field(self, extra_dict, field_name): return extra_dict[field_name] or None return extra_dict.get(f"{prefix}{field_name}") or None + @cached_property + def blob_service_client(self) -> BlobServiceClient: + """Return the BlobServiceClient object (cached).""" + return self.get_conn() + def get_conn(self) -> BlobServiceClient: """Return the BlobServiceClient object.""" conn = self.get_connection(self.conn_id) diff --git a/airflow/providers/microsoft/azure/log/wasb_task_handler.py b/airflow/providers/microsoft/azure/log/wasb_task_handler.py index 97a8af5ae1d67..21e96f1003700 100644 --- a/airflow/providers/microsoft/azure/log/wasb_task_handler.py +++ b/airflow/providers/microsoft/azure/log/wasb_task_handler.py @@ -67,7 +67,6 @@ def __init__( self.wasb_container = wasb_container self.remote_base = wasb_log_folder self.log_relative_path = "" - self._hook = None self.closed = False self.upload_on_close = True self.delete_local_copy = ( diff --git a/airflow/providers/microsoft/azure/operators/batch.py b/airflow/providers/microsoft/azure/operators/batch.py index e26f56dd6e62b..06122ed1f9f00 100644 --- a/airflow/providers/microsoft/azure/operators/batch.py +++ b/airflow/providers/microsoft/azure/operators/batch.py @@ -179,7 +179,8 @@ def __init__( self.should_delete_pool = should_delete_pool @cached_property - def hook(self): + def hook(self) -> AzureBatchHook: + """Create and return an AzureBatchHook (cached).""" return self.get_hook() def _check_inputs(self) -> Any: diff --git a/airflow/providers/microsoft/azure/operators/data_factory.py b/airflow/providers/microsoft/azure/operators/data_factory.py index d6b4592e35b6e..12962e5610228 100644 --- a/airflow/providers/microsoft/azure/operators/data_factory.py +++ b/airflow/providers/microsoft/azure/operators/data_factory.py @@ -18,6 +18,7 @@ import time import warnings +from functools import cached_property from typing import TYPE_CHECKING, Any, Sequence from airflow.configuration import conf @@ -159,8 +160,12 @@ def __init__( self.check_interval = check_interval self.deferrable = deferrable + @cached_property + def hook(self) -> AzureDataFactoryHook: + """Create and return an AzureDataFactoryHook (cached).""" + return AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id) + def execute(self, context: Context) -> None: - self.hook = AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id) self.log.info("Executing the %s pipeline.", self.pipeline_name) response = self.hook.run_pipeline( pipeline_name=self.pipeline_name, diff --git a/airflow/providers/microsoft/azure/operators/synapse.py b/airflow/providers/microsoft/azure/operators/synapse.py index b9d97704c57f6..dd6dda55557f0 100644 --- a/airflow/providers/microsoft/azure/operators/synapse.py +++ b/airflow/providers/microsoft/azure/operators/synapse.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +from functools import cached_property from typing import TYPE_CHECKING, Sequence from azure.synapse.spark.models import SparkBatchJobOptions @@ -73,10 +74,12 @@ def __init__( self.timeout = timeout self.check_interval = check_interval + @cached_property + def hook(self): + """Create and return an AzureSynapseHook (cached).""" + return AzureSynapseHook(azure_synapse_conn_id=self.azure_synapse_conn_id, spark_pool=self.spark_pool) + def execute(self, context: Context) -> None: - self.hook = AzureSynapseHook( - azure_synapse_conn_id=self.azure_synapse_conn_id, spark_pool=self.spark_pool - ) self.log.info("Executing the Synapse spark job.") response = self.hook.run_spark_job(payload=self.payload) self.log.info(response) diff --git a/airflow/providers/microsoft/azure/sensors/data_factory.py b/airflow/providers/microsoft/azure/sensors/data_factory.py index 4caa26a99d84b..91bc4072c0dc8 100644 --- a/airflow/providers/microsoft/azure/sensors/data_factory.py +++ b/airflow/providers/microsoft/azure/sensors/data_factory.py @@ -18,6 +18,7 @@ import warnings from datetime import timedelta +from functools import cached_property from typing import TYPE_CHECKING, Any, Sequence from airflow.configuration import conf @@ -72,8 +73,12 @@ def __init__( self.deferrable = deferrable + @cached_property + def hook(self): + """Create and return an AzureDataFactoryHook (cached).""" + return AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id) + def poke(self, context: Context) -> bool: - self.hook = AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id) pipeline_run_status = self.hook.get_pipeline_run_status( run_id=self.run_id, resource_group_name=self.resource_group_name, diff --git a/tests/providers/microsoft/azure/hooks/test_azure_batch.py b/tests/providers/microsoft/azure/hooks/test_azure_batch.py index a3a421f5a00d9..31eb5487270c4 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_batch.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_batch.py @@ -68,6 +68,9 @@ def test_connection_and_client(self): hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) assert isinstance(hook._connection(), Connection) assert isinstance(hook.get_conn(), BatchServiceClient) + conn = hook.connection + assert isinstance(conn, BatchServiceClient) + assert hook.connection is conn, "`connection` property should be cached" @mock.patch(f"{MODULE}.batch_auth.SharedKeyCredentials") @mock.patch(f"{MODULE}.AzureIdentityCredentialAdapter") @@ -195,7 +198,7 @@ def test_wait_for_all_task_to_complete(self, mock_batch): @mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient") def test_connection_success(self, mock_batch): hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id) - hook.get_conn().job.return_value = {} + hook.connection.job.return_value = {} status, msg = hook.test_connection() assert status is True assert msg == "Successfully connected to Azure Batch." @@ -203,7 +206,7 @@ def test_connection_success(self, mock_batch): @mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient") def test_connection_failure(self, mock_batch): hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id) - hook.get_conn().job.list = PropertyMock(side_effect=Exception("Authentication failed.")) + hook.connection.job.list = PropertyMock(side_effect=Exception("Authentication failed.")) status, msg = hook.test_connection() assert status is False assert msg == "Authentication failed." From aa665f724afc741d93a23c38092a82402a000064 Mon Sep 17 00:00:00 2001 From: Andrey Anshin Date: Sat, 26 Aug 2023 01:23:56 +0400 Subject: [PATCH 2/3] Azure Provider tests refactoring --- .../providers/microsoft/azure/hooks/batch.py | 8 +- .../microsoft/azure/hooks/test_adx.py | 289 +++++---- .../microsoft/azure/hooks/test_asb.py | 42 +- .../microsoft/azure/hooks/test_azure_batch.py | 28 +- .../hooks/test_azure_container_instance.py | 17 +- .../hooks/test_azure_container_registry.py | 15 +- .../hooks/test_azure_container_volume.py | 31 +- .../azure/hooks/test_azure_cosmos.py | 18 +- .../azure/hooks/test_azure_data_factory.py | 232 ++++--- .../azure/hooks/test_azure_data_lake.py | 14 +- .../azure/hooks/test_azure_fileshare.py | 34 +- .../azure/hooks/test_azure_synapse.py | 83 +-- .../microsoft/azure/hooks/test_base_azure.py | 72 ++- .../microsoft/azure/hooks/test_wasb.py | 595 ++++++++---------- .../azure/operators/test_azure_batch.py | 33 +- .../azure/operators/test_azure_cosmos.py | 16 +- .../operators/test_azure_data_factory.py | 10 +- .../azure/operators/test_azure_synapse.py | 21 +- tests/providers/microsoft/conftest.py | 68 ++ 19 files changed, 798 insertions(+), 828 deletions(-) create mode 100644 tests/providers/microsoft/conftest.py diff --git a/airflow/providers/microsoft/azure/hooks/batch.py b/airflow/providers/microsoft/azure/hooks/batch.py index 108e23389ea2a..594725c0da089 100644 --- a/airflow/providers/microsoft/azure/hooks/batch.py +++ b/airflow/providers/microsoft/azure/hooks/batch.py @@ -27,7 +27,6 @@ from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook -from airflow.models import Connection from airflow.providers.microsoft.azure.utils import AzureIdentityCredentialAdapter, get_field from airflow.utils import timezone @@ -79,11 +78,6 @@ def __init__(self, azure_batch_conn_id: str = default_conn_name) -> None: super().__init__() self.conn_id = azure_batch_conn_id - def _connection(self) -> Connection: - """Get connected to Azure Batch service.""" - conn = self.get_connection(self.conn_id) - return conn - @cached_property def connection(self) -> BatchServiceClient: """Get the Batch client connection (cached).""" @@ -95,7 +89,7 @@ def get_conn(self) -> BatchServiceClient: :return: Azure Batch client """ - conn = self._connection() + conn = self.get_connection(self.conn_id) batch_account_url = self._get_field(conn.extra_dejson, "account_url") if not batch_account_url: diff --git a/tests/providers/microsoft/azure/hooks/test_adx.py b/tests/providers/microsoft/azure/hooks/test_adx.py index 4cf61c440b4ae..2268f8b8a6df2 100644 --- a/tests/providers/microsoft/azure/hooks/test_adx.py +++ b/tests/providers/microsoft/azure/hooks/test_adx.py @@ -17,10 +17,7 @@ # under the License. from __future__ import annotations -import json -import os from unittest import mock -from unittest.mock import patch import pytest from azure.kusto.data import ClientRequestProperties, KustoClient, KustoConnectionStringBuilder @@ -29,196 +26,220 @@ from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.providers.microsoft.azure.hooks.adx import AzureDataExplorerHook -from airflow.utils import db -from airflow.utils.session import create_session from tests.test_utils.providers import get_provider_min_airflow_version ADX_TEST_CONN_ID = "adx_test_connection_id" class TestAzureDataExplorerHook: - def teardown_method(self): - with create_session() as session: - session.query(Connection).filter(Connection.conn_id == ADX_TEST_CONN_ID).delete() - - def test_conn_missing_method(self): - db.merge_conn( + @pytest.mark.parametrize( + "mocked_connection", + [ Connection( - conn_id=ADX_TEST_CONN_ID, + conn_id="missing_method", conn_type="azure_data_explorer", login="client_id", password="client secret", host="https://help.kusto.windows.net", - extra=json.dumps({}), + extra={}, ) - ) - with pytest.raises(AirflowException) as ctx: - AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID) - assert "is missing: `data_explorer__auth_method`" in str(ctx.value) + ], + indirect=True, + ) + def test_conn_missing_method(self, mocked_connection): + hook = AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id) + error_pattern = "is missing: `auth_method`" + with pytest.raises(AirflowException, match=error_pattern): + assert hook.get_conn() + with pytest.raises(AirflowException, match=error_pattern): + assert hook.connection - def test_conn_unknown_method(self): - db.merge_conn( + @pytest.mark.parametrize( + "mocked_connection", + [ Connection( - conn_id=ADX_TEST_CONN_ID, + conn_id="unknown_method", conn_type="azure_data_explorer", login="client_id", password="client secret", host="https://help.kusto.windows.net", - extra=json.dumps({"auth_method": "AAD_OTHER"}), - ) - ) - with pytest.raises(AirflowException) as ctx: - AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID) - assert "Unknown authentication method: AAD_OTHER" in str(ctx.value) + extra={"auth_method": "AAD_OTHER"}, + ), + ], + indirect=True, + ) + def test_conn_unknown_method(self, mocked_connection): + hook = AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id) + error_pattern = "Unknown authentication method: AAD_OTHER" + with pytest.raises(AirflowException, match=error_pattern): + assert hook.get_conn() + with pytest.raises(AirflowException, match=error_pattern): + assert hook.connection - def test_conn_missing_cluster(self): - db.merge_conn( + @pytest.mark.parametrize( + "mocked_connection", + [ Connection( - conn_id=ADX_TEST_CONN_ID, + conn_id="missing_cluster", conn_type="azure_data_explorer", login="client_id", password="client secret", - extra=json.dumps({}), - ) - ) - with pytest.raises(AirflowException) as ctx: - AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID) - assert "Host connection option is required" in str(ctx.value) + extra={}, + ), + ], + indirect=True, + ) + def test_conn_missing_cluster(self, mocked_connection): + hook = AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id) + error_pattern = "Host connection option is required" + with pytest.raises(AirflowException, match=error_pattern): + assert hook.get_conn() + with pytest.raises(AirflowException, match=error_pattern): + assert hook.connection - @mock.patch.object(KustoClient, "__init__") - def test_conn_method_aad_creds(self, mock_init): - mock_init.return_value = None - db.merge_conn( + @pytest.mark.parametrize( + "mocked_connection", + [ Connection( - conn_id=ADX_TEST_CONN_ID, + conn_id="method_aad_creds", conn_type="azure_data_explorer", login="client_id", password="client secret", host="https://help.kusto.windows.net", - extra=json.dumps( - { - "tenant": "tenant", - "auth_method": "AAD_CREDS", - } - ), + extra={"tenant": "tenant", "auth_method": "AAD_CREDS"}, ) - ) - AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID) + ], + indirect=True, + ) + @mock.patch.object(KustoClient, "__init__") + def test_conn_method_aad_creds(self, mock_init, mocked_connection): + mock_init.return_value = None + AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id).get_conn() assert mock_init.called_with( KustoConnectionStringBuilder.with_aad_user_password_authentication( "https://help.kusto.windows.net", "client_id", "client secret", "tenant" ) ) - @mock.patch("azure.identity._credentials.environment.ClientSecretCredential") - def test_conn_method_token_creds(self, mock1): - db.merge_conn( + @pytest.mark.parametrize( + "mocked_connection", + [ Connection( - conn_id=ADX_TEST_CONN_ID, + conn_id="method_token_creds", conn_type="azure_data_explorer", host="https://help.kusto.windows.net", - extra=json.dumps( - { - "auth_method": "AZURE_TOKEN_CRED", - } - ), - ) + extra={ + "auth_method": "AZURE_TOKEN_CRED", + }, + ), + ], + indirect=True, + ) + @mock.patch("azure.identity._credentials.environment.ClientSecretCredential") + def test_conn_method_token_creds(self, mock1, mocked_connection, monkeypatch): + hook = AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id) + + monkeypatch.setenv("AZURE_TENANT_ID", "tenant") + monkeypatch.setenv("AZURE_CLIENT_ID", "client") + monkeypatch.setenv("AZURE_CLIENT_SECRET", "secret") + + assert hook.connection._kcsb.data_source == "https://help.kusto.windows.net" + mock1.assert_called_once_with( + tenant_id="tenant", + client_id="client", + client_secret="secret", + authority="https://login.microsoftonline.com", ) - with patch.dict( - in_dict=os.environ, - values={ - "AZURE_TENANT_ID": "tenant", - "AZURE_CLIENT_ID": "client", - "AZURE_CLIENT_SECRET": "secret", - }, - ): - hook = AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID) - assert hook.connection._kcsb.data_source == "https://help.kusto.windows.net" - mock1.assert_called_once_with( - tenant_id="tenant", - client_id="client", - client_secret="secret", - authority="https://login.microsoftonline.com", - ) - @mock.patch.object(KustoClient, "__init__") - def test_conn_method_aad_app(self, mock_init): - mock_init.return_value = None - db.merge_conn( + @pytest.mark.parametrize( + "mocked_connection", + [ Connection( - conn_id=ADX_TEST_CONN_ID, + conn_id="method_aad_app", conn_type="azure_data_explorer", login="app_id", password="app key", host="https://help.kusto.windows.net", - extra=json.dumps( - { - "tenant": "tenant", - "auth_method": "AAD_APP", - } - ), + extra={ + "tenant": "tenant", + "auth_method": "AAD_APP", + }, ) - ) - AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID) + ], + indirect=True, + ) + @mock.patch.object(KustoClient, "__init__") + def test_conn_method_aad_app(self, mock_init, mocked_connection): + mock_init.return_value = None + AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id).get_conn() assert mock_init.called_with( KustoConnectionStringBuilder.with_aad_application_key_authentication( "https://help.kusto.windows.net", "app_id", "app key", "tenant" ) ) - @mock.patch.object(KustoClient, "__init__") - def test_conn_method_aad_app_cert(self, mock_init): - mock_init.return_value = None - db.merge_conn( + @pytest.mark.parametrize( + "mocked_connection", + [ Connection( - conn_id=ADX_TEST_CONN_ID, + conn_id="method_aad_app", conn_type="azure_data_explorer", - login="client_id", + login="app_id", + password="app key", host="https://help.kusto.windows.net", - extra=json.dumps( - { - "tenant": "tenant", - "auth_method": "AAD_APP_CERT", - "certificate": "PEM", - "thumbprint": "thumbprint", - } - ), + extra={ + "tenant": "tenant", + "auth_method": "AAD_APP", + }, ) - ) - AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID) + ], + indirect=True, + ) + @mock.patch.object(KustoClient, "__init__") + def test_conn_method_aad_app_cert(self, mock_init, mocked_connection): + mock_init.return_value = None + AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id).get_conn() assert mock_init.called_with( KustoConnectionStringBuilder.with_aad_application_certificate_authentication( "https://help.kusto.windows.net", "client_id", "PEM", "thumbprint", "tenant" ) ) - @mock.patch.object(KustoClient, "__init__") - def test_conn_method_aad_device(self, mock_init): - mock_init.return_value = None - db.merge_conn( + @pytest.mark.parametrize( + "mocked_connection", + [ Connection( conn_id=ADX_TEST_CONN_ID, conn_type="azure_data_explorer", host="https://help.kusto.windows.net", - extra=json.dumps({"auth_method": "AAD_DEVICE"}), + extra={"auth_method": "AAD_DEVICE"}, ) - ) - AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID) + ], + indirect=True, + ) + @mock.patch.object(KustoClient, "__init__") + def test_conn_method_aad_device(self, mock_init, mocked_connection): + mock_init.return_value = None + AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id).get_conn() assert mock_init.called_with( KustoConnectionStringBuilder.with_aad_device_authentication("https://help.kusto.windows.net") ) - @mock.patch.object(KustoClient, "execute") - def test_run_query(self, mock_execute): - mock_execute.return_value = None - db.merge_conn( + @pytest.mark.parametrize( + "mocked_connection", + [ Connection( conn_id=ADX_TEST_CONN_ID, conn_type="azure_data_explorer", host="https://help.kusto.windows.net", - extra=json.dumps({"auth_method": "AAD_DEVICE"}), + extra={"auth_method": "AAD_DEVICE"}, ) - ) + ], + indirect=True, + ) + @mock.patch.object(KustoClient, "execute") + def test_run_query(self, mock_execute, mocked_connection): + mock_execute.return_value = None hook = AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID) hook.run_query("Database", "Logs | schema", options={"option1": "option_value"}) properties = ClientRequestProperties() @@ -246,7 +267,7 @@ def test_get_ui_field_behaviour_placeholders(self): ) @pytest.mark.parametrize( - "uri", + "mocked_connection", [ param( "a://usr:pw@host?extra__azure_data_explorer__tenant=my-tenant" @@ -255,24 +276,28 @@ def test_get_ui_field_behaviour_placeholders(self): ), param("a://usr:pw@host?tenant=my-tenant&auth_method=AAD_APP", id="no-prefix"), ], + indirect=True, ) - def test_backcompat_prefix_works(self, uri): - with patch.dict(os.environ, AIRFLOW_CONN_MY_CONN=uri): - hook = AzureDataExplorerHook(azure_data_explorer_conn_id="my_conn") # get_conn is called in init + def test_backcompat_prefix_works(self, mocked_connection): + hook = AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id) assert hook.connection._kcsb.data_source == "host" assert hook.connection._kcsb.application_client_id == "usr" assert hook.connection._kcsb.application_key == "pw" assert hook.connection._kcsb.authority_id == "my-tenant" - def test_backcompat_prefix_both_causes_warning(self): - with patch.dict( - in_dict=os.environ, - AIRFLOW_CONN_MY_CONN="a://usr:pw@host?tenant=my-tenant&auth_method=AAD_APP" - "&extra__azure_data_explorer__auth_method=AAD_APP", - ): - with pytest.warns(Warning, match="Using value for `auth_method`"): - hook = AzureDataExplorerHook(azure_data_explorer_conn_id="my_conn") - assert hook.connection._kcsb.data_source == "host" - assert hook.connection._kcsb.application_client_id == "usr" - assert hook.connection._kcsb.application_key == "pw" - assert hook.connection._kcsb.authority_id == "my-tenant" + @pytest.mark.parametrize( + "mocked_connection", + [ + ( + "a://usr:pw@host?tenant=my-tenant&auth_method=AAD_APP" + "&extra__azure_data_explorer__auth_method=AAD_APP" + ) + ], + indirect=True, + ) + def test_backcompat_prefix_both_causes_warning(self, mocked_connection): + hook = AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id) + assert hook.connection._kcsb.data_source == "host" + assert hook.connection._kcsb.application_client_id == "usr" + assert hook.connection._kcsb.application_key == "pw" + assert hook.connection._kcsb.authority_id == "my-tenant" diff --git a/tests/providers/microsoft/azure/hooks/test_asb.py b/tests/providers/microsoft/azure/hooks/test_asb.py index a9a3851561623..5f626d6c290cf 100644 --- a/tests/providers/microsoft/azure/hooks/test_asb.py +++ b/tests/providers/microsoft/azure/hooks/test_asb.py @@ -34,22 +34,23 @@ class TestAdminClientHook: - def setup_class(self) -> None: - self.queue_name: str = "test_queue" - self.conn_id: str = "azure_service_bus_default" + @pytest.fixture(autouse=True) + def setup_test_cases(self, create_mock_connection): + self.queue_name = "test_queue" + self.conn_id = "azure_service_bus_default" self.connection_string = ( "Endpoint=sb://test-service-bus-provider.servicebus.windows.net/;" "SharedAccessKeyName=Test;SharedAccessKey=1234566acbc" ) - self.mock_conn = Connection( - conn_id="azure_service_bus_default", - conn_type="azure_service_bus", - schema=self.connection_string, + self.mock_conn = create_mock_connection( + Connection( + conn_id=self.conn_id, + conn_type="azure_service_bus", + schema=self.connection_string, + ) ) - @mock.patch("airflow.providers.microsoft.azure.hooks.asb.AdminClientHook.get_connection") - def test_get_conn(self, mock_connection): - mock_connection.return_value = self.mock_conn + def test_get_conn(self): hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id) assert isinstance(hook.get_conn(), ServiceBusAdministrationClient) @@ -124,26 +125,27 @@ def test_delete_subscription_exception( class TestMessageHook: - def setup_class(self) -> None: - self.queue_name: str = "test_queue" - self.conn_id: str = "azure_service_bus_default" + @pytest.fixture(autouse=True) + def setup_test_cases(self, create_mock_connection): + self.queue_name = "test_queue" + self.conn_id = "azure_service_bus_default" self.connection_string = ( "Endpoint=sb://test-service-bus-provider.servicebus.windows.net/;" "SharedAccessKeyName=Test;SharedAccessKey=1234566acbc" ) - self.conn = Connection( - conn_id="azure_service_bus_default", - conn_type="azure_service_bus", - schema=self.connection_string, + self.mock_conn = create_mock_connection( + Connection( + conn_id=self.conn_id, + conn_type="azure_service_bus", + schema=self.connection_string, + ) ) - @mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_connection") - def test_get_service_bus_message_conn(self, mock_connection): + def test_get_service_bus_message_conn(self): """ Test get_conn() function and check whether the get_conn() function returns value is instance of ServiceBusClient """ - mock_connection.return_value = self.conn hook = MessageHook(azure_service_bus_conn_id=self.conn_id) assert isinstance(hook.get_conn(), ServiceBusClient) diff --git a/tests/providers/microsoft/azure/hooks/test_azure_batch.py b/tests/providers/microsoft/azure/hooks/test_azure_batch.py index 31eb5487270c4..cd5ab1013433b 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_batch.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_batch.py @@ -17,22 +17,21 @@ # under the License. from __future__ import annotations -import json from unittest import mock from unittest.mock import PropertyMock +import pytest from azure.batch import BatchServiceClient, models as batch_models from airflow.models import Connection from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook -from airflow.utils import db MODULE = "airflow.providers.microsoft.azure.hooks.batch" class TestAzureBatchHook: - # set up the test environment - def setup_method(self): + @pytest.fixture(autouse=True) + def setup_test_cases(self, create_mock_connections): # set up the test variable self.test_vm_conn_id = "test_azure_batch_vm" self.test_cloud_conn_id = "test_azure_batch_cloud" @@ -47,26 +46,23 @@ def setup_method(self): self.test_cloud_os_version = "test-version" self.test_node_agent_sku = "test-node-agent-sku" - # connect with vm configuration - db.merge_conn( + create_mock_connections( + # connect with vm configuration Connection( conn_id=self.test_vm_conn_id, - conn_type="azure_batch", - extra=json.dumps({"account_url": self.test_account_url}), - ) - ) - # connect with cloud service - db.merge_conn( + conn_type="azure-batch", + extra={"account_url": self.test_account_url}, + ), + # connect with cloud service Connection( conn_id=self.test_cloud_conn_id, - conn_type="azure_batch", - extra=json.dumps({"account_url": self.test_account_url}), - ) + conn_type="azure-batch", + extra={"account_url": self.test_account_url}, + ), ) def test_connection_and_client(self): hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) - assert isinstance(hook._connection(), Connection) assert isinstance(hook.get_conn(), BatchServiceClient) conn = hook.connection assert isinstance(conn, BatchServiceClient) diff --git a/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py b/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py index a10dde9c651ae..786df4eb16cd0 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py @@ -17,9 +17,9 @@ # under the License. from __future__ import annotations -import json from unittest.mock import patch +import pytest from azure.mgmt.containerinstance.models import ( Container, ContainerGroup, @@ -30,27 +30,26 @@ from airflow.models import Connection from airflow.providers.microsoft.azure.hooks.container_instance import AzureContainerInstanceHook -from airflow.utils import db class TestAzureContainerInstanceHook: - def setup_method(self): - db.merge_conn( + @pytest.fixture(autouse=True) + def setup_test_cases(self, create_mock_connection): + mock_connection = create_mock_connection( Connection( conn_id="azure_container_instance_test", conn_type="azure_container_instances", login="login", password="key", - extra=json.dumps({"tenantId": "tenant_id", "subscriptionId": "subscription_id"}), + extra={"tenantId": "tenant_id", "subscriptionId": "subscription_id"}, ) ) - self.resources = ResourceRequirements(requests=ResourceRequests(memory_in_gb="4", cpu="1")) - with patch( + self.hook = AzureContainerInstanceHook(azure_conn_id=mock_connection.conn_id) + with patch("azure.mgmt.containerinstance.ContainerInstanceManagementClient"), patch( "azure.common.credentials.ServicePrincipalCredentials.__init__", autospec=True, return_value=None ): - with patch("azure.mgmt.containerinstance.ContainerInstanceManagementClient"): - self.hook = AzureContainerInstanceHook(azure_conn_id="azure_container_instance_test") + yield @patch("azure.mgmt.containerinstance.models.ContainerGroup") @patch("azure.mgmt.containerinstance.operations.ContainerGroupsOperations.create_or_update") diff --git a/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py b/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py index 91ae933d2b0fc..38f326d298a2f 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py @@ -17,14 +17,16 @@ # under the License. from __future__ import annotations +import pytest + from airflow.models import Connection from airflow.providers.microsoft.azure.hooks.container_registry import AzureContainerRegistryHook -from airflow.utils import db class TestAzureContainerRegistryHook: - def test_get_conn(self): - db.merge_conn( + @pytest.mark.parametrize( + "mocked_connection", + [ Connection( conn_id="azure_container_registry", conn_type="azure_container_registry", @@ -32,8 +34,11 @@ def test_get_conn(self): password="password", host="test.cr", ) - ) - hook = AzureContainerRegistryHook(conn_id="azure_container_registry") + ], + indirect=True, + ) + def test_get_conn(self, mocked_connection): + hook = AzureContainerRegistryHook(conn_id=mocked_connection.conn_id) assert hook.connection is not None assert hook.connection.username == "myuser" assert hook.connection.password == "password" diff --git a/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py b/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py index 3dcbfd9a894e5..b4c7b8d1c71be 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py @@ -17,22 +17,25 @@ # under the License. from __future__ import annotations -import json +import pytest from airflow.models import Connection from airflow.providers.microsoft.azure.hooks.container_volume import AzureContainerVolumeHook -from airflow.utils import db from tests.test_utils.providers import get_provider_min_airflow_version class TestAzureContainerVolumeHook: - def test_get_file_volume(self): - db.merge_conn( + @pytest.mark.parametrize( + "mocked_connection", + [ Connection( conn_id="azure_container_test_connection", conn_type="wasb", login="login", password="key" ) - ) - hook = AzureContainerVolumeHook(azure_container_volume_conn_id="azure_container_test_connection") + ], + indirect=True, + ) + def test_get_file_volume(self, mocked_connection): + hook = AzureContainerVolumeHook(azure_container_volume_conn_id=mocked_connection.conn_id) volume = hook.get_file_volume( mount_name="mount", share_name="share", storage_account_name="storage", read_only=True ) @@ -43,19 +46,21 @@ def test_get_file_volume(self): assert volume.azure_file.storage_account_name == "storage" assert volume.azure_file.read_only is True - def test_get_file_volume_connection_string(self): - db.merge_conn( + @pytest.mark.parametrize( + "mocked_connection", + [ Connection( conn_id="azure_container_test_connection_connection_string", conn_type="wasb", login="login", password="key", - extra=json.dumps({"connection_string": "a=b;AccountKey=1"}), + extra={"connection_string": "a=b;AccountKey=1"}, ) - ) - hook = AzureContainerVolumeHook( - azure_container_volume_conn_id="azure_container_test_connection_connection_string" - ) + ], + indirect=True, + ) + def test_get_file_volume_connection_string(self, mocked_connection): + hook = AzureContainerVolumeHook(azure_container_volume_conn_id=mocked_connection.conn_id) volume = hook.get_file_volume( mount_name="mount", share_name="share", storage_account_name="storage", read_only=True ) diff --git a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py index af649f1d6e4d9..f63b8e8dbd7e2 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import json import logging import uuid from unittest import mock @@ -29,14 +28,14 @@ from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.providers.microsoft.azure.hooks.cosmos import AzureCosmosDBHook -from airflow.utils import db from tests.test_utils.providers import get_provider_min_airflow_version class TestAzureCosmosDbHook: # Set up an environment to test with - def setup_method(self): + @pytest.fixture(autouse=True) + def setup_test_cases(self, create_mock_connection): # set up some test variables self.test_end_point = "https://test_endpoint:443" self.test_master_key = "magic_test_key" @@ -44,25 +43,22 @@ def setup_method(self): self.test_collection_name = "test_collection_name" self.test_database_default = "test_database_default" self.test_collection_default = "test_collection_default" - db.merge_conn( + create_mock_connection( Connection( conn_id="azure_cosmos_test_key_id", conn_type="azure_cosmos", login=self.test_end_point, password=self.test_master_key, - extra=json.dumps( - { - "database_name": self.test_database_default, - "collection_name": self.test_collection_default, - } - ), + extra={ + "database_name": self.test_database_default, + "collection_name": self.test_collection_default, + }, ) ) @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient", autospec=True) def test_client(self, mock_cosmos): hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") - assert hook._conn is None assert isinstance(hook.get_conn(), CosmosClient) @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient") diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py index 4ca2eb4920f3f..63a22614dcd8b 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import json import os from unittest import mock from unittest.mock import MagicMock, PropertyMock, patch @@ -25,9 +24,9 @@ from azure.identity import ClientSecretCredential, DefaultAzureCredential from azure.mgmt.datafactory.aio import DataFactoryManagementClient from azure.mgmt.datafactory.models import FactoryListResponse -from pytest import fixture, param +from pytest import param -from airflow import AirflowException +from airflow.exceptions import AirflowException from airflow.models.connection import Connection from airflow.providers.microsoft.azure.hooks.data_factory import ( AzureDataFactoryAsyncHook, @@ -37,7 +36,6 @@ get_field, provide_targeted_factory, ) -from airflow.utils import db DEFAULT_RESOURCE_GROUP = "defaultResourceGroup" AZURE_DATA_FACTORY_CONN_ID = "azure_data_factory_default" @@ -59,66 +57,60 @@ ID = "testId" -def setup_module(): - connection_client_secret = Connection( - conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, - conn_type="azure_data_factory", - login="clientId", - password="clientSecret", - extra=json.dumps( - { +@pytest.fixture(autouse=True) +def setup_connections(create_mock_connections): + create_mock_connections( + # connection_client_secret + Connection( + conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, + conn_type="azure_data_factory", + login="clientId", + password="clientSecret", + extra={ "tenantId": "tenantId", "subscriptionId": "subscriptionId", "resource_group_name": DEFAULT_RESOURCE_GROUP, "factory_name": DEFAULT_FACTORY, - } + }, ), - ) - connection_default_credential = Connection( - conn_id=DEFAULT_CONNECTION_DEFAULT_CREDENTIAL, - conn_type="azure_data_factory", - extra=json.dumps( - { + # connection_default_credential + Connection( + conn_id=DEFAULT_CONNECTION_DEFAULT_CREDENTIAL, + conn_type="azure_data_factory", + extra={ "subscriptionId": "subscriptionId", "resource_group_name": DEFAULT_RESOURCE_GROUP, "factory_name": DEFAULT_FACTORY, - } + }, ), - ) - connection_missing_subscription_id = Connection( - conn_id="azure_data_factory_missing_subscription_id", - conn_type="azure_data_factory", - login="clientId", - password="clientSecret", - extra=json.dumps( - { + Connection( + # connection_missing_subscription_id + conn_id="azure_data_factory_missing_subscription_id", + conn_type="azure_data_factory", + login="clientId", + password="clientSecret", + extra={ "tenantId": "tenantId", "resource_group_name": DEFAULT_RESOURCE_GROUP, "factory_name": DEFAULT_FACTORY, - } + }, ), - ) - connection_missing_tenant_id = Connection( - conn_id="azure_data_factory_missing_tenant_id", - conn_type="azure_data_factory", - login="clientId", - password="clientSecret", - extra=json.dumps( - { + # connection_missing_tenant_id + Connection( + conn_id="azure_data_factory_missing_tenant_id", + conn_type="azure_data_factory", + login="clientId", + password="clientSecret", + extra={ "subscriptionId": "subscriptionId", "resource_group_name": DEFAULT_RESOURCE_GROUP, "factory_name": DEFAULT_FACTORY, - } + }, ), ) - db.merge_conn(connection_client_secret) - db.merge_conn(connection_default_credential) - db.merge_conn(connection_missing_subscription_id) - db.merge_conn(connection_missing_tenant_id) - -@fixture +@pytest.fixture def hook(): client = AzureDataFactoryHook(azure_data_factory_conn_id=DEFAULT_CONNECTION_CLIENT_SECRET) client._conn = MagicMock( @@ -799,7 +791,7 @@ async def test_get_pipeline_run_exception_without_resource( Test get_pipeline_run function without passing the resource name to check the decorator function and raise exception """ - mock_connection = Connection(extra=json.dumps({"factory_name": DATAFACTORY_NAME})) + mock_connection = Connection(extra={"factory_name": DATAFACTORY_NAME}) mock_get_connection.return_value = mock_connection mock_conn.return_value.pipeline_runs.get.return_value = mock_pipeline_run hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) @@ -807,98 +799,98 @@ async def test_get_pipeline_run_exception_without_resource( await hook.get_pipeline_run(RUN_ID, None, DATAFACTORY_NAME) @pytest.mark.asyncio - @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection") - async def test_get_async_conn(self, mock_connection): - """""" - mock_conn = Connection( - conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, - conn_type="azure_data_factory", - login="clientId", - password="clientSecret", - extra=json.dumps( - { + @pytest.mark.parametrize( + "mocked_connection", + [ + Connection( + conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, + conn_type="azure_data_factory", + login="clientId", + password="clientSecret", + extra={ "tenantId": "tenantId", "subscriptionId": "subscriptionId", "resource_group_name": RESOURCE_GROUP_NAME, "factory_name": DATAFACTORY_NAME, - } - ), - ) - mock_connection.return_value = mock_conn - hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) + }, + ) + ], + indirect=True, + ) + async def test_get_async_conn(self, mocked_connection): + """""" + hook = AzureDataFactoryAsyncHook(mocked_connection.conn_id) response = await hook.get_async_conn() assert isinstance(response, DataFactoryManagementClient) @pytest.mark.asyncio - @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection") - async def test_get_async_conn_without_login_id(self, mock_connection): - """Test get_async_conn function without login id""" - mock_conn = Connection( - conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, - conn_type="azure_data_factory", - extra=json.dumps( - { + @pytest.mark.parametrize( + "mocked_connection", + [ + Connection( + conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, + conn_type="azure_data_factory", + extra={ "tenantId": "tenantId", "subscriptionId": "subscriptionId", "resource_group_name": RESOURCE_GROUP_NAME, "factory_name": DATAFACTORY_NAME, - } + }, ), - ) - mock_connection.return_value = mock_conn - hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) + ], + indirect=True, + ) + async def test_get_async_conn_without_login_id(self, mocked_connection): + """Test get_async_conn function without login id""" + hook = AzureDataFactoryAsyncHook(mocked_connection.conn_id) response = await hook.get_async_conn() assert isinstance(response, DataFactoryManagementClient) @pytest.mark.asyncio @pytest.mark.parametrize( - "mock_connection_params", + "mocked_connection", [ - { - "tenantId": "tenantId", - "resource_group_name": RESOURCE_GROUP_NAME, - "factory_name": DATAFACTORY_NAME, - } + Connection( + conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, + conn_type="azure_data_factory", + login="clientId", + password="clientSecret", + extra={ + "tenantId": "tenantId", + "resource_group_name": RESOURCE_GROUP_NAME, + "factory_name": DATAFACTORY_NAME, + }, + ) ], + indirect=True, ) - @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection") - async def test_get_async_conn_key_error_subscription_id(self, mock_connection, mock_connection_params): + async def test_get_async_conn_key_error_subscription_id(self, mocked_connection): """Test get_async_conn function when subscription_id is missing in the connection""" - mock_conn = Connection( - conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, - conn_type="azure_data_factory", - login="clientId", - password="clientSecret", - extra=json.dumps(mock_connection_params), - ) - mock_connection.return_value = mock_conn - hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) + hook = AzureDataFactoryAsyncHook(mocked_connection.conn_id) with pytest.raises(ValueError): await hook.get_async_conn() @pytest.mark.asyncio @pytest.mark.parametrize( - "mock_connection_params", + "mocked_connection", [ - { - "subscriptionId": "subscriptionId", - "resource_group_name": RESOURCE_GROUP_NAME, - "factory_name": DATAFACTORY_NAME, - }, + Connection( + conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, + conn_type="azure_data_factory", + login="clientId", + password="clientSecret", + extra={ + "subscriptionId": "subscriptionId", + "resource_group_name": RESOURCE_GROUP_NAME, + "factory_name": DATAFACTORY_NAME, + }, + ) ], + indirect=True, ) - @mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection") - async def test_get_async_conn_key_error_tenant_id(self, mock_connection, mock_connection_params): + async def test_get_async_conn_key_error_tenant_id(self, mocked_connection): """Test get_async_conn function when tenant id is missing in the connection""" - mock_conn = Connection( - conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, - conn_type="azure_data_factory", - login="clientId", - password="clientSecret", - extra=json.dumps(mock_connection_params), - ) - mock_connection.return_value = mock_conn - hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID) + hook = AzureDataFactoryAsyncHook(mocked_connection.conn_id) with pytest.raises(ValueError): await hook.get_async_conn() @@ -907,14 +899,12 @@ def test_get_field_prefixed_extras(self): mock_conn = Connection( conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, conn_type="azure_data_factory", - extra=json.dumps( - { - "tenantId": "tenantId", - "subscriptionId": "subscriptionId", - "resource_group_name": RESOURCE_GROUP_NAME, - "factory_name": DATAFACTORY_NAME, - } - ), + extra={ + "tenantId": "tenantId", + "subscriptionId": "subscriptionId", + "resource_group_name": RESOURCE_GROUP_NAME, + "factory_name": DATAFACTORY_NAME, + }, ) extras = mock_conn.extra_dejson assert get_field(extras, "tenantId", strict=True) == "tenantId" @@ -929,14 +919,12 @@ def test_get_field_non_prefixed_extras(self): mock_conn = Connection( conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, conn_type="azure_data_factory", - extra=json.dumps( - { - "tenantId": "tenantId", - "subscriptionId": "subscriptionId", - "resource_group_name": RESOURCE_GROUP_NAME, - "factory_name": DATAFACTORY_NAME, - } - ), + extra={ + "tenantId": "tenantId", + "subscriptionId": "subscriptionId", + "resource_group_name": RESOURCE_GROUP_NAME, + "factory_name": DATAFACTORY_NAME, + }, ) extras = mock_conn.extra_dejson assert get_field(extras, "tenantId", strict=True) == "tenantId" diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py b/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py index 122949beac610..f5e2e8be5c9ef 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import json from unittest import mock from unittest.mock import PropertyMock @@ -26,18 +25,18 @@ from airflow.models import Connection from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeStorageV2Hook -from airflow.utils import db class TestAzureDataLakeHook: - def setup_method(self): - db.merge_conn( + @pytest.fixture(autouse=True) + def setup_connections(self, create_mock_connections): + create_mock_connections( Connection( conn_id="adl_test_key", conn_type="azure_data_lake", login="client_id", password="client secret", - extra=json.dumps({"tenant": "tenant", "account_name": "accountname"}), + extra={"tenant": "tenant", "account_name": "accountname"}, ) ) @@ -58,9 +57,10 @@ def test_conn(self, mock_lib): def test_check_for_blob(self, mock_lib, mock_filesystem): from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook + mocked_glob = mock_filesystem.return_value.glob hook = AzureDataLakeHook(azure_data_lake_conn_id="adl_test_key") hook.check_for_file("file_path") - mock_filesystem.glob.called + mocked_glob.assert_called() @mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.multithread.ADLUploader", autospec=True) @mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.lib", autospec=True) @@ -140,7 +140,7 @@ def test_remove(self, mock_lib, mock_fs): class TestAzureDataLakeStorageV2Hook: def setup_class(self) -> None: - self.conn_id: str = "adls_conn_id" + self.conn_id: str = "adls_conn_id1" self.file_system_name = "test_file_system" self.directory_name = "test_directory" self.file_name = "test_file_name" diff --git a/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py b/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py index 1529eadcd9a66..a99ca9e90bba5 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py @@ -25,7 +25,6 @@ """ from __future__ import annotations -import json import os from unittest import mock from unittest.mock import patch @@ -36,51 +35,36 @@ from airflow.models import Connection from airflow.providers.microsoft.azure.hooks.fileshare import AzureFileShareHook -from airflow.utils import db class TestAzureFileshareHook: - def setup_method(self): - db.merge_conn( + @pytest.fixture(autouse=True) + def setup_test_cases(self, create_mock_connections): + create_mock_connections( Connection( conn_id="azure_fileshare_test_key", conn_type="azure_file_share", login="login", password="key", - ) - ) - db.merge_conn( + ), Connection( conn_id="azure_fileshare_extras", conn_type="azure_fileshare", login="login", - extra=json.dumps( - { - "sas_token": "token", - "protocol": "http", - } - ), - ) - ) - db.merge_conn( + extra={"sas_token": "token", "protocol": "http"}, + ), # Neither password nor sas_token present Connection( conn_id="azure_fileshare_missing_credentials", conn_type="azure_fileshare", login="login", - ) - ) - db.merge_conn( + ), Connection( conn_id="azure_fileshare_extras_wrong", conn_type="azure_fileshare", login="login", - extra=json.dumps( - { - "wrong_key": "token", - } - ), - ) + extra={"wrong_key": "token"}, + ), ) def test_key_and_connection(self): diff --git a/tests/providers/microsoft/azure/hooks/test_azure_synapse.py b/tests/providers/microsoft/azure/hooks/test_azure_synapse.py index 3a82efd81217c..b63dacc6da21b 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_synapse.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_synapse.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import json from unittest.mock import MagicMock, patch import pytest @@ -26,7 +25,6 @@ from airflow.models.connection import Connection from airflow.providers.microsoft.azure.hooks.synapse import AzureSynapseHook, AzureSynapseSparkBatchRunStatus -from airflow.utils import db DEFAULT_SPARK_POOL = "defaultSparkPool" @@ -42,60 +40,45 @@ JOB_ID = 1 -def setup_module(): - connection_client_secret = Connection( - conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, - conn_type="azure_synapse", - host="https://testsynapse.dev.azuresynapse.net", - login="clientId", - password="clientSecret", - extra=json.dumps( - { - "tenantId": "tenantId", - "subscriptionId": "subscriptionId", - } +@pytest.fixture(autouse=True) +def setup_connections(create_mock_connections): + create_mock_connections( + # connection_client_secret + Connection( + conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, + conn_type="azure_synapse", + host="https://testsynapse.dev.azuresynapse.net", + login="clientId", + password="clientSecret", + extra={"tenantId": "tenantId", "subscriptionId": "subscriptionId"}, ), - ) - connection_default_credential = Connection( - conn_id=DEFAULT_CONNECTION_DEFAULT_CREDENTIAL, - conn_type="azure_synapse", - host="https://testsynapse.dev.azuresynapse.net", - extra=json.dumps( - { - "subscriptionId": "subscriptionId", - } + # connection_default_credential + Connection( + conn_id=DEFAULT_CONNECTION_DEFAULT_CREDENTIAL, + conn_type="azure_synapse", + host="https://testsynapse.dev.azuresynapse.net", + extra={"subscriptionId": "subscriptionId"}, ), - ) - connection_missing_subscription_id = Connection( - conn_id="azure_synapse_missing_subscription_id", - conn_type="azure_synapse", - host="https://testsynapse.dev.azuresynapse.net", - login="clientId", - password="clientSecret", - extra=json.dumps( - { - "tenantId": "tenantId", - } + Connection( + # connection_missing_subscription_id + conn_id="azure_synapse_missing_subscription_id", + conn_type="azure_synapse", + host="https://testsynapse.dev.azuresynapse.net", + login="clientId", + password="clientSecret", + extra={"tenantId": "tenantId"}, ), - ) - connection_missing_tenant_id = Connection( - conn_id="azure_synapse_missing_tenant_id", - conn_type="azure_synapse", - host="https://testsynapse.dev.azuresynapse.net", - login="clientId", - password="clientSecret", - extra=json.dumps( - { - "subscriptionId": "subscriptionId", - } + # connection_missing_tenant_id + Connection( + conn_id="azure_synapse_missing_tenant_id", + conn_type="azure_synapse", + host="https://testsynapse.dev.azuresynapse.net", + login="clientId", + password="clientSecret", + extra={"subscriptionId": "subscriptionId"}, ), ) - db.merge_conn(connection_client_secret) - db.merge_conn(connection_default_credential) - db.merge_conn(connection_missing_subscription_id) - db.merge_conn(connection_missing_tenant_id) - @fixture def hook(): diff --git a/tests/providers/microsoft/azure/hooks/test_base_azure.py b/tests/providers/microsoft/azure/hooks/test_base_azure.py index 7b587c41217de..53e2614a69a8f 100644 --- a/tests/providers/microsoft/azure/hooks/test_base_azure.py +++ b/tests/providers/microsoft/azure/hooks/test_base_azure.py @@ -18,63 +18,71 @@ from unittest.mock import Mock, patch +import pytest + from airflow.models import Connection from airflow.providers.microsoft.azure.hooks.base_azure import AzureBaseHook class TestBaseAzureHook: - @patch("airflow.providers.microsoft.azure.hooks.base_azure.get_client_from_auth_file") - @patch( - "airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook.get_connection", - return_value=Connection(conn_id="azure_default", extra='{ "key_path": "key_file.json" }'), + @pytest.mark.parametrize( + "mocked_connection", + [Connection(conn_id="azure_default", extra={"key_path": "key_file.json"})], + indirect=True, ) - def test_get_conn_with_key_path(self, mock_connection, mock_get_client_from_auth_file): + @patch("airflow.providers.microsoft.azure.hooks.base_azure.get_client_from_auth_file") + def test_get_conn_with_key_path(self, mock_get_client_from_auth_file, mocked_connection): + mock_get_client_from_auth_file.return_value = "foo-bar" mock_sdk_client = Mock() auth_sdk_client = AzureBaseHook(mock_sdk_client).get_conn() mock_get_client_from_auth_file.assert_called_once_with( - client_class=mock_sdk_client, auth_path=mock_connection.return_value.extra_dejson["key_path"] + client_class=mock_sdk_client, auth_path=mocked_connection.extra_dejson["key_path"] ) - assert auth_sdk_client == mock_get_client_from_auth_file.return_value + assert auth_sdk_client == "foo-bar" - @patch("airflow.providers.microsoft.azure.hooks.base_azure.get_client_from_json_dict") - @patch( - "airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook.get_connection", - return_value=Connection(conn_id="azure_default", extra='{ "key_json": { "test": "test" } }'), + @pytest.mark.parametrize( + "mocked_connection", + [Connection(conn_id="azure_default", extra={"key_json": {"test": "test"}})], + indirect=True, ) - def test_get_conn_with_key_json(self, mock_connection, mock_get_client_from_json_dict): + @patch("airflow.providers.microsoft.azure.hooks.base_azure.get_client_from_json_dict") + def test_get_conn_with_key_json(self, mock_get_client_from_json_dict, mocked_connection): mock_sdk_client = Mock() - + mock_get_client_from_json_dict.return_value = "foo-bar" auth_sdk_client = AzureBaseHook(mock_sdk_client).get_conn() mock_get_client_from_json_dict.assert_called_once_with( - client_class=mock_sdk_client, config_dict=mock_connection.return_value.extra_dejson["key_json"] + client_class=mock_sdk_client, config_dict=mocked_connection.extra_dejson["key_json"] ) - assert auth_sdk_client == mock_get_client_from_json_dict.return_value + assert auth_sdk_client == "foo-bar" @patch("airflow.providers.microsoft.azure.hooks.base_azure.ServicePrincipalCredentials") - @patch( - "airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook.get_connection", - return_value=Connection( - conn_id="azure_default", - login="my_login", - password="my_password", - extra='{ "tenantId": "my_tenant", "subscriptionId": "my_subscription" }', - ), + @pytest.mark.parametrize( + "mocked_connection", + [ + Connection( + conn_id="azure_default", + login="my_login", + password="my_password", + extra={"tenantId": "my_tenant", "subscriptionId": "my_subscription"}, + ) + ], + indirect=True, ) - def test_get_conn_with_credentials(self, mock_connection, mock_spc): - mock_sdk_client = Mock() - + def test_get_conn_with_credentials(self, mock_spc, mocked_connection): + mock_sdk_client = Mock(return_value="spam-egg") + mock_spc.return_value = "foo-bar" auth_sdk_client = AzureBaseHook(mock_sdk_client).get_conn() mock_spc.assert_called_once_with( - client_id=mock_connection.return_value.login, - secret=mock_connection.return_value.password, - tenant=mock_connection.return_value.extra_dejson["tenantId"], + client_id=mocked_connection.login, + secret=mocked_connection.password, + tenant=mocked_connection.extra_dejson["tenantId"], ) mock_sdk_client.assert_called_once_with( - credentials=mock_spc.return_value, - subscription_id=mock_connection.return_value.extra_dejson["subscriptionId"], + credentials="foo-bar", + subscription_id=mocked_connection.extra_dejson["subscriptionId"], ) - assert auth_sdk_client == mock_sdk_client.return_value + assert auth_sdk_client == "spam-egg" diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py b/tests/providers/microsoft/azure/hooks/test_wasb.py index 1f48a7b011eb8..06ef4eedfbf32 100644 --- a/tests/providers/microsoft/azure/hooks/test_wasb.py +++ b/tests/providers/microsoft/azure/hooks/test_wasb.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import json from unittest import mock import pytest @@ -34,269 +33,262 @@ ) ACCESS_KEY_STRING = "AccountName=name;skdkskd" +PROXIES = {"http": "http_proxy_uri", "https": "https_proxy_uri"} + + +@pytest.fixture +def mocked_blob_service_client(): + with mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") as m: + yield m + + +@pytest.fixture +def mocked_default_azure_credential(): + with mock.patch("airflow.providers.microsoft.azure.hooks.wasb.DefaultAzureCredential") as m: + yield m + + +@pytest.fixture +def mocked_client_secret_credential(): + with mock.patch("airflow.providers.microsoft.azure.hooks.wasb.ClientSecretCredential") as m: + yield m class TestWasbHook: - def setup_method(self): + @pytest.fixture(autouse=True) + def setup_method(self, create_mock_connections): self.login = "login" self.wasb_test_key = "wasb_test_key" self.connection_type = "wasb" - self.connection_string_id = "azure_test_connection_string" - self.shared_key_conn_id = "azure_shared_key_test" - self.shared_key_conn_id_without_host = "azure_shared_key_test_wihout_host" - self.ad_conn_id = "azure_AD_test" + self.azure_test_connection_string = "azure_test_connection_string" + self.azure_shared_key_test = "azure_shared_key_test" + self.ad_conn_id = "ad_conn_id" self.sas_conn_id = "sas_token_id" - self.extra__wasb__sas_conn_id = "extra__sas_token_id" - self.http_sas_conn_id = "http_sas_token_id" - self.extra__wasb__http_sas_conn_id = "extra__http_sas_token_id" + self.extra__wasb__sas_conn_id = "extra__wasb__sas_conn_id" + self.http_sas_conn_id = "http_sas_conn_id" + self.extra__wasb__http_sas_conn_id = "extra__wasb__http_sas_conn_id" self.public_read_conn_id = "pub_read_id" self.public_read_conn_id_without_host = "pub_read_id_without_host" - self.managed_identity_conn_id = "managed_identity" + self.managed_identity_conn_id = "managed_identity_conn_id" self.authority = "https://test_authority.com" - self.proxies = {"http": "http_proxy_uri", "https": "https_proxy_uri"} + self.proxies = PROXIES self.client_secret_auth_config = { "proxies": self.proxies, "connection_verify": False, "authority": self.authority, } - self.connection_map = { - self.wasb_test_key: Connection( + + conns = create_mock_connections( + Connection( conn_id="wasb_test_key", conn_type=self.connection_type, login=self.login, password="key", ), - self.public_read_conn_id: Connection( + Connection( conn_id=self.public_read_conn_id, conn_type=self.connection_type, host="https://accountname.blob.core.windows.net", - extra=json.dumps({"proxies": self.proxies}), + extra={"proxies": self.proxies}, ), - self.public_read_conn_id_without_host: Connection( + Connection( conn_id=self.public_read_conn_id_without_host, conn_type=self.connection_type, login=self.login, - extra=json.dumps({"proxies": self.proxies}), + extra={"proxies": self.proxies}, ), - self.connection_string_id: Connection( - conn_id=self.connection_string_id, + Connection( + conn_id=self.azure_test_connection_string, conn_type=self.connection_type, - extra=json.dumps({"connection_string": CONN_STRING, "proxies": self.proxies}), + extra={"connection_string": CONN_STRING, "proxies": self.proxies}, ), - self.shared_key_conn_id: Connection( - conn_id=self.shared_key_conn_id, + Connection( + conn_id=self.azure_shared_key_test, conn_type=self.connection_type, host="https://accountname.blob.core.windows.net", - extra=json.dumps({"shared_access_key": "token", "proxies": self.proxies}), + extra={"shared_access_key": "token", "proxies": self.proxies}, ), - self.shared_key_conn_id_without_host: Connection( - conn_id=self.shared_key_conn_id_without_host, - conn_type=self.connection_type, - login=self.login, - extra=json.dumps({"shared_access_key": "token", "proxies": self.proxies}), - ), - self.ad_conn_id: Connection( + Connection( conn_id=self.ad_conn_id, conn_type=self.connection_type, host="conn_host", login="appID", password="appsecret", - extra=json.dumps( - { - "tenant_id": "token", - "proxies": self.proxies, - "client_secret_auth_config": self.client_secret_auth_config, - } - ), + extra={ + "tenant_id": "token", + "proxies": self.proxies, + "client_secret_auth_config": self.client_secret_auth_config, + }, ), - self.managed_identity_conn_id: Connection( + Connection( conn_id=self.managed_identity_conn_id, conn_type=self.connection_type, - extra=json.dumps({"proxies": self.proxies}), + extra={"proxies": self.proxies}, ), - self.sas_conn_id: Connection( - conn_id=self.sas_conn_id, + Connection( + conn_id="sas_conn_id", conn_type=self.connection_type, login=self.login, - extra=json.dumps({"sas_token": "token", "proxies": self.proxies}), + extra={"sas_token": "token", "proxies": self.proxies}, ), - self.extra__wasb__sas_conn_id: Connection( + Connection( conn_id=self.extra__wasb__sas_conn_id, conn_type=self.connection_type, login=self.login, - extra=json.dumps({"extra__wasb__sas_token": "token", "proxies": self.proxies}), + extra={"extra__wasb__sas_token": "token", "proxies": self.proxies}, ), - self.http_sas_conn_id: Connection( + Connection( conn_id=self.http_sas_conn_id, conn_type=self.connection_type, - extra=json.dumps( - {"sas_token": "https://login.blob.core.windows.net/token", "proxies": self.proxies} - ), + extra={"sas_token": "https://login.blob.core.windows.net/token", "proxies": self.proxies}, ), - self.extra__wasb__http_sas_conn_id: Connection( + Connection( conn_id=self.extra__wasb__http_sas_conn_id, conn_type=self.connection_type, - extra=json.dumps( - { - "extra__wasb__sas_token": "https://login.blob.core.windows.net/token", - "proxies": self.proxies, - } - ), + extra={ + "extra__wasb__sas_token": "https://login.blob.core.windows.net/token", + "proxies": self.proxies, + }, ), - } - - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_key(self, mock_get_conn, mock_blob_service_client): - conn = self.connection_map[self.wasb_test_key] - mock_get_conn.return_value = conn - WasbHook(wasb_conn_id=self.wasb_test_key) - assert mock_blob_service_client.call_args == mock.call( - account_url=f"https://{self.login}.blob.core.windows.net/", - credential=conn.password, + ) + self.connection_map = {conn.conn_id: conn for conn in conns} + + def test_key(self, mocked_blob_service_client): + hook = WasbHook(wasb_conn_id=self.wasb_test_key) + mocked_blob_service_client.assert_not_called() # Not expected during initialisation + hook.get_conn() + mocked_blob_service_client.assert_called_once_with( + account_url=f"https://{self.login}.blob.core.windows.net/", credential="key" ) - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_public_read(self, mock_get_conn, mock_blob_service_client): - conn = self.connection_map[self.public_read_conn_id] - mock_get_conn.return_value = conn - WasbHook(wasb_conn_id=self.public_read_conn_id, public_read=True) - assert mock_blob_service_client.call_args == mock.call( - account_url=conn.host, - proxies=conn.extra_dejson["proxies"], + def test_public_read(self, mocked_blob_service_client): + WasbHook(wasb_conn_id=self.public_read_conn_id, public_read=True).get_conn() + mocked_blob_service_client.assert_called_once_with( + account_url="https://accountname.blob.core.windows.net", proxies=self.proxies ) - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_connection_string(self, mock_get_conn, mock_blob_service_client): - conn = self.connection_map[self.connection_string_id] - mock_get_conn.return_value = conn - WasbHook(wasb_conn_id=self.connection_string_id) - mock_blob_service_client.from_connection_string.assert_called_once_with( - CONN_STRING, - proxies=conn.extra_dejson["proxies"], - connection_string=conn.extra_dejson["connection_string"], + def test_connection_string(self, mocked_blob_service_client): + WasbHook(wasb_conn_id=self.azure_test_connection_string).get_conn() + mocked_blob_service_client.from_connection_string.assert_called_once_with( + CONN_STRING, proxies=self.proxies, connection_string=CONN_STRING ) - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_shared_key_connection(self, mock_get_conn, mock_blob_service_client): - conn = self.connection_map[self.shared_key_conn_id] - mock_get_conn.return_value = conn - WasbHook(wasb_conn_id=self.shared_key_conn_id) - mock_blob_service_client.assert_called_once_with( - account_url=conn.host, - credential=conn.extra_dejson["shared_access_key"], - proxies=conn.extra_dejson["proxies"], - shared_access_key=conn.extra_dejson["shared_access_key"], + def test_shared_key_connection(self, mocked_blob_service_client): + WasbHook(wasb_conn_id=self.azure_shared_key_test).get_conn() + mocked_blob_service_client.assert_called_once_with( + account_url="https://accountname.blob.core.windows.net", + credential="token", + proxies=self.proxies, + shared_access_key="token", ) - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.DefaultAzureCredential") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_managed_identity(self, mock_get_conn, mock_credential, mock_blob_service_client): - conn = self.connection_map[self.managed_identity_conn_id] - mock_get_conn.return_value = conn - WasbHook(wasb_conn_id=self.managed_identity_conn_id) - mock_blob_service_client.assert_called_once_with( - account_url=f"https://{conn.login}.blob.core.windows.net/", - credential=mock_credential.return_value, - proxies=conn.extra_dejson["proxies"], + def test_managed_identity(self, mocked_default_azure_credential, mocked_blob_service_client): + mocked_default_azure_credential.return_value = "foo-bar" + WasbHook(wasb_conn_id=self.managed_identity_conn_id).get_conn() + mocked_blob_service_client.assert_called_once_with( + account_url="https://None.blob.core.windows.net/", + credential="foo-bar", + proxies=self.proxies, ) - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.ClientSecretCredential") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_azure_directory_connection(self, mock_get_conn, mock_credential, mock_blob_service_client): - conn = self.connection_map[self.ad_conn_id] - mock_get_conn.return_value = conn - WasbHook(wasb_conn_id=self.ad_conn_id) - mock_credential.assert_called_once_with( - conn.extra_dejson["tenant_id"], - conn.login, - conn.password, + def test_azure_directory_connection(self, mocked_client_secret_credential, mocked_blob_service_client): + mocked_client_secret_credential.return_value = "spam-egg" + WasbHook(wasb_conn_id=self.ad_conn_id).get_conn() + mocked_client_secret_credential.assert_called_once_with( + "token", + "appID", + "appsecret", proxies=self.client_secret_auth_config["proxies"], connection_verify=self.client_secret_auth_config["connection_verify"], authority=self.client_secret_auth_config["authority"], ) - mock_blob_service_client.assert_called_once_with( - account_url=f"https://{conn.login}.blob.core.windows.net/", - credential=mock_credential.return_value, - tenant_id=conn.extra_dejson["tenant_id"], - proxies=conn.extra_dejson["proxies"], + mocked_blob_service_client.assert_called_once_with( + account_url="https://appID.blob.core.windows.net/", + credential="spam-egg", + tenant_id="token", + proxies=self.proxies, ) - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.DefaultAzureCredential") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_active_directory_ID_used_as_host(self, mock_get_conn, mock_credential, mock_blob_service_client): - mock_get_conn.return_value = Connection( - conn_id="testconn", - conn_type=self.connection_type, - login="testaccountname", - host="testaccountID", - ) - WasbHook(wasb_conn_id="testconn") - assert mock_blob_service_client.call_args == mock.call( + @pytest.mark.parametrize( + "mocked_connection", + [ + Connection( + conn_id="testconn", + conn_type="wasb", + login="testaccountname", + host="testaccountID", + ) + ], + indirect=True, + ) + def test_active_directory_id_used_as_host( + self, mocked_connection, mocked_default_azure_credential, mocked_blob_service_client + ): + mocked_default_azure_credential.return_value = "fake-credential" + WasbHook(wasb_conn_id="testconn").get_conn() + mocked_blob_service_client.assert_called_once_with( account_url="https://testaccountname.blob.core.windows.net/", - credential=mock_credential.return_value, + credential="fake-credential", ) - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_sas_token_provided_and_active_directory_ID_used_as_host( - self, mock_get_conn, mock_blob_service_client + @pytest.mark.parametrize( + "mocked_connection", + [ + Connection( + conn_id="testconn", + conn_type="wasb", + login="testaccountname", + host="testaccountID", + extra={"sas_token": "SAStoken"}, + ) + ], + indirect=True, + ) + def test_sas_token_provided_and_active_directory_id_used_as_host( + self, mocked_connection, mocked_blob_service_client ): - mock_get_conn.return_value = Connection( - conn_id="testconn", - conn_type=self.connection_type, - login="testaccountname", - host="testaccountID", - extra=json.dumps({"sas_token": "SAStoken"}), - ) - WasbHook(wasb_conn_id="testconn") - assert mock_blob_service_client.call_args == mock.call( + WasbHook(wasb_conn_id="testconn").get_conn() + mocked_blob_service_client.assert_called_once_with( account_url="https://testaccountname.blob.core.windows.net/SAStoken", sas_token="SAStoken", ) @pytest.mark.parametrize( - argnames="conn_id_str", - argvalues=[ - "wasb_test_key", - "shared_key_conn_id_without_host", - "public_read_conn_id_without_host", + "mocked_connection", + [ + pytest.param( + Connection( + conn_type="wasb", + login="foo", + extra={"shared_access_key": "token", "proxies": PROXIES}, + ), + id="shared-key-without-host", + ), + pytest.param( + Connection(conn_type="wasb", login="foo", extra={"proxies": PROXIES}), + id="public-read-without-host", + ), ], + indirect=True, ) - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.DefaultAzureCredential") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") def test_account_url_without_host( - self, mock_get_conn, mock_credential, mock_blob_service_client, conn_id_str + self, mocked_connection, mocked_blob_service_client, mocked_default_azure_credential ): - conn_id = self.__getattribute__(conn_id_str) - connection = self.connection_map[conn_id] - mock_get_conn.return_value = connection - WasbHook(wasb_conn_id=conn_id) - if conn_id_str == "wasb_test_key": - mock_blob_service_client.assert_called_once_with( - account_url=f"https://{connection.login}.blob.core.windows.net/", - credential=connection.password, - ) - elif conn_id_str == "shared_key_conn_id_without_host": - mock_blob_service_client.assert_called_once_with( - account_url=f"https://{connection.login}.blob.core.windows.net/", - credential=connection.extra_dejson["shared_access_key"], - proxies=connection.extra_dejson["proxies"], - shared_access_key=connection.extra_dejson["shared_access_key"], + mocked_default_azure_credential.return_value = "default-creds" + WasbHook(wasb_conn_id=mocked_connection.conn_id).get_conn() + if "shared_access_key" in mocked_connection.extra_dejson: + mocked_blob_service_client.assert_called_once_with( + account_url=f"https://{mocked_connection.login}.blob.core.windows.net/", + credential=mocked_connection.extra_dejson["shared_access_key"], + proxies=mocked_connection.extra_dejson["proxies"], + shared_access_key=mocked_connection.extra_dejson["shared_access_key"], ) else: - mock_blob_service_client.assert_called_once_with( - account_url=f"https://{connection.login}.blob.core.windows.net/", - credential=mock_credential.return_value, - proxies=connection.extra_dejson["proxies"], + mocked_blob_service_client.assert_called_once_with( + account_url=f"https://{mocked_connection.login}.blob.core.windows.net/", + credential="default-creds", + proxies=mocked_connection.extra_dejson["proxies"], ) @pytest.mark.parametrize( @@ -308,25 +300,22 @@ def test_account_url_without_host( ("extra__wasb__http_sas_conn_id", "extra__wasb__sas_token"), ], ) - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_sas_token_connection(self, mock_get_conn, conn_id_str, extra_key): - conn_id = self.__getattribute__(conn_id_str) - mock_get_conn.return_value = self.connection_map[conn_id] - hook = WasbHook(wasb_conn_id=conn_id) + def test_sas_token_connection(self, conn_id_str, extra_key): + hook = WasbHook(wasb_conn_id=conn_id_str) conn = hook.get_conn() hook_conn = hook.get_connection(hook.conn_id) sas_token = hook_conn.extra_dejson[extra_key] assert isinstance(conn, BlobServiceClient) assert conn.url.startswith("https://") if hook_conn.login: - assert conn.url.__contains__(hook_conn.login) + assert hook_conn.login in conn.url assert conn.url.endswith(sas_token + "/") @pytest.mark.parametrize( argnames="conn_id_str", argvalues=[ - "connection_string_id", - "shared_key_conn_id", + "azure_test_connection_string", + "azure_shared_key_test", "ad_conn_id", "managed_identity_conn_id", "sas_conn_id", @@ -335,27 +324,17 @@ def test_sas_token_connection(self, mock_get_conn, conn_id_str, extra_key): "extra__wasb__http_sas_conn_id", ], ) - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_connection_extra_arguments(self, mock_get_conn, conn_id_str): - conn_id = self.__getattribute__(conn_id_str) - mock_get_conn.return_value = self.connection_map[conn_id] - hook = WasbHook(wasb_conn_id=conn_id) - conn = hook.get_conn() + def test_connection_extra_arguments(self, conn_id_str): + conn = WasbHook(wasb_conn_id=conn_id_str).get_conn() assert conn._config.proxy_policy.proxies == self.proxies - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_connection_extra_arguments_public_read(self, mock_get_conn): - conn_id = self.public_read_conn_id - mock_get_conn.return_value = self.connection_map[conn_id] - hook = WasbHook(wasb_conn_id=conn_id, public_read=True) + def test_connection_extra_arguments_public_read(self): + hook = WasbHook(wasb_conn_id=self.public_read_conn_id, public_read=True) conn = hook.get_conn() assert conn._config.proxy_policy.proxies == self.proxies - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_extra_client_secret_auth_config_ad_connection(self, mock_get_conn): - mock_get_conn.return_value = self.connection_map[self.ad_conn_id] - conn_id = self.ad_conn_id - hook = WasbHook(wasb_conn_id=conn_id) + def test_extra_client_secret_auth_config_ad_connection(self): + hook = WasbHook(wasb_conn_id=self.ad_conn_id) conn = hook.get_conn() assert conn.credential._authority == self.authority @@ -371,79 +350,62 @@ def test_extra_client_secret_auth_config_ad_connection(self, mock_get_conn): ("testhost.blob.net", "testhost.blob.net"), ], ) - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") def test_proper_account_url_update( - self, mock_get_conn, mock_blob_service_client, provided_host, expected_host + self, mocked_blob_service_client, provided_host, expected_host, create_mock_connection ): - mock_get_conn.return_value = Connection( - conn_id="test_conn", - conn_type=self.connection_type, - password="testpass", - login="accountlogin", - host=provided_host, + conn = create_mock_connection( + Connection( + conn_type=self.connection_type, + password="testpass", + login="accountlogin", + host=provided_host, + ) ) - WasbHook(wasb_conn_id=self.shared_key_conn_id) - mock_blob_service_client.assert_called_once_with(account_url=expected_host, credential="testpass") - - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_check_for_blob(self, mock_get_conn, mock_service): - mock_get_conn.return_value = self.connection_map[self.shared_key_conn_id] - hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) + WasbHook(wasb_conn_id=conn.conn_id).get_conn() + mocked_blob_service_client.assert_called_once_with(account_url=expected_host, credential="testpass") + + def test_check_for_blob(self, mocked_blob_service_client): + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) assert hook.check_for_blob(container_name="mycontainer", blob_name="myblob") - mock_blob_client = mock_service.return_value.get_blob_client + mock_blob_client = mocked_blob_service_client.return_value.get_blob_client mock_blob_client.assert_called_once_with(container="mycontainer", blob="myblob") mock_blob_client.return_value.get_blob_properties.assert_called() - @mock.patch.object(WasbHook, "get_blobs_list") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_check_for_prefix(self, mock_get_conn, get_blobs_list): - mock_get_conn.return_value = self.connection_map[self.shared_key_conn_id] - get_blobs_list.return_value = ["blobs"] - hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) + @mock.patch.object(WasbHook, "get_blobs_list", return_value=["blobs"]) + def test_check_for_prefix(self, get_blobs_list): + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) assert hook.check_for_prefix("container", "prefix", timeout=3) get_blobs_list.assert_called_once_with(container_name="container", prefix="prefix", timeout=3) - @mock.patch.object(WasbHook, "get_blobs_list") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_check_for_prefix_empty(self, mock_get_conn, get_blobs_list): - mock_get_conn.return_value = self.connection_map[self.shared_key_conn_id] - get_blobs_list.return_value = [] - hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) + @mock.patch.object(WasbHook, "get_blobs_list", return_value=[]) + def test_check_for_prefix_empty(self, get_blobs_list): + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) assert not hook.check_for_prefix("container", "prefix", timeout=3) get_blobs_list.assert_called_once_with(container_name="container", prefix="prefix", timeout=3) - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_get_blobs_list(self, mock_get_conn, mock_service): - mock_get_conn.return_value = self.connection_map[self.shared_key_conn_id] - hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) + def test_get_blobs_list(self, mocked_blob_service_client): + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) hook.get_blobs_list(container_name="mycontainer", prefix="my", include=None, delimiter="/") - mock_service.return_value.get_container_client.assert_called_once_with("mycontainer") - mock_service.return_value.get_container_client.return_value.walk_blobs.assert_called_once_with( + mock_container_client = mocked_blob_service_client.return_value.get_container_client + mock_container_client.assert_called_once_with("mycontainer") + mock_container_client.return_value.walk_blobs.assert_called_once_with( name_starts_with="my", include=None, delimiter="/" ) - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_get_blobs_list_recursive(self, mock_get_conn, mock_service): - mock_get_conn.return_value = self.connection_map[self.shared_key_conn_id] - hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) + def test_get_blobs_list_recursive(self, mocked_blob_service_client): + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) hook.get_blobs_list_recursive( container_name="mycontainer", prefix="test", include=None, endswith="file_extension" ) - mock_service.return_value.get_container_client.assert_called_once_with("mycontainer") - mock_service.return_value.get_container_client.return_value.list_blobs.assert_called_once_with( + mock_container_client = mocked_blob_service_client.return_value.get_container_client + mock_container_client.assert_called_once_with("mycontainer") + mock_container_client.return_value.list_blobs.assert_called_once_with( name_starts_with="test", include=None ) - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_get_blobs_list_recursive_endswith(self, mock_get_conn, mock_service): - mock_get_conn.return_value = self.connection_map[self.shared_key_conn_id] - hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) - mock_service.return_value.get_container_client.return_value.list_blobs.return_value = [ + def test_get_blobs_list_recursive_endswith(self, mocked_blob_service_client): + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) + mocked_blob_service_client.return_value.get_container_client.return_value.list_blobs.return_value = [ BlobProperties(name="test/abc.py"), BlobProperties(name="test/inside_test/abc.py"), BlobProperties(name="test/abc.csv"), @@ -455,11 +417,9 @@ def test_get_blobs_list_recursive_endswith(self, mock_get_conn, mock_service): @pytest.mark.parametrize(argnames="create_container", argvalues=[True, False]) @mock.patch.object(WasbHook, "upload") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_load_file(self, mock_get_conn, mock_upload, create_container): - mock_get_conn.return_value = self.connection_map[self.shared_key_conn_id] + def test_load_file(self, mock_upload, create_container): with mock.patch("builtins.open", mock.mock_open(read_data="data")): - hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) hook.load_file("path", "container", "blob", create_container, max_connections=1) mock_upload.assert_called_with( @@ -472,10 +432,8 @@ def test_load_file(self, mock_get_conn, mock_upload, create_container): @pytest.mark.parametrize(argnames="create_container", argvalues=[True, False]) @mock.patch.object(WasbHook, "upload") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_load_string(self, mock_get_conn, mock_upload, create_container): - mock_get_conn.return_value = self.connection_map[self.shared_key_conn_id] - hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) + def test_load_string(self, mock_upload, create_container): + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) hook.load_string("big string", "container", "blob", create_container, max_connections=1) mock_upload.assert_called_once_with( container_name="container", @@ -486,30 +444,22 @@ def test_load_string(self, mock_get_conn, mock_upload, create_container): ) @mock.patch.object(WasbHook, "download") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_get_file(self, mock_get_conn, mock_download): - mock_get_conn.return_value = self.connection_map[self.shared_key_conn_id] + def test_get_file(self, mock_download): with mock.patch("builtins.open", mock.mock_open(read_data="data")): - hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) hook.get_file("path", "container", "blob", max_connections=1) mock_download.assert_called_once_with(container_name="container", blob_name="blob", max_connections=1) mock_download.return_value.readall.assert_called() - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") @mock.patch.object(WasbHook, "download") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_read_file(self, mock_get_conn, mock_download, mock_service): - mock_get_conn.return_value = self.connection_map[self.shared_key_conn_id] - hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) + def test_read_file(self, mock_download, mocked_blob_service_client): + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) hook.read_file("container", "blob", max_connections=1) mock_download.assert_called_once_with("container", "blob", max_connections=1) @pytest.mark.parametrize(argnames="create_container", argvalues=[True, False]) - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_upload(self, mock_get_conn, mock_service, create_container): - mock_get_conn.return_value = self.connection_map[self.shared_key_conn_id] - hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) + def test_upload(self, mocked_blob_service_client, create_container): + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) hook.upload( container_name="mycontainer", blob_name="myblob", @@ -518,80 +468,61 @@ def test_upload(self, mock_get_conn, mock_service, create_container): blob_type="BlockBlob", length=4, ) - mock_blob_client = mock_service.return_value.get_blob_client + mock_blob_client = mocked_blob_service_client.return_value.get_blob_client mock_blob_client.assert_called_once_with(container="mycontainer", blob="myblob") mock_blob_client.return_value.upload_blob.assert_called_once_with(b"mydata", "BlockBlob", length=4) - mock_container_client = mock_service.return_value.get_container_client + mock_container_client = mocked_blob_service_client.return_value.get_container_client if create_container: mock_container_client.assert_called_with("mycontainer") else: mock_container_client.assert_not_called() - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_download(self, mock_get_conn, mock_service): - mock_get_conn.return_value = self.connection_map[self.shared_key_conn_id] - blob_client = mock_service.return_value.get_blob_client - hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) + def test_download(self, mocked_blob_service_client): + blob_client = mocked_blob_service_client.return_value.get_blob_client + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) hook.download(container_name="mycontainer", blob_name="myblob", offset=2, length=4) blob_client.assert_called_once_with(container="mycontainer", blob="myblob") blob_client.return_value.download_blob.assert_called_once_with(offset=2, length=4) - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_get_container_client(self, mock_get_conn, mock_service): - mock_get_conn.return_value = self.connection_map[self.shared_key_conn_id] - hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) + def test_get_container_client(self, mocked_blob_service_client): + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) hook._get_container_client("mycontainer") - mock_service.return_value.get_container_client.assert_called_once_with("mycontainer") + mocked_blob_service_client.return_value.get_container_client.assert_called_once_with("mycontainer") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_get_blob_client(self, mock_get_conn, mock_service): - mock_get_conn.return_value = self.connection_map[self.shared_key_conn_id] - hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) + def test_get_blob_client(self, mocked_blob_service_client): + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) hook._get_blob_client(container_name="mycontainer", blob_name="myblob") - mock_instance = mock_service.return_value.get_blob_client + mock_instance = mocked_blob_service_client.return_value.get_blob_client mock_instance.assert_called_once_with(container="mycontainer", blob="myblob") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_create_container(self, mock_get_conn, mock_service): - mock_get_conn.return_value = self.connection_map[self.shared_key_conn_id] - hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) + def test_create_container(self, mocked_blob_service_client): + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) hook.create_container(container_name="mycontainer") - mock_instance = mock_service.return_value.get_container_client + mock_instance = mocked_blob_service_client.return_value.get_container_client mock_instance.assert_called_once_with("mycontainer") mock_instance.return_value.create_container.assert_called() - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_delete_container(self, mock_get_conn, mock_service): - mock_get_conn.return_value = self.connection_map[self.shared_key_conn_id] - hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) + def test_delete_container(self, mocked_blob_service_client): + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) hook.delete_container("mycontainer") - mock_service.return_value.get_container_client.assert_called_once_with("mycontainer") - mock_service.return_value.get_container_client.return_value.delete_container.assert_called() + mocked_container_client = mocked_blob_service_client.return_value.get_container_client + mocked_container_client.assert_called_once_with("mycontainer") + mocked_container_client.return_value.delete_container.assert_called() - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") @mock.patch.object(WasbHook, "delete_blobs") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_delete_single_blob(self, mock_get_conn, delete_blobs, mock_service): - mock_get_conn.return_value = self.connection_map[self.shared_key_conn_id] - hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) + def test_delete_single_blob(self, delete_blobs, mocked_blob_service_client): + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) hook.delete_file("container", "blob", is_prefix=False) delete_blobs.assert_called_once_with("container", "blob") @mock.patch.object(WasbHook, "delete_blobs") @mock.patch.object(WasbHook, "get_blobs_list") @mock.patch.object(WasbHook, "check_for_blob") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_delete_multiple_blobs(self, mock_get_conn, mock_check, mock_get_blobslist, mock_delete_blobs): - mock_get_conn.return_value = self.connection_map[self.shared_key_conn_id] + def test_delete_multiple_blobs(self, mock_check, mock_get_blobslist, mock_delete_blobs): mock_check.return_value = False mock_get_blobslist.return_value = ["blob_prefix/blob1", "blob_prefix/blob2"] - hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) hook.delete_file("container", "blob_prefix", is_prefix=True) mock_get_blobslist.assert_called_once_with("container", prefix="blob_prefix", delimiter="") mock_delete_blobs.assert_any_call( @@ -604,14 +535,10 @@ def test_delete_multiple_blobs(self, mock_get_conn, mock_check, mock_get_blobsli @mock.patch.object(WasbHook, "delete_blobs") @mock.patch.object(WasbHook, "get_blobs_list") @mock.patch.object(WasbHook, "check_for_blob") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_delete_more_than_256_blobs( - self, mock_get_conn, mock_check, mock_get_blobslist, mock_delete_blobs - ): - mock_get_conn.return_value = self.connection_map[self.shared_key_conn_id] + def test_delete_more_than_256_blobs(self, mock_check, mock_get_blobslist, mock_delete_blobs): mock_check.return_value = False mock_get_blobslist.return_value = [f"blob_prefix/blob{i}" for i in range(300)] - hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) hook.delete_file("container", "blob_prefix", is_prefix=True) mock_get_blobslist.assert_called_once_with("container", prefix="blob_prefix", delimiter="") # The maximum number of blobs that can be deleted in a single request is 256 using the underlying @@ -620,33 +547,26 @@ def test_delete_more_than_256_blobs( # `ContainerClient.delete_blobs()` in this test. assert mock_delete_blobs.call_count == 2 - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") @mock.patch.object(WasbHook, "get_blobs_list") @mock.patch.object(WasbHook, "check_for_blob") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_delete_nonexisting_blob_fails(self, mock_get_conn, mock_check, mock_getblobs, mock_service): + def test_delete_nonexisting_blob_fails(self, mock_check, mock_getblobs, mocked_blob_service_client): mock_getblobs.return_value = [] mock_check.return_value = False with pytest.raises(Exception) as ctx: - hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) hook.delete_file("container", "nonexisting_blob", is_prefix=False, ignore_if_missing=False) assert isinstance(ctx.value, AirflowException) @mock.patch.object(WasbHook, "get_blobs_list") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_delete_multiple_nonexisting_blobs_fails(self, mock_get_conn, mock_getblobs): - mock_get_conn.return_value = self.connection_map[self.shared_key_conn_id] + def test_delete_multiple_nonexisting_blobs_fails(self, mock_getblobs): mock_getblobs.return_value = [] with pytest.raises(Exception) as ctx: - hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) hook.delete_file("container", "nonexisting_blob_prefix", is_prefix=True, ignore_if_missing=False) assert isinstance(ctx.value, AirflowException) - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_connection_success(self, mock_get_conn, mock_service): - mock_get_conn.return_value = self.connection_map[self.shared_key_conn_id] - hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) + def test_connection_success(self, mocked_blob_service_client): + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) hook.get_conn().get_account_information().return_value = { "sku_name": "Standard_RAGRS", "account_kind": "StorageV2", @@ -656,11 +576,8 @@ def test_connection_success(self, mock_get_conn, mock_service): assert status is True assert msg == "Successfully connected to Azure Blob Storage." - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") - @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection") - def test_connection_failure(self, mock_get_conn, mock_service): - mock_get_conn.return_value = self.connection_map[self.shared_key_conn_id] - hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) + def test_connection_failure(self, mocked_blob_service_client): + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) hook.get_conn().get_account_information = mock.PropertyMock( side_effect=Exception("Authentication failed.") ) diff --git a/tests/providers/microsoft/azure/operators/test_azure_batch.py b/tests/providers/microsoft/azure/operators/test_azure_batch.py index e920f7c1d9ed4..e70db80913293 100644 --- a/tests/providers/microsoft/azure/operators/test_azure_batch.py +++ b/tests/providers/microsoft/azure/operators/test_azure_batch.py @@ -26,7 +26,6 @@ from airflow.models import Connection from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook from airflow.providers.microsoft.azure.operators.batch import AzureBatchOperator -from airflow.utils import db TASK_ID = "MyDag" BATCH_POOL_ID = "MyPool" @@ -40,11 +39,20 @@ $TargetDedicated = $isWorkingWeekdayHour ? 20:10;""" +@pytest.fixture +def mocked_batch_service_client(): + with mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient") as m: + yield m + + class TestAzureBatchOperator: # set up the test environment - @mock.patch("airflow.providers.microsoft.azure.hooks.batch.AzureBatchHook") - @mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient") - def setup_method(self, method, mock_batch, mock_hook): + @pytest.fixture(autouse=True) + def setup_test_cases(self, mocked_batch_service_client, create_mock_connections): + # set up mocked Azure Batch client + self.batch_client = mock.MagicMock(name="FakeBatchServiceClient") + mocked_batch_service_client.return_value = self.batch_client + # set up the test variable self.test_vm_conn_id = "test_azure_batch_vm2" self.test_cloud_conn_id = "test_azure_batch_cloud2" @@ -59,22 +67,21 @@ def setup_method(self, method, mock_batch, mock_hook): self.test_cloud_os_version = "test-version" self.test_node_agent_sku = "test-node-agent-sku" - # connect with vm configuration - db.merge_conn( + create_mock_connections( + # connect with vm configuration Connection( conn_id=self.test_vm_conn_id, conn_type="azure_batch", extra=json.dumps({"account_url": self.test_account_url}), - ) - ) - # connect with cloud service - db.merge_conn( + ), + # connect with cloud service Connection( conn_id=self.test_cloud_conn_id, conn_type="azure_batch", extra=json.dumps({"account_url": self.test_account_url}), - ) + ), ) + self.operator = AzureBatchOperator( task_id=TASK_ID, batch_pool_id=BATCH_POOL_ID, @@ -159,10 +166,6 @@ def setup_method(self, method, mock_batch, mock_hook): target_dedicated_nodes=1, timeout=2, ) - self.batch_client = mock_batch.return_value - self.mock_instance = mock_hook.return_value - assert self.batch_client == self.operator.hook.connection - assert self.batch_client == self.operator2_pass.hook.connection @mock.patch.object(AzureBatchHook, "wait_for_all_node_state") def test_execute_without_failures(self, wait_mock): diff --git a/tests/providers/microsoft/azure/operators/test_azure_cosmos.py b/tests/providers/microsoft/azure/operators/test_azure_cosmos.py index 57e50970069d4..599e24889e27f 100644 --- a/tests/providers/microsoft/azure/operators/test_azure_cosmos.py +++ b/tests/providers/microsoft/azure/operators/test_azure_cosmos.py @@ -17,33 +17,35 @@ # under the License. from __future__ import annotations -import json import uuid from unittest import mock +import pytest + from airflow.models import Connection from airflow.providers.microsoft.azure.operators.cosmos import AzureCosmosInsertDocumentOperator -from airflow.utils import db class TestAzureCosmosDbHook: # Set up an environment to test with - def setup_method(self): + @pytest.fixture(autouse=True) + def setup_test_cases(self, create_mock_connection): # set up some test variables self.test_end_point = "https://test_endpoint:443" self.test_master_key = "magic_test_key" self.test_database_name = "test_database_name" self.test_collection_name = "test_collection_name" - db.merge_conn( + create_mock_connection( Connection( conn_id="azure_cosmos_test_key_id", conn_type="azure_cosmos", login=self.test_end_point, password=self.test_master_key, - extra=json.dumps( - {"database_name": self.test_database_name, "collection_name": self.test_collection_name} - ), + extra={ + "database_name": self.test_database_name, + "collection_name": self.test_collection_name, + }, ) ) diff --git a/tests/providers/microsoft/azure/operators/test_azure_data_factory.py b/tests/providers/microsoft/azure/operators/test_azure_data_factory.py index 2a545cd1371f2..e4dd61c1d2058 100644 --- a/tests/providers/microsoft/azure/operators/test_azure_data_factory.py +++ b/tests/providers/microsoft/azure/operators/test_azure_data_factory.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import json from unittest import mock from unittest.mock import MagicMock, patch @@ -35,7 +34,7 @@ ) from airflow.providers.microsoft.azure.operators.data_factory import AzureDataFactoryRunPipelineOperator from airflow.providers.microsoft.azure.triggers.data_factory import AzureDataFactoryTrigger -from airflow.utils import db, timezone +from airflow.utils import timezone from airflow.utils.types import DagRunType DEFAULT_DATE = timezone.datetime(2021, 1, 1) @@ -60,7 +59,8 @@ class TestAzureDataFactoryRunPipelineOperator: - def setup_method(self): + @pytest.fixture(autouse=True) + def setup_test_cases(self, create_mock_connection): self.mock_ti = MagicMock() self.mock_context = {"ti": self.mock_ti} self.config = { @@ -73,13 +73,13 @@ def setup_method(self): "timeout": 3, } - db.merge_conn( + create_mock_connection( Connection( conn_id="azure_data_factory_test", conn_type="azure_data_factory", login="client-id", password="client-secret", - extra=json.dumps(CONN_EXTRAS), + extra=CONN_EXTRAS, ) ) diff --git a/tests/providers/microsoft/azure/operators/test_azure_synapse.py b/tests/providers/microsoft/azure/operators/test_azure_synapse.py index c43b11ef7b888..233e1c57fdc79 100644 --- a/tests/providers/microsoft/azure/operators/test_azure_synapse.py +++ b/tests/providers/microsoft/azure/operators/test_azure_synapse.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import json from unittest import mock from unittest.mock import MagicMock @@ -24,7 +23,7 @@ from airflow.models import Connection from airflow.providers.microsoft.azure.operators.synapse import AzureSynapseRunSparkBatchOperator -from airflow.utils import db, timezone +from airflow.utils import timezone DEFAULT_DATE = timezone.datetime(2021, 1, 1) SUBSCRIPTION_ID = "my-subscription-id" @@ -39,7 +38,8 @@ class TestAzureSynapseRunSparkBatchOperator: - def setup_method(self): + @pytest.fixture(autouse=True) + def setup_test_cases(self, create_mock_connection): self.mock_ti = MagicMock() self.mock_context = {"ti": self.mock_ti} self.config = { @@ -50,22 +50,21 @@ def setup_method(self): "timeout": 3, } - db.merge_conn( + create_mock_connection( Connection( conn_id=AZURE_SYNAPSE_CONN_ID, conn_type="azure_synapse", host="https://synapsetest.net", login="client-id", password="client-secret", - extra=json.dumps(CONN_EXTRAS), + extra=CONN_EXTRAS, ) ) @mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.get_job_run_status") - @mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.get_conn") @mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.run_spark_job") def test_azure_synapse_run_spark_batch_operator_success( - self, mock_run_spark_job, mock_conn, mock_get_job_run_status + self, mock_run_spark_job, mock_get_job_run_status ): mock_get_job_run_status.return_value = "success" mock_run_spark_job.return_value = MagicMock(**JOB_RUN_RESPONSE) @@ -76,11 +75,8 @@ def test_azure_synapse_run_spark_batch_operator_success( assert op.job_id == JOB_RUN_RESPONSE["id"] @mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.get_job_run_status") - @mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.get_conn") @mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.run_spark_job") - def test_azure_synapse_run_spark_batch_operator_error( - self, mock_run_spark_job, mock_conn, mock_get_job_run_status - ): + def test_azure_synapse_run_spark_batch_operator_error(self, mock_run_spark_job, mock_get_job_run_status): mock_get_job_run_status.return_value = "error" mock_run_spark_job.return_value = MagicMock(**JOB_RUN_RESPONSE) op = AzureSynapseRunSparkBatchOperator( @@ -93,11 +89,10 @@ def test_azure_synapse_run_spark_batch_operator_error( op.execute(context=self.mock_context) @mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.get_job_run_status") - @mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.get_conn") @mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.run_spark_job") @mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.cancel_job_run") def test_azure_synapse_run_spark_batch_operator_on_kill( - self, mock_cancel_job_run, mock_run_spark_job, mock_conn, mock_get_job_run_status + self, mock_cancel_job_run, mock_run_spark_job, mock_get_job_run_status ): mock_get_job_run_status.return_value = "success" mock_run_spark_job.return_value = MagicMock(**JOB_RUN_RESPONSE) diff --git a/tests/providers/microsoft/conftest.py b/tests/providers/microsoft/conftest.py new file mode 100644 index 0000000000000..aa75c95203c68 --- /dev/null +++ b/tests/providers/microsoft/conftest.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import string +from random import choices +from typing import TypeVar + +import pytest + +from airflow.models import Connection + +T = TypeVar("T", dict, str, Connection) + + +@pytest.fixture +def create_mock_connection(monkeypatch): + """Helper fixture for create test connection.""" + + def wrapper(conn: T, conn_id: str | None = None): + conn_id = conn_id or "test_conn_" + "".join(choices(string.ascii_lowercase + string.digits, k=6)) + if isinstance(conn, dict): + conn = Connection.from_json(conn) + elif isinstance(conn, str): + conn = Connection(uri=conn) + + if not isinstance(conn, Connection): + raise TypeError( + f"Fixture expected either JSON, URI or Connection type, but got {type(conn).__name__}" + ) + if not conn.conn_id: + conn.conn_id = conn_id + + monkeypatch.setenv(f"AIRFLOW_CONN_{conn.conn_id.upper()}", conn.get_uri()) + return conn + + return wrapper + + +@pytest.fixture +def create_mock_connections(create_mock_connection): + """Helper fixture for create multiple test connections.""" + + def wrapper(*conns: T): + return list(map(create_mock_connection, conns)) + + return wrapper + + +@pytest.fixture +def mocked_connection(request, create_mock_connection): + """Helper indirect fixture for create test connection.""" + return create_mock_connection(request.param) From a6bf98bd8519ec45a7ffcdecc233487aa85936fc Mon Sep 17 00:00:00 2001 From: Andrey Anshin Date: Sat, 26 Aug 2023 02:26:03 +0400 Subject: [PATCH 3/3] Add 'BatchServiceClient' to spellcheck --- docs/spelling_wordlist.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index ec01067b12bfe..eb80939a4c4e1 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -154,6 +154,7 @@ BaseView BaseXCom bashrc batchGet +BatchServiceClient bc bcc bdist