Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions airflow/providers/microsoft/azure/hooks/adx.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from __future__ import annotations

import warnings
from functools import cached_property
from typing import Any

from azure.identity import DefaultAzureCredential
Expand Down Expand Up @@ -76,8 +77,8 @@ class AzureDataExplorerHook(BaseHook):
conn_type = "azure_data_explorer"
hook_name = "Azure Data Explorer"

@staticmethod
def get_connection_form_widgets() -> dict[str, Any]:
@classmethod
def get_connection_form_widgets(cls) -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget
from flask_babel import lazy_gettext
Expand All @@ -94,8 +95,8 @@ def get_connection_form_widgets() -> dict[str, Any]:
),
}

@staticmethod
def get_ui_field_behaviour() -> dict[str, Any]:
@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
"hidden_fields": ["schema", "port", "extra"],
Expand All @@ -116,7 +117,11 @@ def get_ui_field_behaviour() -> dict[str, Any]:
def __init__(self, azure_data_explorer_conn_id: str = default_conn_name) -> None:
super().__init__()
self.conn_id = azure_data_explorer_conn_id
self.connection = self.get_conn() # todo: make this a property, or just delete

@cached_property
def connection(self) -> KustoClient:
"""Return a KustoClient object (cached)."""
return self.get_conn()

def get_conn(self) -> KustoClient:
"""Return a KustoClient object."""
Expand Down
23 changes: 11 additions & 12 deletions airflow/providers/microsoft/azure/hooks/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@

import time
from datetime import timedelta
from functools import cached_property
from typing import Any

from azure.batch import BatchServiceClient, batch_auth, models as batch_models
from azure.batch.models import JobAddParameter, PoolAddParameter, TaskAddParameter

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models import Connection
from airflow.providers.microsoft.azure.utils import AzureIdentityCredentialAdapter, get_field
from airflow.utils import timezone

Expand All @@ -52,8 +52,8 @@ def _get_field(self, extras, name):
field_name=name,
)

@staticmethod
def get_connection_form_widgets() -> dict[str, Any]:
@classmethod
def get_connection_form_widgets(cls) -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from flask_babel import lazy_gettext
Expand All @@ -63,8 +63,8 @@ def get_connection_form_widgets() -> dict[str, Any]:
"account_url": StringField(lazy_gettext("Batch Account URL"), widget=BS3TextFieldWidget()),
}

@staticmethod
def get_ui_field_behaviour() -> dict[str, Any]:
@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
"hidden_fields": ["schema", "port", "host", "extra"],
Expand All @@ -77,20 +77,19 @@ def get_ui_field_behaviour() -> dict[str, Any]:
def __init__(self, azure_batch_conn_id: str = default_conn_name) -> None:
super().__init__()
self.conn_id = azure_batch_conn_id
self.connection = self.get_conn()

def _connection(self) -> Connection:
"""Get connected to Azure Batch service."""
conn = self.get_connection(self.conn_id)
return conn
@cached_property
def connection(self) -> BatchServiceClient:
"""Get the Batch client connection (cached)."""
return self.get_conn()

def get_conn(self):
def get_conn(self) -> BatchServiceClient:
"""
Get the Batch client connection.

:return: Azure Batch client
"""
conn = self._connection()
conn = self.get_connection(self.conn_id)

batch_account_url = self._get_field(conn.extra_dejson, "account_url")
if not batch_account_url:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import warnings
from functools import cached_property

from azure.mgmt.containerinstance import ContainerInstanceManagementClient
from azure.mgmt.containerinstance.models import ContainerGroup
Expand Down Expand Up @@ -47,7 +48,10 @@ class AzureContainerInstanceHook(AzureBaseHook):

def __init__(self, azure_conn_id: str = default_conn_name) -> None:
super().__init__(sdk_client=ContainerInstanceManagementClient, conn_id=azure_conn_id)
self.connection = self.get_conn()

@cached_property
def connection(self):
return self.get_conn()

def create_or_update(self, resource_group: str, name: str, container_group: ContainerGroup) -> None:
"""
Expand Down
10 changes: 7 additions & 3 deletions airflow/providers/microsoft/azure/hooks/container_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Hook for Azure Container Registry."""
from __future__ import annotations

from functools import cached_property
from typing import Any

from azure.mgmt.containerinstance.models import ImageRegistryCredential
Expand All @@ -39,8 +40,8 @@ class AzureContainerRegistryHook(BaseHook):
conn_type = "azure_container_registry"
hook_name = "Azure Container Registry"

@staticmethod
def get_ui_field_behaviour() -> dict[str, Any]:
@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
"hidden_fields": ["schema", "port", "extra"],
Expand All @@ -59,7 +60,10 @@ def get_ui_field_behaviour() -> dict[str, Any]:
def __init__(self, conn_id: str = "azure_registry") -> None:
super().__init__()
self.conn_id = conn_id
self.connection = self.get_conn()

@cached_property
def connection(self) -> ImageRegistryCredential:
return self.get_conn()

def get_conn(self) -> ImageRegistryCredential:
conn = self.get_connection(self.conn_id)
Expand Down
15 changes: 10 additions & 5 deletions airflow/providers/microsoft/azure/hooks/data_lake.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import Any

