Skip to content

Commit

Permalink
Merge pull request #691 from allenai/dave/cosine_linear_envelope
Browse files Browse the repository at this point in the history
Add scheduler for cosine in linear envelope.
  • Loading branch information
dwadden authored Aug 8, 2024
2 parents 6587ddb + 1cf3040 commit 4332c32
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `model.rope_theta` configuration option.
- Added `model.embedding_layer_norm` configuration option for adding a LN to the embeddings.
- Added `model.emb_init_std` configuration option to override the standard deviation used to initialize the embeddings.
- Added `CosLinearEnvelope` scheduler, which is a pointwise product of a cosine schedule and a linear decay.

### Changed

Expand Down
1 change: 1 addition & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ class SchedulerType(StrEnum):
inverse_sqrt_with_warmup = "inverse_sqrt_with_warmup"
max_scheduler = "max_scheduler"
constant = "constant"
cosine_linear_envelope = "cosine_linear_envelope"


class SchedulerUnits(StrEnum):
Expand Down
35 changes: 35 additions & 0 deletions olmo/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"InvSqrtWithWarmup",
"MaxScheduler",
"ConstantScheduler",
"CosLinearEnvelope",
"BoltOnWarmupScheduler",
"build_optimizer",
"build_scheduler",
Expand Down Expand Up @@ -788,6 +789,29 @@ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
return initial_lr


@dataclass
class CosLinearEnvelope(Scheduler):
"Pointwise product of cosine schedule and linear decay; useful during annealing."
warmup_steps: int
alpha_f: float = 0.1
t_max: Optional[int] = None

def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
max_steps = max_steps if self.t_max is None else self.t_max
eta_min = initial_lr * self.alpha_f

if step < self.warmup_steps:
return self._linear_warmup(initial_lr, step, self.warmup_steps)
if step >= max_steps:
return eta_min
else:
step = step - self.warmup_steps
max_steps = max_steps - self.warmup_steps
linear_envelope = 1 - (step / max_steps)
cosine_schedule = (initial_lr - eta_min) * (1 + cos(pi * step / max_steps)) / 2
return eta_min + linear_envelope * cosine_schedule


PARAM_GROUP_FIELDS = ("sharded", "max_grad_norm", "max_grad_norm_ratio", "param_names")


Expand Down Expand Up @@ -981,5 +1005,16 @@ def build_scheduler(cfg: TrainConfig, sched_cfg: Optional[SchedulerConfig] = Non
grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
warmup_min_lr=sched_cfg.warmup_min_lr,
)
elif sched_cfg.name == SchedulerType.cosine_linear_envelope:
return CosLinearEnvelope(
grad_clip_warmup_steps=(
None if sched_cfg.grad_clip_warmup_steps is None else int(sched_cfg.grad_clip_warmup_steps)
),
grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
warmup_steps=int(sched_cfg.t_warmup),
alpha_f=sched_cfg.alpha_f,
t_max=None if sched_cfg.t_max is None else int(sched_cfg.t_max),
warmup_min_lr=sched_cfg.warmup_min_lr,
)
else:
raise NotImplementedError

0 comments on commit 4332c32

Please sign in to comment.