Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
9c219d4
remove obsolete classes from bigtable_dataset_kernel
kboroszko Nov 5, 2021
c36527f
outline
kboroszko Oct 28, 2021
ae7e73b
simplified creating resource
kboroszko Oct 28, 2021
b8fdd3b
parallel read not exactly working
kboroszko Oct 28, 2021
93d9423
parallel not split work working
kboroszko Oct 29, 2021
7137163
sampleRowSet
kboroszko Oct 29, 2021
bd87e5b
parallel without row_set
kboroszko Oct 29, 2021
24eb8c9
rowset in parallel working
kboroszko Oct 29, 2021
d673b99
row_set const ref working
kboroszko Oct 29, 2021
e8fedea
working parallel all
kboroszko Nov 2, 2021
c4c1792
PR comments and linting
kboroszko Nov 2, 2021
45eedf7
added more tests for parallel read
kboroszko Nov 3, 2021
9ba5773
removed sample row_keys because it's unused
kboroszko Nov 3, 2021
d24fd1c
removed obsolete code and comments
kboroszko Nov 3, 2021
96981fb
code cleanup 1
kboroszko Nov 3, 2021
c756183
more tests for parallel read
kboroszko Nov 3, 2021
ace753a
run linter on python files
kboroszko Nov 5, 2021
360852e
linter on tests
kboroszko Nov 5, 2021
5b8110a
after rebase
kboroszko Nov 5, 2021
8975bc9
add tests
kboroszko Nov 5, 2021
963d3b3
linting
kboroszko Nov 5, 2021
71ca232
samples working but ugly
kboroszko Nov 9, 2021
6709520
removed accidental change
kboroszko Nov 10, 2021
099768f
Use resource tensor.
kboroszko Nov 10, 2021
0ca85bf
run linter and fixed namimg
kboroszko Nov 10, 2021
127c355
fix naming
kboroszko Nov 10, 2021
c5a8fae
handled empty row_set
kboroszko Nov 10, 2021
628a07f
pr comments
kboroszko Nov 19, 2021
e9bf2ef
linter
kboroszko Nov 19, 2021
759afcb
removed missed comment
kboroszko Nov 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 142 additions & 16 deletions tensorflow_io/core/kernels/bigtable/bigtable_dataset_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/memory/memory.h"
#include "google/cloud/bigtable/row_set.h"
#include "google/cloud/bigtable/table.h"
#include "google/cloud/bigtable/table_admin.h"
#include "tensorflow/core/framework/common_shape_fns.h"
Expand Down Expand Up @@ -81,13 +82,19 @@ class BigtableClientResource : public ResourceBase {
public:
explicit BigtableClientResource(const std::string& project_id,
const std::string& instance_id)
: data_client_(CreateDataClient(project_id, instance_id)) {}
: data_client_(CreateDataClient(project_id, instance_id)) {
VLOG(1) << "BigtableClientResource ctor";
}

cbt::Table CreateTable(const std::string& table_id) {
VLOG(1) << "CreateTable";
return cbt::Table(data_client_, table_id);
cbt::Table table(data_client_, table_id);
VLOG(1) << "table crated";
return table;
}

~BigtableClientResource() { VLOG(1) << "BigtableClientResource dtor"; }

string DebugString() const override { return "BigtableClientResource"; }

private:
Expand All @@ -108,6 +115,8 @@ class BigtableClientOp : public OpKernel {
VLOG(1) << "BigtableClientOp ctor";
}

~BigtableClientOp() { VLOG(1) << "BigtableClientOp dtor"; }

void Compute(OpKernelContext* ctx) override TF_LOCKS_EXCLUDED(mu_) {
VLOG(1) << "BigtableClientOp compute";
ResourceMgr* mgr = ctx->resource_manager();
Expand Down Expand Up @@ -139,13 +148,16 @@ class Iterator : public DatasetIterator<Dataset> {
const std::vector<std::string>& columns)
: DatasetIterator<Dataset>(params),
columns_(ColumnsToFamiliesAndQualifiers(columns)),
reader_(
this->dataset()->client_resource().CreateTable(table_id).ReadRows(
cbt::RowRange::InfiniteRange(),
cbt::Filter::Chain(CreateColumnsFilter(columns_),
cbt::Filter::Latest(1)))),
table_(this->dataset()->client_resource().CreateTable(table_id)),
reader_(this->table_.ReadRows(
this->dataset()->row_set_resource().row_set(),
// cbt::RowRange::InfiniteRange(),
cbt::Filter::Chain(CreateColumnsFilter(columns_),
cbt::Filter::Latest(1)))),
it_(this->reader_.begin()),
column_to_idx_(CreateColumnToIdxMap(columns_)) {}
column_to_idx_(CreateColumnToIdxMap(columns_)) {
VLOG(1) << "DatasetIterator ctor";
}

Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
Expand Down Expand Up @@ -256,8 +268,8 @@ class Iterator : public DatasetIterator<Dataset> {
}

mutex mu_;
const std::shared_ptr<cbt::DataClient> data_client_;
const std::vector<std::pair<std::string, std::string>> columns_;
cbt::Table table_;
cbt::RowReader reader_ GUARDED_BY(mu_);
cbt::v1::internal::RowReaderIterator it_ GUARDED_BY(mu_);
// we're using a map with const refs to avoid copying strings when searching
Expand All @@ -270,10 +282,11 @@ class Iterator : public DatasetIterator<Dataset> {
class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, BigtableClientResource* client_resource,
std::string table_id, std::vector<std::string> columns)
io::BigtableRowSetResource* row_set_resource, std::string table_id,
std::vector<std::string> columns)
: DatasetBase(DatasetContext(ctx)),
client_resource_(*client_resource),
client_resource_unref_(client_resource),
row_set_resource_(*row_set_resource),
table_id_(table_id),
columns_(columns) {
dtypes_.push_back(DT_STRING);
Expand All @@ -300,6 +313,7 @@ class Dataset : public DatasetBase {
}

BigtableClientResource& client_resource() const { return client_resource_; }
io::BigtableRowSetResource& row_set_resource() const { return row_set_resource_; }

protected:
Status AsGraphDefInternal(SerializationContext* ctx,
Expand All @@ -313,7 +327,7 @@ class Dataset : public DatasetBase {

private:
BigtableClientResource& client_resource_;
const core::ScopedUnref client_resource_unref_;
io::BigtableRowSetResource& row_set_resource_;
const std::string table_id_;
const std::vector<std::string> columns_;
DataTypeVector dtypes_;
Expand All @@ -330,10 +344,17 @@ class BigtableDatasetOp : public DatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
VLOG(1) << "Make Dataset";
BigtableClientResource* client_resource;
OP_REQUIRES_OK(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client_resource));
core::ScopedUnref client_resource_unref_(client_resource);
*output = new Dataset(ctx, client_resource, table_id_, columns_);
OP_REQUIRES_OK(ctx,
GetResourceFromContext(ctx, "client", &client_resource));
core::ScopedUnref unref_client(client_resource);

io::BigtableRowSetResource* row_set_resource;
OP_REQUIRES_OK(ctx,
GetResourceFromContext(ctx, "row_set", &row_set_resource));
core::ScopedUnref row_set_resource_unref_(row_set_resource);

*output = new Dataset(ctx, client_resource, row_set_resource, table_id_,
columns_);
}

private:
Expand All @@ -344,6 +365,111 @@ class BigtableDatasetOp : public DatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("BigtableDataset").Device(DEVICE_CPU),
BigtableDatasetOp);

// Return the index of the tablet that a worker should start with. Each worker
// start with their first tablet and finish on tablet before next worker's first
// tablet. Each worker should get num_tablets/num_workers rounded down, plus at
// most one. If we simply round up, then the last worker may be starved.
// Consider an example where there's 100 tablets and 11 workers. If we give
// round_up(100/11) to each one, then first 10 workers get 10 tablets each, and
// the last one gets nothing.
int GetWorkerStartIndex(size_t num_tablets, size_t num_workers,
size_t worker_id) {
// if there's more workers than tablets, workers get one tablet each or less.
if (num_tablets <= num_workers) return std::min(num_tablets, worker_id);
// tablets_per_worker: minimum tablets each worker should obtain.
size_t const tablets_per_worker = num_tablets / num_workers;
// surplus_tablets: excess that has to be evenly distributed among the workers
// so that no worker gets more than tablets_per_worker + 1.
size_t const surplus_tablets = num_tablets % num_workers;
size_t const workers_before = worker_id;
return tablets_per_worker * workers_before +
std::min(surplus_tablets, workers_before);
}

bool RowSetIntersectsRange(cbt::RowSet const& row_set,
std::string const& start_key,
std::string const& end_key) {
auto range = cbt::RowRange::Range(start_key, end_key);
return !row_set.Intersect(range).IsEmpty();
}

class BigtableSampleRowSetsOp : public OpKernel {
public:
explicit BigtableSampleRowSetsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
VLOG(1) << "BigtableSampleRowSetsOp ctor ";
OP_REQUIRES_OK(ctx, ctx->GetAttr("table_id", &table_id_));
OP_REQUIRES_OK(ctx,
ctx->GetAttr("num_parallel_calls", &num_parallel_calls_));
}

void Compute(OpKernelContext* context) override {
mutex_lock l(mu_);
BigtableClientResource* client_resource;
OP_REQUIRES_OK(context,
GetResourceFromContext(context, "client", &client_resource));
core::ScopedUnref unref_client(client_resource);

io::BigtableRowSetResource* row_set_resource;
OP_REQUIRES_OK(
context, GetResourceFromContext(context, "row_set", &row_set_resource));
core::ScopedUnref unref_row_set(row_set_resource);

auto table = client_resource->CreateTable(table_id_);
auto maybe_sample_row_keys = table.SampleRows();
if (!maybe_sample_row_keys.ok())
throw std::runtime_error(maybe_sample_row_keys.status().message());
auto& sample_row_keys = maybe_sample_row_keys.value();

std::vector<std::pair<std::string, std::string>> tablets;

std::string start_key;
for (auto& sample_row_key : sample_row_keys) {
auto& end_key = sample_row_key.row_key;
tablets.emplace_back(start_key, end_key);
start_key = std::move(end_key);
}
if (!start_key.empty()) {
tablets.emplace_back(start_key, "");
}
tablets.erase(
std::remove_if(
tablets.begin(), tablets.end(),
[row_set_resource](std::pair<std::string, std::string> const& p) {
return !RowSetIntersectsRange(row_set_resource->row_set(), p.first,
p.second);
}),
tablets.end());

VLOG(1) << "got array of tablets of size:" << tablets.size();

size_t output_size = std::min(tablets.size(), (size_t)num_parallel_calls_);

Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, {(long)output_size, 2},
&output_tensor));
auto output_v = output_tensor->tensor<tstring, 2>();

