Skip to content

Commit

Permalink
Add batch_size, rank_zero_only arguments for log_dict to match …
Browse files Browse the repository at this point in the history
…`log` (#8628)
  • Loading branch information
eladsegal authored Aug 3, 2021
1 parent 98319f8 commit 08fba96
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))


- Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628))


- Added check for unique GPU ids ([#8666](https://github.com/PyTorchLightning/pytorch-lightning/pull/8666))


Expand Down
8 changes: 8 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,8 @@ def log_dict(
sync_dist_op: Optional[Any] = None, # todo: Remove in 1.6
sync_dist_group: Optional[Any] = None,
add_dataloader_idx: bool = True,
batch_size: Optional[int] = None,
rank_zero_only: Optional[bool] = None,
) -> None:
"""
Log a dictionary of values at once.
Expand All @@ -502,6 +504,10 @@ def log_dict(
add_dataloader_idx: if True, appends the index of the current dataloader to
the name (when using multiple). If False, user needs to give unique names for
each dataloader to not mix values
batch_size: Current batch_size. This will be directly inferred from the loaded batch,
but some data structures might need to explicitly provide it.
rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which
would produce a deadlock as not all processes would perform this log call.
"""
for k, v in dictionary.items():
self.log(
Expand All @@ -519,6 +525,8 @@ def log_dict(
tbptt_pad_token=tbptt_pad_token,
tbptt_reduce_fx=tbptt_reduce_fx,
add_dataloader_idx=add_dataloader_idx,
batch_size=batch_size,
rank_zero_only=rank_zero_only,
)

@staticmethod
Expand Down

0 comments on commit 08fba96

Please sign in to comment.