Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
5954746
poc
kboroszko Oct 7, 2021
9bb40bb
working example
kboroszko Oct 8, 2021
fbc32ec
output string
kboroszko Oct 8, 2021
cdffd68
ugly reading
kboroszko Oct 8, 2021
5d7cbd3
a little prettier
kboroszko Oct 8, 2021
0f39250
basic dataset iterator
kboroszko Oct 12, 2021
e99e083
dataset not working
kboroszko Oct 12, 2021
0708e46
i am retarded
kboroszko Oct 12, 2021
9e71b7b
working example read with reading multiple columns
kboroszko Oct 12, 2021
234c434
namespace and utils
kboroszko Oct 12, 2021
7f53ac7
reading from a dataset
kboroszko Oct 12, 2021
42743cd
removing not vital comments and fixing typos
kboroszko Oct 19, 2021
3997c1d
output types and shapes as members
kboroszko Oct 19, 2021
97eaa44
refactored the structure
kboroszko Oct 19, 2021
8658f5c
refactored getNextInternal
kboroszko Oct 19, 2021
a442449
use absl flat hash map
kboroszko Oct 19, 2021
7a1bedc
refactored subclasses to the outside
kboroszko Oct 20, 2021
7d417ed
removed multiple types and reordered the map to avoid -Wreorder warning
kboroszko Oct 20, 2021
1c8c696
minor fixes and renames
kboroszko Oct 20, 2021
c9f810b
column name pairs are a member, map uses reference
kboroszko Oct 20, 2021
8b8b001
linting and find column_to_idx_
kboroszko Oct 20, 2021
256897a
renamed
kboroszko Oct 21, 2021
bdfe99b
map uses only string references, segfault fixed
kboroszko Oct 21, 2021
fa9c757
client resource
kboroszko Oct 21, 2021
ceec250
removed unused file, and moved client resource
kboroszko Oct 21, 2021
c0e0ac4
fix namespaces
kboroszko Oct 21, 2021
8b9e803
cleanup and client api
kboroszko Oct 21, 2021
73fce88
bigtableTable api
kboroszko Oct 21, 2021
ab3380b
removed obsolete code
kboroszko Oct 21, 2021
31d7d80
linting
kboroszko Oct 22, 2021
913b14d
deleted obsolete file and minor fixes
kboroszko Oct 22, 2021
ed65a6b
deleted obsolete file and minor fixes
kboroszko Oct 22, 2021
da6ebf9
PR comments and scopedUnref
kboroszko Oct 25, 2021
b195a14
comments from pr, compiling
kboroszko Oct 26, 2021
bfdc90e
linting and consts and references
kboroszko Oct 26, 2021
e7fd27c
linting
kboroszko Oct 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,11 @@ http_archive(
http_archive(
name = "com_google_googleapis",
build_file = "@com_github_googleapis_google_cloud_cpp//bazel:googleapis.BUILD",
sha256 = "7ebab01b06c555f4b6514453dc3e1667f810ef91d1d4d2d3aa29bb9fcb40a900",
strip_prefix = "googleapis-541b1ded4abadcc38e8178680b0677f65594ea6f",
sha256 = "a53e15405f81d5a32594d7f6486e649131fadda5431cf28377dff4ae54d45d16",
strip_prefix = "googleapis-d4d09eb3aec152015f35717102f9b423988b94f7",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/googleapis/googleapis/archive/541b1ded4abadcc38e8178680b0677f65594ea6f.zip",
"https://github.com/googleapis/googleapis/archive/541b1ded4abadcc38e8178680b0677f65594ea6f.zip",
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/googleapis/googleapis/archive/d4d09eb3aec152015f35717102f9b423988b94f7.zip",
"https://github.com/googleapis/googleapis/archive/d4d09eb3aec152015f35717102f9b423988b94f7.zip",
],
)