for (size_t i = 0; i < output_size; i++) {
size_t start_idx = GetWorkerStartIndex(tablets.size(), output_size, i);
size_t next_worker_start_idx =
GetWorkerStartIndex(tablets.size(), output_size, i + 1);
size_t end_idx = next_worker_start_idx - 1;
start_key = tablets.at(start_idx).first;
std::string end_key = tablets.at(end_idx).second;
output_v(i, 0) = start_key;
output_v(i, 1) = end_key;
}
}

private:
mutable mutex mu_;
std::string table_id_;
int num_parallel_calls_;
};

REGISTER_KERNEL_BUILDER(Name("BigtableSampleRowSets").Device(DEVICE_CPU),
BigtableSampleRowSetsOp);

} // namespace
} // namespace data
} // namespace tensorflow
51 changes: 51 additions & 0 deletions tensorflow_io/core/kernels/bigtable/bigtable_row_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,5 +151,56 @@ class BigtableRowSetIntersectOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("BigtableRowSetIntersect").Device(DEVICE_CPU),
BigtableRowSetIntersectOp);


class BigtableRowSetIntersectTensorOp : public OpKernel {
public:
explicit BigtableRowSetIntersectTensorOp(OpKernelConstruction* context)
: OpKernel(context) {}

void Compute(OpKernelContext* context) override TF_LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
ResourceMgr* mgr = context->resource_manager();
OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));

