Skip to content

Commit 2cec527

Browse files
committed
Experimental: Add initial wavefront/obj parser for vertices
This PR is an early experimental implementation of wavefront obj parser in tensorflow-io for 3D objects. This PR is the first step to obtain raw vertices in float32 tensor with shape of `[n, 3]`. Additional follow up PRs will be needed to handle meshs with different shapes (not sure if ragged tensor will be a good fit in that case) Signed-off-by: Yong Tang <[email protected]>
1 parent ac75e1c commit 2cec527

File tree

9 files changed

+208
-0
lines changed

9 files changed

+208
-0
lines changed

WORKSPACE

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,3 +1110,14 @@ http_archive(
11101110
"https://github.com/mongodb/mongo-c-driver/releases/download/1.16.2/mongo-c-driver-1.16.2.tar.gz",
11111111
],
11121112
)
1113+
1114+
http_archive(
1115+
name = "tinyobjloader",
1116+
build_file = "//third_party:tinyobjloader.BUILD",
1117+
sha256 = "b8c972dfbbcef33d55554e7c9031abe7040795b67778ad3660a50afa7df6ec56",
1118+
strip_prefix = "tinyobjloader-2.0.0rc8",
1119+
urls = [
1120+
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/tinyobjloader/tinyobjloader/archive/v2.0.0rc8.tar.gz",
1121+
"https://github.com/tinyobjloader/tinyobjloader/archive/v2.0.0rc8.tar.gz",
1122+
],
1123+
)

tensorflow_io/core/BUILD

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,22 @@ cc_library(
695695
alwayslink = 1,
696696
)
697697

698+
cc_library(
699+
name = "obj_ops",
700+
srcs = [
701+
"kernels/obj_kernels.cc",
702+
"ops/obj_ops.cc",
703+
],
704+
copts = tf_io_copts(),
705+
linkstatic = True,
706+
deps = [
707+
"@local_config_tf//:libtensorflow_framework",
708+
"@local_config_tf//:tf_header_lib",
709+
"@tinyobjloader",
710+
],
711+
alwayslink = 1,
712+
)
713+
698714
cc_binary(
699715
name = "python/ops/libtensorflow_io.so",
700716
copts = tf_io_copts(),
@@ -717,6 +733,7 @@ cc_binary(
717733
"//tensorflow_io/core:parquet_ops",
718734
"//tensorflow_io/core:pcap_ops",
719735
"//tensorflow_io/core:pulsar_ops",
736+
"//tensorflow_io/core:obj_ops",
720737
"//tensorflow_io/core:operation_ops",
721738
"//tensorflow_io/core:pubsub_ops",
722739
"//tensorflow_io/core:serialization_ops",
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/core/framework/op_kernel.h"
17+
#include "tensorflow/core/platform/logging.h"
18+
#include "tiny_obj_loader.h"
19+
20+
namespace tensorflow {
21+
namespace io {
22+
namespace {
23+
24+
class DecodeObjVertexOp : public OpKernel {
25+
public:
26+
explicit DecodeObjVertexOp(OpKernelConstruction* context)
27+
: OpKernel(context) {}
28+
29+
void Compute(OpKernelContext* context) override {
30+
const Tensor* input_tensor;
31+
OP_REQUIRES_OK(context, context->input("input", &input_tensor));
32+
OP_REQUIRES(context, TensorShapeUtils::IsScalar(input_tensor->shape()),
33+
errors::InvalidArgument("input must be scalar, got shape ",
34+
input_tensor->shape().DebugString()));
35+
const tstring& input = input_tensor->scalar<tstring>()();
36+
37+
tinyobj::ObjReader reader;
38+
39+
if (!reader.ParseFromString(input.c_str(), "")) {
40+
OP_REQUIRES(
41+
context, false,
42+
errors::Internal("Unable to read obj file: ", reader.Error()));
43+
}
44+
45+
if (!reader.Warning().empty()) {
46+
LOG(WARNING) << "TinyObjReader: " << reader.Warning();
47+
}
48+
49+
auto& attrib = reader.GetAttrib();
50+
51+
int64 count = attrib.vertices.size() / 3;
52+
53+
Tensor* output_tensor = nullptr;
54+
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({count, 3}),
55+
&output_tensor));
56+
// Loop over attrib.vertices:
57+
for (int64 i = 0; i < count; i++) {
58+
tinyobj::real_t x = attrib.vertices[i * 3 + 0];
59+
tinyobj::real_t y = attrib.vertices[i * 3 + 1];
60+
tinyobj::real_t z = attrib.vertices[i * 3 + 2];
61+
output_tensor->tensor<float, 2>()(i, 0) = x;
62+
output_tensor->tensor<float, 2>()(i, 1) = y;
63+
output_tensor->tensor<float, 2>()(i, 2) = z;
64+
}
65+
}
66+
};
67+
REGISTER_KERNEL_BUILDER(Name("IO>DecodeObjVertex").Device(DEVICE_CPU),
68+
DecodeObjVertexOp);
69+
70+
} // namespace
71+
} // namespace io
72+
} // namespace tensorflow

