Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 2 additions & 36 deletions src/megatron/bridge/training/utils/flop_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,14 @@

import torch.nn.functional as F

from megatron.bridge.peft.lora import LoRA
from megatron.bridge.training.config import ConfigContainer
from megatron.bridge.utils.vocab_utils import calculate_padded_vocab_size


def num_floating_point_operations(cfg: ConfigContainer, batch_size: int = 1):
"""Return the number of floating point operations"""
peft = getattr(cfg, "peft", None)
is_lora = isinstance(peft, LoRA)
# If the model provider has a custom TFLOPS calculation method, use it (non-LoRA only).
if not is_lora and hasattr(cfg.model, "_get_num_floating_point_operations"):
# If the model provider has a custom TFLOPS calculation method, use it.
if hasattr(cfg.model, "_get_num_floating_point_operations"):
return cfg.model._get_num_floating_point_operations(batch_size)

def calculate_layer_counts():
Expand Down Expand Up @@ -186,37 +183,6 @@ def transformer_flops():
num_query_groups = (
cfg.model.num_attention_heads if cfg.model.num_query_groups is None else cfg.model.num_query_groups
)

if is_lora:
_LORA_SEQ_STATS = {
4096: (842603, 4096),
2048: (488991, 2030),
}
seq_len = cfg.model.seq_length
if seq_len not in _LORA_SEQ_STATS:
raise ValueError(f"No LoRA stats for seq_length={seq_len}. Add it to _LORA_SEQ_STATS.")
avg_seqlen2, avg_tokens = _LORA_SEQ_STATS[seq_len]

hs = cfg.model.hidden_size
n_layers = cfg.model.num_layers
n_heads = cfg.model.num_attention_heads
ffn_hs = cfg.model.ffn_hidden_size
vocab_size = cfg.model.vocab_size

model_flops_frozen = (
avg_tokens
* n_layers
* hs**2
* (
12
+ 12 * num_query_groups / n_heads
+ 18 * ffn_hs / hs
+ 6 * vocab_size / (n_layers * hs)
)
)
model_flops_unfrozen = n_layers * hs**2 * (12 * avg_seqlen2 / hs)

return batch_size * (model_flops_frozen * (2.0 / 3.0) + model_flops_unfrozen)
# MoE.
if cfg.model.num_moe_experts is None:
# Every Transformer MLP is dense.
Expand Down
Loading