Skip to content

Commit a645f89

Browse files
committed
Merge branch 'cherry-pick-c4d12e26' into 'core_r0.7.0'
Merge branch 'xuwenc/moe_gmm_infer_fix' into 'main' See merge request ADLR/megatron-lm!1519
2 parents a967adf + 015e427 commit a645f89

File tree

4 files changed

+18
-7
lines changed

4 files changed

+18
-7
lines changed

megatron/core/tensor_parallel/layers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,7 @@ def __init__(
679679
self.disable_grad_reduce = disable_grad_reduce
680680

681681
self.explicit_expert_comm = self.is_expert and (
682-
config.sequence_parallel or self.expert_parallel
682+
config.tensor_model_parallel_size > 1 or self.expert_parallel
683683
)
684684
if self.explicit_expert_comm and config.moe_extended_tp:
685685
world_size = get_tensor_and_expert_parallel_world_size()
@@ -941,7 +941,7 @@ def __init__(
941941
raise RuntimeError("To enable `sequence_parallel`, `input_is_parallel` must be `True`")
942942

943943
self.explicit_expert_comm = self.is_expert and (
944-
config.sequence_parallel or self.expert_parallel
944+
config.tensor_model_parallel_size > 1 or self.expert_parallel
945945
)
946946

947947
# Divide the weight matrix along the last dimension.

megatron/core/transformer/moe/moe_layer.py

+10
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,16 @@ def __init__(
9090
self.moe_layer_recompute = config.moe_layer_recompute
9191

9292
def forward(self, hidden_states: torch.Tensor):
93+
if (
94+
self.training
95+
and self.config.tensor_model_parallel_size > 1
96+
and not self.config.sequence_parallel
97+
):
98+
raise ValueError(
99+
"During training, performance may degrade if MoE and tensor parallelism"
100+
"are enabled without also enabling sequence parallelism."
101+
)
102+
93103
# process MoE
94104
def custom_forward(hidden_states):
95105
probs, indices = self.router(hidden_states)

megatron/core/transformer/moe/token_dispatcher.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def token_permutation(
107107
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
108108

109109
# Permute the tokens across the expert parallel devices.
110-
if self.config.sequence_parallel or (self.config.expert_model_parallel_size > 1):
110+
if (self.config.tensor_model_parallel_size > 1) or (
111+
self.config.expert_model_parallel_size > 1
112+
):
111113
with torch.no_grad():
112114
global_indices = tensor_parallel.gather_from_sequence_parallel_region_to_moe(
113115
max_ind
@@ -214,7 +216,9 @@ def token_unpermutation(
214216
output_bias_total = unpermuted_local_bias
215217

216218
# Unpermute the tokens across expert parallel devices.
217-
if self.config.sequence_parallel or (self.config.expert_model_parallel_size > 1):
219+
if (self.config.tensor_model_parallel_size > 1) or (
220+
self.config.expert_model_parallel_size > 1
221+
):
218222
assert (
219223
self.global_local_map is not None
220224
), "global_local_map is necessary for `AllGather`."

megatron/training/arguments.py

-3
Original file line numberDiff line numberDiff line change
@@ -498,9 +498,6 @@ def validate_args(args, defaults={}):
498498
# MoE Spec check
499499
if args.num_experts is not None:
500500
assert args.spec is None, "Model Spec must be None when using MoEs"
501-
if args.tensor_model_parallel_size > 1:
502-
assert args.sequence_parallel, \
503-
"When using MoE and tensor parallelism, sequence parallelism must be used."
504501

505502
# Expert parallelism check
506503
if args.expert_model_parallel_size > 1:

0 commit comments

Comments
 (0)