From 5f3a89fbe4ae5ceaaef5c0488b6ec9436dfaaa81 Mon Sep 17 00:00:00 2001 From: Taku Suzuki Date: Sat, 4 Apr 2026 23:01:52 +0900 Subject: [PATCH 1/3] remove session cost processor, set explicit cost token parametesrs --- .../io/flightdeck/api/PipelineConsumer.java | 1 - frontend/src/tabs/ExecutionTab.tsx | 2 - frontend/src/tabs/PipelineEventsTab.tsx | 7 +- .../flightdeck/monitoring/MonitoringApp.java | 1 - .../streams/FlightDeckStreamsApp.java | 18 +- .../io/flightdeck/streams/config/Topics.java | 3 - .../flightdeck/streams/model/SessionCost.java | 28 --- .../streams/model/ThinkResponse.java | 9 +- .../streams/model/UserResponse.java | 1 - .../streams/processors/EndTurnProcessor.java | 17 +- .../EnrichInputMessageProcessor.java | 40 +--- .../processors/MemoirSessionEndProcessor.java | 4 +- .../SessionCostAggregationProcessor.java | 166 --------------- .../streams/MemoirEnabledTopologyTest.java | 2 +- .../flightdeck/streams/config/TopicsTest.java | 2 - .../processors/EndTurnProcessorTest.java | 61 +++++- .../EnrichInputMessageProcessorTest.java | 51 +++-- .../SessionCostAggregationProcessorTest.java | 190 ------------------ .../processors/SessionEndProcessorTest.java | 2 +- .../io/flightdeck/think/StandaloneRunner.java | 6 +- .../think/consumer/ThinkConsumer.java | 27 ++- .../flightdeck/think/model/ThinkResponse.java | 9 +- .../think/service/ClaudeApiService.java | 5 +- .../think/service/GeminiApiService.java | 5 +- .../think/consumer/CompactionTest.java | 170 +++++++++++++--- 25 files changed, 284 insertions(+), 543 deletions(-) delete mode 100644 processor-apps/processing/src/main/java/io/flightdeck/streams/model/SessionCost.java delete mode 100644 processor-apps/processing/src/main/java/io/flightdeck/streams/processors/SessionCostAggregationProcessor.java delete mode 100644 processor-apps/processing/src/test/java/io/flightdeck/streams/processors/SessionCostAggregationProcessorTest.java diff --git a/api/chat-api/src/main/java/io/flightdeck/api/PipelineConsumer.java b/api/chat-api/src/main/java/io/flightdeck/api/PipelineConsumer.java index b61a037..1939c98 100644 --- a/api/chat-api/src/main/java/io/flightdeck/api/PipelineConsumer.java +++ b/api/chat-api/src/main/java/io/flightdeck/api/PipelineConsumer.java @@ -37,7 +37,6 @@ public class PipelineConsumer implements Runnable { P + "tool-use-result", P + "tool-use-all-complete", P + "tool-use-latency", - P + "session-cost", P + "message-output" ); diff --git a/frontend/src/tabs/ExecutionTab.tsx b/frontend/src/tabs/ExecutionTab.tsx index 3502593..1a11607 100644 --- a/frontend/src/tabs/ExecutionTab.tsx +++ b/frontend/src/tabs/ExecutionTab.tsx @@ -146,7 +146,6 @@ function buildRows(events: PipelineEvent[]): TableRow[] { // Attach cost info to the enriched sub-row if available if (currentMsgRow && v) { const cost = v.cost != null ? Number(v.cost) : null; - const llmCalls = v.llm_calls != null ? Number(v.llm_calls) : null; if (cost != null) { // Update the enriched sub-row comment to include cost const enrichedSub = currentMsgRow.subRows!.find( @@ -156,7 +155,6 @@ function buildRows(events: PipelineEvent[]): TableRow[] { const costStr = cost < 0.01 ? `$${cost.toFixed(6)}` : `$${cost.toFixed(4)}`; const parts = [enrichedSub.comment]; parts.push(`cost: ${costStr}`); - if (llmCalls != null) parts.push(`${llmCalls} LLM call${llmCalls !== 1 ? "s" : ""}`); enrichedSub.comment = parts.join(", "); } } diff --git a/frontend/src/tabs/PipelineEventsTab.tsx b/frontend/src/tabs/PipelineEventsTab.tsx index 9db69cf..e03e9f4 100644 --- a/frontend/src/tabs/PipelineEventsTab.tsx +++ b/frontend/src/tabs/PipelineEventsTab.tsx @@ -18,7 +18,6 @@ const TOPIC_LABELS: Record = { "tool-use-result": "Tool Result", "tool-use-all-complete": "Tools Complete", "tool-use-latency": "Tool Latency", - "session-cost": "Session Cost", "message-output": "Final Output", }; @@ -32,7 +31,6 @@ const TOPIC_ICONS: Record = { "tool-use-result": "RES", "tool-use-all-complete": "ALL", "tool-use-latency": "LAT", - "session-cost": "CST", "message-output": "OUT", }; @@ -46,7 +44,6 @@ const TOPIC_COLORS: Record = { "tool-use-result": "#10b981", "tool-use-all-complete": "#14b8a6", "tool-use-latency": "#06b6d4", - "session-cost": "#f97316", "message-output": "#22c55e", }; @@ -88,9 +85,7 @@ function eventSummary(event: PipelineEvent): string { case "enriched-message-input": return `history: ${Array.isArray(v.history) ? v.history.length : 0} items`; case "session-context": - return `turns: ${v.llm_calls ?? 0}, cost: ${v.cost != null ? formatDollars(Number(v.cost)) : "-"}`; - case "session-cost": - return `cost: ${v.total_cost != null ? formatDollars(Number(v.total_cost)) : JSON.stringify(v).slice(0, 60)}`; + return `cost: ${v.cost != null ? formatDollars(Number(v.cost)) : "-"}`; default: return truncate(JSON.stringify(v), 80); } diff --git a/monitoring/logging-consumer/src/main/java/io/flightdeck/monitoring/MonitoringApp.java b/monitoring/logging-consumer/src/main/java/io/flightdeck/monitoring/MonitoringApp.java index 97da752..ec51681 100644 --- a/monitoring/logging-consumer/src/main/java/io/flightdeck/monitoring/MonitoringApp.java +++ b/monitoring/logging-consumer/src/main/java/io/flightdeck/monitoring/MonitoringApp.java @@ -33,7 +33,6 @@ public class MonitoringApp { P + "tool-use-result", P + "tool-use-all-complete", P + "tool-use-latency", - P + "session-cost", P + "message-output", P + "session-end", P + "memoir-context", diff --git a/processor-apps/processing/src/main/java/io/flightdeck/streams/FlightDeckStreamsApp.java b/processor-apps/processing/src/main/java/io/flightdeck/streams/FlightDeckStreamsApp.java index c72cf86..d40e479 100644 --- a/processor-apps/processing/src/main/java/io/flightdeck/streams/FlightDeckStreamsApp.java +++ b/processor-apps/processing/src/main/java/io/flightdeck/streams/FlightDeckStreamsApp.java @@ -5,11 +5,9 @@ import io.flightdeck.streams.processors.EndTurnProcessor; import io.flightdeck.streams.processors.EnrichInputMessageProcessor; import io.flightdeck.streams.processors.ExtractToolUseItemsProcessor; -import io.flightdeck.streams.processors.SessionCostAggregationProcessor; import io.flightdeck.streams.processors.MemoirSessionEndProcessor; import io.flightdeck.streams.processors.SessionEndProcessor; import io.flightdeck.streams.processors.TransformToolUseDoneProcessor; -import io.flightdeck.streams.model.SessionCost; import io.flightdeck.streams.model.ThinkResponse; import io.flightdeck.streams.serdes.JsonSerde; import org.apache.kafka.clients.admin.AdminClient; @@ -47,7 +45,7 @@ public class FlightDeckStreamsApp { static final String MEMOIR_CONTEXT_STORE = "memoir-context-store"; static final String THINK_RESPONSE_STORE = "think-response-store"; - static final String SESSION_COST_TABLE_STORE = "session-cost-table-store"; + public static void main(String[] args) { Properties props = buildConfig(); @@ -120,20 +118,9 @@ static Topology buildTopology(boolean memoirEnabled) { .withValueSerde(JsonSerde.of(ThinkResponse.class)) ); - // ── Shared KTable: session-cost (aggregated cost per session) ──────── - KTable sessionCostTable = builder.table( - Topics.SESSION_COST, - Consumed.with(Serdes.String(), JsonSerde.of(SessionCost.class)), - Materialized.as( - Stores.persistentKeyValueStore(SESSION_COST_TABLE_STORE)) - .withKeySerde(Serdes.String()) - .withValueSerde(JsonSerde.of(SessionCost.class)) - ); - // ── Register each processor fragment ────────────────────────────────── - EnrichInputMessageProcessor.register(builder, memoirTable, thinkTable, sessionCostTable); + EnrichInputMessageProcessor.register(builder, memoirTable, thinkTable); ExtractToolUseItemsProcessor.register(builder, thinkStream); - SessionCostAggregationProcessor.register(builder, thinkStream); EndTurnProcessor.register(builder, thinkStream); AggregateToolExecutionResultProcessor.register(builder); TransformToolUseDoneProcessor.register(builder); @@ -169,7 +156,6 @@ private static void ensureTopicsExist(Properties streamsProps) { Topics.TOOL_USE_DLQ, Topics.TOOL_USE_RESULT, Topics.TOOL_USE_ALL_COMPLETE, - Topics.SESSION_COST, Topics.TOOL_USE_LATENCY, Topics.MESSAGE_OUTPUT )); diff --git a/processor-apps/processing/src/main/java/io/flightdeck/streams/config/Topics.java b/processor-apps/processing/src/main/java/io/flightdeck/streams/config/Topics.java index 2e90767..011f7be 100644 --- a/processor-apps/processing/src/main/java/io/flightdeck/streams/config/Topics.java +++ b/processor-apps/processing/src/main/java/io/flightdeck/streams/config/Topics.java @@ -49,9 +49,6 @@ private static String requireEnv(String key) { public static final String TOOL_USE_ALL_COMPLETE = PREFIX + "tool-use-all-complete"; // ── Observability ───────────────────────────────────────────────────────── - /** Aggregated cost (tokens × pricing) per conversation session */ - public static final String SESSION_COST = PREFIX + "session-cost"; - /** Per-tool latency metrics, keyed by tool_name */ public static final String TOOL_USE_LATENCY = PREFIX + "tool-use-latency"; diff --git a/processor-apps/processing/src/main/java/io/flightdeck/streams/model/SessionCost.java b/processor-apps/processing/src/main/java/io/flightdeck/streams/model/SessionCost.java deleted file mode 100644 index b69844f..0000000 --- a/processor-apps/processing/src/main/java/io/flightdeck/streams/model/SessionCost.java +++ /dev/null @@ -1,28 +0,0 @@ -package io.flightdeck.streams.model; - -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.annotation.JsonProperty; - -/** - * Running cost aggregate for a single conversation session. - * Published onto {@code session-cost} by {@code SessionCostAggregationProcessor} - * every time a new {@link ThinkResponse} is processed. - * - *

The diagram annotation: "Aggregate cost per conversation" and - * "Emit Tombstone when aggregated". - */ -@JsonIgnoreProperties(ignoreUnknown = true) -public record SessionCost( - @JsonProperty("session_id") String sessionId, - @JsonProperty("user_id") String userId, - @JsonProperty("llm_calls") int llmCalls, - @JsonProperty("total_input_tokens") int totalInputTokens, - @JsonProperty("total_output_tokens") int totalOutputTokens, - @JsonProperty("estimated_cost_usd") Double estimatedCostUsd, - @JsonProperty("timestamp") String timestamp -) { - /** Zero-value initialiser used by the Kafka Streams aggregator. */ - public static SessionCost zero(String sessionId, String userId) { - return new SessionCost(sessionId, userId, 0, 0, 0, null, null); - } -} \ No newline at end of file diff --git a/processor-apps/processing/src/main/java/io/flightdeck/streams/model/ThinkResponse.java b/processor-apps/processing/src/main/java/io/flightdeck/streams/model/ThinkResponse.java index 137e875..a33232f 100644 --- a/processor-apps/processing/src/main/java/io/flightdeck/streams/model/ThinkResponse.java +++ b/processor-apps/processing/src/main/java/io/flightdeck/streams/model/ThinkResponse.java @@ -20,10 +20,11 @@ public record ThinkResponse( @JsonProperty("session_id") String sessionId, @JsonProperty("user_id") String userId, - @JsonProperty("cost") Double cost, - @JsonProperty("prev_session_cost") Double prevSessionCost, - @JsonProperty("input_tokens") int inputTokens, - @JsonProperty("output_tokens") int outputTokens, + @JsonProperty("total_session_cost") Double totalSessionCost, + @JsonProperty("previous_session_cost") Double previousSessionCost, + @JsonProperty("think_cost") Double thinkCost, + @JsonProperty("think_input_tokens") int thinkInputTokens, + @JsonProperty("think_output_tokens") int thinkOutputTokens, @JsonProperty("previous_messages") List previousMessages, @JsonProperty("last_input_message") MessageInput lastInputMessage, @JsonProperty("last_input_response") List lastInputResponse, diff --git a/processor-apps/processing/src/main/java/io/flightdeck/streams/model/UserResponse.java b/processor-apps/processing/src/main/java/io/flightdeck/streams/model/UserResponse.java index 2b8bc52..f2e3d38 100644 --- a/processor-apps/processing/src/main/java/io/flightdeck/streams/model/UserResponse.java +++ b/processor-apps/processing/src/main/java/io/flightdeck/streams/model/UserResponse.java @@ -16,7 +16,6 @@ public record UserResponse( @JsonProperty("session_id") String sessionId, @JsonProperty("user_id") String userId, @JsonProperty("content") String content, - @JsonProperty("llm_calls") int llmCalls, @JsonProperty("input_tokens") int inputTokens, @JsonProperty("output_tokens") int outputTokens, @JsonProperty("cost") Double cost, diff --git a/processor-apps/processing/src/main/java/io/flightdeck/streams/processors/EndTurnProcessor.java b/processor-apps/processing/src/main/java/io/flightdeck/streams/processors/EndTurnProcessor.java index 9dda8f3..0a04f2c 100644 --- a/processor-apps/processing/src/main/java/io/flightdeck/streams/processors/EndTurnProcessor.java +++ b/processor-apps/processing/src/main/java/io/flightdeck/streams/processors/EndTurnProcessor.java @@ -133,10 +133,9 @@ static UserResponse toUserResponse(String sessionId, ThinkResponse response) { sessionId, response.userId(), content, - 1, // this response represents one LLM call - response.inputTokens(), - response.outputTokens(), - totalCost(response.prevSessionCost(), response.cost()), + response.thinkInputTokens(), + response.thinkOutputTokens(), + response.totalSessionCost(), sourceAgent, Instant.now().toString() ); @@ -147,16 +146,6 @@ static UserResponse toUserResponse(String sessionId, ThinkResponse response) { * Returns an empty string if the list is null, empty, or contains no * assistant messages. */ - /** - * Computes total session cost: prev_session_cost + current call cost. - * Returns null if both are null. - */ - static Double totalCost(Double prevSessionCost, Double callCost) { - if (prevSessionCost == null && callCost == null) return null; - return (prevSessionCost != null ? prevSessionCost : 0.0) - + (callCost != null ? callCost : 0.0); - } - static String assembleContent(List messages) { if (messages == null || messages.isEmpty()) return ""; diff --git a/processor-apps/processing/src/main/java/io/flightdeck/streams/processors/EnrichInputMessageProcessor.java b/processor-apps/processing/src/main/java/io/flightdeck/streams/processors/EnrichInputMessageProcessor.java index 5ebf74a..7247c7a 100644 --- a/processor-apps/processing/src/main/java/io/flightdeck/streams/processors/EnrichInputMessageProcessor.java +++ b/processor-apps/processing/src/main/java/io/flightdeck/streams/processors/EnrichInputMessageProcessor.java @@ -3,7 +3,6 @@ import io.flightdeck.streams.config.Topics; import io.flightdeck.streams.model.MessageInput; import io.flightdeck.streams.model.FullSessionContext; -import io.flightdeck.streams.model.SessionCost; import io.flightdeck.streams.model.ThinkResponse; import io.flightdeck.streams.serdes.JsonSerde; import org.apache.kafka.common.serialization.Serdes; @@ -27,13 +26,10 @@ * │◄────────────── think-request-response (KTable — previous turn's full state) * │ * │ leftJoin - * │◄────────────── session-cost (KTable — aggregated cost per session) - * │ - * │ leftJoin * │◄────────────── memoir-context (KTable — long-term memoir, shared) * │ * ▼ - * enriched-message-input (KStream — history + cost + memoir + latest input) + * enriched-message-input (KStream — history + memoir + latest input) * * *

History is reconstructed from the previous ThinkResponse: @@ -46,12 +42,10 @@ public class EnrichInputMessageProcessor { /** * @param memoirTable shared KTable for memoir-context (keyed by userId) * @param thinkTable shared KTable for think-request-response (keyed by sessionId) - * @param sessionCostTable shared KTable for session-cost (keyed by sessionId) */ public static void register(StreamsBuilder builder, KTable memoirTable, - KTable thinkTable, - KTable sessionCostTable) { + KTable thinkTable) { // ── Left side: incoming user messages ──────────────────────────────── KStream inputStream = builder.stream( @@ -75,18 +69,6 @@ public static void register(StreamsBuilder builder, ) ); - // ── Join: enriched ⟕ session-cost (attach aggregated cost) ────────── - enriched = enriched - .leftJoin( - sessionCostTable, - EnrichInputMessageProcessor::enrichWithCost, - Joined.with( - Serdes.String(), - JsonSerde.of(FullSessionContext.class), - JsonSerde.of(SessionCost.class) - ) - ); - // If memoir is enabled, re-key by userId, join with memoir, re-key back if (memoirTable != null) { enriched = enriched @@ -144,7 +126,7 @@ static FullSessionContext enrichWithThinkResponse(MessageInput message, ThinkRes return new FullSessionContext( message.sessionId(), userId, - null, + (prevResponse != null) ? prevResponse.totalSessionCost() : null, history, message, null, @@ -152,22 +134,6 @@ static FullSessionContext enrichWithThinkResponse(MessageInput message, ThinkRes ); } - /** - * Join: attach aggregated session cost from session-cost KTable. - */ - static FullSessionContext enrichWithCost(FullSessionContext enriched, SessionCost sessionCost) { - Double cost = (sessionCost != null) ? sessionCost.estimatedCostUsd() : null; - return new FullSessionContext( - enriched.sessionId(), - enriched.userId(), - cost, - enriched.history(), - enriched.latestInput(), - enriched.memoirContext(), - enriched.timestamp() - ); - } - /** * Join: attach memoir context to the already-enriched session context. */ diff --git a/processor-apps/processing/src/main/java/io/flightdeck/streams/processors/MemoirSessionEndProcessor.java b/processor-apps/processing/src/main/java/io/flightdeck/streams/processors/MemoirSessionEndProcessor.java index c6391f0..7e2c1d8 100644 --- a/processor-apps/processing/src/main/java/io/flightdeck/streams/processors/MemoirSessionEndProcessor.java +++ b/processor-apps/processing/src/main/java/io/flightdeck/streams/processors/MemoirSessionEndProcessor.java @@ -106,8 +106,8 @@ static MemoirSessionEnd buildSnapshot(ThinkResponse think, String memoirCtx) { if (think.lastInputResponse() != null) fullHistory.addAll(think.lastInputResponse()); ThinkResponse asResponse = new ThinkResponse( - think.sessionId(), think.userId(), think.cost(), think.prevSessionCost(), - 0, 0, + think.sessionId(), think.userId(), think.totalSessionCost(), think.previousSessionCost(), + think.thinkCost(), 0, 0, fullHistory, // previousMessages = full history for memoir null, null, // lastInputMessage, lastInputResponse not needed for memoir null, true, false, 0, 0, 0.0, Instant.now().toString()); diff --git a/processor-apps/processing/src/main/java/io/flightdeck/streams/processors/SessionCostAggregationProcessor.java b/processor-apps/processing/src/main/java/io/flightdeck/streams/processors/SessionCostAggregationProcessor.java deleted file mode 100644 index 9f0564d..0000000 --- a/processor-apps/processing/src/main/java/io/flightdeck/streams/processors/SessionCostAggregationProcessor.java +++ /dev/null @@ -1,166 +0,0 @@ -package io.flightdeck.streams.processors; - -import io.flightdeck.streams.config.Topics; -import io.flightdeck.streams.model.SessionCost; -import io.flightdeck.streams.model.ThinkResponse; -import io.flightdeck.streams.serdes.JsonSerde; -import org.apache.kafka.common.serialization.Serdes; -import org.apache.kafka.streams.StreamsBuilder; -import org.apache.kafka.streams.kstream.*; -import org.apache.kafka.streams.state.Stores; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.time.Instant; -import java.util.Map; - -/** - *

Session Cost Aggregation Processor

- * - *

Implements the "Aggregate cost per conversation" beige node - * visible in the architecture diagram, sitting between - * {@code think-request-response} and {@code session-cost}. - * - *

Topology fragment

- *
- *   think-request-response  (KStream)
- *       │
- *       ▼  groupByKey(session_id)
- *   aggregate()
- *       ├─ llm_calls            += 1
- *       ├─ total_input_tokens   += response.inputTokens
- *       ├─ total_output_tokens  += response.outputTokens
- *       └─ estimated_cost_usd   += response.cost
- *       │
- *       ▼  Materialized KTable  ("session-cost-store")
- *       │
- *       ▼  toStream()
- *   session-cost  (compacted changelog topic)
- * 
- * - *

Pricing model

- * The processor trusts the {@code cost} field already present on each - * {@link ThinkResponse} (set by the upstream Think consumer which has - * access to the exact token pricing via {@code INPUT_TOKEN_PRICE} and - * {@code OUTPUT_TOKEN_PRICE} environment variables at call time). - * This keeps pricing logic in one place and makes the aggregator - * a pure accumulator. - * - *

Tombstone note

- * The diagram annotates "Emit Tombstone when aggregated". - * A tombstone (null-value record) is emitted on {@code session-cost} when a - * session is explicitly closed by sending a {@link ThinkResponse} whose - * {@code endTurn} flag is {@code true} AND whose {@code cost} is exactly - * {@code -1.0} (the sentinel close signal). Downstream consumers can use - * this to evict session state from their own stores. - */ -public class SessionCostAggregationProcessor { - - private static final Logger log = LoggerFactory.getLogger(SessionCostAggregationProcessor.class); - - /** Name of the persistent RocksDB store backing the session-cost KTable. */ - public static final String SESSION_COST_STORE = "session-cost-store"; - - /** Sentinel cost value that signals a session-close / tombstone event. */ - static final double SESSION_CLOSE_SENTINEL = -1.0; - - public static void register(StreamsBuilder builder, KStream thinkStream) { - - // ── Separate close signals from normal LLM responses ───────────────── - // KStream.branch() was removed in Kafka Streams 4.x; use split() + Branched instead. - // Named.as("split-") provides a prefix; the full map key becomes "split-" + branch name. - Map> branches = thinkStream - .split(Named.as("split-")) - .branch( - (sid, r) -> r != null && r.cost() != null && r.cost() == SESSION_CLOSE_SENTINEL, - Branched.as("close-signals") - ) - .branch( - (sid, r) -> r != null, - Branched.as("normal-responses") - ) - .noDefaultBranch(); - - KStream closeSignals = branches.get("split-close-signals"); - KStream normalResponses = branches.get("split-normal-responses"); - - // ── Aggregate normal responses into a running SessionCost ───────────── - KTable costTable = normalResponses - .groupByKey(Grouped.with(Serdes.String(), JsonSerde.of(ThinkResponse.class))) - .aggregate( - // Initialiser - () -> SessionCost.zero("unknown", null), - - // Aggregator - (sessionId, response, current) -> { - // Accumulate cost: null + null = null, null + value = value, value + value = sum - Double callCost = response.cost(); - Double totalCost; - if (current.estimatedCostUsd() == null && callCost == null) { - totalCost = null; - } else { - totalCost = (current.estimatedCostUsd() != null ? current.estimatedCostUsd() : 0.0) - + (callCost != null ? callCost : 0.0); - } - - SessionCost updated = new SessionCost( - sessionId, - resolveUserId(current.userId(), response.userId()), - current.llmCalls() + 1, - current.totalInputTokens() + response.inputTokens(), - current.totalOutputTokens() + response.outputTokens(), - totalCost, - Instant.now().toString() - ); - - log.info("[{}] Cost updated — calls={} input_tok={} output_tok={} total_usd={}", - sessionId, - updated.llmCalls(), - updated.totalInputTokens(), - updated.totalOutputTokens(), - updated.estimatedCostUsd() != null - ? String.format("$%.6f", updated.estimatedCostUsd()) : "null"); - - return updated; - }, - - // Materialized persistent store - Materialized.as( - Stores.persistentKeyValueStore(SESSION_COST_STORE)) - .withKeySerde(Serdes.String()) - .withValueSerde(JsonSerde.of(SessionCost.class)) - ); - - // ── Publish running totals to session-cost topic ────────────────────── - costTable - .toStream() - .peek((sid, cost) -> log.debug("[{}] → {} usd={}", - sid, Topics.SESSION_COST, - cost != null && cost.estimatedCostUsd() != null - ? String.format("$%.6f", cost.estimatedCostUsd()) : "null")) - .to(Topics.SESSION_COST, - Produced.with(Serdes.String(), JsonSerde.of(SessionCost.class))); - - // ── Emit tombstone on session close ─────────────────────────────────── - // A null value on a compacted topic signals downstream consumers to - // delete the key — standard Kafka tombstone pattern. - closeSignals - .peek((sid, r) -> log.info("[{}] Session closed — emitting tombstone on {}", - sid, Topics.SESSION_COST)) - .mapValues(r -> (SessionCost) null) // explicit null = tombstone - .to(Topics.SESSION_COST, - Produced.with(Serdes.String(), JsonSerde.of(SessionCost.class))); - } - - // ───────────────────────────────────────────────────────────────────────── - // Package-private helpers (also used by tests) - // ───────────────────────────────────────────────────────────────────────── - - /** - * Returns {@code incoming} if non-blank, otherwise preserves - * {@code existing} to maintain user identity across turns. - */ - static String resolveUserId(String existing, String incoming) { - return (incoming != null && !incoming.isBlank()) ? incoming : existing; - } -} \ No newline at end of file diff --git a/processor-apps/processing/src/test/java/io/flightdeck/streams/MemoirEnabledTopologyTest.java b/processor-apps/processing/src/test/java/io/flightdeck/streams/MemoirEnabledTopologyTest.java index 95efdaf..9bd105d 100644 --- a/processor-apps/processing/src/test/java/io/flightdeck/streams/MemoirEnabledTopologyTest.java +++ b/processor-apps/processing/src/test/java/io/flightdeck/streams/MemoirEnabledTopologyTest.java @@ -157,7 +157,7 @@ void memoirContextIsNull() { @Test @DisplayName("Enriched message still includes session history from ThinkResponse") void sessionHistoryStillWorks() { - ThinkResponse prevResponse = new ThinkResponse("sess-2", "user-2", 0.01, null, 100, 50, + ThinkResponse prevResponse = new ThinkResponse("sess-2", "user-2", 0.01, null, 0.01, 100, 50, null, null, List.of(assistantMsg("sess-2", "user-2", "Prior reply.")), List.of(), true, false, 0, 0, 0.0, TS); diff --git a/processor-apps/processing/src/test/java/io/flightdeck/streams/config/TopicsTest.java b/processor-apps/processing/src/test/java/io/flightdeck/streams/config/TopicsTest.java index bb5e2da..4d43b77 100644 --- a/processor-apps/processing/src/test/java/io/flightdeck/streams/config/TopicsTest.java +++ b/processor-apps/processing/src/test/java/io/flightdeck/streams/config/TopicsTest.java @@ -37,7 +37,6 @@ void allTopicsArePrefixed() { assertThat(Topics.TOOL_USE_DLQ).startsWith(prefix); assertThat(Topics.TOOL_USE_RESULT).startsWith(prefix); assertThat(Topics.TOOL_USE_ALL_COMPLETE).startsWith(prefix); - assertThat(Topics.SESSION_COST).startsWith(prefix); assertThat(Topics.TOOL_USE_LATENCY).startsWith(prefix); assertThat(Topics.SESSION_END).startsWith(prefix); assertThat(Topics.MEMOIR_CONTEXT).startsWith(prefix); @@ -58,7 +57,6 @@ void topicBaseNames() { assertThat(Topics.TOOL_USE_DLQ).isEqualTo(p + "tool-use-dlq"); assertThat(Topics.TOOL_USE_RESULT).isEqualTo(p + "tool-use-result"); assertThat(Topics.TOOL_USE_ALL_COMPLETE).isEqualTo(p + "tool-use-all-complete"); - assertThat(Topics.SESSION_COST).isEqualTo(p + "session-cost"); assertThat(Topics.TOOL_USE_LATENCY).isEqualTo(p + "tool-use-latency"); assertThat(Topics.SESSION_END).isEqualTo(p + "session-end"); assertThat(Topics.MEMOIR_CONTEXT).isEqualTo(p + "memoir-context"); diff --git a/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/EndTurnProcessorTest.java b/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/EndTurnProcessorTest.java index 2671732..1cab7b5 100644 --- a/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/EndTurnProcessorTest.java +++ b/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/EndTurnProcessorTest.java @@ -96,7 +96,7 @@ void endTurn_withToolCalls_isDropped() { @DisplayName("endTurn=true with null tool_uses list is forwarded") void endTurn_nullToolUses_isForwarded() { ThinkResponse response = new ThinkResponse("sess-4", "user-D", 0.01, null, - 100, 50, + 0.01, 100, 50, null, null, List.of(assistantMsg("sess-4", "user-D", "Done.")), null, true, false, 0, 0, 0.0, TS); @@ -122,7 +122,6 @@ void userResponse_fieldMapping() { assertThat(r.inputTokens()).isEqualTo(200); assertThat(r.outputTokens()).isEqualTo(75); assertThat(r.cost()).isCloseTo(0.0042, within(0.000001)); - assertThat(r.llmCalls()).isEqualTo(1); } @Test @@ -233,13 +232,65 @@ void toUserResponse_contentFromAssistant() { @Test @DisplayName("toUserResponse: empty content when ThinkResponse has no messages") void toUserResponse_noMessages_emptyContent() { - ThinkResponse resp = new ThinkResponse("s", "u", 0.0, null, 0, 0, + ThinkResponse resp = new ThinkResponse("s", "u", 0.0, null, 0.0, 0, 0, null, null, List.of(), List.of(), true, false, 0, 0, 0.0, TS); UserResponse result = toUserResponse("s", resp); assertThat(result.content()).isEmpty(); } + // ── total_session_cost ──────────────────────────────────────────────────── + + @Test + @DisplayName("UserResponse.cost carries totalSessionCost (previous + think + compaction)") + void totalSessionCost_flowsToUserResponse() { + // totalSessionCost = 0.05 (previous) + 0.01 (think) + 0.002 (compaction) = 0.062 + ThinkResponse resp = new ThinkResponse("sess-t", "user-T", + 0.062, 0.05, 0.01, 200, 75, + null, null, + List.of(assistantMsg("sess-t", "user-T", "Answer.")), + null, true, + true, 100, 20, 0.002, TS); + + thinkInput.pipeInput("sess-t", resp); + + UserResponse r = messageOutput.readRecord().value(); + assertThat(r.cost()).isCloseTo(0.062, within(0.000001)); + } + + @Test + @DisplayName("UserResponse.cost is null when totalSessionCost is null") + void totalSessionCost_null_flowsAsNull() { + ThinkResponse resp = new ThinkResponse("sess-n", "user-N", + null, null, null, 100, 50, + null, null, + List.of(assistantMsg("sess-n", "user-N", "Answer.")), + null, true, + false, 0, 0, 0.0, TS); + + thinkInput.pipeInput("sess-n", resp); + + UserResponse r = messageOutput.readRecord().value(); + assertThat(r.cost()).isNull(); + } + + @Test + @DisplayName("thinkInputTokens and thinkOutputTokens flow to UserResponse") + void thinkTokens_flowToUserResponse() { + ThinkResponse resp = new ThinkResponse("sess-tk", "user-TK", + 0.01, null, 0.01, 350, 120, + null, null, + List.of(assistantMsg("sess-tk", "user-TK", "Response.")), + null, true, + false, 0, 0, 0.0, TS); + + thinkInput.pipeInput("sess-tk", resp); + + UserResponse r = messageOutput.readRecord().value(); + assertThat(r.inputTokens()).isEqualTo(350); + assertThat(r.outputTokens()).isEqualTo(120); + } + // ── Helpers ─────────────────────────────────────────────────────────────── private static final String TS = "2026-03-10T12:00:00Z"; @@ -254,14 +305,14 @@ private static ThinkResponse endTurnResponse(String sessionId, String userId, List messages, List tools, double cost, int inputTokens, int outputTokens) { - return new ThinkResponse(sessionId, userId, cost, null, inputTokens, outputTokens, + return new ThinkResponse(sessionId, userId, cost, null, cost, inputTokens, outputTokens, null, null, messages, tools, true, false, 0, 0, 0.0, TS); } private static ThinkResponse midTurnResponse(String sessionId, String userId, List messages, List tools) { - return new ThinkResponse(sessionId, userId, 0.005, null, 150, 60, + return new ThinkResponse(sessionId, userId, 0.005, null, 0.005, 150, 60, null, null, messages, tools, false, false, 0, 0, 0.0, TS); } diff --git a/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/EnrichInputMessageProcessorTest.java b/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/EnrichInputMessageProcessorTest.java index a4017f4..526ccc6 100644 --- a/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/EnrichInputMessageProcessorTest.java +++ b/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/EnrichInputMessageProcessorTest.java @@ -1,7 +1,9 @@ package io.flightdeck.streams.processors; import io.flightdeck.streams.config.Topics; -import io.flightdeck.streams.model.*; +import io.flightdeck.streams.model.FullSessionContext; +import io.flightdeck.streams.model.MessageInput; +import io.flightdeck.streams.model.ThinkResponse; import io.flightdeck.streams.serdes.JsonSerde; import org.apache.kafka.common.serialization.Serdes; import org.apache.kafka.streams.*; @@ -23,8 +25,6 @@ * message-input (KStream) * leftJoin * think-request-response (KTable) - * leftJoin - * session-cost (KTable) * ▼ * enriched-message-input (KStream) */ @@ -55,11 +55,7 @@ void setUp() { KTable thinkTable = builder.table( Topics.THINK_REQUEST_RESPONSE, Consumed.with(Serdes.String(), JsonSerde.of(ThinkResponse.class))); - KTable sessionCostTable = builder.table( - Topics.SESSION_COST, - Consumed.with(Serdes.String(), JsonSerde.of(SessionCost.class))); - - EnrichInputMessageProcessor.register(builder, memoirTable, thinkTable, sessionCostTable); + EnrichInputMessageProcessor.register(builder, memoirTable, thinkTable); Properties props = new Properties(); props.put(StreamsConfig.APPLICATION_ID_CONFIG, "test-enrich"); @@ -111,7 +107,7 @@ void firstMessage_noContext_emptyHistory() { @DisplayName("Message with existing ThinkResponse carries the reconstructed history") void messageWithThinkResponse_includesHistory() { // Seed the KTable with a previous ThinkResponse - ThinkResponse prevResponse = new ThinkResponse("sess-h", "user-A", 0.01, null, 100, 50, + ThinkResponse prevResponse = new ThinkResponse("sess-h", "user-A", 0.01, null, 0.01, 100, 50, List.of(assistantMsg("sess-h", "user-A", "First reply.")), userMsg("sess-h", "user-A", "Second question"), List.of(assistantMsg("sess-h", "user-A", "Second reply.")), @@ -133,7 +129,7 @@ void messageWithThinkResponse_includesHistory() { @Test @DisplayName("latestInput is always the incoming message, not part of history") void latestInput_isNotInHistory() { - ThinkResponse prevResponse = new ThinkResponse("sess-li", "u", 0.01, null, 100, 50, + ThinkResponse prevResponse = new ThinkResponse("sess-li", "u", 0.01, null, 0.01, 100, 50, null, null, List.of(assistantMsg("sess-li", "u", "Prior turn.")), List.of(), true, false, 0, 0, 0.0, TS); @@ -163,7 +159,7 @@ void outputKey_isSessionId() { @Test @DisplayName("userId is taken from the incoming message when present") void userId_fromMessage() { - ThinkResponse prevResponse = new ThinkResponse("sess-u1", "old-user", 0.01, null, 100, 50, + ThinkResponse prevResponse = new ThinkResponse("sess-u1", "old-user", 0.01, null, 0.01, 100, 50, null, null, null, List.of(), true, false, 0, 0, 0.0, TS); thinkInput.pipeInput("sess-u1", prevResponse); messageInput.pipeInput("sess-u1", userMsg("sess-u1", "new-user", "hi")); @@ -174,7 +170,7 @@ void userId_fromMessage() { @Test @DisplayName("userId falls back to ThinkResponse value when message carries none") void userId_fallsBackToThinkResponse() { - ThinkResponse prevResponse = new ThinkResponse("sess-u2", "ctx-user", 0.01, null, 100, 50, + ThinkResponse prevResponse = new ThinkResponse("sess-u2", "ctx-user", 0.01, null, 0.01, 100, 50, null, null, null, List.of(), true, false, 0, 0, 0.0, TS); thinkInput.pipeInput("sess-u2", prevResponse); // Message has null userId (e.g. scheduler-triggered input) @@ -199,12 +195,12 @@ void userId_nullWhenNoSource() { @Test @DisplayName("Two sessions receive their own independent histories") void sessionIsolation() { - ThinkResponse respA = new ThinkResponse("A", "u1", 0.01, null, 100, 50, + ThinkResponse respA = new ThinkResponse("A", "u1", 0.01, null, 0.01, 100, 50, List.of(assistantMsg("A", "u1", "a1")), null, List.of(assistantMsg("A", "u1", "a2")), List.of(), true, false, 0, 0, 0.0, TS); - ThinkResponse respB = new ThinkResponse("B", "u2", 0.01, null, 100, 50, + ThinkResponse respB = new ThinkResponse("B", "u2", 0.01, null, 0.01, 100, 50, null, null, List.of(assistantMsg("B", "u2", "b1")), List.of(), true, false, 0, 0, 0.0, TS); @@ -242,7 +238,7 @@ void enrich_nullThinkResponse_emptyHistory() { @DisplayName("enrichWithThinkResponse(): ThinkResponse with null fields is treated as empty history") void enrich_thinkResponseWithNullFields() { MessageInput msg = userMsg("s", "u", "hello"); - ThinkResponse resp = new ThinkResponse("s", "u", 0.0, null, 0, 0, + ThinkResponse resp = new ThinkResponse("s", "u", 0.0, null, 0.0, 0, 0, null, null, null, List.of(), true, false, 0, 0, 0.0, TS); FullSessionContext result = EnrichInputMessageProcessor.enrichWithThinkResponse(msg, resp); assertThat(result.history()).isEmpty(); @@ -255,7 +251,7 @@ void enrich_historyReconstructed() { MessageInput prior = assistantMsg("s", "u", "old"); MessageInput inputMsg = userMsg("s", "u", "question"); MessageInput response = assistantMsg("s", "u", "answer"); - ThinkResponse resp = new ThinkResponse("s", "u", 0.01, null, 100, 50, + ThinkResponse resp = new ThinkResponse("s", "u", 0.01, null, 0.01, 100, 50, List.of(prior), inputMsg, List.of(response), List.of(), true, false, 0, 0, 0.0, TS); FullSessionContext result = EnrichInputMessageProcessor.enrichWithThinkResponse(msg, resp); assertThat(result.history()).containsExactly(prior, inputMsg, response); @@ -266,7 +262,7 @@ void enrich_historyReconstructed() { @DisplayName("enrichWithThinkResponse(): userId prefers message over ThinkResponse") void enrich_userIdFromMessage() { MessageInput msg = userMsg("s", "msg-user", "hi"); - ThinkResponse resp = new ThinkResponse("s", "ctx-user", 0.01, null, 100, 50, + ThinkResponse resp = new ThinkResponse("s", "ctx-user", 0.01, null, 0.01, 100, 50, null, null, null, List.of(), true, false, 0, 0, 0.0, TS); assertThat(EnrichInputMessageProcessor.enrichWithThinkResponse(msg, resp).userId()).isEqualTo("msg-user"); } @@ -275,11 +271,30 @@ void enrich_userIdFromMessage() { @DisplayName("enrichWithThinkResponse(): userId falls back to ThinkResponse when message userId is blank") void enrich_userIdFallback() { MessageInput msg = new MessageInput("s", " ", "user", "hi", TS, Map.of()); - ThinkResponse resp = new ThinkResponse("s", "ctx-user", 0.01, null, 100, 50, + ThinkResponse resp = new ThinkResponse("s", "ctx-user", 0.01, null, 0.01, 100, 50, null, null, null, List.of(), true, false, 0, 0, 0.0, TS); assertThat(EnrichInputMessageProcessor.enrichWithThinkResponse(msg, resp).userId()).isEqualTo("ctx-user"); } + // ── Cost from ThinkResponse ───────────────────────────────────────────── + + @Test + @DisplayName("Cost flows from ThinkResponse.totalSessionCost into FullSessionContext.cost") + void costFlowsFromThinkResponseTotalSessionCost() { + // Seed the KTable with a ThinkResponse that has totalSessionCost=0.05 + ThinkResponse prevResponse = new ThinkResponse("sess-cost", "user-C", 0.05, null, 0.01, 100, 50, + null, null, + List.of(assistantMsg("sess-cost", "user-C", "Prior reply.")), + List.of(), true, false, 0, 0, 0.0, TS); + thinkInput.pipeInput("sess-cost", prevResponse); + + // Send a new message-input + messageInput.pipeInput("sess-cost", userMsg("sess-cost", "user-C", "Follow-up?")); + + FullSessionContext full = fullContextOutput.readRecord().value(); + assertThat(full.cost()).isEqualTo(0.05); + } + // ── Helpers ─────────────────────────────────────────────────────────────── private static final String TS = "2026-03-10T12:00:00Z"; diff --git a/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/SessionCostAggregationProcessorTest.java b/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/SessionCostAggregationProcessorTest.java deleted file mode 100644 index 371fa94..0000000 --- a/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/SessionCostAggregationProcessorTest.java +++ /dev/null @@ -1,190 +0,0 @@ -package io.flightdeck.streams.processors; - -import io.flightdeck.streams.config.Topics; -import io.flightdeck.streams.model.MessageInput; -import io.flightdeck.streams.model.SessionCost; -import io.flightdeck.streams.model.ThinkResponse; -import io.flightdeck.streams.model.ToolUseItem; -import io.flightdeck.streams.serdes.JsonSerde; -import org.apache.kafka.common.serialization.Serdes; -import org.apache.kafka.streams.*; -import org.apache.kafka.streams.kstream.Consumed; -import org.apache.kafka.streams.kstream.KStream; -import org.apache.kafka.streams.test.TestRecord; -import org.junit.jupiter.api.*; - -import java.util.List; -import java.util.Properties; - -import static io.flightdeck.streams.processors.SessionCostAggregationProcessor.*; -import static org.assertj.core.api.Assertions.*; - -class SessionCostAggregationProcessorTest { - - private TopologyTestDriver driver; - private TestInputTopic thinkInput; - private TestOutputTopic costOutput; - - @BeforeEach - void setUp() { - StreamsBuilder builder = new StreamsBuilder(); - KStream thinkStream = builder.stream( - Topics.THINK_REQUEST_RESPONSE, - Consumed.with(Serdes.String(), JsonSerde.of(ThinkResponse.class))); - SessionCostAggregationProcessor.register(builder, thinkStream); - - Properties props = new Properties(); - props.put(StreamsConfig.APPLICATION_ID_CONFIG, "test-session-cost"); - props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:9092"); - - driver = new TopologyTestDriver(builder.build(), props); - - thinkInput = driver.createInputTopic( - Topics.THINK_REQUEST_RESPONSE, - Serdes.String().serializer(), - JsonSerde.of(ThinkResponse.class).serializer()); - - costOutput = driver.createOutputTopic( - Topics.SESSION_COST, - Serdes.String().deserializer(), - JsonSerde.of(SessionCost.class).deserializer()); - } - - @AfterEach - void tearDown() { driver.close(); } - - // ── Aggregation ─────────────────────────────────────────────────────────── - - @Test - @DisplayName("First ThinkResponse creates a SessionCost with llmCalls=1") - void firstResponse_createsSessionCost() { - thinkInput.pipeInput("sess-1", response("sess-1", "user-A", 0.005, 200, 80)); - - SessionCost cost = costOutput.readRecord().value(); - assertThat(cost.sessionId()).isEqualTo("sess-1"); - assertThat(cost.userId()).isEqualTo("user-A"); - assertThat(cost.llmCalls()).isEqualTo(1); - assertThat(cost.totalInputTokens()).isEqualTo(200); - assertThat(cost.totalOutputTokens()).isEqualTo(80); - // Cost is 0 when INPUT_TOKEN_PRICE / OUTPUT_TOKEN_PRICE env vars are not set - assertThat(cost.estimatedCostUsd()).isCloseTo(0.005, within(0.000001)); - } - - @Test - @DisplayName("Subsequent responses accumulate tokens and cost") - void multipleResponses_accumulate() { - thinkInput.pipeInput("sess-2", response("sess-2", "u", 0.003, 100, 50)); - thinkInput.pipeInput("sess-2", response("sess-2", "u", 0.007, 300, 120)); - thinkInput.pipeInput("sess-2", response("sess-2", "u", 0.002, 80, 30)); - - List> records = costOutput.readRecordsToList(); - assertThat(records).hasSize(3); - - SessionCost latest = records.get(2).value(); - assertThat(latest.llmCalls()).isEqualTo(3); - assertThat(latest.totalInputTokens()).isEqualTo(480); - assertThat(latest.totalOutputTokens()).isEqualTo(200); - assertThat(latest.estimatedCostUsd()).isCloseTo(0.012, within(0.000001)); - } - - @Test - @DisplayName("Different sessions maintain independent cost accumulators") - void independentSessions() { - thinkInput.pipeInput("sess-A", response("sess-A", "u1", 0.01, 100, 50)); - thinkInput.pipeInput("sess-B", response("sess-B", "u2", 0.02, 200, 80)); - thinkInput.pipeInput("sess-A", response("sess-A", "u1", 0.03, 150, 60)); - - List> all = costOutput.readRecordsToList(); - - SessionCost latestA = all.stream().filter(r -> "sess-A".equals(r.key())) - .reduce((a, b) -> b).orElseThrow().value(); - SessionCost latestB = all.stream().filter(r -> "sess-B".equals(r.key())) - .reduce((a, b) -> b).orElseThrow().value(); - - assertThat(latestA.llmCalls()).isEqualTo(2); - assertThat(latestA.estimatedCostUsd()).isCloseTo(0.04, within(0.000001)); - assertThat(latestB.llmCalls()).isEqualTo(1); - assertThat(latestB.estimatedCostUsd()).isCloseTo(0.02, within(0.000001)); - } - - @Test - @DisplayName("Output record key is the session_id") - void outputKey_isSessionId() { - thinkInput.pipeInput("sess-key", response("sess-key", "u", 0.001, 50, 20)); - assertThat(costOutput.readRecord().key()).isEqualTo("sess-key"); - } - - // ── Null cost passthrough ─────────────────────────────────────────────────── - - @Test - @DisplayName("When response.cost is null, aggregated cost remains null") - void nullCostField_passesThrough() { - thinkInput.pipeInput("sess-est", response("sess-est", "u", null, 1_000_000, 1_000_000)); - - SessionCost cost = costOutput.readRecord().value(); - assertThat(cost.estimatedCostUsd()).isNull(); - } - - // ── Tombstone / session close ───────────────────────────────────────────── - - @Test - @DisplayName("Close-sentinel response emits a tombstone (null value) on session-cost") - void closeSignal_emitsTombstone() { - // Seed a real cost first - thinkInput.pipeInput("sess-close", response("sess-close", "u", 0.01, 100, 50)); - costOutput.readRecord(); // drain the normal update - - // Send close sentinel - thinkInput.pipeInput("sess-close", closeSignal("sess-close")); - - TestRecord tombstone = costOutput.readRecord(); - assertThat(tombstone.key()).isEqualTo("sess-close"); - assertThat(tombstone.value()).isNull(); - } - - @Test - @DisplayName("Close-sentinel for unknown session still emits tombstone") - void closeSignalWithNoHistory_stillEmitsTombstone() { - thinkInput.pipeInput("sess-unknown", closeSignal("sess-unknown")); - - TestRecord tombstone = costOutput.readRecord(); - assertThat(tombstone.key()).isEqualTo("sess-unknown"); - assertThat(tombstone.value()).isNull(); - } - - // ── userId resolution ───────────────────────────────────────────────────── - - @Test - @DisplayName("resolveUserId preserves existing when incoming is null") - void resolveUserId_null() { - assertThat(resolveUserId("existing", null)).isEqualTo("existing"); - } - - @Test - @DisplayName("resolveUserId preserves existing when incoming is blank") - void resolveUserId_blank() { - assertThat(resolveUserId("existing", " ")).isEqualTo("existing"); - } - - @Test - @DisplayName("resolveUserId uses incoming when provided") - void resolveUserId_incoming() { - assertThat(resolveUserId("old", "new")).isEqualTo("new"); - } - - // ── Helpers ─────────────────────────────────────────────────────────────── - - private static final String TS = "2026-03-10T12:00:00Z"; - - private static ThinkResponse response(String sessionId, String userId, - Double cost, int inputTokens, int outputTokens) { - return new ThinkResponse(sessionId, userId, cost, null, inputTokens, outputTokens, - null, null, null, List.of(), false, false, 0, 0, 0.0, TS); - } - - /** Builds the sentinel close signal recognised by the processor. */ - private static ThinkResponse closeSignal(String sessionId) { - return new ThinkResponse(sessionId, null, SESSION_CLOSE_SENTINEL, null, - 0, 0, null, null, null, List.of(), true, false, 0, 0, 0.0, TS); - } -} \ No newline at end of file diff --git a/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/SessionEndProcessorTest.java b/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/SessionEndProcessorTest.java index 0e6fa9f..e0bca36 100644 --- a/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/SessionEndProcessorTest.java +++ b/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/SessionEndProcessorTest.java @@ -263,7 +263,7 @@ private static Properties testProps() { } private static ThinkResponse thinkResponse(String sessionId) { - return new ThinkResponse(sessionId, "user-1", 0.01, null, 100, 50, + return new ThinkResponse(sessionId, "user-1", 0.01, null, 0.01, 100, 50, null, null, null, null, true, false, 0, 0, 0.0, TS); } } diff --git a/think/think-consumer/src/main/java/io/flightdeck/think/StandaloneRunner.java b/think/think-consumer/src/main/java/io/flightdeck/think/StandaloneRunner.java index 8dfcde9..2a23b67 100644 --- a/think/think-consumer/src/main/java/io/flightdeck/think/StandaloneRunner.java +++ b/think/think-consumer/src/main/java/io/flightdeck/think/StandaloneRunner.java @@ -94,9 +94,9 @@ private static void sendAndPrint(ClaudeApiService claudeApi, ObjectMapper mapper // Print metadata System.out.println(); System.out.printf("--- tokens: %d in / %d out | cost: $%.6f | end_turn: %s ---%n", - response.inputTokens(), - response.outputTokens(), - response.cost(), + response.thinkInputTokens(), + response.thinkOutputTokens(), + response.thinkCost(), response.endTurn()); } catch (Exception e) { diff --git a/think/think-consumer/src/main/java/io/flightdeck/think/consumer/ThinkConsumer.java b/think/think-consumer/src/main/java/io/flightdeck/think/consumer/ThinkConsumer.java index e0903af..cf9449f 100644 --- a/think/think-consumer/src/main/java/io/flightdeck/think/consumer/ThinkConsumer.java +++ b/think/think-consumer/src/main/java/io/flightdeck/think/consumer/ThinkConsumer.java @@ -187,7 +187,7 @@ void processRecord(ConsumerRecord record) throws Exception { AppConfig.BUDGET_PRICE_PER_SESSION); ThinkResponse budgetResponse = new ThinkResponse( - sessionId, userId, null, context.cost(), 0, 0, + sessionId, userId, context.cost(), context.cost(), null, 0, 0, context.history(), context.latestInput(), List.of(new MessageInput(sessionId, userId, "assistant", budgetMessage, @@ -249,18 +249,18 @@ void processRecord(ConsumerRecord record) throws Exception { log.info("[{}] Compaction LLM response: input_tokens={} output_tokens={} cost={} lastInputResponse={}", sessionId, - summaryResponse.inputTokens(), - summaryResponse.outputTokens(), - summaryResponse.cost() != null ? String.format("$%.6f", summaryResponse.cost()) : "null", + summaryResponse.thinkInputTokens(), + summaryResponse.thinkOutputTokens(), + summaryResponse.thinkCost() != null ? String.format("$%.6f", summaryResponse.thinkCost()) : "null", summaryResponse.lastInputResponse() != null ? mapper.writeValueAsString(summaryResponse.lastInputResponse()) : "null"); // Capture compaction metrics compacted = true; - compactionInputTokens = summaryResponse.inputTokens(); - compactionOutputTokens = summaryResponse.outputTokens(); - compactionCost = summaryResponse.cost() != null ? summaryResponse.cost() : 0.0; + compactionInputTokens = summaryResponse.thinkInputTokens(); + compactionOutputTokens = summaryResponse.thinkOutputTokens(); + compactionCost = summaryResponse.thinkCost() != null ? summaryResponse.thinkCost() : 0.0; String summaryText = extractTextFromMessages( summaryResponse.lastInputResponse()); @@ -297,14 +297,19 @@ void processRecord(ConsumerRecord record) throws Exception { // 7. Build final ThinkResponse Double prevSessionCost = context.cost(); + Double thinkCost = thinkResponse.thinkCost(); + Double totalSessionCost = (prevSessionCost != null ? prevSessionCost : 0.0) + + (thinkCost != null ? thinkCost : 0.0) + + compactionCost; thinkResponse = new ThinkResponse( sessionId, userId, - thinkResponse.cost(), + totalSessionCost, prevSessionCost, - thinkResponse.inputTokens(), - thinkResponse.outputTokens(), + thinkCost, + thinkResponse.thinkInputTokens(), + thinkResponse.thinkOutputTokens(), effectiveHistory, context.latestInput(), thinkResponse.lastInputResponse(), @@ -389,7 +394,7 @@ void emitErrorResponse(ConsumerRecord record, Exception e) { static ThinkResponse buildErrorResponse(String sessionId, String userId, Exception e) { String errorMessage = "Sorry, an error occurred while processing your request: " + e.getMessage(); return new ThinkResponse( - sessionId, userId, null, null, 0, 0, + sessionId, userId, null, null, null, 0, 0, null, null, List.of(new MessageInput(sessionId, userId, "assistant", errorMessage, java.time.Instant.now().toString(), null)), diff --git a/think/think-consumer/src/main/java/io/flightdeck/think/model/ThinkResponse.java b/think/think-consumer/src/main/java/io/flightdeck/think/model/ThinkResponse.java index 8365cb9..7d178a8 100644 --- a/think/think-consumer/src/main/java/io/flightdeck/think/model/ThinkResponse.java +++ b/think/think-consumer/src/main/java/io/flightdeck/think/model/ThinkResponse.java @@ -9,10 +9,11 @@ public record ThinkResponse( @JsonProperty("session_id") String sessionId, @JsonProperty("user_id") String userId, - @JsonProperty("cost") Double cost, - @JsonProperty("prev_session_cost") Double prevSessionCost, - @JsonProperty("input_tokens") int inputTokens, - @JsonProperty("output_tokens") int outputTokens, + @JsonProperty("total_session_cost") Double totalSessionCost, + @JsonProperty("previous_session_cost") Double previousSessionCost, + @JsonProperty("think_cost") Double thinkCost, + @JsonProperty("think_input_tokens") int thinkInputTokens, + @JsonProperty("think_output_tokens") int thinkOutputTokens, @JsonProperty("previous_messages") List previousMessages, @JsonProperty("last_input_message") MessageInput lastInputMessage, @JsonProperty("last_input_response") List lastInputResponse, diff --git a/think/think-consumer/src/main/java/io/flightdeck/think/service/ClaudeApiService.java b/think/think-consumer/src/main/java/io/flightdeck/think/service/ClaudeApiService.java index 2e30f7f..fb5e29c 100644 --- a/think/think-consumer/src/main/java/io/flightdeck/think/service/ClaudeApiService.java +++ b/think/think-consumer/src/main/java/io/flightdeck/think/service/ClaudeApiService.java @@ -277,8 +277,9 @@ ThinkResponse parseResponse(String responseBody, String sessionId, String userId return new ThinkResponse( sessionId, userId, - cost, - null, // prevSessionCost — set by ThinkConsumer + null, // totalSessionCost — set by ThinkConsumer + null, // previousSessionCost — set by ThinkConsumer + cost, // thinkCost inputTokens, outputTokens, null, // previousMessages — set by ThinkConsumer diff --git a/think/think-consumer/src/main/java/io/flightdeck/think/service/GeminiApiService.java b/think/think-consumer/src/main/java/io/flightdeck/think/service/GeminiApiService.java index a98b059..d775bcb 100644 --- a/think/think-consumer/src/main/java/io/flightdeck/think/service/GeminiApiService.java +++ b/think/think-consumer/src/main/java/io/flightdeck/think/service/GeminiApiService.java @@ -290,8 +290,9 @@ ThinkResponse parseResponse(String responseBody, String sessionId, String userId return new ThinkResponse( sessionId, userId, - cost, - null, // prevSessionCost — set by ThinkConsumer + null, // totalSessionCost — set by ThinkConsumer + null, // previousSessionCost — set by ThinkConsumer + cost, // thinkCost inputTokens, outputTokens, null, // previousMessages — set by ThinkConsumer diff --git a/think/think-consumer/src/test/java/io/flightdeck/think/consumer/CompactionTest.java b/think/think-consumer/src/test/java/io/flightdeck/think/consumer/CompactionTest.java index 7c98955..039de41 100644 --- a/think/think-consumer/src/test/java/io/flightdeck/think/consumer/CompactionTest.java +++ b/think/think-consumer/src/test/java/io/flightdeck/think/consumer/CompactionTest.java @@ -119,7 +119,7 @@ void compaction_exactlyAtTrigger_withToolInteractions() throws Exception { // --- Mock: compaction call --- ThinkResponse summaryResponse = new ThinkResponse( - SESSION, USER, 0.001, null, 50, 30, + SESSION, USER, null, null, 0.001, 50, 30, null, null, List.of(assistantMsg("User listed topics. Found 5 topics including topicA and topicB.")), null, true, false, 0, 0, 0.0, TS); @@ -133,7 +133,7 @@ void compaction_exactlyAtTrigger_withToolInteractions() throws Exception { when(mockLlm.toApiMessages(anyList(), eq(latestInput))) .thenReturn(mainApiMessages); ThinkResponse mainResponse = new ThinkResponse( - SESSION, USER, 0.01, null, 200, 100, + SESSION, USER, null, null, 0.01, 200, 100, null, null, List.of(assistantMsg("All brokers healthy.")), null, true, false, 0, 0, 0.0, TS); @@ -178,6 +178,12 @@ void compaction_exactlyAtTrigger_withToolInteractions() throws Exception { assertThat(produced.compactionInputTokens()).isEqualTo(50); assertThat(produced.compactionOutputTokens()).isEqualTo(30); assertThat(produced.compactionCost()).isGreaterThan(0.0); + + // total_session_cost = previous_session_cost(0.05) + think_cost(0.01) + compaction_cost(0.001) + assertThat(produced.previousSessionCost()).isEqualTo(0.05); + assertThat(produced.thinkCost()).isEqualTo(0.01); + assertThat(produced.totalSessionCost()).isCloseTo(0.061, within(0.000001)); + assertThat(produced.lastInputMessage().contentAsString()).isEqualTo("show broker health"); assertThat(produced.lastInputResponse()).hasSize(1); assertThat(produced.lastInputResponse().get(0).contentAsString()).isEqualTo("All brokers healthy."); @@ -234,7 +240,7 @@ void compaction_includesToolUseAndToolResults() throws Exception { .thenReturn(oldApiMessages); ThinkResponse summaryResponse = new ThinkResponse( - SESSION, USER, 0.002, null, 80, 40, + SESSION, USER, null, null, 0.002, 80, 40, null, null, List.of(assistantMsg("User searched flights to NYC. Found AA123, UA456, DL789.")), null, true, false, 0, 0, 0.0, TS); @@ -253,7 +259,7 @@ void compaction_includesToolUseAndToolResults() throws Exception { .thenReturn(mainApiMessages); ThinkResponse mainResponse = new ThinkResponse( - SESSION, USER, 0.01, null, 200, 100, + SESSION, USER, null, null, 0.01, 200, 100, null, null, List.of(assistantMsg("You're welcome!")), null, true, false, 0, 0, 0.0, TS); @@ -310,7 +316,7 @@ void noCompaction_userMessagesBelowTrigger() throws Exception { .thenReturn(apiMessages); ThinkResponse mainResponse = new ThinkResponse( - SESSION, USER, 0.01, null, 100, 50, + SESSION, USER, null, null, 0.01, 100, 50, null, null, List.of(assistantMsg("Nice!")), null, true, false, 0, 0, 0.0, TS); @@ -339,6 +345,13 @@ void noCompaction_userMessagesBelowTrigger() throws Exception { String producedJson = capturingProducer.records.get(0).value(); ThinkResponse produced = mapper.readValue(producedJson, ThinkResponse.class); assertThat(produced.previousMessages()).hasSize(history.size()); + + // total_session_cost = previous_session_cost(0.02) + think_cost(0.01) + compaction_cost(0.0) + assertThat(produced.previousSessionCost()).isEqualTo(0.02); + assertThat(produced.thinkCost()).isEqualTo(0.01); + assertThat(produced.totalSessionCost()).isCloseTo(0.03, within(0.000001)); + assertThat(produced.compaction()).isFalse(); + assertThat(produced.compactionCost()).isEqualTo(0.0); } @Test @@ -369,7 +382,7 @@ void noCompaction_midToolLoop() throws Exception { .thenReturn(apiMessages); ThinkResponse mainResponse = new ThinkResponse( - SESSION, USER, 0.01, null, 100, 50, + SESSION, USER, null, null, 0.01, 100, 50, null, null, List.of(assistantMsg("Here are the results.")), null, true, false, 0, 0, 0.0, TS); @@ -432,7 +445,7 @@ void realWorld_userLatestInput_compactsWithToolHistory() throws Exception { when(mockLlm.callWithoutTools( eq(io.flightdeck.think.config.AppConfig.COMPACTION_PROMPT), anyList(), eq(sessionId), eq(userId))) - .thenReturn(new ThinkResponse(sessionId, userId, 0.001, null, 20, 15, + .thenReturn(new ThinkResponse(sessionId, userId, null, null, 0.001, 20, 15, null, null, List.of(new MessageInput(sessionId, userId, "assistant", "User greeted the Kafka assistant.", TS, null)), @@ -442,7 +455,7 @@ void realWorld_userLatestInput_compactsWithToolHistory() throws Exception { List> mainApiMsgs = List.of(Map.of("role", "user", "content", "main")); when(mockLlm.toApiMessages(anyList(), eq(context.latestInput()))).thenReturn(mainApiMsgs); when(mockLlm.call(anyString(), eq(mainApiMsgs), eq(sessionId), eq(userId))) - .thenReturn(new ThinkResponse(sessionId, userId, 0.005, null, 200, 100, + .thenReturn(new ThinkResponse(sessionId, userId, null, null, 0.005, 200, 100, null, null, List.of(new MessageInput(sessionId, userId, "assistant", "think-consumer-group has 0 lag.", TS, null)), @@ -504,7 +517,7 @@ void realWorld_toolLatestInput_skipsCompaction() throws Exception { when(mockLlm.toApiMessages(eq(context.history()), eq(toolLatestInput))) .thenReturn(List.of(Map.of("role", "user", "content", "test"))); when(mockLlm.call(anyString(), anyList(), eq(sessionId), eq(userId))) - .thenReturn(new ThinkResponse(sessionId, userId, 0.01, null, 100, 50, + .thenReturn(new ThinkResponse(sessionId, userId, null, null, 0.01, 100, 50, null, null, List.of(new MessageInput(sessionId, userId, "assistant", "The group has 0 lag.", TS, null)), @@ -545,7 +558,7 @@ void compactionPrompt_usedForSummarizationCall() throws Exception { .thenReturn(List.of(Map.of("role", "user", "content", "msg1"))); when(mockLlm.callWithoutTools(eq(io.flightdeck.think.config.AppConfig.COMPACTION_PROMPT), anyList(), eq(SESSION), eq(USER))) - .thenReturn(new ThinkResponse(SESSION, USER, 0.0, null, 10, 10, + .thenReturn(new ThinkResponse(SESSION, USER, null, null, 0.0, 10, 10, null, null, List.of(assistantMsg("summary")), null, true, false, 0, 0, 0.0, TS)); @@ -553,7 +566,7 @@ void compactionPrompt_usedForSummarizationCall() throws Exception { when(mockLlm.toApiMessages(anyList(), eq(latestInput))) .thenReturn(List.of(Map.of("role", "user", "content", "msg4"))); when(mockLlm.call(anyString(), anyList(), eq(SESSION), eq(USER))) - .thenReturn(new ThinkResponse(SESSION, USER, 0.01, null, 50, 50, + .thenReturn(new ThinkResponse(SESSION, USER, null, null, 0.01, 50, 50, null, null, List.of(assistantMsg("done")), null, true, false, 0, 0, 0.0, TS)); @@ -609,7 +622,7 @@ void noCompaction_userMsgsEqualUntil() throws Exception { when(mockLlm.toApiMessages(eq(history), eq(latestInput))) .thenReturn(List.of(Map.of("role", "user", "content", "test"))); when(mockLlm.call(anyString(), anyList(), eq(SESSION), eq(USER))) - .thenReturn(new ThinkResponse(SESSION, USER, 0.01, null, 100, 50, + .thenReturn(new ThinkResponse(SESSION, USER, null, null, 0.01, 100, 50, null, null, List.of(assistantMsg("You're welcome!")), null, true, false, 0, 0, 0.0, TS)); @@ -659,14 +672,14 @@ void compaction_manyUserMessages_keepsLastTwo() throws Exception { String summaryText = "User asked 6 questions (q1-q6) using various tools. All returned results."; when(mockLlm.callWithoutTools(anyString(), anyList(), eq(SESSION), eq(USER))) - .thenReturn(new ThinkResponse(SESSION, USER, 0.005, null, 500, 100, + .thenReturn(new ThinkResponse(SESSION, USER, null, null, 0.005, 500, 100, null, null, List.of(assistantMsg(summaryText)), null, true, false, 0, 0, 0.0, TS)); List> mainApiMsgs = List.of(Map.of("role", "user", "content", "main")); when(mockLlm.toApiMessages(anyList(), eq(latestInput))).thenReturn(mainApiMsgs); when(mockLlm.call(anyString(), eq(mainApiMsgs), eq(SESSION), eq(USER))) - .thenReturn(new ThinkResponse(SESSION, USER, 0.01, null, 200, 50, + .thenReturn(new ThinkResponse(SESSION, USER, null, null, 0.01, 200, 50, null, null, List.of(assistantMsg("answer 9")), null, true, false, 0, 0, 0.0, TS)); @@ -755,14 +768,14 @@ void compaction_secondRound_recompactsOldSummary() throws Exception { // The summary should incorporate the old summary + new info String newSummary = "User asked about topics. Then checked broker health (all OK) and it was healthy."; when(mockLlm.callWithoutTools(anyString(), anyList(), eq(SESSION), eq(USER))) - .thenReturn(new ThinkResponse(SESSION, USER, 0.002, null, 80, 40, + .thenReturn(new ThinkResponse(SESSION, USER, null, null, 0.002, 80, 40, null, null, List.of(assistantMsg(newSummary)), null, true, false, 0, 0, 0.0, TS)); List> mainApiMsgs = List.of(Map.of("role", "user", "content", "main")); when(mockLlm.toApiMessages(anyList(), eq(latestInput))).thenReturn(mainApiMsgs); when(mockLlm.call(anyString(), eq(mainApiMsgs), eq(SESSION), eq(USER))) - .thenReturn(new ThinkResponse(SESSION, USER, 0.01, null, 200, 80, + .thenReturn(new ThinkResponse(SESSION, USER, null, null, 0.01, 200, 80, null, null, List.of(assistantMsg("Nothing else to report.")), null, true, false, 0, 0, 0.0, TS)); @@ -820,7 +833,7 @@ void noCompaction_previousMessagesEqualsHistory_fieldsZero() throws Exception { when(mockLlm.toApiMessages(eq(history), eq(latestInput))) .thenReturn(List.of(Map.of("role", "user", "content", "test"))); when(mockLlm.call(anyString(), anyList(), eq(SESSION), eq(USER))) - .thenReturn(new ThinkResponse(SESSION, USER, 0.005, null, 100, 30, + .thenReturn(new ThinkResponse(SESSION, USER, null, null, 0.005, 100, 30, null, null, List.of(assistantMsg("No problem!")), null, true, false, 0, 0, 0.0, TS)); @@ -913,14 +926,14 @@ void compactionFields_exactValues() throws Exception { // Compaction returns specific token counts and cost when(mockLlm.callWithoutTools(anyString(), anyList(), eq(SESSION), eq(USER))) - .thenReturn(new ThinkResponse(SESSION, USER, 0.00234, null, 150, 42, + .thenReturn(new ThinkResponse(SESSION, USER, null, null, 0.00234, 150, 42, null, null, List.of(assistantMsg("Summary of q1.")), null, true, false, 0, 0, 0.0, TS)); when(mockLlm.toApiMessages(anyList(), eq(latestInput))) .thenReturn(List.of(Map.of("role", "user", "content", "test"))); when(mockLlm.call(anyString(), anyList(), eq(SESSION), eq(USER))) - .thenReturn(new ThinkResponse(SESSION, USER, 0.008, null, 300, 80, + .thenReturn(new ThinkResponse(SESSION, USER, null, null, 0.008, 300, 80, null, null, List.of(assistantMsg("a4")), null, true, false, 0, 0, 0.0, TS)); @@ -938,9 +951,13 @@ void compactionFields_exactValues() throws Exception { assertThat(produced.compactionCost()).isEqualTo(0.00234); // Main response fields - assertThat(produced.cost()).isEqualTo(0.008); - assertThat(produced.inputTokens()).isEqualTo(300); - assertThat(produced.outputTokens()).isEqualTo(80); + assertThat(produced.thinkCost()).isEqualTo(0.008); + assertThat(produced.thinkInputTokens()).isEqualTo(300); + assertThat(produced.thinkOutputTokens()).isEqualTo(80); + + // total_session_cost = previous_session_cost(0.05) + think_cost(0.008) + compaction_cost(0.00234) + assertThat(produced.previousSessionCost()).isEqualTo(0.05); + assertThat(produced.totalSessionCost()).isCloseTo(0.06034, within(0.000001)); // Summary is correct assertThat(produced.previousMessages().get(0).contentAsString()).isEqualTo( @@ -973,7 +990,7 @@ void noCompaction_noUserMessages() throws Exception { when(mockLlm.toApiMessages(eq(history), eq(latestInput))) .thenReturn(List.of(Map.of("role", "user", "content", "test"))); when(mockLlm.call(anyString(), anyList(), eq(SESSION), eq(USER))) - .thenReturn(new ThinkResponse(SESSION, USER, 0.005, null, 100, 30, + .thenReturn(new ThinkResponse(SESSION, USER, null, null, 0.005, 100, 30, null, null, List.of(assistantMsg("Final result.")), null, true, false, 0, 0, 0.0, TS)); @@ -1053,6 +1070,113 @@ void splitIndex_emptyHistory() { assertThat(ThinkConsumer.findCompactionSplitIndex(List.of(), 2)).isEqualTo(-1); } + // ── Cost field tests ───────────────────────────────────────────────────── + + @Test + @DisplayName("First turn (previousSessionCost=null): totalSessionCost = thinkCost") + void firstTurn_totalSessionCostEqualsThinkCost() throws Exception { + List history = List.of(); + MessageInput latestInput = userMsg("Hello"); + + FullSessionContext context = new FullSessionContext( + SESSION, USER, null, history, latestInput, null, TS); + + List> apiMessages = List.of( + Map.of("role", "user", "content", "Hello")); + when(mockLlm.toApiMessages(eq(history), eq(latestInput))) + .thenReturn(apiMessages); + when(mockLlm.call(anyString(), eq(apiMessages), eq(SESSION), eq(USER))) + .thenReturn(new ThinkResponse(SESSION, USER, null, null, 0.005, 100, 40, + null, null, List.of(assistantMsg("Hi there!")), + null, true, false, 0, 0, 0.0, TS)); + + thinkConsumer.processRecord(new ConsumerRecord<>( + "test-enriched-message-input", 0, 0, SESSION, + mapper.writeValueAsString(context))); + + ThinkResponse produced = mapper.readValue( + capturingProducer.records.get(0).value(), ThinkResponse.class); + + assertThat(produced.previousSessionCost()).isNull(); + assertThat(produced.thinkCost()).isEqualTo(0.005); + assertThat(produced.totalSessionCost()).isCloseTo(0.005, within(0.000001)); + assertThat(produced.thinkInputTokens()).isEqualTo(100); + assertThat(produced.thinkOutputTokens()).isEqualTo(40); + assertThat(produced.compaction()).isFalse(); + assertThat(produced.compactionCost()).isEqualTo(0.0); + } + + @Test + @DisplayName("Both costs > 0: totalSessionCost = previousSessionCost + thinkCost") + void bothCostsPositive_totalSessionCostAddsUp() throws Exception { + List history = List.of( + userMsg("Hello"), assistantMsg("Hi!")); + MessageInput latestInput = userMsg("What's 2+2?"); + + FullSessionContext context = new FullSessionContext( + SESSION, USER, 0.04, history, latestInput, null, TS); + + List> apiMessages = List.of( + Map.of("role", "user", "content", "What's 2+2?")); + when(mockLlm.toApiMessages(eq(history), eq(latestInput))) + .thenReturn(apiMessages); + when(mockLlm.call(anyString(), eq(apiMessages), eq(SESSION), eq(USER))) + .thenReturn(new ThinkResponse(SESSION, USER, null, null, 0.012, 250, 90, + null, null, List.of(assistantMsg("4")), + null, true, false, 0, 0, 0.0, TS)); + + thinkConsumer.processRecord(new ConsumerRecord<>( + "test-enriched-message-input", 0, 0, SESSION, + mapper.writeValueAsString(context))); + + ThinkResponse produced = mapper.readValue( + capturingProducer.records.get(0).value(), ThinkResponse.class); + + assertThat(produced.previousSessionCost()).isEqualTo(0.04); + assertThat(produced.thinkCost()).isEqualTo(0.012); + assertThat(produced.totalSessionCost()).isCloseTo(0.052, within(0.000001)); + assertThat(produced.thinkInputTokens()).isEqualTo(250); + assertThat(produced.thinkOutputTokens()).isEqualTo(90); + assertThat(produced.compaction()).isFalse(); + assertThat(produced.compactionCost()).isEqualTo(0.0); + } + + @Test + @DisplayName("Budget exceeded: cost fields set correctly, LLM not called") + void budgetExceeded_costFieldsSetCorrectly() throws Exception { + if (io.flightdeck.think.config.AppConfig.BUDGET_PRICE_PER_SESSION == null) { + return; // Budget not configured in test env + } + + double budgetLimit = io.flightdeck.think.config.AppConfig.BUDGET_PRICE_PER_SESSION; + double overBudgetCost = budgetLimit + 0.01; + + List history = List.of( + userMsg("Hello"), assistantMsg("Hi!")); + MessageInput latestInput = userMsg("One more question"); + + FullSessionContext context = new FullSessionContext( + SESSION, USER, overBudgetCost, history, latestInput, null, TS); + + thinkConsumer.processRecord(new ConsumerRecord<>( + "test-enriched-message-input", 0, 0, SESSION, + mapper.writeValueAsString(context))); + + ThinkResponse produced = mapper.readValue( + capturingProducer.records.get(0).value(), ThinkResponse.class); + + verify(mockLlm, never()).call(anyString(), anyList(), anyString(), anyString()); + + assertThat(produced.totalSessionCost()).isEqualTo(overBudgetCost); + assertThat(produced.previousSessionCost()).isEqualTo(overBudgetCost); + assertThat(produced.thinkCost()).isNull(); + assertThat(produced.thinkInputTokens()).isEqualTo(0); + assertThat(produced.thinkOutputTokens()).isEqualTo(0); + assertThat(produced.endTurn()).isTrue(); + assertThat(produced.lastInputResponse().get(0).contentAsString()) + .contains("budget"); + } + // ── Helpers ─────────────────────────────────────────────────────────────── private static MessageInput userMsg(String content) { From fb6bb2ed5c166953bf94b73c87c01822faa0d3d4 Mon Sep 17 00:00:00 2001 From: Taku Suzuki Date: Sat, 4 Apr 2026 23:14:40 +0900 Subject: [PATCH 2/3] remove session-cost, add cost to think consumer --- .../flightdeck_sdk/think_consumer_runner.py | 130 ++++++++++-------- 1 file changed, 74 insertions(+), 56 deletions(-) diff --git a/sdk/python/flightdeck_sdk/think_consumer_runner.py b/sdk/python/flightdeck_sdk/think_consumer_runner.py index fe96947..f4a7bb6 100644 --- a/sdk/python/flightdeck_sdk/think_consumer_runner.py +++ b/sdk/python/flightdeck_sdk/think_consumer_runner.py @@ -30,9 +30,12 @@ class ThinkConsumerConfig: gemini_model: str = "gemini-2.5-flash" gemini_max_tokens: int = 4096 gemini_api_url: str = "https://generativelanguage.googleapis.com/v1beta" - compaction_user_message_threshold: int = -1 + compaction_user_message_trigger: int = -1 + compaction_user_message_until: int = 2 compaction_prompt: str = ( "Summarize the following conversation concisely. " + "If the conversation starts with a previous summary, incorporate and extend it " + "rather than re-summarizing it. " "Preserve key facts, decisions, user preferences, and any context needed " "to continue the conversation naturally. Output only the summary." ) @@ -155,7 +158,9 @@ def _process_record(self, key: str | None, value: str | None, topic: str, partit "prevSessionCost": cumulative_cost, "inputTokens": 0, "outputTokens": 0, - "messages": [ + "previousMessages": history, + "lastInputMessage": latest_input, + "lastInputResponse": [ { "sessionId": session_id, "userId": user_id, @@ -178,50 +183,51 @@ def _process_record(self, key: str | None, value: str | None, topic: str, partit self._consumer.store_offsets(offsets=[tp]) return - # Compact history if user message count exceeds threshold + # Compact history if user message count exceeds trigger effective_history = history compacted_history = None - threshold = self._config.compaction_user_message_threshold + trigger = self._config.compaction_user_message_trigger + keep_last = self._config.compaction_user_message_until - if threshold > 0 and len(effective_history) > 2: + if trigger > 0 and effective_history: user_msg_count = sum(1 for m in effective_history if m.get("role") == "user") - if user_msg_count >= threshold: - last_two = effective_history[-2:] - mid_tool_loop = any( - m.get("role") == "tool" or self._has_tool_use_content(m) - for m in last_two - ) - if not mid_tool_loop: - logger.info( - "[%s] Compacting history: %d user messages >= threshold %d", - session_id, user_msg_count, threshold, - ) - old_messages = effective_history[:-2] - recent_messages = list(last_two) - - provider = self._config.llm_provider.lower() - if provider == "gemini": - summary_input = self._to_gemini_messages(old_messages, {}) - summary_resp = self._call_gemini(self._config.compaction_prompt, summary_input) - summary_text = self._extract_gemini_text(summary_resp) - else: - summary_input = self._to_claude_messages(old_messages, {}) - summary_resp = self._call_claude(self._config.compaction_prompt, summary_input) - summary_text = self._extract_claude_text(summary_resp) - - summary_msg = { - "sessionId": session_id, - "userId": user_id, - "role": "assistant", - "content": f"[Conversation Summary]\n{summary_text}", - "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), - } - compacted_history = [summary_msg] + recent_messages - effective_history = compacted_history - logger.info( - "[%s] History compacted: %d messages → %d", - session_id, len(history), len(compacted_history), - ) + if user_msg_count >= trigger: + split_idx = self._find_compaction_split_index(effective_history, keep_last) + if split_idx > 0: + recent_messages = effective_history[split_idx:] + # Skip compaction if we're in an active tool loop + # (latest_input is a tool result) + mid_tool_loop = latest_input.get("role") == "tool" if latest_input else False + if not mid_tool_loop: + logger.info( + "[%s] Compacting history: %d user messages >= trigger %d, keeping from index %d", + session_id, user_msg_count, trigger, split_idx, + ) + old_messages = effective_history[:split_idx] + + provider = self._config.llm_provider.lower() + if provider == "gemini": + summary_input = self._to_gemini_messages(old_messages, {}) + summary_resp = self._call_gemini(self._config.compaction_prompt, summary_input, include_tools=False) + summary_text = self._extract_gemini_text(summary_resp) + else: + summary_input = self._to_claude_messages(old_messages, {}) + summary_resp = self._call_claude(self._config.compaction_prompt, summary_input, include_tools=False) + summary_text = self._extract_claude_text(summary_resp) + + summary_msg = { + "sessionId": session_id, + "userId": user_id, + "role": "assistant", + "content": f"[Conversation Summary]\n{summary_text}", + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + } + compacted_history = [summary_msg] + recent_messages + effective_history = compacted_history + logger.info( + "[%s] History compacted: %d messages → %d", + session_id, len(history), len(compacted_history), + ) # Build system prompt system_prompt = self._build_system_prompt(memoir_context, context) @@ -237,8 +243,9 @@ def _process_record(self, key: str | None, value: str | None, topic: str, partit response = self._call_claude(system_prompt, messages) think_response = self._parse_response(response, session_id, user_id, latest_input) think_response["prevSessionCost"] = cumulative_cost - if compacted_history is not None: - think_response["compactedHistory"] = compacted_history + think_response["previousMessages"] = effective_history + think_response["lastInputMessage"] = latest_input + think_response["lastInputResponse"] = think_response.pop("messages", []) # Produce to output topic self._producer.produce( @@ -315,7 +322,7 @@ def _append_or_merge(self, messages: list[dict], role: str, text: str) -> None: else: messages.append({"role": role, "content": text}) - def _call_claude(self, system_prompt: str, messages: list[dict]) -> dict: + def _call_claude(self, system_prompt: str, messages: list[dict], *, include_tools: bool = True) -> dict: body: dict[str, Any] = { "model": self._config.claude_model, "max_tokens": self._config.claude_max_tokens, @@ -323,7 +330,7 @@ def _call_claude(self, system_prompt: str, messages: list[dict]) -> dict: "messages": messages, } - if self._config.tools: + if include_tools and self._config.tools: body["tools"] = self._config.tools data = json.dumps(body).encode() @@ -362,10 +369,6 @@ def _parse_response(self, response: dict, session_id: str, user_id: str, latest_ messages: list[dict] = [] tool_uses: list[dict] = [] - # Prepend latest input for downstream request-response pairing - if latest_input: - messages.append(latest_input) - has_tool_use = any(b.get("type") == "tool_use" for b in content_blocks) if has_tool_use: @@ -495,14 +498,14 @@ def _build_function_response_parts(self, content: Any, tool_id_to_name: dict[str parts.append({"functionResponse": {"name": name, "response": res_data}}) return parts - def _call_gemini(self, system_prompt: str, contents: list[dict]) -> dict: + def _call_gemini(self, system_prompt: str, contents: list[dict], *, include_tools: bool = True) -> dict: body: dict[str, Any] = { "system_instruction": {"parts": [{"text": system_prompt}]}, "contents": contents, "generationConfig": {"maxOutputTokens": self._config.gemini_max_tokens}, } - if self._config.tools: + if include_tools and self._config.tools: func_decls = [] for tool in self._config.tools: decl: dict[str, Any] = { @@ -557,10 +560,6 @@ def _parse_gemini_response(self, response: dict, session_id: str, user_id: str, tool_uses: list[dict] = [] now = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) - # Prepend latest input - if latest_input: - messages.append(latest_input) - # Build content blocks in Claude-compatible format for history preservation content_blocks: list[dict] = [] @@ -617,6 +616,25 @@ def _parse_gemini_response(self, response: dict, session_id: str, user_id: str, "timestamp": now, } + @staticmethod + def _find_compaction_split_index(history: list[dict], keep_last: int) -> int: + """Find the index where to split history for compaction. + Everything before this index is summarized; from this index onward is kept. + Returns -1 if nothing to compact.""" + if not history or keep_last <= 0: + return -1 + total_user = sum(1 for m in history if m.get("role") == "user") + if total_user <= keep_last: + return -1 + target = total_user - keep_last + seen = 0 + for i, m in enumerate(history): + if m.get("role") == "user": + seen += 1 + if seen > target: + return i + return -1 + @staticmethod def _has_tool_use_content(msg: dict) -> bool: content = msg.get("content") From f85919e4db1ebe7d0ce130ffe052808619dfb6bf Mon Sep 17 00:00:00 2001 From: Taku Suzuki Date: Sat, 4 Apr 2026 23:21:56 +0900 Subject: [PATCH 3/3] fix python test --- sdk/python/tests/test_think_consumer.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/sdk/python/tests/test_think_consumer.py b/sdk/python/tests/test_think_consumer.py index 318a3e7..b6936ab 100644 --- a/sdk/python/tests/test_think_consumer.py +++ b/sdk/python/tests/test_think_consumer.py @@ -146,10 +146,9 @@ def test_text_only_response_without_pricing(self): assert result["cost"] is None # No pricing env vars set assert len(result["toolUses"]) == 0 - # messages[0] = latest_input, messages[1] = assistant text - assert result["messages"][0] == {"content": "check order"} - assert result["messages"][1]["content"] == "Your order is shipped." - assert result["messages"][1]["role"] == "assistant" + assert len(result["messages"]) == 1 + assert result["messages"][0]["content"] == "Your order is shipped." + assert result["messages"][0]["role"] == "assistant" def test_tool_use_response(self): runner = make_runner() @@ -334,9 +333,9 @@ def test_text_only_response(self): assert result["inputTokens"] == 100 assert result["outputTokens"] == 50 assert len(result["toolUses"]) == 0 - # messages[0] = latest_input, messages[1] = assistant text - assert result["messages"][1]["role"] == "assistant" - assert result["messages"][1]["content"] == "Hello from Gemini" + assert len(result["messages"]) == 1 + assert result["messages"][0]["role"] == "assistant" + assert result["messages"][0]["content"] == "Hello from Gemini" def test_function_call_response(self): runner = make_runner(llm_provider="gemini", gemini_api_key="test-key")