Skip to content

Commit

Permalink
fix(wandb): allow custom init args (#6989)
Browse files Browse the repository at this point in the history
* feat(wandb): allow custom init args

* style: pep8

* fix: get dict args

* refactor: simplify init args

* test: test init args

* style: pep8

* docs: update CHANGELOG

* test: check default resume value

* fix: default value of anonymous

* fix: respect order of parameters

* feat: use look-up table for anonymous

* yapf formatting

Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
borisdayma and awaelchli authored May 4, 2021
1 parent 82c19e1 commit 2a20102
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 20 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed optimizer `state` not moved to `GPU` ([#7277](https://github.com/PyTorchLightning/pytorch-lightning/pull/7277))


- Fixed custom init args for `WandbLogger` ([#6989](https://github.com/PyTorchLightning/pytorch-lightning/pull/6989))



## [1.2.7] - 2021-04-06

### Fixed
Expand Down
36 changes: 18 additions & 18 deletions pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ class WandbLogger(LightningLoggerBase):
log_model: Save checkpoints in wandb dir to upload on W&B servers.
prefix: A string to put at the beginning of metric keys.
experiment: WandB experiment object. Automatically set when creating a run.
\**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by
:func:`wandb.init` can be passed as keyword arguments in this logger.
\**kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc.
Raises:
ImportError:
Expand Down Expand Up @@ -93,7 +92,7 @@ def __init__(
save_dir: Optional[str] = None,
offline: Optional[bool] = False,
id: Optional[str] = None,
anonymous: Optional[bool] = False,
anonymous: Optional[bool] = None,
version: Optional[str] = None,
project: Optional[str] = None,
log_model: Optional[bool] = False,
Expand Down Expand Up @@ -122,16 +121,25 @@ def __init__(
)

super().__init__()
self._name = name
self._save_dir = save_dir
self._offline = offline
self._id = version or id
self._anonymous = 'allow' if anonymous else None
self._project = project
self._log_model = log_model
self._prefix = prefix
self._experiment = experiment
self._kwargs = kwargs
# set wandb init arguments
anonymous_lut = {True: 'allow', False: None}
self._wandb_init = dict(
name=name,
project=project,
id=version or id,
dir=save_dir,
resume='allow',
anonymous=anonymous_lut.get(anonymous, anonymous)
)
self._wandb_init.update(**kwargs)
# extract parameters
self._save_dir = self._wandb_init.get('dir')
self._name = self._wandb_init.get('name')
self._id = self._wandb_init.get('id')

def __getstate__(self):
state = self.__dict__.copy()
Expand All @@ -158,15 +166,7 @@ def experiment(self) -> Run:
if self._experiment is None:
if self._offline:
os.environ['WANDB_MODE'] = 'dryrun'
self._experiment = wandb.init(
name=self._name,
dir=self._save_dir,
project=self._project,
anonymous=self._anonymous,
id=self._id,
resume='allow',
**self._kwargs
) if wandb.run is None else wandb.run
self._experiment = wandb.init(**self._wandb_init) if wandb.run is None else wandb.run

# save checkpoints in wandb dir to upload on W&B servers
if self._save_dir is None:
Expand Down
10 changes: 8 additions & 2 deletions tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,13 @@ def test_wandb_logger_init(wandb, recwarn):

# test wandb.init called when there is no W&B run
wandb.run = None
logger = WandbLogger()
logger = WandbLogger(
name='test_name', save_dir='test_save_dir', version='test_id', project='test_project', resume='never'
)
logger.log_metrics({'acc': 1.0})
wandb.init.assert_called_once()
wandb.init.assert_called_once_with(
name='test_name', dir='test_save_dir', id='test_id', project='test_project', resume='never', anonymous=None
)
wandb.init().log.assert_called_once_with({'acc': 1.0})

# test wandb.init and setting logger experiment externally
Expand All @@ -55,6 +59,8 @@ def test_wandb_logger_init(wandb, recwarn):
wandb.init.reset_mock()
wandb.run = wandb.init()
logger = WandbLogger()
# verify default resume value
assert logger._wandb_init['resume'] == 'allow'
logger.log_metrics({'acc': 1.0}, step=3)
wandb.init.assert_called_once()
wandb.init().log.assert_called_once_with({'acc': 1.0, 'trainer/global_step': 3})
Expand Down

0 comments on commit 2a20102

Please sign in to comment.