Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions providers/common/ai/docs/operators/agent.rst
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,28 @@ cache:
**How it works**

1. On first execution, each LLM response and tool result is saved to a JSON
file as the agent progresses.
file as the agent progresses, together with a fingerprint of the request
that produced it (model, message history, settings, and tools for LLM
steps; tool name, arguments, and call id for tool steps).
2. If the task fails and Airflow retries it, completed steps are loaded from
the cache and returned without calling the model or tool. Steps not yet in
the cache proceed normally.
3. After successful completion, the cache file is deleted.
3. Before a step is replayed, its stored fingerprint is compared against the
current request. If anything changed between attempts -- the system
prompt, the model, the toolset, model settings, or the conversation so
far -- the stale entry is discarded, a warning is logged, and the step
re-runs live. A divergence also invalidates the steps after it: re-running
an LLM step produces fresh tool call ids, so tool results recorded under
the old conversation no longer match. A changed agent costs a re-run; it
never replays responses that belong to a different conversation.
4. After successful completion, the cache file is deleted.

Replay verification compares the **requests** sent to models and tools, not
the code behind them. Editing a tool's implementation between attempts does
not invalidate an already-cached result for an identical call, and pointing
``llm_conn_id`` at a different endpoint serving the same model name does not
invalidate cached responses -- delete the cache file to force a fully fresh
run.

After the run, a single INFO summary line reports how many steps were
replayed vs executed fresh. Per-step detail is available at DEBUG level.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import structlog
from pydantic_ai.models.wrapper import WrapperModel

from airflow.providers.common.ai.durable.fingerprint import fingerprint_model_request

log = structlog.get_logger(logger_name="task")

if TYPE_CHECKING:
Expand All @@ -41,8 +43,12 @@ class CachingModel(WrapperModel):
Wraps a model to cache responses in ObjectStorage for durable execution.

