diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py b/airflow/providers/microsoft/azure/hooks/wasb.py index f76ac7439aacb..b69e999c5b8bf 100644 --- a/airflow/providers/microsoft/azure/hooks/wasb.py +++ b/airflow/providers/microsoft/azure/hooks/wasb.py @@ -33,7 +33,7 @@ from asgiref.sync import sync_to_async from azure.core.exceptions import HttpResponseError, ResourceExistsError, ResourceNotFoundError -from azure.identity import ClientSecretCredential, DefaultAzureCredential +from azure.identity import ClientSecretCredential from azure.identity.aio import ( ClientSecretCredential as AsyncClientSecretCredential, DefaultAzureCredential as AsyncDefaultAzureCredential, @@ -47,6 +47,7 @@ from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook +from airflow.providers.microsoft.azure.utils import get_default_azure_credential if TYPE_CHECKING: from azure.storage.blob._models import BlobProperties @@ -94,6 +95,12 @@ def get_connection_form_widgets() -> dict[str, Any]: lazy_gettext("Tenant Id (Active Directory Auth)"), widget=BS3TextFieldWidget() ), "sas_token": PasswordField(lazy_gettext("SAS Token (optional)"), widget=BS3PasswordFieldWidget()), + "managed_identity_client_id": StringField( + lazy_gettext("Managed Identity Client ID"), widget=BS3TextFieldWidget() + ), + "workload_identity_tenant_id": StringField( + lazy_gettext("Workload Identity Tenant ID"), widget=BS3TextFieldWidget() + ), } @staticmethod @@ -115,6 +122,8 @@ def get_ui_field_behaviour() -> dict[str, Any]: "shared_access_key": "shared access key", "sas_token": "account url or token", "extra": "additional options for use with ClientSecretCredential or DefaultAzureCredential", + "managed_identity_client_id": "Managed Identity Client ID", + "workload_identity_tenant_id": "Workload Identity Tenant ID", }, } @@ -207,7 +216,10 @@ def get_conn(self) -> BlobServiceClient: # Fall back to old auth (password) or use managed identity if not provided. credential = conn.password if not credential: - credential = DefaultAzureCredential() + managed_identity_client_id = self._get_field(extra, "managed_identity_client_id") + workload_identity_tenant_id = self._get_field(extra, "workload_identity_tenant_id") + credential = get_default_azure_credential(managed_identity_client_id, workload_identity_tenant_id) + self.log.info("Using DefaultAzureCredential as credential") return BlobServiceClient( account_url=account_url, diff --git a/docs/apache-airflow-providers-microsoft-azure/connections/wasb.rst b/docs/apache-airflow-providers-microsoft-azure/connections/wasb.rst index 4ab7f2f43f1d5..e7e2f78b596ad 100644 --- a/docs/apache-airflow-providers-microsoft-azure/connections/wasb.rst +++ b/docs/apache-airflow-providers-microsoft-azure/connections/wasb.rst @@ -27,7 +27,7 @@ The Microsoft Azure Blob Storage connection type enables the Azure Blob Storage Authenticating to Azure Blob Storage ------------------------------------ -There are five ways to connect to Azure Blob Storage using Airflow. +There are six ways to connect to Azure Blob Storage using Airflow. 1. Use `token credentials`_ i.e. add specific credentials (client_id, secret, tenant) and subscription id to the Airflow connection. @@ -37,7 +37,8 @@ There are five ways to connect to Azure Blob Storage using Airflow. i.e. add a key config to ``sas_token`` in the Airflow connection. 4. Use a `Connection String`_ i.e. add connection string to ``connection_string`` in the Airflow connection. -5. Fallback on DefaultAzureCredential_. +5. Use managed identity by setting ``managed_identity_client_id``, ``workload_identity_tenant_id`` (under the hook, it uses DefaultAzureCredential_ with these arguments) +6. Fallback on DefaultAzureCredential_. This includes a mechanism to try different options to authenticate: Managed System Identity, environment variables, authentication through Azure CLI, etc. Only one authorization method can be used at a time. If you need to manage multiple credentials or keys then you should @@ -84,6 +85,8 @@ Extra (optional) The following parameters are all optional: * ``client_secret_auth_config``: Extra config to pass while authenticating as a service principal using `ClientSecretCredential`_ It can be left out to fall back on DefaultAzureCredential_. + * ``managed_identity_client_id``: The client ID of a user-assigned managed identity. If provided with `workload_identity_tenant_id`, they'll pass to ``DefaultAzureCredential``. + * ``workload_identity_tenant_id``: ID of the application's Microsoft Entra tenant. Also called its "directory" ID. If provided with `managed_identity_client_id`, they'll pass to ``DefaultAzureCredential``. When specifying the connection in environment variable you should specify it using URI syntax. @@ -96,9 +99,14 @@ For example connect with token credentials: export AIRFLOW_CONN_WASB_DEFAULT='wasb://blob%20username:blob%20password@myblob.com?tenant_id=tenant+id' + .. _token credentials: https://docs.microsoft.com/en-us/azure/developer/python/azure-sdk-authenticate?tabs=cmd#authenticate-with-token-credentials .. _Azure Shared Key Credential: https://docs.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key .. _SAS Token: https://docs.microsoft.com/en-us/rest/api/storageservices/create-account-sas .. _Connection String: https://docs.microsoft.com/en-us/azure/data-explorer/kusto/api/connection-strings/storage .. _DefaultAzureCredential: https://docs.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python#defaultazurecredential .. _ClientSecretCredential: https://learn.microsoft.com/en-in/python/api/azure-identity/azure.identity.clientsecretcredential?view=azure-python + +.. spelling:word-list:: + + Entra diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py b/tests/providers/microsoft/azure/hooks/test_wasb.py index 5344091048565..b5a87175475e9 100644 --- a/tests/providers/microsoft/azure/hooks/test_wasb.py +++ b/tests/providers/microsoft/azure/hooks/test_wasb.py @@ -47,7 +47,7 @@ def mocked_blob_service_client(): @pytest.fixture def mocked_default_azure_credential(): - with mock.patch("airflow.providers.microsoft.azure.hooks.wasb.DefaultAzureCredential") as m: + with mock.patch("airflow.providers.microsoft.azure.hooks.wasb.get_default_azure_credential") as m: yield m