Skip to content

Fix LLAMA3 LoRa TFLOPs Formula#2416

Merged
ko3n1g merged 13 commits intomainfrom
rmukundan/update-llama3-lora-tflops-formula
Feb 25, 2026
Merged

Fix LLAMA3 LoRa TFLOPs Formula#2416
ko3n1g merged 13 commits intomainfrom
rmukundan/update-llama3-lora-tflops-formula

Conversation

@rhmukundan
Copy link
Copy Markdown
Contributor

@rhmukundan rhmukundan commented Feb 17, 2026

Summary by CodeRabbit

Release Notes

  • Improvements
    • Enhanced FLOPs calculation to support LoRA-based model training with sequence-length-specific computation paths and validation checks.

Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
@rhmukundan rhmukundan self-assigned this Feb 17, 2026
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Feb 17, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
@rhmukundan rhmukundan changed the title Update LLAMA3 LoRa TFLOPs Formula Fix LLAMA3 LoRa TFLOPs Formula Feb 17, 2026
@guyueh1 guyueh1 requested a review from malay-nagda February 17, 2026 23:05
@rhmukundan rhmukundan requested a review from guyueh1 February 18, 2026 01:45
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
@rhmukundan rhmukundan force-pushed the rmukundan/update-llama3-lora-tflops-formula branch from 367bc67 to 44bba2e Compare February 18, 2026 17:54
rhmukundan and others added 6 commits February 18, 2026 11:30
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
@ko3n1g ko3n1g mentioned this pull request Feb 24, 2026
5 tasks
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
@rhmukundan rhmukundan marked this pull request as ready for review February 24, 2026 20:06
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Feb 24, 2026

📝 Walkthrough

Walkthrough

Adds LoRA-aware FLOPs calculation logic to the transformer FLOPs computation function. When a LoRA peft scheme is detected, the function uses a hard-coded sequence-length statistics table to compute model FLOPs for both frozen and unfrozen parameters instead of standard transformation calculation.

Changes

Cohort / File(s) Summary
LoRA FLOPs Calculation
src/megatron/bridge/training/utils/flop_utils.py
Added LoRA detection branch in transformer_flops that looks up sequence-length-specific statistics from _LORA_SEQ_STATS table and computes frozen/unfrozen model FLOPs using LoRA-specific parameters (hidden_size, num_layers, etc.). Validates seq_length presence and raises ValueError if missing. Preserves existing standard FLOPs computation for non-LoRA cases.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~15 minutes

Suggested labels

r0.3.0

Suggested reviewers

  • malay-nagda
  • erhoo82
🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR modifies LoRA TFLOPS calculation formula but lacks test validation of the new formula and documentation of hardcoded values. Add LoRA-specific FLOP validation tests and include numerical validation in PR description showing the fix is correct and non-LoRA calculations remain unchanged.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Fix LLAMA3 LoRa TFLOPs Formula' is directly related to the main change: introducing LoRA-aware FLOPs calculation for LLAMA3 by detecting LoRA and using a specialized computation path with sequence-length-specific statistics.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch rmukundan/update-llama3-lora-tflops-formula

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (2)
src/megatron/bridge/training/utils/flop_utils.py (2)

191-202: _LORA_SEQ_STATS should be a module-level constant, and its dataset-specificity needs documentation.

Two distinct concerns here:

  1. Placement: _LORA_SEQ_STATS is a constant dict literal re-created inside a nested function on every call. It has no dependency on any local state and should live at module level alongside the other module-scope constants.

  2. Scope of applicability: The empirical values (842603, 2030, etc.) are clearly dataset-specific statistics (average sum-of-squared sequence lengths, average token count per packed sample) derived from a particular LLaMA3 training dataset. They are silently wrong for any other model or dataset, yet the error path only checks for unsupported seq_length values — it says nothing about unsupported models or datasets. The constant (and the ValueError message) should document this clearly.

♻️ Proposed refactor
+# Empirical per-sequence packing statistics for LLaMA3 LoRA training.
+# Keys are configured seq_length values. Values are (avg_seqlen_sq, avg_tokens) where:
+#   avg_seqlen_sq: average sum of squared individual sequence lengths in a packed sample
+#   avg_tokens:    average number of real (non-padding) tokens per packed sample
+# These values are dataset-specific; update them if the training dataset changes.
+_LORA_SEQ_STATS: dict[int, tuple[int, int]] = {
+    4096: (842603, 4096),
+    2048: (488991, 2030),
+}
+
 def num_floating_point_operations(cfg: ConfigContainer, batch_size: int = 1):
     ...
     def transformer_flops():
         ...
         if is_lora:
