Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,7 @@ du
duckdb
dunder
dup
durable
durations
dylib
Dynamodb
Expand Down
2 changes: 1 addition & 1 deletion providers/common/ai/docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ PIP package Version required
``apache-airflow`` ``>=3.0.0``
``apache-airflow-providers-common-compat`` ``>=1.14.1``
``apache-airflow-providers-standard`` ``>=1.12.1``
``pydantic-ai-slim`` ``>=1.14.0``
``pydantic-ai-slim`` ``>=1.34.0``
========================================== ==================

Cross provider package dependencies
Expand Down
104 changes: 104 additions & 0 deletions providers/common/ai/docs/operators/agent.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,106 @@ tasks can consume it.
:end-before: [END howto_agent_chain]


Durable Execution
-----------------

Agent tasks can involve multiple LLM calls and tool invocations. If a task
fails mid-run (network error, timeout, transient API failure), a plain retry
re-executes every LLM call and tool call from scratch -- repeating work that
already succeeded and incurring additional cost.

Setting ``durable=True`` caches each LLM response and tool result to
ObjectStorage as it completes. On retry, completed steps are replayed from the
cache and only the remaining steps run against the live model and tools. The
cache is deleted after successful completion.

Durable execution only helps when the task has retries configured. Without
retries there is nothing to replay.

**Configuration**

Set the cache location in ``airflow.cfg``. The task raises ``ValueError`` at
runtime if ``durable=True`` and the option is missing.

.. code-block:: ini

[common.ai]
# Local filesystem -- suitable for development
durable_cache_path = file:///tmp/airflow_durable_cache

The value is an ObjectStorage URI, so any supported backend works. For
production, use a shared store so retries on a different worker can read the
cache:

.. code-block:: ini

[common.ai]
durable_cache_path = s3://my-bucket/airflow/durable-cache

**Operator example**

.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_agent_durable.py
:language: python
:start-after: [START howto_operator_agent_durable]
:end-before: [END howto_operator_agent_durable]

**Decorator example**

.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_agent_durable.py
:language: python
:start-after: [START howto_decorator_agent_durable]
:end-before: [END howto_decorator_agent_durable]

**How it works**

1. On first execution, each LLM response and tool result is saved to a JSON
file as the agent progresses.
2. If the task fails and Airflow retries it, completed steps are loaded from
the cache and returned without calling the model or tool. Steps not yet in
the cache proceed normally.
3. After successful completion, the cache file is deleted.

After the run, a single INFO summary line reports how many steps were
replayed vs executed fresh. Per-step detail is available at DEBUG level.

The cache file is named ``{dag_id}_{task_id}_{run_id}.json`` (with
``_{map_index}`` appended for mapped tasks) and stored under the configured
``durable_cache_path``. To force a completely fresh run, delete the cache file
for that task.

.. note::

Runs that fail permanently (exhaust all retries) leave their cache file
behind. These orphaned files do not affect future DAG runs (each run gets
its own file) but will consume storage. Clean them up periodically or add
a lifecycle policy to the storage backend.

**Side effects and idempotency**

Durable execution caches **return values**, not side effects. When a step is
replayed, the tool's code does not run -- only the stored return value is
returned. Two things follow from this:

- If a tool completed successfully and its result was cached, the tool will
**not** run again on retry. Any side effect it produced (writing a file,
sending a message) already happened during the original run and is not
repeated.
- If a tool fails *before* its result is cached, it **will** run again on
retry. A tool that partially completed (e.g. sent an email then raised an
exception) may produce the side effect a second time.

All built-in toolsets (``SQLToolset`` with ``allow_writes=False``,
``HookToolset`` in read-only mode) are read-only and replay safely. For custom
tools with non-idempotent side effects, design the tool to be idempotent. For
example, check whether the operation already completed before acting, or
use database constraints to prevent duplicate writes.

Tool results must be JSON-serializable to be cached. If a tool returns a
non-serializable value (e.g. ``BinaryContent`` from MCP tools), that step is
skipped with a warning and will re-execute on retry instead of replaying from
cache. The task itself still succeeds.


Parameters
----------

Expand All @@ -130,6 +230,10 @@ Parameters
every tool call is logged in real time. Default ``True``.
- ``agent_params``: Additional keyword arguments passed to the pydantic-ai
``Agent`` constructor (e.g. ``retries``, ``model_settings``).
- ``durable``: When ``True``, enables step-level caching of model responses and
tool results via ObjectStorage. On retry, cached steps are replayed instead of
re-executing expensive LLM calls. Requires the ``[common.ai] durable_cache_path``
config option to be set. Default ``False``.


