Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(wandb): add sync_step #5351

Merged
merged 16 commits into from
Jan 24, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,20 @@ class WandbLogger(LightningLoggerBase):

Args:
name: Display name for the run.
save_dir: Path where data is saved.
save_dir: Path where data is saved (wandb dir by default).
offline: Run offline (data can be streamed later to wandb servers).
id: Sets the version, mainly used to resume a previous run.
version: Same as id.
anonymous: Enables or explicitly disables anonymous logging.
version: Sets the version, mainly used to resume a previous run.
project: The name of the project to which this run will belong.
log_model: Save checkpoints in wandb dir to upload on W&B servers.
experiment: WandB experiment object.
prefix: A string to put at the beginning of metric keys.
sync_step: Sync Trainer step with wandb step.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put this after the experiment flag, so that old args order is still working?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the order within the function itself.
For the description of args, I let this order because I thought it made more sense. For example id and version are the same thing. Is that ok?

experiment: WandB experiment object. Automatically set when creating a run.
\**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by
:func:`wandb.init` can be passed as keyword arguments in this logger.

Example::
Example:

.. code-block:: python

Expand All @@ -71,9 +72,9 @@ class WandbLogger(LightningLoggerBase):
make sure to use `commit=False` so the logging step does not increase.

See Also:
- `Tutorial <https://app.wandb.ai/cayush/pytorchlightning/reports/
Use-Pytorch-Lightning-with-Weights-%26-Biases--Vmlldzo2NjQ1Mw>`__
on how to use W&B with Pytorch Lightning.
- `Tutorial <https://colab.research.google.com/drive/16d1uctGaw2y9KhGBlINNTsWpmlXdJwRW?usp=sharing>`__
on how to use W&B with Pytorch Lightning
borisdayma marked this conversation as resolved.
Show resolved Hide resolved
- `W&B Documentation <https://docs.wandb.ai/integrations/lightning>`__

"""

Expand All @@ -83,14 +84,15 @@ def __init__(
self,
name: Optional[str] = None,
save_dir: Optional[str] = None,
offline: bool = False,
offline: Optional[bool] = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the change here is necessary. Optional[bool] is equivalent to Union[bool, None], and offline only accepts bool, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should work but actually I want to clean up the way we pass init parameters in a follow-up PR and suggest to just use kwargs.
The possible parameters has evolved and we could just refer to wandb.init doc

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense.

id: Optional[str] = None,
anonymous: bool = False,
version: Optional[str] = None,
anonymous: Optional[bool] = False,
project: Optional[str] = None,
log_model: bool = False,
log_model: Optional[bool] = False,
prefix: Optional[str] = '',
sync_step: Optional[bool] = True,
borisdayma marked this conversation as resolved.
Show resolved Hide resolved
experiment=None,
prefix: str = '',
**kwargs
):
if wandb is None:
Expand All @@ -99,13 +101,14 @@ def __init__(
super().__init__()
self._name = name
self._save_dir = save_dir
self._anonymous = 'allow' if anonymous else None
self._offline = offline
self._id = version or id
self._anonymous = 'allow' if anonymous else None
self._project = project
self._experiment = experiment
self._offline = offline
self._log_model = log_model
self._prefix = prefix
self._sync_step = sync_step
self._experiment = experiment
self._kwargs = kwargs
# logging multiple Trainer on a single W&B run (k-fold, resuming, etc)
self._step_offset = 0
Expand Down Expand Up @@ -161,11 +164,16 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'

metrics = self._add_prefix(metrics)
if step is not None and step + self._step_offset < self.experiment.step:
if self._sync_step and step is not None and step + self._step_offset < self.experiment.step:
self.warning_cache.warn(
'Trying to log at a previous step. Use `commit=False` when logging metrics manually.'
)
self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None)
'Trying to log at a previous step. '
'Use `WandbLogger(sync_step=False)` or try logging with `commit=False` when calling manually `wandb.log`.')
borisdayma marked this conversation as resolved.
Show resolved Hide resolved
if self._sync_step:
self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None)
elif step is not None:
self.experiment.log({**metrics, 'trainer_step': (step + self._step_offset)})
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
else:
self.experiment.log(metrics)

@property
def save_dir(self) -> Optional[str]:
Expand Down
12 changes: 12 additions & 0 deletions tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ def test_wandb_logger_init(wandb, recwarn):
wandb.init.assert_called_once()
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None)

# test sync_step functionality
wandb.init().log.reset_mock()
wandb.init.reset_mock()
wandb.run = None
wandb.init().step = 0
logger = WandbLogger(sync_step=False)
logger.log_metrics({'acc': 1.0})
wandb.init().log.assert_called_once_with({'acc': 1.0})
wandb.init().log.reset_mock()
logger.log_metrics({'acc': 1.0}, step=3)
wandb.init().log.assert_called_once_with({'acc': 1.0, 'trainer_step': 3})

# mock wandb step
wandb.init().step = 0

Expand Down