diff --git a/docs/converging.md b/docs/converging.md index b098e35991..c078df0fa9 100644 --- a/docs/converging.md +++ b/docs/converging.md @@ -42,14 +42,14 @@ Results are obtained on 2025/01/21, with the latest `torch`, `torchao`, and `tor - Base config: [torchtitan/models/llama/train_configs/llama3_8b.toml](../torchtitan/models/llama/train_configs/llama3_8b.toml) - `training.batch_size = 4`, which is a minimum for Pipeline Parallel with `pipeline_parallel_degree = 2` and `pipeline_parallel_schedule = "Interleaved1F1B"` - `training.data_parallel_shard_degree = 8`, resulting in global batch size 32 -- `training.steps = 3000`, `training.warmup_steps = 600` - -| Parallelism | Techniques | Remarks | -| ----- | ----- | ----- | -| FSDP 8 | default | 1D control set | -| FSDP 8, TP 2, PP 2 | torch.compile, Float8, async TP, Interleaved 1F1B | 3D test set | -| FSDP 8, TP 2, CP 2, PP 2 | torch.compile, Float8, async TP, Interleaved 1F1B | 4D test set | -| FSDP 8, CP 8 | default | to verify CP with a larger degree | +- `training.steps = 3000`, `lr_scheduler.warmup_steps = 600` + +| Parallelism | Techniques | Remarks | +| ------------------------ | ------------------------------------------------- | --------------------------------- | +| FSDP 8 | default | 1D control set | +| FSDP 8, TP 2, PP 2 | torch.compile, Float8, async TP, Interleaved 1F1B | 3D test set | +| FSDP 8, TP 2, CP 2, PP 2 | torch.compile, Float8, async TP, Interleaved 1F1B | 4D test set | +| FSDP 8, CP 8 | default | to verify CP with a larger degree | ### Test results ![image](../assets/images/loss_curves.png) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 673c945785..d4f994b471 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -6,7 +6,8 @@ import copy import functools -from typing import Any, Callable, Dict, Generic, List, TypeVar +import math +from typing import Any, Callable, Dict, Generic, List, TypeVar, Union import torch import torch.nn as nn @@ -21,6 +22,7 @@ from torchtitan.components.ft import FTManager, has_torchft from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import logger __all__ = [ "OptimizersContainer", @@ -362,7 +364,7 @@ def state_dict(self) -> Dict[str, Any]: def load_state_dict(self, state_dict: Dict[str, Any]) -> None: # Load the same state_dict for all schedulers. The key value we're concerned # within ``LRScheduler.state_dict()`` is ``last_epoch``, which is an integer - # that is immutable. As long as ``training.steps`` and ``training.warmup_steps`` + # that is immutable. As long as ``training.steps`` and ``lr_scheduler.warmup_steps`` # in ``job_config`` remain unchanged when resuming from a checkpoint, this # approach is safe. We call ``copy()`` here to ensure extra safety. for scheduler in self.schedulers: @@ -388,30 +390,69 @@ def build_lr_schedulers( optimizers (OptimizersContainer): The corresponding optimizers for the lr_schedulers. """ - warmup_steps = int(job_config.training.warmup_steps) - decay_steps = float(max(1, job_config.training.steps - warmup_steps)) - - def linear_warmup_linear_decay( - warmup_steps: int, decay_steps: int, current_step: int - ) -> float: - """Computes linear warmup followed by linear decay. + training_steps = job_config.training.steps + warmup_steps = int(job_config.lr_scheduler.warmup_steps) + lr_decay_ratio = job_config.lr_scheduler.decay_ratio + lr_decay_type = job_config.lr_scheduler.decay_type + lr_min = job_config.lr_scheduler.lr_min + + def linear_warmup_stable_decay( + current_step: int, + warmup_steps: int, + lr_decay_ratio: Union[float, None], + lr_decay_type: str, + lr_min: float, + ): + """ + Computes linear warmup followed by stable learning rate for a while, + then some type of decay. Per LambdaLR requirement, this is accomplished by returning - a multiplicative factor to adjust the learning rate to - create the desired schedule. + a multiplicative factor `curr_adjustment` ranging from 1 to 0 + to adjust the learning rate to create the desired schedule. + + We offer three types of learning rate decay schedules: + 1. `linear`: decays linearly from 1 to 0 over the decay period. + 2. `sqrt`: decays as 1 minus the square root of the decay progress. + 3. `cosine`: follows a cosine curve, decaying according to the values of the half-period of the cosine function. + + If `lr_min` is specified, the decay range is scaled from 1 to `lr_min` + to ensure the learning rate does not drop below this minimum value. """ + if lr_decay_ratio is None: + warmup_stable_steps = warmup_steps + else: + warmup_stable_steps = training_steps * (1 - lr_decay_ratio) + if warmup_stable_steps < warmup_steps: + logger.warning( + f"The warmup steps should be less than or equal to the warmup-stable steps ({warmup_stable_steps}). " + f"Consider reducing either the decay ratio ({lr_decay_ratio}) or the warmup steps ({warmup_steps})." + ) if current_step < warmup_steps: # linear warmup # 0-indexed step, hence + 1 adjustments current_step += 1 curr_adjustment = float(current_step / (warmup_steps + 1)) - + elif current_step < warmup_stable_steps: + curr_adjustment = 1.0 else: - # linear decay - normalized_step = decay_steps - (current_step - warmup_steps) - curr_adjustment = 1 - (decay_steps - normalized_step) / decay_steps - + decay_steps = float(max(1, training_steps - warmup_stable_steps)) + progress = float(current_step - warmup_stable_steps) / decay_steps + + if lr_decay_type == "linear": + curr_adjustment = 1 - progress + elif lr_decay_type == "sqrt": + curr_adjustment = 1 - math.sqrt(progress) + elif lr_decay_type == "cosine": + curr_adjustment = 0.5 * (1.0 + math.cos(math.pi * progress)) + curr_adjustment = lr_min + (1 - lr_min) * curr_adjustment return curr_adjustment - lr_lambda = functools.partial(linear_warmup_linear_decay, warmup_steps, decay_steps) + lr_lambda = functools.partial( + linear_warmup_stable_decay, + warmup_steps=warmup_steps, + lr_decay_ratio=lr_decay_ratio, + lr_decay_type=lr_decay_type, + lr_min=lr_min, + ) return LRSchedulersContainer(optimizers, lr_lambda) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 014c0f251e..0de00287d6 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -251,6 +251,51 @@ def __init__(self): register_post_accumulate_grad_hook after the optimizer is built.""", ) + # lr scheduler configs + self.parser.add_argument( + "--lr_scheduler.warmup_steps", + type=int, + default=200, + help="Steps for lr scheduler warmup, normally 1/5 of --training.steps", + ) + self.parser.add_argument( + "--lr_scheduler.decay_ratio", + type=float, + default=None, + help=""" + Controls the proportion of the training steps allocated to the learning rate decay phase. + + If `None`, the learning rate will begin decaying immediately after the warmup period. + Otherwise, the learning rate will remain stable after the warmup period and + only start decaying during the last `decay_ratio` portion of the total training steps. + + This is known as the Warmup-Stable-Decay (WSD) schedule, as described in https://arxiv.org/abs/2404.06395. + """, + ) + self.parser.add_argument( + "--lr_scheduler.decay_type", + type=str, + default="linear", + choices=["linear", "sqrt", "cosine"], + help=""" + Learning rate decay type to use during training: + - 'linear': linearly decays learning rate from initial to final value + - 'sqrt': decays learning rate following a 1 minus square root curve + - 'cosine': smoothly decays learning rate following a cosine curve + """, + ) + self.parser.add_argument( + "--lr_scheduler.lr_min", + type=float, + default=0.0, + help=""" + Min lr ratio for lr scheduler. + + If provided, the range of decay factor is scaled from 1 to `lr_min` + to ensure the learning rate does not drop below `optimizer.lr * lr_scheduler.lr_min`. + """, + ) + # training configs self.parser.add_argument( "--training.dataset", type=str, default="c4_test", help="Dataset to use" @@ -268,12 +313,6 @@ def __init__(self): self.parser.add_argument( "--training.seq_len", type=int, default=2048, help="Sequence length" ) - self.parser.add_argument( - "--training.warmup_steps", - type=int, - default=200, - help="Steps for lr scheduler warmup, normally 1/5 of --training.steps", - ) self.parser.add_argument( "--training.max_norm", type=Union[float, int], diff --git a/torchtitan/models/llama/train_configs/debug_model.toml b/torchtitan/models/llama/train_configs/debug_model.toml index ae9264322e..3d21fbff67 100644 --- a/torchtitan/models/llama/train_configs/debug_model.toml +++ b/torchtitan/models/llama/train_configs/debug_model.toml @@ -33,10 +33,15 @@ name = "AdamW" lr = 8e-4 eps = 1e-8 +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +lr_min = 0.0 + [training] batch_size = 8 seq_len = 2048 -warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_replicate_degree = 1 diff --git a/torchtitan/models/llama/train_configs/llama3_405b.toml b/torchtitan/models/llama/train_configs/llama3_405b.toml index 9067fe3001..c70ddcbe9b 100644 --- a/torchtitan/models/llama/train_configs/llama3_405b.toml +++ b/torchtitan/models/llama/train_configs/llama3_405b.toml @@ -27,10 +27,12 @@ name = "AdamW" lr = 8e-5 eps = 1e-8 +[lr_scheduler] +warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps + [training] batch_size = 2 seq_len = 8192 -warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 3000 data_parallel_replicate_degree = 1 diff --git a/torchtitan/models/llama/train_configs/llama3_70b.toml b/torchtitan/models/llama/train_configs/llama3_70b.toml index c065e8409e..a025a7513a 100644 --- a/torchtitan/models/llama/train_configs/llama3_70b.toml +++ b/torchtitan/models/llama/train_configs/llama3_70b.toml @@ -27,10 +27,12 @@ name = "AdamW" lr = 1.5e-4 eps = 1e-8 +[lr_scheduler] +warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps + [training] batch_size = 8 seq_len = 8192 -warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_replicate_degree = 1 diff --git a/torchtitan/models/llama/train_configs/llama3_8b.toml b/torchtitan/models/llama/train_configs/llama3_8b.toml index f47615a42e..6086967fa7 100644 --- a/torchtitan/models/llama/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama/train_configs/llama3_8b.toml @@ -27,10 +27,12 @@ name = "AdamW" lr = 3e-4 eps = 1e-8 +[lr_scheduler] +warmup_steps = 200 # lr scheduler warm up + [training] batch_size = 1 seq_len = 8192 -warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_replicate_degree = 1 diff --git a/torchtitan/train.py b/torchtitan/train.py index cfd55fd9aa..9841e4ca70 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -266,7 +266,7 @@ def main(job_config: JobConfig): f"global batch size {job_config.training.batch_size * dp_degree}, " f"sequence length {job_config.training.seq_len}, " f"total steps {job_config.training.steps} " - f"(warmup {job_config.training.warmup_steps})" + f"(warmup {job_config.lr_scheduler.warmup_steps})" ) with maybe_enable_profiling( job_config, global_step=train_state.step