Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
5954746
poc
kboroszko Oct 7, 2021
9bb40bb
working example
kboroszko Oct 8, 2021
fbc32ec
output string
kboroszko Oct 8, 2021
cdffd68
ugly reading
kboroszko Oct 8, 2021
5d7cbd3
a little prettier
kboroszko Oct 8, 2021
0f39250
basic dataset iterator
kboroszko Oct 12, 2021
e99e083
dataset not working
kboroszko Oct 12, 2021
0708e46
i am retarded
kboroszko Oct 12, 2021
9e71b7b
working example read with reading multiple columns
kboroszko Oct 12, 2021
234c434
namespace and utils
kboroszko Oct 12, 2021
7f53ac7
reading from a dataset
kboroszko Oct 12, 2021
42743cd
removing not vital comments and fixing typos
kboroszko Oct 19, 2021
3997c1d
output types and shapes as members
kboroszko Oct 19, 2021
97eaa44
refactored the structure
kboroszko Oct 19, 2021
8658f5c
refactored getNextInternal
kboroszko Oct 19, 2021
a442449
use absl flat hash map
kboroszko Oct 19, 2021
7a1bedc
refactored subclasses to the outside
kboroszko Oct 20, 2021
7d417ed
removed multiple types and reordered the map to avoid -Wreorder warning
kboroszko Oct 20, 2021
1c8c696
minor fixes and renames
kboroszko Oct 20, 2021
c9f810b
column name pairs are a member, map uses reference
kboroszko Oct 20, 2021
8b8b001
linting and find column_to_idx_
kboroszko Oct 20, 2021
256897a
renamed
kboroszko Oct 21, 2021
bdfe99b
map uses only string references, segfault fixed
kboroszko Oct 21, 2021
fa9c757
client resource
kboroszko Oct 21, 2021
ceec250
removed unused file, and moved client resource
kboroszko Oct 21, 2021
c0e0ac4
fix namespaces
kboroszko Oct 21, 2021
8b9e803
cleanup and client api
kboroszko Oct 21, 2021
73fce88
bigtableTable api
kboroszko Oct 21, 2021
ab3380b
removed obsolete code
kboroszko Oct 21, 2021
31d7d80
linting
kboroszko Oct 22, 2021
913b14d
deleted obsolete file and minor fixes
kboroszko Oct 22, 2021
ed65a6b
deleted obsolete file and minor fixes
kboroszko Oct 22, 2021
da6ebf9
PR comments and scopedUnref
kboroszko Oct 25, 2021
b195a14
comments from pr, compiling
kboroszko Oct 26, 2021
bfdc90e
linting and consts and references
kboroszko Oct 26, 2021
e7fd27c
linting
kboroszko Oct 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,11 @@ http_archive(
http_archive(
name = "com_google_googleapis",
build_file = "@com_github_googleapis_google_cloud_cpp//bazel:googleapis.BUILD",
sha256 = "7ebab01b06c555f4b6514453dc3e1667f810ef91d1d4d2d3aa29bb9fcb40a900",
strip_prefix = "googleapis-541b1ded4abadcc38e8178680b0677f65594ea6f",
sha256 = "a53e15405f81d5a32594d7f6486e649131fadda5431cf28377dff4ae54d45d16",
strip_prefix = "googleapis-d4d09eb3aec152015f35717102f9b423988b94f7",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/googleapis/googleapis/archive/541b1ded4abadcc38e8178680b0677f65594ea6f.zip",
"https://github.com/googleapis/googleapis/archive/541b1ded4abadcc38e8178680b0677f65594ea6f.zip",
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/googleapis/googleapis/archive/d4d09eb3aec152015f35717102f9b423988b94f7.zip",
"https://github.com/googleapis/googleapis/archive/d4d09eb3aec152015f35717102f9b423988b94f7.zip",
],
)

