Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 61 additions & 5 deletions src/megatron/bridge/training/utils/flop_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,51 @@ def num_floating_point_operations(cfg: ConfigContainer, batch_size: int = 1):
return cfg.model._get_num_floating_point_operations(batch_size)

def calculate_layer_counts():
"""Calculate the number of attention, Mamba, and MLP layers."""
"""Calculate the number of attention, Mamba, MLP, and MoE layers."""
if hasattr(cfg.model, "hybrid_override_pattern") and cfg.model.hybrid_override_pattern:
counts = {"M": 0, "*": 0, "-": 0}
counts = {"M": 0, "*": 0, "-": 0, "E": 0}
for layer_type in cfg.model.hybrid_override_pattern:
if layer_type in counts:
counts[layer_type] += 1
return counts["*"], counts["M"], counts["-"]
return counts["*"], counts["M"], counts["-"], counts["E"]
else:
num_attn_layers = round(cfg.model.num_layers * getattr(cfg.model, "hybrid_attention_ratio", 0))
num_mlp_layers = round(cfg.model.num_layers * getattr(cfg.model, "hybrid_mlp_ratio", 0))
num_mamba_layers = cfg.model.num_layers - num_attn_layers - num_mlp_layers
return num_attn_layers, num_mamba_layers, num_mlp_layers
num_moe_layers = 0
return num_attn_layers, num_mamba_layers, num_mlp_layers, num_moe_layers

def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=False):
"""Calculate FLOPs for an MLP layer."""
scale_factor = 3.0 / 2.0 if swiglu else 1.0
return 4 * expansion * scale_factor * batch_size * seq_len * hidden_size**2

def moe_layer_flops(
batch_size,
seq_len,
hidden_size,
moe_ffn_hidden_size,
shared_expert_ffn_hidden_size,
num_experts_routed_to,
moe_latent_size=None,
swiglu=False,
):
"""Calculate FLOPs for an MoE layer."""
scale_factor = 3.0 / 2.0 if swiglu else 1.0
if moe_latent_size is None:
routed_flops = (
4 * batch_size * seq_len * hidden_size * moe_ffn_hidden_size * num_experts_routed_to * scale_factor
)
else:
# Routed experts run on moe_latent_size.
routed_flops = (
4 * batch_size * seq_len * moe_latent_size * moe_ffn_hidden_size * num_experts_routed_to * scale_factor
)
# Up proj and down proj.
routed_flops += 4 * batch_size * seq_len * hidden_size * moe_latent_size
shared_flops = 4 * batch_size * seq_len * hidden_size * shared_expert_ffn_hidden_size * scale_factor
return routed_flops + shared_flops

def attn_layer_flops(
batch_size,
seq_len,
Expand Down Expand Up @@ -93,6 +120,7 @@ def hybrid_flops(
num_attn_layers,
num_mamba_layers,
num_mlp_layers,
num_moe_layers,
mamba_state_dim=128,
mamba_head_dim=64,
mamba_num_groups=8,
Expand All @@ -102,6 +130,10 @@ def hybrid_flops(
kv_channels=None,
mlp_expansion=4.0,
swiglu=False,
moe_latent_size=None,
moe_ffn_hidden_size=2048,
shared_expert_ffn_hidden_size=2048,
num_experts_routed_to=1,
vocab_size=256000,
):
"""Calculate total FLOPs for the hybrid model."""
Expand All @@ -126,6 +158,17 @@ def hybrid_flops(
mamba_num_groups,
mamba_num_heads,
)
+ num_moe_layers
* moe_layer_flops(
batch_size,
seq_len,
hidden_size,
moe_ffn_hidden_size,
shared_expert_ffn_hidden_size,
num_experts_routed_to,
moe_latent_size,
swiglu,
)
+ (2 * batch_size * seq_len * hidden_size * vocab_size) # logits computation
)
return flops_fwd * 3
Expand Down Expand Up @@ -319,7 +362,7 @@ def transformer_flops():
# Main entrypoint for FLOPs calculation.
if getattr(cfg.model, "is_hybrid_model", False):
# Calculate the number of each type of layer.
num_attn_layers, num_mamba_layers, num_mlp_layers = calculate_layer_counts()
num_attn_layers, num_mamba_layers, num_mlp_layers, num_moe_layers = calculate_layer_counts()
padded_vocab_size = calculate_padded_vocab_size(
cfg.model.vocab_size,
cfg.model.make_vocab_size_divisible_by,
Expand All @@ -338,6 +381,7 @@ def transformer_flops():
num_attn_layers=num_attn_layers,
num_mamba_layers=num_mamba_layers,
num_mlp_layers=num_mlp_layers,
num_moe_layers=num_moe_layers,
mamba_state_dim=getattr(cfg.model, "mamba_state_dim", 128),
mamba_head_dim=getattr(cfg.model, "mamba_head_dim", 64),
mamba_num_groups=getattr(cfg.model, "mamba_num_groups", 8),
Expand All @@ -347,6 +391,18 @@ def transformer_flops():
kv_channels=getattr(cfg.model, "kv_channels", None),
mlp_expansion=cfg.model.ffn_hidden_size / cfg.model.hidden_size,
swiglu=getattr(cfg.model, "gated_linear_unit", False),
moe_latent_size=getattr(cfg.model, "moe_latent_size", None),
moe_ffn_hidden_size=(
cfg.model.ffn_hidden_size
if getattr(cfg.model, "moe_ffn_hidden_size", None) is None
else cfg.model.moe_ffn_hidden_size
),
shared_expert_ffn_hidden_size=(
0
if getattr(cfg.model, "moe_shared_expert_intermediate_size", None) is None
else cfg.model.moe_shared_expert_intermediate_size
),
num_experts_routed_to=getattr(cfg.model, "moe_router_topk", 1),
vocab_size=padded_vocab_size,
)
else:
Expand Down
Loading
Loading