From d19395cade70f845aaaecd96952b18b48438b228 Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 31 Jul 2025 15:43:52 -0700 Subject: [PATCH 1/2] All-reduce ntokens_seen before logging --- torchtitan/distributed/utils.py | 10 ++++++++++ torchtitan/train.py | 8 +++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 3c9e20ffb5..aa25149d66 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -62,6 +62,16 @@ def dist_max( ) +def dist_sum( + x: torch.Tensor, + mesh: DeviceMesh, + extra_pg: dist.ProcessGroup | None = None, +) -> float: + return _dist_reduce( + x, reduceOp=c10d.ReduceOp.SUM.name, mesh=mesh, extra_pg=extra_pg + ) + + def dist_mean( x: torch.Tensor, mesh: DeviceMesh, diff --git a/torchtitan/train.py b/torchtitan/train.py index 369c409a81..77734ebbca 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -502,11 +502,17 @@ def train_step( dist_utils.dist_mean(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), dist_utils.dist_max(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), ) + global_ntokens_seen = dist_utils.dist_sum( + torch.tensor(self.ntokens_seen, dtype=torch.int64, device=self.device), + parallel_dims.world_mesh["dp_cp"], + ft_pg, + ) else: global_avg_loss = global_max_loss = loss.detach().item() + global_ntokens_seen = self.ntokens_seen extra_metrics = { - "n_tokens_seen": self.ntokens_seen, + "n_tokens_seen": global_ntokens_seen, "lr": lr, } self.metrics_processor.log( From 540eab8a666250ba9b4d69f6dc9bdeec6b49bfd0 Mon Sep 17 00:00:00 2001 From: runame Date: Fri, 1 Aug 2025 15:20:14 -0700 Subject: [PATCH 2/2] Style fix --- torchtitan/train.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 77734ebbca..0955bbb2cb 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -498,14 +498,16 @@ def train_step( if parallel_dims.dp_cp_enabled: loss = loss.detach() ft_pg = self.ft_manager.loss_sync_pg - global_avg_loss, global_max_loss = ( + global_avg_loss, global_max_loss, global_ntokens_seen = ( dist_utils.dist_mean(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), dist_utils.dist_max(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), - ) - global_ntokens_seen = dist_utils.dist_sum( - torch.tensor(self.ntokens_seen, dtype=torch.int64, device=self.device), - parallel_dims.world_mesh["dp_cp"], - ft_pg, + dist_utils.dist_sum( + torch.tensor( + self.ntokens_seen, dtype=torch.int64, device=self.device + ), + parallel_dims.world_mesh["dp_cp"], + ft_pg, + ), ) else: global_avg_loss = global_max_loss = loss.detach().item()