diff --git a/torchtitan/train.py b/torchtitan/train.py index 58fc69ac2b..ec5fdd340d 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -446,6 +446,8 @@ def train_step( self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] ): self.optimizers.zero_grad() + # Save the current step learning rate for logging + lr = self.lr_schedulers.schedulers[0].get_last_lr()[0] # Keep these variables local to shorten the code as these are # the major variables that are used in the training loop. @@ -493,12 +495,16 @@ def train_step( else: global_avg_loss = global_max_loss = loss.detach().item() + extra_metrics = { + "n_tokens_seen": self.ntokens_seen, + "lr": lr, + } self.metrics_processor.log( self.step, global_avg_loss, global_max_loss, grad_norm.item(), - extra_metrics={"ntokens_seen": self.ntokens_seen}, + extra_metrics=extra_metrics, ) @record