From 4931e83e8175924f5974e6a30a5fb68c91be1a19 Mon Sep 17 00:00:00 2001 From: Derek Carr Date: Thu, 14 May 2026 00:59:00 -0400 Subject: [PATCH] feat(persistence): implement optimistic concurrency control with CAS Adds Compare-And-Swap (CAS) based optimistic concurrency control to prevent lost updates in concurrent modification scenarios. This implements client-driven CAS for update operations and proper atomic create protection. Database Changes: - Add resource_version field to ObjectMeta proto message - Add resource_version column to objects table (migration 005) - Implement update_message_cas for single-attempt versioned updates - Add WriteCondition::MatchResourceVersion for CAS enforcement - Add PersistenceError::Conflict for version mismatch detection Protected Operations: - AttachSandboxProvider: uses client-driven CAS with expected_resource_version parameter - DetachSandboxProvider: uses client-driven CAS with expected_resource_version parameter - UpdateProvider: extracts resource_version from Provider.metadata, validates against current version - UpdateConfig: uses client-driven CAS for policy backfill path - CreateProvider: uses WriteCondition::MustCreate for atomic creation - SSH session operations: proper CAS protection CAS Modes: - Client-driven (expected_version > 0): Client fetches resource, uses its version for update. Conflict returns ABORTED status. - Server-driven (expected_version = 0): Server uses current DB version. Used for internal operations. CLI Changes: - Update attach/detach operations to fetch sandbox first and use its resource_version for CAS protection - Add clear error messages for ABORTED status on CAS conflicts - Add expected_resource_version: 0 to all UpdateConfig requests Testing: - 12 integration tests for concurrent modification scenarios - Tests verify ABORTED status on version conflicts - Coverage for all protected operations Documentation: - Update architecture/gateway.md with CAS design and semantics - Document expected_version parameter modes - List all client-driven CAS operations Signed-off-by: Derek Carr --- architecture/gateway.md | 52 ++ crates/openshell-cli/src/run.rs | 81 +- .../tests/ensure_providers_integration.rs | 2 + .../tests/provider_commands_integration.rs | 27 +- .../sandbox_create_lifecycle_integration.rs | 3 + .../sandbox_name_fallback_integration.rs | 1 + crates/openshell-core/src/lib.rs | 2 +- crates/openshell-core/src/metadata.rs | 107 +++ crates/openshell-sandbox/src/grpc_client.rs | 1 + .../postgres/005_add_resource_version.sql | 5 + .../sqlite/005_add_resource_version.sql | 5 + crates/openshell-server/src/compute/mod.rs | 406 ++++++--- crates/openshell-server/src/grpc/mod.rs | 27 + crates/openshell-server/src/grpc/policy.rs | 514 ++++++++++-- crates/openshell-server/src/grpc/provider.rs | 582 ++++++++++++- crates/openshell-server/src/grpc/sandbox.rs | 786 +++++++++++++++++- crates/openshell-server/src/grpc/service.rs | 192 ++++- .../openshell-server/src/grpc/validation.rs | 1 + crates/openshell-server/src/inference.rs | 211 ++++- .../openshell-server/src/persistence/mod.rs | 210 ++++- .../src/persistence/postgres.rs | 172 +++- .../src/persistence/sqlite.rs | 159 +++- .../openshell-server/src/persistence/tests.rs | 459 +++++++++- .../openshell-server/src/service_routing.rs | 1 + crates/openshell-server/src/ssh_sessions.rs | 1 + .../src/supervisor_session.rs | 1 + crates/openshell-tui/src/lib.rs | 6 + proto/datamodel.proto | 6 +- proto/openshell.proto | 16 + 29 files changed, 3750 insertions(+), 286 deletions(-) create mode 100644 crates/openshell-server/migrations/postgres/005_add_resource_version.sql create mode 100644 crates/openshell-server/migrations/sqlite/005_add_resource_version.sql diff --git a/architecture/gateway.md b/architecture/gateway.md index 68832d0cf..5e1dcca09 100644 --- a/architecture/gateway.md +++ b/architecture/gateway.md @@ -83,6 +83,7 @@ The storage schema is intentionally narrow: | `version` | Optional monotonically increasing version for scoped records. | | `status` | Optional workflow state for records such as policy revisions or draft policy chunks. | | `dedup_key` and `hit_count` | Optional policy-advisor fields for coalescing repeated observations. | +| `resource_version` | Monotonically increasing counter for optimistic concurrency control. Incremented atomically on each update. | | `payload` | Prost-encoded protobuf payload for the full domain object. | | `created_at_ms` and `updated_at_ms` | Gateway timestamps used for ordering and list output. | | `labels` | JSON object carrying Kubernetes-style object labels for filtering and organization. | @@ -113,6 +114,57 @@ default WAL journal mode), which mirror the same sensitive contents. Persisted state includes sandboxes, providers, SSH sessions, policy revisions, settings, inference configuration, and deployment records. +### Optimistic Concurrency (CAS) + +Every object row carries a `resource_version` that the database increments +atomically on each write. Concurrent mutations use compare-and-swap (CAS): the +writer reads the current version, applies changes, and writes back with a +`WHERE resource_version = ` guard. If another writer updated the row +in between, the guard fails and the caller receives a `Conflict` error. + +This matters for HA deployments where multiple gateway replicas share the same +Postgres database, and for single-node deployments where concurrent gRPC +handlers or the reconciler mutate the same sandbox. + +**When to use CAS** -- any mutation that merges caller-supplied fields into an +existing object: + +- Provider credential and config updates (merge maps). +- Sandbox provider attach/detach (append/remove from a list). +- Policy version bumps and draft operations. +- Compute status updates (sandbox phase transitions and reconciliation). + +**When CAS is not needed** -- create operations that generate a unique ID +(conflicts are caught by the primary key constraint), unconditional deletes, +and idempotent overwrites where the full payload is self-contained. + +The `update_message_cas` helper makes a single CAS attempt: it fetches the +current object, applies a mutation closure, and writes with a +`MatchResourceVersion` condition. On conflict the persistence layer returns a +`Conflict` error, which gRPC handlers map to `ABORTED` status so clients can +read fresh state and retry. + +The helper accepts an `expected_version` parameter that selects between two +modes: + +- **Server-driven** (`expected_version = 0`): the helper uses the version it + just read from the database. Internal operations (reconciler, policy status + reports, compute phase transitions) use this mode because the caller does + not track versions. +- **Client-driven** (`expected_version != 0`): the helper validates that the + caller's version matches the current database version before applying the + mutation. If they diverge it returns `Conflict` without attempting the + write. Client-facing operations that carry an `expected_resource_version` + field use this mode: `AttachSandboxProvider`, `DetachSandboxProvider`, + `UpdateProvider`, and `UpdateConfig` (policy backfill path). + +Settings updates are an exception: they use a Tokio `Mutex` instead of CAS +because settings operations require multi-step validation that is simpler under +an exclusive lock than within a CAS write. + +The `resource_version` is surfaced to clients through `ObjectMeta` in proto +responses. Database migrations backfill existing rows with version 1. + Policy and runtime settings are delivered together through the effective sandbox config path. A gateway-global policy can override sandbox-scoped policy. The sandbox supervisor polls for config revisions and hot-reloads dynamic policy diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 2962e546b..df7928119 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -2279,6 +2279,11 @@ pub async fn sandbox_get( println!(" {} {}", "Id:".dimmed(), id); println!(" {} {}", "Name:".dimmed(), name); println!(" {} {}", "Phase:".dimmed(), phase_name(sandbox.phase)); + println!( + " {} {}", + "Resource version:".dimmed(), + sandbox.metadata.as_ref().map_or(0, |m| m.resource_version) + ); // Display labels if present if let Some(metadata) = &sandbox.metadata @@ -2888,14 +2893,38 @@ pub async fn sandbox_provider_attach( tls: &TlsOptions, ) -> Result<()> { let mut client = grpc_client(server, tls).await?; - let response = client + + // Fetch current sandbox to get resource_version for CAS + let sandbox = client + .get_sandbox(GetSandboxRequest { + name: name.to_string(), + }) + .await + .into_diagnostic()? + .into_inner() + .sandbox + .ok_or_else(|| miette::miette!("sandbox not found"))?; + + let resource_version = sandbox.metadata.as_ref().map_or(0, |m| m.resource_version); + + let response = match client .attach_sandbox_provider(AttachSandboxProviderRequest { sandbox_name: name.to_string(), provider_name: provider.to_string(), + expected_resource_version: resource_version, }) .await - .into_diagnostic()? - .into_inner(); + { + Ok(response) => response.into_inner(), + Err(status) if status.code() == Code::Aborted => { + return Err(miette::miette!( + "Failed to attach provider: sandbox was modified by another operation.\n\ + Please retry the command." + ) + .with_source_code(status.message().to_string())); + } + Err(e) => return Err(e).into_diagnostic(), + }; if response.attached { println!( @@ -2917,14 +2946,38 @@ pub async fn sandbox_provider_detach( tls: &TlsOptions, ) -> Result<()> { let mut client = grpc_client(server, tls).await?; - let response = client + + // Fetch current sandbox to get resource_version for CAS + let sandbox = client + .get_sandbox(GetSandboxRequest { + name: name.to_string(), + }) + .await + .into_diagnostic()? + .into_inner() + .sandbox + .ok_or_else(|| miette::miette!("sandbox not found"))?; + + let resource_version = sandbox.metadata.as_ref().map_or(0, |m| m.resource_version); + + let response = match client .detach_sandbox_provider(DetachSandboxProviderRequest { sandbox_name: name.to_string(), provider_name: provider.to_string(), + expected_resource_version: resource_version, }) .await - .into_diagnostic()? - .into_inner(); + { + Ok(response) => response.into_inner(), + Err(status) if status.code() == Code::Aborted => { + return Err(miette::miette!( + "Failed to detach provider: sandbox was modified by another operation.\n\ + Please retry the command." + ) + .with_source_code(status.message().to_string())); + } + Err(e) => return Err(e).into_diagnostic(), + }; if response.detached { println!( @@ -3259,6 +3312,7 @@ async fn auto_create_provider( name: exact_name.to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: provider_type.to_string(), credentials: discovered.credentials.clone(), @@ -3299,6 +3353,7 @@ async fn auto_create_provider( name: name.clone(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: provider_type.to_string(), credentials: discovered.credentials.clone(), @@ -3711,6 +3766,7 @@ pub async fn provider_create( name: name.to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: provider_type.clone(), credentials: credential_map, @@ -3755,6 +3811,11 @@ pub async fn provider_get(server: &str, name: &str, tls: &TlsOptions) -> Result< println!(" {} {}", "Id:".dimmed(), provider.object_id()); println!(" {} {}", "Name:".dimmed(), provider.object_name()); println!(" {} {}", "Type:".dimmed(), provider.r#type); + println!( + " {} {}", + "Resource version:".dimmed(), + provider.metadata.as_ref().map_or(0, |m| m.resource_version) + ); println!( " {} {}", "Credential keys:".dimmed(), @@ -4211,6 +4272,7 @@ pub async fn provider_update( name: name.to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: String::new(), credentials: credential_map, @@ -4765,6 +4827,7 @@ pub async fn sandbox_policy_set_global( delete_setting: false, global: true, merge_operations: vec![], + expected_resource_version: 0, }) .await .into_diagnostic()? @@ -4963,6 +5026,7 @@ pub async fn gateway_setting_set( delete_setting: false, global: true, merge_operations: vec![], + expected_resource_version: 0, }) .await .into_diagnostic()? @@ -4997,6 +5061,7 @@ pub async fn sandbox_setting_set( delete_setting: false, global: false, merge_operations: vec![], + expected_resource_version: 0, }) .await .into_diagnostic()? @@ -5031,6 +5096,7 @@ pub async fn gateway_setting_delete( delete_setting: true, global: true, merge_operations: vec![], + expected_resource_version: 0, }) .await .into_diagnostic()? @@ -5065,6 +5131,7 @@ pub async fn sandbox_setting_delete( delete_setting: true, global: false, merge_operations: vec![], + expected_resource_version: 0, }) .await .into_diagnostic()? @@ -5123,6 +5190,7 @@ pub async fn sandbox_policy_set( delete_setting: false, global: false, merge_operations: vec![], + expected_resource_version: 0, }) .await .into_diagnostic()?; @@ -5297,6 +5365,7 @@ pub async fn sandbox_policy_update( delete_setting: false, global: false, merge_operations: plan.merge_operations, + expected_resource_version: 0, }) .await .into_diagnostic()? diff --git a/crates/openshell-cli/tests/ensure_providers_integration.rs b/crates/openshell-cli/tests/ensure_providers_integration.rs index b857699af..a4f4bfb5a 100644 --- a/crates/openshell-cli/tests/ensure_providers_integration.rs +++ b/crates/openshell-cli/tests/ensure_providers_integration.rs @@ -113,6 +113,7 @@ impl TestOpenShell { name: name.to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: provider_type.to_string(), credentials: HashMap::new(), @@ -377,6 +378,7 @@ impl OpenShell for TestOpenShell { name: provider_metadata.name, created_at_ms: existing_metadata.created_at_ms, labels: existing_metadata.labels, + resource_version: 0, }), r#type: existing.r#type, credentials: merge(existing.credentials, provider.credentials), diff --git a/crates/openshell-cli/tests/provider_commands_integration.rs b/crates/openshell-cli/tests/provider_commands_integration.rs index a63d9d310..2bc0908bc 100644 --- a/crates/openshell-cli/tests/provider_commands_integration.rs +++ b/crates/openshell-cli/tests/provider_commands_integration.rs @@ -15,7 +15,7 @@ use openshell_core::proto::{ HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, ListSandboxProvidersRequest, ListSandboxProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, Provider, ProviderProfile, ProviderResponse, RevokeSshSessionRequest, - RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, ServiceStatus, + RevokeSshSessionResponse, Sandbox, SandboxResponse, SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, WatchSandboxRequest, }; use openshell_core::{ObjectId, ObjectName}; @@ -111,9 +111,25 @@ impl OpenShell for TestOpenShell { async fn get_sandbox( &self, - _request: tonic::Request, + request: tonic::Request, ) -> Result, Status> { - Ok(Response::new(SandboxResponse::default())) + let name = request.into_inner().name; + // Return a minimal sandbox with metadata for CAS operations + Ok(Response::new(SandboxResponse { + sandbox: Some(Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: format!("sb-{name}"), + name, + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 1, + }), + spec: None, + status: None, + phase: 0, + current_policy_version: 0, + }), + })) } async fn list_sandboxes( @@ -183,7 +199,7 @@ impl OpenShell for TestOpenShell { providers.push(request.provider_name.clone()); true }; - let sandbox = openshell_core::proto::Sandbox { + let sandbox = Sandbox { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { name: request.sandbox_name, ..Default::default() @@ -220,7 +236,7 @@ impl OpenShell for TestOpenShell { let before_len = providers.len(); providers.retain(|name| name != &request.provider_name); let detached = providers.len() != before_len; - let sandbox = openshell_core::proto::Sandbox { + let sandbox = Sandbox { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { name: request.sandbox_name, ..Default::default() @@ -475,6 +491,7 @@ impl OpenShell for TestOpenShell { name: provider_metadata.name, created_at_ms: existing_metadata.created_at_ms, labels: existing_metadata.labels, + resource_version: 0, }), r#type: existing.r#type, credentials: merge(existing.credentials, provider.credentials), diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index 6e7d66d11..ab085077a 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -121,6 +121,7 @@ impl OpenShell for TestOpenShell { name: sandbox_name, created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), phase: SandboxPhase::Provisioning as i32, ..Sandbox::default() @@ -140,6 +141,7 @@ impl OpenShell for TestOpenShell { name, created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), phase: SandboxPhase::Ready as i32, ..Sandbox::default() @@ -354,6 +356,7 @@ impl OpenShell for TestOpenShell { name: sandbox_id.trim_start_matches("id-").to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), phase: SandboxPhase::Provisioning as i32, ..Sandbox::default() diff --git a/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs b/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs index b957dfd46..7c5080f5e 100644 --- a/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs +++ b/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs @@ -119,6 +119,7 @@ impl OpenShell for TestOpenShell { name, created_at_ms: 0, labels: std::collections::HashMap::new(), + resource_version: 0, }), ..Default::default() }), diff --git a/crates/openshell-core/src/lib.rs b/crates/openshell-core/src/lib.rs index 893b01f5f..edb0afee3 100644 --- a/crates/openshell-core/src/lib.rs +++ b/crates/openshell-core/src/lib.rs @@ -23,7 +23,7 @@ pub mod settings; pub use config::{ComputeDriverKind, Config, OidcConfig, TlsConfig}; pub use error::{ComputeDriverError, Error, Result}; -pub use metadata::{ObjectId, ObjectLabels, ObjectName}; +pub use metadata::{GetResourceVersion, ObjectId, ObjectLabels, ObjectName, SetResourceVersion}; /// Build version string derived from git metadata. /// diff --git a/crates/openshell-core/src/metadata.rs b/crates/openshell-core/src/metadata.rs index 6f7b7b0a4..efae9dacf 100644 --- a/crates/openshell-core/src/metadata.rs +++ b/crates/openshell-core/src/metadata.rs @@ -26,6 +26,16 @@ pub trait ObjectLabels { fn object_labels(&self) -> Option>; } +/// Provides mutable access to set the object's resource version from persistence. +pub trait SetResourceVersion { + fn set_resource_version(&mut self, version: u64); +} + +/// Provides read access to the object's current resource version. +pub trait GetResourceVersion { + fn get_resource_version(&self) -> u64; +} + // Implementations for Sandbox impl ObjectId for Sandbox { fn object_id(&self) -> &str { @@ -45,6 +55,20 @@ impl ObjectLabels for Sandbox { } } +impl SetResourceVersion for Sandbox { + fn set_resource_version(&mut self, version: u64) { + if let Some(meta) = self.metadata.as_mut() { + meta.resource_version = version; + } + } +} + +impl GetResourceVersion for Sandbox { + fn get_resource_version(&self) -> u64 { + self.metadata.as_ref().map_or(0, |m| m.resource_version) + } +} + // Implementations for Provider impl ObjectId for Provider { fn object_id(&self) -> &str { @@ -64,6 +88,20 @@ impl ObjectLabels for Provider { } } +impl SetResourceVersion for Provider { + fn set_resource_version(&mut self, version: u64) { + if let Some(meta) = self.metadata.as_mut() { + meta.resource_version = version; + } + } +} + +impl GetResourceVersion for Provider { + fn get_resource_version(&self) -> u64 { + self.metadata.as_ref().map_or(0, |m| m.resource_version) + } +} + // Implementations for StoredProviderProfile impl ObjectId for StoredProviderProfile { fn object_id(&self) -> &str { @@ -83,6 +121,20 @@ impl ObjectLabels for StoredProviderProfile { } } +impl SetResourceVersion for StoredProviderProfile { + fn set_resource_version(&mut self, version: u64) { + if let Some(meta) = self.metadata.as_mut() { + meta.resource_version = version; + } + } +} + +impl GetResourceVersion for StoredProviderProfile { + fn get_resource_version(&self) -> u64 { + self.metadata.as_ref().map_or(0, |m| m.resource_version) + } +} + // Implementations for SshSession impl ObjectId for SshSession { fn object_id(&self) -> &str { @@ -102,6 +154,20 @@ impl ObjectLabels for SshSession { } } +impl SetResourceVersion for SshSession { + fn set_resource_version(&mut self, version: u64) { + if let Some(meta) = self.metadata.as_mut() { + meta.resource_version = version; + } + } +} + +impl GetResourceVersion for SshSession { + fn get_resource_version(&self) -> u64 { + self.metadata.as_ref().map_or(0, |m| m.resource_version) + } +} + // Implementations for ServiceEndpoint impl ObjectId for ServiceEndpoint { fn object_id(&self) -> &str { @@ -121,6 +187,20 @@ impl ObjectLabels for ServiceEndpoint { } } +impl SetResourceVersion for ServiceEndpoint { + fn set_resource_version(&mut self, version: u64) { + if let Some(meta) = self.metadata.as_mut() { + meta.resource_version = version; + } + } +} + +impl GetResourceVersion for ServiceEndpoint { + fn get_resource_version(&self) -> u64 { + self.metadata.as_ref().map_or(0, |m| m.resource_version) + } +} + // Implementations for InferenceRoute impl ObjectId for InferenceRoute { fn object_id(&self) -> &str { @@ -140,6 +220,20 @@ impl ObjectLabels for InferenceRoute { } } +impl SetResourceVersion for InferenceRoute { + fn set_resource_version(&mut self, version: u64) { + if let Some(meta) = self.metadata.as_mut() { + meta.resource_version = version; + } + } +} + +impl GetResourceVersion for InferenceRoute { + fn get_resource_version(&self) -> u64 { + self.metadata.as_ref().map_or(0, |m| m.resource_version) + } +} + // Implementations for ObjectForTest (test-only proto type) impl ObjectId for ObjectForTest { fn object_id(&self) -> &str { @@ -158,3 +252,16 @@ impl ObjectLabels for ObjectForTest { None } } + +impl SetResourceVersion for ObjectForTest { + fn set_resource_version(&mut self, _version: u64) { + // ObjectForTest doesn't have metadata, so this is a no-op + } +} + +impl GetResourceVersion for ObjectForTest { + fn get_resource_version(&self) -> u64 { + // ObjectForTest doesn't have metadata + 0 + } +} diff --git a/crates/openshell-sandbox/src/grpc_client.rs b/crates/openshell-sandbox/src/grpc_client.rs index 19a6831e5..b6dafa9ea 100644 --- a/crates/openshell-sandbox/src/grpc_client.rs +++ b/crates/openshell-sandbox/src/grpc_client.rs @@ -189,6 +189,7 @@ async fn sync_policy_with_client( delete_setting: false, global: false, merge_operations: vec![], + expected_resource_version: 0, }) .await .into_diagnostic() diff --git a/crates/openshell-server/migrations/postgres/005_add_resource_version.sql b/crates/openshell-server/migrations/postgres/005_add_resource_version.sql new file mode 100644 index 000000000..e6a294d62 --- /dev/null +++ b/crates/openshell-server/migrations/postgres/005_add_resource_version.sql @@ -0,0 +1,5 @@ +-- Add resource_version column for optimistic concurrency control +ALTER TABLE objects ADD COLUMN resource_version BIGINT NOT NULL DEFAULT 1; + +-- Backfill existing rows with resource_version = 1 +-- (DEFAULT clause handles this automatically for existing rows in PostgreSQL) diff --git a/crates/openshell-server/migrations/sqlite/005_add_resource_version.sql b/crates/openshell-server/migrations/sqlite/005_add_resource_version.sql new file mode 100644 index 000000000..50aacb99d --- /dev/null +++ b/crates/openshell-server/migrations/sqlite/005_add_resource_version.sql @@ -0,0 +1,5 @@ +-- Add resource_version column for optimistic concurrency control +ALTER TABLE objects ADD COLUMN resource_version INTEGER NOT NULL DEFAULT 1; + +-- Backfill existing rows with resource_version = 1 +-- (DEFAULT clause handles this automatically for existing rows in SQLite) diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index d2fd34011..fc6f64081 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -9,7 +9,7 @@ pub use openshell_driver_docker::DockerComputeConfig; pub use vm::VmComputeConfig; use crate::grpc::policy::SANDBOX_SETTINGS_OBJECT_TYPE; -use crate::persistence::{ObjectId, ObjectName, ObjectRecord, ObjectType, Store}; +use crate::persistence::{ObjectId, ObjectName, ObjectRecord, ObjectType, Store, WriteCondition}; use crate::sandbox_index::SandboxIndex; use crate::sandbox_watch::SandboxWatchBus; use crate::supervisor_session::SupervisorSessionRegistry; @@ -422,23 +422,39 @@ impl ComputeRuntime { } pub async fn create_sandbox(&self, sandbox: Sandbox) -> Result { - let existing = self - .store - .get_message_by_name::(sandbox.object_name()) - .await - .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))?; - if existing.is_some() { - return Err(Status::already_exists(format!( - "sandbox '{}' already exists", - sandbox.object_name() - ))); + // Generate UUID for database row and update metadata.id to match + let sandbox_id = uuid::Uuid::new_v4().to_string(); + let mut sandbox = sandbox; + if let Some(metadata) = sandbox.metadata.as_mut() { + metadata.id.clone_from(&sandbox_id); } + // Create with MustCreate condition to prevent duplicate creation race self.sandbox_index.update_from_sandbox(&sandbox); self.store - .put_message(&sandbox) + .put_if( + Sandbox::object_type(), + &sandbox_id, + sandbox.object_name(), + &sandbox.encode_to_vec(), + None, + WriteCondition::MustCreate, + ) .await - .map_err(|e| Status::internal(format!("persist sandbox failed: {e}")))?; + .map_err(|e| { + let err_str = e.to_string(); + if err_str.contains("unique constraint") + || err_str.contains("UNIQUE constraint") + || err_str.contains("unique violation") + { + Status::already_exists(format!( + "sandbox '{}' already exists", + sandbox.object_name() + )) + } else { + Status::internal(format!("persist sandbox failed: {e}")) + } + })?; let driver_sandbox = driver_sandbox_from_public(&sandbox); match self @@ -450,6 +466,13 @@ impl ComputeRuntime { { Ok(_) => { self.sandbox_watch_bus.notify(sandbox.object_id()); + // Read back from DB to get correct resource_version + let sandbox = self + .store + .get_message_by_name::(sandbox.object_name()) + .await + .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? + .ok_or_else(|| Status::internal("sandbox disappeared after creation"))?; Ok(sandbox) } Err(status) if status.code() == Code::AlreadyExists => { @@ -483,22 +506,31 @@ impl ComputeRuntime { } pub async fn delete_sandbox(&self, name: &str) -> Result { + // Resolve sandbox ID from name let sandbox = self .store .get_message_by_name::(name) .await .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))?; - let Some(mut sandbox) = sandbox else { + let Some(sandbox) = sandbox else { return Err(Status::not_found("sandbox not found")); }; let id = sandbox.object_id().to_string(); - sandbox.phase = SandboxPhase::Deleting as i32; - self.store - .put_message(&sandbox) + + // Use CAS to set phase to Deleting + // TODO: Accept expected_version from DeleteSandboxRequest for proper client-driven CAS + let sandbox = self + .store + .update_message_cas::(&id, 0, |s| { + s.phase = SandboxPhase::Deleting as i32; + }) .await - .map_err(|e| Status::internal(format!("persist sandbox failed: {e}")))?; + .map_err(|e| { + crate::grpc::persistence_error_to_status(e, "set sandbox phase to Deleting") + })?; + self.sandbox_index.update_from_sandbox(&sandbox); self.sandbox_watch_bus.notify(&id); self.cleanup_sandbox_owned_records(&sandbox).await; @@ -811,86 +843,138 @@ impl ComputeRuntime { .as_ref() .map(decode_sandbox_record) .transpose()?; - let previous = existing.clone(); - - let mut status = incoming.status.as_ref().map(public_status_from_driver); - rewrite_user_facing_conditions( - &mut status, - existing.as_ref().and_then(|sandbox| sandbox.spec.as_ref()), - ); - let session_connected = self.supervisor_sessions.has_session(&incoming.id); - let mut phase = derive_phase(incoming.status.as_ref()); - let mut sandbox = existing.unwrap_or_else(|| { - use crate::persistence::current_time_ms; + // If no existing record, create initial sandbox (first watch event for this sandbox) + if existing.is_none() { + use crate::persistence::{WriteCondition, current_time_ms}; let now_ms = current_time_ms().unwrap_or(0); - Sandbox { + + let mut status = incoming.status.as_ref().map(public_status_from_driver); + rewrite_user_facing_conditions(&mut status, None); + + let session_connected = self.supervisor_sessions.has_session(&incoming.id); + let mut phase = derive_phase(incoming.status.as_ref()); + + let sandbox_name = incoming.name.clone(); + if session_connected + && matches!(phase, SandboxPhase::Provisioning | SandboxPhase::Unknown) + { + ensure_supervisor_ready_status(&mut status, &sandbox_name); + phase = SandboxPhase::Ready; + } + + let sandbox = Sandbox { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { id: incoming.id.clone(), - name: incoming.name.clone(), + name: sandbox_name, created_at_ms: now_ms, labels: std::collections::HashMap::new(), + resource_version: 0, }), spec: None, - status: None, - phase: SandboxPhase::Unknown as i32, + status, + phase: phase as i32, current_policy_version: 0, - } - }); + }; - if session_connected && matches!(phase, SandboxPhase::Provisioning | SandboxPhase::Unknown) - { - ensure_supervisor_ready_status(&mut status, sandbox.object_name()); - phase = SandboxPhase::Ready; - } + self.store + .put_if( + Sandbox::object_type(), + &incoming.id, + sandbox.object_name(), + &sandbox.encode_to_vec(), + None, + WriteCondition::MustCreate, + ) + .await + .map_err(|e| match e { + crate::persistence::PersistenceError::Conflict { + current_resource_version, + } => format!( + "concurrent modification detected during sandbox creation (current resource_version: {})", + current_resource_version + .map_or_else(|| "unknown".to_string(), |v| v.to_string()) + ), + other => other.to_string(), + })?; - let old_phase = SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); - if old_phase != phase { - info!( - sandbox_id = %incoming.id, - sandbox_name = %incoming.name, - old_phase = ?old_phase, - new_phase = ?phase, - "Sandbox phase changed" - ); + self.sandbox_index.update_from_sandbox(&sandbox); + self.sandbox_watch_bus.notify(sandbox.object_id()); + return Ok(()); } - if phase == SandboxPhase::Error - && let Some(ref status) = status - { - for condition in &status.conditions { - if condition.r#type == "Ready" - && condition.status.eq_ignore_ascii_case("false") - && is_terminal_failure_reason(&condition.reason) + // Use CAS to update existing sandbox (prevents lost updates in HA deployments with concurrent watch events) + // 5 retries = ~5ms max latency under moderate contention from multiple gateway replicas + // Capture external state once to ensure all retries use the same snapshot + let session_connected = self.supervisor_sessions.has_session(&incoming.id); + let sandbox_name = incoming.name.clone(); + + let sandbox = self + .store + .update_message_cas::(&incoming.id, 0, |sandbox| { + let mut status = incoming.status.as_ref().map(public_status_from_driver); + rewrite_user_facing_conditions(&mut status, sandbox.spec.as_ref()); + + let mut phase = derive_phase(incoming.status.as_ref()); + if session_connected + && matches!(phase, SandboxPhase::Provisioning | SandboxPhase::Unknown) { - warn!( + ensure_supervisor_ready_status(&mut status, &sandbox_name); + phase = SandboxPhase::Ready; + } + + let old_phase = + SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); + if old_phase != phase { + info!( sandbox_id = %incoming.id, - sandbox_name = %incoming.name, - reason = %condition.reason, - message = %condition.message, - "Sandbox failed to become ready" + sandbox_name = %sandbox_name, + old_phase = ?old_phase, + new_phase = ?phase, + "Sandbox phase changed" ); } - } - } - // Update metadata fields - if let Some(metadata) = sandbox.metadata.as_mut() { - metadata.name = incoming.name; - } - // Note: namespace field removed from public Sandbox API - it remains internal to DriverSandbox - sandbox.status = status; - sandbox.phase = phase as i32; + if phase == SandboxPhase::Error + && let Some(ref status) = status + { + for condition in &status.conditions { + if condition.r#type == "Ready" + && condition.status.eq_ignore_ascii_case("false") + && is_terminal_failure_reason(&condition.reason) + { + warn!( + sandbox_id = %incoming.id, + sandbox_name = %sandbox_name, + reason = %condition.reason, + message = %condition.message, + "Sandbox failed to become ready" + ); + } + } + } - if previous.as_ref() == Some(&sandbox) { - return Ok(()); - } + // Update metadata fields + if let Some(metadata) = sandbox.metadata.as_mut() { + metadata.name.clone_from(&sandbox_name); + } + // Note: namespace field removed from public Sandbox API - it remains internal to DriverSandbox + sandbox.status = status; + sandbox.phase = phase as i32; + }) + .await + .map_err(|e| match e { + crate::persistence::PersistenceError::Conflict { + current_resource_version, + } => format!( + "concurrent modification detected during sandbox reconciliation (current resource_version: {})", + current_resource_version + .map_or_else(|| "unknown".to_string(), |v| v.to_string()) + ), + other => other.to_string(), + })?; self.sandbox_index.update_from_sandbox(&sandbox); - self.store - .put_message(&sandbox) - .await - .map_err(|e| e.to_string())?; self.sandbox_watch_bus.notify(sandbox.object_id()); Ok(()) } @@ -909,38 +993,51 @@ impl ComputeRuntime { connected: bool, ) -> Result<(), String> { let _guard = self.sync_lock.lock().await; - let Some(record) = self + + // Use CAS to update sandbox phase based on supervisor session state + let result = self .store - .get(Sandbox::object_type(), sandbox_id) - .await - .map_err(|e| e.to_string())? - else { - return Ok(()); - }; + .update_message_cas::(sandbox_id, 0, |sandbox| { + let current_phase = + SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); - let mut sandbox = decode_sandbox_record(&record)?; - let current_phase = SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); + // Skip if sandbox is in terminal state + if current_phase == SandboxPhase::Deleting || current_phase == SandboxPhase::Error { + return; + } - if current_phase == SandboxPhase::Deleting || current_phase == SandboxPhase::Error { - return Ok(()); - } + let sandbox_name = sandbox.object_name().to_string(); + if connected { + ensure_supervisor_ready_status(&mut sandbox.status, &sandbox_name); + sandbox.phase = SandboxPhase::Ready as i32; + } else if current_phase == SandboxPhase::Ready { + ensure_supervisor_not_ready_status(&mut sandbox.status, &sandbox_name); + sandbox.phase = SandboxPhase::Provisioning as i32; + } + }) + .await; - let sandbox_name = sandbox.object_name().to_string(); - if connected { - ensure_supervisor_ready_status(&mut sandbox.status, &sandbox_name); - sandbox.phase = SandboxPhase::Ready as i32; - } else if current_phase == SandboxPhase::Ready { - ensure_supervisor_not_ready_status(&mut sandbox.status, &sandbox_name); - sandbox.phase = SandboxPhase::Provisioning as i32; - } else { - return Ok(()); - } + // Handle not found gracefully (sandbox may have been deleted) + let sandbox = match result { + Ok(s) => s, + Err(crate::persistence::PersistenceError::Database(ref msg)) + if msg.contains("not found") => + { + return Ok(()); + } + Err(crate::persistence::PersistenceError::Conflict { + current_resource_version, + }) => { + return Err(format!( + "concurrent modification detected (current resource_version: {})", + current_resource_version + .map_or_else(|| "unknown".to_string(), |v| v.to_string()) + )); + } + Err(e) => return Err(e.to_string()), + }; self.sandbox_index.update_from_sandbox(&sandbox); - self.store - .put_message(&sandbox) - .await - .map_err(|e| e.to_string())?; self.sandbox_watch_bus.notify(sandbox_id); Ok(()) } @@ -1830,6 +1927,7 @@ mod tests { name: name.to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), phase: phase as i32, ..Default::default() @@ -1843,6 +1941,7 @@ mod tests { name: format!("session-{id}"), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), sandbox_id: sandbox_id.to_string(), token: format!("token-{id}"), @@ -2757,4 +2856,105 @@ mod tests { "unset user_namespaces must not produce host_users" ); } + + #[tokio::test] + async fn create_sandbox_returns_resource_version_one() { + let runtime = test_runtime(Arc::new(TestDriver::default())).await; + + let mut sandbox = sandbox_record("sb-new", "test-sandbox", SandboxPhase::Provisioning); + // Clear metadata to simulate incoming request + sandbox.metadata = Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "sb-new".to_string(), + name: "test-sandbox".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }); + + let created = runtime.create_sandbox(sandbox).await.unwrap(); + + assert_eq!( + created.metadata.as_ref().unwrap().resource_version, + 1, + "create_sandbox should return resource_version: 1 after insert" + ); + + // Verify database also has resource_version: 1 + let created_id = created.metadata.as_ref().unwrap().id.clone(); + let stored = runtime + .store + .get_message::(&created_id) + .await + .unwrap() + .unwrap(); + assert_eq!( + stored.metadata.as_ref().unwrap().resource_version, + 1, + "database should have resource_version: 1 after create" + ); + } + + #[tokio::test] + async fn concurrent_create_sandbox_rejects_duplicate() { + let runtime = Arc::new(test_runtime(Arc::new(TestDriver::default())).await); + + let sandbox = sandbox_record( + "sb-concurrent", + "test-concurrent", + SandboxPhase::Provisioning, + ); + + // Spawn two concurrent creation attempts for the same sandbox + let runtime1 = runtime.clone(); + let sandbox1 = sandbox.clone(); + let handle1 = tokio::spawn(async move { runtime1.create_sandbox(sandbox1).await }); + + let runtime2 = runtime.clone(); + let sandbox2 = sandbox.clone(); + let handle2 = tokio::spawn(async move { runtime2.create_sandbox(sandbox2).await }); + + // Wait for both to complete + let result1 = handle1.await.unwrap(); + let result2 = handle2.await.unwrap(); + + // Exactly one should succeed, one should fail with AlreadyExists + let success_count = [&result1, &result2].iter().filter(|r| r.is_ok()).count(); + let already_exists_count = [&result1, &result2] + .iter() + .filter(|r| { + r.as_ref() + .err() + .is_some_and(|e| e.code() == Code::AlreadyExists) + }) + .count(); + + assert_eq!( + success_count, 1, + "exactly one creation should succeed, got results: {result1:?} {result2:?}" + ); + assert_eq!( + already_exists_count, 1, + "exactly one creation should fail with AlreadyExists, got results: {result1:?} {result2:?}" + ); + + // Verify the successful sandbox can be retrieved by name + let created_sandbox = [result1, result2] + .into_iter() + .find_map(Result::ok) + .expect("should have one successful creation"); + let retrieved = runtime + .store + .get_message_by_name::("test-concurrent") + .await + .unwrap(); + assert!( + retrieved.is_some(), + "created sandbox should be retrievable by name" + ); + assert_eq!( + retrieved.unwrap().object_id(), + created_sandbox.object_id(), + "retrieved sandbox should match created sandbox" + ); + } } diff --git a/crates/openshell-server/src/grpc/mod.rs b/crates/openshell-server/src/grpc/mod.rs index 2605edb81..bdc86c26d 100644 --- a/crates/openshell-server/src/grpc/mod.rs +++ b/crates/openshell-server/src/grpc/mod.rs @@ -65,6 +65,29 @@ pub fn clamp_limit(raw: u32, default: u32, max: u32) -> u32 { if raw == 0 { default } else { raw.min(max) } } +/// Map a `PersistenceError` to an appropriate gRPC `Status`. +/// +/// CAS conflicts (optimistic concurrency failures) are mapped to `ABORTED` +/// to signal that the client should retry with fresh data. Other persistence +/// errors are mapped to `INTERNAL`. +pub fn persistence_error_to_status( + err: crate::persistence::PersistenceError, + operation: &str, +) -> Status { + use crate::persistence::PersistenceError; + + match err { + PersistenceError::Conflict { + current_resource_version, + } => Status::aborted(format!( + "{} failed due to concurrent modification (current resource_version: {})", + operation, + current_resource_version.map_or_else(|| "unknown".to_string(), |v| v.to_string()) + )), + other => Status::internal(format!("{operation} failed: {other}")), + } +} + // --------------------------------------------------------------------------- // Field-level size limits (shared across submodules) // --------------------------------------------------------------------------- @@ -104,6 +127,10 @@ const MAX_PROVIDER_CONFIG_ENTRIES: usize = 64; struct StoredSettings { revision: u64, settings: BTreeMap, + /// Database `resource_version` for CAS. Not persisted in the JSON payload; + /// loaded from `ObjectRecord` and used for optimistic concurrency control. + #[serde(skip)] + resource_version: u64, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index 2930ed975..80ce4d44f 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -1006,32 +1006,29 @@ pub(super) async fn handle_update_config( validate_static_fields_unchanged(baseline_policy, &new_policy)?; validate_policy_safety(&new_policy)?; } else { + // Backfill spec.policy using CAS (first-time policy discovery) let _sandbox_sync_guard = state.compute.sandbox_sync_guard().await; - let mut sandbox = state + let sandbox_id = sandbox.object_id().to_string(); + let new_policy_clone = new_policy.clone(); + state .store - .get_message::(&sandbox_id) + .update_message_cas::( + &sandbox_id, + req.expected_resource_version, + |sandbox| { + if let Some(ref mut spec) = sandbox.spec + && spec.policy.is_none() + { + spec.policy = Some(new_policy_clone.clone()); + } + }, + ) .await - .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? - .ok_or_else(|| Status::not_found("sandbox not found"))?; - let spec = sandbox - .spec - .as_mut() - .ok_or_else(|| Status::internal("sandbox has no spec"))?; - if let Some(baseline_policy) = spec.policy.as_ref() { - validate_static_fields_unchanged(baseline_policy, &new_policy)?; - validate_policy_safety(&new_policy)?; - } else { - spec.policy = Some(new_policy.clone()); - state - .store - .put_message(&sandbox) - .await - .map_err(|e| Status::internal(format!("backfill spec.policy failed: {e}")))?; - info!( - sandbox_id = %sandbox_id, - "UpdateConfig: backfilled spec.policy from sandbox-discovered policy" - ); - } + .map_err(|e| super::persistence_error_to_status(e, "backfill spec.policy"))?; + info!( + sandbox_id = %sandbox_id, + "UpdateConfig: backfilled spec.policy from sandbox-discovered policy" + ); } let latest = state @@ -1228,11 +1225,19 @@ pub(super) async fn handle_report_policy_status( .store .supersede_older_policies(&req.sandbox_id, version) .await; + + // Update current_policy_version using CAS + // TODO: Accept expected_version from UpdateConfigRequest for proper client-driven CAS let _sandbox_sync_guard = state.compute.sandbox_sync_guard().await; - if let Ok(Some(mut sandbox)) = state.store.get_message::(&req.sandbox_id).await { - sandbox.current_policy_version = req.version; - let _ = state.store.put_message(&sandbox).await; - } + let version_to_set = req.version; + state + .store + .update_message_cas::(&req.sandbox_id, 0, |sandbox| { + sandbox.current_policy_version = version_to_set; + }) + .await + .map_err(|e| super::persistence_error_to_status(e, "update current_policy_version"))?; + state.sandbox_watch_bus.notify(&req.sandbox_id); } @@ -2670,8 +2675,11 @@ async fn load_settings_record( .await .map_err(|e| Status::internal(format!("fetch settings failed: {e}")))?; if let Some(record) = record { - serde_json::from_slice::(&record.payload) - .map_err(|e| Status::internal(format!("decode settings payload failed: {e}"))) + let mut settings = serde_json::from_slice::(&record.payload) + .map_err(|e| Status::internal(format!("decode settings payload failed: {e}")))?; + // Populate resource_version from database record for CAS + settings.resource_version = record.resource_version; + Ok(settings) } else { Ok(StoredSettings::default()) } @@ -2683,18 +2691,43 @@ async fn save_settings_record( name: &str, settings: &StoredSettings, ) -> Result<(), Status> { + use crate::persistence::WriteCondition; + let payload = serde_json::to_vec(settings) .map_err(|e| Status::internal(format!("encode settings payload failed: {e}")))?; - store - .put( - object_type, - &uuid::Uuid::new_v4().to_string(), - name, - &payload, - None, + + let (id, condition) = if settings.resource_version == 0 { + // Create new settings (resource_version 0 means never persisted) + (uuid::Uuid::new_v4().to_string(), WriteCondition::MustCreate) + } else { + // Update existing with CAS on the version from when it was loaded + // Fetch the record to get the stable ID + let existing = store + .get_by_name(object_type, name) + .await + .map_err(|e| Status::internal(format!("fetch settings for CAS failed: {e}")))? + .ok_or_else(|| Status::not_found("settings disappeared since load"))?; + + ( + existing.id, + WriteCondition::MatchResourceVersion(settings.resource_version), ) + }; + + // Single-attempt CAS write + store + .put_if(object_type, &id, name, &payload, None, condition) .await - .map_err(|e| Status::internal(format!("persist settings failed: {e}")))?; + .map_err(|e| match e { + crate::persistence::PersistenceError::Conflict { .. } => { + Status::aborted("settings were modified concurrently; please retry") + } + crate::persistence::PersistenceError::UniqueViolation { .. } => { + Status::aborted("settings were created concurrently; please retry") + } + other => super::persistence_error_to_status(other, "persist settings"), + })?; + Ok(()) } @@ -2860,6 +2893,7 @@ mod tests { name: "no-policy-sandbox".to_string(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { policy: None, @@ -2885,6 +2919,7 @@ mod tests { name: name.to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), r#type: provider_type.to_string(), credentials: std::iter::once(("GITHUB_TOKEN".to_string(), "ghp-test".to_string())) @@ -2926,6 +2961,7 @@ mod tests { name: name.to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { policy: Some(policy), @@ -2945,6 +2981,7 @@ mod tests { StoredSettingValue::Bool(true), )) .collect(), + ..Default::default() }; save_global_settings(state.store.as_ref(), &global_settings) .await @@ -2994,6 +3031,7 @@ mod tests { name: "generic".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), profile: Some(openshell_core::proto::ProviderProfile { id: "generic".to_string(), @@ -3035,6 +3073,7 @@ mod tests { name: "custom-api".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), profile: Some(openshell_core::proto::ProviderProfile { id: "custom-api".to_string(), @@ -3097,6 +3136,7 @@ mod tests { name: "custom-api".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), profile: Some(openshell_core::proto::ProviderProfile { id: "custom-api".to_string(), @@ -3523,6 +3563,7 @@ mod tests { Request::new(AttachSandboxProviderRequest { sandbox_name: "attach-lifecycle".to_string(), provider_name: "work-github".to_string(), + expected_resource_version: 0, }), ) .await @@ -3558,6 +3599,7 @@ mod tests { Request::new(DetachSandboxProviderRequest { sandbox_name: "attach-lifecycle".to_string(), provider_name: "work-github".to_string(), + expected_resource_version: 0, }), ) .await @@ -3681,6 +3723,7 @@ mod tests { Request::new(AttachSandboxProviderRequest { sandbox_name: "custom-attach-lifecycle".to_string(), provider_name: "work-custom".to_string(), + expected_resource_version: 0, }), ) .await @@ -3719,6 +3762,7 @@ mod tests { Request::new(DetachSandboxProviderRequest { sandbox_name: "custom-attach-lifecycle".to_string(), provider_name: "work-custom".to_string(), + expected_resource_version: 0, }), ) .await @@ -3782,6 +3826,7 @@ mod tests { name: "global-profile-sandbox".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { policy: Some(sandbox_policy), @@ -3823,6 +3868,7 @@ mod tests { ] .into_iter() .collect(), + ..Default::default() }; save_global_settings(state.store.as_ref(), &global_settings) .await @@ -3869,6 +3915,7 @@ mod tests { name: "backfill-sandbox".to_string(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { policy: None, @@ -3951,6 +3998,7 @@ mod tests { name: "draft-flow".to_string(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { policy: None, @@ -4161,6 +4209,7 @@ mod tests { name: sandbox_name.clone(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { policy: None, @@ -4258,6 +4307,7 @@ mod tests { name: sandbox_name.clone(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { policy: None, @@ -4371,6 +4421,7 @@ mod tests { name: sandbox_name.clone(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { policy: None, @@ -4470,6 +4521,7 @@ mod tests { name: sandbox_name.clone(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { policy: None, @@ -4575,6 +4627,7 @@ mod tests { name: "draft-owner".to_string(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { policy: None, @@ -4589,6 +4642,7 @@ mod tests { name: "draft-other".to_string(), created_at_ms: 1_000_001, labels: std::collections::HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { policy: None, @@ -5324,6 +5378,7 @@ mod tests { revision: 1, settings: std::iter::once(("policy".to_string(), StoredSettingValue::Bytes(encoded))) .collect(), + ..Default::default() }; let decoded = decode_policy_from_global_settings(&global) @@ -5406,6 +5461,7 @@ mod tests { ] .into_iter() .collect(), + ..Default::default() }; let sandbox = StoredSettings { revision: 1, @@ -5418,6 +5474,7 @@ mod tests { ] .into_iter() .collect(), + ..Default::default() }; let merged = merge_effective_settings(&global, &sandbox).unwrap(); @@ -5447,6 +5504,7 @@ mod tests { )] .into_iter() .collect(), + ..Default::default() }; let merged = merge_effective_settings(&global, &sandbox).unwrap(); @@ -5476,6 +5534,7 @@ mod tests { StoredSettingValue::Bytes("deadbeef".to_string()), )) .collect(), + ..Default::default() }; let sandbox = StoredSettings { revision: 1, @@ -5484,6 +5543,7 @@ mod tests { StoredSettingValue::Bytes("cafebabe".to_string()), )) .collect(), + ..Default::default() }; let merged = merge_effective_settings(&global, &sandbox).unwrap(); @@ -5776,24 +5836,53 @@ mod tests { .settings .insert(format!("key_{i}"), StoredSettingValue::Int(i as i64)); settings.revision = settings.revision.wrapping_add(1); - save_global_settings(&store, &settings).await.unwrap(); + save_global_settings(&store, &settings).await })); } + let mut succeeded = 0; + let mut cas_conflicts = 0; for h in handles { - h.await.unwrap(); + match h.await.unwrap() { + Ok(()) => succeeded += 1, + Err(e) if e.code() == Code::Aborted => cas_conflicts += 1, + Err(e) => panic!("unexpected error: {e}"), + } } let final_settings = load_global_settings(&store).await.unwrap(); - let lost = (n as u64).saturating_sub(final_settings.revision); - if lost == 0 { - eprintln!( - "note: no lost writes detected in unlocked test (sequential scheduling); \ - the locked test is the authoritative correctness check" - ); - } else { - eprintln!("unlocked test: {lost} lost writes out of {n} (expected behavior)"); - } + + // With single-attempt CAS (no retry), concurrent modifications are properly detected: + // - All tasks read initial state (revision=0, resource_version=0) + // - First write succeeds with resource_version=1 + // - Subsequent writes fail with ABORTED (CAS conflict) because they all have stale resource_version=0 + // - Only the first write succeeds; all others are rejected + // + // This demonstrates that single-attempt CAS prevents lost writes by rejecting stale updates. + // The caller must retry from a fresh read to incorporate concurrent changes. + assert!( + cas_conflicts > 0, + "most concurrent writes should fail with CAS conflict (succeeded={succeeded}, conflicts={cas_conflicts})" + ); + assert!( + succeeded < n, + "not all writes should succeed due to conflicts (succeeded={succeeded}, total={n})" + ); + assert_eq!( + final_settings.revision as usize, succeeded, + "final revision should match number of successful writes" + ); + assert_eq!( + final_settings.settings.len(), + succeeded, + "final settings should contain exactly the keys from successful writes" + ); + + eprintln!( + "unlocked CAS test: {succeeded} succeeded, {cas_conflicts} CAS conflicts, \ + final revision={} (matches succeeded count, demonstrating proper conflict detection)", + final_settings.revision + ); } // ---- Conflict guard tests ---- @@ -5840,6 +5929,7 @@ mod tests { .await .unwrap(); + // Create initial global settings let mut global = StoredSettings::default(); global.settings.insert( "log_level".to_string(), @@ -5851,6 +5941,8 @@ mod tests { let loaded = load_global_settings(&store).await.unwrap(); assert!(loaded.settings.contains_key("log_level")); + // Load fresh to get current resource_version before updating + let mut global = load_global_settings(&store).await.unwrap(); global.settings.remove("log_level"); global.revision = 2; save_global_settings(&store, &global).await.unwrap(); @@ -5894,4 +5986,330 @@ mod tests { assert_eq!(err.code(), Code::InvalidArgument); assert!(err.message().contains("unknown setting key")); } + + #[tokio::test] + async fn save_settings_detects_concurrent_modification() { + let store = Store::connect("sqlite::memory:").await.unwrap(); + + // Create initial settings + let mut settings = StoredSettings { + revision: 1, + settings: std::iter::once(( + "initial_key".to_string(), + StoredSettingValue::String("initial_value".to_string()), + )) + .collect(), + ..Default::default() + }; + save_global_settings(&store, &settings).await.unwrap(); + + // Load settings (simulating first client read) + let loaded = load_global_settings(&store).await.unwrap(); + assert_eq!(loaded.revision, 1); + + // Simulate concurrent modification: another client updates the settings + let mut concurrent_update = loaded.clone(); + concurrent_update.settings.insert( + "concurrent_key".to_string(), + StoredSettingValue::String("concurrent_value".to_string()), + ); + concurrent_update.revision = 2; + save_global_settings(&store, &concurrent_update) + .await + .unwrap(); + + // Now attempt to save our original modification (which is based on stale revision 1) + settings.settings.insert( + "our_key".to_string(), + StoredSettingValue::String("our_value".to_string()), + ); + settings.revision = 2; // We think we're updating to revision 2 + + let result = save_global_settings(&store, &settings).await; + + // Should fail with ABORTED due to concurrent modification + assert!(result.is_err(), "save with stale revision should fail"); + let err = result.unwrap_err(); + assert_eq!( + err.code(), + Code::Aborted, + "should fail with ABORTED due to version mismatch" + ); + assert!( + err.message().contains("concurrently"), + "error should mention concurrent modification: {}", + err.message() + ); + + // Verify the database contains the concurrent update, not our stale update + let final_settings = load_global_settings(&store).await.unwrap(); + assert_eq!(final_settings.revision, 2); + assert!( + final_settings.settings.contains_key("concurrent_key"), + "concurrent update should be preserved" + ); + assert!( + !final_settings.settings.contains_key("our_key"), + "stale update should NOT be in database" + ); + } + + // ---- CAS (Client-driven optimistic concurrency) tests for UpdateConfig ---- + // These test the policy backfill path where spec.policy is None and UpdateConfig + // uses update_message_cas to atomically set it. + + #[tokio::test] + async fn update_config_policy_backfill_cas_succeeds_with_correct_version() { + use openshell_core::proto::{SandboxPhase, SandboxSpec}; + + let state = test_server_state().await; + + // Create a sandbox WITHOUT a policy (spec.policy = None) + // This simulates a sandbox before the supervisor has discovered and synced a policy + let sandbox = Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "sb-1".to_string(), + name: "test-sandbox".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + spec: Some(SandboxSpec { + policy: None, // No policy yet - will be backfilled + providers: Vec::new(), + ..Default::default() + }), + phase: SandboxPhase::Provisioning as i32, + ..Default::default() + }; + state.store.put_message(&sandbox).await.unwrap(); + + // Fetch the sandbox to get its current resource_version + let current = state + .store + .get_message_by_name::("test-sandbox") + .await + .unwrap() + .unwrap(); + let current_version = current.metadata.as_ref().unwrap().resource_version; + + // Backfill the policy with correct expected_resource_version + let new_policy = ProtoSandboxPolicy::default(); + + let response = handle_update_config( + &state, + Request::new(UpdateConfigRequest { + name: "test-sandbox".to_string(), + policy: Some(new_policy), + setting_key: String::new(), + setting_value: None, + delete_setting: false, + global: false, + merge_operations: vec![], + expected_resource_version: current_version, + }), + ) + .await + .unwrap() + .into_inner(); + + // UpdateConfigResponse contains the policy version + assert_eq!(response.version, 1); + + // Verify the resource_version incremented and policy was backfilled + let updated_sandbox = state + .store + .get_message_by_name::("test-sandbox") + .await + .unwrap() + .unwrap(); + assert_eq!( + updated_sandbox.metadata.as_ref().unwrap().resource_version, + current_version + 1, + "resource_version should increment during CAS backfill" + ); + assert!( + updated_sandbox.spec.as_ref().unwrap().policy.is_some(), + "policy should be backfilled" + ); + } + + #[tokio::test] + async fn update_config_policy_backfill_cas_rejects_stale_version() { + use openshell_core::proto::{SandboxPhase, SandboxSpec}; + + let state = test_server_state().await; + + // Create a sandbox WITHOUT a policy + let sandbox = Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "sb-1".to_string(), + name: "test-sandbox".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + spec: Some(SandboxSpec { + policy: None, + providers: Vec::new(), + ..Default::default() + }), + phase: SandboxPhase::Provisioning as i32, + ..Default::default() + }; + state.store.put_message(&sandbox).await.unwrap(); + + // Get current version + let current = state + .store + .get_message_by_name::("test-sandbox") + .await + .unwrap() + .unwrap(); + let current_version = current.metadata.as_ref().unwrap().resource_version; + + // Try to backfill with a stale version + let new_policy = ProtoSandboxPolicy::default(); + + let err = handle_update_config( + &state, + Request::new(UpdateConfigRequest { + name: "test-sandbox".to_string(), + policy: Some(new_policy), + setting_key: String::new(), + setting_value: None, + delete_setting: false, + global: false, + merge_operations: vec![], + expected_resource_version: 99, // stale version + }), + ) + .await + .unwrap_err(); + + // Should get ABORTED status for CAS conflict + assert_eq!(err.code(), Code::Aborted); + assert!( + err.message().contains("modified concurrently") + || err.message().contains("resource_version"), + "error message should mention concurrency conflict: {}", + err.message() + ); + + // Verify the sandbox was not modified (policy still None) + let unchanged = state + .store + .get_message_by_name::("test-sandbox") + .await + .unwrap() + .unwrap(); + assert_eq!( + unchanged.metadata.as_ref().unwrap().resource_version, + current_version, + "resource_version should not change when CAS fails" + ); + assert!( + unchanged.spec.as_ref().unwrap().policy.is_none(), + "policy should still be None after failed backfill" + ); + } + + #[tokio::test] + async fn update_config_policy_backfill_concurrent_with_stale_versions() { + use openshell_core::proto::{SandboxPhase, SandboxSpec}; + use std::sync::Arc; + + let state = Arc::new(test_server_state().await); + + // Create a sandbox WITHOUT a policy + let sandbox = Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "sb-1".to_string(), + name: "test-sandbox".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + spec: Some(SandboxSpec { + policy: None, + providers: Vec::new(), + ..Default::default() + }), + phase: SandboxPhase::Provisioning as i32, + ..Default::default() + }; + state.store.put_message(&sandbox).await.unwrap(); + + // All three clients fetch the sandbox and see the same version + let initial = state + .store + .get_message_by_name::("test-sandbox") + .await + .unwrap() + .unwrap(); + let initial_version = initial.metadata.as_ref().unwrap().resource_version; + + // Launch 3 concurrent policy backfill attempts, all using the same initial version + let mut handles = vec![]; + for _i in 0..3 { + let state_clone = Arc::clone(&state); + let new_policy = ProtoSandboxPolicy::default(); + + let handle = tokio::spawn(async move { + handle_update_config( + &state_clone, + Request::new(UpdateConfigRequest { + name: "test-sandbox".to_string(), + policy: Some(new_policy), + setting_key: String::new(), + setting_value: None, + delete_setting: false, + global: false, + merge_operations: vec![], + expected_resource_version: initial_version, + }), + ) + .await + }); + handles.push(handle); + } + + let results: Vec<_> = futures::future::join_all(handles) + .await + .into_iter() + .map(|r| r.unwrap()) + .collect(); + + // Only one should succeed; others should get ABORTED + let successes = results.iter().filter(|r| r.is_ok()).count(); + let aborted_conflicts = results + .iter() + .filter(|r| r.as_ref().err().is_some_and(|e| e.code() == Code::Aborted)) + .count(); + + assert_eq!( + successes, 1, + "exactly one backfill should succeed with client-driven CAS" + ); + assert_eq!( + aborted_conflicts, 2, + "two backfills should fail with ABORTED due to stale version" + ); + + // Final sandbox should have resource_version = initial_version + 1 and policy backfilled + let final_sandbox = state + .store + .get_message_by_name::("test-sandbox") + .await + .unwrap() + .unwrap(); + assert_eq!( + final_sandbox.metadata.as_ref().unwrap().resource_version, + initial_version + 1 + ); + assert!( + final_sandbox.spec.as_ref().unwrap().policy.is_some(), + "policy should be backfilled after one success" + ); + } } diff --git a/crates/openshell-server/src/grpc/provider.rs b/crates/openshell-server/src/grpc/provider.rs index 2ed4d439d..271c80a4b 100644 --- a/crates/openshell-server/src/grpc/provider.rs +++ b/crates/openshell-server/src/grpc/provider.rs @@ -5,7 +5,9 @@ #![allow(clippy::result_large_err)] // gRPC handlers return Result, Status> -use crate::persistence::{ObjectName, ObjectType, Store, generate_name}; +use crate::persistence::{ + ObjectLabels, ObjectName, ObjectType, Store, WriteCondition, generate_name, +}; use openshell_core::proto::{Provider, Sandbox}; use prost::Message; use tonic::Status; @@ -44,6 +46,7 @@ pub(super) async fn create_provider_record( name: generate_name(), created_at_ms: now_ms, labels: std::collections::HashMap::new(), + resource_version: 0, }); } @@ -72,19 +75,42 @@ pub(super) async fn create_provider_record( // Validate field sizes before any I/O. validate_provider_fields(&provider)?; - let existing = store - .get_message_by_name::(provider.object_name()) - .await - .map_err(|e| Status::internal(format!("fetch provider failed: {e}")))?; - - if existing.is_some() { - return Err(Status::already_exists("provider already exists")); + // Generate UUID for database row and update metadata.id to match + let provider_id = uuid::Uuid::new_v4().to_string(); + let mut provider = provider; + if let Some(metadata) = provider.metadata.as_mut() { + metadata.id.clone_from(&provider_id); } + // Create with MustCreate condition to prevent duplicate creation race store - .put_message(&provider) + .put_if( + Provider::object_type(), + &provider_id, + provider.object_name(), + &provider.encode_to_vec(), + None, + WriteCondition::MustCreate, + ) + .await + .map_err(|e| { + let err_str = e.to_string(); + if err_str.contains("unique constraint") + || err_str.contains("UNIQUE constraint") + || err_str.contains("unique violation") + { + Status::already_exists("provider already exists") + } else { + Status::internal(format!("persist provider failed: {e}")) + } + })?; + + // Read back from DB to get correct resource_version + let provider = store + .get_message_by_name::(provider.object_name()) .await - .map_err(|e| Status::internal(format!("persist provider failed: {e}")))?; + .map_err(|e| Status::internal(format!("fetch provider failed: {e}")))? + .ok_or_else(|| Status::internal("provider disappeared after creation"))?; Ok(redact_provider_credentials(provider)) } @@ -126,12 +152,16 @@ pub(super) async fn update_provider_record( store: &Store, provider: Provider, ) -> Result { - use crate::persistence::ObjectName; + use crate::persistence::{ObjectId, ObjectName}; if provider.object_name().is_empty() { return Err(Status::invalid_argument("provider.name is required")); } + // Extract expected version from provider metadata + let expected_resource_version = provider.metadata.as_ref().map_or(0, |m| m.resource_version); + + // Resolve provider ID from name for CAS update let existing = store .get_message_by_name::(provider.object_name()) .await @@ -150,24 +180,78 @@ pub(super) async fn update_provider_record( )); } - let updated = Provider { - metadata: existing.metadata, - r#type: existing.r#type, - credentials: merge_map(existing.credentials, provider.credentials), - config: merge_map(existing.config, provider.config), - }; + let current_version = existing.metadata.as_ref().map_or(0, |m| m.resource_version); - // Ensure metadata is valid (defense in depth - existing.metadata should always be valid) - super::validation::validate_object_metadata(updated.metadata.as_ref(), "provider")?; + // Determine the version to use for CAS: + // - If expected_resource_version is 0, use current version (internal/backward compat) + // - Otherwise, validate that expected matches current (client-facing operations) + let cas_version = if expected_resource_version == 0 { + current_version + } else { + if expected_resource_version != current_version { + return Err(Status::aborted(format!( + "provider was modified concurrently (current resource_version: {current_version})" + ))); + } + expected_resource_version + }; - validate_provider_fields(&updated)?; + // Apply merge to create candidate + let mut candidate = existing.clone(); + candidate.credentials = merge_map(candidate.credentials, provider.credentials); + candidate.config = merge_map(candidate.config, provider.config); + + // Validate BEFORE writing to prevent persisting invalid state + super::validation::validate_object_metadata(candidate.metadata.as_ref(), "provider")?; + validate_provider_fields(&candidate)?; + + // Serialize labels for storage + let labels_map = candidate.object_labels(); + let labels_json = if labels_map + .as_ref() + .is_none_or(std::collections::HashMap::is_empty) + { + None + } else { + Some( + serde_json::to_string(&labels_map) + .map_err(|e| Status::internal(format!("serialize labels failed: {e}")))?, + ) + }; - store - .put_message(&updated) + // Write validated candidate with CAS condition + let result = store + .put_if( + Provider::object_type(), + candidate.object_id(), + candidate.object_name(), + &candidate.encode_to_vec(), + labels_json.as_deref(), + WriteCondition::MatchResourceVersion(cas_version), + ) .await - .map_err(|e| Status::internal(format!("persist provider failed: {e}")))?; + .map_err(|e| { + if matches!(e, crate::persistence::PersistenceError::Conflict { .. }) { + Status::aborted(format!( + "provider was modified concurrently (current resource_version: {})", + match e { + crate::persistence::PersistenceError::Conflict { + current_resource_version, + } => current_resource_version.unwrap_or(0), + _ => 0, + } + )) + } else { + Status::internal(format!("update provider failed: {e}")) + } + })?; + + // Update resource_version from successful write + if let Some(metadata) = candidate.metadata.as_mut() { + metadata.resource_version = result.resource_version; + } - Ok(redact_provider_credentials(updated)) + Ok(redact_provider_credentials(candidate)) } pub(super) async fn delete_provider_record(store: &Store, name: &str) -> Result { @@ -656,6 +740,7 @@ fn stored_provider_profile(profile: ProviderProfile) -> StoredProviderProfile { name: profile.id.clone(), created_at_ms: now_ms, labels: std::collections::HashMap::new(), + resource_version: 0, }), profile: Some(profile), } @@ -796,6 +881,7 @@ mod tests { name: name.to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: provider_type.to_string(), credentials: [ @@ -1308,6 +1394,7 @@ mod tests { name: "sandbox-using-custom".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { providers: vec!["custom-provider".to_string()], @@ -1400,6 +1487,7 @@ mod tests { name: "gitlab-local".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), r#type: "gitlab".to_string(), credentials: std::iter::once(( @@ -1475,6 +1563,7 @@ mod tests { name: "attached-sandbox".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { providers: vec!["gitlab-local".to_string()], @@ -1496,6 +1585,78 @@ mod tests { ); } + #[tokio::test] + async fn provider_create_and_update_return_correct_resource_version() { + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + // Create provider and verify resource_version: 1 in response + let created = provider_with_values("test-provider", "openai"); + let persisted = create_provider_record(&store, created).await.unwrap(); + assert_eq!( + persisted.metadata.as_ref().unwrap().resource_version, + 1, + "create_provider_record should return resource_version: 1 after insert" + ); + + // Update provider and verify resource_version: 2 in response + let updated = update_provider_record( + &store, + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "test-provider".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: "openai".to_string(), + credentials: std::iter::once(( + "OPENAI_API_KEY".to_string(), + "updated-key".to_string(), + )) + .collect(), + config: HashMap::new(), + }, + ) + .await + .unwrap(); + assert_eq!( + updated.metadata.as_ref().unwrap().resource_version, + 2, + "update_provider_record should return resource_version: 2 after first update" + ); + + // Update again and verify resource_version: 3 + let updated_again = update_provider_record( + &store, + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "test-provider".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: "openai".to_string(), + credentials: std::iter::once(( + "OPENAI_API_KEY".to_string(), + "third-key".to_string(), + )) + .collect(), + config: HashMap::new(), + }, + ) + .await + .unwrap(); + assert_eq!( + updated_again.metadata.as_ref().unwrap().resource_version, + 3, + "update_provider_record should return resource_version: 3 after second update" + ); + } + #[tokio::test] async fn provider_validation_errors() { let store = Store::connect("sqlite::memory:?cache=shared") @@ -1510,6 +1671,7 @@ mod tests { name: "bad-provider".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), r#type: String::new(), credentials: HashMap::new(), @@ -1534,6 +1696,7 @@ mod tests { name: "missing".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), r#type: String::new(), credentials: HashMap::new(), @@ -1562,6 +1725,7 @@ mod tests { name: "noop-test".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), r#type: String::new(), credentials: HashMap::new(), @@ -1609,6 +1773,7 @@ mod tests { name: "delete-key-test".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: String::new(), credentials: std::iter::once(("SECONDARY".to_string(), String::new())).collect(), @@ -1660,6 +1825,7 @@ mod tests { name: "type-preserve-test".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: String::new(), credentials: HashMap::new(), @@ -1689,6 +1855,7 @@ mod tests { name: "type-change-test".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: "openai".to_string(), credentials: HashMap::new(), @@ -1720,6 +1887,7 @@ mod tests { name: "validate-merge-test".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: String::new(), credentials: std::iter::once((oversized_key, "value".to_string())).collect(), @@ -1748,6 +1916,7 @@ mod tests { name: "claude-local".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: "claude".to_string(), credentials: [ @@ -1791,6 +1960,7 @@ mod tests { name: "test-provider".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: "test".to_string(), credentials: [ @@ -1823,6 +1993,7 @@ mod tests { name: "claude-local".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: "claude".to_string(), credentials: std::iter::once(( @@ -1843,6 +2014,7 @@ mod tests { name: "gitlab-local".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: "gitlab".to_string(), credentials: std::iter::once(("GITLAB_TOKEN".to_string(), "glpat-xyz".to_string())) @@ -1874,6 +2046,7 @@ mod tests { name: "provider-a".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: "claude".to_string(), credentials: std::iter::once(("SHARED_KEY".to_string(), "first-value".to_string())) @@ -1891,6 +2064,7 @@ mod tests { name: "provider-b".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: "gitlab".to_string(), credentials: std::iter::once(( @@ -1927,6 +2101,7 @@ mod tests { name: "my-claude".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), r#type: "claude".to_string(), credentials: std::iter::once(( @@ -1946,6 +2121,7 @@ mod tests { name: "test-sandbox".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { providers: vec!["my-claude".to_string()], @@ -1982,6 +2158,7 @@ mod tests { name: "empty-sandbox".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec::default()), status: None, @@ -2011,4 +2188,361 @@ mod tests { let result = store.get_message::("nonexistent").await.unwrap(); assert!(result.is_none()); } + + #[tokio::test] + async fn update_provider_validates_before_write() { + let store = Arc::new(Store::connect("sqlite::memory:").await.unwrap()); + + // Create a valid provider + let provider = provider_with_values("test-validate-provider", "test-type"); + let created = create_provider_record(&store, provider.clone()) + .await + .unwrap(); + + // Build update request with just the name and new credentials + let mut update_req = Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "test-validate-provider".to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: String::new(), // Empty type is ignored in update + credentials: HashMap::new(), + config: HashMap::new(), + }; + + // Attempt to update with an oversized credential key (exceeds MAX_MAP_KEY_LEN) + update_req.credentials.insert( + "k".repeat(MAX_MAP_KEY_LEN + 1), + "oversized-key-value".to_string(), + ); + + let result = update_provider_record(&store, update_req).await; + + // Update should fail with InvalidArgument due to oversized key + assert!(result.is_err(), "update with invalid data should fail"); + let err = result.unwrap_err(); + assert_eq!( + err.code(), + Code::InvalidArgument, + "should fail validation with InvalidArgument" + ); + assert!( + err.message().contains("key"), + "error message should mention key: {}", + err.message() + ); + + // Verify database still contains the ORIGINAL valid provider (not the invalid one) + let stored = store + .get_message_by_name::("test-validate-provider") + .await + .unwrap() + .expect("provider should still exist"); + + assert_eq!( + stored.object_id(), + created.object_id(), + "stored provider ID should match original" + ); + assert_eq!( + stored.credentials.len(), + created.credentials.len(), + "credentials count should not have changed" + ); + assert!( + !stored + .credentials + .contains_key(&"k".repeat(MAX_MAP_KEY_LEN + 1)), + "oversized key should NOT be in database" + ); + } + + #[tokio::test] + async fn concurrent_create_provider_rejects_duplicate() { + let store = Arc::new( + Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(), + ); + + let provider = provider_with_values("test-concurrent-provider", "test-type"); + + // Spawn two concurrent creation attempts for the same provider + let store1 = store.clone(); + let provider1 = provider.clone(); + let handle1 = tokio::spawn(async move { create_provider_record(&store1, provider1).await }); + + let store2 = store.clone(); + let provider2 = provider.clone(); + let handle2 = tokio::spawn(async move { create_provider_record(&store2, provider2).await }); + + // Wait for both to complete + let result1 = handle1.await.unwrap(); + let result2 = handle2.await.unwrap(); + + // Exactly one should succeed, one should fail with AlreadyExists + let success_count = [&result1, &result2].iter().filter(|r| r.is_ok()).count(); + let already_exists_count = [&result1, &result2] + .iter() + .filter(|r| { + r.as_ref() + .err() + .is_some_and(|e| e.code() == Code::AlreadyExists) + }) + .count(); + + assert_eq!( + success_count, 1, + "exactly one creation should succeed, got results: {result1:?} {result2:?}" + ); + assert_eq!( + already_exists_count, 1, + "exactly one creation should fail with AlreadyExists, got results: {result1:?} {result2:?}" + ); + + // Verify the successful provider can be retrieved by name + let created_provider = [result1, result2] + .into_iter() + .find_map(Result::ok) + .expect("should have one successful creation"); + let retrieved = store + .get_message_by_name::("test-concurrent-provider") + .await + .unwrap(); + assert!( + retrieved.is_some(), + "created provider should be retrievable by name" + ); + assert_eq!( + retrieved.unwrap().object_id(), + created_provider.object_id(), + "retrieved provider should match created provider" + ); + } + + // ---- CAS (Client-driven optimistic concurrency) tests for UpdateProvider ---- + + #[tokio::test] + async fn update_provider_client_driven_cas_succeeds_with_correct_version() { + let state = test_server_state().await; + + // Create a provider + let mut provider = provider_with_values("test-provider", "generic"); + provider.metadata.as_mut().unwrap().id = String::new(); + handle_create_provider( + &state, + Request::new(CreateProviderRequest { + provider: Some(provider.clone()), + }), + ) + .await + .unwrap(); + + // Fetch the provider to get its current resource_version + let current = state + .store + .get_message_by_name::("test-provider") + .await + .unwrap() + .unwrap(); + let current_version = current.metadata.as_ref().unwrap().resource_version; + + // Prepare an update with the correct resource_version + let mut updated_provider = current.clone(); + updated_provider + .credentials + .insert("NEW_KEY".to_string(), "new-value".to_string()); + updated_provider.metadata.as_mut().unwrap().resource_version = current_version; + + // Update should succeed + let response = handle_update_provider( + &state, + Request::new(UpdateProviderRequest { + provider: Some(updated_provider.clone()), + }), + ) + .await + .unwrap() + .into_inner(); + + assert_eq!( + response.provider.as_ref().unwrap().object_name(), + "test-provider" + ); + assert_eq!( + response + .provider + .as_ref() + .unwrap() + .metadata + .as_ref() + .unwrap() + .resource_version, + current_version + 1 + ); + assert!( + response + .provider + .unwrap() + .credentials + .contains_key("NEW_KEY") + ); + } + + #[tokio::test] + async fn update_provider_client_driven_cas_rejects_stale_version() { + let state = test_server_state().await; + + // Create a provider + let mut provider = provider_with_values("test-provider", "generic"); + provider.metadata.as_mut().unwrap().id = String::new(); + handle_create_provider( + &state, + Request::new(CreateProviderRequest { + provider: Some(provider.clone()), + }), + ) + .await + .unwrap(); + + // Fetch the current state + let current = state + .store + .get_message_by_name::("test-provider") + .await + .unwrap() + .unwrap(); + let current_version = current.metadata.as_ref().unwrap().resource_version; + + // Prepare an update with a stale resource_version + let mut stale_provider = current.clone(); + stale_provider + .credentials + .insert("NEW_KEY".to_string(), "new-value".to_string()); + stale_provider.metadata.as_mut().unwrap().resource_version = 99; // stale version + + // Update should fail with ABORTED + let err = handle_update_provider( + &state, + Request::new(UpdateProviderRequest { + provider: Some(stale_provider), + }), + ) + .await + .unwrap_err(); + + assert_eq!(err.code(), Code::Aborted); + assert!( + err.message().contains("modified concurrently") + || err.message().contains("resource_version"), + "error message should mention concurrency conflict: {}", + err.message() + ); + + // Verify the provider was not modified + let unchanged = state + .store + .get_message_by_name::("test-provider") + .await + .unwrap() + .unwrap(); + assert_eq!( + unchanged.metadata.as_ref().unwrap().resource_version, + current_version + ); + assert!(!unchanged.credentials.contains_key("NEW_KEY")); + } + + #[tokio::test] + async fn update_provider_concurrent_updates_with_stale_versions() { + use std::sync::Arc; + + let state = Arc::new(test_server_state().await); + + // Create a provider + let mut provider = provider_with_values("test-provider", "generic"); + provider.metadata.as_mut().unwrap().id = String::new(); + handle_create_provider( + &state, + Request::new(CreateProviderRequest { + provider: Some(provider.clone()), + }), + ) + .await + .unwrap(); + + // All three clients fetch the provider and see the same version + let initial = state + .store + .get_message_by_name::("test-provider") + .await + .unwrap() + .unwrap(); + let initial_version = initial.metadata.as_ref().unwrap().resource_version; + + // Launch 3 concurrent updates, all using the same initial version + let mut handles = vec![]; + for i in 0..3 { + let state_clone = Arc::clone(&state); + let mut updated = initial.clone(); + updated + .credentials + .insert(format!("KEY_{i}"), format!("value-{i}")); + updated.metadata.as_mut().unwrap().resource_version = initial_version; + + let handle = tokio::spawn(async move { + handle_update_provider( + &state_clone, + Request::new(UpdateProviderRequest { + provider: Some(updated), + }), + ) + .await + }); + handles.push(handle); + } + + let results: Vec<_> = futures::future::join_all(handles) + .await + .into_iter() + .map(|r| r.unwrap()) + .collect(); + + // Only one should succeed; others should get ABORTED + let successes = results.iter().filter(|r| r.is_ok()).count(); + let aborted_conflicts = results + .iter() + .filter(|r| r.as_ref().err().is_some_and(|e| e.code() == Code::Aborted)) + .count(); + + assert_eq!( + successes, 1, + "exactly one update should succeed with client-driven CAS" + ); + assert_eq!( + aborted_conflicts, 2, + "two updates should fail with ABORTED due to stale version" + ); + + // Final provider should have exactly 1 new credential key and resource_version = initial_version + 1 + let final_provider = state + .store + .get_message_by_name::("test-provider") + .await + .unwrap() + .unwrap(); + assert_eq!( + final_provider.metadata.as_ref().unwrap().resource_version, + initial_version + 1 + ); + + // Exactly one of KEY_0, KEY_1, or KEY_2 should be present + let new_keys_count = (0..3) + .filter(|i| final_provider.credentials.contains_key(&format!("KEY_{i}"))) + .count(); + assert_eq!(new_keys_count, 1); + } } diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index ad37a5482..ae4f857b1 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -10,9 +10,8 @@ #![allow(clippy::cast_possible_wrap)] // Intentional u32->i32 conversions for proto compat use crate::ServerState; -use crate::persistence::{ObjectType, generate_name}; +use crate::persistence::{ObjectType, WriteCondition, generate_name}; use futures::future; -use openshell_core::ObjectId; use openshell_core::proto::{ AttachSandboxProviderRequest, AttachSandboxProviderResponse, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteSandboxRequest, DeleteSandboxResponse, @@ -24,10 +23,12 @@ use openshell_core::proto::{ TcpRelayTarget, WatchSandboxRequest, relay_open, tcp_forward_init, }; use openshell_core::proto::{Sandbox, SandboxPhase, SandboxTemplate, SshSession}; +use openshell_core::{ObjectId, ObjectName}; use prost::Message; use std::net::IpAddr; use std::pin::Pin; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; @@ -110,6 +111,7 @@ pub(super) async fn handle_create_sandbox( name: name.clone(), created_at_ms: now_ms, labels: request.labels.clone(), + resource_version: 0, }), spec: Some(spec), status: None, @@ -218,6 +220,16 @@ pub(super) async fn handle_attach_sandbox_provider( return Err(Status::invalid_argument("provider_name is required")); } + // Validate provider name would not violate sandbox spec constraints if added + // (pre-validation ensures CAS mutations preserve invariants) + if request.provider_name.len() > super::MAX_NAME_LEN { + return Err(Status::invalid_argument(format!( + "provider_name exceeds maximum length ({} > {})", + request.provider_name.len(), + super::MAX_NAME_LEN + ))); + } + get_provider_record(state.store.as_ref(), &request.provider_name) .await .map_err(|err| { @@ -232,39 +244,61 @@ pub(super) async fn handle_attach_sandbox_provider( })?; let _sandbox_sync_guard = state.compute.sandbox_sync_guard().await; - let mut sandbox = sandbox_by_name(state, &request.sandbox_name).await?; - let sandbox_name = sandbox + let sandbox = sandbox_by_name(state, &request.sandbox_name).await?; + let sandbox_id = sandbox .metadata .as_ref() - .map_or_else(String::new, |metadata| metadata.name.clone()); + .ok_or_else(|| Status::internal("sandbox metadata is missing"))? + .id + .clone(); + + // Pre-check: fail fast if sandbox spec is missing (invariant violation) let spec = sandbox .spec - .as_mut() - .ok_or_else(|| Status::failed_precondition("sandbox spec is missing"))?; - - dedupe_provider_names(&mut spec.providers); - let attached = if spec - .providers - .iter() - .any(|name| name == &request.provider_name) + .as_ref() + .ok_or_else(|| Status::internal("sandbox spec is missing"))?; + + // Pre-check: fail fast if already at MAX_PROVIDERS limit (avoid spurious CAS conflicts) + // Note: This is an optimization; the CAS closure rechecks after dedupe in case of races + if spec.providers.len() >= MAX_PROVIDERS + && !spec + .providers + .iter() + .any(|name| name == &request.provider_name) { - false - } else { - if spec.providers.len() >= MAX_PROVIDERS { - return Err(Status::invalid_argument(format!( - "providers list exceeds maximum ({MAX_PROVIDERS})" - ))); - } - spec.providers.push(request.provider_name.clone()); - true - }; - validate_sandbox_spec(&sandbox_name, spec)?; + return Err(Status::invalid_argument(format!( + "providers list exceeds maximum ({MAX_PROVIDERS})" + ))); + } - state + let provider_name = request.provider_name.clone(); + let attached = Arc::new(AtomicBool::new(false)); + let attached_clone = attached.clone(); + + let sandbox = state .store - .put_message(&sandbox) + .update_message_cas::( + &sandbox_id, + request.expected_resource_version, + |sandbox| { + let Some(ref mut spec) = sandbox.spec else { + // Spec should always exist post-creation; if missing, fail CAS to surface error + return; + }; + + dedupe_provider_names(&mut spec.providers); + if !spec.providers.iter().any(|name| name == &provider_name) + && spec.providers.len() < MAX_PROVIDERS + { + spec.providers.push(provider_name.clone()); + attached_clone.store(true, Ordering::Relaxed); + } + }, + ) .await - .map_err(|e| Status::internal(format!("persist sandbox failed: {e}")))?; + .map_err(|e| super::persistence_error_to_status(e, "attach sandbox provider"))?; + + let attached = attached.load(Ordering::Relaxed); info!( sandbox_name = %request.sandbox_name, @@ -288,28 +322,58 @@ pub(super) async fn handle_detach_sandbox_provider( return Err(Status::invalid_argument("provider_name is required")); } + // Validate provider name (pre-validation ensures CAS mutations preserve invariants) + if request.provider_name.len() > super::MAX_NAME_LEN { + return Err(Status::invalid_argument(format!( + "provider_name exceeds maximum length ({} > {})", + request.provider_name.len(), + super::MAX_NAME_LEN + ))); + } + let _sandbox_sync_guard = state.compute.sandbox_sync_guard().await; - let mut sandbox = sandbox_by_name(state, &request.sandbox_name).await?; - let sandbox_name = sandbox + let sandbox = sandbox_by_name(state, &request.sandbox_name).await?; + let sandbox_id = sandbox .metadata .as_ref() - .map_or_else(String::new, |metadata| metadata.name.clone()); - let spec = sandbox + .ok_or_else(|| Status::internal("sandbox metadata is missing"))? + .id + .clone(); + + // Pre-check: fail fast if sandbox spec is missing (invariant violation) + let _spec = sandbox .spec - .as_mut() - .ok_or_else(|| Status::failed_precondition("sandbox spec is missing"))?; + .as_ref() + .ok_or_else(|| Status::internal("sandbox spec is missing"))?; - let before_len = spec.providers.len(); - spec.providers.retain(|name| name != &request.provider_name); - let detached = spec.providers.len() != before_len; - dedupe_provider_names(&mut spec.providers); - validate_sandbox_spec(&sandbox_name, spec)?; + let provider_name = request.provider_name.clone(); + let detached = Arc::new(AtomicBool::new(false)); + let detached_clone = detached.clone(); - state + let sandbox = state .store - .put_message(&sandbox) + .update_message_cas::( + &sandbox_id, + request.expected_resource_version, + |sandbox| { + let Some(ref mut spec) = sandbox.spec else { + // Spec should always exist post-creation; if missing, fail CAS to surface error + return; + }; + + let before_len = spec.providers.len(); + spec.providers.retain(|name| name != &provider_name); + if spec.providers.len() != before_len { + detached_clone.store(true, Ordering::Relaxed); + // Only dedupe after making a change + dedupe_provider_names(&mut spec.providers); + } + }, + ) .await - .map_err(|e| Status::internal(format!("persist sandbox failed: {e}")))?; + .map_err(|e| super::persistence_error_to_status(e, "detach sandbox provider"))?; + + let detached = detached.load(Ordering::Relaxed); info!( sandbox_name = %request.sandbox_name, @@ -1076,6 +1140,7 @@ pub(super) async fn handle_create_ssh_session( name: generate_name(), created_at_ms: now_ms, labels: std::collections::HashMap::new(), + resource_version: 0, }), sandbox_id: req.sandbox_id.clone(), token: token.clone(), @@ -1086,9 +1151,17 @@ pub(super) async fn handle_create_ssh_session( // Ensure metadata is valid (defense in depth - should always be true for server-constructed metadata) super::validation::validate_object_metadata(session.metadata.as_ref(), "ssh_session")?; + // Use MustCreate to atomically ensure the session token is unique state .store - .put_message(&session) + .put_if( + SshSession::object_type(), + &token, + session.object_name(), + &session.encode_to_vec(), + None, + WriteCondition::MustCreate, + ) .await .map_err(|e| Status::internal(format!("persist ssh session failed: {e}")))?; @@ -1129,12 +1202,26 @@ pub(super) async fn handle_revoke_ssh_session( return Ok(Response::new(RevokeSshSessionResponse { revoked: false })); }; + let resource_version = session + .metadata + .as_ref() + .map_or(0, |metadata| metadata.resource_version); + session.revoked = true; + + // Use CAS to prevent lost updates from concurrent revocations state .store - .put_message(&session) + .put_if( + SshSession::object_type(), + session.object_id(), + session.object_name(), + &session.encode_to_vec(), + None, + WriteCondition::MatchResourceVersion(resource_version), + ) .await - .map_err(|e| Status::internal(format!("persist ssh session failed: {e}")))?; + .map_err(|e| super::persistence_error_to_status(e, "revoke ssh session"))?; Ok(Response::new(RevokeSshSessionResponse { revoked: true })) } @@ -1687,6 +1774,7 @@ mod tests { name: name.to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), r#type: provider_type.to_string(), credentials: std::iter::once(("TOKEN".to_string(), "secret".to_string())).collect(), @@ -1701,6 +1789,7 @@ mod tests { name: name.to_string(), created_at_ms: 1_000_000, labels: std::iter::once(("team".to_string(), "agents".to_string())).collect(), + resource_version: 0, }), spec: Some(openshell_core::proto::SandboxSpec { log_level: "debug".to_string(), @@ -1733,6 +1822,7 @@ mod tests { Request::new(AttachSandboxProviderRequest { sandbox_name: "work".to_string(), provider_name: "work-github".to_string(), + expected_resource_version: 0, }), ) .await @@ -1775,6 +1865,7 @@ mod tests { Request::new(AttachSandboxProviderRequest { sandbox_name: "work".to_string(), provider_name: "work-github".to_string(), + expected_resource_version: 0, }), ) .await @@ -1815,6 +1906,7 @@ mod tests { Request::new(DetachSandboxProviderRequest { sandbox_name: "work".to_string(), provider_name: "work-github".to_string(), + expected_resource_version: 0, }), ) .await @@ -1838,6 +1930,7 @@ mod tests { Request::new(DetachSandboxProviderRequest { sandbox_name: "work".to_string(), provider_name: "work-github".to_string(), + expected_resource_version: 0, }), ) .await @@ -1892,6 +1985,7 @@ mod tests { Request::new(AttachSandboxProviderRequest { sandbox_name: "work".to_string(), provider_name: "missing".to_string(), + expected_resource_version: 0, }), ) .await @@ -1899,4 +1993,608 @@ mod tests { assert_eq!(err.code(), tonic::Code::FailedPrecondition); } + + #[tokio::test] + async fn attach_sandbox_provider_accepts_at_max_providers_limit() { + let state = test_server_state().await; + + // Create MAX_PROVIDERS (32) providers + for i in 0..MAX_PROVIDERS { + state + .store + .put_message(&test_provider(&format!("provider-{i}"), "generic")) + .await + .unwrap(); + } + + // Create sandbox with 31 providers already attached + let mut existing_providers = Vec::new(); + for i in 0..(MAX_PROVIDERS - 1) { + existing_providers.push(format!("provider-{i}")); + } + state + .store + .put_message(&test_sandbox("work", existing_providers)) + .await + .unwrap(); + + // Attaching the 32nd provider should succeed + let response = handle_attach_sandbox_provider( + &state, + Request::new(AttachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: "provider-31".to_string(), + expected_resource_version: 0, + }), + ) + .await + .unwrap() + .into_inner(); + + assert!(response.attached); + let providers = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap() + .spec + .unwrap() + .providers; + assert_eq!(providers.len(), MAX_PROVIDERS); + } + + #[tokio::test] + async fn attach_sandbox_provider_rejects_beyond_max_providers_limit() { + let state = test_server_state().await; + + // Create MAX_PROVIDERS + 1 providers + for i in 0..=MAX_PROVIDERS { + state + .store + .put_message(&test_provider(&format!("provider-{i}"), "generic")) + .await + .unwrap(); + } + + // Create sandbox with MAX_PROVIDERS already attached + let mut existing_providers = Vec::new(); + for i in 0..MAX_PROVIDERS { + existing_providers.push(format!("provider-{i}")); + } + state + .store + .put_message(&test_sandbox("work", existing_providers)) + .await + .unwrap(); + + // Attempting to attach the 33rd provider should fail + let err = handle_attach_sandbox_provider( + &state, + Request::new(AttachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: "provider-32".to_string(), + expected_resource_version: 0, + }), + ) + .await + .unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("exceeds maximum")); + + // Verify sandbox was not modified + let providers = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap() + .spec + .unwrap() + .providers; + assert_eq!(providers.len(), MAX_PROVIDERS); + } + + #[tokio::test] + async fn attach_sandbox_provider_pre_validation_fails_fast() { + let state = test_server_state().await; + + // Provider name that exceeds validation limits + let long_name = "a".repeat(1000); + state + .store + .put_message(&test_provider(&long_name, "generic")) + .await + .unwrap(); + + state + .store + .put_message(&test_sandbox("work", Vec::new())) + .await + .unwrap(); + + // Should fail validation before attempting CAS + let err = handle_attach_sandbox_provider( + &state, + Request::new(AttachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: long_name, + expected_resource_version: 0, + }), + ) + .await + .unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + } + + #[tokio::test] + async fn detach_sandbox_provider_pre_validation_rejects_invalid_names() { + let state = test_server_state().await; + state + .store + .put_message(&test_sandbox("work", vec!["valid".to_string()])) + .await + .unwrap(); + + // Provider name that exceeds validation limits + let long_name = "a".repeat(1000); + + let err = handle_detach_sandbox_provider( + &state, + Request::new(DetachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: long_name, + expected_resource_version: 0, + }), + ) + .await + .unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + } + + #[tokio::test] + async fn concurrent_create_ssh_session_prevents_duplicate_tokens() { + let state = test_server_state().await; + state + .store + .put_message(&test_sandbox("work", Vec::new())) + .await + .unwrap(); + + // Both requests try to create sessions for the same sandbox + // The token generation is random, so we can't force a collision, + // but we can verify that both succeed with different tokens + let state1 = state.clone(); + let handle1 = tokio::spawn(async move { + handle_create_ssh_session( + &state1, + Request::new(CreateSshSessionRequest { + sandbox_id: "sandbox-work".to_string(), + }), + ) + .await + }); + + let state2 = state.clone(); + let handle2 = tokio::spawn(async move { + handle_create_ssh_session( + &state2, + Request::new(CreateSshSessionRequest { + sandbox_id: "sandbox-work".to_string(), + }), + ) + .await + }); + + let result1 = handle1.await.unwrap(); + let result2 = handle2.await.unwrap(); + + // Both should succeed (tokens are random UUIDs, collision is astronomically unlikely) + assert!(result1.is_ok(), "first create should succeed"); + assert!(result2.is_ok(), "second create should succeed"); + + let token1 = result1.unwrap().into_inner().token; + let token2 = result2.unwrap().into_inner().token; + + // Tokens must be different + assert_ne!(token1, token2, "tokens should be unique"); + + // Both sessions should be in the database + let session1 = state + .store + .get_message::(&token1) + .await + .unwrap(); + let session2 = state + .store + .get_message::(&token2) + .await + .unwrap(); + assert!(session1.is_some()); + assert!(session2.is_some()); + } + + #[tokio::test] + async fn concurrent_revoke_ssh_session_handles_cas_properly() { + let state = test_server_state().await; + state + .store + .put_message(&test_sandbox("work", Vec::new())) + .await + .unwrap(); + + // Create a session first + let response = handle_create_ssh_session( + &state, + Request::new(CreateSshSessionRequest { + sandbox_id: "sandbox-work".to_string(), + }), + ) + .await + .unwrap(); + let token = response.into_inner().token; + + // Spawn two concurrent revocation attempts + let state1 = state.clone(); + let token1 = token.clone(); + let handle1 = tokio::spawn(async move { + handle_revoke_ssh_session( + &state1, + Request::new(RevokeSshSessionRequest { token: token1 }), + ) + .await + }); + + let state2 = state.clone(); + let token2 = token.clone(); + let handle2 = tokio::spawn(async move { + handle_revoke_ssh_session( + &state2, + Request::new(RevokeSshSessionRequest { token: token2 }), + ) + .await + }); + + let result1 = handle1.await.unwrap(); + let result2 = handle2.await.unwrap(); + + // One should succeed, one may fail with ABORTED due to CAS conflict + let successes = [&result1, &result2] + .iter() + .filter(|r| r.is_ok() && r.as_ref().unwrap().get_ref().revoked) + .count(); + + // At least one should succeed in revoking + assert!( + successes >= 1, + "at least one revocation should succeed, got: {result1:?}, {result2:?}" + ); + + // The session should be revoked in the database + let session = state.store.get_message::(&token).await.unwrap(); + assert!(session.is_some()); + assert!(session.unwrap().revoked, "session should be revoked"); + } + + // ---- CAS (Client-driven optimistic concurrency) tests ---- + + #[tokio::test] + async fn attach_sandbox_provider_client_driven_cas_succeeds_with_correct_version() { + let state = test_server_state().await; + state + .store + .put_message(&test_provider("github", "github")) + .await + .unwrap(); + state + .store + .put_message(&test_sandbox("work", Vec::new())) + .await + .unwrap(); + + // Fetch the sandbox to get its current resource_version + let sandbox = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap(); + let current_version = sandbox.metadata.as_ref().unwrap().resource_version; + + // Attach with correct expected_resource_version + let response = handle_attach_sandbox_provider( + &state, + Request::new(AttachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: "github".to_string(), + expected_resource_version: current_version, + }), + ) + .await + .unwrap() + .into_inner(); + + assert!(response.attached); + + // Verify the resource_version incremented + let updated_sandbox = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap(); + assert_eq!( + updated_sandbox.metadata.as_ref().unwrap().resource_version, + current_version + 1 + ); + } + + #[tokio::test] + async fn attach_sandbox_provider_client_driven_cas_rejects_stale_version() { + let state = test_server_state().await; + state + .store + .put_message(&test_provider("github", "github")) + .await + .unwrap(); + state + .store + .put_message(&test_sandbox("work", Vec::new())) + .await + .unwrap(); + + // Get current version + let sandbox = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap(); + let current_version = sandbox.metadata.as_ref().unwrap().resource_version; + + // Try to attach with a stale version (current_version - 1 would be 0, use 99 instead) + let err = handle_attach_sandbox_provider( + &state, + Request::new(AttachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: "github".to_string(), + expected_resource_version: 99, + }), + ) + .await + .unwrap_err(); + + // Should get ABORTED status for CAS conflict + assert_eq!(err.code(), tonic::Code::Aborted); + assert!( + err.message().contains("modified concurrently") + || err.message().contains("resource_version"), + "error message should mention concurrency conflict: {}", + err.message() + ); + + // Verify the sandbox was not modified + let unchanged_sandbox = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap(); + assert_eq!( + unchanged_sandbox + .metadata + .as_ref() + .unwrap() + .resource_version, + current_version + ); + assert!(unchanged_sandbox.spec.unwrap().providers.is_empty()); + } + + #[tokio::test] + async fn detach_sandbox_provider_client_driven_cas_succeeds_with_correct_version() { + let state = test_server_state().await; + state + .store + .put_message(&test_provider("github", "github")) + .await + .unwrap(); + state + .store + .put_message(&test_sandbox("work", vec!["github".to_string()])) + .await + .unwrap(); + + // Fetch the sandbox to get its current resource_version + let sandbox = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap(); + let current_version = sandbox.metadata.as_ref().unwrap().resource_version; + + // Detach with correct expected_resource_version + let response = handle_detach_sandbox_provider( + &state, + Request::new(DetachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: "github".to_string(), + expected_resource_version: current_version, + }), + ) + .await + .unwrap() + .into_inner(); + + assert!(response.detached); + + // Verify the resource_version incremented + let updated_sandbox = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap(); + assert_eq!( + updated_sandbox.metadata.as_ref().unwrap().resource_version, + current_version + 1 + ); + } + + #[tokio::test] + async fn detach_sandbox_provider_client_driven_cas_rejects_stale_version() { + let state = test_server_state().await; + state + .store + .put_message(&test_provider("github", "github")) + .await + .unwrap(); + state + .store + .put_message(&test_sandbox("work", vec!["github".to_string()])) + .await + .unwrap(); + + // Get current version + let sandbox = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap(); + let current_version = sandbox.metadata.as_ref().unwrap().resource_version; + + // Try to detach with a stale version + let err = handle_detach_sandbox_provider( + &state, + Request::new(DetachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: "github".to_string(), + expected_resource_version: 99, + }), + ) + .await + .unwrap_err(); + + // Should get ABORTED status for CAS conflict + assert_eq!(err.code(), tonic::Code::Aborted); + assert!( + err.message().contains("modified concurrently") + || err.message().contains("resource_version"), + "error message should mention concurrency conflict: {}", + err.message() + ); + + // Verify the sandbox was not modified + let unchanged_sandbox = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap(); + assert_eq!( + unchanged_sandbox + .metadata + .as_ref() + .unwrap() + .resource_version, + current_version + ); + assert_eq!(unchanged_sandbox.spec.unwrap().providers, vec!["github"]); + } + + #[tokio::test] + async fn attach_sandbox_provider_concurrent_with_stale_versions() { + use std::sync::Arc; + + let state = Arc::new(test_server_state().await); + + // Create multiple providers + for i in 0..3 { + state + .store + .put_message(&test_provider(&format!("provider-{i}"), "generic")) + .await + .unwrap(); + } + + state + .store + .put_message(&test_sandbox("work", Vec::new())) + .await + .unwrap(); + + // All three clients fetch the sandbox and see version 1 + let initial_version = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap() + .metadata + .as_ref() + .unwrap() + .resource_version; + + // Launch 3 concurrent attach operations, all using the same initial version + let mut handles = vec![]; + for i in 0..3 { + let state_clone = Arc::clone(&state); + let handle = tokio::spawn(async move { + handle_attach_sandbox_provider( + &state_clone, + Request::new(AttachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: format!("provider-{i}"), + expected_resource_version: initial_version, + }), + ) + .await + }); + handles.push(handle); + } + + let results: Vec<_> = future::join_all(handles) + .await + .into_iter() + .map(|r| r.unwrap()) + .collect(); + + // Only one should succeed; others should get ABORTED + let successes = results.iter().filter(|r| r.is_ok()).count(); + let aborted_conflicts = results + .iter() + .filter(|r| { + r.as_ref() + .err() + .is_some_and(|e| e.code() == tonic::Code::Aborted) + }) + .count(); + + assert_eq!( + successes, 1, + "exactly one attach should succeed with client-driven CAS" + ); + assert_eq!( + aborted_conflicts, 2, + "two attaches should fail with ABORTED due to stale version" + ); + + // Final sandbox should have exactly 1 provider and resource_version = initial_version + 1 + let final_sandbox = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap(); + assert_eq!(final_sandbox.spec.as_ref().unwrap().providers.len(), 1); + assert_eq!( + final_sandbox.metadata.as_ref().unwrap().resource_version, + initial_version + 1 + ); + } } diff --git a/crates/openshell-server/src/grpc/service.rs b/crates/openshell-server/src/grpc/service.rs index 66d07f51c..1c6fd5cc1 100644 --- a/crates/openshell-server/src/grpc/service.rs +++ b/crates/openshell-server/src/grpc/service.rs @@ -15,7 +15,7 @@ use tonic::{Request, Response, Status}; use uuid::Uuid; use crate::ServerState; -use crate::persistence::ObjectType; +use crate::persistence::{ObjectType, WriteCondition}; use crate::service_routing; const MAX_SERVICE_NAME_LEN: usize = 28; @@ -42,29 +42,52 @@ pub(super) async fn handle_expose_service( let now = super::current_time_ms().map_err(|e| Status::internal(format!("clock error: {e}")))?; let key = service_routing::endpoint_key(&req.sandbox, &req.service); - let (id, created_at_ms, created) = match state + + // Fetch existing endpoint to determine create vs. update path + let existing = state .store .get_message_by_name::(&key) .await - { - Ok(Some(existing)) => ( + .map_err(|e| Status::internal(format!("fetch endpoint failed: {e}")))?; + + let (id, created_at_ms, condition, created) = if let Some(existing) = existing { + // Update path: preserve id and created_at, use CAS to prevent conflicts + let resource_version = existing + .metadata + .as_ref() + .map_or(0, |metadata| metadata.resource_version); + ( existing.object_id().to_string(), existing .metadata .as_ref() .map_or(now, |metadata| metadata.created_at_ms), + WriteCondition::MatchResourceVersion(resource_version), false, - ), - Ok(None) => (Uuid::new_v4().to_string(), now, true), - Err(e) => return Err(Status::internal(format!("fetch endpoint failed: {e}"))), + ) + } else { + // Create path: new id and created_at, use MustCreate to prevent races + ( + Uuid::new_v4().to_string(), + now, + WriteCondition::MustCreate, + true, + ) }; + let labels_json = serde_json::to_string(&HashMap::from([( + "sandbox".to_string(), + req.sandbox.clone(), + )])) + .map_err(|e| Status::internal(format!("serialize labels failed: {e}")))?; + let endpoint = ServiceEndpoint { metadata: Some(ObjectMeta { - id, - name: key, + id: id.clone(), + name: key.clone(), created_at_ms, labels: HashMap::from([("sandbox".to_string(), req.sandbox.clone())]), + resource_version: 0, }), sandbox_id: sandbox.object_id().to_string(), sandbox_name: req.sandbox.clone(), @@ -73,11 +96,19 @@ pub(super) async fn handle_expose_service( domain: true, }; + // Single-attempt CAS write: fails with ABORTED on concurrent modification state .store - .put_message(&endpoint) + .put_if( + ServiceEndpoint::object_type(), + &id, + &key, + &endpoint.encode_to_vec(), + Some(&labels_json), + condition, + ) .await - .map_err(|e| Status::internal(format!("persist endpoint failed: {e}")))?; + .map_err(|e| super::persistence_error_to_status(e, "expose service"))?; let url = service_routing::endpoint_url(&state.config, &req.sandbox, &req.service) .unwrap_or_default(); @@ -280,6 +311,7 @@ mod tests { name: name.to_string(), created_at_ms: 1_000, labels: HashMap::new(), + resource_version: 0, }), spec: Some(openshell_core::proto::SandboxSpec::default()), phase: SandboxPhase::Ready as i32, @@ -398,4 +430,142 @@ mod tests { .into_inner(); assert!(listed.services.is_empty()); } + + #[tokio::test] + async fn concurrent_expose_service_handles_cas_properly() { + let state = test_server_state().await; + seed_sandbox(&state, "my-sandbox").await; + + // Spawn two concurrent expose_service calls for the same endpoint + let state1 = state.clone(); + let handle1 = tokio::spawn(async move { + handle_expose_service( + &state1, + Request::new(ExposeServiceRequest { + sandbox: "my-sandbox".to_string(), + service: "web".to_string(), + target_port: 8080, + domain: true, + }), + ) + .await + }); + + let state2 = state.clone(); + let handle2 = tokio::spawn(async move { + handle_expose_service( + &state2, + Request::new(ExposeServiceRequest { + sandbox: "my-sandbox".to_string(), + service: "web".to_string(), + target_port: 9090, + domain: true, + }), + ) + .await + }); + + let result1 = handle1.await.unwrap(); + let result2 = handle2.await.unwrap(); + + // One should succeed with MustCreate, the other may fail with ABORTED or succeed with update + let successes = [&result1, &result2].iter().filter(|r| r.is_ok()).count(); + + // At least one should succeed + assert!( + successes >= 1, + "at least one expose should succeed, got: {result1:?}, {result2:?}" + ); + + // Only one endpoint should exist + let listed = handle_list_services( + &state, + Request::new(ListServicesRequest { + sandbox: "my-sandbox".to_string(), + limit: 0, + offset: 0, + }), + ) + .await + .unwrap() + .into_inner(); + assert_eq!(listed.services.len(), 1); + } + + #[tokio::test] + async fn concurrent_expose_service_update_uses_cas() { + let state = test_server_state().await; + seed_sandbox(&state, "my-sandbox").await; + + // Create an initial endpoint + handle_expose_service( + &state, + Request::new(ExposeServiceRequest { + sandbox: "my-sandbox".to_string(), + service: "web".to_string(), + target_port: 7070, + domain: true, + }), + ) + .await + .unwrap(); + + // Spawn two concurrent updates + let state1 = state.clone(); + let handle1 = tokio::spawn(async move { + handle_expose_service( + &state1, + Request::new(ExposeServiceRequest { + sandbox: "my-sandbox".to_string(), + service: "web".to_string(), + target_port: 8080, + domain: true, + }), + ) + .await + }); + + let state2 = state.clone(); + let handle2 = tokio::spawn(async move { + handle_expose_service( + &state2, + Request::new(ExposeServiceRequest { + sandbox: "my-sandbox".to_string(), + service: "web".to_string(), + target_port: 9090, + domain: true, + }), + ) + .await + }); + + let result1 = handle1.await.unwrap(); + let result2 = handle2.await.unwrap(); + + // One should succeed, one may fail with ABORTED due to CAS conflict + let successes = [&result1, &result2].iter().filter(|r| r.is_ok()).count(); + + assert!( + successes >= 1, + "at least one update should succeed, got: {result1:?}, {result2:?}" + ); + + // The endpoint should have one of the new port values + let fetched = handle_get_service( + &state, + Request::new(GetServiceRequest { + sandbox: "my-sandbox".to_string(), + service: "web".to_string(), + }), + ) + .await + .unwrap() + .into_inner(); + let port = fetched.endpoint.as_ref().unwrap().target_port; + assert!( + port == 8080 || port == 9090, + "port should be one of the updated values, got {port}" + ); + assert_ne!(port, 7070, "port should not be the original value"); + } } diff --git a/crates/openshell-server/src/grpc/validation.rs b/crates/openshell-server/src/grpc/validation.rs index 160b7e031..dbc380d82 100644 --- a/crates/openshell-server/src/grpc/validation.rs +++ b/crates/openshell-server/src/grpc/validation.rs @@ -874,6 +874,7 @@ mod tests { name: name.to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), r#type: provider_type.to_string(), credentials, diff --git a/crates/openshell-server/src/inference.rs b/crates/openshell-server/src/inference.rs index b52700f0d..4af9bfcc9 100644 --- a/crates/openshell-server/src/inference.rs +++ b/crates/openshell-server/src/inference.rs @@ -3,6 +3,7 @@ #![allow(clippy::result_large_err)] // gRPC handlers return Result, Status> +use openshell_core::ObjectId; use openshell_core::proto::{ ClusterInferenceConfig, GetClusterInferenceRequest, GetClusterInferenceResponse, GetInferenceBundleRequest, GetInferenceBundleResponse, InferenceRoute, Provider, ResolvedRoute, @@ -11,13 +12,14 @@ use openshell_core::proto::{ }; use openshell_router::config::ResolvedRoute as RouterResolvedRoute; use openshell_router::{ValidationFailureKind, verify_backend_endpoint}; +use prost::Message as _; use std::sync::Arc; use std::time::Duration; use tonic::{Request, Response, Status}; use crate::{ ServerState, - persistence::{ObjectName, ObjectType, Store, current_time_ms}, + persistence::{ObjectName, ObjectType, Store, WriteCondition, current_time_ms}, }; #[derive(Debug)] @@ -169,6 +171,7 @@ async fn upsert_cluster_inference_route( let config = build_cluster_inference_config(&provider, model_id, timeout_secs); + // Fetch existing route to determine create vs. update path let existing = store .get_message_by_name::(route_name) .await @@ -177,32 +180,49 @@ async fn upsert_cluster_inference_route( let now_ms = current_time_ms().map_err(|e| Status::internal(format!("get current time: {e}")))?; - let route = if let Some(existing) = existing { - InferenceRoute { - metadata: existing.metadata.clone(), - config: Some(config), - version: existing.version.saturating_add(1), - } + let (id, metadata, new_version, condition) = if let Some(existing) = existing { + // Update path: preserve metadata, increment version, use CAS + let resource_version = existing.metadata.as_ref().map_or(0, |m| m.resource_version); + ( + existing.object_id().to_string(), + existing.metadata.clone(), + existing.version.saturating_add(1), + WriteCondition::MatchResourceVersion(resource_version), + ) } else { - InferenceRoute { - metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { - id: uuid::Uuid::new_v4().to_string(), - name: route_name.to_string(), - created_at_ms: now_ms, - labels: std::collections::HashMap::new(), - }), - config: Some(config), - version: 1, - } + // Create path: new metadata, version 1, use MustCreate + let new_id = uuid::Uuid::new_v4().to_string(); + let new_metadata = Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: new_id.clone(), + name: route_name.to_string(), + created_at_ms: now_ms, + labels: std::collections::HashMap::new(), + resource_version: 0, + }); + (new_id, new_metadata, 1, WriteCondition::MustCreate) + }; + + let route = InferenceRoute { + metadata, + config: Some(config), + version: new_version, }; // Ensure metadata is valid (defense in depth - should always be true for server-constructed metadata) crate::grpc::validate_object_metadata(route.metadata.as_ref(), "inference_route")?; + // Single-attempt CAS write: fails with ABORTED on concurrent modification store - .put_message(&route) + .put_if( + InferenceRoute::object_type(), + &id, + route_name, + &route.encode_to_vec(), + None, + condition, + ) .await - .map_err(|e| Status::internal(format!("persist route failed: {e}")))?; + .map_err(|e| crate::grpc::persistence_error_to_status(e, "upsert inference route"))?; Ok(UpsertedInferenceRoute { route, validation }) } @@ -490,6 +510,7 @@ mod tests { name: name.to_string(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), config: Some(ClusterInferenceConfig { provider_name: provider_name.to_string(), @@ -507,6 +528,7 @@ mod tests { name: name.to_string(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), r#type: provider_type.to_string(), credentials: std::iter::once((key_name.to_string(), key_value.to_string())).collect(), @@ -666,6 +688,7 @@ mod tests { name: "openai-dev".to_string(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), r#type: "openai".to_string(), credentials: std::iter::once(("OPENAI_API_KEY".to_string(), "sk-test".to_string())) @@ -687,6 +710,7 @@ mod tests { name: CLUSTER_INFERENCE_ROUTE_NAME.to_string(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), config: Some(ClusterInferenceConfig { provider_name: "openai-dev".to_string(), @@ -1047,4 +1071,153 @@ mod tests { let err = effective_route_name("unknown-route").unwrap_err(); assert_eq!(err.code(), tonic::Code::InvalidArgument); } + + #[tokio::test] + async fn concurrent_upsert_route_create_uses_must_create() { + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .expect("store"); + + let provider = make_provider("openai-dev", "openai", "OPENAI_API_KEY", "sk-test"); + store.put_message(&provider).await.expect("persist"); + + // Spawn two concurrent upsert calls for the same route (create path) + let store1 = store.clone(); + let handle1 = tokio::spawn(async move { + upsert_cluster_inference_route( + &store1, + CLUSTER_INFERENCE_ROUTE_NAME, + "openai-dev", + "gpt-4o", + 0, + false, + ) + .await + }); + + let store2 = store.clone(); + let handle2 = tokio::spawn(async move { + upsert_cluster_inference_route( + &store2, + CLUSTER_INFERENCE_ROUTE_NAME, + "openai-dev", + "gpt-4.1", + 0, + false, + ) + .await + }); + + let result1 = handle1.await.unwrap(); + let result2 = handle2.await.unwrap(); + + // One should succeed with MustCreate, the other should fail + let successes = [&result1, &result2].iter().filter(|r| r.is_ok()).count(); + let failures = [&result1, &result2] + .iter() + .filter(|r| { + r.as_ref().is_err_and(|e| { + // Accept either ABORTED (from CAS) or Internal (from DB unique constraint) + e.code() == tonic::Code::Aborted + || (e.code() == tonic::Code::Internal + && e.message().contains("unique violation")) + }) + }) + .count(); + + assert_eq!( + successes, 1, + "exactly one create should succeed, got: {result1:?}, {result2:?}" + ); + assert_eq!( + failures, 1, + "exactly one create should fail, got: {result1:?}, {result2:?}" + ); + + // Only one route should exist + let route = store + .get_message_by_name::(CLUSTER_INFERENCE_ROUTE_NAME) + .await + .expect("fetch") + .expect("route should exist"); + assert_eq!(route.version, 1); + } + + #[tokio::test] + async fn concurrent_upsert_route_update_uses_cas() { + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .expect("store"); + + let provider = make_provider("openai-dev", "openai", "OPENAI_API_KEY", "sk-test"); + store.put_message(&provider).await.expect("persist"); + + // Create initial route + upsert_cluster_inference_route( + &store, + CLUSTER_INFERENCE_ROUTE_NAME, + "openai-dev", + "gpt-3.5", + 0, + false, + ) + .await + .expect("initial create should succeed"); + + // Spawn two concurrent updates + let store1 = store.clone(); + let handle1 = tokio::spawn(async move { + upsert_cluster_inference_route( + &store1, + CLUSTER_INFERENCE_ROUTE_NAME, + "openai-dev", + "gpt-4o", + 0, + false, + ) + .await + }); + + let store2 = store.clone(); + let handle2 = tokio::spawn(async move { + upsert_cluster_inference_route( + &store2, + CLUSTER_INFERENCE_ROUTE_NAME, + "openai-dev", + "gpt-4.1", + 0, + false, + ) + .await + }); + + let result1 = handle1.await.unwrap(); + let result2 = handle2.await.unwrap(); + + // One should succeed, one may fail with ABORTED due to CAS conflict + let successes = [&result1, &result2].iter().filter(|r| r.is_ok()).count(); + + assert!( + successes >= 1, + "at least one update should succeed, got: {result1:?}, {result2:?}" + ); + + // The route should have one of the new model values and version 2 + let route = store + .get_message_by_name::(CLUSTER_INFERENCE_ROUTE_NAME) + .await + .expect("fetch") + .expect("route should exist"); + let config = route.config.expect("config"); + assert!( + config.model_id == "gpt-4o" || config.model_id == "gpt-4.1", + "model should be one of the updated values, got {}", + config.model_id + ); + assert_ne!( + config.model_id, "gpt-3.5", + "model should not be the original value" + ); + assert_eq!(route.version, 2, "version should be incremented to 2"); + } } diff --git a/crates/openshell-server/src/persistence/mod.rs b/crates/openshell-server/src/persistence/mod.rs index 1c926bd4a..2ec0991fe 100644 --- a/crates/openshell-server/src/persistence/mod.rs +++ b/crates/openshell-server/src/persistence/mod.rs @@ -41,6 +41,10 @@ pub enum PersistenceError { detail: Option, constraint_msg: String, }, + #[error("resource version conflict: expected version does not match current")] + Conflict { + current_resource_version: Option, + }, } impl PersistenceError { @@ -78,6 +82,28 @@ pub struct ObjectRecord { pub updated_at_ms: i64, /// JSON-serialized labels (key-value pairs). pub labels: Option, + /// Optimistic concurrency control version. + /// Incremented on each update for compare-and-swap operations. + pub resource_version: u64, +} + +/// Write condition for compare-and-swap operations. +#[derive(Debug, Clone, Copy)] +pub enum WriteCondition { + /// Object must not exist (insert only). + MustCreate, + /// Object must exist with the specified resource version (update only). + MatchResourceVersion(u64), + /// Unconditional write (insert or update). + Unconditional, +} + +/// Result of a successful write operation. +#[derive(Debug, Clone)] +pub struct WriteResult { + pub resource_version: u64, + pub created_at_ms: i64, + pub updated_at_ms: i64, } /// Persistence store implementations. @@ -94,7 +120,9 @@ pub trait ObjectType { // Import object metadata accessor traits from openshell-core // (implementations for all proto types are in openshell-core::metadata) -pub use openshell_core::{ObjectId, ObjectLabels, ObjectName}; +pub use openshell_core::{ + GetResourceVersion, ObjectId, ObjectLabels, ObjectName, SetResourceVersion, +}; /// Generate a random 6-character lowercase alphabetic name. pub fn generate_name() -> String { @@ -147,6 +175,74 @@ impl Store { } } + /// Insert or update a generic object with compare-and-swap support. + /// + /// # Arguments + /// * `object_type` - Type discriminator for the object + /// * `id` - Stable object identifier + /// * `name` - Human-readable object name + /// * `payload` - Serialized object data + /// * `labels` - Optional JSON-serialized labels + /// * `condition` - Write precondition (`MustCreate`, `MatchResourceVersion`, or `Unconditional`) + /// + /// # Returns + /// * `Ok(WriteResult)` - Write succeeded with new `resource_version` and timestamps + /// * `Err(Conflict)` - Resource version mismatch (for `MatchResourceVersion`) + /// * `Err(UniqueViolation)` - Object already exists (for `MustCreate`) or name conflict + pub async fn put_if( + &self, + object_type: &str, + id: &str, + name: &str, + payload: &[u8], + labels: Option<&str>, + condition: WriteCondition, + ) -> PersistenceResult { + match self { + Self::Postgres(store) => { + store + .put_if(object_type, id, name, payload, labels, condition) + .await + } + Self::Sqlite(store) => { + store + .put_if(object_type, id, name, payload, labels, condition) + .await + } + } + } + + /// Delete an object by id with compare-and-swap support. + /// + /// # Arguments + /// * `object_type` - Type discriminator for the object + /// * `id` - Stable object identifier + /// * `expected_resource_version` - Required resource version for the delete to proceed + /// + /// # Returns + /// * `Ok(true)` - Object was deleted + /// * `Ok(false)` - Object not found + /// * `Err(Conflict)` - Resource version mismatch + pub async fn delete_if( + &self, + object_type: &str, + id: &str, + expected_resource_version: u64, + ) -> PersistenceResult { + match self { + Self::Postgres(store) => { + store + .delete_if(object_type, id, expected_resource_version) + .await + } + Self::Sqlite(store) => { + store + .delete_if(object_type, id, expected_resource_version) + .await + } + } + } + /// Fetch an object by id. pub async fn get( &self, @@ -253,7 +349,7 @@ impl Store { } /// Fetch and decode a protobuf message by id. - pub async fn get_message( + pub async fn get_message( &self, id: &str, ) -> PersistenceResult> { @@ -262,13 +358,17 @@ impl Store { return Ok(None); }; - T::decode(record.payload.as_slice()) - .map(Some) - .map_err(|e| PersistenceError::Decode(format!("protobuf decode error: {e}"))) + let mut message = T::decode(record.payload.as_slice()) + .map_err(|e| PersistenceError::Decode(format!("protobuf decode error: {e}")))?; + + // Hydrate resource_version from DB row (authoritative source) + message.set_resource_version(record.resource_version); + + Ok(Some(message)) } /// Fetch and decode a protobuf message by name. - pub async fn get_message_by_name( + pub async fn get_message_by_name( &self, name: &str, ) -> PersistenceResult> { @@ -277,9 +377,101 @@ impl Store { return Ok(None); }; - T::decode(record.payload.as_slice()) - .map(Some) - .map_err(|e| PersistenceError::Decode(format!("protobuf decode error: {e}"))) + let mut message = T::decode(record.payload.as_slice()) + .map_err(|e| PersistenceError::Decode(format!("protobuf decode error: {e}")))?; + + // Hydrate resource_version from DB row (authoritative source) + message.set_resource_version(record.resource_version); + + Ok(Some(message)) + } + + /// Update a protobuf message using CAS (compare-and-swap). + /// + /// Fetches the current object, validates the expected version, applies the + /// mutation function, and attempts a single CAS write. Returns Conflict on + /// version mismatch for caller-driven retry. + /// + /// # Arguments + /// * `id` - Object ID to update + /// * `expected_version` - Required resource version for the update to proceed. + /// Pass 0 to use the current version (internal operations only). + /// For client-facing operations, pass the client-provided expected version. + /// * `mutate` - Function that modifies the object in place + /// + /// # Returns + /// * `Ok(T)` - Successfully updated object with new `resource_version` + /// * `Err(Conflict)` - Version mismatch; caller should retry + /// * `Err(Database)` - Object not found or other DB error + pub async fn update_message_cas( + &self, + id: &str, + expected_version: u64, + mut mutate: F, + ) -> PersistenceResult + where + T: Message + + Default + + ObjectType + + ObjectId + + ObjectName + + ObjectLabels + + SetResourceVersion + + GetResourceVersion + + Clone, + F: FnMut(&mut T), + { + // Fetch current object with authoritative resource_version + let current = self + .get_message::(id) + .await? + .ok_or_else(|| PersistenceError::Database(format!("object {id} not found")))?; + + let current_version = current.get_resource_version(); + + // Determine the version to use for CAS: + // - If expected_version is 0, use current version (internal operations) + // - Otherwise, validate that expected matches current (client-facing operations) + let cas_version = if expected_version == 0 { + current_version + } else { + if expected_version != current_version { + return Err(PersistenceError::Conflict { + current_resource_version: Some(current_version), + }); + } + expected_version + }; + + // Apply mutation + let mut updated = current.clone(); + mutate(&mut updated); + + // Serialize labels + let labels_map = updated.object_labels(); + let labels_json = if labels_map.as_ref().is_none_or(HashMap::is_empty) { + None + } else { + Some(serde_json::to_string(&labels_map).map_err(|e| { + PersistenceError::Encode(format!("failed to serialize labels: {e}")) + })?) + }; + + // Single-attempt CAS write - fails with Conflict on version mismatch + let result = self + .put_if( + T::object_type(), + updated.object_id(), + updated.object_name(), + &updated.encode_to_vec(), + labels_json.as_deref(), + WriteCondition::MatchResourceVersion(cas_version), + ) + .await?; + + // Success - hydrate the new resource_version and return + updated.set_resource_version(result.resource_version); + Ok(updated) } } diff --git a/crates/openshell-server/src/persistence/postgres.rs b/crates/openshell-server/src/persistence/postgres.rs index d9167b63b..6ae0d2460 100644 --- a/crates/openshell-server/src/persistence/postgres.rs +++ b/crates/openshell-server/src/persistence/postgres.rs @@ -2,8 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 use super::{ - DraftChunkRecord, ObjectRecord, PersistenceResult, PolicyRecord, current_time_ms, map_db_error, - map_migrate_error, + DraftChunkRecord, ObjectRecord, PersistenceError, PersistenceResult, PolicyRecord, + WriteCondition, WriteResult, current_time_ms, map_db_error, map_migrate_error, }; use crate::policy_store::{ draft_chunk_payload_from_record, draft_chunk_record_from_parts, policy_payload_from_record, @@ -52,7 +52,7 @@ impl PostgresStore { let labels_jsonb: Option = labels .map(serde_json::from_str) .transpose() - .map_err(|e| super::PersistenceError::Encode(format!("invalid labels JSON: {e}")))?; + .map_err(|e| PersistenceError::Encode(format!("invalid labels JSON: {e}")))?; sqlx::query( r" @@ -76,6 +76,157 @@ ON CONFLICT (object_type, name) WHERE name IS NOT NULL DO UPDATE SET Ok(()) } + pub async fn put_if( + &self, + object_type: &str, + id: &str, + name: &str, + payload: &[u8], + labels: Option<&str>, + condition: WriteCondition, + ) -> PersistenceResult { + let now_ms = current_time_ms()?; + let labels_jsonb: Option = labels + .map(serde_json::from_str) + .transpose() + .map_err(|e| PersistenceError::Encode(format!("invalid labels JSON: {e}")))?; + + match condition { + WriteCondition::MustCreate => { + // Insert only - fail if object exists + let row = sqlx::query( + r" +INSERT INTO objects (object_type, id, name, payload, created_at_ms, updated_at_ms, labels, resource_version) +VALUES ($1, $2, $3, $4, $5, $5, COALESCE($6, '{}'::jsonb), 1) +RETURNING resource_version, created_at_ms, updated_at_ms +", + ) + .bind(object_type) + .bind(id) + .bind(name) + .bind(payload) + .bind(now_ms) + .bind(labels_jsonb) + .fetch_one(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + let resource_version_i64: i64 = row.try_get("resource_version").unwrap_or(1); + Ok(WriteResult { + resource_version: resource_version_i64.max(1).cast_unsigned(), + created_at_ms: row.get("created_at_ms"), + updated_at_ms: row.get("updated_at_ms"), + }) + } + WriteCondition::MatchResourceVersion(expected_version) => { + // Update with version check using RETURNING + let row_result = sqlx::query( + r" +UPDATE objects +SET payload = $4, labels = COALESCE($5, '{}'::jsonb), updated_at_ms = $6, resource_version = resource_version + 1 +WHERE object_type = $1 AND id = $2 AND resource_version = $3 +RETURNING resource_version, created_at_ms, updated_at_ms +", + ) + .bind(object_type) + .bind(id) + .bind(i64::try_from(expected_version).unwrap_or(i64::MAX)) + .bind(payload) + .bind(labels_jsonb) + .bind(now_ms) + .fetch_optional(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + if let Some(row) = row_result { + let resource_version_i64: i64 = row.try_get("resource_version").unwrap_or(1); + Ok(WriteResult { + resource_version: resource_version_i64.max(1).cast_unsigned(), + created_at_ms: row.get("created_at_ms"), + updated_at_ms: row.get("updated_at_ms"), + }) + } else { + // Check if object exists to distinguish NotFound from Conflict + let existing = self.get(object_type, id).await?; + if let Some(record) = existing { + Err(PersistenceError::Conflict { + current_resource_version: Some(record.resource_version), + }) + } else { + Err(PersistenceError::Database(format!( + "object not found: {object_type}/{id}" + ))) + } + } + } + WriteCondition::Unconditional => { + // Unconditional upsert by name + let row = sqlx::query( + r" +INSERT INTO objects (object_type, id, name, payload, created_at_ms, updated_at_ms, labels, resource_version) +VALUES ($1, $2, $3, $4, $5, $5, COALESCE($6, '{}'::jsonb), 1) +ON CONFLICT (object_type, name) WHERE name IS NOT NULL DO UPDATE SET + payload = EXCLUDED.payload, + updated_at_ms = EXCLUDED.updated_at_ms, + labels = EXCLUDED.labels, + resource_version = objects.resource_version + 1 +RETURNING resource_version, created_at_ms, updated_at_ms +", + ) + .bind(object_type) + .bind(id) + .bind(name) + .bind(payload) + .bind(now_ms) + .bind(labels_jsonb) + .fetch_one(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + let resource_version_i64: i64 = row.try_get("resource_version").unwrap_or(1); + Ok(WriteResult { + resource_version: resource_version_i64.max(1).cast_unsigned(), + created_at_ms: row.get("created_at_ms"), + updated_at_ms: row.get("updated_at_ms"), + }) + } + } + } + + pub async fn delete_if( + &self, + object_type: &str, + id: &str, + expected_resource_version: u64, + ) -> PersistenceResult { + let result = sqlx::query( + r" +DELETE FROM objects +WHERE object_type = $1 AND id = $2 AND resource_version = $3 +", + ) + .bind(object_type) + .bind(id) + .bind(i64::try_from(expected_resource_version).unwrap_or(i64::MAX)) + .execute(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + if result.rows_affected() > 0 { + Ok(true) + } else { + // Check if object exists to distinguish NotFound from Conflict + let existing = self.get(object_type, id).await?; + if let Some(record) = existing { + Err(PersistenceError::Conflict { + current_resource_version: Some(record.resource_version), + }) + } else { + Ok(false) + } + } + } + pub async fn get( &self, object_type: &str, @@ -83,7 +234,7 @@ ON CONFLICT (object_type, name) WHERE name IS NOT NULL DO UPDATE SET ) -> PersistenceResult> { let row = sqlx::query( r" -SELECT object_type, id, name, payload, created_at_ms, updated_at_ms, labels +SELECT object_type, id, name, payload, created_at_ms, updated_at_ms, labels, resource_version FROM objects WHERE object_type = $1 AND id = $2 ", @@ -104,7 +255,7 @@ WHERE object_type = $1 AND id = $2 ) -> PersistenceResult> { let row = sqlx::query( r" -SELECT object_type, id, name, payload, created_at_ms, updated_at_ms, labels +SELECT object_type, id, name, payload, created_at_ms, updated_at_ms, labels, resource_version FROM objects WHERE object_type = $1 AND name = $2 ", @@ -146,7 +297,7 @@ WHERE object_type = $1 AND name = $2 ) -> PersistenceResult> { let rows = sqlx::query( r" -SELECT object_type, id, name, payload, created_at_ms, updated_at_ms, labels +SELECT object_type, id, name, payload, created_at_ms, updated_at_ms, labels, resource_version FROM objects WHERE object_type = $1 ORDER BY created_at_ms ASC, name ASC @@ -173,13 +324,12 @@ LIMIT $2 OFFSET $3 use super::parse_label_selector; let required_labels = parse_label_selector(label_selector)?; - let labels_jsonb = serde_json::to_value(&required_labels).map_err(|e| { - super::PersistenceError::Encode(format!("failed to serialize labels: {e}")) - })?; + let labels_jsonb = serde_json::to_value(&required_labels) + .map_err(|e| PersistenceError::Encode(format!("failed to serialize labels: {e}")))?; let rows = sqlx::query( r" -SELECT object_type, id, name, payload, created_at_ms, updated_at_ms, labels +SELECT object_type, id, name, payload, created_at_ms, updated_at_ms, labels, resource_version FROM objects WHERE object_type = $1 AND labels @> $2 ORDER BY created_at_ms ASC, name ASC @@ -611,6 +761,7 @@ WHERE object_type = $1 AND scope = $2 fn row_to_object_record(row: sqlx::postgres::PgRow) -> ObjectRecord { let labels_jsonb: Option = row.get("labels"); + let resource_version_i64: i64 = row.try_get("resource_version").unwrap_or(1); ObjectRecord { object_type: row.get("object_type"), id: row.get("id"), @@ -619,6 +770,7 @@ fn row_to_object_record(row: sqlx::postgres::PgRow) -> ObjectRecord { created_at_ms: row.get("created_at_ms"), updated_at_ms: row.get("updated_at_ms"), labels: labels_jsonb.map(|value| value.to_string()), + resource_version: resource_version_i64.max(1).cast_unsigned(), } } diff --git a/crates/openshell-server/src/persistence/sqlite.rs b/crates/openshell-server/src/persistence/sqlite.rs index 1ed6bccd3..47b3d00ab 100644 --- a/crates/openshell-server/src/persistence/sqlite.rs +++ b/crates/openshell-server/src/persistence/sqlite.rs @@ -3,7 +3,7 @@ use super::{ DraftChunkRecord, ObjectRecord, PersistenceError, PersistenceResult, PolicyRecord, - current_time_ms, map_db_error, map_migrate_error, + WriteCondition, WriteResult, current_time_ms, map_db_error, map_migrate_error, }; use crate::policy_store::{ draft_chunk_payload_from_record, draft_chunk_record_from_parts, policy_payload_from_record, @@ -98,6 +98,155 @@ ON CONFLICT ("object_type", "name") WHERE "name" IS NOT NULL DO UPDATE SET Ok(()) } + pub async fn put_if( + &self, + object_type: &str, + id: &str, + name: &str, + payload: &[u8], + labels: Option<&str>, + condition: WriteCondition, + ) -> PersistenceResult { + let now_ms = current_time_ms()?; + + match condition { + WriteCondition::MustCreate => { + // Insert only - fail if object exists + sqlx::query( + r#" +INSERT INTO "objects" ("object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms", "labels", "resource_version") +VALUES (?1, ?2, ?3, ?4, ?5, ?5, ?6, 1) +"#, + ) + .bind(object_type) + .bind(id) + .bind(name) + .bind(payload) + .bind(now_ms) + .bind(labels.unwrap_or("{}")) + .execute(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + Ok(WriteResult { + resource_version: 1, + created_at_ms: now_ms, + updated_at_ms: now_ms, + }) + } + WriteCondition::MatchResourceVersion(expected_version) => { + // Update with version check + let result = sqlx::query( + r#" +UPDATE "objects" +SET "payload" = ?4, "labels" = ?5, "updated_at_ms" = ?6, "resource_version" = "resource_version" + 1 +WHERE "object_type" = ?1 AND "id" = ?2 AND "resource_version" = ?3 +"#, + ) + .bind(object_type) + .bind(id) + .bind(i64::try_from(expected_version).unwrap_or(i64::MAX)) + .bind(payload) + .bind(labels.unwrap_or("{}")) + .bind(now_ms) + .execute(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + if result.rows_affected() == 0 { + // Check if object exists to distinguish NotFound from Conflict + let existing = self.get(object_type, id).await?; + if let Some(record) = existing { + return Err(PersistenceError::Conflict { + current_resource_version: Some(record.resource_version), + }); + } + return Err(PersistenceError::Database(format!( + "object not found: {object_type}/{id}" + ))); + } + + // Fetch the updated record to get the new resource_version + let updated = self.get(object_type, id).await?.ok_or_else(|| { + PersistenceError::Database("object disappeared after update".to_string()) + })?; + + Ok(WriteResult { + resource_version: updated.resource_version, + created_at_ms: updated.created_at_ms, + updated_at_ms: updated.updated_at_ms, + }) + } + WriteCondition::Unconditional => { + // Unconditional upsert by name + sqlx::query( + r#" +INSERT INTO "objects" ("object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms", "labels", "resource_version") +VALUES (?1, ?2, ?3, ?4, ?5, ?5, ?6, 1) +ON CONFLICT ("object_type", "name") WHERE "name" IS NOT NULL DO UPDATE SET + "payload" = excluded."payload", + "updated_at_ms" = excluded."updated_at_ms", + "labels" = excluded."labels", + "resource_version" = "objects"."resource_version" + 1 +"#, + ) + .bind(object_type) + .bind(id) + .bind(name) + .bind(payload) + .bind(now_ms) + .bind(labels.unwrap_or("{}")) + .execute(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + // Fetch the result to get the resource_version + let record = self.get_by_name(object_type, name).await?.ok_or_else(|| { + PersistenceError::Database("object disappeared after upsert".to_string()) + })?; + + Ok(WriteResult { + resource_version: record.resource_version, + created_at_ms: record.created_at_ms, + updated_at_ms: record.updated_at_ms, + }) + } + } + } + + pub async fn delete_if( + &self, + object_type: &str, + id: &str, + expected_resource_version: u64, + ) -> PersistenceResult { + let result = sqlx::query( + r#" +DELETE FROM "objects" +WHERE "object_type" = ?1 AND "id" = ?2 AND "resource_version" = ?3 +"#, + ) + .bind(object_type) + .bind(id) + .bind(i64::try_from(expected_resource_version).unwrap_or(i64::MAX)) + .execute(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + if result.rows_affected() > 0 { + Ok(true) + } else { + // Check if object exists to distinguish NotFound from Conflict + let existing = self.get(object_type, id).await?; + if let Some(record) = existing { + return Err(PersistenceError::Conflict { + current_resource_version: Some(record.resource_version), + }); + } + Ok(false) + } + } + pub async fn get( &self, object_type: &str, @@ -105,7 +254,7 @@ ON CONFLICT ("object_type", "name") WHERE "name" IS NOT NULL DO UPDATE SET ) -> PersistenceResult> { let row = sqlx::query( r#" -SELECT "object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms", "labels" +SELECT "object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms", "labels", "resource_version" FROM "objects" WHERE "object_type" = ?1 AND "id" = ?2 "#, @@ -126,7 +275,7 @@ WHERE "object_type" = ?1 AND "id" = ?2 ) -> PersistenceResult> { let row = sqlx::query( r#" -SELECT "object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms", "labels" +SELECT "object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms", "labels", "resource_version" FROM "objects" WHERE "object_type" = ?1 AND "name" = ?2 "#, @@ -178,7 +327,7 @@ WHERE "object_type" = ?1 AND "name" = ?2 ) -> PersistenceResult> { let rows = sqlx::query( r#" -SELECT "object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms", "labels" +SELECT "object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms", "labels", "resource_version" FROM "objects" WHERE "object_type" = ?1 ORDER BY "created_at_ms" ASC, "name" ASC @@ -669,6 +818,7 @@ pub(super) fn sqlite_sidecar_paths(path: &Path) -> [PathBuf; 2] { } fn row_to_object_record(row: sqlx::sqlite::SqliteRow) -> ObjectRecord { + let resource_version_i64: i64 = row.try_get("resource_version").unwrap_or(1); ObjectRecord { object_type: row.get("object_type"), id: row.get("id"), @@ -677,6 +827,7 @@ fn row_to_object_record(row: sqlx::sqlite::SqliteRow) -> ObjectRecord { created_at_ms: row.get("created_at_ms"), updated_at_ms: row.get("updated_at_ms"), labels: row.get("labels"), + resource_version: resource_version_i64.max(1).cast_unsigned(), } } diff --git a/crates/openshell-server/src/persistence/tests.rs b/crates/openshell-server/src/persistence/tests.rs index 09549ad29..db85d2a0e 100644 --- a/crates/openshell-server/src/persistence/tests.rs +++ b/crates/openshell-server/src/persistence/tests.rs @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use super::{ObjectType, Store, generate_name}; +use super::{ObjectType, PersistenceError, Store, generate_name}; use crate::policy_store::PolicyStoreExt; use openshell_core::proto::{ObjectForTest, SandboxPolicy}; use prost::Message; @@ -962,3 +962,460 @@ fn parse_label_selector_handles_whitespace() { assert_eq!(result.get("env"), Some(&"prod".to_string())); assert_eq!(result.get("tier"), Some(&"frontend".to_string())); } + +// --------------------------------------------------------------------------- +// CAS (compare-and-swap) tests +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn cas_put_if_must_create_succeeds() { + use super::WriteCondition; + + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + let result = store + .put_if( + "sandbox", + "id-1", + "new-sandbox", + b"payload", + None, + WriteCondition::MustCreate, + ) + .await + .unwrap(); + + assert_eq!(result.resource_version, 1); + + let record = store.get("sandbox", "id-1").await.unwrap().unwrap(); + assert_eq!(record.resource_version, 1); + assert_eq!(record.payload, b"payload"); +} + +#[tokio::test] +async fn cas_put_if_must_create_fails_on_duplicate() { + use super::{PersistenceError, WriteCondition}; + + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + // First insert succeeds + store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"payload1", + None, + WriteCondition::MustCreate, + ) + .await + .unwrap(); + + // Second insert with same ID fails + let result = store + .put_if( + "sandbox", + "id-1", + "sandbox-2", + b"payload2", + None, + WriteCondition::MustCreate, + ) + .await; + + assert!(matches!( + result, + Err(PersistenceError::UniqueViolation { .. }) + )); +} + +#[tokio::test] +async fn cas_put_if_match_version_succeeds() { + use super::WriteCondition; + + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + // Create initial object + store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"v1", + None, + WriteCondition::MustCreate, + ) + .await + .unwrap(); + + // Update with correct version + let result = store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"v2", + None, + WriteCondition::MatchResourceVersion(1), + ) + .await + .unwrap(); + + assert_eq!(result.resource_version, 2); + + let record = store.get("sandbox", "id-1").await.unwrap().unwrap(); + assert_eq!(record.resource_version, 2); + assert_eq!(record.payload, b"v2"); +} + +#[tokio::test] +async fn cas_put_if_match_version_fails_on_mismatch() { + use super::{PersistenceError, WriteCondition}; + + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + // Create initial object + store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"v1", + None, + WriteCondition::MustCreate, + ) + .await + .unwrap(); + + // Update with wrong version + let result = store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"v2", + None, + WriteCondition::MatchResourceVersion(99), + ) + .await; + + assert!(matches!( + result, + Err(PersistenceError::Conflict { + current_resource_version: Some(1) + }) + )); + + // Original payload unchanged + let record = store.get("sandbox", "id-1").await.unwrap().unwrap(); + assert_eq!(record.resource_version, 1); + assert_eq!(record.payload, b"v1"); +} + +#[tokio::test] +async fn cas_delete_if_succeeds_with_correct_version() { + use super::WriteCondition; + + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"payload", + None, + WriteCondition::MustCreate, + ) + .await + .unwrap(); + + let deleted = store.delete_if("sandbox", "id-1", 1).await.unwrap(); + assert!(deleted); + + let record = store.get("sandbox", "id-1").await.unwrap(); + assert!(record.is_none()); +} + +#[tokio::test] +async fn cas_delete_if_fails_with_wrong_version() { + use super::{PersistenceError, WriteCondition}; + + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"payload", + None, + WriteCondition::MustCreate, + ) + .await + .unwrap(); + + let result = store.delete_if("sandbox", "id-1", 99).await; + assert!(matches!( + result, + Err(PersistenceError::Conflict { + current_resource_version: Some(1) + }) + )); + + // Object still exists + let record = store.get("sandbox", "id-1").await.unwrap().unwrap(); + assert_eq!(record.resource_version, 1); +} + +#[tokio::test] +async fn cas_resource_version_increments() { + use super::WriteCondition; + + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + // Create + let r1 = store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"v1", + None, + WriteCondition::MustCreate, + ) + .await + .unwrap(); + assert_eq!(r1.resource_version, 1); + + // Update 1 + let r2 = store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"v2", + None, + WriteCondition::MatchResourceVersion(1), + ) + .await + .unwrap(); + assert_eq!(r2.resource_version, 2); + + // Update 2 + let r3 = store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"v3", + None, + WriteCondition::MatchResourceVersion(2), + ) + .await + .unwrap(); + assert_eq!(r3.resource_version, 3); + + let record = store.get("sandbox", "id-1").await.unwrap().unwrap(); + assert_eq!(record.resource_version, 3); +} + +#[tokio::test] +async fn cas_concurrent_updates_one_succeeds() { + use super::WriteCondition; + use std::sync::Arc; + + let store = Arc::new( + Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(), + ); + + // Create initial object + store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"initial", + None, + WriteCondition::MustCreate, + ) + .await + .unwrap(); + + // Spawn 10 concurrent updates trying to update from version 1 + let mut handles = vec![]; + for i in 0..10 { + let store = Arc::clone(&store); + let handle = tokio::spawn(async move { + store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + format!("update-{i}").as_bytes(), + None, + WriteCondition::MatchResourceVersion(1), + ) + .await + }); + handles.push(handle); + } + + let results: Vec<_> = futures::future::join_all(handles) + .await + .into_iter() + .map(|r| r.unwrap()) + .collect(); + + // Exactly one should succeed, rest should conflict + let successes = results.iter().filter(|r| r.is_ok()).count(); + let conflicts = results.iter().filter(|r| r.is_err()).count(); + + assert_eq!(successes, 1); + assert_eq!(conflicts, 9); + + // Final version should be 2 + let record = store.get("sandbox", "id-1").await.unwrap().unwrap(); + assert_eq!(record.resource_version, 2); +} + +#[tokio::test] +async fn cas_update_message_cas_succeeds() { + use openshell_core::proto::Sandbox; + + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + // Create a sandbox + let sandbox = Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "test-id".to_string(), + name: "test-sandbox".to_string(), + created_at_ms: 1000, + labels: std::collections::HashMap::new(), + resource_version: 0, + }), + spec: None, + status: None, + phase: 0, + current_policy_version: 0, + }; + + store.put_message(&sandbox).await.unwrap(); + + // Update using CAS with expected_version = 0 (use current version) + let updated = store + .update_message_cas::("test-id", 0, |s| { + s.phase = 2; // Set to Ready + s.current_policy_version = 42; + }) + .await + .unwrap(); + + assert_eq!(updated.phase, 2); + assert_eq!(updated.current_policy_version, 42); + assert_eq!( + updated.metadata.as_ref().map_or(0, |m| m.resource_version), + 2 + ); +} + +#[tokio::test] +async fn cas_update_message_cas_conflicts_on_concurrent_updates() { + use openshell_core::proto::Sandbox; + use std::sync::Arc; + use std::sync::atomic::{AtomicU32, Ordering}; + + let store = Arc::new( + Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(), + ); + + // Create a sandbox + let sandbox = Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "test-id".to_string(), + name: "test-sandbox".to_string(), + created_at_ms: 1000, + labels: std::collections::HashMap::new(), + resource_version: 0, + }), + spec: None, + status: None, + phase: 0, + current_policy_version: 0, + }; + + store.put_message(&sandbox).await.unwrap(); + + // Track how many updates succeed + let success_count = Arc::new(AtomicU32::new(0)); + + // Spawn 5 concurrent CAS updates (using expected_version = 0 to use current) + let mut handles = vec![]; + for i in 0..5 { + let store = Arc::clone(&store); + let success_count = Arc::clone(&success_count); + let handle = tokio::spawn(async move { + let result = store + .update_message_cas::("test-id", 0, |s| { + s.current_policy_version = i; + }) + .await; + if result.is_ok() { + success_count.fetch_add(1, Ordering::SeqCst); + } + result + }); + handles.push(handle); + } + + let results: Vec<_> = futures::future::join_all(handles) + .await + .into_iter() + .map(|r| r.unwrap()) + .collect(); + + // Only one should succeed; others fail with Conflict due to single-attempt CAS + let successes = results.iter().filter(|r| r.is_ok()).count(); + let conflicts = results + .iter() + .filter(|r| matches!(r, Err(PersistenceError::Conflict { .. }))) + .count(); + assert_eq!(successes, 1, "exactly one concurrent update should succeed"); + assert_eq!(conflicts, 4, "four updates should fail with Conflict"); + assert_eq!(success_count.load(Ordering::SeqCst), 1); + + // Final version should be 2 (initial 1 + 1 successful update) + let final_sandbox = store + .get_message::("test-id") + .await + .unwrap() + .unwrap(); + assert_eq!( + final_sandbox + .metadata + .as_ref() + .map_or(0, |m| m.resource_version), + 2, + "resource_version should be 2 (initial 1 + 1 successful update)" + ); +} diff --git a/crates/openshell-server/src/service_routing.rs b/crates/openshell-server/src/service_routing.rs index 194f10417..d126fdd7f 100644 --- a/crates/openshell-server/src/service_routing.rs +++ b/crates/openshell-server/src/service_routing.rs @@ -803,6 +803,7 @@ mod tests { name: "my-sandbox--web".to_string(), created_at_ms: 1_700_000_000_000, labels: std::collections::HashMap::default(), + resource_version: 0, }), sandbox_id: "sandbox-id".to_string(), sandbox_name: "my-sandbox".to_string(), diff --git a/crates/openshell-server/src/ssh_sessions.rs b/crates/openshell-server/src/ssh_sessions.rs index c3294b361..e328a50ac 100644 --- a/crates/openshell-server/src/ssh_sessions.rs +++ b/crates/openshell-server/src/ssh_sessions.rs @@ -90,6 +90,7 @@ mod tests { name: format!("session-{id}"), created_at_ms: 1000, labels: HashMap::new(), + resource_version: 0, }), sandbox_id: sandbox_id.to_string(), token: id.to_string(), diff --git a/crates/openshell-server/src/supervisor_session.rs b/crates/openshell-server/src/supervisor_session.rs index 19d358826..91f40c289 100644 --- a/crates/openshell-server/src/supervisor_session.rs +++ b/crates/openshell-server/src/supervisor_session.rs @@ -800,6 +800,7 @@ mod tests { name: name.to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), ..Default::default() } diff --git a/crates/openshell-tui/src/lib.rs b/crates/openshell-tui/src/lib.rs index b96c0abbf..f8a393df0 100644 --- a/crates/openshell-tui/src/lib.rs +++ b/crates/openshell-tui/src/lib.rs @@ -1557,6 +1557,7 @@ fn spawn_create_provider(app: &App, tx: mpsc::UnboundedSender) { name: provider_name.clone(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: ptype.clone(), credentials: credentials.clone(), @@ -1647,6 +1648,7 @@ fn spawn_update_provider(app: &App, tx: mpsc::UnboundedSender) { name: name.clone(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: ptype, credentials, @@ -1990,6 +1992,7 @@ fn spawn_set_global_setting(app: &App, tx: mpsc::UnboundedSender) { delete_setting: false, global: true, merge_operations: vec![], + expected_resource_version: 0, }; let result = tokio::time::timeout(Duration::from_secs(5), client.update_config(req)).await; @@ -2025,6 +2028,7 @@ fn spawn_delete_global_setting(app: &App, tx: mpsc::UnboundedSender) { delete_setting: true, global: true, merge_operations: vec![], + expected_resource_version: 0, }; let result = tokio::time::timeout(Duration::from_secs(5), client.update_config(req)).await; @@ -2094,6 +2098,7 @@ fn spawn_set_sandbox_setting(app: &App, tx: mpsc::UnboundedSender) { delete_setting: false, global: false, merge_operations: vec![], + expected_resource_version: 0, }; let result = tokio::time::timeout(Duration::from_secs(5), client.update_config(req)).await; @@ -2133,6 +2138,7 @@ fn spawn_delete_sandbox_setting(app: &App, tx: mpsc::UnboundedSender) { delete_setting: true, global: false, merge_operations: vec![], + expected_resource_version: 0, }; let result = tokio::time::timeout(Duration::from_secs(5), client.update_config(req)).await; diff --git a/proto/datamodel.proto b/proto/datamodel.proto index 534b043ae..462088124 100644 --- a/proto/datamodel.proto +++ b/proto/datamodel.proto @@ -8,7 +8,7 @@ package openshell.datamodel.v1; // Kubernetes-style metadata shared by all top-level OpenShell domain objects. // // This structure provides consistent metadata (identity, labels, timestamps, -// versioning) across Sandbox, Provider, SshSession, and other resources. +// resource versioning) across Sandbox, Provider, SshSession, and other resources. message ObjectMeta { // Stable object ID generated by the gateway. string id = 1; @@ -22,6 +22,10 @@ message ObjectMeta { // Key-value labels for filtering and organization. // Labels must follow Kubernetes conventions: alphanumeric + `-._/`, max 63 chars per segment. map labels = 4; + + // Optimistic concurrency control version. + // Incremented by the gateway on each update. Clients can use this for compare-and-swap operations. + uint64 resource_version = 5; } // Provider model stored by OpenShell. diff --git a/proto/openshell.proto b/proto/openshell.proto index d6bcbece2..d05b0c9cd 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -379,6 +379,11 @@ message AttachSandboxProviderRequest { string sandbox_name = 1; // Provider name to attach. string provider_name = 2; + // Expected resource version for optimistic concurrency control. + // If 0, the server uses the current version (backward compatibility). + // If non-zero, the server validates that the sandbox's current resource_version + // matches this value before applying the mutation, returning ABORTED on mismatch. + uint64 expected_resource_version = 3; } // Detach provider from sandbox request. @@ -387,6 +392,11 @@ message DetachSandboxProviderRequest { string sandbox_name = 1; // Provider name to detach. string provider_name = 2; + // Expected resource version for optimistic concurrency control. + // If 0, the server uses the current version (backward compatibility). + // If non-zero, the server validates that the sandbox's current resource_version + // matches this value before applying the mutation, returning ABORTED on mismatch. + uint64 expected_resource_version = 3; } // Delete sandbox request. @@ -905,6 +915,12 @@ message UpdateConfigRequest { bool global = 6; // Batched incremental policy merge operations. Sandbox-scoped only. repeated PolicyMergeOperation merge_operations = 7; + // Expected resource version for optimistic concurrency control (sandbox-scoped only). + // If 0, the server uses the current version (backward compatibility). + // If non-zero, the server validates that the sandbox's current resource_version + // matches this value before applying the mutation, returning ABORTED on mismatch. + // Ignored for global-scoped updates. + uint64 expected_resource_version = 8; } message PolicyMergeOperation {