diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 3c9e20ffb5..aa25149d66 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -62,6 +62,16 @@ def dist_max( ) +def dist_sum( + x: torch.Tensor, + mesh: DeviceMesh, + extra_pg: dist.ProcessGroup | None = None, +) -> float: + return _dist_reduce( + x, reduceOp=c10d.ReduceOp.SUM.name, mesh=mesh, extra_pg=extra_pg + ) + + def dist_mean( x: torch.Tensor, mesh: DeviceMesh, diff --git a/torchtitan/train.py b/torchtitan/train.py index 369c409a81..0955bbb2cb 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -498,15 +498,23 @@ def train_step( if parallel_dims.dp_cp_enabled: loss = loss.detach() ft_pg = self.ft_manager.loss_sync_pg - global_avg_loss, global_max_loss = ( + global_avg_loss, global_max_loss, global_ntokens_seen = ( dist_utils.dist_mean(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), dist_utils.dist_max(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), + dist_utils.dist_sum( + torch.tensor( + self.ntokens_seen, dtype=torch.int64, device=self.device + ), + parallel_dims.world_mesh["dp_cp"], + ft_pg, + ), ) else: global_avg_loss = global_max_loss = loss.detach().item() + global_ntokens_seen = self.ntokens_seen extra_metrics = { - "n_tokens_seen": self.ntokens_seen, + "n_tokens_seen": global_ntokens_seen, "lr": lr, } self.metrics_processor.log(