diff --git a/providers/microsoft/azure/docs/operators/batch.rst b/providers/microsoft/azure/docs/operators/batch.rst index 8cc5cc6300637..9b7722d46864d 100644 --- a/providers/microsoft/azure/docs/operators/batch.rst +++ b/providers/microsoft/azure/docs/operators/batch.rst @@ -32,6 +32,15 @@ Below is an example of using this operator to trigger a task on Azure Batch :start-after: [START howto_azure_batch_operator] :end-before: [END howto_azure_batch_operator] +Below is an example of using this operator to trigger a task on Azure Batch with the deferrable flag, +so that polling for job/task completion occurs on the Airflow Triggerer. + + .. exampleinclude:: /../tests/system/microsoft/azure/example_azure_batch_operator.py + :language: python + :dedent: 4 + :start-after: [START howto_azure_batch_operator_with_deferrable_flag] + :end-before: [END howto_azure_batch_operator_with_deferrable_flag] + Reference --------- diff --git a/providers/microsoft/azure/provider.yaml b/providers/microsoft/azure/provider.yaml index a5cbd0bc1e8df..801f283bb9990 100644 --- a/providers/microsoft/azure/provider.yaml +++ b/providers/microsoft/azure/provider.yaml @@ -310,6 +310,9 @@ hooks: - airflow.providers.microsoft.azure.hooks.powerbi triggers: + - integration-name: Microsoft Azure Batch + python-modules: + - airflow.providers.microsoft.azure.triggers.batch - integration-name: Microsoft Azure Compute python-modules: - airflow.providers.microsoft.azure.triggers.compute diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py index 0b312ec09a955..40364bf18455a 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py @@ -284,6 +284,10 @@ def get_provider_info(): }, ], "triggers": [ + { + "integration-name": "Microsoft Azure Batch", + "python-modules": ["airflow.providers.microsoft.azure.triggers.batch"], + }, { "integration-name": "Microsoft Azure Compute", "python-modules": ["airflow.providers.microsoft.azure.triggers.compute"], diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py index e5b36841c66cf..4e9216ec18198 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py @@ -17,14 +17,16 @@ # under the License. from __future__ import annotations +import time from collections.abc import Sequence from functools import cached_property from typing import TYPE_CHECKING, Any from azure.batch import models as batch_models -from airflow.providers.common.compat.sdk import AirflowException, BaseOperator +from airflow.providers.common.compat.sdk import AirflowException, BaseOperator, conf from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook +from airflow.providers.microsoft.azure.triggers.batch import AzureBatchTrigger if TYPE_CHECKING: from airflow.sdk import Context @@ -91,6 +93,10 @@ class AzureBatchOperator(BaseOperator): :param timeout: The amount of time to wait for the job to complete in minutes. Default is 25 :param should_delete_job: Whether to delete job after execution. Default is False :param should_delete_pool: Whether to delete pool after execution of jobs. Default is False + :param poll_interval: Polling interval in seconds for deferrable mode. Default is 30. + Determines how frequently the trigger checks task completion status when deferrable=True. + :param deferrable: Run operator in deferrable mode. + """ template_fields: Sequence[str] = ( @@ -139,6 +145,8 @@ def __init__( timeout: int = 25, should_delete_job: bool = False, should_delete_pool: bool = False, + poll_interval: int = 30, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) @@ -176,6 +184,8 @@ def __init__( self.timeout = timeout self.should_delete_job = should_delete_job self.should_delete_pool = should_delete_pool + self.poll_interval = poll_interval + self.deferrable = deferrable @cached_property def hook(self) -> AzureBatchHook: @@ -265,6 +275,7 @@ def execute(self, context: Context) -> None: start_task=self.batch_start_task, ) self.hook.create_pool(pool) + # Wait for nodes to reach complete state self.hook.wait_for_all_node_state( self.batch_pool_id, @@ -296,6 +307,29 @@ def execute(self, context: Context) -> None: ) # Add task to job self.hook.add_single_task_to_job(job_id=self.batch_job_id, task=task) + + if self.deferrable: + # Pre-deferral check (node readiness is already enforced by wait_for_all_node_state above) + pool = self.hook.connection.pool.get(self.batch_pool_id) + if pool.resize_errors: + raise RuntimeError(f"Pool resize errors: {pool.resize_errors}") + + nodes = list(self.hook.connection.compute_node.list(self.batch_pool_id)) + self.log.debug("Deferral pre-check: %d nodes present in pool %s", len(nodes), self.batch_pool_id) + end_time = time.time() + (self.timeout * 60) + + self.defer( + timeout=self.execution_timeout, + trigger=AzureBatchTrigger( + job_id=self.batch_job_id, + azure_batch_conn_id=self.azure_batch_conn_id, + end_time=end_time, + poll_interval=self.poll_interval, + ), + method_name="execute_complete", + ) + return + # Wait for tasks to complete fail_tasks = self.hook.wait_for_job_tasks_to_complete(job_id=self.batch_job_id, timeout=self.timeout) # Clean up @@ -306,7 +340,44 @@ def execute(self, context: Context) -> None: self.clean_up(self.batch_pool_id) # raise exception if any task fail if fail_tasks: - raise AirflowException(f"Job fail. The failed task are: {fail_tasks}") + raise RuntimeError(f"Job fail. The failed task are: {fail_tasks}") + + def execute_complete(self, context: Context, event: dict[str, Any] | None) -> None: + """ + Return immediately - callback for when the trigger fires. + + The trigger communicates the terminal Azure Batch job state + through the event payload. + """ + if event is None: + raise RuntimeError("Trigger returned no event.") + + status = event.get("status") + message = event.get("message", "No message returned from trigger.") + failed_tasks = event.get("failed_tasks") + + try: + if status == "success": + self.log.info(message) + return + + if status == "timeout": + raise RuntimeError(message) + + if status == "error": + if failed_tasks: + raise RuntimeError(f"{message} Failed tasks: {failed_tasks}") + + raise RuntimeError(message) + + raise RuntimeError(f"Unexpected trigger event received: {event}") + + finally: + if self.should_delete_job: + self.clean_up(job_id=self.batch_job_id) + + if self.should_delete_pool: + self.clean_up(pool_id=self.batch_pool_id) def on_kill(self) -> None: response = self.hook.connection.job.terminate( diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/batch.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/batch.py new file mode 100644 index 0000000000000..65bf5eb0217f6 --- /dev/null +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/batch.py @@ -0,0 +1,176 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import asyncio +import time +from collections.abc import AsyncIterator +from typing import Any + +from azure.batch import models as batch_models + +from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class AzureBatchTrigger(BaseTrigger): + """ + Trigger when Azure Batch job tasks reach a terminal state. + + :param job_id: Azure Batch job identifier. + :param azure_batch_conn_id: Azure Batch connection id. + :param end_time: Absolute timeout deadline as determined using ``time.time()``. + :param poll_interval: Poll interval in seconds. + """ + + def __init__( + self, + job_id: str, + azure_batch_conn_id: str, + end_time: float, + poll_interval: int = 30, + ): + super().__init__() + + self.job_id = job_id + self.azure_batch_conn_id = azure_batch_conn_id + self.end_time = end_time + self.poll_interval = poll_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize trigger arguments and classpath.""" + return ( + f"{self.__class__.__module__}.{self.__class__.__name__}", + { + "job_id": self.job_id, + "azure_batch_conn_id": self.azure_batch_conn_id, + "end_time": self.end_time, + "poll_interval": self.poll_interval, + }, + ) + + def _get_incomplete_tasks( + self, + tasks: list[batch_models.CloudTask], + ) -> list[batch_models.CloudTask]: + """Return tasks that have not yet completed.""" + return [task for task in tasks if task.state != batch_models.TaskState.completed] + + def _build_trigger_event( + self, + tasks: list[batch_models.CloudTask], + ) -> TriggerEvent | None: + """ + Convert Batch task states to TriggerEvent. + + Returns None if tasks are still running. + """ + if not tasks: + return TriggerEvent( + { + "status": "error", + "message": f"Azure Batch job {self.job_id} contains no tasks.", + "job_id": self.job_id, + } + ) + + if self._get_incomplete_tasks(tasks): + return None + + failed_tasks = [ + task.id + for task in tasks + if task.execution_info and task.execution_info.result == batch_models.TaskExecutionResult.failure + ] + + if failed_tasks: + return TriggerEvent( + { + "status": "error", + "message": f"Azure Batch job {self.job_id} failed.", + "job_id": self.job_id, + "failed_tasks": failed_tasks, + } + ) + + return TriggerEvent( + { + "status": "success", + "message": f"Azure Batch job {self.job_id} completed successfully.", + "job_id": self.job_id, + } + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Poll Azure Batch job tasks until completion or timeout.""" + hook = AzureBatchHook( + azure_batch_conn_id=self.azure_batch_conn_id, + ) + + try: + while time.time() <= self.end_time: + tasks = await asyncio.to_thread(lambda: list(hook.connection.task.list(self.job_id))) + + event = self._build_trigger_event(tasks) + + if event: + yield event + return + + incomplete_tasks = self._get_incomplete_tasks(tasks) + + self.log.info( + "Azure Batch job %s still running. Incomplete tasks: %s. Sleeping %s seconds.", + self.job_id, + incomplete_tasks, + self.poll_interval, + ) + + await asyncio.sleep(self.poll_interval) + + # Final check before timeout event in case job completed + # during the last sleep interval. + tasks = await asyncio.to_thread(lambda: list(hook.connection.task.list(self.job_id))) + + event = self._build_trigger_event(tasks) + + if event: + yield event + return + + yield TriggerEvent( + { + "status": "timeout", + "message": f"Timeout waiting for Azure Batch job {self.job_id}.", + "job_id": self.job_id, + } + ) + + except Exception as e: + self.log.exception( + "Azure Batch trigger failed for job %s.", + self.job_id, + ) + + yield TriggerEvent( + { + "status": "error", + "message": str(e), + "job_id": self.job_id, + } + ) diff --git a/providers/microsoft/azure/tests/system/microsoft/azure/example_azure_batch_operator.py b/providers/microsoft/azure/tests/system/microsoft/azure/example_azure_batch_operator.py index f5c563aab77f3..b6f6dded1814f 100644 --- a/providers/microsoft/azure/tests/system/microsoft/azure/example_azure_batch_operator.py +++ b/providers/microsoft/azure/tests/system/microsoft/azure/example_azure_batch_operator.py @@ -57,6 +57,25 @@ ) # [END howto_azure_batch_operator] + # [START howto_azure_batch_operator_with_deferrable_flag] + azure_batch_operator_deferrable = AzureBatchOperator( + task_id="azure_batch_deferrable", + batch_pool_id=POOL_ID, + batch_pool_vm_size="standard_d2s_v3", + batch_job_id="example-job", + batch_task_command_line="/bin/bash -c 'set -e; set -o pipefail; echo hello world!; wait'", + batch_task_id="example-task", + vm_node_agent_sku_id="batch.node.ubuntu 22.04", + vm_publisher="Canonical", + vm_offer="0001-com-ubuntu-server-jammy", + vm_sku="22_04-lts-gen2", + target_dedicated_nodes=1, + deferrable=True, + ) + # [END howto_azure_batch_operator_with_deferrable_flag] + + azure_batch_operator >> azure_batch_operator_deferrable + from tests_common.test_utils.system_tests import get_test_run # noqa: E402 # Needed to run the example DAG with pytest (see: contributing-docs/testing/system_tests.rst) diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_batch.py b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_batch.py index 160aea3df7e79..7c9fba490cc02 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_batch.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_batch.py @@ -23,9 +23,10 @@ import pytest from airflow.models import Connection -from airflow.providers.common.compat.sdk import AirflowException +from airflow.providers.common.compat.sdk import AirflowException, TaskDeferred from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook from airflow.providers.microsoft.azure.operators.batch import AzureBatchOperator +from airflow.providers.microsoft.azure.triggers.batch import AzureBatchTrigger TASK_ID = "MyDag" BATCH_POOL_ID = "MyPool" @@ -247,3 +248,183 @@ def test_cleaning_works(self): self.operator.clean_up("mypool", "myjob") self.batch_client.job.delete.assert_called_with("myjob") self.batch_client.pool.delete.assert_called_with("mypool") + + +class TestAzureBatchOperatorDeferrable: + @pytest.fixture(autouse=True) + def setup_test_cases(self, mocked_batch_service_client, create_mock_connections): + self.batch_client = mock.MagicMock(name="FakeBatchServiceClient") + mocked_batch_service_client.return_value = self.batch_client + + self.test_conn_id = "test_azure_batch" + self.test_account_url = "http://test-endpoint:29000" + + create_mock_connections( + Connection( + conn_id=self.test_conn_id, + conn_type="azure_batch", + extra=json.dumps({"account_url": self.test_account_url}), + ), + ) + + self.operator = AzureBatchOperator( + task_id=TASK_ID, + batch_pool_id=BATCH_POOL_ID, + batch_pool_vm_size=BATCH_VM_SIZE, + batch_job_id=BATCH_JOB_ID, + batch_task_id=BATCH_TASK_ID, + batch_task_command_line="echo hello", + vm_node_agent_sku_id="node-agent", + os_family="4", + target_dedicated_nodes=1, + azure_batch_conn_id=self.test_conn_id, + deferrable=True, + ) + + @mock.patch.object(AzureBatchHook, "wait_for_all_node_state") + def test_execute_defers(self, wait_mock): + + wait_mock.return_value = True + self.batch_client.pool.get.return_value.resize_errors = None + + with pytest.raises(TaskDeferred) as ctx: + self.operator.execute(None) + + trigger = ctx.value.trigger + + assert isinstance(trigger, AzureBatchTrigger) + + assert trigger.job_id == BATCH_JOB_ID + assert trigger.azure_batch_conn_id == self.test_conn_id + + self.batch_client.pool.add.assert_called() + self.batch_client.job.add.assert_called() + self.batch_client.task.add.assert_called() + + def test_execute_complete_success(self): + with mock.patch.object(self.operator.log, "info") as mock_log: + self.operator.execute_complete( + context={}, + event={ + "status": "success", + "message": "success", + "job_id": BATCH_JOB_ID, + }, + ) + + mock_log.assert_called_once_with("success") + + def test_execute_complete_error(self): + with pytest.raises(RuntimeError, match="error"): + self.operator.execute_complete( + context={}, + event={ + "status": "error", + "message": "error", + "job_id": BATCH_JOB_ID, + }, + ) + + def test_execute_complete_timeout(self): + with pytest.raises(RuntimeError, match="timeout"): + self.operator.execute_complete( + context={}, + event={ + "status": "timeout", + "message": "timeout", + "job_id": BATCH_JOB_ID, + }, + ) + + def test_execute_complete_no_event(self): + with pytest.raises(RuntimeError, match="no event"): + self.operator.execute_complete( + context={}, + event=None, + ) + + def test_execute_complete_unexpected_event(self): + with pytest.raises(RuntimeError, match="Unexpected"): + self.operator.execute_complete( + context={}, + event={ + "status": "unknown", + "message": "???", + }, + ) + + def test_execute_complete_failed_tasks(self): + with pytest.raises(RuntimeError, match="task1"): + self.operator.execute_complete( + context={}, + event={ + "status": "error", + "message": "job failed", + "failed_tasks": ["task1"], + }, + ) + + @pytest.mark.parametrize( + ("event", "expected_exception"), + [ + ( + { + "status": "success", + "message": "success", + "job_id": BATCH_JOB_ID, + }, + None, + ), + ( + { + "status": "error", + "message": "error", + "job_id": BATCH_JOB_ID, + }, + RuntimeError, + ), + ( + { + "status": "timeout", + "message": "timeout", + "job_id": BATCH_JOB_ID, + }, + RuntimeError, + ), + ( + { + "status": "unknown", + "message": "???", + }, + RuntimeError, + ), + ], + ) + @mock.patch.object(AzureBatchOperator, "clean_up") + def test_execute_complete_cleanup( + self, + clean_up_mock, + event, + expected_exception, + ): + self.operator.should_delete_job = True + self.operator.should_delete_pool = True + + if expected_exception: + with pytest.raises(expected_exception): + self.operator.execute_complete( + context={}, + event=event, + ) + else: + self.operator.execute_complete( + context={}, + event=event, + ) + + clean_up_mock.assert_has_calls( + [ + mock.call(job_id=BATCH_JOB_ID), + mock.call(pool_id=BATCH_POOL_ID), + ] + ) diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_batch.py b/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_batch.py new file mode 100644 index 0000000000000..4c24fb6821649 --- /dev/null +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_batch.py @@ -0,0 +1,287 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import time +from unittest import mock + +import pytest +from azure.batch import models as batch_models + +from airflow.providers.microsoft.azure.triggers.batch import AzureBatchTrigger +from airflow.triggers.base import TriggerEvent + +AZURE_BATCH_CONN_ID = "azure_batch_default" +JOB_ID = "test-job" +POKE_INTERVAL = 5 +BATCH_END_TIME = time.time() + 60 * 60 * 24 * 7 +MODULE = "airflow.providers.microsoft.azure" + + +class TestAzureBatchTrigger: + TRIGGER = AzureBatchTrigger( + job_id=JOB_ID, + azure_batch_conn_id=AZURE_BATCH_CONN_ID, + poll_interval=POKE_INTERVAL, + end_time=BATCH_END_TIME, + ) + + def test_batch_trigger_serialization(self): + classpath, kwargs = self.TRIGGER.serialize() + + assert classpath == f"{MODULE}.triggers.batch.AzureBatchTrigger" + + assert kwargs == { + "job_id": JOB_ID, + "azure_batch_conn_id": AZURE_BATCH_CONN_ID, + "poll_interval": POKE_INTERVAL, + "end_time": BATCH_END_TIME, + } + + def test_build_trigger_event_success(self): + completed_task = mock.MagicMock() + completed_task.id = "task1" + completed_task.state = batch_models.TaskState.completed + completed_task.execution_info.result = batch_models.TaskExecutionResult.success + + event = self.TRIGGER._build_trigger_event([completed_task]) + + assert event is not None + + assert event.payload == { + "status": "success", + "message": f"Azure Batch job {JOB_ID} completed successfully.", + "job_id": JOB_ID, + } + + def test_build_trigger_event_failure(self): + failed_task = mock.MagicMock() + failed_task.id = "task1" + failed_task.state = batch_models.TaskState.completed + failed_task.execution_info.result = batch_models.TaskExecutionResult.failure + + event = self.TRIGGER._build_trigger_event([failed_task]) + + assert event is not None + + assert event.payload == { + "status": "error", + "message": f"Azure Batch job {JOB_ID} failed.", + "job_id": JOB_ID, + "failed_tasks": ["task1"], + } + + def test_build_trigger_event_mixed_states(self): + completed_task = mock.MagicMock() + completed_task.id = "task1" + completed_task.state = batch_models.TaskState.completed + completed_task.execution_info.result = batch_models.TaskExecutionResult.success + + running_task = mock.MagicMock() + running_task.id = "task2" + running_task.state = batch_models.TaskState.running + + event = self.TRIGGER._build_trigger_event([completed_task, running_task]) + + assert event is None + + def test_build_trigger_event_empty_tasks(self): + event = self.TRIGGER._build_trigger_event([]) + + assert event is not None + + assert event.payload == { + "status": "error", + "message": f"Azure Batch job {JOB_ID} contains no tasks.", + "job_id": JOB_ID, + } + + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.triggers.batch.asyncio.sleep") + @mock.patch(f"{MODULE}.triggers.batch.asyncio.to_thread") + async def test_trigger_run_non_terminal_sleeps( + self, + mock_to_thread, + mock_sleep, + ): + running_task = mock.MagicMock() + running_task.id = "task1" + running_task.state = batch_models.TaskState.running + + completed_task = mock.MagicMock() + completed_task.id = "task1" + completed_task.state = batch_models.TaskState.completed + completed_task.execution_info.result = batch_models.TaskExecutionResult.success + + mock_to_thread.side_effect = [ + [running_task], + [completed_task], + ] + + events = [event async for event in self.TRIGGER.run()] + + assert events == [ + TriggerEvent( + { + "status": "success", + "message": f"Azure Batch job {JOB_ID} completed successfully.", + "job_id": JOB_ID, + } + ) + ] + + mock_sleep.assert_awaited_once_with(POKE_INTERVAL) + + def test_build_trigger_event_non_terminal(self): + running_task = mock.MagicMock() + running_task.id = "task1" + running_task.state = batch_models.TaskState.running + + event = self.TRIGGER._build_trigger_event([running_task]) + + assert event is None + + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.triggers.batch.asyncio.to_thread") + async def test_trigger_run_success(self, mock_to_thread): + completed_task = mock.MagicMock() + completed_task.id = "task1" + completed_task.state = batch_models.TaskState.completed + completed_task.execution_info.result = batch_models.TaskExecutionResult.success + + mock_to_thread.return_value = [completed_task] + + generator = self.TRIGGER.run() + actual = await generator.asend(None) + + assert actual == TriggerEvent( + { + "status": "success", + "message": f"Azure Batch job {JOB_ID} completed successfully.", + "job_id": JOB_ID, + } + ) + + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.triggers.batch.asyncio.to_thread") + async def test_trigger_run_failure(self, mock_to_thread): + failed_task = mock.MagicMock() + failed_task.id = "task1" + failed_task.state = batch_models.TaskState.completed + failed_task.execution_info.result = batch_models.TaskExecutionResult.failure + + mock_to_thread.return_value = [failed_task] + + generator = self.TRIGGER.run() + actual = await generator.asend(None) + + assert actual == TriggerEvent( + { + "status": "error", + "message": f"Azure Batch job {JOB_ID} failed.", + "job_id": JOB_ID, + "failed_tasks": ["task1"], + } + ) + + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.triggers.batch.asyncio.to_thread") + async def test_trigger_exception(self, mock_to_thread): + mock_to_thread.side_effect = Exception("API failure") + + events = [event async for event in self.TRIGGER.run()] + + assert events == [ + TriggerEvent( + { + "status": "error", + "message": "API failure", + "job_id": JOB_ID, + } + ) + ] + + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.triggers.batch.asyncio.to_thread") + async def test_trigger_run_empty_tasks(self, mock_to_thread): + mock_to_thread.return_value = [] + + events = [event async for event in self.TRIGGER.run()] + + assert events == [ + TriggerEvent( + { + "status": "error", + "message": f"Azure Batch job {JOB_ID} contains no tasks.", + "job_id": JOB_ID, + } + ) + ] + + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.triggers.batch.time") + @mock.patch(f"{MODULE}.triggers.batch.asyncio.to_thread") + async def test_trigger_timeout_job_already_succeeded( + self, + mock_to_thread, + mock_time, + ): + completed_task = mock.MagicMock() + completed_task.id = "task1" + completed_task.state = batch_models.TaskState.completed + completed_task.execution_info.result = batch_models.TaskExecutionResult.success + + mock_to_thread.return_value = [completed_task] + + mock_time.time.return_value = BATCH_END_TIME + 60 + + events = [event async for event in self.TRIGGER.run()] + + assert events == [ + TriggerEvent( + { + "status": "success", + "message": f"Azure Batch job {JOB_ID} completed successfully.", + "job_id": JOB_ID, + } + ) + ] + + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.triggers.batch.time") + @mock.patch(f"{MODULE}.triggers.batch.asyncio.to_thread") + async def test_trigger_timeout(self, mock_to_thread, mock_time): + running_task = mock.MagicMock() + running_task.id = "task1" + running_task.state = batch_models.TaskState.running + + mock_to_thread.return_value = [running_task] + + mock_time.time.return_value = BATCH_END_TIME + 60 + + events = [event async for event in self.TRIGGER.run()] + + assert events == [ + TriggerEvent( + { + "status": "timeout", + "message": f"Timeout waiting for Azure Batch job {JOB_ID}.", + "job_id": JOB_ID, + } + ) + ] diff --git a/scripts/ci/prek/known_airflow_exceptions.txt b/scripts/ci/prek/known_airflow_exceptions.txt index bd4570fc55f25..b4cf860387009 100644 --- a/scripts/ci/prek/known_airflow_exceptions.txt +++ b/scripts/ci/prek/known_airflow_exceptions.txt @@ -352,7 +352,7 @@ providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_facto providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_lake.py::1 providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py::3 providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/wasb.py::2 -providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py::10 +providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/batch.py::9 providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/container_instances.py::10 providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/data_factory.py::1 providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py::1