On each ``request()`` call, checks if a cached response exists for the
current step index. If so, returns the cached response without calling
the underlying model. Otherwise, calls the model and caches the response.
current step index and was produced by an equivalent request (same model,
message history, settings, and tools -- compared via fingerprint). If so,
returns the cached response without calling the underlying model.
Otherwise, calls the model and caches the response. A fingerprint
mismatch means the agent changed between attempts; the stale entry is
discarded and the step re-runs live.
"""

storage: DurableStorage = field(repr=False)
Expand All @@ -67,15 +73,45 @@ async def request(
) -> ModelResponse:
step = self.counter.next_step()
key = f"model_step_{step}"
# Fingerprint the *prepared* request, not the raw arguments. Concrete
# models call ``prepare_request()`` at the start of ``request()`` to merge
# their model-level ``settings`` and apply profile-specific transforms
# (thinking resolution, native-tool handling, output-mode defaults) before
# the provider sees the request. Fingerprinting the raw arguments would
# miss a change that lives only at the model level -- e.g. a different
# temperature or thinking setting on the connection -- and replay a stale
# response. The raw arguments are still passed to ``wrapped.request()``,
# which re-runs ``prepare_request()`` itself (it is pure and idempotent).
prepared_settings, prepared_parameters = self.wrapped.prepare_request(
model_settings, model_request_parameters
)
fingerprint = fingerprint_model_request(
Comment thread
gopidesupavan marked this conversation as resolved.
f"{self.wrapped.system}:{self.wrapped.model_name}",
messages,
prepared_settings,
prepared_parameters,
)

cached = self.storage.load_model_response(key)
cached, cached_fingerprint = self.storage.load_model_response(key)
if cached is not None:
self.counter.replayed_model += 1
log.debug("Durable: replayed cached model response", step=step)
return cached
if cached_fingerprint == fingerprint:
self.counter.replayed_model += 1
log.debug("Durable: replayed cached model response", step=step)
return cached
log.warning(
"Durable: cached model response does not match the current request; "
"re-running this step instead of replaying",
step=step,
reason=(
"entry predates fingerprinting or the request could not be fingerprinted"
if fingerprint is None or cached_fingerprint is None
else "model, prompt, message history, settings, or tools changed since "
"the previous attempt"
),
)

response = await self.wrapped.request(messages, model_settings, model_request_parameters)
self.storage.save_model_response(key, response)
self.storage.save_model_response(key, response, fingerprint=fingerprint)
self.counter.cached_model += 1
log.debug("Durable: cached model response", step=step)
return response
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import structlog
from pydantic_ai.toolsets.wrapper import WrapperToolset

from airflow.providers.common.ai.durable.fingerprint import fingerprint_tool_call

if TYPE_CHECKING:
from pydantic_ai.toolsets.abstract import ToolsetTool

Expand All @@ -39,8 +41,12 @@ class CachingToolset(WrapperToolset[Any]):
Wraps a toolset to cache tool call results in ObjectStorage for durable execution.

On each ``call_tool()`` invocation, checks if a cached result exists for
the current step index. If so, returns the cached result without executing
the tool. Otherwise, executes the tool and caches the result.
the current step index and was produced by the same call (same tool name,
arguments, and model-issued ``tool_call_id`` -- compared via fingerprint).
If so, returns the cached result without executing the tool. Otherwise,
executes the tool and caches the result. A fingerprint mismatch means the
conversation diverged from the previous attempt; the stale entry is
discarded and the tool runs live.

The step index is grabbed before the first ``await``, so parallel tool
calls via ``asyncio.gather`` get deterministic indices (tasks start
Expand All @@ -61,15 +67,28 @@ async def call_tool(
# even when multiple tool calls run concurrently via asyncio.gather.
step = self.counter.next_step()
key = f"tool_step_{step}"
fingerprint = fingerprint_tool_call(name, tool_args, ctx.tool_call_id)

found, cached = self.storage.load_tool_result(key)
found, cached, cached_fingerprint = self.storage.load_tool_result(key)
if found:
self.counter.replayed_tool += 1
log.debug("Durable: replayed cached tool result", step=step, tool=name)
return cached
if cached_fingerprint == fingerprint:
self.counter.replayed_tool += 1
log.debug("Durable: replayed cached tool result", step=step, tool=name)
return cached
log.warning(
"Durable: cached tool result does not match the current tool call; "
"re-running the tool instead of replaying",
step=step,
tool=name,
reason=(
"entry predates fingerprinting or the call could not be fingerprinted"
if fingerprint is None or cached_fingerprint is None
else "the conversation diverged from the previous attempt"
),
)

result = await self.wrapped.call_tool(name, tool_args, ctx, tool)
self.storage.save_tool_result(key, result)
self.storage.save_tool_result(key, result, fingerprint=fingerprint)
self.counter.cached_tool += 1
log.debug("Durable: cached tool result", step=step, tool=name)
return result
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Request fingerprints for durable replay verification.

Durable caching keys steps positionally (``model_step_{N}`` / ``tool_step_{N}``).
Position alone cannot tell whether a cached entry still corresponds to the
current request: if the prompt, model, toolset, or message history changed
between the failed attempt and the retry, replaying by position would feed the
agent responses recorded for a different conversation.

Each cache entry therefore stores a fingerprint of the request that produced
it. On a cache hit the stored fingerprint is compared against the current
request; a mismatch is treated as a cache miss and the step re-runs live.
A divergence invalidates downstream steps too: a fresh model response carries
newly generated ``tool_call_id`` values, which are part of the tool
fingerprint, so stale tool results recorded under the old conversation no
longer match.

Fields that pydantic-ai regenerates on every attempt (message-level
``timestamp``/``run_id``/``conversation_id`` and part-level ``timestamp``)
are excluded from the fingerprint. Requests that cannot be serialized to
JSON fingerprint as ``None``, which degrades that step to unverified
positional replay (the pre-fingerprint behavior) rather than disabling
caching.
"""

from __future__ import annotations

import hashlib
import json
from typing import TYPE_CHECKING, Any

import structlog
from pydantic import TypeAdapter
from pydantic_ai.messages import ModelMessagesTypeAdapter
from pydantic_ai.models import ModelRequestParameters

if TYPE_CHECKING:
from pydantic_ai.messages import ModelMessage
from pydantic_ai.settings import ModelSettings

log = structlog.get_logger(logger_name="task")

