diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 48d5f3df9..268d339e9 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -1133,6 +1133,12 @@ def forward(self, x, attention_mask, layer_past=None): 0.0, device=layernorm_output.device, dtype=layernorm_output.dtype ) + if self.num_experts == 1: + mlp_output, mlp_bias = self.mlp(layernorm_output) + else: + mlp_output, moe_loss, _ = self.mlp(layernorm_output) + mlp_bias = None # deepspeed.moe.layer.MoE.forward ignores the bias term + if self.num_experts == 1: mlp_output, mlp_bias = self.mlp(layernorm_output) else: diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 0e951c811..786ad0b3b 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -1285,3 +1285,58 @@ class NeoXArgsTextgen(NeoXArgsTemplate): """ Number of parallel experts in MoE """ + + moe_top_k: int = 1 + """ + Activate top K experts in MoE + """ + + use_tutel: bool = False + """ + Use Tutel optimizations in MoE + """ + + num_experts: int = 1 + """ + Number of MoE experts + """ + + moe_loss_coeff: float = 0.1 + """ + Coefficient for MoE loss + """ + + moe_train_capacity_factor: float = 1.0 + """ + The capacity of the expert at train time + """ + + moe_eval_capacity_factor: float = 1.0 + """ + The capacity of the expert at eval time + """ + + moe_min_capacity: int = 4 + """ + The minimum capacity per expert regardless of the capacity_factor + """ + + moe_token_dropping: bool = True + """ + Whether to drop tokens when exceeding capacity + """ + + create_moe_param_group: bool = True + """ + Whether to create a separate parameter group for MoE parameters + """ + + moe_use_residual: bool = True + """ + Whether to use residual in MoE + """ + + moe_expert_parallel_size: int = 1 + """ + Number of parallel experts in MoE + """ diff --git a/megatron/training.py b/megatron/training.py index 98404c4ea..d94ef6112 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -387,6 +387,7 @@ def forward_step( if neox_args.memory_profiling: torch.cuda.nvtx.range_push(f"Forward pass") + # Sequential returns moe_losses, but this is not yet supported by pipe parallel maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args) if type(maybe_tuple) is tuple: