From b8325aa8625db0e642fe451cd8e65e785a45589e Mon Sep 17 00:00:00 2001 From: Pavel Podlipensky Date: Mon, 29 Apr 2019 16:10:43 -0700 Subject: [PATCH 01/12] Initial commit of MeshVisualizer plugin (server-side) --- tensorboard/plugins/mesh_visualizer/BUILD | 112 ++++++++ .../plugins/mesh_visualizer/__init__.py | 0 .../plugins/mesh_visualizer/mesh_plugin.py | 253 ++++++++++++++++++ .../mesh_visualizer/mesh_plugin_test.py | 224 ++++++++++++++++ .../plugins/mesh_visualizer/mesh_summary.py | 180 +++++++++++++ .../mesh_visualizer/mesh_summary_test.py | 101 +++++++ .../plugins/mesh_visualizer/metadata.py | 98 +++++++ .../plugins/mesh_visualizer/metadata_test.py | 86 ++++++ .../plugins/mesh_visualizer/plugin_data.proto | 29 ++ .../plugins/mesh_visualizer/test_utils.py | 180 +++++++++++++ 10 files changed, 1263 insertions(+) create mode 100644 tensorboard/plugins/mesh_visualizer/BUILD create mode 100644 tensorboard/plugins/mesh_visualizer/__init__.py create mode 100644 tensorboard/plugins/mesh_visualizer/mesh_plugin.py create mode 100644 tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py create mode 100644 tensorboard/plugins/mesh_visualizer/mesh_summary.py create mode 100644 tensorboard/plugins/mesh_visualizer/mesh_summary_test.py create mode 100644 tensorboard/plugins/mesh_visualizer/metadata.py create mode 100644 tensorboard/plugins/mesh_visualizer/metadata_test.py create mode 100644 tensorboard/plugins/mesh_visualizer/plugin_data.proto create mode 100644 tensorboard/plugins/mesh_visualizer/test_utils.py diff --git a/tensorboard/plugins/mesh_visualizer/BUILD b/tensorboard/plugins/mesh_visualizer/BUILD new file mode 100644 index 00000000000..24fdd7ad4ee --- /dev/null +++ b/tensorboard/plugins/mesh_visualizer/BUILD @@ -0,0 +1,112 @@ +package(default_visibility = ["//tensorboard:internal"]) + +licenses(["notice"]) # Apache 2.0 +exports_files(["LICENSE"]) + +load("//tensorboard/defs:protos.bzl", "tb_proto_library") + +py_library( + name = "metadata", + srcs = ["metadata.py"], + srcs_version = "PY2AND3", + deps = [ + ":protos_all_py_pb2", + "//tensorboard/compat/proto:protos_all_py_pb2", + ], +) + +py_test( + name = "metadata_test", + size = "small", + srcs = ["metadata_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":metadata", + "//tensorboard:expect_tensorflow_installed", + "@org_pythonhosted_six", + ], +) + +py_library( + name = "mesh_plugin", + srcs = ["mesh_plugin.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":metadata", + ":protos_all_py_pb2", + "//tensorboard:expect_numpy_installed", + "//tensorboard:plugin_util", + "//tensorboard/backend:http_util", + "//tensorboard/backend/event_processing:event_accumulator", + "//tensorboard/compat:tensorflow", + "//tensorboard/plugins:base_plugin", + "@org_pythonhosted_six", + "@org_pocoo_werkzeug", + ], +) + +py_library( + name = "test_utils", + testonly = 1, + srcs = ["test_utils.py"], + srcs_version = "PY2AND3", + deps = [ + ":mesh_summary", + "//tensorboard:expect_tensorflow_installed", + "//tensorboard/backend:application", + "//tensorboard/backend/event_processing:event_multiplexer", + "@org_pocoo_werkzeug", + ], +) + +py_test( + name = "mesh_plugin_test", + size = "small", + srcs = ["mesh_plugin_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":mesh_plugin", + ":mesh_summary", + ":test_utils", + "//tensorboard:expect_numpy_installed", + "//tensorboard:expect_tensorflow_installed", + "//tensorboard/backend:application", + "//tensorboard/backend/event_processing:event_multiplexer", + "//tensorboard/plugins:base_plugin", + "@org_pocoo_werkzeug", + "@org_pythonhosted_six", + ], +) + +py_library( + name = "mesh_summary", + srcs = ["mesh_summary.py"], + srcs_version = "PY2AND3", + visibility = [ + "//visibility:public", + ], + deps = [ + ":metadata", + ":protos_all_py_pb2", + "//tensorboard/compat:tensorflow", + ], +) + +py_test( + name = "mesh_summary_test", + size = "small", + srcs = ["mesh_summary_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":mesh_summary", + ":test_utils", + "//tensorboard:expect_tensorflow_installed", + ], +) + +tb_proto_library( + name = "protos_all", + srcs = ["plugin_data.proto"], + visibility = ["//visibility:public"], +) \ No newline at end of file diff --git a/tensorboard/plugins/mesh_visualizer/__init__.py b/tensorboard/plugins/mesh_visualizer/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorboard/plugins/mesh_visualizer/mesh_plugin.py b/tensorboard/plugins/mesh_visualizer/mesh_plugin.py new file mode 100644 index 00000000000..ac574204e01 --- /dev/null +++ b/tensorboard/plugins/mesh_visualizer/mesh_plugin.py @@ -0,0 +1,253 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorBoard 3D mesh visualizer plugin.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import numpy as np +import six +import tensorflow as tf +from tensorflow_graphics.tensorboard.mesh_visualizer import metadata +from tensorflow_graphics.tensorboard.mesh_visualizer import plugin_data_pb2 +from werkzeug import wrappers +from tensorboard.backend import http_util +from tensorboard.plugins import base_plugin + + +class MeshPlugin(base_plugin.TBPlugin): + """A plugin that serves 3D visualization of meshes.""" + + plugin_name = 'mesh' + + def __init__(self, context): + """Instantiates a MeshPlugin via TensorBoard core. + + Args: + context: A base_plugin.TBContext instance. A magic container that + TensorBoard uses to make objects available to the plugin. + """ + # Retrieve the multiplexer from the context and store a reference to it. + self._multiplexer = context.multiplexer + self._tag_to_instance_tags = collections.defaultdict(list) + self._instance_tag_to_tag = dict() + self._instance_tag_to_metadata = dict() + self.prepare_metadata() + + def prepare_metadata(self): + """Processes all tags and caches metadata for each.""" + if self._tag_to_instance_tags: + return + # This is a dictionary mapping from run to (tag to string content). + # To be clear, the values of the dictionary are dictionaries. + all_runs = self._multiplexer.PluginRunToTagToContent(MeshPlugin.plugin_name) + + # tagToContent is itself a dictionary mapping tag name to string + # SummaryMetadata.plugin_data.content. Retrieve the keys of that dictionary + # to obtain a list of tags associated with each run. For each tag, estimate + # the number of samples. + self._tag_to_instance_tags = collections.defaultdict(list) + self._instance_tag_to_metadata = dict() + for _, tag_to_content in six.iteritems(all_runs): + for tag, content in six.iteritems(tag_to_content): + meta = metadata.parse_plugin_metadata(content) + self._instance_tag_to_metadata[tag] = meta + # Remember instance_name (instance_tag) for future reference. + self._tag_to_instance_tags[meta.name].append(tag) + self._instance_tag_to_tag[tag] = meta.name + + @wrappers.Request.application + def _serve_tags(self, request): + """A route (HTTP handler) that returns a response with tags. + + Args: + request: The werkzeug.Request object. + + Returns: + A response that contains a JSON object. The keys of the object + are all the runs. Each run is mapped to a (potentially empty) + list of all tags that are relevant to this plugin. + """ + # This is a dictionary mapping from run to (tag to string content). + # To be clear, the values of the dictionary are dictionaries. + all_runs = self._multiplexer.PluginRunToTagToContent( + MeshPlugin.plugin_name) + + # Make sure we populate tags mapping structures. + self.prepare_metadata() + + # tagToContent is itself a dictionary mapping tag name to string + # SummaryMetadata.plugin_data.content. Retrieve the keys of that dictionary + # to obtain a list of tags associated with each run. For each tag estimate + # number of samples. + response = dict() + for run, tag_to_content in six.iteritems(all_runs): + response[run] = dict() + for instance_tag, _ in six.iteritems(tag_to_content): + # Make sure we only operate on user-defined tags here. + tag = self._instance_tag_to_tag[instance_tag] + meta = self._instance_tag_to_metadata[instance_tag] + # Shape should be at least BxNx3 where B represents the batch dimensions + # and N - the number of points, each with x,y,z coordinates. + assert len(meta.shape) >= 3 + response[run][tag] = {'samples': meta.shape[0]} + return http_util.Respond(request, response, 'application/json') + + def get_plugin_apps(self): + """Gets all routes offered by the plugin. + + This method is called by TensorBoard when retrieving all the + routes offered by the plugin. + Returns: + A dictionary mapping URL path to route that handles it. + """ + # Note that the methods handling routes are decorated with + # @wrappers.Request.application. + return { + '/tags': self._serve_tags, + '/meshes': self._serve_mesh_metadata, + '/data': self._serve_mesh_data + } + + def is_active(self): + """Determines whether this plugin is active. + + This plugin is only active if TensorBoard sampled any summaries + relevant to the mesh plugin. + Returns: + Whether this plugin is active. + """ + all_runs = self._multiplexer.PluginRunToTagToContent( + MeshPlugin.plugin_name) + + # The plugin is active if any of the runs has a tag relevant + # to the plugin. + return bool(self._multiplexer and any(six.itervalues(all_runs))) + + def _get_sample(self, tensor_event, sample): + """Returns a single sample from a batch of samples.""" + data = tf.make_ndarray(tensor_event.tensor_proto) + return data[sample].tolist() + + def _get_tensor_metadata(self, event, content_type, data_shape, config): + """Converts a TensorEvent into a JSON-compatible response. + + Args: + event: TensorEvent object containing data in proto format. + content_type: enum plugin_data_pb2.MeshPluginData.ContentType value, + representing content type in TensorEvent. + data_shape: list of dimensions sizes of the tensor. + config: rendering scene configuration as dictionary. + Returns: + Dictionary of transformed metadata. + """ + return { + 'wall_time': event.wall_time, + 'step': event.step, + 'content_type': content_type, + 'config': config, + 'data_shape': list(data_shape) + } + + def _get_tensor_data(self, event, sample): + """Convert a TensorEvent into a JSON-compatible response.""" + data = self._get_sample(event, sample) + return data + + def _collect_tensor_events(self, request): + """Collects list of tensor events based on request.""" + run = request.args.get('run') + tag = request.args.get('tag') + + # TODO(b/128995556): investigate why this additional metadata mapping is + # necessary, it must have something todo with the lifecycle of the request. + # Make sure we populate tags mapping structures. + self.prepare_metadata() + + # We fetch all the tensor events that contain tag. + tensor_events = [] # List of tuples (meta, tensor). + for instance_tag in self._tag_to_instance_tags[tag]: + tensors = self._multiplexer.Tensors(run, instance_tag) + meta = self._instance_tag_to_metadata[instance_tag] + tensor_events += [(meta, tensor) for tensor in tensors] + + # Make sure tensors sorted by timestamp in ascending order. + tensor_events = sorted( + tensor_events, key=lambda tensor_data: tensor_data[1].wall_time) + + return tensor_events + + @wrappers.Request.application + def _serve_mesh_data(self, request): + """A route that returns data for particular summary of specified type. + + Data can represent vertices coordinates, vertices indices in faces, + vertices colors and so on. Each mesh may have different combination of + abovementioned data and each type/part of mesh summary must be served as + separate roundtrip to the server. + + Args: + request: werkzeug.Request containing content_type as a name of enum + plugin_data_pb2.MeshPluginData.ContentType. + Returns: + werkzeug.Response either float32 or int32 data in binary format. + """ + tensor_events = self._collect_tensor_events(request) + content_type = request.args.get('content_type') + content_type = plugin_data_pb2.MeshPluginData.ContentType.Value( + content_type) + sample = int(request.args.get('sample', 0)) + + response = [ + self._get_tensor_data(tensor, sample) + for meta, tensor in tensor_events + if meta.content_type == content_type + ] + + np_type = np.float32 + if content_type == plugin_data_pb2.MeshPluginData.ContentType.FACE: + np_type = np.int32 + elif content_type == plugin_data_pb2.MeshPluginData.ContentType.COLOR: + np_type = np.uint8 + response = np.array(response, dtype=np_type) + # Looks like reshape can take around 160ms, so why not store it reshaped. + response = response.reshape(-1).tobytes() + + return http_util.Respond(request, response, 'arraybuffer') + + @wrappers.Request.application + def _serve_mesh_metadata(self, request): + """A route that returns the mesh metadata associated with a tag. + + Metadata consists of wall time, type of elements in tensor, scene + configuration and so on. + + Args: + request: The werkzeug.Request object. + + Returns: + A JSON list of mesh data associated with the run and tag + combination. + """ + tensor_events = self._collect_tensor_events(request) + + # We convert the tensor data to text. + response = [ + self._get_tensor_metadata(tensor, meta.content_type, meta.shape, + meta.json_config) + for meta, tensor in tensor_events + ] + return http_util.Respond(request, response, 'application/json') \ No newline at end of file diff --git a/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py b/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py new file mode 100644 index 00000000000..110cc487fa9 --- /dev/null +++ b/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py @@ -0,0 +1,224 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests the Tensorboard mesh plugin.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import os +import shutil +from mock import patch +import numpy as np +import tensorflow as tf + +from tensorflow_graphics.tensorboard.mesh_visualizer import mesh_plugin +from tensorflow_graphics.tensorboard.mesh_visualizer import mesh_summary +from tensorflow_graphics.tensorboard.mesh_visualizer import plugin_data_pb2 +from tensorflow_graphics.tensorboard.mesh_visualizer import test_utils +from werkzeug import test as werkzeug_test +from werkzeug import wrappers +from google3.third_party.tensorboard.backend import application +from google3.third_party.tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer +from google3.third_party.tensorboard.plugins import base_plugin + + +class MeshPluginTest(tf.test.TestCase): + """Tests for mesh plugin server.""" + + def setUp(self): + # We use numpy.random to generate meshes. We seed to avoid non-determinism + # in this test. + np.random.seed(17) + + # Log dir to save temp events into. + self.log_dir = self.get_temp_dir() + + # Create mesh summary. + tf.compat.v1.reset_default_graph() + sess = tf.compat.v1.Session() + point_cloud = test_utils.get_random_mesh(1000) + point_cloud_vertices = tf.compat.v1.placeholder(tf.float32, + point_cloud.vertices.shape) + + mesh_no_color = test_utils.get_random_mesh(2000, add_faces=True) + mesh_no_color_vertices = tf.compat.v1.placeholder( + tf.float32, mesh_no_color.vertices.shape) + mesh_no_color_faces = tf.compat.v1.placeholder(tf.int32, + mesh_no_color.faces.shape) + + mesh_color = test_utils.get_random_mesh( + 3000, add_faces=True, add_colors=True) + mesh_color_vertices = tf.compat.v1.placeholder(tf.float32, + mesh_color.vertices.shape) + mesh_color_faces = tf.compat.v1.placeholder(tf.int32, + mesh_color.faces.shape) + mesh_color_colors = tf.compat.v1.placeholder(tf.uint8, + mesh_color.colors.shape) + self.data = [point_cloud, mesh_no_color, mesh_color] + + # In case when name is present and display_name is not, we will reuse name + # as display_name. Summaries below intended to test both cases. + self.names = ["point_cloud", "mesh_no_color", "mesh_color"] + mesh_summary.op( + self.names[0], + point_cloud_vertices, + description="just point cloud") + mesh_summary.op( + self.names[1], + mesh_no_color_vertices, + faces=mesh_no_color_faces, + display_name="name_to_display_in_ui", + description="beautiful mesh in grayscale") + mesh_summary.op( + self.names[2], + mesh_color_vertices, + faces=mesh_color_faces, + colors=mesh_color_colors, + description="mesh with random colors") + + merged_summary_op = tf.compat.v1.summary.merge_all() + self.runs = ["bar"] + self.steps = 20 + bar_directory = os.path.join(self.log_dir, self.runs[0]) + with test_utils.FileWriterCache.get(bar_directory) as writer: + writer.add_graph(sess.graph) + for step in xrange(self.steps): + writer.add_summary( + sess.run( + merged_summary_op, + feed_dict={ + point_cloud_vertices: point_cloud.vertices, + mesh_no_color_vertices: mesh_no_color.vertices, + mesh_no_color_faces: mesh_no_color.faces, + mesh_color_vertices: mesh_color.vertices, + mesh_color_faces: mesh_color.faces, + mesh_color_colors: mesh_color.colors, + }), + global_step=step) + + # Start a server that will receive requests. + self.multiplexer = event_multiplexer.EventMultiplexer({ + "bar": bar_directory, + }) + self.context = base_plugin.TBContext( + logdir=self.log_dir, multiplexer=self.multiplexer) + self.plugin = mesh_plugin.MeshPlugin(self.context) + wsgi_app = application.TensorBoardWSGIApp( + self.log_dir, [self.plugin], + self.multiplexer, + reload_interval=0, + path_prefix="") + self.server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse) + self.multiplexer.Reload() + self.routes = self.plugin.get_plugin_apps() + + def tearDown(self): + shutil.rmtree(self.log_dir, ignore_errors=True) + + def testRoutes(self): + """Tests that the /tags route offers the correct run to tag mapping.""" + self.assertIsInstance(self.routes["/tags"], collections.Callable) + self.assertIsInstance(self.routes["/meshes"], collections.Callable) + self.assertIsInstance(self.routes["/data"], collections.Callable) + + def testTagsRoute(self): + """Tests that the /tags route offers the correct run to tag mapping.""" + response = self.server.get("/data/plugin/mesh/tags") + self.assertEqual(200, response.status_code) + tags = test_utils.deserialize_json_response(response.get_data()) + self.assertIn(self.runs[0], tags) + for name in self.names: + self.assertIn(name, tags[self.runs[0]]) + + def testDataRoute(self): + """Tests that the /data route returns correct data for meshes.""" + response = self.server.get( + "/data/plugin/mesh/data?run=%s&tag=%s&sample=%d&content_type=%s" % + (self.runs[0], self.names[0], 0, "VERTEX")) + self.assertEqual(200, response.status_code) + data = test_utils.deserialize_array_buffer_response( + response.response.next(), np.float32) + vertices = np.tile(self.data[0].vertices.reshape(-1), self.steps) + self.assertEqual(vertices.tolist(), data.tolist()) + + response = self.server.get( + "/data/plugin/mesh/data?run=%s&tag=%s&sample=%d&content_type=%s" % + (self.runs[0], self.names[1], 0, "FACE")) + self.assertEqual(200, response.status_code) + data = test_utils.deserialize_array_buffer_response( + response.response.next(), np.int32) + faces = np.tile(self.data[1].faces.reshape(-1), self.steps) + self.assertEqual(faces.tolist(), data.tolist()) + + response = self.server.get( + "/data/plugin/mesh/data?run=%s&tag=%s&sample=%d&content_type=%s" % + (self.runs[0], self.names[2], 0, "COLOR")) + self.assertEqual(200, response.status_code) + data = test_utils.deserialize_array_buffer_response( + response.response.next(), np.uint8) + colors = np.tile(self.data[2].colors.reshape(-1), self.steps) + self.assertListEqual(colors.tolist(), data.tolist()) + + def testMetadataRoute(self): + """Tests that the /meshes route returns correct metadata for meshes.""" + response = self.server.get( + "/data/plugin/mesh/meshes?run=%s&tag=%s&sample=%d" % + (self.runs[0], self.names[0], 0)) + self.assertEqual(200, response.status_code) + metadata = test_utils.deserialize_json_response(response.get_data()) + self.assertLen(metadata, self.steps) + self.assertAllEqual(metadata[0]["content_type"], + plugin_data_pb2.MeshPluginData.VERTEX) + self.assertAllEqual(metadata[0]["data_shape"], self.data[0].vertices.shape) + + def testsEventsAlwaysSortedByWallTime(self): + """Tests that events always sorted by wall time.""" + response = self.server.get( + "/data/plugin/mesh/meshes?run=%s&tag=%s&sample=%d" % + (self.runs[0], self.names[1], 0)) + self.assertEqual(200, response.status_code) + metadata = test_utils.deserialize_json_response(response.get_data()) + for i in range(1, self.steps): + # Timestamp will be equal when two tensors of different content type + # belong to the same mesh. + self.assertLessEqual(metadata[i - 1]["wall_time"], + metadata[i]["wall_time"]) + + @patch.object( + event_multiplexer.EventMultiplexer, + "PluginRunToTagToContent", + return_value={"bar": { + "foo": "" + }}) + def testMetadataComputedOnce(self, run_to_tag_mock): + """Tests that metadata mapping computed once.""" + self.plugin.prepare_metadata() + self.plugin.prepare_metadata() + self.assertEqual(1, run_to_tag_mock.call_count) + + def testIsActive(self): + self.assertTrue(self.plugin.is_active()) + + @patch.object( + event_multiplexer.EventMultiplexer, + "PluginRunToTagToContent", + return_value={}) + def testIsInactive(self, get_random_mesh_stub): + self.assertFalse(self.plugin.is_active()) + + +if __name__ == "__main__": + tf.test.main() \ No newline at end of file diff --git a/tensorboard/plugins/mesh_visualizer/mesh_summary.py b/tensorboard/plugins/mesh_visualizer/mesh_summary.py new file mode 100644 index 00000000000..766d94659e1 --- /dev/null +++ b/tensorboard/plugins/mesh_visualizer/mesh_summary.py @@ -0,0 +1,180 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Mesh summaries and TensorFlow operations to create them.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import tensorflow as tf + +from tensorflow_graphics.tensorboard.mesh_visualizer import metadata +from tensorflow_graphics.tensorboard.mesh_visualizer import plugin_data_pb2 + +PLUGIN_NAME = 'mesh' + + +def _get_tensor_summary( + name, display_name, description, tensor, content_type, json_config, + collections): + """Creates a tensor summary with summary metadata. + + Args: + name: Uniquely identifiable name of the summary op. Could be replaced by + combination of name and type to make it unique even outside of this + summary. + display_name: Will be used as the display name in TensorBoard. + Defaults to `tag`. + description: A longform readable description of the summary data. Markdown + is supported. + tensor: Tensor to display in summary. + content_type: Type of content inside the Tensor. + json_config: A string, JSON-serialized dictionary of ThreeJS classes + configuration. + collections: List of collections to add this summary to. + + Returns: + Tensor summary with metadata. + """ + tensor = tf.convert_to_tensor(value=tensor) + tensor_metadata = metadata.create_summary_metadata( + name, + display_name, + content_type, + tensor.shape.as_list(), + description, + json_config=json_config) + tensor_summary = tf.summary.tensor_summary( + metadata.get_instance_name(name, content_type), + tensor, + summary_metadata=tensor_metadata, + collections=collections) + return tensor_summary + + +def _get_display_name(name, display_name): + """Returns display_name from display_name and name.""" + if display_name is None: + return name + return display_name + + +def _get_json_config(config_dict): + """Parses and returns JSON string from python dictionary.""" + json_config = '{}' + if config_dict is not None: + json_config = json.dumps(config_dict, sort_keys=True) + return json_config + + +def op(name, vertices, faces=None, colors=None, display_name=None, + description=None, collections=None, config_dict=None): + """Creates a TensorFlow summary op for mesh rendering. + + Args: + name: A name for this summary operation. + vertices: Tensor of shape `[dim_1, ..., dim_n, 3]` representing the 3D + coordinates of vertices. + faces: Tensor of shape `[dim_1, ..., dim_n, 3]` containing indices of + vertices within each triangle. + colors: Tensor of shape `[dim_1, ..., dim_n, 3]` containing colors for each + vertex. + display_name: If set, will be used as the display name in TensorBoard. + Defaults to `name`. + description: A longform readable description of the summary data. Markdown + is supported. + collections: Which TensorFlow graph collections to add the summary op to. + Defaults to `['summaries']`. Can usually be ignored. + config_dict: Dictionary with ThreeJS classes names and configuration. + Returns: + Merged summary for mesh/point cloud representation. + """ + display_name = _get_display_name(name, display_name) + json_config = _get_json_config(config_dict) + + # All tensors representing a single mesh will be represented as separate + # summaries internally. Those summaries will be regrouped on the client before + # rendering. + summaries = [] + tensors = [(vertices, plugin_data_pb2.MeshPluginData.VERTEX), + (faces, plugin_data_pb2.MeshPluginData.FACE), + (colors, plugin_data_pb2.MeshPluginData.COLOR)] + + for tensor, content_type in tensors: + if tensor is None: + continue + summaries.append( + _get_tensor_summary(name, display_name, description, tensor, + content_type, json_config, collections)) + + all_summaries = tf.summary.merge( + summaries, collections=collections, name=name) + return all_summaries + + +def pb(name, + vertices, + faces=None, + colors=None, + display_name=None, + description=None, + config_dict=None): + """Create a mesh summary to save in pb format. + + Args: + name: A name for this summary operation. + vertices: numpy array of shape `[dim_1, ..., dim_n, 3]` representing the 3D + coordinates of vertices. + faces: numpy array of shape `[dim_1, ..., dim_n, 3]` containing indices of + vertices within each triangle. + colors: numpy array of shape `[dim_1, ..., dim_n, 3]` containing colors for + each vertex. + display_name: If set, will be used as the display name in TensorBoard. + Defaults to `name`. + description: A longform readable description of the summary data. Markdown + is supported. + config_dict: Dictionary with ThreeJS classes names and configuration. + + Returns: + Instance of tf.Summary class. + """ + display_name = _get_display_name(name, display_name) + json_config = _get_json_config(config_dict) + + summaries = [] + tensors = [(vertices, plugin_data_pb2.MeshPluginData.VERTEX, tf.float32), + (faces, plugin_data_pb2.MeshPluginData.FACE, tf.int32), + (colors, plugin_data_pb2.MeshPluginData.COLOR, tf.uint8)] + for tensor, content_type, data_type in tensors: + if tensor is None: + continue + tensor_shape = tensor.shape + tensor = tf.compat.v1.make_tensor_proto(tensor, dtype=data_type) + summary_metadata = metadata.create_summary_metadata( + name, + display_name, + content_type, + tensor_shape, + description, + json_config=json_config) + tag = metadata.get_instance_name(name, content_type) + summaries.append((tag, summary_metadata, tensor)) + + summary = tf.Summary() + for tag, summary_metadata, tensor in summaries: + tf_summary_metadata = tf.SummaryMetadata.FromString( + summary_metadata.SerializeToString()) + summary.value.add(tag=tag, metadata=tf_summary_metadata, tensor=tensor) + return summary \ No newline at end of file diff --git a/tensorboard/plugins/mesh_visualizer/mesh_summary_test.py b/tensorboard/plugins/mesh_visualizer/mesh_summary_test.py new file mode 100644 index 00000000000..2cd1fe551c5 --- /dev/null +++ b/tensorboard/plugins/mesh_visualizer/mesh_summary_test.py @@ -0,0 +1,101 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow_graphics.tensorboard.mesh_visualizer.mesh_summary.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import tensorflow as tf + +from tensorflow_graphics.tensorboard.mesh_visualizer import mesh_summary +from tensorflow_graphics.tensorboard.mesh_visualizer import metadata +from tensorflow_graphics.tensorboard.mesh_visualizer import plugin_data_pb2 +from tensorflow_graphics.tensorboard.mesh_visualizer import test_utils +from tensorboard.compat.proto import summary_pb2 + + +class MeshSummaryTest(tf.test.TestCase): + + def pb_via_op(self, summary_op): + """Parses pb proto.""" + actual_pbtxt = summary_op.eval() + actual_proto = summary_pb2.Summary() + actual_proto.ParseFromString(actual_pbtxt) + return actual_proto + + def verify_proto(self, proto, name): + """Validates proto.""" + self.assertEqual(3, len(proto.value)) + self.assertEqual("%s_VERTEX" % name, proto.value[0].tag) + self.assertEqual("%s_FACE" % name, proto.value[1].tag) + self.assertEqual("%s_COLOR" % name, proto.value[2].tag) + + def test_get_tensor_summary(self): + """Tests proper creation of tensor summary with mesh plugin metadata.""" + name = "my_mesh" + display_name = "my_display_name" + description = "my mesh is the best of meshes" + tensor_data = test_utils.get_random_mesh(100) + tensor_summary = mesh_summary._get_tensor_summary( + name, display_name, description, tensor_data.vertices, + plugin_data_pb2.MeshPluginData.VERTEX, "", None) + with self.test_session(): + proto = self.pb_via_op(tensor_summary) + self.assertEqual("%s_VERTEX" % name, proto.value[0].tag) + self.assertEqual(metadata.PLUGIN_NAME, + proto.value[0].metadata.plugin_data.plugin_name) + + def test_op(self): + """Tests merged summary with different types of data.""" + name = "my_mesh" + tensor_data = test_utils.get_random_mesh( + 100, add_faces=True, add_colors=True) + config_dict = {"foo": 1} + tensor_summary = mesh_summary.op( + name, + tensor_data.vertices, + faces=tensor_data.faces, + colors=tensor_data.colors, + config_dict=config_dict) + with self.test_session(): + proto = self.pb_via_op(tensor_summary) + self.verify_proto(proto, name) + plugin_metadata = metadata.parse_plugin_metadata( + proto.value[0].metadata.plugin_data.content) + self.assertEqual( + json.dumps(config_dict, sort_keys=True), plugin_metadata.json_config) + + def test_pb(self): + """Tests merged summary protobuf with different types of data.""" + name = "my_mesh" + tensor_data = test_utils.get_random_mesh( + 100, add_faces=True, add_colors=True) + config_dict = {"foo": 1} + proto = mesh_summary.pb( + name, + tensor_data.vertices, + faces=tensor_data.faces, + colors=tensor_data.colors, + config_dict=config_dict) + self.verify_proto(proto, name) + plugin_metadata = metadata.parse_plugin_metadata( + proto.value[0].metadata.plugin_data.content) + self.assertEqual( + json.dumps(config_dict, sort_keys=True), plugin_metadata.json_config) + + +if __name__ == "__main__": + tf.test.main() \ No newline at end of file diff --git a/tensorboard/plugins/mesh_visualizer/metadata.py b/tensorboard/plugins/mesh_visualizer/metadata.py new file mode 100644 index 00000000000..a705fdacba4 --- /dev/null +++ b/tensorboard/plugins/mesh_visualizer/metadata.py @@ -0,0 +1,98 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Internal information about the mesh plugin.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow_graphics.tensorboard.mesh_visualizer import plugin_data_pb2 +from tensorboard.compat.proto import summary_pb2 + +# TODO(podlipensky): use this variable everywhere, avoid duplicating 'mesh'. +# b/129002587 +PLUGIN_NAME = 'mesh' + +# The most recent value for the `version` field of the +# `PrCurvePluginData` proto. +_PROTO_VERSION = 0 + + +def get_current_version(): + """Returns current verions of the proto.""" + return _PROTO_VERSION + + +def get_instance_name(name, content_type): + """Returns a unique instance name for a given summary related to the mesh.""" + return '%s_%s' % ( + name, + plugin_data_pb2.MeshPluginData.ContentType.Name(content_type)) + + +def create_summary_metadata(name, + display_name, + content_type, + shape, + description=None, + json_config=None): + """Creates summary metadata which defined at MeshPluginData proto. + + Arguments: + name: Original merged (summaries of different types) summary name. + display_name: The display name used in TensorBoard. + content_type: Value from MeshPluginData.ContentType enum describing data. + shape: list of dimensions sizes of the tensor. + description: The description to show in TensorBoard. + json_config: A string, JSON-serialized dictionary of ThreeJS classes + configuration. + + Returns: + A `summary_pb2.SummaryMetadata` protobuf object. + """ + mesh_plugin_data = plugin_data_pb2.MeshPluginData( + version=get_current_version(), + name=name, + content_type=content_type, + shape=shape, + json_config=json_config) + content = mesh_plugin_data.SerializeToString() + return summary_pb2.SummaryMetadata( + display_name=display_name, # Will not be used in TensorBoard UI. + summary_description=description, + plugin_data=summary_pb2.SummaryMetadata.PluginData( + plugin_name=PLUGIN_NAME, + content=content)) + + +def parse_plugin_metadata(content): + """Parse summary metadata to a Python object. + + Arguments: + content: The `content` field of a `SummaryMetadata` proto + corresponding to the mesh plugin. + + Returns: + A `MeshPluginData` protobuf object. + Raises: Error if the version of the plugin is not supported. + """ + if not isinstance(content, bytes): + raise TypeError('Content type must be bytes.') + result = plugin_data_pb2.MeshPluginData.FromString(content) + if result.version == get_current_version(): + return result + raise ValueError('Unknown metadata version: %s. The latest version known to ' + 'this build of TensorBoard is %s; perhaps a newer build is ' + 'available?' % (result.version, get_current_version())) \ No newline at end of file diff --git a/tensorboard/plugins/mesh_visualizer/metadata_test.py b/tensorboard/plugins/mesh_visualizer/metadata_test.py new file mode 100644 index 00000000000..5fd5d30feeb --- /dev/null +++ b/tensorboard/plugins/mesh_visualizer/metadata_test.py @@ -0,0 +1,86 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for util functions to create/parse mesh plugin metadata.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from mock import patch +import six +import tensorflow as tf +from tensorflow_graphics.tensorboard.mesh_visualizer import metadata +from tensorflow_graphics.tensorboard.mesh_visualizer import plugin_data_pb2 + + +class MetadataTest(tf.test.TestCase): + + def _create_metadata(self): + """Creates metadata with dummy data.""" + self.name = 'unique_name' + self.display_name = 'my mesh' + self.json_config = '{}' + self.shape = [1, 100, 3] + self.summary_metadata = metadata.create_summary_metadata( + self.name, + self.display_name, + plugin_data_pb2.MeshPluginData.ContentType.VERTEX, + self.shape, + json_config=self.json_config) + + def test_get_instance_name(self): + """Tests proper creation of instance name based on display_name.""" + display_name = 'my_mesh' + instance_name = metadata.get_instance_name( + display_name, plugin_data_pb2.MeshPluginData.ContentType.VERTEX) + self.assertEqual('%s_VERTEX' % display_name, instance_name) + + def test_create_summary_metadata(self): + """Tests MeshPlugin metadata creation.""" + self._create_metadata() + self.assertEqual(self.display_name, + self.summary_metadata.display_name) + self.assertEqual(metadata.PLUGIN_NAME, + self.summary_metadata.plugin_data.plugin_name) + + def test_parse_plugin_metadata(self): + """Tests parsing of saved plugin metadata.""" + self._create_metadata() + parsed_metadata = metadata.parse_plugin_metadata( + self.summary_metadata.plugin_data.content) + self.assertEqual(self.name, parsed_metadata.name) + self.assertEqual(plugin_data_pb2.MeshPluginData.ContentType.VERTEX, + parsed_metadata.content_type) + self.assertEqual(self.shape, parsed_metadata.shape) + self.assertEqual(self.json_config, parsed_metadata.json_config) + + def test_metadata_version(self): + """Tests that only the latest version of metadata is supported.""" + self._create_metadata() + # Change the version. + with patch.object(metadata, 'get_current_version', return_value=100): + # Try to parse metadata from a prior version. + with self.assertRaises(ValueError): + metadata.parse_plugin_metadata( + self.summary_metadata.plugin_data.content) + + def test_metadata_format(self): + """Tests that metadata content must be passed as a serialized string.""" + with six.assertRaisesRegex(self, TypeError, r'Content type must be bytes.'): + metadata.parse_plugin_metadata(123) + + +if __name__ == '__main__': + tf.test.main() \ No newline at end of file diff --git a/tensorboard/plugins/mesh_visualizer/plugin_data.proto b/tensorboard/plugins/mesh_visualizer/plugin_data.proto new file mode 100644 index 00000000000..c87dd8bdf33 --- /dev/null +++ b/tensorboard/plugins/mesh_visualizer/plugin_data.proto @@ -0,0 +1,29 @@ +syntax = "proto3"; + +package tensorboard.mesh_visualizer; + +// A MeshPluginData encapsulates information on which plugins are able to make +// use of a certain summary value. +message MeshPluginData { + enum ContentType { + UNDEFINED = 0; + VERTEX = 1; + FACE = 2; // Triangle face. + COLOR = 3; + } + + // Version `0` is the only supported version. + int32 version = 1; + + // The name of the mesh summary this particular summary belongs to. + string name = 2; + + // Type of data in the summary. + ContentType content_type = 3; + + // JSON-serialized dictionary of ThreeJS classes configuration. + string json_config = 5; + + // Shape of underlying data. Cache it here for performance reasons. + repeated int32 shape = 6; +} \ No newline at end of file diff --git a/tensorboard/plugins/mesh_visualizer/test_utils.py b/tensorboard/plugins/mesh_visualizer/test_utils.py new file mode 100644 index 00000000000..cace30b69ec --- /dev/null +++ b/tensorboard/plugins/mesh_visualizer/test_utils.py @@ -0,0 +1,180 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test utils for mesh plugin tests.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import google_type_annotations +from __future__ import print_function + +import collections +import json +import threading +import numpy as np +import tensorflow as tf +from google3.third_party.tensorboard.compat.proto import event_pb2 +from google3.third_party.tensorboard.compat.proto import graph_pb2 +from google3.third_party.tensorboard.compat.proto import meta_graph_pb2 +from google3.third_party.tensorboard.compat.proto import summary_pb2 +from google3.third_party.tensorboard.util import tb_logging + +Mesh = collections.namedtuple('Mesh', ('vertices', 'faces', 'colors')) +logger = tb_logging.get_logger() + + +# NOTE: copy FileWriter and FileWriterCache from tensorboard test_util.py +# until this plugin start live in TensorBoard github repo. +class FileWriter(tf.compat.v1.summary.FileWriter): + """FileWriter for test. + + TensorFlow FileWriter uses TensorFlow's Protobuf Python binding which is + largely discouraged in TensorBoard. We do not want a TB.Writer but require one + for testing in integrational style (writing out event files and use the real + event readers). + """ + + def add_event(self, event): + if isinstance(event, event_pb2.Event): + tf_event = tf.compat.v1.Event.FromString(event.SerializeToString()) + else: + logger.warn('Added TensorFlow event proto. ' + 'Please prefer TensorBoard copy of the proto') + tf_event = event + super(FileWriter, self).add_event(tf_event) + + def add_summary(self, summary, global_step=None): + if isinstance(summary, summary_pb2.Summary): + tf_summary = tf.compat.v1.Summary.FromString(summary.SerializeToString()) + else: + logger.warn('Added TensorFlow summary proto. ' + 'Please prefer TensorBoard copy of the proto') + tf_summary = summary + super(FileWriter, self).add_summary(tf_summary, global_step) + + def add_session_log(self, session_log, global_step=None): + if isinstance(session_log, event_pb2.SessionLog): + tf_session_log = tf.compat.v1.SessionLog.FromString( + session_log.SerializeToString()) + else: + logger.warn('Added TensorFlow session_log proto. ' + 'Please prefer TensorBoard copy of the proto') + tf_session_log = session_log + super(FileWriter, self).add_session_log(tf_session_log, global_step) + + def add_graph(self, graph, global_step=None, graph_def=None): + if isinstance(graph_def, graph_pb2.GraphDef): + tf_graph_def = tf.compat.v1.GraphDef.FromString( + graph_def.SerializeToString()) + else: + tf_graph_def = graph_def + + super(FileWriter, self).add_graph( + graph, global_step=global_step, graph_def=tf_graph_def) + + def add_meta_graph(self, meta_graph_def, global_step=None): + if isinstance(meta_graph_def, meta_graph_pb2.MetaGraphDef): + tf_meta_graph_def = tf.compat.v1.MetaGraphDef.FromString( + meta_graph_def.SerializeToString()) + else: + tf_meta_graph_def = meta_graph_def + + super(FileWriter, self).add_meta_graph( + meta_graph_def=tf_meta_graph_def, global_step=global_step) + + +class FileWriterCache(object): + """Cache for TensorBoard test file writers.""" + # Cache, keyed by directory. + _cache = {} + + # Lock protecting _FILE_WRITERS. + _lock = threading.RLock() + + @staticmethod + def get(logdir): + """Returns the FileWriter for the specified directory. + + Args: + logdir: str, name of the directory. + + Returns: + A `FileWriter`. + """ + with FileWriterCache._lock: + if logdir not in FileWriterCache._cache: + FileWriterCache._cache[logdir] = FileWriter( + logdir, graph=tf.compat.v1.get_default_graph()) + return FileWriterCache._cache[logdir] + + +def get_random_mesh(num_vertices, + add_faces=False, + add_colors=False, + batch_size=1): + """Returns a random point cloud, optionally with random disconnected faces. + + Args: + num_vertices: Number of vertices in the point cloud or mesh. + add_faces: Random faces will be generated and added to the mesh when True. + add_colors: Random colors will be assigned to each vertex when True. Each + color will be in a range of [0, 255]. + batch_size: Size of batch dimension in output array. + + Returns: + Mesh namedtuple with vertices and optionally with faces and/or colors. + """ + vertices = np.random.random([num_vertices, 3]) * 1000 + # Add batch dimension. + vertices = np.tile(vertices, [batch_size, 1, 1]) + faces = None + colors = None + if add_faces: + arranged_vertices = np.random.permutation(num_vertices) + faces = [] + for i in range(num_vertices - 2): + faces.append([ + arranged_vertices[i], arranged_vertices[i + 1], + arranged_vertices[i + 2] + ]) + faces = np.array(faces) + faces = np.tile(faces, [batch_size, 1, 1]).astype(np.int32) + if add_colors: + colors = np.random.randint(low=0, high=255, size=[num_vertices, 3]) + colors = np.tile(colors, [batch_size, 1, 1]).astype(np.uint8) + return Mesh(vertices.astype(np.float32), faces, colors) + + +def deserialize_json_response(byte_content): + """Deserializes byte content that is a JSON encoding. + + Args: + byte_content: The byte content of a response. + + Returns: + The deserialized python object decoded from JSON. + """ + return json.loads(byte_content.decode('utf-8')) + + +def deserialize_array_buffer_response(byte_content, data_type): + """Deserializes arraybuffer response and optionally tiles the array. + + Args: + byte_content: The byte content of a response. + data_type: Numpy type to parse data with. + + Returns: + Flat numpy array with the data. + """ + return np.frombuffer(byte_content, dtype=data_type) \ No newline at end of file From 3e7a8e3513c845b41a71e00d9539f660de361203 Mon Sep 17 00:00:00 2001 From: Pavel Podlipensky Date: Wed, 1 May 2019 18:27:21 -0700 Subject: [PATCH 02/12] Address several Stephan's comments. Switching to TB repo completely. Update tests. --- tensorboard/plugins/mesh_visualizer/BUILD | 10 +- .../plugins/mesh_visualizer/__init__.py | 14 +++ .../plugins/mesh_visualizer/mesh_plugin.py | 14 +-- .../mesh_visualizer/mesh_plugin_test.py | 35 ++++--- .../plugins/mesh_visualizer/mesh_summary.py | 4 +- .../mesh_visualizer/mesh_summary_test.py | 12 ++- .../plugins/mesh_visualizer/metadata.py | 3 +- .../plugins/mesh_visualizer/metadata_test.py | 12 ++- .../plugins/mesh_visualizer/test_utils.py | 96 +------------------ 9 files changed, 72 insertions(+), 128 deletions(-) diff --git a/tensorboard/plugins/mesh_visualizer/BUILD b/tensorboard/plugins/mesh_visualizer/BUILD index 24fdd7ad4ee..c0ba46e9817 100644 --- a/tensorboard/plugins/mesh_visualizer/BUILD +++ b/tensorboard/plugins/mesh_visualizer/BUILD @@ -23,6 +23,7 @@ py_test( deps = [ ":metadata", "//tensorboard:expect_tensorflow_installed", + "//tensorboard/util:test_util", "@org_pythonhosted_six", ], ) @@ -38,9 +39,9 @@ py_library( "//tensorboard:expect_numpy_installed", "//tensorboard:plugin_util", "//tensorboard/backend:http_util", - "//tensorboard/backend/event_processing:event_accumulator", - "//tensorboard/compat:tensorflow", + "//tensorboard/backend/event_processing:event_accumulator", "//tensorboard/plugins:base_plugin", + "//tensorboard/util:tensor_util", "@org_pythonhosted_six", "@org_pocoo_werkzeug", ], @@ -74,8 +75,10 @@ py_test( "//tensorboard/backend:application", "//tensorboard/backend/event_processing:event_multiplexer", "//tensorboard/plugins:base_plugin", + "//tensorboard/util:test_util", "@org_pocoo_werkzeug", "@org_pythonhosted_six", + "@org_pythonhosted_mock", ], ) @@ -100,8 +103,9 @@ py_test( srcs_version = "PY2AND3", deps = [ ":mesh_summary", - ":test_utils", + ":test_utils", "//tensorboard:expect_tensorflow_installed", + "//tensorboard/util:test_util", ], ) diff --git a/tensorboard/plugins/mesh_visualizer/__init__.py b/tensorboard/plugins/mesh_visualizer/__init__.py index e69de29bb2d..931c2ef11db 100644 --- a/tensorboard/plugins/mesh_visualizer/__init__.py +++ b/tensorboard/plugins/mesh_visualizer/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/tensorboard/plugins/mesh_visualizer/mesh_plugin.py b/tensorboard/plugins/mesh_visualizer/mesh_plugin.py index ac574204e01..5f2e7d610a9 100644 --- a/tensorboard/plugins/mesh_visualizer/mesh_plugin.py +++ b/tensorboard/plugins/mesh_visualizer/mesh_plugin.py @@ -20,18 +20,18 @@ import collections import numpy as np import six -import tensorflow as tf -from tensorflow_graphics.tensorboard.mesh_visualizer import metadata -from tensorflow_graphics.tensorboard.mesh_visualizer import plugin_data_pb2 +from tensorboard.util import tensor_util from werkzeug import wrappers from tensorboard.backend import http_util from tensorboard.plugins import base_plugin +from tensorboard.plugins.mesh_visualizer import metadata +from tensorboard.plugins.mesh_visualizer import plugin_data_pb2 class MeshPlugin(base_plugin.TBPlugin): """A plugin that serves 3D visualization of meshes.""" - plugin_name = 'mesh' + plugin_name = metadata.PLUGIN_NAME def __init__(self, context): """Instantiates a MeshPlugin via TensorBoard core. @@ -139,7 +139,7 @@ def is_active(self): def _get_sample(self, tensor_event, sample): """Returns a single sample from a batch of samples.""" - data = tf.make_ndarray(tensor_event.tensor_proto) + data = tensor_util.make_ndarray(tensor_event.tensor_proto) return data[sample].tolist() def _get_tensor_metadata(self, event, content_type, data_shape, config): @@ -218,9 +218,9 @@ def _serve_mesh_data(self, request): ] np_type = np.float32 - if content_type == plugin_data_pb2.MeshPluginData.ContentType.FACE: + if content_type == plugin_data_pb2.MeshPluginData.ContentType.Value('FACE'): np_type = np.int32 - elif content_type == plugin_data_pb2.MeshPluginData.ContentType.COLOR: + elif content_type == plugin_data_pb2.MeshPluginData.ContentType.Value('COLOR'): np_type = np.uint8 response = np.array(response, dtype=np_type) # Looks like reshape can take around 160ms, so why not store it reshaped. diff --git a/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py b/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py index 110cc487fa9..2a70b4cc54c 100644 --- a/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py +++ b/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py @@ -20,21 +20,28 @@ import collections import os import shutil -from mock import patch import numpy as np import tensorflow as tf -from tensorflow_graphics.tensorboard.mesh_visualizer import mesh_plugin -from tensorflow_graphics.tensorboard.mesh_visualizer import mesh_summary -from tensorflow_graphics.tensorboard.mesh_visualizer import plugin_data_pb2 -from tensorflow_graphics.tensorboard.mesh_visualizer import test_utils from werkzeug import test as werkzeug_test from werkzeug import wrappers -from google3.third_party.tensorboard.backend import application -from google3.third_party.tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer -from google3.third_party.tensorboard.plugins import base_plugin - - +from tensorboard.backend import application +from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer +from tensorboard.plugins import base_plugin +from tensorboard.plugins.mesh_visualizer import mesh_plugin +from tensorboard.plugins.mesh_visualizer import mesh_summary +from tensorboard.plugins.mesh_visualizer import plugin_data_pb2 +from tensorboard.plugins.mesh_visualizer import test_utils +from tensorboard.util import test_util as tensorboard_test_util + +try: + # python version >= 3.3 + from unittest import mock # pylint: disable=g-import-not-at-top +except ImportError: + import mock # pylint: disable=g-import-not-at-top,unused-import + + +@tensorboard_test_util.run_v1_only('Uses contrib') class MeshPluginTest(tf.test.TestCase): """Tests for mesh plugin server.""" @@ -93,7 +100,7 @@ def setUp(self): self.runs = ["bar"] self.steps = 20 bar_directory = os.path.join(self.log_dir, self.runs[0]) - with test_utils.FileWriterCache.get(bar_directory) as writer: + with tensorboard_test_util.FileWriterCache.get(bar_directory) as writer: writer.add_graph(sess.graph) for step in xrange(self.steps): writer.add_summary( @@ -179,7 +186,7 @@ def testMetadataRoute(self): (self.runs[0], self.names[0], 0)) self.assertEqual(200, response.status_code) metadata = test_utils.deserialize_json_response(response.get_data()) - self.assertLen(metadata, self.steps) + self.assertEqual(len(metadata), self.steps) self.assertAllEqual(metadata[0]["content_type"], plugin_data_pb2.MeshPluginData.VERTEX) self.assertAllEqual(metadata[0]["data_shape"], self.data[0].vertices.shape) @@ -197,7 +204,7 @@ def testsEventsAlwaysSortedByWallTime(self): self.assertLessEqual(metadata[i - 1]["wall_time"], metadata[i]["wall_time"]) - @patch.object( + @mock.patch.object( event_multiplexer.EventMultiplexer, "PluginRunToTagToContent", return_value={"bar": { @@ -212,7 +219,7 @@ def testMetadataComputedOnce(self, run_to_tag_mock): def testIsActive(self): self.assertTrue(self.plugin.is_active()) - @patch.object( + @mock.patch.object( event_multiplexer.EventMultiplexer, "PluginRunToTagToContent", return_value={}) diff --git a/tensorboard/plugins/mesh_visualizer/mesh_summary.py b/tensorboard/plugins/mesh_visualizer/mesh_summary.py index 766d94659e1..7952f18a197 100644 --- a/tensorboard/plugins/mesh_visualizer/mesh_summary.py +++ b/tensorboard/plugins/mesh_visualizer/mesh_summary.py @@ -20,8 +20,8 @@ import json import tensorflow as tf -from tensorflow_graphics.tensorboard.mesh_visualizer import metadata -from tensorflow_graphics.tensorboard.mesh_visualizer import plugin_data_pb2 +from tensorboard.plugins.mesh_visualizer import metadata +from tensorboard.plugins.mesh_visualizer import plugin_data_pb2 PLUGIN_NAME = 'mesh' diff --git a/tensorboard/plugins/mesh_visualizer/mesh_summary_test.py b/tensorboard/plugins/mesh_visualizer/mesh_summary_test.py index 2cd1fe551c5..a76d612fa91 100644 --- a/tensorboard/plugins/mesh_visualizer/mesh_summary_test.py +++ b/tensorboard/plugins/mesh_visualizer/mesh_summary_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for tensorflow_graphics.tensorboard.mesh_visualizer.mesh_summary.""" +"""Tests for tensorboard.plugins.mesh_visualizer.mesh_summary.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -20,13 +20,15 @@ import json import tensorflow as tf -from tensorflow_graphics.tensorboard.mesh_visualizer import mesh_summary -from tensorflow_graphics.tensorboard.mesh_visualizer import metadata -from tensorflow_graphics.tensorboard.mesh_visualizer import plugin_data_pb2 -from tensorflow_graphics.tensorboard.mesh_visualizer import test_utils from tensorboard.compat.proto import summary_pb2 +from tensorboard.plugins.mesh_visualizer import mesh_summary +from tensorboard.plugins.mesh_visualizer import metadata +from tensorboard.plugins.mesh_visualizer import plugin_data_pb2 +from tensorboard.plugins.mesh_visualizer import test_utils +from tensorboard.util import test_util +@test_util.run_v1_only('Uses contrib') class MeshSummaryTest(tf.test.TestCase): def pb_via_op(self, summary_op): diff --git a/tensorboard/plugins/mesh_visualizer/metadata.py b/tensorboard/plugins/mesh_visualizer/metadata.py index a705fdacba4..2572114c6cb 100644 --- a/tensorboard/plugins/mesh_visualizer/metadata.py +++ b/tensorboard/plugins/mesh_visualizer/metadata.py @@ -18,8 +18,9 @@ from __future__ import division from __future__ import print_function -from tensorflow_graphics.tensorboard.mesh_visualizer import plugin_data_pb2 from tensorboard.compat.proto import summary_pb2 +from tensorboard.plugins.mesh_visualizer import plugin_data_pb2 + # TODO(podlipensky): use this variable everywhere, avoid duplicating 'mesh'. # b/129002587 diff --git a/tensorboard/plugins/mesh_visualizer/metadata_test.py b/tensorboard/plugins/mesh_visualizer/metadata_test.py index 5fd5d30feeb..1747d3c5b25 100644 --- a/tensorboard/plugins/mesh_visualizer/metadata_test.py +++ b/tensorboard/plugins/mesh_visualizer/metadata_test.py @@ -21,10 +21,12 @@ from mock import patch import six import tensorflow as tf -from tensorflow_graphics.tensorboard.mesh_visualizer import metadata -from tensorflow_graphics.tensorboard.mesh_visualizer import plugin_data_pb2 +from tensorboard.plugins.mesh_visualizer import metadata +from tensorboard.plugins.mesh_visualizer import plugin_data_pb2 +from tensorboard.util import test_util +@test_util.run_v1_only('Uses contrib') class MetadataTest(tf.test.TestCase): def _create_metadata(self): @@ -36,7 +38,7 @@ def _create_metadata(self): self.summary_metadata = metadata.create_summary_metadata( self.name, self.display_name, - plugin_data_pb2.MeshPluginData.ContentType.VERTEX, + plugin_data_pb2.MeshPluginData.ContentType.Value('VERTEX'), self.shape, json_config=self.json_config) @@ -44,7 +46,7 @@ def test_get_instance_name(self): """Tests proper creation of instance name based on display_name.""" display_name = 'my_mesh' instance_name = metadata.get_instance_name( - display_name, plugin_data_pb2.MeshPluginData.ContentType.VERTEX) + display_name, plugin_data_pb2.MeshPluginData.ContentType.Value('VERTEX')) self.assertEqual('%s_VERTEX' % display_name, instance_name) def test_create_summary_metadata(self): @@ -61,7 +63,7 @@ def test_parse_plugin_metadata(self): parsed_metadata = metadata.parse_plugin_metadata( self.summary_metadata.plugin_data.content) self.assertEqual(self.name, parsed_metadata.name) - self.assertEqual(plugin_data_pb2.MeshPluginData.ContentType.VERTEX, + self.assertEqual(plugin_data_pb2.MeshPluginData.ContentType.Value('VERTEX'), parsed_metadata.content_type) self.assertEqual(self.shape, parsed_metadata.shape) self.assertEqual(self.json_config, parsed_metadata.json_config) diff --git a/tensorboard/plugins/mesh_visualizer/test_utils.py b/tensorboard/plugins/mesh_visualizer/test_utils.py index cace30b69ec..4b751680f02 100644 --- a/tensorboard/plugins/mesh_visualizer/test_utils.py +++ b/tensorboard/plugins/mesh_visualizer/test_utils.py @@ -15,7 +15,6 @@ """Test utils for mesh plugin tests.""" from __future__ import absolute_import from __future__ import division -from __future__ import google_type_annotations from __future__ import print_function import collections @@ -23,101 +22,16 @@ import threading import numpy as np import tensorflow as tf -from google3.third_party.tensorboard.compat.proto import event_pb2 -from google3.third_party.tensorboard.compat.proto import graph_pb2 -from google3.third_party.tensorboard.compat.proto import meta_graph_pb2 -from google3.third_party.tensorboard.compat.proto import summary_pb2 -from google3.third_party.tensorboard.util import tb_logging +from tensorboard.compat.proto import event_pb2 +from tensorboard.compat.proto import graph_pb2 +from tensorboard.compat.proto import meta_graph_pb2 +from tensorboard.compat.proto import summary_pb2 +from tensorboard.util import tb_logging Mesh = collections.namedtuple('Mesh', ('vertices', 'faces', 'colors')) logger = tb_logging.get_logger() -# NOTE: copy FileWriter and FileWriterCache from tensorboard test_util.py -# until this plugin start live in TensorBoard github repo. -class FileWriter(tf.compat.v1.summary.FileWriter): - """FileWriter for test. - - TensorFlow FileWriter uses TensorFlow's Protobuf Python binding which is - largely discouraged in TensorBoard. We do not want a TB.Writer but require one - for testing in integrational style (writing out event files and use the real - event readers). - """ - - def add_event(self, event): - if isinstance(event, event_pb2.Event): - tf_event = tf.compat.v1.Event.FromString(event.SerializeToString()) - else: - logger.warn('Added TensorFlow event proto. ' - 'Please prefer TensorBoard copy of the proto') - tf_event = event - super(FileWriter, self).add_event(tf_event) - - def add_summary(self, summary, global_step=None): - if isinstance(summary, summary_pb2.Summary): - tf_summary = tf.compat.v1.Summary.FromString(summary.SerializeToString()) - else: - logger.warn('Added TensorFlow summary proto. ' - 'Please prefer TensorBoard copy of the proto') - tf_summary = summary - super(FileWriter, self).add_summary(tf_summary, global_step) - - def add_session_log(self, session_log, global_step=None): - if isinstance(session_log, event_pb2.SessionLog): - tf_session_log = tf.compat.v1.SessionLog.FromString( - session_log.SerializeToString()) - else: - logger.warn('Added TensorFlow session_log proto. ' - 'Please prefer TensorBoard copy of the proto') - tf_session_log = session_log - super(FileWriter, self).add_session_log(tf_session_log, global_step) - - def add_graph(self, graph, global_step=None, graph_def=None): - if isinstance(graph_def, graph_pb2.GraphDef): - tf_graph_def = tf.compat.v1.GraphDef.FromString( - graph_def.SerializeToString()) - else: - tf_graph_def = graph_def - - super(FileWriter, self).add_graph( - graph, global_step=global_step, graph_def=tf_graph_def) - - def add_meta_graph(self, meta_graph_def, global_step=None): - if isinstance(meta_graph_def, meta_graph_pb2.MetaGraphDef): - tf_meta_graph_def = tf.compat.v1.MetaGraphDef.FromString( - meta_graph_def.SerializeToString()) - else: - tf_meta_graph_def = meta_graph_def - - super(FileWriter, self).add_meta_graph( - meta_graph_def=tf_meta_graph_def, global_step=global_step) - - -class FileWriterCache(object): - """Cache for TensorBoard test file writers.""" - # Cache, keyed by directory. - _cache = {} - - # Lock protecting _FILE_WRITERS. - _lock = threading.RLock() - - @staticmethod - def get(logdir): - """Returns the FileWriter for the specified directory. - - Args: - logdir: str, name of the directory. - - Returns: - A `FileWriter`. - """ - with FileWriterCache._lock: - if logdir not in FileWriterCache._cache: - FileWriterCache._cache[logdir] = FileWriter( - logdir, graph=tf.compat.v1.get_default_graph()) - return FileWriterCache._cache[logdir] - - def get_random_mesh(num_vertices, add_faces=False, add_colors=False, From 0f9eaff921560bff7d1e3e57c2b33666bd7e4edb Mon Sep 17 00:00:00 2001 From: Pavel Podlipensky Date: Thu, 2 May 2019 06:57:16 -0700 Subject: [PATCH 03/12] Add HTTP API document. --- .../plugins/mesh_visualizer/http_api.md | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 tensorboard/plugins/mesh_visualizer/http_api.md diff --git a/tensorboard/plugins/mesh_visualizer/http_api.md b/tensorboard/plugins/mesh_visualizer/http_api.md new file mode 100644 index 00000000000..30f2353fc42 --- /dev/null +++ b/tensorboard/plugins/mesh_visualizer/http_api.md @@ -0,0 +1,65 @@ +# Mesh plugin HTTP API + +The mesh plugin name is `mesh`, so all its routes are under +`/data/plugin/mesh`. + +## `/data/plugin/mesh/tags` + +Retrieves an index of tags containing mesh data. + +Returns a dictionary mapping from `runName` (quoted string) to +dictionaries that map a `tagName` (quoted string) to an object +containing that tag’s `displayName` and `description`, the latter of +which is a string containing sanitized HTML to be rendered into the DOM. +Here is an example: + { + "train_run": { + "mesh_color_tensor": { + "samples": 1 + }, + "point_cloud": { + "samples": 1 + } + } + } + +Note that runs without any mesh tags are included as keys with value the empty dictionary. + +## `/data/plugin/mesh/meshes?tag=mesh_color_tensor&run=train_run&sample=0` + +Retrieves all necessary metadata to render a mesh with particular tag. + +Returns list of metadata for each data (tensor) that should be retrieved next. This includes content type (i.e. vertices, faces or colors), shape of the data, scene configuration, wall time etc. Type of the content maps directly to underlying binary data type, i.e. `float32`, `int32` or `uint8`. + +Here is an example: + [ + { + "config": "{\"camera\": {\"cls\": \"PerspectiveCamera\", \"fov\": 75}}", + "data_shape": [1, 17192, 3], + "step": 0, + "content_type": 2, + "wall_time": 1556678491.836787 + }, + { + "config": "{\"camera\": {\"cls\": \"PerspectiveCamera\", \"fov\": 75}}", + "data_shape": [1, 9771, 3], + "step": 0, + "content_type": 3, + "wall_time": 1556678491.836787 + }, + { + "config": "{\"camera\": {\"cls\": \"PerspectiveCamera\", \"fov\": 75}}", + "data_shape": [1, 9771, 3], + "step": 0, + "content_type": 1, + "wall_time": 1556678491.836787 + } + ] + +Scene configuration is a JSON string passed to `config_dict` during summary creation and may contain the following high-level keys: `camera`, `lights` and `material`. Each such key must correspond to an object with `cls` property which must be a valid THREE.js class. The rest of the keys of the object will be used as parameters to the class constructor and should also be valid THREE.js options. Invalid keys will be ignored by the library. + +## `/data/plugin/mesh/data?tag=mesh_color_tensor&run=train_run&content_type=VERTEX&sample=0` + +Retrieves binary data of particular type representing some part of the mesh, for example vertices with 3D coordinates. + +Returns stream of binary data, which will represent either mesh vertices, faces or RGB colors. Response type of this request is set to `arraybuffer` therefore Typed Array will be received instead of a JSON string. From 066c56e527081064386f455790cc465ed982199f Mon Sep 17 00:00:00 2001 From: Pavel Podlipensky Date: Thu, 2 May 2019 09:57:16 -0700 Subject: [PATCH 04/12] Make Python3 compatible. --- tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py b/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py index 2a70b4cc54c..e85093a192b 100644 --- a/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py +++ b/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py @@ -102,7 +102,7 @@ def setUp(self): bar_directory = os.path.join(self.log_dir, self.runs[0]) with tensorboard_test_util.FileWriterCache.get(bar_directory) as writer: writer.add_graph(sess.graph) - for step in xrange(self.steps): + for step in range(self.steps): writer.add_summary( sess.run( merged_summary_op, From 9d511f469d08026da671d506d5b2dcb9a5a1eab8 Mon Sep 17 00:00:00 2001 From: Pavel Podlipensky Date: Thu, 2 May 2019 11:10:56 -0700 Subject: [PATCH 05/12] Fix unit test in python3 --- tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py b/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py index e85093a192b..fe529aaab27 100644 --- a/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py +++ b/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py @@ -157,7 +157,7 @@ def testDataRoute(self): (self.runs[0], self.names[0], 0, "VERTEX")) self.assertEqual(200, response.status_code) data = test_utils.deserialize_array_buffer_response( - response.response.next(), np.float32) + next(response.response), np.float32) vertices = np.tile(self.data[0].vertices.reshape(-1), self.steps) self.assertEqual(vertices.tolist(), data.tolist()) @@ -166,7 +166,7 @@ def testDataRoute(self): (self.runs[0], self.names[1], 0, "FACE")) self.assertEqual(200, response.status_code) data = test_utils.deserialize_array_buffer_response( - response.response.next(), np.int32) + next(response.response), np.int32) faces = np.tile(self.data[1].faces.reshape(-1), self.steps) self.assertEqual(faces.tolist(), data.tolist()) @@ -175,7 +175,7 @@ def testDataRoute(self): (self.runs[0], self.names[2], 0, "COLOR")) self.assertEqual(200, response.status_code) data = test_utils.deserialize_array_buffer_response( - response.response.next(), np.uint8) + next(response.response), np.uint8) colors = np.tile(self.data[2].colors.reshape(-1), self.steps) self.assertListEqual(colors.tolist(), data.tolist()) From cc49abda4f085f3afc2ecb6757a650b8a9ecdfb5 Mon Sep 17 00:00:00 2001 From: Pavel Podlipensky Date: Thu, 2 May 2019 11:43:02 -0700 Subject: [PATCH 06/12] Explicitly convert string to bytes in tests. --- tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py b/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py index fe529aaab27..17212fef809 100644 --- a/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py +++ b/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py @@ -208,7 +208,7 @@ def testsEventsAlwaysSortedByWallTime(self): event_multiplexer.EventMultiplexer, "PluginRunToTagToContent", return_value={"bar": { - "foo": "" + "foo": "".encode("utf-8") }}) def testMetadataComputedOnce(self, run_to_tag_mock): """Tests that metadata mapping computed once.""" From 8f6ecb947e693b03b545e2995ea8fd9a9c271aee Mon Sep 17 00:00:00 2001 From: Pavel Podlipensky Date: Fri, 3 May 2019 08:38:29 -0700 Subject: [PATCH 07/12] Addressing William's comments. --- tensorboard/plugins/mesh_visualizer/BUILD | 5 ++- .../plugins/mesh_visualizer/http_api.md | 36 ++++++++++++++----- .../plugins/mesh_visualizer/mesh_plugin.py | 36 ++++++++++++------- .../mesh_visualizer/mesh_plugin_test.py | 5 +-- .../plugins/mesh_visualizer/mesh_summary.py | 11 +++--- .../mesh_visualizer/mesh_summary_test.py | 3 +- .../plugins/mesh_visualizer/metadata.py | 6 ++-- .../plugins/mesh_visualizer/metadata_test.py | 3 +- .../plugins/mesh_visualizer/test_utils.py | 3 +- 9 files changed, 71 insertions(+), 37 deletions(-) diff --git a/tensorboard/plugins/mesh_visualizer/BUILD b/tensorboard/plugins/mesh_visualizer/BUILD index c0ba46e9817..757c2e93333 100644 --- a/tensorboard/plugins/mesh_visualizer/BUILD +++ b/tensorboard/plugins/mesh_visualizer/BUILD @@ -38,8 +38,7 @@ py_library( ":protos_all_py_pb2", "//tensorboard:expect_numpy_installed", "//tensorboard:plugin_util", - "//tensorboard/backend:http_util", - "//tensorboard/backend/event_processing:event_accumulator", + "//tensorboard/backend:http_util", "//tensorboard/plugins:base_plugin", "//tensorboard/util:tensor_util", "@org_pythonhosted_six", @@ -113,4 +112,4 @@ tb_proto_library( name = "protos_all", srcs = ["plugin_data.proto"], visibility = ["//visibility:public"], -) \ No newline at end of file +) diff --git a/tensorboard/plugins/mesh_visualizer/http_api.md b/tensorboard/plugins/mesh_visualizer/http_api.md index 30f2353fc42..6d549b8f6df 100644 --- a/tensorboard/plugins/mesh_visualizer/http_api.md +++ b/tensorboard/plugins/mesh_visualizer/http_api.md @@ -3,6 +3,7 @@ The mesh plugin name is `mesh`, so all its routes are under `/data/plugin/mesh`. + ## `/data/plugin/mesh/tags` Retrieves an index of tags containing mesh data. @@ -12,6 +13,7 @@ dictionaries that map a `tagName` (quoted string) to an object containing that tag’s `displayName` and `description`, the latter of which is a string containing sanitized HTML to be rendered into the DOM. Here is an example: +```json { "train_run": { "mesh_color_tensor": { @@ -22,44 +24,62 @@ Here is an example: } } } +``` +Note that runs without any mesh tags are included as keys with value the empty +dictionary. -Note that runs without any mesh tags are included as keys with value the empty dictionary. ## `/data/plugin/mesh/meshes?tag=mesh_color_tensor&run=train_run&sample=0` Retrieves all necessary metadata to render a mesh with particular tag. -Returns list of metadata for each data (tensor) that should be retrieved next. This includes content type (i.e. vertices, faces or colors), shape of the data, scene configuration, wall time etc. Type of the content maps directly to underlying binary data type, i.e. `float32`, `int32` or `uint8`. +Returns list of metadata for each data (tensor) that should be retrieved next. +This includes content type (i.e. vertices, faces or colors), shape of the +data, scene configuration, wall time etc. Type of the content maps directly to +underlying binary data type, i.e. `float32`, `int32` or `uint8`. Here is an example: +```json [ { - "config": "{\"camera\": {\"cls\": \"PerspectiveCamera\", \"fov\": 75}}", + "config": "{\"camera\": {\"cls\": \"PerspectiveCamera\", \"fov\": + 75}}", "data_shape": [1, 17192, 3], "step": 0, "content_type": 2, "wall_time": 1556678491.836787 }, { - "config": "{\"camera\": {\"cls\": \"PerspectiveCamera\", \"fov\": 75}}", + "config": "{\"camera\": {\"cls\": \"PerspectiveCamera\", \"fov\": + 75}}", "data_shape": [1, 9771, 3], "step": 0, "content_type": 3, "wall_time": 1556678491.836787 }, { - "config": "{\"camera\": {\"cls\": \"PerspectiveCamera\", \"fov\": 75}}", + "config": "{\"camera\": {\"cls\": \"PerspectiveCamera\", \"fov\": + 75}}", "data_shape": [1, 9771, 3], "step": 0, "content_type": 1, "wall_time": 1556678491.836787 } ] +``` +Scene configuration is a JSON string passed to `config_dict` during summary +creation and may contain the following high-level keys: `camera`, `lights` and +`material`. Each such key must correspond to an object with `cls` property +which must be a valid THREE.js class. The rest of the keys of the object will +be used as parameters to the class constructor and should also be valid +THREE.js options. Invalid keys will be ignored by the library. -Scene configuration is a JSON string passed to `config_dict` during summary creation and may contain the following high-level keys: `camera`, `lights` and `material`. Each such key must correspond to an object with `cls` property which must be a valid THREE.js class. The rest of the keys of the object will be used as parameters to the class constructor and should also be valid THREE.js options. Invalid keys will be ignored by the library. ## `/data/plugin/mesh/data?tag=mesh_color_tensor&run=train_run&content_type=VERTEX&sample=0` -Retrieves binary data of particular type representing some part of the mesh, for example vertices with 3D coordinates. +Retrieves binary data of particular type representing some part of the mesh, +for example vertices with 3D coordinates. -Returns stream of binary data, which will represent either mesh vertices, faces or RGB colors. Response type of this request is set to `arraybuffer` therefore Typed Array will be received instead of a JSON string. +Returns stream of binary data, which will represent either mesh vertices, +faces or RGB colors. Response type of this request is set to `arraybuffer` +therefore Typed Array will be received instead of a JSON string. diff --git a/tensorboard/plugins/mesh_visualizer/mesh_plugin.py b/tensorboard/plugins/mesh_visualizer/mesh_plugin.py index 5f2e7d610a9..be7e8e00c06 100644 --- a/tensorboard/plugins/mesh_visualizer/mesh_plugin.py +++ b/tensorboard/plugins/mesh_visualizer/mesh_plugin.py @@ -13,19 +13,25 @@ # limitations under the License. # ============================================================================== """TensorBoard 3D mesh visualizer plugin.""" +# Parser directives from __future__ import absolute_import from __future__ import division from __future__ import print_function +# Standard library modules import collections + +# Third-party modules import numpy as np import six -from tensorboard.util import tensor_util from werkzeug import wrappers + +# First-party modules from tensorboard.backend import http_util from tensorboard.plugins import base_plugin from tensorboard.plugins.mesh_visualizer import metadata from tensorboard.plugins.mesh_visualizer import plugin_data_pb2 +from tensorboard.util import tensor_util class MeshPlugin(base_plugin.TBPlugin): @@ -111,6 +117,7 @@ def get_plugin_apps(self): This method is called by TensorBoard when retrieving all the routes offered by the plugin. + Returns: A dictionary mapping URL path to route that handles it. """ @@ -119,7 +126,7 @@ def get_plugin_apps(self): return { '/tags': self._serve_tags, '/meshes': self._serve_mesh_metadata, - '/data': self._serve_mesh_data + '/data': self._serve_mesh_data, } def is_active(self): @@ -127,6 +134,7 @@ def is_active(self): This plugin is only active if TensorBoard sampled any summaries relevant to the mesh plugin. + Returns: Whether this plugin is active. """ @@ -151,6 +159,7 @@ def _get_tensor_metadata(self, event, content_type, data_shape, config): representing content type in TensorEvent. data_shape: list of dimensions sizes of the tensor. config: rendering scene configuration as dictionary. + Returns: Dictionary of transformed metadata. """ @@ -159,7 +168,7 @@ def _get_tensor_metadata(self, event, content_type, data_shape, config): 'step': event.step, 'content_type': content_type, 'config': config, - 'data_shape': list(data_shape) + 'data_shape': list(data_shape), } def _get_tensor_data(self, event, sample): @@ -177,8 +186,7 @@ def _collect_tensor_events(self, request): # Make sure we populate tags mapping structures. self.prepare_metadata() - # We fetch all the tensor events that contain tag. - tensor_events = [] # List of tuples (meta, tensor). + tensor_events = [] # List of tuples (meta, tensor) that contain tag. for instance_tag in self._tag_to_instance_tags[tag]: tensors = self._multiplexer.Tensors(run, instance_tag) meta = self._instance_tag_to_metadata[instance_tag] @@ -186,7 +194,7 @@ def _collect_tensor_events(self, request): # Make sure tensors sorted by timestamp in ascending order. tensor_events = sorted( - tensor_events, key=lambda tensor_data: tensor_data[1].wall_time) + tensor_events, key=lambda (_, event): event.wall_time) return tensor_events @@ -202,11 +210,12 @@ def _serve_mesh_data(self, request): Args: request: werkzeug.Request containing content_type as a name of enum plugin_data_pb2.MeshPluginData.ContentType. + Returns: werkzeug.Response either float32 or int32 data in binary format. """ tensor_events = self._collect_tensor_events(request) - content_type = request.args.get('content_type') + content_type = request.args['content_type'] content_type = plugin_data_pb2.MeshPluginData.ContentType.Value( content_type) sample = int(request.args.get('sample', 0)) @@ -217,11 +226,12 @@ def _serve_mesh_data(self, request): if meta.content_type == content_type ] - np_type = np.float32 - if content_type == plugin_data_pb2.MeshPluginData.ContentType.Value('FACE'): - np_type = np.int32 - elif content_type == plugin_data_pb2.MeshPluginData.ContentType.Value('COLOR'): - np_type = np.uint8 + np_type = { + plugin_data_pb2.MeshPluginData.VERTEX: np.float32, + plugin_data_pb2.MeshPluginData.FACE: np.int32, + plugin_data_pb2.MeshPluginData.COLOR: np.uint8, + }[content_type] + response = np.array(response, dtype=np_type) # Looks like reshape can take around 160ms, so why not store it reshaped. response = response.reshape(-1).tobytes() @@ -250,4 +260,4 @@ def _serve_mesh_metadata(self, request): meta.json_config) for meta, tensor in tensor_events ] - return http_util.Respond(request, response, 'application/json') \ No newline at end of file + return http_util.Respond(request, response, 'application/json') diff --git a/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py b/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py index 17212fef809..6f1704b816f 100644 --- a/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py +++ b/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py @@ -208,7 +208,7 @@ def testsEventsAlwaysSortedByWallTime(self): event_multiplexer.EventMultiplexer, "PluginRunToTagToContent", return_value={"bar": { - "foo": "".encode("utf-8") + "foo": "".encode("utf-8"), }}) def testMetadataComputedOnce(self, run_to_tag_mock): """Tests that metadata mapping computed once.""" @@ -228,4 +228,5 @@ def testIsInactive(self, get_random_mesh_stub): if __name__ == "__main__": - tf.test.main() \ No newline at end of file + tf.test.main() + \ No newline at end of file diff --git a/tensorboard/plugins/mesh_visualizer/mesh_summary.py b/tensorboard/plugins/mesh_visualizer/mesh_summary.py index 7952f18a197..2ef1ccafad2 100644 --- a/tensorboard/plugins/mesh_visualizer/mesh_summary.py +++ b/tensorboard/plugins/mesh_visualizer/mesh_summary.py @@ -98,6 +98,7 @@ def op(name, vertices, faces=None, colors=None, display_name=None, collections: Which TensorFlow graph collections to add the summary op to. Defaults to `['summaries']`. Can usually be ignored. config_dict: Dictionary with ThreeJS classes names and configuration. + Returns: Merged summary for mesh/point cloud representation. """ @@ -108,9 +109,11 @@ def op(name, vertices, faces=None, colors=None, display_name=None, # summaries internally. Those summaries will be regrouped on the client before # rendering. summaries = [] - tensors = [(vertices, plugin_data_pb2.MeshPluginData.VERTEX), - (faces, plugin_data_pb2.MeshPluginData.FACE), - (colors, plugin_data_pb2.MeshPluginData.COLOR)] + tensors = [ + (vertices, plugin_data_pb2.MeshPluginData.VERTEX), + (faces, plugin_data_pb2.MeshPluginData.FACE), + (colors, plugin_data_pb2.MeshPluginData.COLOR) + ] for tensor, content_type in tensors: if tensor is None: @@ -177,4 +180,4 @@ def pb(name, tf_summary_metadata = tf.SummaryMetadata.FromString( summary_metadata.SerializeToString()) summary.value.add(tag=tag, metadata=tf_summary_metadata, tensor=tensor) - return summary \ No newline at end of file + return summary diff --git a/tensorboard/plugins/mesh_visualizer/mesh_summary_test.py b/tensorboard/plugins/mesh_visualizer/mesh_summary_test.py index a76d612fa91..247577812f4 100644 --- a/tensorboard/plugins/mesh_visualizer/mesh_summary_test.py +++ b/tensorboard/plugins/mesh_visualizer/mesh_summary_test.py @@ -100,4 +100,5 @@ def test_pb(self): if __name__ == "__main__": - tf.test.main() \ No newline at end of file + tf.test.main() + \ No newline at end of file diff --git a/tensorboard/plugins/mesh_visualizer/metadata.py b/tensorboard/plugins/mesh_visualizer/metadata.py index 2572114c6cb..1d4b2a3f271 100644 --- a/tensorboard/plugins/mesh_visualizer/metadata.py +++ b/tensorboard/plugins/mesh_visualizer/metadata.py @@ -22,12 +22,10 @@ from tensorboard.plugins.mesh_visualizer import plugin_data_pb2 -# TODO(podlipensky): use this variable everywhere, avoid duplicating 'mesh'. -# b/129002587 PLUGIN_NAME = 'mesh' # The most recent value for the `version` field of the -# `PrCurvePluginData` proto. +# `MeshPluginData` proto. _PROTO_VERSION = 0 @@ -96,4 +94,4 @@ def parse_plugin_metadata(content): return result raise ValueError('Unknown metadata version: %s. The latest version known to ' 'this build of TensorBoard is %s; perhaps a newer build is ' - 'available?' % (result.version, get_current_version())) \ No newline at end of file + 'available?' % (result.version, get_current_version())) diff --git a/tensorboard/plugins/mesh_visualizer/metadata_test.py b/tensorboard/plugins/mesh_visualizer/metadata_test.py index 1747d3c5b25..eb9491c088e 100644 --- a/tensorboard/plugins/mesh_visualizer/metadata_test.py +++ b/tensorboard/plugins/mesh_visualizer/metadata_test.py @@ -85,4 +85,5 @@ def test_metadata_format(self): if __name__ == '__main__': - tf.test.main() \ No newline at end of file + tf.test.main() + \ No newline at end of file diff --git a/tensorboard/plugins/mesh_visualizer/test_utils.py b/tensorboard/plugins/mesh_visualizer/test_utils.py index 4b751680f02..a490a1d0223 100644 --- a/tensorboard/plugins/mesh_visualizer/test_utils.py +++ b/tensorboard/plugins/mesh_visualizer/test_utils.py @@ -91,4 +91,5 @@ def deserialize_array_buffer_response(byte_content, data_type): Returns: Flat numpy array with the data. """ - return np.frombuffer(byte_content, dtype=data_type) \ No newline at end of file + return np.frombuffer(byte_content, dtype=data_type) + \ No newline at end of file From 7104a2a1244c53f2fc1ae26f3425508cb66c822d Mon Sep 17 00:00:00 2001 From: Pavel Podlipensky Date: Fri, 3 May 2019 08:45:10 -0700 Subject: [PATCH 08/12] Rename mesh_summary to summary. --- tensorboard/plugins/mesh_visualizer/BUILD | 14 +++++++------- .../plugins/mesh_visualizer/mesh_plugin_test.py | 9 ++++----- .../{mesh_summary.py => summary.py} | 0 .../{mesh_summary_test.py => summary_test.py} | 10 +++++----- 4 files changed, 16 insertions(+), 17 deletions(-) rename tensorboard/plugins/mesh_visualizer/{mesh_summary.py => summary.py} (100%) rename tensorboard/plugins/mesh_visualizer/{mesh_summary_test.py => summary_test.py} (93%) diff --git a/tensorboard/plugins/mesh_visualizer/BUILD b/tensorboard/plugins/mesh_visualizer/BUILD index 757c2e93333..211f0c2c6b1 100644 --- a/tensorboard/plugins/mesh_visualizer/BUILD +++ b/tensorboard/plugins/mesh_visualizer/BUILD @@ -52,7 +52,7 @@ py_library( srcs = ["test_utils.py"], srcs_version = "PY2AND3", deps = [ - ":mesh_summary", + ":summary", "//tensorboard:expect_tensorflow_installed", "//tensorboard/backend:application", "//tensorboard/backend/event_processing:event_multiplexer", @@ -67,7 +67,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":mesh_plugin", - ":mesh_summary", + ":summary", ":test_utils", "//tensorboard:expect_numpy_installed", "//tensorboard:expect_tensorflow_installed", @@ -82,8 +82,8 @@ py_test( ) py_library( - name = "mesh_summary", - srcs = ["mesh_summary.py"], + name = "summary", + srcs = ["summary.py"], srcs_version = "PY2AND3", visibility = [ "//visibility:public", @@ -96,12 +96,12 @@ py_library( ) py_test( - name = "mesh_summary_test", + name = "summary_test", size = "small", - srcs = ["mesh_summary_test.py"], + srcs = ["summary_test.py"], srcs_version = "PY2AND3", deps = [ - ":mesh_summary", + ":summary", ":test_utils", "//tensorboard:expect_tensorflow_installed", "//tensorboard/util:test_util", diff --git a/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py b/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py index 6f1704b816f..593112d82d0 100644 --- a/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py +++ b/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py @@ -29,7 +29,7 @@ from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer from tensorboard.plugins import base_plugin from tensorboard.plugins.mesh_visualizer import mesh_plugin -from tensorboard.plugins.mesh_visualizer import mesh_summary +from tensorboard.plugins.mesh_visualizer import summary from tensorboard.plugins.mesh_visualizer import plugin_data_pb2 from tensorboard.plugins.mesh_visualizer import test_utils from tensorboard.util import test_util as tensorboard_test_util @@ -79,17 +79,17 @@ def setUp(self): # In case when name is present and display_name is not, we will reuse name # as display_name. Summaries below intended to test both cases. self.names = ["point_cloud", "mesh_no_color", "mesh_color"] - mesh_summary.op( + summary.op( self.names[0], point_cloud_vertices, description="just point cloud") - mesh_summary.op( + summary.op( self.names[1], mesh_no_color_vertices, faces=mesh_no_color_faces, display_name="name_to_display_in_ui", description="beautiful mesh in grayscale") - mesh_summary.op( + summary.op( self.names[2], mesh_color_vertices, faces=mesh_color_faces, @@ -229,4 +229,3 @@ def testIsInactive(self, get_random_mesh_stub): if __name__ == "__main__": tf.test.main() - \ No newline at end of file diff --git a/tensorboard/plugins/mesh_visualizer/mesh_summary.py b/tensorboard/plugins/mesh_visualizer/summary.py similarity index 100% rename from tensorboard/plugins/mesh_visualizer/mesh_summary.py rename to tensorboard/plugins/mesh_visualizer/summary.py diff --git a/tensorboard/plugins/mesh_visualizer/mesh_summary_test.py b/tensorboard/plugins/mesh_visualizer/summary_test.py similarity index 93% rename from tensorboard/plugins/mesh_visualizer/mesh_summary_test.py rename to tensorboard/plugins/mesh_visualizer/summary_test.py index 247577812f4..23f40af22ad 100644 --- a/tensorboard/plugins/mesh_visualizer/mesh_summary_test.py +++ b/tensorboard/plugins/mesh_visualizer/summary_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for tensorboard.plugins.mesh_visualizer.mesh_summary.""" +"""Tests for tensorboard.plugins.mesh_visualizer.summary.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -21,7 +21,7 @@ import tensorflow as tf from tensorboard.compat.proto import summary_pb2 -from tensorboard.plugins.mesh_visualizer import mesh_summary +from tensorboard.plugins.mesh_visualizer import summary from tensorboard.plugins.mesh_visualizer import metadata from tensorboard.plugins.mesh_visualizer import plugin_data_pb2 from tensorboard.plugins.mesh_visualizer import test_utils @@ -51,7 +51,7 @@ def test_get_tensor_summary(self): display_name = "my_display_name" description = "my mesh is the best of meshes" tensor_data = test_utils.get_random_mesh(100) - tensor_summary = mesh_summary._get_tensor_summary( + tensor_summary = summary._get_tensor_summary( name, display_name, description, tensor_data.vertices, plugin_data_pb2.MeshPluginData.VERTEX, "", None) with self.test_session(): @@ -66,7 +66,7 @@ def test_op(self): tensor_data = test_utils.get_random_mesh( 100, add_faces=True, add_colors=True) config_dict = {"foo": 1} - tensor_summary = mesh_summary.op( + tensor_summary = summary.op( name, tensor_data.vertices, faces=tensor_data.faces, @@ -86,7 +86,7 @@ def test_pb(self): tensor_data = test_utils.get_random_mesh( 100, add_faces=True, add_colors=True) config_dict = {"foo": 1} - proto = mesh_summary.pb( + proto = summary.pb( name, tensor_data.vertices, faces=tensor_data.faces, From b2bc07c08688776c18b0692bb75a7ba5be7f8f55 Mon Sep 17 00:00:00 2001 From: Pavel Podlipensky Date: Fri, 3 May 2019 09:48:58 -0700 Subject: [PATCH 09/12] Fix PY3 build. --- tensorboard/plugins/mesh_visualizer/mesh_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorboard/plugins/mesh_visualizer/mesh_plugin.py b/tensorboard/plugins/mesh_visualizer/mesh_plugin.py index be7e8e00c06..0f97489dbe9 100644 --- a/tensorboard/plugins/mesh_visualizer/mesh_plugin.py +++ b/tensorboard/plugins/mesh_visualizer/mesh_plugin.py @@ -194,7 +194,7 @@ def _collect_tensor_events(self, request): # Make sure tensors sorted by timestamp in ascending order. tensor_events = sorted( - tensor_events, key=lambda (_, event): event.wall_time) + tensor_events, key=lambda tensor_data: tensor_data[1].wall_time) return tensor_events From ba41c5e32ea82cfc771c9bd5382839dd88e35fb6 Mon Sep 17 00:00:00 2001 From: Pavel Podlipensky Date: Fri, 3 May 2019 10:42:41 -0700 Subject: [PATCH 10/12] Update api docs. --- tensorboard/plugins/mesh_visualizer/http_api.md | 2 +- tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py | 2 +- tensorboard/plugins/mesh_visualizer/metadata_test.py | 2 +- tensorboard/plugins/mesh_visualizer/summary_test.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorboard/plugins/mesh_visualizer/http_api.md b/tensorboard/plugins/mesh_visualizer/http_api.md index 6d549b8f6df..03685ca161f 100644 --- a/tensorboard/plugins/mesh_visualizer/http_api.md +++ b/tensorboard/plugins/mesh_visualizer/http_api.md @@ -36,7 +36,7 @@ Retrieves all necessary metadata to render a mesh with particular tag. Returns list of metadata for each data (tensor) that should be retrieved next. This includes content type (i.e. vertices, faces or colors), shape of the data, scene configuration, wall time etc. Type of the content maps directly to -underlying binary data type, i.e. `float32`, `int32` or `uint8`. +underlying binary data type, i.e. `float32`, `int32` or `uint8`. Content type mapping to heir enum constant representations is given by a [proto definition](https://github.com/tensorflow/tensorboard/plugins/mesh_visualizer/plugin_data.proto). Here is an example: ```json diff --git a/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py b/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py index 593112d82d0..e2529bcf3f8 100644 --- a/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py +++ b/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py @@ -41,7 +41,7 @@ import mock # pylint: disable=g-import-not-at-top,unused-import -@tensorboard_test_util.run_v1_only('Uses contrib') +@tensorboard_test_util.run_v1_only('requires tf.Session') class MeshPluginTest(tf.test.TestCase): """Tests for mesh plugin server.""" diff --git a/tensorboard/plugins/mesh_visualizer/metadata_test.py b/tensorboard/plugins/mesh_visualizer/metadata_test.py index eb9491c088e..6095f0d5efc 100644 --- a/tensorboard/plugins/mesh_visualizer/metadata_test.py +++ b/tensorboard/plugins/mesh_visualizer/metadata_test.py @@ -26,7 +26,7 @@ from tensorboard.util import test_util -@test_util.run_v1_only('Uses contrib') +@test_util.run_v1_only('requires tf.Session') class MetadataTest(tf.test.TestCase): def _create_metadata(self): diff --git a/tensorboard/plugins/mesh_visualizer/summary_test.py b/tensorboard/plugins/mesh_visualizer/summary_test.py index 23f40af22ad..86219ac1187 100644 --- a/tensorboard/plugins/mesh_visualizer/summary_test.py +++ b/tensorboard/plugins/mesh_visualizer/summary_test.py @@ -28,7 +28,7 @@ from tensorboard.util import test_util -@test_util.run_v1_only('Uses contrib') +@test_util.run_v1_only('requires tf.Session') class MeshSummaryTest(tf.test.TestCase): def pb_via_op(self, summary_op): From 0d3e5ec5b005c0eda518532502ed8c754c6117a1 Mon Sep 17 00:00:00 2001 From: Pavel Podlipensky Date: Fri, 3 May 2019 15:06:08 -0700 Subject: [PATCH 11/12] Another round of comments from William. --- tensorboard/plugins/mesh_visualizer/http_api.md | 4 +++- tensorboard/plugins/mesh_visualizer/mesh_plugin.py | 3 --- .../plugins/mesh_visualizer/mesh_plugin_test.py | 5 ++--- tensorboard/plugins/mesh_visualizer/metadata.py | 4 ++++ tensorboard/plugins/mesh_visualizer/metadata_test.py | 11 +++++++++-- tensorboard/plugins/mesh_visualizer/plugin_data.proto | 2 +- tensorboard/plugins/mesh_visualizer/summary_test.py | 1 - tensorboard/plugins/mesh_visualizer/test_utils.py | 2 ++ 8 files changed, 21 insertions(+), 11 deletions(-) diff --git a/tensorboard/plugins/mesh_visualizer/http_api.md b/tensorboard/plugins/mesh_visualizer/http_api.md index 03685ca161f..bdf0db8db24 100644 --- a/tensorboard/plugins/mesh_visualizer/http_api.md +++ b/tensorboard/plugins/mesh_visualizer/http_api.md @@ -36,7 +36,9 @@ Retrieves all necessary metadata to render a mesh with particular tag. Returns list of metadata for each data (tensor) that should be retrieved next. This includes content type (i.e. vertices, faces or colors), shape of the data, scene configuration, wall time etc. Type of the content maps directly to -underlying binary data type, i.e. `float32`, `int32` or `uint8`. Content type mapping to heir enum constant representations is given by a [proto definition](https://github.com/tensorflow/tensorboard/plugins/mesh_visualizer/plugin_data.proto). +underlying binary data type, i.e. `float32`, `int32` or `uint8`. Content type +mapping to their enum constant representations is given by a +[proto definition](https://github.com/tensorflow/tensorboard/plugins/mesh_visualizer/plugin_data.proto). Here is an example: ```json diff --git a/tensorboard/plugins/mesh_visualizer/mesh_plugin.py b/tensorboard/plugins/mesh_visualizer/mesh_plugin.py index 0f97489dbe9..5c83f9cf176 100644 --- a/tensorboard/plugins/mesh_visualizer/mesh_plugin.py +++ b/tensorboard/plugins/mesh_visualizer/mesh_plugin.py @@ -106,9 +106,6 @@ def _serve_tags(self, request): # Make sure we only operate on user-defined tags here. tag = self._instance_tag_to_tag[instance_tag] meta = self._instance_tag_to_metadata[instance_tag] - # Shape should be at least BxNx3 where B represents the batch dimensions - # and N - the number of points, each with x,y,z coordinates. - assert len(meta.shape) >= 3 response[run][tag] = {'samples': meta.shape[0]} return http_util.Respond(request, response, 'application/json') diff --git a/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py b/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py index e2529bcf3f8..c0df82d9ad1 100644 --- a/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py +++ b/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py @@ -207,9 +207,8 @@ def testsEventsAlwaysSortedByWallTime(self): @mock.patch.object( event_multiplexer.EventMultiplexer, "PluginRunToTagToContent", - return_value={"bar": { - "foo": "".encode("utf-8"), - }}) + return_value={"bar": {"foo": "".encode("utf-8")}}, + ) def testMetadataComputedOnce(self, run_to_tag_mock): """Tests that metadata mapping computed once.""" self.plugin.prepare_metadata() diff --git a/tensorboard/plugins/mesh_visualizer/metadata.py b/tensorboard/plugins/mesh_visualizer/metadata.py index 1d4b2a3f271..58fcdf9c77d 100644 --- a/tensorboard/plugins/mesh_visualizer/metadata.py +++ b/tensorboard/plugins/mesh_visualizer/metadata.py @@ -61,6 +61,10 @@ def create_summary_metadata(name, Returns: A `summary_pb2.SummaryMetadata` protobuf object. """ + # Shape should be at least BxNx3 where B represents the batch dimensions + # and N - the number of points, each with x,y,z coordinates. + if len(shape) != 3: + raise ValueError('Tensor shape should be of shape BxNx3, but got %s.' % str(shape)) mesh_plugin_data = plugin_data_pb2.MeshPluginData( version=get_current_version(), name=name, diff --git a/tensorboard/plugins/mesh_visualizer/metadata_test.py b/tensorboard/plugins/mesh_visualizer/metadata_test.py index 6095f0d5efc..c9f7b779a13 100644 --- a/tensorboard/plugins/mesh_visualizer/metadata_test.py +++ b/tensorboard/plugins/mesh_visualizer/metadata_test.py @@ -29,12 +29,14 @@ @test_util.run_v1_only('requires tf.Session') class MetadataTest(tf.test.TestCase): - def _create_metadata(self): + def _create_metadata(self, shape=None): """Creates metadata with dummy data.""" self.name = 'unique_name' self.display_name = 'my mesh' self.json_config = '{}' - self.shape = [1, 100, 3] + if shape is None: + shape = [1, 100, 3] + self.shape = shape self.summary_metadata = metadata.create_summary_metadata( self.name, self.display_name, @@ -78,6 +80,11 @@ def test_metadata_version(self): metadata.parse_plugin_metadata( self.summary_metadata.plugin_data.content) + def test_tensor_shape(self): + """Tests that target tensor should be of particular shape.""" + with six.assertRaisesRegex(self, ValueError, r'Tensor shape should be of shape BxNx3.*'): + self._create_metadata([1]) + def test_metadata_format(self): """Tests that metadata content must be passed as a serialized string.""" with six.assertRaisesRegex(self, TypeError, r'Content type must be bytes.'): diff --git a/tensorboard/plugins/mesh_visualizer/plugin_data.proto b/tensorboard/plugins/mesh_visualizer/plugin_data.proto index c87dd8bdf33..67ad389b4e7 100644 --- a/tensorboard/plugins/mesh_visualizer/plugin_data.proto +++ b/tensorboard/plugins/mesh_visualizer/plugin_data.proto @@ -26,4 +26,4 @@ message MeshPluginData { // Shape of underlying data. Cache it here for performance reasons. repeated int32 shape = 6; -} \ No newline at end of file +} diff --git a/tensorboard/plugins/mesh_visualizer/summary_test.py b/tensorboard/plugins/mesh_visualizer/summary_test.py index 86219ac1187..a4d627c9f0c 100644 --- a/tensorboard/plugins/mesh_visualizer/summary_test.py +++ b/tensorboard/plugins/mesh_visualizer/summary_test.py @@ -101,4 +101,3 @@ def test_pb(self): if __name__ == "__main__": tf.test.main() - \ No newline at end of file diff --git a/tensorboard/plugins/mesh_visualizer/test_utils.py b/tensorboard/plugins/mesh_visualizer/test_utils.py index a490a1d0223..1b8a271d200 100644 --- a/tensorboard/plugins/mesh_visualizer/test_utils.py +++ b/tensorboard/plugins/mesh_visualizer/test_utils.py @@ -20,8 +20,10 @@ import collections import json import threading + import numpy as np import tensorflow as tf + from tensorboard.compat.proto import event_pb2 from tensorboard.compat.proto import graph_pb2 from tensorboard.compat.proto import meta_graph_pb2 From f5041b03ef7d13b6c7bd9d676892ee02f53462ec Mon Sep 17 00:00:00 2001 From: Pavel Podlipensky Date: Fri, 3 May 2019 15:45:14 -0700 Subject: [PATCH 12/12] Rename mesh_visualizer to mesh. --- tensorboard/plugins/{mesh_visualizer => mesh}/BUILD | 0 .../plugins/{mesh_visualizer => mesh}/__init__.py | 0 .../plugins/{mesh_visualizer => mesh}/http_api.md | 2 +- .../plugins/{mesh_visualizer => mesh}/mesh_plugin.py | 4 ++-- .../{mesh_visualizer => mesh}/mesh_plugin_test.py | 8 ++++---- .../plugins/{mesh_visualizer => mesh}/metadata.py | 2 +- .../plugins/{mesh_visualizer => mesh}/metadata_test.py | 4 ++-- .../{mesh_visualizer => mesh}/plugin_data.proto | 2 +- .../plugins/{mesh_visualizer => mesh}/summary.py | 4 ++-- .../plugins/{mesh_visualizer => mesh}/summary_test.py | 10 +++++----- .../plugins/{mesh_visualizer => mesh}/test_utils.py | 0 11 files changed, 18 insertions(+), 18 deletions(-) rename tensorboard/plugins/{mesh_visualizer => mesh}/BUILD (100%) rename tensorboard/plugins/{mesh_visualizer => mesh}/__init__.py (100%) rename tensorboard/plugins/{mesh_visualizer => mesh}/http_api.md (98%) rename tensorboard/plugins/{mesh_visualizer => mesh}/mesh_plugin.py (98%) rename tensorboard/plugins/{mesh_visualizer => mesh}/mesh_plugin_test.py (97%) rename tensorboard/plugins/{mesh_visualizer => mesh}/metadata.py (98%) rename tensorboard/plugins/{mesh_visualizer => mesh}/metadata_test.py (96%) rename tensorboard/plugins/{mesh_visualizer => mesh}/plugin_data.proto (94%) rename tensorboard/plugins/{mesh_visualizer => mesh}/summary.py (98%) rename tensorboard/plugins/{mesh_visualizer => mesh}/summary_test.py (92%) rename tensorboard/plugins/{mesh_visualizer => mesh}/test_utils.py (100%) diff --git a/tensorboard/plugins/mesh_visualizer/BUILD b/tensorboard/plugins/mesh/BUILD similarity index 100% rename from tensorboard/plugins/mesh_visualizer/BUILD rename to tensorboard/plugins/mesh/BUILD diff --git a/tensorboard/plugins/mesh_visualizer/__init__.py b/tensorboard/plugins/mesh/__init__.py similarity index 100% rename from tensorboard/plugins/mesh_visualizer/__init__.py rename to tensorboard/plugins/mesh/__init__.py diff --git a/tensorboard/plugins/mesh_visualizer/http_api.md b/tensorboard/plugins/mesh/http_api.md similarity index 98% rename from tensorboard/plugins/mesh_visualizer/http_api.md rename to tensorboard/plugins/mesh/http_api.md index bdf0db8db24..92fd0875106 100644 --- a/tensorboard/plugins/mesh_visualizer/http_api.md +++ b/tensorboard/plugins/mesh/http_api.md @@ -38,7 +38,7 @@ This includes content type (i.e. vertices, faces or colors), shape of the data, scene configuration, wall time etc. Type of the content maps directly to underlying binary data type, i.e. `float32`, `int32` or `uint8`. Content type mapping to their enum constant representations is given by a -[proto definition](https://github.com/tensorflow/tensorboard/plugins/mesh_visualizer/plugin_data.proto). +[proto definition](https://github.com/tensorflow/tensorboard/plugins/mesh/plugin_data.proto). Here is an example: ```json diff --git a/tensorboard/plugins/mesh_visualizer/mesh_plugin.py b/tensorboard/plugins/mesh/mesh_plugin.py similarity index 98% rename from tensorboard/plugins/mesh_visualizer/mesh_plugin.py rename to tensorboard/plugins/mesh/mesh_plugin.py index 5c83f9cf176..88d0a8bce0b 100644 --- a/tensorboard/plugins/mesh_visualizer/mesh_plugin.py +++ b/tensorboard/plugins/mesh/mesh_plugin.py @@ -29,8 +29,8 @@ # First-party modules from tensorboard.backend import http_util from tensorboard.plugins import base_plugin -from tensorboard.plugins.mesh_visualizer import metadata -from tensorboard.plugins.mesh_visualizer import plugin_data_pb2 +from tensorboard.plugins.mesh import metadata +from tensorboard.plugins.mesh import plugin_data_pb2 from tensorboard.util import tensor_util diff --git a/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py b/tensorboard/plugins/mesh/mesh_plugin_test.py similarity index 97% rename from tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py rename to tensorboard/plugins/mesh/mesh_plugin_test.py index c0df82d9ad1..15554f7c156 100644 --- a/tensorboard/plugins/mesh_visualizer/mesh_plugin_test.py +++ b/tensorboard/plugins/mesh/mesh_plugin_test.py @@ -28,10 +28,10 @@ from tensorboard.backend import application from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer from tensorboard.plugins import base_plugin -from tensorboard.plugins.mesh_visualizer import mesh_plugin -from tensorboard.plugins.mesh_visualizer import summary -from tensorboard.plugins.mesh_visualizer import plugin_data_pb2 -from tensorboard.plugins.mesh_visualizer import test_utils +from tensorboard.plugins.mesh import mesh_plugin +from tensorboard.plugins.mesh import summary +from tensorboard.plugins.mesh import plugin_data_pb2 +from tensorboard.plugins.mesh import test_utils from tensorboard.util import test_util as tensorboard_test_util try: diff --git a/tensorboard/plugins/mesh_visualizer/metadata.py b/tensorboard/plugins/mesh/metadata.py similarity index 98% rename from tensorboard/plugins/mesh_visualizer/metadata.py rename to tensorboard/plugins/mesh/metadata.py index 58fcdf9c77d..98126bf00ab 100644 --- a/tensorboard/plugins/mesh_visualizer/metadata.py +++ b/tensorboard/plugins/mesh/metadata.py @@ -19,7 +19,7 @@ from __future__ import print_function from tensorboard.compat.proto import summary_pb2 -from tensorboard.plugins.mesh_visualizer import plugin_data_pb2 +from tensorboard.plugins.mesh import plugin_data_pb2 PLUGIN_NAME = 'mesh' diff --git a/tensorboard/plugins/mesh_visualizer/metadata_test.py b/tensorboard/plugins/mesh/metadata_test.py similarity index 96% rename from tensorboard/plugins/mesh_visualizer/metadata_test.py rename to tensorboard/plugins/mesh/metadata_test.py index c9f7b779a13..ca31e483f65 100644 --- a/tensorboard/plugins/mesh_visualizer/metadata_test.py +++ b/tensorboard/plugins/mesh/metadata_test.py @@ -21,8 +21,8 @@ from mock import patch import six import tensorflow as tf -from tensorboard.plugins.mesh_visualizer import metadata -from tensorboard.plugins.mesh_visualizer import plugin_data_pb2 +from tensorboard.plugins.mesh import metadata +from tensorboard.plugins.mesh import plugin_data_pb2 from tensorboard.util import test_util diff --git a/tensorboard/plugins/mesh_visualizer/plugin_data.proto b/tensorboard/plugins/mesh/plugin_data.proto similarity index 94% rename from tensorboard/plugins/mesh_visualizer/plugin_data.proto rename to tensorboard/plugins/mesh/plugin_data.proto index 67ad389b4e7..f1206c93675 100644 --- a/tensorboard/plugins/mesh_visualizer/plugin_data.proto +++ b/tensorboard/plugins/mesh/plugin_data.proto @@ -1,6 +1,6 @@ syntax = "proto3"; -package tensorboard.mesh_visualizer; +package tensorboard.mesh; // A MeshPluginData encapsulates information on which plugins are able to make // use of a certain summary value. diff --git a/tensorboard/plugins/mesh_visualizer/summary.py b/tensorboard/plugins/mesh/summary.py similarity index 98% rename from tensorboard/plugins/mesh_visualizer/summary.py rename to tensorboard/plugins/mesh/summary.py index 2ef1ccafad2..47cb2129ca5 100644 --- a/tensorboard/plugins/mesh_visualizer/summary.py +++ b/tensorboard/plugins/mesh/summary.py @@ -20,8 +20,8 @@ import json import tensorflow as tf -from tensorboard.plugins.mesh_visualizer import metadata -from tensorboard.plugins.mesh_visualizer import plugin_data_pb2 +from tensorboard.plugins.mesh import metadata +from tensorboard.plugins.mesh import plugin_data_pb2 PLUGIN_NAME = 'mesh' diff --git a/tensorboard/plugins/mesh_visualizer/summary_test.py b/tensorboard/plugins/mesh/summary_test.py similarity index 92% rename from tensorboard/plugins/mesh_visualizer/summary_test.py rename to tensorboard/plugins/mesh/summary_test.py index a4d627c9f0c..244fd0c25c3 100644 --- a/tensorboard/plugins/mesh_visualizer/summary_test.py +++ b/tensorboard/plugins/mesh/summary_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for tensorboard.plugins.mesh_visualizer.summary.""" +"""Tests for tensorboard.plugins.mesh.summary.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -21,10 +21,10 @@ import tensorflow as tf from tensorboard.compat.proto import summary_pb2 -from tensorboard.plugins.mesh_visualizer import summary -from tensorboard.plugins.mesh_visualizer import metadata -from tensorboard.plugins.mesh_visualizer import plugin_data_pb2 -from tensorboard.plugins.mesh_visualizer import test_utils +from tensorboard.plugins.mesh import summary +from tensorboard.plugins.mesh import metadata +from tensorboard.plugins.mesh import plugin_data_pb2 +from tensorboard.plugins.mesh import test_utils from tensorboard.util import test_util diff --git a/tensorboard/plugins/mesh_visualizer/test_utils.py b/tensorboard/plugins/mesh/test_utils.py similarity index 100% rename from tensorboard/plugins/mesh_visualizer/test_utils.py rename to tensorboard/plugins/mesh/test_utils.py