From da9fc6ad8463ffd3d0e50d81ae17541e71c5daa7 Mon Sep 17 00:00:00 2001 From: Ido Hakimi Date: Fri, 18 Jul 2025 09:55:56 +0000 Subject: [PATCH 1/4] Add logging for learning rates in MetricsProcessor --- torchtitan/components/metrics.py | 14 ++++++++++++++ torchtitan/train.py | 1 + 2 files changed, 15 insertions(+) diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index dcd8782810..ce94b3f182 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -399,6 +399,20 @@ def log( "memory/num_ooms": device_mem_stats.num_ooms, } + if self.lr_schedulers: + # Log learning rate for each scheduler + lr_metrics = {} + for i, scheduler in enumerate(self.lr_schedulers.schedulers): + for j, lr in enumerate(scheduler.get_last_lr()): + lr_metrics[f"lr_scheduler/{i}/param_group_{j}/lr"] = lr + + if len(lr_metrics) == 1: + # If there's only one learning rate, log it directly + metrics.update({"lr": list(lr_metrics.values())[0]}) + else: + # Otherwise, log all learning rates under the lr_scheduler key + metrics.update(lr_metrics) + if extra_metrics: metrics.update(extra_metrics) diff --git a/torchtitan/train.py b/torchtitan/train.py index 58fc69ac2b..94d9dbbf4e 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -291,6 +291,7 @@ def __init__(self, job_config: JobConfig): ) ) self.metrics_processor.optimizers = self.optimizers + self.metrics_processor.lr_schedulers = self.lr_schedulers # Initialize trainer states that will be saved in checkpoint. # These attributes must be initialized before checkpoint loading. From b5e5b2fa573bfc2d89e3bc2afc976d1a2c46ed88 Mon Sep 17 00:00:00 2001 From: Ido Hakimi Date: Mon, 28 Jul 2025 09:30:40 +0000 Subject: [PATCH 2/4] Simplify learning rate logging to log only the first scheduler's learning rate --- torchtitan/components/metrics.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index ce94b3f182..35b8feeb5e 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -400,18 +400,8 @@ def log( } if self.lr_schedulers: - # Log learning rate for each scheduler - lr_metrics = {} - for i, scheduler in enumerate(self.lr_schedulers.schedulers): - for j, lr in enumerate(scheduler.get_last_lr()): - lr_metrics[f"lr_scheduler/{i}/param_group_{j}/lr"] = lr - - if len(lr_metrics) == 1: - # If there's only one learning rate, log it directly - metrics.update({"lr": list(lr_metrics.values())[0]}) - else: - # Otherwise, log all learning rates under the lr_scheduler key - metrics.update(lr_metrics) + # Log the learning rate from the first scheduler + metrics.update({"lr": self.lr_schedulers.schedulers[0].get_last_lr()[0]}) if extra_metrics: metrics.update(extra_metrics) From f37a69e64d43516b751aa6c7e31f188718d6f4b8 Mon Sep 17 00:00:00 2001 From: Ido Hakimi Date: Tue, 29 Jul 2025 20:25:16 +0000 Subject: [PATCH 3/4] Add note to clarify learning rate logging behavior in MetricsProcessor --- torchtitan/components/metrics.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index 35b8feeb5e..8e2d842467 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -401,6 +401,7 @@ def log( if self.lr_schedulers: # Log the learning rate from the first scheduler + # Note: This logs the LR for step i+1 when logging is done at the end of step i metrics.update({"lr": self.lr_schedulers.schedulers[0].get_last_lr()[0]}) if extra_metrics: From 69916b602ca261799667c04500423ac0002f9730 Mon Sep 17 00:00:00 2001 From: Ido Hakimi Date: Wed, 30 Jul 2025 19:19:38 +0000 Subject: [PATCH 4/4] Refactor learning rate logging in MetricsProcessor and Trainer --- torchtitan/components/metrics.py | 5 ----- torchtitan/train.py | 9 +++++++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index 8e2d842467..dcd8782810 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -399,11 +399,6 @@ def log( "memory/num_ooms": device_mem_stats.num_ooms, } - if self.lr_schedulers: - # Log the learning rate from the first scheduler - # Note: This logs the LR for step i+1 when logging is done at the end of step i - metrics.update({"lr": self.lr_schedulers.schedulers[0].get_last_lr()[0]}) - if extra_metrics: metrics.update(extra_metrics) diff --git a/torchtitan/train.py b/torchtitan/train.py index 94d9dbbf4e..ec5fdd340d 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -291,7 +291,6 @@ def __init__(self, job_config: JobConfig): ) ) self.metrics_processor.optimizers = self.optimizers - self.metrics_processor.lr_schedulers = self.lr_schedulers # Initialize trainer states that will be saved in checkpoint. # These attributes must be initialized before checkpoint loading. @@ -447,6 +446,8 @@ def train_step( self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] ): self.optimizers.zero_grad() + # Save the current step learning rate for logging + lr = self.lr_schedulers.schedulers[0].get_last_lr()[0] # Keep these variables local to shorten the code as these are # the major variables that are used in the training loop. @@ -494,12 +495,16 @@ def train_step( else: global_avg_loss = global_max_loss = loss.detach().item() + extra_metrics = { + "n_tokens_seen": self.ntokens_seen, + "lr": lr, + } self.metrics_processor.log( self.step, global_avg_loss, global_max_loss, grad_norm.item(), - extra_metrics={"ntokens_seen": self.ntokens_seen}, + extra_metrics=extra_metrics, ) @record