Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 52 additions & 7 deletions tensorstore/driver/kvs_backed_chunk_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ absl::Status ValidateNewMetadata(DataCacheBase* cache,
Result<MetadataPtr> GetUpdatedMetadataWithAssumeCachedMetadata(
KvsMetadataDriverBase& driver, DataCacheBase& cache,
internal::OpenTransactionPtr transaction) {

assert(driver.assumed_metadata_time_ != absl::InfiniteFuture() &&
driver.assumed_metadata_);
assert(&cache == driver.cache());
Expand Down Expand Up @@ -370,6 +371,7 @@ Result<MetadataPtr> ValidateNewMetadata(
Result<IndexTransform<>> GetInitialTransform(DataCacheBase* cache,
const void* metadata,
size_t component_index) {

TENSORSTORE_ASSIGN_OR_RETURN(
auto new_transform, cache->GetExternalToInternalTransform(
cache->initial_metadata_.get(), component_index));
Expand Down Expand Up @@ -712,6 +714,7 @@ Future<IndexTransform<>> KvsChunkedDriverBase::Resize(
Result<IndexTransform<>> KvsMetadataDriverBase::GetBoundSpecData(
internal::OpenTransactionPtr transaction, KvsDriverSpec& spec,
IndexTransformView<> transform_view) {

auto* cache = this->cache();
auto* metadata_cache = cache->metadata_cache();
TENSORSTORE_ASSIGN_OR_RETURN(spec.store.driver,
Expand Down Expand Up @@ -794,17 +797,36 @@ Result<std::size_t> ValidateOpenRequest(OpenState* state,
return absl::NotFoundError(
GetMetadataMissingErrorMessage(base.metadata_cache_entry_.get()));
}
return state->GetComponentIndex(metadata, base.spec_->open_mode());
auto result = state->GetComponentIndex(
metadata, base.spec_->open_mode());

return result;
}

/// The goal here is to provide a method to allow us to open struct data
/// as a bytearray.
Result<std::shared_ptr<void>> ValidateByteArray(
OpenState* state, const void* metadata) {

auto& base = *(PrivateOpenState*)state;
if (!metadata) {
return absl::NotFoundError(
GetMetadataMissingErrorMessage(base.metadata_cache_entry_.get()));
}

return state->AsByteArray(metadata, base.spec_->open_mode());
}


/// \pre `component_index` is the result of a previous call to
/// `state->GetComponentIndex` with the same `metadata`.
/// \pre `metadata != nullptr`
Result<internal::Driver::Handle> CreateTensorStoreFromMetadata(
OpenState::Ptr state, std::shared_ptr<const void> metadata,
size_t component_index) {

ABSL_LOG_IF(INFO, TENSORSTORE_KVS_DRIVER_DEBUG)
<< "CreateTensorStoreFromMetadata: state=" << state.get();
<< "CreateTensorStoreFromMetadata : state=" << state.get();
auto& base = *(PrivateOpenState*)state.get(); // Cast to private base
// TODO(jbms): The read-write mode should be determined based on the kvstore
// mode, once that is exposed.
Expand Down Expand Up @@ -1070,11 +1092,15 @@ Future<const void> MetadataCache::Entry::RequestAtomicUpdate(

Result<MetadataCache::MetadataPtr> MetadataCache::Entry::GetMetadata(
internal::OpenTransactionPtr transaction) {
if (!transaction) return GetMetadata();
if (!transaction){
return GetMetadata();
}
TENSORSTORE_ASSIGN_OR_RETURN(auto node,
GetTransactionNode(*this, transaction));

TENSORSTORE_ASSIGN_OR_RETURN(auto metadata, node->GetUpdatedMetadata(),
this->AnnotateError(_, /*reading=*/false));

return metadata;
}

Expand Down Expand Up @@ -1252,10 +1278,29 @@ internal::CachePtr<MetadataCache> GetOrCreateMetadataCache(

Result<internal::Driver::Handle> OpenState::CreateDriverHandleFromMetadata(
std::shared_ptr<const void> metadata) {
TENSORSTORE_ASSIGN_OR_RETURN(std::size_t component_index,
ValidateOpenRequest(this, metadata.get()));
return CreateTensorStoreFromMetadata(OpenState::Ptr(this),
std::move(metadata), component_index);
// try to do things by the book ...

auto result = ValidateOpenRequest(this, metadata.get());

if(result.ok()){
std::size_t component_index = result.value();
return CreateTensorStoreFromMetadata(
OpenState::Ptr(this), std::move(metadata), component_index
);
} else {
// Check if the metadata is compatible with our expectation of a byte array
auto maybe_new_metadata = ValidateByteArray(this, metadata.get());
if(absl::IsInvalidArgument(maybe_new_metadata.status())) {
return result.status();
}
TENSORSTORE_ASSIGN_OR_RETURN(
auto new_metadata, ValidateByteArray(this, metadata.get())
);
std::size_t component_index = 0;
return CreateTensorStoreFromMetadata(
OpenState::Ptr(this), std::move(new_metadata), component_index
);
}
}

Future<internal::Driver::Handle> OpenDriver(MetadataOpenState::Ptr state) {
Expand Down
8 changes: 7 additions & 1 deletion tensorstore/driver/kvs_backed_chunk_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,10 @@ class MetadataCache
public:
using OwningCache = MetadataCache;

MetadataPtr GetMetadata() { return ReadLock<void>(*this).shared_data(); }
MetadataPtr GetMetadata() {
auto result = ReadLock<void>(*this).shared_data();
return result;
}

Result<MetadataPtr> GetMetadata(internal::OpenTransactionPtr transaction);

Expand Down Expand Up @@ -740,6 +743,9 @@ class OpenState : public MetadataOpenState {
/// If the `metadata` is not compatible, returns an error.
virtual Result<size_t> GetComponentIndex(const void* metadata,
OpenMode open_mode) = 0;
/// attempt to cast to a Void type ...
virtual Result<std::shared_ptr<void>> AsByteArray(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can make this change purely in the zarr driver, rather than here in kvs_backed_chunk_driver.h

In particular, if the user does not specify a field, then we can always just open as a byte array.

I also think rather than create synthetic metadata it would be better to just create an additional special dtype field for the byte array representation. The synthetic metadata approach has some drawbacks.

const void* metadata, OpenMode open_mode) = 0;
};

/// Attempts to open a TensorStore with a kvstore-backed chunk driver.
Expand Down
1 change: 0 additions & 1 deletion tensorstore/driver/zarr/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ tensorstore_cc_test(
"//tensorstore:strided_layout",
"//tensorstore:transaction",
"//tensorstore/driver:driver_testutil",
"//tensorstore/driver/n5",
"//tensorstore/index_space:dim_expression",
"//tensorstore/index_space:index_transform",
"//tensorstore/internal:decoded_matches",
Expand Down
82 changes: 79 additions & 3 deletions tensorstore/driver/zarr/driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,26 @@ TENSORSTORE_DEFINE_JSON_DEFAULT_BINDER(
DimensionSeparatorJsonBinder))),
jb::Member("field", jb::Projection<&ZarrDriverSpec::selected_field>(
jb::DefaultValue<jb::kNeverIncludeDefaults>(
[](auto* obj) { *obj = std::string{}; }))),
[](auto* obj) {
*obj = std::string{};
}))),
jb::Initialize([](auto* obj) {
/*resolve the issue here obj is type driver*/
// this only has the partial metadata it hasn't built
// the ZarrMetadata

TENSORSTORE_ASSIGN_OR_RETURN(auto info, obj->GetSpecInfo());

if (info.full_rank != dynamic_rank) {
TENSORSTORE_RETURN_IF_ERROR(
obj->schema.Set(RankConstraint(info.full_rank)));
}

//It's setting something here ...
if (info.field) {
TENSORSTORE_RETURN_IF_ERROR(obj->schema.Set(info.field->dtype));
}

return absl::OkStatus();
})));

Expand Down Expand Up @@ -262,6 +272,7 @@ internal::ChunkGridSpecification DataCache::GetChunkGridSpecification(
metadata.chunks.size());
std::iota(chunked_to_cell_dimensions.begin(),
chunked_to_cell_dimensions.end(), static_cast<DimensionIndex>(0));

for (std::size_t field_i = 0; field_i < metadata.dtype.fields.size();
++field_i) {
const auto& field = metadata.dtype.fields[field_i];
Expand Down Expand Up @@ -289,6 +300,7 @@ internal::ChunkGridSpecification DataCache::GetChunkGridSpecification(
for (DimensionIndex cell_dim = fill_value_start_dim; cell_dim < cell_rank;
++cell_dim) {
const Index size = field_layout.full_chunk_shape()[cell_dim];

assert(fill_value.shape()[cell_dim - fill_value_start_dim] == size);
chunk_fill_value.shape()[cell_dim] = size;
chunk_fill_value.byte_strides()[cell_dim] =
Expand Down Expand Up @@ -438,6 +450,7 @@ class ZarrDriver::OpenState : public ZarrDriver::OpenStateBase {

Result<std::shared_ptr<const void>> Create(
const void* existing_metadata) override {

if (existing_metadata) {
return absl::AlreadyExistsError("");
}
Expand All @@ -454,32 +467,95 @@ class ZarrDriver::OpenState : public ZarrDriver::OpenStateBase {
std::string result;
const auto& spec = this->spec();
const auto& zarr_metadata = *static_cast<const ZarrMetadata*>(metadata);

internal::EncodeCacheKey(
&result, spec.store.path,
GetDimensionSeparator(spec.partial_metadata, zarr_metadata),
zarr_metadata, spec.metadata_key);

return result;
}

std::unique_ptr<internal_kvs_backed_chunk_driver::DataCacheBase> GetDataCache(
DataCache::Initializer&& initializer) override {
// seems to get executated after GetDataCacheKey (on creation)
// this is the zarr file
const auto& metadata =
*static_cast<const ZarrMetadata*>(initializer.metadata.get());

return std::make_unique<DataCache>(
std::move(initializer), spec().store.path,
GetDimensionSeparator(spec().partial_metadata, metadata),
spec().metadata_key);
}

/// The concept here is to create a new metadata object that has the
/// the dtype change such that we can create new driver for loading
/// byte arrays. It copies the metadata and updates the dtype, fill_value,
/// and chunk_layout.
Result<std::shared_ptr<void>> AsByteArray(
const void* metadata_ptr, OpenMode open_mode) override {
const auto& metadata = *static_cast<const ZarrMetadata*>(metadata_ptr);

if(metadata.dtype.fields.size() == 1 && metadata.dtype.fields[0].dtype != tensorstore::dtype_v<std::byte>) {
return absl::InvalidArgumentError(
"Trying to convert dtype rank 1 to byte array, but dtype is not std::byte"
);
}

ZarrMetadata new_metadata(metadata);
new_metadata.dtype = ParseDType("|V" + getDtypeTotalBytes(metadata_ptr)).value();

auto field = new_metadata.dtype.fields[0];
new_metadata.fill_value = std::vector<SharedArray<const void>>(
{
AllocateArray(
field.field_shape, ContiguousLayoutOrder::c,
default_init, field.dtype
)
}
);

TENSORSTORE_ASSIGN_OR_RETURN(
new_metadata.chunk_layout, ComputeChunkLayout(
new_metadata.dtype, ContiguousLayoutOrder::c, new_metadata.chunks
)
)

return std::make_shared<ZarrMetadata>(new_metadata);
}

std::string getDtypeTotalBytes(const void* metadata_ptr) {
const auto& metadata = *static_cast<const ZarrMetadata*>(metadata_ptr);

// TODO: Ensure that fields of rank > 1 are handled
int bytes = 0;
for(auto field : metadata.dtype.fields) {
bytes += field.num_bytes;
}

return std::to_string(bytes);
}

Result<std::size_t> GetComponentIndex(const void* metadata_ptr,
OpenMode open_mode) override {
// This will get called by the kvs and call driver/GetComponentIndex
// to make sure the dtype "field" is set and agrees with the dtype.
// and we have the open/create mode here too ...
const auto& metadata = *static_cast<const ZarrMetadata*>(metadata_ptr);

TENSORSTORE_RETURN_IF_ERROR(
ValidateMetadata(metadata, spec().partial_metadata));
ValidateMetadata(metadata, spec().partial_metadata)
);

// GetFieldIndex will return "0" if there the selected field is empty.
// And the dtype is not a struct array
TENSORSTORE_ASSIGN_OR_RETURN(
auto field_index, GetFieldIndex(metadata.dtype, spec().selected_field));
auto field_index, GetFieldIndex(metadata.dtype, spec().selected_field));

TENSORSTORE_RETURN_IF_ERROR(
ValidateMetadataSchema(metadata, field_index, spec().schema));

return field_index;
}
};
Expand Down
Loading