diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index 17213605f..8979b00b4 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -180,10 +180,10 @@ cc_library( srcs = [ "kernels/bigtable/bigtable_dataset_kernel.cc", "kernels/bigtable/bigtable_resource_kernel.h", - "kernels/bigtable/bigtable_row_range.h", - "kernels/bigtable/bigtable_row_set.h", "kernels/bigtable/bigtable_row_range.cc", + "kernels/bigtable/bigtable_row_range.h", "kernels/bigtable/bigtable_row_set.cc", + "kernels/bigtable/bigtable_row_set.h", "ops/bigtable_ops.cc", ], copts = tf_io_copts(), diff --git a/tensorflow_io/core/kernels/bigtable/bigtable_dataset_kernel.cc b/tensorflow_io/core/kernels/bigtable/bigtable_dataset_kernel.cc index 400683c96..e77bd39fd 100644 --- a/tensorflow_io/core/kernels/bigtable/bigtable_dataset_kernel.cc +++ b/tensorflow_io/core/kernels/bigtable/bigtable_dataset_kernel.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "absl/memory/memory.h" +#include "google/cloud/bigtable/row_set.h" #include "google/cloud/bigtable/table.h" #include "google/cloud/bigtable/table_admin.h" #include "tensorflow/core/framework/common_shape_fns.h" @@ -29,7 +30,8 @@ namespace tensorflow { namespace data { namespace { -tensorflow::error::Code GoogleCloudErrorCodeToTfErrorCode(::google::cloud::StatusCode code) { +tensorflow::error::Code GoogleCloudErrorCodeToTfErrorCode( + ::google::cloud::StatusCode code) { switch (code) { case ::google::cloud::StatusCode::kOk: return ::tensorflow::error::OK; @@ -72,22 +74,25 @@ Status GoogleCloudStatusToTfStatus(const ::google::cloud::Status& status) { if (status.ok()) { return Status::OK(); } - return Status(GoogleCloudErrorCodeToTfErrorCode(status.code()), - strings::StrCat("Error reading from Cloud Bigtable: ", - status.message())); + return Status( + GoogleCloudErrorCodeToTfErrorCode(status.code()), + strings::StrCat("Error reading from Cloud Bigtable: ", status.message())); } class BigtableClientResource : public ResourceBase { public: explicit BigtableClientResource(const std::string& project_id, const std::string& instance_id) - : data_client_(CreateDataClient(project_id, instance_id)) {} + : data_client_(CreateDataClient(project_id, instance_id)) { + VLOG(1) << "BigtableClientResource ctor"; + } - cbt::Table CreateTable(const std::string& table_id) { - VLOG(1) << "CreateTable"; - return cbt::Table(data_client_, table_id); + const std::shared_ptr& data_client() const { + return data_client_; } + ~BigtableClientResource() { VLOG(1) << "BigtableClientResource dtor"; } + string DebugString() const override { return "BigtableClientResource"; } private: @@ -108,6 +113,8 @@ class BigtableClientOp : public OpKernel { VLOG(1) << "BigtableClientOp ctor"; } + ~BigtableClientOp() { VLOG(1) << "BigtableClientOp dtor"; } + void Compute(OpKernelContext* ctx) override TF_LOCKS_EXCLUDED(mu_) { VLOG(1) << "BigtableClientOp compute"; ResourceMgr* mgr = ctx->resource_manager(); @@ -135,17 +142,17 @@ template class Iterator : public DatasetIterator { public: explicit Iterator(const typename DatasetIterator::Params& params, - const std::string& table_id, const std::vector& columns) : DatasetIterator(params), columns_(ColumnsToFamiliesAndQualifiers(columns)), - reader_( - this->dataset()->client_resource().CreateTable(table_id).ReadRows( - cbt::RowRange::InfiniteRange(), - cbt::Filter::Chain(CreateColumnsFilter(columns_), - cbt::Filter::Latest(1)))), + reader_(this->dataset()->CreateTable().ReadRows( + this->dataset()->row_set(), + cbt::Filter::Chain(CreateColumnsFilter(columns_), + cbt::Filter::Latest(1)))), it_(this->reader_.begin()), - column_to_idx_(CreateColumnToIdxMap(columns_)) {} + column_to_idx_(CreateColumnToIdxMap(columns_)) { + VLOG(1) << "DatasetIterator ctor"; + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { @@ -256,7 +263,6 @@ class Iterator : public DatasetIterator { } mutex mu_; - const std::shared_ptr data_client_; const std::vector> columns_; cbt::RowReader reader_ GUARDED_BY(mu_); cbt::v1::internal::RowReaderIterator it_ GUARDED_BY(mu_); @@ -269,11 +275,13 @@ class Iterator : public DatasetIterator { class Dataset : public DatasetBase { public: - Dataset(OpKernelContext* ctx, BigtableClientResource* client_resource, - std::string table_id, std::vector columns) + Dataset(OpKernelContext* ctx, + const std::shared_ptr& data_client, + cbt::RowSet row_set, std::string table_id, + std::vector columns) : DatasetBase(DatasetContext(ctx)), - client_resource_(*client_resource), - client_resource_unref_(client_resource), + data_client_(data_client), + row_set_(std::move(row_set)), table_id_(table_id), columns_(columns) { dtypes_.push_back(DT_STRING); @@ -286,7 +294,7 @@ class Dataset : public DatasetBase { return absl::make_unique>( typename DatasetIterator::Params{ this, strings::StrCat(prefix, "::BigtableDataset")}, - table_id_, columns_); + columns_); } const DataTypeVector& output_dtypes() const override { return dtypes_; } @@ -299,7 +307,17 @@ class Dataset : public DatasetBase { return "BigtableDatasetOp::Dataset"; } - BigtableClientResource& client_resource() const { return client_resource_; } + const std::shared_ptr& data_client() const { + return data_client_; + } + const cbt::RowSet& row_set() const { return row_set_; } + + cbt::Table CreateTable() const { + VLOG(1) << "CreateTable"; + cbt::Table table(data_client_, table_id_); + VLOG(1) << "table crated"; + return table; + } protected: Status AsGraphDefInternal(SerializationContext* ctx, @@ -312,8 +330,8 @@ class Dataset : public DatasetBase { Status CheckExternalState() const override { return Status::OK(); } private: - BigtableClientResource& client_resource_; - const core::ScopedUnref client_resource_unref_; + std::shared_ptr const& data_client_; + const cbt::RowSet row_set_; const std::string table_id_; const std::vector columns_; DataTypeVector dtypes_; @@ -330,10 +348,17 @@ class BigtableDatasetOp : public DatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { VLOG(1) << "Make Dataset"; BigtableClientResource* client_resource; - OP_REQUIRES_OK( - ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client_resource)); - core::ScopedUnref client_resource_unref_(client_resource); - *output = new Dataset(ctx, client_resource, table_id_, columns_); + OP_REQUIRES_OK(ctx, + GetResourceFromContext(ctx, "client", &client_resource)); + core::ScopedUnref unref_client(client_resource); + + io::BigtableRowSetResource* row_set_resource; + OP_REQUIRES_OK(ctx, + GetResourceFromContext(ctx, "row_set", &row_set_resource)); + core::ScopedUnref row_set_resource_unref_(row_set_resource); + + *output = new Dataset(ctx, client_resource->data_client(), + row_set_resource->row_set(), table_id_, columns_); } private: @@ -344,6 +369,138 @@ class BigtableDatasetOp : public DatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("BigtableDataset").Device(DEVICE_CPU), BigtableDatasetOp); +// Return the index of the tablet that a worker should start with. Each worker +// start with their first tablet and finish on tablet before next worker's first +// tablet. Each worker should get num_tablets/num_workers rounded down, plus at +// most one. If we simply round up, then the last worker may be starved. +// Consider an example where there's 100 tablets and 11 workers. If we give +// round_up(100/11) to each one, then first 10 workers get 10 tablets each, and +// the last one gets nothing. +int GetWorkerStartIndex(size_t num_tablets, size_t num_workers, + size_t worker_id) { + // if there's more workers than tablets, workers get one tablet each or less. + if (num_tablets <= num_workers) return std::min(num_tablets, worker_id); + // tablets_per_worker: minimum tablets each worker should obtain. + size_t const tablets_per_worker = num_tablets / num_workers; + // surplus_tablets: excess that has to be evenly distributed among the workers + // so that no worker gets more than tablets_per_worker + 1. + size_t const surplus_tablets = num_tablets % num_workers; + size_t const workers_before = worker_id; + return tablets_per_worker * workers_before + + std::min(surplus_tablets, workers_before); +} + +bool RowSetIntersectsRange(cbt::RowSet const& row_set, + std::string const& start_key, + std::string const& end_key) { + auto range = cbt::RowRange::Range(start_key, end_key); + return !row_set.Intersect(range).IsEmpty(); +} + +class BigtableSplitRowSetEvenlyOp : public OpKernel { + public: + explicit BigtableSplitRowSetEvenlyOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + VLOG(1) << "BigtableSplitRowSetEvenlyOp ctor "; + OP_REQUIRES_OK(ctx, ctx->GetAttr("table_id", &table_id_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_splits", &num_splits_)); + } + + void Compute(OpKernelContext* context) override { + mutex_lock l(mu_); + + ResourceMgr* mgr = context->resource_manager(); + ContainerInfo cinfo; + OP_REQUIRES_OK(context, cinfo.Init(mgr, def())); + + BigtableClientResource* client_resource; + OP_REQUIRES_OK(context, + GetResourceFromContext(context, "client", &client_resource)); + core::ScopedUnref unref_client(client_resource); + + io::BigtableRowSetResource* row_set_resource; + OP_REQUIRES_OK( + context, GetResourceFromContext(context, "row_set", &row_set_resource)); + core::ScopedUnref unref_row_set(row_set_resource); + + VLOG(1) << "BigtableSplitRowSetEvenlyOp got RowSet: " + << row_set_resource->ToString(); + if (row_set_resource->row_set().IsEmpty()) { + OP_REQUIRES_OK(context, + errors::FailedPrecondition("row_set cannot be empty!")); + } + + auto table = cbt::Table(client_resource->data_client(), table_id_); + auto maybe_sample_row_keys = table.SampleRows(); + OP_REQUIRES_OK(context, + GoogleCloudStatusToTfStatus(maybe_sample_row_keys.status())); + + auto& sample_row_keys = maybe_sample_row_keys.value(); + + std::vector> tablets; + + std::string start_key; + for (auto& sample_row_key : sample_row_keys) { + auto& end_key = sample_row_key.row_key; + tablets.emplace_back(start_key, end_key); + start_key = std::move(end_key); + } + if (!start_key.empty() || tablets.size() == 0) { + tablets.emplace_back(start_key, ""); + } + tablets.erase( + std::remove_if( + tablets.begin(), tablets.end(), + [row_set_resource](std::pair const& p) { + return !RowSetIntersectsRange(row_set_resource->row_set(), + p.first, p.second); + }), + tablets.end()); + + VLOG(1) << "got array of tablets of size:" << tablets.size(); + + size_t output_size = std::min(tablets.size(), num_splits_); + + Tensor* output_tensor = NULL; + OP_REQUIRES_OK(context, + context->allocate_output(0, {static_cast(output_size)}, + &output_tensor)); + auto output_v = output_tensor->tensor(); + + for (size_t i = 0; i < output_size; i++) { + size_t start_idx = GetWorkerStartIndex(tablets.size(), output_size, i); + size_t next_worker_start_idx = + GetWorkerStartIndex(tablets.size(), output_size, i + 1); + size_t end_idx = next_worker_start_idx - 1; + start_key = tablets.at(start_idx).first; + std::string end_key = tablets.at(end_idx).second; + io::BigtableRowSetResource* work_chunk_row_set = + new io::BigtableRowSetResource(row_set_resource->Intersect( + cbt::RowRange::RightOpen(start_key, end_key))); + + std::string container_name = cinfo.name() + std::to_string(i); + + VLOG(1) << "creating resource:" << cinfo.container() << ":" + << container_name; + + OP_REQUIRES_OK( + context, mgr->Create( + cinfo.container(), container_name, work_chunk_row_set)); + output_v(i) = MakeResourceHandle( + cinfo.container(), container_name, *context->device(), + TypeIndex::Make()); + } + } + + private: + mutable mutex mu_; + std::string table_id_ GUARDED_BY(mu_); + int num_splits_ GUARDED_BY(mu_); +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableSplitRowSetEvenly").Device(DEVICE_CPU), + BigtableSplitRowSetEvenlyOp); + } // namespace } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/core/kernels/bigtable/bigtable_row_range.cc b/tensorflow_io/core/kernels/bigtable/bigtable_row_range.cc index 49278219a..7e33154a1 100644 --- a/tensorflow_io/core/kernels/bigtable/bigtable_row_range.cc +++ b/tensorflow_io/core/kernels/bigtable/bigtable_row_range.cc @@ -132,8 +132,7 @@ class BigtablePrefixRowRangeOp } private: - StatusOr CreateResource() - override { + StatusOr CreateResource() override { return new BigtableRowRangeResource(cbt::RowRange::Prefix(prefix_)); } diff --git a/tensorflow_io/core/kernels/bigtable/bigtable_row_set.cc b/tensorflow_io/core/kernels/bigtable/bigtable_row_set.cc index 1841d01ed..fb8297a3b 100644 --- a/tensorflow_io/core/kernels/bigtable/bigtable_row_set.cc +++ b/tensorflow_io/core/kernels/bigtable/bigtable_row_set.cc @@ -28,8 +28,7 @@ class BigtableEmptyRowSetOp } private: - StatusOr CreateResource() - override { + StatusOr CreateResource() override { return new BigtableRowSetResource(cbt::RowSet()); } }; @@ -90,14 +89,13 @@ class BigtableRowSetAppendRowRangeOp : public OpKernel { void Compute(OpKernelContext* context) override { mutex_lock lock(mu_); BigtableRowSetResource* row_set_resource; - OP_REQUIRES_OK(context, GetResourceFromContext(context, "row_set", - &row_set_resource)); + OP_REQUIRES_OK( + context, GetResourceFromContext(context, "row_set", &row_set_resource)); core::ScopedUnref row_set_resource_unref(row_set_resource); BigtableRowRangeResource* row_range_resource; - OP_REQUIRES_OK(context, - GetResourceFromContext(context, "row_range", - &row_range_resource)); + OP_REQUIRES_OK(context, GetResourceFromContext(context, "row_range", + &row_range_resource)); core::ScopedUnref row_range_resource_unref(row_range_resource); row_set_resource->AppendRowRange(row_range_resource->row_range()); @@ -122,21 +120,21 @@ class BigtableRowSetIntersectOp : public OpKernel { OP_REQUIRES_OK(context, cinfo_.Init(mgr, def())); BigtableRowSetResource* row_set_resource; - OP_REQUIRES_OK(context, GetResourceFromContext(context, "row_set", - &row_set_resource)); + OP_REQUIRES_OK( + context, GetResourceFromContext(context, "row_set", &row_set_resource)); core::ScopedUnref row_set_resource_unref(row_set_resource); BigtableRowRangeResource* row_range_resource; - OP_REQUIRES_OK(context, - GetResourceFromContext(context, "row_range", - &row_range_resource)); + OP_REQUIRES_OK(context, GetResourceFromContext(context, "row_range", + &row_range_resource)); core::ScopedUnref row_range_resource_unref(row_range_resource); BigtableRowSetResource* result_resource = new BigtableRowSetResource( - row_set_resource->Intersect(row_range_resource->row_range())); + row_set_resource->Intersect(row_range_resource->row_range())); - OP_REQUIRES_OK(context, mgr->Create( - cinfo_.container(), cinfo_.name(), result_resource)); + OP_REQUIRES_OK(context, + mgr->Create( + cinfo_.container(), cinfo_.name(), result_resource)); OP_REQUIRES_OK(context, MakeResourceHandleToOutput( context, 0, cinfo_.container(), cinfo_.name(), diff --git a/tensorflow_io/core/kernels/bigtable/bigtable_row_set.h b/tensorflow_io/core/kernels/bigtable/bigtable_row_set.h index 4dc641593..14fc49d71 100644 --- a/tensorflow_io/core/kernels/bigtable/bigtable_row_set.h +++ b/tensorflow_io/core/kernels/bigtable/bigtable_row_set.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef BIGTABLE_ROW_SET_H #define BIGTABLE_ROW_SET_H - #include "google/cloud/bigtable/table.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -50,6 +49,8 @@ class BigtableRowSetResource : public ResourceBase { return row_set_.Intersect(row_range); } + google::cloud::bigtable::RowSet const& row_set() { return row_set_; } + string DebugString() const override { return "BigtableRowSetResource:{" + ToString() + "}"; } diff --git a/tensorflow_io/core/ops/bigtable_ops.cc b/tensorflow_io/core/ops/bigtable_ops.cc index 4ebff7f3f..73492e8e6 100644 --- a/tensorflow_io/core/ops/bigtable_ops.cc +++ b/tensorflow_io/core/ops/bigtable_ops.cc @@ -27,13 +27,13 @@ REGISTER_OP("BigtableClient") REGISTER_OP("BigtableDataset") .Input("client: resource") + .Input("row_set: resource") .Attr("table_id: string") .Attr("columns: list(string) >= 1") .Output("handle: variant") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); - REGISTER_OP("BigtableEmptyRowSet") .Attr("container: string = ''") .Attr("shared_name: string = ''") @@ -93,3 +93,16 @@ REGISTER_OP("BigtableRowSetIntersect") .Output("result_row_set: resource") .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("BigtableSplitRowSetEvenly") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Input("client: resource") + .Input("row_set: resource") + .Attr("table_id: string") + .Attr("num_splits: int") + .Output("samples: resource") + .SetIsStateful() + .SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->Vector(c->UnknownDim())); + return tensorflow::Status::OK(); + }); diff --git a/tensorflow_io/python/ops/bigtable/bigtable_dataset_ops.py b/tensorflow_io/python/ops/bigtable/bigtable_dataset_ops.py index 81b344e65..ca047a4d8 100644 --- a/tensorflow_io/python/ops/bigtable/bigtable_dataset_ops.py +++ b/tensorflow_io/python/ops/bigtable/bigtable_dataset_ops.py @@ -4,6 +4,14 @@ from tensorflow_io.python.ops import core_ops from tensorflow.python.framework import dtypes import tensorflow as tf +from tensorflow.python.data.ops import dataset_ops + +from tensorflow_io.python.ops.bigtable.bigtable_row_set import ( + from_rows_or_ranges, + RowSet, + intersect, +) +from tensorflow_io.python.ops.bigtable.bigtable_row_range import infinite class BigtableClient: @@ -15,9 +23,7 @@ class BigtableClient: def __init__(self, project_id: str, instance_id: str): """Creates a BigtableClient to start Bigtable read sessions.""" - self._client_resource = core_ops.bigtable_client( - project_id, instance_id - ) + self._client_resource = core_ops.bigtable_client(project_id, instance_id) def get_table(self, table_id): return BigtableTable(self._client_resource, table_id) @@ -28,22 +34,49 @@ def __init__(self, client_resource, table_id: str): self._table_id = table_id self._client_resource = client_resource - def read_rows(self, columns: List[str]): - return _BigtableDataset(self._client_resource, self._table_id, columns) + def read_rows(self, columns: List[str], row_set: RowSet): + return _BigtableDataset(self._client_resource, self._table_id, columns, row_set) + + def parallel_read_rows( + self, + columns: List[str], + num_parallel_calls=tf.data.AUTOTUNE, + row_set: RowSet = from_rows_or_ranges(infinite()), + ): + + print("calling parallel read_rows with row_set:", row_set) + samples = core_ops.bigtable_split_row_set_evenly( + self._client_resource, row_set._impl, self._table_id, num_parallel_calls, + ) + + def map_func(idx): + return self.read_rows(columns, RowSet(samples[idx])) + + # We interleave a dataset of sample's indexes instead of a dataset of + # samples, because Dataset.from_tensor_slices attempts to copy the + # resource tensors using DeepCopy from tensor_util.cc which is not + # possible for tensors of type DT_RESOURCE. + return tf.data.Dataset.range(samples.shape[0]).interleave( + map_func=map_func, + cycle_length=num_parallel_calls, + block_length=1, + num_parallel_calls=num_parallel_calls, + deterministic=False, + ) class _BigtableDataset(dataset_ops.DatasetSource): """_BigtableDataset represents a dataset that retrieves keys and values.""" - def __init__(self, client_resource, table_id: str, columns: List[str]): + def __init__( + self, client_resource, table_id: str, columns: List[str], row_set: RowSet, + ): self._table_id = table_id self._columns = columns - self._element_spec = tf.TensorSpec( - shape=[len(columns)], dtype=dtypes.string - ) + self._element_spec = tf.TensorSpec(shape=[len(columns)], dtype=dtypes.string) variant_tensor = core_ops.bigtable_dataset( - client_resource, table_id, columns + client_resource, row_set._impl, table_id, columns ) super().__init__(variant_tensor) diff --git a/tensorflow_io/python/ops/bigtable/bigtable_row_range.py b/tensorflow_io/python/ops/bigtable/bigtable_row_range.py index 23192bd73..f01ec939a 100644 --- a/tensorflow_io/python/ops/bigtable/bigtable_row_range.py +++ b/tensorflow_io/python/ops/bigtable/bigtable_row_range.py @@ -15,7 +15,6 @@ """Module implementing basic functions for obtaining BigTable RowRanges""" from tensorflow_io.python.ops import core_ops -import tensorflow class RowRange: diff --git a/tensorflow_io/python/ops/bigtable/bigtable_row_set.py b/tensorflow_io/python/ops/bigtable/bigtable_row_set.py index 92f1ba9a4..acded8f7f 100644 --- a/tensorflow_io/python/ops/bigtable/bigtable_row_set.py +++ b/tensorflow_io/python/ops/bigtable/bigtable_row_set.py @@ -14,7 +14,6 @@ """Module implementing basic functions for obtaining BigTable RowSets""" -from tensorflow.python.framework import dtypes from tensorflow_io.python.ops import core_ops from . import bigtable_row_range from typing import Union @@ -31,9 +30,7 @@ def append(self, row_or_range): if isinstance(row_or_range, str): core_ops.bigtable_row_set_append_row(self._impl, row_or_range) else: - core_ops.bigtable_row_set_append_row_range( - self._impl, row_or_range._impl - ) + core_ops.bigtable_row_set_append_row_range(self._impl, row_or_range._impl) def empty(): @@ -72,6 +69,4 @@ def intersect(row_set: RowSet, row_range: bigtable_row_range.RowRange): Returns: RowSet: an intersection of the given row set and row range. """ - return RowSet( - core_ops.bigtable_row_set_intersect(row_set._impl, row_range._impl) - ) + return RowSet(core_ops.bigtable_row_set_intersect(row_set._impl, row_range._impl)) diff --git a/tests/test_bigtable/bigtable_emulator.py b/tests/test_bigtable/bigtable_emulator.py index 7258a7ac6..d42144fa3 100644 --- a/tests/test_bigtable/bigtable_emulator.py +++ b/tests/test_bigtable/bigtable_emulator.py @@ -67,9 +67,7 @@ def _get_cbt_emulator_path(): def _get_cbt_cli_path(): - return _get_cbt_binary_path( - CBT_CLI_PATH_ENV_VAR, CBT_CLI_SEARCH_PATHS, "cbt cli" - ) + return _get_cbt_binary_path(CBT_CLI_PATH_ENV_VAR, CBT_CLI_SEARCH_PATHS, "cbt cli") def _extract_emulator_addr_from_output(emulator_output): @@ -82,9 +80,7 @@ def _extract_emulator_addr_from_output(emulator_output): for word in words: if re.fullmatch("[a-z.0-9]+:[0-9]+", word): return word - raise RuntimeError( - f"Failed to find CBT emulator in the line {line}" - ) + raise RuntimeError(f"Failed to find CBT emulator in the line {line}") class BigtableEmulator: diff --git a/tests/test_bigtable/test_parallel_read_rows.py b/tests/test_bigtable/test_parallel_read_rows.py new file mode 100644 index 000000000..b20b1ef56 --- /dev/null +++ b/tests/test_bigtable/test_parallel_read_rows.py @@ -0,0 +1,180 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# disable module docstring for tests +# pylint: disable=C0114 +# disable class docstring for tests +# pylint: disable=C0115 + +import os +from re import escape +from .bigtable_emulator import BigtableEmulator +from tensorflow_io.python.ops import core_ops +from tensorflow_io.python.ops.bigtable.bigtable_dataset_ops import BigtableClient +import tensorflow_io.python.ops.bigtable.bigtable_row_range as row_range +import tensorflow_io.python.ops.bigtable.bigtable_row_set as row_set +import tensorflow as tf +from tensorflow import test + + +class BigtableParallelReadTest(test.TestCase): + def setUp(self): + self.emulator = BigtableEmulator() + + def tearDown(self): + self.emulator.stop() + + def test_parallel_read(self): + os.environ["BIGTABLE_EMULATOR_HOST"] = self.emulator.get_addr() + self.emulator.create_table( + "fake_project", + "fake_instance", + "test-table", + ["fam1", "fam2"], + splits=["row005", "row010", "row015"], + ) + + values = [[f"[{i,j}]" for j in range(2)] for i in range(20)] + flat_values = [value for row in values for value in row] + + ten = tf.constant(values) + + client = BigtableClient("fake_project", "fake_instance") + table = client.get_table("test-table") + + self.emulator.write_tensor( + "fake_project", + "fake_instance", + "test-table", + ten, + ["row" + str(i).rjust(3, "0") for i in range(20)], + ["fam1:col1", "fam2:col2"], + ) + + for r in table.parallel_read_rows( + ["fam1:col1", "fam2:col2"], + row_set=row_set.from_rows_or_ranges(row_range.infinite()), + ): + for c in r: + self.assertTrue(c.numpy().decode() in flat_values) + + def test_not_parallel_read(self): + os.environ["BIGTABLE_EMULATOR_HOST"] = self.emulator.get_addr() + self.emulator.create_table( + "fake_project", + "fake_instance", + "test-table", + ["fam1", "fam2"], + splits=["row005", "row010", "row015"], + ) + + values = [[f"[{i,j}]" for j in range(2)] for i in range(20)] + + ten = tf.constant(values) + + client = BigtableClient("fake_project", "fake_instance") + table = client.get_table("test-table") + + self.emulator.write_tensor( + "fake_project", + "fake_instance", + "test-table", + ten, + ["row" + str(i).rjust(3, "0") for i in range(20)], + ["fam1:col1", "fam2:col2"], + ) + + dataset = table.parallel_read_rows( + ["fam1:col1", "fam2:col2"], + row_set=row_set.from_rows_or_ranges(row_range.infinite()), + num_parallel_calls=2, + ) + results = [[v.numpy().decode() for v in row] for row in dataset] + self.assertEqual(repr(sorted(values)), repr(sorted(results))) + + def test_split_row_set(self): + os.environ["BIGTABLE_EMULATOR_HOST"] = self.emulator.get_addr() + self.emulator.create_table( + "fake_project", + "fake_instance", + "test-table", + ["fam1", "fam2"], + splits=["row005", "row010", "row015", "row020", "row025", "row030"], + ) + + values = [[f"[{i,j}]" for j in range(2)] for i in range(40)] + + ten = tf.constant(values) + + client = BigtableClient("fake_project", "fake_instance") + + self.emulator.write_tensor( + "fake_project", + "fake_instance", + "test-table", + ten, + ["row" + str(i).rjust(3, "0") for i in range(40)], + ["fam1:col1", "fam2:col2"], + ) + + rs = row_set.from_rows_or_ranges(row_range.infinite()) + + num_parallel_calls = 2 + samples = [ + s + for s in core_ops.bigtable_split_row_set_evenly( + client._client_resource, rs._impl, "test-table", num_parallel_calls, + ) + ] + self.assertEqual(len(samples), num_parallel_calls) + + num_parallel_calls = 6 + samples = [ + s + for s in core_ops.bigtable_split_row_set_evenly( + client._client_resource, rs._impl, "test-table", num_parallel_calls, + ) + ] + + # The emulator may return different samples each time, so we can't + # expect an exact number, but it must be no more than num_parallel_calls + self.assertLessEqual(len(samples), num_parallel_calls) + + num_parallel_calls = 1 + samples = [ + s + for s in core_ops.bigtable_split_row_set_evenly( + client._client_resource, rs._impl, "test-table", num_parallel_calls, + ) + ] + self.assertEqual(len(samples), num_parallel_calls) + + def test_split_empty(self): + os.environ["BIGTABLE_EMULATOR_HOST"] = self.emulator.get_addr() + self.emulator.create_table( + "fake_project", "fake_instance", "test-table", ["fam1", "fam2"], + ) + + client = BigtableClient("fake_project", "fake_instance") + + rs = row_set.from_rows_or_ranges(row_range.empty()) + + num_parallel_calls = 2 + + self.assertRaises( + tf.errors.FailedPreconditionError, + lambda: core_ops.bigtable_split_row_set_evenly( + client._client_resource, rs._impl, "test-table", num_parallel_calls, + ), + ) diff --git a/tests/test_bigtable/test_read_rows.py b/tests/test_bigtable/test_read_rows.py index 05c8a3319..823908713 100644 --- a/tests/test_bigtable/test_read_rows.py +++ b/tests/test_bigtable/test_read_rows.py @@ -19,9 +19,9 @@ import os from .bigtable_emulator import BigtableEmulator -from tensorflow_io.python.ops.bigtable.bigtable_dataset_ops import ( - BigtableClient, -) +from tensorflow_io.python.ops.bigtable.bigtable_dataset_ops import BigtableClient +import tensorflow_io.python.ops.bigtable.bigtable_row_range as row_range +import tensorflow_io.python.ops.bigtable.bigtable_row_set as row_set import tensorflow as tf from tensorflow import test @@ -55,7 +55,40 @@ def test_read(self): ["fam1:col1", "fam2:col2"], ) - for i, r in enumerate(table.read_rows(["fam1:col1", "fam2:col2"])): + for i, r in enumerate( + table.read_rows( + ["fam1:col1", "fam2:col2"], + row_set=row_set.from_rows_or_ranges(row_range.empty()), + ) + ): for j, c in enumerate(r): - print("ij", i, j, c) self.assertEqual(values[i][j], c.numpy().decode()) + + def test_read_row_set(self): + os.environ["BIGTABLE_EMULATOR_HOST"] = self.emulator.get_addr() + self.emulator.create_table( + "fake_project", "fake_instance", "test-table", ["fam1", "fam2"] + ) + + values = [[f"[{i,j}]" for j in range(2)] for i in range(20)] + + ten = tf.constant(values) + + client = BigtableClient("fake_project", "fake_instance") + table = client.get_table("test-table") + + self.emulator.write_tensor( + "fake_project", + "fake_instance", + "test-table", + ten, + ["row" + str(i).rjust(3, "0") for i in range(20)], + ["fam1:col1", "fam2:col2"], + ) + + row_s = row_set.from_rows_or_ranges(row_range.closed_range("row000", "row009")) + + read_rows = [ + r for r in table.read_rows(["fam1:col1", "fam2:col2"], row_set=row_s) + ] + self.assertEqual(len(read_rows), 10) diff --git a/tests/test_bigtable/test_row_set.py b/tests/test_bigtable/test_row_set.py index 385e835b8..949f53d31 100644 --- a/tests/test_bigtable/test_row_set.py +++ b/tests/test_bigtable/test_row_set.py @@ -16,9 +16,11 @@ # pylint: disable=C0114 # disable class docstring for tests # pylint: disable=C0115 +from tensorflow_io.python.ops import core_ops import tensorflow_io.python.ops.bigtable.bigtable_row_range as row_range import tensorflow_io.python.ops.bigtable.bigtable_row_set as row_set from tensorflow import test +import tensorflow as tf class RowRangeTest(test.TestCase): @@ -104,14 +106,10 @@ def test_from_rows_or_ranges(self): self.assertEqual(expected, repr(r_set)) def test_intersect(self): - r_set = row_set.from_rows_or_ranges( - row_range.open_range("row1", "row5") - ) + r_set = row_set.from_rows_or_ranges(row_range.open_range("row1", "row5")) r_set = row_set.intersect(r_set, row_range.closed_range("row3", "row7")) expected = ( - "row_ranges {\n" - + ' start_key_closed: "row3"\n' - + " end_key_open: " + "row_ranges {\n" + ' start_key_closed: "row3"\n' + " end_key_open: " '"row5"\n' + "}\n" ) self.assertEqual(expected, repr(r_set))