From 563fc907c71202b8d2e07110f8c78a7f906516e6 Mon Sep 17 00:00:00 2001 From: Phani Kumar Date: Wed, 5 Apr 2023 23:15:30 +0530 Subject: [PATCH] Merge WasbBlobAsyncSensor to WasbBlobSensor --- .../providers/microsoft/azure/sensors/wasb.py | 90 ++++++++++--------- .../microsoft/azure/sensors/test_wasb.py | 4 +- .../azure/example_azure_blob_to_gcs.py | 4 +- 3 files changed, 51 insertions(+), 47 deletions(-) diff --git a/airflow/providers/microsoft/azure/sensors/wasb.py b/airflow/providers/microsoft/azure/sensors/wasb.py index 017d73720dc92..77f2256a64de7 100644 --- a/airflow/providers/microsoft/azure/sensors/wasb.py +++ b/airflow/providers/microsoft/azure/sensors/wasb.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import warnings from datetime import timedelta from typing import TYPE_CHECKING, Any, Sequence @@ -38,6 +39,8 @@ class WasbBlobSensor(BaseSensorOperator): :param wasb_conn_id: Reference to the :ref:`wasb connection `. :param check_options: Optional keyword arguments that `WasbHook.check_for_blob()` takes. + :param deferrable: Run sensor in the deferrable mode. + :param public_read: whether an anonymous public read access should be used. Default is False """ template_fields: Sequence[str] = ("container_name", "blob_name") @@ -49,6 +52,8 @@ def __init__( blob_name: str, wasb_conn_id: str = "wasb_default", check_options: dict | None = None, + public_read: bool = False, + deferrable: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) @@ -58,57 +63,32 @@ def __init__( self.container_name = container_name self.blob_name = blob_name self.check_options = check_options + self.public_read = public_read + self.deferrable = deferrable def poke(self, context: Context): self.log.info("Poking for blob: %s\n in wasb://%s", self.blob_name, self.container_name) hook = WasbHook(wasb_conn_id=self.wasb_conn_id) return hook.check_for_blob(self.container_name, self.blob_name, **self.check_options) - -class WasbBlobAsyncSensor(WasbBlobSensor): - """ - Polls asynchronously for the existence of a blob in a WASB container. - - :param container_name: name of the container in which the blob should be searched for - :param blob_name: name of the blob to check existence for - :param wasb_conn_id: the connection identifier for connecting to Azure WASB - :param poke_interval: polling period in seconds to check for the status - :param public_read: whether an anonymous public read access should be used. Default is False - :param timeout: Time, in seconds before the task times out and fails. - """ - - def __init__( - self, - *, - container_name: str, - blob_name: str, - wasb_conn_id: str = "wasb_default", - public_read: bool = False, - poke_interval: float = 5.0, - **kwargs: Any, - ): - self.container_name = container_name - self.blob_name = blob_name - self.poke_interval = poke_interval - super().__init__(container_name=container_name, blob_name=blob_name, **kwargs) - self.wasb_conn_id = wasb_conn_id - self.public_read = public_read - def execute(self, context: Context) -> None: - """Defers trigger class to poll for state of the job run until it reaches - a failure state or success state + """Defers trigger class to poll for state of the job run until + it reaches a failure state or success state """ - self.defer( - timeout=timedelta(seconds=self.timeout), - trigger=WasbBlobSensorTrigger( - container_name=self.container_name, - blob_name=self.blob_name, - wasb_conn_id=self.wasb_conn_id, - public_read=self.public_read, - poke_interval=self.poke_interval, - ), - method_name="execute_complete", - ) + if not self.deferrable: + super().execute(context=context) + else: + self.defer( + timeout=timedelta(seconds=self.timeout), + trigger=WasbBlobSensorTrigger( + container_name=self.container_name, + blob_name=self.blob_name, + wasb_conn_id=self.wasb_conn_id, + public_read=self.public_read, + poke_interval=self.poke_interval, + ), + method_name="execute_complete", + ) def execute_complete(self, context: Context, event: dict[str, str]) -> None: """ @@ -124,6 +104,30 @@ def execute_complete(self, context: Context, event: dict[str, str]) -> None: raise AirflowException("Did not receive valid event from the triggerer") +class WasbBlobAsyncSensor(WasbBlobSensor): + """ + Polls asynchronously for the existence of a blob in a WASB container. + + :param container_name: name of the container in which the blob should be searched for + :param blob_name: name of the blob to check existence for + :param wasb_conn_id: the connection identifier for connecting to Azure WASB + :param poke_interval: polling period in seconds to check for the status + :param public_read: whether an anonymous public read access should be used. Default is False + :param timeout: Time, in seconds before the task times out and fails. + """ + + def __init__(self, **kwargs: Any) -> None: + warnings.warn( + "Class `WasbBlobAsyncSensor` is deprecated and " + "will be removed in a future release. " + "Please use `WasbBlobSensor` and " + "set `deferrable` attribute to `True` instead", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(**kwargs, deferrable=True) + + class WasbPrefixSensor(BaseSensorOperator): """ Waits for blobs matching a prefix to arrive on Azure Blob Storage. diff --git a/tests/providers/microsoft/azure/sensors/test_wasb.py b/tests/providers/microsoft/azure/sensors/test_wasb.py index 67830c44e3dff..a7e30e272d482 100644 --- a/tests/providers/microsoft/azure/sensors/test_wasb.py +++ b/tests/providers/microsoft/azure/sensors/test_wasb.py @@ -29,7 +29,6 @@ from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance from airflow.providers.microsoft.azure.sensors.wasb import ( - WasbBlobAsyncSensor, WasbBlobSensor, WasbPrefixSensor, ) @@ -120,11 +119,12 @@ def create_context(self, task, dag=None): "logical_date": execution_date, } - SENSOR = WasbBlobAsyncSensor( + SENSOR = WasbBlobSensor( task_id="wasb_blob_async_sensor", container_name=TEST_DATA_STORAGE_CONTAINER_NAME, blob_name=TEST_DATA_STORAGE_BLOB_NAME, timeout=5, + deferrable=True, ) def test_wasb_blob_sensor_async(self): diff --git a/tests/system/providers/microsoft/azure/example_azure_blob_to_gcs.py b/tests/system/providers/microsoft/azure/example_azure_blob_to_gcs.py index 83da48d985781..8bd5de0d939c2 100644 --- a/tests/system/providers/microsoft/azure/example_azure_blob_to_gcs.py +++ b/tests/system/providers/microsoft/azure/example_azure_blob_to_gcs.py @@ -21,7 +21,7 @@ from datetime import datetime from airflow import DAG -from airflow.providers.microsoft.azure.sensors.wasb import WasbBlobAsyncSensor, WasbBlobSensor +from airflow.providers.microsoft.azure.sensors.wasb import WasbBlobSensor from airflow.providers.microsoft.azure.transfers.azure_blob_to_gcs import AzureBlobStorageToGCSOperator # Ignore missing args provided by default_args @@ -46,7 +46,7 @@ wait_for_blob = WasbBlobSensor(task_id="wait_for_blob") - wait_for_blob_async = WasbBlobAsyncSensor(task_id="wait_for_blob_async") + wait_for_blob_async = WasbBlobSensor(task_id="wait_for_blob_async", deferrable=True) transfer_files_to_gcs = AzureBlobStorageToGCSOperator( task_id="transfer_files_to_gcs",