Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions tensorflow_io/bigtable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Cloud Bigtable Client for TensorFlow.

This package allows TensorFlow to interface directly with Cloud Bigtable
for high-speed data loading.

@@BigtableClient
@@BigtableTable
@@RowRange
@@RowSet


"""


from tensorflow.python.util.all_util import remove_undocumented
from tensorflow_io.python.ops.bigtable.bigtable_dataset_ops import BigtableClient
from tensorflow_io.python.ops.bigtable.bigtable_dataset_ops import BigtableTable
import tensorflow_io.python.ops.bigtable.bigtable_version_filters as filters
import tensorflow_io.python.ops.bigtable.bigtable_row_set as row_set
import tensorflow_io.python.ops.bigtable.bigtable_row_range as row_range

_allowed_symbols = [
"BigtableClient",
"BigtableTable",
"filters",
"row_set",
"row_range",
]

remove_undocumented(__name__, _allowed_symbols)
12 changes: 6 additions & 6 deletions tensorflow_io/core/kernels/bigtable/bigtable_dataset_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow_io/core/kernels/bigtable/serialization.h"
#include "tensorflow_io/core/kernels/bigtable/bigtable_row_set.h"
#include "tensorflow_io/core/kernels/bigtable/bigtable_version_filters.h"
#include "tensorflow_io/core/kernels/bigtable/serialization.h"

namespace cbt = ::google::cloud::bigtable;

Expand Down Expand Up @@ -117,7 +117,7 @@ class BigtableClientOp : public OpKernel {

~BigtableClientOp() { VLOG(1) << "BigtableClientOp dtor"; }

void Compute(OpKernelContext* ctx) override TF_LOCKS_EXCLUDED(mu_) {
void Compute(OpKernelContext* ctx) override {
VLOG(1) << "BigtableClientOp compute";
ResourceMgr* mgr = ctx->resource_manager();
ContainerInfo cinfo;
Expand Down Expand Up @@ -375,9 +375,9 @@ class BigtableDatasetOp : public DatasetOpKernel {
GetResourceFromContext(ctx, "filter", &filter_resource));
core::ScopedUnref filter_resource_unref_(filter_resource);

*output = new Dataset(ctx, client_resource->data_client(),
row_set_resource->row_set(),
filter_resource->filter(), table_id_, columns_, output_type_);
*output = new Dataset(
ctx, client_resource->data_client(), row_set_resource->row_set(),
filter_resource->filter(), table_id_, columns_, output_type_);
}

private:
Expand Down Expand Up @@ -513,7 +513,7 @@ class BigtableSplitRowSetEvenlyOp : public OpKernel {
}

private:
mutable mutex mu_;
mutex mu_;
std::string table_id_ GUARDED_BY(mu_);
int num_splits_ GUARDED_BY(mu_);
};
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_io/core/kernels/bigtable/bigtable_row_range.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class BigtableRowRangeOp
}

private:
mutable mutex mu_;
mutex mu_;
std::string left_row_key_ TF_GUARDED_BY(mu_);
bool left_open_ TF_GUARDED_BY(mu_);
std::string right_row_key_ TF_GUARDED_BY(mu_);
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_io/core/kernels/bigtable/bigtable_row_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class BigtableRowSetAppendRowRangeOp : public OpKernel {
}

private:
mutable mutex mu_;
mutex mu_;
std::string row_key_;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class BigtableLatestFilterOp

REGISTER_KERNEL_BUILDER(Name("BigtableLatestFilter").Device(DEVICE_CPU),
BigtableLatestFilterOp);

class BigtableTimestampRangeFilterOp
: public AbstractBigtableResourceOp<BigtableFilterResource> {
public:
Expand All @@ -48,7 +48,8 @@ class BigtableTimestampRangeFilterOp

private:
StatusOr<BigtableFilterResource*> CreateResource() override {
return new BigtableFilterResource(cbt::Filter::TimestampRangeMicros(start_ts_us_, end_ts_us_));
return new BigtableFilterResource(
cbt::Filter::TimestampRangeMicros(start_ts_us_, end_ts_us_));
}

private:
Expand Down Expand Up @@ -82,5 +83,5 @@ class BigtablePrintFilterOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("BigtablePrintFilter").Device(DEVICE_CPU),
BigtablePrintFilterOp);

} // namespace io
} // namespace tensorflow
} // namespace io
} // namespace tensorflow
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ limitations under the License.
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow_io/core/kernels/bigtable/bigtable_resource_kernel.h"


namespace tensorflow {
namespace io {

Expand Down Expand Up @@ -56,7 +55,6 @@ class BigtableFilterResource : public ResourceBase {
const google::cloud::bigtable::Filter filter_;
};


} // namespace io
} // namespace tensorflow

Expand Down
6 changes: 3 additions & 3 deletions tensorflow_io/core/kernels/bigtable/serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.
*/

#include "tensorflow_io/core/kernels/bigtable/serialization.h"

#include "rpc/xdr.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/statusor.h"
Expand Down Expand Up @@ -72,9 +73,8 @@ inline StatusOr<bool_t> BytesToBool(std::string const& s) {
return v;
}

Status PutCellValueInTensor(Tensor& tensor, size_t index,
DataType cell_type,
google::cloud::bigtable::Cell const& cell) {
Status PutCellValueInTensor(Tensor& tensor, size_t index, DataType cell_type,
google::cloud::bigtable::Cell const& cell) {
switch (cell_type) {
case DT_STRING: {
auto tensor_data = tensor.tensor<tstring, 1>();
Expand Down
Empty file.
5 changes: 4 additions & 1 deletion tensorflow_io/python/ops/bigtable/bigtable_dataset_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ def parallel_read_rows(
):

samples = core_ops.bigtable_split_row_set_evenly(
self._client_resource, row_set._impl, self._table_id, num_parallel_calls,
self._client_resource,
row_set._impl,
self._table_id,
num_parallel_calls,
)

def map_func(idx):
Expand Down
10 changes: 4 additions & 6 deletions tests/test_bigtable/bigtable_emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,15 @@ def _get_cbt_binary_path(env_var_name, search_paths, description):
res = os.environ.get(env_var_name)
if res is not None:
if not os.path.isfile(res):
raise EnvironmentError(
(
f"{description} specified in the {env_var_name} "
"environment variable does not exist"
)
raise OSError(
f"{description} specified in the {env_var_name} "
"environment variable does not exist"
)
return res
for candidate in search_paths:
if os.path.isfile(candidate):
return candidate
raise EnvironmentError(f"Could not find {description}")
raise OSError(f"Could not find {description}")


def _get_cbt_emulator_path():
Expand Down
25 changes: 20 additions & 5 deletions tests/test_bigtable/test_parallel_read_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ def test_split_row_set(self):
samples = [
s
for s in core_ops.bigtable_split_row_set_evenly(
client._client_resource, rs._impl, "test-table", num_parallel_calls,
client._client_resource,
rs._impl,
"test-table",
num_parallel_calls,
)
]
self.assertEqual(len(samples), num_parallel_calls)
Expand All @@ -143,7 +146,10 @@ def test_split_row_set(self):
samples = [
s
for s in core_ops.bigtable_split_row_set_evenly(
client._client_resource, rs._impl, "test-table", num_parallel_calls,
client._client_resource,
rs._impl,
"test-table",
num_parallel_calls,
)
]

Expand All @@ -155,15 +161,21 @@ def test_split_row_set(self):
samples = [
s
for s in core_ops.bigtable_split_row_set_evenly(
client._client_resource, rs._impl, "test-table", num_parallel_calls,
client._client_resource,
rs._impl,
"test-table",
num_parallel_calls,
)
]
self.assertEqual(len(samples), num_parallel_calls)

def test_split_empty(self):
os.environ["BIGTABLE_EMULATOR_HOST"] = self.emulator.get_addr()
self.emulator.create_table(
"fake_project", "fake_instance", "test-table", ["fam1", "fam2"],
"fake_project",
"fake_instance",
"test-table",
["fam1", "fam2"],
)

client = BigtableClient("fake_project", "fake_instance")
Expand All @@ -175,6 +187,9 @@ def test_split_empty(self):
self.assertRaises(
tf.errors.FailedPreconditionError,
lambda: core_ops.bigtable_split_row_set_evenly(
client._client_resource, rs._impl, "test-table", num_parallel_calls,
client._client_resource,
rs._impl,
"test-table",
num_parallel_calls,
),
)