Expand Down
1 change: 1 addition & 0 deletions tensorflow_io/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ cc_binary(
deps = [
"//tensorflow_io/core:arrow_ops",
"//tensorflow_io/core:bigquery_ops",
"//tensorflow_io/core:bigtable_ops",
"//tensorflow_io/core:audio_video_ops",
"//tensorflow_io/core:avro_ops",
"//tensorflow_io/core:orc_ops",
Expand Down
19 changes: 19 additions & 0 deletions tensorflow_io/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,25 @@ cc_library(
alwayslink = 1,
)

cc_library(
name = "bigtable_ops",
srcs = [
"kernels/bigtable/bigtable_kernels.cc",
"kernels/bigtable/bigtable_dataset_kernel.cc",
"ops/bigtable_ops.cc",
],
copts = tf_io_copts(),
linkstatic = True,
deps = [
"@com_github_googleapis_google_cloud_cpp//:bigtable_client",
"@com_github_grpc_grpc//:grpc++",
"@local_config_tf//:libtensorflow_framework",
"@local_config_tf//:tf_header_lib",
],
alwayslink = 1,
)


# A library for use in the bigquery kernels.
cc_library(
name = "bigquery_lib_cc",
Expand Down
236 changes: 236 additions & 0 deletions tensorflow_io/core/kernels/bigtable/bigtable_dataset_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "google/cloud/bigtable/table.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"

using ::tensorflow::DT_STRING;
using ::tensorflow::PartialTensorShape;
using ::tensorflow::Status;

namespace cbt = ::google::cloud::bigtable;

namespace tensorflow {
namespace data {
namespace {


template <typename Dataset>
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const typename DatasetIterator<Dataset>::Params& params, std::string const& project_id,
std::string const& instance_id, std::string const& table_id,
std::vector<std::string> columns)
: DatasetIterator<Dataset>(params),
data_client_(CreateDataClient(project_id, instance_id)),
reader_(
CreateTable(this->data_client_, table_id)
->ReadRows(cbt::RowRange::InfiniteRange(),
cbt::Filter::Chain(CreateColumnsFilter(column_to_idx_),
cbt::Filter::Latest(1)))),
it_(this->reader_.begin()),
column_to_idx_(CreateColumnMap(columns)) {}

Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
VLOG(1) << "GetNextInternal";
mutex_lock l(mu_);
if (it_ == reader_.end()) {
VLOG(1) << "End of sequence";
*end_of_sequence = true;
return Status::OK();
}

VLOG(1) << "alocating tensor";
long n_cols = column_to_idx_.size();
Tensor res(ctx->allocator({}), DT_STRING, {n_cols});
auto res_data = res.tensor<tstring, 1>();

VLOG(1) << "getting row";
auto const& row = *it_;
for (const auto& cell : row.value().cells()) {
auto const key = std::make_pair(cell.family_name(),
cell.column_qualifier());
VLOG(1) << "getting column:" << column_to_idx_[key];
res_data(column_to_idx_[key]) = std::move(cell.value());
}
VLOG(1) << "returning value";
out_tensors->emplace_back(std::move(res));
*end_of_sequence = false;

VLOG(1) << "incrementing iterator";
it_ = std::next(it_);

return Status::OK();
}

protected:
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override {
return errors::Unimplemented("SaveInternal");
}

Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
return errors::Unimplemented(
"Iterator does not support 'RestoreInternal')");
}

private:
std::unique_ptr<cbt::Table> CreateTable(
std::shared_ptr<cbt::DataClient> const& data_client,
std::string const& table_id) {
VLOG(1) << "CreateTable";
return std::make_unique<cbt::Table>(data_client, table_id);
}

std::shared_ptr<cbt::DataClient> CreateDataClient(
std::string const& project_id, std::string const& instance_id) {
VLOG(1) << "CreateDataClient";
return cbt::CreateDefaultDataClient(
std::move(project_id), std::move(instance_id), cbt::ClientOptions());
}

cbt::Filter CreateColumnsFilter(
absl::flat_hash_map<std::pair<std::string, std::string>, size_t> const& columns) {
VLOG(1) << "CreateColumnsFilter";
std::vector<cbt::Filter> filters;

for (const auto& key : columns) {
std::pair<std::string, std::string> column = key.first;
cbt::Filter f = cbt::Filter::ColumnName(std::move(column.first), std::move(column.second));
filters.push_back(std::move(f));
}

return cbt::Filter::InterleaveFromRange(filters.begin(), filters.end());
}

static std::pair<std::string, std::string> ColumnNameToPair(
std::string const& col_name_full) {
size_t delimiter_pos = col_name_full.find(':');
if (delimiter_pos == std::string::npos)
throw std::invalid_argument("Invalid column name:" + col_name_full +
"\nColumn name must be in format " +
"column_family:column_name.");
std::string col_family = col_name_full.substr(0, delimiter_pos);
std::string col_name =
col_name_full.substr(delimiter_pos + 1, col_name_full.length());
std::pair<std::string, std::string> pair(col_family, col_name);
return pair;
}

static absl::flat_hash_map<std::pair<std::string, std::string>, size_t> CreateColumnMap(
std::vector<std::string> const& columns) {
absl::flat_hash_map<std::pair<std::string, std::string>, size_t> column_map;
size_t index = 0;
for (const auto& column_name : columns) {
std::pair<std::string, std::string> pair = ColumnNameToPair(column_name);
column_map[pair] = index++;
}
return column_map;
}

mutex mu_;
std::shared_ptr<cbt::DataClient> data_client_ GUARDED_BY(mu_);
absl::flat_hash_map<std::pair<std::string, std::string>, size_t> column_to_idx_ GUARDED_BY(mu_);
cbt::RowReader reader_ GUARDED_BY(mu_);
cbt::v1::internal::RowReaderIterator it_ GUARDED_BY(mu_);
};


class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, std::string project_id, std::string instance_id,
std::string table_id, std::vector<std::string> columns)
: DatasetBase(DatasetContext(ctx)),
project_id_(project_id),
instance_id_(instance_id),
table_id_(table_id),
columns_(columns) {
size_t num_outputs = columns_.size();
dtypes_.push_back(DT_STRING);
output_shapes_.push_back({});
}

std::unique_ptr<IteratorBase> MakeIteratorInternal(
const std::string& prefix) const {
VLOG(1) << "MakeIteratorInternal. table=" << project_id_ << ":"
<< instance_id_ << ":" << table_id_;
return std::unique_ptr<IteratorBase>(
new Iterator<Dataset>({this, strings::StrCat(prefix, "::BigtableDataset")},
project_id_, instance_id_, table_id_, columns_));
}

const DataTypeVector& output_dtypes() const override { return dtypes_; }

const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}