Expand Down
1 change: 1 addition & 0 deletions tensorflow_io/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ cc_binary(
deps = [
"//tensorflow_io/core:arrow_ops",
"//tensorflow_io/core:bigquery_ops",
"//tensorflow_io/core:bigtable_ops",
"//tensorflow_io/core:audio_video_ops",
"//tensorflow_io/core:avro_ops",
"//tensorflow_io/core:orc_ops",
Expand Down
19 changes: 19 additions & 0 deletions tensorflow_io/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,25 @@ cc_library(
alwayslink = 1,
)

cc_library(
name = "bigtable_ops",
srcs = [
"kernels/bigtable/bigtable_kernels.cc",
"kernels/bigtable/bigtable_dataset_kernel.cc",
"ops/bigtable_ops.cc",
],
copts = tf_io_copts(),
linkstatic = True,
deps = [
"@com_github_googleapis_google_cloud_cpp//:bigtable_client",
"@com_github_grpc_grpc//:grpc++",
"@local_config_tf//:libtensorflow_framework",
"@local_config_tf//:tf_header_lib",
],
alwayslink = 1,
)


# A library for use in the bigquery kernels.
cc_library(
name = "bigquery_lib_cc",
Expand Down
257 changes: 257 additions & 0 deletions tensorflow_io/core/kernels/bigtable/bigtable_dataset_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
#include "google/cloud/bigtable/table.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"

using ::tensorflow::DT_STRING;
using ::tensorflow::PartialTensorShape;
using ::tensorflow::Status;

namespace cbt = ::google::cloud::bigtable;

namespace tensorflow {
namespace data {
namespace {

class BigtableDatasetOp : public DatasetOpKernel {
public:
explicit BigtableDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("project_id", &project_id_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("instance_id", &instance_id_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("table_id", &table_id_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("columns", &columns_));
}

void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
// Parse and validate any input tensors that define the dataset using
// `ctx->input()` or the utility function
// `ParseScalarArgument<T>(ctx, &arg)`.
VLOG(1) << "Make Dataset";

// Create the dataset object, passing any (already-validated) arguments from
// attrs or input tensors.
*output = new Dataset(ctx, project_id_, instance_id_, table_id_, columns_);
}

private:
std::string project_id_;
std::string instance_id_;
std::string table_id_;
std::vector<std::string> columns_;

class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, std::string project_id,
std::string instance_id, std::string table_id,
std::vector<std::string> columns)
: DatasetBase(DatasetContext(ctx)),
project_id_(project_id),
instance_id_(instance_id),
table_id_(table_id),
columns_(columns) {}

std::unique_ptr<IteratorBase> MakeIteratorInternal(
const std::string& prefix) const {
VLOG(1) << "MakeIteratorInternal. instance=" << project_id_ << ":"
<< instance_id_ << ":" << table_id_;
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::BigtableDataset")},
project_id_, instance_id_, table_id_, columns_));
}

// Record structure: Each record is represented by a scalar string tensor.
//
// Dataset elements can have a fixed number of components of different
// types and shapes; replace the following two methods to customize this
// aspect of the dataset.
const DataTypeVector& output_dtypes() const override {
static auto* const dtypes = new DataTypeVector({DT_STRING});
return *dtypes;
}

const std::vector<PartialTensorShape>& output_shapes() const override {
static std::vector<PartialTensorShape>* shapes =
new std::vector<PartialTensorShape>({{}});
return *shapes;
}

std::string DebugString() const override {
return "BigtableDatasetOp::Dataset";
}

protected:
// Optional: Implementation of `GraphDef` serialization for this dataset.
//
// Implement this method if you want to be able to save and restore
// instances of this dataset (and any iterators over it).
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, Node** output) const {
// Construct nodes to represent any of the input tensors from this
// object's member variables using `b->AddScalar()` and `b->AddVector()`.

return errors::Unimplemented("%s does not support serialization",
DebugString());
}

Status CheckExternalState() const override { return Status::OK(); }

private:
std::string project_id_;
std::string instance_id_;
std::string table_id_;
std::vector<std::string> columns_;

class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params, std::string const& project_id,
std::string const& instance_id,
std::string const& table_id,
std::vector<std::string> columns)
: DatasetIterator<Dataset>(params),
data_client_(CreateDataClient(project_id, instance_id)),
reader_(CreateTable(this->data_client_, table_id)
->ReadRows(
cbt::RowRange::InfiniteRange(),
cbt::Filter::Chain(CreateColumnsFilter(column_map_),
cbt::Filter::Latest(1)))),
it_(this->reader_.begin()),
column_map_(CreateColumnMap(columns)) {}

// Implementation of the reading logic.
//
// The example implementation in this file yields the string "MyReader!"
// ten times. In general there are three cases:
//
// 1. If an element is successfully read, store it as one or more tensors
// in `*out_tensors`, set `*end_of_sequence = false` and return
// `Status::OK()`.
// 2. If the end of input is reached, set `*end_of_sequence = true` and
// return `Status::OK()`.
// 3. If an error occurs, return an error status using one of the helper
// functions from "tensorflow/core/lib/core/errors.h".
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
// NOTE: `GetNextInternal()` may be called concurrently, so it is
// recommended that you protect the iterator state with a mutex.

VLOG(1) << "GetNextInternal";
mutex_lock l(mu_);
if (it_ == reader_.end()) {
VLOG(1) << "End of sequence";
*end_of_sequence = true;
} else {
VLOG(1) << "alocating tensor";
long n_cols = column_map_.size();
Tensor record_tensor(ctx->allocator({}), DT_STRING, {n_cols});
auto record_v = record_tensor.tensor<tstring, 1>();

VLOG(1) << "getting row";
auto const& row = *it_;
for (const auto& cell : row.value().cells()) {
std::pair<std::string, std::string> key(cell.family_name(),
cell.column_qualifier());
VLOG(1) << "getting column:" << column_map_[key];
record_v(column_map_[key]) = cell.value();
}
VLOG(1) << "returning value";
out_tensors->emplace_back(std::move(record_tensor));
*end_of_sequence = false;

VLOG(1) << "incrementing iterator";
it_ = std::next(it_);
}

return Status::OK();
}

protected:
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override {
return errors::Unimplemented("SaveInternal");
}

Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
return errors::Unimplemented(
"Iterator does not support 'RestoreInternal')");
}

