Skip to content

Commit

Permalink
WandbLogger to log model topology by default (#8662)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
gau-nernst and awaelchli authored Aug 4, 2021
1 parent 560a5c3 commit 052aefc
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added DeepSpeed collate checkpoint utility function ([#8701](https://github.com/PyTorchLightning/pytorch-lightning/pull/8701))


-
- Added `log_graph` argument for `watch` method of `WandbLogger` ([#8662](https://github.com/PyTorchLightning/pytorch-lightning/pull/8662))


- Fault-tolerant training:
4 changes: 2 additions & 2 deletions pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
@@ -199,8 +199,8 @@ def experiment(self) -> Run:

return self._experiment

def watch(self, model: nn.Module, log: str = "gradients", log_freq: int = 100):
self.experiment.watch(model, log=log, log_freq=log_freq)
def watch(self, model: nn.Module, log: str = "gradients", log_freq: int = 100, log_graph: bool = True):
self.experiment.watch(model, log=log, log_freq=log_freq, log_graph=log_graph)

@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
4 changes: 2 additions & 2 deletions tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
@@ -69,8 +69,8 @@ def test_wandb_logger_init(wandb):
)

# watch a model
logger.watch("model", "log", 10)
wandb.init().watch.assert_called_once_with("model", log="log", log_freq=10)
logger.watch("model", "log", 10, False)
wandb.init().watch.assert_called_once_with("model", log="log", log_freq=10, log_graph=False)

assert logger.name == wandb.init().project_name()
assert logger.version == wandb.init().id

0 comments on commit 052aefc

Please sign in to comment.