diff --git a/torchtnt/utils/loggers/tensorboard.py b/torchtnt/utils/loggers/tensorboard.py index 8b014b320d..0b6cac6aa2 100644 --- a/torchtnt/utils/loggers/tensorboard.py +++ b/torchtnt/utils/loggers/tensorboard.py @@ -134,6 +134,29 @@ def log_hparams( if self._writer: self._writer.add_hparams(hparams, metrics) + def log_image(self, *args: Any, **kwargs: Any) -> None: + """Add image data to TensorBoard. + + + Args: + *args (Any): Positional arguments passed to SummaryWriter.add_image + **kwargs(Any): Keyword arguments passed to SummaryWriter.add_image + """ + writer = self._writer + if writer: + writer.add_image(*args, **kwargs) + + def log_images(self, *args: Any, **kwargs: Any) -> None: + """Add batched image data to summary. + + Args: + *args (Any): Positional arguments passed to SummaryWriter.add_images + **kwargs(Any): Keyword arguments passed to SummaryWriter.add_images + """ + writer = self._writer + if writer: + writer.add_images(*args, **kwargs) + def flush(self) -> None: """Writes pending logs to disk."""