std::string DebugString() const override {
return "BigtableDatasetOp::Dataset";
}

protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, Node** output) const {
return errors::Unimplemented("%s does not support serialization",
DebugString());
}

Status CheckExternalState() const override { return Status::OK(); }

private:
std::string project_id_;
std::string instance_id_;
std::string table_id_;
std::vector<std::string> columns_;
DataTypeVector dtypes_;
std::vector<PartialTensorShape> output_shapes_;
};


class BigtableDatasetOp : public DatasetOpKernel {
public:
explicit BigtableDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("project_id", &project_id_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("instance_id", &instance_id_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("table_id", &table_id_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("columns", &columns_));
}

void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
VLOG(1) << "Make Dataset";

*output = new Dataset(ctx, project_id_, instance_id_, table_id_, columns_);
}

private:
std::string project_id_;
std::string instance_id_;
std::string table_id_;
std::vector<std::string> columns_;
};



REGISTER_KERNEL_BUILDER(Name("BigtableDataset").Device(DEVICE_CPU),
BigtableDatasetOp);

} // namespace
} // namespace data
} // namespace tensorflow
94 changes: 94 additions & 0 deletions tensorflow_io/core/kernels/bigtable/bigtable_kernels.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "google/cloud/bigtable/table.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"

using namespace tensorflow;

namespace cbt = ::google::cloud::bigtable;

class BigtableTestOp : public OpKernel {
public:
explicit BigtableTestOp(OpKernelConstruction* context) : OpKernel(context) {
// Get the index of the value to preserve
OP_REQUIRES_OK(context, context->GetAttr("project_id", &project_id_));
OP_REQUIRES_OK(context, context->GetAttr("instance_id", &instance_id_));
OP_REQUIRES_OK(context, context->GetAttr("table_id", &table_id_));
OP_REQUIRES_OK(context, context->GetAttr("columns", &columns_));

int index = 0;
for (auto const& column : columns_) {
column_map[column] = index++;
}
}

void Compute(OpKernelContext* context) override {
// Grab the input tensor

cbt::Table table(cbt::CreateDefaultDataClient(project_id_, instance_id_,
cbt::ClientOptions()),
table_id_);

google::cloud::bigtable::v1::RowReader reader1 = table.ReadRows(
cbt::RowRange::InfiniteRange(), cbt::Filter::PassAllFilter());

std::vector<std::vector<std::string>> rows_vec;

for (auto const& row : reader1) {
if (!row) throw std::runtime_error(row.status().message());
std::vector<std::string> row_vec(column_map.size());
std::fill(row_vec.begin(), row_vec.end(), "Nothing");

for (auto const& cell : row->cells()) {
std::string col_name =
cell.family_name() + ":" + cell.column_qualifier();
if (column_map.find(col_name) != column_map.end()) {
row_vec[column_map[col_name]] = cell.value();
}
}
rows_vec.push_back(row_vec);
}

// Create an output tensor
Tensor* output_tensor = NULL;
long N_rows = (long)rows_vec.size();
long N_cols = (long)column_map.size();
OP_REQUIRES_OK(
context, context->allocate_output(0, {N_rows, N_cols}, &output_tensor));
auto output_v = output_tensor->tensor<tstring, 2>();

// Set all but the first element of the output tensor to 0.

for (int i = 0; i < N_rows; i++) {
for (int j = 0; j < N_cols; j++) {
output_v(i, j) = rows_vec[i][j];
}
}
}

private:
string project_id_;
string instance_id_;
string table_id_;
std::vector<string> columns_;
std::map<std::string, int> column_map;
};

REGISTER_KERNEL_BUILDER(Name("BigtableTest").Device(DEVICE_CPU),
BigtableTestOp);
Loading