-            _LORA_SEQ_STATS = {
-                4096: (842603, 4096),
-                2048: (488991, 2030),
-            }
             seq_len = cfg.model.seq_length
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/training/utils/flop_utils.py` around lines 191 - 202,
Move the _LORA_SEQ_STATS dict out of the function and declare it as a
module-level constant (e.g., _LORA_SEQ_STATS) alongside other module-scope
constants, and add a brief docstring comment explaining these values are
empirical statistics derived from a specific LLaMA3 training dataset (i.e.,
dataset/model-specific and not universally applicable). In the is_lora branch
where seq_len = cfg.model.seq_length and you lookup _LORA_SEQ_STATS, update the
ValueError message to explicitly state that the table contains
dataset/model-specific statistics and that seq_length is unsupported for the
available dataset statistics, prompting the user to add an entry or provide
appropriate dataset-specific stats. Ensure references in the code use the
module-level _LORA_SEQ_STATS and the existing symbols is_lora, seq_len, and
cfg.model.seq_length.

210-223: Document the derivation of the magic constants (12, 18, 6, 2/3) in the LoRA FLOPs formula.

The non-LoRA path carefully documents every numeric factor (lines 273–281, with an arXiv reference). The LoRA path introduces 12, 18 * ffn_hs / hs, 6 * vocab_size / (n_layers * hs), and the 2.0 / 3.0 frozen-pass multiplier with no explanation. Adding a short derivation comment (analogous to the existing arXiv reference block) would make the formula verifiable and maintainable.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/training/utils/flop_utils.py` around lines 210 - 223, Add
a short derivation comment above the LoRA FLOPs calculation (the block computing
model_flops_frozen and model_flops_unfrozen) that documents where the numeric
factors 12, 18, 6 and the frozen-pass multiplier 2.0/3.0 come from, referencing
the involved symbols (avg_tokens, n_layers, hs, num_query_groups, n_heads,
ffn_hs, vocab_size, avg_seqlen2, batch_size); mirror the style of the existing
arXiv-backed comment in the non-LoRA path and cite the paper(s) or algebra that
justify each term (e.g., attention q/k/v projections, output projection, FFN two
matmuls, vocab-embedding softmax cost split, and the reduced cost for frozen
forward/backward passes leading to 2/3 factor).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/megatron/bridge/training/utils/flop_utils.py`:
- Line 208: The LoRA branch sets vocab_size from cfg.model.vocab_size while the
non‑LoRA transformer_flops path uses calculate_padded_vocab_size(...), causing
inconsistent embedding/logit FLOP estimates; update the LoRA path to call
calculate_padded_vocab_size(cfg.model.vocab_size, tp_size, ...) (same args used
by transformer_flops) and use that padded value for embedding/logit FLOP
calculation (or, if intentional, add a clear comment in the LoRA branch
explaining why raw vocab_size is used). Locate the vocab_size assignment in the
LoRA code path and align it with the calculate_padded_vocab_size usage
referenced in transformer_flops.
- Around line 198-201: Ruff TRY003 flags the multi-line message being
constructed directly in the raise site; extract that long string into a variable
first and then raise with that variable. Locate the raise in flop_utils.py where
seq_len and _LORA_SEQ_STATS are referenced, create a local variable (e.g., msg
or error_msg) assigned to the formatted message using seq_len, and replace the
raise ValueError(...) with raise ValueError(msg).
- Around line 24-26: Replace the reflection-based LoRA detection in
flop_utils.py: instead of computing peft_scheme via
peft.__class__.__name__.lower() and checking "lora" in it, import the concrete
classes LoRA, VLMLoRA, CanonicalLoRA from megatron.bridge.peft.lora and set
is_lora = isinstance(peft, (LoRA, VLMLoRA, CanonicalLoRA)); keep the peft =
getattr(cfg, "peft", None) line and update the import list accordingly so
detection uses isinstance on the peft variable.

---

Nitpick comments:
In `@src/megatron/bridge/training/utils/flop_utils.py`:
- Around line 191-202: Move the _LORA_SEQ_STATS dict out of the function and
declare it as a module-level constant (e.g., _LORA_SEQ_STATS) alongside other
module-scope constants, and add a brief docstring comment explaining these
values are empirical statistics derived from a specific LLaMA3 training dataset
(i.e., dataset/model-specific and not universally applicable). In the is_lora
branch where seq_len = cfg.model.seq_length and you lookup _LORA_SEQ_STATS,
update the ValueError message to explicitly state that the table contains
dataset/model-specific statistics and that seq_length is unsupported for the
available dataset statistics, prompting the user to add an entry or provide
appropriate dataset-specific stats. Ensure references in the code use the
module-level _LORA_SEQ_STATS and the existing symbols is_lora, seq_len, and
cfg.model.seq_length.
- Around line 210-223: Add a short derivation comment above the LoRA FLOPs
calculation (the block computing model_flops_frozen and model_flops_unfrozen)
that documents where the numeric factors 12, 18, 6 and the frozen-pass
multiplier 2.0/3.0 come from, referencing the involved symbols (avg_tokens,
n_layers, hs, num_query_groups, n_heads, ffn_hs, vocab_size, avg_seqlen2,
batch_size); mirror the style of the existing arXiv-backed comment in the
non-LoRA path and cite the paper(s) or algebra that justify each term (e.g.,
attention q/k/v projections, output projection, FFN two matmuls, vocab-embedding
softmax cost split, and the reduced cost for frozen forward/backward passes
leading to 2/3 factor).

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7cbcf4a and 4a1bb6f.

📒 Files selected for processing (1)
  • src/megatron/bridge/training/utils/flop_utils.py

n_layers = cfg.model.num_layers
n_heads = cfg.model.num_attention_heads
ffn_hs = cfg.model.ffn_hidden_size
vocab_size = cfg.model.vocab_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

LoRA path uses raw vocab_size; non-LoRA path uses padded vocab size — inconsistency may skew estimates.

Line 208 reads cfg.model.vocab_size directly, while the non-LoRA transformer_flops path calls calculate_padded_vocab_size(...) (line 358) which accounts for tensor-parallel padding. The resulting FLOPs estimate for the embedding/logit term will be slightly lower than the padded-vocab equivalent, making LoRA FLOPs non-comparable to the non-LoRA baseline. If this is intentional, add a comment; otherwise align with the padded calculation.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/training/utils/flop_utils.py` at line 208, The LoRA
branch sets vocab_size from cfg.model.vocab_size while the non‑LoRA
transformer_flops path uses calculate_padded_vocab_size(...), causing
inconsistent embedding/logit FLOP estimates; update the LoRA path to call
calculate_padded_vocab_size(cfg.model.vocab_size, tp_size, ...) (same args used
by transformer_flops) and use that padded value for embedding/logit FLOP
calculation (or, if intentional, add a clear comment in the LoRA branch
explaining why raw vocab_size is used). Locate the vocab_size assignment in the
LoRA code path and align it with the calculate_padded_vocab_size usage
referenced in transformer_flops.

Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
@rhmukundan rhmukundan force-pushed the rmukundan/update-llama3-lora-tflops-formula branch from 4a1bb6f to 5557e4e Compare February 24, 2026 23:28
@malay-nagda malay-nagda added performance r0.3.0 Cherry-pick label for r0.3.0 release branch labels Feb 25, 2026
@rhmukundan rhmukundan requested a review from erhoo82 February 25, 2026 16:29
Copy link
Copy Markdown
Contributor

@ko3n1g ko3n1g left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ty!

@ko3n1g ko3n1g merged commit fa4a01b into main Feb 25, 2026
2 checks passed
@ko3n1g ko3n1g deleted the rmukundan/update-llama3-lora-tflops-formula branch February 25, 2026 20:37
svcnvidia-nemo-ci pushed a commit that referenced this pull request Feb 25, 2026
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
ko3n1g pushed a commit that referenced this pull request Feb 25, 2026
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
@maanug-nv maanug-nv mentioned this pull request Feb 25, 2026
5 tasks
ko3n1g added a commit that referenced this pull request Feb 25, 2026
ko3n1g added a commit that referenced this pull request Feb 25, 2026
pengdurice pushed a commit to pengdurice/Megatron-Bridge that referenced this pull request Feb 26, 2026
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
pengdurice pushed a commit to pengdurice/Megatron-Bridge that referenced this pull request Feb 26, 2026
copy-pr-bot bot pushed a commit that referenced this pull request Mar 19, 2026
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance r0.3.0 Cherry-pick label for r0.3.0 release branch

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants