Skip to content

Commit

Permalink
Fixed setting of _save_dir when run initiated externally (#7106)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
THasthika and awaelchli authored Apr 23, 2021
1 parent f48ac62 commit c502e47
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 8 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `CombinedLoader` in distributed settings for validation / testing ([#7102](https://github.com/PyTorchLightning/pytorch-lightning/pull/7102))


- Fixed the save_dir in `WandbLogger` when the run was initiated externally ([#7106](https://github.com/PyTorchLightning/pytorch-lightning/pull/7106))


- Fixed parsing for pre-release package versions ([#6999](https://github.com/PyTorchLightning/pytorch-lightning/pull/6999))


Expand Down
16 changes: 8 additions & 8 deletions pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,14 @@ def experiment(self) -> Run:
**self._kwargs
) if wandb.run is None else wandb.run

# save checkpoints in wandb dir to upload on W&B servers
if self._save_dir is None:
self._save_dir = self._experiment.dir

# define default x-axis (for latest wandb versions)
if getattr(self._experiment, "define_metric", None):
self._experiment.define_metric("trainer/global_step")
self._experiment.define_metric("*", step_metric='trainer/global_step', step_sync=True)
# save checkpoints in wandb dir to upload on W&B servers
if self._save_dir is None:
self._save_dir = self._experiment.dir

# define default x-axis (for latest wandb versions)
if getattr(self._experiment, "define_metric", None):
self._experiment.define_metric("trainer/global_step")
self._experiment.define_metric("*", step_metric='trainer/global_step', step_sync=True)

return self._experiment

Expand Down
8 changes: 8 additions & 0 deletions tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ def test_wandb_logger_init(wandb, recwarn):
wandb.init.assert_called_once()
wandb.init().log.assert_called_once_with({'acc': 1.0})

# test wandb.init and setting logger experiment externally
wandb.run = None
run = wandb.init()
logger = WandbLogger(experiment=run)
assert logger.experiment
assert run.dir is not None
assert logger.save_dir == run.dir

# test wandb.init not called if there is a W&B run
wandb.init().log.reset_mock()
wandb.init.reset_mock()
Expand Down

0 comments on commit c502e47

Please sign in to comment.