private:
std::unique_ptr<cbt::Table> CreateTable(
std::shared_ptr<cbt::DataClient> const& data_client,
std::string const& table_id) {
VLOG(1) << "CreateTable";
return std::make_unique<cbt::Table>(data_client, table_id);
}

std::shared_ptr<cbt::DataClient> CreateDataClient(
std::string const& project_id, std::string const& instance_id) {
VLOG(1) << "CreateDataClient";
return cbt::CreateDefaultDataClient(std::move(project_id),
std::move(instance_id),
cbt::ClientOptions());
}

cbt::Filter CreateColumnsFilter(
std::map<std::pair<std::string, std::string>, size_t> const&
columns) {
VLOG(1) << "CreateColumnsFilter";
std::vector<cbt::Filter> filters;

for (const auto& key : columns) {
std::pair<std::string, std::string> pair = key.first;
cbt::Filter f = cbt::Filter::ColumnName(pair.first, pair.second);
filters.push_back(std::move(f));
}

return cbt::Filter::InterleaveFromRange(filters.begin(), filters.end());
}

static std::pair<std::string, std::string> ColumnNameToPair(
std::string const& col_name_full) {
size_t delimiter_pos = col_name_full.find(':');
if (delimiter_pos == std::string::npos)
throw std::invalid_argument("Invalid column name:" + col_name_full +
"\nColumn name must be in format " +
"column_family:column_name.");
std::string col_family = col_name_full.substr(0, delimiter_pos);
std::string col_name =
col_name_full.substr(delimiter_pos + 1, col_name_full.length());
std::pair<std::string, std::string> pair(col_family, col_name);
return pair;
}

