diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 9f3e69e3ccfd..bd194ece89d6 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -2173,9 +2173,11 @@ dependencies = [ "futures", "gethostname", "hmac 0.12.1", + "httpdate", "jsonwebtoken", "owo-colors", "pretty_assertions", + "rand 0.9.3", "serde", "serde_json", "sha2 0.10.9", diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index 0824f7db7b7f..34097b56aad5 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -318,6 +318,7 @@ glob = "0.3" globset = "0.4" hmac = "0.12.1" http = "1.3.1" +httpdate = "1.0.3" iana-time-zone = "0.1.64" icu_decimal = "2.1" icu_locale_core = "2.1" diff --git a/codex-rs/app-server-transport/Cargo.toml b/codex-rs/app-server-transport/Cargo.toml index d4c0a83739c4..af1f0e960357 100644 --- a/codex-rs/app-server-transport/Cargo.toml +++ b/codex-rs/app-server-transport/Cargo.toml @@ -35,8 +35,10 @@ constant_time_eq = { workspace = true } futures = { workspace = true } gethostname = { workspace = true } hmac = { workspace = true } +httpdate = { workspace = true } jsonwebtoken = { workspace = true } owo-colors = { workspace = true, features = ["supports-colors"] } +rand = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } sha2 = { workspace = true } diff --git a/codex-rs/app-server-transport/src/transport/remote_control/enroll.rs b/codex-rs/app-server-transport/src/transport/remote_control/enroll.rs index e327503a937f..be640724765d 100644 --- a/codex-rs/app-server-transport/src/transport/remote_control/enroll.rs +++ b/codex-rs/app-server-transport/src/transport/remote_control/enroll.rs @@ -1,8 +1,4 @@ -use super::auth::RemoteControlConnectionAuth; use super::pairing_unavailable_error; -use super::protocol::EnrollRemoteServerRequest; -use super::protocol::EnrollRemoteServerResponse; -use super::protocol::RefreshRemoteServerRequest; use super::protocol::RemoteControlPairingStatusRequest; use super::protocol::RemoteControlPairingStatusResponse as BackendRemoteControlPairingStatusResponse; use super::protocol::RemoteControlTarget; @@ -14,8 +10,6 @@ use codex_app_server_protocol::RemoteControlPairingStatusResponse; use codex_login::default_client::build_reqwest_client; use codex_state::RemoteControlEnrollmentRecord; use codex_state::StateRuntime; -use serde::Serialize; -use serde::de::DeserializeOwned; use std::io; use std::io::ErrorKind; use time::OffsetDateTime; @@ -23,15 +17,13 @@ use time::format_description::well_known::Rfc3339; use tracing::info; use tracing::warn; -const REMOTE_CONTROL_ENROLL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30); const REMOTE_CONTROL_PAIRING_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30); const REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES: usize = 4096; -const REMOTE_CONTROL_SERVER_TOKEN_REFRESH_SKEW_SECS: i64 = 30; +const REMOTE_CONTROL_SERVER_TOKEN_REFRESH_SKEW_SECS: i64 = 5 * 60; const REQUEST_ID_HEADER: &str = "x-request-id"; const OAI_REQUEST_ID_HEADER: &str = "x-oai-request-id"; const CF_RAY_HEADER: &str = "cf-ray"; -pub(super) const REMOTE_CONTROL_INSTALLATION_ID_HEADER: &str = "x-codex-installation-id"; #[derive(Debug, Clone, PartialEq, Eq)] pub(super) struct RemoteControlEnrollment { @@ -42,6 +34,14 @@ pub(super) struct RemoteControlEnrollment { pub(super) server_name: String, pub(super) remote_control_token: Option, pub(super) expires_at: Option, + pub(super) next_refresh_at: Option, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(super) enum RemoteControlServerTokenRefreshRequirement { + Required, + Proactive, + NotNeeded, } impl RemoteControlEnrollment { @@ -49,7 +49,9 @@ impl RemoteControlEnrollment { &self, request: StartRemoteControlPairingRequest, ) -> io::Result { - if self.should_refresh_server_token() { + if self.server_token_refresh_requirement() + == RemoteControlServerTokenRefreshRequirement::Required + { return Err(pairing_unavailable_error()); } let remote_control_token = self @@ -142,7 +144,9 @@ impl RemoteControlEnrollment { &self, request: RemoteControlPairingStatusRequest, ) -> io::Result { - if self.should_refresh_server_token() { + if self.server_token_refresh_requirement() + == RemoteControlServerTokenRefreshRequirement::Required + { return Err(pairing_unavailable_error()); } let remote_control_token = self @@ -201,13 +205,35 @@ impl RemoteControlEnrollment { }) } + pub(super) fn server_token_refresh_requirement( + &self, + ) -> RemoteControlServerTokenRefreshRequirement { + self.server_token_refresh_requirement_at(OffsetDateTime::now_utc()) + } + pub(super) fn should_refresh_server_token(&self) -> bool { - self.remote_control_token.is_none() - || self.expires_at.is_none_or(|expires_at| { - expires_at.unix_timestamp() - <= OffsetDateTime::now_utc().unix_timestamp() - + REMOTE_CONTROL_SERVER_TOKEN_REFRESH_SKEW_SECS - }) + self.server_token_refresh_requirement() + != RemoteControlServerTokenRefreshRequirement::NotNeeded + } + + pub(super) fn server_token_refresh_requirement_at( + &self, + now: OffsetDateTime, + ) -> RemoteControlServerTokenRefreshRequirement { + let Some(expires_at) = self.remote_control_token.as_ref().and(self.expires_at) else { + return RemoteControlServerTokenRefreshRequirement::Required; + }; + if expires_at <= now { + return RemoteControlServerTokenRefreshRequirement::Required; + } + if expires_at > now + time::Duration::seconds(REMOTE_CONTROL_SERVER_TOKEN_REFRESH_SKEW_SECS) + || self + .next_refresh_at + .is_some_and(|next_refresh_at| next_refresh_at > now) + { + return RemoteControlServerTokenRefreshRequirement::NotNeeded; + } + RemoteControlServerTokenRefreshRequirement::Proactive } pub(super) fn clear_server_token(&mut self) { @@ -267,6 +293,7 @@ pub(super) async fn load_persisted_remote_control_enrollment( server_name: enrollment.server_name, remote_control_token: None, expires_at: None, + next_refresh_at: None, })) } None => { @@ -398,164 +425,12 @@ pub(crate) fn format_headers(headers: &HeaderMap) -> String { format!("request-id: {request_id_str}, cf-ray: {cf_ray_str}") } -pub(super) async fn enroll_remote_control_server( - remote_control_target: &RemoteControlTarget, - auth: &RemoteControlConnectionAuth, - installation_id: &str, - server_name: &str, -) -> io::Result { - let enroll_url = &remote_control_target.enroll_url; - let request = EnrollRemoteServerRequest { - name: server_name.to_string(), - os: std::env::consts::OS, - arch: std::env::consts::ARCH, - app_server_version: env!("CARGO_PKG_VERSION"), - installation_id: installation_id.to_string(), - }; - let enrollment_response = send_remote_control_server_request::<_, EnrollRemoteServerResponse>( - enroll_url, - auth, - installation_id, - &request, - "enroll", - "server enrollment", - ) - .await?; - let mut enrollment = RemoteControlEnrollment { - remote_control_target: remote_control_target.clone(), - account_id: auth.account_id.clone(), - environment_id: enrollment_response.environment_id, - server_id: enrollment_response.server_id, - server_name: server_name.to_string(), - remote_control_token: None, - expires_at: None, - }; - update_remote_control_server_token( - &mut enrollment, - enroll_url, - enrollment_response.remote_control_token, - enrollment_response.expires_at, - )?; - Ok(enrollment) -} - -pub(super) async fn refresh_remote_control_server( - auth: &RemoteControlConnectionAuth, - installation_id: &str, - enrollment: &mut RemoteControlEnrollment, -) -> io::Result<()> { - let refresh_url = enrollment.remote_control_target.refresh_url.clone(); - let request = RefreshRemoteServerRequest { - server_id: enrollment.server_id.clone(), - installation_id: installation_id.to_string(), - }; - let refreshed = send_remote_control_server_request::<_, EnrollRemoteServerResponse>( - &refresh_url, - auth, - installation_id, - &request, - "refresh", - "server refresh", - ) - .await?; - if refreshed.server_id != enrollment.server_id - || refreshed.environment_id != enrollment.environment_id - { - return Err(io::Error::other(format!( - "remote control server refresh returned mismatched enrollment: expected server_id={}, environment_id={}; got server_id={}, environment_id={}", - enrollment.server_id, - enrollment.environment_id, - refreshed.server_id, - refreshed.environment_id - ))); - } - - update_remote_control_server_token( - enrollment, - &refresh_url, - refreshed.remote_control_token, - refreshed.expires_at, - ) -} - -async fn send_remote_control_server_request( - url: &str, - auth: &RemoteControlConnectionAuth, - installation_id: &str, - request: &Request, - action: &str, - response_kind: &str, -) -> io::Result -where - Request: Serialize, - Response: DeserializeOwned, -{ - let client = build_reqwest_client(); - let auth_headers = auth.request_headers()?; - let response = client - .post(url) - .timeout(REMOTE_CONTROL_ENROLL_TIMEOUT) - .headers(auth_headers) - .header(REMOTE_CONTROL_INSTALLATION_ID_HEADER, installation_id) - .json(request) - .send() - .await - .map_err(|err| { - io::Error::other(format!( - "failed to {action} remote control server at `{url}`: {err}" - )) - })?; - let headers = response.headers().clone(); - let status = response.status(); - let body = response.bytes().await.map_err(|err| { - io::Error::other(format!( - "failed to read remote control {response_kind} response from `{url}`: {err}" - )) - })?; - let body_preview = preview_remote_control_response_body(&body); - if !status.is_success() { - let headers_str = format_headers(&headers); - let error_kind = match status.as_u16() { - 401 | 403 => ErrorKind::PermissionDenied, - 404 => ErrorKind::NotFound, - _ => ErrorKind::Other, - }; - return Err(io::Error::new( - error_kind, - format!( - "remote control {response_kind} failed at `{url}`: HTTP {status}, {headers_str}, body: {body_preview}" - ), - )); - } - - serde_json::from_slice::(&body).map_err(|err| { - let headers_str = format_headers(&headers); - io::Error::other(format!( - "failed to parse remote control {response_kind} response from `{url}`: HTTP {status}, {headers_str}, body: {body_preview}, decode error: {err}" - )) - }) -} - -fn update_remote_control_server_token( - enrollment: &mut RemoteControlEnrollment, - url: &str, - token: String, - expires_at: String, -) -> io::Result<()> { - let expires_at = OffsetDateTime::parse(&expires_at, &Rfc3339).map_err(|err| { - io::Error::other(format!( - "failed to parse remote control server token expiry from `{url}`: {err}" - )) - })?; - enrollment.remote_control_token = Some(token); - enrollment.expires_at = Some(expires_at); - Ok(()) -} - #[cfg(test)] mod tests { use super::*; + use crate::transport::remote_control::auth::RemoteControlConnectionAuth; use crate::transport::remote_control::protocol::normalize_remote_control_url; + use crate::transport::remote_control::server_api::enroll_remote_control_server; use codex_state::StateRuntime; use pretty_assertions::assert_eq; use serde_json::json; @@ -575,28 +450,6 @@ mod tests { .expect("state runtime should initialize") } - #[test] - fn remote_control_enrollment_refreshes_server_token_before_expiry() { - let expires_soon = RemoteControlEnrollment { - remote_control_target: normalize_remote_control_url("http://localhost/backend-api/") - .expect("target should normalize"), - account_id: "account-a".to_string(), - environment_id: "env_first".to_string(), - server_id: "srv_e_first".to_string(), - server_name: "first-server".to_string(), - remote_control_token: Some("expires-soon".to_string()), - expires_at: Some(OffsetDateTime::now_utc() + time::Duration::seconds(29)), - }; - let expires_later = RemoteControlEnrollment { - expires_at: Some(OffsetDateTime::now_utc() + time::Duration::seconds(31)), - remote_control_token: Some("expires-later".to_string()), - ..expires_soon.clone() - }; - - assert!(expires_soon.should_refresh_server_token()); - assert!(!expires_later.should_refresh_server_token()); - } - #[test] fn preview_remote_control_response_body_redacts_server_token() { assert_eq!( @@ -630,6 +483,7 @@ mod tests { server_name: "first-server".to_string(), remote_control_token: None, expires_at: None, + next_refresh_at: None, }; let second_enrollment = RemoteControlEnrollment { remote_control_target: second_target.clone(), @@ -639,6 +493,7 @@ mod tests { server_name: "second-server".to_string(), remote_control_token: None, expires_at: None, + next_refresh_at: None, }; update_persisted_remote_control_enrollment( @@ -714,6 +569,7 @@ mod tests { server_name: "first-server".to_string(), remote_control_token: None, expires_at: None, + next_refresh_at: None, }; let second_enrollment = RemoteControlEnrollment { remote_control_target: second_target.clone(), @@ -723,6 +579,7 @@ mod tests { server_name: "second-server".to_string(), remote_control_token: None, expires_at: None, + next_refresh_at: None, }; update_persisted_remote_control_enrollment( diff --git a/codex-rs/app-server-transport/src/transport/remote_control/mod.rs b/codex-rs/app-server-transport/src/transport/remote_control/mod.rs index d0e934d90bb7..7422c92c42cb 100644 --- a/codex-rs/app-server-transport/src/transport/remote_control/mod.rs +++ b/codex-rs/app-server-transport/src/transport/remote_control/mod.rs @@ -5,6 +5,7 @@ mod desired_state; mod enroll; mod protocol; mod segment; +mod server_api; mod websocket; use self::auth::load_remote_control_auth; @@ -12,10 +13,10 @@ use self::auth::recover_remote_control_auth; use self::desired_state::RemoteControlDesiredState; use self::desired_state::acquire_persistence_lock; use self::enroll::RemoteControlEnrollment; -use self::enroll::enroll_remote_control_server; use self::enroll::load_persisted_remote_control_enrollment; -use self::enroll::refresh_remote_control_server; use self::enroll::update_persisted_remote_control_enrollment; +use self::server_api::enroll_remote_control_server; +use self::server_api::refresh_remote_control_server; use crate::transport::remote_control::websocket::RemoteControlChannels; use crate::transport::remote_control::websocket::RemoteControlStatusPublisher; use crate::transport::remote_control::websocket::RemoteControlWebsocket; @@ -825,27 +826,39 @@ async fn refresh_pairing_enrollment( installation_id: &str, enrollment: &mut RemoteControlEnrollment, ) -> io::Result<()> { - if let Err(err) = refresh_remote_control_server(auth, installation_id, enrollment).await { - if err.kind() != io::ErrorKind::PermissionDenied { - return Err(err); - } + let mut refresh_result = refresh_remote_control_server(auth, installation_id, enrollment).await; + if refresh_result + .as_ref() + .is_err_and(|err| err.kind() == io::ErrorKind::PermissionDenied) + { let mut auth_recovery = auth_manager.unauthorized_recovery(); let mut auth_change_rx = auth_manager.auth_change_receiver(); - if !recover_remote_control_auth(&mut auth_recovery, &mut auth_change_rx).await { - return Err(err); - } - *auth = load_remote_control_auth(auth_manager) - .await - .map_err(|_| pairing_unavailable_error())?; - if auth.account_id != enrollment.account_id { - return Err(pairing_unavailable_error()); + if recover_remote_control_auth(&mut auth_recovery, &mut auth_change_rx).await { + match load_remote_control_auth(auth_manager).await { + Ok(recovered_auth) if recovered_auth.account_id == enrollment.account_id => { + *auth = recovered_auth; + refresh_result = + refresh_remote_control_server(auth, installation_id, enrollment).await; + } + Ok(_) | Err(_) => { + enrollment.clear_server_token(); + refresh_result = Err(pairing_unavailable_error()); + } + } + } else { + enrollment.clear_server_token(); } - refresh_remote_control_server(auth, installation_id, enrollment).await? } - if replace_current_enrollment(current_enrollment, enrollment) { - Ok(()) - } else { + if refresh_result + .as_ref() + .is_err_and(|err| err.kind() == io::ErrorKind::PermissionDenied) + { + enrollment.clear_server_token(); + } + if !replace_current_enrollment(current_enrollment, enrollment) { Err(pairing_unavailable_error()) + } else { + refresh_result } } diff --git a/codex-rs/app-server-transport/src/transport/remote_control/server_api.rs b/codex-rs/app-server-transport/src/transport/remote_control/server_api.rs new file mode 100644 index 000000000000..def38907c72f --- /dev/null +++ b/codex-rs/app-server-transport/src/transport/remote_control/server_api.rs @@ -0,0 +1,339 @@ +use super::auth::RemoteControlConnectionAuth; +use super::enroll::RemoteControlEnrollment; +use super::enroll::RemoteControlServerTokenRefreshRequirement; +use super::enroll::format_headers; +use super::enroll::preview_remote_control_response_body; +use super::protocol::EnrollRemoteServerRequest; +use super::protocol::EnrollRemoteServerResponse; +use super::protocol::RefreshRemoteServerRequest; +use super::protocol::RemoteControlTarget; +use axum::http::HeaderMap; +use axum::http::StatusCode; +use codex_login::default_client::build_reqwest_client; +use rand::Rng; +use serde::Serialize; +use serde::de::DeserializeOwned; +use std::fmt; +use std::io; +use std::io::ErrorKind; +use std::time::Duration; +use time::OffsetDateTime; +use time::format_description::well_known::Rfc3339; +use tracing::warn; + +const REMOTE_CONTROL_ENROLL_TIMEOUT: Duration = Duration::from_secs(30); +const REMOTE_CONTROL_SERVER_TOKEN_REFRESH_BACKOFF_MIN_SECS: u64 = 24; +const REMOTE_CONTROL_SERVER_TOKEN_REFRESH_BACKOFF_MAX_SECS: u64 = 36; + +pub(super) const REMOTE_CONTROL_INSTALLATION_ID_HEADER: &str = "x-codex-installation-id"; + +#[derive(Debug)] +struct RemoteControlServerRequestError { + message: String, + status: Option, + retry_at: Option, +} + +impl RemoteControlServerRequestError { + fn io_error( + message: String, + status: Option, + retry_at: Option, + timed_out: bool, + ) -> io::Error { + let kind = match status { + Some(StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN) => ErrorKind::PermissionDenied, + Some(StatusCode::NOT_FOUND) => ErrorKind::NotFound, + Some(status) if timed_out && !status.is_client_error() => ErrorKind::TimedOut, + None if timed_out => ErrorKind::TimedOut, + Some(_) | None => ErrorKind::Other, + }; + io::Error::new( + kind, + Self { + message, + status, + retry_at, + }, + ) + } + + fn is_transient(&self, kind: ErrorKind) -> bool { + kind == ErrorKind::TimedOut + || self.status.is_none() + || self.status.is_some_and(|status| { + status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error() + }) + } +} + +impl fmt::Display for RemoteControlServerRequestError { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str(&self.message) + } +} + +impl std::error::Error for RemoteControlServerRequestError {} + +pub(super) async fn enroll_remote_control_server( + remote_control_target: &RemoteControlTarget, + auth: &RemoteControlConnectionAuth, + installation_id: &str, + server_name: &str, +) -> io::Result { + let enroll_url = &remote_control_target.enroll_url; + let request = EnrollRemoteServerRequest { + name: server_name.to_string(), + os: std::env::consts::OS, + arch: std::env::consts::ARCH, + app_server_version: env!("CARGO_PKG_VERSION"), + installation_id: installation_id.to_string(), + }; + let enrollment_response = send_remote_control_server_request::<_, EnrollRemoteServerResponse>( + enroll_url, + auth, + installation_id, + &request, + "enroll", + "server enrollment", + REMOTE_CONTROL_ENROLL_TIMEOUT, + ) + .await?; + let mut enrollment = RemoteControlEnrollment { + remote_control_target: remote_control_target.clone(), + account_id: auth.account_id.clone(), + environment_id: enrollment_response.environment_id, + server_id: enrollment_response.server_id, + server_name: server_name.to_string(), + remote_control_token: None, + expires_at: None, + next_refresh_at: None, + }; + update_remote_control_server_token( + &mut enrollment, + enroll_url, + enrollment_response.remote_control_token, + enrollment_response.expires_at, + )?; + Ok(enrollment) +} + +pub(super) async fn refresh_remote_control_server( + auth: &RemoteControlConnectionAuth, + installation_id: &str, + enrollment: &mut RemoteControlEnrollment, +) -> io::Result<()> { + let now = OffsetDateTime::now_utc(); + let refresh_requirement = enrollment.server_token_refresh_requirement_at(now); + if refresh_requirement == RemoteControlServerTokenRefreshRequirement::NotNeeded { + return Ok(()); + } + if refresh_requirement == RemoteControlServerTokenRefreshRequirement::Required + && let Some(next_refresh_at) = enrollment.next_refresh_at + && next_refresh_at > now + { + return Err(io::Error::new( + ErrorKind::WouldBlock, + format!("remote control server token refresh deferred until {next_refresh_at}"), + )); + } + let refresh_url = enrollment.remote_control_target.refresh_url.clone(); + let request = RefreshRemoteServerRequest { + server_id: enrollment.server_id.clone(), + installation_id: installation_id.to_string(), + }; + let refreshed = match send_remote_control_server_request::<_, EnrollRemoteServerResponse>( + &refresh_url, + auth, + installation_id, + &request, + "refresh", + "server refresh", + REMOTE_CONTROL_ENROLL_TIMEOUT, + ) + .await + { + Ok(refreshed) => refreshed, + Err(err) => { + let Some(refresh_error) = remote_control_server_request_error(&err) else { + return Err(err); + }; + if !refresh_error.is_transient(err.kind()) { + return Err(err); + } + let now = OffsetDateTime::now_utc(); + let refresh_is_required = enrollment.server_token_refresh_requirement_at(now) + == RemoteControlServerTokenRefreshRequirement::Required; + let (refresh_delay, next_refresh_at) = refresh_deferral(refresh_error.retry_at, now); + enrollment.next_refresh_at = Some(next_refresh_at); + if refresh_is_required { + warn!( + refresh_url, + server_id = %enrollment.server_id, + environment_id = %enrollment.environment_id, + error = %err, + ?refresh_delay, + %next_refresh_at, + "required remote control server token refresh failed; deferring next attempt" + ); + return Err(err); + } + warn!( + refresh_url, + server_id = %enrollment.server_id, + environment_id = %enrollment.environment_id, + error = %err, + ?refresh_delay, + %next_refresh_at, + "proactive remote control server token refresh failed; continuing with valid token" + ); + return Ok(()); + } + }; + if refreshed.server_id != enrollment.server_id + || refreshed.environment_id != enrollment.environment_id + { + return Err(io::Error::other(format!( + "remote control server refresh returned mismatched enrollment: expected server_id={}, environment_id={}; got server_id={}, environment_id={}", + enrollment.server_id, + enrollment.environment_id, + refreshed.server_id, + refreshed.environment_id + ))); + } + + update_remote_control_server_token( + enrollment, + &refresh_url, + refreshed.remote_control_token, + refreshed.expires_at, + ) +} + +async fn send_remote_control_server_request( + url: &str, + auth: &RemoteControlConnectionAuth, + installation_id: &str, + request: &Request, + action: &str, + response_kind: &str, + timeout: Duration, +) -> io::Result +where + Request: Serialize, + Response: DeserializeOwned, +{ + let client = build_reqwest_client(); + let auth_headers = auth.request_headers()?; + let response = client + .post(url) + .timeout(timeout) + .headers(auth_headers) + .header(REMOTE_CONTROL_INSTALLATION_ID_HEADER, installation_id) + .json(request) + .send() + .await + .map_err(|err| { + let timed_out = err.is_timeout(); + RemoteControlServerRequestError::io_error( + format!("failed to {action} remote control server at `{url}`: {err}"), + /*status*/ None, + /*retry_at*/ None, + timed_out, + ) + })?; + let headers = response.headers().clone(); + let status = response.status(); + let retry_at = parse_retry_after(&headers, OffsetDateTime::now_utc()); + let body = response.bytes().await.map_err(|err| { + let timed_out = err.is_timeout(); + RemoteControlServerRequestError::io_error( + format!("failed to read remote control {response_kind} response from `{url}`: {err}"), + Some(status), + retry_at, + timed_out, + ) + })?; + let body_preview = preview_remote_control_response_body(&body); + if !status.is_success() { + let headers_str = format_headers(&headers); + return Err(RemoteControlServerRequestError::io_error( + format!( + "remote control {response_kind} failed at `{url}`: HTTP {status}, {headers_str}, body: {body_preview}" + ), + Some(status), + retry_at, + /*timed_out*/ false, + )); + } + + serde_json::from_slice::(&body).map_err(|err| { + let headers_str = format_headers(&headers); + io::Error::other(format!( + "failed to parse remote control {response_kind} response from `{url}`: HTTP {status}, {headers_str}, body: {body_preview}, decode error: {err}" + )) + }) +} + +fn update_remote_control_server_token( + enrollment: &mut RemoteControlEnrollment, + url: &str, + token: String, + expires_at: String, +) -> io::Result<()> { + let expires_at = OffsetDateTime::parse(&expires_at, &Rfc3339).map_err(|err| { + io::Error::other(format!( + "failed to parse remote control server token expiry from `{url}`: {err}" + )) + })?; + enrollment.remote_control_token = Some(token); + enrollment.expires_at = Some(expires_at); + enrollment.next_refresh_at = None; + Ok(()) +} + +fn remote_control_server_request_error( + err: &io::Error, +) -> Option<&RemoteControlServerRequestError> { + err.get_ref()?.downcast_ref() +} + +fn parse_retry_after(headers: &HeaderMap, received_at: OffsetDateTime) -> Option { + let retry_after = headers + .get(axum::http::header::RETRY_AFTER)? + .to_str() + .ok()?; + let retry_at = if let Ok(seconds) = retry_after.parse::() { + let seconds = i64::try_from(seconds).ok()?; + received_at.checked_add(time::Duration::seconds(seconds))? + } else { + OffsetDateTime::from(httpdate::parse_http_date(retry_after).ok()?) + }; + (retry_at > received_at).then_some(retry_at) +} + +fn refresh_deferral( + retry_at: Option, + now: OffsetDateTime, +) -> (Duration, OffsetDateTime) { + if let Some(retry_at) = retry_at + && let Ok(delay) = Duration::try_from(retry_at - now) + && !delay.is_zero() + { + return (delay, retry_at); + } + let delay = remote_control_server_token_refresh_backoff(); + let next_refresh_at = now + time::Duration::seconds(delay.as_secs() as i64); + (delay, next_refresh_at) +} + +fn remote_control_server_token_refresh_backoff() -> Duration { + Duration::from_secs(rand::rng().random_range( + REMOTE_CONTROL_SERVER_TOKEN_REFRESH_BACKOFF_MIN_SECS + ..=REMOTE_CONTROL_SERVER_TOKEN_REFRESH_BACKOFF_MAX_SECS, + )) +} + +#[cfg(test)] +#[path = "server_api_tests.rs"] +mod tests; diff --git a/codex-rs/app-server-transport/src/transport/remote_control/server_api_tests.rs b/codex-rs/app-server-transport/src/transport/remote_control/server_api_tests.rs new file mode 100644 index 000000000000..786dc65a906a --- /dev/null +++ b/codex-rs/app-server-transport/src/transport/remote_control/server_api_tests.rs @@ -0,0 +1,284 @@ +use super::*; +use crate::transport::remote_control::protocol::normalize_remote_control_url; +use pretty_assertions::assert_eq; +use serde_json::json; +use std::time::SystemTime; +use tokio::io::AsyncWriteExt; +use tokio::net::TcpListener; +use tokio::sync::oneshot; + +const TEST_REQUEST_TIMEOUT: Duration = Duration::from_millis(100); + +fn auth() -> RemoteControlConnectionAuth { + RemoteControlConnectionAuth { + auth_provider: codex_model_provider::unauthenticated_auth_provider(), + account_id: "account-a".to_string(), + } +} + +fn assert_transient_timeout(err: &io::Error, expected_status: Option) { + let request_error = remote_control_server_request_error(err) + .expect("request error should preserve refresh metadata"); + assert_eq!( + ( + err.kind(), + request_error.status, + request_error.is_transient(err.kind()), + ), + (ErrorKind::TimedOut, expected_status, true) + ); +} + +async fn timed_out_request(partial_response: Option<&'static [u8]>) -> io::Error { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let url = format!( + "http://{}/backend-api/wham/remote/control/server/refresh", + listener + .local_addr() + .expect("listener should have a local address") + ); + let (request_done_tx, request_done_rx) = oneshot::channel(); + let server_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.expect("request should connect"); + if let Some(partial_response) = partial_response { + stream + .write_all(partial_response) + .await + .expect("partial response should write"); + } + request_done_rx + .await + .expect("test should report request completion"); + }); + + let err = send_remote_control_server_request::<_, serde_json::Value>( + &url, + &auth(), + "installation-id", + &json!({"server_id": "server-id"}), + "refresh", + "server refresh", + TEST_REQUEST_TIMEOUT, + ) + .await + .expect_err("incomplete response should time out"); + request_done_tx + .send(()) + .expect("server should wait for request completion"); + server_task.await.expect("server task should finish"); + err +} + +fn enrollment(now: OffsetDateTime) -> RemoteControlEnrollment { + RemoteControlEnrollment { + remote_control_target: normalize_remote_control_url("http://localhost/backend-api/") + .expect("target should normalize"), + account_id: "account-a".to_string(), + environment_id: "env_first".to_string(), + server_id: "srv_e_first".to_string(), + server_name: "first-server".to_string(), + remote_control_token: Some("token".to_string()), + expires_at: Some(now + time::Duration::seconds(300)), + next_refresh_at: None, + } +} + +#[test] +fn remote_control_enrollment_classifies_server_token_refresh_requirement() { + let now = + OffsetDateTime::from_unix_timestamp(1_700_000_000).expect("test timestamp should parse"); + let enrollment = enrollment(now); + let cases = [ + ( + enrollment.clone(), + RemoteControlServerTokenRefreshRequirement::Proactive, + ), + ( + RemoteControlEnrollment { + expires_at: Some(now + time::Duration::seconds(301)), + ..enrollment.clone() + }, + RemoteControlServerTokenRefreshRequirement::NotNeeded, + ), + ( + RemoteControlEnrollment { + next_refresh_at: Some(now + time::Duration::seconds(30)), + ..enrollment.clone() + }, + RemoteControlServerTokenRefreshRequirement::NotNeeded, + ), + ( + RemoteControlEnrollment { + next_refresh_at: Some(now), + ..enrollment.clone() + }, + RemoteControlServerTokenRefreshRequirement::Proactive, + ), + ( + RemoteControlEnrollment { + remote_control_token: None, + ..enrollment.clone() + }, + RemoteControlServerTokenRefreshRequirement::Required, + ), + ( + RemoteControlEnrollment { + expires_at: None, + ..enrollment.clone() + }, + RemoteControlServerTokenRefreshRequirement::Required, + ), + ( + RemoteControlEnrollment { + expires_at: Some(now), + next_refresh_at: Some(now + time::Duration::hours(1)), + ..enrollment + }, + RemoteControlServerTokenRefreshRequirement::Required, + ), + ]; + + for (enrollment, expected) in cases { + assert_eq!( + enrollment.server_token_refresh_requirement_at(now), + expected + ); + } +} + +#[test] +fn remote_control_server_request_error_classifies_status_before_timeout() { + let cases = [ + (None, true, ErrorKind::TimedOut, true), + (Some(StatusCode::OK), true, ErrorKind::TimedOut, true), + ( + Some(StatusCode::TOO_MANY_REQUESTS), + false, + ErrorKind::Other, + true, + ), + (Some(StatusCode::BAD_GATEWAY), false, ErrorKind::Other, true), + ( + Some(StatusCode::UNAUTHORIZED), + true, + ErrorKind::PermissionDenied, + false, + ), + ( + Some(StatusCode::FORBIDDEN), + true, + ErrorKind::PermissionDenied, + false, + ), + ( + Some(StatusCode::NOT_FOUND), + true, + ErrorKind::NotFound, + false, + ), + (Some(StatusCode::BAD_REQUEST), true, ErrorKind::Other, false), + (None, false, ErrorKind::Other, true), + ]; + + for (status, timed_out, expected_kind, expected_transient) in cases { + let err = RemoteControlServerRequestError::io_error( + String::new(), + status, + /*retry_at*/ None, + timed_out, + ); + let request_error = remote_control_server_request_error(&err) + .expect("request error should preserve refresh metadata"); + assert_eq!( + (err.kind(), request_error.is_transient(err.kind())), + (expected_kind, expected_transient) + ); + } +} + +#[tokio::test] +async fn request_timeout_before_response_headers_is_transient() { + let err = timed_out_request(/*partial_response*/ None).await; + assert_transient_timeout(&err, /*expected_status*/ None); +} + +#[tokio::test] +async fn response_body_timeout_is_transient() { + let err = timed_out_request(Some(b"HTTP/1.1 200 OK\r\nContent-Length: 20\r\n\r\n{")).await; + assert_transient_timeout(&err, Some(StatusCode::OK)); +} + +#[test] +fn retry_after_supports_delta_seconds_and_http_dates() { + let now = + OffsetDateTime::from_unix_timestamp(1_700_000_000).expect("test timestamp should parse"); + let mut headers = HeaderMap::new(); + headers.insert( + axum::http::header::RETRY_AFTER, + axum::http::HeaderValue::from_static("120"), + ); + assert_eq!( + parse_retry_after(&headers, now), + Some(now + time::Duration::seconds(120)) + ); + + let retry_at = now + time::Duration::seconds(90); + let retry_at_system = SystemTime::UNIX_EPOCH + Duration::from_secs(1_700_000_090); + headers.insert( + axum::http::header::RETRY_AFTER, + httpdate::fmt_http_date(retry_at_system) + .parse() + .expect("HTTP date should be a valid header value"), + ); + assert_eq!(parse_retry_after(&headers, now), Some(retry_at)); +} + +#[test] +fn invalid_or_expired_retry_after_uses_bounded_fallback() { + let now = + OffsetDateTime::from_unix_timestamp(1_700_000_000).expect("test timestamp should parse"); + let mut headers = HeaderMap::new(); + headers.insert( + axum::http::header::RETRY_AFTER, + axum::http::HeaderValue::from_static("invalid"), + ); + assert_eq!(parse_retry_after(&headers, now), None); + + headers.insert( + axum::http::header::RETRY_AFTER, + httpdate::fmt_http_date(SystemTime::UNIX_EPOCH + Duration::from_secs(1_699_999_999)) + .parse() + .expect("HTTP date should be a valid header value"), + ); + assert_eq!(parse_retry_after(&headers, now), None); + + let expired_while_reading_body = Some(now + time::Duration::seconds(1)); + for retry_at in [None, expired_while_reading_body] { + let deferred_at = now + time::Duration::seconds(2); + let (delay, next_refresh_at) = refresh_deferral(retry_at, deferred_at); + assert!( + (Duration::from_secs(REMOTE_CONTROL_SERVER_TOKEN_REFRESH_BACKOFF_MIN_SECS) + ..=Duration::from_secs(REMOTE_CONTROL_SERVER_TOKEN_REFRESH_BACKOFF_MAX_SECS,)) + .contains(&delay) + ); + assert_eq!( + next_refresh_at, + deferred_at + time::Duration::seconds(delay.as_secs() as i64) + ); + } +} + +#[test] +fn http_date_retry_after_preserves_absolute_deadline() { + let received_at = + OffsetDateTime::from_unix_timestamp(1_700_000_000).expect("test timestamp should parse"); + let retry_at = received_at + time::Duration::seconds(120); + let body_read_at = received_at + time::Duration::seconds(30); + + assert_eq!( + refresh_deferral(Some(retry_at), body_read_at), + (Duration::from_secs(90), retry_at) + ); +} diff --git a/codex-rs/app-server-transport/src/transport/remote_control/tests.rs b/codex-rs/app-server-transport/src/transport/remote_control/tests.rs index 3a2d844654a7..465c6b13ec95 100644 --- a/codex-rs/app-server-transport/src/transport/remote_control/tests.rs +++ b/codex-rs/app-server-transport/src/transport/remote_control/tests.rs @@ -1,5 +1,4 @@ use super::auth::REMOTE_CONTROL_ACCOUNT_ID_HEADER; -use super::enroll::REMOTE_CONTROL_INSTALLATION_ID_HEADER; use super::enroll::RemoteControlEnrollment; use super::enroll::load_persisted_remote_control_enrollment; use super::enroll::update_persisted_remote_control_enrollment; @@ -8,6 +7,7 @@ use super::protocol::ClientEvent; use super::protocol::ClientId; use super::protocol::StreamId; use super::protocol::normalize_remote_control_url; +use super::server_api::REMOTE_CONTROL_INSTALLATION_ID_HEADER; use super::websocket::REMOTE_CONTROL_PROTOCOL_VERSION; use super::websocket::RemoteControlWebsocket; use super::websocket::RemoteControlWebsocketConfig; @@ -372,7 +372,7 @@ fn test_server_name() -> String { gethostname().to_string_lossy().trim().to_string() } -fn remote_control_handle_with_current_enrollment( +pub(super) fn remote_control_handle_with_current_enrollment( remote_control_url: &str, auth_manager: Arc, ) -> RemoteControlHandle { @@ -400,6 +400,7 @@ fn remote_control_handle_with_current_enrollment( OffsetDateTime::from_unix_timestamp(33_336_362_096) .expect("future timestamp should parse"), ), + next_refresh_at: None, }, ))); RemoteControlHandle { @@ -1682,6 +1683,7 @@ async fn remote_control_http_mode_refreshes_persisted_enrollment_before_connecti server_name: "persisted-server".to_string(), remote_control_token: None, expires_at: None, + next_refresh_at: None, }; update_persisted_remote_control_enrollment( Some(state_db.as_ref()), @@ -1802,6 +1804,7 @@ async fn remote_control_stdio_mode_waits_for_client_name_before_connecting() { server_name: "persisted-server".to_string(), remote_control_token: None, expires_at: None, + next_refresh_at: None, }; update_persisted_remote_control_enrollment( Some(state_db.as_ref()), @@ -1899,6 +1902,7 @@ async fn remote_control_waits_for_account_id_before_enrolling() { server_name: expected_server_name, remote_control_token: None, expires_at: None, + next_refresh_at: None, }; let (transport_event_tx, _transport_event_rx) = @@ -1995,6 +1999,7 @@ async fn persisted_enable_does_not_follow_auth_to_an_account_without_a_preferenc server_name: "server-a".to_string(), remote_control_token: None, expires_at: None, + next_refresh_at: None, }; update_persisted_remote_control_enrollment( Some(state_db.as_ref()), @@ -2102,6 +2107,7 @@ async fn remote_control_http_mode_reenrolls_when_refresh_reports_stale_enrollmen server_name: "stale-server".to_string(), remote_control_token: None, expires_at: None, + next_refresh_at: None, }; let refreshed_enrollment = RemoteControlEnrollment { remote_control_target: remote_control_target.clone(), @@ -2111,6 +2117,7 @@ async fn remote_control_http_mode_reenrolls_when_refresh_reports_stale_enrollmen server_name: expected_server_name, remote_control_token: None, expires_at: None, + next_refresh_at: None, }; update_persisted_remote_control_enrollment( Some(state_db.as_ref()), @@ -2225,6 +2232,7 @@ async fn remote_control_http_mode_reenrolls_after_explicit_missing_server_404() server_name: "stale-server".to_string(), remote_control_token: None, expires_at: None, + next_refresh_at: None, }; let refreshed_enrollment = RemoteControlEnrollment { remote_control_target: remote_control_target.clone(), @@ -2234,6 +2242,7 @@ async fn remote_control_http_mode_reenrolls_after_explicit_missing_server_404() server_name: expected_server_name, remote_control_token: None, expires_at: None, + next_refresh_at: None, }; update_persisted_remote_control_enrollment( Some(state_db.as_ref()), @@ -2371,6 +2380,7 @@ async fn remote_control_http_mode_preserves_stale_enrollment_when_reenrollment_f server_name: test_server_name(), remote_control_token: None, expires_at: None, + next_refresh_at: None, }; update_persisted_remote_control_enrollment( Some(state_db.as_ref()), @@ -2421,6 +2431,7 @@ async fn remote_control_http_mode_preserves_stale_enrollment_when_reenrollment_f retry_refresh_request.request_line, "POST /backend-api/wham/remote/control/server/refresh HTTP/1.1" ); + let refresh_failed_at = OffsetDateTime::now_utc(); respond_with_status( retry_refresh_request.stream, "500 Internal Server Error", @@ -2428,9 +2439,26 @@ async fn remote_control_http_mode_preserves_stale_enrollment_when_reenrollment_f ) .await; + let current_enrollment = remote_handle + .current_enrollment + .lock() + .await + .clone() + .expect("stale enrollment should remain available"); + let next_refresh_at = current_enrollment + .next_refresh_at + .expect("required refresh failure should set a retry deadline"); + assert!( + (refresh_failed_at + time::Duration::seconds(24) + ..=OffsetDateTime::now_utc() + time::Duration::seconds(36)) + .contains(&next_refresh_at) + ); assert_eq!( - *remote_handle.current_enrollment.lock().await, - Some(stale_enrollment.clone()) + current_enrollment, + RemoteControlEnrollment { + next_refresh_at: Some(next_refresh_at), + ..stale_enrollment.clone() + } ); assert_eq!( state_db @@ -2474,6 +2502,7 @@ async fn remote_control_http_mode_preserves_enrollment_after_generic_websocket_4 server_name: "stale-server".to_string(), remote_control_token: None, expires_at: None, + next_refresh_at: None, }; update_persisted_remote_control_enrollment( Some(state_db.as_ref()), diff --git a/codex-rs/app-server-transport/src/transport/remote_control/tests/pairing_tests.rs b/codex-rs/app-server-transport/src/transport/remote_control/tests/pairing_tests.rs index 380f13984eec..6f8b9878ea35 100644 --- a/codex-rs/app-server-transport/src/transport/remote_control/tests/pairing_tests.rs +++ b/codex-rs/app-server-transport/src/transport/remote_control/tests/pairing_tests.rs @@ -21,6 +21,69 @@ fn remote_control_enrollment( OffsetDateTime::from_unix_timestamp(33_336_362_096) .expect("future timestamp should parse"), ), + next_refresh_at: None, + } +} + +async fn auth_manager_with_replacement( + codex_home: &TempDir, + replacement_account_id: &str, +) -> Arc { + let mut stale_auth = remote_control_auth_dot_json(Some("account_id")); + stale_auth + .tokens + .as_mut() + .expect("stale auth should include tokens") + .access_token = "stale-token".to_string(); + save_auth( + codex_home.path(), + &stale_auth, + AuthCredentialsStoreMode::File, + AuthKeyringBackendKind::default(), + ) + .expect("stale auth should save"); + let auth_manager = AuthManager::shared( + codex_home.path().to_path_buf(), + /*enable_codex_api_key_env*/ false, + AuthCredentialsStoreMode::File, + /*forced_chatgpt_workspace_id*/ None, + /*chatgpt_base_url*/ None, + AuthKeyringBackendKind::default(), + /*auth_route_config*/ None, + ) + .await; + let mut replacement_auth = remote_control_auth_dot_json(Some(replacement_account_id)); + replacement_auth + .tokens + .as_mut() + .expect("replacement auth should include tokens") + .access_token = "fresh-token".to_string(); + save_auth( + codex_home.path(), + &replacement_auth, + AuthCredentialsStoreMode::File, + AuthKeyringBackendKind::default(), + ) + .expect("replacement auth should save"); + auth_manager +} + +fn pairing_response_json(server_id: &str, environment_id: &str) -> serde_json::Value { + json!({ + "pairing_code": "pairing-code", + "manual_pairing_code": "ABCD-EFGH", + "server_id": server_id, + "environment_id": environment_id, + "expires_at": "3026-05-22T12:34:56Z", + }) +} + +fn pairing_response(environment_id: &str) -> RemoteControlPairingStartResponse { + RemoteControlPairingStartResponse { + pairing_code: "pairing-code".to_string(), + manual_pairing_code: Some("ABCD-EFGH".to_string()), + environment_id: environment_id.to_string(), + expires_at: 33_336_362_096, } } @@ -147,13 +210,7 @@ async fn remote_control_handle_starts_pairing_before_websocket_connects() { ); respond_with_json( pairing_request.stream, - json!({ - "pairing_code": "pairing-code", - "manual_pairing_code": "ABCD-EFGH", - "server_id": "srv_e_test", - "environment_id": "env_test", - "expires_at": "3026-05-22T12:34:56Z", - }), + pairing_response_json("srv_e_test", "env_test"), ) .await; }); @@ -178,15 +235,139 @@ async fn remote_control_handle_starts_pairing_before_websocket_connects() { .expect("pairing should use the current server before websocket connect"); server_task.await.expect("server task should finish"); - assert_eq!( - response, - RemoteControlPairingStartResponse { - pairing_code: "pairing-code".to_string(), - manual_pairing_code: Some("ABCD-EFGH".to_string()), - environment_id: "env_test".to_string(), - expires_at: 33_336_362_096, - } + assert_eq!(response, pairing_response("env_test")); +} + +#[tokio::test] +async fn proactive_refresh_rate_limit_uses_valid_token_for_pairing() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let server_task = tokio::spawn(async move { + let refresh_request = accept_http_request(&listener).await; + assert_eq!( + refresh_request.request_line, + "POST /backend-api/wham/remote/control/server/refresh HTTP/1.1" + ); + respond_with_status_and_headers( + refresh_request.stream, + "429 Too Many Requests", + &[], + "rate limited", + ) + .await; + + let pairing_request = accept_http_request(&listener).await; + assert_eq!( + pairing_request.request_line, + "POST /backend-api/wham/remote/control/server/pair HTTP/1.1" + ); + assert_eq!( + pairing_request.headers.get("authorization"), + Some(&format!("Bearer {TEST_REMOTE_CONTROL_SERVER_TOKEN}")) + ); + respond_with_json( + pairing_request.stream, + pairing_response_json("srv_e_test", "env_test"), + ) + .await; + }); + let remote_handle = remote_control_handle_with_current_enrollment( + &remote_control_url, + remote_control_auth_manager(), + ); + remote_handle + .current_enrollment + .lock() + .await + .as_mut() + .expect("current enrollment should exist") + .expires_at = Some(OffsetDateTime::now_utc() + time::Duration::minutes(4)); + + let response = remote_handle + .start_pairing( + RemoteControlPairingStartParams::default(), + /*app_server_client_name*/ None, + ) + .await + .expect("valid token should allow pairing after proactive refresh failure"); + server_task.await.expect("server task should finish"); + + assert_eq!(response, pairing_response("env_test")); + assert!( + remote_handle + .current_enrollment + .snapshot() + .and_then(|enrollment| enrollment.next_refresh_at) + .is_some() + ); +} + +#[tokio::test] +async fn required_refresh_deadline_blocks_pairing_without_request() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let server_task = tokio::spawn(async move { + let refresh_request = accept_http_request(&listener).await; + assert_eq!( + refresh_request.request_line, + "POST /backend-api/wham/remote/control/server/refresh HTTP/1.1" + ); + respond_with_status_and_headers( + refresh_request.stream, + "502 Bad Gateway", + &[("retry-after", "120")], + "upstream unavailable", + ) + .await; + listener + }); + let remote_handle = remote_control_handle_with_current_enrollment( + &remote_control_url, + remote_control_auth_manager(), ); + remote_handle + .current_enrollment + .lock() + .await + .as_mut() + .expect("current enrollment should exist") + .expires_at = Some(OffsetDateTime::now_utc() - time::Duration::seconds(1)); + + let refresh_err = remote_handle + .start_pairing( + RemoteControlPairingStartParams::default(), + /*app_server_client_name*/ None, + ) + .await + .expect_err("required refresh failure should block pairing"); + let listener = server_task.await.expect("server task should finish"); + let next_refresh_at = remote_handle + .current_enrollment + .snapshot() + .and_then(|enrollment| enrollment.next_refresh_at) + .expect("required pairing refresh should preserve the retry deadline"); + let deferred_err = remote_handle + .start_pairing( + RemoteControlPairingStartParams::default(), + /*app_server_client_name*/ None, + ) + .await + .expect_err("required refresh deadline should block pairing"); + + assert!(refresh_err.to_string().contains("HTTP 502 Bad Gateway")); + assert_eq!(deferred_err.kind(), io::ErrorKind::WouldBlock); + assert!( + deferred_err + .to_string() + .contains(&next_refresh_at.to_string()) + ); + timeout(Duration::from_millis(100), listener.accept()) + .await + .expect_err("pairing should not issue a request before the refresh deadline"); } #[tokio::test] @@ -423,13 +604,7 @@ async fn remote_control_handle_refreshes_after_pairing_auth_failure() { ); respond_with_json( refreshed_pairing_request.stream, - json!({ - "pairing_code": "pairing-code", - "manual_pairing_code": "ABCD-EFGH", - "server_id": "srv_e_test", - "environment_id": "env_test", - "expires_at": "3026-05-22T12:34:56Z", - }), + pairing_response_json("srv_e_test", "env_test"), ) .await; }); @@ -447,14 +622,54 @@ async fn remote_control_handle_refreshes_after_pairing_auth_failure() { .expect("pairing should refresh after server token auth failure"); server_task.await.expect("server task should finish"); + assert_eq!(response, pairing_response("env_test")); +} + +#[tokio::test] +async fn pairing_auth_failure_preserves_refresh_deadline() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let server_task = tokio::spawn(async move { + let pairing_request = accept_http_request(&listener).await; + assert_eq!( + pairing_request.request_line, + "POST /backend-api/wham/remote/control/server/pair HTTP/1.1" + ); + respond_with_status(pairing_request.stream, "401 Unauthorized", "").await; + }); + let remote_handle = remote_control_handle_with_current_enrollment( + &remote_control_url, + remote_control_auth_manager(), + ); + let next_refresh_at = OffsetDateTime::now_utc() + time::Duration::minutes(2); + remote_handle + .current_enrollment + .lock() + .await + .as_mut() + .expect("current enrollment should exist") + .next_refresh_at = Some(next_refresh_at); + let mut expected_enrollment = remote_handle + .current_enrollment + .snapshot() + .expect("current enrollment should exist"); + expected_enrollment.clear_server_token(); + + let err = remote_handle + .start_pairing( + RemoteControlPairingStartParams::default(), + /*app_server_client_name*/ None, + ) + .await + .expect_err("refresh deadline should throttle recovery after token rejection"); + server_task.await.expect("server task should finish"); + + assert_eq!(err.kind(), io::ErrorKind::WouldBlock); assert_eq!( - response, - RemoteControlPairingStartResponse { - pairing_code: "pairing-code".to_string(), - manual_pairing_code: Some("ABCD-EFGH".to_string()), - environment_id: "env_test".to_string(), - expires_at: 33_336_362_096, - } + remote_handle.current_enrollment.snapshot(), + Some(expected_enrollment) ); } @@ -508,53 +723,12 @@ async fn remote_control_handle_recovers_auth_before_refreshing_pairing() { ); respond_with_json( pairing_request.stream, - json!({ - "pairing_code": "pairing-code", - "manual_pairing_code": "ABCD-EFGH", - "server_id": "srv_e_test", - "environment_id": "env_test", - "expires_at": "3026-05-22T12:34:56Z", - }), + pairing_response_json("srv_e_test", "env_test"), ) .await; }); let codex_home = TempDir::new().expect("temp dir should create"); - let mut stale_auth = remote_control_auth_dot_json(Some("account_id")); - stale_auth - .tokens - .as_mut() - .expect("stale auth should include tokens") - .access_token = "stale-token".to_string(); - save_auth( - codex_home.path(), - &stale_auth, - AuthCredentialsStoreMode::File, - AuthKeyringBackendKind::default(), - ) - .expect("stale auth should save"); - let auth_manager = AuthManager::shared( - codex_home.path().to_path_buf(), - /*enable_codex_api_key_env*/ false, - AuthCredentialsStoreMode::File, - /*forced_chatgpt_workspace_id*/ None, - /*chatgpt_base_url*/ None, - AuthKeyringBackendKind::default(), - /*auth_route_config*/ None, - ) - .await; - let mut fresh_auth = remote_control_auth_dot_json(Some("account_id")); - fresh_auth - .tokens - .as_mut() - .expect("fresh auth should include tokens") - .access_token = "fresh-token".to_string(); - save_auth( - codex_home.path(), - &fresh_auth, - AuthCredentialsStoreMode::File, - AuthKeyringBackendKind::default(), - ) - .expect("fresh auth should save"); + let auth_manager = auth_manager_with_replacement(&codex_home, "account_id").await; let remote_handle = remote_control_handle_with_current_enrollment(&remote_control_url, auth_manager); remote_handle @@ -574,14 +748,128 @@ async fn remote_control_handle_recovers_auth_before_refreshing_pairing() { .expect("pairing should refresh after auth recovery"); server_task.await.expect("server task should finish"); + assert_eq!(response, pairing_response("env_test")); +} + +#[tokio::test] +async fn pairing_publishes_refresh_deferral_after_auth_recovery() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let server_task = tokio::spawn(async move { + let stale_refresh_request = accept_http_request(&listener).await; + assert_eq!( + stale_refresh_request.headers.get("authorization"), + Some(&"Bearer stale-token".to_string()) + ); + respond_with_status(stale_refresh_request.stream, "401 Unauthorized", "").await; + + let recovered_refresh_request = accept_http_request(&listener).await; + assert_eq!( + recovered_refresh_request.headers.get("authorization"), + Some(&"Bearer fresh-token".to_string()) + ); + respond_with_status_and_headers( + recovered_refresh_request.stream, + "502 Bad Gateway", + &[("retry-after", "120")], + "upstream unavailable", + ) + .await; + }); + let codex_home = TempDir::new().expect("temp dir should create"); + let auth_manager = auth_manager_with_replacement(&codex_home, "account_id").await; + let remote_handle = + remote_control_handle_with_current_enrollment(&remote_control_url, auth_manager); + remote_handle + .current_enrollment + .lock() + .await + .as_mut() + .expect("current enrollment should exist") + .expires_at = Some(OffsetDateTime::now_utc() - time::Duration::seconds(1)); + + let refresh_started_at = OffsetDateTime::now_utc(); + let refresh_err = remote_handle + .start_pairing( + RemoteControlPairingStartParams::default(), + /*app_server_client_name*/ None, + ) + .await + .expect_err("required refresh should remain strict after auth recovery"); + let refresh_completed_at = OffsetDateTime::now_utc(); + let deferred_err = remote_handle + .start_pairing( + RemoteControlPairingStartParams::default(), + /*app_server_client_name*/ None, + ) + .await + .expect_err("published deadline should throttle the next pairing refresh"); + server_task.await.expect("server task should finish"); + + assert!(refresh_err.to_string().contains("HTTP 502 Bad Gateway")); + assert_eq!(deferred_err.kind(), io::ErrorKind::WouldBlock); + let next_refresh_at = remote_handle + .current_enrollment + .snapshot() + .and_then(|enrollment| enrollment.next_refresh_at) + .expect("required refresh failure should publish its retry deadline"); + assert!( + (refresh_started_at + time::Duration::seconds(120) + ..=refresh_completed_at + time::Duration::seconds(120)) + .contains(&next_refresh_at) + ); +} + +#[tokio::test] +async fn pairing_auth_recovery_failure_publishes_cleared_server_token() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let server_task = tokio::spawn(async move { + let stale_refresh_request = accept_http_request(&listener).await; + assert_eq!( + stale_refresh_request.request_line, + "POST /backend-api/wham/remote/control/server/refresh HTTP/1.1" + ); + assert_eq!( + stale_refresh_request.headers.get("authorization"), + Some(&"Bearer stale-token".to_string()) + ); + respond_with_status(stale_refresh_request.stream, "401 Unauthorized", "").await; + }); + let codex_home = TempDir::new().expect("temp dir should create"); + let auth_manager = auth_manager_with_replacement(&codex_home, "different_account_id").await; + let remote_handle = + remote_control_handle_with_current_enrollment(&remote_control_url, auth_manager); + remote_handle + .current_enrollment + .lock() + .await + .as_mut() + .expect("current enrollment should exist") + .expires_at = Some(OffsetDateTime::now_utc() + time::Duration::seconds(29)); + let mut expected_enrollment = remote_handle + .current_enrollment + .snapshot() + .expect("current enrollment should exist"); + expected_enrollment.clear_server_token(); + + let err = remote_handle + .start_pairing( + RemoteControlPairingStartParams::default(), + /*app_server_client_name*/ None, + ) + .await + .expect_err("pairing should fail after auth changes account"); + server_task.await.expect("server task should finish"); + + assert_eq!(err.kind(), io::ErrorKind::PermissionDenied); assert_eq!( - response, - RemoteControlPairingStartResponse { - pairing_code: "pairing-code".to_string(), - manual_pairing_code: Some("ABCD-EFGH".to_string()), - environment_id: "env_test".to_string(), - expires_at: 33_336_362_096, - } + remote_handle.current_enrollment.snapshot(), + Some(expected_enrollment) ); } @@ -688,6 +976,7 @@ async fn remote_control_handle_reenrolls_after_stale_pairing_enrollment() { server_name: test_server_name(), remote_control_token: None, expires_at: None, + next_refresh_at: None, }; update_persisted_remote_control_enrollment( Some(state_db.as_ref()), @@ -745,13 +1034,10 @@ async fn remote_control_handle_reenrolls_after_stale_pairing_enrollment() { ); respond_with_json( refreshed_pairing_request.stream, - json!({ - "pairing_code": "pairing-code", - "manual_pairing_code": "ABCD-EFGH", - "server_id": server_refreshed_enrollment.server_id, - "environment_id": server_refreshed_enrollment.environment_id, - "expires_at": "3026-05-22T12:34:56Z", - }), + pairing_response_json( + &server_refreshed_enrollment.server_id, + &server_refreshed_enrollment.environment_id, + ), ) .await; }); @@ -764,15 +1050,7 @@ async fn remote_control_handle_reenrolls_after_stale_pairing_enrollment() { .expect("pairing should re-enroll after stale enrollment"); server_task.await.expect("server task should finish"); - assert_eq!( - response, - RemoteControlPairingStartResponse { - pairing_code: "pairing-code".to_string(), - manual_pairing_code: Some("ABCD-EFGH".to_string()), - environment_id: "env_refreshed".to_string(), - expires_at: 33_336_362_096, - } - ); + assert_eq!(response, pairing_response("env_refreshed")); assert_eq!( state_db .get_remote_control_enrollment( diff --git a/codex-rs/app-server-transport/src/transport/remote_control/websocket.rs b/codex-rs/app-server-transport/src/transport/remote_control/websocket.rs index e34cd3409153..743445449989 100644 --- a/codex-rs/app-server-transport/src/transport/remote_control/websocket.rs +++ b/codex-rs/app-server-transport/src/transport/remote_control/websocket.rs @@ -23,12 +23,12 @@ use crate::transport::remote_control::auth::recover_remote_control_auth; use crate::transport::remote_control::client_tracker::ClientTracker; use crate::transport::remote_control::client_tracker::REMOTE_CONTROL_IDLE_SWEEP_INTERVAL; use crate::transport::remote_control::enroll::RemoteControlEnrollment; -use crate::transport::remote_control::enroll::enroll_remote_control_server; use crate::transport::remote_control::enroll::format_headers; use crate::transport::remote_control::enroll::load_persisted_remote_control_enrollment; use crate::transport::remote_control::enroll::preview_remote_control_response_body; -use crate::transport::remote_control::enroll::refresh_remote_control_server; use crate::transport::remote_control::enroll::update_persisted_remote_control_enrollment; +use crate::transport::remote_control::server_api::enroll_remote_control_server; +use crate::transport::remote_control::server_api::refresh_remote_control_server; use axum::http::HeaderValue; use base64::Engine; use codex_app_server_protocol::RemoteControlConnectionStatus; @@ -1566,17 +1566,19 @@ async fn prepare_remote_control_enrollment( ) .await?; } - Err(err) - if err.kind() == ErrorKind::PermissionDenied - && recover_remote_control_auth( - auth_context.auth_recovery, - auth_context.auth_change_rx, - ) - .await => - { - return Err(io::Error::other(format!( - "{err}; retrying after auth recovery" - ))); + Err(err) if err.kind() == ErrorKind::PermissionDenied => { + if recover_remote_control_auth( + auth_context.auth_recovery, + auth_context.auth_change_rx, + ) + .await + { + return Err(io::Error::other(format!( + "{err}; retrying after auth recovery" + ))); + } + enrollment_ref.clear_server_token(); + return Err(err); } Err(err) => return Err(err), } @@ -1681,13 +1683,15 @@ async fn clear_remote_control_server_token_if_matches( enrollment: &RemoteControlEnrollment, ) -> io::Result<()> { let mut current_enrollment = current_enrollment.lock().await; - current_enrollment + let current_enrollment = current_enrollment .as_mut() .filter(|current| same_remote_control_enrollment(current, enrollment)) .ok_or_else(|| { io::Error::other("missing remote control enrollment after websocket auth failure") - })? - .clear_server_token(); + })?; + if current_enrollment.remote_control_token == enrollment.remote_control_token { + current_enrollment.clear_server_token(); + } Ok(()) } @@ -1804,6 +1808,10 @@ fn format_remote_control_websocket_connect_error( message } +#[cfg(test)] +#[path = "websocket_refresh_tests.rs"] +mod refresh_tests; + #[cfg(test)] mod tests { use super::*; @@ -1844,13 +1852,15 @@ mod tests { // Windows Bazel CI can take longer than a few seconds for the websocket // client connection attempt to reach the local test listener. #[cfg(windows)] - const TEST_HTTP_ACCEPT_TIMEOUT: Duration = Duration::from_secs(30); + pub(super) const TEST_HTTP_ACCEPT_TIMEOUT: Duration = Duration::from_secs(30); #[cfg(not(windows))] - const TEST_HTTP_ACCEPT_TIMEOUT: Duration = Duration::from_secs(5); - const TEST_INSTALLATION_ID: &str = "11111111-1111-4111-8111-111111111111"; - const TEST_REMOTE_CONTROL_SERVER_TOKEN: &str = "Remote Control Token"; + pub(super) const TEST_HTTP_ACCEPT_TIMEOUT: Duration = Duration::from_secs(5); + pub(super) const TEST_INSTALLATION_ID: &str = "11111111-1111-4111-8111-111111111111"; + pub(super) const TEST_REMOTE_CONTROL_SERVER_TOKEN: &str = "Remote Control Token"; - fn remote_control_enrollment(remote_control_token: Option<&str>) -> RemoteControlEnrollment { + pub(super) fn remote_control_enrollment( + remote_control_token: Option<&str>, + ) -> RemoteControlEnrollment { RemoteControlEnrollment { remote_control_target: normalize_remote_control_url("http://localhost/backend-api/") .expect("target should normalize"), @@ -1861,10 +1871,11 @@ mod tests { remote_control_token: remote_control_token.map(str::to_string), expires_at: remote_control_token .map(|_| time::OffsetDateTime::now_utc() + time::Duration::hours(1)), + next_refresh_at: None, } } - fn test_current_enrollment( + pub(super) fn test_current_enrollment( enrollment: Option, ) -> CurrentRemoteControlEnrollment { Arc::new(RemoteControlEnrollmentState::new(enrollment)) @@ -1930,7 +1941,7 @@ mod tests { )); } - fn remote_control_status_channel() -> ( + pub(super) fn remote_control_status_channel() -> ( RemoteControlStatusPublisher, watch::Receiver, ) { @@ -1943,7 +1954,7 @@ mod tests { (RemoteControlStatusPublisher::new(status_tx), status_rx) } - fn enabled_desired_state_sender() -> watch::Sender { + pub(super) fn enabled_desired_state_sender() -> watch::Sender { watch::channel(RemoteControlDesiredState::Enabled { persistence_preference: None, }) @@ -1981,17 +1992,17 @@ mod tests { ); } - async fn remote_control_state_runtime(codex_home: &TempDir) -> Arc { + pub(super) async fn remote_control_state_runtime(codex_home: &TempDir) -> Arc { StateRuntime::init(codex_home.path().to_path_buf(), "test-provider".to_string()) .await .expect("state runtime should initialize") } - fn remote_control_auth_manager() -> Arc { + pub(super) fn remote_control_auth_manager() -> Arc { auth_manager_from_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing()) } - fn remote_control_url_for_listener(listener: &TcpListener) -> String { + pub(super) fn remote_control_url_for_listener(listener: &TcpListener) -> String { let addr = listener .local_addr() .expect("listener should have a local addr"); @@ -2126,9 +2137,10 @@ mod tests { let auth_manager = remote_control_auth_manager(); let mut auth_recovery = auth_manager.unauthorized_recovery(); let mut auth_change_rx = auth_manager.auth_change_receiver(); - let current_enrollment = test_current_enrollment(Some(remote_control_enrollment(Some( - TEST_REMOTE_CONTROL_SERVER_TOKEN, - )))); + let next_refresh_at = time::OffsetDateTime::now_utc() + time::Duration::minutes(2); + let mut enrollment = remote_control_enrollment(Some(TEST_REMOTE_CONTROL_SERVER_TOKEN)); + enrollment.next_refresh_at = Some(next_refresh_at); + let current_enrollment = test_current_enrollment(Some(enrollment)); let (status_publisher, status_rx) = remote_control_status_channel(); let server_task = tokio::spawn(async move { @@ -2178,6 +2190,7 @@ mod tests { ); let mut expected_enrollment = remote_control_enrollment(/*remote_control_token*/ None); expected_enrollment.remote_control_target = remote_control_target; + expected_enrollment.next_refresh_at = Some(next_refresh_at); assert_eq!(*current_enrollment.lock().await, Some(expected_enrollment)); } @@ -2318,9 +2331,12 @@ mod tests { .await; let mut auth_recovery = auth_manager.unauthorized_recovery(); let mut auth_change_rx = auth_manager.auth_change_receiver(); - let current_enrollment = test_current_enrollment(Some(remote_control_enrollment( - /*remote_control_token*/ None, - ))); + let mut expected_enrollment = + remote_control_enrollment(Some(TEST_REMOTE_CONTROL_SERVER_TOKEN)); + expected_enrollment.remote_control_target = remote_control_target.clone(); + expected_enrollment.expires_at = + Some(time::OffsetDateTime::now_utc() + time::Duration::minutes(4)); + let current_enrollment = test_current_enrollment(Some(expected_enrollment.clone())); let (status_publisher, status_rx) = remote_control_status_channel(); save_auth( codex_home.path(), @@ -2377,6 +2393,7 @@ mod tests { .expect("token should be readable"), "fresh-token" ); + assert_eq!(current_enrollment.snapshot(), Some(expected_enrollment)); assert!( !auth_change_rx .has_changed() @@ -3411,7 +3428,7 @@ mod tests { state.observe_client_message(envelope, wire_size_bytes) } - async fn accept_http_request(listener: &TcpListener) -> (TcpStream, String) { + pub(super) async fn accept_http_request(listener: &TcpListener) -> (TcpStream, String) { let (stream, _) = timeout(TEST_HTTP_ACCEPT_TIMEOUT, listener.accept()) .await .expect("HTTP request should arrive in time") @@ -3482,7 +3499,7 @@ mod tests { serde_json::from_str(text.as_ref()).expect("server event should deserialize") } - async fn respond_with_status_and_headers( + pub(super) async fn respond_with_status_and_headers( mut stream: TcpStream, status: &str, headers: &[(&str, &str)], diff --git a/codex-rs/app-server-transport/src/transport/remote_control/websocket_refresh_tests.rs b/codex-rs/app-server-transport/src/transport/remote_control/websocket_refresh_tests.rs new file mode 100644 index 000000000000..ee4ea15f8361 --- /dev/null +++ b/codex-rs/app-server-transport/src/transport/remote_control/websocket_refresh_tests.rs @@ -0,0 +1,467 @@ +use super::tests::TEST_HTTP_ACCEPT_TIMEOUT; +use super::tests::TEST_INSTALLATION_ID; +use super::tests::TEST_REMOTE_CONTROL_SERVER_TOKEN; +use super::tests::accept_http_request; +use super::tests::enabled_desired_state_sender; +use super::tests::remote_control_auth_manager; +use super::tests::remote_control_enrollment; +use super::tests::remote_control_state_runtime; +use super::tests::remote_control_status_channel; +use super::tests::remote_control_url_for_listener; +use super::tests::respond_with_status_and_headers; +use super::tests::test_current_enrollment; +use super::*; +use crate::transport::remote_control::protocol::normalize_remote_control_url; +use crate::transport::remote_control::tests::remote_control_handle_with_current_enrollment; +use codex_app_server_protocol::RemoteControlPairingStartParams; +use codex_app_server_protocol::RemoteControlPairingStartResponse; +use pretty_assertions::assert_eq; +use tempfile::TempDir; +use tokio::net::TcpListener; +use tokio::net::TcpStream; +use tokio::time::Duration; +use tokio::time::timeout; +use tokio_tungstenite::WebSocketStream; +use tokio_tungstenite::accept_async; + +async fn connect_test_websocket( + remote_control_target: &RemoteControlTarget, + state_db: &StateRuntime, + auth_manager: &Arc, + current_enrollment: &CurrentRemoteControlEnrollment, +) -> io::Result<()> { + let mut auth_recovery = auth_manager.unauthorized_recovery(); + let mut auth_change_rx = auth_manager.auth_change_receiver(); + let (status_publisher, _) = remote_control_status_channel(); + let desired_state_tx = enabled_desired_state_sender(); + let desired_state_persistence_lock = Semaphore::new(1); + connect_remote_control_websocket( + remote_control_target, + Some(state_db), + RemoteControlAuthContext { + auth_manager, + auth_recovery: &mut auth_recovery, + auth_change_rx: &mut auth_change_rx, + }, + current_enrollment, + RemoteControlConnectOptions { + installation_id: TEST_INSTALLATION_ID, + server_name: "test-server", + subscribe_cursor: None, + app_server_client_name: None, + desired_state_tx: &desired_state_tx, + desired_state_persistence_lock: &desired_state_persistence_lock, + }, + &status_publisher, + ) + .await + .map(|_| ()) +} + +#[tokio::test] +async fn proactive_refresh_failure_uses_valid_token_for_websocket_connect() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let remote_control_target = + normalize_remote_control_url(&remote_control_url).expect("target should parse"); + let server_task = tokio::spawn(async move { + let (stream, request_line) = accept_http_request(&listener).await; + assert_eq!( + request_line, + "POST /backend-api/wham/remote/control/server/refresh HTTP/1.1" + ); + respond_with_status_and_headers(stream, "502 Bad Gateway", &[], "upstream unavailable") + .await; + accept_test_websocket(&listener).await + }); + let codex_home = TempDir::new().expect("temp dir should create"); + let state_db = remote_control_state_runtime(&codex_home).await; + let auth_manager = remote_control_auth_manager(); + let mut enrollment = remote_control_enrollment(Some(TEST_REMOTE_CONTROL_SERVER_TOKEN)); + enrollment.expires_at = Some(time::OffsetDateTime::now_utc() + time::Duration::minutes(4)); + let current_enrollment = test_current_enrollment(Some(enrollment)); + + let refresh_started_at = time::OffsetDateTime::now_utc(); + connect_test_websocket( + &remote_control_target, + state_db.as_ref(), + &auth_manager, + ¤t_enrollment, + ) + .await + .expect("valid token should allow websocket connect after proactive refresh failure"); + let refresh_completed_at = time::OffsetDateTime::now_utc(); + let server_websocket = server_task.await.expect("server task should succeed"); + + let enrollment = current_enrollment + .lock() + .await + .clone() + .expect("enrollment should remain available"); + assert_eq!( + enrollment.remote_control_token.as_deref(), + Some(TEST_REMOTE_CONTROL_SERVER_TOKEN) + ); + let next_refresh_at = enrollment + .next_refresh_at + .expect("transient refresh should set a retry deadline"); + assert!( + (refresh_started_at + time::Duration::seconds(24) + ..=refresh_completed_at + time::Duration::seconds(36)) + .contains(&next_refresh_at) + ); + drop(server_websocket); +} + +#[tokio::test] +async fn proactive_refresh_connection_failure_uses_valid_token_for_websocket_connect() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let remote_control_target = + normalize_remote_control_url(&remote_control_url).expect("target should parse"); + let server_task = tokio::spawn(async move { + let (stream, request_line) = accept_http_request(&listener).await; + assert_eq!( + request_line, + "POST /backend-api/wham/remote/control/server/refresh HTTP/1.1" + ); + drop(stream); + accept_test_websocket(&listener).await + }); + let codex_home = TempDir::new().expect("temp dir should create"); + let state_db = remote_control_state_runtime(&codex_home).await; + let auth_manager = remote_control_auth_manager(); + let mut enrollment = remote_control_enrollment(Some(TEST_REMOTE_CONTROL_SERVER_TOKEN)); + enrollment.expires_at = Some(time::OffsetDateTime::now_utc() + time::Duration::minutes(4)); + let current_enrollment = test_current_enrollment(Some(enrollment)); + + connect_test_websocket( + &remote_control_target, + state_db.as_ref(), + &auth_manager, + ¤t_enrollment, + ) + .await + .expect("valid token should allow websocket connect after refresh connection failure"); + let server_websocket = server_task.await.expect("server task should succeed"); + + assert!( + current_enrollment + .snapshot() + .and_then(|enrollment| enrollment.next_refresh_at) + .is_some(), + "connection failure should set a retry deadline" + ); + drop(server_websocket); +} + +#[tokio::test] +async fn websocket_retry_after_throttles_pairing_refresh() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let remote_control_target = + normalize_remote_control_url(&remote_control_url).expect("target should parse"); + let server_task = tokio::spawn(async move { + let (stream, request_line) = accept_http_request(&listener).await; + assert_eq!( + request_line, + "POST /backend-api/wham/remote/control/server/refresh HTTP/1.1" + ); + respond_with_status_and_headers( + stream, + "502 Bad Gateway", + &[("retry-after", "120")], + "upstream unavailable", + ) + .await; + let first_websocket = accept_test_websocket(&listener).await; + let (pairing_stream, request_line) = accept_http_request(&listener).await; + assert_eq!( + request_line, + "POST /backend-api/wham/remote/control/server/pair HTTP/1.1" + ); + respond_with_status_and_headers( + pairing_stream, + "200 OK", + &[], + r#"{"pairing_code":"pairing-code","manual_pairing_code":"ABCD-EFGH","server_id":"srv_e_test","environment_id":"env_test","expires_at":"3026-05-22T12:34:56Z"}"#, + ) + .await; + first_websocket + }); + let codex_home = TempDir::new().expect("temp dir should create"); + let state_db = remote_control_state_runtime(&codex_home).await; + let auth_manager = remote_control_auth_manager(); + let mut remote_handle = + remote_control_handle_with_current_enrollment(&remote_control_url, auth_manager.clone()); + remote_handle.state_db = Some(state_db.clone()); + remote_handle + .current_enrollment + .lock() + .await + .as_mut() + .expect("current enrollment should exist") + .expires_at = Some(time::OffsetDateTime::now_utc() + time::Duration::minutes(4)); + let current_enrollment = remote_handle.current_enrollment.clone(); + let refresh_started_at = time::OffsetDateTime::now_utc(); + connect_test_websocket( + &remote_control_target, + state_db.as_ref(), + &auth_manager, + ¤t_enrollment, + ) + .await + .expect("first websocket should connect after deferred refresh"); + let refresh_completed_at = time::OffsetDateTime::now_utc(); + let next_refresh_at = current_enrollment + .snapshot() + .and_then(|enrollment| enrollment.next_refresh_at) + .expect("Retry-After should set a retry deadline"); + assert!( + (refresh_started_at + time::Duration::seconds(120) + ..=refresh_completed_at + time::Duration::seconds(120)) + .contains(&next_refresh_at) + ); + + let pairing_response = remote_handle + .start_pairing( + RemoteControlPairingStartParams::default(), + /*app_server_client_name*/ None, + ) + .await + .expect("websocket Retry-After should throttle pairing refresh"); + let first_server_websocket = server_task.await.expect("server task should succeed"); + + assert_eq!( + pairing_response, + RemoteControlPairingStartResponse { + pairing_code: "pairing-code".to_string(), + manual_pairing_code: Some("ABCD-EFGH".to_string()), + environment_id: "env_test".to_string(), + expires_at: 33_336_362_096, + } + ); + drop(first_server_websocket); +} + +#[tokio::test] +async fn pairing_http_date_retry_after_throttles_websocket_refresh() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let remote_control_target = + normalize_remote_control_url(&remote_control_url).expect("target should parse"); + let retry_after = + httpdate::fmt_http_date(std::time::SystemTime::now() + Duration::from_secs(120)); + let expected_next_refresh_at = time::OffsetDateTime::from( + httpdate::parse_http_date(&retry_after).expect("Retry-After date should parse"), + ); + let server_task = tokio::spawn(async move { + let (refresh_stream, request_line) = accept_http_request(&listener).await; + assert_eq!( + request_line, + "POST /backend-api/wham/remote/control/server/refresh HTTP/1.1" + ); + respond_with_status_and_headers( + refresh_stream, + "502 Bad Gateway", + &[("retry-after", &retry_after)], + "upstream unavailable", + ) + .await; + let (pairing_stream, request_line) = accept_http_request(&listener).await; + assert_eq!( + request_line, + "POST /backend-api/wham/remote/control/server/pair HTTP/1.1" + ); + respond_with_status_and_headers( + pairing_stream, + "200 OK", + &[], + r#"{"pairing_code":"pairing-code","manual_pairing_code":"ABCD-EFGH","server_id":"srv_e_test","environment_id":"env_test","expires_at":"3026-05-22T12:34:56Z"}"#, + ) + .await; + accept_test_websocket(&listener).await + }); + let codex_home = TempDir::new().expect("temp dir should create"); + let state_db = remote_control_state_runtime(&codex_home).await; + let auth_manager = remote_control_auth_manager(); + let mut remote_handle = + remote_control_handle_with_current_enrollment(&remote_control_url, auth_manager.clone()); + remote_handle.state_db = Some(state_db.clone()); + remote_handle + .current_enrollment + .lock() + .await + .as_mut() + .expect("current enrollment should exist") + .expires_at = Some(time::OffsetDateTime::now_utc() + time::Duration::minutes(4)); + let current_enrollment = remote_handle.current_enrollment.clone(); + + let pairing_response = remote_handle + .start_pairing( + RemoteControlPairingStartParams::default(), + /*app_server_client_name*/ None, + ) + .await + .expect("pairing should continue after proactive refresh failure"); + assert_eq!( + current_enrollment + .snapshot() + .and_then(|enrollment| enrollment.next_refresh_at), + Some(expected_next_refresh_at) + ); + connect_test_websocket( + &remote_control_target, + state_db.as_ref(), + &auth_manager, + ¤t_enrollment, + ) + .await + .expect("pairing Retry-After should throttle websocket refresh"); + let server_websocket = server_task.await.expect("server task should succeed"); + + assert_eq!( + pairing_response, + RemoteControlPairingStartResponse { + pairing_code: "pairing-code".to_string(), + manual_pairing_code: Some("ABCD-EFGH".to_string()), + environment_id: "env_test".to_string(), + expires_at: 33_336_362_096, + } + ); + drop(server_websocket); +} + +async fn assert_refresh_failure_blocks_websocket( + expires_in: time::Duration, + response_delay: Duration, +) { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let remote_control_target = + normalize_remote_control_url(&remote_control_url).expect("target should parse"); + let (connects_done_tx, connects_done_rx) = oneshot::channel(); + let server_task = tokio::spawn(async move { + let (stream, request_line) = accept_http_request(&listener).await; + assert_eq!( + request_line, + "POST /backend-api/wham/remote/control/server/refresh HTTP/1.1" + ); + tokio::time::sleep(response_delay).await; + respond_with_status_and_headers( + stream, + "502 Bad Gateway", + &[("retry-after", "120")], + "upstream unavailable", + ) + .await; + assert_no_connection_until_connect_finishes(&listener, connects_done_rx).await; + }); + let codex_home = TempDir::new().expect("temp dir should create"); + let state_db = remote_control_state_runtime(&codex_home).await; + let auth_manager = remote_control_auth_manager(); + let mut enrollment = remote_control_enrollment(Some(TEST_REMOTE_CONTROL_SERVER_TOKEN)); + enrollment.expires_at = Some(time::OffsetDateTime::now_utc() + expires_in); + let current_enrollment = test_current_enrollment(Some(enrollment)); + + let refresh_started_at = time::OffsetDateTime::now_utc(); + let refresh_err = connect_test_websocket( + &remote_control_target, + state_db.as_ref(), + &auth_manager, + ¤t_enrollment, + ) + .await + .expect_err("required refresh failure should block websocket connect"); + let refresh_completed_at = time::OffsetDateTime::now_utc(); + let deferred_err = connect_test_websocket( + &remote_control_target, + state_db.as_ref(), + &auth_manager, + ¤t_enrollment, + ) + .await + .expect_err("required refresh deadline should block websocket reconnect"); + connects_done_tx + .send(()) + .expect("server should wait for connect attempts to finish"); + + server_task.await.expect("server task should succeed"); + assert!(refresh_err.to_string().contains("HTTP 502 Bad Gateway")); + assert_eq!(deferred_err.kind(), io::ErrorKind::WouldBlock); + assert!(deferred_err.to_string().contains("refresh deferred until")); + let next_refresh_at = current_enrollment + .snapshot() + .and_then(|enrollment| enrollment.next_refresh_at) + .expect("required refresh failure should set a retry deadline"); + assert!( + (refresh_started_at + time::Duration::seconds(120) + ..=refresh_completed_at + time::Duration::seconds(120)) + .contains(&next_refresh_at) + ); +} + +#[tokio::test] +async fn expired_token_refresh_failure_throttles_reconnect_without_websocket() { + assert_refresh_failure_blocks_websocket(-time::Duration::seconds(1), Duration::ZERO).await; +} + +#[tokio::test] +async fn token_expiring_during_refresh_failure_throttles_reconnect_without_websocket() { + assert_refresh_failure_blocks_websocket( + time::Duration::seconds(1), + Duration::from_millis(1_200), + ) + .await; +} + +#[tokio::test] +async fn websocket_auth_failure_does_not_clear_rotated_server_token() { + let attempted_enrollment = remote_control_enrollment(Some("old-token")); + let mut rotated_enrollment = attempted_enrollment.clone(); + rotated_enrollment.remote_control_token = Some("new-token".to_string()); + rotated_enrollment.expires_at = + Some(time::OffsetDateTime::now_utc() + time::Duration::hours(1)); + let current_enrollment = test_current_enrollment(Some(rotated_enrollment.clone())); + + clear_remote_control_server_token_if_matches(¤t_enrollment, &attempted_enrollment) + .await + .expect("matching enrollment identity should remain available"); + + assert_eq!(current_enrollment.snapshot(), Some(rotated_enrollment)); +} + +async fn accept_test_websocket(listener: &TcpListener) -> WebSocketStream { + let (stream, _) = timeout(TEST_HTTP_ACCEPT_TIMEOUT, listener.accept()) + .await + .expect("websocket request should arrive in time") + .expect("listener accept should succeed"); + accept_async(stream) + .await + .expect("websocket handshake should succeed") +} + +async fn assert_no_connection_until_connect_finishes( + listener: &TcpListener, + mut connect_done_rx: oneshot::Receiver<()>, +) { + tokio::select! { + accepted = listener.accept() => { + accepted.expect("unexpected websocket connection should be accepted"); + panic!("required refresh failure must not proceed to websocket connect"); + } + connect_done = &mut connect_done_rx => { + connect_done.expect("connect completion should be reported"); + } + } +}