Skip to content

Commit a8e6bba

Browse files
authored
feat: version filters (#6)
This PR adds support for Bigtable version filters.
1 parent 44fa86d commit a8e6bba

File tree

7 files changed

+291
-7
lines changed

7 files changed

+291
-7
lines changed

tensorflow_io/core/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ cc_library(
184184
"kernels/bigtable/bigtable_row_range.h",
185185
"kernels/bigtable/bigtable_row_set.cc",
186186
"kernels/bigtable/bigtable_row_set.h",
187+
"kernels/bigtable/bigtable_version_filters.cc",
188+
"kernels/bigtable/bigtable_version_filters.h",
187189
"ops/bigtable_ops.cc",
188190
],
189191
copts = tf_io_copts(),

tensorflow_io/core/kernels/bigtable/bigtable_dataset_kernel.cc

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ limitations under the License.
2323
#include "tensorflow/core/framework/resource_mgr.h"
2424
#include "tensorflow/core/framework/resource_op_kernel.h"
2525
#include "tensorflow_io/core/kernels/bigtable/bigtable_row_set.h"
26+
#include "tensorflow_io/core/kernels/bigtable/bigtable_version_filters.h"
2627

2728
namespace cbt = ::google::cloud::bigtable;
2829

@@ -148,6 +149,7 @@ class Iterator : public DatasetIterator<Dataset> {
148149
reader_(this->dataset()->CreateTable().ReadRows(
149150
this->dataset()->row_set(),
150151
cbt::Filter::Chain(CreateColumnsFilter(columns_),
152+
this->dataset()->filter(),
151153
cbt::Filter::Latest(1)))),
152154
it_(this->reader_.begin()),
153155
column_to_idx_(CreateColumnToIdxMap(columns_)) {
@@ -277,11 +279,12 @@ class Dataset : public DatasetBase {
277279
public:
278280
Dataset(OpKernelContext* ctx,
279281
const std::shared_ptr<cbt::DataClient>& data_client,
280-
cbt::RowSet row_set, std::string table_id,
282+
cbt::RowSet row_set, cbt::Filter filter, std::string table_id,
281283
std::vector<std::string> columns)
282284
: DatasetBase(DatasetContext(ctx)),
283285
data_client_(data_client),
284286
row_set_(std::move(row_set)),
287+
filter_(std::move(filter)),
285288
table_id_(table_id),
286289
columns_(columns) {
287290
dtypes_.push_back(DT_STRING);
@@ -319,6 +322,8 @@ class Dataset : public DatasetBase {
319322
return table;
320323
}
321324

325+
const cbt::Filter& filter() const { return filter_; }
326+
322327
protected:
323328
Status AsGraphDefInternal(SerializationContext* ctx,
324329
DatasetGraphDefBuilder* b,
@@ -332,6 +337,7 @@ class Dataset : public DatasetBase {
332337
private:
333338
std::shared_ptr<cbt::DataClient> const& data_client_;
334339
const cbt::RowSet row_set_;
340+
cbt::Filter filter_;
335341
const std::string table_id_;
336342
const std::vector<std::string> columns_;
337343
DataTypeVector dtypes_;
@@ -357,8 +363,14 @@ class BigtableDatasetOp : public DatasetOpKernel {
357363
GetResourceFromContext(ctx, "row_set", &row_set_resource));
358364
core::ScopedUnref row_set_resource_unref_(row_set_resource);
359365

366+
io::BigtableFilterResource* filter_resource;
367+
OP_REQUIRES_OK(ctx,
368+
GetResourceFromContext(ctx, "filter", &filter_resource));
369+
core::ScopedUnref filter_resource_unref_(filter_resource);
370+
360371
*output = new Dataset(ctx, client_resource->data_client(),
361-
row_set_resource->row_set(), table_id_, columns_);
372+
row_set_resource->row_set(),
373+
filter_resource->filter(), table_id_, columns_);
362374
}
363375

364376
private:
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#include "tensorflow_io/core/kernels/bigtable/bigtable_version_filters.h"
16+
17+
namespace cbt = ::google::cloud::bigtable;
18+
19+
namespace tensorflow {
20+
namespace io {
21+
22+
class BigtableLatestFilterOp
23+
: public AbstractBigtableResourceOp<BigtableFilterResource> {
24+
public:
25+
explicit BigtableLatestFilterOp(OpKernelConstruction* ctx)
26+
: AbstractBigtableResourceOp<BigtableFilterResource>(ctx) {
27+
VLOG(1) << "BigtableLatestFilterOp ctor ";
28+
}
29+
30+
private:
31+
StatusOr<BigtableFilterResource*> CreateResource() override {
32+
return new BigtableFilterResource(cbt::Filter::Latest(1));
33+
}
34+
};
35+
36+
REGISTER_KERNEL_BUILDER(Name("BigtableLatestFilter").Device(DEVICE_CPU),
37+
BigtableLatestFilterOp);
38+
39+
class BigtableTimestampRangeFilterOp
40+
: public AbstractBigtableResourceOp<BigtableFilterResource> {
41+
public:
42+
explicit BigtableTimestampRangeFilterOp(OpKernelConstruction* ctx)
43+
: AbstractBigtableResourceOp<BigtableFilterResource>(ctx) {
44+
VLOG(1) << "BigtableTimestampRangeFilterOp ctor ";
45+
OP_REQUIRES_OK(ctx, ctx->GetAttr("start_ts_us", &start_ts_us_));
46+
OP_REQUIRES_OK(ctx, ctx->GetAttr("end_ts_us", &end_ts_us_));
47+
}
48+
49+
private:
50+
StatusOr<BigtableFilterResource*> CreateResource() override {
51+
return new BigtableFilterResource(cbt::Filter::TimestampRangeMicros(start_ts_us_, end_ts_us_));
52+
}
53+
54+
private:
55+
int64_t start_ts_us_;
56+
int64_t end_ts_us_;
57+
};
58+
59+
REGISTER_KERNEL_BUILDER(Name("BigtableTimestampRangeFilter").Device(DEVICE_CPU),
60+
BigtableTimestampRangeFilterOp);
61+
62+
class BigtablePrintFilterOp : public OpKernel {
63+
public:
64+
explicit BigtablePrintFilterOp(OpKernelConstruction* context)
65+
: OpKernel(context) {}
66+
67+
void Compute(OpKernelContext* context) override {
68+
BigtableFilterResource* resource;
69+
OP_REQUIRES_OK(context,
70+
GetResourceFromContext(context, "filter", &resource));
71+
core::ScopedUnref unref(resource);
72+
73+
// Create an output tensor
74+
Tensor* output_tensor = NULL;
75+
OP_REQUIRES_OK(context, context->allocate_output(0, {1}, &output_tensor));
76+
auto output_v = output_tensor->tensor<tstring, 1>();
77+
78+
output_v(0) = resource->ToString();
79+
}
80+
};
81+
82+
REGISTER_KERNEL_BUILDER(Name("BigtablePrintFilter").Device(DEVICE_CPU),
83+
BigtablePrintFilterOp);
84+
85+
} // namespace io
86+
} // namespace tensorflow
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef BIGTABLE_VERSION_FILTERS_H
17+
#define BIGTABLE_VERSION_FILTERS_H
18+
19+
#include "absl/memory/memory.h"
20+
#include "google/cloud/bigtable/table.h"
21+
#include "google/cloud/bigtable/table_admin.h"
22+
#include "tensorflow/core/framework/common_shape_fns.h"
23+
#include "tensorflow/core/framework/dataset.h"
24+
#include "tensorflow/core/framework/op.h"
25+
#include "tensorflow/core/framework/op_kernel.h"
26+
#include "tensorflow/core/framework/resource_mgr.h"
27+
#include "tensorflow/core/framework/resource_op_kernel.h"
28+
#include "tensorflow_io/core/kernels/bigtable/bigtable_resource_kernel.h"
29+
30+
31+
namespace tensorflow {
32+
namespace io {
33+
34+
class BigtableFilterResource : public ResourceBase {
35+
public:
36+
explicit BigtableFilterResource(google::cloud::bigtable::Filter filter)
37+
: filter_(std::move(filter)) {
38+
VLOG(1) << "BigtableFilterResource ctor";
39+
}
40+
41+
~BigtableFilterResource() { VLOG(1) << "BigtableFilterResource dtor"; }
42+
43+
std::string ToString() const {
44+
std::string res;
45+
google::protobuf::TextFormat::PrintToString(filter_.as_proto(), &res);
46+
return res;
47+
}
48+
49+
const google::cloud::bigtable::Filter& filter() const { return filter_; }
50+
51+
string DebugString() const override {
52+
return "BigtableFilterResource:{" + ToString() + "}";
53+
}
54+
55+
private:
56+
const google::cloud::bigtable::Filter filter_;
57+
};
58+
59+
60+
} // namespace io
61+
} // namespace tensorflow
62+
63+
#endif /* BIGTABLE_ROW_SET_H */

tensorflow_io/core/ops/bigtable_ops.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ REGISTER_OP("BigtableClient")
2828
REGISTER_OP("BigtableDataset")
2929
.Input("client: resource")
3030
.Input("row_set: resource")
31+
.Input("filter: resource")
3132
.Attr("table_id: string")
3233
.Attr("columns: list(string) >= 1")
3334
.Output("handle: variant")
@@ -106,3 +107,24 @@ REGISTER_OP("BigtableSplitRowSetEvenly")
106107
c->set_output(0, c->Vector(c->UnknownDim()));
107108
return tensorflow::Status::OK();
108109
});
110+
111+
REGISTER_OP("BigtableLatestFilter")
112+
.Attr("container: string = ''")
113+
.Attr("shared_name: string = ''")
114+
.Output("filter: resource")
115+
.SetIsStateful()
116+
.SetShapeFn(shape_inference::ScalarShape);
117+
118+
REGISTER_OP("BigtableTimestampRangeFilter")
119+
.Attr("container: string = ''")
120+
.Attr("shared_name: string = ''")
121+
.Attr("start_ts_us: int")
122+
.Attr("end_ts_us: int")
123+
.Output("filter: resource")
124+
.SetIsStateful()
125+
.SetShapeFn(shape_inference::ScalarShape);
126+
127+
REGISTER_OP("BigtablePrintFilter")
128+
.Input("filter: resource")
129+
.Output("output: string")
130+
.SetShapeFn(shape_inference::ScalarShape);

tensorflow_io/python/ops/bigtable/bigtable_dataset_ops.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from tensorflow.python.data.ops import dataset_ops
33
from tensorflow.python.framework import tensor_spec
44
from tensorflow_io.python.ops import core_ops
5+
import tensorflow_io.python.ops.bigtable.bigtable_version_filters as filters
56
from tensorflow.python.framework import dtypes
67
import tensorflow as tf
78
from tensorflow.python.data.ops import dataset_ops
@@ -34,14 +35,22 @@ def __init__(self, client_resource, table_id: str):
3435
self._table_id = table_id
3536
self._client_resource = client_resource
3637

37-
def read_rows(self, columns: List[str], row_set: RowSet):
38-
return _BigtableDataset(self._client_resource, self._table_id, columns, row_set)
38+
def read_rows(
39+
self,
40+
columns: List[str],
41+
row_set: RowSet,
42+
filter: filters.BigtableFilter = filters.latest(),
43+
):
44+
return _BigtableDataset(
45+
self._client_resource, self._table_id, columns, row_set, filter
46+
)
3947

4048
def parallel_read_rows(
4149
self,
4250
columns: List[str],
4351
num_parallel_calls=tf.data.AUTOTUNE,
4452
row_set: RowSet = from_rows_or_ranges(infinite()),
53+
filter: filters.BigtableFilter = filters.latest(),
4554
):
4655

4756
print("calling parallel read_rows with row_set:", row_set)
@@ -50,7 +59,7 @@ def parallel_read_rows(
5059
)
5160

5261
def map_func(idx):
53-
return self.read_rows(columns, RowSet(samples[idx]))
62+
return self.read_rows(columns, RowSet(samples[idx]), filter)
5463

5564
# We interleave a dataset of sample's indexes instead of a dataset of
5665
# samples, because Dataset.from_tensor_slices attempts to copy the
@@ -69,14 +78,20 @@ class _BigtableDataset(dataset_ops.DatasetSource):
6978
"""_BigtableDataset represents a dataset that retrieves keys and values."""
7079

7180
def __init__(
72-
self, client_resource, table_id: str, columns: List[str], row_set: RowSet,
81+
self,
82+
client_resource,
83+
table_id: str,
84+
columns: List[str],
85+
row_set: RowSet,
86+
filter,
7387
):
7488
self._table_id = table_id
7589
self._columns = columns
90+
self._filter = filter
7691
self._element_spec = tf.TensorSpec(shape=[len(columns)], dtype=dtypes.string)
7792

7893
variant_tensor = core_ops.bigtable_dataset(
79-
client_resource, row_set._impl, table_id, columns
94+
client_resource, row_set._impl, filter._impl, table_id, columns
8095
)
8196
super().__init__(variant_tensor)
8297

0 commit comments

Comments
 (0)