diff --git a/codex-rs/code-mode-protocol/src/runtime.rs b/codex-rs/code-mode-protocol/src/runtime.rs index 2848b9ed7e16..a0c2bf3fe86e 100644 --- a/codex-rs/code-mode-protocol/src/runtime.rs +++ b/codex-rs/code-mode-protocol/src/runtime.rs @@ -14,6 +14,7 @@ pub const DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL: usize = 10_000; #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] pub struct CreateCellRequest { + pub idempotency_key: String, pub tool_call_id: String, pub enabled_tools: Vec, pub source: String, @@ -21,6 +22,7 @@ pub struct CreateCellRequest { #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] pub struct ObserveRequest { + pub idempotency_key: String, pub cell_id: CellId, pub yield_time_ms: u64, } diff --git a/codex-rs/code-mode-protocol/src/runtime_tests.rs b/codex-rs/code-mode-protocol/src/runtime_tests.rs index 997cb6d13e85..e0e615cf81e1 100644 --- a/codex-rs/code-mode-protocol/src/runtime_tests.rs +++ b/codex-rs/code-mode-protocol/src/runtime_tests.rs @@ -34,6 +34,7 @@ fn image(image_url: &str, detail: Option) -> FunctionCallOutputCont #[test] fn requests_round_trip_with_exact_tool_and_namespace_fields() { let create = CreateCellRequest { + idempotency_key: "thread-3:response-call-7".to_string(), tool_call_id: "response-call-7".to_string(), enabled_tools: vec![ ToolDefinition { @@ -61,6 +62,7 @@ fn requests_round_trip_with_exact_tool_and_namespace_fields() { assert_json_round_trip( &create, json!({ + "idempotency_key": "thread-3:response-call-7", "tool_call_id": "response-call-7", "enabled_tools": [ { @@ -89,10 +91,15 @@ fn requests_round_trip_with_exact_tool_and_namespace_fields() { assert_json_round_trip( &ObserveRequest { + idempotency_key: "thread-3:wait-call-2".to_string(), cell_id: CellId::new("cell-a7".to_string()), yield_time_ms: 250, }, - json!({"cell_id": "cell-a7", "yield_time_ms": 250}), + json!({ + "idempotency_key": "thread-3:wait-call-2", + "cell_id": "cell-a7", + "yield_time_ms": 250, + }), ); } diff --git a/codex-rs/code-mode/src/cell_actor/conversions.rs b/codex-rs/code-mode/src/cell_actor/conversions.rs index 324002a79afc..807423f20fc3 100644 --- a/codex-rs/code-mode/src/cell_actor/conversions.rs +++ b/codex-rs/code-mode/src/cell_actor/conversions.rs @@ -12,6 +12,7 @@ use crate::session_runtime::ToolKind as CellToolKind; pub(super) fn runtime_request(request: CellRequest) -> CreateCellRequest { CreateCellRequest { + idempotency_key: request.idempotency_key, tool_call_id: request.tool_call_id, enabled_tools: request .enabled_tools diff --git a/codex-rs/code-mode/src/cell_actor/tests.rs b/codex-rs/code-mode/src/cell_actor/tests.rs index 5933f15c6cad..c236aabe36d8 100644 --- a/codex-rs/code-mode/src/cell_actor/tests.rs +++ b/codex-rs/code-mode/src/cell_actor/tests.rs @@ -121,6 +121,7 @@ fn spawn_cell_actor_harness_with_policy( let (runtime_tx, runtime_pause_tx, runtime_terminate_handle) = spawn_runtime( HashMap::new(), CreateCellRequest { + idempotency_key: "cell-actor-harness".to_string(), tool_call_id: "call-1".to_string(), enabled_tools: Vec::new(), source: "await new Promise(() => {});".to_string(), diff --git a/codex-rs/code-mode/src/runtime/mod.rs b/codex-rs/code-mode/src/runtime/mod.rs index 03e0bfa3c8e2..c0c9a7fb54e5 100644 --- a/codex-rs/code-mode/src/runtime/mod.rs +++ b/codex-rs/code-mode/src/runtime/mod.rs @@ -351,6 +351,7 @@ mod tests { fn execute_request(source: &str) -> CreateCellRequest { CreateCellRequest { + idempotency_key: format!("call_1:{source}"), tool_call_id: "call_1".to_string(), enabled_tools: Vec::new(), source: source.to_string(), diff --git a/codex-rs/code-mode/src/service.rs b/codex-rs/code-mode/src/service.rs index 68119590fd19..dbfbc4fa39ee 100644 --- a/codex-rs/code-mode/src/service.rs +++ b/codex-rs/code-mode/src/service.rs @@ -22,6 +22,7 @@ use codex_code_mode_protocol::TerminateOutcome; use codex_code_mode_protocol::ToolInvocationFuture; use serde_json::Value as JsonValue; use tokio::sync::Mutex; +use tokio::sync::watch; use tokio_util::sync::CancellationToken; use crate::session_runtime as runtime; @@ -94,10 +95,17 @@ impl CodeModeSessionProvider for InProcessCodeModeSessionProvider { } pub struct CodeModeService { - runtime: SessionRuntime, + runtime: Arc>, + observations: Mutex>, + #[cfg(test)] pending_generations: Mutex>, } +struct ObservationRecord { + request: ObserveRequest, + result_rx: watch::Receiver>>, +} + impl CodeModeService { pub fn new() -> Self { Self::with_delegate(Arc::new(NoopCodeModeSessionDelegate)) @@ -105,7 +113,9 @@ impl CodeModeService { pub fn with_delegate(delegate: Arc) -> Self { Self { - runtime: SessionRuntime::new(Arc::new(ProtocolDelegate { delegate })), + runtime: Arc::new(SessionRuntime::new(Arc::new(ProtocolDelegate { delegate }))), + observations: Mutex::new(HashMap::new()), + #[cfg(test)] pending_generations: Mutex::new(HashMap::new()), } } @@ -130,44 +140,51 @@ impl CodeModeService { } pub async fn observe(&self, request: ObserveRequest) -> Result { - self.begin_observe(request).await.await + let idempotency_key = request.idempotency_key.clone(); + let mut result_rx = { + let mut observations = self.observations.lock().await; + if let Some(existing) = observations.get(&idempotency_key) { + if existing.request != request { + return Err(format!( + "observation idempotency key `{idempotency_key}` was reused for a different request" + )); + } + existing.result_rx.clone() + } else { + let (result_tx, result_rx) = watch::channel(None); + observations.insert( + idempotency_key, + ObservationRecord { + request: request.clone(), + result_rx: result_rx.clone(), + }, + ); + let runtime = Arc::clone(&self.runtime); + tokio::spawn(async move { + let result = begin_observe_runtime(runtime, request).await.await; + result_tx.send_replace(Some(result)); + }); + result_rx + } + }; + + result_rx + .wait_for(Option::is_some) + .await + .map_err(|_| "observation ended before producing a result".to_string())?; + + result_rx + .borrow() + .clone() + .ok_or_else(|| "observation ended before producing a result".to_string())? } + #[allow(dead_code)] async fn begin_observe( &self, request: ObserveRequest, ) -> CodeModeSessionResultFuture<'static, ObserveOutcome> { - let ObserveRequest { - cell_id, - yield_time_ms, - } = request; - let runtime_cell_id = runtime_cell_id(&cell_id); - let cell = match self.runtime.cell(&runtime_cell_id).await { - Ok(cell) => cell, - Err(runtime::Error::MissingCell(_) | runtime::Error::ClosedCell(_)) => { - return missing_observation(cell_id); - } - Err(error) => return Box::pin(async move { Err(error.to_string()) }), - }; - match self - .runtime - .begin_wait(&cell, Duration::from_millis(yield_time_ms)) - .await - { - Ok(pending_event) => Box::pin(async move { - match pending_event.event().await { - Ok(event) => Ok(observe_outcome(&cell_id, event)), - Err(runtime::Error::MissingCell(_) | runtime::Error::ClosedCell(_)) => { - Ok(ObserveOutcome::Missing { cell_id }) - } - Err(error) => Err(error.to_string()), - } - }), - Err(runtime::Error::MissingCell(_) | runtime::Error::ClosedCell(_)) => { - missing_observation(cell_id) - } - Err(error) => Box::pin(async move { Err(error.to_string()) }), - } + begin_observe_runtime(Arc::clone(&self.runtime), request).await } pub async fn terminate(&self, cell_id: CellId) -> Result { @@ -179,6 +196,7 @@ impl CodeModeService { } Err(error) => Err(error.to_string()), }; + #[cfg(test)] self.pending_generations .lock() .await @@ -254,11 +272,49 @@ impl CodeModeService { .shutdown() .await .map_err(|error| error.to_string()); + #[cfg(test)] self.pending_generations.lock().await.clear(); result } } +async fn begin_observe_runtime( + runtime: Arc>, + request: ObserveRequest, +) -> CodeModeSessionResultFuture<'static, ObserveOutcome> { + let ObserveRequest { + idempotency_key: _, + cell_id, + yield_time_ms, + } = request; + let runtime_cell_id = runtime_cell_id(&cell_id); + let cell = match runtime.cell(&runtime_cell_id).await { + Ok(cell) => cell, + Err(runtime::Error::MissingCell(_) | runtime::Error::ClosedCell(_)) => { + return missing_observation(cell_id); + } + Err(error) => return Box::pin(async move { Err(error.to_string()) }), + }; + match runtime + .begin_wait(&cell, Duration::from_millis(yield_time_ms)) + .await + { + Ok(pending_event) => Box::pin(async move { + match pending_event.event().await { + Ok(event) => Ok(observe_outcome(&cell_id, event)), + Err(runtime::Error::MissingCell(_) | runtime::Error::ClosedCell(_)) => { + Ok(ObserveOutcome::Missing { cell_id }) + } + Err(error) => Err(error.to_string()), + } + }), + Err(runtime::Error::MissingCell(_) | runtime::Error::ClosedCell(_)) => { + missing_observation(cell_id) + } + Err(error) => Box::pin(async move { Err(error.to_string()) }), + } +} + impl Default for CodeModeService { fn default() -> Self { Self::new() @@ -349,9 +405,8 @@ impl runtime::SessionRuntimeDelegate for ProtocolDelegate { } fn runtime_request(request: CreateCellRequest) -> runtime::CreateCellRequest { - let idempotency_key = format!("{}:{}", request.tool_call_id, request.source); runtime::CreateCellRequest { - idempotency_key, + idempotency_key: request.idempotency_key, tool_call_id: request.tool_call_id, enabled_tools: request .enabled_tools diff --git a/codex-rs/code-mode/src/service_contract_tests.rs b/codex-rs/code-mode/src/service_contract_tests.rs index fddb9bbe96ef..3a507cafe8bc 100644 --- a/codex-rs/code-mode/src/service_contract_tests.rs +++ b/codex-rs/code-mode/src/service_contract_tests.rs @@ -156,6 +156,7 @@ fn cell_id(value: &str) -> CellId { fn execute_request(source: &str) -> CreateCellRequest { CreateCellRequest { + idempotency_key: format!("call-1:{source}"), tool_call_id: "call-1".to_string(), enabled_tools: Vec::new(), source: source.to_string(), @@ -174,6 +175,7 @@ async fn execute_with_yield_time( let cell_id = service.create_cell(request).await.unwrap(); service .observe(ObserveRequest { + idempotency_key: format!("observe:{cell_id}"), cell_id, yield_time_ms, }) @@ -214,6 +216,74 @@ async fn next_event(events_rx: &mut mpsc::UnboundedReceiver) -> D .expect("delegate event channel closed") } +#[tokio::test] +async fn create_retry_returns_the_cell_from_the_original_ambiguous_request() { + let service = CodeModeService::new(); + let request = execute_request("await new Promise(() => {});"); + + let original_cell_id = service.create_cell(request.clone()).await.unwrap(); + let retry_cell_id = service.create_cell(request).await.unwrap(); + + assert_eq!(retry_cell_id, original_cell_id); + service.terminate(original_cell_id).await.unwrap(); +} + +#[tokio::test] +async fn cancelled_observation_is_replayed_by_its_idempotency_key() { + let (delegate, mut events_rx) = BlockingDelegate::new(); + let service = Arc::new(CodeModeService::with_delegate(delegate.clone())); + let created_cell_id = service + .create_cell(CreateCellRequest { + idempotency_key: "ambiguous-observation-cell".to_string(), + enabled_tools: vec![blocking_tool()], + source: r#"await tools.block({}); text("done");"#.to_string(), + ..execute_request("") + }) + .await + .unwrap(); + assert_eq!(next_event(&mut events_rx).await, DelegateEvent::ToolStarted); + let request = ObserveRequest { + idempotency_key: "lost-observation-response".to_string(), + cell_id: created_cell_id, + yield_time_ms: 60_000, + }; + + let first_attempt = tokio::spawn({ + let service = Arc::clone(&service); + let request = request.clone(); + async move { service.observe(request).await } + }); + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if service + .observations + .lock() + .await + .contains_key("lost-observation-response") + { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("observation registration timed out"); + first_attempt.abort(); + assert!(first_attempt.await.unwrap_err().is_cancelled()); + delegate.release_tool(); + + assert_eq!( + service.observe(request).await, + Ok(ObserveOutcome::Completed { + cell_id: cell_id("1"), + content_items: vec![FunctionCallOutputContentItem::InputText { + text: "done".to_string(), + }], + error_text: None, + }) + ); +} + #[tokio::test] async fn yields_and_resumes() { let service = CodeModeService::new(); @@ -239,6 +309,7 @@ async fn yields_and_resumes() { assert_eq!( service .observe(ObserveRequest { + idempotency_key: "observe-after-yield".to_string(), cell_id: cell_id("1"), yield_time_ms: 60_000, }) @@ -326,7 +397,7 @@ async fn observed_natural_completion_wins_over_termination() { let response = create_and_observe_to_pending( &service, CreateCellRequest { - tool_call_id: format!("completion-probe-{probe_generation}"), + idempotency_key: format!("completion-probe-{probe_generation}"), ..execute_request(r#"text(String(load("finished")));"#) }, ) @@ -547,6 +618,7 @@ async fn second_observer_is_rejected_without_displacing_the_first() { let first_observer = service .begin_observe(ObserveRequest { + idempotency_key: "first-observer".to_string(), cell_id: cell_id("1"), yield_time_ms: 60_000, }) @@ -554,6 +626,7 @@ async fn second_observer_is_rejected_without_displacing_the_first() { assert_eq!( service .observe(ObserveRequest { + idempotency_key: "second-observer".to_string(), cell_id: cell_id("1"), yield_time_ms: 60_000, }) @@ -598,6 +671,7 @@ async fn natural_completion_cleans_up_callbacks_before_responding() { assert_eq!( service .observe(ObserveRequest { + idempotency_key: "observe-created-cell".to_string(), cell_id: created_cell_id, yield_time_ms: 60_000, }) diff --git a/codex-rs/code-mode/src/service_tests.rs b/codex-rs/code-mode/src/service_tests.rs index 78ca0426c1ef..14a579a8eb26 100644 --- a/codex-rs/code-mode/src/service_tests.rs +++ b/codex-rs/code-mode/src/service_tests.rs @@ -64,6 +64,7 @@ impl CodeModeSessionDelegate for ReleasableToolDelegate { fn execute_request(source: &str) -> CreateCellRequest { CreateCellRequest { + idempotency_key: format!("call_1:{source}"), tool_call_id: "call_1".to_string(), enabled_tools: Vec::new(), source: source.to_string(), @@ -97,6 +98,7 @@ async fn execute_with_yield_time( let cell_id = service.create_cell(request).await.unwrap(); service .observe(ObserveRequest { + idempotency_key: format!("observe:{cell_id}"), cell_id, yield_time_ms, }) @@ -152,6 +154,7 @@ async fn stored_values_are_shared_between_cells_but_not_sessions() { let write_response = execute( &first_session, CreateCellRequest { + idempotency_key: "write-shared-value".to_string(), source: r#"store("key", "visible");"#.to_string(), ..execute_request("") }, @@ -161,6 +164,7 @@ async fn stored_values_are_shared_between_cells_but_not_sessions() { let same_session = execute( &first_session, CreateCellRequest { + idempotency_key: "read-shared-value".to_string(), source: r#"text(String(load("key")));"#.to_string(), ..execute_request("") }, @@ -169,6 +173,7 @@ async fn stored_values_are_shared_between_cells_but_not_sessions() { let other_session = execute( &second_session, CreateCellRequest { + idempotency_key: "read-other-session-value".to_string(), source: r#"text(String(load("key")));"#.to_string(), ..execute_request("") }, @@ -980,6 +985,7 @@ async fn observe_reports_missing_cell_separately_from_runtime_results() { let response = service .observe(ObserveRequest { + idempotency_key: "observe-missing".to_string(), cell_id: cell_id("missing"), yield_time_ms: 1, }) diff --git a/codex-rs/core/src/tools/code_mode/execute_handler.rs b/codex-rs/core/src/tools/code_mode/execute_handler.rs index 8410e08c33b9..98fd326f6a33 100644 --- a/codex-rs/core/src/tools/code_mode/execute_handler.rs +++ b/codex-rs/core/src/tools/code_mode/execute_handler.rs @@ -40,11 +40,13 @@ impl CodeModeExecuteHandler { let enabled_tools = codex_tools::collect_code_mode_tool_definitions(&self.nested_tool_specs); let started_at = std::time::Instant::now(); + let idempotency_key = format!("{}:{call_id}", exec.session.thread_id()); let cell_id = exec .session .services .code_mode_service .create_cell(codex_code_mode::CreateCellRequest { + idempotency_key: idempotency_key.clone(), tool_call_id: call_id.clone(), enabled_tools, source: args.code.clone(), @@ -77,6 +79,7 @@ impl CodeModeExecuteHandler { .services .code_mode_service .observe(codex_code_mode::ObserveRequest { + idempotency_key, cell_id: cell_id.clone(), yield_time_ms: args .yield_time_ms diff --git a/codex-rs/core/src/tools/code_mode/wait_handler.rs b/codex-rs/core/src/tools/code_mode/wait_handler.rs index 04e162bb6e00..ce98d736c323 100644 --- a/codex-rs/core/src/tools/code_mode/wait_handler.rs +++ b/codex-rs/core/src/tools/code_mode/wait_handler.rs @@ -66,6 +66,7 @@ impl CodeModeWaitHandler { let ToolInvocation { session, turn, + call_id, tool_name, payload, .. @@ -76,6 +77,7 @@ impl CodeModeWaitHandler { if tool_name.namespace.is_none() && tool_name.name.as_str() == WAIT_TOOL_NAME => { let args: ExecWaitArgs = parse_arguments(&arguments)?; + let idempotency_key = format!("{}:{call_id}", session.thread_id()); let exec = ExecContext { session, turn }; let started_at = std::time::Instant::now(); let cell_id = codex_code_mode::CellId::new(args.cell_id); @@ -99,6 +101,7 @@ impl CodeModeWaitHandler { .services .code_mode_service .observe(codex_code_mode::ObserveRequest { + idempotency_key, cell_id, yield_time_ms: args.yield_time_ms, })