Conversation
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
367bc67 to
44bba2e
Compare
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
📝 WalkthroughWalkthroughAdds LoRA-aware FLOPs calculation logic to the transformer FLOPs computation function. When a LoRA peft scheme is detected, the function uses a hard-coded sequence-length statistics table to compute model FLOPs for both frozen and unfrozen parameters instead of standard transformation calculation. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~15 minutes Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (2)
src/megatron/bridge/training/utils/flop_utils.py (2)
191-202:_LORA_SEQ_STATSshould be a module-level constant, and its dataset-specificity needs documentation.Two distinct concerns here:
Placement:
_LORA_SEQ_STATSis a constant dict literal re-created inside a nested function on every call. It has no dependency on any local state and should live at module level alongside the other module-scope constants.Scope of applicability: The empirical values (
842603,2030, etc.) are clearly dataset-specific statistics (average sum-of-squared sequence lengths, average token count per packed sample) derived from a particular LLaMA3 training dataset. They are silently wrong for any other model or dataset, yet the error path only checks for unsupportedseq_lengthvalues — it says nothing about unsupported models or datasets. The constant (and theValueErrormessage) should document this clearly.♻️ Proposed refactor
+# Empirical per-sequence packing statistics for LLaMA3 LoRA training. +# Keys are configured seq_length values. Values are (avg_seqlen_sq, avg_tokens) where: +# avg_seqlen_sq: average sum of squared individual sequence lengths in a packed sample +# avg_tokens: average number of real (non-padding) tokens per packed sample +# These values are dataset-specific; update them if the training dataset changes. +_LORA_SEQ_STATS: dict[int, tuple[int, int]] = { + 4096: (842603, 4096), + 2048: (488991, 2030), +} + def num_floating_point_operations(cfg: ConfigContainer, batch_size: int = 1): ... def transformer_flops(): ... if is_lora: - _LORA_SEQ_STATS = { - 4096: (842603, 4096), - 2048: (488991, 2030), - } seq_len = cfg.model.seq_length🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/training/utils/flop_utils.py` around lines 191 - 202, Move the _LORA_SEQ_STATS dict out of the function and declare it as a module-level constant (e.g., _LORA_SEQ_STATS) alongside other module-scope constants, and add a brief docstring comment explaining these values are empirical statistics derived from a specific LLaMA3 training dataset (i.e., dataset/model-specific and not universally applicable). In the is_lora branch where seq_len = cfg.model.seq_length and you lookup _LORA_SEQ_STATS, update the ValueError message to explicitly state that the table contains dataset/model-specific statistics and that seq_length is unsupported for the available dataset statistics, prompting the user to add an entry or provide appropriate dataset-specific stats. Ensure references in the code use the module-level _LORA_SEQ_STATS and the existing symbols is_lora, seq_len, and cfg.model.seq_length.
210-223: Document the derivation of the magic constants (12,18,6,2/3) in the LoRA FLOPs formula.The non-LoRA path carefully documents every numeric factor (lines 273–281, with an arXiv reference). The LoRA path introduces
12,18 * ffn_hs / hs,6 * vocab_size / (n_layers * hs), and the2.0 / 3.0frozen-pass multiplier with no explanation. Adding a short derivation comment (analogous to the existing arXiv reference block) would make the formula verifiable and maintainable.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/training/utils/flop_utils.py` around lines 210 - 223, Add a short derivation comment above the LoRA FLOPs calculation (the block computing model_flops_frozen and model_flops_unfrozen) that documents where the numeric factors 12, 18, 6 and the frozen-pass multiplier 2.0/3.0 come from, referencing the involved symbols (avg_tokens, n_layers, hs, num_query_groups, n_heads, ffn_hs, vocab_size, avg_seqlen2, batch_size); mirror the style of the existing arXiv-backed comment in the non-LoRA path and cite the paper(s) or algebra that justify each term (e.g., attention q/k/v projections, output projection, FFN two matmuls, vocab-embedding softmax cost split, and the reduced cost for frozen forward/backward passes leading to 2/3 factor).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/megatron/bridge/training/utils/flop_utils.py`:
- Line 208: The LoRA branch sets vocab_size from cfg.model.vocab_size while the
non‑LoRA transformer_flops path uses calculate_padded_vocab_size(...), causing
inconsistent embedding/logit FLOP estimates; update the LoRA path to call
calculate_padded_vocab_size(cfg.model.vocab_size, tp_size, ...) (same args used
by transformer_flops) and use that padded value for embedding/logit FLOP
calculation (or, if intentional, add a clear comment in the LoRA branch
explaining why raw vocab_size is used). Locate the vocab_size assignment in the
LoRA code path and align it with the calculate_padded_vocab_size usage
referenced in transformer_flops.
- Around line 198-201: Ruff TRY003 flags the multi-line message being
constructed directly in the raise site; extract that long string into a variable
first and then raise with that variable. Locate the raise in flop_utils.py where
seq_len and _LORA_SEQ_STATS are referenced, create a local variable (e.g., msg
or error_msg) assigned to the formatted message using seq_len, and replace the
raise ValueError(...) with raise ValueError(msg).
- Around line 24-26: Replace the reflection-based LoRA detection in
flop_utils.py: instead of computing peft_scheme via
peft.__class__.__name__.lower() and checking "lora" in it, import the concrete
classes LoRA, VLMLoRA, CanonicalLoRA from megatron.bridge.peft.lora and set
is_lora = isinstance(peft, (LoRA, VLMLoRA, CanonicalLoRA)); keep the peft =
getattr(cfg, "peft", None) line and update the import list accordingly so
detection uses isinstance on the peft variable.
---
Nitpick comments:
In `@src/megatron/bridge/training/utils/flop_utils.py`:
- Around line 191-202: Move the _LORA_SEQ_STATS dict out of the function and
declare it as a module-level constant (e.g., _LORA_SEQ_STATS) alongside other
module-scope constants, and add a brief docstring comment explaining these
values are empirical statistics derived from a specific LLaMA3 training dataset
(i.e., dataset/model-specific and not universally applicable). In the is_lora
branch where seq_len = cfg.model.seq_length and you lookup _LORA_SEQ_STATS,
update the ValueError message to explicitly state that the table contains
dataset/model-specific statistics and that seq_length is unsupported for the
available dataset statistics, prompting the user to add an entry or provide
appropriate dataset-specific stats. Ensure references in the code use the
module-level _LORA_SEQ_STATS and the existing symbols is_lora, seq_len, and
cfg.model.seq_length.
- Around line 210-223: Add a short derivation comment above the LoRA FLOPs
calculation (the block computing model_flops_frozen and model_flops_unfrozen)
that documents where the numeric factors 12, 18, 6 and the frozen-pass
multiplier 2.0/3.0 come from, referencing the involved symbols (avg_tokens,
n_layers, hs, num_query_groups, n_heads, ffn_hs, vocab_size, avg_seqlen2,
batch_size); mirror the style of the existing arXiv-backed comment in the
non-LoRA path and cite the paper(s) or algebra that justify each term (e.g.,
attention q/k/v projections, output projection, FFN two matmuls, vocab-embedding
softmax cost split, and the reduced cost for frozen forward/backward passes
leading to 2/3 factor).
| n_layers = cfg.model.num_layers | ||
| n_heads = cfg.model.num_attention_heads | ||
| ffn_hs = cfg.model.ffn_hidden_size | ||
| vocab_size = cfg.model.vocab_size |
There was a problem hiding this comment.
LoRA path uses raw vocab_size; non-LoRA path uses padded vocab size — inconsistency may skew estimates.
Line 208 reads cfg.model.vocab_size directly, while the non-LoRA transformer_flops path calls calculate_padded_vocab_size(...) (line 358) which accounts for tensor-parallel padding. The resulting FLOPs estimate for the embedding/logit term will be slightly lower than the padded-vocab equivalent, making LoRA FLOPs non-comparable to the non-LoRA baseline. If this is intentional, add a comment; otherwise align with the padded calculation.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/megatron/bridge/training/utils/flop_utils.py` at line 208, The LoRA
branch sets vocab_size from cfg.model.vocab_size while the non‑LoRA
transformer_flops path uses calculate_padded_vocab_size(...), causing
inconsistent embedding/logit FLOP estimates; update the LoRA path to call
calculate_padded_vocab_size(cfg.model.vocab_size, tp_size, ...) (same args used
by transformer_flops) and use that padded value for embedding/logit FLOP
calculation (or, if intentional, add a clear comment in the LoRA branch
explaining why raw vocab_size is used). Locate the vocab_size assignment in the
LoRA code path and align it with the calculate_padded_vocab_size usage
referenced in transformer_flops.
4a1bb6f to
5557e4e
Compare
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com> Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com> Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com> Signed-off-by: pengdurice <pengduhit@gmail.com>
…Mo#2541) Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
Summary by CodeRabbit
Release Notes