diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index d4f994b471..9d04ce4f58 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -7,7 +7,7 @@ import copy import functools import math -from typing import Any, Callable, Dict, Generic, List, TypeVar, Union +from typing import Any, Callable, Dict, Generic, List, TypeVar import torch import torch.nn as nn @@ -392,14 +392,26 @@ def build_lr_schedulers( """ training_steps = job_config.training.steps warmup_steps = int(job_config.lr_scheduler.warmup_steps) - lr_decay_ratio = job_config.lr_scheduler.decay_ratio + if job_config.lr_scheduler.decay_ratio is not None: + decay_steps = round(training_steps * job_config.lr_scheduler.decay_ratio) + if warmup_steps + decay_steps > training_steps: + logger.warning( + f"Warmup ({warmup_steps}) + decay ({decay_steps}) steps exceed " + f"total training steps ({training_steps}). " + f"Adjusting decay steps to {training_steps - warmup_steps}." + ) + decay_steps = training_steps - warmup_steps + else: + decay_steps = training_steps - warmup_steps + stable_steps = training_steps - warmup_steps - decay_steps lr_decay_type = job_config.lr_scheduler.decay_type lr_min = job_config.lr_scheduler.lr_min def linear_warmup_stable_decay( current_step: int, warmup_steps: int, - lr_decay_ratio: Union[float, None], + stable_steps: int, + decay_steps: int, lr_decay_type: str, lr_min: float, ): @@ -419,15 +431,7 @@ def linear_warmup_stable_decay( If `lr_min` is specified, the decay range is scaled from 1 to `lr_min` to ensure the learning rate does not drop below this minimum value. """ - if lr_decay_ratio is None: - warmup_stable_steps = warmup_steps - else: - warmup_stable_steps = training_steps * (1 - lr_decay_ratio) - if warmup_stable_steps < warmup_steps: - logger.warning( - f"The warmup steps should be less than or equal to the warmup-stable steps ({warmup_stable_steps}). " - f"Consider reducing either the decay ratio ({lr_decay_ratio}) or the warmup steps ({warmup_steps})." - ) + warmup_stable_steps = warmup_steps + stable_steps if current_step < warmup_steps: # linear warmup # 0-indexed step, hence + 1 adjustments @@ -436,7 +440,6 @@ def linear_warmup_stable_decay( elif current_step < warmup_stable_steps: curr_adjustment = 1.0 else: - decay_steps = float(max(1, training_steps - warmup_stable_steps)) progress = float(current_step - warmup_stable_steps) / decay_steps if lr_decay_type == "linear": @@ -451,7 +454,8 @@ def linear_warmup_stable_decay( lr_lambda = functools.partial( linear_warmup_stable_decay, warmup_steps=warmup_steps, - lr_decay_ratio=lr_decay_ratio, + stable_steps=stable_steps, + decay_steps=decay_steps, lr_decay_type=lr_decay_type, lr_min=lr_min, )