diff --git a/WORKSPACE b/WORKSPACE index 6297d1ef4..ce72c155f 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -112,11 +112,11 @@ http_archive( http_archive( name = "com_google_googleapis", build_file = "@com_github_googleapis_google_cloud_cpp//bazel:googleapis.BUILD", - sha256 = "7ebab01b06c555f4b6514453dc3e1667f810ef91d1d4d2d3aa29bb9fcb40a900", - strip_prefix = "googleapis-541b1ded4abadcc38e8178680b0677f65594ea6f", + sha256 = "a53e15405f81d5a32594d7f6486e649131fadda5431cf28377dff4ae54d45d16", + strip_prefix = "googleapis-d4d09eb3aec152015f35717102f9b423988b94f7", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/googleapis/googleapis/archive/541b1ded4abadcc38e8178680b0677f65594ea6f.zip", - "https://github.com/googleapis/googleapis/archive/541b1ded4abadcc38e8178680b0677f65594ea6f.zip", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/googleapis/googleapis/archive/d4d09eb3aec152015f35717102f9b423988b94f7.zip", + "https://github.com/googleapis/googleapis/archive/d4d09eb3aec152015f35717102f9b423988b94f7.zip", ], ) diff --git a/tensorflow_io/BUILD b/tensorflow_io/BUILD index 91bd5724c..58e42752d 100644 --- a/tensorflow_io/BUILD +++ b/tensorflow_io/BUILD @@ -12,6 +12,7 @@ cc_binary( deps = [ "//tensorflow_io/core:arrow_ops", "//tensorflow_io/core:bigquery_ops", + "//tensorflow_io/core:bigtable_ops", "//tensorflow_io/core:audio_video_ops", "//tensorflow_io/core:avro_ops", "//tensorflow_io/core:orc_ops", diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index 87c516938..739405d8e 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -175,6 +175,23 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "bigtable_ops", + srcs = [ + "kernels/bigtable/bigtable_dataset_kernel.cc", + "ops/bigtable_ops.cc", + ], + copts = tf_io_copts(), + linkstatic = True, + deps = [ + "@com_github_googleapis_google_cloud_cpp//:bigtable_client", + "@com_github_grpc_grpc//:grpc++", + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ], + alwayslink = 1, +) + # A library for use in the bigquery kernels. cc_library( name = "bigquery_lib_cc", diff --git a/tensorflow_io/core/kernels/bigtable/bigtable_dataset_kernel.cc b/tensorflow_io/core/kernels/bigtable/bigtable_dataset_kernel.cc new file mode 100644 index 000000000..fd9e0200d --- /dev/null +++ b/tensorflow_io/core/kernels/bigtable/bigtable_dataset_kernel.cc @@ -0,0 +1,321 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "absl/memory/memory.h" +#include "google/cloud/bigtable/table.h" +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" + +using ::tensorflow::DT_STRING; +using ::tensorflow::PartialTensorShape; +using ::tensorflow::Status; + +namespace cbt = ::google::cloud::bigtable; + +namespace tensorflow { +namespace data { +namespace { + +class BigtableClientResource : public ResourceBase { + public: + explicit BigtableClientResource(const std::string& project_id, + const std::string& instance_id) + : data_client_(CreateDataClient(project_id, instance_id)) {} + + cbt::Table CreateTable(const std::string& table_id) { + VLOG(1) << "CreateTable"; + return cbt::Table(data_client_, table_id); + } + + string DebugString() const override { return "BigtableClientResource"; } + + private: + std::shared_ptr CreateDataClient( + const std::string& project_id, const std::string& instance_id) { + VLOG(1) << "CreateDataClient"; + return cbt::CreateDefaultDataClient(project_id, instance_id, + cbt::ClientOptions()); + } + std::shared_ptr data_client_; +}; + +class BigtableClientOp : public OpKernel { + public: + explicit BigtableClientOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("project_id", &project_id_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("instance_id", &instance_id_)); + VLOG(1) << "BigtableClientOp ctor"; + } + + ~BigtableClientOp() override { + VLOG(1) << "BigtableClientOp dtor"; + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + + void Compute(OpKernelContext* ctx) override TF_LOCKS_EXCLUDED(mu_) { + VLOG(1) << "BigtableClientOp compute"; + mutex_lock l(mu_); + if (!initialized_) { + ResourceMgr* mgr = ctx->resource_manager(); + OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); + BigtableClientResource* resource; + OP_REQUIRES_OK(ctx, mgr->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &resource, + [this, ctx](BigtableClientResource** ret) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + *ret = new BigtableClientResource( + project_id_, instance_id_); + return Status::OK(); + })); + core::ScopedUnref resource_cleanup(resource); + initialized_ = true; + } + OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( + ctx, 0, cinfo_.container(), cinfo_.name(), + TypeIndex::Make())); + } + + private: + mutex mu_; + ContainerInfo cinfo_ TF_GUARDED_BY(mu_); + bool initialized_ TF_GUARDED_BY(mu_) = false; + string project_id_; + string instance_id_; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableClient").Device(DEVICE_CPU), + BigtableClientOp); + +template +class Iterator : public DatasetIterator { + public: + explicit Iterator(const typename DatasetIterator::Params& params, + const std::string& table_id, + const std::vector& columns) + : DatasetIterator(params), + columns_(ColumnsToFamiliesAndQualifiers(columns)), + reader_( + this->dataset()->client_resource().CreateTable(table_id).ReadRows( + cbt::RowRange::InfiniteRange(), + cbt::Filter::Chain(CreateColumnsFilter(columns_), + cbt::Filter::Latest(1)))), + it_(this->reader_.begin()), + column_to_idx_(CreateColumnToIdxMap(columns_)) {} + + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) override { + VLOG(1) << "GetNextInternal"; + mutex_lock l(mu_); + if (it_ == reader_.end()) { + VLOG(1) << "End of sequence"; + *end_of_sequence = true; + return Status::OK(); + } + *end_of_sequence = false; + + VLOG(1) << "alocating tensor"; + const std::size_t kNumCols = column_to_idx_.size(); + Tensor res(ctx->allocator({}), DT_STRING, {(long)kNumCols}); + auto res_data = res.tensor(); + + VLOG(1) << "getting row"; + const auto& row = *it_; + for (const auto& cell : row.value().cells()) { + std::pair key( + cell.family_name(), cell.column_qualifier()); + const auto column_idx = column_to_idx_.find(key); + if (column_idx != column_to_idx_.end()) { + VLOG(1) << "getting column:" << column_idx->second; + res_data(column_idx->second) = std::move(cell.value()); + } else { + LOG(ERROR) << "column " << cell.family_name() << ":" + << cell.column_qualifier() + << " was unexpectedly read from bigtable"; + } + } + VLOG(1) << "returning value"; + out_tensors->emplace_back(std::move(res)); + + VLOG(1) << "incrementing iterator"; + it_ = std::next(it_); + + return Status::OK(); + } + + protected: + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { + return errors::Unimplemented("SaveInternal"); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return errors::Unimplemented( + "Iterator does not support 'RestoreInternal')"); + } + + private: + cbt::Filter CreateColumnsFilter( + const std::vector>& columns) { + VLOG(1) << "CreateColumnsFilter"; + std::vector filters; + + for (const auto& column : columns) { + cbt::Filter f = cbt::Filter::ColumnName(column.first, column.second); + filters.push_back(std::move(f)); + } + + return cbt::Filter::InterleaveFromRange(filters.begin(), filters.end()); + } + + static std::pair ColumnToFamilyAndQualifier( + const std::string& col_name_full) { + VLOG(1) << "ColumnToFamilyAndQualifier" << col_name_full; + std::vector result_vector = absl::StrSplit(col_name_full, ":"); + if (result_vector.size() != 2 || result_vector[0].empty()) + throw std::invalid_argument("Invalid column name:" + col_name_full + + "\nColumn name must be in format " + + "column_family:column_name."); + return std::make_pair(result_vector[0], result_vector[1]); + } + + static std::vector> + ColumnsToFamiliesAndQualifiers(const std::vector& columns) { + VLOG(1) << "ColumnsToFamiliesAndQualifiers"; + std::vector> columnPairs( + columns.size()); + std::transform(columns.begin(), columns.end(), columnPairs.begin(), + &ColumnToFamilyAndQualifier); + return columnPairs; + } + + static absl::flat_hash_map, + size_t> + CreateColumnToIdxMap( + const std::vector>& columns) { + VLOG(1) << "CreateColumnToIdxMap"; + absl::flat_hash_map, + size_t> + column_map; + std::size_t index = 0; + for (const auto& column : columns) { + std::pair key(column.first, + column.second); + column_map[key] = index++; + } + return column_map; + } + + mutex mu_; + const std::shared_ptr data_client_; + const std::vector> columns_; + cbt::RowReader reader_ GUARDED_BY(mu_); + cbt::v1::internal::RowReaderIterator it_ GUARDED_BY(mu_); + // we're using a map with const refs to avoid copying strings when searching + // for a value. + const absl::flat_hash_map, + size_t> + column_to_idx_; +}; + +class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, BigtableClientResource* client_resource, + std::string table_id, std::vector columns) + : DatasetBase(DatasetContext(ctx)), + client_resource_(*client_resource), + client_resource_unref_(client_resource), + table_id_(table_id), + columns_(columns) { + dtypes_.push_back(DT_STRING); + output_shapes_.push_back({}); + } + + std::unique_ptr MakeIteratorInternal( + const std::string& prefix) const override { + VLOG(1) << "MakeIteratorInternal. table=" << table_id_; + return absl::make_unique>( + typename DatasetIterator::Params{ + this, strings::StrCat(prefix, "::BigtableDataset")}, + table_id_, columns_); + } + + const DataTypeVector& output_dtypes() const override { return dtypes_; } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + std::string DebugString() const override { + return "BigtableDatasetOp::Dataset"; + } + + BigtableClientResource& client_resource() const { return client_resource_; } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented("%s does not support serialization", + DebugString()); + } + + Status CheckExternalState() const override { return Status::OK(); } + + private: + BigtableClientResource& client_resource_; + const core::ScopedUnref client_resource_unref_; + const std::string table_id_; + const std::vector columns_; + DataTypeVector dtypes_; + std::vector output_shapes_; +}; + +class BigtableDatasetOp : public DatasetOpKernel { + public: + explicit BigtableDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("table_id", &table_id_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("columns", &columns_)); + } + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + VLOG(1) << "Make Dataset"; + BigtableClientResource* client_resource; + OP_REQUIRES_OK( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client_resource)); + core::ScopedUnref client_resource_unref_(client_resource); + *output = new Dataset(ctx, client_resource, table_id_, columns_); + } + + private: + std::string table_id_; + std::vector columns_; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableDataset").Device(DEVICE_CPU), + BigtableDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/core/ops/bigtable_ops.cc b/tensorflow_io/core/ops/bigtable_ops.cc new file mode 100644 index 000000000..18339b105 --- /dev/null +++ b/tensorflow_io/core/ops/bigtable_ops.cc @@ -0,0 +1,34 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" + +using namespace tensorflow; + +REGISTER_OP("BigtableClient") + .Attr("project_id: string") + .Attr("instance_id: string") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Output("client: resource") + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("BigtableDataset") + .Input("client: resource") + .Attr("table_id: string") + .Attr("columns: list(string) >= 1") + .Output("handle: variant") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape); \ No newline at end of file diff --git a/tensorflow_io/python/ops/bigtable_dataset_ops.py b/tensorflow_io/python/ops/bigtable_dataset_ops.py new file mode 100644 index 000000000..89d74583d --- /dev/null +++ b/tensorflow_io/python/ops/bigtable_dataset_ops.py @@ -0,0 +1,46 @@ +from typing import List +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import tensor_spec +from tensorflow_io.python.ops import core_ops +from tensorflow.python.framework import dtypes +import tensorflow as tf + + +class BigtableClient: + """BigtableClient is the entrypoint for interacting with Cloud Bigtable in TF. + + BigtableClient encapsulates a connection to Cloud Bigtable, and exposes the + `readSession` method to initiate a Bigtable read session. + """ + + def __init__(self, project_id, instance_id): + """Creates a BigtableClient to start Bigtable read sessions.""" + self._client_resource = core_ops.bigtable_client(project_id, instance_id) + + def get_table(self, table_id): + return BigtableTable(self._client_resource, table_id) + + +class BigtableTable: + def __init__(self, client_resource, table_id: str): + self._table_id = table_id + self._client_resource = client_resource + + def read_rows(self, columns: List[str]): + return _BigtableDataset(self._client_resource, self._table_id, columns) + + +class _BigtableDataset(dataset_ops.DatasetSource): + """_BigtableDataset represents a dataset that retrieves keys and values.""" + + def __init__(self, client_resource, table_id: str, columns: List[str]): + self._table_id = table_id + self._columns = columns + self._element_spec = tf.TensorSpec(shape=[len(columns)], dtype=dtypes.string) + + variant_tensor = core_ops.bigtable_dataset(client_resource, table_id, columns) + super().__init__(variant_tensor) + + @property + def element_spec(self): + return self._element_spec