diff --git a/tensorboard/uploader/BUILD b/tensorboard/uploader/BUILD index f21cf56e18..8568efba82 100644 --- a/tensorboard/uploader/BUILD +++ b/tensorboard/uploader/BUILD @@ -68,6 +68,7 @@ py_library( visibility = ["//tensorboard:internal"], deps = [ ":auth", + ":dry_run_stubs", ":exporter", ":flags_parser", ":formatters", @@ -121,9 +122,12 @@ py_test( srcs = ["uploader_test.py"], srcs_version = "PY3", deps = [ + ":dry_run_stubs", + ":server_info", ":test_util", ":upload_tracker", ":uploader", + ":uploader_subcommand", ":util", "//tensorboard:data_compat", "//tensorboard:dataclass_compat", @@ -156,6 +160,26 @@ py_test( ], ) +py_library( + name = "dry_run_stubs", + srcs = ["dry_run_stubs.py"], + deps = [ + "//tensorboard/uploader/proto:protos_all_py_pb2", + "//tensorboard/uploader/proto:protos_all_py_pb2_grpc", + ], +) + +py_test( + name = "dry_run_stubs_test", + srcs = ["dry_run_stubs_test.py"], + srcs_version = "PY3", + deps = [ + ":dry_run_stubs", + "//tensorboard:test", + "//tensorboard/uploader/proto:protos_all_py_pb2", + ], +) + py_library( name = "auth", srcs = ["auth.py"], diff --git a/tensorboard/uploader/dry_run_stubs.py b/tensorboard/uploader/dry_run_stubs.py new file mode 100644 index 0000000000..f66f4db9a2 --- /dev/null +++ b/tensorboard/uploader/dry_run_stubs.py @@ -0,0 +1,57 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Dry-run stubs for various rpc services.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorboard.uploader.proto import write_service_pb2 +from tensorboard.uploader.proto import write_service_pb2_grpc + + +class DryRunTensorBoardWriterStub(object): + """A dry-run TensorBoardWriter gRPC Server. + + Only the methods used by the `tensorboard dev upload` are + mocked out in this class. + + When additional methods start to be used by the command, + their mocks should be added to this class. + """ + + def CreateExperiment(self, request, **kwargs): + """Create a new experiment and remember it has been created.""" + del request, kwargs # Unused. + return write_service_pb2.CreateExperimentResponse() + + def WriteScalar(self, request, **kwargs): + del request, kwargs # Unused. + return write_service_pb2.WriteScalarResponse() + + def WriteTensor(self, request, **kwargs): + del request, kwargs # Unused. + return write_service_pb2.WriteTensorResponse() + + def GetOrCreateBlobSequence(self, request, **kwargs): + del request, kwargs # Unused. + return write_service_pb2.GetOrCreateBlobSequenceResponse( + blob_sequence_id="dummy_blob_sequence_id" + ) + + def WriteBlob(self, request, **kwargs): + del kwargs # Unused. + for item in request: + yield write_service_pb2.WriteBlobResponse() diff --git a/tensorboard/uploader/dry_run_stubs_test.py b/tensorboard/uploader/dry_run_stubs_test.py new file mode 100644 index 0000000000..cf5d91c1d4 --- /dev/null +++ b/tensorboard/uploader/dry_run_stubs_test.py @@ -0,0 +1,55 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for dry-run rpc servicers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorboard import test as tb_test +from tensorboard.uploader import dry_run_stubs +from tensorboard.uploader.proto import write_service_pb2 + + +class DryRunTensorBoardWriterServicerTest(tb_test.TestCase): + def setUp(self): + super(DryRunTensorBoardWriterServicerTest, self).setUp() + self._stub = dry_run_stubs.DryRunTensorBoardWriterStub() + + def testCreateExperiment(self): + self._stub.CreateExperiment(write_service_pb2.CreateExperimentRequest()) + + def testWriteScalar(self): + self._stub.WriteScalar(write_service_pb2.WriteScalarRequest()) + + def testWriteTensor(self): + self._stub.WriteTensor(write_service_pb2.WriteTensorRequest()) + + def testGetOrCreateBlobSequence(self): + self._stub.GetOrCreateBlobSequence( + write_service_pb2.GetOrCreateBlobSequenceRequest() + ) + + def testWriteBlob(self): + def dummy_iterator(): + yield write_service_pb2.WriteBlobRequest() + yield write_service_pb2.WriteBlobRequest() + + for response in self._stub.WriteBlob(dummy_iterator()): + self.assertTrue(response) + + +if __name__ == "__main__": + tb_test.main() diff --git a/tensorboard/uploader/flags_parser.py b/tensorboard/uploader/flags_parser.py index 5d2297da21..44f5d1237c 100644 --- a/tensorboard/uploader/flags_parser.py +++ b/tensorboard/uploader/flags_parser.py @@ -106,6 +106,21 @@ def define_flags(parser): "0: no statistics printed during uploading. 1 (default): print data " "statistics as data is uploaded.", ) + upload.add_argument( + "--dry_run", + action="store_true", + help="Perform a dry run of uploading. In a dry run, the data is read " + "from the logdir as pointed to by the --logdir flag and statistics are " + "displayed (if --verbose is not 0), but no data is actually uploaded " + "to the server.", + ) + upload.add_argument( + "--one_shot", + action="store_true", + help="Upload only the existing data in the logdir and then exit " + "immediately, instead of continuing to listen for new data in the " + "logdir.", + ) upload.add_argument( "--plugins", type=lambda option: option.split(","), diff --git a/tensorboard/uploader/upload_tracker.py b/tensorboard/uploader/upload_tracker.py index 92777efa46..11ae23f31a 100644 --- a/tensorboard/uploader/upload_tracker.py +++ b/tensorboard/uploader/upload_tracker.py @@ -18,6 +18,8 @@ from __future__ import division from __future__ import print_function +from absl import logging + import contextlib from datetime import datetime import sys @@ -156,6 +158,21 @@ def blob_bytes_skipped(self): def plugin_names(self): return self._plugin_names + def has_data(self): + """Has any data been tracked by this instance. + + This counts the tensor and blob data that have been scanned + but skipped. + + Returns: + Whether this stats tracking object has tracked any data. + """ + return ( + self._num_scalars > 0 + or self._num_tensors > 0 + or self._num_blobs > 0 + ) + def summarize(self): """Get a summary string for actually-uploaded and skipped data. @@ -255,6 +272,7 @@ def __init__(self, verbosity): ) self._verbosity = verbosity self._stats = UploadStats() + self._send_count = 0 def _dummy_generator(self): while True: @@ -264,14 +282,30 @@ def _dummy_generator(self): def _update_uploading_status(self, message, color_code=_STYLE_GREEN): if not self._verbosity: return - message += "." * 3 sys.stdout.write( _STYLE_ERASE_LINE + color_code + message + _STYLE_RESET + "\r" ) sys.stdout.flush() + def _upload_start(self): + """Write an update indicating the start of the uploading.""" + if not self._verbosity: + return + start_message = "%s[%s]%s Uploader started.\n" % ( + _STYLE_BOLD, + readable_time_string(), + _STYLE_RESET, + ) + sys.stdout.write(start_message) + sys.stdout.flush() + + def has_data(self): + """Determine if any data has been uploaded under the tracker's watch.""" + return self._stats.has_data() + def _update_cumulative_status(self): + """Write an update summarizing the data uploaded since the start.""" if not self._verbosity: return if not self._stats.has_new_data_since_last_summarize(): @@ -299,6 +333,9 @@ def add_plugin_name(self, plugin_name): @contextlib.contextmanager def send_tracker(self): """Create a context manager for a round of data sending.""" + self._send_count += 1 + if self._send_count == 1: + self._upload_start() try: # self._reset_bars() self._update_uploading_status("Data upload starting") diff --git a/tensorboard/uploader/upload_tracker_test.py b/tensorboard/uploader/upload_tracker_test.py index e7d219b1df..847d493fae 100644 --- a/tensorboard/uploader/upload_tracker_test.py +++ b/tensorboard/uploader/upload_tracker_test.py @@ -85,16 +85,6 @@ def testAddTensorsNumTensorsSkippedGreaterThanNumTenosrsErrors(self): tensor_bytes_skipped=0, ) - def testAddTensorsNumTensorsSkippedGreaterThanNumTenosrsErrors(self): - stats = upload_tracker.UploadStats() - with self.assertRaises(AssertionError): - stats.add_tensors( - num_tensors=10, - num_tensors_skipped=12, - tensor_bytes=1000, - tensor_bytes_skipped=0, - ) - def testAddBlob(self): stats = upload_tracker.UploadStats() stats.add_blob(blob_bytes=1000, is_skipped=False) @@ -185,6 +175,45 @@ def testHasNewDataSinceLastSummarizeReturnsTrueAfterNewTensors(self): stats.add_blob(blob_bytes=2000, is_skipped=True) self.assertEqual(stats.has_new_data_since_last_summarize(), True) + def testHasDataInitiallyReturnsFalse(self): + stats = upload_tracker.UploadStats() + self.assertEqual(stats.has_data(), False) + + def testHasDataReturnsTrueWithScalars(self): + stats = upload_tracker.UploadStats() + stats.add_scalars(1) + self.assertEqual(stats.has_data(), True) + + def testHasDataReturnsTrueWithUnskippedTensors(self): + stats = upload_tracker.UploadStats() + stats.add_tensors( + num_tensors=10, + num_tensors_skipped=0, + tensor_bytes=1000, + tensor_bytes_skipped=0, + ) + self.assertEqual(stats.has_data(), True) + + def testHasDataReturnsTrueWithSkippedTensors(self): + stats = upload_tracker.UploadStats() + stats.add_tensors( + num_tensors=10, + num_tensors_skipped=10, + tensor_bytes=1000, + tensor_bytes_skipped=1000, + ) + self.assertEqual(stats.has_data(), True) + + def testHasDataReturnsTrueWithUnskippedBlob(self): + stats = upload_tracker.UploadStats() + stats.add_blob(blob_bytes=1000, is_skipped=False) + self.assertEqual(stats.has_data(), True) + + def testHasDataReturnsTrueWithSkippedBlob(self): + stats = upload_tracker.UploadStats() + stats.add_blob(blob_bytes=1000, is_skipped=True) + self.assertEqual(stats.has_data(), True) + class UploadTrackerTest(tb_test.TestCase): """Test for the UploadTracker class.""" @@ -213,17 +242,18 @@ def tearDown(self): def testSendTracker(self): tracker = upload_tracker.UploadTracker(verbosity=1) with tracker.send_tracker(): - self.assertEqual(self.mock_write.call_count, 1) - self.assertEqual(self.mock_flush.call_count, 1) + self.assertEqual(self.mock_write.call_count, 2) + self.assertEqual(self.mock_flush.call_count, 2) self.assertIn( "Data upload starting...", self.mock_write.call_args[0][0], ) - self.assertEqual(self.mock_write.call_count, 2) - self.assertEqual(self.mock_flush.call_count, 2) + self.assertEqual(self.mock_write.call_count, 3) + self.assertEqual(self.mock_flush.call_count, 3) self.assertIn( "Listening for new data in logdir...", self.mock_write.call_args[0][0], ) + self.assertEqual(tracker.has_data(), False) def testSendTrackerWithVerbosity0(self): tracker = upload_tracker.UploadTracker(verbosity=0) @@ -243,6 +273,7 @@ def testScalarsTracker(self): ) self.assertEqual(self.mock_write.call_count, 1) self.assertEqual(self.mock_flush.call_count, 1) + self.assertEqual(tracker.has_data(), True) def testScalarsTrackerWithVerbosity0(self): tracker = upload_tracker.UploadTracker(verbosity=0) @@ -266,6 +297,7 @@ def testTensorsTrackerWithSkippedTensors(self): "Uploading 150 tensors (2.0 kB) (Skipping 50 tensors, 3.9 kB)", self.mock_write.call_args[0][0], ) + self.assertEqual(tracker.has_data(), True) def testTensorsTrackerWithVerbosity0(self): tracker = upload_tracker.UploadTracker(verbosity=0) @@ -294,6 +326,7 @@ def testTensorsTrackerWithoutSkippedTensors(self): "Uploading 200 tensors (5.9 kB)", self.mock_write.call_args[0][0], ) + self.assertEqual(tracker.has_data(), True) def testBlobTrackerUploaded(self): tracker = upload_tracker.UploadTracker(verbosity=1) @@ -316,28 +349,32 @@ def testBlobTrackerWithVerbosity0(self): def testBlobTrackerNotUploaded(self): tracker = upload_tracker.UploadTracker(verbosity=1) with tracker.send_tracker(): - self.assertEqual(self.mock_write.call_count, 1) - self.assertEqual(self.mock_flush.call_count, 1) + self.assertEqual(self.mock_write.call_count, 2) + self.assertEqual(self.mock_flush.call_count, 2) + self.assertIn( + "Uploader started.", self.mock_write.call_args_list[0][0][0], + ) with tracker.blob_tracker( blob_bytes=2048 * 1024 * 1024 ) as blob_tracker: - self.assertEqual(self.mock_write.call_count, 2) - self.assertEqual(self.mock_flush.call_count, 2) + self.assertEqual(self.mock_write.call_count, 3) + self.assertEqual(self.mock_flush.call_count, 3) self.assertIn( "Uploading binary object (2048.0 MB)", self.mock_write.call_args[0][0], ) blob_tracker.mark_uploaded(is_uploaded=False) - self.assertEqual(self.mock_write.call_count, 5) - self.assertEqual(self.mock_flush.call_count, 4) + self.assertEqual(self.mock_write.call_count, 6) + self.assertEqual(self.mock_flush.call_count, 5) self.assertIn( "Total uploaded: 0 scalars, 0 tensors, 0 binary objects\n", - self.mock_write.call_args_list[2][0][0], + self.mock_write.call_args_list[3][0][0], ) self.assertIn( "Total skipped: 1 binary objects (2048.0 MB)\n", - self.mock_write.call_args_list[3][0][0], + self.mock_write.call_args_list[4][0][0], ) + self.assertEqual(tracker.has_data(), True) def testInvalidVerbosityRaisesError(self): with self.assertRaises(ValueError): diff --git a/tensorboard/uploader/uploader.py b/tensorboard/uploader/uploader.py index 9cfb2e0d64..3d12658896 100644 --- a/tensorboard/uploader/uploader.py +++ b/tensorboard/uploader/uploader.py @@ -78,6 +78,7 @@ def __init__( name=None, description=None, verbosity=None, + one_shot=None, ): """Constructs a TensorBoardUploader. @@ -103,6 +104,10 @@ def __init__( verbosity: Level of verbosity, an integer. Supported value: 0 - No upload statistics is printed. 1 - Print upload statistics while uploading data (default). + one_shot: Once uploading starts, upload only the existing data in + the logdir and then return immediately, instead of the default + behavior of continuing to listen for new data in the logdir and + upload them when it appears. """ self._api = writer_client self._logdir = logdir @@ -112,6 +117,7 @@ def __init__( self._name = name self._description = description self._verbosity = 1 if verbosity is None else verbosity + self._one_shot = False if one_shot is None else one_shot self._request_sender = None if logdir_poll_rate_limiter is None: self._logdir_poll_rate_limiter = util.RateLimiter( @@ -191,6 +197,13 @@ def start_uploading(self): while True: self._logdir_poll_rate_limiter.tick() self._upload_once() + if self._one_shot: + break + if self._one_shot and not self._tracker.has_data(): + logger.warning( + "One-shot mode was used on a logdir (%s) " + "without any uploadable data" % self._logdir + ) def _upload_once(self): """Runs one upload cycle, sending zero or more RPCs.""" diff --git a/tensorboard/uploader/uploader_subcommand.py b/tensorboard/uploader/uploader_subcommand.py index ef0b43be8f..9ce11ce071 100644 --- a/tensorboard/uploader/uploader_subcommand.py +++ b/tensorboard/uploader/uploader_subcommand.py @@ -33,6 +33,7 @@ from tensorboard.uploader.proto import export_service_pb2_grpc from tensorboard.uploader.proto import write_service_pb2_grpc from tensorboard.uploader import auth +from tensorboard.uploader import dry_run_stubs from tensorboard.uploader import exporter as exporter_lib from tensorboard.uploader import flags_parser from tensorboard.uploader import formatters @@ -375,7 +376,7 @@ def _die_if_bad_experiment_description(description): ) -class _UploadIntent(_Intent): +class UploadIntent(_Intent): """The user intends to upload an experiment from the given logdir.""" _MESSAGE_TEMPLATE = textwrap.dedent( @@ -391,20 +392,31 @@ class _UploadIntent(_Intent): ) def __init__( - self, logdir, name=None, description=None, verbosity=None, + self, + logdir, + name=None, + description=None, + verbosity=None, + dry_run=None, + one_shot=None, ): self.logdir = logdir self.name = name self.description = description self.verbosity = verbosity + self.dry_run = False if dry_run is None else dry_run + self.one_shot = False if one_shot is None else one_shot def get_ack_message_body(self): return self._MESSAGE_TEMPLATE.format(logdir=self.logdir) def execute(self, server_info, channel): - api_client = write_service_pb2_grpc.TensorBoardWriterServiceStub( - channel - ) + if self.dry_run: + api_client = dry_run_stubs.DryRunTensorBoardWriterStub() + else: + api_client = write_service_pb2_grpc.TensorBoardWriterServiceStub( + channel + ) _die_if_bad_experiment_name(self.name) _die_if_bad_experiment_description(self.description) uploader = uploader_lib.TensorBoardUploader( @@ -415,6 +427,7 @@ def execute(self, server_info, channel): name=self.name, description=self.description, verbosity=self.verbosity, + one_shot=self.one_shot, ) experiment_id = uploader.create_experiment() url = server_info_lib.experiment_url(server_info, experiment_id) @@ -422,19 +435,24 @@ def execute(self, server_info, channel): "Upload started and will continue reading any new data as it's added" ) print("to the logdir. To stop uploading, press Ctrl-C.") - print("View your TensorBoard live at: %s" % url) + if self.dry_run: + print( + "\n** This is a dry run. " + "No data will be sent to tensorboard.dev. **\n" + ) + else: + print("View your TensorBoard live at: %s" % url) try: uploader.start_uploading() except uploader_lib.ExperimentNotFoundError: print("Experiment was deleted; uploading has been cancelled") return except KeyboardInterrupt: - print() - print("Upload stopped. View your TensorBoard at %s" % url) - return - # TODO(@nfelt): make it possible for the upload cycle to end once we - # detect that no more runs are active, so this code can be reached. - print("Done! View your TensorBoard at %s" % url) + pass + finally: + if not self.dry_run: + print() + print("Done! View your TensorBoard at %s" % url) class _ExportIntent(_Intent): @@ -503,11 +521,13 @@ def _get_intent(flags): raise base_plugin.FlagsError("Must specify subcommand (try --help).") if cmd == flags_parser.SUBCOMMAND_KEY_UPLOAD: if flags.logdir: - return _UploadIntent( + return UploadIntent( os.path.expanduser(flags.logdir), name=flags.name, description=flags.description, verbosity=flags.verbose, + dry_run=flags.dry_run, + one_shot=flags.one_shot, ) else: raise base_plugin.FlagsError( diff --git a/tensorboard/uploader/uploader_test.py b/tensorboard/uploader/uploader_test.py index 9e261bb0be..15e76c0eee 100644 --- a/tensorboard/uploader/uploader_test.py +++ b/tensorboard/uploader/uploader_test.py @@ -43,10 +43,13 @@ from tensorboard.uploader.proto import server_info_pb2 from tensorboard.uploader.proto import write_service_pb2 from tensorboard.uploader.proto import write_service_pb2_grpc +from tensorboard.uploader import dry_run_stubs from tensorboard.uploader import test_util from tensorboard.uploader import upload_tracker from tensorboard.uploader import uploader as uploader_lib +from tensorboard.uploader import uploader_subcommand from tensorboard.uploader import logdir_loader +from tensorboard.uploader import server_info as server_info_lib from tensorboard.uploader import util from tensorboard.compat.proto import event_pb2 from tensorboard.compat.proto import graph_pb2 @@ -130,6 +133,7 @@ def _create_uploader( name=None, description=None, verbosity=0, # Use 0 to minimize littering the test output. + one_shot=None, ): if writer_client is _USE_DEFAULT: writer_client = _create_mock_client() @@ -148,12 +152,13 @@ def _create_uploader( if blob_rpc_rate_limiter is _USE_DEFAULT: blob_rpc_rate_limiter = util.RateLimiter(0) - upload_limits = server_info_pb2.UploadLimits() - upload_limits.max_scalar_request_size = max_scalar_request_size - upload_limits.max_tensor_request_size = 128000 - upload_limits.max_blob_request_size = max_blob_request_size - upload_limits.max_blob_size = max_blob_size - upload_limits.max_tensor_point_size = 11111 + upload_limits = server_info_pb2.UploadLimits( + max_scalar_request_size=max_scalar_request_size, + max_tensor_request_size=128000, + max_tensor_point_size=11111, + max_blob_request_size=max_blob_request_size, + max_blob_size=max_blob_size, + ) return uploader_lib.TensorBoardUploader( writer_client, @@ -167,6 +172,7 @@ def _create_uploader( name=name, description=description, verbosity=verbosity, + one_shot=one_shot, ) @@ -178,11 +184,12 @@ def _create_request_sender( if allowed_plugins is _USE_DEFAULT: allowed_plugins = _SCALARS_HISTOGRAMS_AND_GRAPHS - upload_limits = server_info_pb2.UploadLimits() - upload_limits.max_blob_size = 12345 - upload_limits.max_tensor_point_size = 11111 - upload_limits.max_scalar_request_size = 128000 - upload_limits.max_tensor_request_size = 128000 + upload_limits = server_info_pb2.UploadLimits( + max_scalar_request_size=128000, + max_tensor_request_size=128000, + max_tensor_point_size=11111, + max_blob_size=12345, + ) rpc_rate_limiter = util.RateLimiter(0) tensor_rpc_rate_limiter = util.RateLimiter(0) @@ -371,6 +378,60 @@ def scalar_event(tag, value): self.assertEqual(mock_tracker.tensors_tracker.call_count, 0) self.assertEqual(mock_tracker.blob_tracker.call_count, 0) + def test_start_uploading_scalars_one_shot(self): + """Check that one-shot uploading stops without AbortUploadError.""" + mock_client = _create_mock_client() + mock_rate_limiter = mock.create_autospec(util.RateLimiter) + mock_tensor_rate_limiter = mock.create_autospec(util.RateLimiter) + mock_blob_rate_limiter = mock.create_autospec(util.RateLimiter) + mock_tracker = mock.MagicMock() + with mock.patch.object( + upload_tracker, "UploadTracker", return_value=mock_tracker + ): + uploader = _create_uploader( + mock_client, + "/logs/foo", + # Send each Event below in a separate WriteScalarRequest + max_scalar_request_size=100, + rpc_rate_limiter=mock_rate_limiter, + tensor_rpc_rate_limiter=mock_tensor_rate_limiter, + blob_rpc_rate_limiter=mock_blob_rate_limiter, + verbosity=1, # In order to test the upload tracker. + one_shot=True, + ) + uploader.create_experiment() + + def scalar_event(tag, value): + return event_pb2.Event(summary=scalar_v2.scalar_pb(tag, value)) + + mock_logdir_loader = mock.create_autospec(logdir_loader.LogdirLoader) + mock_logdir_loader.get_run_events.side_effect = [ + { + "run 1": _apply_compat( + [scalar_event("1.1", 5.0), scalar_event("1.2", 5.0)] + ), + "run 2": _apply_compat( + [scalar_event("2.1", 5.0), scalar_event("2.2", 5.0)] + ), + }, + # Note the lack of AbortUploadError here. + ] + + with mock.patch.object(uploader, "_logdir_loader", mock_logdir_loader): + uploader.start_uploading() + + self.assertEqual(4, mock_client.WriteScalar.call_count) + self.assertEqual(4, mock_rate_limiter.tick.call_count) + self.assertEqual(0, mock_tensor_rate_limiter.tick.call_count) + self.assertEqual(0, mock_blob_rate_limiter.tick.call_count) + + # Check upload tracker calls. + self.assertEqual(mock_tracker.send_tracker.call_count, 1) + self.assertEqual(mock_tracker.scalars_tracker.call_count, 4) + self.assertLen(mock_tracker.scalars_tracker.call_args[0], 1) + self.assertEqual(mock_tracker.tensors_tracker.call_count, 0) + self.assertEqual(mock_tracker.blob_tracker.call_count, 0) + def test_start_uploading_tensors(self): mock_client = _create_mock_client() mock_rate_limiter = mock.create_autospec(util.RateLimiter) @@ -1926,6 +1987,36 @@ def test_varint_cost(self): self.assertEqual(uploader_lib._varint_cost(128 * 128), 3) +class UploadIntentTest(tf.test.TestCase): + def testUploadIntentUnderDryRunOneShot(self): + """Test the upload intent under the dry-run + one-shot mode.""" + mock_server_info = mock.MagicMock() + mock_channel = mock.MagicMock() + upload_limits = server_info_pb2.UploadLimits( + max_scalar_request_size=128000, + max_tensor_request_size=128000, + max_tensor_point_size=11111, + max_blob_request_size=128000, + max_blob_size=128000, + ) + with mock.patch.object( + server_info_lib, + "allowed_plugins", + return_value=_SCALARS_HISTOGRAMS_AND_GRAPHS, + ), mock.patch.object( + server_info_lib, "upload_limits", return_value=upload_limits + ), mock.patch.object( + dry_run_stubs, + "DryRunTensorBoardWriterStub", + side_effect=dry_run_stubs.DryRunTensorBoardWriterStub, + ) as mock_dry_run_stub: + intent = uploader_subcommand.UploadIntent( + self.get_temp_dir(), dry_run=True, one_shot=True + ) + intent.execute(mock_server_info, mock_channel) + self.assertEqual(mock_dry_run_stub.call_count, 1) + + def _clear_wall_times(request): """Clears the wall_time fields in a WriteScalarRequest to be deterministic."""