From eebd16e8751c1d2beed965f957b36f8ab1da7f12 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Wed, 25 Mar 2026 18:36:35 -0700 Subject: [PATCH 1/6] Implement comm handlers for callback supervisor --- .../src/airflow/executors/base_executor.py | 2 + .../src/airflow/jobs/triggerer_job_runner.py | 35 +--- .../unit/executors/test_local_executor.py | 4 + .../sdk/execution_time/callback_supervisor.py | 155 ++++++++++++---- .../sdk/execution_time/request_handlers.py | 101 ++++++++++ .../airflow/sdk/execution_time/supervisor.py | 174 ++++++++---------- .../test_callback_supervisor.py | 149 ++++++++++++++- .../execution_time/test_supervisor.py | 2 +- 8 files changed, 461 insertions(+), 161 deletions(-) create mode 100644 task-sdk/src/airflow/sdk/execution_time/request_handlers.py diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py index cc708545cfac0..e37e68683b353 100644 --- a/airflow-core/src/airflow/executors/base_executor.py +++ b/airflow-core/src/airflow/executors/base_executor.py @@ -651,6 +651,8 @@ def run_workload( callback_kwargs=workload.callback.data.get("kwargs", {}), log_path=workload.log_path, bundle_info=workload.bundle_info, + token=workload.token, + server=server, ) raise ValueError(f"Unknown workload type: {type(workload).__name__}") diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 7522349c461d0..5d3ecdb529565 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -84,6 +84,11 @@ _new_encoder, _RequestFrame, ) +from airflow.sdk.execution_time.request_handlers import ( + handle_get_connection, + handle_get_variable, + handle_mask_secret, +) from airflow.sdk.execution_time.supervisor import WatchedSubprocess, make_buffered_socket_reader from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance from airflow.serialization.serialized_objects import DagSerialization @@ -447,14 +452,12 @@ def client(self) -> Client: def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, req_id: int) -> None: from airflow.sdk.api.datamodels._generated import ( - ConnectionResponse, TaskStatesResponse, - VariableResponse, XComResponse, ) resp: BaseModel | None = None - dump_opts = {} + dump_opts: dict[str, bool] = {} if isinstance(msg, messages.TriggerStateChanges): if msg.events: @@ -482,29 +485,11 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r resp = response elif isinstance(msg, GetConnection): - conn = self.client.connections.get(msg.conn_id) - if isinstance(conn, ConnectionResponse): - conn_result = ConnectionResult.from_conn_response(conn) - resp = conn_result - # `by_alias=True` is used to convert the `schema` field to `schema_` in the Connection model - dump_opts = {"exclude_unset": True, "by_alias": True} - else: - resp = conn + resp, dump_opts = handle_get_connection(self.client, msg) elif isinstance(msg, DeleteVariable): resp = self.client.variables.delete(msg.key) elif isinstance(msg, GetVariable): - var = self.client.variables.get(msg.key) - if isinstance(var, VariableResponse): - # TODO: call for help to figure out why this is needed - if var.value: - from airflow.sdk.log import mask_secret - - mask_secret(var.value, var.key) - var_result = VariableResult.from_variable_response(var) - resp = var_result - dump_opts = {"exclude_unset": True} - else: - resp = var + resp, dump_opts = handle_get_variable(self.client, msg) elif isinstance(msg, PutVariable): self.client.variables.set(msg.key, msg.value, msg.description) elif isinstance(msg, DeleteXCom): @@ -583,9 +568,7 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r api_resp = self.client.hitl.get_detail_response(ti_id=msg.ti_id) resp = HITLDetailResponseResult.from_api_response(response=api_resp) elif isinstance(msg, MaskSecret): - from airflow.sdk.log import mask_secret - - mask_secret(msg.value, msg.name) + handle_mask_secret(msg) else: raise ValueError(f"Unknown message type {type(msg)}") diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index 0f3c4deca41d2..532a56eff8a68 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -394,6 +394,8 @@ def test_global_executor_without_team_name(self): class TestLocalExecutorCallbackSupport: CALLBACK_UUID = "12345678-1234-5678-1234-567812345678" + TEST_TOKEN = "test_token" + TEST_SERVER = "http://localhost:8080/execution/" def test_supports_callbacks_flag_is_true(self): executor = LocalExecutor() @@ -451,6 +453,8 @@ def test_execute_workload_calls_supervise_callback(self, mock_supervise_callback callback_kwargs={"arg1": "val1"}, log_path="test.log", bundle_info=BundleInfo(name="test_bundle", version="1.0"), + token=TestLocalExecutorCallbackSupport.TEST_TOKEN, + server=TestLocalExecutorCallbackSupport.TEST_SERVER, ) @mock.patch( diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py index 58a1e766e52ca..1f53beb9b666a 100644 --- a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py @@ -21,25 +21,46 @@ import sys import time from importlib import import_module -from typing import TYPE_CHECKING, BinaryIO, ClassVar, Protocol +from typing import TYPE_CHECKING, Annotated, BinaryIO, ClassVar, Protocol from uuid import UUID import attrs import structlog -from pydantic import TypeAdapter +from pydantic import Field, TypeAdapter from airflow.sdk._shared.module_loading import accepts_context, accepts_keyword_args +from airflow.sdk.execution_time.comms import ( + ErrorResponse, + GetAssetByName, + GetAssetByUri, + GetConnection, + GetVariable, + GetXCom, + MaskSecret, +) +from airflow.sdk.execution_time.request_handlers import ( + handle_get_asset_by_name, + handle_get_asset_by_uri, + handle_get_connection, + handle_get_variable, + handle_get_xcom, + handle_mask_secret, +) from airflow.sdk.execution_time.supervisor import ( MIN_HEARTBEAT_INTERVAL, SOCKET_CLEANUP_TIMEOUT, WatchedSubprocess, + _ensure_client, _make_process_nondumpable, ) if TYPE_CHECKING: + from pydantic import BaseModel from structlog.typing import FilteringBoundLogger from typing_extensions import Self + from airflow.sdk.api.client import Client + # Core (airflow.executors.workloads.base.BundleInfo) and SDK (airflow.sdk.api.datamodels._generated.BundleInfo) # are structurally identical, but MyPy treats them as different types. This Protocol makes MyPy happy. class _BundleInfoLike(Protocol): @@ -52,6 +73,15 @@ class _BundleInfoLike(Protocol): log: FilteringBoundLogger = structlog.get_logger(logger_name="callback_supervisor") +# The set of messages that a callback subprocess can send to the supervisor. +# This is a minimal subset of ToSupervisor: read-only access to Connections, +# Variables, XCom, and Assets, plus MaskSecret for the secrets masker. +CallbackToSupervisor = Annotated[ + GetAssetByName | GetAssetByUri | GetConnection | GetVariable | GetXCom | MaskSecret, + Field(discriminator="type"), +] + + def execute_callback( callback_path: str, callback_kwargs: dict, @@ -123,14 +153,6 @@ def execute_callback( return False, error_msg -# An empty message set; the callback subprocess doesn't currently communicate back to the -# supervisor. This means callback code cannot access runtime services like Connection.get() -# or Variable.get() which require the supervisor to pass requests to the API server. -# To enable this, add the needed message types here and implement _handle_request accordingly. -# See ActivitySubprocess.decoder in supervisor.py for the full task message set and examples. -_EmptyMessage: TypeAdapter[None] = TypeAdapter(None) - - @attrs.define(kw_only=True) class CallbackSubprocess(WatchedSubprocess): """ @@ -138,9 +160,15 @@ class CallbackSubprocess(WatchedSubprocess): Uses the WatchedSubprocess infrastructure for fork/monitor/signal handling while keeping a simple lifecycle: start, run callback, exit. + + Provides a limited set of comms channels (Connections, Variables, XCom, + Assets) so that callback code can access runtime services like + ``Connection.get()`` and ``Variable.get()`` via the supervisor's API client. """ - decoder: ClassVar[TypeAdapter] = _EmptyMessage + client: Client # The HTTP client to use for communication with the API server. + + decoder: ClassVar[TypeAdapter[CallbackToSupervisor]] = TypeAdapter(CallbackToSupervisor) @classmethod def start( # type: ignore[override] @@ -150,6 +178,7 @@ def start( # type: ignore[override] callback_path: str, callback_kwargs: dict, bundle_info: _BundleInfoLike | None = None, + client: Client, logger: FilteringBoundLogger | None = None, **kwargs, ) -> Self: @@ -159,7 +188,12 @@ def start( # type: ignore[override] # ONLY works because WatchedSubprocess.start() uses os.fork(), so the child # inherits the parent's memory space and the variables are available directly. def _target(): + + from airflow.sdk.execution_time import task_runner + from airflow.sdk.execution_time.comms import CommsDecoder, ToTask + _log = structlog.get_logger(logger_name="callback_runner") + task_runner.SUPERVISOR_COMMS = CommsDecoder[ToTask, CallbackToSupervisor](log=_log) # If bundle info is provided, initialize the bundle and ensure its path is importable. # This is needed for user-defined callbacks that live inside a DAG bundle rather than @@ -192,6 +226,7 @@ def _target(): return super().start( id=UUID(id) if not isinstance(id, UUID) else id, + client=client, target=_target, logger=logger, **kwargs, @@ -241,9 +276,43 @@ def _monitor_subprocess(self): ) self._cleanup_open_sockets() - def _handle_request(self, msg, log: FilteringBoundLogger, req_id: int) -> None: - """Handle incoming requests from the callback subprocess (currently none expected).""" - log.warning("Unexpected request from callback subprocess", msg=msg) + def _handle_request(self, msg: CallbackToSupervisor, log: FilteringBoundLogger, req_id: int) -> None: + """Handle incoming requests from the callback subprocess.""" + from airflow.sdk.exceptions import ErrorType + + if isinstance(msg, MaskSecret): + log.debug("Received request from callback (body omitted)", msg=type(msg)) + else: + log.debug("Received request from callback", msg=msg) + + resp: BaseModel | None = None + dump_opts: dict[str, bool] = {} + + if isinstance(msg, GetConnection): + resp, dump_opts = handle_get_connection(self.client, msg) + elif isinstance(msg, GetVariable): + resp, dump_opts = handle_get_variable(self.client, msg) + elif isinstance(msg, GetXCom): + resp, dump_opts = handle_get_xcom(self.client, msg) + elif isinstance(msg, GetAssetByName): + resp, dump_opts = handle_get_asset_by_name(self.client, msg) + elif isinstance(msg, GetAssetByUri): + resp, dump_opts = handle_get_asset_by_uri(self.client, msg) + elif isinstance(msg, MaskSecret): + handle_mask_secret(msg) + else: + log.warning("Unhandled request from callback subprocess", msg=msg) + self.send_msg( + None, + request_id=req_id, + error=ErrorResponse( + error=ErrorType.API_SERVER_ERROR, + detail={"status_code": 400, "message": "Unhandled request"}, + ), + ) + return + + self.send_msg(resp, request_id=req_id, error=None, **dump_opts) def _configure_logging(log_path: str) -> tuple[FilteringBoundLogger, BinaryIO]: @@ -266,6 +335,9 @@ def supervise_callback( callback_kwargs: dict, log_path: str | None = None, bundle_info: _BundleInfoLike | None = None, + token: str = "", + server: str | None = None, + client: Client | None = None, ) -> int: """ Run a single callback execution to completion in a supervised subprocess. @@ -275,6 +347,9 @@ def supervise_callback( :param callback_kwargs: Keyword arguments to pass to the callback. :param log_path: Path to write logs, if required. :param bundle_info: When provided, the bundle's path is added to sys.path so callbacks in Dag Bundles are importable. + :param token: Authentication token for the API client. + :param server: Base URL of the API server. + :param client: Optional preconfigured client for communication with the server (mostly for tests). :return: Exit code of the subprocess (0 = success). """ _make_process_nondumpable() @@ -290,28 +365,30 @@ def supervise_callback( # so logs are clearly separated from task logs. logger = structlog.get_logger(logger_name="callback").bind() - try: - process = CallbackSubprocess.start( - id=id, - callback_path=callback_path, - callback_kwargs=callback_kwargs, - bundle_info=bundle_info, - logger=logger, - subprocess_logs_to_stdout=True, - ) - - exit_code = process.wait() - end = time.monotonic() - log.info( - "Workload finished", - workload_type="ExecutorCallback", - workload_id=id, - exit_code=exit_code, - duration=end - start, - ) - if exit_code != 0: - raise RuntimeError(f"Callback subprocess exited with code {exit_code}") - return exit_code - finally: - if log_path and log_file_descriptor: - log_file_descriptor.close() + with _ensure_client(server, token, client=client) as client: + try: + process = CallbackSubprocess.start( + id=id, + callback_path=callback_path, + callback_kwargs=callback_kwargs, + bundle_info=bundle_info, + client=client, + logger=logger, + subprocess_logs_to_stdout=True, + ) + + exit_code = process.wait() + end = time.monotonic() + log.info( + "Workload finished", + workload_type="ExecutorCallback", + workload_id=id, + exit_code=exit_code, + duration=end - start, + ) + if exit_code != 0: + raise RuntimeError(f"Callback subprocess exited with code {exit_code}") + return exit_code + finally: + if log_path and log_file_descriptor: + log_file_descriptor.close() diff --git a/task-sdk/src/airflow/sdk/execution_time/request_handlers.py b/task-sdk/src/airflow/sdk/execution_time/request_handlers.py new file mode 100644 index 0000000000000..9692f810512e8 --- /dev/null +++ b/task-sdk/src/airflow/sdk/execution_time/request_handlers.py @@ -0,0 +1,101 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Shared request handlers for supervised subprocess comms channels. + +These functions implement the supervisor-side logic for message types that are +used by more than one subprocess type (tasks, callbacks, triggerer). Each +handler accepts a ``Client`` and a request message and returns +``(response_model | None, dump_opts)`` so the caller can forward the result +via ``send_msg``. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.sdk.api.datamodels._generated import AssetResponse, ConnectionResponse, VariableResponse +from airflow.sdk.execution_time.comms import ( + AssetResult, + ConnectionResult, + GetAssetByName, + GetAssetByUri, + GetConnection, + GetVariable, + GetXCom, + MaskSecret, + VariableResult, + XComResult, +) +from airflow.sdk.log import mask_secret + +if TYPE_CHECKING: + from pydantic import BaseModel + + from airflow.sdk.api.client import Client + + +def handle_get_connection(client: Client, msg: GetConnection) -> tuple[BaseModel | None, dict[str, bool]]: + """Fetch a connection and mask its sensitive fields.""" + conn = client.connections.get(msg.conn_id) + if isinstance(conn, ConnectionResponse): + if conn.password: + mask_secret(conn.password) + if conn.extra: + mask_secret(conn.extra) + return ConnectionResult.from_conn_response(conn), {"exclude_unset": True, "by_alias": True} + return conn, {} + + +def handle_get_variable(client: Client, msg: GetVariable) -> tuple[BaseModel | None, dict[str, bool]]: + """Fetch a variable and mask its value.""" + var = client.variables.get(msg.key) + if isinstance(var, VariableResponse): + if var.value: + mask_secret(var.value, var.key) + return VariableResult.from_variable_response(var), {"exclude_unset": True} + return var, {} + + +def handle_get_xcom(client: Client, msg: GetXCom) -> tuple[BaseModel | None, dict[str, bool]]: + """Fetch a single XCom value.""" + xcom = client.xcoms.get( + msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index, msg.include_prior_dates + ) + return XComResult.from_xcom_response(xcom), {} + + +def handle_get_asset_by_name(client: Client, msg: GetAssetByName) -> tuple[BaseModel | None, dict[str, bool]]: + """Fetch an asset by name.""" + asset_resp = client.assets.get(name=msg.name) + if isinstance(asset_resp, AssetResponse): + return AssetResult.from_asset_response(asset_resp), {"exclude_unset": True} + return asset_resp, {} + + +def handle_get_asset_by_uri(client: Client, msg: GetAssetByUri) -> tuple[BaseModel | None, dict[str, bool]]: + """Fetch an asset by URI.""" + asset_resp = client.assets.get(uri=msg.uri) + if isinstance(asset_resp, AssetResponse): + return AssetResult.from_asset_response(asset_resp), {"exclude_unset": True} + return asset_resp, {} + + +def handle_mask_secret(msg: MaskSecret) -> None: + """Register a value with the secrets masker.""" + mask_secret(msg.value, msg.name) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index c50199df7e490..854033f9371ee 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -50,12 +50,10 @@ from airflow.sdk._shared.logging.structlog import reconfigure_logger from airflow.sdk.api.client import Client, ServerResponseError from airflow.sdk.api.datamodels._generated import ( - AssetResponse, ConnectionResponse, TaskInstance, TaskInstanceState, TaskStatesResponse, - VariableResponse, XComSequenceIndexResponse, ) from airflow.sdk.configuration import conf @@ -63,7 +61,6 @@ from airflow.sdk.execution_time import comms from airflow.sdk.execution_time.comms import ( AssetEventsResult, - AssetResult, ConnectionResult, CreateHITLDetailPayload, DagResult, @@ -115,14 +112,19 @@ ToSupervisor, TriggerDagRun, ValidateInletsAndOutlets, - VariableResult, - XComResult, XComSequenceIndexResult, XComSequenceSliceResult, _RequestFrame, _ResponseFrame, ) -from airflow.sdk.log import mask_secret +from airflow.sdk.execution_time.request_handlers import ( + handle_get_asset_by_name, + handle_get_asset_by_uri, + handle_get_connection, + handle_get_variable, + handle_get_xcom, + handle_mask_secret, +) try: from socket import send_fds @@ -902,6 +904,33 @@ def _fetch_remote_logging_conn(conn_id: str, client: Client) -> Connection | Non return result +@contextlib.contextmanager +def _ensure_client( + server: str | None, + token: str, + client: Client | None = None, + dry_run: bool = False, +) -> Generator[Client, None, None]: + """ + Yield an API client, creating one if not provided. + + If a client is created internally, it will be closed when the context exits. + Pre-existing clients are yielded as-is and left open for the caller to manage. + """ + if client: + yield client + return + + limits = httpx.Limits(max_keepalive_connections=1, max_connections=10) + new_client = Client(base_url=server or "", limits=limits, dry_run=dry_run, token=token) + log.debug("Connecting to execution API server", server=server) + try: + yield new_client + finally: + with suppress(Exception): + new_client.close() + + @contextlib.contextmanager def _remote_logging_conn(client: Client): """ @@ -1252,7 +1281,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: else: log.debug("Received message from task runner", msg=msg) resp: BaseModel | None = None - dump_opts = {} + dump_opts: dict[str, bool] = {} if isinstance(msg, TaskState): self._terminal_state = msg.state self._task_end_time_monotonic = time.monotonic() @@ -1278,33 +1307,11 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: rendered_map_index=self._rendered_map_index, ) elif isinstance(msg, GetConnection): - conn = self.client.connections.get(msg.conn_id) - if isinstance(conn, ConnectionResponse): - if conn.password: - mask_secret(conn.password) - if conn.extra: - mask_secret(conn.extra) - conn_result = ConnectionResult.from_conn_response(conn) - resp = conn_result - dump_opts = {"exclude_unset": True, "by_alias": True} - else: - resp = conn + resp, dump_opts = handle_get_connection(self.client, msg) elif isinstance(msg, GetVariable): - var = self.client.variables.get(msg.key) - if isinstance(var, VariableResponse): - if var.value: - mask_secret(var.value, var.key) - var_result = VariableResult.from_variable_response(var) - resp = var_result - dump_opts = {"exclude_unset": True} - else: - resp = var + resp, dump_opts = handle_get_variable(self.client, msg) elif isinstance(msg, GetXCom): - xcom = self.client.xcoms.get( - msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index, msg.include_prior_dates - ) - xcom_result = XComResult.from_xcom_response(xcom) - resp = xcom_result + resp, dump_opts = handle_get_xcom(self.client, msg) elif isinstance(msg, GetXComCount): resp = self.client.xcoms.head(msg.dag_id, msg.run_id, msg.task_id, msg.key) elif isinstance(msg, GetXComSequenceItem): @@ -1356,21 +1363,9 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: elif isinstance(msg, SetRenderedMapIndex): self.client.task_instances.set_rendered_map_index(self.id, msg.rendered_map_index) elif isinstance(msg, GetAssetByName): - asset_resp = self.client.assets.get(name=msg.name) - if isinstance(asset_resp, AssetResponse): - asset_result = AssetResult.from_asset_response(asset_resp) - resp = asset_result - dump_opts = {"exclude_unset": True} - else: - resp = asset_resp + resp, dump_opts = handle_get_asset_by_name(self.client, msg) elif isinstance(msg, GetAssetByUri): - asset_resp = self.client.assets.get(uri=msg.uri) - if isinstance(asset_resp, AssetResponse): - asset_result = AssetResult.from_asset_response(asset_resp) - resp = asset_result - dump_opts = {"exclude_unset": True} - else: - resp = asset_resp + resp, dump_opts = handle_get_asset_by_uri(self.client, msg) elif isinstance(msg, GetAssetEventByAsset): asset_event_resp = self.client.asset_events.get( uri=msg.uri, @@ -1484,7 +1479,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: resp = HITLDetailRequestResult.from_api_response(hitl_detail_request) dump_opts = {"exclude_unset": True} elif isinstance(msg, MaskSecret): - mask_secret(msg.value, msg.name) + handle_mask_secret(msg) elif isinstance(msg, GetDag): dag = self.client.dags.get( dag_id=msg.dag_id, @@ -2075,58 +2070,49 @@ def supervise_task( if not dag_rel_path: raise ValueError("dag_path is required") - close_client = False - if not client: - limits = httpx.Limits(max_keepalive_connections=1, max_connections=10) - client = Client(base_url=server or "", limits=limits, dry_run=dry_run, token=token) - close_client = True - log.debug("Connecting to execution API server", server=server) - - start = time.monotonic() + with _ensure_client(server, token, client=client, dry_run=dry_run) as client: + start = time.monotonic() - # TODO: Use logging providers to handle the chunked upload for us etc. - logger: FilteringBoundLogger | None = None - log_file_descriptor: BinaryIO | TextIO | None = None - if log_path: - logger, log_file_descriptor = _configure_logging(log_path, client) + # TODO: Use logging providers to handle the chunked upload for us etc. + logger: FilteringBoundLogger | None = None + log_file_descriptor: BinaryIO | TextIO | None = None + if log_path: + logger, log_file_descriptor = _configure_logging(log_path, client) - backends = ensure_secrets_backend_loaded() - log.info( - "Secrets backends loaded for worker", - count=len(backends), - backend_classes=[type(b).__name__ for b in backends], - ) + backends = ensure_secrets_backend_loaded() + log.info( + "Secrets backends loaded for worker", + count=len(backends), + backend_classes=[type(b).__name__ for b in backends], + ) - reset_secrets_masker() + reset_secrets_masker() - try: - process = ActivitySubprocess.start( - dag_rel_path=dag_rel_path, - what=ti, - client=client, - logger=logger, - bundle_info=bundle_info, - subprocess_logs_to_stdout=subprocess_logs_to_stdout, - sentry_integration=sentry_integration, - ) + try: + process = ActivitySubprocess.start( + dag_rel_path=dag_rel_path, + what=ti, + client=client, + logger=logger, + bundle_info=bundle_info, + subprocess_logs_to_stdout=subprocess_logs_to_stdout, + sentry_integration=sentry_integration, + ) - exit_code = process.wait() - end = time.monotonic() - log.info( - "Workload finished", - workload_type="ExecuteTask", - workload_id=str(ti.id), - exit_code=exit_code, - duration=end - start, - final_state=process.final_state, - ) - return exit_code - finally: - if log_path and log_file_descriptor: - log_file_descriptor.close() - if close_client and client: - with suppress(Exception): - client.close() + exit_code = process.wait() + end = time.monotonic() + log.info( + "Workload finished", + workload_type="ExecuteTask", + workload_id=str(ti.id), + exit_code=exit_code, + duration=end - start, + final_state=process.final_state, + ) + return exit_code + finally: + if log_path and log_file_descriptor: + log_file_descriptor.close() def supervise(**kwargs) -> int: diff --git a/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py index 93b459d200e87..3c4cc57646bd2 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py @@ -1,3 +1,4 @@ +# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -18,10 +19,29 @@ from __future__ import annotations +import socket +from dataclasses import dataclass +from operator import attrgetter +from typing import Any +from unittest.mock import patch + import pytest import structlog -from airflow.sdk.execution_time.callback_supervisor import execute_callback +from airflow.sdk.api.datamodels._generated import AssetResponse +from airflow.sdk.execution_time.callback_supervisor import CallbackSubprocess, execute_callback +from airflow.sdk.execution_time.comms import ( + ConnectionResult, + GetAssetByName, + GetAssetByUri, + GetConnection, + GetVariable, + GetXCom, + MaskSecret, + VariableResult, + XComResult, + _RequestFrame, +) def callback_no_args(): @@ -113,3 +133,130 @@ def test_execute_callback(self, path, kwargs, expect_success, error_contains): assert error_contains in error else: assert error is None + + +class TestCallbackHandleRequest: + """Verify that CallbackSubprocess._handle_request dispatches each message type to the correct handler.""" + + @dataclass + class ClientMock: + method_path: str + args: tuple = () + kwargs: dict | None = None + response: Any = None + + def __post_init__(self): + if self.kwargs is None: + self.kwargs = {} + + @dataclass + class RequestCase: + message: Any + test_id: str + client_mock: Any = None # Should be ClientMock but Python can't forward-ref sibling nested classes + mask_secret_args: tuple | None = None + + REQUEST_CASES = [ + RequestCase( + message=GetConnection(conn_id="test_conn"), + test_id="get_connection", + client_mock=ClientMock( + method_path="connections.get", + args=("test_conn",), + response=ConnectionResult(conn_id="test_conn", conn_type="mysql"), + ), + ), + RequestCase( + message=GetConnection(conn_id="test_conn"), + test_id="get_connection_with_password", + client_mock=ClientMock( + method_path="connections.get", + args=("test_conn",), + response=ConnectionResult(conn_id="test_conn", conn_type="mysql", password="secret"), + ), + mask_secret_args=("secret",), + ), + RequestCase( + message=GetVariable(key="test_key"), + test_id="get_variable", + client_mock=ClientMock( + method_path="variables.get", + args=("test_key",), + response=VariableResult(key="test_key", value="test_value"), + ), + ), + RequestCase( + message=GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), + test_id="get_xcom", + client_mock=ClientMock( + method_path="xcoms.get", + args=("test_dag", "test_run", "test_task", "test_key", None, False), + response=XComResult(key="test_key", value="test_value"), + ), + ), + RequestCase( + message=GetAssetByName(name="my_asset"), + test_id="get_asset_by_name", + client_mock=ClientMock( + method_path="assets.get", + kwargs={"name": "my_asset"}, + response=AssetResponse(name="my_asset", uri="s3://bucket/key", group="default"), + ), + ), + RequestCase( + message=GetAssetByUri(uri="s3://bucket/key"), + test_id="get_asset_by_uri", + client_mock=ClientMock( + method_path="assets.get", + kwargs={"uri": "s3://bucket/key"}, + response=AssetResponse(name="my_asset", uri="s3://bucket/key", group="default"), + ), + ), + RequestCase( + message=MaskSecret(value="super_secret", name="api_key"), + test_id="mask_secret", + mask_secret_args=("super_secret", "api_key"), + ), + ] + + @pytest.fixture + def callback_subprocess(self, mocker): + read_end, write_end = socket.socketpair() + proc = CallbackSubprocess( + process_log=mocker.MagicMock(), + id="12345678-1234-5678-1234-567812345678", + pid=12345, + stdin=write_end, + client=mocker.Mock(), + process=mocker.Mock(), + ) + return proc, read_end + + @patch("airflow.sdk.execution_time.request_handlers.mask_secret") + @pytest.mark.parametrize("test_case", REQUEST_CASES, ids=lambda tc: tc.test_id) + def test_handle_requests( + self, + mock_mask_secret, + callback_subprocess, + mocker, + test_case, + ): + client_mock = test_case.client_mock + + proc, _read_end = callback_subprocess + + if client_mock: + mock_client_method = attrgetter(client_mock.method_path)(proc.client) + mock_client_method.return_value = client_mock.response + + generator = proc.handle_requests(log=mocker.Mock()) + next(generator) + + req_frame = _RequestFrame(id=42, body=test_case.message.model_dump()) + generator.send(req_frame) + + if test_case.mask_secret_args is not None: + mock_mask_secret.assert_called_with(*test_case.mask_secret_args) + + if client_mock: + mock_client_method.assert_called_once_with(*client_mock.args, **client_mock.kwargs) diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 17fe9ec1defa8..1f4e55d73c1fa 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -2576,7 +2576,7 @@ def watched_subprocess(self, mocker): return subprocess, read_end - @patch("airflow.sdk.execution_time.supervisor.mask_secret") + @patch("airflow.sdk.execution_time.request_handlers.mask_secret") @pytest.mark.parametrize("test_case", REQUEST_TEST_CASES, ids=lambda tc: tc.test_id) def test_handle_requests( self, From afa020510ee93aea648e663087fcbf421a89f3e4 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 14 Apr 2026 17:00:03 -0700 Subject: [PATCH 2/6] add news fragment --- airflow-core/newsfragments/65269.significant.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 airflow-core/newsfragments/65269.significant.rst diff --git a/airflow-core/newsfragments/65269.significant.rst b/airflow-core/newsfragments/65269.significant.rst new file mode 100644 index 0000000000000..5b253d81c9c71 --- /dev/null +++ b/airflow-core/newsfragments/65269.significant.rst @@ -0,0 +1 @@ +Synchronous deadline callbacks (``SyncCallback``) can now access Connections, Variables, XCom, and Assets from the Airflow metadata database. From ede7c6d96a3f0b059f29fb9860dc5a25356a4936 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 14 Apr 2026 17:48:10 -0700 Subject: [PATCH 3/6] kaxil fixes --- .../airflow/sdk/execution_time/callback_supervisor.py | 3 +-- .../airflow/sdk/execution_time/request_handlers.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py index 1f53beb9b666a..592d153aa75b9 100644 --- a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py @@ -29,6 +29,7 @@ from pydantic import Field, TypeAdapter from airflow.sdk._shared.module_loading import accepts_context, accepts_keyword_args +from airflow.sdk.exceptions import ErrorType from airflow.sdk.execution_time.comms import ( ErrorResponse, GetAssetByName, @@ -278,8 +279,6 @@ def _monitor_subprocess(self): def _handle_request(self, msg: CallbackToSupervisor, log: FilteringBoundLogger, req_id: int) -> None: """Handle incoming requests from the callback subprocess.""" - from airflow.sdk.exceptions import ErrorType - if isinstance(msg, MaskSecret): log.debug("Received request from callback (body omitted)", msg=type(msg)) else: diff --git a/task-sdk/src/airflow/sdk/execution_time/request_handlers.py b/task-sdk/src/airflow/sdk/execution_time/request_handlers.py index 9692f810512e8..81ec55d024eb4 100644 --- a/task-sdk/src/airflow/sdk/execution_time/request_handlers.py +++ b/task-sdk/src/airflow/sdk/execution_time/request_handlers.py @@ -29,7 +29,12 @@ from typing import TYPE_CHECKING -from airflow.sdk.api.datamodels._generated import AssetResponse, ConnectionResponse, VariableResponse +from airflow.sdk.api.datamodels._generated import ( + AssetResponse, + ConnectionResponse, + VariableResponse, + XComResponse, +) from airflow.sdk.execution_time.comms import ( AssetResult, ConnectionResult, @@ -77,7 +82,9 @@ def handle_get_xcom(client: Client, msg: GetXCom) -> tuple[BaseModel | None, dic xcom = client.xcoms.get( msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index, msg.include_prior_dates ) - return XComResult.from_xcom_response(xcom), {} + if isinstance(xcom, XComResponse): + return XComResult.from_xcom_response(xcom), {} + return xcom, {} def handle_get_asset_by_name(client: Client, msg: GetAssetByName) -> tuple[BaseModel | None, dict[str, bool]]: From 5fd9a3b95317184ebe105951a04b3d3b00ab60bd Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Mon, 20 Apr 2026 12:49:04 -0700 Subject: [PATCH 4/6] revert xcom handler portion --- .../sdk/execution_time/callback_supervisor.py | 10 +++------- .../airflow/sdk/execution_time/request_handlers.py | 13 ------------- .../src/airflow/sdk/execution_time/supervisor.py | 8 ++++++-- .../execution_time/test_callback_supervisor.py | 11 ----------- 4 files changed, 9 insertions(+), 33 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py index 592d153aa75b9..c21d6d97828f4 100644 --- a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py @@ -36,7 +36,6 @@ GetAssetByUri, GetConnection, GetVariable, - GetXCom, MaskSecret, ) from airflow.sdk.execution_time.request_handlers import ( @@ -44,7 +43,6 @@ handle_get_asset_by_uri, handle_get_connection, handle_get_variable, - handle_get_xcom, handle_mask_secret, ) from airflow.sdk.execution_time.supervisor import ( @@ -76,9 +74,9 @@ class _BundleInfoLike(Protocol): # The set of messages that a callback subprocess can send to the supervisor. # This is a minimal subset of ToSupervisor: read-only access to Connections, -# Variables, XCom, and Assets, plus MaskSecret for the secrets masker. +# Variables, and Assets, plus MaskSecret for the secrets masker. CallbackToSupervisor = Annotated[ - GetAssetByName | GetAssetByUri | GetConnection | GetVariable | GetXCom | MaskSecret, + GetAssetByName | GetAssetByUri | GetConnection | GetVariable | MaskSecret, Field(discriminator="type"), ] @@ -162,7 +160,7 @@ class CallbackSubprocess(WatchedSubprocess): Uses the WatchedSubprocess infrastructure for fork/monitor/signal handling while keeping a simple lifecycle: start, run callback, exit. - Provides a limited set of comms channels (Connections, Variables, XCom, + Provides a limited set of comms channels (Connections, Variables, Assets) so that callback code can access runtime services like ``Connection.get()`` and ``Variable.get()`` via the supervisor's API client. """ @@ -291,8 +289,6 @@ def _handle_request(self, msg: CallbackToSupervisor, log: FilteringBoundLogger, resp, dump_opts = handle_get_connection(self.client, msg) elif isinstance(msg, GetVariable): resp, dump_opts = handle_get_variable(self.client, msg) - elif isinstance(msg, GetXCom): - resp, dump_opts = handle_get_xcom(self.client, msg) elif isinstance(msg, GetAssetByName): resp, dump_opts = handle_get_asset_by_name(self.client, msg) elif isinstance(msg, GetAssetByUri): diff --git a/task-sdk/src/airflow/sdk/execution_time/request_handlers.py b/task-sdk/src/airflow/sdk/execution_time/request_handlers.py index 81ec55d024eb4..46fe81115daa7 100644 --- a/task-sdk/src/airflow/sdk/execution_time/request_handlers.py +++ b/task-sdk/src/airflow/sdk/execution_time/request_handlers.py @@ -33,7 +33,6 @@ AssetResponse, ConnectionResponse, VariableResponse, - XComResponse, ) from airflow.sdk.execution_time.comms import ( AssetResult, @@ -42,10 +41,8 @@ GetAssetByUri, GetConnection, GetVariable, - GetXCom, MaskSecret, VariableResult, - XComResult, ) from airflow.sdk.log import mask_secret @@ -77,16 +74,6 @@ def handle_get_variable(client: Client, msg: GetVariable) -> tuple[BaseModel | N return var, {} -def handle_get_xcom(client: Client, msg: GetXCom) -> tuple[BaseModel | None, dict[str, bool]]: - """Fetch a single XCom value.""" - xcom = client.xcoms.get( - msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index, msg.include_prior_dates - ) - if isinstance(xcom, XComResponse): - return XComResult.from_xcom_response(xcom), {} - return xcom, {} - - def handle_get_asset_by_name(client: Client, msg: GetAssetByName) -> tuple[BaseModel | None, dict[str, bool]]: """Fetch an asset by name.""" asset_resp = client.assets.get(name=msg.name) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 854033f9371ee..a55ba5dd661a3 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -112,6 +112,7 @@ ToSupervisor, TriggerDagRun, ValidateInletsAndOutlets, + XComResult, XComSequenceIndexResult, XComSequenceSliceResult, _RequestFrame, @@ -122,7 +123,6 @@ handle_get_asset_by_uri, handle_get_connection, handle_get_variable, - handle_get_xcom, handle_mask_secret, ) @@ -1311,7 +1311,11 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: elif isinstance(msg, GetVariable): resp, dump_opts = handle_get_variable(self.client, msg) elif isinstance(msg, GetXCom): - resp, dump_opts = handle_get_xcom(self.client, msg) + xcom = self.client.xcoms.get( + msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index, msg.include_prior_dates + ) + xcom_result = XComResult.from_xcom_response(xcom) + resp = xcom_result elif isinstance(msg, GetXComCount): resp = self.client.xcoms.head(msg.dag_id, msg.run_id, msg.task_id, msg.key) elif isinstance(msg, GetXComSequenceItem): diff --git a/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py index 3c4cc57646bd2..6762c7e7a6493 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py @@ -36,10 +36,8 @@ GetAssetByUri, GetConnection, GetVariable, - GetXCom, MaskSecret, VariableResult, - XComResult, _RequestFrame, ) @@ -185,15 +183,6 @@ class RequestCase: response=VariableResult(key="test_key", value="test_value"), ), ), - RequestCase( - message=GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), - test_id="get_xcom", - client_mock=ClientMock( - method_path="xcoms.get", - args=("test_dag", "test_run", "test_task", "test_key", None, False), - response=XComResult(key="test_key", value="test_value"), - ), - ), RequestCase( message=GetAssetByName(name="my_asset"), test_id="get_asset_by_name", From 1957c7b207e022eb9fb80935ad3c0b2abab8f19b Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Wed, 22 Apr 2026 11:48:09 -0700 Subject: [PATCH 5/6] revert asset handlers --- .../sdk/execution_time/callback_supervisor.py | 18 +++++----------- .../airflow/sdk/execution_time/supervisor.py | 20 ++++++++++++++---- .../test_callback_supervisor.py | 21 ------------------- 3 files changed, 21 insertions(+), 38 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py index c21d6d97828f4..13b736f14ce60 100644 --- a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py @@ -32,15 +32,11 @@ from airflow.sdk.exceptions import ErrorType from airflow.sdk.execution_time.comms import ( ErrorResponse, - GetAssetByName, - GetAssetByUri, GetConnection, GetVariable, MaskSecret, ) from airflow.sdk.execution_time.request_handlers import ( - handle_get_asset_by_name, - handle_get_asset_by_uri, handle_get_connection, handle_get_variable, handle_mask_secret, @@ -73,10 +69,10 @@ class _BundleInfoLike(Protocol): # The set of messages that a callback subprocess can send to the supervisor. -# This is a minimal subset of ToSupervisor: read-only access to Connections, -# Variables, and Assets, plus MaskSecret for the secrets masker. +# This is a minimal subset of ToSupervisor: read-only access to Connections +# and Variables, plus MaskSecret for the secrets masker. CallbackToSupervisor = Annotated[ - GetAssetByName | GetAssetByUri | GetConnection | GetVariable | MaskSecret, + GetConnection | GetVariable | MaskSecret, Field(discriminator="type"), ] @@ -160,8 +156,8 @@ class CallbackSubprocess(WatchedSubprocess): Uses the WatchedSubprocess infrastructure for fork/monitor/signal handling while keeping a simple lifecycle: start, run callback, exit. - Provides a limited set of comms channels (Connections, Variables, - Assets) so that callback code can access runtime services like + Provides a limited set of comms channels (Connections and Variables) so + that callback code can access runtime services like ``Connection.get()`` and ``Variable.get()`` via the supervisor's API client. """ @@ -289,10 +285,6 @@ def _handle_request(self, msg: CallbackToSupervisor, log: FilteringBoundLogger, resp, dump_opts = handle_get_connection(self.client, msg) elif isinstance(msg, GetVariable): resp, dump_opts = handle_get_variable(self.client, msg) - elif isinstance(msg, GetAssetByName): - resp, dump_opts = handle_get_asset_by_name(self.client, msg) - elif isinstance(msg, GetAssetByUri): - resp, dump_opts = handle_get_asset_by_uri(self.client, msg) elif isinstance(msg, MaskSecret): handle_mask_secret(msg) else: diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index a55ba5dd661a3..6249765bf579d 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -50,6 +50,7 @@ from airflow.sdk._shared.logging.structlog import reconfigure_logger from airflow.sdk.api.client import Client, ServerResponseError from airflow.sdk.api.datamodels._generated import ( + AssetResponse, ConnectionResponse, TaskInstance, TaskInstanceState, @@ -61,6 +62,7 @@ from airflow.sdk.execution_time import comms from airflow.sdk.execution_time.comms import ( AssetEventsResult, + AssetResult, ConnectionResult, CreateHITLDetailPayload, DagResult, @@ -119,8 +121,6 @@ _ResponseFrame, ) from airflow.sdk.execution_time.request_handlers import ( - handle_get_asset_by_name, - handle_get_asset_by_uri, handle_get_connection, handle_get_variable, handle_mask_secret, @@ -1367,9 +1367,21 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: elif isinstance(msg, SetRenderedMapIndex): self.client.task_instances.set_rendered_map_index(self.id, msg.rendered_map_index) elif isinstance(msg, GetAssetByName): - resp, dump_opts = handle_get_asset_by_name(self.client, msg) + asset_resp = self.client.assets.get(name=msg.name) + if isinstance(asset_resp, AssetResponse): + asset_result = AssetResult.from_asset_response(asset_resp) + resp = asset_result + dump_opts = {"exclude_unset": True} + else: + resp = asset_resp elif isinstance(msg, GetAssetByUri): - resp, dump_opts = handle_get_asset_by_uri(self.client, msg) + asset_resp = self.client.assets.get(uri=msg.uri) + if isinstance(asset_resp, AssetResponse): + asset_result = AssetResult.from_asset_response(asset_resp) + resp = asset_result + dump_opts = {"exclude_unset": True} + else: + resp = asset_resp elif isinstance(msg, GetAssetEventByAsset): asset_event_resp = self.client.asset_events.get( uri=msg.uri, diff --git a/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py index 6762c7e7a6493..8cb9fdcc8167a 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py @@ -28,12 +28,9 @@ import pytest import structlog -from airflow.sdk.api.datamodels._generated import AssetResponse from airflow.sdk.execution_time.callback_supervisor import CallbackSubprocess, execute_callback from airflow.sdk.execution_time.comms import ( ConnectionResult, - GetAssetByName, - GetAssetByUri, GetConnection, GetVariable, MaskSecret, @@ -183,24 +180,6 @@ class RequestCase: response=VariableResult(key="test_key", value="test_value"), ), ), - RequestCase( - message=GetAssetByName(name="my_asset"), - test_id="get_asset_by_name", - client_mock=ClientMock( - method_path="assets.get", - kwargs={"name": "my_asset"}, - response=AssetResponse(name="my_asset", uri="s3://bucket/key", group="default"), - ), - ), - RequestCase( - message=GetAssetByUri(uri="s3://bucket/key"), - test_id="get_asset_by_uri", - client_mock=ClientMock( - method_path="assets.get", - kwargs={"uri": "s3://bucket/key"}, - response=AssetResponse(name="my_asset", uri="s3://bucket/key", group="default"), - ), - ), RequestCase( message=MaskSecret(value="super_secret", name="api_key"), test_id="mask_secret", From 9468ee9a15e24836b9e61714fdb39ff6d1535fc5 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Wed, 22 Apr 2026 16:19:53 -0700 Subject: [PATCH 6/6] revert asset handlers fixes --- .../newsfragments/65269.significant.rst | 2 +- .../sdk/execution_time/callback_supervisor.py | 1 - .../sdk/execution_time/request_handlers.py | 20 ------------------- 3 files changed, 1 insertion(+), 22 deletions(-) diff --git a/airflow-core/newsfragments/65269.significant.rst b/airflow-core/newsfragments/65269.significant.rst index 5b253d81c9c71..c5f68aa9457a0 100644 --- a/airflow-core/newsfragments/65269.significant.rst +++ b/airflow-core/newsfragments/65269.significant.rst @@ -1 +1 @@ -Synchronous deadline callbacks (``SyncCallback``) can now access Connections, Variables, XCom, and Assets from the Airflow metadata database. +Synchronous deadline callbacks (``SyncCallback``) can now access Connections and Variables from the Airflow metadata database. diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py index 13b736f14ce60..94d84193192db 100644 --- a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py @@ -183,7 +183,6 @@ def start( # type: ignore[override] # ONLY works because WatchedSubprocess.start() uses os.fork(), so the child # inherits the parent's memory space and the variables are available directly. def _target(): - from airflow.sdk.execution_time import task_runner from airflow.sdk.execution_time.comms import CommsDecoder, ToTask diff --git a/task-sdk/src/airflow/sdk/execution_time/request_handlers.py b/task-sdk/src/airflow/sdk/execution_time/request_handlers.py index 46fe81115daa7..eed3e840e395b 100644 --- a/task-sdk/src/airflow/sdk/execution_time/request_handlers.py +++ b/task-sdk/src/airflow/sdk/execution_time/request_handlers.py @@ -30,15 +30,11 @@ from typing import TYPE_CHECKING from airflow.sdk.api.datamodels._generated import ( - AssetResponse, ConnectionResponse, VariableResponse, ) from airflow.sdk.execution_time.comms import ( - AssetResult, ConnectionResult, - GetAssetByName, - GetAssetByUri, GetConnection, GetVariable, MaskSecret, @@ -74,22 +70,6 @@ def handle_get_variable(client: Client, msg: GetVariable) -> tuple[BaseModel | N return var, {} -def handle_get_asset_by_name(client: Client, msg: GetAssetByName) -> tuple[BaseModel | None, dict[str, bool]]: - """Fetch an asset by name.""" - asset_resp = client.assets.get(name=msg.name) - if isinstance(asset_resp, AssetResponse): - return AssetResult.from_asset_response(asset_resp), {"exclude_unset": True} - return asset_resp, {} - - -def handle_get_asset_by_uri(client: Client, msg: GetAssetByUri) -> tuple[BaseModel | None, dict[str, bool]]: - """Fetch an asset by URI.""" - asset_resp = client.assets.get(uri=msg.uri) - if isinstance(asset_resp, AssetResponse): - return AssetResult.from_asset_response(asset_resp), {"exclude_unset": True} - return asset_resp, {} - - def handle_mask_secret(msg: MaskSecret) -> None: """Register a value with the secrets masker.""" mask_secret(msg.value, msg.name)