From c7501cdde68975de78c1194ebe0dcb753ad30284 Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Wed, 25 Mar 2026 21:51:52 -0400 Subject: [PATCH 01/10] refresh tee configs periodically --- src/opengradient/client/llm.py | 34 ++++++++++++++++++++++++++++------ uv.lock | 2 +- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index bd7f856f..bceecbf9 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -1,8 +1,10 @@ """LLM chat and completion via TEE-verified execution with x402 payments.""" +import asyncio import json import logging import ssl +import time from dataclasses import dataclass from typing import AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union import httpx @@ -32,6 +34,7 @@ _CHAT_ENDPOINT = "/v1/chat/completions" _COMPLETION_ENDPOINT = "/v1/completions" _REQUEST_TIMEOUT = 60 +_TEE_REFRESH_INTERVAL = 300 # Re-resolve TEE from registry every 5 minutes @dataclass @@ -107,6 +110,8 @@ def __init__( register_upto_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) self._connect_tee() + self._tee_refreshed_at: float = time.monotonic() + self._refresh_lock = asyncio.Lock() # ── TEE resolution and connection ─────────────────────────────────────────── @@ -127,12 +132,27 @@ def _connect_tee(self) -> None: async def _refresh_tee(self) -> None: """Re-resolve TEE from the registry and rebuild the HTTP client.""" - old_http_client = self._http_client - self._connect_tee() - try: - await old_http_client.aclose() - except Exception: - logger.debug("Failed to close previous HTTP client during TEE refresh.", exc_info=True) + async with self._refresh_lock: + old_http_client = self._http_client + self._connect_tee() + self._tee_refreshed_at = time.monotonic() + try: + await old_http_client.aclose() + except Exception: + logger.debug("Failed to close previous HTTP client during TEE refresh.", exc_info=True) + + async def _maybe_refresh_tee(self) -> None: + """Re-resolve TEE if the current one is older than ``_TEE_REFRESH_INTERVAL``. + + Skips the refresh for explicit ``llm_server_url`` overrides since they + bypass the registry entirely. + """ + if self._llm_server_url is not None: + return + if time.monotonic() - self._tee_refreshed_at < _TEE_REFRESH_INTERVAL: + return + logger.debug("TEE endpoint stale (>%ds); refreshing from registry.", _TEE_REFRESH_INTERVAL) + await self._refresh_tee() @staticmethod @@ -212,6 +232,7 @@ async def _call_with_tee_retry( Only retries when the request never reached the server (no HTTP response). Server-side errors (4xx/5xx) are not retried. """ + await self._maybe_refresh_tee() try: return await call() except httpx.HTTPStatusError: @@ -448,6 +469,7 @@ async def _chat_tools_as_stream(self, params: _ChatParams, messages: List[Dict]) async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> AsyncGenerator[StreamChunk, None]: """Async SSE streaming implementation.""" + await self._maybe_refresh_tee() headers = self._headers(params.x402_settlement_mode) payload = self._chat_payload(params, messages, stream=True) diff --git a/uv.lock b/uv.lock index 9c74921b..6d0327a6 100644 --- a/uv.lock +++ b/uv.lock @@ -1867,7 +1867,7 @@ wheels = [ [[package]] name = "opengradient" -version = "0.9.2" +version = "0.9.3" source = { editable = "." } dependencies = [ { name = "click" }, From eb09eb56659182fda613984f1ed20b3f54cd2c5e Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Thu, 26 Mar 2026 21:11:59 -0400 Subject: [PATCH 02/10] use frozen --- src/opengradient/client/llm.py | 152 +++++-------------- src/opengradient/client/tee_connection.py | 172 ++++++++++++++++++++++ src/opengradient/client/tee_registry.py | 2 +- tests/client_test.py | 4 +- tests/llm_test.py | 75 +++++----- 5 files changed, 253 insertions(+), 152 deletions(-) create mode 100644 src/opengradient/client/tee_connection.py diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index bceecbf9..1467bc01 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -1,10 +1,7 @@ """LLM chat and completion via TEE-verified execution with x402 payments.""" -import asyncio import json import logging -import ssl -import time from dataclasses import dataclass from typing import AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union import httpx @@ -12,14 +9,14 @@ from eth_account import Account from eth_account.account import LocalAccount from x402 import x402Client -from x402.http.clients import x402HttpxClient from x402.mechanisms.evm import EthAccountSigner from x402.mechanisms.evm.exact.register import register_exact_evm_client from x402.mechanisms.evm.upto.register import register_upto_evm_client from ..types import TEE_LLM, StreamChoice, StreamChunk, StreamDelta, TextGenerationOutput, x402SettlementMode from .opg_token import Permit2ApprovalResult, ensure_opg_approval -from .tee_registry import TEERegistry, build_ssl_context_from_der +from .tee_connection import TEEConnection +from .tee_registry import TEERegistry logger = logging.getLogger(__name__) T = TypeVar("T") @@ -34,10 +31,9 @@ _CHAT_ENDPOINT = "/v1/chat/completions" _COMPLETION_ENDPOINT = "/v1/completions" _REQUEST_TIMEOUT = 60 -_TEE_REFRESH_INTERVAL = 300 # Re-resolve TEE from registry every 5 minutes -@dataclass +@dataclass(frozen=True) class _ChatParams: """Bundles the common parameters for chat/completion requests.""" @@ -99,96 +95,30 @@ def __init__( llm_server_url: Optional[str] = None, ): self._wallet_account: LocalAccount = Account.from_key(private_key) - self._rpc_url = rpc_url - self._tee_registry_address = tee_registry_address - self._llm_server_url = llm_server_url # x402 payment stack (created once, reused across TEE refreshes) signer = EthAccountSigner(self._wallet_account) - self._x402_client = x402Client() - register_exact_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) - register_upto_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) - - self._connect_tee() - self._tee_refreshed_at: float = time.monotonic() - self._refresh_lock = asyncio.Lock() - - # ── TEE resolution and connection ─────────────────────────────────────────── - - def _connect_tee(self) -> None: - """Resolve TEE from registry and create a secure HTTP client for it.""" - endpoint, tls_cert_der, tee_id, tee_payment_address = self._resolve_tee( - self._llm_server_url, - self._rpc_url, - self._tee_registry_address, + x402_client = x402Client() + register_exact_evm_client(x402_client, signer, networks=[BASE_TESTNET_NETWORK]) + register_upto_evm_client(x402_client, signer, networks=[BASE_TESTNET_NETWORK]) + + registry: Optional[TEERegistry] = ( + TEERegistry(rpc_url=rpc_url, registry_address=tee_registry_address) + if llm_server_url is None + else None ) - self._tee_id = tee_id - self._tee_endpoint = endpoint - self._tee_payment_address = tee_payment_address - - ssl_ctx = build_ssl_context_from_der(tls_cert_der) if tls_cert_der else None - self._tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else (self._llm_server_url is None) - self._http_client = x402HttpxClient(self._x402_client, verify=self._tls_verify) - - async def _refresh_tee(self) -> None: - """Re-resolve TEE from the registry and rebuild the HTTP client.""" - async with self._refresh_lock: - old_http_client = self._http_client - self._connect_tee() - self._tee_refreshed_at = time.monotonic() - try: - await old_http_client.aclose() - except Exception: - logger.debug("Failed to close previous HTTP client during TEE refresh.", exc_info=True) - - async def _maybe_refresh_tee(self) -> None: - """Re-resolve TEE if the current one is older than ``_TEE_REFRESH_INTERVAL``. - - Skips the refresh for explicit ``llm_server_url`` overrides since they - bypass the registry entirely. - """ - if self._llm_server_url is not None: - return - if time.monotonic() - self._tee_refreshed_at < _TEE_REFRESH_INTERVAL: - return - logger.debug("TEE endpoint stale (>%ds); refreshing from registry.", _TEE_REFRESH_INTERVAL) - await self._refresh_tee() - - - @staticmethod - def _resolve_tee( - tee_endpoint_override: Optional[str], - og_rpc_url: Optional[str], - tee_registry_address: Optional[str], - ) -> tuple: - """Resolve TEE endpoint and metadata from the on-chain registry or explicit URL. - - Returns: - (endpoint, tls_cert_der, tee_id, payment_address) - """ - if tee_endpoint_override is not None: - return tee_endpoint_override, None, None, None - if og_rpc_url is None or tee_registry_address is None: - raise ValueError("Either llm_server_url or both rpc_url and tee_registry_address must be provided.") - - try: - registry = TEERegistry(rpc_url=og_rpc_url, registry_address=tee_registry_address) - tee = registry.get_llm_tee() - except Exception as e: - raise RuntimeError(f"Failed to fetch LLM TEE endpoint from registry ({tee_registry_address} on {og_rpc_url}): {e}. ") from e - - if tee is None: - raise ValueError("No active LLM proxy TEE found in the registry. Pass llm_server_url explicitly to override.") - - logger.info("Using TEE endpoint from registry: %s (teeId=%s)", tee.endpoint, tee.tee_id) - return tee.endpoint, tee.tls_cert_der, tee.tee_id, tee.payment_address + self._tee = TEEConnection( + x402_client=x402_client, + registry=registry, + llm_server_url=llm_server_url, + ) # ── Lifecycle ─────────────────────────────────────────────────────── async def close(self) -> None: - """Close the underlying HTTP client.""" - await self._http_client.aclose() + """Cancel the background refresh loop and close the HTTP client.""" + await self._tee.close() # ── Request helpers ───────────────────────────────────────────────── @@ -215,13 +145,6 @@ def _chat_payload(self, params: _ChatParams, messages: List[Dict], stream: bool payload["tool_choice"] = params.tool_choice or "auto" return payload - def _tee_metadata(self) -> Dict: - return dict( - tee_id=self._tee_id, - tee_endpoint=self._tee_endpoint, - tee_payment_address=self._tee_payment_address, - ) - async def _call_with_tee_retry( self, operation_name: str, @@ -232,7 +155,7 @@ async def _call_with_tee_retry( Only retries when the request never reached the server (no HTTP response). Server-side errors (4xx/5xx) are not retried. """ - await self._maybe_refresh_tee() + self._tee.ensure_refresh_loop() try: return await call() except httpx.HTTPStatusError: @@ -243,7 +166,7 @@ async def _call_with_tee_retry( operation_name, exc, ) - await self._refresh_tee() + await self._tee.reconnect() return await call() # ── Public API ────────────────────────────────────────────────────── @@ -316,8 +239,9 @@ async def completion( payload["stop"] = stop_sequence async def _request() -> TextGenerationOutput: - response = await self._http_client.post( - self._tee_endpoint + _COMPLETION_ENDPOINT, + tee = self._tee.get() + response = await tee.http_client.post( + tee.endpoint + _COMPLETION_ENDPOINT, json=payload, headers=self._headers(x402_settlement_mode), timeout=_REQUEST_TIMEOUT, @@ -330,7 +254,7 @@ async def _request() -> TextGenerationOutput: completion_output=result.get("completion"), tee_signature=result.get("tee_signature"), tee_timestamp=result.get("tee_timestamp"), - **self._tee_metadata(), + **tee.metadata(), ) try: @@ -405,8 +329,9 @@ async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> Text payload = self._chat_payload(params, messages) async def _request() -> TextGenerationOutput: - response = await self._http_client.post( - self._tee_endpoint + _CHAT_ENDPOINT, + tee = self._tee.get() + response = await tee.http_client.post( + tee.endpoint + _CHAT_ENDPOINT, json=payload, headers=self._headers(params.x402_settlement_mode), timeout=_REQUEST_TIMEOUT, @@ -432,7 +357,7 @@ async def _request() -> TextGenerationOutput: chat_output=message, tee_signature=result.get("tee_signature"), tee_timestamp=result.get("tee_timestamp"), - **self._tee_metadata(), + **tee.metadata(), ) try: @@ -469,15 +394,16 @@ async def _chat_tools_as_stream(self, params: _ChatParams, messages: List[Dict]) async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> AsyncGenerator[StreamChunk, None]: """Async SSE streaming implementation.""" - await self._maybe_refresh_tee() + self._tee.ensure_refresh_loop() headers = self._headers(params.x402_settlement_mode) payload = self._chat_payload(params, messages, stream=True) chunks_yielded = False try: - async with self._http_client.stream( + tee = self._tee.get() + async with tee.http_client.stream( "POST", - self._tee_endpoint + _CHAT_ENDPOINT, + tee.endpoint + _CHAT_ENDPOINT, json=payload, headers=headers, timeout=_REQUEST_TIMEOUT, @@ -496,11 +422,12 @@ async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> Async exc, ) - await self._refresh_tee() + await self._tee.reconnect() + tee = self._tee.get() headers = self._headers(params.x402_settlement_mode) - async with self._http_client.stream( + async with tee.http_client.stream( "POST", - self._tee_endpoint + _CHAT_ENDPOINT, + tee.endpoint + _CHAT_ENDPOINT, json=payload, headers=headers, timeout=_REQUEST_TIMEOUT, @@ -546,7 +473,8 @@ async def _parse_sse_response(self, response) -> AsyncGenerator[StreamChunk, Non chunk = StreamChunk.from_sse_data(data) if chunk.is_final: - chunk.tee_id = self._tee_id - chunk.tee_endpoint = self._tee_endpoint - chunk.tee_payment_address = self._tee_payment_address + tee = self._tee.get() + chunk.tee_id = tee.tee_id + chunk.tee_endpoint = tee.endpoint + chunk.tee_payment_address = tee.payment_address yield chunk diff --git a/src/opengradient/client/tee_connection.py b/src/opengradient/client/tee_connection.py new file mode 100644 index 00000000..5403c530 --- /dev/null +++ b/src/opengradient/client/tee_connection.py @@ -0,0 +1,172 @@ +"""Manages the lifecycle of a connection to a TEE endpoint.""" + +import asyncio +import logging +import ssl +from dataclasses import dataclass +from typing import Dict, Optional, Union + +from x402 import x402Client +from x402.http.clients import x402HttpxClient + +from .tee_registry import TEE_TYPE_LLM_PROXY, TEERegistry, build_ssl_context_from_der + +logger = logging.getLogger(__name__) + +_TEE_REFRESH_INTERVAL = 300 # Re-resolve TEE from registry every 5 minutes + + +@dataclass(frozen=True) +class ActiveTEE: + """Snapshot of the currently connected TEE.""" + + endpoint: str + http_client: x402HttpxClient + tee_id: Optional[str] + payment_address: Optional[str] + + def metadata(self) -> Dict: + """Return TEE metadata dict for decorating responses.""" + return dict( + tee_id=self.tee_id, + tee_endpoint=self.endpoint, + tee_payment_address=self.payment_address, + ) + + +class TEEConnection: + """Maintains a verified connection to a single TEE endpoint. + + Handles initial resolution from the on-chain registry (or an explicit URL), + TLS certificate pinning, background health checks, and automatic failover + when the current TEE becomes unavailable. + + Use ``get()`` to obtain the current ``ActiveTEE`` snapshot for making requests. + + Args: + x402_client: Configured x402 payment client for creating HTTP clients. + registry: TEERegistry for looking up active TEEs. None when using an explicit URL. + llm_server_url: Bypass the registry and connect directly to this URL. + """ + + def __init__( + self, + x402_client: x402Client, + registry: Optional[TEERegistry] = None, + llm_server_url: Optional[str] = None, + ): + self._x402_client = x402_client + self._registry = registry + self._llm_server_url = llm_server_url + + self._active: Optional[ActiveTEE] = None + self._refresh_lock = asyncio.Lock() + self._refresh_task: Optional[asyncio.Task] = None + + self._connect() + + # ── Public API ────────────────────────────────────────────────────── + + def get(self) -> ActiveTEE: + """Return a snapshot of the current TEE connection.""" + return self._active + + # ── Connection management ─────────────────────────────────────────── + + def _connect(self) -> None: + """Resolve TEE from registry and create a secure HTTP client.""" + endpoint, tls_cert_der, tee_id, payment_address = self._resolve_tee( + self._llm_server_url, + self._registry, + ) + + ssl_ctx = build_ssl_context_from_der(tls_cert_der) if tls_cert_der else None + tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else (self._llm_server_url is None) + + self._active = ActiveTEE( + endpoint=endpoint, + http_client=x402HttpxClient(self._x402_client, verify=tls_verify), + tee_id=tee_id, + payment_address=payment_address, + ) + + async def reconnect(self) -> None: + """Connect to a new TEE from the registry and rebuild the HTTP client.""" + async with self._refresh_lock: + old_client = self._active.http_client + self._connect() + try: + await old_client.aclose() + except Exception: + logger.debug("Failed to close previous HTTP client during TEE refresh.", exc_info=True) + + # ── Background health check ───────────────────────────────────────── + + def ensure_refresh_loop(self) -> None: + """Start the background TEE refresh loop if not already running. + + No-op when ``llm_server_url`` is set (bypasses the registry). + Called lazily from async request methods since ``__init__`` is synchronous. + """ + if self._llm_server_url is not None: + return + if self._refresh_task is not None and not self._refresh_task.done(): + return + self._refresh_task = asyncio.create_task(self._tee_refresh_loop()) + + async def _tee_refresh_loop(self) -> None: + """Periodically check that the current TEE is still active in the registry. + + If the current TEE is no longer active, performs a full refresh to pick + a new one. Does nothing when the current TEE is still healthy. + """ + while True: + await asyncio.sleep(_TEE_REFRESH_INTERVAL) + try: + active_tees = self._registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) + if any(t.tee_id == self._active.tee_id for t in active_tees): + logger.debug("Current TEE %s still active; no refresh needed.", self._active.tee_id) + continue + logger.info("Current TEE %s no longer active; switching to a new one.", self._active.tee_id) + await self.reconnect() + except Exception: + logger.warning("Background TEE health check failed; will retry next cycle.", exc_info=True) + + # ── Lifecycle ─────────────────────────────────────────────────────── + + async def close(self) -> None: + """Cancel the background refresh loop and close the HTTP client.""" + if self._refresh_task is not None: + self._refresh_task.cancel() + self._refresh_task = None + if self._active is not None: + await self._active.http_client.aclose() + + # ── Static helpers ────────────────────────────────────────────────── + + @staticmethod + def _resolve_tee( + tee_endpoint_override: Optional[str], + registry: Optional[TEERegistry], + ) -> tuple: + """Resolve TEE endpoint and metadata from the on-chain registry or explicit URL. + + Returns: + (endpoint, tls_cert_der, tee_id, payment_address) + """ + if tee_endpoint_override is not None: + return tee_endpoint_override, None, None, None + + if registry is None: + raise ValueError("Either llm_server_url or a TEERegistry instance must be provided.") + + try: + tee = registry.get_llm_tee() + except Exception as e: + raise RuntimeError(f"Failed to fetch LLM TEE endpoint from registry: {e}") from e + + if tee is None: + raise ValueError("No active LLM proxy TEE found in the registry. Pass llm_server_url explicitly to override.") + + logger.info("Using TEE endpoint from registry: %s (teeId=%s)", tee.endpoint, tee.tee_id) + return tee.endpoint, tee.tls_cert_der, tee.tee_id, tee.payment_address diff --git a/src/opengradient/client/tee_registry.py b/src/opengradient/client/tee_registry.py index 9ad3cfd7..571e3712 100644 --- a/src/opengradient/client/tee_registry.py +++ b/src/opengradient/client/tee_registry.py @@ -31,7 +31,7 @@ class TEEInfo(NamedTuple): last_heartbeat_at: int -@dataclass +@dataclass(frozen=True) class TEEEndpoint: """A verified TEE with its endpoint URL and TLS certificate from the registry.""" diff --git a/tests/client_test.py b/tests/client_test.py index 28df2bf0..4cc62763 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -72,13 +72,13 @@ class TestLLMInitialization: def test_llm_initialization(self, mock_tee_registry): """Test basic LLM initialization.""" llm = LLM(private_key=FAKE_PRIVATE_KEY) - assert llm._tee_endpoint == "https://test.tee.server" + assert llm._tee.get().endpoint == "https://test.tee.server" def test_llm_initialization_custom_url(self, mock_tee_registry): """Test LLM initialization with custom server URL.""" custom_llm_url = "https://custom.llm.server" llm = LLM(private_key=FAKE_PRIVATE_KEY, llm_server_url=custom_llm_url) - assert llm._tee_endpoint == custom_llm_url + assert llm._tee.get().endpoint == custom_llm_url # --- ModelHub Authentication Tests --- diff --git a/tests/llm_test.py b/tests/llm_test.py index 3953943e..b450a3ec 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -14,6 +14,7 @@ import pytest from src.opengradient.client.llm import LLM +from src.opengradient.client.tee_connection import TEEConnection from src.opengradient.types import TEE_LLM, x402SettlementMode # ── Fake HTTP transport ────────────────────────────────────────────── @@ -107,7 +108,7 @@ async def aread(self) -> bytes: # so LLM.__init__ runs its real code but gets our FakeHTTPClient. _PATCHES = { - "x402_httpx": "src.opengradient.client.llm.x402HttpxClient", + "x402_httpx": "src.opengradient.client.tee_connection.x402HttpxClient", "x402_client": "src.opengradient.client.llm.x402Client", "signer": "src.opengradient.client.llm.EthAccountSigner", "register_exact": "src.opengradient.client.llm.register_exact_evm_client", @@ -138,9 +139,9 @@ def _make_llm( ) -> LLM: """Build an LLM with an explicit server URL (skips registry lookup).""" llm = LLM(private_key=FAKE_PRIVATE_KEY, llm_server_url=endpoint) - # llm_server_url path sets tee_id/payment_address to None; set them for assertions. - llm._tee_id = "test-tee-id" - llm._tee_payment_address = "0xTestPayment" + # llm_server_url path sets tee_id/payment_address to None; replace with test values. + from dataclasses import replace + llm._tee._active = replace(llm._tee.get(), tee_id="test-tee-id", payment_address="0xTestPayment") return llm @@ -515,7 +516,7 @@ async def test_close_delegates_to_http_client(self, fake_http): class TestResolveTeE: def test_explicit_url_skips_registry(self): - endpoint, cert, tee_id, pay_addr = LLM._resolve_tee("https://explicit.url", None, None) + endpoint, cert, tee_id, pay_addr = TEEConnection._resolve_tee("https://explicit.url", None) assert endpoint == "https://explicit.url" assert cert is None @@ -524,34 +525,34 @@ def test_explicit_url_skips_registry(self): def test_missing_rpc_and_registry_raises(self): with pytest.raises(ValueError): - LLM._resolve_tee(None, None, None) + TEEConnection._resolve_tee(None, None) def test_missing_registry_address_raises(self): with pytest.raises(ValueError): - LLM._resolve_tee(None, "https://rpc", None) + TEEConnection._resolve_tee(None, None) def test_registry_returns_none_raises(self): - with patch("src.opengradient.client.llm.TEERegistry") as mock_reg: - mock_reg.return_value.get_llm_tee.return_value = None + mock_reg = MagicMock() + mock_reg.get_llm_tee.return_value = None - with pytest.raises(ValueError, match="No active LLM proxy TEE"): - LLM._resolve_tee(None, "https://rpc", "0xRegistry") + with pytest.raises(ValueError, match="No active LLM proxy TEE"): + TEEConnection._resolve_tee(None, mock_reg) def test_registry_success(self): - with patch("src.opengradient.client.llm.TEERegistry") as mock_reg: - mock_tee = MagicMock() - mock_tee.endpoint = "https://registry.tee" - mock_tee.tls_cert_der = b"cert-bytes" - mock_tee.tee_id = "tee-42" - mock_tee.payment_address = "0xPay" - mock_reg.return_value.get_llm_tee.return_value = mock_tee + mock_reg = MagicMock() + mock_tee = MagicMock() + mock_tee.endpoint = "https://registry.tee" + mock_tee.tls_cert_der = b"cert-bytes" + mock_tee.tee_id = "tee-42" + mock_tee.payment_address = "0xPay" + mock_reg.get_llm_tee.return_value = mock_tee - endpoint, cert, tee_id, pay_addr = LLM._resolve_tee(None, "https://rpc", "0xRegistry") + endpoint, cert, tee_id, pay_addr = TEEConnection._resolve_tee(None, mock_reg) - assert endpoint == "https://registry.tee" - assert cert == b"cert-bytes" - assert tee_id == "tee-42" - assert pay_addr == "0xPay" + assert endpoint == "https://registry.tee" + assert cert == b"cert-bytes" + assert tee_id == "tee-42" + assert pay_addr == "0xPay" # ── TEE retry tests (non-streaming) ────────────────────────────────── @@ -677,13 +678,13 @@ async def aread(self) -> bytes: assert len(fake_http.post_calls) == 1 -# ── _refresh_tee tests ───────────────────────────────────── +# ── TEE reconnect tests ───────────────────────────────────────────── @pytest.mark.asyncio -class TestRefreshTeeAndReset: +class TestReconnect: async def test_replaces_http_client(self): - """After refresh, the http client should be a new instance.""" + """After reconnect, the http client should be a new instance.""" clients_created = [] def make_client(*args, **kwargs): @@ -699,29 +700,29 @@ def make_client(*args, **kwargs): patch(_PATCHES["register_upto"]), ): llm = _make_llm() - old_client = llm._http_client + old_client = llm._tee.get().http_client - await llm._refresh_tee() + await llm._tee.reconnect() - assert llm._http_client is not old_client + assert llm._tee.get().http_client is not old_client assert len(clients_created) == 2 # init + refresh async def test_closes_old_client(self, fake_http): llm = _make_llm() - old_client = llm._http_client + old_client = llm._tee.get().http_client old_client.aclose = AsyncMock() - await llm._refresh_tee() + await llm._tee.reconnect() old_client.aclose.assert_awaited_once() async def test_close_failure_is_swallowed(self, fake_http): llm = _make_llm() - old_client = llm._http_client + old_client = llm._tee.get().http_client old_client.aclose = AsyncMock(side_effect=OSError("already closed")) # Should not raise - await llm._refresh_tee() + await llm._tee.reconnect() # ── TEE cert rotation (crash + re-register) tests ──────────────────── @@ -740,10 +741,10 @@ async def test_ssl_verification_failure_triggers_tee_refresh_completion(self, fa fake_http.fail_next_post(ssl.SSLCertVerificationError("certificate verify failed")) llm = _make_llm() - with patch.object(llm, "_connect_tee", wraps=llm._connect_tee) as spy: + with patch.object(llm._tee, "_connect", wraps=llm._tee._connect) as spy: result = await llm.completion(model=TEE_LLM.GPT_5, prompt="Hi") - # _connect_tee was called once during the retry (refresh) + # _connect was called once during the retry (reconnect) spy.assert_called_once() assert result.completion_output == "ok after refresh" assert len(fake_http.post_calls) == 2 @@ -756,7 +757,7 @@ async def test_ssl_verification_failure_triggers_tee_refresh_chat(self, fake_htt fake_http.fail_next_post(ssl.SSLCertVerificationError("certificate verify failed")) llm = _make_llm() - with patch.object(llm, "_connect_tee", wraps=llm._connect_tee) as spy: + with patch.object(llm._tee, "_connect", wraps=llm._tee._connect) as spy: result = await llm.chat(model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}]) spy.assert_called_once() @@ -774,7 +775,7 @@ async def test_ssl_verification_failure_triggers_tee_refresh_streaming(self, fak fake_http.fail_next_stream(ssl.SSLCertVerificationError("certificate verify failed")) llm = _make_llm() - with patch.object(llm, "_connect_tee", wraps=llm._connect_tee) as spy: + with patch.object(llm._tee, "_connect", wraps=llm._tee._connect) as spy: gen = await llm.chat( model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}], From ed7edea8a26fd7fc718e8387ad4e0fabf8365ee2 Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Thu, 26 Mar 2026 21:17:34 -0400 Subject: [PATCH 03/10] fix tests --- tests/llm_test.py | 92 ------- tests/tee_connection_test.py | 509 +++++++++++++++++++++++++++++++++++ 2 files changed, 509 insertions(+), 92 deletions(-) create mode 100644 tests/tee_connection_test.py diff --git a/tests/llm_test.py b/tests/llm_test.py index b450a3ec..b7867126 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -14,7 +14,6 @@ import pytest from src.opengradient.client.llm import LLM -from src.opengradient.client.tee_connection import TEEConnection from src.opengradient.types import TEE_LLM, x402SettlementMode # ── Fake HTTP transport ────────────────────────────────────────────── @@ -511,50 +510,6 @@ async def test_close_delegates_to_http_client(self, fake_http): # FakeHTTPClient.aclose is a no-op; just verify it doesn't blow up. -# ── TEE resolution tests ───────────────────────────────────────────── - - -class TestResolveTeE: - def test_explicit_url_skips_registry(self): - endpoint, cert, tee_id, pay_addr = TEEConnection._resolve_tee("https://explicit.url", None) - - assert endpoint == "https://explicit.url" - assert cert is None - assert tee_id is None - assert pay_addr is None - - def test_missing_rpc_and_registry_raises(self): - with pytest.raises(ValueError): - TEEConnection._resolve_tee(None, None) - - def test_missing_registry_address_raises(self): - with pytest.raises(ValueError): - TEEConnection._resolve_tee(None, None) - - def test_registry_returns_none_raises(self): - mock_reg = MagicMock() - mock_reg.get_llm_tee.return_value = None - - with pytest.raises(ValueError, match="No active LLM proxy TEE"): - TEEConnection._resolve_tee(None, mock_reg) - - def test_registry_success(self): - mock_reg = MagicMock() - mock_tee = MagicMock() - mock_tee.endpoint = "https://registry.tee" - mock_tee.tls_cert_der = b"cert-bytes" - mock_tee.tee_id = "tee-42" - mock_tee.payment_address = "0xPay" - mock_reg.get_llm_tee.return_value = mock_tee - - endpoint, cert, tee_id, pay_addr = TEEConnection._resolve_tee(None, mock_reg) - - assert endpoint == "https://registry.tee" - assert cert == b"cert-bytes" - assert tee_id == "tee-42" - assert pay_addr == "0xPay" - - # ── TEE retry tests (non-streaming) ────────────────────────────────── @@ -678,53 +633,6 @@ async def aread(self) -> bytes: assert len(fake_http.post_calls) == 1 -# ── TEE reconnect tests ───────────────────────────────────────────── - - -@pytest.mark.asyncio -class TestReconnect: - async def test_replaces_http_client(self): - """After reconnect, the http client should be a new instance.""" - clients_created = [] - - def make_client(*args, **kwargs): - c = FakeHTTPClient() - clients_created.append(c) - return c - - with ( - patch(_PATCHES["x402_httpx"], side_effect=make_client), - patch(_PATCHES["x402_client"]), - patch(_PATCHES["signer"]), - patch(_PATCHES["register_exact"]), - patch(_PATCHES["register_upto"]), - ): - llm = _make_llm() - old_client = llm._tee.get().http_client - - await llm._tee.reconnect() - - assert llm._tee.get().http_client is not old_client - assert len(clients_created) == 2 # init + refresh - - async def test_closes_old_client(self, fake_http): - llm = _make_llm() - old_client = llm._tee.get().http_client - old_client.aclose = AsyncMock() - - await llm._tee.reconnect() - - old_client.aclose.assert_awaited_once() - - async def test_close_failure_is_swallowed(self, fake_http): - llm = _make_llm() - old_client = llm._tee.get().http_client - old_client.aclose = AsyncMock(side_effect=OSError("already closed")) - - # Should not raise - await llm._tee.reconnect() - - # ── TEE cert rotation (crash + re-register) tests ──────────────────── diff --git a/tests/tee_connection_test.py b/tests/tee_connection_test.py new file mode 100644 index 00000000..7621d163 --- /dev/null +++ b/tests/tee_connection_test.py @@ -0,0 +1,509 @@ +"""Tests for TEEConnection and ActiveTEE. + +Covers TEE resolution, connection lifecycle, reconnect, background refresh, +and the ActiveTEE data snapshot. +""" + +import asyncio +import ssl +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from src.opengradient.client.tee_connection import ( + ActiveTEE, + TEEConnection, + _TEE_REFRESH_INTERVAL, +) +from src.opengradient.client.tee_registry import TEE_TYPE_LLM_PROXY + + +# ── Helpers ────────────────────────────────────────────────────────── + + +class FakeHTTPClient: + """Minimal stand-in for x402HttpxClient.""" + + def __init__(self, *_a, **_kw): + self.closed = False + + async def aclose(self): + self.closed = True + + +def _mock_x402_client(): + return MagicMock() + + +def _make_connection( + *, + llm_server_url: str = "https://test.tee", + registry=None, + http_factory=None, +): + """Build a TEEConnection with patched externals.""" + factory = http_factory or FakeHTTPClient + with patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=factory, + ): + return TEEConnection( + x402_client=_mock_x402_client(), + registry=registry, + llm_server_url=llm_server_url, + ) + + +# ── ActiveTEE tests ───────────────────────────────────────────────── + + +class TestActiveTEE: + def test_metadata_returns_dict(self): + tee = ActiveTEE( + endpoint="https://ep", + http_client=MagicMock(), + tee_id="tee-1", + payment_address="0xPay", + ) + meta = tee.metadata() + + assert meta == { + "tee_id": "tee-1", + "tee_endpoint": "https://ep", + "tee_payment_address": "0xPay", + } + + def test_metadata_with_none_values(self): + tee = ActiveTEE( + endpoint="https://ep", + http_client=MagicMock(), + tee_id=None, + payment_address=None, + ) + meta = tee.metadata() + + assert meta["tee_id"] is None + assert meta["tee_payment_address"] is None + + def test_frozen_dataclass(self): + tee = ActiveTEE( + endpoint="https://ep", + http_client=MagicMock(), + tee_id="tee-1", + payment_address="0xPay", + ) + with pytest.raises(AttributeError): + tee.endpoint = "https://other" + + +# ── _resolve_tee (static) tests ───────────────────────────────────── + + +class TestResolveTee: + def test_explicit_url_skips_registry(self): + endpoint, cert, tee_id, pay_addr = TEEConnection._resolve_tee( + "https://explicit.url", None + ) + + assert endpoint == "https://explicit.url" + assert cert is None + assert tee_id is None + assert pay_addr is None + + def test_no_url_and_no_registry_raises(self): + with pytest.raises(ValueError, match="Either llm_server_url or"): + TEEConnection._resolve_tee(None, None) + + def test_registry_returns_none_raises(self): + mock_reg = MagicMock() + mock_reg.get_llm_tee.return_value = None + + with pytest.raises(ValueError, match="No active LLM proxy TEE"): + TEEConnection._resolve_tee(None, mock_reg) + + def test_registry_exception_wraps_in_runtime_error(self): + mock_reg = MagicMock() + mock_reg.get_llm_tee.side_effect = Exception("rpc down") + + with pytest.raises(RuntimeError, match="Failed to fetch LLM TEE"): + TEEConnection._resolve_tee(None, mock_reg) + + def test_registry_success(self): + mock_reg = MagicMock() + mock_tee = MagicMock() + mock_tee.endpoint = "https://registry.tee" + mock_tee.tls_cert_der = b"cert-bytes" + mock_tee.tee_id = "tee-42" + mock_tee.payment_address = "0xPay" + mock_reg.get_llm_tee.return_value = mock_tee + + endpoint, cert, tee_id, pay_addr = TEEConnection._resolve_tee( + None, mock_reg + ) + + assert endpoint == "https://registry.tee" + assert cert == b"cert-bytes" + assert tee_id == "tee-42" + assert pay_addr == "0xPay" + + +# ── TEEConnection.__init__ / get / _connect tests ─────────────────── + + +class TestConnectionInit: + def test_get_returns_active_tee(self): + conn = _make_connection() + active = conn.get() + + assert isinstance(active, ActiveTEE) + assert active.endpoint == "https://test.tee" + + def test_explicit_url_sets_none_tee_id(self): + conn = _make_connection(llm_server_url="https://custom.url") + active = conn.get() + + assert active.tee_id is None + assert active.payment_address is None + + def test_ssl_verify_false_for_explicit_url(self): + """When using an explicit URL with no TLS cert, verify should be False.""" + clients = [] + + def capture_client(*args, **kwargs): + c = FakeHTTPClient(*args, **kwargs) + c._verify = kwargs.get("verify") + clients.append(c) + return c + + with patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=capture_client, + ): + TEEConnection( + x402_client=_mock_x402_client(), + llm_server_url="https://custom.url", + ) + + # verify=False because llm_server_url is set and no TLS cert + assert clients[0]._verify is False + + def test_registry_path_creates_ssl_context(self): + """When registry provides a TLS cert, an SSLContext should be built.""" + mock_reg = MagicMock() + mock_tee = MagicMock() + mock_tee.endpoint = "https://registry.tee" + mock_tee.tls_cert_der = b"fake-der" + mock_tee.tee_id = "tee-1" + mock_tee.payment_address = "0xPay" + mock_reg.get_llm_tee.return_value = mock_tee + + mock_ssl_ctx = MagicMock(spec=ssl.SSLContext) + + with ( + patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=FakeHTTPClient, + ), + patch( + "src.opengradient.client.tee_connection.build_ssl_context_from_der", + return_value=mock_ssl_ctx, + ) as mock_build, + ): + conn = TEEConnection( + x402_client=_mock_x402_client(), + registry=mock_reg, + ) + + mock_build.assert_called_once_with(b"fake-der") + assert conn.get().endpoint == "https://registry.tee" + assert conn.get().tee_id == "tee-1" + + +# ── reconnect tests ───────────────────────────────────────────────── + + +@pytest.mark.asyncio +class TestReconnect: + async def test_replaces_active_tee(self): + clients_created = [] + + def make_client(*args, **kwargs): + c = FakeHTTPClient() + clients_created.append(c) + return c + + with patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=make_client, + ): + conn = TEEConnection( + x402_client=_mock_x402_client(), + llm_server_url="https://test.tee", + ) + old_client = conn.get().http_client + + await conn.reconnect() + + assert conn.get().http_client is not old_client + assert len(clients_created) == 2 + + async def test_closes_old_client(self): + conn = _make_connection() + old_client = conn.get().http_client + old_client.aclose = AsyncMock() + + with patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=FakeHTTPClient, + ): + await conn.reconnect() + + old_client.aclose.assert_awaited_once() + + async def test_close_failure_is_swallowed(self): + conn = _make_connection() + old_client = conn.get().http_client + old_client.aclose = AsyncMock(side_effect=OSError("already closed")) + + with patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=FakeHTTPClient, + ): + # Should not raise + await conn.reconnect() + + async def test_reconnect_is_serialized(self): + """Concurrent reconnect calls should not race.""" + call_order = [] + + original_connect = TEEConnection._connect + + def slow_connect(self): + call_order.append("start") + original_connect(self) + call_order.append("end") + + conn = _make_connection() + + with patch.object(TEEConnection, "_connect", slow_connect): + await asyncio.gather(conn.reconnect(), conn.reconnect()) + + # Both should complete without interleaving (lock serializes them) + assert call_order == ["start", "end", "start", "end"] + + +# ── ensure_refresh_loop tests ──────────────────────────────────────── + + +@pytest.mark.asyncio +class TestEnsureRefreshLoop: + async def test_noop_when_llm_server_url_set(self): + conn = _make_connection(llm_server_url="https://explicit.url") + + conn.ensure_refresh_loop() + + assert conn._refresh_task is None + + async def test_starts_task_when_no_llm_server_url(self): + mock_reg = MagicMock() + mock_tee = MagicMock() + mock_tee.endpoint = "https://tee.endpoint" + mock_tee.tls_cert_der = None + mock_tee.tee_id = "tee-1" + mock_tee.payment_address = "0xPay" + mock_reg.get_llm_tee.return_value = mock_tee + + with patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=FakeHTTPClient, + ): + conn = TEEConnection( + x402_client=_mock_x402_client(), + registry=mock_reg, + ) + + conn.ensure_refresh_loop() + + assert conn._refresh_task is not None + assert not conn._refresh_task.done() + + # Cleanup + conn._refresh_task.cancel() + try: + await conn._refresh_task + except asyncio.CancelledError: + pass + + async def test_idempotent_when_already_running(self): + mock_reg = MagicMock() + mock_tee = MagicMock() + mock_tee.endpoint = "https://tee.endpoint" + mock_tee.tls_cert_der = None + mock_tee.tee_id = "tee-1" + mock_tee.payment_address = "0xPay" + mock_reg.get_llm_tee.return_value = mock_tee + + with patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=FakeHTTPClient, + ): + conn = TEEConnection( + x402_client=_mock_x402_client(), + registry=mock_reg, + ) + + conn.ensure_refresh_loop() + first_task = conn._refresh_task + + conn.ensure_refresh_loop() + + assert conn._refresh_task is first_task + + # Cleanup + conn._refresh_task.cancel() + try: + await conn._refresh_task + except asyncio.CancelledError: + pass + + +# ── _tee_refresh_loop tests ────────────────────────────────────────── + + +@pytest.mark.asyncio +class TestTeeRefreshLoop: + async def test_no_reconnect_when_tee_still_active(self): + mock_reg = MagicMock() + mock_tee = MagicMock() + mock_tee.endpoint = "https://tee.endpoint" + mock_tee.tls_cert_der = None + mock_tee.tee_id = "tee-1" + mock_tee.payment_address = "0xPay" + mock_reg.get_llm_tee.return_value = mock_tee + + active_tee = MagicMock() + active_tee.tee_id = "tee-1" + mock_reg.get_active_tees_by_type.return_value = [active_tee] + + with patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=FakeHTTPClient, + ): + conn = TEEConnection( + x402_client=_mock_x402_client(), + registry=mock_reg, + ) + + with patch.object(conn, "reconnect", new_callable=AsyncMock) as mock_reconnect: + with patch( + "src.opengradient.client.tee_connection.asyncio.sleep", + side_effect=[None, asyncio.CancelledError], + ): + with pytest.raises(asyncio.CancelledError): + await conn._tee_refresh_loop() + + mock_reconnect.assert_not_called() + + async def test_reconnects_when_tee_no_longer_active(self): + mock_reg = MagicMock() + mock_tee = MagicMock() + mock_tee.endpoint = "https://tee.endpoint" + mock_tee.tls_cert_der = None + mock_tee.tee_id = "tee-1" + mock_tee.payment_address = "0xPay" + mock_reg.get_llm_tee.return_value = mock_tee + + # Registry says a different TEE is now active + other_tee = MagicMock() + other_tee.tee_id = "tee-99" + mock_reg.get_active_tees_by_type.return_value = [other_tee] + + with patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=FakeHTTPClient, + ): + conn = TEEConnection( + x402_client=_mock_x402_client(), + registry=mock_reg, + ) + + with patch.object(conn, "reconnect", new_callable=AsyncMock) as mock_reconnect: + with patch( + "src.opengradient.client.tee_connection.asyncio.sleep", + side_effect=[None, asyncio.CancelledError], + ): + with pytest.raises(asyncio.CancelledError): + await conn._tee_refresh_loop() + + mock_reconnect.assert_awaited_once() + + async def test_registry_error_does_not_crash_loop(self): + mock_reg = MagicMock() + mock_tee = MagicMock() + mock_tee.endpoint = "https://tee.endpoint" + mock_tee.tls_cert_der = None + mock_tee.tee_id = "tee-1" + mock_tee.payment_address = "0xPay" + mock_reg.get_llm_tee.return_value = mock_tee + + # First check fails, second check cancels + mock_reg.get_active_tees_by_type.side_effect = [ + RuntimeError("rpc timeout"), + asyncio.CancelledError, + ] + + with patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=FakeHTTPClient, + ): + conn = TEEConnection( + x402_client=_mock_x402_client(), + registry=mock_reg, + ) + + with patch( + "src.opengradient.client.tee_connection.asyncio.sleep", + side_effect=[None, None], + ): + with pytest.raises(asyncio.CancelledError): + await conn._tee_refresh_loop() + + # The loop survived the first error and ran a second iteration + assert mock_reg.get_active_tees_by_type.call_count == 2 + + +# ── close tests ────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +class TestClose: + async def test_closes_http_client(self): + conn = _make_connection() + conn.get().http_client.aclose = AsyncMock() + + await conn.close() + + conn.get().http_client.aclose.assert_awaited_once() + + async def test_cancels_refresh_task(self): + conn = _make_connection() + mock_task = MagicMock() + conn._refresh_task = mock_task + + await conn.close() + + mock_task.cancel.assert_called_once() + assert conn._refresh_task is None + + async def test_close_without_refresh_task(self): + conn = _make_connection() + + # Should not raise when no refresh task exists + await conn.close() + + async def test_close_with_no_active_tee(self): + conn = _make_connection() + conn._active = None + + # Should not raise + await conn.close() From 7df98a063c0bc1f8ac85b6ce9f3d885aa3fd2b5a Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Thu, 26 Mar 2026 21:20:48 -0400 Subject: [PATCH 04/10] err message --- src/opengradient/client/llm.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index 1467bc01..9dca7b00 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -94,6 +94,11 @@ def __init__( tee_registry_address: str = DEFAULT_TEE_REGISTRY_ADDRESS, llm_server_url: Optional[str] = None, ): + if not private_key: + raise ValueError( + "A private key is required to use the LLM client. " + "Pass a valid private_key to the constructor." + ) self._wallet_account: LocalAccount = Account.from_key(private_key) # x402 payment stack (created once, reused across TEE refreshes) From 10843887bb13890b2e74430af509bc8b80ce73a6 Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Thu, 26 Mar 2026 21:40:11 -0400 Subject: [PATCH 05/10] refactor --- src/opengradient/client/llm.py | 79 +++-- src/opengradient/client/tee_connection.py | 157 +++++---- tests/client_test.py | 2 +- tests/llm_test.py | 5 +- tests/tee_connection_test.py | 395 ++++++++++++---------- 5 files changed, 348 insertions(+), 290 deletions(-) diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index 80501a15..0a74c9e9 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -15,7 +15,7 @@ from ..types import TEE_LLM, StreamChoice, StreamChunk, StreamDelta, TextGenerationOutput, x402SettlementMode from .opg_token import Permit2ApprovalResult, ensure_opg_approval -from .tee_connection import TEEConnection +from .tee_connection import RegistryTEEConnection, StaticTEEConnection, TEEConnectionInterface from .tee_registry import TEERegistry logger = logging.getLogger(__name__) @@ -62,29 +62,17 @@ class LLM: below the requested amount. Usage: + # Via on-chain registry (default) llm = og.LLM(private_key="0x...") + # Via hardcoded URL (development / self-hosted) + llm = og.LLM.from_url(private_key="0x...", llm_server_url="https://1.2.3.4") + # One-time approval (idempotent — skips if allowance is already sufficient) llm.ensure_opg_approval(opg_amount=5) result = await llm.chat(model=TEE_LLM.CLAUDE_HAIKU_4_5, messages=[...]) result = await llm.completion(model=TEE_LLM.CLAUDE_HAIKU_4_5, prompt="Hello") - - Args: - private_key (str): Ethereum private key for signing x402 payments. - rpc_url (str): RPC URL for the OpenGradient network. Used to fetch the - active TEE endpoint from the on-chain registry when ``llm_server_url`` - is not provided. - tee_registry_address (str): Address of the on-chain TEE registry contract. - llm_server_url (str, optional): Bypass the registry and connect directly - to this TEE endpoint URL (e.g. ``"https://1.2.3.4"``). When set, - TLS certificate verification is disabled automatically because - self-hosted TEE servers typically use self-signed certificates. - - .. warning:: - Using ``llm_server_url`` disables TLS certificate verification, - which removes protection against man-in-the-middle attacks. - Only connect to servers you trust and over secure network paths. """ def __init__( @@ -92,32 +80,49 @@ def __init__( private_key: str, rpc_url: str = DEFAULT_RPC_URL, tee_registry_address: str = DEFAULT_TEE_REGISTRY_ADDRESS, - llm_server_url: Optional[str] = None, ): if not private_key: - raise ValueError( - "A private key is required to use the LLM client. " - "Pass a valid private_key to the constructor." - ) + raise ValueError("A private key is required to use the LLM client. Pass a valid private_key to the constructor.") self._wallet_account: LocalAccount = Account.from_key(private_key) - # x402 payment stack (created once, reused across TEE refreshes) - signer = EthAccountSigner(self._wallet_account) - x402_client = x402Client() - register_exact_evm_client(x402_client, signer, networks=[BASE_TESTNET_NETWORK]) - register_upto_evm_client(x402_client, signer, networks=[BASE_TESTNET_NETWORK]) + x402_client = LLM._build_x402_client(private_key) + onchain_registry = TEERegistry(rpc_url=rpc_url, registry_address=tee_registry_address) + self._tee: TEEConnectionInterface = RegistryTEEConnection(x402_client=x402_client, registry=onchain_registry) - registry: Optional[TEERegistry] = ( - TEERegistry(rpc_url=rpc_url, registry_address=tee_registry_address) - if llm_server_url is None - else None - ) + @classmethod + def from_url( + cls, + private_key: str, + llm_server_url: str, + ) -> "LLM": + """**[Dev]** Create an LLM client with a hardcoded TEE endpoint URL. - self._tee = TEEConnection( - x402_client=x402_client, - registry=registry, - llm_server_url=llm_server_url, - ) + Intended for development and self-hosted TEE servers. TLS certificate + verification is disabled because these servers typically use self-signed + certificates. For production use, prefer the default constructor which + resolves TEEs from the on-chain registry. + + Args: + private_key: Ethereum private key for signing x402 payments. + llm_server_url: The TEE endpoint URL (e.g. ``"https://1.2.3.4"``). + """ + instance = cls.__new__(cls) + if not private_key: + raise ValueError("A private key is required to use the LLM client. Pass a valid private_key to the constructor.") + instance._wallet_account = Account.from_key(private_key) + x402_client = cls._build_x402_client(private_key) + instance._tee = StaticTEEConnection(x402_client=x402_client, endpoint=llm_server_url) + return instance + + @staticmethod + def _build_x402_client(private_key: str) -> x402Client: + """Build the x402 payment stack from a private key.""" + account = Account.from_key(private_key) + signer = EthAccountSigner(account) + client = x402Client() + register_exact_evm_client(client, signer, networks=[BASE_TESTNET_NETWORK]) + register_upto_evm_client(client, signer, networks=[BASE_TESTNET_NETWORK]) + return client # ── Lifecycle ─────────────────────────────────────────────────────── diff --git a/src/opengradient/client/tee_connection.py b/src/opengradient/client/tee_connection.py index 5403c530..1b6a807b 100644 --- a/src/opengradient/client/tee_connection.py +++ b/src/opengradient/client/tee_connection.py @@ -4,7 +4,7 @@ import logging import ssl from dataclasses import dataclass -from typing import Dict, Optional, Union +from typing import Dict, Optional, Protocol, Union from x402 import x402Client from x402.http.clients import x402HttpxClient @@ -34,36 +34,80 @@ def metadata(self) -> Dict: ) -class TEEConnection: - """Maintains a verified connection to a single TEE endpoint. +class TEEConnectionInterface(Protocol): + """Interface for TEE connection implementations.""" - Handles initial resolution from the on-chain registry (or an explicit URL), - TLS certificate pinning, background health checks, and automatic failover - when the current TEE becomes unavailable. + def get(self) -> ActiveTEE: ... + def ensure_refresh_loop(self) -> None: ... + async def reconnect(self) -> None: ... + async def close(self) -> None: ... - Use ``get()`` to obtain the current ``ActiveTEE`` snapshot for making requests. + +class StaticTEEConnection: + """TEE connection with a hardcoded endpoint URL. + + No registry lookup, no background refresh. TLS certificate verification + is disabled because self-hosted TEE servers typically use self-signed certs. Args: x402_client: Configured x402 payment client for creating HTTP clients. - registry: TEERegistry for looking up active TEEs. None when using an explicit URL. - llm_server_url: Bypass the registry and connect directly to this URL. + endpoint: The TEE endpoint URL to connect to. """ - def __init__( - self, - x402_client: x402Client, - registry: Optional[TEERegistry] = None, - llm_server_url: Optional[str] = None, - ): + def __init__(self, x402_client: x402Client, endpoint: str): + self._x402_client = x402_client + self._endpoint = endpoint + self._active: ActiveTEE = self._connect() + + def get(self) -> ActiveTEE: + """Return a snapshot of the current TEE connection.""" + return self._active + + def _connect(self) -> ActiveTEE: + return ActiveTEE( + endpoint=self._endpoint, + http_client=x402HttpxClient(self._x402_client, verify=False), + tee_id=None, + payment_address=None, + ) + + def ensure_refresh_loop(self) -> None: + """No-op — static connections don't refresh.""" + pass + + async def reconnect(self) -> None: + """Rebuild the HTTP client (same endpoint).""" + old_client = self._active.http_client + self._active = self._connect() + try: + await old_client.aclose() + except Exception: + logger.debug("Failed to close previous HTTP client during reconnect.", exc_info=True) + + async def close(self) -> None: + """Close the HTTP client.""" + await self._active.http_client.aclose() + + +class RegistryTEEConnection: + """TEE connection resolved from the on-chain registry. + + Handles TLS certificate pinning, background health checks, and automatic + failover when the current TEE becomes unavailable. + + Args: + x402_client: Configured x402 payment client for creating HTTP clients. + registry: TEERegistry for looking up active TEEs. + """ + + def __init__(self, x402_client: x402Client, registry: TEERegistry): self._x402_client = x402_client self._registry = registry - self._llm_server_url = llm_server_url - self._active: Optional[ActiveTEE] = None self._refresh_lock = asyncio.Lock() self._refresh_task: Optional[asyncio.Task] = None - self._connect() + self._active: ActiveTEE = self._connect() # ── Public API ────────────────────────────────────────────────────── @@ -73,28 +117,46 @@ def get(self) -> ActiveTEE: # ── Connection management ─────────────────────────────────────────── - def _connect(self) -> None: + def _resolve_tee(self): + """Resolve TEE endpoint and metadata from the on-chain registry. + + Returns: + The TEE object from the registry. + + Raises: + RuntimeError: If the registry lookup fails. + ValueError: If no active LLM proxy TEE is found. + """ + try: + tee = self._registry.get_llm_tee() + except Exception as e: + raise RuntimeError(f"Failed to fetch LLM TEE endpoint from registry: {e}") from e + + if tee is None: + raise ValueError("No active LLM proxy TEE found in the registry. Pass llm_server_url explicitly to override.") + + logger.info("Using TEE endpoint from registry: %s (teeId=%s)", tee.endpoint, tee.tee_id) + return tee + + def _connect(self) -> ActiveTEE: """Resolve TEE from registry and create a secure HTTP client.""" - endpoint, tls_cert_der, tee_id, payment_address = self._resolve_tee( - self._llm_server_url, - self._registry, - ) + tee = self._resolve_tee() - ssl_ctx = build_ssl_context_from_der(tls_cert_der) if tls_cert_der else None - tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else (self._llm_server_url is None) + ssl_ctx = build_ssl_context_from_der(tee.tls_cert_der) if tee.tls_cert_der else None + tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else True - self._active = ActiveTEE( - endpoint=endpoint, + return ActiveTEE( + endpoint=tee.endpoint, http_client=x402HttpxClient(self._x402_client, verify=tls_verify), - tee_id=tee_id, - payment_address=payment_address, + tee_id=tee.tee_id, + payment_address=tee.payment_address, ) async def reconnect(self) -> None: """Connect to a new TEE from the registry and rebuild the HTTP client.""" async with self._refresh_lock: old_client = self._active.http_client - self._connect() + self._active = self._connect() try: await old_client.aclose() except Exception: @@ -105,11 +167,8 @@ async def reconnect(self) -> None: def ensure_refresh_loop(self) -> None: """Start the background TEE refresh loop if not already running. - No-op when ``llm_server_url`` is set (bypasses the registry). Called lazily from async request methods since ``__init__`` is synchronous. """ - if self._llm_server_url is not None: - return if self._refresh_task is not None and not self._refresh_task.done(): return self._refresh_task = asyncio.create_task(self._tee_refresh_loop()) @@ -139,34 +198,4 @@ async def close(self) -> None: if self._refresh_task is not None: self._refresh_task.cancel() self._refresh_task = None - if self._active is not None: - await self._active.http_client.aclose() - - # ── Static helpers ────────────────────────────────────────────────── - - @staticmethod - def _resolve_tee( - tee_endpoint_override: Optional[str], - registry: Optional[TEERegistry], - ) -> tuple: - """Resolve TEE endpoint and metadata from the on-chain registry or explicit URL. - - Returns: - (endpoint, tls_cert_der, tee_id, payment_address) - """ - if tee_endpoint_override is not None: - return tee_endpoint_override, None, None, None - - if registry is None: - raise ValueError("Either llm_server_url or a TEERegistry instance must be provided.") - - try: - tee = registry.get_llm_tee() - except Exception as e: - raise RuntimeError(f"Failed to fetch LLM TEE endpoint from registry: {e}") from e - - if tee is None: - raise ValueError("No active LLM proxy TEE found in the registry. Pass llm_server_url explicitly to override.") - - logger.info("Using TEE endpoint from registry: %s (teeId=%s)", tee.endpoint, tee.tee_id) - return tee.endpoint, tee.tls_cert_der, tee.tee_id, tee.payment_address + await self._active.http_client.aclose() diff --git a/tests/client_test.py b/tests/client_test.py index 4cc62763..2e2bf14c 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -77,7 +77,7 @@ def test_llm_initialization(self, mock_tee_registry): def test_llm_initialization_custom_url(self, mock_tee_registry): """Test LLM initialization with custom server URL.""" custom_llm_url = "https://custom.llm.server" - llm = LLM(private_key=FAKE_PRIVATE_KEY, llm_server_url=custom_llm_url) + llm = LLM.from_url(private_key=FAKE_PRIVATE_KEY, llm_server_url=custom_llm_url) assert llm._tee.get().endpoint == custom_llm_url diff --git a/tests/llm_test.py b/tests/llm_test.py index b7867126..ce0cd48b 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -137,9 +137,10 @@ def _make_llm( endpoint: str = "https://test.tee.server", ) -> LLM: """Build an LLM with an explicit server URL (skips registry lookup).""" - llm = LLM(private_key=FAKE_PRIVATE_KEY, llm_server_url=endpoint) - # llm_server_url path sets tee_id/payment_address to None; replace with test values. from dataclasses import replace + + llm = LLM.from_url(private_key=FAKE_PRIVATE_KEY, llm_server_url=endpoint) + # from_url sets tee_id/payment_address to None; replace with test values. llm._tee._active = replace(llm._tee.get(), tee_id="test-tee-id", payment_address="0xTestPayment") return llm diff --git a/tests/tee_connection_test.py b/tests/tee_connection_test.py index 7621d163..1368f803 100644 --- a/tests/tee_connection_test.py +++ b/tests/tee_connection_test.py @@ -1,4 +1,4 @@ -"""Tests for TEEConnection and ActiveTEE. +"""Tests for StaticTEEConnection, RegistryTEEConnection, and ActiveTEE. Covers TEE resolution, connection lifecycle, reconnect, background refresh, and the ActiveTEE data snapshot. @@ -12,7 +12,8 @@ from src.opengradient.client.tee_connection import ( ActiveTEE, - TEEConnection, + RegistryTEEConnection, + StaticTEEConnection, _TEE_REFRESH_INTERVAL, ) from src.opengradient.client.tee_registry import TEE_TYPE_LLM_PROXY @@ -35,25 +36,47 @@ def _mock_x402_client(): return MagicMock() -def _make_connection( +def _make_static_connection( *, - llm_server_url: str = "https://test.tee", - registry=None, + endpoint: str = "https://test.tee", http_factory=None, ): - """Build a TEEConnection with patched externals.""" + """Build a StaticTEEConnection with patched externals.""" factory = http_factory or FakeHTTPClient with patch( "src.opengradient.client.tee_connection.x402HttpxClient", side_effect=factory, ): - return TEEConnection( + return StaticTEEConnection( + x402_client=_mock_x402_client(), + endpoint=endpoint, + ) + + +def _make_registry_connection(*, registry=None, http_factory=None): + """Build a RegistryTEEConnection with patched externals.""" + factory = http_factory or FakeHTTPClient + with patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=factory, + ): + return RegistryTEEConnection( x402_client=_mock_x402_client(), registry=registry, - llm_server_url=llm_server_url, ) +def _mock_registry_with_tee(endpoint="https://tee.endpoint", tls_cert_der=None, tee_id="tee-1", payment_address="0xPay"): + mock_reg = MagicMock() + mock_tee = MagicMock() + mock_tee.endpoint = endpoint + mock_tee.tls_cert_der = tls_cert_der + mock_tee.tee_id = tee_id + mock_tee.payment_address = payment_address + mock_reg.get_llm_tee.return_value = mock_tee + return mock_reg + + # ── ActiveTEE tests ───────────────────────────────────────────────── @@ -96,77 +119,26 @@ def test_frozen_dataclass(self): tee.endpoint = "https://other" -# ── _resolve_tee (static) tests ───────────────────────────────────── - - -class TestResolveTee: - def test_explicit_url_skips_registry(self): - endpoint, cert, tee_id, pay_addr = TEEConnection._resolve_tee( - "https://explicit.url", None - ) - - assert endpoint == "https://explicit.url" - assert cert is None - assert tee_id is None - assert pay_addr is None - - def test_no_url_and_no_registry_raises(self): - with pytest.raises(ValueError, match="Either llm_server_url or"): - TEEConnection._resolve_tee(None, None) - - def test_registry_returns_none_raises(self): - mock_reg = MagicMock() - mock_reg.get_llm_tee.return_value = None - - with pytest.raises(ValueError, match="No active LLM proxy TEE"): - TEEConnection._resolve_tee(None, mock_reg) - - def test_registry_exception_wraps_in_runtime_error(self): - mock_reg = MagicMock() - mock_reg.get_llm_tee.side_effect = Exception("rpc down") - - with pytest.raises(RuntimeError, match="Failed to fetch LLM TEE"): - TEEConnection._resolve_tee(None, mock_reg) - - def test_registry_success(self): - mock_reg = MagicMock() - mock_tee = MagicMock() - mock_tee.endpoint = "https://registry.tee" - mock_tee.tls_cert_der = b"cert-bytes" - mock_tee.tee_id = "tee-42" - mock_tee.payment_address = "0xPay" - mock_reg.get_llm_tee.return_value = mock_tee - - endpoint, cert, tee_id, pay_addr = TEEConnection._resolve_tee( - None, mock_reg - ) - - assert endpoint == "https://registry.tee" - assert cert == b"cert-bytes" - assert tee_id == "tee-42" - assert pay_addr == "0xPay" - +# ── StaticTEEConnection tests ─────────────────────────────────────── -# ── TEEConnection.__init__ / get / _connect tests ─────────────────── - -class TestConnectionInit: +class TestStaticConnectionInit: def test_get_returns_active_tee(self): - conn = _make_connection() + conn = _make_static_connection() active = conn.get() assert isinstance(active, ActiveTEE) assert active.endpoint == "https://test.tee" - def test_explicit_url_sets_none_tee_id(self): - conn = _make_connection(llm_server_url="https://custom.url") + def test_sets_none_tee_id_and_payment(self): + conn = _make_static_connection(endpoint="https://custom.url") active = conn.get() assert active.tee_id is None assert active.payment_address is None - def test_ssl_verify_false_for_explicit_url(self): - """When using an explicit URL with no TLS cert, verify should be False.""" + def test_ssl_verify_false(self): + """Static connections disable TLS verification.""" clients = [] def capture_client(*args, **kwargs): @@ -179,24 +151,147 @@ def capture_client(*args, **kwargs): "src.opengradient.client.tee_connection.x402HttpxClient", side_effect=capture_client, ): - TEEConnection( + StaticTEEConnection( x402_client=_mock_x402_client(), - llm_server_url="https://custom.url", + endpoint="https://custom.url", ) - # verify=False because llm_server_url is set and no TLS cert assert clients[0]._verify is False - def test_registry_path_creates_ssl_context(self): - """When registry provides a TLS cert, an SSLContext should be built.""" + def test_ensure_refresh_loop_is_noop(self): + conn = _make_static_connection() + conn.ensure_refresh_loop() + # No task created, no error raised + + +@pytest.mark.asyncio +class TestStaticReconnect: + async def test_replaces_active_tee(self): + clients_created = [] + + def make_client(*args, **kwargs): + c = FakeHTTPClient() + clients_created.append(c) + return c + + with patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=make_client, + ): + conn = StaticTEEConnection( + x402_client=_mock_x402_client(), + endpoint="https://test.tee", + ) + old_client = conn.get().http_client + + await conn.reconnect() + + assert conn.get().http_client is not old_client + assert len(clients_created) == 2 + + async def test_closes_old_client(self): + conn = _make_static_connection() + old_client = conn.get().http_client + old_client.aclose = AsyncMock() + + with patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=FakeHTTPClient, + ): + await conn.reconnect() + + old_client.aclose.assert_awaited_once() + + async def test_close_failure_is_swallowed(self): + conn = _make_static_connection() + old_client = conn.get().http_client + old_client.aclose = AsyncMock(side_effect=OSError("already closed")) + + with patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=FakeHTTPClient, + ): + # Should not raise + await conn.reconnect() + + +@pytest.mark.asyncio +class TestStaticClose: + async def test_closes_http_client(self): + conn = _make_static_connection() + conn.get().http_client.aclose = AsyncMock() + + await conn.close() + + conn.get().http_client.aclose.assert_awaited_once() + + +# ── RegistryTEEConnection._resolve_tee tests ──────────────────────── + + +class TestResolveTee: + def test_registry_returns_none_raises(self): + mock_reg = MagicMock() + mock_reg.get_llm_tee.return_value = None + + with patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=FakeHTTPClient, + ): + with pytest.raises(ValueError, match="No active LLM proxy TEE"): + RegistryTEEConnection(x402_client=_mock_x402_client(), registry=mock_reg) + + def test_registry_exception_wraps_in_runtime_error(self): mock_reg = MagicMock() - mock_tee = MagicMock() - mock_tee.endpoint = "https://registry.tee" - mock_tee.tls_cert_der = b"fake-der" - mock_tee.tee_id = "tee-1" - mock_tee.payment_address = "0xPay" - mock_reg.get_llm_tee.return_value = mock_tee + mock_reg.get_llm_tee.side_effect = Exception("rpc down") + + with patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=FakeHTTPClient, + ): + with pytest.raises(RuntimeError, match="Failed to fetch LLM TEE"): + RegistryTEEConnection(x402_client=_mock_x402_client(), registry=mock_reg) + + def test_registry_success(self): + mock_reg = _mock_registry_with_tee( + endpoint="https://registry.tee", + tls_cert_der=b"cert-bytes", + tee_id="tee-42", + payment_address="0xPay", + ) + + with ( + patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=FakeHTTPClient, + ), + patch( + "src.opengradient.client.tee_connection.build_ssl_context_from_der", + return_value=MagicMock(spec=ssl.SSLContext), + ), + ): + conn = RegistryTEEConnection(x402_client=_mock_x402_client(), registry=mock_reg) + + assert conn.get().endpoint == "https://registry.tee" + assert conn.get().tee_id == "tee-42" + assert conn.get().payment_address == "0xPay" + + +# ── RegistryTEEConnection init / connect tests ────────────────────── + +class TestRegistryConnectionInit: + def test_get_returns_active_tee(self): + mock_reg = _mock_registry_with_tee() + conn = _make_registry_connection(registry=mock_reg) + active = conn.get() + + assert isinstance(active, ActiveTEE) + assert active.endpoint == "https://tee.endpoint" + + def test_registry_path_creates_ssl_context(self): + """When registry provides a TLS cert, an SSLContext should be built.""" + mock_reg = _mock_registry_with_tee(tls_cert_der=b"fake-der") mock_ssl_ctx = MagicMock(spec=ssl.SSLContext) with ( @@ -209,21 +304,21 @@ def test_registry_path_creates_ssl_context(self): return_value=mock_ssl_ctx, ) as mock_build, ): - conn = TEEConnection( + conn = RegistryTEEConnection( x402_client=_mock_x402_client(), registry=mock_reg, ) mock_build.assert_called_once_with(b"fake-der") - assert conn.get().endpoint == "https://registry.tee" + assert conn.get().endpoint == "https://tee.endpoint" assert conn.get().tee_id == "tee-1" -# ── reconnect tests ───────────────────────────────────────────────── +# ── RegistryTEEConnection reconnect tests ─────────────────────────── @pytest.mark.asyncio -class TestReconnect: +class TestRegistryReconnect: async def test_replaces_active_tee(self): clients_created = [] @@ -232,13 +327,15 @@ def make_client(*args, **kwargs): clients_created.append(c) return c + mock_reg = _mock_registry_with_tee() + with patch( "src.opengradient.client.tee_connection.x402HttpxClient", side_effect=make_client, ): - conn = TEEConnection( + conn = RegistryTEEConnection( x402_client=_mock_x402_client(), - llm_server_url="https://test.tee", + registry=mock_reg, ) old_client = conn.get().http_client @@ -248,7 +345,8 @@ def make_client(*args, **kwargs): assert len(clients_created) == 2 async def test_closes_old_client(self): - conn = _make_connection() + mock_reg = _mock_registry_with_tee() + conn = _make_registry_connection(registry=mock_reg) old_client = conn.get().http_client old_client.aclose = AsyncMock() @@ -261,7 +359,8 @@ async def test_closes_old_client(self): old_client.aclose.assert_awaited_once() async def test_close_failure_is_swallowed(self): - conn = _make_connection() + mock_reg = _mock_registry_with_tee() + conn = _make_registry_connection(registry=mock_reg) old_client = conn.get().http_client old_client.aclose = AsyncMock(side_effect=OSError("already closed")) @@ -276,16 +375,18 @@ async def test_reconnect_is_serialized(self): """Concurrent reconnect calls should not race.""" call_order = [] - original_connect = TEEConnection._connect + original_connect = RegistryTEEConnection._connect def slow_connect(self): call_order.append("start") - original_connect(self) + result = original_connect(self) call_order.append("end") + return result - conn = _make_connection() + mock_reg = _mock_registry_with_tee() + conn = _make_registry_connection(registry=mock_reg) - with patch.object(TEEConnection, "_connect", slow_connect): + with patch.object(RegistryTEEConnection, "_connect", slow_connect): await asyncio.gather(conn.reconnect(), conn.reconnect()) # Both should complete without interleaving (lock serializes them) @@ -297,30 +398,9 @@ def slow_connect(self): @pytest.mark.asyncio class TestEnsureRefreshLoop: - async def test_noop_when_llm_server_url_set(self): - conn = _make_connection(llm_server_url="https://explicit.url") - - conn.ensure_refresh_loop() - - assert conn._refresh_task is None - - async def test_starts_task_when_no_llm_server_url(self): - mock_reg = MagicMock() - mock_tee = MagicMock() - mock_tee.endpoint = "https://tee.endpoint" - mock_tee.tls_cert_der = None - mock_tee.tee_id = "tee-1" - mock_tee.payment_address = "0xPay" - mock_reg.get_llm_tee.return_value = mock_tee - - with patch( - "src.opengradient.client.tee_connection.x402HttpxClient", - side_effect=FakeHTTPClient, - ): - conn = TEEConnection( - x402_client=_mock_x402_client(), - registry=mock_reg, - ) + async def test_starts_task(self): + mock_reg = _mock_registry_with_tee() + conn = _make_registry_connection(registry=mock_reg) conn.ensure_refresh_loop() @@ -335,22 +415,8 @@ async def test_starts_task_when_no_llm_server_url(self): pass async def test_idempotent_when_already_running(self): - mock_reg = MagicMock() - mock_tee = MagicMock() - mock_tee.endpoint = "https://tee.endpoint" - mock_tee.tls_cert_der = None - mock_tee.tee_id = "tee-1" - mock_tee.payment_address = "0xPay" - mock_reg.get_llm_tee.return_value = mock_tee - - with patch( - "src.opengradient.client.tee_connection.x402HttpxClient", - side_effect=FakeHTTPClient, - ): - conn = TEEConnection( - x402_client=_mock_x402_client(), - registry=mock_reg, - ) + mock_reg = _mock_registry_with_tee() + conn = _make_registry_connection(registry=mock_reg) conn.ensure_refresh_loop() first_task = conn._refresh_task @@ -373,26 +439,13 @@ async def test_idempotent_when_already_running(self): @pytest.mark.asyncio class TestTeeRefreshLoop: async def test_no_reconnect_when_tee_still_active(self): - mock_reg = MagicMock() - mock_tee = MagicMock() - mock_tee.endpoint = "https://tee.endpoint" - mock_tee.tls_cert_der = None - mock_tee.tee_id = "tee-1" - mock_tee.payment_address = "0xPay" - mock_reg.get_llm_tee.return_value = mock_tee + mock_reg = _mock_registry_with_tee(tee_id="tee-1") active_tee = MagicMock() active_tee.tee_id = "tee-1" mock_reg.get_active_tees_by_type.return_value = [active_tee] - with patch( - "src.opengradient.client.tee_connection.x402HttpxClient", - side_effect=FakeHTTPClient, - ): - conn = TEEConnection( - x402_client=_mock_x402_client(), - registry=mock_reg, - ) + conn = _make_registry_connection(registry=mock_reg) with patch.object(conn, "reconnect", new_callable=AsyncMock) as mock_reconnect: with patch( @@ -405,27 +458,14 @@ async def test_no_reconnect_when_tee_still_active(self): mock_reconnect.assert_not_called() async def test_reconnects_when_tee_no_longer_active(self): - mock_reg = MagicMock() - mock_tee = MagicMock() - mock_tee.endpoint = "https://tee.endpoint" - mock_tee.tls_cert_der = None - mock_tee.tee_id = "tee-1" - mock_tee.payment_address = "0xPay" - mock_reg.get_llm_tee.return_value = mock_tee + mock_reg = _mock_registry_with_tee(tee_id="tee-1") # Registry says a different TEE is now active other_tee = MagicMock() other_tee.tee_id = "tee-99" mock_reg.get_active_tees_by_type.return_value = [other_tee] - with patch( - "src.opengradient.client.tee_connection.x402HttpxClient", - side_effect=FakeHTTPClient, - ): - conn = TEEConnection( - x402_client=_mock_x402_client(), - registry=mock_reg, - ) + conn = _make_registry_connection(registry=mock_reg) with patch.object(conn, "reconnect", new_callable=AsyncMock) as mock_reconnect: with patch( @@ -438,13 +478,7 @@ async def test_reconnects_when_tee_no_longer_active(self): mock_reconnect.assert_awaited_once() async def test_registry_error_does_not_crash_loop(self): - mock_reg = MagicMock() - mock_tee = MagicMock() - mock_tee.endpoint = "https://tee.endpoint" - mock_tee.tls_cert_der = None - mock_tee.tee_id = "tee-1" - mock_tee.payment_address = "0xPay" - mock_reg.get_llm_tee.return_value = mock_tee + mock_reg = _mock_registry_with_tee(tee_id="tee-1") # First check fails, second check cancels mock_reg.get_active_tees_by_type.side_effect = [ @@ -452,14 +486,7 @@ async def test_registry_error_does_not_crash_loop(self): asyncio.CancelledError, ] - with patch( - "src.opengradient.client.tee_connection.x402HttpxClient", - side_effect=FakeHTTPClient, - ): - conn = TEEConnection( - x402_client=_mock_x402_client(), - registry=mock_reg, - ) + conn = _make_registry_connection(registry=mock_reg) with patch( "src.opengradient.client.tee_connection.asyncio.sleep", @@ -472,13 +499,14 @@ async def test_registry_error_does_not_crash_loop(self): assert mock_reg.get_active_tees_by_type.call_count == 2 -# ── close tests ────────────────────────────────────────────────────── +# ── RegistryTEEConnection close tests ─────────────────────────────── @pytest.mark.asyncio -class TestClose: +class TestRegistryClose: async def test_closes_http_client(self): - conn = _make_connection() + mock_reg = _mock_registry_with_tee() + conn = _make_registry_connection(registry=mock_reg) conn.get().http_client.aclose = AsyncMock() await conn.close() @@ -486,7 +514,8 @@ async def test_closes_http_client(self): conn.get().http_client.aclose.assert_awaited_once() async def test_cancels_refresh_task(self): - conn = _make_connection() + mock_reg = _mock_registry_with_tee() + conn = _make_registry_connection(registry=mock_reg) mock_task = MagicMock() conn._refresh_task = mock_task @@ -496,14 +525,8 @@ async def test_cancels_refresh_task(self): assert conn._refresh_task is None async def test_close_without_refresh_task(self): - conn = _make_connection() + mock_reg = _mock_registry_with_tee() + conn = _make_registry_connection(registry=mock_reg) # Should not raise when no refresh task exists await conn.close() - - async def test_close_with_no_active_tee(self): - conn = _make_connection() - conn._active = None - - # Should not raise - await conn.close() From 9282113435375e130a59b5f5a403c8d6224061df Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Thu, 26 Mar 2026 21:44:11 -0400 Subject: [PATCH 06/10] simplify test --- tests/tee_connection_test.py | 235 +++++------------------------------ 1 file changed, 30 insertions(+), 205 deletions(-) diff --git a/tests/tee_connection_test.py b/tests/tee_connection_test.py index 1368f803..dcd83c66 100644 --- a/tests/tee_connection_test.py +++ b/tests/tee_connection_test.py @@ -1,8 +1,4 @@ -"""Tests for StaticTEEConnection, RegistryTEEConnection, and ActiveTEE. - -Covers TEE resolution, connection lifecycle, reconnect, background refresh, -and the ActiveTEE data snapshot. -""" +"""Tests for RegistryTEEConnection and ActiveTEE.""" import asyncio import ssl @@ -13,10 +9,7 @@ from src.opengradient.client.tee_connection import ( ActiveTEE, RegistryTEEConnection, - StaticTEEConnection, - _TEE_REFRESH_INTERVAL, ) -from src.opengradient.client.tee_registry import TEE_TYPE_LLM_PROXY # ── Helpers ────────────────────────────────────────────────────────── @@ -36,23 +29,6 @@ def _mock_x402_client(): return MagicMock() -def _make_static_connection( - *, - endpoint: str = "https://test.tee", - http_factory=None, -): - """Build a StaticTEEConnection with patched externals.""" - factory = http_factory or FakeHTTPClient - with patch( - "src.opengradient.client.tee_connection.x402HttpxClient", - side_effect=factory, - ): - return StaticTEEConnection( - x402_client=_mock_x402_client(), - endpoint=endpoint, - ) - - def _make_registry_connection(*, registry=None, http_factory=None): """Build a RegistryTEEConnection with patched externals.""" factory = http_factory or FakeHTTPClient @@ -77,7 +53,7 @@ def _mock_registry_with_tee(endpoint="https://tee.endpoint", tls_cert_der=None, return mock_reg -# ── ActiveTEE tests ───────────────────────────────────────────────── +# ── Tests ──────────────────────────────────────────────────────────── class TestActiveTEE: @@ -88,9 +64,7 @@ def test_metadata_returns_dict(self): tee_id="tee-1", payment_address="0xPay", ) - meta = tee.metadata() - - assert meta == { + assert tee.metadata() == { "tee_id": "tee-1", "tee_endpoint": "https://ep", "tee_payment_address": "0xPay", @@ -104,7 +78,6 @@ def test_metadata_with_none_values(self): payment_address=None, ) meta = tee.metadata() - assert meta["tee_id"] is None assert meta["tee_payment_address"] is None @@ -119,118 +92,19 @@ def test_frozen_dataclass(self): tee.endpoint = "https://other" -# ── StaticTEEConnection tests ─────────────────────────────────────── - +@pytest.mark.asyncio +class TestRegistryTEEConnection: + # ── init / resolve ─────────────────────────────────────────── -class TestStaticConnectionInit: - def test_get_returns_active_tee(self): - conn = _make_static_connection() + async def test_get_returns_active_tee(self): + mock_reg = _mock_registry_with_tee() + conn = _make_registry_connection(registry=mock_reg) active = conn.get() assert isinstance(active, ActiveTEE) - assert active.endpoint == "https://test.tee" - - def test_sets_none_tee_id_and_payment(self): - conn = _make_static_connection(endpoint="https://custom.url") - active = conn.get() - - assert active.tee_id is None - assert active.payment_address is None - - def test_ssl_verify_false(self): - """Static connections disable TLS verification.""" - clients = [] - - def capture_client(*args, **kwargs): - c = FakeHTTPClient(*args, **kwargs) - c._verify = kwargs.get("verify") - clients.append(c) - return c - - with patch( - "src.opengradient.client.tee_connection.x402HttpxClient", - side_effect=capture_client, - ): - StaticTEEConnection( - x402_client=_mock_x402_client(), - endpoint="https://custom.url", - ) - - assert clients[0]._verify is False - - def test_ensure_refresh_loop_is_noop(self): - conn = _make_static_connection() - conn.ensure_refresh_loop() - # No task created, no error raised - - -@pytest.mark.asyncio -class TestStaticReconnect: - async def test_replaces_active_tee(self): - clients_created = [] - - def make_client(*args, **kwargs): - c = FakeHTTPClient() - clients_created.append(c) - return c - - with patch( - "src.opengradient.client.tee_connection.x402HttpxClient", - side_effect=make_client, - ): - conn = StaticTEEConnection( - x402_client=_mock_x402_client(), - endpoint="https://test.tee", - ) - old_client = conn.get().http_client - - await conn.reconnect() - - assert conn.get().http_client is not old_client - assert len(clients_created) == 2 - - async def test_closes_old_client(self): - conn = _make_static_connection() - old_client = conn.get().http_client - old_client.aclose = AsyncMock() - - with patch( - "src.opengradient.client.tee_connection.x402HttpxClient", - side_effect=FakeHTTPClient, - ): - await conn.reconnect() - - old_client.aclose.assert_awaited_once() - - async def test_close_failure_is_swallowed(self): - conn = _make_static_connection() - old_client = conn.get().http_client - old_client.aclose = AsyncMock(side_effect=OSError("already closed")) - - with patch( - "src.opengradient.client.tee_connection.x402HttpxClient", - side_effect=FakeHTTPClient, - ): - # Should not raise - await conn.reconnect() - - -@pytest.mark.asyncio -class TestStaticClose: - async def test_closes_http_client(self): - conn = _make_static_connection() - conn.get().http_client.aclose = AsyncMock() - - await conn.close() - - conn.get().http_client.aclose.assert_awaited_once() - - -# ── RegistryTEEConnection._resolve_tee tests ──────────────────────── - + assert active.endpoint == "https://tee.endpoint" -class TestResolveTee: - def test_registry_returns_none_raises(self): + async def test_resolve_none_raises(self): mock_reg = MagicMock() mock_reg.get_llm_tee.return_value = None @@ -241,7 +115,7 @@ def test_registry_returns_none_raises(self): with pytest.raises(ValueError, match="No active LLM proxy TEE"): RegistryTEEConnection(x402_client=_mock_x402_client(), registry=mock_reg) - def test_registry_exception_wraps_in_runtime_error(self): + async def test_resolve_exception_wraps_in_runtime_error(self): mock_reg = MagicMock() mock_reg.get_llm_tee.side_effect = Exception("rpc down") @@ -252,7 +126,7 @@ def test_registry_exception_wraps_in_runtime_error(self): with pytest.raises(RuntimeError, match="Failed to fetch LLM TEE"): RegistryTEEConnection(x402_client=_mock_x402_client(), registry=mock_reg) - def test_registry_success(self): + async def test_resolve_success_with_cert(self): mock_reg = _mock_registry_with_tee( endpoint="https://registry.tee", tls_cert_der=b"cert-bytes", @@ -276,21 +150,7 @@ def test_registry_success(self): assert conn.get().tee_id == "tee-42" assert conn.get().payment_address == "0xPay" - -# ── RegistryTEEConnection init / connect tests ────────────────────── - - -class TestRegistryConnectionInit: - def test_get_returns_active_tee(self): - mock_reg = _mock_registry_with_tee() - conn = _make_registry_connection(registry=mock_reg) - active = conn.get() - - assert isinstance(active, ActiveTEE) - assert active.endpoint == "https://tee.endpoint" - - def test_registry_path_creates_ssl_context(self): - """When registry provides a TLS cert, an SSLContext should be built.""" + async def test_builds_ssl_context_from_der(self): mock_reg = _mock_registry_with_tee(tls_cert_der=b"fake-der") mock_ssl_ctx = MagicMock(spec=ssl.SSLContext) @@ -310,16 +170,11 @@ def test_registry_path_creates_ssl_context(self): ) mock_build.assert_called_once_with(b"fake-der") - assert conn.get().endpoint == "https://tee.endpoint" assert conn.get().tee_id == "tee-1" + # ── reconnect ──────────────────────────────────────────────── -# ── RegistryTEEConnection reconnect tests ─────────────────────────── - - -@pytest.mark.asyncio -class TestRegistryReconnect: - async def test_replaces_active_tee(self): + async def test_reconnect_replaces_active_tee(self): clients_created = [] def make_client(*args, **kwargs): @@ -338,13 +193,12 @@ def make_client(*args, **kwargs): registry=mock_reg, ) old_client = conn.get().http_client - await conn.reconnect() assert conn.get().http_client is not old_client assert len(clients_created) == 2 - async def test_closes_old_client(self): + async def test_reconnect_closes_old_client(self): mock_reg = _mock_registry_with_tee() conn = _make_registry_connection(registry=mock_reg) old_client = conn.get().http_client @@ -358,23 +212,19 @@ async def test_closes_old_client(self): old_client.aclose.assert_awaited_once() - async def test_close_failure_is_swallowed(self): + async def test_reconnect_swallows_close_failure(self): mock_reg = _mock_registry_with_tee() conn = _make_registry_connection(registry=mock_reg) - old_client = conn.get().http_client - old_client.aclose = AsyncMock(side_effect=OSError("already closed")) + conn.get().http_client.aclose = AsyncMock(side_effect=OSError("already closed")) with patch( "src.opengradient.client.tee_connection.x402HttpxClient", side_effect=FakeHTTPClient, ): - # Should not raise - await conn.reconnect() + await conn.reconnect() # should not raise async def test_reconnect_is_serialized(self): - """Concurrent reconnect calls should not race.""" call_order = [] - original_connect = RegistryTEEConnection._connect def slow_connect(self): @@ -389,16 +239,11 @@ def slow_connect(self): with patch.object(RegistryTEEConnection, "_connect", slow_connect): await asyncio.gather(conn.reconnect(), conn.reconnect()) - # Both should complete without interleaving (lock serializes them) assert call_order == ["start", "end", "start", "end"] + # ── refresh loop ───────────────────────────────────────────── -# ── ensure_refresh_loop tests ──────────────────────────────────────── - - -@pytest.mark.asyncio -class TestEnsureRefreshLoop: - async def test_starts_task(self): + async def test_ensure_refresh_loop_starts_task(self): mock_reg = _mock_registry_with_tee() conn = _make_registry_connection(registry=mock_reg) @@ -407,40 +252,30 @@ async def test_starts_task(self): assert conn._refresh_task is not None assert not conn._refresh_task.done() - # Cleanup conn._refresh_task.cancel() try: await conn._refresh_task except asyncio.CancelledError: pass - async def test_idempotent_when_already_running(self): + async def test_ensure_refresh_loop_is_idempotent(self): mock_reg = _mock_registry_with_tee() conn = _make_registry_connection(registry=mock_reg) conn.ensure_refresh_loop() first_task = conn._refresh_task - conn.ensure_refresh_loop() assert conn._refresh_task is first_task - # Cleanup conn._refresh_task.cancel() try: await conn._refresh_task except asyncio.CancelledError: pass - -# ── _tee_refresh_loop tests ────────────────────────────────────────── - - -@pytest.mark.asyncio -class TestTeeRefreshLoop: - async def test_no_reconnect_when_tee_still_active(self): + async def test_refresh_loop_skips_when_tee_still_active(self): mock_reg = _mock_registry_with_tee(tee_id="tee-1") - active_tee = MagicMock() active_tee.tee_id = "tee-1" mock_reg.get_active_tees_by_type.return_value = [active_tee] @@ -457,10 +292,8 @@ async def test_no_reconnect_when_tee_still_active(self): mock_reconnect.assert_not_called() - async def test_reconnects_when_tee_no_longer_active(self): + async def test_refresh_loop_reconnects_when_tee_gone(self): mock_reg = _mock_registry_with_tee(tee_id="tee-1") - - # Registry says a different TEE is now active other_tee = MagicMock() other_tee.tee_id = "tee-99" mock_reg.get_active_tees_by_type.return_value = [other_tee] @@ -477,10 +310,8 @@ async def test_reconnects_when_tee_no_longer_active(self): mock_reconnect.assert_awaited_once() - async def test_registry_error_does_not_crash_loop(self): + async def test_refresh_loop_survives_registry_error(self): mock_reg = _mock_registry_with_tee(tee_id="tee-1") - - # First check fails, second check cancels mock_reg.get_active_tees_by_type.side_effect = [ RuntimeError("rpc timeout"), asyncio.CancelledError, @@ -495,16 +326,11 @@ async def test_registry_error_does_not_crash_loop(self): with pytest.raises(asyncio.CancelledError): await conn._tee_refresh_loop() - # The loop survived the first error and ran a second iteration assert mock_reg.get_active_tees_by_type.call_count == 2 + # ── close ──────────────────────────────────────────────────── -# ── RegistryTEEConnection close tests ─────────────────────────────── - - -@pytest.mark.asyncio -class TestRegistryClose: - async def test_closes_http_client(self): + async def test_close_closes_http_client(self): mock_reg = _mock_registry_with_tee() conn = _make_registry_connection(registry=mock_reg) conn.get().http_client.aclose = AsyncMock() @@ -513,7 +339,7 @@ async def test_closes_http_client(self): conn.get().http_client.aclose.assert_awaited_once() - async def test_cancels_refresh_task(self): + async def test_close_cancels_refresh_task(self): mock_reg = _mock_registry_with_tee() conn = _make_registry_connection(registry=mock_reg) mock_task = MagicMock() @@ -528,5 +354,4 @@ async def test_close_without_refresh_task(self): mock_reg = _mock_registry_with_tee() conn = _make_registry_connection(registry=mock_reg) - # Should not raise when no refresh task exists - await conn.close() + await conn.close() # should not raise From 2cf22c2ad8ae49a0f24cd8abd5596ec42406303c Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Thu, 26 Mar 2026 21:47:37 -0400 Subject: [PATCH 07/10] test --- tests/tee_connection_test.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/tests/tee_connection_test.py b/tests/tee_connection_test.py index dcd83c66..1e44f19f 100644 --- a/tests/tee_connection_test.py +++ b/tests/tee_connection_test.py @@ -53,9 +53,6 @@ def _mock_registry_with_tee(endpoint="https://tee.endpoint", tls_cert_der=None, return mock_reg -# ── Tests ──────────────────────────────────────────────────────────── - - class TestActiveTEE: def test_metadata_returns_dict(self): tee = ActiveTEE( @@ -126,25 +123,13 @@ async def test_resolve_exception_wraps_in_runtime_error(self): with pytest.raises(RuntimeError, match="Failed to fetch LLM TEE"): RegistryTEEConnection(x402_client=_mock_x402_client(), registry=mock_reg) - async def test_resolve_success_with_cert(self): + async def test_resolve_success(self): mock_reg = _mock_registry_with_tee( endpoint="https://registry.tee", - tls_cert_der=b"cert-bytes", tee_id="tee-42", payment_address="0xPay", ) - - with ( - patch( - "src.opengradient.client.tee_connection.x402HttpxClient", - side_effect=FakeHTTPClient, - ), - patch( - "src.opengradient.client.tee_connection.build_ssl_context_from_der", - return_value=MagicMock(spec=ssl.SSLContext), - ), - ): - conn = RegistryTEEConnection(x402_client=_mock_x402_client(), registry=mock_reg) + conn = _make_registry_connection(registry=mock_reg) assert conn.get().endpoint == "https://registry.tee" assert conn.get().tee_id == "tee-42" From d2d6bf818c955f254988b0a640fe83752caea333 Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Thu, 26 Mar 2026 22:02:30 -0400 Subject: [PATCH 08/10] test handling --- src/opengradient/client/llm.py | 5 +++++ src/opengradient/client/tee_connection.py | 9 +++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index 0a74c9e9..55f825c8 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from typing import AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union import httpx +import asyncio from eth_account import Account from eth_account.account import LocalAccount @@ -170,6 +171,8 @@ async def _call_with_tee_retry( return await call() except httpx.HTTPStatusError: raise + except asyncio.CancelledError: + raise except Exception as exc: logger.warning( "Connection failure during %s; refreshing TEE and retrying once: %s", @@ -424,6 +427,8 @@ async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> Async return except httpx.HTTPStatusError: raise + except asyncio.CancelledError: + raise except Exception as exc: if chunks_yielded: raise diff --git a/src/opengradient/client/tee_connection.py b/src/opengradient/client/tee_connection.py index 1b6a807b..2a7682d3 100644 --- a/src/opengradient/client/tee_connection.py +++ b/src/opengradient/client/tee_connection.py @@ -133,7 +133,7 @@ def _resolve_tee(self): raise RuntimeError(f"Failed to fetch LLM TEE endpoint from registry: {e}") from e if tee is None: - raise ValueError("No active LLM proxy TEE found in the registry. Pass llm_server_url explicitly to override.") + raise ValueError("No active LLM proxy TEE found in the registry.") logger.info("Using TEE endpoint from registry: %s (teeId=%s)", tee.endpoint, tee.tee_id) return tee @@ -155,10 +155,8 @@ def _connect(self) -> ActiveTEE: async def reconnect(self) -> None: """Connect to a new TEE from the registry and rebuild the HTTP client.""" async with self._refresh_lock: - old_client = self._active.http_client - self._active = self._connect() try: - await old_client.aclose() + self._active = self._connect() except Exception: logger.debug("Failed to close previous HTTP client during TEE refresh.", exc_info=True) @@ -188,6 +186,9 @@ async def _tee_refresh_loop(self) -> None: continue logger.info("Current TEE %s no longer active; switching to a new one.", self._active.tee_id) await self.reconnect() + except asyncio.CancelledError: + logger.debug("Background TEE health check cancelled; exiting loop.") + raise except Exception: logger.warning("Background TEE health check failed; will retry next cycle.", exc_info=True) From 2e65ef6c3cd807c8cebb2a301624251377b5d884 Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Thu, 26 Mar 2026 22:03:39 -0400 Subject: [PATCH 09/10] rm test --- tests/tee_connection_test.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/tee_connection_test.py b/tests/tee_connection_test.py index 1e44f19f..3f019123 100644 --- a/tests/tee_connection_test.py +++ b/tests/tee_connection_test.py @@ -183,19 +183,6 @@ def make_client(*args, **kwargs): assert conn.get().http_client is not old_client assert len(clients_created) == 2 - async def test_reconnect_closes_old_client(self): - mock_reg = _mock_registry_with_tee() - conn = _make_registry_connection(registry=mock_reg) - old_client = conn.get().http_client - old_client.aclose = AsyncMock() - - with patch( - "src.opengradient.client.tee_connection.x402HttpxClient", - side_effect=FakeHTTPClient, - ): - await conn.reconnect() - - old_client.aclose.assert_awaited_once() async def test_reconnect_swallows_close_failure(self): mock_reg = _mock_registry_with_tee() From dc2d28805febab3061c69280c7a6e84b1de0fcd5 Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Thu, 26 Mar 2026 22:06:10 -0400 Subject: [PATCH 10/10] capture --- src/opengradient/client/llm.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index 55f825c8..5a1e67a9 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -421,7 +421,7 @@ async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> Async headers=headers, timeout=_REQUEST_TIMEOUT, ) as response: - async for chunk in self._parse_sse_response(response): + async for chunk in self._parse_sse_response(response, tee): chunks_yielded = True yield chunk return @@ -450,10 +450,10 @@ async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> Async headers=headers, timeout=_REQUEST_TIMEOUT, ) as response: - async for chunk in self._parse_sse_response(response): + async for chunk in self._parse_sse_response(response, tee): yield chunk - async def _parse_sse_response(self, response) -> AsyncGenerator[StreamChunk, None]: + async def _parse_sse_response(self, response, tee) -> AsyncGenerator[StreamChunk, None]: """Parse an SSE response stream into StreamChunk objects.""" status_code = getattr(response, "status_code", None) if status_code is not None and status_code >= 400: @@ -491,7 +491,6 @@ async def _parse_sse_response(self, response) -> AsyncGenerator[StreamChunk, Non chunk = StreamChunk.from_sse_data(data) if chunk.is_final: - tee = self._tee.get() chunk.tee_id = tee.tee_id chunk.tee_endpoint = tee.endpoint chunk.tee_payment_address = tee.payment_address