from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
Expand Down Expand Up @@ -256,8 +257,8 @@ class AzureDataLakeStorageV2Hook(BaseHook):
conn_type = "adls"
hook_name = "Azure Date Lake Storage V2"

@staticmethod
def get_connection_form_widgets() -> dict[str, Any]:
@classmethod
def get_connection_form_widgets(cls) -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget
from flask_babel import lazy_gettext
Expand All @@ -272,8 +273,8 @@ def get_connection_form_widgets() -> dict[str, Any]:
),
}

@staticmethod
def get_ui_field_behaviour() -> dict[str, Any]:
@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
"hidden_fields": ["schema", "port"],
Expand All @@ -296,7 +297,11 @@ def __init__(self, adls_conn_id: str, public_read: bool = False) -> None:
super().__init__()
self.conn_id = adls_conn_id
self.public_read = public_read
self.service_client = self.get_conn()

@cached_property
def service_client(self) -> DataLakeServiceClient:
"""Return the DataLakeServiceClient object (cached)."""
return self.get_conn()

def get_conn(self) -> DataLakeServiceClient: # type: ignore[override]
"""Return the DataLakeServiceClient object."""
Expand Down
7 changes: 6 additions & 1 deletion airflow/providers/microsoft/azure/hooks/wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import logging
import os
from functools import cached_property
from typing import Any, Union
from urllib.parse import urlparse

Expand Down Expand Up @@ -123,7 +124,6 @@ def __init__(
super().__init__()
self.conn_id = wasb_conn_id
self.public_read = public_read
self.blob_service_client: BlobServiceClient = self.get_conn()

logger = logging.getLogger("azure.core.pipeline.policies.http_logging_policy")
try:
Expand All @@ -142,6 +142,11 @@ def _get_field(self, extra_dict, field_name):
return extra_dict[field_name] or None
return extra_dict.get(f"{prefix}{field_name}") or None

@cached_property
def blob_service_client(self) -> BlobServiceClient:
"""Return the BlobServiceClient object (cached)."""
return self.get_conn()

def get_conn(self) -> BlobServiceClient:
"""Return the BlobServiceClient object."""
conn = self.get_connection(self.conn_id)
Expand Down
1 change: 0 additions & 1 deletion airflow/providers/microsoft/azure/log/wasb_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def __init__(
self.wasb_container = wasb_container
self.remote_base = wasb_log_folder
self.log_relative_path = ""
self._hook = None
self.closed = False
self.upload_on_close = True
self.delete_local_copy = (
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/microsoft/azure/operators/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ def __init__(
self.should_delete_pool = should_delete_pool

@cached_property
def hook(self):
def hook(self) -> AzureBatchHook:
"""Create and return an AzureBatchHook (cached)."""
return self.get_hook()

def _check_inputs(self) -> Any:
Expand Down
7 changes: 6 additions & 1 deletion airflow/providers/microsoft/azure/operators/data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import time
import warnings
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from airflow.configuration import conf
Expand Down Expand Up @@ -159,8 +160,12 @@ def __init__(
self.check_interval = check_interval
self.deferrable = deferrable

@cached_property
def hook(self) -> AzureDataFactoryHook:
"""Create and return an AzureDataFactoryHook (cached)."""
return AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)

def execute(self, context: Context) -> None:
self.hook = AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)
self.log.info("Executing the %s pipeline.", self.pipeline_name)
response = self.hook.run_pipeline(
pipeline_name=self.pipeline_name,
Expand Down
9 changes: 6 additions & 3 deletions airflow/providers/microsoft/azure/operators/synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Sequence

from azure.synapse.spark.models import SparkBatchJobOptions
Expand Down Expand Up @@ -73,10 +74,12 @@ def __init__(
self.timeout = timeout
self.check_interval = check_interval

@cached_property
def hook(self):
"""Create and return an AzureSynapseHook (cached)."""
return AzureSynapseHook(azure_synapse_conn_id=self.azure_synapse_conn_id, spark_pool=self.spark_pool)

def execute(self, context: Context) -> None:
self.hook = AzureSynapseHook(
azure_synapse_conn_id=self.azure_synapse_conn_id, spark_pool=self.spark_pool
)
self.log.info("Executing the Synapse spark job.")
response = self.hook.run_spark_job(payload=self.payload)
self.log.info(response)
Expand Down
7 changes: 6 additions & 1 deletion airflow/providers/microsoft/azure/sensors/data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import warnings
from datetime import timedelta
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from airflow.configuration import conf
Expand Down Expand Up @@ -72,8 +73,12 @@ def __init__(

self.deferrable = deferrable

@cached_property
def hook(self):
"""Create and return an AzureDataFactoryHook (cached)."""
return AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)

def poke(self, context: Context) -> bool:
self.hook = AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)
pipeline_run_status = self.hook.get_pipeline_run_status(
run_id=self.run_id,
resource_group_name=self.resource_group_name,
Expand Down
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ BaseView
BaseXCom
bashrc
batchGet
BatchServiceClient
bc
bcc
bdist
Expand Down
Loading