From 76a7e29fd309eeacd00b6d76593476d3242d59b6 Mon Sep 17 00:00:00 2001 From: Peng Yu Date: Tue, 18 Dec 2018 23:25:17 -0500 Subject: [PATCH 1/3] initial commits --- tensorflow_io/libsvm/BUILD | 54 ++++++ tensorflow_io/libsvm/__init__.py | 32 ++++ .../libsvm/kernels/decode_libsvm_op.cc | 168 ++++++++++++++++++ tensorflow_io/libsvm/ops/libsvm_ops.cc | 58 ++++++ tensorflow_io/libsvm/python/__init__.py | 0 .../kernel_tests/decode_libsvm_op_test.py | 71 ++++++++ tensorflow_io/libsvm/python/ops/__init__.py | 0 .../libsvm/python/ops/libsvm_dataset_ops.py | 69 +++++++ 8 files changed, 452 insertions(+) create mode 100644 tensorflow_io/libsvm/BUILD create mode 100644 tensorflow_io/libsvm/__init__.py create mode 100644 tensorflow_io/libsvm/kernels/decode_libsvm_op.cc create mode 100644 tensorflow_io/libsvm/ops/libsvm_ops.cc create mode 100644 tensorflow_io/libsvm/python/__init__.py create mode 100644 tensorflow_io/libsvm/python/kernel_tests/decode_libsvm_op_test.py create mode 100644 tensorflow_io/libsvm/python/ops/__init__.py create mode 100644 tensorflow_io/libsvm/python/ops/libsvm_dataset_ops.py diff --git a/tensorflow_io/libsvm/BUILD b/tensorflow_io/libsvm/BUILD new file mode 100644 index 000000000..7305b1798 --- /dev/null +++ b/tensorflow_io/libsvm/BUILD @@ -0,0 +1,54 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + +cc_binary( + name = "python/ops/_libsvm_ops.so", + srcs = [ + "kernels/decode_libsvm_op.cc", + "ops/libsvm_ops.cc", + ], + linkshared = 1, + deps = [ + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + "@kafka//:kafka", + ], + copts = ["-pthread", "-std=c++11", "-D_GLIBCXX_USE_CXX11_ABI=0", "-DNDEBUG"] +) + +py_library( + name = "libsvm_ops_py", + srcs = [ + "python/ops/libsvm_dataset_ops.py", + ], + data = [ + ":python/ops/_libsvm_ops.so", + ], + srcs_version = "PY2AND3", +) + +py_test( + name = "decode_libsvm_op_test", + srcs = [ + "python/kernel_tests/decode_libsvm_op_test.py" + ], + main = "python/kernel_tests/decode_libsvm_ops_test.py", + deps = [ + ":libsvm_ops_py", + ], + srcs_version = "PY2AND3", +) + +py_library( + name = "libsvm_py", + srcs = ([ + "__init__.py", + "python/__init__.py", + "python/ops/__init__.py", + ]), + deps = [ + ":libsvm_ops_py" + ], + srcs_version = "PY2AND3", +) diff --git a/tensorflow_io/libsvm/__init__.py b/tensorflow_io/libsvm/__init__.py new file mode 100644 index 000000000..7b7e62ed2 --- /dev/null +++ b/tensorflow_io/libsvm/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""LibSVM Dataset. + +@@make_libsvm_dataset +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.libsvm.python.ops.libsvm_dataset_ops import make_libsvm_dataset + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + "make_libsvm_dataset", +] + +remove_undocumented(__name__) diff --git a/tensorflow_io/libsvm/kernels/decode_libsvm_op.cc b/tensorflow_io/libsvm/kernels/decode_libsvm_op.cc new file mode 100644 index 000000000..720c74e3d --- /dev/null +++ b/tensorflow_io/libsvm/kernels/decode_libsvm_op.cc @@ -0,0 +1,168 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace tensorflow { + +template +class DecodeLibsvmOp : public OpKernel { + public: + explicit DecodeLibsvmOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_features", &num_features_)); + OP_REQUIRES(ctx, (num_features_ >= 1), + errors::InvalidArgument("Invalid number of features \"", + num_features_, "\"")); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor* input_tensor; + OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); + const auto& input_flat = input_tensor->flat(); + + Tensor* label_tensor; + OP_REQUIRES_OK( + ctx, ctx->allocate_output(0, input_tensor->shape(), &label_tensor)); + auto label = label_tensor->flat(); + + std::vector out_values; + std::vector> out_indices; + for (int i = 0; i < input_flat.size(); ++i) { + StringPiece line(input_flat(i)); + str_util::RemoveWhitespaceContext(&line); + + StringPiece piece; + OP_REQUIRES(ctx, str_util::ConsumeNonWhitespace(&line, &piece), + errors::InvalidArgument("No label found for input[", i, + "]: \"", input_flat(i), "\"")); + + Tlabel label_value; + OP_REQUIRES(ctx, + strings::SafeStringToNumeric(piece, &label_value), + errors::InvalidArgument("Label format incorrect: ", piece)); + + label(i) = label_value; + + str_util::RemoveLeadingWhitespace(&line); + while (str_util::ConsumeNonWhitespace(&line, &piece)) { + size_t p = piece.find(':'); + OP_REQUIRES(ctx, (p != StringPiece::npos), + errors::InvalidArgument("Invalid feature \"", piece, "\"")); + + int64 feature_index; + OP_REQUIRES( + ctx, strings::safe_strto64(piece.substr(0, p), &feature_index), + errors::InvalidArgument("Feature format incorrect: ", piece)); + OP_REQUIRES(ctx, (feature_index >= 0), + errors::InvalidArgument( + "Feature index should be >= 0, got ", feature_index)); + + T feature_value; + OP_REQUIRES( + + ctx, + strings::SafeStringToNumeric(piece.substr(p + 1), + &feature_value), + errors::InvalidArgument("Feature format incorrect: ", piece)); + + out_values.emplace_back(feature_value); + out_indices.emplace_back(std::pair(i, feature_index)); + + str_util::RemoveLeadingWhitespace(&line); + } + } + + Tensor* indices_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_output( + 1, + TensorShape({static_cast(out_indices.size()), + input_tensor->shape().dims() + 1}), + &indices_tensor)); + auto indices = indices_tensor->matrix(); + // Translate flat index to shaped index like np.unravel_index + // Calculate factors for each dimension + std::vector factors(input_tensor->shape().dims()); + factors[input_tensor->shape().dims() - 1] = 1; + for (int j = input_tensor->shape().dims() - 2; j >= 0; j--) { + factors[j] = factors[j + 1] * input_tensor->shape().dim_size(j + 1); + } + for (int i = 0; i < out_indices.size(); i++) { + indices(i, 0) = out_indices[i].first; + int64 value = out_indices[i].first; + for (int j = 0; j < input_tensor->shape().dims(); j++) { + indices(i, j) = value / factors[j]; + value = value % factors[j]; + } + indices(i, input_tensor->shape().dims()) = out_indices[i].second; + } + + Tensor* values_tensor; + OP_REQUIRES_OK(ctx, + ctx->allocate_output( + 2, TensorShape({static_cast(out_values.size())}), + &values_tensor)); + auto values = values_tensor->vec(); + std::copy_n(out_values.begin(), out_values.size(), &values(0)); + + Tensor* shape_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_output( + 3, TensorShape({input_tensor->shape().dims() + 1}), + &shape_tensor)); + auto shape = shape_tensor->flat(); + for (int i = 0; i < input_tensor->shape().dims(); i++) { + shape(i) = input_tensor->shape().dim_size(i); + } + shape(input_tensor->shape().dims()) = num_features_; + } + + private: + int64 num_features_; +}; + +#define REGISTER_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype") \ + .TypeConstraint("label_dtype"), \ + DecodeLibsvmOp); \ + REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype") \ + .TypeConstraint("label_dtype"), \ + DecodeLibsvmOp); \ + REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype") \ + .TypeConstraint("label_dtype"), \ + DecodeLibsvmOp); \ + REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype") \ + .TypeConstraint("label_dtype"), \ + DecodeLibsvmOp); + +REGISTER_KERNEL(float); +REGISTER_KERNEL(double); +REGISTER_KERNEL(int32); +REGISTER_KERNEL(int64); +#undef REGISTER_KERNEL + +} // namespace tensorflow diff --git a/tensorflow_io/libsvm/ops/libsvm_ops.cc b/tensorflow_io/libsvm/ops/libsvm_ops.cc new file mode 100644 index 000000000..dec946189 --- /dev/null +++ b/tensorflow_io/libsvm/ops/libsvm_ops.cc @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using shape_inference::InferenceContext; + +REGISTER_OP("DecodeLibsvm") + .Input("input: string") + .Output("label: label_dtype") + .Output("feature_indices: int64") + .Output("feature_values: dtype") + .Output("feature_shape: int64") + .Attr("dtype: {float, double, int32, int64} = DT_FLOAT") + .Attr("label_dtype: {float, double, int32, int64} = DT_INT64") + .Attr("num_features: int >= 1") + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(0)); + + c->set_output(1, c->Matrix(InferenceContext::kUnknownDim, + InferenceContext::kUnknownDim)); + c->set_output(2, c->Vector(InferenceContext::kUnknownDim)); + c->set_output(3, c->Vector(InferenceContext::kUnknownDim)); + + return Status::OK(); + }) + + .Doc(R"doc( +Convert LibSVM input to tensors. The output consists of +a label and a feature tensor. The shape of the label tensor +is the same as input and the shape of the feature tensor is +`[input_shape, num_features]`. + +input: Each string is a record in the LibSVM. +label: A tensor of the same shape as input. +feature_indices: A 2-D int64 tensor of dense_shape [N, ndims]. +feature_values: A 1-D tensor of any type and dense_shape [N]. +feature_shape: A 1-D int64 tensor of dense_shape [ndims]. +num_features: The number of features. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow_io/libsvm/python/__init__.py b/tensorflow_io/libsvm/python/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tensorflow_io/libsvm/python/kernel_tests/decode_libsvm_op_test.py b/tensorflow_io/libsvm/python/kernel_tests/decode_libsvm_op_test.py new file mode 100644 index 000000000..8390ddda9 --- /dev/null +++ b/tensorflow_io/libsvm/python/kernel_tests/decode_libsvm_op_test.py @@ -0,0 +1,71 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for DecodeLibsvm op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.libsvm.python.ops import libsvm_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test + + +class DecodeLibsvmOpTest(test.TestCase): + + def testBasic(self): + with self.cached_session() as sess: + content = [ + "1 1:3.4 2:0.5 4:0.231", "1 2:2.5 3:inf 5:0.503", + "2 3:2.5 2:nan 1:0.105" + ] + sparse_features, labels = libsvm_ops.decode_libsvm( + content, num_features=6) + features = sparse_ops.sparse_tensor_to_dense( + sparse_features, validate_indices=False) + + self.assertAllEqual(labels.get_shape().as_list(), [3]) + + features, labels = sess.run([features, labels]) + self.assertAllEqual(labels, [1, 1, 2]) + self.assertAllClose( + features, [[0, 3.4, 0.5, 0, 0.231, 0], [0, 0, 2.5, np.inf, 0, 0.503], + [0, 0.105, np.nan, 2.5, 0, 0]]) + + def testNDimension(self): + with self.cached_session() as sess: + content = [["1 1:3.4 2:0.5 4:0.231", "1 1:3.4 2:0.5 4:0.231"], + ["1 2:2.5 3:inf 5:0.503", "1 2:2.5 3:inf 5:0.503"], + ["2 3:2.5 2:nan 1:0.105", "2 3:2.5 2:nan 1:0.105"]] + sparse_features, labels = libsvm_ops.decode_libsvm( + content, num_features=6, label_dtype=dtypes.float64) + features = sparse_ops.sparse_tensor_to_dense( + sparse_features, validate_indices=False) + + self.assertAllEqual(labels.get_shape().as_list(), [3, 2]) + + features, labels = sess.run([features, labels]) + self.assertAllEqual(labels, [[1, 1], [1, 1], [2, 2]]) + self.assertAllClose( + features, [[[0, 3.4, 0.5, 0, 0.231, 0], [0, 3.4, 0.5, 0, 0.231, 0]], [ + [0, 0, 2.5, np.inf, 0, 0.503], [0, 0, 2.5, np.inf, 0, 0.503] + ], [[0, 0.105, np.nan, 2.5, 0, 0], [0, 0.105, np.nan, 2.5, 0, 0]]]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow_io/libsvm/python/ops/__init__.py b/tensorflow_io/libsvm/python/ops/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tensorflow_io/libsvm/python/ops/libsvm_dataset_ops.py b/tensorflow_io/libsvm/python/ops/libsvm_dataset_ops.py new file mode 100644 index 000000000..97673cf18 --- /dev/null +++ b/tensorflow_io/libsvm/python/ops/libsvm_dataset_ops.py @@ -0,0 +1,69 @@ +"""LibSVM Dataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from functools import partial + +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.data.experimental.ops import batching + +from tensorflow.python.framework import load_library +from tensorflow.python.platform import resource_loader + +libsvm_ops = load_library.load_op_library( + resource_loader.get_path_to_datafile('_libsvm_ops.so')) + + +decode_libsvm = libsvm_ops.decode_libsvm + + +def make_libsvm_dataset(file_names, + num_features, + dtype=None, + label_dtype=None, + batch_size=1, + compression_type='', + buffer_size=None, + num_parallel_parser_calls=None, + drop_final_batch=False, + prefetch_buffer_size=0): + """Reads LibSVM files into a dataset. + + Args: + file_names: A `tf.string` tensor containing one or more filenames. + num_features: The number of features. + dtype(Optional): The type of the output feature tensor. Default to tf.float32. + label_dtype(Optional): The type of the output label tensor. Default to tf.int64. + batch_size: (Optional.) An int representing the number of records to combine + in a single batch, default 1. + compression_type: (Optional.) A `tf.string` scalar evaluating to one of + `""` (no compression), `"ZLIB"`, or `"GZIP"`. + buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes + to buffer. A value of 0 results in the default buffering values chosen + based on the compression type. + num_parallel_parser_calls: (Optional.) Number of parallel + records to parse in parallel. Defaults to an automatic selection. + drop_final_batch: (Optional.) Whether the last batch should be + dropped in case its size is smaller than `batch_size`; the + default behavior is not to drop the smaller batch. + prefetch_buffer_size: (Optional.) An int specifying the number of + feature batches to prefetch for performance improvement. + Defaults to auto-tune. Set to 0 to disable prefetching. + """ + dataset = core_readers.TextLineDataset(file_names, + compression_type=compression_type, + buffer_size=buffer_size) + parsing_func = partial(decode_libsvm, + num_features=num_features, + dtype=dtype, + label_type=label_type) + dataset = dataset.apply(batching.map_and_batch( + parsing_func, + batch_size, + num_parallel_calls=num_parallel_parser_calls, + drop_remainder=drop_final_batch)) + if prefetch_buffer_size == 0: + return dataset + else: + return dataset.prefetch(buffer_size=prefetch_buffer_size) From f917de154ac54beb3e8a5392130dce3d77569476 Mon Sep 17 00:00:00 2001 From: Peng YU Date: Wed, 19 Dec 2018 12:52:10 -0500 Subject: [PATCH 2/3] fix the name issues --- tensorflow_io/libsvm/BUILD | 2 +- .../kernel_tests/decode_libsvm_op_test.py | 6 +-- .../libsvm/python/ops/libsvm_dataset_ops.py | 39 ++++++++++++------- 3 files changed, 30 insertions(+), 17 deletions(-) diff --git a/tensorflow_io/libsvm/BUILD b/tensorflow_io/libsvm/BUILD index 7305b1798..24af97f6e 100644 --- a/tensorflow_io/libsvm/BUILD +++ b/tensorflow_io/libsvm/BUILD @@ -33,7 +33,7 @@ py_test( srcs = [ "python/kernel_tests/decode_libsvm_op_test.py" ], - main = "python/kernel_tests/decode_libsvm_ops_test.py", + main = "python/kernel_tests/decode_libsvm_op_test.py", deps = [ ":libsvm_ops_py", ], diff --git a/tensorflow_io/libsvm/python/kernel_tests/decode_libsvm_op_test.py b/tensorflow_io/libsvm/python/kernel_tests/decode_libsvm_op_test.py index 8390ddda9..6cd01e4a5 100644 --- a/tensorflow_io/libsvm/python/kernel_tests/decode_libsvm_op_test.py +++ b/tensorflow_io/libsvm/python/kernel_tests/decode_libsvm_op_test.py @@ -20,7 +20,7 @@ import numpy as np -from tensorflow.contrib.libsvm.python.ops import libsvm_ops +from tensorflow_io.libsvm.python.ops import libsvm_dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test @@ -34,7 +34,7 @@ def testBasic(self): "1 1:3.4 2:0.5 4:0.231", "1 2:2.5 3:inf 5:0.503", "2 3:2.5 2:nan 1:0.105" ] - sparse_features, labels = libsvm_ops.decode_libsvm( + sparse_features, labels = libsvm_dataset_ops.decode_libsvm( content, num_features=6) features = sparse_ops.sparse_tensor_to_dense( sparse_features, validate_indices=False) @@ -52,7 +52,7 @@ def testNDimension(self): content = [["1 1:3.4 2:0.5 4:0.231", "1 1:3.4 2:0.5 4:0.231"], ["1 2:2.5 3:inf 5:0.503", "1 2:2.5 3:inf 5:0.503"], ["2 3:2.5 2:nan 1:0.105", "2 3:2.5 2:nan 1:0.105"]] - sparse_features, labels = libsvm_ops.decode_libsvm( + sparse_features, labels = libsvm_dataset_ops.decode_libsvm( content, num_features=6, label_dtype=dtypes.float64) features = sparse_ops.sparse_tensor_to_dense( sparse_features, validate_indices=False) diff --git a/tensorflow_io/libsvm/python/ops/libsvm_dataset_ops.py b/tensorflow_io/libsvm/python/ops/libsvm_dataset_ops.py index 97673cf18..29fed35c6 100644 --- a/tensorflow_io/libsvm/python/ops/libsvm_dataset_ops.py +++ b/tensorflow_io/libsvm/python/ops/libsvm_dataset_ops.py @@ -8,14 +8,28 @@ from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.data.experimental.ops import batching -from tensorflow.python.framework import load_library +from tensorflow.python.framework import load_library, sparse_tensor from tensorflow.python.platform import resource_loader -libsvm_ops = load_library.load_op_library( +gen_libsvm_ops = load_library.load_op_library( resource_loader.get_path_to_datafile('_libsvm_ops.so')) -decode_libsvm = libsvm_ops.decode_libsvm +def decode_libsvm(content, num_features, dtype=None, label_dtype=None): + """Convert Libsvm records to a tensor of label and a tensor of feature. + Args: + content: A `Tensor` of type `string`. Each string is a record/row in + the Libsvm format. + num_features: The number of features. + dtype: The type of the output feature tensor. Default to tf.float32. + label_dtype: The type of the output label tensor. Default to tf.int64. + Returns: + features: A `SparseTensor` of the shape `[input_shape, num_features]`. + labels: A `Tensor` of the same shape as content. + """ + labels, indices, values, shape = gen_libsvm_ops.decode_libsvm( + content, num_features, dtype=dtype, label_dtype=label_dtype) + return sparse_tensor.SparseTensor(indices, values, shape), labels def make_libsvm_dataset(file_names, @@ -28,7 +42,7 @@ def make_libsvm_dataset(file_names, num_parallel_parser_calls=None, drop_final_batch=False, prefetch_buffer_size=0): - """Reads LibSVM files into a dataset. + """Reads LibSVM files into a dataset. Args: file_names: A `tf.string` tensor containing one or more filenames. @@ -51,19 +65,18 @@ def make_libsvm_dataset(file_names, feature batches to prefetch for performance improvement. Defaults to auto-tune. Set to 0 to disable prefetching. """ - dataset = core_readers.TextLineDataset(file_names, + dataset = core_readers.TextLineDataset(file_names, compression_type=compression_type, buffer_size=buffer_size) - parsing_func = partial(decode_libsvm, - num_features=num_features, - dtype=dtype, - label_type=label_type) + def parsing_func(content): + return decode_libsvm(content, num_features, dtype, label_type) + dataset = dataset.apply(batching.map_and_batch( parsing_func, batch_size, num_parallel_calls=num_parallel_parser_calls, drop_remainder=drop_final_batch)) - if prefetch_buffer_size == 0: - return dataset - else: - return dataset.prefetch(buffer_size=prefetch_buffer_size) + if prefetch_buffer_size == 0: + return dataset + else: + return dataset.prefetch(buffer_size=prefetch_buffer_size) From ccc6b100dd879cc9c85a6758b42aa1f9177598c4 Mon Sep 17 00:00:00 2001 From: Peng YU Date: Fri, 21 Dec 2018 10:38:36 -0500 Subject: [PATCH 3/3] remove kafka dependency --- tensorflow_io/libsvm/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow_io/libsvm/BUILD b/tensorflow_io/libsvm/BUILD index 24af97f6e..94a3c557d 100644 --- a/tensorflow_io/libsvm/BUILD +++ b/tensorflow_io/libsvm/BUILD @@ -12,7 +12,6 @@ cc_binary( deps = [ "@local_config_tf//:libtensorflow_framework", "@local_config_tf//:tf_header_lib", - "@kafka//:kafka", ], copts = ["-pthread", "-std=c++11", "-D_GLIBCXX_USE_CXX11_ABI=0", "-DNDEBUG"] )