Logging
Expand Down
2 changes: 1 addition & 1 deletion providers/common/ai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ dependencies = [
"apache-airflow>=3.0.0",
"apache-airflow-providers-common-compat>=1.14.1",
"apache-airflow-providers-standard>=1.12.1",
"pydantic-ai-slim>=1.14.0",
"pydantic-ai-slim>=1.34.0",
]

# The optional dependencies should be modified in place in the generated file
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Caching model wrapper for durable execution."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

import structlog
from pydantic_ai.models.wrapper import WrapperModel

log = structlog.get_logger(logger_name="task")

if TYPE_CHECKING:
from pydantic_ai.messages import ModelMessage, ModelResponse
from pydantic_ai.models import ModelRequestParameters
from pydantic_ai.settings import ModelSettings

from airflow.providers.common.ai.durable.step_counter import DurableStepCounter
from airflow.providers.common.ai.durable.storage import DurableStorage


@dataclass(init=False)
class CachingModel(WrapperModel):
"""
Wraps a model to cache responses in ObjectStorage for durable execution.

On each ``request()`` call, checks if a cached response exists for the
current step index. If so, returns the cached response without calling
the underlying model. Otherwise, calls the model and caches the response.
"""

storage: DurableStorage = field(repr=False)
counter: DurableStepCounter = field(repr=False)

def __init__(
self,
wrapped: Any,
*,
storage: DurableStorage,
counter: DurableStepCounter,
) -> None:
super().__init__(wrapped)
self.storage = storage
self.counter = counter

async def request(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> ModelResponse:
step = self.counter.next_step()
key = f"model_step_{step}"

cached = self.storage.load_model_response(key)
if cached is not None:
self.counter.replayed_model += 1
log.debug("Durable: replayed cached model response", step=step)
return cached

response = await self.wrapped.request(messages, model_settings, model_request_parameters)
self.storage.save_model_response(key, response)
self.counter.cached_model += 1
log.debug("Durable: cached model response", step=step)
return response
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Caching toolset wrapper for durable execution."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

import structlog
from pydantic_ai.toolsets.wrapper import WrapperToolset

if TYPE_CHECKING:
from pydantic_ai.toolsets.abstract import ToolsetTool

from airflow.providers.common.ai.durable.step_counter import DurableStepCounter
from airflow.providers.common.ai.durable.storage import DurableStorage

log = structlog.get_logger(logger_name="task")


@dataclass
class CachingToolset(WrapperToolset[Any]):
"""
Wraps a toolset to cache tool call results in ObjectStorage for durable execution.

On each ``call_tool()`` invocation, checks if a cached result exists for
the current step index. If so, returns the cached result without executing
the tool. Otherwise, executes the tool and caches the result.

The step index is grabbed before the first ``await``, so parallel tool
calls via ``asyncio.gather`` get deterministic indices (tasks start
executing their synchronous preamble in creation order).
"""

storage: DurableStorage = field(repr=False)
counter: DurableStepCounter = field(repr=False)

async def call_tool(
self,
name: str,
tool_args: dict[str, Any],
ctx: Any,
tool: ToolsetTool[Any],
) -> Any:
# Grab step index BEFORE any await -- ensures deterministic ordering
# even when multiple tool calls run concurrently via asyncio.gather.
step = self.counter.next_step()
key = f"tool_step_{step}"

found, cached = self.storage.load_tool_result(key)
if found:
self.counter.replayed_tool += 1
log.debug("Durable: replayed cached tool result", step=step, tool=name)
return cached

result = await self.wrapped.call_tool(name, tool_args, ctx, tool)
self.storage.save_tool_result(key, result)
self.counter.cached_tool += 1
log.debug("Durable: cached tool result", step=step, tool=name)
return result
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Shared step counter for durable execution caching."""

from __future__ import annotations


class DurableStepCounter:
"""
Monotonically increasing counter shared between CachingModel and CachingToolset.

Each model call and tool call increments the counter. The step index
is used as the cache key, ensuring deterministic replay on retry.
"""

def __init__(self) -> None:
self._step: int = 0
self.replayed_model: int = 0
self.replayed_tool: int = 0
self.cached_model: int = 0
self.cached_tool: int = 0

def next_step(self) -> int:
"""Return the current step and advance the counter."""
step = self._step
self._step += 1
return step

@property
def total_steps(self) -> int:
"""Total number of steps executed so far."""
return self._step
Loading
Loading