Skip to content

Commit 0f9d2b5

Browse files
committed
MOE: Update global norm calculation for pipeline
When using MoE with MoE-TP disabled, use pipeline parallel group to max or sum MoE gradients. This also fixes the behavior for following configuration: No pipeline, TP enabled, MoE TP disabled. Signed-off-by: Moshe Island <[email protected]>
1 parent d8ecc22 commit 0f9d2b5

File tree

2 files changed

+46
-3
lines changed

2 files changed

+46
-3
lines changed

deepspeed/runtime/utils.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
from torch import inf
2626

2727
from deepspeed.utils import groups, logger
28-
from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank
28+
from deepspeed.utils.bwc import (bwc_tensor_model_parallel_rank, bwc_pipeline_parallel_world_size,
29+
bwc_pipeline_parallel_group)
2930
from deepspeed.runtime.constants import PIPE_REPLICATED
3031
from numpy import prod
3132
from deepspeed.accelerator import get_accelerator
@@ -857,8 +858,16 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=F
857858
all_norms.append(t.data.abs().max().float())
858859
total_norm = torch.stack(all_norms).max()
859860
device_total_norm = total_norm.to(get_accelerator().current_device_name())
861+
# Max across model parallel
860862
if mpu is not None:
861-
dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())
863+
# For MoE grads, max over model parallel only if MoE-TP is enabled
864+
if moe_ep_group is None or groups._get_expert_model_parallel_world_size() > 1:
865+
dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())
866+
# If MoE grads and MoE-TP disabled, max over pipeline parallel
867+
elif bwc_pipeline_parallel_world_size(mpu) > 1:
868+
dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=bwc_pipeline_parallel_group(mpu))
869+
870+
# MoE grads: max across expert parallel group
862871
if moe_ep_group is not None:
863872
dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=moe_ep_group)
864873
total_norm = device_total_norm.to(input_tensors[0].device)
@@ -885,8 +894,16 @@ def _norm_tensors(tensor_list, _compute_buffer, _norm_type):
885894

886895
device_total_norm = compute_buffer[0].float().detach()
887896

897+
# Sum across model parallel
888898
if mpu is not None:
889-
dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())
899+
# For MoE grads, sum over model parallel only if MoE-TP is enabled
900+
if moe_ep_group is None or groups._get_expert_model_parallel_world_size() > 1:
901+
dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())
902+
# If MoE grads and MoE-TP disabled, sum over pipeline parallel
903+
elif bwc_pipeline_parallel_world_size(mpu) > 1:
904+
dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=bwc_pipeline_parallel_group(mpu))
905+
906+
# MoE grads: sum across expert parallel group
890907
if moe_ep_group is not None:
891908
dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=moe_ep_group)
892909
total_norm = device_total_norm.to(input_tensors[0].device).pow(1. / norm_type)

deepspeed/utils/bwc.py

+26
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,29 @@ def bwc_tensor_model_parallel_group(mpu=None):
7676
else:
7777
# Deprecated Megatron and DeepSpeed convention
7878
return mpu.get_model_parallel_group()
79+
80+
81+
def bwc_pipeline_parallel_world_size(mpu=None):
82+
"""Backwards-compatible way of querying the pipeline parallel world size."""
83+
world_size = 1
84+
if mpu is not None:
85+
if hasattr(mpu, 'get_pipeline_model_parallel_world_size'):
86+
# New Megatron and DeepSpeed convention (post pipeline-parallelism release)
87+
world_size = mpu.get_pipeline_model_parallel_world_size()
88+
elif hasattr(mpu, 'get_pipe_parallel_world_size'):
89+
# DeepSpeed Topology
90+
world_size = mpu.get_pipe_parallel_world_size()
91+
return world_size
92+
93+
94+
def bwc_pipeline_parallel_group(mpu=None):
95+
"""Backwards-compatible way of querying the pipeline parallel group."""
96+
if mpu is None:
97+
return None
98+
if hasattr(mpu, 'get_pipeline_model_parallel_group'):
99+
# Megatron
100+
return mpu.get_pipeline_model_parallel_group()
101+
elif hasattr(mpu, 'get_pipe_parallel_group'):
102+
# DeepSpeed Topology
103+
return mpu.get_pipe_parallel_group()
104+
assert False, 'mpu does not support pipeline parallel group'

0 commit comments

Comments
 (0)