diff --git a/providers/common/ai/docs/operators/agent.rst b/providers/common/ai/docs/operators/agent.rst index 8e7c8ad5f3983..20ca4cd2a56a8 100644 --- a/providers/common/ai/docs/operators/agent.rst +++ b/providers/common/ai/docs/operators/agent.rst @@ -209,11 +209,28 @@ cache: **How it works** 1. On first execution, each LLM response and tool result is saved to a JSON - file as the agent progresses. + file as the agent progresses, together with a fingerprint of the request + that produced it (model, message history, settings, and tools for LLM + steps; tool name, arguments, and call id for tool steps). 2. If the task fails and Airflow retries it, completed steps are loaded from the cache and returned without calling the model or tool. Steps not yet in the cache proceed normally. -3. After successful completion, the cache file is deleted. +3. Before a step is replayed, its stored fingerprint is compared against the + current request. If anything changed between attempts -- the system + prompt, the model, the toolset, model settings, or the conversation so + far -- the stale entry is discarded, a warning is logged, and the step + re-runs live. A divergence also invalidates the steps after it: re-running + an LLM step produces fresh tool call ids, so tool results recorded under + the old conversation no longer match. A changed agent costs a re-run; it + never replays responses that belong to a different conversation. +4. After successful completion, the cache file is deleted. + +Replay verification compares the **requests** sent to models and tools, not +the code behind them. Editing a tool's implementation between attempts does +not invalidate an already-cached result for an identical call, and pointing +``llm_conn_id`` at a different endpoint serving the same model name does not +invalidate cached responses -- delete the cache file to force a fully fresh +run. After the run, a single INFO summary line reports how many steps were replayed vs executed fresh. Per-step detail is available at DEBUG level. diff --git a/providers/common/ai/src/airflow/providers/common/ai/durable/caching_model.py b/providers/common/ai/src/airflow/providers/common/ai/durable/caching_model.py index 0b2f85ecb400e..18f89439d82dd 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/durable/caching_model.py +++ b/providers/common/ai/src/airflow/providers/common/ai/durable/caching_model.py @@ -24,6 +24,8 @@ import structlog from pydantic_ai.models.wrapper import WrapperModel +from airflow.providers.common.ai.durable.fingerprint import fingerprint_model_request + log = structlog.get_logger(logger_name="task") if TYPE_CHECKING: @@ -41,8 +43,12 @@ class CachingModel(WrapperModel): Wraps a model to cache responses in ObjectStorage for durable execution. On each ``request()`` call, checks if a cached response exists for the - current step index. If so, returns the cached response without calling - the underlying model. Otherwise, calls the model and caches the response. + current step index and was produced by an equivalent request (same model, + message history, settings, and tools -- compared via fingerprint). If so, + returns the cached response without calling the underlying model. + Otherwise, calls the model and caches the response. A fingerprint + mismatch means the agent changed between attempts; the stale entry is + discarded and the step re-runs live. """ storage: DurableStorage = field(repr=False) @@ -67,15 +73,45 @@ async def request( ) -> ModelResponse: step = self.counter.next_step() key = f"model_step_{step}" + # Fingerprint the *prepared* request, not the raw arguments. Concrete + # models call ``prepare_request()`` at the start of ``request()`` to merge + # their model-level ``settings`` and apply profile-specific transforms + # (thinking resolution, native-tool handling, output-mode defaults) before + # the provider sees the request. Fingerprinting the raw arguments would + # miss a change that lives only at the model level -- e.g. a different + # temperature or thinking setting on the connection -- and replay a stale + # response. The raw arguments are still passed to ``wrapped.request()``, + # which re-runs ``prepare_request()`` itself (it is pure and idempotent). + prepared_settings, prepared_parameters = self.wrapped.prepare_request( + model_settings, model_request_parameters + ) + fingerprint = fingerprint_model_request( + f"{self.wrapped.system}:{self.wrapped.model_name}", + messages, + prepared_settings, + prepared_parameters, + ) - cached = self.storage.load_model_response(key) + cached, cached_fingerprint = self.storage.load_model_response(key) if cached is not None: - self.counter.replayed_model += 1 - log.debug("Durable: replayed cached model response", step=step) - return cached + if cached_fingerprint == fingerprint: + self.counter.replayed_model += 1 + log.debug("Durable: replayed cached model response", step=step) + return cached + log.warning( + "Durable: cached model response does not match the current request; " + "re-running this step instead of replaying", + step=step, + reason=( + "entry predates fingerprinting or the request could not be fingerprinted" + if fingerprint is None or cached_fingerprint is None + else "model, prompt, message history, settings, or tools changed since " + "the previous attempt" + ), + ) response = await self.wrapped.request(messages, model_settings, model_request_parameters) - self.storage.save_model_response(key, response) + self.storage.save_model_response(key, response, fingerprint=fingerprint) self.counter.cached_model += 1 log.debug("Durable: cached model response", step=step) return response diff --git a/providers/common/ai/src/airflow/providers/common/ai/durable/caching_toolset.py b/providers/common/ai/src/airflow/providers/common/ai/durable/caching_toolset.py index 2fd58fe78a40e..045c98aea1f50 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/durable/caching_toolset.py +++ b/providers/common/ai/src/airflow/providers/common/ai/durable/caching_toolset.py @@ -24,6 +24,8 @@ import structlog from pydantic_ai.toolsets.wrapper import WrapperToolset +from airflow.providers.common.ai.durable.fingerprint import fingerprint_tool_call + if TYPE_CHECKING: from pydantic_ai.toolsets.abstract import ToolsetTool @@ -39,8 +41,12 @@ class CachingToolset(WrapperToolset[Any]): Wraps a toolset to cache tool call results in ObjectStorage for durable execution. On each ``call_tool()`` invocation, checks if a cached result exists for - the current step index. If so, returns the cached result without executing - the tool. Otherwise, executes the tool and caches the result. + the current step index and was produced by the same call (same tool name, + arguments, and model-issued ``tool_call_id`` -- compared via fingerprint). + If so, returns the cached result without executing the tool. Otherwise, + executes the tool and caches the result. A fingerprint mismatch means the + conversation diverged from the previous attempt; the stale entry is + discarded and the tool runs live. The step index is grabbed before the first ``await``, so parallel tool calls via ``asyncio.gather`` get deterministic indices (tasks start @@ -61,15 +67,28 @@ async def call_tool( # even when multiple tool calls run concurrently via asyncio.gather. step = self.counter.next_step() key = f"tool_step_{step}" + fingerprint = fingerprint_tool_call(name, tool_args, ctx.tool_call_id) - found, cached = self.storage.load_tool_result(key) + found, cached, cached_fingerprint = self.storage.load_tool_result(key) if found: - self.counter.replayed_tool += 1 - log.debug("Durable: replayed cached tool result", step=step, tool=name) - return cached + if cached_fingerprint == fingerprint: + self.counter.replayed_tool += 1 + log.debug("Durable: replayed cached tool result", step=step, tool=name) + return cached + log.warning( + "Durable: cached tool result does not match the current tool call; " + "re-running the tool instead of replaying", + step=step, + tool=name, + reason=( + "entry predates fingerprinting or the call could not be fingerprinted" + if fingerprint is None or cached_fingerprint is None + else "the conversation diverged from the previous attempt" + ), + ) result = await self.wrapped.call_tool(name, tool_args, ctx, tool) - self.storage.save_tool_result(key, result) + self.storage.save_tool_result(key, result, fingerprint=fingerprint) self.counter.cached_tool += 1 log.debug("Durable: cached tool result", step=step, tool=name) return result diff --git a/providers/common/ai/src/airflow/providers/common/ai/durable/fingerprint.py b/providers/common/ai/src/airflow/providers/common/ai/durable/fingerprint.py new file mode 100644 index 0000000000000..6b5a2d490a540 --- /dev/null +++ b/providers/common/ai/src/airflow/providers/common/ai/durable/fingerprint.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Request fingerprints for durable replay verification. + +Durable caching keys steps positionally (``model_step_{N}`` / ``tool_step_{N}``). +Position alone cannot tell whether a cached entry still corresponds to the +current request: if the prompt, model, toolset, or message history changed +between the failed attempt and the retry, replaying by position would feed the +agent responses recorded for a different conversation. + +Each cache entry therefore stores a fingerprint of the request that produced +it. On a cache hit the stored fingerprint is compared against the current +request; a mismatch is treated as a cache miss and the step re-runs live. +A divergence invalidates downstream steps too: a fresh model response carries +newly generated ``tool_call_id`` values, which are part of the tool +fingerprint, so stale tool results recorded under the old conversation no +longer match. + +Fields that pydantic-ai regenerates on every attempt (message-level +``timestamp``/``run_id``/``conversation_id`` and part-level ``timestamp``) +are excluded from the fingerprint. Requests that cannot be serialized to +JSON fingerprint as ``None``, which degrades that step to unverified +positional replay (the pre-fingerprint behavior) rather than disabling +caching. +""" + +from __future__ import annotations + +import hashlib +import json +from typing import TYPE_CHECKING, Any + +import structlog +from pydantic import TypeAdapter +from pydantic_ai.messages import ModelMessagesTypeAdapter +from pydantic_ai.models import ModelRequestParameters + +if TYPE_CHECKING: + from pydantic_ai.messages import ModelMessage + from pydantic_ai.settings import ModelSettings + +log = structlog.get_logger(logger_name="task") + +_MODEL_REQUEST_PARAMETERS_ADAPTER = TypeAdapter(ModelRequestParameters) + +# Message-level fields regenerated on every attempt. +_VOLATILE_MESSAGE_KEYS = ("timestamp", "run_id", "conversation_id") + +# Settings that control transport, not response content. Excluded from the +# fingerprint: changing them should not invalidate a cached response, and some +# (``timeout`` can be an ``httpx.Timeout``) are not JSON-serializable, which +# would otherwise force the whole fingerprint to ``None`` and silently disable +# replay verification for every step. +_TRANSPORT_ONLY_SETTINGS = frozenset({"timeout"}) + + +def _content_settings(model_settings: ModelSettings | None) -> dict[str, Any] | None: + """Return the content-affecting settings, or ``None`` if there are none.""" + if not model_settings: + return None + content = {k: v for k, v in model_settings.items() if k not in _TRANSPORT_ONLY_SETTINGS} + return content or None + + +def _strip_volatile(messages_dump: list[dict[str, Any]]) -> list[dict[str, Any]]: + """ + Drop per-attempt fields from a dumped message list. + + Only the levels pydantic-ai regenerates are touched (message-level ids and + timestamps, part-level timestamps); user data such as tool arguments is + never recursed into, so an argument legitimately named ``run_id`` still + affects the fingerprint. + """ + stripped = [] + for message in messages_dump: + cleaned = {k: v for k, v in message.items() if k not in _VOLATILE_MESSAGE_KEYS} + if isinstance(cleaned.get("parts"), list): + cleaned["parts"] = [ + {k: v for k, v in part.items() if k != "timestamp"} if isinstance(part, dict) else part + for part in cleaned["parts"] + ] + stripped.append(cleaned) + return stripped + + +def _digest(payload: Any) -> str: + # No ``default=`` fallback: a non-JSON-serializable value must raise so the + # callers degrade to an unverifiable (None) fingerprint instead of hashing + # process-local reprs like ```` that never match on retry. + canonical = json.dumps(payload, sort_keys=True) + return hashlib.sha256(canonical.encode()).hexdigest() + + +def fingerprint_model_request( + model_identifier: str, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, +) -> str | None: + """ + Fingerprint a model request: model identity, message history, settings, and request parameters. + + The full ``ModelRequestParameters`` object is hashed (tool definitions, + output mode and schema, native tools, ...) so any change to what is sent + to the model invalidates the cached response. + + Returns ``None`` when the request cannot be serialized; ``None`` compares + equal to ``None``, so requests that cannot be fingerprinted degrade to + unverified positional replay rather than disabling caching. + """ + try: + dumped = ModelMessagesTypeAdapter.dump_python(messages, mode="json") + params = _MODEL_REQUEST_PARAMETERS_ADAPTER.dump_python(model_request_parameters, mode="json") + return _digest( + { + "model": model_identifier, + "messages": _strip_volatile(dumped), + "settings": _content_settings(model_settings), + "params": params, + } + ) + except (TypeError, ValueError): + # TypeError from json.dumps, ValueError covers PydanticSerializationError + log.warning( + "Durable: could not fingerprint model request; cached responses for this " + "step replay without verification" + ) + return None + + +def fingerprint_tool_call(name: str, tool_args: dict[str, Any], tool_call_id: str | None) -> str | None: + """ + Fingerprint a tool call: tool name, arguments, and the model-issued call id. + + ``tool_call_id`` round-trips through the model-response cache, so it is + stable under faithful replay but regenerated whenever a live model call + replaces a cached response -- chaining invalidation to downstream tool steps. + """ + try: + return _digest({"name": name, "args": tool_args, "tool_call_id": tool_call_id}) + except (TypeError, ValueError): + log.warning( + "Durable: could not fingerprint tool call; cached results for this " + "step replay without verification", + tool=name, + ) + return None diff --git a/providers/common/ai/src/airflow/providers/common/ai/durable/step_counter.py b/providers/common/ai/src/airflow/providers/common/ai/durable/step_counter.py index 85643f9b34a5a..a32b4a1ee3d3d 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/durable/step_counter.py +++ b/providers/common/ai/src/airflow/providers/common/ai/durable/step_counter.py @@ -24,7 +24,9 @@ class DurableStepCounter: Monotonically increasing counter shared between CachingModel and CachingToolset. Each model call and tool call increments the counter. The step index - is used as the cache key, ensuring deterministic replay on retry. + is used as the cache key; replay correctness is verified separately by + comparing the request fingerprint stored with each cache entry (see + ``airflow.providers.common.ai.durable.fingerprint``). """ def __init__(self) -> None: diff --git a/providers/common/ai/src/airflow/providers/common/ai/durable/storage.py b/providers/common/ai/src/airflow/providers/common/ai/durable/storage.py index d4437b18d11e1..88de494fe91df 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/durable/storage.py +++ b/providers/common/ai/src/airflow/providers/common/ai/durable/storage.py @@ -99,24 +99,50 @@ def _save_cache(self) -> None: path.parent.mkdir(parents=True, exist_ok=True) path.write_text(json.dumps(self._cache)) - def save_model_response(self, key: str, response: ModelResponse) -> None: - """Serialize and store a ModelResponse in the cache.""" + def save_model_response(self, key: str, response: ModelResponse, *, fingerprint: str | None) -> None: + """Serialize and store a ModelResponse with the request fingerprint that produced it.""" cache = self._load_cache() - cache[key] = ModelMessagesTypeAdapter.dump_json([response]).decode() + # Store the dumped messages as native JSON-compatible objects, not a + # pre-encoded string: the whole cache is JSON-encoded once in + # ``_save_cache``, so embedding a string here would double-encode the + # (large) response payload. + cache[key] = { + "fingerprint": fingerprint, + "data": ModelMessagesTypeAdapter.dump_python([response], mode="json"), + } self._save_cache() - def load_model_response(self, key: str) -> ModelResponse | None: - """Load a cached ModelResponse, or return None if not cached.""" + def load_model_response(self, key: str) -> tuple[ModelResponse | None, str | None]: + """ + Load a cached ModelResponse and its stored request fingerprint. + + Returns ``(None, None)`` if not cached. Entries written before + fingerprints existed load with a ``None`` fingerprint. + """ cache = self._load_cache() raw = cache.get(key) if raw is None: - return None - messages = ModelMessagesTypeAdapter.validate_json(raw) - return messages[0] # type: ignore[return-value] - - def save_tool_result(self, key: str, result: Any) -> None: + return None, None + try: + if isinstance(raw, dict): + messages = ModelMessagesTypeAdapter.validate_python(raw["data"]) + fingerprint = raw.get("fingerprint") + else: + # Legacy entry: the adapter JSON (a list) was stored directly as a string. + messages = ModelMessagesTypeAdapter.validate_json(raw) + fingerprint = None + except (KeyError, IndexError, ValueError): + # A torn or malformed entry degrades to a miss (the step re-runs), + # never a task crash -- the cache is best-effort. + log.warning("Durable: ignoring malformed cached model response", key=key) + return None, None + if not messages: + return None, None + return messages[0], fingerprint # type: ignore[return-value] + + def save_tool_result(self, key: str, result: Any, *, fingerprint: str | None) -> None: """ - Store a tool call result in the cache. + Store a tool call result with the call fingerprint that produced it. Non-serializable results (e.g. BinaryContent from MCP tools) are skipped with a warning -- the tool call still succeeds, but won't @@ -124,30 +150,39 @@ def save_tool_result(self, key: str, result: Any) -> None: """ cache = self._load_cache() try: - cache[key] = json.dumps({_SENTINEL: True, "value": result}) - except TypeError: + # Probe serializability before mutating the shared cache: a + # non-serializable result must skip only this entry, not break the + # whole-file ``_save_cache``. TypeError covers unsupported types; + # ValueError covers circular references. + json.dumps(result) + except (TypeError, ValueError): log.warning( "Durable: skipping cache for non-serializable tool result", key=key, type=type(result).__name__, ) return + cache[key] = {_SENTINEL: True, "value": result, "fingerprint": fingerprint} self._save_cache() - def load_tool_result(self, key: str) -> tuple[bool, Any]: + def load_tool_result(self, key: str) -> tuple[bool, Any, str | None]: """ - Load a cached tool result. + Load a cached tool result and its stored call fingerprint. - Returns (found, value) tuple since the cached value itself could be None. + Returns a (found, value, fingerprint) tuple since the cached value + itself could be None. Entries written before fingerprints existed + load with a ``None`` fingerprint. """ cache = self._load_cache() raw = cache.get(key) if raw is None: - return False, None - parsed = json.loads(raw) - if not isinstance(parsed, dict) or _SENTINEL not in parsed: - return False, None - return True, parsed["value"] + return False, None, None + # Legacy entries were stored as a JSON string; new entries are native dicts. + if isinstance(raw, str): + raw = json.loads(raw) + if not isinstance(raw, dict) or _SENTINEL not in raw: + return False, None, None + return True, raw["value"], raw.get("fingerprint") def cleanup(self) -> None: """Delete the cache file after successful execution.""" diff --git a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py index b41a0c54d8b2f..7f1179e2be3e9 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py +++ b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py @@ -123,7 +123,11 @@ class AgentOperator(BaseOperator, HITLReviewMixin): or tool budget. ``None`` (default) means no enforcement. :param durable: When ``True``, enables step-level caching of model responses and tool results for durable execution. On retry, cached - steps are replayed instead of re-executing. Default ``False``. + steps are replayed instead of re-executing. Each cached step is + verified against the current request before replay: if the prompt, + model, settings, tools, or message history changed since the failed + attempt, the affected steps re-run live (with a warning) instead of + replaying stale results. Default ``False``. Requires ``[common.ai] durable_cache_path`` to be set. **HITL Review parameters** (requires the ``hitl_review`` plugin): diff --git a/providers/common/ai/tests/unit/common/ai/durable/test_caching_model.py b/providers/common/ai/tests/unit/common/ai/durable/test_caching_model.py index 117f5d4a08c03..2b00fa4a408cc 100644 --- a/providers/common/ai/tests/unit/common/ai/durable/test_caching_model.py +++ b/providers/common/ai/tests/unit/common/ai/durable/test_caching_model.py @@ -20,15 +20,17 @@ import pytest from pydantic_ai.messages import ModelResponse, TextPart +from pydantic_ai.models import ModelRequestParameters from airflow.providers.common.ai.durable.caching_model import CachingModel +from airflow.providers.common.ai.durable.fingerprint import fingerprint_model_request from airflow.providers.common.ai.durable.step_counter import DurableStepCounter @pytest.fixture def mock_storage(): storage = MagicMock() - storage.load_model_response.return_value = None + storage.load_model_response.return_value = (None, None) return storage @@ -44,6 +46,8 @@ def mock_model(): model.system = "test" model.profile = MagicMock() model.settings = None + # CachingModel fingerprints the prepared request; identity keeps prepared == raw. + model.prepare_request = lambda settings, params: (settings, params) return model @@ -59,15 +63,22 @@ def sample_response(): return ModelResponse(parts=[TextPart(content="Hello!")]) +def request_fingerprint(messages=(), settings=None, params=None): + """Fingerprint matching what CachingModel computes for the mock model.""" + return fingerprint_model_request( + "test:test-model", list(messages), settings, params or ModelRequestParameters() + ) + + class TestCachingModelCacheHit: @pytest.mark.asyncio async def test_returns_cached_response_without_calling_model( self, mock_model, mock_storage, counter, sample_response ): - mock_storage.load_model_response.return_value = sample_response + mock_storage.load_model_response.return_value = (sample_response, request_fingerprint()) caching = CachingModel(mock_model, storage=mock_storage, counter=counter) - result = await caching.request([], None, MagicMock()) + result = await caching.request([], None, ModelRequestParameters()) assert result is sample_response mock_model.request.assert_not_called() @@ -75,10 +86,10 @@ async def test_returns_cached_response_without_calling_model( @pytest.mark.asyncio async def test_advances_counter_on_cache_hit(self, mock_model, mock_storage, counter, sample_response): - mock_storage.load_model_response.return_value = sample_response + mock_storage.load_model_response.return_value = (sample_response, request_fingerprint()) caching = CachingModel(mock_model, storage=mock_storage, counter=counter) - await caching.request([], None, MagicMock()) + await caching.request([], None, ModelRequestParameters()) assert counter.total_steps == 1 @@ -89,11 +100,13 @@ async def test_calls_model_and_caches_on_miss(self, mock_model, mock_storage, co mock_model.request = AsyncMock(return_value=sample_response) caching = CachingModel(mock_model, storage=mock_storage, counter=counter) - result = await caching.request([], None, MagicMock()) + result = await caching.request([], None, ModelRequestParameters()) assert result is sample_response mock_model.request.assert_called_once() - mock_storage.save_model_response.assert_called_once_with("model_step_0", sample_response) + mock_storage.save_model_response.assert_called_once_with( + "model_step_0", sample_response, fingerprint=request_fingerprint() + ) @pytest.mark.asyncio async def test_sequential_calls_use_incrementing_keys(self, mock_model, mock_storage, counter): @@ -102,8 +115,71 @@ async def test_sequential_calls_use_incrementing_keys(self, mock_model, mock_sto mock_model.request = AsyncMock(side_effect=[response_1, response_2]) caching = CachingModel(mock_model, storage=mock_storage, counter=counter) - await caching.request([], None, MagicMock()) - await caching.request([], None, MagicMock()) + await caching.request([], None, ModelRequestParameters()) + await caching.request([], None, ModelRequestParameters()) keys = [call[0][0] for call in mock_storage.save_model_response.call_args_list] assert keys == ["model_step_0", "model_step_1"] + + +class TestCachingModelReplayVerification: + @pytest.mark.asyncio + async def test_fingerprint_mismatch_treated_as_miss( + self, mock_model, mock_storage, counter, sample_response + ): + """A cached entry recorded for a different request must not be replayed.""" + stale = ModelResponse(parts=[TextPart(content="stale")]) + mock_storage.load_model_response.return_value = (stale, "fp_of_old_conversation") + mock_model.request = AsyncMock(return_value=sample_response) + caching = CachingModel(mock_model, storage=mock_storage, counter=counter) + + result = await caching.request([], None, ModelRequestParameters()) + + assert result is sample_response + mock_model.request.assert_called_once() + assert counter.replayed_model == 0 + mock_storage.save_model_response.assert_called_once_with( + "model_step_0", sample_response, fingerprint=request_fingerprint() + ) + + @pytest.mark.asyncio + async def test_legacy_entry_without_fingerprint_treated_as_miss( + self, mock_model, mock_storage, counter, sample_response + ): + """Pre-fingerprint cache entries cannot be verified, so they re-run.""" + stale = ModelResponse(parts=[TextPart(content="stale")]) + mock_storage.load_model_response.return_value = (stale, None) + mock_model.request = AsyncMock(return_value=sample_response) + caching = CachingModel(mock_model, storage=mock_storage, counter=counter) + + result = await caching.request([], None, ModelRequestParameters()) + + assert result is sample_response + mock_model.request.assert_called_once() + + @pytest.mark.asyncio + async def test_fingerprint_uses_prepared_request_not_raw_arguments( + self, mock_storage, counter, sample_response + ): + """Concrete models merge model-level settings in ``prepare_request`` before the + provider sees the request. The fingerprint must reflect the prepared settings, + so a model-level change (e.g. a different temperature on the connection) is not + invisible behind identical raw ``request()`` arguments.""" + model = MagicMock() + model.model_name = "test-model" + model.system = "test" + model.profile = MagicMock() + model.settings = None + model.request = AsyncMock(return_value=sample_response) + # Simulate prepare_request merging a model-level temperature into settings. + model.prepare_request = lambda settings, params: ({"temperature": 0.9}, params) + caching = CachingModel(model, storage=mock_storage, counter=counter) + + await caching.request([], None, ModelRequestParameters()) + + stored_fingerprint = mock_storage.save_model_response.call_args.kwargs["fingerprint"] + # Reflects the prepared settings, not the raw ``None`` the agent passed in. + assert stored_fingerprint == fingerprint_model_request( + "test:test-model", [], {"temperature": 0.9}, ModelRequestParameters() + ) + assert stored_fingerprint != request_fingerprint() diff --git a/providers/common/ai/tests/unit/common/ai/durable/test_caching_toolset.py b/providers/common/ai/tests/unit/common/ai/durable/test_caching_toolset.py index d2d999d9218a4..d7104928a1496 100644 --- a/providers/common/ai/tests/unit/common/ai/durable/test_caching_toolset.py +++ b/providers/common/ai/tests/unit/common/ai/durable/test_caching_toolset.py @@ -16,21 +16,24 @@ # under the License. from __future__ import annotations +from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch import pytest from pydantic_ai.messages import ModelResponse, TextPart +from pydantic_ai.models import ModelRequestParameters from airflow.providers.common.ai.durable.caching_model import CachingModel from airflow.providers.common.ai.durable.caching_toolset import CachingToolset +from airflow.providers.common.ai.durable.fingerprint import fingerprint_tool_call from airflow.providers.common.ai.durable.step_counter import DurableStepCounter @pytest.fixture def mock_storage(): storage = MagicMock() - storage.load_tool_result.return_value = (False, None) - storage.load_model_response.return_value = None + storage.load_tool_result.return_value = (False, None, None) + storage.load_model_response.return_value = (None, None) return storage @@ -49,13 +52,18 @@ def mock_toolset(): return toolset +def ctx_for(tool_call_id: str | None = "call_1") -> SimpleNamespace: + return SimpleNamespace(tool_call_id=tool_call_id) + + class TestCachingToolsetCacheHit: @pytest.mark.asyncio async def test_returns_cached_result_without_calling_tool(self, mock_toolset, mock_storage, counter): - mock_storage.load_tool_result.return_value = (True, "cached result") + fingerprint = fingerprint_tool_call("search", {"q": "foo"}, "call_1") + mock_storage.load_tool_result.return_value = (True, "cached result", fingerprint) caching = CachingToolset(wrapped=mock_toolset, storage=mock_storage, counter=counter) - result = await caching.call_tool("search", {"q": "foo"}, MagicMock(), MagicMock()) + result = await caching.call_tool("search", {"q": "foo"}, ctx_for("call_1"), MagicMock()) assert result == "cached result" mock_toolset.call_tool.assert_not_called() @@ -63,10 +71,11 @@ async def test_returns_cached_result_without_calling_tool(self, mock_toolset, mo @pytest.mark.asyncio async def test_advances_counter_on_cache_hit(self, mock_toolset, mock_storage, counter): - mock_storage.load_tool_result.return_value = (True, "cached") + fingerprint = fingerprint_tool_call("search", {}, "call_1") + mock_storage.load_tool_result.return_value = (True, "cached", fingerprint) caching = CachingToolset(wrapped=mock_toolset, storage=mock_storage, counter=counter) - await caching.call_tool("search", {}, MagicMock(), MagicMock()) + await caching.call_tool("search", {}, ctx_for("call_1"), MagicMock()) assert counter.total_steps == 1 @@ -76,24 +85,66 @@ class TestCachingToolsetCacheMiss: async def test_calls_tool_and_caches_on_miss(self, mock_toolset, mock_storage, counter): caching = CachingToolset(wrapped=mock_toolset, storage=mock_storage, counter=counter) - result = await caching.call_tool("search", {"q": "foo"}, MagicMock(), MagicMock()) + result = await caching.call_tool("search", {"q": "foo"}, ctx_for("call_1"), MagicMock()) assert result == "fresh result" mock_toolset.call_tool.assert_called_once() - mock_storage.save_tool_result.assert_called_once_with("tool_step_0", "fresh result") + mock_storage.save_tool_result.assert_called_once_with( + "tool_step_0", "fresh result", fingerprint=fingerprint_tool_call("search", {"q": "foo"}, "call_1") + ) @pytest.mark.asyncio async def test_sequential_calls_use_incrementing_keys(self, mock_toolset, mock_storage, counter): mock_toolset.call_tool = AsyncMock(side_effect=["result_a", "result_b"]) caching = CachingToolset(wrapped=mock_toolset, storage=mock_storage, counter=counter) - await caching.call_tool("tool_a", {}, MagicMock(), MagicMock()) - await caching.call_tool("tool_b", {}, MagicMock(), MagicMock()) + await caching.call_tool("tool_a", {}, ctx_for(), MagicMock()) + await caching.call_tool("tool_b", {}, ctx_for(), MagicMock()) keys = [call[0][0] for call in mock_storage.save_tool_result.call_args_list] assert keys == ["tool_step_0", "tool_step_1"] +class TestCachingToolsetReplayVerification: + @pytest.mark.asyncio + async def test_different_tool_call_treated_as_miss(self, mock_toolset, mock_storage, counter): + """A cached result recorded for a different tool call must not be replayed.""" + stale_fingerprint = fingerprint_tool_call("lookup_order", {"id": "A1"}, "old_call") + mock_storage.load_tool_result.return_value = (True, "stale result", stale_fingerprint) + caching = CachingToolset(wrapped=mock_toolset, storage=mock_storage, counter=counter) + + result = await caching.call_tool("charge_card", {"amount": 5}, ctx_for("new_call"), MagicMock()) + + assert result == "fresh result" + mock_toolset.call_tool.assert_called_once() + assert counter.replayed_tool == 0 + + @pytest.mark.asyncio + async def test_changed_tool_call_id_treated_as_miss(self, mock_toolset, mock_storage, counter): + """Same name/args but a new model-issued call id means the conversation diverged.""" + stale_fingerprint = fingerprint_tool_call("search", {"q": "foo"}, "old_call") + mock_storage.load_tool_result.return_value = (True, "stale result", stale_fingerprint) + caching = CachingToolset(wrapped=mock_toolset, storage=mock_storage, counter=counter) + + result = await caching.call_tool("search", {"q": "foo"}, ctx_for("new_call"), MagicMock()) + + assert result == "fresh result" + mock_toolset.call_tool.assert_called_once() + + @pytest.mark.asyncio + async def test_legacy_entry_without_fingerprint_treated_as_miss( + self, mock_toolset, mock_storage, counter + ): + """Pre-fingerprint cache entries cannot be verified, so the tool re-runs.""" + mock_storage.load_tool_result.return_value = (True, "stale result", None) + caching = CachingToolset(wrapped=mock_toolset, storage=mock_storage, counter=counter) + + result = await caching.call_tool("search", {"q": "foo"}, ctx_for("call_1"), MagicMock()) + + assert result == "fresh result" + mock_toolset.call_tool.assert_called_once() + + class TestSharedCounter: @pytest.mark.asyncio async def test_model_and_toolset_share_counter(self, mock_toolset, mock_storage): @@ -105,6 +156,7 @@ async def test_model_and_toolset_share_counter(self, mock_toolset, mock_storage) mock_model.system = "test" mock_model.profile = MagicMock() mock_model.settings = None + mock_model.prepare_request = lambda settings, params: (settings, params) response = ModelResponse(parts=[TextPart(content="response")]) mock_model.request = AsyncMock(return_value=response) @@ -114,9 +166,9 @@ async def test_model_and_toolset_share_counter(self, mock_toolset, mock_storage) caching_toolset = CachingToolset(wrapped=mock_toolset, storage=mock_storage, counter=counter) # Simulate: model call -> tool call -> model call - await caching_model.request([], None, MagicMock()) - await caching_toolset.call_tool("search", {}, MagicMock(), MagicMock()) - await caching_model.request([], None, MagicMock()) + await caching_model.request([], None, ModelRequestParameters()) + await caching_toolset.call_tool("search", {}, ctx_for(), MagicMock()) + await caching_model.request([], None, ModelRequestParameters()) model_keys = [call[0][0] for call in mock_storage.save_model_response.call_args_list] tool_keys = [call[0][0] for call in mock_storage.save_tool_result.call_args_list] diff --git a/providers/common/ai/tests/unit/common/ai/durable/test_fingerprint.py b/providers/common/ai/tests/unit/common/ai/durable/test_fingerprint.py new file mode 100644 index 0000000000000..94e517d1cc398 --- /dev/null +++ b/providers/common/ai/tests/unit/common/ai/durable/test_fingerprint.py @@ -0,0 +1,212 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime + +import httpx +from pydantic_ai.messages import ( + ModelRequest, + ModelResponse, + SystemPromptPart, + ToolCallPart, + UserPromptPart, +) +from pydantic_ai.models import ModelRequestParameters +from pydantic_ai.tools import ToolDefinition + +from airflow.providers.common.ai.durable.fingerprint import ( + fingerprint_model_request, + fingerprint_tool_call, +) + + +def make_messages(system: str = "You are a bot.", user: str = "hello", **part_kwargs): + return [ + ModelRequest( + parts=[ + SystemPromptPart(content=system, **part_kwargs), + UserPromptPart(content=user, **part_kwargs), + ] + ) + ] + + +class TestModelRequestFingerprint: + def test_stable_across_part_timestamps(self): + """Part timestamps regenerate on every attempt and must not affect the fingerprint.""" + t1 = datetime.datetime(2026, 1, 1, tzinfo=datetime.timezone.utc) + t2 = datetime.datetime(2026, 1, 2, tzinfo=datetime.timezone.utc) + fp1 = fingerprint_model_request("m", make_messages(timestamp=t1), None, ModelRequestParameters()) + fp2 = fingerprint_model_request("m", make_messages(timestamp=t2), None, ModelRequestParameters()) + + assert fp1 == fp2 + + def test_stable_across_separate_message_constructions(self): + """run_id/conversation_id and other per-run fields must not affect the fingerprint.""" + fp1 = fingerprint_model_request("m", make_messages(), None, ModelRequestParameters()) + fp2 = fingerprint_model_request("m", make_messages(), None, ModelRequestParameters()) + + assert fp1 == fp2 + + def test_changes_with_system_prompt(self): + fp1 = fingerprint_model_request("m", make_messages(system="a"), None, ModelRequestParameters()) + fp2 = fingerprint_model_request("m", make_messages(system="b"), None, ModelRequestParameters()) + + assert fp1 != fp2 + + def test_changes_with_user_prompt(self): + fp1 = fingerprint_model_request("m", make_messages(user="a"), None, ModelRequestParameters()) + fp2 = fingerprint_model_request("m", make_messages(user="b"), None, ModelRequestParameters()) + + assert fp1 != fp2 + + def test_changes_with_model_identifier(self): + fp1 = fingerprint_model_request("openai:gpt-5", make_messages(), None, ModelRequestParameters()) + fp2 = fingerprint_model_request("openai:gpt-5-mini", make_messages(), None, ModelRequestParameters()) + + assert fp1 != fp2 + + def test_changes_with_model_settings(self): + fp1 = fingerprint_model_request("m", make_messages(), None, ModelRequestParameters()) + fp2 = fingerprint_model_request("m", make_messages(), {"temperature": 0.5}, ModelRequestParameters()) + + assert fp1 != fp2 + + def test_changes_with_toolset(self): + tool = ToolDefinition(name="search", parameters_json_schema={"type": "object"}) + fp1 = fingerprint_model_request("m", make_messages(), None, ModelRequestParameters()) + fp2 = fingerprint_model_request( + "m", make_messages(), None, ModelRequestParameters(function_tools=[tool]) + ) + + assert fp1 != fp2 + + def test_changes_with_output_mode(self): + """The full request parameters are hashed, not just the tool list.""" + fp1 = fingerprint_model_request("m", make_messages(), None, ModelRequestParameters()) + fp2 = fingerprint_model_request( + "m", make_messages(), None, ModelRequestParameters(output_mode="native") + ) + + assert fp1 != fp2 + + def test_changes_with_tool_definition_fields(self): + """Changes inside a tool definition (e.g. strict mode) affect the fingerprint.""" + strict = ToolDefinition(name="t", parameters_json_schema={"type": "object"}, strict=True) + lax = ToolDefinition(name="t", parameters_json_schema={"type": "object"}, strict=False) + fp1 = fingerprint_model_request( + "m", make_messages(), None, ModelRequestParameters(function_tools=[strict]) + ) + fp2 = fingerprint_model_request( + "m", make_messages(), None, ModelRequestParameters(function_tools=[lax]) + ) + + assert fp1 != fp2 + + def test_volatile_keys_inside_user_data_are_not_stripped(self): + """Only pydantic-ai's own message/part fields are volatile; a tool argument + legitimately named run_id must still affect the fingerprint.""" + + def messages_with_args(args): + return [ + ModelRequest(parts=[UserPromptPart(content="q")]), + ModelResponse(parts=[ToolCallPart(tool_name="t", args=args, tool_call_id="id1")]), + ] + + fp1 = fingerprint_model_request( + "m", messages_with_args({"run_id": "a"}), None, ModelRequestParameters() + ) + fp2 = fingerprint_model_request( + "m", messages_with_args({"run_id": "b"}), None, ModelRequestParameters() + ) + + assert fp1 != fp2 + + def test_unserializable_request_returns_none(self): + fp = fingerprint_model_request("m", [object()], None, ModelRequestParameters()) # type: ignore[list-item] + + assert fp is None + + def test_unserializable_settings_returns_none(self): + """Non-JSON settings values degrade to unverified replay instead of hashing + process-local reprs that would never match on retry.""" + fp = fingerprint_model_request( + "m", make_messages(), {"extra_body": object()}, ModelRequestParameters() + ) # type: ignore[typeddict-item] + + assert fp is None + + def test_httpx_timeout_does_not_disable_fingerprint(self): + """``timeout`` may be an ``httpx.Timeout`` (a supported, non-JSON shape). + It must not force the fingerprint to None and silently disable verification.""" + fp = fingerprint_model_request( + "m", make_messages(), {"timeout": httpx.Timeout(30.0)}, ModelRequestParameters() + ) + + assert fp is not None + + def test_timeout_excluded_from_fingerprint(self): + """timeout is transport-only -- changing it (or its type) must not invalidate + the cached response, so it is the same fingerprint as no timeout at all.""" + no_timeout = fingerprint_model_request("m", make_messages(), None, ModelRequestParameters()) + float_timeout = fingerprint_model_request( + "m", make_messages(), {"timeout": 30.0}, ModelRequestParameters() + ) + httpx_timeout = fingerprint_model_request( + "m", make_messages(), {"timeout": httpx.Timeout(5.0)}, ModelRequestParameters() + ) + + assert no_timeout == float_timeout == httpx_timeout + + def test_content_settings_still_count_when_timeout_present(self): + """Stripping timeout must not drop content settings sharing the dict.""" + low = fingerprint_model_request( + "m", + make_messages(), + {"temperature": 0.2, "timeout": httpx.Timeout(1.0)}, + ModelRequestParameters(), + ) + high = fingerprint_model_request( + "m", + make_messages(), + {"temperature": 0.9, "timeout": httpx.Timeout(1.0)}, + ModelRequestParameters(), + ) + + assert low is not None + assert high is not None + assert low != high + + +class TestToolCallFingerprint: + def test_stable_for_identical_call(self): + assert fingerprint_tool_call("t", {"a": 1}, "id1") == fingerprint_tool_call("t", {"a": 1}, "id1") + + def test_changes_with_name(self): + assert fingerprint_tool_call("a", {}, "id1") != fingerprint_tool_call("b", {}, "id1") + + def test_changes_with_args(self): + assert fingerprint_tool_call("t", {"a": 1}, "id1") != fingerprint_tool_call("t", {"a": 2}, "id1") + + def test_changes_with_tool_call_id(self): + assert fingerprint_tool_call("t", {}, "id1") != fingerprint_tool_call("t", {}, "id2") + + def test_arg_order_does_not_matter(self): + assert fingerprint_tool_call("t", {"a": 1, "b": 2}, "id1") == fingerprint_tool_call( + "t", {"b": 2, "a": 1}, "id1" + ) diff --git a/providers/common/ai/tests/unit/common/ai/durable/test_replay_verification.py b/providers/common/ai/tests/unit/common/ai/durable/test_replay_verification.py new file mode 100644 index 0000000000000..ad9c244b0c60f --- /dev/null +++ b/providers/common/ai/tests/unit/common/ai/durable/test_replay_verification.py @@ -0,0 +1,138 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +End-to-end replay verification through a real pydantic-ai agent loop. + +Simulates the retry scenario durable execution exists for: attempt 1 fails +partway with steps cached, attempt 2 starts a fresh counter against the same +cache file. Replay must happen when the agent is unchanged and must NOT happen +when the agent changed between attempts (the positional-keying staleness bug). +""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest +from pydantic_ai import Agent +from pydantic_ai.messages import ModelMessage, ModelResponse, TextPart, ToolCallPart +from pydantic_ai.models.function import AgentInfo, FunctionModel +from pydantic_ai.toolsets import FunctionToolset + +from airflow.providers.common.ai.durable.caching_model import CachingModel +from airflow.providers.common.ai.durable.caching_toolset import CachingToolset +from airflow.providers.common.ai.durable.step_counter import DurableStepCounter +from airflow.providers.common.ai.durable.storage import DurableStorage +from airflow.sdk import ObjectStoragePath + + +@pytest.fixture +def storage(tmp_path): + with patch("airflow.providers.common.ai.durable.storage._get_base_path") as mock_base: + mock_base.return_value = ObjectStoragePath(f"file://{tmp_path.as_posix()}") + yield DurableStorage(dag_id="d", task_id="t", run_id="r") + + +class AgentHarness: + """A scripted two-step agent: one tool call, then a final answer.""" + + def __init__(self, storage: DurableStorage, *, system_prompt: str, rate: float, fail: bool): + self.live_model_calls = 0 + self.live_tool_calls = 0 + self.counter = DurableStepCounter() + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + self.live_model_calls += 1 + if not any(isinstance(m, ModelResponse) for m in messages): + return ModelResponse(parts=[ToolCallPart(tool_name="get_fx_rate", args={})]) + if fail: + raise RuntimeError("simulated transient failure") + returned = next(p.content for m in messages for p in m.parts if p.part_kind == "tool-return") + return ModelResponse(parts=[TextPart(content=f"rate={returned}")]) + + def get_fx_rate() -> float: + """Return the USD to EUR exchange rate.""" + self.live_tool_calls += 1 + return rate + + # Constructor form: the @toolset.tool decorator is typed for + # RunContext-taking functions only. + toolset = FunctionToolset(tools=[get_fx_rate]) + + self.agent = Agent( + model=CachingModel(FunctionModel(model_fn), storage=storage, counter=self.counter), + system_prompt=system_prompt, + toolsets=[CachingToolset(wrapped=toolset, storage=storage, counter=self.counter)], + ) + + async def run(self) -> str: + result = await self.agent.run("Convert 100 USD to EUR") + return result.output + + +async def failed_first_attempt(storage: DurableStorage, system_prompt: str = "currency bot") -> None: + """Attempt 1: tool result 0.42 gets cached, then the final model call fails.""" + harness = AgentHarness(storage, system_prompt=system_prompt, rate=0.42, fail=True) + with pytest.raises(RuntimeError, match="simulated transient failure"): + await harness.run() + storage._cache = None # retry runs in a new process: in-memory cache is cold + + +class TestUnchangedRetryReplays: + @pytest.mark.asyncio + async def test_completed_steps_replay_only_failed_step_reruns(self, storage): + await failed_first_attempt(storage) + + retry = AgentHarness(storage, system_prompt="currency bot", rate=0.42, fail=False) + output = await retry.run() + + assert output == "rate=0.42" + assert retry.counter.replayed_model == 1 + assert retry.counter.replayed_tool == 1 + assert retry.live_tool_calls == 0 + assert retry.live_model_calls == 1 # only the step that failed on attempt 1 + + +class TestChangedAgentDoesNotReplayStaleSteps: + @pytest.mark.asyncio + async def test_changed_system_prompt_invalidates_replay(self, storage): + """Regression test for positional-keying staleness: a prompt tweak between + attempts must re-run the conversation, not replay the old one.""" + await failed_first_attempt(storage) + + retry = AgentHarness(storage, system_prompt="careful currency bot", rate=0.99, fail=False) + output = await retry.run() + + # The fixed tool ran in the new conversation; nothing stale was replayed. + assert output == "rate=0.99" + assert retry.counter.replayed_model == 0 + assert retry.counter.replayed_tool == 0 + assert retry.live_tool_calls == 1 + + @pytest.mark.asyncio + async def test_divergence_chains_to_downstream_tool_steps(self, storage): + """A live model call mints fresh tool_call_ids, so cached tool results + recorded under the old conversation cannot be cross-wired into the new one.""" + await failed_first_attempt(storage) + + retry = AgentHarness(storage, system_prompt="careful currency bot", rate=0.42, fail=False) + await retry.run() + + # Same tool name and args as attempt 1, but the conversation diverged + # at step 0 -- the tool must run live, not replay. + assert retry.live_tool_calls == 1 + assert retry.counter.replayed_tool == 0 diff --git a/providers/common/ai/tests/unit/common/ai/durable/test_storage.py b/providers/common/ai/tests/unit/common/ai/durable/test_storage.py index 507fd3126ac57..03f85be30e11f 100644 --- a/providers/common/ai/tests/unit/common/ai/durable/test_storage.py +++ b/providers/common/ai/tests/unit/common/ai/durable/test_storage.py @@ -16,13 +16,16 @@ # under the License. from __future__ import annotations +import json from unittest.mock import patch import pytest from pydantic_ai.messages import ( + ModelMessagesTypeAdapter, ModelResponse, TextPart, ) +from pydantic_ai.usage import RequestUsage from airflow.providers.common.ai.durable.storage import DurableStorage from airflow.sdk import ObjectStoragePath @@ -61,47 +64,147 @@ def test_cache_id_without_map_index(self): class TestSaveLoadModelResponse: def test_save_and_load_roundtrips(self, storage, sample_response): - storage.save_model_response("model_step_0", sample_response) + storage.save_model_response("model_step_0", sample_response, fingerprint="fp_abc") # Reset in-memory cache to force read from file storage._cache = None - loaded = storage.load_model_response("model_step_0") + loaded, fingerprint = storage.load_model_response("model_step_0") assert loaded is not None assert loaded.parts[0].content == "Hello!" + assert fingerprint == "fp_abc" def test_load_returns_none_when_no_cache(self, storage): - assert storage.load_model_response("model_step_0") is None + assert storage.load_model_response("model_step_0") == (None, None) + + def test_metadata_carrying_response_roundtrips_byte_identical(self, storage): + """Multi-step replay relies on cached responses round-tripping byte-identically: + a later step's fingerprint includes earlier responses in history, metadata + (usage, provider_response_id, finish_reason) and all. If a store/load cycle + altered any of it, every multi-step replay would mismatch and re-run. Pin it.""" + resp = ModelResponse( + parts=[TextPart(content="answer")], + usage=RequestUsage(input_tokens=11, output_tokens=22), + model_name="gpt-x", + provider_response_id="resp_xyz", + finish_reason="stop", + ) + before = ModelMessagesTypeAdapter.dump_python([resp], mode="json") + + storage.save_model_response("model_step_0", resp, fingerprint="fp") + storage._cache = None + loaded, _ = storage.load_model_response("model_step_0") + + after = ModelMessagesTypeAdapter.dump_python([loaded], mode="json") + assert after == before + + def test_stored_entry_is_single_encoded(self, storage, sample_response): + """The response payload is stored as native JSON objects, not a nested + JSON string -- the whole cache is encoded exactly once by ``_save_cache``.""" + storage.save_model_response("model_step_0", sample_response, fingerprint="fp") + + on_disk = json.loads(storage._get_path().read_text()) + entry = on_disk["model_step_0"] + + assert isinstance(entry, dict) # not a re-encoded JSON string + assert isinstance(entry["data"], list) # not a nested JSON string + assert entry["fingerprint"] == "fp" + + def test_legacy_entry_without_fingerprint_loads(self, storage, sample_response): + """Entries written before fingerprinting (raw adapter JSON) load with a None fingerprint.""" + cache = storage._load_cache() + cache["model_step_0"] = ModelMessagesTypeAdapter.dump_json([sample_response]).decode() + storage._save_cache() + storage._cache = None + + loaded, fingerprint = storage.load_model_response("model_step_0") + + assert loaded is not None + assert loaded.parts[0].content == "Hello!" + assert fingerprint is None class TestSaveLoadToolResult: def test_save_and_load_roundtrips(self, storage): - storage.save_tool_result("tool_step_0", {"rows": [1, 2, 3]}) + storage.save_tool_result("tool_step_0", {"rows": [1, 2, 3]}, fingerprint="fp") storage._cache = None - found, value = storage.load_tool_result("tool_step_0") + found, value, fingerprint = storage.load_tool_result("tool_step_0") assert found is True assert value == {"rows": [1, 2, 3]} + def test_fingerprint_roundtrips(self, storage): + storage.save_tool_result("tool_step_0", "result", fingerprint="fp_tool") + + storage._cache = None + found, value, fingerprint = storage.load_tool_result("tool_step_0") + + assert found is True + assert fingerprint == "fp_tool" + + def test_legacy_entry_without_fingerprint_loads(self, storage): + """Entries written before fingerprinting load with a None fingerprint.""" + cache = storage._load_cache() + cache["tool_step_0"] = json.dumps({"__durable_cached__": True, "value": "old"}) + storage._save_cache() + storage._cache = None + + found, value, fingerprint = storage.load_tool_result("tool_step_0") + + assert found is True + assert value == "old" + assert fingerprint is None + def test_load_returns_false_when_no_cache(self, storage): - found, value = storage.load_tool_result("tool_step_0") + found, value, fingerprint = storage.load_tool_result("tool_step_0") assert found is False assert value is None + assert fingerprint is None def test_none_result_roundtrips(self, storage): - storage.save_tool_result("tool_step_0", None) + storage.save_tool_result("tool_step_0", None, fingerprint="fp") storage._cache = None - found, value = storage.load_tool_result("tool_step_0") + found, value, fingerprint = storage.load_tool_result("tool_step_0") assert found is True assert value is None + def test_circular_reference_result_is_skipped_not_raised(self, storage): + """A circular reference raises ValueError in json.dumps; it must skip the + entry with a warning, not crash the (otherwise successful) tool step.""" + circular: dict = {} + circular["self"] = circular + + storage.save_tool_result("tool_step_0", circular, fingerprint="fp") # must not raise + + found, _, _ = storage.load_tool_result("tool_step_0") + assert found is False + + +class TestMalformedEntries: + def test_empty_data_list_degrades_to_miss(self, storage): + """A torn entry whose data list is empty loads as a miss, not an IndexError.""" + cache = storage._load_cache() + cache["model_step_0"] = {"fingerprint": "fp", "data": []} + storage._save_cache() + storage._cache = None + + assert storage.load_model_response("model_step_0") == (None, None) + + def test_entry_missing_data_key_degrades_to_miss(self, storage): + cache = storage._load_cache() + cache["model_step_0"] = {"fingerprint": "fp"} + storage._save_cache() + storage._cache = None + + assert storage.load_model_response("model_step_0") == (None, None) + class TestCleanup: def test_cleanup_deletes_file(self, storage, sample_response): - storage.save_model_response("model_step_0", sample_response) + storage.save_model_response("model_step_0", sample_response, fingerprint="fp") path = storage._get_path() assert path.exists() @@ -114,9 +217,9 @@ def test_cleanup_on_nonexistent_file(self, storage): class TestInMemoryCaching: def test_multiple_saves_write_single_file(self, storage, sample_response): - storage.save_model_response("model_step_0", sample_response) - storage.save_tool_result("tool_step_1", "result") - storage.save_model_response("model_step_2", sample_response) + storage.save_model_response("model_step_0", sample_response, fingerprint="fp") + storage.save_tool_result("tool_step_1", "result", fingerprint="fp") + storage.save_model_response("model_step_2", sample_response, fingerprint="fp") assert "model_step_0" in storage._cache assert "tool_step_1" in storage._cache @@ -124,16 +227,16 @@ def test_multiple_saves_write_single_file(self, storage, sample_response): def test_cache_survives_reload(self, storage, sample_response): """Simulate retry: save cache, reset in-memory, reload from file.""" - storage.save_model_response("model_step_0", sample_response) - storage.save_tool_result("tool_step_1", "tool result") + storage.save_model_response("model_step_0", sample_response, fingerprint="fp") + storage.save_tool_result("tool_step_1", "tool result", fingerprint="fp") # Simulate new DurableStorage instance (as on retry) storage._cache = None - loaded_response = storage.load_model_response("model_step_0") + loaded_response, _ = storage.load_model_response("model_step_0") assert loaded_response is not None assert loaded_response.parts[0].content == "Hello!" - found, value = storage.load_tool_result("tool_step_1") + found, value, _ = storage.load_tool_result("tool_step_1") assert found is True assert value == "tool result"