Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions codex-rs/code-mode-protocol/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ pub const DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL: usize = 10_000;

#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
pub struct CreateCellRequest {
pub idempotency_key: String,
pub tool_call_id: String,
pub enabled_tools: Vec<ToolDefinition>,
pub source: String,
}

#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
pub struct ObserveRequest {
pub idempotency_key: String,
pub cell_id: CellId,
pub yield_time_ms: u64,
}
Expand Down
9 changes: 8 additions & 1 deletion codex-rs/code-mode-protocol/src/runtime_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ fn image(image_url: &str, detail: Option<ImageDetail>) -> FunctionCallOutputCont
#[test]
fn requests_round_trip_with_exact_tool_and_namespace_fields() {
let create = CreateCellRequest {
idempotency_key: "thread-3:response-call-7".to_string(),
tool_call_id: "response-call-7".to_string(),
enabled_tools: vec![
ToolDefinition {
Expand Down Expand Up @@ -61,6 +62,7 @@ fn requests_round_trip_with_exact_tool_and_namespace_fields() {
assert_json_round_trip(
&create,
json!({
"idempotency_key": "thread-3:response-call-7",
"tool_call_id": "response-call-7",
"enabled_tools": [
{
Expand Down Expand Up @@ -89,10 +91,15 @@ fn requests_round_trip_with_exact_tool_and_namespace_fields() {

assert_json_round_trip(
&ObserveRequest {
idempotency_key: "thread-3:wait-call-2".to_string(),
cell_id: CellId::new("cell-a7".to_string()),
yield_time_ms: 250,
},
json!({"cell_id": "cell-a7", "yield_time_ms": 250}),
json!({
"idempotency_key": "thread-3:wait-call-2",
"cell_id": "cell-a7",
"yield_time_ms": 250,
}),
);
}

Expand Down
1 change: 1 addition & 0 deletions codex-rs/code-mode/src/cell_actor/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::session_runtime::ToolKind as CellToolKind;

pub(super) fn runtime_request(request: CellRequest) -> CreateCellRequest {
CreateCellRequest {
idempotency_key: request.idempotency_key,
tool_call_id: request.tool_call_id,
enabled_tools: request
.enabled_tools
Expand Down
1 change: 1 addition & 0 deletions codex-rs/code-mode/src/cell_actor/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ fn spawn_cell_actor_harness_with_policy<H: CellHost>(
let (runtime_tx, runtime_pause_tx, runtime_terminate_handle) = spawn_runtime(
HashMap::new(),
CreateCellRequest {
idempotency_key: "cell-actor-harness".to_string(),
tool_call_id: "call-1".to_string(),
enabled_tools: Vec::new(),
source: "await new Promise(() => {});".to_string(),
Expand Down
1 change: 1 addition & 0 deletions codex-rs/code-mode/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ mod tests {

fn execute_request(source: &str) -> CreateCellRequest {
CreateCellRequest {
idempotency_key: format!("call_1:{source}"),
tool_call_id: "call_1".to_string(),
enabled_tools: Vec::new(),
source: source.to_string(),
Expand Down
127 changes: 91 additions & 36 deletions codex-rs/code-mode/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use codex_code_mode_protocol::TerminateOutcome;
use codex_code_mode_protocol::ToolInvocationFuture;
use serde_json::Value as JsonValue;
use tokio::sync::Mutex;
use tokio::sync::watch;
use tokio_util::sync::CancellationToken;

use crate::session_runtime as runtime;
Expand Down Expand Up @@ -94,18 +95,27 @@ impl CodeModeSessionProvider for InProcessCodeModeSessionProvider {
}

pub struct CodeModeService {
runtime: SessionRuntime<ProtocolDelegate>,
runtime: Arc<SessionRuntime<ProtocolDelegate>>,
observations: Mutex<HashMap<String, ObservationRecord>>,
#[cfg(test)]
pending_generations: Mutex<HashMap<runtime::CellId, runtime::PendingGeneration>>,
}

struct ObservationRecord {
request: ObserveRequest,
result_rx: watch::Receiver<Option<Result<ObserveOutcome, String>>>,
}

impl CodeModeService {
pub fn new() -> Self {
Self::with_delegate(Arc::new(NoopCodeModeSessionDelegate))
}

pub fn with_delegate(delegate: Arc<dyn CodeModeSessionDelegate>) -> Self {
Self {
runtime: SessionRuntime::new(Arc::new(ProtocolDelegate { delegate })),
runtime: Arc::new(SessionRuntime::new(Arc::new(ProtocolDelegate { delegate }))),
observations: Mutex::new(HashMap::new()),
#[cfg(test)]
pending_generations: Mutex::new(HashMap::new()),
}
}
Expand All @@ -130,44 +140,51 @@ impl CodeModeService {
}

pub async fn observe(&self, request: ObserveRequest) -> Result<ObserveOutcome, String> {
self.begin_observe(request).await.await
let idempotency_key = request.idempotency_key.clone();
let mut result_rx = {
let mut observations = self.observations.lock().await;
if let Some(existing) = observations.get(&idempotency_key) {
if existing.request != request {
return Err(format!(
"observation idempotency key `{idempotency_key}` was reused for a different request"
));
}
existing.result_rx.clone()
} else {
let (result_tx, result_rx) = watch::channel(None);
observations.insert(
idempotency_key,
ObservationRecord {
request: request.clone(),
result_rx: result_rx.clone(),
},
);
let runtime = Arc::clone(&self.runtime);
tokio::spawn(async move {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This task keeps the runtime alive even after the service and all replay receivers are gone, so dropping the session no longer cancels the cell until the observation deadline finishes
This is a standard ownership deadlock. This needs a proper teardown or a weaken hold

let result = begin_observe_runtime(runtime, request).await.await;
result_tx.send_replace(Some(result));
});
result_rx
}
};

result_rx
.wait_for(Option::is_some)
.await
.map_err(|_| "observation ended before producing a result".to_string())?;

result_rx
.borrow()
.clone()
.ok_or_else(|| "observation ended before producing a result".to_string())?
}

#[allow(dead_code)]
async fn begin_observe(
&self,
request: ObserveRequest,
) -> CodeModeSessionResultFuture<'static, ObserveOutcome> {
let ObserveRequest {
cell_id,
yield_time_ms,
} = request;
let runtime_cell_id = runtime_cell_id(&cell_id);
let cell = match self.runtime.cell(&runtime_cell_id).await {
Ok(cell) => cell,
Err(runtime::Error::MissingCell(_) | runtime::Error::ClosedCell(_)) => {
return missing_observation(cell_id);
}
Err(error) => return Box::pin(async move { Err(error.to_string()) }),
};
match self
.runtime
.begin_wait(&cell, Duration::from_millis(yield_time_ms))
.await
{
Ok(pending_event) => Box::pin(async move {
match pending_event.event().await {
Ok(event) => Ok(observe_outcome(&cell_id, event)),
Err(runtime::Error::MissingCell(_) | runtime::Error::ClosedCell(_)) => {
Ok(ObserveOutcome::Missing { cell_id })
}
Err(error) => Err(error.to_string()),
}
}),
Err(runtime::Error::MissingCell(_) | runtime::Error::ClosedCell(_)) => {
missing_observation(cell_id)
}
Err(error) => Box::pin(async move { Err(error.to_string()) }),
}
begin_observe_runtime(Arc::clone(&self.runtime), request).await
}

pub async fn terminate(&self, cell_id: CellId) -> Result<TerminateOutcome, String> {
Expand All @@ -179,6 +196,7 @@ impl CodeModeService {
}
Err(error) => Err(error.to_string()),
};
#[cfg(test)]
self.pending_generations
.lock()
.await
Expand Down Expand Up @@ -254,11 +272,49 @@ impl CodeModeService {
.shutdown()
.await
.map_err(|error| error.to_string());
#[cfg(test)]
self.pending_generations.lock().await.clear();
result
}
}

async fn begin_observe_runtime(
runtime: Arc<SessionRuntime<ProtocolDelegate>>,
request: ObserveRequest,
) -> CodeModeSessionResultFuture<'static, ObserveOutcome> {
let ObserveRequest {
idempotency_key: _,
cell_id,
yield_time_ms,
} = request;
let runtime_cell_id = runtime_cell_id(&cell_id);
let cell = match runtime.cell(&runtime_cell_id).await {
Ok(cell) => cell,
Err(runtime::Error::MissingCell(_) | runtime::Error::ClosedCell(_)) => {
return missing_observation(cell_id);
}
Err(error) => return Box::pin(async move { Err(error.to_string()) }),
};
match runtime
.begin_wait(&cell, Duration::from_millis(yield_time_ms))
.await
{
Ok(pending_event) => Box::pin(async move {
match pending_event.event().await {
Ok(event) => Ok(observe_outcome(&cell_id, event)),
Err(runtime::Error::MissingCell(_) | runtime::Error::ClosedCell(_)) => {
Ok(ObserveOutcome::Missing { cell_id })
}
Err(error) => Err(error.to_string()),
}
}),
Err(runtime::Error::MissingCell(_) | runtime::Error::ClosedCell(_)) => {
missing_observation(cell_id)
}
Err(error) => Box::pin(async move { Err(error.to_string()) }),
}
}

impl Default for CodeModeService {
fn default() -> Self {
Self::new()
Expand Down Expand Up @@ -349,9 +405,8 @@ impl runtime::SessionRuntimeDelegate for ProtocolDelegate {
}

fn runtime_request(request: CreateCellRequest) -> runtime::CreateCellRequest {
let idempotency_key = format!("{}:{}", request.tool_call_id, request.source);
runtime::CreateCellRequest {
idempotency_key,
idempotency_key: request.idempotency_key,
tool_call_id: request.tool_call_id,
enabled_tools: request
.enabled_tools
Expand Down
76 changes: 75 additions & 1 deletion codex-rs/code-mode/src/service_contract_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ fn cell_id(value: &str) -> CellId {

fn execute_request(source: &str) -> CreateCellRequest {
CreateCellRequest {
idempotency_key: format!("call-1:{source}"),
tool_call_id: "call-1".to_string(),
enabled_tools: Vec::new(),
source: source.to_string(),
Expand All @@ -174,6 +175,7 @@ async fn execute_with_yield_time(
let cell_id = service.create_cell(request).await.unwrap();
service
.observe(ObserveRequest {
idempotency_key: format!("observe:{cell_id}"),
cell_id,
yield_time_ms,
})
Expand Down Expand Up @@ -214,6 +216,74 @@ async fn next_event(events_rx: &mut mpsc::UnboundedReceiver<DelegateEvent>) -> D
.expect("delegate event channel closed")
}

#[tokio::test]
async fn create_retry_returns_the_cell_from_the_original_ambiguous_request() {
let service = CodeModeService::new();
let request = execute_request("await new Promise(() => {});");

let original_cell_id = service.create_cell(request.clone()).await.unwrap();
let retry_cell_id = service.create_cell(request).await.unwrap();

assert_eq!(retry_cell_id, original_cell_id);
service.terminate(original_cell_id).await.unwrap();
}

#[tokio::test]
async fn cancelled_observation_is_replayed_by_its_idempotency_key() {
let (delegate, mut events_rx) = BlockingDelegate::new();
let service = Arc::new(CodeModeService::with_delegate(delegate.clone()));
let created_cell_id = service
.create_cell(CreateCellRequest {
idempotency_key: "ambiguous-observation-cell".to_string(),
enabled_tools: vec![blocking_tool()],
source: r#"await tools.block({}); text("done");"#.to_string(),
..execute_request("")
})
.await
.unwrap();
assert_eq!(next_event(&mut events_rx).await, DelegateEvent::ToolStarted);
let request = ObserveRequest {
idempotency_key: "lost-observation-response".to_string(),
cell_id: created_cell_id,
yield_time_ms: 60_000,
};

let first_attempt = tokio::spawn({
let service = Arc::clone(&service);
let request = request.clone();
async move { service.observe(request).await }
});
tokio::time::timeout(Duration::from_secs(1), async {
loop {
if service
.observations
.lock()
.await
.contains_key("lost-observation-response")
{
break;
}
tokio::task::yield_now().await;
}
})
.await
.expect("observation registration timed out");
first_attempt.abort();
assert!(first_attempt.await.unwrap_err().is_cancelled());
delegate.release_tool();

assert_eq!(
service.observe(request).await,
Ok(ObserveOutcome::Completed {
cell_id: cell_id("1"),
content_items: vec![FunctionCallOutputContentItem::InputText {
text: "done".to_string(),
}],
error_text: None,
})
);
}

#[tokio::test]
async fn yields_and_resumes() {
let service = CodeModeService::new();
Expand All @@ -239,6 +309,7 @@ async fn yields_and_resumes() {
assert_eq!(
service
.observe(ObserveRequest {
idempotency_key: "observe-after-yield".to_string(),
cell_id: cell_id("1"),
yield_time_ms: 60_000,
})
Expand Down Expand Up @@ -326,7 +397,7 @@ async fn observed_natural_completion_wins_over_termination() {
let response = create_and_observe_to_pending(
&service,
CreateCellRequest {
tool_call_id: format!("completion-probe-{probe_generation}"),
idempotency_key: format!("completion-probe-{probe_generation}"),
..execute_request(r#"text(String(load("finished")));"#)
},
)
Expand Down Expand Up @@ -547,13 +618,15 @@ async fn second_observer_is_rejected_without_displacing_the_first() {

let first_observer = service
.begin_observe(ObserveRequest {
idempotency_key: "first-observer".to_string(),
cell_id: cell_id("1"),
yield_time_ms: 60_000,
})
.await;
assert_eq!(
service
.observe(ObserveRequest {
idempotency_key: "second-observer".to_string(),
cell_id: cell_id("1"),
yield_time_ms: 60_000,
})
Expand Down Expand Up @@ -598,6 +671,7 @@ async fn natural_completion_cleans_up_callbacks_before_responding() {
assert_eq!(
service
.observe(ObserveRequest {
idempotency_key: "observe-created-cell".to_string(),
cell_id: created_cell_id,
yield_time_ms: 60_000,
})
Expand Down
Loading
Loading