Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions infra/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ services:
# the api can be relayed by the web's /api/auth/oauth/<provider>/callback
# route handler back to the same origin the browser sees.
OAUTH_REDIRECT_BASE: ${OAUTH_REDIRECT_BASE:-http://localhost:7090}
# Circuit breaker for the LiteLLM dispatch gateway. Default
# 'litellm'; set to 'legacy' to halt all gateway dispatches with
# a 502 (used as an emergency switch — actual fix is a deploy
# rollback).
LLM_GATEWAY: ${LLM_GATEWAY:-litellm}
ports:
- "7081:7081"
depends_on:
Expand Down
32 changes: 32 additions & 0 deletions schemas/postgres/migrations/0023_litellm_provider_matrix.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
-- 0023_litellm_provider_matrix.sql
-- Widen the provider check from {anthropic, openai} to the six-provider
-- matrix LiteLLM dispatches against. Add a default-enabled flag so a
-- workspace admin can mark which credentials should auto-link to new
-- projects. Add the project ↔ credential link table that every
-- LLM-dispatching surface joins through to find a key.

begin;

alter table workspace_llm_credential
drop constraint workspace_llm_credential_provider_check;
alter table workspace_llm_credential
add constraint workspace_llm_credential_provider_check
check (provider in (
'anthropic', 'openai', 'gemini', 'mistral', 'deepseek', 'groq'
));

alter table workspace_llm_credential
add column default_enabled boolean not null default false;

create table project_llm_credential (
project_id uuid not null references project (id) on delete cascade,
credential_id uuid not null references workspace_llm_credential (id) on delete cascade,
enabled_at timestamptz not null default now(),
enabled_by uuid references app_user (id) on delete set null,
primary key (project_id, credential_id)
);

insert into schema_migrations (version) values ('0023_litellm_provider_matrix')
on conflict (version) do nothing;

commit;
38 changes: 38 additions & 0 deletions schemas/postgres/migrations/0024_dispatch_cost.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
-- 0024_dispatch_cost.sql
-- Single source of truth for "what a dispatch cost". Every LLM call
-- writes one row here on success or failure; surface tables (playground
-- _session / comparison_run / studio_branch / eval_score / eval_run) do
-- NOT carry cost columns. surface_ref_id joins back to those tables.

begin;

create table dispatch_cost (
id uuid primary key default gen_random_uuid(),
project_id uuid not null references project (id) on delete cascade,
workspace_id uuid not null references workspace (id) on delete cascade,
surface text not null check (surface in (
'playground', 'comparisons', 'studio', 'luna', 'eval', 'poll'
)),
surface_ref_id uuid not null,
provider text not null,
model text not null,
prompt_tokens integer,
completion_tokens integer,
cost_usd numeric(10, 6) not null default 0,
cost_calculated_via text not null default 'litellm-table',
dispatched_at timestamptz not null default now(),
error_code text,
error_detail text
);

create index dispatch_cost_proj_dispatched_idx
on dispatch_cost (project_id, dispatched_at desc);
create index dispatch_cost_proj_surface_idx
on dispatch_cost (project_id, surface, dispatched_at desc);
create index dispatch_cost_surface_ref_idx
on dispatch_cost (surface, surface_ref_id);

insert into schema_migrations (version) values ('0024_dispatch_cost')
on conflict (version) do nothing;

commit;
9 changes: 9 additions & 0 deletions services/api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ dependencies = [
"python-jose[cryptography]>=3.3",
"itsdangerous>=2.1",
"clickhouse-connect>=0.7",
"litellm==1.52.16",
]

[project.optional-dependencies]
dev = [
"pytest>=8.0",
"pytest-asyncio>=0.23",
"pytest-mock>=3.12",
"httpx>=0.27",
"ruff>=0.4",
"pyright>=1.1",
Expand All @@ -32,3 +34,10 @@ build-backend = "hatchling.build"

[tool.hatch.build.targets.wheel]
packages = ["tracebility_api"]

[tool.pytest.ini_options]
asyncio_mode = "auto"
markers = [
"live: real-network test against a live LLM provider; opt-in",
]
testpaths = ["tests"]
Empty file added services/api/tests/__init__.py
Empty file.
36 changes: 36 additions & 0 deletions services/api/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Shared pytest fixtures for the api test suite.

The unit layer never touches Postgres — it mocks asyncpg.Pool. The
integration layer uses a real local Postgres database; the test runner
is responsible for running migrations against it.
"""

from __future__ import annotations

import os

import pytest

# Required-by-config env vars. Set defaults so unit tests don't have to
# pass them and config.load() doesn't raise at import time.
os.environ.setdefault("TRACEBILITY_PG_DSN", "postgres://test/test")
os.environ.setdefault("TRACEBILITY_SESSION_SECRET", "x" * 40)


@pytest.fixture
def fake_pool(mocker):
"""An asyncpg.Pool double whose execute/fetch/fetchrow/fetchval are AsyncMocks."""
pool = mocker.MagicMock(name="pool")
pool.execute = mocker.AsyncMock(return_value="INSERT 0 1")
pool.fetch = mocker.AsyncMock(return_value=[])
pool.fetchrow = mocker.AsyncMock(return_value=None)
pool.fetchval = mocker.AsyncMock(return_value=None)
return pool


@pytest.fixture
def integration_dsn() -> str:
dsn = os.environ.get("TRACEBILITY_TEST_DSN")
if not dsn:
pytest.skip("set TRACEBILITY_TEST_DSN to run integration tests")
return dsn
Empty file.
57 changes: 57 additions & 0 deletions services/api/tests/integration/test_migration_0023.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Verify migration 0023 widens providers, adds default flag, and creates link table."""

from __future__ import annotations

import asyncpg
import pytest

pytestmark = pytest.mark.asyncio


async def test_provider_check_accepts_six_providers(integration_dsn: str) -> None:
pool = await asyncpg.create_pool(integration_dsn, min_size=1, max_size=2)
try:
async with pool.acquire() as conn:
for provider in ("anthropic", "openai", "gemini", "mistral", "deepseek", "groq"):
row = await conn.fetchval(
"""
select count(*) from pg_constraint
where conname = 'workspace_llm_credential_provider_check'
and pg_get_constraintdef(oid) like $1
""",
f"%{provider}%",
)
assert row == 1, f"{provider} missing from provider check"
finally:
await pool.close()


async def test_default_enabled_column_exists(integration_dsn: str) -> None:
pool = await asyncpg.create_pool(integration_dsn, min_size=1, max_size=2)
try:
col = await pool.fetchval(
"""
select column_name from information_schema.columns
where table_name = 'workspace_llm_credential'
and column_name = 'default_enabled'
""",
)
assert col == "default_enabled"
finally:
await pool.close()


async def test_project_llm_credential_table_exists(integration_dsn: str) -> None:
pool = await asyncpg.create_pool(integration_dsn, min_size=1, max_size=2)
try:
cols = await pool.fetch(
"""
select column_name from information_schema.columns
where table_name = 'project_llm_credential'
order by ordinal_position
""",
)
names = [r["column_name"] for r in cols]
assert names == ["project_id", "credential_id", "enabled_at", "enabled_by"]
finally:
await pool.close()
51 changes: 51 additions & 0 deletions services/api/tests/integration/test_migration_0024.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Verify migration 0024 creates dispatch_cost with the expected shape."""

from __future__ import annotations

import asyncpg
import pytest

pytestmark = pytest.mark.asyncio


async def test_dispatch_cost_columns(integration_dsn: str) -> None:
pool = await asyncpg.create_pool(integration_dsn, min_size=1, max_size=2)
try:
cols = await pool.fetch(
"""
select column_name from information_schema.columns
where table_name = 'dispatch_cost'
order by ordinal_position
""",
)
names = [r["column_name"] for r in cols]
assert names == [
"id", "project_id", "workspace_id", "surface", "surface_ref_id",
"provider", "model", "prompt_tokens", "completion_tokens",
"cost_usd", "cost_calculated_via", "dispatched_at",
"error_code", "error_detail",
]
finally:
await pool.close()


async def test_surface_check_constraint_rejects_unknown(integration_dsn: str) -> None:
pool = await asyncpg.create_pool(integration_dsn, min_size=1, max_size=2)
try:
async with pool.acquire() as conn:
ws = await conn.fetchval("select id from workspace limit 1")
proj = await conn.fetchval("select id from project where workspace_id = $1 limit 1", ws)
if proj is None:
pytest.skip("test DB lacks a seeded project; integration scaffolding required")
with pytest.raises(asyncpg.exceptions.CheckViolationError):
await conn.execute(
"""
insert into dispatch_cost (
project_id, workspace_id, surface, surface_ref_id,
provider, model
) values ($1, $2, 'banana', $3, 'openai', 'gpt-4o')
""",
proj, ws, proj,
)
finally:
await pool.close()
Empty file.
46 changes: 46 additions & 0 deletions services/api/tests/live/test_litellm_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Live-network sanity for each provider in the matrix.

Skipped by default. Run with PROVIDER_LIVE_TEST=1 and the per-provider
keys set. One 5-token 'say hi' per provider; confirms LiteLLM hasn't
broken against the live API.
"""

from __future__ import annotations

import os

import litellm
import pytest

pytestmark = [
pytest.mark.live,
pytest.mark.asyncio,
pytest.mark.skipif(
os.environ.get("PROVIDER_LIVE_TEST") != "1",
reason="set PROVIDER_LIVE_TEST=1 to run",
),
]

CASES = [
("openai/gpt-4o-mini", "OPENAI_API_KEY"),
("anthropic/claude-3-5-haiku-20241022", "ANTHROPIC_API_KEY"),
("gemini/gemini-1.5-flash", "GEMINI_API_KEY"),
("mistral/mistral-small-latest", "MISTRAL_API_KEY"),
("deepseek/deepseek-chat", "DEEPSEEK_API_KEY"),
("groq/llama-3.1-8b-instant", "GROQ_API_KEY"),
]


@pytest.mark.parametrize("model,env_key", CASES)
async def test_live_provider_returns_text_and_cost(model: str, env_key: str) -> None:
api_key = os.environ.get(env_key)
if not api_key:
pytest.skip(f"{env_key} not set")
resp = await litellm.acompletion(
model=model, api_key=api_key,
messages=[{"role": "user", "content": "say hi"}],
max_tokens=8, num_retries=0,
)
assert resp.choices[0].message.content
cost = float(litellm.completion_cost(completion_response=resp) or 0)
assert cost >= 0
Empty file.
Empty file.
50 changes: 50 additions & 0 deletions services/api/tests/unit/llm/test_audit_throttle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Audit throttle: 1 event per (project, provider, code) per hour."""

from __future__ import annotations

import uuid
from unittest.mock import AsyncMock

import pytest

from tracebility_api.llm.audit_throttle import should_emit_audit

pytestmark = pytest.mark.asyncio


async def test_first_event_emits(fake_pool) -> None:
fake_pool.fetchval = AsyncMock(return_value=None)
emit = await should_emit_audit(
fake_pool,
project_id=uuid.uuid4(),
provider="openai",
action="dispatch.no_credential",
)
assert emit is True


async def test_within_window_suppresses(fake_pool, mocker) -> None:
fake_pool.fetchval = AsyncMock(return_value=1)
emit = await should_emit_audit(
fake_pool,
project_id=uuid.uuid4(),
provider="openai",
action="dispatch.no_credential",
)
assert emit is False


async def test_query_uses_one_hour_window(fake_pool) -> None:
fake_pool.fetchval = AsyncMock(return_value=None)
project_id = uuid.uuid4()
await should_emit_audit(
fake_pool,
project_id=project_id,
provider="anthropic",
action="dispatch.ceiling_exceeded",
)
fake_pool.fetchval.assert_awaited_once()
call_args = fake_pool.fetchval.await_args
sql = call_args.args[0]
assert "interval '1 hour'" in sql
assert call_args.args[1] == project_id
Loading
Loading