File tree 4 files changed +18
-7
lines changed
4 files changed +18
-7
lines changed Original file line number Diff line number Diff line change @@ -679,7 +679,7 @@ def __init__(
679
679
self .disable_grad_reduce = disable_grad_reduce
680
680
681
681
self .explicit_expert_comm = self .is_expert and (
682
- config .sequence_parallel or self .expert_parallel
682
+ config .tensor_model_parallel_size > 1 or self .expert_parallel
683
683
)
684
684
if self .explicit_expert_comm and config .moe_extended_tp :
685
685
world_size = get_tensor_and_expert_parallel_world_size ()
@@ -941,7 +941,7 @@ def __init__(
941
941
raise RuntimeError ("To enable `sequence_parallel`, `input_is_parallel` must be `True`" )
942
942
943
943
self .explicit_expert_comm = self .is_expert and (
944
- config .sequence_parallel or self .expert_parallel
944
+ config .tensor_model_parallel_size > 1 or self .expert_parallel
945
945
)
946
946
947
947
# Divide the weight matrix along the last dimension.
Original file line number Diff line number Diff line change @@ -90,6 +90,16 @@ def __init__(
90
90
self .moe_layer_recompute = config .moe_layer_recompute
91
91
92
92
def forward (self , hidden_states : torch .Tensor ):
93
+ if (
94
+ self .training
95
+ and self .config .tensor_model_parallel_size > 1
96
+ and not self .config .sequence_parallel
97
+ ):
98
+ raise ValueError (
99
+ "During training, performance may degrade if MoE and tensor parallelism"
100
+ "are enabled without also enabling sequence parallelism."
101
+ )
102
+
93
103
# process MoE
94
104
def custom_forward (hidden_states ):
95
105
probs , indices = self .router (hidden_states )
Original file line number Diff line number Diff line change @@ -107,7 +107,9 @@ def token_permutation(
107
107
hidden_states = hidden_states .view (- 1 , self .hidden_shape [- 1 ])
108
108
109
109
# Permute the tokens across the expert parallel devices.
110
- if self .config .sequence_parallel or (self .config .expert_model_parallel_size > 1 ):
110
+ if (self .config .tensor_model_parallel_size > 1 ) or (
111
+ self .config .expert_model_parallel_size > 1
112
+ ):
111
113
with torch .no_grad ():
112
114
global_indices = tensor_parallel .gather_from_sequence_parallel_region_to_moe (
113
115
max_ind
@@ -214,7 +216,9 @@ def token_unpermutation(
214
216
output_bias_total = unpermuted_local_bias
215
217
216
218
# Unpermute the tokens across expert parallel devices.
217
- if self .config .sequence_parallel or (self .config .expert_model_parallel_size > 1 ):
219
+ if (self .config .tensor_model_parallel_size > 1 ) or (
220
+ self .config .expert_model_parallel_size > 1
221
+ ):
218
222
assert (
219
223
self .global_local_map is not None
220
224
), "global_local_map is necessary for `AllGather`."
Original file line number Diff line number Diff line change @@ -498,9 +498,6 @@ def validate_args(args, defaults={}):
498
498
# MoE Spec check
499
499
if args .num_experts is not None :
500
500
assert args .spec is None , "Model Spec must be None when using MoEs"
501
- if args .tensor_model_parallel_size > 1 :
502
- assert args .sequence_parallel , \
503
- "When using MoE and tensor parallelism, sequence parallelism must be used."
504
501
505
502
# Expert parallelism check
506
503
if args .expert_model_parallel_size > 1 :
You can’t perform that action at this time.
0 commit comments