Skip to content

Commit

Permalink
Add optimizer stats logging
Browse files Browse the repository at this point in the history
  • Loading branch information
ibro45 committed Jan 15, 2024
1 parent 0695800 commit 24634aa
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 5 deletions.
16 changes: 11 additions & 5 deletions lighter/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torchmetrics import Metric, MetricCollection

from lighter.utils.collate import collate_replace_corrupted
from lighter.utils.misc import apply_fns, ensure_dict_schema, get_name, hasarg
from lighter.utils.misc import apply_fns, ensure_dict_schema, get_name, get_optimizer_stats, hasarg


class LighterSystem(pl.LightningModule):
Expand Down Expand Up @@ -203,15 +203,21 @@ def _base_step(self, batch: Union[List, Tuple], batch_idx: int, mode: str) -> Un

# Logging
default_kwargs = {"logger": True, "batch_size": self.batch_size}
step_kwargs = {"on_epoch": False, "on_step": True, "sync_dist": False}
epoch_kwargs = {"on_epoch": True, "on_step": False, "sync_dist": True}
step_kwargs = {"on_epoch": False, "on_step": True}
epoch_kwargs = {"on_epoch": True, "on_step": False}
# - Loss
if loss is not None:
self.log(f"{mode}/loss/step", loss, **default_kwargs, **step_kwargs)
self.log(f"{mode}/loss/epoch", loss, **default_kwargs, **epoch_kwargs)
self.log(f"{mode}/loss/epoch", loss, **default_kwargs, **epoch_kwargs, sync_dist=True)
# - Metrics
if metrics is not None:
for k, v in metrics.items():
self.log(f"{mode}/metrics/{k}/step", v, **default_kwargs, **step_kwargs)
self.log(f"{mode}/metrics/{k}/epoch", v, **default_kwargs, **epoch_kwargs)
self.log(f"{mode}/metrics/{k}/epoch", v, **default_kwargs, **epoch_kwargs, sync_dist=True)
# - Optimizer's learning rate, momentum, beta. Logged in train mode and once per epoch.
if mode == "train" and batch_idx == 0:
for k, v in get_optimizer_stats(self.optimizer).items():
self.log(f"{mode}/{k}", v, **default_kwargs, **epoch_kwargs)

return loss

Expand Down
34 changes: 34 additions & 0 deletions lighter/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import inspect

from torch.optim.optimizer import Optimizer


def ensure_list(vals: Any) -> List:
"""Wrap the input into a list if it is not a list. If it is a None, return an empty list.
Expand Down Expand Up @@ -114,3 +116,35 @@ def apply_fns(data: Any, fns: Union[Callable, List[Callable]]) -> Any:
for fn in ensure_list(fns):
data = fn(data)
return data


def get_optimizer_stats(optimizer: Optimizer) -> Dict[str, float]:
"""
Extract learning rates and momentum values from each parameter group of the optimizer.
Args:
optimizer (Optimizer): A PyTorch optimizer.
Returns:
Dict[str, float]: Dictionary with formatted keys and values for learning rates and momentum.
"""
stats_dict = {}
for group_idx, group in enumerate(optimizer.param_groups):
lr_key = f"optimizer/{optimizer.__class__.__name__}/lr"
momentum_key = f"optimizer/{optimizer.__class__.__name__}/momentum"

# Add group index to the key if there are multiple parameter groups
if len(optimizer.param_groups) > 1:
lr_key += f"/group{group_idx+1}"
momentum_key += f"/group{group_idx+1}"

# Extracting learning rate
stats_dict[lr_key] = group["lr"]

# Extracting momentum or betas[0] if available
if "momentum" in group:
stats_dict[momentum_key] = group["momentum"]
if "betas" in group:
stats_dict[momentum_key] = group["betas"][0]

return stats_dict

0 comments on commit 24634aa

Please sign in to comment.