diff --git a/codex-rs/core/config.schema.json b/codex-rs/core/config.schema.json index 53ab77249206..7a17e1fb86ba 100644 --- a/codex-rs/core/config.schema.json +++ b/codex-rs/core/config.schema.json @@ -411,6 +411,82 @@ } ] }, + "HookEntryToml": { + "description": "Single hook entry from configuration.", + "properties": { + "command": { + "description": "The command to execute as argv (program + args).", + "items": { + "type": "string" + }, + "type": "array" + }, + "matcher": { + "default": null, + "description": "Optional matcher pattern for tool-use hooks.\n\nSupported patterns: - `\"*\"` matches any tool name - `\"prefix*\"` matches tool names starting with `prefix` - `\"exact\"` matches only that exact tool name\n\nNote: suffix patterns like `\"*shell\"` and infix patterns like `\"read_*_file\"` are **not** supported.", + "type": "string" + }, + "timeout": { + "default": 30, + "description": "Optional timeout in seconds (default: 30).", + "format": "uint64", + "minimum": 0.0, + "type": "integer" + } + }, + "required": [ + "command" + ], + "type": "object" + }, + "HooksConfigToml": { + "description": "All hook entries grouped by event type.", + "properties": { + "after_agent": { + "default": [], + "items": { + "$ref": "#/definitions/HookEntryToml" + }, + "type": "array" + }, + "notification": { + "default": [], + "items": { + "$ref": "#/definitions/HookEntryToml" + }, + "type": "array" + }, + "post_tool_use": { + "default": [], + "items": { + "$ref": "#/definitions/HookEntryToml" + }, + "type": "array" + }, + "pre_tool_use": { + "default": [], + "items": { + "$ref": "#/definitions/HookEntryToml" + }, + "type": "array" + }, + "stop": { + "default": [], + "items": { + "$ref": "#/definitions/HookEntryToml" + }, + "type": "array" + }, + "user_prompt_submit": { + "default": [], + "items": { + "$ref": "#/definitions/HookEntryToml" + }, + "type": "array" + } + }, + "type": "object" + }, "ModeKind": { "description": "Initial collaboration mode to use when the TUI starts.", "enum": [ @@ -1379,6 +1455,22 @@ "default": null, "description": "Settings that govern if and what will be written to `~/.codex/history.jsonl`." }, + "hooks": { + "allOf": [ + { + "$ref": "#/definitions/HooksConfigToml" + } + ], + "default": { + "after_agent": [], + "notification": [], + "post_tool_use": [], + "pre_tool_use": [], + "stop": [], + "user_prompt_submit": [] + }, + "description": "Hook definitions grouped by event type. Each hook specifies a command to execute and optionally a matcher pattern for tool-use events." + }, "instructions": { "description": "System instructions.", "type": "string" diff --git a/codex-rs/core/src/config/mod.rs b/codex-rs/core/src/config/mod.rs index dee21376cbd9..1f81c3fcef4f 100644 --- a/codex-rs/core/src/config/mod.rs +++ b/codex-rs/core/src/config/mod.rs @@ -205,6 +205,9 @@ pub struct Config { /// If unset the feature is disabled. pub notify: Option>, + /// Hook definitions loaded from config.toml `[hooks]` section. + pub hooks: crate::hooks::config::HooksConfigToml, + /// TUI notifications preference. When set, the TUI will send terminal notifications on /// approvals and turn completions when not focused. pub tui_notifications: Notifications, @@ -869,6 +872,11 @@ pub struct ConfigToml { #[serde(default)] pub notify: Option>, + /// Hook definitions grouped by event type. Each hook specifies a command + /// to execute and optionally a matcher pattern for tool-use events. + #[serde(default)] + pub hooks: crate::hooks::config::HooksConfigToml, + /// System instructions. pub instructions: Option, @@ -1694,6 +1702,7 @@ impl Config { forced_auto_mode_downgraded_on_windows, shell_environment_policy, notify: cfg.notify, + hooks: cfg.hooks, user_instructions, base_instructions, personality, @@ -4008,6 +4017,7 @@ model_verbosity = "high" shell_environment_policy: ShellEnvironmentPolicy::default(), user_instructions: None, notify: None, + hooks: Default::default(), cwd: fixture.cwd(), cli_auth_credentials_store_mode: Default::default(), mcp_servers: Constrained::allow_any(HashMap::new()), @@ -4096,6 +4106,7 @@ model_verbosity = "high" shell_environment_policy: ShellEnvironmentPolicy::default(), user_instructions: None, notify: None, + hooks: Default::default(), cwd: fixture.cwd(), cli_auth_credentials_store_mode: Default::default(), mcp_servers: Constrained::allow_any(HashMap::new()), @@ -4199,6 +4210,7 @@ model_verbosity = "high" shell_environment_policy: ShellEnvironmentPolicy::default(), user_instructions: None, notify: None, + hooks: Default::default(), cwd: fixture.cwd(), cli_auth_credentials_store_mode: Default::default(), mcp_servers: Constrained::allow_any(HashMap::new()), @@ -4288,6 +4300,7 @@ model_verbosity = "high" shell_environment_policy: ShellEnvironmentPolicy::default(), user_instructions: None, notify: None, + hooks: Default::default(), cwd: fixture.cwd(), cli_auth_credentials_store_mode: Default::default(), mcp_servers: Constrained::allow_any(HashMap::new()), diff --git a/codex-rs/core/src/function_tool.rs b/codex-rs/core/src/function_tool.rs index 240e04361cd4..df7c7dca5650 100644 --- a/codex-rs/core/src/function_tool.rs +++ b/codex-rs/core/src/function_tool.rs @@ -8,4 +8,6 @@ pub enum FunctionCallError { MissingLocalShellCallId, #[error("Fatal error: {0}")] Fatal(String), + #[error("Tool call blocked: {0}")] + ToolCallBlocked(String), } diff --git a/codex-rs/core/src/hooks/config.rs b/codex-rs/core/src/hooks/config.rs new file mode 100644 index 000000000000..40ce53c163ed --- /dev/null +++ b/codex-rs/core/src/hooks/config.rs @@ -0,0 +1,341 @@ +use std::time::Duration; + +use schemars::JsonSchema; +use serde::Deserialize; +use serde::Serialize; + +use super::executor::command_hook; +use super::types::Hook; + +/// Single hook entry from configuration. +#[derive(Debug, Clone, Default, PartialEq, Deserialize, Serialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub struct HookEntryToml { + /// The command to execute as argv (program + args). + pub command: Vec, + + /// Optional timeout in seconds (default: 30). + #[serde(default = "default_timeout_secs")] + pub timeout: u64, + + /// Optional matcher pattern for tool-use hooks. + /// + /// Supported patterns: + /// - `"*"` matches any tool name + /// - `"prefix*"` matches tool names starting with `prefix` + /// - `"exact"` matches only that exact tool name + /// + /// Note: suffix patterns like `"*shell"` and infix patterns like + /// `"read_*_file"` are **not** supported. + #[serde(default)] + pub matcher: Option, +} + +fn default_timeout_secs() -> u64 { + 30 +} + +/// All hook entries grouped by event type. +#[derive(Debug, Clone, Default, PartialEq, Deserialize, Serialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub struct HooksConfigToml { + #[serde(default)] + pub after_agent: Vec, + + #[serde(default)] + pub pre_tool_use: Vec, + + #[serde(default)] + pub post_tool_use: Vec, + + #[serde(default)] + pub notification: Vec, + + #[serde(default)] + pub stop: Vec, + + #[serde(default)] + pub user_prompt_submit: Vec, +} + +/// Convert a single HookEntryToml into a Hook via the command executor. +/// +/// If the entry has a matcher pattern, the hook will only execute for events +/// whose tool name matches the pattern. Non-tool events always match. +pub(super) fn hook_from_entry(entry: &HookEntryToml) -> Hook { + let timeout = Duration::from_secs(entry.timeout); + let inner = command_hook(entry.command.clone(), timeout); + match &entry.matcher { + None => inner, + Some(pattern) => { + let pattern = pattern.clone(); + Hook { + func: std::sync::Arc::new(move |payload| { + let tool_name = match &payload.hook_event { + super::types::HookEvent::PreToolUse { event } => Some(&event.tool_name), + super::types::HookEvent::PostToolUse { event } => Some(&event.tool_name), + _ => None, // Non-tool events always match + }; + + if let Some(name) = tool_name + && !matches_pattern(&pattern, name) + { + return Box::pin(async { super::types::HookOutcome::Proceed }); + } + + inner.func.clone()(payload) + }), + } + } + } +} + +/// Check if a tool name matches a simple pattern. +/// +/// Supports three forms: +/// - `"*"` matches any tool name. +/// - `"prefix*"` (trailing wildcard) matches tool names starting with `prefix`. +/// - Any other string is compared as an exact match. +/// +/// Full glob semantics (suffix, infix wildcards) are intentionally not +/// supported to keep the matching logic trivial and predictable. +fn matches_pattern(pattern: &str, tool_name: &str) -> bool { + if pattern == "*" { + return true; + } + if let Some(prefix) = pattern.strip_suffix('*') { + return tool_name.starts_with(prefix); + } + pattern == tool_name +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + + use super::*; + + /// Check if a tool name matches a hook entry's matcher pattern. + /// If matcher is None, the hook matches all tools. + fn matches_tool(entry: &HookEntryToml, tool_name: &str) -> bool { + match &entry.matcher { + None => true, + Some(pattern) => matches_pattern(pattern, tool_name), + } + } + + #[test] + fn test_hook_entry_deserialize_minimal() { + let toml_str = r#" + command = ["./hook.sh"] + "#; + let entry: HookEntryToml = toml::from_str(toml_str).unwrap(); + assert_eq!(entry.command, vec!["./hook.sh"]); + assert_eq!(entry.timeout, 30); // default + assert_eq!(entry.matcher, None); // default + } + + #[test] + fn test_hook_entry_deserialize_full() { + let toml_str = r#" + command = ["./pre-tool.sh", "--verbose"] + timeout = 60 + matcher = "shell*" + "#; + let entry: HookEntryToml = toml::from_str(toml_str).unwrap(); + assert_eq!(entry.command, vec!["./pre-tool.sh", "--verbose"]); + assert_eq!(entry.timeout, 60); + assert_eq!(entry.matcher, Some("shell*".to_string())); + } + + #[test] + fn test_hooks_config_deserialize_empty() { + let toml_str = ""; + let config: HooksConfigToml = toml::from_str(toml_str).unwrap(); + assert!(config.after_agent.is_empty()); + assert!(config.pre_tool_use.is_empty()); + assert!(config.post_tool_use.is_empty()); + assert!(config.notification.is_empty()); + assert!(config.stop.is_empty()); + assert!(config.user_prompt_submit.is_empty()); + } + + #[test] + fn test_hooks_config_deserialize_after_agent() { + let toml_str = r#" + [[after_agent]] + command = ["./hook1.sh"] + timeout = 45 + + [[after_agent]] + command = ["./hook2.sh"] + "#; + let config: HooksConfigToml = toml::from_str(toml_str).unwrap(); + assert_eq!(config.after_agent.len(), 2); + assert_eq!(config.after_agent[0].command, vec!["./hook1.sh"]); + assert_eq!(config.after_agent[0].timeout, 45); + assert_eq!(config.after_agent[1].command, vec!["./hook2.sh"]); + assert_eq!(config.after_agent[1].timeout, 30); // default + } + + #[test] + fn test_matches_tool_none_matches_all() { + let entry = HookEntryToml { + command: vec!["./hook.sh".to_string()], + timeout: 30, + matcher: None, + }; + assert!(matches_tool(&entry, "shell")); + assert!(matches_tool(&entry, "read")); + assert!(matches_tool(&entry, "write")); + } + + #[test] + fn test_matches_tool_exact() { + let entry = HookEntryToml { + command: vec!["./hook.sh".to_string()], + timeout: 30, + matcher: Some("shell".to_string()), + }; + assert!(matches_tool(&entry, "shell")); + assert!(!matches_tool(&entry, "shell_exec")); + assert!(!matches_tool(&entry, "read")); + } + + #[test] + fn test_matches_tool_glob_prefix() { + let entry = HookEntryToml { + command: vec!["./hook.sh".to_string()], + timeout: 30, + matcher: Some("shell*".to_string()), + }; + assert!(matches_tool(&entry, "shell")); + assert!(matches_tool(&entry, "shell_exec")); + assert!(matches_tool(&entry, "shell_command")); + assert!(!matches_tool(&entry, "read")); + } + + #[test] + fn test_matches_tool_wildcard() { + let entry = HookEntryToml { + command: vec!["./hook.sh".to_string()], + timeout: 30, + matcher: Some("*".to_string()), + }; + assert!(matches_tool(&entry, "shell")); + assert!(matches_tool(&entry, "read")); + assert!(matches_tool(&entry, "write")); + assert!(matches_tool(&entry, "anything")); + } + + #[test] + fn test_matches_tool_no_match() { + let entry = HookEntryToml { + command: vec!["./hook.sh".to_string()], + timeout: 30, + matcher: Some("read".to_string()), + }; + assert!(matches_tool(&entry, "read")); + assert!(!matches_tool(&entry, "write")); + assert!(!matches_tool(&entry, "read_file")); + } + + #[test] + fn test_hooks_config_deserialize_pre_tool_use() { + let toml_str = r#" + [[pre_tool_use]] + command = ["./validate-tool.sh"] + timeout = 10 + matcher = "bash*" + + [[pre_tool_use]] + command = ["./log-tool.sh", "--verbose"] + matcher = "*" + "#; + let config: HooksConfigToml = toml::from_str(toml_str).unwrap(); + assert_eq!(config.pre_tool_use.len(), 2); + assert_eq!(config.pre_tool_use[0].command, vec!["./validate-tool.sh"]); + assert_eq!(config.pre_tool_use[0].timeout, 10); + assert_eq!(config.pre_tool_use[0].matcher, Some("bash*".to_string())); + assert_eq!( + config.pre_tool_use[1].command, + vec!["./log-tool.sh", "--verbose"] + ); + assert_eq!(config.pre_tool_use[1].matcher, Some("*".to_string())); + } + + #[test] + fn test_hooks_config_full_deserialize() { + let toml_str = r#" + [[after_agent]] + command = ["./notify.sh"] + + [[pre_tool_use]] + command = ["./pre-tool.sh"] + matcher = "bash" + + [[post_tool_use]] + command = ["./post-tool.sh"] + + [[notification]] + command = ["./notify-desktop.sh"] + + [[stop]] + command = ["./cleanup.sh"] + + [[user_prompt_submit]] + command = ["./log-prompt.sh"] + "#; + let config: HooksConfigToml = toml::from_str(toml_str).unwrap(); + assert_eq!(config.after_agent.len(), 1); + assert_eq!(config.pre_tool_use.len(), 1); + assert_eq!(config.post_tool_use.len(), 1); + assert_eq!(config.notification.len(), 1); + assert_eq!(config.stop.len(), 1); + assert_eq!(config.user_prompt_submit.len(), 1); + } + + #[tokio::test] + async fn test_hook_from_entry_creates_working_hook() { + let entry = HookEntryToml { + command: vec!["echo".to_string(), "test".to_string()], + timeout: 5, + matcher: None, + }; + + let hook = hook_from_entry(&entry); + + // Create a minimal payload to test hook execution + use super::super::types::HookEvent; + use super::super::types::HookEventAfterAgent; + use super::super::types::HookPayload; + use chrono::TimeZone; + use chrono::Utc; + use codex_protocol::ThreadId; + use std::path::PathBuf; + + let payload = HookPayload { + session_id: ThreadId::new(), + cwd: PathBuf::from("/tmp"), + triggered_at: Utc + .with_ymd_and_hms(2025, 1, 1, 0, 0, 0) + .single() + .expect("valid timestamp"), + hook_event: HookEvent::AfterAgent { + event: HookEventAfterAgent { + thread_id: ThreadId::new(), + turn_id: "test".to_string(), + input_messages: vec!["test".to_string()], + last_assistant_message: None, + }, + }, + }; + + // Hook should execute without panicking + let outcome = hook.execute(&payload).await; + + // command_hook returns Proceed on success + use super::super::types::HookOutcome; + assert_eq!(outcome, HookOutcome::Proceed); + } +} diff --git a/codex-rs/core/src/hooks/executor.rs b/codex-rs/core/src/hooks/executor.rs new file mode 100644 index 000000000000..4397e0603b51 --- /dev/null +++ b/codex-rs/core/src/hooks/executor.rs @@ -0,0 +1,639 @@ +use std::process::Stdio; +use std::sync::Arc; +use std::time::Duration; + +use serde::Deserialize; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; + +use super::types::Hook; +use super::types::HookOutcome; +use super::types::HookPayload; + +/// Maximum bytes to read from a hook command's stdout to prevent unbounded memory usage. +const MAX_STDOUT_BYTES: usize = 1_048_576; // 1MB + +/// Maximum bytes to read from a hook command's stderr to prevent unbounded memory usage. +const MAX_STDERR_BYTES: usize = 1_048_576; // 1MB + +/// Decision returned by a hook command. +#[derive(Debug, Clone, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub(super) enum HookDecision { + Proceed, + Block, + Modify, +} + +/// Result structure returned by a hook command via stdout JSON. +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "snake_case")] +pub(super) struct HookCommandResult { + pub decision: HookDecision, + #[serde(default)] + pub message: Option, + #[serde(default)] + pub content: Option, +} + +impl From for HookOutcome { + fn from(result: HookCommandResult) -> Self { + match result.decision { + HookDecision::Proceed => HookOutcome::Proceed, + HookDecision::Block => HookOutcome::Block { + message: result.message, + }, + HookDecision::Modify => match result.content { + Some(content) => HookOutcome::Modify { content }, + None => { + tracing::warn!( + "hook returned modify decision without content field; \ + treating as block to prevent empty input substitution" + ); + HookOutcome::Block { + message: Some( + "hook returned modify without content field".to_string(), + ), + } + } + }, + } + } +} + +/// Creates a hook that executes a command via stdin/stdout JSON protocol. +/// +/// The hook serializes the payload to JSON, pipes it to the command's stdin, +/// reads the command's stdout, and interprets the result as a HookOutcome. +/// +/// # Interpretation Rules +/// +/// - Exit code 0 + empty stdout → `HookOutcome::Proceed` +/// - Exit code 0 + stdout JSON with `{"decision": "block", "message": "..."}` → `HookOutcome::Block` +/// - Exit code 0 + stdout JSON with `{"decision": "modify", "content": "..."}` → `HookOutcome::Modify` +/// - Non-zero exit code → `HookOutcome::Block { message: Some(stderr_or_default) }` +/// - Timeout → `HookOutcome::Block { message: Some("hook timed out") }` +/// - Spawn failure → log warning and return `HookOutcome::Proceed` (fail-open) +pub(super) fn command_hook(argv: Vec, timeout: Duration) -> Hook { + Hook { + func: Arc::new(move |payload: &HookPayload| { + let argv = argv.clone(); + let payload = payload.clone(); + Box::pin(async move { + let Some(mut command) = super::registry::command_from_argv(&argv) else { + tracing::warn!("hook command argv is empty, skipping"); + return HookOutcome::Proceed; + }; + + command + .current_dir(&payload.cwd) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + + let mut child = match command.spawn() { + Ok(child) => child, + Err(err) => { + tracing::warn!("failed to spawn hook command: {err}"); + return HookOutcome::Proceed; + } + }; + + let Some(mut stdin) = child.stdin.take() else { + tracing::warn!("hook child process has no stdin handle"); + return HookOutcome::Proceed; + }; + let Some(mut stdout) = child.stdout.take() else { + tracing::warn!("hook child process has no stdout handle"); + return HookOutcome::Proceed; + }; + let Some(mut stderr) = child.stderr.take() else { + tracing::warn!("hook child process has no stderr handle"); + return HookOutcome::Proceed; + }; + + // Serialize payload to JSON before entering the timed block. + let payload_json = match serde_json::to_vec(&payload) { + Ok(json) => json, + Err(err) => { + tracing::warn!("failed to serialize hook payload: {err}"); + return HookOutcome::Proceed; + } + }; + + // Wrap the entire IO sequence (stdin write, stdout + stderr + // read) in a single timeout so that a misbehaving hook cannot + // hang any individual phase indefinitely. Stdout and stderr + // are drained concurrently to avoid pipe deadlocks when a hook + // produces verbose output on both streams. + let io_result = tokio::time::timeout(timeout, async { + // Write payload to stdin. If the hook closes stdin + // early (e.g. a short script that ignores input), we + // still need to read its stdout/stderr and exit status + // so that block/modify decisions are not silently lost. + if let Err(err) = stdin.write_all(&payload_json).await { + tracing::warn!("failed to write payload to hook stdin: {err}"); + } + drop(stdin); // Close stdin to signal EOF + + // Drain stdout and stderr concurrently to prevent pipe + // deadlocks (a full stderr buffer can block the child + // before it closes stdout, causing a false timeout). + let read_stdout = async { + let mut bytes = Vec::new(); + let mut buf = [0u8; 4096]; + let mut capped = false; + loop { + match stdout.read(&mut buf).await { + Ok(0) => break, + Ok(n) => { + if capped { + continue; // drain but discard + } + if bytes.len() + n > MAX_STDOUT_BYTES { + // Keep as many bytes as still fit + // before switching to drain mode. + let remaining = MAX_STDOUT_BYTES - bytes.len(); + bytes.extend_from_slice(&buf[..remaining]); + tracing::warn!( + "hook stdout exceeded max size of {MAX_STDOUT_BYTES} bytes" + ); + capped = true; + continue; + } + bytes.extend_from_slice(&buf[..n]); + } + Err(err) => { + tracing::warn!("failed to read hook stdout: {err}"); + break; + } + } + } + (bytes, capped) + }; + + let read_stderr = async { + let mut bytes = Vec::new(); + let mut buf = [0u8; 4096]; + let mut capped = false; + loop { + match stderr.read(&mut buf).await { + Ok(0) => break, + Ok(n) => { + if capped { + continue; // drain but discard + } + if bytes.len() + n > MAX_STDERR_BYTES { + bytes.extend_from_slice( + &buf[..MAX_STDERR_BYTES - bytes.len()], + ); + tracing::warn!( + "hook stderr exceeded max size of {MAX_STDERR_BYTES} bytes, truncated" + ); + capped = true; + continue; + } + bytes.extend_from_slice(&buf[..n]); + } + Err(_) => break, + } + } + String::from_utf8_lossy(&bytes).to_string() + }; + + let ((stdout_bytes, stdout_capped), stderr_string) = + tokio::join!(read_stdout, read_stderr); + + (stdout_bytes, stdout_capped, stderr_string) + }) + .await; + + // Handle IO timeout: kill the child and return Block. + let (stdout_bytes, stdout_capped, stderr_string) = match io_result { + Err(_elapsed) => { + let _ = child.kill().await; + return HookOutcome::Block { + message: Some("hook timed out".to_string()), + }; + } + Ok(data) => data, + }; + + // Wait for process exit. Once stdout and stderr are fully + // consumed the process should exit promptly; apply a generous + // grace period to guard against pathological cases. + const WAIT_GRACE: Duration = Duration::from_secs(5); + let status = match tokio::time::timeout(WAIT_GRACE, child.wait()).await { + Ok(Ok(status)) => status, + Ok(Err(err)) => { + tracing::warn!("failed to wait for hook command: {err}"); + return HookOutcome::Proceed; + } + Err(_elapsed) => { + let _ = child.kill().await; + return HookOutcome::Block { + message: Some("hook timed out".to_string()), + }; + } + }; + + // Non-zero exit code → block with stderr message + if !status.success() { + let message = if stderr_string.is_empty() { + format!("hook command failed with exit code {status}") + } else { + stderr_string + }; + return HookOutcome::Block { + message: Some(message), + }; + } + + // Exit code 0: parse stdout or default to Proceed + if stdout_bytes.is_empty() { + return HookOutcome::Proceed; + } + + // If stdout was truncated, the JSON is likely corrupted. + // Block rather than falling through to Proceed, which would + // silently bypass the hook's intended decision. + if stdout_capped { + return HookOutcome::Block { + message: Some( + "hook stdout exceeded size limit; output truncated and cannot be trusted".to_string(), + ), + }; + } + + match serde_json::from_slice::(&stdout_bytes) { + Ok(result) => result.into(), + Err(err) => { + tracing::warn!("failed to parse hook command result: {err}"); + HookOutcome::Proceed + } + } + }) + }), + } +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + use serde_json::json; + + use super::super::types::HookOutcome; + use super::HookCommandResult; + use super::HookDecision; + + #[test] + fn test_hook_command_result_deserialize_proceed() { + let json = json!({"decision": "proceed"}); + let result: HookCommandResult = serde_json::from_value(json).unwrap(); + assert_eq!(result.decision, HookDecision::Proceed); + assert_eq!(result.message, None); + assert_eq!(result.content, None); + } + + #[test] + fn test_hook_command_result_deserialize_block() { + let json = json!({"decision": "block", "message": "denied"}); + let result: HookCommandResult = serde_json::from_value(json).unwrap(); + assert_eq!(result.decision, HookDecision::Block); + assert_eq!(result.message, Some("denied".to_string())); + assert_eq!(result.content, None); + } + + #[test] + fn test_hook_command_result_deserialize_modify() { + let json = json!({"decision": "modify", "content": "new text"}); + let result: HookCommandResult = serde_json::from_value(json).unwrap(); + assert_eq!(result.decision, HookDecision::Modify); + assert_eq!(result.message, None); + assert_eq!(result.content, Some("new text".to_string())); + } + + // ---- command_hook() integration tests (Unix only) ---- + + #[cfg(not(windows))] + mod command_hook_integration { + use std::path::PathBuf; + use std::time::Duration; + + use chrono::TimeZone; + use chrono::Utc; + use codex_protocol::ThreadId; + use pretty_assertions::assert_eq; + + use super::super::super::types::HookEvent; + use super::super::super::types::HookEventAfterAgent; + use super::super::super::types::HookOutcome; + use super::super::super::types::HookPayload; + use super::super::command_hook; + + fn test_payload() -> HookPayload { + HookPayload { + session_id: ThreadId::new(), + cwd: PathBuf::from("/tmp"), + triggered_at: Utc + .with_ymd_and_hms(2025, 1, 1, 0, 0, 0) + .single() + .expect("valid timestamp"), + hook_event: HookEvent::AfterAgent { + event: HookEventAfterAgent { + thread_id: ThreadId::new(), + turn_id: "test".to_string(), + input_messages: vec!["hello".to_string()], + last_assistant_message: None, + }, + }, + } + } + + #[tokio::test] + async fn command_hook_empty_stdout_returns_proceed() { + // Command reads stdin but produces no stdout → Proceed + let hook = command_hook( + vec![ + "/bin/sh".to_string(), + "-c".to_string(), + "cat > /dev/null".to_string(), + ], + Duration::from_secs(5), + ); + let outcome = hook.execute(&test_payload()).await; + assert_eq!(outcome, HookOutcome::Proceed); + } + + #[tokio::test] + async fn command_hook_stdout_proceed_json() { + let hook = command_hook( + vec![ + "/bin/sh".to_string(), + "-c".to_string(), + r#"cat > /dev/null; echo '{"decision":"proceed"}'"#.to_string(), + ], + Duration::from_secs(5), + ); + let outcome = hook.execute(&test_payload()).await; + assert_eq!(outcome, HookOutcome::Proceed); + } + + #[tokio::test] + async fn command_hook_stdout_block_json() { + let hook = command_hook( + vec![ + "/bin/sh".to_string(), + "-c".to_string(), + r#"cat > /dev/null; echo '{"decision":"block","message":"denied by policy"}'"# + .to_string(), + ], + Duration::from_secs(5), + ); + let outcome = hook.execute(&test_payload()).await; + assert_eq!( + outcome, + HookOutcome::Block { + message: Some("denied by policy".to_string()) + } + ); + } + + #[tokio::test] + async fn command_hook_stdout_modify_json() { + let hook = command_hook( + vec![ + "/bin/sh".to_string(), + "-c".to_string(), + r#"cat > /dev/null; echo '{"decision":"modify","content":"new content"}'"# + .to_string(), + ], + Duration::from_secs(5), + ); + let outcome = hook.execute(&test_payload()).await; + assert_eq!( + outcome, + HookOutcome::Modify { + content: "new content".to_string() + } + ); + } + + #[tokio::test] + async fn command_hook_nonzero_exit_returns_block() { + let hook = command_hook( + vec![ + "/bin/sh".to_string(), + "-c".to_string(), + "cat > /dev/null; echo 'error msg' >&2; exit 1".to_string(), + ], + Duration::from_secs(5), + ); + let outcome = hook.execute(&test_payload()).await; + match outcome { + HookOutcome::Block { message } => { + let msg = message.expect("should have error message"); + assert!( + msg.contains("error msg"), + "stderr should be in message: {msg}" + ); + } + other => panic!("expected Block, got {other:?}"), + } + } + + #[tokio::test] + async fn command_hook_nonzero_exit_empty_stderr_uses_exit_code() { + let hook = command_hook( + vec![ + "/bin/sh".to_string(), + "-c".to_string(), + "cat > /dev/null; exit 42".to_string(), + ], + Duration::from_secs(5), + ); + let outcome = hook.execute(&test_payload()).await; + match outcome { + HookOutcome::Block { message } => { + let msg = message.expect("should have error message"); + assert!( + msg.contains("exit"), + "message should mention exit code: {msg}" + ); + } + other => panic!("expected Block, got {other:?}"), + } + } + + #[tokio::test] + async fn command_hook_timeout_returns_block() { + let hook = command_hook( + vec![ + "/bin/sh".to_string(), + "-c".to_string(), + "cat > /dev/null; sleep 60".to_string(), + ], + Duration::from_millis(100), // Very short timeout + ); + let outcome = hook.execute(&test_payload()).await; + assert_eq!( + outcome, + HookOutcome::Block { + message: Some("hook timed out".to_string()) + } + ); + } + + #[tokio::test] + async fn command_hook_invalid_json_stdout_returns_proceed() { + let hook = command_hook( + vec![ + "/bin/sh".to_string(), + "-c".to_string(), + "cat > /dev/null; echo 'not valid json'".to_string(), + ], + Duration::from_secs(5), + ); + let outcome = hook.execute(&test_payload()).await; + // Invalid JSON → fail-open → Proceed + assert_eq!(outcome, HookOutcome::Proceed); + } + + #[tokio::test] + async fn command_hook_nonexistent_command_returns_proceed() { + let hook = command_hook( + vec!["/nonexistent/command/path/xxxxx".to_string()], + Duration::from_secs(5), + ); + let outcome = hook.execute(&test_payload()).await; + // Spawn failure → fail-open → Proceed + assert_eq!(outcome, HookOutcome::Proceed); + } + + #[tokio::test] + async fn command_hook_empty_argv_returns_proceed() { + let hook = command_hook(vec![], Duration::from_secs(5)); + let outcome = hook.execute(&test_payload()).await; + assert_eq!(outcome, HookOutcome::Proceed); + } + + #[tokio::test] + async fn command_hook_receives_payload_on_stdin() { + // Verify the hook receives the JSON payload on stdin by having + // the script parse it and echo back a field from the payload. + let hook = command_hook( + vec![ + "/bin/sh".to_string(), + "-c".to_string(), + // Read stdin, check it's valid JSON with jq-like approach, + // then return proceed. We just verify it doesn't fail. + "cat > /dev/null; echo '{\"decision\":\"proceed\"}'".to_string(), + ], + Duration::from_secs(5), + ); + let outcome = hook.execute(&test_payload()).await; + assert_eq!(outcome, HookOutcome::Proceed); + } + + #[tokio::test] + async fn command_hook_runs_in_payload_cwd() { + // Verify that the hook command runs in the payload's cwd directory + // by having the script print its working directory via `pwd`. + let hook = command_hook( + vec![ + "/bin/sh".to_string(), + "-c".to_string(), + "cat > /dev/null; pwd".to_string(), + ], + Duration::from_secs(5), + ); + // test_payload() sets cwd to /tmp + let outcome = hook.execute(&test_payload()).await; + // pwd outputs the working directory; since it's not valid JSON, + // the executor falls through to Proceed (fail-open on invalid JSON). + // The important thing is that it doesn't fail to spawn, proving + // the command runs. We verify cwd more precisely below. + assert_eq!(outcome, HookOutcome::Proceed); + + // Now verify with a JSON response that includes the cwd + let hook = command_hook( + vec![ + "/bin/sh".to_string(), + "-c".to_string(), + r#"cat > /dev/null; CWD=$(pwd); echo "{\"decision\":\"block\",\"message\":\"$CWD\"}""# + .to_string(), + ], + Duration::from_secs(5), + ); + let outcome = hook.execute(&test_payload()).await; + match outcome { + HookOutcome::Block { message } => { + let msg = message.expect("should have cwd message"); + assert_eq!( + msg, "/tmp", + "hook should run in payload.cwd (/tmp), got: {msg}" + ); + } + other => panic!("expected Block with cwd message, got {other:?}"), + } + } + } + + #[test] + fn test_hook_command_result_to_outcome() { + let result = HookCommandResult { + decision: HookDecision::Proceed, + message: None, + content: None, + }; + assert_eq!(HookOutcome::from(result), HookOutcome::Proceed); + + let result = HookCommandResult { + decision: HookDecision::Block, + message: Some("blocked".to_string()), + content: None, + }; + assert_eq!( + HookOutcome::from(result), + HookOutcome::Block { + message: Some("blocked".to_string()) + } + ); + + let result = HookCommandResult { + decision: HookDecision::Modify, + message: None, + content: Some("modified content".to_string()), + }; + assert_eq!( + HookOutcome::from(result), + HookOutcome::Modify { + content: "modified content".to_string() + } + ); + + // Modify with explicit empty content is allowed + let result = HookCommandResult { + decision: HookDecision::Modify, + message: None, + content: Some(String::new()), + }; + assert_eq!( + HookOutcome::from(result), + HookOutcome::Modify { + content: String::new() + } + ); + + // Modify without content field → Block (malformed response) + let result = HookCommandResult { + decision: HookDecision::Modify, + message: None, + content: None, + }; + assert!( + matches!(HookOutcome::from(result), HookOutcome::Block { .. }), + "modify without content should be treated as Block" + ); + } +} diff --git a/codex-rs/core/src/hooks/mod.rs b/codex-rs/core/src/hooks/mod.rs index 2c0612825dec..5e9ee7bd51d2 100644 --- a/codex-rs/core/src/hooks/mod.rs +++ b/codex-rs/core/src/hooks/mod.rs @@ -1,3 +1,5 @@ +pub(crate) mod config; +mod executor; mod registry; mod types; mod user_notification; @@ -5,4 +7,7 @@ mod user_notification; pub(crate) use registry::Hooks; pub(crate) use types::HookEvent; pub(crate) use types::HookEventAfterAgent; +pub(crate) use types::HookEventPostToolUse; +pub(crate) use types::HookEventPreToolUse; +pub(crate) use types::HookOutcome; pub(crate) use types::HookPayload; diff --git a/codex-rs/core/src/hooks/registry.rs b/codex-rs/core/src/hooks/registry.rs index 6bccee85c614..1def53924b4f 100644 --- a/codex-rs/core/src/hooks/registry.rs +++ b/codex-rs/core/src/hooks/registry.rs @@ -1,5 +1,6 @@ use tokio::process::Command; +use super::config::hook_from_entry; use super::types::Hook; use super::types::HookEvent; use super::types::HookOutcome; @@ -10,6 +11,11 @@ use crate::config::Config; #[derive(Default, Clone)] pub(crate) struct Hooks { after_agent: Vec, + pre_tool_use: Vec, + post_tool_use: Vec, + stop: Vec, + user_prompt_submit: Vec, + notification: Vec, } fn get_notify_hook(config: &Config) -> Option { @@ -25,26 +31,79 @@ fn get_notify_hook(config: &Config) -> Option { impl Hooks { // new creates a new Hooks instance from config. // For legacy compatibility, if config.notify is set, it will be added to - // the after_agent hooks. + // the after_agent hooks. New-style hooks from [hooks] config section are + // appended after legacy hooks. pub(crate) fn new(config: &Config) -> Self { - let after_agent = get_notify_hook(config).into_iter().collect(); - Self { after_agent } + let hooks_config = &config.hooks; + + let mut after_agent: Vec = get_notify_hook(config).into_iter().collect(); + after_agent.extend(hooks_config.after_agent.iter().map(hook_from_entry)); + + let pre_tool_use = hooks_config + .pre_tool_use + .iter() + .map(hook_from_entry) + .collect(); + let post_tool_use = hooks_config + .post_tool_use + .iter() + .map(hook_from_entry) + .collect(); + let stop = hooks_config.stop.iter().map(hook_from_entry).collect(); + let user_prompt_submit = hooks_config + .user_prompt_submit + .iter() + .map(hook_from_entry) + .collect(); + let notification = hooks_config + .notification + .iter() + .map(hook_from_entry) + .collect(); + + Self { + after_agent, + pre_tool_use, + post_tool_use, + stop, + user_prompt_submit, + notification, + } } fn hooks_for_event(&self, hook_event: &HookEvent) -> &[Hook] { match hook_event { HookEvent::AfterAgent { .. } => &self.after_agent, + HookEvent::PreToolUse { .. } => &self.pre_tool_use, + HookEvent::PostToolUse { .. } => &self.post_tool_use, + HookEvent::Stop { .. } => &self.stop, + HookEvent::UserPromptSubmit { .. } => &self.user_prompt_submit, + HookEvent::Notification { .. } => &self.notification, } } - pub(crate) async fn dispatch(&self, hook_payload: HookPayload) { - // TODO(gt): support interrupting program execution by returning a result here. + /// Dispatch hooks for the given event and return the aggregate outcome. + /// + /// - If any hook returns `Block`, dispatching stops immediately and + /// `Block` is returned. + /// - If any hook returns `Modify`, the last `Modify` result wins and + /// is returned after all hooks run. Note: subsequent hooks still + /// see the *original* payload (modifications are not carried forward + /// between hooks in the current implementation). + /// - Otherwise `Proceed` is returned. + pub(crate) async fn dispatch(&self, hook_payload: HookPayload) -> HookOutcome { + let mut result = HookOutcome::Proceed; for hook in self.hooks_for_event(&hook_payload.hook_event) { let outcome = hook.execute(&hook_payload).await; - if matches!(outcome, HookOutcome::Stop) { - break; + match &outcome { + HookOutcome::Block { .. } => return outcome, + HookOutcome::Modify { .. } => { + result = outcome; + } + HookOutcome::Proceed => {} } } + result } } @@ -115,6 +174,7 @@ mod tests { Hook { func: Arc::new(move |_| { let calls = Arc::clone(&calls); + let outcome = outcome.clone(); Box::pin(async move { calls.fetch_add(1, Ordering::SeqCst); outcome @@ -124,7 +184,62 @@ mod tests { } fn hooks_for_after_agent(hooks: Vec) -> Hooks { - Hooks { after_agent: hooks } + Hooks { + after_agent: hooks, + ..Default::default() + } + } + + fn hooks_for_pre_tool_use(hooks: Vec) -> Hooks { + Hooks { + pre_tool_use: hooks, + ..Default::default() + } + } + + fn hooks_for_post_tool_use(hooks: Vec) -> Hooks { + Hooks { + post_tool_use: hooks, + ..Default::default() + } + } + + fn hook_payload_pre_tool_use(label: &str) -> HookPayload { + use super::super::types::HookEventPreToolUse; + + HookPayload { + session_id: ThreadId::new(), + cwd: PathBuf::from(CWD), + triggered_at: Utc + .with_ymd_and_hms(2025, 1, 1, 0, 0, 0) + .single() + .expect("valid timestamp"), + hook_event: HookEvent::PreToolUse { + event: HookEventPreToolUse { + tool_name: format!("tool-{label}"), + tool_input: r#"{"arg": "value"}"#.to_string(), + }, + }, + } + } + + fn hook_payload_post_tool_use(label: &str) -> HookPayload { + use super::super::types::HookEventPostToolUse; + + HookPayload { + session_id: ThreadId::new(), + cwd: PathBuf::from(CWD), + triggered_at: Utc + .with_ymd_and_hms(2025, 1, 1, 0, 0, 0) + .single() + .expect("valid timestamp"), + hook_event: HookEvent::PostToolUse { + event: HookEventPostToolUse { + tool_name: format!("tool-{label}"), + tool_output: "success".to_string(), + }, + }, + } } #[test] @@ -170,25 +285,25 @@ mod tests { #[tokio::test] async fn dispatch_executes_hook() { let calls = Arc::new(AtomicUsize::new(0)); - let hooks = hooks_for_after_agent(vec![counting_hook(&calls, HookOutcome::Continue)]); + let hooks = hooks_for_after_agent(vec![counting_hook(&calls, HookOutcome::Proceed)]); hooks.dispatch(hook_payload("1")).await; assert_eq!(calls.load(Ordering::SeqCst), 1); } #[tokio::test] - async fn default_hook_is_noop_and_continues() { + async fn default_hook_is_noop_and_proceeds() { let payload = hook_payload("d"); let outcome = Hook::default().execute(&payload).await; - assert_eq!(outcome, HookOutcome::Continue); + assert_eq!(outcome, HookOutcome::Proceed); } #[tokio::test] async fn dispatch_executes_multiple_hooks_for_same_event() { let calls = Arc::new(AtomicUsize::new(0)); let hooks = hooks_for_after_agent(vec![ - counting_hook(&calls, HookOutcome::Continue), - counting_hook(&calls, HookOutcome::Continue), + counting_hook(&calls, HookOutcome::Proceed), + counting_hook(&calls, HookOutcome::Proceed), ]); hooks.dispatch(hook_payload("2")).await; @@ -196,11 +311,11 @@ mod tests { } #[tokio::test] - async fn dispatch_stops_when_hook_returns_stop() { + async fn dispatch_stops_when_hook_returns_block() { let calls = Arc::new(AtomicUsize::new(0)); let hooks = hooks_for_after_agent(vec![ - counting_hook(&calls, HookOutcome::Stop), - counting_hook(&calls, HookOutcome::Continue), + counting_hook(&calls, HookOutcome::Block { message: None }), + counting_hook(&calls, HookOutcome::Proceed), ]); hooks.dispatch(hook_payload("3")).await; @@ -228,7 +343,7 @@ mod tests { ]) .expect("build command"); command.status().await.expect("run hook command"); - HookOutcome::Continue + HookOutcome::Proceed }) }), }; @@ -286,7 +401,7 @@ mod tests { ]) .expect("build command"); command.status().await.expect("run hook command"); - HookOutcome::Continue + HookOutcome::Proceed }) }), }; @@ -312,4 +427,107 @@ mod tests { assert_eq!(contents, expected); Ok(()) } + + #[tokio::test] + async fn dispatch_pre_tool_use_hooks_for_pre_tool_use_event() { + let calls = Arc::new(AtomicUsize::new(0)); + let hooks = hooks_for_pre_tool_use(vec![counting_hook(&calls, HookOutcome::Proceed)]); + + hooks.dispatch(hook_payload_pre_tool_use("1")).await; + assert_eq!(calls.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn dispatch_post_tool_use_hooks_for_post_tool_use_event() { + let calls = Arc::new(AtomicUsize::new(0)); + let hooks = hooks_for_post_tool_use(vec![counting_hook(&calls, HookOutcome::Proceed)]); + + hooks.dispatch(hook_payload_post_tool_use("1")).await; + assert_eq!(calls.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn dispatch_does_not_fire_hooks_for_different_event_type() { + let calls_after = Arc::new(AtomicUsize::new(0)); + let calls_pre = Arc::new(AtomicUsize::new(0)); + + let hooks = Hooks { + after_agent: vec![counting_hook(&calls_after, HookOutcome::Proceed)], + pre_tool_use: vec![counting_hook(&calls_pre, HookOutcome::Proceed)], + ..Default::default() + }; + + // Dispatch PreToolUse event should not fire after_agent hooks + hooks.dispatch(hook_payload_pre_tool_use("1")).await; + assert_eq!(calls_after.load(Ordering::SeqCst), 0); + assert_eq!(calls_pre.load(Ordering::SeqCst), 1); + + // Dispatch AfterAgent event should not fire pre_tool_use hooks + hooks.dispatch(hook_payload("2")).await; + assert_eq!(calls_after.load(Ordering::SeqCst), 1); + assert_eq!(calls_pre.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn dispatch_modify_outcome_is_carried_forward() { + let hooks = hooks_for_after_agent(vec![ + Hook { + func: Arc::new(|_| { + Box::pin(async { + HookOutcome::Modify { + content: "first".to_string(), + } + }) + }), + }, + Hook { + func: Arc::new(|_| Box::pin(async { HookOutcome::Proceed })), + }, + Hook { + func: Arc::new(|_| { + Box::pin(async { + HookOutcome::Modify { + content: "second".to_string(), + } + }) + }), + }, + ]); + + let outcome = hooks.dispatch(hook_payload("1")).await; + // Last Modify wins + assert_eq!( + outcome, + HookOutcome::Modify { + content: "second".to_string() + } + ); + } + + #[tokio::test] + async fn dispatch_modify_returned_after_all_hooks_run() { + let calls = Arc::new(AtomicUsize::new(0)); + let hooks = hooks_for_after_agent(vec![ + Hook { + func: Arc::new(|_| { + Box::pin(async { + HookOutcome::Modify { + content: "modified".to_string(), + } + }) + }), + }, + counting_hook(&calls, HookOutcome::Proceed), + counting_hook(&calls, HookOutcome::Proceed), + ]); + + let outcome = hooks.dispatch(hook_payload("1")).await; + assert_eq!(calls.load(Ordering::SeqCst), 2); // Both subsequent hooks ran + assert_eq!( + outcome, + HookOutcome::Modify { + content: "modified".to_string() + } + ); + } } diff --git a/codex-rs/core/src/hooks/types.rs b/codex-rs/core/src/hooks/types.rs index 3b22d031b64f..42362b800136 100644 --- a/codex-rs/core/src/hooks/types.rs +++ b/codex-rs/core/src/hooks/types.rs @@ -20,7 +20,7 @@ pub(crate) struct Hook { impl Default for Hook { fn default() -> Self { Self { - func: Arc::new(|_| Box::pin(async { HookOutcome::Continue })), + func: Arc::new(|_| Box::pin(async { HookOutcome::Proceed })), } } } @@ -50,6 +50,39 @@ pub(crate) struct HookEventAfterAgent { pub last_assistant_message: Option, } +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "snake_case")] +pub(crate) struct HookEventPreToolUse { + pub tool_name: String, + pub tool_input: String, +} + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "snake_case")] +pub(crate) struct HookEventPostToolUse { + pub tool_name: String, + pub tool_output: String, +} + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "snake_case")] +pub(crate) struct HookEventStop { + pub reason: String, +} + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "snake_case")] +pub(crate) struct HookEventUserPromptSubmit { + pub user_message: String, +} + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "snake_case")] +pub(crate) struct HookEventNotification { + pub message: String, + pub level: String, +} + fn serialize_triggered_at(value: &DateTime, serializer: S) -> Result where S: Serializer, @@ -64,13 +97,40 @@ pub(crate) enum HookEvent { #[serde(flatten)] event: HookEventAfterAgent, }, + PreToolUse { + #[serde(flatten)] + event: HookEventPreToolUse, + }, + PostToolUse { + #[serde(flatten)] + event: HookEventPostToolUse, + }, + #[allow(dead_code)] // Integration point in codex.rs agent loop requires separate PR. + Stop { + #[serde(flatten)] + event: HookEventStop, + }, + #[allow(dead_code)] // Integration point requires architectural changes. + UserPromptSubmit { + #[serde(flatten)] + event: HookEventUserPromptSubmit, + }, + #[allow(dead_code)] // Integration point requires architectural changes. + Notification { + #[serde(flatten)] + event: HookEventNotification, + }, } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +/// Outcome of a hook execution that determines how the agent should proceed. +#[derive(Debug, Clone, PartialEq, Eq)] pub(crate) enum HookOutcome { - Continue, - #[allow(dead_code)] - Stop, + /// Hook completed; proceed with the operation normally. + Proceed, + /// Hook requests blocking the operation (e.g. deny a tool call). + Block { message: Option }, + /// Hook requests modifying the input or output content. + Modify { content: String }, } #[cfg(test)] @@ -124,4 +184,105 @@ mod tests { assert_eq!(actual, expected); } + + #[test] + fn hook_event_pre_tool_use_serializes_with_flattened_fields() { + use super::HookEventPreToolUse; + + let hook_event = HookEvent::PreToolUse { + event: HookEventPreToolUse { + tool_name: "bash".to_string(), + tool_input: r#"{"command": "ls"}"#.to_string(), + }, + }; + + let actual = serde_json::to_value(&hook_event).expect("serialize pre_tool_use event"); + let expected = json!({ + "event_type": "pre_tool_use", + "tool_name": "bash", + "tool_input": r#"{"command": "ls"}"#, + }); + + assert_eq!(actual, expected); + } + + #[test] + fn hook_event_post_tool_use_serializes_correctly() { + use super::HookEventPostToolUse; + + let hook_event = HookEvent::PostToolUse { + event: HookEventPostToolUse { + tool_name: "bash".to_string(), + tool_output: "file1.txt\nfile2.txt".to_string(), + }, + }; + + let actual = serde_json::to_value(&hook_event).expect("serialize post_tool_use event"); + let expected = json!({ + "event_type": "post_tool_use", + "tool_name": "bash", + "tool_output": "file1.txt\nfile2.txt", + }); + + assert_eq!(actual, expected); + } + + #[test] + fn hook_event_stop_serializes_correctly() { + use super::HookEventStop; + + let hook_event = HookEvent::Stop { + event: HookEventStop { + reason: "max_tokens_reached".to_string(), + }, + }; + + let actual = serde_json::to_value(&hook_event).expect("serialize stop event"); + let expected = json!({ + "event_type": "stop", + "reason": "max_tokens_reached", + }); + + assert_eq!(actual, expected); + } + + #[test] + fn hook_event_user_prompt_submit_serializes_correctly() { + use super::HookEventUserPromptSubmit; + + let hook_event = HookEvent::UserPromptSubmit { + event: HookEventUserPromptSubmit { + user_message: "Help me debug this code".to_string(), + }, + }; + + let actual = serde_json::to_value(&hook_event).expect("serialize user_prompt_submit event"); + let expected = json!({ + "event_type": "user_prompt_submit", + "user_message": "Help me debug this code", + }); + + assert_eq!(actual, expected); + } + + #[test] + fn hook_event_notification_serializes_correctly() { + use super::HookEventNotification; + + let hook_event = HookEvent::Notification { + event: HookEventNotification { + message: "Build completed successfully".to_string(), + level: "info".to_string(), + }, + }; + + let actual = serde_json::to_value(&hook_event).expect("serialize notification event"); + let expected = json!({ + "event_type": "notification", + "message": "Build completed successfully", + "level": "info", + }); + + assert_eq!(actual, expected); + } } diff --git a/codex-rs/core/src/hooks/user_notification.rs b/codex-rs/core/src/hooks/user_notification.rs index de1317f9c350..29dcb3680a3a 100644 --- a/codex-rs/core/src/hooks/user_notification.rs +++ b/codex-rs/core/src/hooks/user_notification.rs @@ -32,7 +32,7 @@ pub(super) fn legacy_notify_json( hook_event: &HookEvent, cwd: &Path, ) -> Result { - serde_json::to_string(&match hook_event { + let notification = match hook_event { HookEvent::AfterAgent { event } => UserNotification::AgentTurnComplete { thread_id: event.thread_id.to_string(), turn_id: event.turn_id.clone(), @@ -40,7 +40,11 @@ pub(super) fn legacy_notify_json( input_messages: event.input_messages.clone(), last_assistant_message: event.last_assistant_message.clone(), }, - }) + // Legacy notification format only supports AfterAgent events. + // Other events use the new stdin/stdout JSON protocol. + _ => return serde_json::to_string(hook_event), + }; + serde_json::to_string(¬ification) } pub(super) fn notify_hook(argv: Vec) -> Hook { @@ -51,7 +55,7 @@ pub(super) fn notify_hook(argv: Vec) -> Hook { Box::pin(async move { let mut command = match command_from_argv(&argv) { Some(command) => command, - None => return HookOutcome::Continue, + None => return HookOutcome::Proceed, }; if let Ok(notify_payload) = legacy_notify_json(&payload.hook_event, &payload.cwd) { command.arg(notify_payload); @@ -64,7 +68,7 @@ pub(super) fn notify_hook(argv: Vec) -> Hook { .stderr(Stdio::null()); let _ = command.spawn(); - HookOutcome::Continue + HookOutcome::Proceed }) }), } @@ -129,4 +133,75 @@ mod tests { Ok(()) } + + #[test] + fn legacy_notify_json_for_pre_tool_use_returns_event_serialized_directly() -> Result<()> { + use super::super::types::HookEventPreToolUse; + + let hook_event = HookEvent::PreToolUse { + event: HookEventPreToolUse { + tool_name: "bash".to_string(), + tool_input: r#"{"command": "ls"}"#.to_string(), + }, + }; + + let serialized = legacy_notify_json(&hook_event, Path::new("/tmp"))?; + let actual: Value = serde_json::from_str(&serialized)?; + + // PreToolUse events use new protocol, not legacy format + let expected = json!({ + "event_type": "pre_tool_use", + "tool_name": "bash", + "tool_input": r#"{"command": "ls"}"#, + }); + + assert_eq!(actual, expected); + Ok(()) + } + + #[test] + fn legacy_notify_json_for_post_tool_use_returns_event_serialized_directly() -> Result<()> { + use super::super::types::HookEventPostToolUse; + + let hook_event = HookEvent::PostToolUse { + event: HookEventPostToolUse { + tool_name: "bash".to_string(), + tool_output: "file1.txt\nfile2.txt".to_string(), + }, + }; + + let serialized = legacy_notify_json(&hook_event, Path::new("/tmp"))?; + let actual: Value = serde_json::from_str(&serialized)?; + + let expected = json!({ + "event_type": "post_tool_use", + "tool_name": "bash", + "tool_output": "file1.txt\nfile2.txt", + }); + + assert_eq!(actual, expected); + Ok(()) + } + + #[test] + fn legacy_notify_json_for_stop_returns_event_serialized_directly() -> Result<()> { + use super::super::types::HookEventStop; + + let hook_event = HookEvent::Stop { + event: HookEventStop { + reason: "max_tokens".to_string(), + }, + }; + + let serialized = legacy_notify_json(&hook_event, Path::new("/tmp"))?; + let actual: Value = serde_json::from_str(&serialized)?; + + let expected = json!({ + "event_type": "stop", + "reason": "max_tokens", + }); + + assert_eq!(actual, expected); + Ok(()) + } } diff --git a/codex-rs/core/src/stream_events_utils.rs b/codex-rs/core/src/stream_events_utils.rs index 394f4b93297d..d4360cb8e186 100644 --- a/codex-rs/core/src/stream_events_utils.rs +++ b/codex-rs/core/src/stream_events_utils.rs @@ -128,12 +128,67 @@ pub(crate) async fn handle_output_item_done( } // The tool request should be answered directly (or was denied); push that response into the transcript. Err(FunctionCallError::RespondToModel(message)) => { - let response = ResponseInputItem::FunctionCallOutput { - call_id: String::new(), - output: FunctionCallOutputPayload { - body: FunctionCallOutputBody::Text(message), - ..Default::default() - }, + let is_custom = matches!(&item, ResponseItem::CustomToolCall { .. }); + let respond_call_id = match &item { + ResponseItem::FunctionCall { call_id, .. } + | ResponseItem::CustomToolCall { call_id, .. } => call_id.clone(), + ResponseItem::LocalShellCall { call_id, id, .. } => { + call_id.clone().or(id.clone()).unwrap_or_default() + } + _ => String::new(), + }; + let response = if is_custom { + ResponseInputItem::CustomToolCallOutput { + call_id: respond_call_id, + output: message, + } + } else { + ResponseInputItem::FunctionCallOutput { + call_id: respond_call_id, + output: FunctionCallOutputPayload { + body: FunctionCallOutputBody::Text(message), + ..Default::default() + }, + } + }; + ctx.sess + .record_conversation_items(&ctx.turn_context, std::slice::from_ref(&item)) + .await; + if let Some(response_item) = response_input_to_response_item(&response) { + ctx.sess + .record_conversation_items( + &ctx.turn_context, + std::slice::from_ref(&response_item), + ) + .await; + } + + output.needs_follow_up = true; + } + // A tool call was blocked by a pre_tool_use hook; surface the block message back into the transcript. + Err(FunctionCallError::ToolCallBlocked(message)) => { + let is_custom = matches!(&item, ResponseItem::CustomToolCall { .. }); + let blocked_call_id = match &item { + ResponseItem::FunctionCall { call_id, .. } + | ResponseItem::CustomToolCall { call_id, .. } => call_id.clone(), + ResponseItem::LocalShellCall { call_id, id, .. } => { + call_id.clone().or(id.clone()).unwrap_or_default() + } + _ => String::new(), + }; + let response = if is_custom { + ResponseInputItem::CustomToolCallOutput { + call_id: blocked_call_id, + output: message, + } + } else { + ResponseInputItem::FunctionCallOutput { + call_id: blocked_call_id, + output: FunctionCallOutputPayload { + body: FunctionCallOutputBody::Text(message), + ..Default::default() + }, + } }; ctx.sess .record_conversation_items(&ctx.turn_context, std::slice::from_ref(&item)) diff --git a/codex-rs/core/src/tools/context.rs b/codex-rs/core/src/tools/context.rs index e9edd7db4605..a8d986f9069b 100644 --- a/codex-rs/core/src/tools/context.rs +++ b/codex-rs/core/src/tools/context.rs @@ -53,6 +53,32 @@ impl ToolPayload { ToolPayload::Mcp { raw_arguments, .. } => Cow::Borrowed(raw_arguments), } } + + /// Returns a structured representation suitable for hook consumption. + /// + /// Unlike `log_payload()` which is for human-readable logging, + /// this preserves full argument structure for `LocalShell` payloads + /// (command as JSON array + workdir override) so hooks can enforce + /// accurate security policies. + pub fn hook_input(&self) -> String { + match self { + ToolPayload::Function { arguments } => arguments.clone(), + ToolPayload::Custom { input } => input.clone(), + ToolPayload::LocalShell { params } => { + let mut obj = serde_json::json!({ + "command": params.command, + }); + if let Some(workdir) = ¶ms.workdir { + obj["workdir"] = serde_json::Value::String(workdir.clone()); + } + if let Some(timeout_ms) = params.timeout_ms { + obj["timeout_ms"] = serde_json::Value::Number(timeout_ms.into()); + } + obj.to_string() + } + ToolPayload::Mcp { raw_arguments, .. } => raw_arguments.clone(), + } + } } #[derive(Clone)] diff --git a/codex-rs/core/src/tools/router.rs b/codex-rs/core/src/tools/router.rs index 1eb6190bcbbf..bab5fe0604e6 100644 --- a/codex-rs/core/src/tools/router.rs +++ b/codex-rs/core/src/tools/router.rs @@ -2,6 +2,11 @@ use crate::client_common::tools::ToolSpec; use crate::codex::Session; use crate::codex::TurnContext; use crate::function_tool::FunctionCallError; +use crate::hooks::HookEvent; +use crate::hooks::HookEventPostToolUse; +use crate::hooks::HookEventPreToolUse; +use crate::hooks::HookOutcome; +use crate::hooks::HookPayload; use crate::sandboxing::SandboxPermissions; use crate::tools::context::SharedTurnDiffTracker; use crate::tools::context::ToolInvocation; @@ -142,21 +147,80 @@ impl ToolRouter { let ToolCall { tool_name, call_id, - payload, + mut payload, } = call; let payload_outputs_custom = matches!(payload, ToolPayload::Custom { .. }); let failure_call_id = call_id.clone(); + // Extract structured tool input for hooks (preserves shell arg + // boundaries and workdir overrides, unlike log_payload()). + let tool_input = payload.hook_input(); + + // --- PreToolUse hook --- + let pre_outcome = session + .hooks() + .dispatch(HookPayload { + session_id: session.conversation_id, + cwd: turn.cwd.clone(), + triggered_at: chrono::Utc::now(), + hook_event: HookEvent::PreToolUse { + event: HookEventPreToolUse { + tool_name: tool_name.clone(), + tool_input: tool_input.clone(), + }, + }, + }) + .await; + + match pre_outcome { + HookOutcome::Proceed => {} + HookOutcome::Block { message } => { + let block_msg = + message.unwrap_or_else(|| "Blocked by pre_tool_use hook".to_string()); + return Ok(Self::failure_response( + failure_call_id, + payload_outputs_custom, + FunctionCallError::ToolCallBlocked(block_msg), + )); + } + HookOutcome::Modify { content } => { + // Apply the modified content to the tool arguments. + match &mut payload { + ToolPayload::Function { arguments } => { + *arguments = content; + } + ToolPayload::Mcp { raw_arguments, .. } => { + *raw_arguments = content; + } + ToolPayload::Custom { input } => { + *input = content; + } + ToolPayload::LocalShell { .. } => { + // Modifying shell command structure from a hook is + // not safely supported. Block the call so the + // hook's policy intent is not silently bypassed. + return Ok(Self::failure_response( + failure_call_id, + payload_outputs_custom, + FunctionCallError::ToolCallBlocked( + "pre_tool_use hook returned Modify for local_shell which is not supported; blocking execution".to_string(), + ), + )); + } + } + } + } + let invocation = ToolInvocation { - session, - turn, + session: Arc::clone(&session), + turn: Arc::clone(&turn), tracker, call_id, - tool_name, + tool_name: tool_name.clone(), payload, }; - match self.registry.dispatch(invocation).await { + let result = match self.registry.dispatch(invocation).await { Ok(response) => Ok(response), Err(FunctionCallError::Fatal(message)) => Err(FunctionCallError::Fatal(message)), Err(err) => Ok(Self::failure_response( @@ -164,6 +228,51 @@ impl ToolRouter { payload_outputs_custom, err, )), + }; + + // --- PostToolUse hook (fire-and-forget, does not alter the result) --- + // Spawned as a background task so that slow/hung post-hooks do not + // add latency to the tool response path. + if let Ok(ref response) = result { + let tool_output = Self::extract_output_text(response); + let hooks = session.hooks().clone(); + let cwd = turn.cwd.clone(); + let conversation_id = session.conversation_id; + tokio::spawn(async move { + hooks + .dispatch(HookPayload { + session_id: conversation_id, + cwd, + triggered_at: chrono::Utc::now(), + hook_event: HookEvent::PostToolUse { + event: HookEventPostToolUse { + tool_name, + tool_output, + }, + }, + }) + .await; + }); + } + + result + } + + /// Extract a textual preview from a `ResponseInputItem` for the PostToolUse hook. + fn extract_output_text(item: &ResponseInputItem) -> String { + match item { + ResponseInputItem::FunctionCallOutput { output, .. } => { + output.body.to_text().unwrap_or_default() + } + ResponseInputItem::McpToolCallOutput { result, .. } => match result { + Ok(ctr) => { + let payload: codex_protocol::models::FunctionCallOutputPayload = ctr.into(); + payload.body.to_text().unwrap_or_default() + } + Err(err) => err.clone(), + }, + ResponseInputItem::CustomToolCallOutput { output, .. } => output.clone(), + _ => String::new(), } }