From 052aefc3420eb507cd4eb7256ca452118448cd66 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 4 Aug 2021 18:36:57 +0800 Subject: [PATCH] WandbLogger to log model topology by default (#8662) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 2 +- pytorch_lightning/loggers/wandb.py | 4 ++-- tests/loggers/test_wandb.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ec40d0441b451..d6bcde01dfd42 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,7 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added DeepSpeed collate checkpoint utility function ([#8701](https://github.com/PyTorchLightning/pytorch-lightning/pull/8701)) -- +- Added `log_graph` argument for `watch` method of `WandbLogger` ([#8662](https://github.com/PyTorchLightning/pytorch-lightning/pull/8662)) - Fault-tolerant training: diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 6b1c178003b48..72ae61aaffe1f 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -199,8 +199,8 @@ def experiment(self) -> Run: return self._experiment - def watch(self, model: nn.Module, log: str = "gradients", log_freq: int = 100): - self.experiment.watch(model, log=log, log_freq=log_freq) + def watch(self, model: nn.Module, log: str = "gradients", log_freq: int = 100, log_graph: bool = True): + self.experiment.watch(model, log=log, log_freq=log_freq, log_graph=log_graph) @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 6063a558402c9..40243860b1cb9 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -69,8 +69,8 @@ def test_wandb_logger_init(wandb): ) # watch a model - logger.watch("model", "log", 10) - wandb.init().watch.assert_called_once_with("model", log="log", log_freq=10) + logger.watch("model", "log", 10, False) + wandb.init().watch.assert_called_once_with("model", log="log", log_freq=10, log_graph=False) assert logger.name == wandb.init().project_name() assert logger.version == wandb.init().id