static std::map<std::pair<std::string, std::string>, size_t>
CreateColumnMap(std::vector<std::string> const& columns) {
std::map<std::pair<std::string, std::string>, size_t> column_map;
size_t index = 0;
for (const auto& column_name : columns) {
std::pair<std::string, std::string> pair =
ColumnNameToPair(column_name);
column_map[pair] = index++;
}
return column_map;
}

mutex mu_;
// Mapping between column names and their indices in tensors. We're using
// a
// regular map because unordered_map cannot hash a pair by default.
std::map<std::pair<std::string, std::string>, size_t> column_map_
GUARDED_BY(mu_);
std::shared_ptr<cbt::DataClient> data_client_ GUARDED_BY(mu_);
cbt::RowReader reader_ GUARDED_BY(mu_);
cbt::v1::internal::RowReaderIterator it_ GUARDED_BY(mu_);
};
};
};

// Register the kernel implementation for MyReaderDataset.
REGISTER_KERNEL_BUILDER(Name("BigtableDataset").Device(DEVICE_CPU),
BigtableDatasetOp);

} // namespace
} // namespace data
} // namespace tensorflow
80 changes: 80 additions & 0 deletions tensorflow_io/core/kernels/bigtable/bigtable_kernels.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#include "google/cloud/bigtable/table.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"

using namespace tensorflow;

namespace cbt = ::google::cloud::bigtable;

class BigtableTestOp : public OpKernel {
public:
explicit BigtableTestOp(OpKernelConstruction* context) : OpKernel(context) {
// Get the index of the value to preserve
OP_REQUIRES_OK(context, context->GetAttr("project_id", &project_id_));
OP_REQUIRES_OK(context, context->GetAttr("instance_id", &instance_id_));
OP_REQUIRES_OK(context, context->GetAttr("table_id", &table_id_));
OP_REQUIRES_OK(context, context->GetAttr("columns", &columns_));

int index = 0;
for (auto const& column : columns_) {
column_map[column] = index++;
}
}

void Compute(OpKernelContext* context) override {
// Grab the input tensor

cbt::Table table(cbt::CreateDefaultDataClient(project_id_, instance_id_,
cbt::ClientOptions()),
table_id_);

google::cloud::bigtable::v1::RowReader reader1 = table.ReadRows(
cbt::RowRange::InfiniteRange(), cbt::Filter::PassAllFilter());

std::vector<std::vector<std::string>> rows_vec;

for (auto const& row : reader1) {
if (!row) throw std::runtime_error(row.status().message());
std::vector<std::string> row_vec(column_map.size());
std::fill(row_vec.begin(), row_vec.end(), "Nothing");

for (auto const& cell : row->cells()) {
std::string col_name =
cell.family_name() + ":" + cell.column_qualifier();
if (column_map.find(col_name) != column_map.end()) {
row_vec[column_map[col_name]] = cell.value();
}
}
rows_vec.push_back(row_vec);
}

// Create an output tensor
Tensor* output_tensor = NULL;
long N_rows = (long)rows_vec.size();
long N_cols = (long)column_map.size();
OP_REQUIRES_OK(
context, context->allocate_output(0, {N_rows, N_cols}, &output_tensor));
auto output_v = output_tensor->tensor<tstring, 2>();

// Set all but the first element of the output tensor to 0.

for (int i = 0; i < N_rows; i++) {
for (int j = 0; j < N_cols; j++) {
output_v(i, j) = rows_vec[i][j];
}
}
}

private:
string project_id_;
string instance_id_;
string table_id_;
std::vector<string> columns_;
std::map<std::string, int> column_map;
};

REGISTER_KERNEL_BUILDER(Name("BigtableTest").Device(DEVICE_CPU),
BigtableTestOp);
Loading