diff --git a/CHANGELOG.md b/CHANGELOG.md index ec40d0441b451..d6bcde01dfd42 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,7 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added DeepSpeed collate checkpoint utility function ([#8701](https://github.com/PyTorchLightning/pytorch-lightning/pull/8701)) -- +- Added `log_graph` argument for `watch` method of `WandbLogger` ([#8662](https://github.com/PyTorchLightning/pytorch-lightning/pull/8662)) - Fault-tolerant training: diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 6b1c178003b48..72ae61aaffe1f 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -199,8 +199,8 @@ def experiment(self) -> Run: return self._experiment - def watch(self, model: nn.Module, log: str = "gradients", log_freq: int = 100): - self.experiment.watch(model, log=log, log_freq=log_freq) + def watch(self, model: nn.Module, log: str = "gradients", log_freq: int = 100, log_graph: bool = True): + self.experiment.watch(model, log=log, log_freq=log_freq, log_graph=log_graph) @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 6063a558402c9..40243860b1cb9 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -69,8 +69,8 @@ def test_wandb_logger_init(wandb): ) # watch a model - logger.watch("model", "log", 10) - wandb.init().watch.assert_called_once_with("model", log="log", log_freq=10) + logger.watch("model", "log", 10, False) + wandb.init().watch.assert_called_once_with("model", log="log", log_freq=10, log_graph=False) assert logger.name == wandb.init().project_name() assert logger.version == wandb.init().id