BigtableRowSetResource* row_set_resource;
OP_REQUIRES_OK(context, GetResourceFromContext(context, "row_set",
&row_set_resource));
core::ScopedUnref row_set_resource_unref(row_set_resource);


const Tensor* row_keys_tensor;
OP_REQUIRES_OK(context, context->input("row_range_tensor", &row_keys_tensor));
auto row_keys = row_keys_tensor->tensor<tstring, 1>();

VLOG(1) << "RowsetIntersectTensor intersecting: [" << row_keys(0) << "," << row_keys(1) << ")";

BigtableRowSetResource* result_resource;
OP_REQUIRES_OK(
context,
mgr->LookupOrCreate<BigtableRowSetResource>(
cinfo_.container(), cinfo_.name(), &result_resource,
[this, row_set_resource, &row_keys](
BigtableRowSetResource** ret) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
*ret = new BigtableRowSetResource(
row_set_resource->Intersect(
cbt::RowRange::RightOpen(row_keys(0), row_keys(1))
));
return Status::OK();
}));

OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
context, 0, cinfo_.container(), cinfo_.name(),
TypeIndex::Make<BigtableRowSetResource>()));
}

protected:
// Variables accessible from subclasses.
mutex mu_;
ContainerInfo cinfo_ TF_GUARDED_BY(mu_);
};

