Skip to content

Commit

Permalink
Add DeepSpeed MoE
Browse files Browse the repository at this point in the history
Thanks to dayofthepenguin for extensive testing

Closes EleutherAI#479
  • Loading branch information
yang committed Mar 6, 2024
1 parent 4825f1a commit 60ba057
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 0 deletions.
6 changes: 6 additions & 0 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,12 @@ def forward(self, x, attention_mask, layer_past=None):
0.0, device=layernorm_output.device, dtype=layernorm_output.dtype
)

if self.num_experts == 1:
mlp_output, mlp_bias = self.mlp(layernorm_output)
else:
mlp_output, moe_loss, _ = self.mlp(layernorm_output)
mlp_bias = None # deepspeed.moe.layer.MoE.forward ignores the bias term

if self.num_experts == 1:
mlp_output, mlp_bias = self.mlp(layernorm_output)
else:
Expand Down
55 changes: 55 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,3 +1285,58 @@ class NeoXArgsTextgen(NeoXArgsTemplate):
"""
Number of parallel experts in MoE
"""

moe_top_k: int = 1
"""
Activate top K experts in MoE
"""

use_tutel: bool = False
"""
Use Tutel optimizations in MoE
"""

num_experts: int = 1
"""
Number of MoE experts
"""

moe_loss_coeff: float = 0.1
"""
Coefficient for MoE loss
"""

moe_train_capacity_factor: float = 1.0
"""
The capacity of the expert at train time
"""

moe_eval_capacity_factor: float = 1.0
"""
The capacity of the expert at eval time
"""

moe_min_capacity: int = 4
"""
The minimum capacity per expert regardless of the capacity_factor
"""

moe_token_dropping: bool = True
"""
Whether to drop tokens when exceeding capacity
"""

create_moe_param_group: bool = True
"""
Whether to create a separate parameter group for MoE parameters
"""

moe_use_residual: bool = True
"""
Whether to use residual in MoE
"""

moe_expert_parallel_size: int = 1
"""
Number of parallel experts in MoE
"""
1 change: 1 addition & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def forward_step(

if neox_args.memory_profiling:
torch.cuda.nvtx.range_push(f"Forward pass")

# Sequential returns moe_losses, but this is not yet supported by pipe parallel
maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args)
if type(maybe_tuple) is tuple:
Expand Down

0 comments on commit 60ba057

Please sign in to comment.