_MODEL_REQUEST_PARAMETERS_ADAPTER = TypeAdapter(ModelRequestParameters)

# Message-level fields regenerated on every attempt.
_VOLATILE_MESSAGE_KEYS = ("timestamp", "run_id", "conversation_id")

# Settings that control transport, not response content. Excluded from the
# fingerprint: changing them should not invalidate a cached response, and some
# (``timeout`` can be an ``httpx.Timeout``) are not JSON-serializable, which
# would otherwise force the whole fingerprint to ``None`` and silently disable
# replay verification for every step.
_TRANSPORT_ONLY_SETTINGS = frozenset({"timeout"})


def _content_settings(model_settings: ModelSettings | None) -> dict[str, Any] | None:
"""Return the content-affecting settings, or ``None`` if there are none."""
if not model_settings:
return None
content = {k: v for k, v in model_settings.items() if k not in _TRANSPORT_ONLY_SETTINGS}
return content or None


def _strip_volatile(messages_dump: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Drop per-attempt fields from a dumped message list.

Only the levels pydantic-ai regenerates are touched (message-level ids and
timestamps, part-level timestamps); user data such as tool arguments is
never recursed into, so an argument legitimately named ``run_id`` still
affects the fingerprint.
"""
stripped = []
for message in messages_dump:
cleaned = {k: v for k, v in message.items() if k not in _VOLATILE_MESSAGE_KEYS}
if isinstance(cleaned.get("parts"), list):
cleaned["parts"] = [
{k: v for k, v in part.items() if k != "timestamp"} if isinstance(part, dict) else part
for part in cleaned["parts"]
]
stripped.append(cleaned)
return stripped


def _digest(payload: Any) -> str:
# No ``default=`` fallback: a non-JSON-serializable value must raise so the
# callers degrade to an unverifiable (None) fingerprint instead of hashing
# process-local reprs like ``<object at 0x...>`` that never match on retry.
canonical = json.dumps(payload, sort_keys=True)
return hashlib.sha256(canonical.encode()).hexdigest()


def fingerprint_model_request(
model_identifier: str,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> str | None:
"""
Fingerprint a model request: model identity, message history, settings, and request parameters.

The full ``ModelRequestParameters`` object is hashed (tool definitions,
output mode and schema, native tools, ...) so any change to what is sent
to the model invalidates the cached response.

Returns ``None`` when the request cannot be serialized; ``None`` compares
equal to ``None``, so requests that cannot be fingerprinted degrade to
unverified positional replay rather than disabling caching.
"""
try:
dumped = ModelMessagesTypeAdapter.dump_python(messages, mode="json")
params = _MODEL_REQUEST_PARAMETERS_ADAPTER.dump_python(model_request_parameters, mode="json")
return _digest(
{
"model": model_identifier,
"messages": _strip_volatile(dumped),
"settings": _content_settings(model_settings),
"params": params,
}
)
except (TypeError, ValueError):
# TypeError from json.dumps, ValueError covers PydanticSerializationError
log.warning(
"Durable: could not fingerprint model request; cached responses for this "
"step replay without verification"
)
return None


def fingerprint_tool_call(name: str, tool_args: dict[str, Any], tool_call_id: str | None) -> str | None:
"""
Fingerprint a tool call: tool name, arguments, and the model-issued call id.

``tool_call_id`` round-trips through the model-response cache, so it is
stable under faithful replay but regenerated whenever a live model call
replaces a cached response -- chaining invalidation to downstream tool steps.
"""
try:
return _digest({"name": name, "args": tool_args, "tool_call_id": tool_call_id})
except (TypeError, ValueError):
log.warning(
"Durable: could not fingerprint tool call; cached results for this "
"step replay without verification",
tool=name,
)
return None
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ class DurableStepCounter:
Monotonically increasing counter shared between CachingModel and CachingToolset.

Each model call and tool call increments the counter. The step index
is used as the cache key, ensuring deterministic replay on retry.
is used as the cache key; replay correctness is verified separately by
comparing the request fingerprint stored with each cache entry (see
``airflow.providers.common.ai.durable.fingerprint``).
"""

def __init__(self) -> None:
Expand Down
Loading
Loading