diff --git a/airflow-core/newsfragments/65269.significant.rst b/airflow-core/newsfragments/65269.significant.rst new file mode 100644 index 0000000000000..c5f68aa9457a0 --- /dev/null +++ b/airflow-core/newsfragments/65269.significant.rst @@ -0,0 +1 @@ +Synchronous deadline callbacks (``SyncCallback``) can now access Connections and Variables from the Airflow metadata database. diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py index cc708545cfac0..e37e68683b353 100644 --- a/airflow-core/src/airflow/executors/base_executor.py +++ b/airflow-core/src/airflow/executors/base_executor.py @@ -651,6 +651,8 @@ def run_workload( callback_kwargs=workload.callback.data.get("kwargs", {}), log_path=workload.log_path, bundle_info=workload.bundle_info, + token=workload.token, + server=server, ) raise ValueError(f"Unknown workload type: {type(workload).__name__}") diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 7522349c461d0..5d3ecdb529565 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -84,6 +84,11 @@ _new_encoder, _RequestFrame, ) +from airflow.sdk.execution_time.request_handlers import ( + handle_get_connection, + handle_get_variable, + handle_mask_secret, +) from airflow.sdk.execution_time.supervisor import WatchedSubprocess, make_buffered_socket_reader from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance from airflow.serialization.serialized_objects import DagSerialization @@ -447,14 +452,12 @@ def client(self) -> Client: def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, req_id: int) -> None: from airflow.sdk.api.datamodels._generated import ( - ConnectionResponse, TaskStatesResponse, - VariableResponse, XComResponse, ) resp: BaseModel | None = None - dump_opts = {} + dump_opts: dict[str, bool] = {} if isinstance(msg, messages.TriggerStateChanges): if msg.events: @@ -482,29 +485,11 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r resp = response elif isinstance(msg, GetConnection): - conn = self.client.connections.get(msg.conn_id) - if isinstance(conn, ConnectionResponse): - conn_result = ConnectionResult.from_conn_response(conn) - resp = conn_result - # `by_alias=True` is used to convert the `schema` field to `schema_` in the Connection model - dump_opts = {"exclude_unset": True, "by_alias": True} - else: - resp = conn + resp, dump_opts = handle_get_connection(self.client, msg) elif isinstance(msg, DeleteVariable): resp = self.client.variables.delete(msg.key) elif isinstance(msg, GetVariable): - var = self.client.variables.get(msg.key) - if isinstance(var, VariableResponse): - # TODO: call for help to figure out why this is needed - if var.value: - from airflow.sdk.log import mask_secret - - mask_secret(var.value, var.key) - var_result = VariableResult.from_variable_response(var) - resp = var_result - dump_opts = {"exclude_unset": True} - else: - resp = var + resp, dump_opts = handle_get_variable(self.client, msg) elif isinstance(msg, PutVariable): self.client.variables.set(msg.key, msg.value, msg.description) elif isinstance(msg, DeleteXCom): @@ -583,9 +568,7 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r api_resp = self.client.hitl.get_detail_response(ti_id=msg.ti_id) resp = HITLDetailResponseResult.from_api_response(response=api_resp) elif isinstance(msg, MaskSecret): - from airflow.sdk.log import mask_secret - - mask_secret(msg.value, msg.name) + handle_mask_secret(msg) else: raise ValueError(f"Unknown message type {type(msg)}") diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index 0f3c4deca41d2..532a56eff8a68 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -394,6 +394,8 @@ def test_global_executor_without_team_name(self): class TestLocalExecutorCallbackSupport: CALLBACK_UUID = "12345678-1234-5678-1234-567812345678" + TEST_TOKEN = "test_token" + TEST_SERVER = "http://localhost:8080/execution/" def test_supports_callbacks_flag_is_true(self): executor = LocalExecutor() @@ -451,6 +453,8 @@ def test_execute_workload_calls_supervise_callback(self, mock_supervise_callback callback_kwargs={"arg1": "val1"}, log_path="test.log", bundle_info=BundleInfo(name="test_bundle", version="1.0"), + token=TestLocalExecutorCallbackSupport.TEST_TOKEN, + server=TestLocalExecutorCallbackSupport.TEST_SERVER, ) @mock.patch( diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py index 58a1e766e52ca..94d84193192db 100644 --- a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py @@ -21,25 +21,41 @@ import sys import time from importlib import import_module -from typing import TYPE_CHECKING, BinaryIO, ClassVar, Protocol +from typing import TYPE_CHECKING, Annotated, BinaryIO, ClassVar, Protocol from uuid import UUID import attrs import structlog -from pydantic import TypeAdapter +from pydantic import Field, TypeAdapter from airflow.sdk._shared.module_loading import accepts_context, accepts_keyword_args +from airflow.sdk.exceptions import ErrorType +from airflow.sdk.execution_time.comms import ( + ErrorResponse, + GetConnection, + GetVariable, + MaskSecret, +) +from airflow.sdk.execution_time.request_handlers import ( + handle_get_connection, + handle_get_variable, + handle_mask_secret, +) from airflow.sdk.execution_time.supervisor import ( MIN_HEARTBEAT_INTERVAL, SOCKET_CLEANUP_TIMEOUT, WatchedSubprocess, + _ensure_client, _make_process_nondumpable, ) if TYPE_CHECKING: + from pydantic import BaseModel from structlog.typing import FilteringBoundLogger from typing_extensions import Self + from airflow.sdk.api.client import Client + # Core (airflow.executors.workloads.base.BundleInfo) and SDK (airflow.sdk.api.datamodels._generated.BundleInfo) # are structurally identical, but MyPy treats them as different types. This Protocol makes MyPy happy. class _BundleInfoLike(Protocol): @@ -52,6 +68,15 @@ class _BundleInfoLike(Protocol): log: FilteringBoundLogger = structlog.get_logger(logger_name="callback_supervisor") +# The set of messages that a callback subprocess can send to the supervisor. +# This is a minimal subset of ToSupervisor: read-only access to Connections +# and Variables, plus MaskSecret for the secrets masker. +CallbackToSupervisor = Annotated[ + GetConnection | GetVariable | MaskSecret, + Field(discriminator="type"), +] + + def execute_callback( callback_path: str, callback_kwargs: dict, @@ -123,14 +148,6 @@ def execute_callback( return False, error_msg -# An empty message set; the callback subprocess doesn't currently communicate back to the -# supervisor. This means callback code cannot access runtime services like Connection.get() -# or Variable.get() which require the supervisor to pass requests to the API server. -# To enable this, add the needed message types here and implement _handle_request accordingly. -# See ActivitySubprocess.decoder in supervisor.py for the full task message set and examples. -_EmptyMessage: TypeAdapter[None] = TypeAdapter(None) - - @attrs.define(kw_only=True) class CallbackSubprocess(WatchedSubprocess): """ @@ -138,9 +155,15 @@ class CallbackSubprocess(WatchedSubprocess): Uses the WatchedSubprocess infrastructure for fork/monitor/signal handling while keeping a simple lifecycle: start, run callback, exit. + + Provides a limited set of comms channels (Connections and Variables) so + that callback code can access runtime services like + ``Connection.get()`` and ``Variable.get()`` via the supervisor's API client. """ - decoder: ClassVar[TypeAdapter] = _EmptyMessage + client: Client # The HTTP client to use for communication with the API server. + + decoder: ClassVar[TypeAdapter[CallbackToSupervisor]] = TypeAdapter(CallbackToSupervisor) @classmethod def start( # type: ignore[override] @@ -150,6 +173,7 @@ def start( # type: ignore[override] callback_path: str, callback_kwargs: dict, bundle_info: _BundleInfoLike | None = None, + client: Client, logger: FilteringBoundLogger | None = None, **kwargs, ) -> Self: @@ -159,7 +183,11 @@ def start( # type: ignore[override] # ONLY works because WatchedSubprocess.start() uses os.fork(), so the child # inherits the parent's memory space and the variables are available directly. def _target(): + from airflow.sdk.execution_time import task_runner + from airflow.sdk.execution_time.comms import CommsDecoder, ToTask + _log = structlog.get_logger(logger_name="callback_runner") + task_runner.SUPERVISOR_COMMS = CommsDecoder[ToTask, CallbackToSupervisor](log=_log) # If bundle info is provided, initialize the bundle and ensure its path is importable. # This is needed for user-defined callbacks that live inside a DAG bundle rather than @@ -192,6 +220,7 @@ def _target(): return super().start( id=UUID(id) if not isinstance(id, UUID) else id, + client=client, target=_target, logger=logger, **kwargs, @@ -241,9 +270,35 @@ def _monitor_subprocess(self): ) self._cleanup_open_sockets() - def _handle_request(self, msg, log: FilteringBoundLogger, req_id: int) -> None: - """Handle incoming requests from the callback subprocess (currently none expected).""" - log.warning("Unexpected request from callback subprocess", msg=msg) + def _handle_request(self, msg: CallbackToSupervisor, log: FilteringBoundLogger, req_id: int) -> None: + """Handle incoming requests from the callback subprocess.""" + if isinstance(msg, MaskSecret): + log.debug("Received request from callback (body omitted)", msg=type(msg)) + else: + log.debug("Received request from callback", msg=msg) + + resp: BaseModel | None = None + dump_opts: dict[str, bool] = {} + + if isinstance(msg, GetConnection): + resp, dump_opts = handle_get_connection(self.client, msg) + elif isinstance(msg, GetVariable): + resp, dump_opts = handle_get_variable(self.client, msg) + elif isinstance(msg, MaskSecret): + handle_mask_secret(msg) + else: + log.warning("Unhandled request from callback subprocess", msg=msg) + self.send_msg( + None, + request_id=req_id, + error=ErrorResponse( + error=ErrorType.API_SERVER_ERROR, + detail={"status_code": 400, "message": "Unhandled request"}, + ), + ) + return + + self.send_msg(resp, request_id=req_id, error=None, **dump_opts) def _configure_logging(log_path: str) -> tuple[FilteringBoundLogger, BinaryIO]: @@ -266,6 +321,9 @@ def supervise_callback( callback_kwargs: dict, log_path: str | None = None, bundle_info: _BundleInfoLike | None = None, + token: str = "", + server: str | None = None, + client: Client | None = None, ) -> int: """ Run a single callback execution to completion in a supervised subprocess. @@ -275,6 +333,9 @@ def supervise_callback( :param callback_kwargs: Keyword arguments to pass to the callback. :param log_path: Path to write logs, if required. :param bundle_info: When provided, the bundle's path is added to sys.path so callbacks in Dag Bundles are importable. + :param token: Authentication token for the API client. + :param server: Base URL of the API server. + :param client: Optional preconfigured client for communication with the server (mostly for tests). :return: Exit code of the subprocess (0 = success). """ _make_process_nondumpable() @@ -290,28 +351,30 @@ def supervise_callback( # so logs are clearly separated from task logs. logger = structlog.get_logger(logger_name="callback").bind() - try: - process = CallbackSubprocess.start( - id=id, - callback_path=callback_path, - callback_kwargs=callback_kwargs, - bundle_info=bundle_info, - logger=logger, - subprocess_logs_to_stdout=True, - ) - - exit_code = process.wait() - end = time.monotonic() - log.info( - "Workload finished", - workload_type="ExecutorCallback", - workload_id=id, - exit_code=exit_code, - duration=end - start, - ) - if exit_code != 0: - raise RuntimeError(f"Callback subprocess exited with code {exit_code}") - return exit_code - finally: - if log_path and log_file_descriptor: - log_file_descriptor.close() + with _ensure_client(server, token, client=client) as client: + try: + process = CallbackSubprocess.start( + id=id, + callback_path=callback_path, + callback_kwargs=callback_kwargs, + bundle_info=bundle_info, + client=client, + logger=logger, + subprocess_logs_to_stdout=True, + ) + + exit_code = process.wait() + end = time.monotonic() + log.info( + "Workload finished", + workload_type="ExecutorCallback", + workload_id=id, + exit_code=exit_code, + duration=end - start, + ) + if exit_code != 0: + raise RuntimeError(f"Callback subprocess exited with code {exit_code}") + return exit_code + finally: + if log_path and log_file_descriptor: + log_file_descriptor.close() diff --git a/task-sdk/src/airflow/sdk/execution_time/request_handlers.py b/task-sdk/src/airflow/sdk/execution_time/request_handlers.py new file mode 100644 index 0000000000000..eed3e840e395b --- /dev/null +++ b/task-sdk/src/airflow/sdk/execution_time/request_handlers.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Shared request handlers for supervised subprocess comms channels. + +These functions implement the supervisor-side logic for message types that are +used by more than one subprocess type (tasks, callbacks, triggerer). Each +handler accepts a ``Client`` and a request message and returns +``(response_model | None, dump_opts)`` so the caller can forward the result +via ``send_msg``. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.sdk.api.datamodels._generated import ( + ConnectionResponse, + VariableResponse, +) +from airflow.sdk.execution_time.comms import ( + ConnectionResult, + GetConnection, + GetVariable, + MaskSecret, + VariableResult, +) +from airflow.sdk.log import mask_secret + +if TYPE_CHECKING: + from pydantic import BaseModel + + from airflow.sdk.api.client import Client + + +def handle_get_connection(client: Client, msg: GetConnection) -> tuple[BaseModel | None, dict[str, bool]]: + """Fetch a connection and mask its sensitive fields.""" + conn = client.connections.get(msg.conn_id) + if isinstance(conn, ConnectionResponse): + if conn.password: + mask_secret(conn.password) + if conn.extra: + mask_secret(conn.extra) + return ConnectionResult.from_conn_response(conn), {"exclude_unset": True, "by_alias": True} + return conn, {} + + +def handle_get_variable(client: Client, msg: GetVariable) -> tuple[BaseModel | None, dict[str, bool]]: + """Fetch a variable and mask its value.""" + var = client.variables.get(msg.key) + if isinstance(var, VariableResponse): + if var.value: + mask_secret(var.value, var.key) + return VariableResult.from_variable_response(var), {"exclude_unset": True} + return var, {} + + +def handle_mask_secret(msg: MaskSecret) -> None: + """Register a value with the secrets masker.""" + mask_secret(msg.value, msg.name) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index c50199df7e490..6249765bf579d 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -55,7 +55,6 @@ TaskInstance, TaskInstanceState, TaskStatesResponse, - VariableResponse, XComSequenceIndexResponse, ) from airflow.sdk.configuration import conf @@ -115,14 +114,17 @@ ToSupervisor, TriggerDagRun, ValidateInletsAndOutlets, - VariableResult, XComResult, XComSequenceIndexResult, XComSequenceSliceResult, _RequestFrame, _ResponseFrame, ) -from airflow.sdk.log import mask_secret +from airflow.sdk.execution_time.request_handlers import ( + handle_get_connection, + handle_get_variable, + handle_mask_secret, +) try: from socket import send_fds @@ -902,6 +904,33 @@ def _fetch_remote_logging_conn(conn_id: str, client: Client) -> Connection | Non return result +@contextlib.contextmanager +def _ensure_client( + server: str | None, + token: str, + client: Client | None = None, + dry_run: bool = False, +) -> Generator[Client, None, None]: + """ + Yield an API client, creating one if not provided. + + If a client is created internally, it will be closed when the context exits. + Pre-existing clients are yielded as-is and left open for the caller to manage. + """ + if client: + yield client + return + + limits = httpx.Limits(max_keepalive_connections=1, max_connections=10) + new_client = Client(base_url=server or "", limits=limits, dry_run=dry_run, token=token) + log.debug("Connecting to execution API server", server=server) + try: + yield new_client + finally: + with suppress(Exception): + new_client.close() + + @contextlib.contextmanager def _remote_logging_conn(client: Client): """ @@ -1252,7 +1281,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: else: log.debug("Received message from task runner", msg=msg) resp: BaseModel | None = None - dump_opts = {} + dump_opts: dict[str, bool] = {} if isinstance(msg, TaskState): self._terminal_state = msg.state self._task_end_time_monotonic = time.monotonic() @@ -1278,27 +1307,9 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: rendered_map_index=self._rendered_map_index, ) elif isinstance(msg, GetConnection): - conn = self.client.connections.get(msg.conn_id) - if isinstance(conn, ConnectionResponse): - if conn.password: - mask_secret(conn.password) - if conn.extra: - mask_secret(conn.extra) - conn_result = ConnectionResult.from_conn_response(conn) - resp = conn_result - dump_opts = {"exclude_unset": True, "by_alias": True} - else: - resp = conn + resp, dump_opts = handle_get_connection(self.client, msg) elif isinstance(msg, GetVariable): - var = self.client.variables.get(msg.key) - if isinstance(var, VariableResponse): - if var.value: - mask_secret(var.value, var.key) - var_result = VariableResult.from_variable_response(var) - resp = var_result - dump_opts = {"exclude_unset": True} - else: - resp = var + resp, dump_opts = handle_get_variable(self.client, msg) elif isinstance(msg, GetXCom): xcom = self.client.xcoms.get( msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index, msg.include_prior_dates @@ -1484,7 +1495,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: resp = HITLDetailRequestResult.from_api_response(hitl_detail_request) dump_opts = {"exclude_unset": True} elif isinstance(msg, MaskSecret): - mask_secret(msg.value, msg.name) + handle_mask_secret(msg) elif isinstance(msg, GetDag): dag = self.client.dags.get( dag_id=msg.dag_id, @@ -2075,58 +2086,49 @@ def supervise_task( if not dag_rel_path: raise ValueError("dag_path is required") - close_client = False - if not client: - limits = httpx.Limits(max_keepalive_connections=1, max_connections=10) - client = Client(base_url=server or "", limits=limits, dry_run=dry_run, token=token) - close_client = True - log.debug("Connecting to execution API server", server=server) - - start = time.monotonic() + with _ensure_client(server, token, client=client, dry_run=dry_run) as client: + start = time.monotonic() - # TODO: Use logging providers to handle the chunked upload for us etc. - logger: FilteringBoundLogger | None = None - log_file_descriptor: BinaryIO | TextIO | None = None - if log_path: - logger, log_file_descriptor = _configure_logging(log_path, client) + # TODO: Use logging providers to handle the chunked upload for us etc. + logger: FilteringBoundLogger | None = None + log_file_descriptor: BinaryIO | TextIO | None = None + if log_path: + logger, log_file_descriptor = _configure_logging(log_path, client) - backends = ensure_secrets_backend_loaded() - log.info( - "Secrets backends loaded for worker", - count=len(backends), - backend_classes=[type(b).__name__ for b in backends], - ) + backends = ensure_secrets_backend_loaded() + log.info( + "Secrets backends loaded for worker", + count=len(backends), + backend_classes=[type(b).__name__ for b in backends], + ) - reset_secrets_masker() + reset_secrets_masker() - try: - process = ActivitySubprocess.start( - dag_rel_path=dag_rel_path, - what=ti, - client=client, - logger=logger, - bundle_info=bundle_info, - subprocess_logs_to_stdout=subprocess_logs_to_stdout, - sentry_integration=sentry_integration, - ) + try: + process = ActivitySubprocess.start( + dag_rel_path=dag_rel_path, + what=ti, + client=client, + logger=logger, + bundle_info=bundle_info, + subprocess_logs_to_stdout=subprocess_logs_to_stdout, + sentry_integration=sentry_integration, + ) - exit_code = process.wait() - end = time.monotonic() - log.info( - "Workload finished", - workload_type="ExecuteTask", - workload_id=str(ti.id), - exit_code=exit_code, - duration=end - start, - final_state=process.final_state, - ) - return exit_code - finally: - if log_path and log_file_descriptor: - log_file_descriptor.close() - if close_client and client: - with suppress(Exception): - client.close() + exit_code = process.wait() + end = time.monotonic() + log.info( + "Workload finished", + workload_type="ExecuteTask", + workload_id=str(ti.id), + exit_code=exit_code, + duration=end - start, + final_state=process.final_state, + ) + return exit_code + finally: + if log_path and log_file_descriptor: + log_file_descriptor.close() def supervise(**kwargs) -> int: diff --git a/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py index 93b459d200e87..8cb9fdcc8167a 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py @@ -1,3 +1,4 @@ +# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -18,10 +19,24 @@ from __future__ import annotations +import socket +from dataclasses import dataclass +from operator import attrgetter +from typing import Any +from unittest.mock import patch + import pytest import structlog -from airflow.sdk.execution_time.callback_supervisor import execute_callback +from airflow.sdk.execution_time.callback_supervisor import CallbackSubprocess, execute_callback +from airflow.sdk.execution_time.comms import ( + ConnectionResult, + GetConnection, + GetVariable, + MaskSecret, + VariableResult, + _RequestFrame, +) def callback_no_args(): @@ -113,3 +128,103 @@ def test_execute_callback(self, path, kwargs, expect_success, error_contains): assert error_contains in error else: assert error is None + + +class TestCallbackHandleRequest: + """Verify that CallbackSubprocess._handle_request dispatches each message type to the correct handler.""" + + @dataclass + class ClientMock: + method_path: str + args: tuple = () + kwargs: dict | None = None + response: Any = None + + def __post_init__(self): + if self.kwargs is None: + self.kwargs = {} + + @dataclass + class RequestCase: + message: Any + test_id: str + client_mock: Any = None # Should be ClientMock but Python can't forward-ref sibling nested classes + mask_secret_args: tuple | None = None + + REQUEST_CASES = [ + RequestCase( + message=GetConnection(conn_id="test_conn"), + test_id="get_connection", + client_mock=ClientMock( + method_path="connections.get", + args=("test_conn",), + response=ConnectionResult(conn_id="test_conn", conn_type="mysql"), + ), + ), + RequestCase( + message=GetConnection(conn_id="test_conn"), + test_id="get_connection_with_password", + client_mock=ClientMock( + method_path="connections.get", + args=("test_conn",), + response=ConnectionResult(conn_id="test_conn", conn_type="mysql", password="secret"), + ), + mask_secret_args=("secret",), + ), + RequestCase( + message=GetVariable(key="test_key"), + test_id="get_variable", + client_mock=ClientMock( + method_path="variables.get", + args=("test_key",), + response=VariableResult(key="test_key", value="test_value"), + ), + ), + RequestCase( + message=MaskSecret(value="super_secret", name="api_key"), + test_id="mask_secret", + mask_secret_args=("super_secret", "api_key"), + ), + ] + + @pytest.fixture + def callback_subprocess(self, mocker): + read_end, write_end = socket.socketpair() + proc = CallbackSubprocess( + process_log=mocker.MagicMock(), + id="12345678-1234-5678-1234-567812345678", + pid=12345, + stdin=write_end, + client=mocker.Mock(), + process=mocker.Mock(), + ) + return proc, read_end + + @patch("airflow.sdk.execution_time.request_handlers.mask_secret") + @pytest.mark.parametrize("test_case", REQUEST_CASES, ids=lambda tc: tc.test_id) + def test_handle_requests( + self, + mock_mask_secret, + callback_subprocess, + mocker, + test_case, + ): + client_mock = test_case.client_mock + + proc, _read_end = callback_subprocess + + if client_mock: + mock_client_method = attrgetter(client_mock.method_path)(proc.client) + mock_client_method.return_value = client_mock.response + + generator = proc.handle_requests(log=mocker.Mock()) + next(generator) + + req_frame = _RequestFrame(id=42, body=test_case.message.model_dump()) + generator.send(req_frame) + + if test_case.mask_secret_args is not None: + mock_mask_secret.assert_called_with(*test_case.mask_secret_args) + + if client_mock: + mock_client_method.assert_called_once_with(*client_mock.args, **client_mock.kwargs) diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 17fe9ec1defa8..1f4e55d73c1fa 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -2576,7 +2576,7 @@ def watched_subprocess(self, mocker): return subprocess, read_end - @patch("airflow.sdk.execution_time.supervisor.mask_secret") + @patch("airflow.sdk.execution_time.request_handlers.mask_secret") @pytest.mark.parametrize("test_case", REQUEST_TEST_CASES, ids=lambda tc: tc.test_id) def test_handle_requests( self,