Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions airflow/providers/microsoft/azure/hooks/cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@
import json
import uuid
from typing import Any
from urllib.parse import urlparse

from azure.cosmos.cosmos_client import CosmosClient
from azure.cosmos.exceptions import CosmosHttpResponseError
from azure.identity import DefaultAzureCredential
from azure.mgmt.cosmosdb import CosmosDBManagementClient

from airflow.exceptions import AirflowBadRequest
from airflow.exceptions import AirflowBadRequest, AirflowException
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import get_field

Expand Down Expand Up @@ -68,6 +71,14 @@ def get_connection_form_widgets() -> dict[str, Any]:
"collection_name": StringField(
lazy_gettext("Cosmos Collection Name (optional)"), widget=BS3TextFieldWidget()
),
"subscription_id": StringField(
lazy_gettext("Subscription ID (optional)"),
widget=BS3TextFieldWidget(),
),
"resource_group_name": StringField(
lazy_gettext("Resource Group Name (optional)"),
widget=BS3TextFieldWidget(),
),
}

@staticmethod
Expand All @@ -81,9 +92,11 @@ def get_ui_field_behaviour() -> dict[str, Any]:
},
"placeholders": {
"login": "endpoint uri",
"password": "master key",
"password": "master key (not needed for Azure AD authentication)",
"database_name": "database name",
"collection_name": "collection name",
"subscription_id": "Subscription ID (required for Azure AD authentication)",
"resource_group_name": "Resource Group Name (required for Azure AD authentication)",
},
}

Expand All @@ -109,7 +122,23 @@ def get_conn(self) -> CosmosClient:
conn = self.get_connection(self.conn_id)
extras = conn.extra_dejson
endpoint_uri = conn.login
master_key = conn.password
resource_group_name = self._get_field(extras, "resource_group_name")

if conn.password:
master_key = conn.password
elif resource_group_name:
management_client = CosmosDBManagementClient(
credential=DefaultAzureCredential(),
subscription_id=self._get_field(extras, "subscription_id"),
)

database_account = urlparse(conn.login).netloc.split(".")[0]
database_account_keys = management_client.database_accounts.list_keys(
resource_group_name, database_account
)
master_key = database_account_keys.primary_master_key
else:
raise AirflowException("Either password or resource_group_name is required")

self.default_database_name = self._get_field(extras, "database_name")
self.default_collection_name = self._get_field(extras, "collection_name")
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/microsoft/azure/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ dependencies:
- apache-airflow>=2.4.0
- azure-batch>=8.0.0
- azure-cosmos>=4.0.0
- azure-mgmt-cosmosdb
- azure-datalake-store>=0.0.45
- azure-identity>=1.3.1
- azure-keyvault-secrets>=4.1.0
Expand Down
1 change: 1 addition & 0 deletions docs/apache-airflow-providers-microsoft-azure/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ PIP package Version required
``apache-airflow`` ``>=2.4.0``
``azure-batch`` ``>=8.0.0``
``azure-cosmos`` ``>=4.0.0``
``azure-mgmt-cosmosdb``
``azure-datalake-store`` ``>=0.0.45``
``azure-identity`` ``>=1.3.1``
``azure-keyvault-secrets`` ``>=4.1.0``
Expand Down
1 change: 1 addition & 0 deletions generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@
"azure-keyvault-secrets>=4.1.0",
"azure-kusto-data>=4.1.0",
"azure-mgmt-containerinstance>=1.5.0,<2.0",
"azure-mgmt-cosmosdb",
"azure-mgmt-datafactory>=1.0.0,<2.0",
"azure-mgmt-datalake-store>=0.5.0",
"azure-mgmt-resource>=2.2.0",
Expand Down
3 changes: 2 additions & 1 deletion tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@


class TestAzureCosmosDbHook:

# Set up an environment to test with
def setup_method(self):
# set up some test variables
Expand Down Expand Up @@ -266,6 +265,8 @@ def test_get_ui_field_behaviour_placeholders(self):
"password",
"database_name",
"collection_name",
"subscription_id",
"resource_group_name",
]
if get_provider_min_airflow_version("apache-airflow-providers-microsoft-azure") >= (2, 5):
raise Exception(
Expand Down