Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 47 additions & 43 deletions airflow/providers/microsoft/azure/sensors/wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import warnings
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Sequence

Expand All @@ -38,6 +39,8 @@ class WasbBlobSensor(BaseSensorOperator):
:param wasb_conn_id: Reference to the :ref:`wasb connection <howto/connection:wasb>`.
:param check_options: Optional keyword arguments that
`WasbHook.check_for_blob()` takes.
:param deferrable: Run sensor in the deferrable mode.
:param public_read: whether an anonymous public read access should be used. Default is False
"""

template_fields: Sequence[str] = ("container_name", "blob_name")
Expand All @@ -49,6 +52,8 @@ def __init__(
blob_name: str,
wasb_conn_id: str = "wasb_default",
check_options: dict | None = None,
public_read: bool = False,
deferrable: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -58,57 +63,32 @@ def __init__(
self.container_name = container_name
self.blob_name = blob_name
self.check_options = check_options
self.public_read = public_read
self.deferrable = deferrable

def poke(self, context: Context):
self.log.info("Poking for blob: %s\n in wasb://%s", self.blob_name, self.container_name)
hook = WasbHook(wasb_conn_id=self.wasb_conn_id)
return hook.check_for_blob(self.container_name, self.blob_name, **self.check_options)


class WasbBlobAsyncSensor(WasbBlobSensor):
"""
Polls asynchronously for the existence of a blob in a WASB container.

:param container_name: name of the container in which the blob should be searched for
:param blob_name: name of the blob to check existence for
:param wasb_conn_id: the connection identifier for connecting to Azure WASB
:param poke_interval: polling period in seconds to check for the status
:param public_read: whether an anonymous public read access should be used. Default is False
:param timeout: Time, in seconds before the task times out and fails.
"""

def __init__(
self,
*,
container_name: str,
blob_name: str,
wasb_conn_id: str = "wasb_default",
public_read: bool = False,
poke_interval: float = 5.0,
**kwargs: Any,
):
self.container_name = container_name
self.blob_name = blob_name
self.poke_interval = poke_interval
super().__init__(container_name=container_name, blob_name=blob_name, **kwargs)
self.wasb_conn_id = wasb_conn_id
self.public_read = public_read

def execute(self, context: Context) -> None:
"""Defers trigger class to poll for state of the job run until it reaches
a failure state or success state
"""Defers trigger class to poll for state of the job run until
it reaches a failure state or success state
"""
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=WasbBlobSensorTrigger(
container_name=self.container_name,
blob_name=self.blob_name,
wasb_conn_id=self.wasb_conn_id,
public_read=self.public_read,
poke_interval=self.poke_interval,
),
method_name="execute_complete",
)
if not self.deferrable:
super().execute(context=context)
else:
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=WasbBlobSensorTrigger(
container_name=self.container_name,
blob_name=self.blob_name,
wasb_conn_id=self.wasb_conn_id,
public_read=self.public_read,
poke_interval=self.poke_interval,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, str]) -> None:
"""
Expand All @@ -124,6 +104,30 @@ def execute_complete(self, context: Context, event: dict[str, str]) -> None:
raise AirflowException("Did not receive valid event from the triggerer")


class WasbBlobAsyncSensor(WasbBlobSensor):
"""
Polls asynchronously for the existence of a blob in a WASB container.

:param container_name: name of the container in which the blob should be searched for
:param blob_name: name of the blob to check existence for
:param wasb_conn_id: the connection identifier for connecting to Azure WASB
:param poke_interval: polling period in seconds to check for the status
:param public_read: whether an anonymous public read access should be used. Default is False
:param timeout: Time, in seconds before the task times out and fails.
"""

def __init__(self, **kwargs: Any) -> None:
warnings.warn(
"Class `WasbBlobAsyncSensor` is deprecated and "
"will be removed in a future release. "
"Please use `WasbBlobSensor` and "
"set `deferrable` attribute to `True` instead",
DeprecationWarning,
stacklevel=2,
)
super().__init__(**kwargs, deferrable=True)


class WasbPrefixSensor(BaseSensorOperator):
"""
Waits for blobs matching a prefix to arrive on Azure Blob Storage.
Expand Down
4 changes: 2 additions & 2 deletions tests/providers/microsoft/azure/sensors/test_wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.providers.microsoft.azure.sensors.wasb import (
WasbBlobAsyncSensor,
WasbBlobSensor,
WasbPrefixSensor,
)
Expand Down Expand Up @@ -120,11 +119,12 @@ def create_context(self, task, dag=None):
"logical_date": execution_date,
}

SENSOR = WasbBlobAsyncSensor(
SENSOR = WasbBlobSensor(
task_id="wasb_blob_async_sensor",
container_name=TEST_DATA_STORAGE_CONTAINER_NAME,
blob_name=TEST_DATA_STORAGE_BLOB_NAME,
timeout=5,
deferrable=True,
)

def test_wasb_blob_sensor_async(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from datetime import datetime

from airflow import DAG
from airflow.providers.microsoft.azure.sensors.wasb import WasbBlobAsyncSensor, WasbBlobSensor
from airflow.providers.microsoft.azure.sensors.wasb import WasbBlobSensor
from airflow.providers.microsoft.azure.transfers.azure_blob_to_gcs import AzureBlobStorageToGCSOperator

# Ignore missing args provided by default_args
Expand All @@ -46,7 +46,7 @@

wait_for_blob = WasbBlobSensor(task_id="wait_for_blob")

wait_for_blob_async = WasbBlobAsyncSensor(task_id="wait_for_blob_async")
wait_for_blob_async = WasbBlobSensor(task_id="wait_for_blob_async", deferrable=True)

transfer_files_to_gcs = AzureBlobStorageToGCSOperator(
task_id="transfer_files_to_gcs",
Expand Down