Skip to content

Commit

Permalink
Update LearningRateMonitor docs and tests for log_weight_decay (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Peiffap authored May 21, 2024
1 parent d76feef commit b1bb3f3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/lightning/pytorch/callbacks/lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class LearningRateMonitor(Callback):
according to the ``interval`` key of each scheduler. Defaults to ``None``.
log_momentum: option to also log the momentum values of the optimizer, if the optimizer
has the ``momentum`` or ``betas`` attribute. Defaults to ``False``.
log_weight_decay: option to also log the weight decay values of the optimizer. Defaults to
``False``.
Raises:
MisconfigurationException:
Expand All @@ -58,7 +60,7 @@ class LearningRateMonitor(Callback):
Logging names are automatically determined based on optimizer class name.
In case of multiple optimizers of same type, they will be named ``Adam``,
``Adam-1`` etc. If a optimizer has multiple parameter groups they will
``Adam-1`` etc. If an optimizer has multiple parameter groups they will
be named ``Adam/pg1``, ``Adam/pg2`` etc. To control naming, pass in a
``name`` keyword in the construction of the learning rate schedulers.
A ``name`` keyword can also be used for parameter groups in the
Expand Down
3 changes: 3 additions & 0 deletions tests/tests_pytorch/callbacks/test_lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def test_lr_monitor_single_lr(tmp_path):

assert lr_monitor.lrs, "No learning rates logged"
assert all(v is None for v in lr_monitor.last_momentum_values.values()), "Momentum should not be logged by default"
assert all(
v is None for v in lr_monitor.last_weight_decay_values.values()
), "Weight decay should not be logged by default"
assert len(lr_monitor.lrs) == len(trainer.lr_scheduler_configs)
assert list(lr_monitor.lrs) == ["lr-SGD"]

Expand Down

0 comments on commit b1bb3f3

Please sign in to comment.