diff --git a/src/megatron/bridge/training/utils/flop_utils.py b/src/megatron/bridge/training/utils/flop_utils.py index d6be885066..2dfdc7f054 100644 --- a/src/megatron/bridge/training/utils/flop_utils.py +++ b/src/megatron/bridge/training/utils/flop_utils.py @@ -14,17 +14,14 @@ import torch.nn.functional as F -from megatron.bridge.peft.lora import LoRA from megatron.bridge.training.config import ConfigContainer from megatron.bridge.utils.vocab_utils import calculate_padded_vocab_size def num_floating_point_operations(cfg: ConfigContainer, batch_size: int = 1): """Return the number of floating point operations""" - peft = getattr(cfg, "peft", None) - is_lora = isinstance(peft, LoRA) - # If the model provider has a custom TFLOPS calculation method, use it (non-LoRA only). - if not is_lora and hasattr(cfg.model, "_get_num_floating_point_operations"): + # If the model provider has a custom TFLOPS calculation method, use it. + if hasattr(cfg.model, "_get_num_floating_point_operations"): return cfg.model._get_num_floating_point_operations(batch_size) def calculate_layer_counts(): @@ -186,37 +183,6 @@ def transformer_flops(): num_query_groups = ( cfg.model.num_attention_heads if cfg.model.num_query_groups is None else cfg.model.num_query_groups ) - - if is_lora: - _LORA_SEQ_STATS = { - 4096: (842603, 4096), - 2048: (488991, 2030), - } - seq_len = cfg.model.seq_length - if seq_len not in _LORA_SEQ_STATS: - raise ValueError(f"No LoRA stats for seq_length={seq_len}. Add it to _LORA_SEQ_STATS.") - avg_seqlen2, avg_tokens = _LORA_SEQ_STATS[seq_len] - - hs = cfg.model.hidden_size - n_layers = cfg.model.num_layers - n_heads = cfg.model.num_attention_heads - ffn_hs = cfg.model.ffn_hidden_size - vocab_size = cfg.model.vocab_size - - model_flops_frozen = ( - avg_tokens - * n_layers - * hs**2 - * ( - 12 - + 12 * num_query_groups / n_heads - + 18 * ffn_hs / hs - + 6 * vocab_size / (n_layers * hs) - ) - ) - model_flops_unfrozen = n_layers * hs**2 * (12 * avg_seqlen2 / hs) - - return batch_size * (model_flops_frozen * (2.0 / 3.0) + model_flops_unfrozen) # MoE. if cfg.model.num_moe_experts is None: # Every Transformer MLP is dense.