diff --git a/megatron/training/training.py b/megatron/training/training.py index d8b57326f67..741b8f536e7 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -159,22 +159,32 @@ def num_floating_point_operations(args, batch_size): def calculate_layer_counts(): """Calculate the number of attention, Mamba, and MLP layers.""" if args.hybrid_override_pattern: - counts = {'M': 0, '*': 0, '-': 0} + counts = {'M': 0, '*': 0, '-': 0, 'E':0} for layer_type in args.hybrid_override_pattern: if layer_type in counts: counts[layer_type] += 1 - return counts['*'], counts['M'], counts['-'] + return counts['*'], counts['M'], counts['-'], counts['E'] else: num_attn_layers = round(args.num_layers * args.hybrid_attention_ratio) num_mlp_layers = round(args.num_layers * args.hybrid_mlp_ratio) num_mamba_layers = args.num_layers - num_attn_layers - num_mlp_layers - return num_attn_layers, num_mamba_layers, num_mlp_layers + num_moe_layers = 0 + return num_attn_layers, num_mamba_layers, num_mlp_layers, num_moe_layers def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=False): """Calculate FLOPs for an MLP layer.""" scale_factor = 3.0 / 2.0 if swiglu else 1.0 return 4 * expansion * scale_factor * batch_size * seq_len * hidden_size**2 + def moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size, + shared_expert_ffn_hidden_size, num_experts_routed_to, swiglu=False): + """Calculate FLOPs for an MoE layer.""" + scale_factor = 3.0 / 2.0 if swiglu else 1.0 + routed_flops = (4 * batch_size * seq_len * hidden_size * + moe_ffn_hidden_size * num_experts_routed_to * scale_factor) + shared_flops = 4 * batch_size * seq_len * hidden_size * shared_expert_ffn_hidden_size * scale_factor + return routed_flops + shared_flops + def attn_layer_flops( batch_size, seq_len, hidden_size, num_heads, gqa=True, gqa_groups=8, kv_channels=None ): @@ -213,12 +223,13 @@ def mamba_layer_flops(batch_size, seq_len, hidden_size, state_dim=16, ) def hybrid_flops(batch_size, seq_len, hidden_size, - num_attn_layers, num_mamba_layers, num_mlp_layers, + num_attn_layers, num_mamba_layers, num_mlp_layers, num_moe_layers, mamba_state_dim=128, mamba_head_dim=64, mamba_num_groups=8, mamba_num_heads=128, - num_attn_heads=32,gqa=True, + num_attn_heads=32, gqa=True, gqa_groups=8, kv_channels=None, mlp_expansion=4.0, swiglu=False, + moe_ffn_hidden_size=2048, shared_expert_ffn_hidden_size=2048, num_experts_routed_to=1, vocab_size=256000): """Calculate total FLOPs for the hybrid model.""" flops_fwd = ( @@ -229,6 +240,8 @@ def hybrid_flops(batch_size, seq_len, hidden_size, num_mamba_layers * mamba_layer_flops(batch_size, seq_len, hidden_size, mamba_state_dim, mamba_head_dim, mamba_num_groups, mamba_num_heads) + + num_moe_layers * moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size, + shared_expert_ffn_hidden_size, num_experts_routed_to, swiglu) + (2 * batch_size * seq_len * hidden_size * vocab_size) # logits computation ) return flops_fwd * 3 @@ -412,7 +425,7 @@ def transformer_flops(): # Main entrypoint for FLOPs calculation. if args.is_hybrid_model: # Calculate the number of each type of layer. - num_attn_layers, num_mamba_layers, num_mlp_layers = calculate_layer_counts() + num_attn_layers, num_mamba_layers, num_mlp_layers, num_moe_layers = calculate_layer_counts() # Compute hybrid model FLOPs. return hybrid_flops( @@ -422,6 +435,7 @@ def transformer_flops(): num_attn_layers=num_attn_layers, num_mamba_layers=num_mamba_layers, num_mlp_layers=num_mlp_layers, + num_moe_layers=num_moe_layers, mamba_state_dim=args.mamba_state_dim, mamba_head_dim=args.mamba_head_dim, mamba_num_groups=args.mamba_num_groups, @@ -432,6 +446,11 @@ def transformer_flops(): kv_channels=args.kv_channels, mlp_expansion=args.ffn_hidden_size / args.hidden_size, swiglu=args.swiglu, + moe_ffn_hidden_size=(args.moe_ffn_hidden_size if args.moe_ffn_hidden_size is not None + else args.ffn_hidden_size), + shared_expert_ffn_hidden_size=(0 if args.moe_shared_expert_intermediate_size is None + else args.moe_shared_expert_intermediate_size), + num_experts_routed_to=args.moe_router_topk, vocab_size=args.padded_vocab_size, ) else: