diff --git a/tensorstore/driver/zarr3/chunk_cache.cc b/tensorstore/driver/zarr3/chunk_cache.cc index 64b6d69fd..f14efd607 100644 --- a/tensorstore/driver/zarr3/chunk_cache.cc +++ b/tensorstore/driver/zarr3/chunk_cache.cc @@ -75,10 +75,12 @@ ZarrChunkCache::~ZarrChunkCache() = default; ZarrLeafChunkCache::ZarrLeafChunkCache( kvstore::DriverPtr store, ZarrCodecChain::PreparedState::Ptr codec_state, - ZarrDType dtype, internal::CachePool::WeakPtr /*data_cache_pool*/) + ZarrDType dtype, internal::CachePool::WeakPtr /*data_cache_pool*/, + bool open_as_void) : Base(std::move(store)), codec_state_(std::move(codec_state)), - dtype_(std::move(dtype)) {} + dtype_(std::move(dtype)), + open_as_void_(open_as_void) {} void ZarrLeafChunkCache::Read(ZarrChunkCache::ReadRequest request, AnyFlowReceiver chunk_indices, absl::InlinedVector, 1> field_arrays(num_fields); // Special case: void access - return raw bytes directly - if (num_fields == 1 && dtype_.fields[0].name == "") { + if (open_as_void_) { TENSORSTORE_ASSIGN_OR_RETURN( field_arrays[0], codec_state_->DecodeArray(grid().components[0].shape(), std::move(data))); @@ -221,11 +223,13 @@ kvstore::Driver* ZarrLeafChunkCache::GetKvStoreDriver() { ZarrShardedChunkCache::ZarrShardedChunkCache( kvstore::DriverPtr store, ZarrCodecChain::PreparedState::Ptr codec_state, - ZarrDType dtype, internal::CachePool::WeakPtr data_cache_pool) + ZarrDType dtype, internal::CachePool::WeakPtr data_cache_pool, + bool open_as_void) : base_kvstore_(std::move(store)), codec_state_(std::move(codec_state)), dtype_(std::move(dtype)), - data_cache_pool_(std::move(data_cache_pool)) {} + data_cache_pool_(std::move(data_cache_pool)), + open_as_void_(open_as_void) {} Result> TranslateCellToSourceTransformForShard( IndexTransform<> transform, span grid_cell_indices, @@ -534,7 +538,7 @@ void ZarrShardedChunkCache::Entry::DoInitialize() { *sharding_state.sub_chunk_codec_chain, std::move(sharding_kvstore), cache.executor(), ZarrShardingCodec::PreparedState::Ptr(&sharding_state), - cache.dtype_, cache.data_cache_pool_); + cache.dtype_, cache.data_cache_pool_, cache.open_as_void_); zarr_chunk_cache = new_cache.release(); return std::unique_ptr(&zarr_chunk_cache->cache()); }) diff --git a/tensorstore/driver/zarr3/chunk_cache.h b/tensorstore/driver/zarr3/chunk_cache.h index 5933115d7..a39eb1dc8 100644 --- a/tensorstore/driver/zarr3/chunk_cache.h +++ b/tensorstore/driver/zarr3/chunk_cache.h @@ -158,7 +158,8 @@ class ZarrLeafChunkCache : public internal::KvsBackedChunkCache, explicit ZarrLeafChunkCache(kvstore::DriverPtr store, ZarrCodecChain::PreparedState::Ptr codec_state, ZarrDType dtype, - internal::CachePool::WeakPtr data_cache_pool); + internal::CachePool::WeakPtr data_cache_pool, + bool open_as_void = false); void Read(ZarrChunkCache::ReadRequest request, AnyFlowReceiver( @@ -246,6 +249,7 @@ class ZarrShardedChunkCache : public internal::Cache, public ZarrChunkCache { kvstore::DriverPtr base_kvstore_; ZarrCodecChain::PreparedState::Ptr codec_state_; ZarrDType dtype_; + bool open_as_void_; // Data cache pool, if it differs from `this->pool()` (which is equal to the // metadata cache pool). @@ -260,11 +264,13 @@ class ZarrShardSubChunkCache : public ChunkCacheImpl { explicit ZarrShardSubChunkCache( kvstore::DriverPtr store, Executor executor, ZarrShardingCodec::PreparedState::Ptr sharding_state, - ZarrDType dtype, internal::CachePool::WeakPtr data_cache_pool) + ZarrDType dtype, internal::CachePool::WeakPtr data_cache_pool, + bool open_as_void = false) : ChunkCacheImpl(std::move(store), ZarrCodecChain::PreparedState::Ptr( sharding_state->sub_chunk_codec_state), - std::move(dtype), std::move(data_cache_pool)), + std::move(dtype), std::move(data_cache_pool), + open_as_void), sharding_state_(std::move(sharding_state)), executor_(std::move(executor)) {} diff --git a/tensorstore/driver/zarr3/driver.cc b/tensorstore/driver/zarr3/driver.cc index dd95c711b..f4c0ad9d7 100644 --- a/tensorstore/driver/zarr3/driver.cc +++ b/tensorstore/driver/zarr3/driver.cc @@ -149,20 +149,9 @@ class ZarrDriverSpec jb::Member("field", jb::Projection<&ZarrDriverSpec::selected_field>( jb::DefaultValue( [](auto* obj) { *obj = std::string{}; }))), - - // NEW: wrap the open_as_void projection in a Validate - jb::Member("open_as_void", - jb::Validate( - [](const auto& options, ZarrDriverSpec* obj) -> absl::Status { - // At this point, Projection has already set obj->open_as_void - if (obj->open_as_void) { - obj->selected_field = ""; - } - return absl::OkStatus(); - }, - jb::Projection<&ZarrDriverSpec::open_as_void>( + jb::Member("open_as_void", jb::Projection<&ZarrDriverSpec::open_as_void>( jb::DefaultValue( - [](auto* v) { *v = false; }))))); + [](auto* v) { *v = false; })))); @@ -592,10 +581,7 @@ class ZarrDataCache : public ChunkCacheImpl, public DataCacheBase { grid_(DataCacheBase::GetChunkGridSpecification( metadata(), // Check if this is void access by examining the dtype - (ChunkCacheImpl::dtype_.fields.size() == 1 && - ChunkCacheImpl::dtype_.fields[0].name == "") - ? kVoidFieldIndex - : 0)) {} + ChunkCacheImpl::open_as_void_ ? kVoidFieldIndex : false)) {} const internal::LexicographicalGridIndexKeyParser& GetChunkStorageKeyParser() final { @@ -626,9 +612,8 @@ class ZarrDataCache : public ChunkCacheImpl, public DataCacheBase { const void* metadata_ptr, size_t component_index) override { const auto& metadata = *static_cast(metadata_ptr); - // Check if this is void access by examining the cache's dtype - const bool is_void_access = (ChunkCacheImpl::dtype_.fields.size() == 1 && - ChunkCacheImpl::dtype_.fields[0].name == ""); + // Check if this is void access by examining the stored flag + const bool is_void_access = ChunkCacheImpl::open_as_void_; if (is_void_access) { // For void access, create transform with extra bytes dimension @@ -802,7 +787,7 @@ class ZarrDriver::OpenState : public ZarrDriver::OpenStateBase { TENSORSTORE_ASSIGN_OR_RETURN( auto metadata, internal_zarr3::GetNewMetadata(spec().metadata_constraints, - spec().schema), + spec().schema, spec().selected_field, spec().open_as_void), tensorstore::MaybeAnnotateStatus( _, "Cannot create using specified \"metadata\" and schema")); return metadata; @@ -819,15 +804,15 @@ class ZarrDriver::OpenState : public ZarrDriver::OpenStateBase { *static_cast(initializer.metadata.get()); // For void access, modify the dtype to indicate special handling ZarrDType dtype = metadata.data_type; - if (spec().selected_field == "") { + if (spec().open_as_void) { // Create a synthetic dtype for void access dtype = ZarrDType{ /*.has_fields=*/false, /*.fields=*/{ZarrDType::Field{ - ZarrDType::BaseDType{"", dtype_v, + ZarrDType::BaseDType{"", dtype_v, {metadata.data_type.bytes_per_outer_element}}, /*.outer_shape=*/{}, - /*.name=*/"", + /*.name=*/"", /*.field_shape=*/{metadata.data_type.bytes_per_outer_element}, /*.num_inner_elements=*/metadata.data_type.bytes_per_outer_element, /*.byte_offset=*/0, @@ -837,7 +822,8 @@ class ZarrDriver::OpenState : public ZarrDriver::OpenStateBase { return internal_zarr3::MakeZarrChunkCache( *metadata.codecs, std::move(initializer), spec().store.path, metadata.codec_state, dtype, - /*data_cache_pool=*/*cache_pool()); + /*data_cache_pool=*/*cache_pool(), + spec().open_as_void); } Result GetComponentIndex(const void* metadata_ptr, @@ -847,7 +833,7 @@ class ZarrDriver::OpenState : public ZarrDriver::OpenStateBase { ValidateMetadata(metadata, spec().metadata_constraints)); TENSORSTORE_ASSIGN_OR_RETURN( auto field_index, - GetFieldIndex(metadata.data_type, spec().selected_field)); + GetFieldIndex(metadata.data_type, spec().selected_field, spec().open_as_void)); // For void access, map to component index 0 if (field_index == kVoidFieldIndex) { field_index = 0; diff --git a/tensorstore/driver/zarr3/metadata.cc b/tensorstore/driver/zarr3/metadata.cc index 9aef7bd0b..ba4454de4 100644 --- a/tensorstore/driver/zarr3/metadata.cc +++ b/tensorstore/driver/zarr3/metadata.cc @@ -799,12 +799,14 @@ std::string GetFieldNames(const ZarrDType& dtype) { constexpr size_t kVoidFieldIndex = size_t(-1); Result GetFieldIndex(const ZarrDType& dtype, - std::string_view selected_field) { - // Special case: "" requests raw byte access (works for any dtype) - if (selected_field == "") { + std::string_view selected_field, + bool open_as_void) { + // Special case: open_as_void requests raw byte access (works for any dtype) + + if (open_as_void) { if (dtype.fields.empty()) { return absl::FailedPreconditionError( - "Requested field \"\" but dtype has no fields"); + "Requested void access but dtype has no fields"); } return kVoidFieldIndex; } @@ -1138,7 +1140,7 @@ absl::Status ValidateMetadataSchema(const ZarrMetadata& metadata, Result> GetNewMetadata( const ZarrMetadataConstraints& metadata_constraints, const Schema& schema, - std::string_view selected_field) { + std::string_view selected_field, bool open_as_void) { auto metadata = std::make_shared(); metadata->zarr_format = metadata_constraints.zarr_format.value_or(3); @@ -1165,7 +1167,7 @@ Result> GetNewMetadata( } TENSORSTORE_ASSIGN_OR_RETURN( - size_t field_index, GetFieldIndex(metadata->data_type, selected_field)); + size_t field_index, GetFieldIndex(metadata->data_type, selected_field, open_as_void)); SpecRankAndFieldInfo info; info.field = &metadata->data_type.fields[field_index]; info.chunked_rank = metadata_constraints.rank; diff --git a/tensorstore/driver/zarr3/metadata.h b/tensorstore/driver/zarr3/metadata.h index 4c7871b0d..857210546 100644 --- a/tensorstore/driver/zarr3/metadata.h +++ b/tensorstore/driver/zarr3/metadata.h @@ -230,12 +230,14 @@ absl::Status ValidateMetadataSchema(const ZarrMetadata& metadata, /// unspecified. Result> GetNewMetadata( const ZarrMetadataConstraints& metadata_constraints, - const Schema& schema, std::string_view selected_field = {}); + const Schema& schema, std::string_view selected_field = {}, + bool open_as_void = false); absl::Status ValidateDataType(DataType dtype); Result GetFieldIndex(const ZarrDType& dtype, - std::string_view selected_field); + std::string_view selected_field, + bool open_as_void = false); struct SpecRankAndFieldInfo { DimensionIndex chunked_rank = dynamic_rank; diff --git a/tensorstore/driver/zarr3/metadata_test.cc b/tensorstore/driver/zarr3/metadata_test.cc index 11c97619f..ba7a26593 100644 --- a/tensorstore/driver/zarr3/metadata_test.cc +++ b/tensorstore/driver/zarr3/metadata_test.cc @@ -438,7 +438,7 @@ Result> TestGetNewMetadata( TENSORSTORE_RETURN_IF_ERROR(status); TENSORSTORE_ASSIGN_OR_RETURN( auto constraints, ZarrMetadataConstraints::FromJson(constraints_json)); - return GetNewMetadata(constraints, schema); + return GetNewMetadata(constraints, schema, /*selected_field=*/{}, /*open_as_void=*/false); } TEST(GetNewMetadataTest, DuplicateDimensionNames) {