From 38f03291d261659ccb97a146f76db01d561b43e9 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 16 Aug 2023 18:35:19 +0800 Subject: [PATCH 1/2] feat(providers/microsoft): add DefaultAzureCredential support to cosmos --- airflow/providers/microsoft/azure/hooks/cosmos.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/airflow/providers/microsoft/azure/hooks/cosmos.py b/airflow/providers/microsoft/azure/hooks/cosmos.py index 3a73eeada7d7d..53135dc1d32fd 100644 --- a/airflow/providers/microsoft/azure/hooks/cosmos.py +++ b/airflow/providers/microsoft/azure/hooks/cosmos.py @@ -31,6 +31,7 @@ from azure.cosmos.cosmos_client import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError +from azure.identity import DefaultAzureCredential from airflow.exceptions import AirflowBadRequest from airflow.hooks.base import BaseHook @@ -109,13 +110,18 @@ def get_conn(self) -> CosmosClient: conn = self.get_connection(self.conn_id) extras = conn.extra_dejson endpoint_uri = conn.login - master_key = conn.password + credential: dict[str, Any] | DefaultAzureCredential + if conn.password: + master_key = conn.password + credential = {"masterKey": master_key} + else: + credential = DefaultAzureCredential() self.default_database_name = self._get_field(extras, "database_name") self.default_collection_name = self._get_field(extras, "collection_name") # Initialize the Python Azure Cosmos DB client - self._conn = CosmosClient(endpoint_uri, {"masterKey": master_key}) + self._conn = CosmosClient(endpoint_uri, credential=credential) return self._conn def __get_database_name(self, database_name: str | None = None) -> str: From 02cf98c19304e89ce03630c4e8c570ef9a8b06d1 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 22 Aug 2023 23:24:26 +0800 Subject: [PATCH 2/2] feat: use CosmosDBManagementClient to authenticate cosmos client through DefaultAzureCredential --- .../providers/microsoft/azure/hooks/cosmos.py | 35 +++++++++++++++---- .../providers/microsoft/azure/provider.yaml | 1 + .../index.rst | 1 + generated/provider_dependencies.json | 1 + .../azure/hooks/test_azure_cosmos.py | 3 +- 5 files changed, 34 insertions(+), 7 deletions(-) diff --git a/airflow/providers/microsoft/azure/hooks/cosmos.py b/airflow/providers/microsoft/azure/hooks/cosmos.py index 53135dc1d32fd..2ebf9ed0f9bee 100644 --- a/airflow/providers/microsoft/azure/hooks/cosmos.py +++ b/airflow/providers/microsoft/azure/hooks/cosmos.py @@ -28,12 +28,14 @@ import json import uuid from typing import Any +from urllib.parse import urlparse from azure.cosmos.cosmos_client import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError from azure.identity import DefaultAzureCredential +from azure.mgmt.cosmosdb import CosmosDBManagementClient -from airflow.exceptions import AirflowBadRequest +from airflow.exceptions import AirflowBadRequest, AirflowException from airflow.hooks.base import BaseHook from airflow.providers.microsoft.azure.utils import get_field @@ -69,6 +71,14 @@ def get_connection_form_widgets() -> dict[str, Any]: "collection_name": StringField( lazy_gettext("Cosmos Collection Name (optional)"), widget=BS3TextFieldWidget() ), + "subscription_id": StringField( + lazy_gettext("Subscription ID (optional)"), + widget=BS3TextFieldWidget(), + ), + "resource_group_name": StringField( + lazy_gettext("Resource Group Name (optional)"), + widget=BS3TextFieldWidget(), + ), } @staticmethod @@ -82,9 +92,11 @@ def get_ui_field_behaviour() -> dict[str, Any]: }, "placeholders": { "login": "endpoint uri", - "password": "master key", + "password": "master key (not needed for Azure AD authentication)", "database_name": "database name", "collection_name": "collection name", + "subscription_id": "Subscription ID (required for Azure AD authentication)", + "resource_group_name": "Resource Group Name (required for Azure AD authentication)", }, } @@ -110,18 +122,29 @@ def get_conn(self) -> CosmosClient: conn = self.get_connection(self.conn_id) extras = conn.extra_dejson endpoint_uri = conn.login - credential: dict[str, Any] | DefaultAzureCredential + resource_group_name = self._get_field(extras, "resource_group_name") + if conn.password: master_key = conn.password - credential = {"masterKey": master_key} + elif resource_group_name: + management_client = CosmosDBManagementClient( + credential=DefaultAzureCredential(), + subscription_id=self._get_field(extras, "subscription_id"), + ) + + database_account = urlparse(conn.login).netloc.split(".")[0] + database_account_keys = management_client.database_accounts.list_keys( + resource_group_name, database_account + ) + master_key = database_account_keys.primary_master_key else: - credential = DefaultAzureCredential() + raise AirflowException("Either password or resource_group_name is required") self.default_database_name = self._get_field(extras, "database_name") self.default_collection_name = self._get_field(extras, "collection_name") # Initialize the Python Azure Cosmos DB client - self._conn = CosmosClient(endpoint_uri, credential=credential) + self._conn = CosmosClient(endpoint_uri, {"masterKey": master_key}) return self._conn def __get_database_name(self, database_name: str | None = None) -> str: diff --git a/airflow/providers/microsoft/azure/provider.yaml b/airflow/providers/microsoft/azure/provider.yaml index a0ed5fbdf2599..8459d583ee9a1 100644 --- a/airflow/providers/microsoft/azure/provider.yaml +++ b/airflow/providers/microsoft/azure/provider.yaml @@ -66,6 +66,7 @@ dependencies: - apache-airflow>=2.4.0 - azure-batch>=8.0.0 - azure-cosmos>=4.0.0 + - azure-mgmt-cosmosdb - azure-datalake-store>=0.0.45 - azure-identity>=1.3.1 - azure-keyvault-secrets>=4.1.0 diff --git a/docs/apache-airflow-providers-microsoft-azure/index.rst b/docs/apache-airflow-providers-microsoft-azure/index.rst index 73f9a0733c656..a3889484100ff 100644 --- a/docs/apache-airflow-providers-microsoft-azure/index.rst +++ b/docs/apache-airflow-providers-microsoft-azure/index.rst @@ -107,6 +107,7 @@ PIP package Version required ``apache-airflow`` ``>=2.4.0`` ``azure-batch`` ``>=8.0.0`` ``azure-cosmos`` ``>=4.0.0`` +``azure-mgmt-cosmosdb`` ``azure-datalake-store`` ``>=0.0.45`` ``azure-identity`` ``>=1.3.1`` ``azure-keyvault-secrets`` ``>=4.1.0`` diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 5191600a98392..b0dcab34c1747 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -560,6 +560,7 @@ "azure-keyvault-secrets>=4.1.0", "azure-kusto-data>=4.1.0", "azure-mgmt-containerinstance>=1.5.0,<2.0", + "azure-mgmt-cosmosdb", "azure-mgmt-datafactory>=1.0.0,<2.0", "azure-mgmt-datalake-store>=0.5.0", "azure-mgmt-resource>=2.2.0", diff --git a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py index af649f1d6e4d9..787ea5c4537c8 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py @@ -34,7 +34,6 @@ class TestAzureCosmosDbHook: - # Set up an environment to test with def setup_method(self): # set up some test variables @@ -266,6 +265,8 @@ def test_get_ui_field_behaviour_placeholders(self): "password", "database_name", "collection_name", + "subscription_id", + "resource_group_name", ] if get_provider_min_airflow_version("apache-airflow-providers-microsoft-azure") >= (2, 5): raise Exception(