diff --git a/megatron/arguments.py b/megatron/arguments.py index b65a2f0e073..c0911dbbbc4 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -635,6 +635,9 @@ def _add_distributed_args(parser): group.add_argument('--tensor-model-parallel-size', type=int, default=1, help='Degree of tensor model parallelism.') + group.add_argument('--enable-expert-tensor-parallelism', action='store_true', + default=False, + help="use tensor parallelism for expert layers in MoE") group.add_argument('--pipeline-model-parallel-size', type=int, default=1, help='Degree of pipeline model parallelism.') group.add_argument('--moe-expert-parallel-size', type=int, default=1, @@ -911,4 +914,4 @@ def _add_distillation_args(parser): group.add_argument('--load-teacher', type=str, default=None, help='Directory containing a teacher model checkpoint.') - return parser \ No newline at end of file + return parser diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 06a10b8f941..72641445d17 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -59,7 +59,7 @@ class ParallelMLP(MegatronModule): applied. """ - def __init__(self, init_method, output_layer_init_method, MOE=False, MoE_mp_size=1): + def __init__(self, init_method, output_layer_init_method, moe=False, enable_expert_tensor_parallelism=False): super(ParallelMLP, self).__init__() args = get_args() @@ -70,8 +70,9 @@ def __init__(self, init_method, output_layer_init_method, MOE=False, MoE_mp_size gather_output=False, init_method=init_method, skip_bias_add=True, - MOE=MOE, - MoE_mp_size=MoE_mp_size) + moe=moe, + enable_expert_tensor_parallelism=enable_expert_tensor_parallelism + ) self.bias_gelu_fusion = args.bias_gelu_fusion self.activation_func = F.gelu @@ -87,9 +88,8 @@ def __init__(self, init_method, output_layer_init_method, MOE=False, MoE_mp_size input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True, - MOE=MOE, - MoE_mp_size=MoE_mp_size) - + moe=moe, + enable_expert_tensor_parallelism=enable_expert_tensor_parallelism) def forward(self, hidden_states): @@ -448,16 +448,12 @@ def __init__(self, init_method, output_layer_init_method, self.mlp = ParallelMLP(init_method, output_layer_init_method) else: - if not args.ds_inference or self.num_experts > dist.get_world_size(): - moe_mp_size = 1 - else: - moe_mp_size = dist.get_world_size() // self.num_experts - + enable_expert_tensor_parallelism = args.enable_expert_tensor_parallelism self.mlp = MoE(args.hidden_size, ParallelMLP(init_method, output_layer_init_method=output_layer_init_method, - MOE=True, - MoE_mp_size=moe_mp_size), + moe=True, + enable_expert_tensor_parallelism=enable_expert_tensor_parallelism), num_experts=self.num_experts, ep_size=args.moe_expert_parallel_size, k=args.topk, @@ -465,7 +461,8 @@ def __init__(self, init_method, output_layer_init_method, capacity_factor=args.moe_train_capacity_factor, eval_capacity_factor=args.moe_eval_capacity_factor, min_capacity=args.moe_min_capacity, - drop_tokens=args.moe_token_dropping, use_tutel=args.use_tutel) + drop_tokens=args.moe_token_dropping, use_tutel=args.use_tutel, + enable_expert_tensor_parallelism=enable_expert_tensor_parallelism) def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, diff --git a/megatron/mpu/layers.py b/megatron/mpu/layers.py index c837160a371..0d81d562238 100644 --- a/megatron/mpu/layers.py +++ b/megatron/mpu/layers.py @@ -230,7 +230,7 @@ class ColumnParallelLinear(torch.nn.Module): def __init__(self, input_size, output_size, bias=True, gather_output=True, init_method=init.xavier_normal_, stride=1, keep_master_weight_for_test=False, - skip_bias_add=False, MOE=False, MoE_mp_size=1): + skip_bias_add=False, moe=False, enable_expert_tensor_parallelism=False): super(ColumnParallelLinear, self).__init__() # Keep input parameters @@ -238,7 +238,13 @@ def __init__(self, input_size, output_size, bias=True, gather_output=True, self.output_size = output_size self.gather_output = gather_output # Divide the weight matrix along the last dimension. - world_size = MoE_mp_size if MOE else get_tensor_model_parallel_world_size() + if moe and (not enable_expert_tensor_parallelism): + world_size = 1 + self.is_expert_without_slicing = True + else: + world_size = get_tensor_model_parallel_world_size() + self.is_expert_without_slicing = False + self.output_size_per_partition = divide(output_size, world_size) self.skip_bias_add = skip_bias_add @@ -282,12 +288,16 @@ def __init__(self, input_size, output_size, bias=True, gather_output=True, def forward(self, input_): # Set up backprop all-reduce. - input_parallel = copy_to_tensor_model_parallel_region(input_) + if self.is_expert_without_slicing: # non-expert only tensor parallelism + input_parallel = input_ + else: + input_parallel = copy_to_tensor_model_parallel_region(input_) + # Matrix multiply. bias = self.bias if not self.skip_bias_add else None output_parallel = F.linear(input_parallel, self.weight, bias) - if self.gather_output: + if self.gather_output and not self.is_expert_without_slicing: # All-gather across the partitions. output = gather_from_tensor_model_parallel_region(output_parallel) else: @@ -330,7 +340,7 @@ def __init__(self, input_size, output_size, bias=True, input_is_parallel=False, init_method=init.xavier_normal_, stride=1, keep_master_weight_for_test=False, - skip_bias_add=False, MOE=False, MoE_mp_size=1): + skip_bias_add=False, moe=False, enable_expert_tensor_parallelism=False): super(RowParallelLinear, self).__init__() # Keep input parameters @@ -338,7 +348,14 @@ def __init__(self, input_size, output_size, bias=True, self.output_size = output_size self.input_is_parallel = input_is_parallel # Divide the weight matrix along the last dimension. - world_size = MoE_mp_size if MOE else get_tensor_model_parallel_world_size() + + if moe and (not enable_expert_tensor_parallelism): + world_size = 1 + else: + world_size = get_tensor_model_parallel_world_size() + + self.is_expert_without_slicing = moe and world_size==1 + self.input_size_per_partition = divide(input_size, world_size) self.skip_bias_add = skip_bias_add @@ -379,14 +396,18 @@ def __init__(self, input_size, output_size, bias=True, def forward(self, input_): # Set up backprop all-reduce. - if self.input_is_parallel: + if self.input_is_parallel or self.is_expert_without_slicing: input_parallel = input_ else: input_parallel = scatter_to_tensor_model_parallel_region(input_) # Matrix multiply. output_parallel = F.linear(input_parallel, self.weight) # All-reduce across all the partitions. - output_ = reduce_from_tensor_model_parallel_region(output_parallel) + if self.is_expert_without_slicing: # non-expert only tensor-parallelism + output_ = output_parallel + else: + output_ = reduce_from_tensor_model_parallel_region(output_parallel) + if not self.skip_bias_add: output = output_ + self.bias if self.bias is not None else output_ output_bias = None