Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
629bd51
[Scheduler] Add support for cosine and wsd scheduler
yzhangcs Mar 6, 2025
3c120e3
[Misc.] Log learning rate
yzhangcs Mar 6, 2025
2fc78e2
Unify the three decay lambda fns
yzhangcs Mar 7, 2025
fce4a14
Remove the default value in function signature
yzhangcs Mar 7, 2025
29281a6
Update toml configs
yzhangcs Mar 7, 2025
ed6e1e1
Configurable `lr_decay_ratio`
yzhangcs Mar 7, 2025
12e83be
[Scheduler] Rename `lr_decay_fn` to `linear_warmup_stable_decay`
yzhangcs Mar 7, 2025
d9b91a5
Delete `lr_decay_type` check in `build_lr_schedulers`
yzhangcs Mar 7, 2025
bbc82b2
Revert changes on train.py
yzhangcs Mar 7, 2025
2230d3a
[Config] Move scheduler-related params to [scheduler] section
yzhangcs Mar 10, 2025
01b4b62
Update train.py
yzhangcs Mar 10, 2025
e246428
Update train.py
yzhangcs Mar 10, 2025
3a14cf5
Add all scheduler configs in debug config
yzhangcs Mar 10, 2025
69b05df
Add warnings if warmup_stable_steps < warmup_steps
yzhangcs Mar 10, 2025
827395c
Revert changes on train.py
yzhangcs Mar 10, 2025
f3293ab
Obey the code format
yzhangcs Mar 10, 2025
72a0286
int type warmup_stable_steps
yzhangcs Mar 10, 2025
2e2b6b4
Rename `training.warmup_steps` to `scheduler.warmup_steps`
yzhangcs Mar 10, 2025
698d63c
Rename `scheduler` to `lr_scheduler`
yzhangcs Mar 10, 2025
5f742f5
[Legal] Modifications requested by legal for adding additional datase…
lessw2020 Mar 6, 2025
e9fe2e5
[FSDP2][doc] highlight set_requires_gradient_sync and ignored_params …
weifengpy Mar 6, 2025
f5a9abe
[PP] Ensure loss is visible on console for users (#946)
lessw2020 Mar 7, 2025
6d8da38
Make MetricsLogger as a component (#945)
fegin Mar 10, 2025
f395ed2
[Misc.] Log learning rate
yzhangcs Mar 6, 2025
af00afb
Update train.py
yzhangcs Mar 10, 2025
65f5f66
Add warnings if warmup_stable_steps < warmup_steps
yzhangcs Mar 10, 2025
1e61236
Revert changes on train.py
yzhangcs Mar 10, 2025
6328fc7
Rename `training.warmup_steps` to `scheduler.warmup_steps`
yzhangcs Mar 10, 2025
f378b2f
Merge branch 'main' into main
yzhangcs Mar 10, 2025
1eb7c71
Fix code formats
yzhangcs Mar 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions docs/converging.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ Results are obtained on 2025/01/21, with the latest `torch`, `torchao`, and `tor
- Base config: [torchtitan/models/llama/train_configs/llama3_8b.toml](../torchtitan/models/llama/train_configs/llama3_8b.toml)
- `training.batch_size = 4`, which is a minimum for Pipeline Parallel with `pipeline_parallel_degree = 2` and `pipeline_parallel_schedule = "Interleaved1F1B"`
- `training.data_parallel_shard_degree = 8`, resulting in global batch size 32
- `training.steps = 3000`, `training.warmup_steps = 600`

| Parallelism | Techniques | Remarks |
| ----- | ----- | ----- |
| FSDP 8 | default | 1D control set |
| FSDP 8, TP 2, PP 2 | torch.compile, Float8, async TP, Interleaved 1F1B | 3D test set |
| FSDP 8, TP 2, CP 2, PP 2 | torch.compile, Float8, async TP, Interleaved 1F1B | 4D test set |
| FSDP 8, CP 8 | default | to verify CP with a larger degree |
- `training.steps = 3000`, `lr_scheduler.warmup_steps = 600`

| Parallelism | Techniques | Remarks |
| ------------------------ | ------------------------------------------------- | --------------------------------- |
| FSDP 8 | default | 1D control set |
| FSDP 8, TP 2, PP 2 | torch.compile, Float8, async TP, Interleaved 1F1B | 3D test set |
| FSDP 8, TP 2, CP 2, PP 2 | torch.compile, Float8, async TP, Interleaved 1F1B | 4D test set |
| FSDP 8, CP 8 | default | to verify CP with a larger degree |

### Test results
![image](../assets/images/loss_curves.png)
Expand Down
75 changes: 58 additions & 17 deletions torchtitan/components/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import copy
import functools
from typing import Any, Callable, Dict, Generic, List, TypeVar
import math
from typing import Any, Callable, Dict, Generic, List, TypeVar, Union

import torch
import torch.nn as nn
Expand All @@ -21,6 +22,7 @@

from torchtitan.components.ft import FTManager, has_torchft
from torchtitan.config_manager import JobConfig
from torchtitan.tools.logging import logger

__all__ = [
"OptimizersContainer",
Expand Down Expand Up @@ -362,7 +364,7 @@ def state_dict(self) -> Dict[str, Any]:
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# Load the same state_dict for all schedulers. The key value we're concerned
# within ``LRScheduler.state_dict()`` is ``last_epoch``, which is an integer
# that is immutable. As long as ``training.steps`` and ``training.warmup_steps``
# that is immutable. As long as ``training.steps`` and ``lr_scheduler.warmup_steps``
# in ``job_config`` remain unchanged when resuming from a checkpoint, this
# approach is safe. We call ``copy()`` here to ensure extra safety.
for scheduler in self.schedulers:
Expand All @@ -388,30 +390,69 @@ def build_lr_schedulers(
optimizers (OptimizersContainer): The corresponding optimizers for the
lr_schedulers.
"""
warmup_steps = int(job_config.training.warmup_steps)
decay_steps = float(max(1, job_config.training.steps - warmup_steps))

def linear_warmup_linear_decay(
warmup_steps: int, decay_steps: int, current_step: int
) -> float:
"""Computes linear warmup followed by linear decay.
training_steps = job_config.training.steps
warmup_steps = int(job_config.lr_scheduler.warmup_steps)
lr_decay_ratio = job_config.lr_scheduler.decay_ratio
lr_decay_type = job_config.lr_scheduler.decay_type
lr_min = job_config.lr_scheduler.lr_min

def linear_warmup_stable_decay(
current_step: int,
warmup_steps: int,
lr_decay_ratio: Union[float, None],
lr_decay_type: str,
lr_min: float,
):
"""
Computes linear warmup followed by stable learning rate for a while,
then some type of decay.

Per LambdaLR requirement, this is accomplished by returning
a multiplicative factor to adjust the learning rate to
create the desired schedule.
a multiplicative factor `curr_adjustment` ranging from 1 to 0
to adjust the learning rate to create the desired schedule.

We offer three types of learning rate decay schedules:
1. `linear`: decays linearly from 1 to 0 over the decay period.
2. `sqrt`: decays as 1 minus the square root of the decay progress.
3. `cosine`: follows a cosine curve, decaying according to the values of the half-period of the cosine function.

If `lr_min` is specified, the decay range is scaled from 1 to `lr_min`
to ensure the learning rate does not drop below this minimum value.
"""
if lr_decay_ratio is None:
warmup_stable_steps = warmup_steps
else:
warmup_stable_steps = training_steps * (1 - lr_decay_ratio)
if warmup_stable_steps < warmup_steps:
logger.warning(
f"The warmup steps should be less than or equal to the warmup-stable steps ({warmup_stable_steps}). "
f"Consider reducing either the decay ratio ({lr_decay_ratio}) or the warmup steps ({warmup_steps})."
)
if current_step < warmup_steps:
# linear warmup
# 0-indexed step, hence + 1 adjustments
current_step += 1
curr_adjustment = float(current_step / (warmup_steps + 1))

elif current_step < warmup_stable_steps:
curr_adjustment = 1.0
else:
# linear decay
normalized_step = decay_steps - (current_step - warmup_steps)
curr_adjustment = 1 - (decay_steps - normalized_step) / decay_steps

decay_steps = float(max(1, training_steps - warmup_stable_steps))
progress = float(current_step - warmup_stable_steps) / decay_steps

if lr_decay_type == "linear":
curr_adjustment = 1 - progress
elif lr_decay_type == "sqrt":
curr_adjustment = 1 - math.sqrt(progress)
elif lr_decay_type == "cosine":
curr_adjustment = 0.5 * (1.0 + math.cos(math.pi * progress))
curr_adjustment = lr_min + (1 - lr_min) * curr_adjustment
return curr_adjustment

lr_lambda = functools.partial(linear_warmup_linear_decay, warmup_steps, decay_steps)
lr_lambda = functools.partial(
linear_warmup_stable_decay,
warmup_steps=warmup_steps,
lr_decay_ratio=lr_decay_ratio,
lr_decay_type=lr_decay_type,
lr_min=lr_min,
)
return LRSchedulersContainer(optimizers, lr_lambda)
51 changes: 45 additions & 6 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,51 @@ def __init__(self):
register_post_accumulate_grad_hook after the optimizer is built.""",
)

# lr scheduler configs
self.parser.add_argument(
"--lr_scheduler.warmup_steps",
type=int,
default=200,
help="Steps for lr scheduler warmup, normally 1/5 of --training.steps",
)
self.parser.add_argument(
"--lr_scheduler.decay_ratio",
type=float,
default=None,
help="""
Controls the proportion of the training steps allocated to the learning rate decay phase.

If `None`, the learning rate will begin decaying immediately after the warmup period.
Otherwise, the learning rate will remain stable after the warmup period and
only start decaying during the last `decay_ratio` portion of the total training steps.

This is known as the Warmup-Stable-Decay (WSD) schedule, as described in https://arxiv.org/abs/2404.06395.
""",
)
self.parser.add_argument(
"--lr_scheduler.decay_type",
type=str,
default="linear",
choices=["linear", "sqrt", "cosine"],
help="""
Learning rate decay type to use during training:
- 'linear': linearly decays learning rate from initial to final value
- 'sqrt': decays learning rate following a 1 minus square root curve
- 'cosine': smoothly decays learning rate following a cosine curve
""",
)
self.parser.add_argument(
"--lr_scheduler.lr_min",
type=float,
default=0.0,
help="""
Min lr ratio for lr scheduler.

If provided, the range of decay factor is scaled from 1 to `lr_min`
to ensure the learning rate does not drop below `optimizer.lr * lr_scheduler.lr_min`.
""",
)

# training configs
self.parser.add_argument(
"--training.dataset", type=str, default="c4_test", help="Dataset to use"
Expand All @@ -268,12 +313,6 @@ def __init__(self):
self.parser.add_argument(
"--training.seq_len", type=int, default=2048, help="Sequence length"
)
self.parser.add_argument(
"--training.warmup_steps",
type=int,
default=200,
help="Steps for lr scheduler warmup, normally 1/5 of --training.steps",
)
self.parser.add_argument(
"--training.max_norm",
type=Union[float, int],
Expand Down
7 changes: 6 additions & 1 deletion torchtitan/models/llama/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,15 @@ name = "AdamW"
lr = 8e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
decay_type = "linear"
lr_min = 0.0

[training]
batch_size = 8
seq_len = 2048
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 10
data_parallel_replicate_degree = 1
Expand Down
4 changes: 3 additions & 1 deletion torchtitan/models/llama/train_configs/llama3_405b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ name = "AdamW"
lr = 8e-5
eps = 1e-8

[lr_scheduler]
warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps

[training]
batch_size = 2
seq_len = 8192
warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 3000
data_parallel_replicate_degree = 1
Expand Down
4 changes: 3 additions & 1 deletion torchtitan/models/llama/train_configs/llama3_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ name = "AdamW"
lr = 1.5e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps

[training]
batch_size = 8
seq_len = 8192
warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_replicate_degree = 1
Expand Down
4 changes: 3 additions & 1 deletion torchtitan/models/llama/train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ name = "AdamW"
lr = 3e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 200 # lr scheduler warm up

[training]
batch_size = 1
seq_len = 8192
warmup_steps = 200 # lr scheduler warm up
max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_replicate_degree = 1
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def main(job_config: JobConfig):
f"global batch size {job_config.training.batch_size * dp_degree}, "
f"sequence length {job_config.training.seq_len}, "
f"total steps {job_config.training.steps} "
f"(warmup {job_config.training.warmup_steps})"
f"(warmup {job_config.lr_scheduler.warmup_steps})"
)
with maybe_enable_profiling(
job_config, global_step=train_state.step
Expand Down