From 89b68dbbe178968150d88d1c7333391ab769d4c6 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Tue, 11 Mar 2025 13:58:03 +0800 Subject: [PATCH 1/7] [Scheduler] Ensure `warmup_stable_steps` is an integer --- torchtitan/components/optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index d4f994b471..9e0676b881 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -422,7 +422,7 @@ def linear_warmup_stable_decay( if lr_decay_ratio is None: warmup_stable_steps = warmup_steps else: - warmup_stable_steps = training_steps * (1 - lr_decay_ratio) + warmup_stable_steps = round(training_steps * (1 - lr_decay_ratio)) if warmup_stable_steps < warmup_steps: logger.warning( f"The warmup steps should be less than or equal to the warmup-stable steps ({warmup_stable_steps}). " From 9f99a4151df34c35b3acb72170cc3189de4c0d8b Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 10 Mar 2025 23:35:19 -0700 Subject: [PATCH 2/7] Conduct lr check only once --- torchtitan/components/optimizer.py | 8 ++------ torchtitan/config_manager.py | 9 +++++++++ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 9e0676b881..9247f7a828 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -22,7 +22,7 @@ from torchtitan.components.ft import FTManager, has_torchft from torchtitan.config_manager import JobConfig -from torchtitan.tools.logging import logger + __all__ = [ "OptimizersContainer", @@ -423,11 +423,7 @@ def linear_warmup_stable_decay( warmup_stable_steps = warmup_steps else: warmup_stable_steps = round(training_steps * (1 - lr_decay_ratio)) - if warmup_stable_steps < warmup_steps: - logger.warning( - f"The warmup steps should be less than or equal to the warmup-stable steps ({warmup_stable_steps}). " - f"Consider reducing either the decay ratio ({lr_decay_ratio}) or the warmup steps ({warmup_steps})." - ) + if current_step < warmup_steps: # linear warmup # 0-indexed step, hence + 1 adjustments diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 0de00287d6..8e7a7aef6b 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -825,6 +825,15 @@ def _validate_config(self) -> None: "Please update your config." ) + if self.lr_scheduler.decay_ratio is not None: + warmup_stable_steps = round(self.training.steps * (1 - self.lr_scheduler.decay_ratio)) + if warmup_stable_steps < self.lr_scheduler.warmup_steps: + logger.warning( + f"The warmup steps should be less than or equal to the warmup-stable steps ({warmup_stable_steps}). " + f"Consider reducing either the decay ratio ({self.lr_scheduler.decay_ratio}) " + f"or the warmup steps ({self.lr_scheduler.warmup_steps})." + ) + def _get_string_list_argument_names(self) -> list[str]: """Get the parser argument names of type `string_list`.""" string_list_args = [ From 1127c74c8d5ef8d15e4235a02a569d66420c16de Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Wed, 12 Mar 2025 11:01:57 -0700 Subject: [PATCH 3/7] Delete `decay_ratio` checks in config manager --- torchtitan/config_manager.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 8e7a7aef6b..0de00287d6 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -825,15 +825,6 @@ def _validate_config(self) -> None: "Please update your config." ) - if self.lr_scheduler.decay_ratio is not None: - warmup_stable_steps = round(self.training.steps * (1 - self.lr_scheduler.decay_ratio)) - if warmup_stable_steps < self.lr_scheduler.warmup_steps: - logger.warning( - f"The warmup steps should be less than or equal to the warmup-stable steps ({warmup_stable_steps}). " - f"Consider reducing either the decay ratio ({self.lr_scheduler.decay_ratio}) " - f"or the warmup steps ({self.lr_scheduler.warmup_steps})." - ) - def _get_string_list_argument_names(self) -> list[str]: """Get the parser argument names of type `string_list`.""" string_list_args = [ From 5f189edea9fb1aebd94e906033ca0ac246405877 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Wed, 12 Mar 2025 11:12:23 -0700 Subject: [PATCH 4/7] Raise warning only once for abnormal decay ratio --- torchtitan/components/optimizer.py | 7 ++++++- torchtitan/tools/logging.py | 16 +++++++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 9247f7a828..d7b4d32a3d 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -22,7 +22,7 @@ from torchtitan.components.ft import FTManager, has_torchft from torchtitan.config_manager import JobConfig - +from torchtitan.tools.logging import logger __all__ = [ "OptimizersContainer", @@ -424,6 +424,11 @@ def linear_warmup_stable_decay( else: warmup_stable_steps = round(training_steps * (1 - lr_decay_ratio)) + if warmup_stable_steps < warmup_steps: + logger.warning_once( + f"The warmup steps should be less than or equal to the warmup-stable steps ({warmup_stable_steps}). " + f"Consider reducing either the decay ratio ({lr_decay_ratio}) or the warmup steps ({warmup_steps})." + ) if current_step < warmup_steps: # linear warmup # 0-indexed step, hence + 1 adjustments diff --git a/torchtitan/tools/logging.py b/torchtitan/tools/logging.py index 22b302f97d..a89dc207c1 100644 --- a/torchtitan/tools/logging.py +++ b/torchtitan/tools/logging.py @@ -4,10 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import functools import logging import os - logger = logging.getLogger() @@ -23,3 +23,17 @@ def init_logger(): # suppress verbose torch.profiler logging os.environ["KINETO_LOG_LEVEL"] = "5" + + +@functools.lru_cache(None) +def warning_once(self, *args, **kwargs): + """ + Emit a warning message only once for unique arguments. + + This method is similar to `logger.warning()`, but will emit the warning + with the same message only once for a given set of arguments. + """ + self.warning(*args, **kwargs) + + +logging.Logger.warning_once = warning_once From a8ff29ea2977764fc10a798ae0b48ac64bb76534 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Thu, 13 Mar 2025 09:32:02 -0700 Subject: [PATCH 5/7] Pass warm/stable/decay steps as args of `linear_warmup_stable_decay` --- torchtitan/components/optimizer.py | 31 ++++++++++++++++-------------- torchtitan/tools/logging.py | 15 --------------- 2 files changed, 17 insertions(+), 29 deletions(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index d7b4d32a3d..78a571ae8c 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -392,14 +392,26 @@ def build_lr_schedulers( """ training_steps = job_config.training.steps warmup_steps = int(job_config.lr_scheduler.warmup_steps) - lr_decay_ratio = job_config.lr_scheduler.decay_ratio + if job_config.lr_scheduler.decay_ratio is not None: + decay_steps = round(training_steps * job_config.lr_scheduler.decay_ratio) + if warmup_steps + decay_steps > training_steps: + logger.warning( + f"Warmup ({warmup_steps}) + decay ({decay_steps}) steps exceed " + f"total training steps ({training_steps}). " + f"Adjusting decay steps to {training_steps - warmup_steps}." + ) + decay_steps = training_steps - warmup_steps + else: + decay_steps = training_steps - warmup_steps + stable_steps = training_steps - warmup_steps - decay_steps lr_decay_type = job_config.lr_scheduler.decay_type lr_min = job_config.lr_scheduler.lr_min def linear_warmup_stable_decay( current_step: int, warmup_steps: int, - lr_decay_ratio: Union[float, None], + stable_steps: int, + decay_steps: int, lr_decay_type: str, lr_min: float, ): @@ -419,16 +431,7 @@ def linear_warmup_stable_decay( If `lr_min` is specified, the decay range is scaled from 1 to `lr_min` to ensure the learning rate does not drop below this minimum value. """ - if lr_decay_ratio is None: - warmup_stable_steps = warmup_steps - else: - warmup_stable_steps = round(training_steps * (1 - lr_decay_ratio)) - - if warmup_stable_steps < warmup_steps: - logger.warning_once( - f"The warmup steps should be less than or equal to the warmup-stable steps ({warmup_stable_steps}). " - f"Consider reducing either the decay ratio ({lr_decay_ratio}) or the warmup steps ({warmup_steps})." - ) + warmup_stable_steps = warmup_steps + stable_steps if current_step < warmup_steps: # linear warmup # 0-indexed step, hence + 1 adjustments @@ -437,7 +440,6 @@ def linear_warmup_stable_decay( elif current_step < warmup_stable_steps: curr_adjustment = 1.0 else: - decay_steps = float(max(1, training_steps - warmup_stable_steps)) progress = float(current_step - warmup_stable_steps) / decay_steps if lr_decay_type == "linear": @@ -452,7 +454,8 @@ def linear_warmup_stable_decay( lr_lambda = functools.partial( linear_warmup_stable_decay, warmup_steps=warmup_steps, - lr_decay_ratio=lr_decay_ratio, + stable_steps=stable_steps, + decay_steps=decay_steps, lr_decay_type=lr_decay_type, lr_min=lr_min, ) diff --git a/torchtitan/tools/logging.py b/torchtitan/tools/logging.py index a89dc207c1..3995e10171 100644 --- a/torchtitan/tools/logging.py +++ b/torchtitan/tools/logging.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import functools import logging import os @@ -23,17 +22,3 @@ def init_logger(): # suppress verbose torch.profiler logging os.environ["KINETO_LOG_LEVEL"] = "5" - - -@functools.lru_cache(None) -def warning_once(self, *args, **kwargs): - """ - Emit a warning message only once for unique arguments. - - This method is similar to `logger.warning()`, but will emit the warning - with the same message only once for a given set of arguments. - """ - self.warning(*args, **kwargs) - - -logging.Logger.warning_once = warning_once From 55e76a1fd4eef1119ba0b8fa0aa9ac1a52e8d694 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Thu, 13 Mar 2025 09:33:22 -0700 Subject: [PATCH 6/7] Revert changes in logging --- torchtitan/tools/logging.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/tools/logging.py b/torchtitan/tools/logging.py index 3995e10171..22b302f97d 100644 --- a/torchtitan/tools/logging.py +++ b/torchtitan/tools/logging.py @@ -7,6 +7,7 @@ import logging import os + logger = logging.getLogger() From 1994c63d4f371306e1f3be2b89042d991ff4f2ac Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Fri, 14 Mar 2025 16:38:53 +0800 Subject: [PATCH 7/7] Fix code formats --- torchtitan/components/optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 78a571ae8c..9d04ce4f58 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -7,7 +7,7 @@ import copy import functools import math -from typing import Any, Callable, Dict, Generic, List, TypeVar, Union +from typing import Any, Callable, Dict, Generic, List, TypeVar import torch import torch.nn as nn