Skip to content

Commit 4489e24

Browse files
committed
add filters to python api
1 parent 1f6088e commit 4489e24

File tree

6 files changed

+51
-30
lines changed

6 files changed

+51
-30
lines changed

tensorflow_io/core/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ cc_library(
184184
"kernels/bigtable/bigtable_row_set.h",
185185
"kernels/bigtable/bigtable_row_range.cc",
186186
"kernels/bigtable/bigtable_row_set.cc",
187+
"kernels/bigtable/bigtable_version_filters.cc",
188+
"kernels/bigtable/bigtable_version_filters.h",
187189
"ops/bigtable_ops.cc",
188190
],
189191
copts = tf_io_copts(),

tensorflow_io/core/kernels/bigtable/bigtable_dataset_kernel.cc

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,16 @@ limitations under the License.
2222
#include "tensorflow/core/framework/resource_mgr.h"
2323
#include "tensorflow/core/framework/resource_op_kernel.h"
2424
#include "tensorflow_io/core/kernels/bigtable/bigtable_row_set.h"
25+
#include "tensorflow_io/core/kernels/bigtable/bigtable_version_filters.h"
2526

2627
namespace cbt = ::google::cloud::bigtable;
2728

2829
namespace tensorflow {
2930
namespace data {
3031
namespace {
3132

32-
tensorflow::error::Code GoogleCloudErrorCodeToTfErrorCode(::google::cloud::StatusCode code) {
33+
tensorflow::error::Code GoogleCloudErrorCodeToTfErrorCode(
34+
::google::cloud::StatusCode code) {
3335
switch (code) {
3436
case ::google::cloud::StatusCode::kOk:
3537
return ::tensorflow::error::OK;
@@ -72,9 +74,9 @@ Status GoogleCloudStatusToTfStatus(const ::google::cloud::Status& status) {
7274
if (status.ok()) {
7375
return Status::OK();
7476
}
75-
return Status(GoogleCloudErrorCodeToTfErrorCode(status.code()),
76-
strings::StrCat("Error reading from Cloud Bigtable: ",
77-
status.message()));
77+
return Status(
78+
GoogleCloudErrorCodeToTfErrorCode(status.code()),
79+
strings::StrCat("Error reading from Cloud Bigtable: ", status.message()));
7880
}
7981

8082
class BigtableClientResource : public ResourceBase {
@@ -143,6 +145,7 @@ class Iterator : public DatasetIterator<Dataset> {
143145
this->dataset()->client_resource().CreateTable(table_id).ReadRows(
144146
cbt::RowRange::InfiniteRange(),
145147
cbt::Filter::Chain(CreateColumnsFilter(columns_),
148+
this->dataset()->filter_resource().filter(),
146149
cbt::Filter::Latest(1)))),
147150
it_(this->reader_.begin()),
148151
column_to_idx_(CreateColumnToIdxMap(columns_)) {}
@@ -270,10 +273,13 @@ class Iterator : public DatasetIterator<Dataset> {
270273
class Dataset : public DatasetBase {
271274
public:
272275
Dataset(OpKernelContext* ctx, BigtableClientResource* client_resource,
273-
std::string table_id, std::vector<std::string> columns)
276+
io::BigtableFilterResource* filter_resource, std::string table_id,
277+
std::vector<std::string> columns)
274278
: DatasetBase(DatasetContext(ctx)),
275279
client_resource_(*client_resource),
276280
client_resource_unref_(client_resource),
281+
filter_resource_(*filter_resource),
282+
filter_resource_unref_(filter_resource),
277283
table_id_(table_id),
278284
columns_(columns) {
279285
dtypes_.push_back(DT_STRING);
@@ -300,6 +306,9 @@ class Dataset : public DatasetBase {
300306
}
301307

302308
BigtableClientResource& client_resource() const { return client_resource_; }
309+
io::BigtableFilterResource& filter_resource() const {
310+
return filter_resource_;
311+
}
303312

304313
protected:
305314
Status AsGraphDefInternal(SerializationContext* ctx,
@@ -314,6 +323,8 @@ class Dataset : public DatasetBase {
314323
private:
315324
BigtableClientResource& client_resource_;
316325
const core::ScopedUnref client_resource_unref_;
326+
io::BigtableFilterResource& filter_resource_;
327+
const core::ScopedUnref filter_resource_unref_;
317328
const std::string table_id_;
318329
const std::vector<std::string> columns_;
319330
DataTypeVector dtypes_;
@@ -330,10 +341,15 @@ class BigtableDatasetOp : public DatasetOpKernel {
330341
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
331342
VLOG(1) << "Make Dataset";
332343
BigtableClientResource* client_resource;
333-
OP_REQUIRES_OK(
334-
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client_resource));
344+
OP_REQUIRES_OK(ctx,
345+
GetResourceFromContext(ctx, "client", &client_resource));
335346
core::ScopedUnref client_resource_unref_(client_resource);
336-
*output = new Dataset(ctx, client_resource, table_id_, columns_);
347+
io::BigtableFilterResource* filter_resource;
348+
OP_REQUIRES_OK(ctx,
349+
GetResourceFromContext(ctx, "filter", &filter_resource));
350+
core::ScopedUnref filter_resource_unref_(filter_resource);
351+
*output =
352+
new Dataset(ctx, client_resource, filter_resource, table_id_, columns_);
337353
}
338354

339355
private:

tensorflow_io/core/kernels/bigtable/bigtable_version_filters.cc

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,16 @@ namespace tensorflow {
2020
namespace io {
2121

2222
class BigtableLatestFilterOp
23-
: public OpKernelCreatingResource<BigtableFilterResource> {
23+
: public AbstractBigtableResourceOp<BigtableFilterResource> {
2424
public:
2525
explicit BigtableLatestFilterOp(OpKernelConstruction* ctx)
26-
: OpKernelCreatingResource<BigtableFilterResource>(ctx) {
26+
: AbstractBigtableResourceOp<BigtableFilterResource>(ctx) {
2727
VLOG(1) << "BigtableLatestFilterOp ctor ";
2828
}
2929

3030
private:
31-
Status CreateResource(BigtableFilterResource** resource)
32-
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
33-
*resource = new BigtableFilterResource(cbt::Filter::Latest(1));
34-
return Status::OK();
31+
StatusOr<BigtableFilterResource*> CreateResource() override {
32+
return new BigtableFilterResource(cbt::Filter::Latest(1));
3533
}
3634

3735
private:
@@ -42,20 +40,18 @@ REGISTER_KERNEL_BUILDER(Name("BigtableLatestFilter").Device(DEVICE_CPU),
4240
BigtableLatestFilterOp);
4341

4442
class BigtableTimestampRangeFilterOp
45-
: public OpKernelCreatingResource<BigtableFilterResource> {
43+
: public AbstractBigtableResourceOp<BigtableFilterResource> {
4644
public:
4745
explicit BigtableTimestampRangeFilterOp(OpKernelConstruction* ctx)
48-
: OpKernelCreatingResource<BigtableFilterResource>(ctx) {
46+
: AbstractBigtableResourceOp<BigtableFilterResource>(ctx) {
4947
VLOG(1) << "BigtableTimestampRangeFilterOp ctor ";
5048
OP_REQUIRES_OK(ctx, ctx->GetAttr("start", &start_));
5149
OP_REQUIRES_OK(ctx, ctx->GetAttr("end", &end_));
5250
}
5351

5452
private:
55-
Status CreateResource(BigtableFilterResource** resource)
56-
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
57-
*resource = new BigtableFilterResource(cbt::Filter::TimestampRangeMicros(start_, end_));
58-
return Status::OK();
53+
StatusOr<BigtableFilterResource*> CreateResource() override {
54+
return new BigtableFilterResource(cbt::Filter::TimestampRangeMicros(start_, end_));
5955
}
6056

6157
private:

tensorflow_io/core/ops/bigtable_ops.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ REGISTER_OP("BigtableClient")
2727

2828
REGISTER_OP("BigtableDataset")
2929
.Input("client: resource")
30+
.Input("filter: resource")
3031
.Attr("table_id: string")
3132
.Attr("columns: list(string) >= 1")
3233
.Output("handle: variant")
@@ -113,7 +114,7 @@ REGISTER_OP("BigtableSampleRowSets")
113114
return tensorflow::Status::OK();
114115
});
115116

116-
REGISTER_OP("BigtableEmptyRowRange")
117+
REGISTER_OP("BigtableLatestFilter")
117118
.Attr("container: string = ''")
118119
.Attr("shared_name: string = ''")
119120
.Output("filter: resource")
@@ -124,7 +125,7 @@ REGISTER_OP("BigtableTimestampRangeFilter")
124125
.Attr("container: string = ''")
125126
.Attr("shared_name: string = ''")
126127
.Attr("start: int")
127-
.Attr("start: int")
128+
.Attr("end: int")
128129
.Output("filter: resource")
129130
.SetIsStateful()
130131
.SetShapeFn(shape_inference::ScalarShape);

tensorflow_io/python/ops/bigtable/bigtable_dataset_ops.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from tensorflow.python.data.ops import dataset_ops
33
from tensorflow.python.framework import tensor_spec
44
from tensorflow_io.python.ops import core_ops
5+
import tensorflow_io.python.ops.bigtable.bigtable_version_filters as filters
56
from tensorflow.python.framework import dtypes
67
import tensorflow as tf
78

@@ -28,22 +29,23 @@ def __init__(self, client_resource, table_id: str):
2829
self._table_id = table_id
2930
self._client_resource = client_resource
3031

31-
def read_rows(self, columns: List[str]):
32-
return _BigtableDataset(self._client_resource, self._table_id, columns)
32+
def read_rows(self, columns: List[str], filter: filters.BigtableFilter = filters.latest()):
33+
return _BigtableDataset(self._client_resource, self._table_id, columns, filter)
3334

3435

3536
class _BigtableDataset(dataset_ops.DatasetSource):
3637
"""_BigtableDataset represents a dataset that retrieves keys and values."""
3738

38-
def __init__(self, client_resource, table_id: str, columns: List[str]):
39+
def __init__(self, client_resource, table_id: str, columns: List[str], filter):
3940
self._table_id = table_id
4041
self._columns = columns
42+
self._filter = filter
4143
self._element_spec = tf.TensorSpec(
4244
shape=[len(columns)], dtype=dtypes.string
4345
)
4446

4547
variant_tensor = core_ops.bigtable_dataset(
46-
client_resource, table_id, columns
48+
client_resource, filter._impl, table_id, columns
4749
)
4850
super().__init__(variant_tensor)
4951

tensorflow_io/python/ops/bigtable/bigtable_version_filters.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def latest():
3434
Returns:
3535
pbt_C.Filter: Filter passing only most recent version of a value.
3636
"""
37-
return core_ops.bigtable_latest_filter()
37+
return BigtableFilter(core_ops.bigtable_latest_filter())
3838

3939

4040
def timestamp_range(
@@ -60,7 +60,9 @@ def timestamp_range(
6060
else:
6161
end_timestamp = int(end * 1e6)
6262

63-
return core_ops.bigtable_timestamp_range_filter(start_timestamp, end_timestamp)
63+
return BigtableFilter(
64+
core_ops.bigtable_timestamp_range_filter(start_timestamp, end_timestamp)
65+
)
6466

6567

6668
def timestamp_range_micros(
@@ -75,6 +77,8 @@ def timestamp_range_micros(
7577
Returns:
7678
pbt_C.Filter: Filter passing only values' versions from the specified range.
7779
"""
78-
return core_ops.bigtable_timestamp_range_filter(
79-
int(start_timestamp), int(end_timestamp)
80+
return BigtableFilter(
81+
core_ops.bigtable_timestamp_range_filter(
82+
int(start_timestamp), int(end_timestamp)
83+
)
8084
)

0 commit comments

Comments
 (0)