Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
86dba34
add tensor parallelism
siddharth9820 Jun 29, 2022
aa08fc1
TP for non-experts, drop token before a2a
siddharth9820 Jul 1, 2022
44a1203
dropping tokens after gating
siddharth9820 Jul 6, 2022
92051b3
remove commented code
siddharth9820 Jul 6, 2022
441173d
remove spurious changes
siddharth9820 Jul 6, 2022
7ed0611
remove commented code
siddharth9820 Jul 6, 2022
2f29eb5
remove spurious code
siddharth9820 Jul 6, 2022
1aa8ce0
remove blank line
siddharth9820 Jul 11, 2022
e3b0105
change to schedules
siddharth9820 Jul 12, 2022
3f0cbe1
modified checks in mpu/layers
siddharth9820 Jul 12, 2022
e79106a
better named flag
siddharth9820 Jul 12, 2022
30e2234
migrate code to deepspeed
siddharth9820 Jul 12, 2022
71088c2
remove unnecessary changes to code formatting
siddharth9820 Jul 13, 2022
6364a28
remove unnecessary changes to code formatting
siddharth9820 Jul 13, 2022
3a0eb55
shift code to deepspeed
siddharth9820 Jul 13, 2022
ef3f77c
remove blank lines
siddharth9820 Jul 13, 2022
96b1cb6
remove blank lines
siddharth9820 Jul 13, 2022
961bf8d
remove blank lines
siddharth9820 Jul 13, 2022
4089188
restore mappings.py
siddharth9820 Jul 13, 2022
e1a345c
restore mappings.py
siddharth9820 Jul 13, 2022
18c0d84
remove unnecessary code
siddharth9820 Jul 13, 2022
1619965
restructure code and introduce tensor parallelism for experts
siddharth9820 Jul 19, 2022
066632b
correct ep_size
siddharth9820 Jul 21, 2022
c5e0f40
set ep size correctly
siddharth9820 Jul 22, 2022
8398bb0
correctly set ep_size
siddharth9820 Jul 22, 2022
3aa05d3
remove client side code that sets ep_size
siddharth9820 Jul 25, 2022
e20fc2e
correct ep_size
siddharth9820 Jul 26, 2022
92e8839
small fix
siddharth9820 Jul 26, 2022
ad593a0
small fix
siddharth9820 Jul 26, 2022
3e361a4
Merge branch 'main' of github.com:microsoft/Megatron-DeepSpeed into m…
siddharth9820 Jul 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,9 @@ def _add_distributed_args(parser):

