Skip to content

Commit

Permalink
Add option to skip optim steps for 0 grad params (#636)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Jul 9, 2024
1 parent cbc7c25 commit bc60b8a
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 14 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added caching to disk of HF datasets used in downstream evals
- Added FLOPs logging
- Added configs for OLMo tiny set of models
- Added configuration field `optimizer.record_update_metrics`, which defaults to `False`, but when set to `True` will trigger AdamW to collect the step size norm and absolute max for each parameter.
- Added configuration field `optimizer.selective_updates`, which defaults to `False`, but when set to `True` will tell the optimizer to skip updating the parameter and state when the corresponding gradient is 0.
- Added configuration field `optimizer.record_update_metrics`, which defaults to `False`, but when set to True will trigger AdamW to collect the step size norm and absolute max for each parameter.
- Added `olmo_data`, a package holding data files like tokenizers.
- Added ability to load tokenizers from `olmo_data` package data.
Expand Down
5 changes: 5 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,11 @@ class OptimizerConfig(BaseConfig):
Deprecated. Use ``decay_norm_and_bias`` and ``decay_embeddings`` instead.
"""

selective_updates: bool = False
"""
If ``True``, optimizer parameter and state updates are skipped when the corresponding gradient is 0.
"""

decay_norm_and_bias: bool = False
decay_embeddings: bool = False
metrics_log_interval: Optional[int] = None
Expand Down
45 changes: 31 additions & 14 deletions olmo/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass, replace
from math import cos, pi, sqrt
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -35,10 +35,11 @@


class Optimizer(OptimizerBase):
def __init__(self, *args, record_update_metrics: bool = False, **kwargs):
def __init__(self, *args, record_update_metrics: bool = False, selective_updates: bool = False, **kwargs):
super().__init__(*args, **kwargs)
self._record_update_metrics = record_update_metrics
self._collecting_metrics = False
self._selective_updates = selective_updates

def _clean_param_name(self, name: str) -> str:
return name.replace("_fsdp_wrapped_module.", "")
Expand Down Expand Up @@ -372,12 +373,15 @@ def __init__(
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
record_update_metrics: bool = False,
selective_updates: bool = False,
device: Optional[torch.device] = None,
):
assert lr > 0.0
assert all([0.0 <= beta <= 1.0 for beta in betas])
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
super().__init__(params, defaults, record_update_metrics=record_update_metrics)
super().__init__(
params, defaults, record_update_metrics=record_update_metrics, selective_updates=selective_updates
)
for group in self.param_groups:
group["initial_lr"] = group["lr"]
self._update_total_dot_prod: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -440,15 +444,16 @@ def step(self, closure=None) -> None:

for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
grad = p.grad
if grad is None:
continue

# Perform step weight decay
p.data.mul_(1 - group["lr"] * group["weight_decay"])

grad = p.grad
state = self.state[p]

# Perform step weight decay
mask: Union[torch.Tensor, int] = grad != 0 if self._selective_updates else 1
p.data.mul_(1 - mask * (group["lr"] * group["weight_decay"]))

# State initialization
if len(state) == 0:
# Exponential moving average of gradient values
Expand All @@ -459,11 +464,15 @@ def step(self, closure=None) -> None:

# Weight update
update = exp_avg * beta1 + grad * (1 - beta1)
if isinstance(mask, torch.Tensor):
# When mask isn't a tensor it's just a literal `1` (python int), so there's
# no point in calling this op.
update.mul_(mask)
signed_update = torch.sign(update)
p.add_(signed_update, alpha=-group["lr"])

# Decay the momentum running average coefficient
exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
exp_avg.mul_(1 - mask * (1 - beta2)).add_(grad, alpha=1 - beta2)

# Track dot product and norms of update vs signed update in order to calculate
# their cosine similarity.
Expand Down Expand Up @@ -494,21 +503,22 @@ def step(self, closure=None) -> None:


class AdamW(torch.optim.AdamW, Optimizer):
def __init__(self, *args, record_update_metrics: bool = False, **kwargs):
def __init__(self, *args, record_update_metrics: bool = False, selective_updates: bool = False, **kwargs):
super().__init__(*args, **kwargs)

# Need to set these here just like in our base `Optimizer` class since our `Optimizer.__init__`
# won't be called.
self._record_update_metrics = record_update_metrics
self._collecting_metrics = False
self._selective_updates = selective_updates

self._step_size_param_names: Optional[List[str]] = None
self._step_size_norms: Optional[List[torch.Tensor]] = None
self._step_size_maxs: Optional[List[torch.Tensor]] = None

@torch.no_grad()
def step(self, closure=None) -> None:
if not (self._record_update_metrics and self._collecting_metrics):
if not (self._record_update_metrics and self._collecting_metrics) and not self._selective_updates:
return super().step(closure=closure)

device = get_default_device()
Expand Down Expand Up @@ -554,11 +564,12 @@ def step(self, closure=None) -> None:
step_t += 1

# Perform step weight decay.
param.mul_(1 - lr * weight_decay)
mask: Union[torch.Tensor, int] = grad != 0 if self._selective_updates else 1
param.mul_(1 - mask * (lr * weight_decay))

# Decay the first and second moment running average coefficient.
exp_avg.lerp_(grad, 1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
exp_avg.lerp_(grad, mask * (1 - beta1))
exp_avg_sq.mul_(1 - mask * (1 - beta2)).addcmul_(grad, grad, value=1 - beta2)

step = step_t.item()

Expand All @@ -580,6 +591,10 @@ def step(self, closure=None) -> None:
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)

update = -step_size * torch.div(exp_avg, denom)
if isinstance(mask, torch.Tensor):
# When mask isn't a tensor it's just a literal `1` (python int), so there's
# no point in calling this op.
update.mul_(mask)
param.add_(update)
step_size_norms.append(torch.linalg.vector_norm(update, 2.0, dtype=torch.float32).unsqueeze(0))
step_size_maxs.append(update.abs().max().unsqueeze(0))
Expand Down Expand Up @@ -899,6 +914,7 @@ def build_optimizer(cfg: TrainConfig, model: nn.Module) -> Optimizer:
betas=cfg.optimizer.betas,
weight_decay=cfg.optimizer.weight_decay,
record_update_metrics=cfg.optimizer.record_update_metrics,
selective_updates=cfg.optimizer.selective_updates,
)
elif cfg.optimizer.name == OptimizerType.adamw:
return AdamW(
Expand All @@ -907,6 +923,7 @@ def build_optimizer(cfg: TrainConfig, model: nn.Module) -> Optimizer:
betas=cfg.optimizer.betas,
weight_decay=cfg.optimizer.weight_decay,
record_update_metrics=cfg.optimizer.record_update_metrics,
selective_updates=cfg.optimizer.selective_updates,
eps=cfg.optimizer.eps,
)
else:
Expand Down

0 comments on commit bc60b8a

Please sign in to comment.