From 629bd5139487b0d416cf0c97995a737f2da5257b Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Thu, 6 Mar 2025 00:47:43 -0800 Subject: [PATCH 01/29] [Scheduler] Add support for cosine and wsd scheduler --- torchtitan/components/optimizer.py | 78 ++++++++++++++++++- torchtitan/config_manager.py | 18 +++++ .../llama/train_configs/debug_model.toml | 2 + .../llama/train_configs/llama3_405b.toml | 2 + .../llama/train_configs/llama3_70b.toml | 2 + .../models/llama/train_configs/llama3_8b.toml | 2 + 6 files changed, 101 insertions(+), 3 deletions(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 673c945785..de362bba9e 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -6,9 +6,11 @@ import copy import functools +import math from typing import Any, Callable, Dict, Generic, List, TypeVar import torch + import torch.nn as nn from torch.distributed.checkpoint.state_dict import ( get_optimizer_state_dict, @@ -389,10 +391,11 @@ def build_lr_schedulers( lr_schedulers. """ warmup_steps = int(job_config.training.warmup_steps) - decay_steps = float(max(1, job_config.training.steps - warmup_steps)) + training_steps = job_config.training.steps + min_lr_ratio = job_config.optimizer.min_lr_ratio def linear_warmup_linear_decay( - warmup_steps: int, decay_steps: int, current_step: int + current_step: int, warmup_steps: int, min_lr_ratio: float ) -> float: """Computes linear warmup followed by linear decay. @@ -400,6 +403,7 @@ def linear_warmup_linear_decay( a multiplicative factor to adjust the learning rate to create the desired schedule. """ + decay_steps = float(max(1, training_steps - warmup_steps)) if current_step < warmup_steps: # linear warmup # 0-indexed step, hence + 1 adjustments @@ -410,8 +414,76 @@ def linear_warmup_linear_decay( # linear decay normalized_step = decay_steps - (current_step - warmup_steps) curr_adjustment = 1 - (decay_steps - normalized_step) / decay_steps + curr_adjustment = curr_adjustment * (1 - min_lr_ratio) + min_lr_ratio + + return curr_adjustment + + def linear_warmup_cosine_decay( + current_step: int, warmup_steps: int, min_lr_ratio: float = 0.0 + ): + decay_steps = float(max(1, training_steps - warmup_steps)) + if current_step < warmup_steps: + # linear warmup + # 0-indexed step, hence + 1 adjustments + current_step += 1 + curr_adjustment = float(current_step / (warmup_steps + 1)) + else: + # cosine decay + progress = (current_step - warmup_steps) / decay_steps + curr_adjustment = 0.5 * (1.0 + math.cos(math.pi * progress)) + curr_adjustment = min_lr_ratio + (1 - min_lr_ratio) * curr_adjustment + + return curr_adjustment + def linear_warmup_stable_decay( + current_step: int, + warmup_steps: int, + decay_ratio: float = 0.1, + min_lr_ratio: float = 0.0, + decay_type: str = "sqrt", + ): + warmup_stable_steps = training_steps * (1 - decay_ratio) + decay_steps = float(max(1, training_steps - warmup_stable_steps)) + if current_step < warmup_steps: + # linear warmup + # 0-indexed step, hence + 1 adjustments + current_step += 1 + curr_adjustment = float(current_step / (warmup_steps + 1)) + elif current_step < warmup_stable_steps: + curr_adjustment = 1.0 + else: + progress = float(current_step - warmup_stable_steps) / decay_steps + if decay_type == "linear": + curr_adjustment = 1 - progress + elif decay_type == "sqrt": + curr_adjustment = 1 - math.sqrt(progress) + elif decay_type == "cosine": + curr_adjustment = 0.5 * (1.0 + math.cos(math.pi * progress)) + else: + raise ValueError( + f"decay type {decay_type} is not in ['linear', 'sqrt', 'cosine']" + ) + curr_adjustment = min_lr_ratio + (1 - min_lr_ratio) * curr_adjustment return curr_adjustment - lr_lambda = functools.partial(linear_warmup_linear_decay, warmup_steps, decay_steps) + if job_config.optimizer.scheduler == "linear": + lr_lambda = functools.partial( + linear_warmup_linear_decay, + warmup_steps=warmup_steps, + min_lr_ratio=min_lr_ratio, + ) + elif job_config.optimizer.scheduler == "cosine": + lr_lambda = functools.partial( + linear_warmup_cosine_decay, + warmup_steps=warmup_steps, + min_lr_ratio=min_lr_ratio, + ) + elif job_config.optimizer.scheduler == "wsd": + lr_lambda = functools.partial( + linear_warmup_stable_decay, + warmup_steps=warmup_steps, + min_lr_ratio=min_lr_ratio, + ) + else: + raise ValueError(f"Scheduler {job_config.optimizer.scheduler} not supported") return LRSchedulersContainer(optimizers, lr_lambda) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 79bd2d9199..d95f4826a0 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -233,9 +233,27 @@ def __init__(self): self.parser.add_argument( "--optimizer.lr", type=float, default=8e-4, help="Learning rate to use" ) + self.parser.add_argument( + "--optimizer.min_lr_ratio", + type=float, + default=0.0, + help="Min lr ratio for lr scheduler", + ) self.parser.add_argument( "--optimizer.eps", type=float, default=1e-8, help="Epsilon value to use" ) + self.parser.add_argument( + "--optimizer.scheduler", + type=str, + default="linear", + choices=["linear", "cosine", "wsd"], + help=""" + Learning rate scheduler to use during training: + - 'linear': Linear scheduler that linearly decays learning rate from initial to final value + - 'cosine': Cosine annealing scheduler that smoothly decays learning rate following a cosine curve + - 'wsd': Warmup-Stable-Decay scheduler that follows warmup, stable, and decay phases (see https://arxiv.org/abs/2404.06395) + """, + ) self.parser.add_argument( "--optimizer.implementation", type=str, diff --git a/torchtitan/models/llama/train_configs/debug_model.toml b/torchtitan/models/llama/train_configs/debug_model.toml index ae9264322e..151daceae1 100644 --- a/torchtitan/models/llama/train_configs/debug_model.toml +++ b/torchtitan/models/llama/train_configs/debug_model.toml @@ -32,6 +32,8 @@ tokenizer_path = "./tests/assets/test_tiktoken.model" name = "AdamW" lr = 8e-4 eps = 1e-8 +scheduler = "linear" +min_lr_ratio = 0.0 [training] batch_size = 8 diff --git a/torchtitan/models/llama/train_configs/llama3_405b.toml b/torchtitan/models/llama/train_configs/llama3_405b.toml index 9067fe3001..fb756a3b19 100644 --- a/torchtitan/models/llama/train_configs/llama3_405b.toml +++ b/torchtitan/models/llama/train_configs/llama3_405b.toml @@ -26,6 +26,8 @@ converters = "float8" name = "AdamW" lr = 8e-5 eps = 1e-8 +scheduler = "linear" +min_lr_ratio = 0.0 [training] batch_size = 2 diff --git a/torchtitan/models/llama/train_configs/llama3_70b.toml b/torchtitan/models/llama/train_configs/llama3_70b.toml index c065e8409e..0bf7f21cf1 100644 --- a/torchtitan/models/llama/train_configs/llama3_70b.toml +++ b/torchtitan/models/llama/train_configs/llama3_70b.toml @@ -26,6 +26,8 @@ tokenizer_path = "./assets/tokenizer/original/tokenizer.model" name = "AdamW" lr = 1.5e-4 eps = 1e-8 +scheduler = "linear" +min_lr_ratio = 0.0 [training] batch_size = 8 diff --git a/torchtitan/models/llama/train_configs/llama3_8b.toml b/torchtitan/models/llama/train_configs/llama3_8b.toml index f47615a42e..a5491cdd11 100644 --- a/torchtitan/models/llama/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama/train_configs/llama3_8b.toml @@ -26,6 +26,8 @@ tokenizer_path = "./assets/tokenizer/original/tokenizer.model" name = "AdamW" lr = 3e-4 eps = 1e-8 +scheduler = "linear" +min_lr_ratio = 0.0 [training] batch_size = 1 From 3c120e392b3a2c9f4612e884d117b9ed84ad7dde Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Thu, 6 Mar 2025 01:30:34 -0800 Subject: [PATCH 02/29] [Misc.] Log learning rate --- torchtitan/train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 630626dbec..3a24160c03 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -370,6 +370,7 @@ def main(job_config: JobConfig): time_delta = time.perf_counter() - time_last_log + last_lr = lr_schedulers.schedulers[0].get_last_lr()[0] # tokens per second per device, abbreviated as tps tps = ntokens_since_last_log / ( time_delta * parallel_dims.non_data_parallel_size @@ -392,6 +393,7 @@ def main(job_config: JobConfig): "throughput(tps)": tps, "tflops": tflops, "mfu(%)": mfu, + "optim/learning_rate": last_lr, "time_metrics/end_to_end(s)": time_end_to_end, "time_metrics/data_loading(s)": time_data_loading, "time_metrics/data_loading(%)": time_data_loading_pct, @@ -410,8 +412,8 @@ def main(job_config: JobConfig): f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB" f"({device_mem_stats.max_reserved_pct:.2f}%) " f"{color.blue}tps: {round(tps):,} " - f"{color.cyan}tflops: {tflops:,.2f} " - f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" + f"{color.cyan}lr: {last_lr:.4e} " + f"{color.magenta}tflops: {tflops:,.2f} mfu: {mfu:.2f}%{color.reset}" ) ntokens_since_last_log = 0 From 2fc78e27a93c76aed45de8110dc6203423708de4 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Thu, 6 Mar 2025 21:34:15 -0800 Subject: [PATCH 03/29] Unify the three decay lambda fns --- torchtitan/components/optimizer.py | 109 ++++++++++------------------- torchtitan/config_manager.py | 38 ++++++---- 2 files changed, 63 insertions(+), 84 deletions(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index de362bba9e..b984055c87 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -392,58 +392,34 @@ def build_lr_schedulers( """ warmup_steps = int(job_config.training.warmup_steps) training_steps = job_config.training.steps - min_lr_ratio = job_config.optimizer.min_lr_ratio + lr_decay_ratio = job_config.optimizer.lr_decay_ratio + lr_decay_type = job_config.optimizer.lr_decay_type + lr_min = job_config.optimizer.lr_min - def linear_warmup_linear_decay( - current_step: int, warmup_steps: int, min_lr_ratio: float - ) -> float: - """Computes linear warmup followed by linear decay. - - Per LambdaLR requirement, this is accomplished by returning - a multiplicative factor to adjust the learning rate to - create the desired schedule. + def lr_decay_fn( + current_step: int, + warmup_steps: int, + lr_decay_ratio: float = 0.1, + lr_decay_type: str = "sqrt", + lr_min: float = 0.0, + ): """ - decay_steps = float(max(1, training_steps - warmup_steps)) - if current_step < warmup_steps: - # linear warmup - # 0-indexed step, hence + 1 adjustments - current_step += 1 - curr_adjustment = float(current_step / (warmup_steps + 1)) - - else: - # linear decay - normalized_step = decay_steps - (current_step - warmup_steps) - curr_adjustment = 1 - (decay_steps - normalized_step) / decay_steps - curr_adjustment = curr_adjustment * (1 - min_lr_ratio) + min_lr_ratio + Computes linear warmup followed by stable learning rate for a while, + then some type of decay. - return curr_adjustment - - def linear_warmup_cosine_decay( - current_step: int, warmup_steps: int, min_lr_ratio: float = 0.0 - ): - decay_steps = float(max(1, training_steps - warmup_steps)) - if current_step < warmup_steps: - # linear warmup - # 0-indexed step, hence + 1 adjustments - current_step += 1 - curr_adjustment = float(current_step / (warmup_steps + 1)) - else: - # cosine decay - progress = (current_step - warmup_steps) / decay_steps - curr_adjustment = 0.5 * (1.0 + math.cos(math.pi * progress)) - curr_adjustment = min_lr_ratio + (1 - min_lr_ratio) * curr_adjustment + Per LambdaLR requirement, this is accomplished by returning + a multiplicative factor `curr_adjustment` ranging from 1 to 0 + to adjust the learning rate to create the desired schedule. - return curr_adjustment + We offer three types of learning rate decay schedules: + 1. `linear`: decreases linearly from 1 to 0 over the decay period. + 2. `sqrt`: decreases as 1 minus the square root of the decay progress. + 3. `cosine`: follows a cosine curve, decreasing according to the values of the half-period of the cosine function. - def linear_warmup_stable_decay( - current_step: int, - warmup_steps: int, - decay_ratio: float = 0.1, - min_lr_ratio: float = 0.0, - decay_type: str = "sqrt", - ): - warmup_stable_steps = training_steps * (1 - decay_ratio) - decay_steps = float(max(1, training_steps - warmup_stable_steps)) + If `lr_min` is specified, the decay range is scaled from 1 to `lr_min` + to ensure the learning rate does not drop below this minimum value. + """ + warmup_stable_steps = training_steps * (1 - lr_decay_ratio) if current_step < warmup_steps: # linear warmup # 0-indexed step, hence + 1 adjustments @@ -452,38 +428,27 @@ def linear_warmup_stable_decay( elif current_step < warmup_stable_steps: curr_adjustment = 1.0 else: + decay_steps = float(max(1, training_steps - warmup_stable_steps)) progress = float(current_step - warmup_stable_steps) / decay_steps - if decay_type == "linear": + + if lr_decay_type == "linear": curr_adjustment = 1 - progress - elif decay_type == "sqrt": + elif lr_decay_type == "sqrt": curr_adjustment = 1 - math.sqrt(progress) - elif decay_type == "cosine": + elif lr_decay_type == "cosine": curr_adjustment = 0.5 * (1.0 + math.cos(math.pi * progress)) else: raise ValueError( - f"decay type {decay_type} is not in ['linear', 'sqrt', 'cosine']" + f"LR decay type {lr_decay_type} is not in ['linear', 'sqrt', 'cosine']" ) - curr_adjustment = min_lr_ratio + (1 - min_lr_ratio) * curr_adjustment + curr_adjustment = lr_min + (1 - lr_min) * curr_adjustment return curr_adjustment - if job_config.optimizer.scheduler == "linear": - lr_lambda = functools.partial( - linear_warmup_linear_decay, - warmup_steps=warmup_steps, - min_lr_ratio=min_lr_ratio, - ) - elif job_config.optimizer.scheduler == "cosine": - lr_lambda = functools.partial( - linear_warmup_cosine_decay, - warmup_steps=warmup_steps, - min_lr_ratio=min_lr_ratio, - ) - elif job_config.optimizer.scheduler == "wsd": - lr_lambda = functools.partial( - linear_warmup_stable_decay, - warmup_steps=warmup_steps, - min_lr_ratio=min_lr_ratio, - ) - else: - raise ValueError(f"Scheduler {job_config.optimizer.scheduler} not supported") + lr_lambda = functools.partial( + lr_decay_fn, + warmup_steps=warmup_steps, + lr_decay_ratio=lr_decay_ratio, + lr_decay_type=lr_decay_type, + lr_min=lr_min, + ) return LRSchedulersContainer(optimizers, lr_lambda) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index d95f4826a0..fbdbe33e46 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -234,26 +234,40 @@ def __init__(self): "--optimizer.lr", type=float, default=8e-4, help="Learning rate to use" ) self.parser.add_argument( - "--optimizer.min_lr_ratio", + "--optimizer.lr_decay_ratio", type=float, - default=0.0, - help="Min lr ratio for lr scheduler", - ) - self.parser.add_argument( - "--optimizer.eps", type=float, default=1e-8, help="Epsilon value to use" + default=0.1, + help=""" + The ratio of the learning rate decay period. + + If specified, the learning rate decay will only occur during the last `lr_decay_ratio` portion of the total training steps. + """, ) self.parser.add_argument( - "--optimizer.scheduler", + "--optimizer.lr_decay_type", type=str, default="linear", - choices=["linear", "cosine", "wsd"], + choices=["linear", "sqrt", "cosine"], help=""" - Learning rate scheduler to use during training: - - 'linear': Linear scheduler that linearly decays learning rate from initial to final value - - 'cosine': Cosine annealing scheduler that smoothly decays learning rate following a cosine curve - - 'wsd': Warmup-Stable-Decay scheduler that follows warmup, stable, and decay phases (see https://arxiv.org/abs/2404.06395) + Learning rate decay type to use during training: + - 'linear': linearly decays learning rate from initial to final value + - 'sqrt': decays learning rate following a 1 minus square root curve + - 'cosine': smoothly decays learning rate following a cosine curve """, ) + self.parser.add_argument( + "--optimizer.lr_min", + type=float, + default=0.0, + help=""" + Min lr ratio for lr scheduler. + + If provided, the range of decay factor is scaled from 1 to `lr_min` to ensure the learning rate does not drop below `lr * lr_min`. + """, + ) + self.parser.add_argument( + "--optimizer.eps", type=float, default=1e-8, help="Epsilon value to use" + ) self.parser.add_argument( "--optimizer.implementation", type=str, From fce4a144093d1c286eb8cfb1295d5de36a40d57e Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Thu, 6 Mar 2025 21:38:03 -0800 Subject: [PATCH 04/29] Remove the default value in function signature --- torchtitan/components/optimizer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index b984055c87..c2fa883ca0 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -399,9 +399,9 @@ def build_lr_schedulers( def lr_decay_fn( current_step: int, warmup_steps: int, - lr_decay_ratio: float = 0.1, - lr_decay_type: str = "sqrt", - lr_min: float = 0.0, + lr_decay_ratio: float, + lr_decay_type: str, + lr_min: float, ): """ Computes linear warmup followed by stable learning rate for a while, @@ -412,9 +412,9 @@ def lr_decay_fn( to adjust the learning rate to create the desired schedule. We offer three types of learning rate decay schedules: - 1. `linear`: decreases linearly from 1 to 0 over the decay period. - 2. `sqrt`: decreases as 1 minus the square root of the decay progress. - 3. `cosine`: follows a cosine curve, decreasing according to the values of the half-period of the cosine function. + 1. `linear`: decays linearly from 1 to 0 over the decay period. + 2. `sqrt`: decays as 1 minus the square root of the decay progress. + 3. `cosine`: follows a cosine curve, decaying according to the values of the half-period of the cosine function. If `lr_min` is specified, the decay range is scaled from 1 to `lr_min` to ensure the learning rate does not drop below this minimum value. From 29281a6046173c6cfc76042e2c36ad8d3dd8395a Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Thu, 6 Mar 2025 21:55:25 -0800 Subject: [PATCH 05/29] Update toml configs --- torchtitan/models/llama/train_configs/debug_model.toml | 4 ++-- torchtitan/models/llama/train_configs/llama3_405b.toml | 4 ++-- torchtitan/models/llama/train_configs/llama3_70b.toml | 4 ++-- torchtitan/models/llama/train_configs/llama3_8b.toml | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/torchtitan/models/llama/train_configs/debug_model.toml b/torchtitan/models/llama/train_configs/debug_model.toml index 151daceae1..1305d1b58c 100644 --- a/torchtitan/models/llama/train_configs/debug_model.toml +++ b/torchtitan/models/llama/train_configs/debug_model.toml @@ -32,8 +32,8 @@ tokenizer_path = "./tests/assets/test_tiktoken.model" name = "AdamW" lr = 8e-4 eps = 1e-8 -scheduler = "linear" -min_lr_ratio = 0.0 +lr_decay_type = "linear" +lr_min = 0.0 [training] batch_size = 8 diff --git a/torchtitan/models/llama/train_configs/llama3_405b.toml b/torchtitan/models/llama/train_configs/llama3_405b.toml index fb756a3b19..b533271fb3 100644 --- a/torchtitan/models/llama/train_configs/llama3_405b.toml +++ b/torchtitan/models/llama/train_configs/llama3_405b.toml @@ -26,8 +26,8 @@ converters = "float8" name = "AdamW" lr = 8e-5 eps = 1e-8 -scheduler = "linear" -min_lr_ratio = 0.0 +lr_decay_type = "linear" +lr_min = 0.0 [training] batch_size = 2 diff --git a/torchtitan/models/llama/train_configs/llama3_70b.toml b/torchtitan/models/llama/train_configs/llama3_70b.toml index 0bf7f21cf1..235b252b0c 100644 --- a/torchtitan/models/llama/train_configs/llama3_70b.toml +++ b/torchtitan/models/llama/train_configs/llama3_70b.toml @@ -26,8 +26,8 @@ tokenizer_path = "./assets/tokenizer/original/tokenizer.model" name = "AdamW" lr = 1.5e-4 eps = 1e-8 -scheduler = "linear" -min_lr_ratio = 0.0 +lr_decay_type = "linear" +lr_min = 0.0 [training] batch_size = 8 diff --git a/torchtitan/models/llama/train_configs/llama3_8b.toml b/torchtitan/models/llama/train_configs/llama3_8b.toml index a5491cdd11..35b341bd40 100644 --- a/torchtitan/models/llama/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama/train_configs/llama3_8b.toml @@ -26,8 +26,8 @@ tokenizer_path = "./assets/tokenizer/original/tokenizer.model" name = "AdamW" lr = 3e-4 eps = 1e-8 -scheduler = "linear" -min_lr_ratio = 0.0 +lr_decay_type = "linear" +lr_min = 0.0 [training] batch_size = 1 From ed6e1e1b25b6a064388e017e4b7459492e674f98 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Thu, 6 Mar 2025 22:03:49 -0800 Subject: [PATCH 06/29] Configurable `lr_decay_ratio` --- torchtitan/components/optimizer.py | 9 ++++++--- torchtitan/config_manager.py | 10 +++++++--- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index c2fa883ca0..14e2d6bfc5 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -7,7 +7,7 @@ import copy import functools import math -from typing import Any, Callable, Dict, Generic, List, TypeVar +from typing import Any, Callable, Dict, Generic, List, TypeVar, Union import torch @@ -399,7 +399,7 @@ def build_lr_schedulers( def lr_decay_fn( current_step: int, warmup_steps: int, - lr_decay_ratio: float, + lr_decay_ratio: Union[float, None], lr_decay_type: str, lr_min: float, ): @@ -419,7 +419,10 @@ def lr_decay_fn( If `lr_min` is specified, the decay range is scaled from 1 to `lr_min` to ensure the learning rate does not drop below this minimum value. """ - warmup_stable_steps = training_steps * (1 - lr_decay_ratio) + if lr_decay_ratio is None: + warmup_stable_steps = warmup_steps + else: + warmup_stable_steps = training_steps * (1 - lr_decay_ratio) if current_step < warmup_steps: # linear warmup # 0-indexed step, hence + 1 adjustments diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index fbdbe33e46..5639466409 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -236,11 +236,15 @@ def __init__(self): self.parser.add_argument( "--optimizer.lr_decay_ratio", type=float, - default=0.1, + default=None, help=""" - The ratio of the learning rate decay period. + Controls the proportion of the training steps allocated to the learning rate decay phase. + + If `None`, the learning rate will begin decaying immediately after the warmup period. + Otherwise, the learning rate will remain stable after the warmup period and + only start decaying during the last `lr_decay_ratio` portion of the total training steps. - If specified, the learning rate decay will only occur during the last `lr_decay_ratio` portion of the total training steps. + This is known as the Warmup-Stable-Decay (WSD) schedule, as described in https://arxiv.org/abs/2404.06395. """, ) self.parser.add_argument( From 12e83be9ebce73ccea80a3fde88b7ecff1290da8 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Fri, 7 Mar 2025 00:32:44 -0800 Subject: [PATCH 07/29] [Scheduler] Rename `lr_decay_fn` to `linear_warmup_stable_decay` --- torchtitan/components/optimizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 14e2d6bfc5..0c6e77d36d 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -396,7 +396,7 @@ def build_lr_schedulers( lr_decay_type = job_config.optimizer.lr_decay_type lr_min = job_config.optimizer.lr_min - def lr_decay_fn( + def linear_warmup_stable_decay( current_step: int, warmup_steps: int, lr_decay_ratio: Union[float, None], @@ -448,7 +448,7 @@ def lr_decay_fn( return curr_adjustment lr_lambda = functools.partial( - lr_decay_fn, + linear_warmup_stable_decay, warmup_steps=warmup_steps, lr_decay_ratio=lr_decay_ratio, lr_decay_type=lr_decay_type, From d9b91a5b0bbb2acce5b20b755ba79f02ec607af3 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Fri, 7 Mar 2025 00:47:44 -0800 Subject: [PATCH 08/29] Delete `lr_decay_type` check in `build_lr_schedulers` --- torchtitan/components/optimizer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 0c6e77d36d..8939a3399d 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -440,10 +440,6 @@ def linear_warmup_stable_decay( curr_adjustment = 1 - math.sqrt(progress) elif lr_decay_type == "cosine": curr_adjustment = 0.5 * (1.0 + math.cos(math.pi * progress)) - else: - raise ValueError( - f"LR decay type {lr_decay_type} is not in ['linear', 'sqrt', 'cosine']" - ) curr_adjustment = lr_min + (1 - lr_min) * curr_adjustment return curr_adjustment From bbc82b2a6b455649618ef0da6e0e4d241ef1f9ea Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Fri, 7 Mar 2025 09:48:42 -0800 Subject: [PATCH 09/29] Revert changes on train.py --- torchtitan/train.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 3a24160c03..630626dbec 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -370,7 +370,6 @@ def main(job_config: JobConfig): time_delta = time.perf_counter() - time_last_log - last_lr = lr_schedulers.schedulers[0].get_last_lr()[0] # tokens per second per device, abbreviated as tps tps = ntokens_since_last_log / ( time_delta * parallel_dims.non_data_parallel_size @@ -393,7 +392,6 @@ def main(job_config: JobConfig): "throughput(tps)": tps, "tflops": tflops, "mfu(%)": mfu, - "optim/learning_rate": last_lr, "time_metrics/end_to_end(s)": time_end_to_end, "time_metrics/data_loading(s)": time_data_loading, "time_metrics/data_loading(%)": time_data_loading_pct, @@ -412,8 +410,8 @@ def main(job_config: JobConfig): f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB" f"({device_mem_stats.max_reserved_pct:.2f}%) " f"{color.blue}tps: {round(tps):,} " - f"{color.cyan}lr: {last_lr:.4e} " - f"{color.magenta}tflops: {tflops:,.2f} mfu: {mfu:.2f}%{color.reset}" + f"{color.cyan}tflops: {tflops:,.2f} " + f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" ) ntokens_since_last_log = 0 From 2230d3ae32fb0c1f4d23d71fc81cbf1116970b0f Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 10 Mar 2025 01:38:49 -0700 Subject: [PATCH 10/29] [Config] Move scheduler-related params to [scheduler] section --- docs/converging.md | 16 ++-- torchtitan/components/optimizer.py | 10 +-- torchtitan/config_manager.py | 73 ++++++++++--------- .../llama/train_configs/debug_model.toml | 6 +- .../llama/train_configs/llama3_405b.toml | 6 +- .../llama/train_configs/llama3_70b.toml | 6 +- .../models/llama/train_configs/llama3_8b.toml | 6 +- torchtitan/train.py | 8 +- 8 files changed, 69 insertions(+), 62 deletions(-) diff --git a/docs/converging.md b/docs/converging.md index b098e35991..39b5154597 100644 --- a/docs/converging.md +++ b/docs/converging.md @@ -42,14 +42,14 @@ Results are obtained on 2025/01/21, with the latest `torch`, `torchao`, and `tor - Base config: [torchtitan/models/llama/train_configs/llama3_8b.toml](../torchtitan/models/llama/train_configs/llama3_8b.toml) - `training.batch_size = 4`, which is a minimum for Pipeline Parallel with `pipeline_parallel_degree = 2` and `pipeline_parallel_schedule = "Interleaved1F1B"` - `training.data_parallel_shard_degree = 8`, resulting in global batch size 32 -- `training.steps = 3000`, `training.warmup_steps = 600` - -| Parallelism | Techniques | Remarks | -| ----- | ----- | ----- | -| FSDP 8 | default | 1D control set | -| FSDP 8, TP 2, PP 2 | torch.compile, Float8, async TP, Interleaved 1F1B | 3D test set | -| FSDP 8, TP 2, CP 2, PP 2 | torch.compile, Float8, async TP, Interleaved 1F1B | 4D test set | -| FSDP 8, CP 8 | default | to verify CP with a larger degree | +- `training.steps = 3000`, `scheduler.warmup_steps = 600` + +| Parallelism | Techniques | Remarks | +| ------------------------ | ------------------------------------------------- | --------------------------------- | +| FSDP 8 | default | 1D control set | +| FSDP 8, TP 2, PP 2 | torch.compile, Float8, async TP, Interleaved 1F1B | 3D test set | +| FSDP 8, TP 2, CP 2, PP 2 | torch.compile, Float8, async TP, Interleaved 1F1B | 4D test set | +| FSDP 8, CP 8 | default | to verify CP with a larger degree | ### Test results ![image](../assets/images/loss_curves.png) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 8939a3399d..3aca906d36 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -364,7 +364,7 @@ def state_dict(self) -> Dict[str, Any]: def load_state_dict(self, state_dict: Dict[str, Any]) -> None: # Load the same state_dict for all schedulers. The key value we're concerned # within ``LRScheduler.state_dict()`` is ``last_epoch``, which is an integer - # that is immutable. As long as ``training.steps`` and ``training.warmup_steps`` + # that is immutable. As long as ``training.steps`` and ``scheduler.warmup_steps`` # in ``job_config`` remain unchanged when resuming from a checkpoint, this # approach is safe. We call ``copy()`` here to ensure extra safety. for scheduler in self.schedulers: @@ -390,11 +390,11 @@ def build_lr_schedulers( optimizers (OptimizersContainer): The corresponding optimizers for the lr_schedulers. """ - warmup_steps = int(job_config.training.warmup_steps) training_steps = job_config.training.steps - lr_decay_ratio = job_config.optimizer.lr_decay_ratio - lr_decay_type = job_config.optimizer.lr_decay_type - lr_min = job_config.optimizer.lr_min + warmup_steps = int(job_config.scheduler.warmup_steps) + lr_decay_ratio = job_config.scheduler.decay_ratio + lr_decay_type = job_config.scheduler.decay_type + lr_min = job_config.scheduler.lr_min def linear_warmup_stable_decay( current_step: int, diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 5639466409..7fa0c5d594 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -234,7 +234,39 @@ def __init__(self): "--optimizer.lr", type=float, default=8e-4, help="Learning rate to use" ) self.parser.add_argument( - "--optimizer.lr_decay_ratio", + "--optimizer.eps", type=float, default=1e-8, help="Epsilon value to use" + ) + self.parser.add_argument( + "--optimizer.implementation", + type=str, + default="fused", + choices=["for-loop", "foreach", "fused"], + help=""" + Specify which optimizer implementation to use: + - 'fused': Use fused implementation (CUDA only) for best performance. + - 'foreach': Use some horizontal fusion of tensors for better performance. + - 'for-loop': Use the default implementation for the optimizer (slowest). + - more info: https://pytorch.org/docs/stable/optim.html + """, + ) + self.parser.add_argument( + "--optimizer.early_step_in_backward", + action="store_true", + help=""" + Whether to apply optimizer in the backward. Caution, optimizer_in_backward + is not compatible with gradients clipping, users should not call + register_post_accumulate_grad_hook after the optimizer is built.""", + ) + + # lr scheduler configs + self.parser.add_argument( + "--scheduler.warmup_steps", + type=int, + default=200, + help="Steps for lr scheduler warmup, normally 1/5 of --training.steps", + ) + self.parser.add_argument( + "--scheduler.decay_ratio", type=float, default=None, help=""" @@ -242,13 +274,13 @@ def __init__(self): If `None`, the learning rate will begin decaying immediately after the warmup period. Otherwise, the learning rate will remain stable after the warmup period and - only start decaying during the last `lr_decay_ratio` portion of the total training steps. + only start decaying during the last `decay_ratio` portion of the total training steps. This is known as the Warmup-Stable-Decay (WSD) schedule, as described in https://arxiv.org/abs/2404.06395. """, ) self.parser.add_argument( - "--optimizer.lr_decay_type", + "--scheduler.decay_type", type=str, default="linear", choices=["linear", "sqrt", "cosine"], @@ -260,39 +292,16 @@ def __init__(self): """, ) self.parser.add_argument( - "--optimizer.lr_min", + "--scheduler.lr_min", type=float, default=0.0, help=""" Min lr ratio for lr scheduler. - If provided, the range of decay factor is scaled from 1 to `lr_min` to ensure the learning rate does not drop below `lr * lr_min`. + If provided, the range of decay factor is scaled from 1 to `lr_min` + to ensure the learning rate does not drop below `optimizer.lr * scheduler.lr_min`. """, ) - self.parser.add_argument( - "--optimizer.eps", type=float, default=1e-8, help="Epsilon value to use" - ) - self.parser.add_argument( - "--optimizer.implementation", - type=str, - default="fused", - choices=["for-loop", "foreach", "fused"], - help=""" - Specify which optimizer implementation to use: - - 'fused': Use fused implementation (CUDA only) for best performance. - - 'foreach': Use some horizontal fusion of tensors for better performance. - - 'for-loop': Use the default implementation for the optimizer (slowest). - - more info: https://pytorch.org/docs/stable/optim.html - """, - ) - self.parser.add_argument( - "--optimizer.early_step_in_backward", - action="store_true", - help=""" - Whether to apply optimizer in the backward. Caution, optimizer_in_backward - is not compatible with gradients clipping, users should not call - register_post_accumulate_grad_hook after the optimizer is built.""", - ) # training configs self.parser.add_argument( @@ -311,12 +320,6 @@ def __init__(self): self.parser.add_argument( "--training.seq_len", type=int, default=2048, help="Sequence length" ) - self.parser.add_argument( - "--training.warmup_steps", - type=int, - default=200, - help="Steps for lr scheduler warmup, normally 1/5 of --training.steps", - ) self.parser.add_argument( "--training.max_norm", type=Union[float, int], diff --git a/torchtitan/models/llama/train_configs/debug_model.toml b/torchtitan/models/llama/train_configs/debug_model.toml index 1305d1b58c..8050d4b4a4 100644 --- a/torchtitan/models/llama/train_configs/debug_model.toml +++ b/torchtitan/models/llama/train_configs/debug_model.toml @@ -32,13 +32,15 @@ tokenizer_path = "./tests/assets/test_tiktoken.model" name = "AdamW" lr = 8e-4 eps = 1e-8 -lr_decay_type = "linear" + +[scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_type = "linear" lr_min = 0.0 [training] batch_size = 8 seq_len = 2048 -warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_replicate_degree = 1 diff --git a/torchtitan/models/llama/train_configs/llama3_405b.toml b/torchtitan/models/llama/train_configs/llama3_405b.toml index b533271fb3..dc14d3e2c7 100644 --- a/torchtitan/models/llama/train_configs/llama3_405b.toml +++ b/torchtitan/models/llama/train_configs/llama3_405b.toml @@ -26,13 +26,13 @@ converters = "float8" name = "AdamW" lr = 8e-5 eps = 1e-8 -lr_decay_type = "linear" -lr_min = 0.0 + +[scheduler] +warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps [training] batch_size = 2 seq_len = 8192 -warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 3000 data_parallel_replicate_degree = 1 diff --git a/torchtitan/models/llama/train_configs/llama3_70b.toml b/torchtitan/models/llama/train_configs/llama3_70b.toml index 235b252b0c..a323203005 100644 --- a/torchtitan/models/llama/train_configs/llama3_70b.toml +++ b/torchtitan/models/llama/train_configs/llama3_70b.toml @@ -26,13 +26,13 @@ tokenizer_path = "./assets/tokenizer/original/tokenizer.model" name = "AdamW" lr = 1.5e-4 eps = 1e-8 -lr_decay_type = "linear" -lr_min = 0.0 + +[scheduler] +warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps [training] batch_size = 8 seq_len = 8192 -warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_replicate_degree = 1 diff --git a/torchtitan/models/llama/train_configs/llama3_8b.toml b/torchtitan/models/llama/train_configs/llama3_8b.toml index 35b341bd40..b219f529ac 100644 --- a/torchtitan/models/llama/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama/train_configs/llama3_8b.toml @@ -26,13 +26,13 @@ tokenizer_path = "./assets/tokenizer/original/tokenizer.model" name = "AdamW" lr = 3e-4 eps = 1e-8 -lr_decay_type = "linear" -lr_min = 0.0 + +[scheduler] +warmup_steps = 200 # lr scheduler warm up [training] batch_size = 1 seq_len = 8192 -warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_replicate_degree = 1 diff --git a/torchtitan/train.py b/torchtitan/train.py index 630626dbec..e5bbad21cc 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -270,7 +270,7 @@ def main(job_config: JobConfig): f"global batch size {job_config.training.batch_size * dp_degree}, " f"sequence length {job_config.training.seq_len}, " f"total steps {job_config.training.steps} " - f"(warmup {job_config.training.warmup_steps})" + f"(warmup {job_config.scheduler.warmup_steps})" ) with maybe_enable_profiling( job_config, global_step=train_state.step @@ -370,6 +370,7 @@ def main(job_config: JobConfig): time_delta = time.perf_counter() - time_last_log + last_lr = lr_schedulers.schedulers[0].get_last_lr()[0] # tokens per second per device, abbreviated as tps tps = ntokens_since_last_log / ( time_delta * parallel_dims.non_data_parallel_size @@ -392,6 +393,7 @@ def main(job_config: JobConfig): "throughput(tps)": tps, "tflops": tflops, "mfu(%)": mfu, + "optim/learning_rate": last_lr, "time_metrics/end_to_end(s)": time_end_to_end, "time_metrics/data_loading(s)": time_data_loading, "time_metrics/data_loading(%)": time_data_loading_pct, @@ -410,8 +412,8 @@ def main(job_config: JobConfig): f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB" f"({device_mem_stats.max_reserved_pct:.2f}%) " f"{color.blue}tps: {round(tps):,} " - f"{color.cyan}tflops: {tflops:,.2f} " - f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" + f"{color.cyan}lr: {last_lr:.4e} " + f"{color.magenta}tflops: {tflops:,.2f} mfu: {mfu:.2f}%{color.reset}" ) ntokens_since_last_log = 0 From 01b4b6291e9da9aa7467b1bde8c0b1ba3959f9c4 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 10 Mar 2025 16:41:21 +0800 Subject: [PATCH 11/29] Update train.py --- torchtitan/train.py | 122 +++++++++++--------------------------------- 1 file changed, 31 insertions(+), 91 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index e5bbad21cc..dbc1ca01f1 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -13,6 +13,10 @@ from torchtitan.components.checkpoint import CheckpointManager, TrainState from torchtitan.components.ft import FTParallelDims, init_ft_manager +from torchtitan.components.metrics import ( + build_metrics_processor, + ensure_pp_loss_visible, +) from torchtitan.config_manager import JobConfig from torchtitan.distributed import ParallelDims, utils as dist_utils @@ -21,7 +25,6 @@ from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger -from torchtitan.tools.metrics import build_device_memory_monitor, build_metric_logger from torchtitan.tools.profiling import ( maybe_enable_memory_snapshot, maybe_enable_profiling, @@ -39,9 +42,6 @@ def main(job_config: JobConfig): if job_config.job.print_args: logger.info(f"Running with args: {job_config.to_dict()}") - # used for colorful printing - color = utils.NoColor if job_config.metrics.disable_color_printing else utils.Color - # take control of garbage collection to avoid stragglers gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq) @@ -75,10 +75,6 @@ def main(job_config: JobConfig): ft_manager=ft_manager, ) dist_utils.init_distributed(job_config) - # initialize device memory monitor and get peak flops for MFU calculation - device_memory_monitor = build_device_memory_monitor() - gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) - logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}") # build meshes world_mesh = parallel_dims.build_mesh(device_type=device_type) @@ -132,9 +128,18 @@ def main(job_config: JobConfig): model_converters = build_model_converters(job_config, parallel_dims) model_converters.convert(model) + # metrics logging + build_metrics_processor_fn = ( + build_metrics_processor + if train_spec.build_metrics_processor_fn is None + else train_spec.build_metrics_processor_fn + ) + metrics_processor = build_metrics_processor_fn(job_config, parallel_dims) + color = metrics_processor.color + # log model size model_param_count = utils.get_num_params(model) - num_flop_per_token = utils.get_num_flop_per_token( + metrics_processor.num_flop_per_token = utils.get_num_flop_per_token( utils.get_num_params(model, exclude_embedding=True), model_config, job_config.training.seq_len, @@ -185,6 +190,10 @@ def main(job_config: JobConfig): with torch.no_grad(): m.init_weights(buffer_device=buffer_device) m.train() + + # confirm that user will be able to view loss metrics on the console + ensure_pp_loss_visible(parallel_dims, job_config, color) + else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config) @@ -195,6 +204,10 @@ def main(job_config: JobConfig): model_parts = [model] + # initialize device memory monitor and get peak flops for MFU calculation + device_memory_monitor = metrics_processor.device_memory_monitor + gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) + logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}") device_mem_stats = device_memory_monitor.get_peak_stats() logger.info( f"{device_type.upper()} memory usage for model: " @@ -237,18 +250,6 @@ def main(job_config: JobConfig): return checkpoint.load(step=job_config.checkpoint.load_step) - metric_logger = build_metric_logger(job_config, parallel_dims) - - # plot losses loaded from checkpoint (if any) to TensorBoard - # NOTE: Loss info after the last log step before checkpoint saving will not be ploted. - # This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq - if train_state.step > 0 and not job_config.metrics.disable_logging_from_checkpoint: - for idx, step in enumerate(train_state.log_steps): - metrics = { - "loss_metrics/global_avg_loss": train_state.global_avg_losses[idx], - "loss_metrics/global_max_loss": train_state.global_max_losses[idx], - } - metric_logger.log(metrics, step=step) data_iterator = iter(dataloader) @@ -257,12 +258,6 @@ def main(job_config: JobConfig): job_config.experimental.enable_compiled_autograd, ) - # variables used to keep info for metrics logging - ntokens_since_last_log = 0 - data_loading_times = [] - time_last_log = time.perf_counter() - device_memory_monitor.reset_peak_stats() - # train loop logger.info( f"Training starts at step {train_state.step + 1}, " @@ -270,7 +265,7 @@ def main(job_config: JobConfig): f"global batch size {job_config.training.batch_size * dp_degree}, " f"sequence length {job_config.training.seq_len}, " f"total steps {job_config.training.steps} " - f"(warmup {job_config.scheduler.warmup_steps})" + f"(warmup {job_config.training.warmup_steps})" ) with maybe_enable_profiling( job_config, global_step=train_state.step @@ -285,8 +280,10 @@ def main(job_config: JobConfig): data_load_start = time.perf_counter() batch = next(data_iterator) input_ids, labels = batch - ntokens_since_last_log += labels.numel() - data_loading_times.append(time.perf_counter() - data_load_start) + metrics_processor.ntokens_since_last_log += labels.numel() + metrics_processor.data_loading_times.append( + time.perf_counter() - data_load_start + ) input_ids = input_ids.to(device_type) labels = labels.to(device_type) @@ -346,10 +343,7 @@ def main(job_config: JobConfig): lr_schedulers.step() # log metrics - if ( - train_state.step == 1 - or train_state.step % job_config.metrics.log_freq == 0 - ): + if metrics_processor.should_log(train_state.step): if ( parallel_dims.dp_replicate_enabled or parallel_dims.dp_shard_enabled @@ -363,63 +357,9 @@ def main(job_config: JobConfig): else: global_avg_loss = global_max_loss = loss.item() - # update train state - train_state.log_steps.append(train_state.step) - train_state.global_avg_losses.append(global_avg_loss) - train_state.global_max_losses.append(global_max_loss) - - time_delta = time.perf_counter() - time_last_log - - last_lr = lr_schedulers.schedulers[0].get_last_lr()[0] - # tokens per second per device, abbreviated as tps - tps = ntokens_since_last_log / ( - time_delta * parallel_dims.non_data_parallel_size + metrics_processor.log( + train_state.step, global_avg_loss, global_max_loss ) - # model FLOPS utilization - # For its definition and calculation, please refer to the PaLM paper: - # https://arxiv.org/abs/2204.02311 - mfu = 100 * num_flop_per_token * tps / gpu_peak_flops - tflops = num_flop_per_token * tps / 1e12 - - time_end_to_end = time_delta / job_config.metrics.log_freq - time_data_loading = sum(data_loading_times) / len(data_loading_times) - time_data_loading_pct = 100 * sum(data_loading_times) / time_delta - - device_mem_stats = device_memory_monitor.get_peak_stats() - - metrics = { - "loss_metrics/global_avg_loss": global_avg_loss, - "loss_metrics/global_max_loss": global_max_loss, - "throughput(tps)": tps, - "tflops": tflops, - "mfu(%)": mfu, - "optim/learning_rate": last_lr, - "time_metrics/end_to_end(s)": time_end_to_end, - "time_metrics/data_loading(s)": time_data_loading, - "time_metrics/data_loading(%)": time_data_loading_pct, - "memory/max_active(GiB)": device_mem_stats.max_active_gib, - "memory/max_active(%)": device_mem_stats.max_active_pct, - "memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib, - "memory/max_reserved(%)": device_mem_stats.max_reserved_pct, - "memory/num_alloc_retries": device_mem_stats.num_alloc_retries, - "memory/num_ooms": device_mem_stats.num_ooms, - } - metric_logger.log(metrics, step=train_state.step) - - logger.info( - f"{color.red}step: {train_state.step:2} " - f"{color.green}loss: {global_avg_loss:7.4f} " - f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB" - f"({device_mem_stats.max_reserved_pct:.2f}%) " - f"{color.blue}tps: {round(tps):,} " - f"{color.cyan}lr: {last_lr:.4e} " - f"{color.magenta}tflops: {tflops:,.2f} mfu: {mfu:.2f}%{color.reset}" - ) - - ntokens_since_last_log = 0 - data_loading_times.clear() - time_last_log = time.perf_counter() - device_memory_monitor.reset_peak_stats() checkpoint.save( train_state.step, force=(train_state.step == job_config.training.steps) @@ -443,7 +383,7 @@ def main(job_config: JobConfig): logger.info("Sleeping 2 seconds for other ranks to complete") time.sleep(2) - metric_logger.close() + metrics_processor.close() logger.info("Training completed") From e24642839e9647673a92345d39a7f34042e243e6 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 10 Mar 2025 16:42:41 +0800 Subject: [PATCH 12/29] Update train.py From 3a14cf50a4660ee346b9afe1d2b674edba916f2f Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 10 Mar 2025 01:59:55 -0700 Subject: [PATCH 13/29] Add all scheduler configs in debug config --- torchtitan/models/llama/train_configs/debug_model.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/models/llama/train_configs/debug_model.toml b/torchtitan/models/llama/train_configs/debug_model.toml index 8050d4b4a4..bb54d56a82 100644 --- a/torchtitan/models/llama/train_configs/debug_model.toml +++ b/torchtitan/models/llama/train_configs/debug_model.toml @@ -35,6 +35,7 @@ eps = 1e-8 [scheduler] warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps decay_type = "linear" lr_min = 0.0 From 69b05dff5c652bfbab7dbc655a1daf4c4b75ebab Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 10 Mar 2025 02:09:05 -0700 Subject: [PATCH 14/29] Add warnings if warmup_stable_steps < warmup_steps --- torchtitan/components/optimizer.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 3aca906d36..9869daada5 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -10,19 +10,17 @@ from typing import Any, Callable, Dict, Generic, List, TypeVar, Union import torch - import torch.nn as nn -from torch.distributed.checkpoint.state_dict import ( - get_optimizer_state_dict, - set_optimizer_state_dict, - StateDictOptions, -) +from torch.distributed.checkpoint.state_dict import (StateDictOptions, + get_optimizer_state_dict, + set_optimizer_state_dict) from torch.distributed.checkpoint.stateful import Stateful from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR, LRScheduler from torchtitan.components.ft import FTManager, has_torchft from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import logger __all__ = [ "OptimizersContainer", @@ -423,6 +421,11 @@ def linear_warmup_stable_decay( warmup_stable_steps = warmup_steps else: warmup_stable_steps = training_steps * (1 - lr_decay_ratio) + if warmup_stable_steps < warmup_steps: + logger.warning( + f"The warmup steps should be less than or equal to the warmup-stable steps ({warmup_stable_steps}). " + f"Consider reducing either the decay ratio ({lr_decay_ratio}) or the warmup steps ({warmup_steps})." + ) if current_step < warmup_steps: # linear warmup # 0-indexed step, hence + 1 adjustments From 827395ce43ff03f50f59c094a19ec875baa7e1df Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 10 Mar 2025 17:11:17 +0800 Subject: [PATCH 15/29] Revert changes on train.py --- torchtitan/train.py | 118 +++++++++++++++++++++++++++++++++----------- 1 file changed, 88 insertions(+), 30 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index dbc1ca01f1..630626dbec 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -13,10 +13,6 @@ from torchtitan.components.checkpoint import CheckpointManager, TrainState from torchtitan.components.ft import FTParallelDims, init_ft_manager -from torchtitan.components.metrics import ( - build_metrics_processor, - ensure_pp_loss_visible, -) from torchtitan.config_manager import JobConfig from torchtitan.distributed import ParallelDims, utils as dist_utils @@ -25,6 +21,7 @@ from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger +from torchtitan.tools.metrics import build_device_memory_monitor, build_metric_logger from torchtitan.tools.profiling import ( maybe_enable_memory_snapshot, maybe_enable_profiling, @@ -42,6 +39,9 @@ def main(job_config: JobConfig): if job_config.job.print_args: logger.info(f"Running with args: {job_config.to_dict()}") + # used for colorful printing + color = utils.NoColor if job_config.metrics.disable_color_printing else utils.Color + # take control of garbage collection to avoid stragglers gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq) @@ -75,6 +75,10 @@ def main(job_config: JobConfig): ft_manager=ft_manager, ) dist_utils.init_distributed(job_config) + # initialize device memory monitor and get peak flops for MFU calculation + device_memory_monitor = build_device_memory_monitor() + gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) + logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}") # build meshes world_mesh = parallel_dims.build_mesh(device_type=device_type) @@ -128,18 +132,9 @@ def main(job_config: JobConfig): model_converters = build_model_converters(job_config, parallel_dims) model_converters.convert(model) - # metrics logging - build_metrics_processor_fn = ( - build_metrics_processor - if train_spec.build_metrics_processor_fn is None - else train_spec.build_metrics_processor_fn - ) - metrics_processor = build_metrics_processor_fn(job_config, parallel_dims) - color = metrics_processor.color - # log model size model_param_count = utils.get_num_params(model) - metrics_processor.num_flop_per_token = utils.get_num_flop_per_token( + num_flop_per_token = utils.get_num_flop_per_token( utils.get_num_params(model, exclude_embedding=True), model_config, job_config.training.seq_len, @@ -190,10 +185,6 @@ def main(job_config: JobConfig): with torch.no_grad(): m.init_weights(buffer_device=buffer_device) m.train() - - # confirm that user will be able to view loss metrics on the console - ensure_pp_loss_visible(parallel_dims, job_config, color) - else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config) @@ -204,10 +195,6 @@ def main(job_config: JobConfig): model_parts = [model] - # initialize device memory monitor and get peak flops for MFU calculation - device_memory_monitor = metrics_processor.device_memory_monitor - gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) - logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}") device_mem_stats = device_memory_monitor.get_peak_stats() logger.info( f"{device_type.upper()} memory usage for model: " @@ -250,6 +237,18 @@ def main(job_config: JobConfig): return checkpoint.load(step=job_config.checkpoint.load_step) + metric_logger = build_metric_logger(job_config, parallel_dims) + + # plot losses loaded from checkpoint (if any) to TensorBoard + # NOTE: Loss info after the last log step before checkpoint saving will not be ploted. + # This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq + if train_state.step > 0 and not job_config.metrics.disable_logging_from_checkpoint: + for idx, step in enumerate(train_state.log_steps): + metrics = { + "loss_metrics/global_avg_loss": train_state.global_avg_losses[idx], + "loss_metrics/global_max_loss": train_state.global_max_losses[idx], + } + metric_logger.log(metrics, step=step) data_iterator = iter(dataloader) @@ -258,6 +257,12 @@ def main(job_config: JobConfig): job_config.experimental.enable_compiled_autograd, ) + # variables used to keep info for metrics logging + ntokens_since_last_log = 0 + data_loading_times = [] + time_last_log = time.perf_counter() + device_memory_monitor.reset_peak_stats() + # train loop logger.info( f"Training starts at step {train_state.step + 1}, " @@ -280,10 +285,8 @@ def main(job_config: JobConfig): data_load_start = time.perf_counter() batch = next(data_iterator) input_ids, labels = batch - metrics_processor.ntokens_since_last_log += labels.numel() - metrics_processor.data_loading_times.append( - time.perf_counter() - data_load_start - ) + ntokens_since_last_log += labels.numel() + data_loading_times.append(time.perf_counter() - data_load_start) input_ids = input_ids.to(device_type) labels = labels.to(device_type) @@ -343,7 +346,10 @@ def main(job_config: JobConfig): lr_schedulers.step() # log metrics - if metrics_processor.should_log(train_state.step): + if ( + train_state.step == 1 + or train_state.step % job_config.metrics.log_freq == 0 + ): if ( parallel_dims.dp_replicate_enabled or parallel_dims.dp_shard_enabled @@ -357,9 +363,61 @@ def main(job_config: JobConfig): else: global_avg_loss = global_max_loss = loss.item() - metrics_processor.log( - train_state.step, global_avg_loss, global_max_loss + # update train state + train_state.log_steps.append(train_state.step) + train_state.global_avg_losses.append(global_avg_loss) + train_state.global_max_losses.append(global_max_loss) + + time_delta = time.perf_counter() - time_last_log + + # tokens per second per device, abbreviated as tps + tps = ntokens_since_last_log / ( + time_delta * parallel_dims.non_data_parallel_size ) + # model FLOPS utilization + # For its definition and calculation, please refer to the PaLM paper: + # https://arxiv.org/abs/2204.02311 + mfu = 100 * num_flop_per_token * tps / gpu_peak_flops + tflops = num_flop_per_token * tps / 1e12 + + time_end_to_end = time_delta / job_config.metrics.log_freq + time_data_loading = sum(data_loading_times) / len(data_loading_times) + time_data_loading_pct = 100 * sum(data_loading_times) / time_delta + + device_mem_stats = device_memory_monitor.get_peak_stats() + + metrics = { + "loss_metrics/global_avg_loss": global_avg_loss, + "loss_metrics/global_max_loss": global_max_loss, + "throughput(tps)": tps, + "tflops": tflops, + "mfu(%)": mfu, + "time_metrics/end_to_end(s)": time_end_to_end, + "time_metrics/data_loading(s)": time_data_loading, + "time_metrics/data_loading(%)": time_data_loading_pct, + "memory/max_active(GiB)": device_mem_stats.max_active_gib, + "memory/max_active(%)": device_mem_stats.max_active_pct, + "memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib, + "memory/max_reserved(%)": device_mem_stats.max_reserved_pct, + "memory/num_alloc_retries": device_mem_stats.num_alloc_retries, + "memory/num_ooms": device_mem_stats.num_ooms, + } + metric_logger.log(metrics, step=train_state.step) + + logger.info( + f"{color.red}step: {train_state.step:2} " + f"{color.green}loss: {global_avg_loss:7.4f} " + f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB" + f"({device_mem_stats.max_reserved_pct:.2f}%) " + f"{color.blue}tps: {round(tps):,} " + f"{color.cyan}tflops: {tflops:,.2f} " + f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" + ) + + ntokens_since_last_log = 0 + data_loading_times.clear() + time_last_log = time.perf_counter() + device_memory_monitor.reset_peak_stats() checkpoint.save( train_state.step, force=(train_state.step == job_config.training.steps) @@ -383,7 +441,7 @@ def main(job_config: JobConfig): logger.info("Sleeping 2 seconds for other ranks to complete") time.sleep(2) - metrics_processor.close() + metric_logger.close() logger.info("Training completed") From f3293ab2b712aa21fc7a40ef482cfd8d87c5c993 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 10 Mar 2025 02:13:01 -0700 Subject: [PATCH 16/29] Obey the code format --- torchtitan/components/optimizer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 9869daada5..9916e8f9a0 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -11,9 +11,11 @@ import torch import torch.nn as nn -from torch.distributed.checkpoint.state_dict import (StateDictOptions, - get_optimizer_state_dict, - set_optimizer_state_dict) +from torch.distributed.checkpoint.state_dict import ( + get_optimizer_state_dict, + set_optimizer_state_dict, + StateDictOptions, +) from torch.distributed.checkpoint.stateful import Stateful from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR, LRScheduler From 72a0286b0d341c7ca080f6f938c69a6372e98064 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 10 Mar 2025 02:48:50 -0700 Subject: [PATCH 17/29] int type warmup_stable_steps --- torchtitan/components/optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 9916e8f9a0..b5a4a954d6 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -422,7 +422,7 @@ def linear_warmup_stable_decay( if lr_decay_ratio is None: warmup_stable_steps = warmup_steps else: - warmup_stable_steps = training_steps * (1 - lr_decay_ratio) + warmup_stable_steps = round(training_steps * (1 - lr_decay_ratio)) if warmup_stable_steps < warmup_steps: logger.warning( f"The warmup steps should be less than or equal to the warmup-stable steps ({warmup_stable_steps}). " From 2e2b6b427b40e7baca0bdaadc92e66ebc3492383 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 10 Mar 2025 17:57:16 +0800 Subject: [PATCH 18/29] Rename `training.warmup_steps` to `scheduler.warmup_steps` --- torchtitan/train.py | 120 ++++++++++++-------------------------------- 1 file changed, 31 insertions(+), 89 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 630626dbec..5cf9b0cf25 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -13,6 +13,10 @@ from torchtitan.components.checkpoint import CheckpointManager, TrainState from torchtitan.components.ft import FTParallelDims, init_ft_manager +from torchtitan.components.metrics import ( + build_metrics_processor, + ensure_pp_loss_visible, +) from torchtitan.config_manager import JobConfig from torchtitan.distributed import ParallelDims, utils as dist_utils @@ -21,7 +25,6 @@ from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger -from torchtitan.tools.metrics import build_device_memory_monitor, build_metric_logger from torchtitan.tools.profiling import ( maybe_enable_memory_snapshot, maybe_enable_profiling, @@ -39,9 +42,6 @@ def main(job_config: JobConfig): if job_config.job.print_args: logger.info(f"Running with args: {job_config.to_dict()}") - # used for colorful printing - color = utils.NoColor if job_config.metrics.disable_color_printing else utils.Color - # take control of garbage collection to avoid stragglers gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq) @@ -75,10 +75,6 @@ def main(job_config: JobConfig): ft_manager=ft_manager, ) dist_utils.init_distributed(job_config) - # initialize device memory monitor and get peak flops for MFU calculation - device_memory_monitor = build_device_memory_monitor() - gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) - logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}") # build meshes world_mesh = parallel_dims.build_mesh(device_type=device_type) @@ -132,9 +128,18 @@ def main(job_config: JobConfig): model_converters = build_model_converters(job_config, parallel_dims) model_converters.convert(model) + # metrics logging + build_metrics_processor_fn = ( + build_metrics_processor + if train_spec.build_metrics_processor_fn is None + else train_spec.build_metrics_processor_fn + ) + metrics_processor = build_metrics_processor_fn(job_config, parallel_dims) + color = metrics_processor.color + # log model size model_param_count = utils.get_num_params(model) - num_flop_per_token = utils.get_num_flop_per_token( + metrics_processor.num_flop_per_token = utils.get_num_flop_per_token( utils.get_num_params(model, exclude_embedding=True), model_config, job_config.training.seq_len, @@ -185,6 +190,10 @@ def main(job_config: JobConfig): with torch.no_grad(): m.init_weights(buffer_device=buffer_device) m.train() + + # confirm that user will be able to view loss metrics on the console + ensure_pp_loss_visible(parallel_dims, job_config, color) + else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config) @@ -195,6 +204,10 @@ def main(job_config: JobConfig): model_parts = [model] + # initialize device memory monitor and get peak flops for MFU calculation + device_memory_monitor = metrics_processor.device_memory_monitor + gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) + logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}") device_mem_stats = device_memory_monitor.get_peak_stats() logger.info( f"{device_type.upper()} memory usage for model: " @@ -237,18 +250,6 @@ def main(job_config: JobConfig): return checkpoint.load(step=job_config.checkpoint.load_step) - metric_logger = build_metric_logger(job_config, parallel_dims) - - # plot losses loaded from checkpoint (if any) to TensorBoard - # NOTE: Loss info after the last log step before checkpoint saving will not be ploted. - # This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq - if train_state.step > 0 and not job_config.metrics.disable_logging_from_checkpoint: - for idx, step in enumerate(train_state.log_steps): - metrics = { - "loss_metrics/global_avg_loss": train_state.global_avg_losses[idx], - "loss_metrics/global_max_loss": train_state.global_max_losses[idx], - } - metric_logger.log(metrics, step=step) data_iterator = iter(dataloader) @@ -257,12 +258,6 @@ def main(job_config: JobConfig): job_config.experimental.enable_compiled_autograd, ) - # variables used to keep info for metrics logging - ntokens_since_last_log = 0 - data_loading_times = [] - time_last_log = time.perf_counter() - device_memory_monitor.reset_peak_stats() - # train loop logger.info( f"Training starts at step {train_state.step + 1}, " @@ -270,7 +265,7 @@ def main(job_config: JobConfig): f"global batch size {job_config.training.batch_size * dp_degree}, " f"sequence length {job_config.training.seq_len}, " f"total steps {job_config.training.steps} " - f"(warmup {job_config.training.warmup_steps})" + f"(warmup {job_config.scheduler.warmup_steps})" ) with maybe_enable_profiling( job_config, global_step=train_state.step @@ -285,8 +280,10 @@ def main(job_config: JobConfig): data_load_start = time.perf_counter() batch = next(data_iterator) input_ids, labels = batch - ntokens_since_last_log += labels.numel() - data_loading_times.append(time.perf_counter() - data_load_start) + metrics_processor.ntokens_since_last_log += labels.numel() + metrics_processor.data_loading_times.append( + time.perf_counter() - data_load_start + ) input_ids = input_ids.to(device_type) labels = labels.to(device_type) @@ -346,10 +343,7 @@ def main(job_config: JobConfig): lr_schedulers.step() # log metrics - if ( - train_state.step == 1 - or train_state.step % job_config.metrics.log_freq == 0 - ): + if metrics_processor.should_log(train_state.step): if ( parallel_dims.dp_replicate_enabled or parallel_dims.dp_shard_enabled @@ -363,61 +357,9 @@ def main(job_config: JobConfig): else: global_avg_loss = global_max_loss = loss.item() - # update train state - train_state.log_steps.append(train_state.step) - train_state.global_avg_losses.append(global_avg_loss) - train_state.global_max_losses.append(global_max_loss) - - time_delta = time.perf_counter() - time_last_log - - # tokens per second per device, abbreviated as tps - tps = ntokens_since_last_log / ( - time_delta * parallel_dims.non_data_parallel_size + metrics_processor.log( + train_state.step, global_avg_loss, global_max_loss ) - # model FLOPS utilization - # For its definition and calculation, please refer to the PaLM paper: - # https://arxiv.org/abs/2204.02311 - mfu = 100 * num_flop_per_token * tps / gpu_peak_flops - tflops = num_flop_per_token * tps / 1e12 - - time_end_to_end = time_delta / job_config.metrics.log_freq - time_data_loading = sum(data_loading_times) / len(data_loading_times) - time_data_loading_pct = 100 * sum(data_loading_times) / time_delta - - device_mem_stats = device_memory_monitor.get_peak_stats() - - metrics = { - "loss_metrics/global_avg_loss": global_avg_loss, - "loss_metrics/global_max_loss": global_max_loss, - "throughput(tps)": tps, - "tflops": tflops, - "mfu(%)": mfu, - "time_metrics/end_to_end(s)": time_end_to_end, - "time_metrics/data_loading(s)": time_data_loading, - "time_metrics/data_loading(%)": time_data_loading_pct, - "memory/max_active(GiB)": device_mem_stats.max_active_gib, - "memory/max_active(%)": device_mem_stats.max_active_pct, - "memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib, - "memory/max_reserved(%)": device_mem_stats.max_reserved_pct, - "memory/num_alloc_retries": device_mem_stats.num_alloc_retries, - "memory/num_ooms": device_mem_stats.num_ooms, - } - metric_logger.log(metrics, step=train_state.step) - - logger.info( - f"{color.red}step: {train_state.step:2} " - f"{color.green}loss: {global_avg_loss:7.4f} " - f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB" - f"({device_mem_stats.max_reserved_pct:.2f}%) " - f"{color.blue}tps: {round(tps):,} " - f"{color.cyan}tflops: {tflops:,.2f} " - f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" - ) - - ntokens_since_last_log = 0 - data_loading_times.clear() - time_last_log = time.perf_counter() - device_memory_monitor.reset_peak_stats() checkpoint.save( train_state.step, force=(train_state.step == job_config.training.steps) @@ -441,7 +383,7 @@ def main(job_config: JobConfig): logger.info("Sleeping 2 seconds for other ranks to complete") time.sleep(2) - metric_logger.close() + metrics_processor.close() logger.info("Training completed") From 698d63c8fb8616156bc5598c243d63589a8d418e Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 10 Mar 2025 14:41:03 -0700 Subject: [PATCH 19/29] Rename `scheduler` to `lr_scheduler` --- docs/converging.md | 2 +- torchtitan/components/optimizer.py | 10 +++++----- torchtitan/config_manager.py | 10 +++++----- torchtitan/models/llama/train_configs/debug_model.toml | 2 +- torchtitan/models/llama/train_configs/llama3_405b.toml | 2 +- torchtitan/models/llama/train_configs/llama3_70b.toml | 2 +- torchtitan/models/llama/train_configs/llama3_8b.toml | 2 +- torchtitan/train.py | 4 +--- 8 files changed, 16 insertions(+), 18 deletions(-) diff --git a/docs/converging.md b/docs/converging.md index 39b5154597..c078df0fa9 100644 --- a/docs/converging.md +++ b/docs/converging.md @@ -42,7 +42,7 @@ Results are obtained on 2025/01/21, with the latest `torch`, `torchao`, and `tor - Base config: [torchtitan/models/llama/train_configs/llama3_8b.toml](../torchtitan/models/llama/train_configs/llama3_8b.toml) - `training.batch_size = 4`, which is a minimum for Pipeline Parallel with `pipeline_parallel_degree = 2` and `pipeline_parallel_schedule = "Interleaved1F1B"` - `training.data_parallel_shard_degree = 8`, resulting in global batch size 32 -- `training.steps = 3000`, `scheduler.warmup_steps = 600` +- `training.steps = 3000`, `lr_scheduler.warmup_steps = 600` | Parallelism | Techniques | Remarks | | ------------------------ | ------------------------------------------------- | --------------------------------- | diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index b5a4a954d6..9e0676b881 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -364,7 +364,7 @@ def state_dict(self) -> Dict[str, Any]: def load_state_dict(self, state_dict: Dict[str, Any]) -> None: # Load the same state_dict for all schedulers. The key value we're concerned # within ``LRScheduler.state_dict()`` is ``last_epoch``, which is an integer - # that is immutable. As long as ``training.steps`` and ``scheduler.warmup_steps`` + # that is immutable. As long as ``training.steps`` and ``lr_scheduler.warmup_steps`` # in ``job_config`` remain unchanged when resuming from a checkpoint, this # approach is safe. We call ``copy()`` here to ensure extra safety. for scheduler in self.schedulers: @@ -391,10 +391,10 @@ def build_lr_schedulers( lr_schedulers. """ training_steps = job_config.training.steps - warmup_steps = int(job_config.scheduler.warmup_steps) - lr_decay_ratio = job_config.scheduler.decay_ratio - lr_decay_type = job_config.scheduler.decay_type - lr_min = job_config.scheduler.lr_min + warmup_steps = int(job_config.lr_scheduler.warmup_steps) + lr_decay_ratio = job_config.lr_scheduler.decay_ratio + lr_decay_type = job_config.lr_scheduler.decay_type + lr_min = job_config.lr_scheduler.lr_min def linear_warmup_stable_decay( current_step: int, diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 7fa0c5d594..9939067437 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -260,13 +260,13 @@ def __init__(self): # lr scheduler configs self.parser.add_argument( - "--scheduler.warmup_steps", + "--lr_scheduler.warmup_steps", type=int, default=200, help="Steps for lr scheduler warmup, normally 1/5 of --training.steps", ) self.parser.add_argument( - "--scheduler.decay_ratio", + "--lr_scheduler.decay_ratio", type=float, default=None, help=""" @@ -280,7 +280,7 @@ def __init__(self): """, ) self.parser.add_argument( - "--scheduler.decay_type", + "--lr_scheduler.decay_type", type=str, default="linear", choices=["linear", "sqrt", "cosine"], @@ -292,14 +292,14 @@ def __init__(self): """, ) self.parser.add_argument( - "--scheduler.lr_min", + "--lr_scheduler.lr_min", type=float, default=0.0, help=""" Min lr ratio for lr scheduler. If provided, the range of decay factor is scaled from 1 to `lr_min` - to ensure the learning rate does not drop below `optimizer.lr * scheduler.lr_min`. + to ensure the learning rate does not drop below `optimizer.lr * lr_scheduler.lr_min`. """, ) diff --git a/torchtitan/models/llama/train_configs/debug_model.toml b/torchtitan/models/llama/train_configs/debug_model.toml index bb54d56a82..3d21fbff67 100644 --- a/torchtitan/models/llama/train_configs/debug_model.toml +++ b/torchtitan/models/llama/train_configs/debug_model.toml @@ -33,7 +33,7 @@ name = "AdamW" lr = 8e-4 eps = 1e-8 -[scheduler] +[lr_scheduler] warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps decay_type = "linear" diff --git a/torchtitan/models/llama/train_configs/llama3_405b.toml b/torchtitan/models/llama/train_configs/llama3_405b.toml index dc14d3e2c7..c70ddcbe9b 100644 --- a/torchtitan/models/llama/train_configs/llama3_405b.toml +++ b/torchtitan/models/llama/train_configs/llama3_405b.toml @@ -27,7 +27,7 @@ name = "AdamW" lr = 8e-5 eps = 1e-8 -[scheduler] +[lr_scheduler] warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps [training] diff --git a/torchtitan/models/llama/train_configs/llama3_70b.toml b/torchtitan/models/llama/train_configs/llama3_70b.toml index a323203005..a025a7513a 100644 --- a/torchtitan/models/llama/train_configs/llama3_70b.toml +++ b/torchtitan/models/llama/train_configs/llama3_70b.toml @@ -27,7 +27,7 @@ name = "AdamW" lr = 1.5e-4 eps = 1e-8 -[scheduler] +[lr_scheduler] warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps [training] diff --git a/torchtitan/models/llama/train_configs/llama3_8b.toml b/torchtitan/models/llama/train_configs/llama3_8b.toml index b219f529ac..6086967fa7 100644 --- a/torchtitan/models/llama/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama/train_configs/llama3_8b.toml @@ -27,7 +27,7 @@ name = "AdamW" lr = 3e-4 eps = 1e-8 -[scheduler] +[lr_scheduler] warmup_steps = 200 # lr scheduler warm up [training] diff --git a/torchtitan/train.py b/torchtitan/train.py index 5cf9b0cf25..6634b3da0d 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -19,10 +19,8 @@ ) from torchtitan.config_manager import JobConfig from torchtitan.distributed import ParallelDims, utils as dist_utils - from torchtitan.protocols.model_converter import build_model_converters from torchtitan.protocols.train_spec import get_train_spec - from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger from torchtitan.tools.profiling import ( @@ -265,7 +263,7 @@ def main(job_config: JobConfig): f"global batch size {job_config.training.batch_size * dp_degree}, " f"sequence length {job_config.training.seq_len}, " f"total steps {job_config.training.steps} " - f"(warmup {job_config.scheduler.warmup_steps})" + f"(warmup {job_config.lr_scheduler.warmup_steps})" ) with maybe_enable_profiling( job_config, global_step=train_state.step From 5f742f5ca57e9d98e9328c9a7a872e6aa8806c51 Mon Sep 17 00:00:00 2001 From: Less Wright Date: Wed, 5 Mar 2025 23:14:48 -0800 Subject: [PATCH 20/29] [Legal] Modifications requested by legal for adding additional datasets (#936) Two very minor changes required by Meta legal as part of adding two new datasets. 1 - License verbiage update in readme 2 - copyright header change in BSD-License. --- LICENSE | 2 +- README.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/LICENSE b/LICENSE index bc559a98b9..6d1df98ff6 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ BSD 3-Clause License -Copyright 2024 Meta +(c) Meta Platforms, Inc. and affiliates. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/README.md b/README.md index bec2a74dd4..90e4a14ce7 100644 --- a/README.md +++ b/README.md @@ -131,4 +131,4 @@ We provide a detailed look into the parallelisms and optimizations available in ## License -This code is made available under [BSD 3 license](./LICENSE). However you may have other legal obligations that govern your use of other content, such as the terms of service for third-party models, data, etc. +Source code is made available under a [BSD 3 license](./LICENSE), however you may have other legal obligations that govern your use of other content linked in this repository, such as the license or terms of service for third-party data and models. From e9fe2e5ec0a2aedb1f4a90a63726856ae26b97ec Mon Sep 17 00:00:00 2001 From: "Wei (Will) Feng" <134637289+weifengpy@users.noreply.github.com> Date: Thu, 6 Mar 2025 15:17:09 -0800 Subject: [PATCH 21/29] [FSDP2][doc] highlight set_requires_gradient_sync and ignored_params (#940) * people asks about the FSDP2 equivalance of no_sync, that's `set_requires_gradient_sync` * ignored_params is recently implemented. people start using it already. update the doc --- docs/fsdp.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/fsdp.md b/docs/fsdp.md index 3f632ae683..51d8f03fbf 100644 --- a/docs/fsdp.md +++ b/docs/fsdp.md @@ -45,6 +45,8 @@ def fully_shard( | `forward_prefetch` | not yet implemented | | `limit_all_gathers` | removed | | `use_orig_params` | removed | +| `no_sync` | `set_requires_gradient_sync` | +| `ignored_modules`, `ignored_states` | `ignored_params` | - `fully_shard(module)` is similar to `FullyShardedDataParallel(module)`, constructing one communication bucket from `module.parameters()` except those already assigned to a nested `fully_shard`/`FullyShardedDataParallel` call. - `fully_shard(module)` adds an `FSDPState` object on `module`, accessible via `fully_shard.state(module)`, instead of being an `nn.Module` wrapper. This is done via the `@contract` decorator. @@ -71,7 +73,7 @@ def fully_shard( - FSDP2 always moves managed parameters/buffers to the `mesh`'s corresponding device, removing the need for `device_id`. For example, if `mesh.device_type` is `"cuda"`, then FSDP2 uses the current CUDA device. - FSDP2 uses a new memory management system that preserves communication/computation overlap while achieving deterministic and lower memory usage than FSDP1. This system does not require any CPU synchronization, so there is no need for `limit_all_gathers`. - FSDP2 always "uses the original parameters" since there is no more `FlatParameter`, removing the need for `use_orig_params`. -- How to implement `ignored_modules`/`ignored_states` and `forward_prefetch` in FSDP2 is under discussion. +- How to implement `forward_prefetch` in FSDP2 is under discussion. | FSDP1 | FSDP2 | | ----- | ----- | From f5a9abef497b4d77440f975f9a222846f13c6d77 Mon Sep 17 00:00:00 2001 From: Less Wright Date: Fri, 7 Mar 2025 14:20:08 -0800 Subject: [PATCH 22/29] [PP] Ensure loss is visible on console for users (#946) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is similar in spirit to [PR_944](https://github.com/pytorch/torchtitan/pull/944) (cc @lkhphuc) but takes a slightly different approach. Problem - users that default turn on PP training will get -1 for their loss. This is b/c by default, rank 0 is the only one logged. However, for *most* PP schedules, the loss is output on the last rank. Thus, users see -1 for loss and it's a bad/confusing experience. This PR adds a check to review both the current PP schedule (b/c for VBlocks, loss is returned on 0) and if it is a last rank loss schedule, then it checks that the first rank of the last stage is visible in the LOG_RANK environment variable. If not, it warns the user, using Red for the warning if color is enabled, and highlights the rank they should add in yellow: Screenshot 2025-03-07 at 11 51 46 AM Note that I attempted to then modify the LOG_RANK to add the missing last rank...but it has no effect. This is b/c the --log_rank_filter passed into torchrun is fixed and thus the env has no effect. We can fix this by moving to our own filtering via python log filtering (thanks to @d4l3k for this idea) and then it would auto-update. The tradeoff is that we have to init distributed first (to understand the ranks) meaning that at launch, there's a bit of delay before the first logging. From there, then NCCL warnings are not suppressed b/c they are emitted from .cpp file vs torchrun filtering controls that...so we get some additional console spam. This PR thus sticks to a simple warning with Red highlight (assuming color is on) and provide the user how to fix. --- torchtitan/tools/metrics.py | 34 +++++++++++++++++++++++++++++++++- torchtitan/train.py | 18 ++++++++---------- 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/torchtitan/tools/metrics.py b/torchtitan/tools/metrics.py index f302d25c92..9a99b586c1 100644 --- a/torchtitan/tools/metrics.py +++ b/torchtitan/tools/metrics.py @@ -15,7 +15,7 @@ from torchtitan.config_manager import JobConfig from torchtitan.distributed import ParallelDims from torchtitan.tools.logging import logger -from torchtitan.tools.utils import device_module, device_type +from torchtitan.tools.utils import Color, device_module, device_type # named tuple for passing device memory stats for logging DeviceMemStats = namedtuple( @@ -154,6 +154,38 @@ def close(self) -> None: self.wandb.finish() +def ensure_pp_loss_visible( + parallel_dims: ParallelDims, job_config: JobConfig, color: Color +) -> None: + """ + Ensures that the loss is visible on the console for pipeline-parallel training. + + For pipeline-parallel training, the loss is only visible on the last pipeline stage. + This function checks if the appropriate rank is included in the LOG_RANK environment + variable and warns if it's not. + """ + + # V Block Schedules return loss on rank 0 + if job_config.experimental.pipeline_parallel_schedule == "ZBVZeroBubble": + return + + # Calculate the rank where loss is visible (first rank of the last pipeline stage) + world_size = parallel_dims.world_size + pp_size = parallel_dims.pp + loss_visible_rank = (world_size // pp_size) * (pp_size - 1) + + # Check if the loss-visible rank is included in LOG_RANK environment variable + env_logged_ranks = os.environ.get("LOG_RANK", "").split(",") + if env_logged_ranks == [""]: + env_logged_ranks = [] + + if str(loss_visible_rank) not in env_logged_ranks: + logger.warning( + f"{color.red}Pipeline parallel loss is not visible. " + f"Add {color.yellow}rank {loss_visible_rank}{color.red} to LOG_RANK environment variable in run_train.sh.{color.reset}" + ) + + def _get_metrics_rank( parallel_dims: ParallelDims, job_config: JobConfig, diff --git a/torchtitan/train.py b/torchtitan/train.py index 6634b3da0d..e8dc035c5d 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -13,20 +13,17 @@ from torchtitan.components.checkpoint import CheckpointManager, TrainState from torchtitan.components.ft import FTParallelDims, init_ft_manager -from torchtitan.components.metrics import ( - build_metrics_processor, - ensure_pp_loss_visible, -) +from torchtitan.components.metrics import (build_metrics_processor, + ensure_pp_loss_visible) from torchtitan.config_manager import JobConfig -from torchtitan.distributed import ParallelDims, utils as dist_utils +from torchtitan.distributed import ParallelDims +from torchtitan.distributed import utils as dist_utils from torchtitan.protocols.model_converter import build_model_converters from torchtitan.protocols.train_spec import get_train_spec from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger -from torchtitan.tools.profiling import ( - maybe_enable_memory_snapshot, - maybe_enable_profiling, -) +from torchtitan.tools.profiling import (maybe_enable_memory_snapshot, + maybe_enable_profiling) # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html @@ -222,6 +219,7 @@ def main(job_config: JobConfig): optimizers.register_step_post_hook( lambda *args, **kwargs: model_converters.post_optimizer_hook(model_parts) ) + metrics_processor.optimizers = optimizers train_state = TrainState() @@ -390,4 +388,4 @@ def main(job_config: JobConfig): config = JobConfig() config.parse_args() main(config) - torch.distributed.destroy_process_group() + torch.distributed.destroy_process_group() \ No newline at end of file From 6d8da38376f71ff22604bfe7941c943056bf261e Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 10 Mar 2025 00:46:29 -0700 Subject: [PATCH 23/29] Make MetricsLogger as a component (#945) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #945 MetricsLogger should be a component as its role is similar to CheckpointManager, which provides some functions and has its own states. More importantly, users may want to customize the metrics. Make it a component and can be customized through TrainSpec. Change the name of `MetricsLogger` to `MetricsProcessor` as it not only log but also process metrics. --- scripts/generate/test_generate.py | 2 +- torchtitan/components/dataloader.py | 5 +- torchtitan/{tools => components}/metrics.py | 139 +++++++++++++++++++- torchtitan/config_manager.py | 7 - torchtitan/protocols/train_spec.py | 11 +- torchtitan/train.py | 17 ++- 6 files changed, 155 insertions(+), 26 deletions(-) rename torchtitan/{tools => components}/metrics.py (63%) diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index 735a6a0e35..0fd4e7b00f 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -24,13 +24,13 @@ parallelize_module, RowwiseParallel, ) +from torchtitan.components.metrics import build_device_memory_monitor from torchtitan.config_manager import JobConfig from torchtitan.distributed import ParallelDims, utils as dist_utils from torchtitan.protocols.train_spec import get_train_spec from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger -from torchtitan.tools.metrics import build_device_memory_monitor from torchtitan.tools.utils import device_module, device_type # support running w/o installing as package diff --git a/torchtitan/components/dataloader.py b/torchtitan/components/dataloader.py index 2f2bec9fc8..38aab8fae6 100644 --- a/torchtitan/components/dataloader.py +++ b/torchtitan/components/dataloader.py @@ -8,7 +8,7 @@ import pickle from abc import ABC, abstractmethod -from typing import Any, Callable, TypeAlias +from typing import Any from torch.distributed.checkpoint.stateful import Stateful from torch.utils.data import IterableDataset @@ -86,6 +86,3 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: # We don't have to use pickle as DCP will serialize the state_dict. However, we have to # keep this for backward compatibility. super().load_state_dict(pickle.loads(state_dict[self._rank_id])) - - -DataLoaderBuilder: TypeAlias = Callable[[...], BaseDataLoader] diff --git a/torchtitan/tools/metrics.py b/torchtitan/components/metrics.py similarity index 63% rename from torchtitan/tools/metrics.py rename to torchtitan/components/metrics.py index 9a99b586c1..294f35f1ed 100644 --- a/torchtitan/tools/metrics.py +++ b/torchtitan/components/metrics.py @@ -5,15 +5,17 @@ # LICENSE file in the root directory of this source tree. import os +import time from collections import namedtuple from datetime import datetime from typing import Any, Dict, Optional import torch from torch.utils.tensorboard import SummaryWriter - +from torchtitan.components.optimizer import LRSchedulersContainer, OptimizersContainer from torchtitan.config_manager import JobConfig from torchtitan.distributed import ParallelDims +from torchtitan.tools import utils from torchtitan.tools.logging import logger from torchtitan.tools.utils import Color, device_module, device_type @@ -213,7 +215,7 @@ def _get_metrics_rank( return (world_size // pp_size) * (pp_size - 1) -def build_metric_logger( +def _build_metric_logger( job_config: JobConfig, parallel_dims: ParallelDims, tag: Optional[str] = None ) -> BaseLogger: """ @@ -276,3 +278,136 @@ def build_metric_logger( logger.debug("No loggers enabled, returning BaseLogger") return BaseLogger() + + +class MetricsProcessor: + """Metrics processor to processes the metrics and log metrics. + + The current MetricsProcessor log some metrics to STDOUT and some metrics to + TensorBoard or WandB. + + Args: + job_config (JobConfig): Job configuration. + parallel_dims (ParallelDims): Parallel dimensions. + tag (Optional[str]): Tag to use for TensorBoard or WandB. Defaults to None. + """ + + logger: BaseLogger + parallel_dims: ParallelDims + job_config: JobConfig + device_memory_monitor: DeviceMemoryMonitor + color: utils.Color + + gpu_peak_flops: int + ntokens_since_last_log: int + data_loading_times: list[float] + time_last_log: float + + num_flop_per_token: int + optimizers: Optional[OptimizersContainer] + lr_schedulers: Optional[LRSchedulersContainer] + + def __init__( + self, + job_config: JobConfig, + parallel_dims: ParallelDims, + tag: Optional[str] = None, + ): + self.logger = _build_metric_logger(job_config, parallel_dims, tag) + self.parallel_dims = parallel_dims + self.job_config = job_config + self.device_memory_monitor = build_device_memory_monitor() + # used for colorful printing + self.color = ( + utils.NoColor if job_config.metrics.disable_color_printing else utils.Color + ) + + self.gpu_peak_flops = utils.get_peak_flops( + self.device_memory_monitor.device_name + ) + self.ntokens_since_last_log = 0 + self.data_loading_times = [] + self.time_last_log = time.perf_counter() + self.device_memory_monitor.reset_peak_stats() + + # These variables have to be set later as they depend on other components or model. + self.num_flop_per_token = -1 + self.optimizers = None + self.lr_schedulers = None + + def should_log(self, step: int) -> bool: + return step == 1 or step % self.job_config.metrics.log_freq == 0 + + def log(self, step: int, global_avg_loss: float, global_max_loss: float): + assert self.num_flop_per_token > 0, "num_flop_per_token must be set" + + time_delta = time.perf_counter() - self.time_last_log + + # tokens per second per device, abbreviated as tps + tps = self.ntokens_since_last_log / ( + time_delta * self.parallel_dims.non_data_parallel_size + ) + # model FLOPS utilization + # For its definition and calculation, please refer to the PaLM paper: + # https://arxiv.org/abs/2204.02311 + mfu = 100 * self.num_flop_per_token * tps / self.gpu_peak_flops + tflops = self.num_flop_per_token * tps / 1e12 + + time_end_to_end = time_delta / self.job_config.metrics.log_freq + time_data_loading = sum(self.data_loading_times) / len(self.data_loading_times) + time_data_loading_pct = 100 * sum(self.data_loading_times) / time_delta + + device_mem_stats = self.device_memory_monitor.get_peak_stats() + + metrics = { + "loss_metrics/global_avg_loss": global_avg_loss, + "loss_metrics/global_max_loss": global_max_loss, + "throughput(tps)": tps, + "tflops": tflops, + "mfu(%)": mfu, + "time_metrics/end_to_end(s)": time_end_to_end, + "time_metrics/data_loading(s)": time_data_loading, + "time_metrics/data_loading(%)": time_data_loading_pct, + "memory/max_active(GiB)": device_mem_stats.max_active_gib, + "memory/max_active(%)": device_mem_stats.max_active_pct, + "memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib, + "memory/max_reserved(%)": device_mem_stats.max_reserved_pct, + "memory/num_alloc_retries": device_mem_stats.num_alloc_retries, + "memory/num_ooms": device_mem_stats.num_ooms, + } + self.logger.log(metrics, step) + + color = self.color + logger.info( + f"{color.red}step: {step:2} " + f"{color.green}loss: {global_avg_loss:7.4f} " + f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB" + f"({device_mem_stats.max_reserved_pct:.2f}%) " + f"{color.blue}tps: {round(tps):,} " + f"{color.cyan}tflops: {tflops:,.2f} " + f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" + ) + + self.ntokens_since_last_log = 0 + self.data_loading_times.clear() + self.time_last_log = time.perf_counter() + self.device_memory_monitor.reset_peak_stats() + + def close(self): + self.logger.close() + + +def build_metrics_processor( + job_config: JobConfig, parallel_dims: ParallelDims, tag: Optional[str] = None +) -> MetricsProcessor: + """Create a metrics processor. + + Args: + job_config (JobConfig): Job configuration. + parallel_dims (ParallelDims): Parallel dimensions. + tag (Optional[str]): Tag to use for TensorBoard or WandB. Defaults to None. + + Returns: + MetricsProcessor: A metrics processor. + """ + return MetricsProcessor(job_config, parallel_dims, tag) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 9939067437..0de00287d6 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -165,13 +165,6 @@ def __init__(self): which is the only stage that computes loss metrics. """, ) - self.parser.add_argument( - "--metrics.disable_logging_from_checkpoint", - action="store_true", - help=""" - Whether to log metrics from scratch for each checkpoint load. We have seen this feature - leading to nccl watchdog timeout issue when testing with tb. This flag disables it.""", - ) self.parser.add_argument( "--metrics.enable_wandb", action="store_true", diff --git a/torchtitan/protocols/train_spec.py b/torchtitan/protocols/train_spec.py index aaef723c1b..4e2b9c7781 100644 --- a/torchtitan/protocols/train_spec.py +++ b/torchtitan/protocols/train_spec.py @@ -7,12 +7,13 @@ # Copyright (c) Meta Platforms, Inc. All Rights Reserved. from dataclasses import dataclass -from typing import Callable, Protocol, Type, TypeAlias +from typing import Callable, Optional, Protocol, Type, TypeAlias import torch import torch.nn as nn from torch.distributed.pipelining.schedules import _PipelineSchedule -from torchtitan.components.dataloader import DataLoaderBuilder +from torchtitan.components.dataloader import BaseDataLoader +from torchtitan.components.metrics import MetricsProcessor from torchtitan.components.optimizer import LRSchedulersContainer, OptimizersContainer from torchtitan.components.tokenizer import Tokenizer from torchtitan.config_manager import JobConfig @@ -41,6 +42,8 @@ def from_model_args(cls, args: BaseModelArgs) -> nn.Module: ... +DataLoaderBuilder: TypeAlias = Callable[[...], BaseDataLoader] +MetricsProcessorBuilder: TypeAlias = Callable[[...], MetricsProcessor] OptimizersBuilder: TypeAlias = Callable[ [list[nn.Module], JobConfig], OptimizersContainer ] @@ -62,9 +65,7 @@ class TrainSpec: build_dataloader_fn: DataLoaderBuilder tokenizer_cls: Type[Tokenizer] loss_fn: LossFunction - - # TODO: Add a FQN convert fn to allow users to load checkpoints from - # HuggingFace or other sources that have different FQN conventions. + build_metrics_processor_fn: Optional[MetricsProcessorBuilder] = None _train_specs = {} diff --git a/torchtitan/train.py b/torchtitan/train.py index e8dc035c5d..004ab83ee9 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -13,17 +13,20 @@ from torchtitan.components.checkpoint import CheckpointManager, TrainState from torchtitan.components.ft import FTParallelDims, init_ft_manager -from torchtitan.components.metrics import (build_metrics_processor, - ensure_pp_loss_visible) +from torchtitan.components.metrics import ( + build_metrics_processor, + ensure_pp_loss_visible, +) from torchtitan.config_manager import JobConfig -from torchtitan.distributed import ParallelDims -from torchtitan.distributed import utils as dist_utils +from torchtitan.distributed import ParallelDims, utils as dist_utils from torchtitan.protocols.model_converter import build_model_converters from torchtitan.protocols.train_spec import get_train_spec from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger -from torchtitan.tools.profiling import (maybe_enable_memory_snapshot, - maybe_enable_profiling) +from torchtitan.tools.profiling import ( + maybe_enable_memory_snapshot, + maybe_enable_profiling, +) # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html @@ -388,4 +391,4 @@ def main(job_config: JobConfig): config = JobConfig() config.parse_args() main(config) - torch.distributed.destroy_process_group() \ No newline at end of file + torch.distributed.destroy_process_group() From f395ed2246e113d131b87e7b061c83ebfc47cd8f Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Thu, 6 Mar 2025 01:30:34 -0800 Subject: [PATCH 24/29] [Misc.] Log learning rate --- torchtitan/train.py | 58 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 56 insertions(+), 2 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 004ab83ee9..d29cf1159e 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -356,9 +356,63 @@ def main(job_config: JobConfig): else: global_avg_loss = global_max_loss = loss.item() - metrics_processor.log( - train_state.step, global_avg_loss, global_max_loss + # update train state + train_state.log_steps.append(train_state.step) + train_state.global_avg_losses.append(global_avg_loss) + train_state.global_max_losses.append(global_max_loss) + + time_delta = time.perf_counter() - time_last_log + + last_lr = lr_schedulers.schedulers[0].get_last_lr()[0] + # tokens per second per device, abbreviated as tps + tps = ntokens_since_last_log / ( + time_delta * parallel_dims.non_data_parallel_size ) + # model FLOPS utilization + # For its definition and calculation, please refer to the PaLM paper: + # https://arxiv.org/abs/2204.02311 + mfu = 100 * num_flop_per_token * tps / gpu_peak_flops + tflops = num_flop_per_token * tps / 1e12 + + time_end_to_end = time_delta / job_config.metrics.log_freq + time_data_loading = sum(data_loading_times) / len(data_loading_times) + time_data_loading_pct = 100 * sum(data_loading_times) / time_delta + + device_mem_stats = device_memory_monitor.get_peak_stats() + + metrics = { + "loss_metrics/global_avg_loss": global_avg_loss, + "loss_metrics/global_max_loss": global_max_loss, + "throughput(tps)": tps, + "tflops": tflops, + "mfu(%)": mfu, + "optim/learning_rate": last_lr, + "time_metrics/end_to_end(s)": time_end_to_end, + "time_metrics/data_loading(s)": time_data_loading, + "time_metrics/data_loading(%)": time_data_loading_pct, + "memory/max_active(GiB)": device_mem_stats.max_active_gib, + "memory/max_active(%)": device_mem_stats.max_active_pct, + "memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib, + "memory/max_reserved(%)": device_mem_stats.max_reserved_pct, + "memory/num_alloc_retries": device_mem_stats.num_alloc_retries, + "memory/num_ooms": device_mem_stats.num_ooms, + } + metric_logger.log(metrics, step=train_state.step) + + logger.info( + f"{color.red}step: {train_state.step:2} " + f"{color.green}loss: {global_avg_loss:7.4f} " + f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB" + f"({device_mem_stats.max_reserved_pct:.2f}%) " + f"{color.blue}tps: {round(tps):,} " + f"{color.cyan}lr: {last_lr:.4e} " + f"{color.magenta}tflops: {tflops:,.2f} mfu: {mfu:.2f}%{color.reset}" + ) + + ntokens_since_last_log = 0 + data_loading_times.clear() + time_last_log = time.perf_counter() + device_memory_monitor.reset_peak_stats() checkpoint.save( train_state.step, force=(train_state.step == job_config.training.steps) From af00afb64c54131efc2ca0c746d045f1a962bcc4 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 10 Mar 2025 16:42:41 +0800 Subject: [PATCH 25/29] Update train.py From 65f5f66d7ab4145d482b70180adb22cd5b0a0f0c Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 10 Mar 2025 02:09:05 -0700 Subject: [PATCH 26/29] Add warnings if warmup_stable_steps < warmup_steps --- torchtitan/components/optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 9e0676b881..d4f994b471 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -422,7 +422,7 @@ def linear_warmup_stable_decay( if lr_decay_ratio is None: warmup_stable_steps = warmup_steps else: - warmup_stable_steps = round(training_steps * (1 - lr_decay_ratio)) + warmup_stable_steps = training_steps * (1 - lr_decay_ratio) if warmup_stable_steps < warmup_steps: logger.warning( f"The warmup steps should be less than or equal to the warmup-stable steps ({warmup_stable_steps}). " From 1e6123600430cafb0eeb08c1c11f6606fcb8416c Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 10 Mar 2025 17:11:17 +0800 Subject: [PATCH 27/29] Revert changes on train.py --- torchtitan/train.py | 62 +++++++++++++++++++++++++-------------------- 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index d29cf1159e..85185e2e66 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -13,16 +13,13 @@ from torchtitan.components.checkpoint import CheckpointManager, TrainState from torchtitan.components.ft import FTParallelDims, init_ft_manager -from torchtitan.components.metrics import ( - build_metrics_processor, - ensure_pp_loss_visible, -) from torchtitan.config_manager import JobConfig from torchtitan.distributed import ParallelDims, utils as dist_utils from torchtitan.protocols.model_converter import build_model_converters from torchtitan.protocols.train_spec import get_train_spec from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger +from torchtitan.tools.metrics import build_device_memory_monitor, build_metric_logger from torchtitan.tools.profiling import ( maybe_enable_memory_snapshot, maybe_enable_profiling, @@ -40,6 +37,9 @@ def main(job_config: JobConfig): if job_config.job.print_args: logger.info(f"Running with args: {job_config.to_dict()}") + # used for colorful printing + color = utils.NoColor if job_config.metrics.disable_color_printing else utils.Color + # take control of garbage collection to avoid stragglers gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq) @@ -73,6 +73,10 @@ def main(job_config: JobConfig): ft_manager=ft_manager, ) dist_utils.init_distributed(job_config) + # initialize device memory monitor and get peak flops for MFU calculation + device_memory_monitor = build_device_memory_monitor() + gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) + logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}") # build meshes world_mesh = parallel_dims.build_mesh(device_type=device_type) @@ -126,18 +130,9 @@ def main(job_config: JobConfig): model_converters = build_model_converters(job_config, parallel_dims) model_converters.convert(model) - # metrics logging - build_metrics_processor_fn = ( - build_metrics_processor - if train_spec.build_metrics_processor_fn is None - else train_spec.build_metrics_processor_fn - ) - metrics_processor = build_metrics_processor_fn(job_config, parallel_dims) - color = metrics_processor.color - # log model size model_param_count = utils.get_num_params(model) - metrics_processor.num_flop_per_token = utils.get_num_flop_per_token( + num_flop_per_token = utils.get_num_flop_per_token( utils.get_num_params(model, exclude_embedding=True), model_config, job_config.training.seq_len, @@ -188,10 +183,6 @@ def main(job_config: JobConfig): with torch.no_grad(): m.init_weights(buffer_device=buffer_device) m.train() - - # confirm that user will be able to view loss metrics on the console - ensure_pp_loss_visible(parallel_dims, job_config, color) - else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config) @@ -202,10 +193,6 @@ def main(job_config: JobConfig): model_parts = [model] - # initialize device memory monitor and get peak flops for MFU calculation - device_memory_monitor = metrics_processor.device_memory_monitor - gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) - logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}") device_mem_stats = device_memory_monitor.get_peak_stats() logger.info( f"{device_type.upper()} memory usage for model: " @@ -249,6 +236,18 @@ def main(job_config: JobConfig): return checkpoint.load(step=job_config.checkpoint.load_step) + metric_logger = build_metric_logger(job_config, parallel_dims) + + # plot losses loaded from checkpoint (if any) to TensorBoard + # NOTE: Loss info after the last log step before checkpoint saving will not be ploted. + # This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq + if train_state.step > 0 and not job_config.metrics.disable_logging_from_checkpoint: + for idx, step in enumerate(train_state.log_steps): + metrics = { + "loss_metrics/global_avg_loss": train_state.global_avg_losses[idx], + "loss_metrics/global_max_loss": train_state.global_max_losses[idx], + } + metric_logger.log(metrics, step=step) data_iterator = iter(dataloader) @@ -257,6 +256,12 @@ def main(job_config: JobConfig): job_config.experimental.enable_compiled_autograd, ) + # variables used to keep info for metrics logging + ntokens_since_last_log = 0 + data_loading_times = [] + time_last_log = time.perf_counter() + device_memory_monitor.reset_peak_stats() + # train loop logger.info( f"Training starts at step {train_state.step + 1}, " @@ -279,10 +284,8 @@ def main(job_config: JobConfig): data_load_start = time.perf_counter() batch = next(data_iterator) input_ids, labels = batch - metrics_processor.ntokens_since_last_log += labels.numel() - metrics_processor.data_loading_times.append( - time.perf_counter() - data_load_start - ) + ntokens_since_last_log += labels.numel() + data_loading_times.append(time.perf_counter() - data_load_start) input_ids = input_ids.to(device_type) labels = labels.to(device_type) @@ -342,7 +345,10 @@ def main(job_config: JobConfig): lr_schedulers.step() # log metrics - if metrics_processor.should_log(train_state.step): + if ( + train_state.step == 1 + or train_state.step % job_config.metrics.log_freq == 0 + ): if ( parallel_dims.dp_replicate_enabled or parallel_dims.dp_shard_enabled @@ -436,7 +442,7 @@ def main(job_config: JobConfig): logger.info("Sleeping 2 seconds for other ranks to complete") time.sleep(2) - metrics_processor.close() + metric_logger.close() logger.info("Training completed") From 6328fc75e5049a95a0ff6efdc0025673cdcb4fe5 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 10 Mar 2025 17:57:16 +0800 Subject: [PATCH 28/29] Rename `training.warmup_steps` to `scheduler.warmup_steps` --- torchtitan/train.py | 62 ++++++++++++++++++++------------------------- 1 file changed, 28 insertions(+), 34 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 85185e2e66..d29cf1159e 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -13,13 +13,16 @@ from torchtitan.components.checkpoint import CheckpointManager, TrainState from torchtitan.components.ft import FTParallelDims, init_ft_manager +from torchtitan.components.metrics import ( + build_metrics_processor, + ensure_pp_loss_visible, +) from torchtitan.config_manager import JobConfig from torchtitan.distributed import ParallelDims, utils as dist_utils from torchtitan.protocols.model_converter import build_model_converters from torchtitan.protocols.train_spec import get_train_spec from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger -from torchtitan.tools.metrics import build_device_memory_monitor, build_metric_logger from torchtitan.tools.profiling import ( maybe_enable_memory_snapshot, maybe_enable_profiling, @@ -37,9 +40,6 @@ def main(job_config: JobConfig): if job_config.job.print_args: logger.info(f"Running with args: {job_config.to_dict()}") - # used for colorful printing - color = utils.NoColor if job_config.metrics.disable_color_printing else utils.Color - # take control of garbage collection to avoid stragglers gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq) @@ -73,10 +73,6 @@ def main(job_config: JobConfig): ft_manager=ft_manager, ) dist_utils.init_distributed(job_config) - # initialize device memory monitor and get peak flops for MFU calculation - device_memory_monitor = build_device_memory_monitor() - gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) - logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}") # build meshes world_mesh = parallel_dims.build_mesh(device_type=device_type) @@ -130,9 +126,18 @@ def main(job_config: JobConfig): model_converters = build_model_converters(job_config, parallel_dims) model_converters.convert(model) + # metrics logging + build_metrics_processor_fn = ( + build_metrics_processor + if train_spec.build_metrics_processor_fn is None + else train_spec.build_metrics_processor_fn + ) + metrics_processor = build_metrics_processor_fn(job_config, parallel_dims) + color = metrics_processor.color + # log model size model_param_count = utils.get_num_params(model) - num_flop_per_token = utils.get_num_flop_per_token( + metrics_processor.num_flop_per_token = utils.get_num_flop_per_token( utils.get_num_params(model, exclude_embedding=True), model_config, job_config.training.seq_len, @@ -183,6 +188,10 @@ def main(job_config: JobConfig): with torch.no_grad(): m.init_weights(buffer_device=buffer_device) m.train() + + # confirm that user will be able to view loss metrics on the console + ensure_pp_loss_visible(parallel_dims, job_config, color) + else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config) @@ -193,6 +202,10 @@ def main(job_config: JobConfig): model_parts = [model] + # initialize device memory monitor and get peak flops for MFU calculation + device_memory_monitor = metrics_processor.device_memory_monitor + gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) + logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}") device_mem_stats = device_memory_monitor.get_peak_stats() logger.info( f"{device_type.upper()} memory usage for model: " @@ -236,18 +249,6 @@ def main(job_config: JobConfig): return checkpoint.load(step=job_config.checkpoint.load_step) - metric_logger = build_metric_logger(job_config, parallel_dims) - - # plot losses loaded from checkpoint (if any) to TensorBoard - # NOTE: Loss info after the last log step before checkpoint saving will not be ploted. - # This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq - if train_state.step > 0 and not job_config.metrics.disable_logging_from_checkpoint: - for idx, step in enumerate(train_state.log_steps): - metrics = { - "loss_metrics/global_avg_loss": train_state.global_avg_losses[idx], - "loss_metrics/global_max_loss": train_state.global_max_losses[idx], - } - metric_logger.log(metrics, step=step) data_iterator = iter(dataloader) @@ -256,12 +257,6 @@ def main(job_config: JobConfig): job_config.experimental.enable_compiled_autograd, ) - # variables used to keep info for metrics logging - ntokens_since_last_log = 0 - data_loading_times = [] - time_last_log = time.perf_counter() - device_memory_monitor.reset_peak_stats() - # train loop logger.info( f"Training starts at step {train_state.step + 1}, " @@ -284,8 +279,10 @@ def main(job_config: JobConfig): data_load_start = time.perf_counter() batch = next(data_iterator) input_ids, labels = batch - ntokens_since_last_log += labels.numel() - data_loading_times.append(time.perf_counter() - data_load_start) + metrics_processor.ntokens_since_last_log += labels.numel() + metrics_processor.data_loading_times.append( + time.perf_counter() - data_load_start + ) input_ids = input_ids.to(device_type) labels = labels.to(device_type) @@ -345,10 +342,7 @@ def main(job_config: JobConfig): lr_schedulers.step() # log metrics - if ( - train_state.step == 1 - or train_state.step % job_config.metrics.log_freq == 0 - ): + if metrics_processor.should_log(train_state.step): if ( parallel_dims.dp_replicate_enabled or parallel_dims.dp_shard_enabled @@ -442,7 +436,7 @@ def main(job_config: JobConfig): logger.info("Sleeping 2 seconds for other ranks to complete") time.sleep(2) - metric_logger.close() + metrics_processor.close() logger.info("Training completed") From 1eb7c71afed09529d16d3895127487693866ee8b Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 10 Mar 2025 15:20:22 -0700 Subject: [PATCH 29/29] Fix code formats --- torchtitan/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 2598e76efd..9841e4ca70 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -393,4 +393,4 @@ def main(job_config: JobConfig): config = JobConfig() config.parse_args() main(config) - torch.distributed.destroy_process_group() \ No newline at end of file + torch.distributed.destroy_process_group()