diff --git a/internal/middleware/README.md b/internal/middleware/README.md index 0e40ce68..d9df8e3a 100644 --- a/internal/middleware/README.md +++ b/internal/middleware/README.md @@ -91,28 +91,38 @@ with open(payload_path) as f: ## Implementation Details -### jq Schema Filter +### Schema Walk: Native Go Implementation -The middleware uses the same jq filter logic as the gh-aw jqschema utility: +The middleware implements `walk_schema` as a native Go function (`inferSchema`) registered +via `gojq.WithFunction`. This means all recursive traversal runs directly in Go instead of +going through the jq interpreter, giving better performance on deeply-nested payloads. + +The transformation replaces every leaf value with its jq type name: + +| Go type | Schema value | +|------------------------|--------------| +| `nil` | `"null"` | +| `bool` | `"boolean"` | +| `float64`, `int`, `json.Number` | `"number"` | +| `string` | `"string"` | +| `map[string]interface{}` | recursed object | +| `[]interface{}` | recursed array (first element only, or `[]`) | + +The jq filter invokes the native function with a single call: ```jq -def walk_schema: - . as $in | - if type == "object" then - reduce keys[] as $k ({}; . + {($k): ($in[$k] | walk_schema)}) - elif type == "array" then - if length == 0 then [] else [.[0] | walk_schema] end - else - type - end; walk_schema ``` -This recursively walks the JSON structure and replaces values with their type names. The function is named `walk_schema` to avoid shadowing gojq's built-in `walk/1`. +This replaces the previous pure-jq recursive `def walk_schema: ...` definition. Behaviour +is identical but the hot recursive path is now compiled Go rather than jq bytecode. ### Go Implementation -The middleware is implemented using [gojq](https://github.com/itchyny/gojq), a pure Go implementation of jq, eliminating the need to spawn external processes. +The middleware is implemented using [gojq](https://github.com/itchyny/gojq), a pure Go +implementation of jq, eliminating the need to spawn external processes. The +`walk_schema` function is bound to the native `inferSchema` Go function via +`gojq.WithFunction` at compile time. ## Configuration diff --git a/internal/middleware/jqschema.go b/internal/middleware/jqschema.go index cf855618..b428a040 100644 --- a/internal/middleware/jqschema.go +++ b/internal/middleware/jqschema.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "path/filepath" + "reflect" "strings" "time" "unicode/utf8" @@ -41,57 +42,87 @@ type PayloadMetadata struct { QueryID string `json:"-"` // Internal use only, not serialized to clients } -// jqSchemaFilter is the jq filter that transforms JSON to schema -// This filter leverages gojq v0.12.19 features including: -// - Enhanced array handling (supports up to 536,870,912 elements / 2^29) -// - Improved concurrent execution performance -// - Better error messages for type errors +// jqSchemaFilter is the jq entry point that invokes the native Go walk_schema function. +// The recursive schema-walk logic is implemented in inferSchema (see below) and registered +// via gojq.WithFunction, so the filter itself is a single call. // -// The filter recursively walks JSON structures and replaces values with their type names: +// The transformation replaces every leaf value with its jq type name: // // Input: {"name": "test", "count": 42, "items": [{"id": 1}]} // Output: {"name": "string", "count": "number", "items": [{"id": "number"}]} // // For arrays, only the first element's schema is retained to represent the array structure. // Empty arrays are preserved as []. -// -// NOTE: This defines a custom walk_schema function rather than using gojq's built-in walk(f). -// The built-in walk(f) applies f to every node but preserves the original structure. -// Our custom walk_schema does two things the built-in cannot: -// 1. Replaces leaf values with their type name (e.g., "test" → "string") -// 2. Collapses arrays to only the first element for schema inference -// -// These behaviors are incompatible with standard walk(f) semantics, which would -// apply f post-recursion without structural changes to arrays. -// Using a distinct name avoids shadowing gojq's built-in walk/1. -const jqSchemaFilter = ` -def walk_schema: - . as $in | - if type == "object" then - reduce keys[] as $k ({}; . + {($k): ($in[$k] | walk_schema)}) - elif type == "array" then - if length == 0 then [] else [.[0] | walk_schema] end - else - type - end; -walk_schema -` - -// Pre-compiled jq query code for performance -// This is compiled once at package initialization and reused for all requests +const jqSchemaFilter = `walk_schema` + +// Pre-compiled jq query code for performance. +// This is compiled once at package initialization and reused for all requests. var ( jqSchemaCode *gojq.Code jqSchemaCompileErr error ) -// init compiles the jq schema filter at startup for better performance and validation -// Following gojq best practices: compile once, run many times +// inferSchema recursively walks a JSON-compatible Go value and replaces every leaf +// with its jq type name ("null", "boolean", "number", "string"). Objects are +// traversed key-by-key; arrays are collapsed to a single representative element (or +// [] when empty). The output mirrors what the previous pure-jq walk_schema filter +// produced, but runs entirely in Go, bypassing jq interpreter overhead for recursion. +// +// Type mapping (matches jq's built-in type function): +// - nil → "null" +// - bool → "boolean" +// - any integer or floating-point numeric type → "number" +// (float32/64, int/8/16/32/64, uint/8/16/32/64, json.Number) +// - string → "string" +// - map[string]interface{} → recursed object +// - []interface{} → recursed array (first element only) +func inferSchema(v interface{}) interface{} { + switch val := v.(type) { + case map[string]interface{}: + result := make(map[string]interface{}, len(val)) + for k, child := range val { + result[k] = inferSchema(child) + } + return result + case []interface{}: + if len(val) == 0 { + return []interface{}{} + } + return []interface{}{inferSchema(val[0])} + case nil: + return "null" + case bool: + return "boolean" + case float64, float32, + int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, + json.Number: + return "number" + case string: + return "string" + default: + // Defensive fallback: classify any remaining numeric reflect.Kind as "number" + // and everything else as "string" to keep the schema output valid. + switch reflect.TypeOf(v).Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + return "number" + default: + return "string" + } + } +} + +// init compiles the jq schema filter at startup for better performance and validation. +// Following gojq best practices: compile once, run many times. +// +// The walk_schema function is registered as a native Go implementation via +// gojq.WithFunction so that the recursive schema walk runs entirely in Go, +// avoiding jq interpreter overhead for deeply-nested payloads. // // This provides fail-fast behavior - if the jq query is invalid, the application // will fail at startup rather than at runtime during a tool call. -// -// Performance benefit: Compiling once and reusing the code provides 10-100x speedup -// compared to parsing and compiling on every request. func init() { query, err := gojq.Parse(jqSchemaFilter) if err != nil { @@ -101,7 +132,11 @@ func init() { return } - jqSchemaCode, jqSchemaCompileErr = gojq.Compile(query) + jqSchemaCode, jqSchemaCompileErr = gojq.Compile(query, + gojq.WithFunction("walk_schema", 0, 0, func(v interface{}, _ []interface{}) interface{} { + return inferSchema(v) + }), + ) if jqSchemaCompileErr != nil { logMiddleware.Printf("FATAL: Failed to compile jq schema filter at init: %v", jqSchemaCompileErr) logger.LogError("startup", "Failed to compile jq schema filter at init (application will not start): %v", jqSchemaCompileErr) @@ -109,7 +144,7 @@ func init() { } logMiddleware.Printf("Successfully compiled jq schema filter at init") - logger.LogInfo("startup", "jq schema filter compiled successfully - array limit: 2^29 elements, timeout: %v", DefaultJqTimeout) + logger.LogInfo("startup", "jq schema filter compiled successfully - native Go walk_schema, array limit: 2^29 elements, timeout: %v", DefaultJqTimeout) } // generateRandomID generates a random ID for payload storage @@ -350,9 +385,20 @@ func WrapToolHandler( // Prepare data for jq processing. If data is already a native JSON-compatible // Go type, use it directly to avoid a redundant JSON round-trip. var jsonData interface{} - switch data.(type) { + switch v := data.(type) { case map[string]interface{}, []interface{}, string, float64, bool: jsonData = data + case nil: + // nil data produces a nil schema; nothing meaningful to infer. + return nil + case json.Number: + // json.Number is emitted by decoders using UseNumber(); convert to + // float64 so jq sees a plain number value rather than an opaque type. + f, convErr := v.Float64() + if convErr != nil { + return fmt.Errorf("failed to convert json.Number to float64: %w", convErr) + } + jsonData = f default: if err := json.Unmarshal(payloadJSON, &jsonData); err != nil { return fmt.Errorf("failed to unmarshal for schema: %w", err) diff --git a/internal/middleware/jqschema_bench_test.go b/internal/middleware/jqschema_bench_test.go index 511b9b57..4149154b 100644 --- a/internal/middleware/jqschema_bench_test.go +++ b/internal/middleware/jqschema_bench_test.go @@ -82,90 +82,15 @@ func BenchmarkApplyJqSchema_CompiledCode(b *testing.B) { } } -// BenchmarkApplyJqSchema_ParseEveryTime benchmarks the old implementation -// that parses the query on every invocation (for comparison) +// BenchmarkApplyJqSchema_ParseEveryTime used to benchmark parsing the query on every +// invocation to quantify the compile-once speedup. +// +// NOTE: This benchmark is no longer valid. Since walk_schema is now a native Go +// function registered via gojq.WithFunction, running the parsed query without the +// corresponding function registration will fail at runtime with an undefined-function +// error. Skipping to avoid misleading benchmark output. func BenchmarkApplyJqSchema_ParseEveryTime(b *testing.B) { - tests := []struct { - name string - input interface{} - }{ - { - name: "small object", - input: map[string]interface{}{"name": "test", "count": 42, "active": true}, - }, - { - name: "medium object", - input: map[string]interface{}{ - "total_count": 1000, - "items": []interface{}{ - map[string]interface{}{"id": 1, "name": "item1", "price": 10.5}, - map[string]interface{}{"id": 2, "name": "item2", "price": 20.5}, - map[string]interface{}{"id": 3, "name": "item3", "price": 30.5}, - }, - }, - }, - { - name: "large nested object", - input: map[string]interface{}{ - "user": map[string]interface{}{ - "id": 123, - "login": "testuser", - "verified": true, - "profile": map[string]interface{}{ - "bio": "Test bio", - "location": "Test location", - "website": "https://example.com", - }, - }, - "repositories": []interface{}{ - map[string]interface{}{ - "id": 1, - "name": "repo1", - "stars": 100, - "description": "First repo", - "owner": map[string]interface{}{ - "login": "owner1", - "id": 999, - }, - }, - map[string]interface{}{ - "id": 2, - "name": "repo2", - "stars": 200, - "description": "Second repo", - "owner": map[string]interface{}{ - "login": "owner2", - "id": 888, - }, - }, - }, - }, - }, - } - - for _, tt := range tests { - b.Run(tt.name, func(b *testing.B) { - ctx := context.Background() - b.ResetTimer() - for i := 0; i < b.N; i++ { - // Simulate old implementation: Parse on every call - query, err := gojq.Parse(jqSchemaFilter) - if err != nil { - b.Fatalf("Parse failed: %v", err) - } - - iter := query.RunWithContext(ctx, tt.input) - v, ok := iter.Next() - if !ok { - b.Fatal("No results") - } - - if err, ok := v.(error); ok { - b.Fatalf("Query error: %v", err) - } - } - }) - } + b.Skip("invalid benchmark: walk_schema requires gojq.Compile with gojq.WithFunction registration; parse-only path no longer produces meaningful results") } // BenchmarkCompileVsParse compares the time to compile vs parse the jq query diff --git a/internal/middleware/jqschema_test.go b/internal/middleware/jqschema_test.go index a3fb08a0..9f31939c 100644 --- a/internal/middleware/jqschema_test.go +++ b/internal/middleware/jqschema_test.go @@ -1418,3 +1418,153 @@ func TestWrapToolHandler_PreviewUTF8Boundary(t *testing.T) { assert.True(t, utf8.ValidString(preview), "preview must be valid UTF-8") assert.True(t, strings.HasSuffix(preview, "..."), "truncated preview must end with '...'") } + +// TestInferSchema verifies that the native Go inferSchema function produces the same +// output as the previous pure-jq walk_schema filter for all supported value types. +func TestInferSchema(t *testing.T) { + tests := []struct { + name string + input interface{} + expected interface{} + }{ + { + name: "nil leaf", + input: nil, + expected: "null", + }, + { + name: "bool leaf", + input: true, + expected: "boolean", + }, + { + name: "float64 leaf", + input: float64(3.14), + expected: "number", + }, + { + name: "int leaf", + input: 42, + expected: "number", + }, + { + name: "json.Number leaf", + input: json.Number("99"), + expected: "number", + }, + { + name: "int64 leaf", + input: int64(100), + expected: "number", + }, + { + name: "float32 leaf", + input: float32(2.5), + expected: "number", + }, + { + name: "string leaf", + input: "hello", + expected: "string", + }, + { + name: "empty object", + input: map[string]interface{}{}, + expected: map[string]interface{}{}, + }, + { + name: "flat object", + input: map[string]interface{}{"name": "alice", "age": float64(30), "active": false}, + expected: map[string]interface{}{"name": "string", "age": "number", "active": "boolean"}, + }, + { + name: "nested object", + input: map[string]interface{}{"user": map[string]interface{}{"id": float64(1), "verified": true}}, + expected: map[string]interface{}{"user": map[string]interface{}{"id": "number", "verified": "boolean"}}, + }, + { + name: "empty array", + input: []interface{}{}, + expected: []interface{}{}, + }, + { + name: "array collapses to first element schema", + input: []interface{}{map[string]interface{}{"id": float64(1)}, map[string]interface{}{"id": float64(2)}}, + expected: []interface{}{map[string]interface{}{"id": "number"}}, + }, + { + name: "object with nil value", + input: map[string]interface{}{"value": nil}, + expected: map[string]interface{}{"value": "null"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := inferSchema(tt.input) + assert.Equal(t, tt.expected, got) + }) + } +} + +// TestInferSchema_MatchesJqOutput verifies that inferSchema (called directly) produces +// the same output as applyJqSchema (which invokes inferSchema via the gojq runtime). +// This validates the gojq.WithFunction wiring: that the compiled jq code correctly +// dispatches to the native Go implementation for all supported input shapes. +func TestInferSchema_MatchesJqOutput(t *testing.T) { + inputs := []interface{}{ + map[string]interface{}{"name": "test", "count": 42}, + map[string]interface{}{"user": map[string]interface{}{"id": 123, "active": true}}, + map[string]interface{}{"items": []interface{}{map[string]interface{}{"id": 1, "name": "test"}}}, + map[string]interface{}{"items": []interface{}{}}, + map[string]interface{}{"value": nil}, + } + + for _, input := range inputs { + inputJSON, _ := json.Marshal(input) + t.Run(string(inputJSON), func(t *testing.T) { + jqResult, err := applyJqSchema(context.Background(), input) + require.NoError(t, err, "applyJqSchema must not error") + + goResult := inferSchema(input) + + jqJSON, err := json.Marshal(jqResult) + require.NoError(t, err) + goJSON, err := json.Marshal(goResult) + require.NoError(t, err) + + assert.JSONEq(t, string(jqJSON), string(goJSON), + "inferSchema output must match applyJqSchema output") + }) + } +} + +// TestWrapToolHandler_JsonNumberData verifies that a json.Number value in the data +// field is handled correctly by the type switch (converted to float64, not falling +// through to the json.Unmarshal path). +func TestWrapToolHandler_JsonNumberData(t *testing.T) { + baseDir := t.TempDir() + + // Return json.Number as the top-level data value so the type switch hits the + // direct json.Number case (the switch is on `data`, not on individual map values). + mockHandler := func(ctx context.Context, req *sdk.CallToolRequest, args interface{}) (*sdk.CallToolResult, interface{}, error) { + return &sdk.CallToolResult{IsError: false}, json.Number("42"), nil + } + + // Use a threshold of 0 to force the large-payload path. + wrapped := WrapToolHandler(mockHandler, "test_tool", baseDir, "", 0, testGetSessionID) + result, _, err := wrapped(context.Background(), &sdk.CallToolRequest{}, nil) + + require.NoError(t, err) + require.NotNil(t, result) + assert.False(t, result.IsError) + + // The schema for a json.Number (a number leaf) should be "number". + require.NotEmpty(t, result.Content) + textContent, ok := result.Content[0].(*sdk.TextContent) + require.True(t, ok) + + var meta map[string]interface{} + require.NoError(t, json.Unmarshal([]byte(textContent.Text), &meta)) + assert.Equal(t, "number", meta["payloadSchema"], "schema for json.Number data should be 'number'") +}