group.add_argument('--tensor-model-parallel-size', type=int, default=1,
help='Degree of tensor model parallelism.')
group.add_argument('--enable-expert-tensor-parallelism', action='store_true',
default=False,
help="use tensor parallelism for expert layers in MoE")
group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
help='Degree of pipeline model parallelism.')
group.add_argument('--moe-expert-parallel-size', type=int, default=1,
Expand Down Expand Up @@ -911,4 +914,4 @@ def _add_distillation_args(parser):
group.add_argument('--load-teacher', type=str, default=None,
help='Directory containing a teacher model checkpoint.')

return parser
return parser
25 changes: 11 additions & 14 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class ParallelMLP(MegatronModule):
applied.
"""

def __init__(self, init_method, output_layer_init_method, MOE=False, MoE_mp_size=1):
def __init__(self, init_method, output_layer_init_method, moe=False, enable_expert_tensor_parallelism=False):
super(ParallelMLP, self).__init__()
args = get_args()

Expand All @@ -70,8 +70,9 @@ def __init__(self, init_method, output_layer_init_method, MOE=False, MoE_mp_size
gather_output=False,
init_method=init_method,
skip_bias_add=True,
MOE=MOE,
MoE_mp_size=MoE_mp_size)
moe=moe,
enable_expert_tensor_parallelism=enable_expert_tensor_parallelism
)

self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu
Expand All @@ -87,9 +88,8 @@ def __init__(self, init_method, output_layer_init_method, MOE=False, MoE_mp_size
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
MOE=MOE,
MoE_mp_size=MoE_mp_size)

moe=moe,
enable_expert_tensor_parallelism=enable_expert_tensor_parallelism)

def forward(self, hidden_states):

Expand Down Expand Up @@ -448,24 +448,21 @@ def __init__(self, init_method, output_layer_init_method,
self.mlp = ParallelMLP(init_method,
output_layer_init_method)
else:
if not args.ds_inference or self.num_experts > dist.get_world_size():
moe_mp_size = 1
else:
moe_mp_size = dist.get_world_size() // self.num_experts

enable_expert_tensor_parallelism = args.enable_expert_tensor_parallelism
self.mlp = MoE(args.hidden_size,
ParallelMLP(init_method,
output_layer_init_method=output_layer_init_method,
MOE=True,
MoE_mp_size=moe_mp_size),
moe=True,
enable_expert_tensor_parallelism=enable_expert_tensor_parallelism),
num_experts=self.num_experts,
ep_size=args.moe_expert_parallel_size,
k=args.topk,
use_residual=(args.mlp_type == 'residual'),
capacity_factor=args.moe_train_capacity_factor,
eval_capacity_factor=args.moe_eval_capacity_factor,
min_capacity=args.moe_min_capacity,
drop_tokens=args.moe_token_dropping, use_tutel=args.use_tutel)
drop_tokens=args.moe_token_dropping, use_tutel=args.use_tutel,
enable_expert_tensor_parallelism=enable_expert_tensor_parallelism)

def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
Expand Down
37 changes: 29 additions & 8 deletions megatron/mpu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,21 @@ class ColumnParallelLinear(torch.nn.Module):
def __init__(self, input_size, output_size, bias=True, gather_output=True,
init_method=init.xavier_normal_, stride=1,
keep_master_weight_for_test=False,
skip_bias_add=False, MOE=False, MoE_mp_size=1):
skip_bias_add=False, moe=False, enable_expert_tensor_parallelism=False):
super(ColumnParallelLinear, self).__init__()

# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
world_size = MoE_mp_size if MOE else get_tensor_model_parallel_world_size()
if moe and (not enable_expert_tensor_parallelism):
world_size = 1
self.is_expert_without_slicing = True
else:
world_size = get_tensor_model_parallel_world_size()
self.is_expert_without_slicing = False

self.output_size_per_partition = divide(output_size, world_size)
self.skip_bias_add = skip_bias_add

Expand Down Expand Up @@ -282,12 +288,16 @@ def __init__(self, input_size, output_size, bias=True, gather_output=True,

def forward(self, input_):
# Set up backprop all-reduce.
input_parallel = copy_to_tensor_model_parallel_region(input_)
if self.is_expert_without_slicing: # non-expert only tensor parallelism
input_parallel = input_
else:
input_parallel = copy_to_tensor_model_parallel_region(input_)

# Matrix multiply.

bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output:
if self.gather_output and not self.is_expert_without_slicing:
# All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel)
else:
Expand Down Expand Up @@ -330,15 +340,22 @@ def __init__(self, input_size, output_size, bias=True,
input_is_parallel=False,
init_method=init.xavier_normal_, stride=1,
keep_master_weight_for_test=False,
skip_bias_add=False, MOE=False, MoE_mp_size=1):
skip_bias_add=False, moe=False, enable_expert_tensor_parallelism=False):
super(RowParallelLinear, self).__init__()

# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.input_is_parallel = input_is_parallel
# Divide the weight matrix along the last dimension.
world_size = MoE_mp_size if MOE else get_tensor_model_parallel_world_size()

if moe and (not enable_expert_tensor_parallelism):
world_size = 1
else:
world_size = get_tensor_model_parallel_world_size()

self.is_expert_without_slicing = moe and world_size==1

self.input_size_per_partition = divide(input_size, world_size)
self.skip_bias_add = skip_bias_add

Expand Down Expand Up @@ -379,14 +396,18 @@ def __init__(self, input_size, output_size, bias=True,

def forward(self, input_):
# Set up backprop all-reduce.
if self.input_is_parallel:
if self.input_is_parallel or self.is_expert_without_slicing:
input_parallel = input_
else:
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight)
# All-reduce across all the partitions.
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if self.is_expert_without_slicing: # non-expert only tensor-parallelism
output_ = output_parallel
else:
output_ = reduce_from_tensor_model_parallel_region(output_parallel)

if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_
output_bias = None
Expand Down