Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,22 +159,32 @@ def num_floating_point_operations(args, batch_size):
def calculate_layer_counts():
"""Calculate the number of attention, Mamba, and MLP layers."""
if args.hybrid_override_pattern:
counts = {'M': 0, '*': 0, '-': 0}
counts = {'M': 0, '*': 0, '-': 0, 'E':0}
for layer_type in args.hybrid_override_pattern:
if layer_type in counts:
counts[layer_type] += 1
return counts['*'], counts['M'], counts['-']
return counts['*'], counts['M'], counts['-'], counts['E']
else:
num_attn_layers = round(args.num_layers * args.hybrid_attention_ratio)
num_mlp_layers = round(args.num_layers * args.hybrid_mlp_ratio)
num_mamba_layers = args.num_layers - num_attn_layers - num_mlp_layers
return num_attn_layers, num_mamba_layers, num_mlp_layers
num_moe_layers = 0
return num_attn_layers, num_mamba_layers, num_mlp_layers, num_moe_layers

def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=False):
"""Calculate FLOPs for an MLP layer."""
scale_factor = 3.0 / 2.0 if swiglu else 1.0
return 4 * expansion * scale_factor * batch_size * seq_len * hidden_size**2

def moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size,
shared_expert_ffn_hidden_size, num_experts_routed_to, swiglu=False):
"""Calculate FLOPs for an MoE layer."""
scale_factor = 3.0 / 2.0 if swiglu else 1.0
routed_flops = (4 * batch_size * seq_len * hidden_size *
moe_ffn_hidden_size * num_experts_routed_to * scale_factor)
shared_flops = 4 * batch_size * seq_len * hidden_size * shared_expert_ffn_hidden_size * scale_factor
return routed_flops + shared_flops

def attn_layer_flops(
batch_size, seq_len, hidden_size, num_heads, gqa=True, gqa_groups=8, kv_channels=None
):
Expand Down Expand Up @@ -213,12 +223,13 @@ def mamba_layer_flops(batch_size, seq_len, hidden_size, state_dim=16,
)

def hybrid_flops(batch_size, seq_len, hidden_size,
num_attn_layers, num_mamba_layers, num_mlp_layers,
num_attn_layers, num_mamba_layers, num_mlp_layers, num_moe_layers,
mamba_state_dim=128, mamba_head_dim=64,
mamba_num_groups=8, mamba_num_heads=128,
num_attn_heads=32,gqa=True,
num_attn_heads=32, gqa=True,
gqa_groups=8, kv_channels=None,
mlp_expansion=4.0, swiglu=False,
moe_ffn_hidden_size=2048, shared_expert_ffn_hidden_size=2048, num_experts_routed_to=1,
vocab_size=256000):
"""Calculate total FLOPs for the hybrid model."""
flops_fwd = (
Expand All @@ -229,6 +240,8 @@ def hybrid_flops(batch_size, seq_len, hidden_size,
num_mamba_layers * mamba_layer_flops(batch_size, seq_len, hidden_size,
mamba_state_dim, mamba_head_dim,
mamba_num_groups, mamba_num_heads) +
num_moe_layers * moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size,
shared_expert_ffn_hidden_size, num_experts_routed_to, swiglu) +
(2 * batch_size * seq_len * hidden_size * vocab_size) # logits computation
)
return flops_fwd * 3
Expand Down Expand Up @@ -412,7 +425,7 @@ def transformer_flops():
# Main entrypoint for FLOPs calculation.
if args.is_hybrid_model:
# Calculate the number of each type of layer.
num_attn_layers, num_mamba_layers, num_mlp_layers = calculate_layer_counts()
num_attn_layers, num_mamba_layers, num_mlp_layers, num_moe_layers = calculate_layer_counts()

# Compute hybrid model FLOPs.
return hybrid_flops(
Expand All @@ -422,6 +435,7 @@ def transformer_flops():
num_attn_layers=num_attn_layers,
num_mamba_layers=num_mamba_layers,
num_mlp_layers=num_mlp_layers,
num_moe_layers=num_moe_layers,
mamba_state_dim=args.mamba_state_dim,
mamba_head_dim=args.mamba_head_dim,
mamba_num_groups=args.mamba_num_groups,
Expand All @@ -432,6 +446,11 @@ def transformer_flops():
kv_channels=args.kv_channels,
mlp_expansion=args.ffn_hidden_size / args.hidden_size,
swiglu=args.swiglu,
moe_ffn_hidden_size=(args.moe_ffn_hidden_size if args.moe_ffn_hidden_size is not None
else args.ffn_hidden_size),
shared_expert_ffn_hidden_size=(0 if args.moe_shared_expert_intermediate_size is None
else args.moe_shared_expert_intermediate_size),
num_experts_routed_to=args.moe_router_topk,
vocab_size=args.padded_vocab_size,
)
else:
Expand Down
Loading