Skip to content

Commit

Permalink
Add DeepSpeed MoE
Browse files Browse the repository at this point in the history
Thanks to dayofthepenguin for extensive testing

Closes EleutherAI#479
  • Loading branch information
yang committed Mar 4, 2024
1 parent 6399155 commit bab8e4f
Show file tree
Hide file tree
Showing 9 changed files with 332 additions and 33 deletions.
99 changes: 99 additions & 0 deletions configs/125M-moe.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# GPT-2 pretraining setup
{
# Have 4 experts per layer (every 2 layers by default)
"num_experts": 4,

# parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages
# across the node boundaries )
"pipe_parallel_size": 1,
"model_parallel_size": 1,
"moe_expert_parallel_size": 1,

# model settings
"num_layers": 12,
"hidden_size": 768,
"num_attention_heads": 12,
"seq_length": 2048,
"max_position_embeddings": 2048,
"norm": "layernorm",
"pos_emb": "rotary",
"no_weight_tying": true,
"gpt_j_residual": false,
"output_layer_parallelism": "column",

# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,

# init methods
"init_method": "small_init",
"output_layer_init_method": "wang_init",


# optimizer settings
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.0006,
"betas": [0.9, 0.95],
"eps": 1.0e-8,
}
},
"min_lr": 0.00006,

# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
"zero_optimization": {
"stage": 1,
"allgather_partitions": True,
"allgather_bucket_size": 500000000,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 500000000,
"contiguous_gradients": True,
},

# batch / data settings
"train_micro_batch_size_per_gpu": 4,
"data_impl": "mmap",

# activation checkpointing
"checkpoint_activations": true,
"checkpoint_num_layers": 1,
"partition_activations": true,
"synchronize_each_layer": true,

# regularization
"gradient_clipping": 1.0,
"weight_decay": 0.1,
"hidden_dropout": 0.0,
"attention_dropout": 0.0,

# precision settings
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},

# misc. training settings
"train_iters": 320000,
"lr_decay_iters": 320000,
"distributed_backend": "nccl",
"lr_decay_style": "cosine",
"warmup": 0.01,
"checkpoint_factor": 10000,
"eval_interval": 1000,
"eval_iters": 10,

# logging
"log_interval": 10,
"steps_per_print": 10,
"keep_last_n_checkpoints": 4,
"wall_clock_breakdown": true,

# networking
"hostfile": "/mock_path"
}
113 changes: 89 additions & 24 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2021 EleutherAI
# Copyright (c) 2024 EleutherAI
# This file is based on code by the authors denoted below and has been modified from its original version.
#
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -43,6 +43,7 @@
bias_dropout_add_fused_inference,
)
from megatron.model.utils import configure_sparse_attention
from deepspeed.moe.layer import MoE

# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
Expand Down Expand Up @@ -82,7 +83,13 @@ class ParallelMLP(nn.Module):
"""

def __init__(
self, neox_args, init_method, output_layer_init_method, parallel_output=False
self,
neox_args,
init_method,
output_layer_init_method,
parallel_output=False,
MOE=False,
MoE_mp_size=1,
):
super().__init__()

Expand All @@ -104,6 +111,8 @@ def __init__(
gather_output=False,
init_method=init_method,
skip_bias_add=True,
MOE=MOE,
MoE_mp_size=MoE_mp_size,
)
ff_dim_in = ff_dim // 2 if self.activation_type == "geglu" else ff_dim
# Project back to h.
Expand All @@ -113,8 +122,10 @@ def __init__(
output_size=neox_args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
parallel_output=parallel_output,
skip_bias_add=True,
MOE=MOE,
MoE_mp_size=MoE_mp_size,
)

def forward(self, hidden_states):
Expand Down Expand Up @@ -156,6 +167,8 @@ def __init__(
output_layer_init_method,
parallel_output=False,
multiple_of=256,
MOE=False,
MoE_mp_size=1,
):
super().__init__()

Expand All @@ -174,6 +187,8 @@ def __init__(
init_method=init_method,
skip_bias_add=True,
bias=False,
MOE=MOE,
MoE_mp_size=MoE_mp_size,
)
self.w3 = mpu.ColumnParallelLinear(
neox_args=neox_args,
Expand All @@ -183,6 +198,8 @@ def __init__(
init_method=init_method,
skip_bias_add=True,
bias=False,
MOE=MOE,
MoE_mp_size=MoE_mp_size,
)
self.w2 = mpu.RowParallelLinear(
neox_args=neox_args,
Expand All @@ -193,6 +210,8 @@ def __init__(
skip_bias_add=True,
parallel_output=parallel_output,
bias=False,
MOE=MOE,
MoE_mp_size=MoE_mp_size,
)

def forward(self, hidden_states):
Expand Down Expand Up @@ -800,22 +819,55 @@ def __init__(
self.post_attention_layernorm = norm(neox_args.hidden_size, eps=eps)

# MLP
if neox_args.mlp_type == "regular":
self.mlp = ParallelMLP(
neox_args=neox_args,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
parallel_output=self.gpt_j_residual,
)
elif neox_args.mlp_type == "llama":
self.mlp = LLaMAParallelMLP(
neox_args=neox_args,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
parallel_output=self.gpt_j_residual,
)
def get_mlp(mlp_type, **kw):
if mlp_type == "regular":
return ParallelMLP(
neox_args=neox_args,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
parallel_output=self.gpt_j_residual,
**kw
)
elif mlp_type == "llama":
return LLaMAParallelMLP(
neox_args=neox_args,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
parallel_output=self.gpt_j_residual,
**kw
)
else:
raise KeyError(mlp_type)

self.num_experts = neox_args.num_experts
args = neox_args
if self.num_experts <= 1:
self.mlp = get_mlp(neox_args.mlp_type)
else:
raise KeyError(neox_args.mlp_type)
from torch import distributed as dist

if self.num_experts > dist.get_world_size():
moe_mp_size = 1
else:
moe_mp_size = dist.get_world_size() // self.num_experts

self.mlp = MoE(
args.hidden_size,
get_mlp(
"regular",
MOE=True,
MoE_mp_size=moe_mp_size,
),
num_experts=self.num_experts,
ep_size=args.moe_expert_parallel_size,
k=args.moe_top_k,
use_residual=args.moe_use_residual,
capacity_factor=args.moe_train_capacity_factor,
eval_capacity_factor=args.moe_eval_capacity_factor,
min_capacity=args.moe_min_capacity,
drop_tokens=args.moe_token_dropping,
use_tutel=args.use_tutel,
)

self.layer_past = None # used to cache k/v pairs in inference

Expand Down Expand Up @@ -913,12 +965,22 @@ def forward(self, x, attention_mask, layer_past=None):
)

# output = x + mlp(ln2(x))
mlp_output, mlp_bias = self.mlp(
self.post_attention_layernorm(attention_output)
layernorm_output = self.post_attention_layernorm(attention_output)
moe_loss = torch.tensor(
0.0, device=layernorm_output.device, dtype=layernorm_output.dtype
)
mlp_bias = torch.tensor(
0.0, device=layernorm_output.device, dtype=layernorm_output.dtype
)

if self.num_experts == 1:
mlp_output, mlp_bias = self.mlp(layernorm_output)
else:
mlp_output, moe_loss, _ = self.mlp(layernorm_output)
mlp_bias = None # deepspeed.moe.layer.MoE.forward ignores the bias term

with torch.enable_grad():
if self.mlp_type == "llama":
if self.mlp_type == "llama" or self.num_experts > 1:
# No dropout either
assert mlp_bias is None
output = mlp_output + attention_output
Expand All @@ -930,7 +992,7 @@ def forward(self, x, attention_mask, layer_past=None):
prob=self.hidden_dropout,
)

return output
return output, moe_loss


class ParallelTransformerLayerPipe(ParallelTransformerLayer):
Expand All @@ -942,7 +1004,10 @@ def forward(self, args):
), "ParallelTransformerLayerPipe expects 2 arguments - hidden_states and attention_mask"
hidden_states, attention_mask = args
# we are returning just [hidden_states, mask]
return super().forward(hidden_states, attention_mask), attention_mask
output, moe_loss = super().forward(hidden_states, attention_mask)
# auxiliary output
self.last_moe_loss = moe_loss
return output, attention_mask


class ParallelLinearPipe(ParallelLinear):
Expand Down
14 changes: 11 additions & 3 deletions megatron/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@ def get_params_for_weight_decay_optimization(module, neox_args):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and biases will have no weight decay but the rest will.
"""
weight_decay_params = {"params": []}
no_weight_decay_params = {"params": [], "weight_decay": 0.0}
weight_decay_params = {"params": [], "name": "weight_decay_params"}
no_weight_decay_params = {
"params": [],
"weight_decay": 0.0,
"name": "no_weight_decay_params",
}
for module_ in module.modules():
if any(
[
Expand Down Expand Up @@ -162,6 +166,8 @@ def forward(
].contiguous()
forward_input = (tokens, input_ids, attention_mask)

moe_losses = []

def exec_range_func(start, end):
"""Helper function to be used with checkpoint()
Adapted from torch.utils.checkpoint:checkpoint_sequential()
Expand All @@ -173,6 +179,8 @@ def exec_func(*inputs):
inputs = inputs[0]
for idx, layer in enumerate(self.sequential[start:end]):
inputs = layer(inputs)
if hasattr(layer, 'last_moe_loss'):
moe_losses.append(layer.last_moe_loss)
return inputs

return exec_func
Expand Down Expand Up @@ -200,7 +208,7 @@ def exec_func(*inputs):
)
else:
x = exec_range_func(start_idx, end_idx)(*x)
return x
return x, moe_losses

def clear_cache(self):
"""
Expand Down
3 changes: 3 additions & 0 deletions megatron/mpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
from .initialize import get_pipe_parallel_group
from .initialize import get_pipe_parallel_rank
from .initialize import get_pipe_parallel_world_size
from .initialize import get_tensor_model_parallel_group
from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size
from .initialize import get_io_parallel_group
from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized
Expand Down
32 changes: 32 additions & 0 deletions megatron/mpu/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,38 @@ def get_pipe_parallel_world_size():
return torch.distributed.get_world_size(group=get_pipe_parallel_group())


def set_tensor_model_parallel_world_size(world_size):
"""Set the tensor model parallel size"""
set_model_parallel_world_size(world_size)


def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
return get_model_parallel_group()


def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
return get_model_parallel_rank()


# Needed for MOE. True tensor parallelism todo.
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
return get_model_parallel_world_size()


def set_tensor_model_parallel_rank(rank):
"""Set tensor model parallel rank."""
set_model_parallel_rank(rank)


def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
return get_model_parallel_rank()


def destroy_model_parallel():
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
Expand Down
Loading

0 comments on commit bab8e4f

Please sign in to comment.