Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions torchtitan/components/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import copy
import functools
import math
from typing import Any, Callable, Dict, Generic, List, TypeVar, Union
from typing import Any, Callable, Dict, Generic, List, TypeVar

import torch
import torch.nn as nn
Expand Down Expand Up @@ -392,14 +392,26 @@ def build_lr_schedulers(
"""
training_steps = job_config.training.steps
warmup_steps = int(job_config.lr_scheduler.warmup_steps)
lr_decay_ratio = job_config.lr_scheduler.decay_ratio
if job_config.lr_scheduler.decay_ratio is not None:
decay_steps = round(training_steps * job_config.lr_scheduler.decay_ratio)
if warmup_steps + decay_steps > training_steps:
logger.warning(
f"Warmup ({warmup_steps}) + decay ({decay_steps}) steps exceed "
f"total training steps ({training_steps}). "
f"Adjusting decay steps to {training_steps - warmup_steps}."
)
decay_steps = training_steps - warmup_steps
else:
decay_steps = training_steps - warmup_steps
stable_steps = training_steps - warmup_steps - decay_steps
lr_decay_type = job_config.lr_scheduler.decay_type
lr_min = job_config.lr_scheduler.lr_min

def linear_warmup_stable_decay(
current_step: int,
warmup_steps: int,
lr_decay_ratio: Union[float, None],
stable_steps: int,
decay_steps: int,
lr_decay_type: str,
lr_min: float,
):
Expand All @@ -419,15 +431,7 @@ def linear_warmup_stable_decay(
If `lr_min` is specified, the decay range is scaled from 1 to `lr_min`
to ensure the learning rate does not drop below this minimum value.
"""
if lr_decay_ratio is None:
warmup_stable_steps = warmup_steps
else:
warmup_stable_steps = training_steps * (1 - lr_decay_ratio)
if warmup_stable_steps < warmup_steps:
logger.warning(
f"The warmup steps should be less than or equal to the warmup-stable steps ({warmup_stable_steps}). "
f"Consider reducing either the decay ratio ({lr_decay_ratio}) or the warmup steps ({warmup_steps})."
)
warmup_stable_steps = warmup_steps + stable_steps
if current_step < warmup_steps:
# linear warmup
# 0-indexed step, hence + 1 adjustments
Expand All @@ -436,7 +440,6 @@ def linear_warmup_stable_decay(
elif current_step < warmup_stable_steps:
curr_adjustment = 1.0
else:
decay_steps = float(max(1, training_steps - warmup_stable_steps))
progress = float(current_step - warmup_stable_steps) / decay_steps

if lr_decay_type == "linear":
Expand All @@ -451,7 +454,8 @@ def linear_warmup_stable_decay(
lr_lambda = functools.partial(
linear_warmup_stable_decay,
warmup_steps=warmup_steps,
lr_decay_ratio=lr_decay_ratio,
stable_steps=stable_steps,
decay_steps=decay_steps,
lr_decay_type=lr_decay_type,
lr_min=lr_min,
)
Expand Down