REGISTER_KERNEL_BUILDER(Name("BigtableRowSetIntersectTensor").Device(DEVICE_CPU),
BigtableRowSetIntersectTensorOp);

} // namespace io
} // namespace tensorflow
4 changes: 4 additions & 0 deletions tensorflow_io/core/kernels/bigtable/bigtable_row_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ class BigtableRowSetResource : public ResourceBase {
return row_set_.Intersect(row_range);
}

google::cloud::bigtable::RowSet const& row_set(){
return row_set_;
}

string DebugString() const override {
return "BigtableRowSetResource:{" + ToString() + "}";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ int64 MemcachedFileBlockCache::AddToCacheBuffer(const string& memc_key,
cache_buffer_keys_.push_back(memc_key);
auto page = absl::make_unique<std::vector<char>>();
page->assign(data->begin(), data->end());
cache_buffer_map_.emplace(memc_key, page.release());
cache_buffer_map_.emplace(memc_key, std::move(page));
}
return cache_buffer_keys_.size();
}
Expand Down
20 changes: 20 additions & 0 deletions tensorflow_io/core/ops/bigtable_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ REGISTER_OP("BigtableClient")

REGISTER_OP("BigtableDataset")
.Input("client: resource")
.Input("row_set: resource")
.Attr("table_id: string")
.Attr("columns: list(string) >= 1")
.Output("handle: variant")
Expand Down Expand Up @@ -93,3 +94,22 @@ REGISTER_OP("BigtableRowSetIntersect")
.Output("result_row_set: resource")
.SetShapeFn(shape_inference::ScalarShape);

REGISTER_OP("BigtableRowSetIntersectTensor")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Input("row_set: resource")
.Input("row_range_tensor: string")
.Output("result_row_set: resource")
.SetShapeFn(shape_inference::UnchangedShape);


REGISTER_OP("BigtableSampleRowSets")
.Input("client: resource")
.Input("row_set: resource")
.Attr("table_id: string")
.Attr("num_parallel_calls: int")
.Output("samples: string")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->Vector(c->UnknownDim()));
return tensorflow::Status::OK();
});
Loading