Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nemo/collections/common/metrics/perf_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def eval_model_flops(self):
"gpt3": flops_formulas.gpt3,
"llama2": flops_formulas.llama2,
"llama3": flops_formulas.llama3,
"llama4": flops_formulas.llama4,
"nemotron": flops_formulas.nemotron,
"mixtral": flops_formulas.mixtral,
"bert": flops_formulas.bert,
Expand Down
3 changes: 2 additions & 1 deletion nemo/collections/common/parts/perf_metrics_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"gpt3": 51200,
"llama2": 32000,
"llama3": 128256,
"llama4": 202048,
"nemotron": 256000,
"bert": 29000,
"mixtral": 32000,
Expand All @@ -45,7 +46,7 @@ def read_tb_log(path: str, summary_name: str) -> List:
files = glob.glob(f"{path}/events*tfevents*")
files.sort(key=lambda x: os.path.getmtime(os.path.join(path, x)))
if len(files) == 0 or not os.path.isfile(files[0]):
raise FileNotFoundError(f"Missing TensorBoard log file.")
raise FileNotFoundError("Missing TensorBoard log file.")

events_file = files[0]
try:
Expand Down
2 changes: 1 addition & 1 deletion nemo/lightning/pytorch/callbacks/flops_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"gpt3": flops_formulas.gpt3,
"llama2": flops_formulas.llama2,
"llama3": flops_formulas.llama3,
"llama4": flops_formulas.llama3, # TODO: add llama4 flops formulas
"llama4": flops_formulas.llama4,
"nemotron3": flops_formulas.nemotron,
"nemotron4": flops_formulas.nemotron,
"mixtral": flops_formulas.mixtral,
Expand Down
83 changes: 83 additions & 0 deletions nemo/utils/flops_formulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,89 @@ def llama3(config: FLOPSConfig):
)


def llama4(config: FLOPSConfig):
"""Model FLOPs for Llama4 family (MoE architecture)

Llama4 models:
- Scout (16E): All 48 layers use MoE
- Maverick (128E): Alternating dense/MoE pattern via moe_layer_freq

The formula accounts for:
1. Attention computation (GQA) - same across all layers
2. Mixed FFN layers - dense and MoE based on moe_layer_freq pattern
3. Shared expert computation (for MoE layers)
4. Embedding/vocabulary projection
"""
vocab_size = LLM_VOCAB_SIZE_MAP["llama4"]
causal_self_attn = True
seq_len = config.enc_seq_len
hidden_size = config.hs

# Attention FLOPs (same for all layers, using GQA like Llama 3)
# QKV projections + attention computation + output projection
attention_flops = (
config.gbs
* seq_len
* config.layers
* hidden_size
* hidden_size
* (
12 # Q projection
+ (12 * config.query_groups / config.attention_heads) # KV projections (GQA)
+ (12 * seq_len / hidden_size) * (0.5 if causal_self_attn else 1) # Attention computation
)
)

# FFN FLOPs - need to account for both dense and MoE layers
# Create moe_layer_pattern: 0=dense, 1=MoE
if config.moe_layer_freq is None:
# If no pattern specified (e.g., 16E model), all layers are MoE
moe_layer_pattern = [1] * config.layers
elif isinstance(config.moe_layer_freq, int):
# If integer, create pattern (e.g., freq=2 means every 2nd layer is MoE)
moe_layer_pattern = [1 if (i % config.moe_layer_freq == 0) else 0 for i in range(config.layers)]
else:
# If list, use it directly (e.g., [0,1]*24 for 128E model)
moe_layer_pattern = config.moe_layer_freq

# Calculate FLOPs for each layer type
num_dense_layers = sum(1 for x in moe_layer_pattern if x == 0)
num_moe_layers = sum(1 for x in moe_layer_pattern if x == 1)

# Dense layer FFN FLOPs (standard gated FFN)
# Factor of 18 = 6 (fwd+bwd) * 3 (up+gate+down projections)
dense_ffn_flops = 6 * config.gbs * seq_len * num_dense_layers * hidden_size * config.ffn_hs * 3

# MoE layer FFN FLOPs
# Shared experts (always active) + routed experts (top-k)
# Each expert has gated FFN: up_proj, gate_proj, down_proj
moe_ffn_flops = 0
if num_moe_layers > 0:
# Shared expert FLOPs (always computed)
shared_expert_flops = (
6 * config.gbs * seq_len * num_moe_layers * hidden_size * config.moe_shared_expert_intermediate_size * 3
)

# Routed expert FLOPs (only top-k experts per token)
routed_expert_flops = (
6
* config.gbs
* seq_len
* num_moe_layers
* hidden_size
* config.moe_ffn_hidden_size
* config.moe_router_topk
* 3
)

moe_ffn_flops = shared_expert_flops + routed_expert_flops

# Vocabulary/Embedding FLOPs
vocab_flops = 6 * config.gbs * seq_len * hidden_size * vocab_size

return attention_flops + dense_ffn_flops + moe_ffn_flops + vocab_flops


def nemotron(config: FLOPSConfig):
"""Model FLOPs for nemotron family"""
vocab_size = LLM_VOCAB_SIZE_MAP["nemotron"]
Expand Down
Loading