diff --git a/internal/cmd/flags.go b/internal/cmd/flags.go index c7180db2..7939d6a1 100644 --- a/internal/cmd/flags.go +++ b/internal/cmd/flags.go @@ -35,8 +35,8 @@ package cmd import ( + "github.com/github/gh-aw-mcpg/internal/config" "github.com/github/gh-aw-mcpg/internal/difc" - "github.com/github/gh-aw-mcpg/internal/guard" "github.com/spf13/cobra" ) @@ -96,7 +96,7 @@ func registerFlagCompletions(cmd *cobra.Command) { cmd.RegisterFlagCompletionFunc("guards-mode", cobra.FixedCompletions( difc.ValidModes, cobra.ShellCompDirectiveNoFileComp)) cmd.RegisterFlagCompletionFunc("allowonly-min-integrity", cobra.FixedCompletions( - guard.AllowedIntegrityLevels, cobra.ShellCompDirectiveNoFileComp)) + config.AllIntegrityLevels(), cobra.ShellCompDirectiveNoFileComp)) // Add ActiveHelp for --config and --config-stdin flags cmd.ValidArgsFunction = func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { diff --git a/internal/cmd/flags_test.go b/internal/cmd/flags_test.go index cf29f78f..09c9eb02 100644 --- a/internal/cmd/flags_test.go +++ b/internal/cmd/flags_test.go @@ -4,8 +4,8 @@ import ( "strings" "testing" + "github.com/github/gh-aw-mcpg/internal/config" "github.com/github/gh-aw-mcpg/internal/difc" - "github.com/github/gh-aw-mcpg/internal/guard" "github.com/spf13/cobra" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -177,7 +177,7 @@ func TestRegisterFlagCompletions(t *testing.T) { completions, directive := completionFn(cmd, nil, "") assert.Equal(t, cobra.ShellCompDirectiveNoFileComp, directive, "allowonly-min-integrity flag should use NoFileComp directive") - assert.ElementsMatch(t, guard.AllowedIntegrityLevels, completions, + assert.ElementsMatch(t, config.AllIntegrityLevels(), completions, "allowonly-min-integrity should complete with all valid integrity levels") }) diff --git a/internal/config/guard_policy_test.go b/internal/config/guard_policy_test.go index 80df7598..38f8cc78 100644 --- a/internal/config/guard_policy_test.go +++ b/internal/config/guard_policy_test.go @@ -1421,7 +1421,7 @@ func TestValidateAndNormalizeIntegrityField(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := validateAndNormalizeIntegrityField(tt.fieldPath, tt.raw, tt.optional) + got, err := ValidateAndNormalizeIntegrityField(tt.fieldPath, tt.raw, tt.optional) if tt.wantErrContains != "" { require.Error(t, err) assert.ErrorContains(t, err, tt.wantErrContains) diff --git a/internal/config/guard_policy_validation.go b/internal/config/guard_policy_validation.go index 37ed0191..03dca6c9 100644 --- a/internal/config/guard_policy_validation.go +++ b/internal/config/guard_policy_validation.go @@ -98,7 +98,7 @@ func NormalizeGuardPolicy(policy *GuardPolicy) (*NormalizedGuardPolicy, error) { return nil, fmt.Errorf("policy must include allow-only") } - integrity, err := validateAndNormalizeIntegrityField("allow-only.min-integrity", policy.AllowOnly.MinIntegrity, false) + integrity, err := ValidateAndNormalizeIntegrityField("allow-only.min-integrity", policy.AllowOnly.MinIntegrity, false) if err != nil { return nil, err } @@ -144,14 +144,14 @@ func NormalizeGuardPolicy(policy *GuardPolicy) (*NormalizedGuardPolicy, error) { // Validate and normalize disapproval-integrity (optional; empty means feature // uses Rust-side default of "none" when endorsement/disapproval is evaluated). - normalized.DisapprovalIntegrity, err = validateAndNormalizeIntegrityField("allow-only.disapproval-integrity", policy.AllowOnly.DisapprovalIntegrity, true) + normalized.DisapprovalIntegrity, err = ValidateAndNormalizeIntegrityField("allow-only.disapproval-integrity", policy.AllowOnly.DisapprovalIntegrity, true) if err != nil { return nil, err } // Validate and normalize endorser-min-integrity (optional; empty means feature // uses Rust-side default of "approved" when evaluating reactor eligibility). - normalized.EndorserMinIntegrity, err = validateAndNormalizeIntegrityField("allow-only.endorser-min-integrity", policy.AllowOnly.EndorserMinIntegrity, true) + normalized.EndorserMinIntegrity, err = ValidateAndNormalizeIntegrityField("allow-only.endorser-min-integrity", policy.AllowOnly.EndorserMinIntegrity, true) if err != nil { return nil, err } @@ -204,7 +204,9 @@ func NormalizeGuardPolicy(policy *GuardPolicy) (*NormalizedGuardPolicy, error) { } } -func validateAndNormalizeIntegrityField(fieldPath, raw string, optional bool) (string, error) { +// ValidateAndNormalizeIntegrityField validates and normalizes a named integrity-level field. +// It wraps NormalizeIntegrityLevel and prefixes the field path in any error message. +func ValidateAndNormalizeIntegrityField(fieldPath, raw string, optional bool) (string, error) { v, err := NormalizeIntegrityLevel(raw, optional) if err != nil { return "", fmt.Errorf("%s %w", fieldPath, err) diff --git a/internal/guard/guard_test.go b/internal/guard/guard_test.go index b9af9e8d..efc29f17 100644 --- a/internal/guard/guard_test.go +++ b/internal/guard/guard_test.go @@ -858,7 +858,7 @@ func TestBuildStrictLabelAgentPayload(t *testing.T) { _, err := buildStrictLabelAgentPayload(input) require.Error(t, err) - assert.Equal(t, "invalid integrity value: expected one of none|unapproved|approved|merged", err.Error()) + assert.Equal(t, "integrity must be one of: none, unapproved, approved, merged", err.Error()) }) } @@ -976,7 +976,7 @@ func TestBuildStrictLabelAgentPayloadExtendedGuard(t *testing.T) { } _, err := buildStrictLabelAgentPayload(input) require.Error(t, err) - assert.ErrorContains(t, err, "invalid integrity value") + assert.ErrorContains(t, err, "integrity must be one of") }) t.Run("blocked-users validation - not an array", func(t *testing.T) { @@ -1144,7 +1144,7 @@ func TestBuildStrictLabelAgentPayloadExtendedGuard(t *testing.T) { input["allow-only"].(map[string]interface{})["disapproval-integrity"] = 99 _, err := buildStrictLabelAgentPayload(input) require.Error(t, err) - assert.ErrorContains(t, err, "invalid disapproval-integrity value: expected one of none|unapproved|approved|merged") + assert.ErrorContains(t, err, "disapproval-integrity must be one of") }) t.Run("disapproval-integrity validation - invalid string value", func(t *testing.T) { @@ -1152,7 +1152,7 @@ func TestBuildStrictLabelAgentPayloadExtendedGuard(t *testing.T) { input["allow-only"].(map[string]interface{})["disapproval-integrity"] = "high" _, err := buildStrictLabelAgentPayload(input) require.Error(t, err) - assert.ErrorContains(t, err, "invalid disapproval-integrity value: expected one of none|unapproved|approved|merged") + assert.ErrorContains(t, err, "disapproval-integrity must be one of") }) t.Run("disapproval-integrity validation - valid", func(t *testing.T) { @@ -1168,7 +1168,7 @@ func TestBuildStrictLabelAgentPayloadExtendedGuard(t *testing.T) { input["allow-only"].(map[string]interface{})["endorser-min-integrity"] = true _, err := buildStrictLabelAgentPayload(input) require.Error(t, err) - assert.ErrorContains(t, err, "invalid endorser-min-integrity value: expected one of none|unapproved|approved|merged") + assert.ErrorContains(t, err, "endorser-min-integrity must be one of") }) t.Run("endorser-min-integrity validation - invalid string value", func(t *testing.T) { @@ -1176,7 +1176,7 @@ func TestBuildStrictLabelAgentPayloadExtendedGuard(t *testing.T) { input["allow-only"].(map[string]interface{})["endorser-min-integrity"] = "critical" _, err := buildStrictLabelAgentPayload(input) require.Error(t, err) - assert.ErrorContains(t, err, "invalid endorser-min-integrity value: expected one of none|unapproved|approved|merged") + assert.ErrorContains(t, err, "endorser-min-integrity must be one of") }) t.Run("endorser-min-integrity validation - valid", func(t *testing.T) { diff --git a/internal/guard/wasm_test.go b/internal/guard/wasm_test.go index 0f302eea..8264c171 100644 --- a/internal/guard/wasm_test.go +++ b/internal/guard/wasm_test.go @@ -513,7 +513,7 @@ func TestBuildStrictLabelAgentPayloadExtended(t *testing.T) { _, err := buildStrictLabelAgentPayload(policy) require.Error(t, err) - assert.ErrorContains(t, err, "invalid min-integrity value") + assert.ErrorContains(t, err, "min-integrity must be one of") }) t.Run("valid allow-only policy succeeds", func(t *testing.T) { diff --git a/internal/guard/wasm_validate.go b/internal/guard/wasm_validate.go index c98245cc..0d2fab00 100644 --- a/internal/guard/wasm_validate.go +++ b/internal/guard/wasm_validate.go @@ -7,28 +7,16 @@ import ( "github.com/github/gh-aw-mcpg/internal/config" ) -// AllowedIntegrityLevels is derived from the canonical integrity levels in config. -var AllowedIntegrityLevels = config.AllIntegrityLevels() - -func invalidIntegrityFieldError(fieldName string) error { - return fmt.Errorf( - "invalid %s value: expected one of %s", - fieldName, - strings.Join(AllowedIntegrityLevels, "|"), - ) -} - // validateIntegrityField returns an error if raw is not a valid integrity-level // string. fieldName is used in the error message (e.g. "disapproval-integrity"). +// It delegates to config.ValidateAndNormalizeIntegrityField for validation. func validateIntegrityField(fieldName string, raw interface{}) error { s, ok := raw.(string) if !ok { - return invalidIntegrityFieldError(fieldName) - } - if _, err := config.NormalizeIntegrityLevel(s, false); err == nil { - return nil + s = "" } - return invalidIntegrityFieldError(fieldName) + _, err := config.ValidateAndNormalizeIntegrityField(fieldName, s, false) + return err } // checkBoolFailure returns a non-nil error if the given raw response map diff --git a/internal/guard/wasm_validate_test.go b/internal/guard/wasm_validate_test.go index 8812f8c4..c865a83b 100644 --- a/internal/guard/wasm_validate_test.go +++ b/internal/guard/wasm_validate_test.go @@ -211,28 +211,28 @@ func TestValidateIntegrityField(t *testing.T) { fieldName: "disapproval-integrity", raw: 42, wantErr: true, - wantErrContains: "invalid disapproval-integrity value", + wantErrContains: "disapproval-integrity must be one of", }, { name: "bool returns error", fieldName: "endorser-min-integrity", raw: true, wantErr: true, - wantErrContains: "invalid endorser-min-integrity value", + wantErrContains: "endorser-min-integrity must be one of", }, { name: "nil returns error", fieldName: "min-integrity", raw: nil, wantErr: true, - wantErrContains: "invalid min-integrity value", + wantErrContains: "min-integrity must be one of", }, { name: "slice returns error", fieldName: "disapproval-integrity", raw: []string{"none"}, wantErr: true, - wantErrContains: "invalid disapproval-integrity value", + wantErrContains: "disapproval-integrity must be one of", }, // Invalid string value { @@ -240,21 +240,21 @@ func TestValidateIntegrityField(t *testing.T) { fieldName: "disapproval-integrity", raw: "invalid", wantErr: true, - wantErrContains: "invalid disapproval-integrity value", + wantErrContains: "disapproval-integrity must be one of", }, { name: "empty string returns error", fieldName: "endorser-min-integrity", raw: "", wantErr: true, - wantErrContains: "invalid endorser-min-integrity value", + wantErrContains: "endorser-min-integrity must be one of", }, { name: "whitespace-only string returns error", fieldName: "min-integrity", raw: " ", wantErr: true, - wantErrContains: "invalid min-integrity value", + wantErrContains: "min-integrity must be one of", }, // Valid integrity levels { @@ -306,7 +306,7 @@ func TestValidateIntegrityField(t *testing.T) { fieldName: "disapproval-integrity", raw: "bad", wantErr: true, - wantErrContains: "none|unapproved|approved|merged", + wantErrContains: "must be one of", }, } @@ -324,10 +324,3 @@ func TestValidateIntegrityField(t *testing.T) { }) } } - -func TestInvalidIntegrityFieldError(t *testing.T) { - err := invalidIntegrityFieldError("test-field") - require.Error(t, err) - assert.ErrorContains(t, err, "test-field") - assert.ErrorContains(t, err, "none|unapproved|approved|merged") -} diff --git a/internal/logger/doc.go b/internal/logger/doc.go index 3a523fcd..6c143c23 100644 --- a/internal/logger/doc.go +++ b/internal/logger/doc.go @@ -8,4 +8,20 @@ // // These APIs target different sinks and can be used together when a message should // appear in multiple outputs. +// +// # Per-type setup and error-handler functions +// +// Each logger type has its own setup* and handleError* functions (e.g. +// setupFileLogger / handleFileLoggerError, setupMarkdownLogger / +// handleMarkdownLoggerError). These are intentionally not collapsed into a +// single generic helper because each type has unique initialization behaviour: +// +// - JSONLLogger has no fallback path (a failure returns an error directly). +// - ToolsLogger writes a one-time header and closes the file immediately after +// opening, so its setup function owns that lifecycle step. +// - FileLogger, MarkdownLogger, and RPCLogger each open a persistent file and +// wire up different formatters. +// +// The per-type functions are bundled via the loggerFactory[T] generic defined in +// common.go, which handles the shared open-file / call-setup / call-onError flow. package logger diff --git a/internal/server/session.go b/internal/server/session.go index 4b180ddb..c5a24700 100644 --- a/internal/server/session.go +++ b/internal/server/session.go @@ -165,18 +165,6 @@ func (us *UnifiedServer) sysInitHandler(ctx context.Context, req *sdk.CallToolRe return us.callAndLogSysTool(sessionID, "session initialization", "sys_init") } -func (us *UnifiedServer) sysListServersHandler(ctx context.Context, _ *sdk.CallToolRequest, _ interface{}) (*sdk.CallToolResult, interface{}, error) { - sessionID := us.getSessionID(ctx) - logger.LogInfo("client", "MCP sys_list_servers request, session=%s", truncateSessionID(sessionID)) - - if err := us.requireSession(ctx); err != nil { - logger.LogError("client", "MCP sys_list_servers failed: session not initialized, session=%s", sessionID) - return mcp.NewErrorCallToolResult(err) - } - - return us.callAndLogSysTool(truncateSessionID(sessionID), "sys_list_servers", "sys_list_servers") -} - // getSessionKeys returns a list of active session IDs for debugging func (us *UnifiedServer) getSessionKeys() []string { us.sessionMu.RLock() diff --git a/internal/server/system_tools.go b/internal/server/system_tools.go index a2527e3c..3d24a544 100644 --- a/internal/server/system_tools.go +++ b/internal/server/system_tools.go @@ -1,10 +1,12 @@ package server import ( + "context" "fmt" "github.com/github/gh-aw-mcpg/internal/logger" "github.com/github/gh-aw-mcpg/internal/mcp" + sdk "github.com/modelcontextprotocol/go-sdk/mcp" ) var logSys = logger.New("server:system_tools") @@ -40,3 +42,17 @@ func (s *SysServer) ListServers() (interface{}, error) { return mcp.BuildMCPTextResponse(fmt.Sprintf("Configured MCP Servers:\n%s", serverList)), nil } + +// sysListServersHandler handles sys___list_servers tool calls. +// It validates that a session exists and delegates to callAndLogSysTool. +func (us *UnifiedServer) sysListServersHandler(ctx context.Context, _ *sdk.CallToolRequest, _ interface{}) (*sdk.CallToolResult, interface{}, error) { + sessionID := us.getSessionID(ctx) + logger.LogInfo("client", "MCP sys_list_servers request, session=%s", truncateSessionID(sessionID)) + + if err := us.requireSession(ctx); err != nil { + logger.LogError("client", "MCP sys_list_servers failed: session not initialized, session=%s", sessionID) + return mcp.NewErrorCallToolResult(err) + } + + return us.callAndLogSysTool(truncateSessionID(sessionID), "sys_list_servers", "sys_list_servers") +} diff --git a/internal/strutil/json_clone.go b/internal/strutil/json_clone.go deleted file mode 100644 index 0bf648df..00000000 --- a/internal/strutil/json_clone.go +++ /dev/null @@ -1,21 +0,0 @@ -package strutil - -// DeepCloneJSON creates a deep copy of a JSON-compatible value. -func DeepCloneJSON(v interface{}) interface{} { - switch val := v.(type) { - case map[string]interface{}: - clone := make(map[string]interface{}, len(val)) - for k, v := range val { - clone[k] = DeepCloneJSON(v) - } - return clone - case []interface{}: - clone := make([]interface{}, len(val)) - for i, v := range val { - clone[i] = DeepCloneJSON(v) - } - return clone - default: - return v - } -} diff --git a/internal/strutil/session_suffix.go b/internal/strutil/session_suffix.go deleted file mode 100644 index 91514dc7..00000000 --- a/internal/strutil/session_suffix.go +++ /dev/null @@ -1,12 +0,0 @@ -package strutil - -import "fmt" - -// SessionSuffix returns a formatted session suffix for log messages. -// Returns " for session ''" when sessionID is non-empty, or "" otherwise. -func SessionSuffix(sessionID string) string { - if sessionID == "" { - return "" - } - return fmt.Sprintf(" for session '%s'", sessionID) -} diff --git a/internal/strutil/strutil.go b/internal/strutil/strutil.go index 9a0e2171..726363e0 100644 --- a/internal/strutil/strutil.go +++ b/internal/strutil/strutil.go @@ -1,6 +1,7 @@ package strutil import ( + "fmt" "sort" "strings" ) @@ -65,3 +66,32 @@ func CopyTrimmedStringIntMap(input map[string]int) map[string]int { } return out } + +// SessionSuffix returns a formatted session suffix for log messages. +// Returns " for session ''" when sessionID is non-empty, or "" otherwise. +func SessionSuffix(sessionID string) string { + if sessionID == "" { + return "" + } + return fmt.Sprintf(" for session '%s'", sessionID) +} + +// DeepCloneJSON creates a deep copy of a JSON-compatible value. +func DeepCloneJSON(v interface{}) interface{} { + switch val := v.(type) { + case map[string]interface{}: + clone := make(map[string]interface{}, len(val)) + for k, v := range val { + clone[k] = DeepCloneJSON(v) + } + return clone + case []interface{}: + clone := make([]interface{}, len(val)) + for i, v := range val { + clone[i] = DeepCloneJSON(v) + } + return clone + default: + return v + } +}