Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions providers/microsoft/azure/docs/operators/batch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ Below is an example of using this operator to trigger a task on Azure Batch
:start-after: [START howto_azure_batch_operator]
:end-before: [END howto_azure_batch_operator]

Below is an example of using this operator to trigger a task on Azure Batch with the deferrable flag,
so that polling for job/task completion occurs on the Airflow Triggerer.

.. exampleinclude:: /../tests/system/microsoft/azure/example_azure_batch_operator.py
:language: python
:dedent: 4
:start-after: [START howto_azure_batch_operator_with_deferrable_flag]
:end-before: [END howto_azure_batch_operator_with_deferrable_flag]


Reference
---------
Expand Down
3 changes: 3 additions & 0 deletions providers/microsoft/azure/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,9 @@ hooks:
- airflow.providers.microsoft.azure.hooks.powerbi

triggers:
- integration-name: Microsoft Azure Batch
python-modules:
- airflow.providers.microsoft.azure.triggers.batch
- integration-name: Microsoft Azure Compute
python-modules:
- airflow.providers.microsoft.azure.triggers.compute
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,10 @@ def get_provider_info():
},
],
"triggers": [
{
"integration-name": "Microsoft Azure Batch",
"python-modules": ["airflow.providers.microsoft.azure.triggers.batch"],
},
{
"integration-name": "Microsoft Azure Compute",
"python-modules": ["airflow.providers.microsoft.azure.triggers.compute"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@
# under the License.
from __future__ import annotations

import time
from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any

from azure.batch import models as batch_models

from airflow.providers.common.compat.sdk import AirflowException, BaseOperator
from airflow.providers.common.compat.sdk import AirflowException, BaseOperator, conf
from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook
from airflow.providers.microsoft.azure.triggers.batch import AzureBatchTrigger

if TYPE_CHECKING:
from airflow.sdk import Context
Expand Down Expand Up @@ -91,6 +93,10 @@ class AzureBatchOperator(BaseOperator):
:param timeout: The amount of time to wait for the job to complete in minutes. Default is 25
:param should_delete_job: Whether to delete job after execution. Default is False
:param should_delete_pool: Whether to delete pool after execution of jobs. Default is False
:param poll_interval: Polling interval in seconds for deferrable mode. Default is 30.
Determines how frequently the trigger checks task completion status when deferrable=True.
:param deferrable: Run operator in deferrable mode.

"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -139,6 +145,8 @@ def __init__(
timeout: int = 25,
should_delete_job: bool = False,
should_delete_pool: bool = False,
poll_interval: int = 30,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -176,6 +184,8 @@ def __init__(
self.timeout = timeout
self.should_delete_job = should_delete_job
self.should_delete_pool = should_delete_pool
self.poll_interval = poll_interval
self.deferrable = deferrable

@cached_property
def hook(self) -> AzureBatchHook:
Expand Down Expand Up @@ -265,6 +275,7 @@ def execute(self, context: Context) -> None:
start_task=self.batch_start_task,
)
self.hook.create_pool(pool)

# Wait for nodes to reach complete state
self.hook.wait_for_all_node_state(
self.batch_pool_id,
Expand Down Expand Up @@ -296,6 +307,29 @@ def execute(self, context: Context) -> None:
)
# Add task to job
self.hook.add_single_task_to_job(job_id=self.batch_job_id, task=task)

if self.deferrable:
# Pre-deferral check (node readiness is already enforced by wait_for_all_node_state above)
pool = self.hook.connection.pool.get(self.batch_pool_id)
if pool.resize_errors:
raise RuntimeError(f"Pool resize errors: {pool.resize_errors}")

nodes = list(self.hook.connection.compute_node.list(self.batch_pool_id))
self.log.debug("Deferral pre-check: %d nodes present in pool %s", len(nodes), self.batch_pool_id)
Comment thread
SameerMesiah97 marked this conversation as resolved.
end_time = time.time() + (self.timeout * 60)

Comment thread
jscheffl marked this conversation as resolved.
self.defer(
timeout=self.execution_timeout,
trigger=AzureBatchTrigger(
job_id=self.batch_job_id,
azure_batch_conn_id=self.azure_batch_conn_id,
end_time=end_time,
poll_interval=self.poll_interval,
),
method_name="execute_complete",
)
return

# Wait for tasks to complete
fail_tasks = self.hook.wait_for_job_tasks_to_complete(job_id=self.batch_job_id, timeout=self.timeout)
# Clean up
Expand All @@ -306,7 +340,44 @@ def execute(self, context: Context) -> None:
self.clean_up(self.batch_pool_id)
# raise exception if any task fail
if fail_tasks:
raise AirflowException(f"Job fail. The failed task are: {fail_tasks}")
raise RuntimeError(f"Job fail. The failed task are: {fail_tasks}")

def execute_complete(self, context: Context, event: dict[str, Any] | None) -> None:
"""
Return immediately - callback for when the trigger fires.

The trigger communicates the terminal Azure Batch job state
through the event payload.
"""
if event is None:
raise RuntimeError("Trigger returned no event.")

status = event.get("status")
message = event.get("message", "No message returned from trigger.")
failed_tasks = event.get("failed_tasks")

try:
if status == "success":
self.log.info(message)
return

if status == "timeout":
raise RuntimeError(message)

if status == "error":
if failed_tasks:
raise RuntimeError(f"{message} Failed tasks: {failed_tasks}")

raise RuntimeError(message)

raise RuntimeError(f"Unexpected trigger event received: {event}")

finally:
if self.should_delete_job:
self.clean_up(job_id=self.batch_job_id)

if self.should_delete_pool:
self.clean_up(pool_id=self.batch_pool_id)

def on_kill(self) -> None:
response = self.hook.connection.job.terminate(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import asyncio
import time
from collections.abc import AsyncIterator
from typing import Any

from azure.batch import models as batch_models

from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook
from airflow.triggers.base import BaseTrigger, TriggerEvent


class AzureBatchTrigger(BaseTrigger):
"""
Trigger when Azure Batch job tasks reach a terminal state.

:param job_id: Azure Batch job identifier.
:param azure_batch_conn_id: Azure Batch connection id.
:param end_time: Absolute timeout deadline as determined using ``time.time()``.
:param poll_interval: Poll interval in seconds.
Comment thread
SameerMesiah97 marked this conversation as resolved.
"""

def __init__(
self,
job_id: str,
azure_batch_conn_id: str,
end_time: float,
poll_interval: int = 30,
):
super().__init__()

self.job_id = job_id
self.azure_batch_conn_id = azure_batch_conn_id
self.end_time = end_time
self.poll_interval = poll_interval

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize trigger arguments and classpath."""
return (
f"{self.__class__.__module__}.{self.__class__.__name__}",
{
"job_id": self.job_id,
"azure_batch_conn_id": self.azure_batch_conn_id,
"end_time": self.end_time,
"poll_interval": self.poll_interval,
},
)

def _get_incomplete_tasks(
self,
tasks: list[batch_models.CloudTask],
) -> list[batch_models.CloudTask]:
"""Return tasks that have not yet completed."""
return [task for task in tasks if task.state != batch_models.TaskState.completed]

def _build_trigger_event(
self,
tasks: list[batch_models.CloudTask],
) -> TriggerEvent | None:
"""
Convert Batch task states to TriggerEvent.

Returns None if tasks are still running.
"""
if not tasks:
return TriggerEvent(
{
"status": "error",
"message": f"Azure Batch job {self.job_id} contains no tasks.",
"job_id": self.job_id,
}
)

if self._get_incomplete_tasks(tasks):
return None

failed_tasks = [
task.id
for task in tasks
if task.execution_info and task.execution_info.result == batch_models.TaskExecutionResult.failure
]

if failed_tasks:
return TriggerEvent(
{
"status": "error",
"message": f"Azure Batch job {self.job_id} failed.",
"job_id": self.job_id,
"failed_tasks": failed_tasks,
}
)

return TriggerEvent(
{
"status": "success",
"message": f"Azure Batch job {self.job_id} completed successfully.",
"job_id": self.job_id,
}
)

async def run(self) -> AsyncIterator[TriggerEvent]:
"""Poll Azure Batch job tasks until completion or timeout."""
hook = AzureBatchHook(
azure_batch_conn_id=self.azure_batch_conn_id,
)

try:
while time.time() <= self.end_time:
tasks = await asyncio.to_thread(lambda: list(hook.connection.task.list(self.job_id)))

Comment thread
SameerMesiah97 marked this conversation as resolved.
event = self._build_trigger_event(tasks)

if event:
yield event
return

incomplete_tasks = self._get_incomplete_tasks(tasks)

self.log.info(
"Azure Batch job %s still running. Incomplete tasks: %s. Sleeping %s seconds.",
self.job_id,
incomplete_tasks,
self.poll_interval,
)

await asyncio.sleep(self.poll_interval)

# Final check before timeout event in case job completed
# during the last sleep interval.
tasks = await asyncio.to_thread(lambda: list(hook.connection.task.list(self.job_id)))

event = self._build_trigger_event(tasks)

if event:
yield event
return

yield TriggerEvent(
{
"status": "timeout",
"message": f"Timeout waiting for Azure Batch job {self.job_id}.",
"job_id": self.job_id,
}
)

except Exception as e:
self.log.exception(
"Azure Batch trigger failed for job %s.",
self.job_id,
)

yield TriggerEvent(
{
"status": "error",
"message": str(e),
"job_id": self.job_id,
}
)
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,25 @@
)
# [END howto_azure_batch_operator]

# [START howto_azure_batch_operator_with_deferrable_flag]
azure_batch_operator_deferrable = AzureBatchOperator(
task_id="azure_batch_deferrable",
batch_pool_id=POOL_ID,
batch_pool_vm_size="standard_d2s_v3",
batch_job_id="example-job",
batch_task_command_line="/bin/bash -c 'set -e; set -o pipefail; echo hello world!; wait'",
batch_task_id="example-task",
vm_node_agent_sku_id="batch.node.ubuntu 22.04",
vm_publisher="Canonical",
vm_offer="0001-com-ubuntu-server-jammy",
vm_sku="22_04-lts-gen2",
target_dedicated_nodes=1,
deferrable=True,
)
# [END howto_azure_batch_operator_with_deferrable_flag]

azure_batch_operator >> azure_batch_operator_deferrable

from tests_common.test_utils.system_tests import get_test_run # noqa: E402

# Needed to run the example DAG with pytest (see: contributing-docs/testing/system_tests.rst)
Expand Down
Loading
Loading