diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 526e95a8aa886..bee00082c44a4 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -511,6 +511,7 @@ du duckdb dunder dup +durable durations dylib Dynamodb diff --git a/providers/common/ai/docs/index.rst b/providers/common/ai/docs/index.rst index fcb8abc021d17..c421b139ce71f 100644 --- a/providers/common/ai/docs/index.rst +++ b/providers/common/ai/docs/index.rst @@ -106,7 +106,7 @@ PIP package Version required ``apache-airflow`` ``>=3.0.0`` ``apache-airflow-providers-common-compat`` ``>=1.14.1`` ``apache-airflow-providers-standard`` ``>=1.12.1`` -``pydantic-ai-slim`` ``>=1.14.0`` +``pydantic-ai-slim`` ``>=1.34.0`` ========================================== ================== Cross provider package dependencies diff --git a/providers/common/ai/docs/operators/agent.rst b/providers/common/ai/docs/operators/agent.rst index 8052facd59f8b..71d55617ec07e 100644 --- a/providers/common/ai/docs/operators/agent.rst +++ b/providers/common/ai/docs/operators/agent.rst @@ -111,6 +111,106 @@ tasks can consume it. :end-before: [END howto_agent_chain] +Durable Execution +----------------- + +Agent tasks can involve multiple LLM calls and tool invocations. If a task +fails mid-run (network error, timeout, transient API failure), a plain retry +re-executes every LLM call and tool call from scratch -- repeating work that +already succeeded and incurring additional cost. + +Setting ``durable=True`` caches each LLM response and tool result to +ObjectStorage as it completes. On retry, completed steps are replayed from the +cache and only the remaining steps run against the live model and tools. The +cache is deleted after successful completion. + +Durable execution only helps when the task has retries configured. Without +retries there is nothing to replay. + +**Configuration** + +Set the cache location in ``airflow.cfg``. The task raises ``ValueError`` at +runtime if ``durable=True`` and the option is missing. + +.. code-block:: ini + + [common.ai] + # Local filesystem -- suitable for development + durable_cache_path = file:///tmp/airflow_durable_cache + +The value is an ObjectStorage URI, so any supported backend works. For +production, use a shared store so retries on a different worker can read the +cache: + +.. code-block:: ini + + [common.ai] + durable_cache_path = s3://my-bucket/airflow/durable-cache + +**Operator example** + +.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_agent_durable.py + :language: python + :start-after: [START howto_operator_agent_durable] + :end-before: [END howto_operator_agent_durable] + +**Decorator example** + +.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_agent_durable.py + :language: python + :start-after: [START howto_decorator_agent_durable] + :end-before: [END howto_decorator_agent_durable] + +**How it works** + +1. On first execution, each LLM response and tool result is saved to a JSON + file as the agent progresses. +2. If the task fails and Airflow retries it, completed steps are loaded from + the cache and returned without calling the model or tool. Steps not yet in + the cache proceed normally. +3. After successful completion, the cache file is deleted. + +After the run, a single INFO summary line reports how many steps were +replayed vs executed fresh. Per-step detail is available at DEBUG level. + +The cache file is named ``{dag_id}_{task_id}_{run_id}.json`` (with +``_{map_index}`` appended for mapped tasks) and stored under the configured +``durable_cache_path``. To force a completely fresh run, delete the cache file +for that task. + +.. note:: + + Runs that fail permanently (exhaust all retries) leave their cache file + behind. These orphaned files do not affect future DAG runs (each run gets + its own file) but will consume storage. Clean them up periodically or add + a lifecycle policy to the storage backend. + +**Side effects and idempotency** + +Durable execution caches **return values**, not side effects. When a step is +replayed, the tool's code does not run -- only the stored return value is +returned. Two things follow from this: + +- If a tool completed successfully and its result was cached, the tool will + **not** run again on retry. Any side effect it produced (writing a file, + sending a message) already happened during the original run and is not + repeated. +- If a tool fails *before* its result is cached, it **will** run again on + retry. A tool that partially completed (e.g. sent an email then raised an + exception) may produce the side effect a second time. + +All built-in toolsets (``SQLToolset`` with ``allow_writes=False``, +``HookToolset`` in read-only mode) are read-only and replay safely. For custom +tools with non-idempotent side effects, design the tool to be idempotent. For +example, check whether the operation already completed before acting, or +use database constraints to prevent duplicate writes. + +Tool results must be JSON-serializable to be cached. If a tool returns a +non-serializable value (e.g. ``BinaryContent`` from MCP tools), that step is +skipped with a warning and will re-execute on retry instead of replaying from +cache. The task itself still succeeds. + + Parameters ---------- @@ -130,6 +230,10 @@ Parameters every tool call is logged in real time. Default ``True``. - ``agent_params``: Additional keyword arguments passed to the pydantic-ai ``Agent`` constructor (e.g. ``retries``, ``model_settings``). +- ``durable``: When ``True``, enables step-level caching of model responses and + tool results via ObjectStorage. On retry, cached steps are replayed instead of + re-executing expensive LLM calls. Requires the ``[common.ai] durable_cache_path`` + config option to be set. Default ``False``. Logging diff --git a/providers/common/ai/pyproject.toml b/providers/common/ai/pyproject.toml index 203497e401dd3..9dff4932318ae 100644 --- a/providers/common/ai/pyproject.toml +++ b/providers/common/ai/pyproject.toml @@ -69,7 +69,7 @@ dependencies = [ "apache-airflow>=3.0.0", "apache-airflow-providers-common-compat>=1.14.1", "apache-airflow-providers-standard>=1.12.1", - "pydantic-ai-slim>=1.14.0", + "pydantic-ai-slim>=1.34.0", ] # The optional dependencies should be modified in place in the generated file diff --git a/providers/common/ai/src/airflow/providers/common/ai/durable/__init__.py b/providers/common/ai/src/airflow/providers/common/ai/durable/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/common/ai/src/airflow/providers/common/ai/durable/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/common/ai/src/airflow/providers/common/ai/durable/caching_model.py b/providers/common/ai/src/airflow/providers/common/ai/durable/caching_model.py new file mode 100644 index 0000000000000..0b2f85ecb400e --- /dev/null +++ b/providers/common/ai/src/airflow/providers/common/ai/durable/caching_model.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Caching model wrapper for durable execution.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import structlog +from pydantic_ai.models.wrapper import WrapperModel + +log = structlog.get_logger(logger_name="task") + +if TYPE_CHECKING: + from pydantic_ai.messages import ModelMessage, ModelResponse + from pydantic_ai.models import ModelRequestParameters + from pydantic_ai.settings import ModelSettings + + from airflow.providers.common.ai.durable.step_counter import DurableStepCounter + from airflow.providers.common.ai.durable.storage import DurableStorage + + +@dataclass(init=False) +class CachingModel(WrapperModel): + """ + Wraps a model to cache responses in ObjectStorage for durable execution. + + On each ``request()`` call, checks if a cached response exists for the + current step index. If so, returns the cached response without calling + the underlying model. Otherwise, calls the model and caches the response. + """ + + storage: DurableStorage = field(repr=False) + counter: DurableStepCounter = field(repr=False) + + def __init__( + self, + wrapped: Any, + *, + storage: DurableStorage, + counter: DurableStepCounter, + ) -> None: + super().__init__(wrapped) + self.storage = storage + self.counter = counter + + async def request( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> ModelResponse: + step = self.counter.next_step() + key = f"model_step_{step}" + + cached = self.storage.load_model_response(key) + if cached is not None: + self.counter.replayed_model += 1 + log.debug("Durable: replayed cached model response", step=step) + return cached + + response = await self.wrapped.request(messages, model_settings, model_request_parameters) + self.storage.save_model_response(key, response) + self.counter.cached_model += 1 + log.debug("Durable: cached model response", step=step) + return response diff --git a/providers/common/ai/src/airflow/providers/common/ai/durable/caching_toolset.py b/providers/common/ai/src/airflow/providers/common/ai/durable/caching_toolset.py new file mode 100644 index 0000000000000..2fd58fe78a40e --- /dev/null +++ b/providers/common/ai/src/airflow/providers/common/ai/durable/caching_toolset.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Caching toolset wrapper for durable execution.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import structlog +from pydantic_ai.toolsets.wrapper import WrapperToolset + +if TYPE_CHECKING: + from pydantic_ai.toolsets.abstract import ToolsetTool + + from airflow.providers.common.ai.durable.step_counter import DurableStepCounter + from airflow.providers.common.ai.durable.storage import DurableStorage + +log = structlog.get_logger(logger_name="task") + + +@dataclass +class CachingToolset(WrapperToolset[Any]): + """ + Wraps a toolset to cache tool call results in ObjectStorage for durable execution. + + On each ``call_tool()`` invocation, checks if a cached result exists for + the current step index. If so, returns the cached result without executing + the tool. Otherwise, executes the tool and caches the result. + + The step index is grabbed before the first ``await``, so parallel tool + calls via ``asyncio.gather`` get deterministic indices (tasks start + executing their synchronous preamble in creation order). + """ + + storage: DurableStorage = field(repr=False) + counter: DurableStepCounter = field(repr=False) + + async def call_tool( + self, + name: str, + tool_args: dict[str, Any], + ctx: Any, + tool: ToolsetTool[Any], + ) -> Any: + # Grab step index BEFORE any await -- ensures deterministic ordering + # even when multiple tool calls run concurrently via asyncio.gather. + step = self.counter.next_step() + key = f"tool_step_{step}" + + found, cached = self.storage.load_tool_result(key) + if found: + self.counter.replayed_tool += 1 + log.debug("Durable: replayed cached tool result", step=step, tool=name) + return cached + + result = await self.wrapped.call_tool(name, tool_args, ctx, tool) + self.storage.save_tool_result(key, result) + self.counter.cached_tool += 1 + log.debug("Durable: cached tool result", step=step, tool=name) + return result diff --git a/providers/common/ai/src/airflow/providers/common/ai/durable/step_counter.py b/providers/common/ai/src/airflow/providers/common/ai/durable/step_counter.py new file mode 100644 index 0000000000000..85643f9b34a5a --- /dev/null +++ b/providers/common/ai/src/airflow/providers/common/ai/durable/step_counter.py @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Shared step counter for durable execution caching.""" + +from __future__ import annotations + + +class DurableStepCounter: + """ + Monotonically increasing counter shared between CachingModel and CachingToolset. + + Each model call and tool call increments the counter. The step index + is used as the cache key, ensuring deterministic replay on retry. + """ + + def __init__(self) -> None: + self._step: int = 0 + self.replayed_model: int = 0 + self.replayed_tool: int = 0 + self.cached_model: int = 0 + self.cached_tool: int = 0 + + def next_step(self) -> int: + """Return the current step and advance the counter.""" + step = self._step + self._step += 1 + return step + + @property + def total_steps(self) -> int: + """Total number of steps executed so far.""" + return self._step diff --git a/providers/common/ai/src/airflow/providers/common/ai/durable/storage.py b/providers/common/ai/src/airflow/providers/common/ai/durable/storage.py new file mode 100644 index 0000000000000..7a6a03e33b7a9 --- /dev/null +++ b/providers/common/ai/src/airflow/providers/common/ai/durable/storage.py @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""ObjectStorage-backed durable storage for pydantic-ai agent step caching.""" + +from __future__ import annotations + +import json +from functools import lru_cache +from typing import Any + +import structlog +from pydantic_ai.messages import ModelMessagesTypeAdapter, ModelResponse + +log = structlog.get_logger(logger_name="task") + +# Sentinel to distinguish "cached None" from "no cache entry" for tool results. +_SENTINEL = "__durable_cached__" + +SECTION = "common.ai" + + +@lru_cache(maxsize=1) +def _get_base_path(): + from airflow.configuration import conf + from airflow.sdk import ObjectStoragePath + + path = conf.get(SECTION, "durable_cache_path", fallback="") + if not path: + raise ValueError( + "durable=True requires [common.ai] durable_cache_path to be set. " + "Example: durable_cache_path = file:///tmp/airflow_durable_cache" + ) + return ObjectStoragePath(path) + + +class DurableStorage: + """ + Stores step-level caches in a single JSON file on ObjectStorage. + + All step caches (model responses and tool results) are stored as entries + in a single JSON blob, written to a file named after the task execution: + ``{base_path}/{dag_id}_{task_id}_{run_id}[_{map_index}].json``. + + The file survives Airflow task retries since it lives outside the + XCom system. It is deleted on successful task completion. + + :param dag_id: DAG ID of the running task. + :param task_id: Task ID of the running task. + :param run_id: DAG run ID. + :param map_index: Map index for mapped tasks (``-1`` for non-mapped). + """ + + def __init__( + self, + *, + dag_id: str, + task_id: str, + run_id: str, + map_index: int = -1, + ) -> None: + suffix = f"_{map_index}" if map_index >= 0 else "" + self._cache_id = f"{dag_id}_{task_id}_{run_id}{suffix}" + self._cache: dict[str, Any] | None = None + + def _get_path(self): + return _get_base_path() / f"{self._cache_id}.json" + + def _load_cache(self) -> dict[str, Any]: + """Load the full cache blob from storage, with in-memory caching.""" + if self._cache is not None: + return self._cache + + path = self._get_path() + try: + self._cache = json.loads(path.read_text()) + except (FileNotFoundError, OSError, json.JSONDecodeError, ValueError): + self._cache = {} + + return self._cache + + def _save_cache(self) -> None: + """Persist the in-memory cache blob to storage.""" + path = self._get_path() + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(self._cache)) + + def save_model_response(self, key: str, response: ModelResponse) -> None: + """Serialize and store a ModelResponse in the cache.""" + cache = self._load_cache() + cache[key] = ModelMessagesTypeAdapter.dump_json([response]).decode() + self._save_cache() + + def load_model_response(self, key: str) -> ModelResponse | None: + """Load a cached ModelResponse, or return None if not cached.""" + cache = self._load_cache() + raw = cache.get(key) + if raw is None: + return None + messages = ModelMessagesTypeAdapter.validate_json(raw) + return messages[0] # type: ignore[return-value] + + def save_tool_result(self, key: str, result: Any) -> None: + """ + Store a tool call result in the cache. + + Non-serializable results (e.g. BinaryContent from MCP tools) are + skipped with a warning -- the tool call still succeeds, but won't + be replayed on retry. + """ + cache = self._load_cache() + try: + cache[key] = json.dumps({_SENTINEL: True, "value": result}) + except TypeError: + log.warning( + "Durable: skipping cache for non-serializable tool result", + key=key, + type=type(result).__name__, + ) + return + self._save_cache() + + def load_tool_result(self, key: str) -> tuple[bool, Any]: + """ + Load a cached tool result. + + Returns (found, value) tuple since the cached value itself could be None. + """ + cache = self._load_cache() + raw = cache.get(key) + if raw is None: + return False, None + parsed = json.loads(raw) + if not isinstance(parsed, dict) or _SENTINEL not in parsed: + return False, None + return True, parsed["value"] + + def cleanup(self) -> None: + """Delete the cache file after successful execution.""" + try: + self._get_path().unlink() + except (FileNotFoundError, OSError): + pass # Best-effort cleanup + self._cache = None + log.debug("Durable cache cleaned up") diff --git a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent_durable.py b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent_durable.py new file mode 100644 index 0000000000000..e2d3ec6e3fc25 --- /dev/null +++ b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent_durable.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Example DAGs demonstrating durable execution with AgentOperator and @task.agent.""" + +from __future__ import annotations + +from datetime import timedelta + +from airflow.providers.common.ai.operators.agent import AgentOperator +from airflow.providers.common.ai.toolsets.sql import SQLToolset +from airflow.providers.common.compat.sdk import dag, task + +# --------------------------------------------------------------------------- +# 1. Durable AgentOperator: resumes from last model call on retry +# --------------------------------------------------------------------------- + + +# [START howto_operator_agent_durable] +@dag(default_args={"retries": 3, "retry_delay": timedelta(seconds=30)}) +def example_agent_durable_operator(): + """Agent with durable execution -- resumes from the last model call on retry.""" + AgentOperator( + task_id="durable_analyst", + prompt="What are the top 5 customers by order count?", + llm_conn_id="pydanticai_default", + system_prompt=( + "You are a SQL analyst. Use the available tools to explore " + "the schema and answer the question with data." + ), + durable=True, + toolsets=[ + SQLToolset( + db_conn_id="postgres_default", + allowed_tables=["customers", "orders"], + max_rows=20, + ) + ], + ) + + +# [END howto_operator_agent_durable] + +example_agent_durable_operator() + + +# --------------------------------------------------------------------------- +# 2. Durable @task.agent decorator +# --------------------------------------------------------------------------- + + +# [START howto_decorator_agent_durable] +@dag(default_args={"retries": 3, "retry_delay": timedelta(seconds=30)}) +def example_agent_durable_decorator(): + @task.agent( + llm_conn_id="pydanticai_default", + system_prompt="You are a data analyst. Use tools to answer questions.", + durable=True, + toolsets=[ + SQLToolset( + db_conn_id="postgres_default", + allowed_tables=["orders"], + ) + ], + ) + def analyze(question: str): + return f"Answer this question about our orders data: {question}" + + analyze("What was our total revenue last month?") + + +# [END howto_decorator_agent_durable] + +example_agent_durable_decorator() diff --git a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py index da8ecafc1c562..d4d42c718f9bf 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py +++ b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py @@ -40,6 +40,8 @@ from pydantic_ai import Agent from pydantic_ai.toolsets.abstract import AbstractToolset + from airflow.providers.common.ai.durable.step_counter import DurableStepCounter + from airflow.providers.common.ai.durable.storage import DurableStorage from airflow.providers.common.compat.sdk import TaskInstanceKey from airflow.sdk import Context @@ -101,6 +103,10 @@ class AgentOperator(BaseOperator, HITLReviewMixin): arguments at DEBUG level. Set to ``False`` to disable. :param agent_params: Additional keyword arguments passed to the pydantic-ai ``Agent`` constructor (e.g. ``retries``, ``model_settings``). + :param durable: When ``True``, enables step-level caching of model + responses and tool results for durable execution. On retry, cached + steps are replayed instead of re-executing. Default ``False``. + Requires ``[common.ai] durable_cache_path`` to be set. **HITL Review parameters** (requires the ``hitl_review`` plugin): @@ -142,6 +148,7 @@ def __init__( toolsets: list[AbstractToolset] | None = None, enable_tool_logging: bool = True, agent_params: dict[str, Any] | None = None, + durable: bool = False, # Agent feedback parameters enable_hitl_review: bool = False, max_hitl_iterations: int = 5, @@ -160,6 +167,11 @@ def __init__( self.enable_tool_logging = enable_tool_logging self.agent_params = agent_params or {} + self.durable = durable + + if durable and enable_hitl_review: + raise ValueError("durable=True and enable_hitl_review=True cannot be used together.") + self.enable_hitl_review = enable_hitl_review self.max_hitl_iterations = max_hitl_iterations self.hitl_timeout = hitl_timeout @@ -182,20 +194,84 @@ def _build_agent(self) -> Agent[None, Any]: """Build and return a pydantic-ai Agent from the operator's config.""" extra_kwargs = dict(self.agent_params) if self.toolsets: + toolsets = self.toolsets + if self.durable and self._durable_storage is not None and self._durable_counter is not None: + toolsets = self._build_durable_toolsets( + toolsets, self._durable_storage, self._durable_counter + ) if self.enable_tool_logging: - extra_kwargs["toolsets"] = wrap_toolsets_for_logging(self.toolsets, self.log) - else: - extra_kwargs["toolsets"] = self.toolsets + toolsets = wrap_toolsets_for_logging(toolsets, self.log) + extra_kwargs["toolsets"] = toolsets return self.llm_hook.create_agent( output_type=self.output_type, instructions=self.system_prompt, **extra_kwargs, ) + def _build_durable_toolsets( + self, toolsets: list[AbstractToolset], storage: DurableStorage, counter: DurableStepCounter + ) -> list[AbstractToolset]: + """Wrap each toolset with CachingToolset for durable execution.""" + from airflow.providers.common.ai.durable.caching_toolset import CachingToolset + + return [CachingToolset(wrapped=ts, storage=storage, counter=counter) for ts in toolsets] + def execute(self, context: Context) -> Any: + self._durable_storage = None + self._durable_counter = None + + if self.durable: + from airflow.providers.common.ai.durable.step_counter import DurableStepCounter + from airflow.providers.common.ai.durable.storage import DurableStorage + + ti = context["task_instance"] + self._durable_storage = DurableStorage( + dag_id=ti.dag_id, + task_id=ti.task_id, + run_id=ti.run_id, + map_index=ti.map_index if ti.map_index is not None else -1, + ) + self._durable_counter = DurableStepCounter() + agent = self._build_agent() - result = agent.run_sync(self.prompt) + + storage = self._durable_storage + counter = self._durable_counter + if self.durable and storage is not None and counter is not None: + from pydantic_ai.models import infer_model + + from airflow.providers.common.ai.durable.caching_model import CachingModel + + if agent.model is None: + raise ValueError("Agent model must be set when durable=True") + resolved_model = infer_model(agent.model) + caching_model = CachingModel(resolved_model, storage=storage, counter=counter) + with agent.override(model=caching_model): + result = agent.run_sync(self.prompt) + else: + result = agent.run_sync(self.prompt) + log_run_summary(self.log, result) + + if self._durable_counter is not None: + c = self._durable_counter + replayed = c.replayed_model + c.replayed_tool + cached = c.cached_model + c.cached_tool + if replayed: + self.log.info( + "Durable: replayed %d cached steps (%d model, %d tool), " + "executed %d new steps (%d model, %d tool)", + replayed, + c.replayed_model, + c.replayed_tool, + cached, + c.cached_model, + c.cached_tool, + ) + + if self._durable_storage is not None: + self._durable_storage.cleanup() + output = result.output if self.enable_hitl_review: diff --git a/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py b/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py index f576632514009..6865db9a1dd77 100644 --- a/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py +++ b/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py @@ -137,3 +137,22 @@ class Summary(BaseModel): result = op.execute(context={}) assert result == {"text": "Great results"} + + def test_durable_kwarg_passes_through_to_operator(self): + """durable=True is forwarded to AgentOperator via **kwargs.""" + op = _AgentDecoratedOperator( + task_id="test", + python_callable=lambda: "prompt", + llm_conn_id="my_llm", + durable=True, + ) + assert op.durable is True + + def test_durable_default_false_through_decorator(self): + """durable defaults to False when not specified.""" + op = _AgentDecoratedOperator( + task_id="test", + python_callable=lambda: "prompt", + llm_conn_id="my_llm", + ) + assert op.durable is False diff --git a/providers/common/ai/tests/unit/common/ai/durable/__init__.py b/providers/common/ai/tests/unit/common/ai/durable/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/common/ai/tests/unit/common/ai/durable/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/common/ai/tests/unit/common/ai/durable/test_caching_model.py b/providers/common/ai/tests/unit/common/ai/durable/test_caching_model.py new file mode 100644 index 0000000000000..117f5d4a08c03 --- /dev/null +++ b/providers/common/ai/tests/unit/common/ai/durable/test_caching_model.py @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pydantic_ai.messages import ModelResponse, TextPart + +from airflow.providers.common.ai.durable.caching_model import CachingModel +from airflow.providers.common.ai.durable.step_counter import DurableStepCounter + + +@pytest.fixture +def mock_storage(): + storage = MagicMock() + storage.load_model_response.return_value = None + return storage + + +@pytest.fixture +def counter(): + return DurableStepCounter() + + +@pytest.fixture +def mock_model(): + model = MagicMock() + model.model_name = "test-model" + model.system = "test" + model.profile = MagicMock() + model.settings = None + return model + + +@pytest.fixture(autouse=True) +def _patch_infer_model(): + """Prevent WrapperModel.__init__ from resolving the mock as a real model.""" + with patch("pydantic_ai.models.wrapper.infer_model", side_effect=lambda m: m): + yield + + +@pytest.fixture +def sample_response(): + return ModelResponse(parts=[TextPart(content="Hello!")]) + + +class TestCachingModelCacheHit: + @pytest.mark.asyncio + async def test_returns_cached_response_without_calling_model( + self, mock_model, mock_storage, counter, sample_response + ): + mock_storage.load_model_response.return_value = sample_response + caching = CachingModel(mock_model, storage=mock_storage, counter=counter) + + result = await caching.request([], None, MagicMock()) + + assert result is sample_response + mock_model.request.assert_not_called() + mock_storage.load_model_response.assert_called_once_with("model_step_0") + + @pytest.mark.asyncio + async def test_advances_counter_on_cache_hit(self, mock_model, mock_storage, counter, sample_response): + mock_storage.load_model_response.return_value = sample_response + caching = CachingModel(mock_model, storage=mock_storage, counter=counter) + + await caching.request([], None, MagicMock()) + + assert counter.total_steps == 1 + + +class TestCachingModelCacheMiss: + @pytest.mark.asyncio + async def test_calls_model_and_caches_on_miss(self, mock_model, mock_storage, counter, sample_response): + mock_model.request = AsyncMock(return_value=sample_response) + caching = CachingModel(mock_model, storage=mock_storage, counter=counter) + + result = await caching.request([], None, MagicMock()) + + assert result is sample_response + mock_model.request.assert_called_once() + mock_storage.save_model_response.assert_called_once_with("model_step_0", sample_response) + + @pytest.mark.asyncio + async def test_sequential_calls_use_incrementing_keys(self, mock_model, mock_storage, counter): + response_1 = ModelResponse(parts=[TextPart(content="First")]) + response_2 = ModelResponse(parts=[TextPart(content="Second")]) + mock_model.request = AsyncMock(side_effect=[response_1, response_2]) + caching = CachingModel(mock_model, storage=mock_storage, counter=counter) + + await caching.request([], None, MagicMock()) + await caching.request([], None, MagicMock()) + + keys = [call[0][0] for call in mock_storage.save_model_response.call_args_list] + assert keys == ["model_step_0", "model_step_1"] diff --git a/providers/common/ai/tests/unit/common/ai/durable/test_caching_toolset.py b/providers/common/ai/tests/unit/common/ai/durable/test_caching_toolset.py new file mode 100644 index 0000000000000..d2d999d9218a4 --- /dev/null +++ b/providers/common/ai/tests/unit/common/ai/durable/test_caching_toolset.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pydantic_ai.messages import ModelResponse, TextPart + +from airflow.providers.common.ai.durable.caching_model import CachingModel +from airflow.providers.common.ai.durable.caching_toolset import CachingToolset +from airflow.providers.common.ai.durable.step_counter import DurableStepCounter + + +@pytest.fixture +def mock_storage(): + storage = MagicMock() + storage.load_tool_result.return_value = (False, None) + storage.load_model_response.return_value = None + return storage + + +@pytest.fixture +def counter(): + return DurableStepCounter() + + +@pytest.fixture +def mock_toolset(): + toolset = MagicMock() + toolset.call_tool = AsyncMock(return_value="fresh result") + toolset.get_tools = AsyncMock(return_value={}) + toolset.__aenter__ = AsyncMock(return_value=toolset) + toolset.__aexit__ = AsyncMock(return_value=None) + return toolset + + +class TestCachingToolsetCacheHit: + @pytest.mark.asyncio + async def test_returns_cached_result_without_calling_tool(self, mock_toolset, mock_storage, counter): + mock_storage.load_tool_result.return_value = (True, "cached result") + caching = CachingToolset(wrapped=mock_toolset, storage=mock_storage, counter=counter) + + result = await caching.call_tool("search", {"q": "foo"}, MagicMock(), MagicMock()) + + assert result == "cached result" + mock_toolset.call_tool.assert_not_called() + mock_storage.load_tool_result.assert_called_once_with("tool_step_0") + + @pytest.mark.asyncio + async def test_advances_counter_on_cache_hit(self, mock_toolset, mock_storage, counter): + mock_storage.load_tool_result.return_value = (True, "cached") + caching = CachingToolset(wrapped=mock_toolset, storage=mock_storage, counter=counter) + + await caching.call_tool("search", {}, MagicMock(), MagicMock()) + + assert counter.total_steps == 1 + + +class TestCachingToolsetCacheMiss: + @pytest.mark.asyncio + async def test_calls_tool_and_caches_on_miss(self, mock_toolset, mock_storage, counter): + caching = CachingToolset(wrapped=mock_toolset, storage=mock_storage, counter=counter) + + result = await caching.call_tool("search", {"q": "foo"}, MagicMock(), MagicMock()) + + assert result == "fresh result" + mock_toolset.call_tool.assert_called_once() + mock_storage.save_tool_result.assert_called_once_with("tool_step_0", "fresh result") + + @pytest.mark.asyncio + async def test_sequential_calls_use_incrementing_keys(self, mock_toolset, mock_storage, counter): + mock_toolset.call_tool = AsyncMock(side_effect=["result_a", "result_b"]) + caching = CachingToolset(wrapped=mock_toolset, storage=mock_storage, counter=counter) + + await caching.call_tool("tool_a", {}, MagicMock(), MagicMock()) + await caching.call_tool("tool_b", {}, MagicMock(), MagicMock()) + + keys = [call[0][0] for call in mock_storage.save_tool_result.call_args_list] + assert keys == ["tool_step_0", "tool_step_1"] + + +class TestSharedCounter: + @pytest.mark.asyncio + async def test_model_and_toolset_share_counter(self, mock_toolset, mock_storage): + """When CachingModel and CachingToolset share a counter, steps interleave correctly.""" + counter = DurableStepCounter() + + mock_model = MagicMock() + mock_model.model_name = "test" + mock_model.system = "test" + mock_model.profile = MagicMock() + mock_model.settings = None + + response = ModelResponse(parts=[TextPart(content="response")]) + mock_model.request = AsyncMock(return_value=response) + + with patch("pydantic_ai.models.wrapper.infer_model", side_effect=lambda m: m): + caching_model = CachingModel(mock_model, storage=mock_storage, counter=counter) + caching_toolset = CachingToolset(wrapped=mock_toolset, storage=mock_storage, counter=counter) + + # Simulate: model call -> tool call -> model call + await caching_model.request([], None, MagicMock()) + await caching_toolset.call_tool("search", {}, MagicMock(), MagicMock()) + await caching_model.request([], None, MagicMock()) + + model_keys = [call[0][0] for call in mock_storage.save_model_response.call_args_list] + tool_keys = [call[0][0] for call in mock_storage.save_tool_result.call_args_list] + + assert model_keys == ["model_step_0", "model_step_2"] + assert tool_keys == ["tool_step_1"] + assert counter.total_steps == 3 diff --git a/providers/common/ai/tests/unit/common/ai/durable/test_step_counter.py b/providers/common/ai/tests/unit/common/ai/durable/test_step_counter.py new file mode 100644 index 0000000000000..cd074989425a4 --- /dev/null +++ b/providers/common/ai/tests/unit/common/ai/durable/test_step_counter.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.common.ai.durable.step_counter import DurableStepCounter + + +class TestDurableStepCounter: + def test_starts_at_zero(self): + counter = DurableStepCounter() + assert counter.next_step() == 0 + + def test_increments_monotonically(self): + counter = DurableStepCounter() + assert counter.next_step() == 0 + assert counter.next_step() == 1 + assert counter.next_step() == 2 + + def test_total_steps_tracks_count(self): + counter = DurableStepCounter() + assert counter.total_steps == 0 + counter.next_step() + counter.next_step() + assert counter.total_steps == 2 diff --git a/providers/common/ai/tests/unit/common/ai/durable/test_storage.py b/providers/common/ai/tests/unit/common/ai/durable/test_storage.py new file mode 100644 index 0000000000000..507fd3126ac57 --- /dev/null +++ b/providers/common/ai/tests/unit/common/ai/durable/test_storage.py @@ -0,0 +1,139 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import patch + +import pytest +from pydantic_ai.messages import ( + ModelResponse, + TextPart, +) + +from airflow.providers.common.ai.durable.storage import DurableStorage +from airflow.sdk import ObjectStoragePath + + +@pytest.fixture +def tmp_cache_path(tmp_path): + """Return a file:// path to a temporary directory for cache files.""" + return f"file://{tmp_path.as_posix()}" + + +@pytest.fixture +def storage(tmp_cache_path): + with patch("airflow.providers.common.ai.durable.storage._get_base_path") as mock_base: + mock_base.return_value = ObjectStoragePath(tmp_cache_path) + yield DurableStorage(dag_id="test_dag", task_id="my_task", run_id="run_1", map_index=-1) + + +@pytest.fixture +def sample_response(): + return ModelResponse(parts=[TextPart(content="Hello!")]) + + +class TestDurableStorageInit: + def test_cache_id_format(self, storage): + assert storage._cache_id == "test_dag_my_task_run_1" + + def test_cache_id_with_map_index(self): + s = DurableStorage(dag_id="d", task_id="t", run_id="r", map_index=3) + assert s._cache_id == "d_t_r_3" + + def test_cache_id_without_map_index(self): + s = DurableStorage(dag_id="d", task_id="t", run_id="r", map_index=-1) + assert "_-1" not in s._cache_id + + +class TestSaveLoadModelResponse: + def test_save_and_load_roundtrips(self, storage, sample_response): + storage.save_model_response("model_step_0", sample_response) + + # Reset in-memory cache to force read from file + storage._cache = None + loaded = storage.load_model_response("model_step_0") + + assert loaded is not None + assert loaded.parts[0].content == "Hello!" + + def test_load_returns_none_when_no_cache(self, storage): + assert storage.load_model_response("model_step_0") is None + + +class TestSaveLoadToolResult: + def test_save_and_load_roundtrips(self, storage): + storage.save_tool_result("tool_step_0", {"rows": [1, 2, 3]}) + + storage._cache = None + found, value = storage.load_tool_result("tool_step_0") + + assert found is True + assert value == {"rows": [1, 2, 3]} + + def test_load_returns_false_when_no_cache(self, storage): + found, value = storage.load_tool_result("tool_step_0") + assert found is False + assert value is None + + def test_none_result_roundtrips(self, storage): + storage.save_tool_result("tool_step_0", None) + + storage._cache = None + found, value = storage.load_tool_result("tool_step_0") + + assert found is True + assert value is None + + +class TestCleanup: + def test_cleanup_deletes_file(self, storage, sample_response): + storage.save_model_response("model_step_0", sample_response) + path = storage._get_path() + assert path.exists() + + storage.cleanup() + assert not path.exists() + + def test_cleanup_on_nonexistent_file(self, storage): + storage.cleanup() # Should not raise + + +class TestInMemoryCaching: + def test_multiple_saves_write_single_file(self, storage, sample_response): + storage.save_model_response("model_step_0", sample_response) + storage.save_tool_result("tool_step_1", "result") + storage.save_model_response("model_step_2", sample_response) + + assert "model_step_0" in storage._cache + assert "tool_step_1" in storage._cache + assert "model_step_2" in storage._cache + + def test_cache_survives_reload(self, storage, sample_response): + """Simulate retry: save cache, reset in-memory, reload from file.""" + storage.save_model_response("model_step_0", sample_response) + storage.save_tool_result("tool_step_1", "tool result") + + # Simulate new DurableStorage instance (as on retry) + storage._cache = None + + loaded_response = storage.load_model_response("model_step_0") + assert loaded_response is not None + assert loaded_response.parts[0].content == "Hello!" + + found, value = storage.load_tool_result("tool_step_1") + assert found is True + assert value == "tool result" diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_agent.py b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py index 13410b35c55bb..9d2af0d29581a 100644 --- a/providers/common/ai/tests/unit/common/ai/operators/test_agent.py +++ b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py @@ -405,3 +405,61 @@ class Summary(BaseModel): ) assert output == '{"text":"Revised"}' + + +class TestAgentOperatorDurable: + def test_durable_param_stored(self): + op = AgentOperator(task_id="test", prompt="test", llm_conn_id="my_llm", durable=True) + assert op.durable is True + + def test_durable_default_false(self): + op = AgentOperator(task_id="test", prompt="test", llm_conn_id="my_llm") + assert op.durable is False + + @patch("pydantic_ai.models.wrapper.infer_model", side_effect=lambda m: m) + @patch("pydantic_ai.models.infer_model", autospec=True) + @patch("airflow.providers.common.ai.durable.storage._get_base_path") + @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", autospec=True) + def test_execute_durable_wraps_model_and_cleans_up( + self, mock_hook_cls, mock_base_path, mock_infer, _, tmp_path + ): + """durable=True wraps model with CachingModel and cleans up on success.""" + from airflow.sdk import ObjectStoragePath + + mock_base_path.return_value = ObjectStoragePath(f"file://{tmp_path.as_posix()}") + + mock_agent = MagicMock() + mock_agent.run_sync.return_value = _make_mock_run_result("ok") + mock_agent.model = "test-model" + mock_agent.override = MagicMock() + mock_agent.override.return_value.__enter__ = MagicMock(return_value=None) + mock_agent.override.return_value.__exit__ = MagicMock(return_value=False) + mock_hook_cls.get_hook.return_value.create_agent.return_value = mock_agent + + mock_resolved = MagicMock() + mock_infer.return_value = mock_resolved + + context = MagicMock() + context.__getitem__ = MagicMock( + return_value=MagicMock(dag_id="d", task_id="t", run_id="r", map_index=-1) + ) + + op = AgentOperator(task_id="test", prompt="test", llm_conn_id="my_llm", durable=True) + result = op.execute(context=context) + + assert result == "ok" + mock_agent.override.assert_called_once() + override_kwargs = mock_agent.override.call_args[1] + assert "model" in override_kwargs + + @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", autospec=True) + def test_execute_non_durable_does_not_wrap(self, mock_hook_cls): + """Default (durable=False) does not use override.""" + mock_agent = _make_mock_agent("ok") + mock_hook_cls.get_hook.return_value.create_agent.return_value = mock_agent + + op = AgentOperator(task_id="test", prompt="test", llm_conn_id="my_llm") + op.execute(context=MagicMock()) + + # run_sync called directly, no override + mock_agent.run_sync.assert_called_once_with("test") diff --git a/uv.lock b/uv.lock index ae5a9cbb425d7..37c0acf8a38b4 100644 --- a/uv.lock +++ b/uv.lock @@ -3925,7 +3925,7 @@ requires-dist = [ { name = "apache-airflow-providers-common-sql", marker = "extra == 'common-sql'", editable = "providers/common/sql" }, { name = "apache-airflow-providers-common-sql", marker = "extra == 'sql'", editable = "providers/common/sql" }, { name = "apache-airflow-providers-standard", editable = "providers/standard" }, - { name = "pydantic-ai-slim", specifier = ">=1.14.0" }, + { name = "pydantic-ai-slim", specifier = ">=1.34.0" }, { name = "pydantic-ai-slim", extras = ["anthropic"], marker = "extra == 'anthropic'" }, { name = "pydantic-ai-slim", extras = ["bedrock"], marker = "extra == 'bedrock'" }, { name = "pydantic-ai-slim", extras = ["google"], marker = "extra == 'google'" },