tensorflow_io/core/ops/obj_ops.cc

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/core/framework/common_shape_fns.h"
17+
#include "tensorflow/core/framework/op.h"
18+
#include "tensorflow/core/framework/shape_inference.h"
19+
20+
namespace tensorflow {
21+
namespace io {
22+
namespace {
23+
24+
REGISTER_OP("IO>DecodeObjVertex")
25+
.Input("input: string")
26+
.Output("output: float32")
27+
.SetShapeFn([](shape_inference::InferenceContext* c) {
28+
shape_inference::ShapeHandle unused;
29+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
30+
c->set_output(0, c->MakeShape({c->UnknownDim(), 3}));
31+
return Status::OK();
32+
});
33+
34+
} // namespace
35+
} // namespace io
36+
} // namespace tensorflow

tensorflow_io/core/python/api/experimental/image.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,5 @@
2727
decode_yuy2,
2828
decode_avif,
2929
decode_jp2,
30+
decode_obj_vertex,
3031
)

tensorflow_io/core/python/experimental/image_ops.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,17 @@ def decode_jp2(contents, dtype=tf.uint8, name=None):
208208
A `Tensor` of type `uint8` and shape of `[height, width, 3]` (RGB).
209209
"""
210210
return core_ops.io_decode_jpeg2k(contents, dtype=dtype, name=name)
211+
212+
213+
def decode_obj_vertex(contents, name=None):
214+
"""
215+
Decode a Wavefront (obj) file into a float32 tensor.
216+
217+
Args:
218+
contents: A `Tensor` of type `string`. 0-D. The Wavefront (obj) file.
219+
name: A name for the operation (optional).
220+
221+
Returns:
222+
A `Tensor` of type `float32` and shape of `[n, 3]` for vertices.
223+
"""
224+
return core_ops.io_decode_obj_vertex(contents, name=name)

tests/test_obj.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4+
# use this file except in compliance with the License. You may obtain a copy of
5+
# the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations under
13+
# the License.
14+
# ==============================================================================
15+
"""Test Wavefront OBJ"""
16+
17+
import os
18+
import numpy as np
19+
import pytest
20+
21+
import tensorflow as tf
22+
import tensorflow_io as tfio
23+
24+
25+
def test_decode_obj_vertices():
26+
"""Test case for decode obj"""
27+
filename = os.path.join(
28+
os.path.dirname(os.path.abspath(__file__)), "test_obj", "sample.obj",
29+
)
30+
filename = "file://" + filename
31+
32+
obj = tfio.experimental.image.decode_obj_vertex(tf.io.read_file(filename))
33+
expected = np.array(
34+
[[-0.5, 0.0, 0.4], [-0.5, 0.0, -0.8], [-0.5, 1.0, -0.8], [-0.5, 1.0, 0.4]],
35+
dtype=np.float32,
36+
)
37+
assert np.array_equal(obj, expected)

tests/test_obj/sample.obj

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Simple Wavefront file
2+
v -0.500000 0.000000 0.400000
3+
v -0.500000 0.000000 -0.800000
4+
v -0.500000 1.000000 -0.800000
5+
v -0.500000 1.000000 0.400000
6+
f -4 -3 -2 -1

third_party/tinyobjloader.BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package(default_visibility = ["//visibility:public"])
2+
3+
licenses(["notice"]) # MIT license
4+
5+
cc_library(
6+
name = "tinyobjloader",
7+
srcs = [
8+
"tiny_obj_loader.cc",
9+
],
10+
hdrs = [
11+
"tiny_obj_loader.h",
12+
],
13+
copts = [],
14+
)

0 commit comments

Comments
 (0)