25
25
from torch import inf
26
26
27
27
from deepspeed .utils import groups , logger
28
- from deepspeed .utils .bwc import bwc_tensor_model_parallel_rank
28
+ from deepspeed .utils .bwc import (bwc_tensor_model_parallel_rank , bwc_pipeline_parallel_world_size ,
29
+ bwc_pipeline_parallel_group )
29
30
from deepspeed .runtime .constants import PIPE_REPLICATED
30
31
from numpy import prod
31
32
from deepspeed .accelerator import get_accelerator
@@ -857,8 +858,16 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=F
857
858
all_norms .append (t .data .abs ().max ().float ())
858
859
total_norm = torch .stack (all_norms ).max ()
859
860
device_total_norm = total_norm .to (get_accelerator ().current_device_name ())
861
+ # Max across model parallel
860
862
if mpu is not None :
861
- dist .all_reduce (device_total_norm , op = dist .ReduceOp .MAX , group = mpu .get_model_parallel_group ())
863
+ # For MoE grads, max over model parallel only if MoE-TP is enabled
864
+ if moe_ep_group is None or groups ._get_expert_model_parallel_world_size () > 1 :
865
+ dist .all_reduce (device_total_norm , op = dist .ReduceOp .MAX , group = mpu .get_model_parallel_group ())
866
+ # If MoE grads and MoE-TP disabled, max over pipeline parallel
867
+ elif bwc_pipeline_parallel_world_size (mpu ) > 1 :
868
+ dist .all_reduce (device_total_norm , op = dist .ReduceOp .MAX , group = bwc_pipeline_parallel_group (mpu ))
869
+
870
+ # MoE grads: max across expert parallel group
862
871
if moe_ep_group is not None :
863
872
dist .all_reduce (device_total_norm , op = dist .ReduceOp .MAX , group = moe_ep_group )
864
873
total_norm = device_total_norm .to (input_tensors [0 ].device )
@@ -885,8 +894,16 @@ def _norm_tensors(tensor_list, _compute_buffer, _norm_type):
885
894
886
895
device_total_norm = compute_buffer [0 ].float ().detach ()
887
896
897
+ # Sum across model parallel
888
898
if mpu is not None :
889
- dist .all_reduce (device_total_norm , op = dist .ReduceOp .SUM , group = mpu .get_model_parallel_group ())
899
+ # For MoE grads, sum over model parallel only if MoE-TP is enabled
900
+ if moe_ep_group is None or groups ._get_expert_model_parallel_world_size () > 1 :
901
+ dist .all_reduce (device_total_norm , op = dist .ReduceOp .SUM , group = mpu .get_model_parallel_group ())
902
+ # If MoE grads and MoE-TP disabled, sum over pipeline parallel
903
+ elif bwc_pipeline_parallel_world_size (mpu ) > 1 :
904
+ dist .all_reduce (device_total_norm , op = dist .ReduceOp .SUM , group = bwc_pipeline_parallel_group (mpu ))
905
+
906
+ # MoE grads: sum across expert parallel group
890
907
if moe_ep_group is not None :
891
908
dist .all_reduce (device_total_norm , op = dist .ReduceOp .SUM , group = moe_ep_group )
892
909
total_norm = device_total_norm .to (input_tensors [0 ].device ).pow (1. / norm_type )
0 commit comments