Update the LoRa TFLOPs Formula Fix without any hardcoding of values #2571
Update the LoRa TFLOPs Formula Fix without any hardcoding of values #2571
Conversation
📝 WalkthroughWalkthroughThis PR adds utility functions for computing sequence length statistics from packed datasets and integrates them into LoRA-specific FLOP computation paths. It introduces data-driven sequence statistics calculation with caching to enable accurate FLOP estimates for LoRA scenarios with specific model and dataset combinations. Changes
Sequence DiagramsequenceDiagram
participant flop_calc as FLOP Calculator<br/>(flop_utils)
participant cache as Sequence Stats<br/>Cache
participant seq_calc as Sequence<br/>Calculator
participant dataset as Dataset<br/>Loader
rect rgba(100, 150, 255, 0.5)
note over flop_calc: LoRA FLOP Computation
flop_calc->>cache: Check cache for<br/>sequence stats
alt Cache Hit
cache-->>flop_calc: Return cached stats
else Cache Miss
flop_calc->>seq_calc: Call calculate_avg_seqlen()
seq_calc->>dataset: Load dataset
dataset-->>seq_calc: Dataset elements
seq_calc->>seq_calc: Iterate & compute<br/>length statistics
seq_calc-->>cache: Store stats in cache
cache-->>flop_calc: Return stats
end
flop_calc->>flop_calc: Compute model_flops<br/>(frozen & unfrozen)
flop_calc->>flop_calc: Return scaled<br/>FLOP result
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~22 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (4)
src/megatron/bridge/training/utils/flop_utils.py (2)
24-24: Consider renaming global cache to follow naming convention.Per coding guidelines, global variables should use upper snake_case with a
G_prefix (e.g.,G_LORA_SEQ_STATS_CACHE). However, if this is intentionally module-private (prefixed with_), document that it's a cache for memoization purposes.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/training/utils/flop_utils.py` at line 24, The global cache _lora_seq_stats_cache should follow project naming conventions: either rename it to G_LORA_SEQ_STATS_CACHE (upper snake_case with G_ prefix) to mark it as a true global, or if you intended it to be module-private keep the underscore but add a short comment/docstring above it explaining "module-private cache for memoization of LoRA sequence stats" so intent is clear; update any references to _lora_seq_stats_cache (e.g., in functions that read/write the cache) to use the new name if you rename it.
230-243: Document the LoRA FLOPs formula derivation.The formula computes frozen and unfrozen model FLOPs separately with specific coefficients (12, 18, 6, 2/3) that aren't explained. Add a comment or reference explaining the derivation for maintainability.
Suggested documentation
+ # LoRA FLOPs computation: + # - model_flops_frozen: FLOPs for frozen layers (forward pass only, hence 2/3 factor) + # Components: attention QKV (12), GQA factor, FFN (18), logits (6) + # - model_flops_unfrozen: FLOPs for LoRA adapter weights (full fwd+bwd) + # Uses avg_seqlen2 (squared sequence lengths) for attention cost model_flops_frozen = ( avg_tokens * n_layers🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/training/utils/flop_utils.py` around lines 230 - 243, The FLOPs calculation in model_flops_frozen and model_flops_unfrozen (using avg_tokens, n_layers, hs, num_query_groups, n_heads, ffn_hs, vocab_size, avg_seqlen2, batch_size and coefficients 12, 18, 6, 2/3) lacks documentation; add a concise comment immediately above these calculations that states the mathematical derivation/assumptions (e.g., which terms correspond to attention, MLP, embedding/VOC operations and why coefficients like 12, 18, 6 and the 2/3 multiplier are used) and include a reference link or citation to the paper/notes where the formula originates so future maintainers can verify the coefficients and assumptions.src/megatron/bridge/data/datasets/packing_utils.py (2)
274-286: Add type hints and docstring for public utility function.This function is used by other modules (
flop_utils.py) and should have type hints and a docstring per coding guidelines.Suggested improvement
-def get_seqlen_list(elem): +def get_seqlen_list(elem: dict) -> tuple[list[int], int]: + """Compute per-sequence token counts from a packed dataset element. + + Args: + elem: A dictionary containing 'seq_start_id' (list of sequence start indices) + and 'input_ids' (list of token IDs). + + Returns: + A tuple of (token_counts, tokens_minus_eos) where token_counts is a list of + per-sequence lengths (excluding EOS) and tokens_minus_eos is the total. + """ num_seq = len(elem['seq_start_id'])🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/data/datasets/packing_utils.py` around lines 274 - 286, Update the public utility function get_seqlen_list to include a clear docstring and Python type hints: annotate the parameter elem as a mapping/dict containing 'seq_start_id' and 'input_ids' (e.g., Mapping[str, Sequence[int]] or Dict[str, List[int]]), and the return type as Tuple[List[int], int]; in the docstring describe the expected keys in elem, what the function computes (per-sequence token counts excluding EOS and total tokens minus EOS), and note the AssertionError raised when counts mismatch; ensure references to elem['seq_start_id'] and elem['input_ids'] remain unchanged so callers (e.g., flop_utils.py) keep compatibility.
299-300: Useloggerinstead ofprint()for consistency.The rest of this file uses
logger.info()andlogger.debug()for output. Usingprint()here is inconsistent.Proposed fix
if count != rows_total: - print(f'Dropping {rows_total - count}, total was {rows_total}') + logger.info(f"Dropping {rows_total - count}, total was {rows_total}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/data/datasets/packing_utils.py` around lines 299 - 300, Replace the print call with the module logger to keep output consistent: change print(f'Dropping {rows_total - count}, total was {rows_total}') to logger.info(...) (or logger.debug(...) if more appropriate) so it logs the same message and include the same interpolated values; make sure this location (where variables count and rows_total are used) is within the scope that has logger defined/imported in this module (e.g., the existing logger used elsewhere in packing_utils.py).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/megatron/bridge/data/datasets/packing_utils.py`:
- Around line 311-314: The four average calculations (avg_seqlen_count,
avg_seqlen_total, avg_seqlen_sq_individual, avg_seqlen_sq_per_row) can divide by
zero when count or seq_count_accum is 0; update each assignment to guard the
denominator (use a conditional or ternary) so when count == 0 or seq_count_accum
== 0 you return a safe default (e.g., 0 or None) instead of performing the
division. Specifically change the expressions that use count and seq_count_accum
to conditional forms referencing those exact variables and keep the same
variable names (avg_seqlen_count, avg_seqlen_total, avg_seqlen_sq_individual,
avg_seqlen_sq_per_row).
- Line 289: The calculate_avg_seqlen function currently accepts an unused
max_seq_len, uses print() instead of the module logger, lacks type
hints/docstring, and can divide by zero when the dataset is empty; remove the
max_seq_len parameter from the function signature, add concise type hints and a
docstring describing inputs/outputs, replace any print(...) calls with the
module logger (logger.info/debug/error as appropriate), and guard against
division by zero by checking for zero total tokens/sequences and returning safe
defaults (e.g., 0 or None) or raising a clear error; also update the call site
that does _lora_seq_stats_cache[cache_key] =
calculate_avg_seqlen(packed_data_path, gbs, seq_len, drop_remainder=True) to
call calculate_avg_seqlen(packed_data_path, gbs, drop_remainder=True) (remove
the seq_len argument).
In `@src/megatron/bridge/training/utils/flop_utils.py`:
- Around line 201-207: The current logic that sets packed_data_path by taking
the first element of
sorted(Path(dataset_root).glob(f"packed/*/training_{seq_size}.npy")) is brittle
and may pick the wrong file; update the block that reads
cfg.dataset.dataset_root, packed_specs.packed_sequence_size and builds matches
to handle multiple hits explicitly: if multiple matches exist, either (a) choose
deterministically by a clear criterion (e.g., natural sort, newest by mtime, or
a specific directory attribute) and document that choice, or (b) emit a warning
via the project logger including the full list of matches and which path is
selected (or raise an error to force user disambiguation); ensure the final
assignment to packed_data_path and any logs include the selected path so
consumers of packed_data_path can be traced.
- Around line 195-196: The detection for Llama 3 70B uses hardcoded magic
numbers in is_llama3_70b (cfg.model.hidden_size == 8192 and cfg.model.num_layers
== 80); replace this brittle check with an explicit model identifier from the
config (e.g., check getattr(cfg.model, "model_name", None) == "llama3_70b") or,
if no identifier exists, add a clear comment above the is_llama3_70b assignment
documenting why those exact hidden_size/num_layers values are required and
consider falling back to a named config field (e.g., cfg.model.version or
cfg.model.family) to avoid magic numbers.
---
Nitpick comments:
In `@src/megatron/bridge/data/datasets/packing_utils.py`:
- Around line 274-286: Update the public utility function get_seqlen_list to
include a clear docstring and Python type hints: annotate the parameter elem as
a mapping/dict containing 'seq_start_id' and 'input_ids' (e.g., Mapping[str,
Sequence[int]] or Dict[str, List[int]]), and the return type as Tuple[List[int],
int]; in the docstring describe the expected keys in elem, what the function
computes (per-sequence token counts excluding EOS and total tokens minus EOS),
and note the AssertionError raised when counts mismatch; ensure references to
elem['seq_start_id'] and elem['input_ids'] remain unchanged so callers (e.g.,
flop_utils.py) keep compatibility.
- Around line 299-300: Replace the print call with the module logger to keep
output consistent: change print(f'Dropping {rows_total - count}, total was
{rows_total}') to logger.info(...) (or logger.debug(...) if more appropriate) so
it logs the same message and include the same interpolated values; make sure
this location (where variables count and rows_total are used) is within the
scope that has logger defined/imported in this module (e.g., the existing logger
used elsewhere in packing_utils.py).
In `@src/megatron/bridge/training/utils/flop_utils.py`:
- Line 24: The global cache _lora_seq_stats_cache should follow project naming
conventions: either rename it to G_LORA_SEQ_STATS_CACHE (upper snake_case with
G_ prefix) to mark it as a true global, or if you intended it to be
module-private keep the underscore but add a short comment/docstring above it
explaining "module-private cache for memoization of LoRA sequence stats" so
intent is clear; update any references to _lora_seq_stats_cache (e.g., in
functions that read/write the cache) to use the new name if you rename it.
- Around line 230-243: The FLOPs calculation in model_flops_frozen and
model_flops_unfrozen (using avg_tokens, n_layers, hs, num_query_groups, n_heads,
ffn_hs, vocab_size, avg_seqlen2, batch_size and coefficients 12, 18, 6, 2/3)
lacks documentation; add a concise comment immediately above these calculations
that states the mathematical derivation/assumptions (e.g., which terms
correspond to attention, MLP, embedding/VOC operations and why coefficients like
12, 18, 6 and the 2/3 multiplier are used) and include a reference link or
citation to the paper/notes where the formula originates so future maintainers
can verify the coefficients and assumptions.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/megatron/bridge/data/datasets/packing_utils.pysrc/megatron/bridge/training/utils/flop_utils.py
| return token_counts, tokens_minus_eos | ||
|
|
||
|
|
||
| def calculate_avg_seqlen(dataset_file, gbs, max_seq_len, drop_remainder): |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's look at the function implementation in packing_utils.py
sed -n '289,320p' src/megatron/bridge/data/datasets/packing_utils.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1085
🏁 Script executed:
# Check how many lines the function actually spans
wc -l src/megatron/bridge/data/datasets/packing_utils.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 125
🏁 Script executed:
# Search for all usages of calculate_avg_seqlen
rg -n "calculate_avg_seqlen" --type pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 454
🏁 Script executed:
# Check the actual call site at flop_utils.py line 219
sed -n '215,225p' src/megatron/bridge/training/utils/flop_utils.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 595
🏁 Script executed:
# Check if logger is used elsewhere in packing_utils.py
rg -n "logger\|logging\|print" src/megatron/bridge/data/datasets/packing_utils.py | head -20Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
# Check how logging is typically done in other files in the same directory
rg -n "import.*log\|from.*log" src/megatron/bridge/data/datasets/ --type py | head -10Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
# Check imports at the top of packing_utils.py
head -30 src/megatron/bridge/data/datasets/packing_utils.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1084
🏁 Script executed:
# Also check flop_utils.py to see how they handle logging there
head -40 src/megatron/bridge/training/utils/flop_utils.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1853
🏁 Script executed:
# Check the complete context of where division happens
sed -n '310,316p' src/megatron/bridge/data/datasets/packing_utils.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 370
🏁 Script executed:
# Verify the exact line numbers of the function definition and print statement
rg -n "def calculate_avg_seqlen|print\(" src/megatron/bridge/data/datasets/packing_utils.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 224
Remove unused parameter max_seq_len and replace print() with logger.
The max_seq_len parameter is never used in the function body. The function also uses print() at line 300, which is inconsistent with the file's logging pattern (logger is available and initialized in the file).
Additionally, the function has no type hints or docstring, and contains potential division by zero risks at lines 311-314 if the dataset is empty or contains no sequences.
Proposed fix
-def calculate_avg_seqlen(dataset_file, gbs, max_seq_len, drop_remainder):
+def calculate_avg_seqlen(dataset_file: str, gbs: int, drop_remainder: bool) -> tuple[float, float, float, float]:
+ """Calculate average sequence length statistics from a packed dataset.
+
+ Args:
+ dataset_file: Path to the .npy packed dataset file.
+ gbs: Global batch size for dropping remainder rows.
+ drop_remainder: If True, drop rows that don't fill a complete batch.
+
+ Returns:
+ A tuple of (avg_seqlen_count, avg_seqlen_total, avg_seqlen_sq_individual, avg_seqlen_sq_per_row).
+ """
data = np.load(dataset_file, allow_pickle=True)
total_len_accum = 0
@@ -300,7 +307,7 @@ def calculate_avg_seqlen(dataset_file, gbs, max_seq_len, drop_remainder):
rows_total = len(data)
count = (rows_total // gbs)*gbs if drop_remainder else rows_total
if count != rows_total:
- print(f'Dropping {rows_total - count}, total was {rows_total}')
+ logger.info(f'Dropping {rows_total - count}, total was {rows_total}')Update the call site at flop_utils.py (line 219) to remove the seq_len argument:
_lora_seq_stats_cache[cache_key] = calculate_avg_seqlen(
packed_data_path, gbs, drop_remainder=True
)🧰 Tools
🪛 Ruff (0.15.2)
[warning] 289-289: Unused function argument: max_seq_len
(ARG001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/megatron/bridge/data/datasets/packing_utils.py` at line 289, The
calculate_avg_seqlen function currently accepts an unused max_seq_len, uses
print() instead of the module logger, lacks type hints/docstring, and can divide
by zero when the dataset is empty; remove the max_seq_len parameter from the
function signature, add concise type hints and a docstring describing
inputs/outputs, replace any print(...) calls with the module logger
(logger.info/debug/error as appropriate), and guard against division by zero by
checking for zero total tokens/sequences and returning safe defaults (e.g., 0 or
None) or raising a clear error; also update the call site that does
_lora_seq_stats_cache[cache_key] = calculate_avg_seqlen(packed_data_path, gbs,
seq_len, drop_remainder=True) to call calculate_avg_seqlen(packed_data_path,
gbs, drop_remainder=True) (remove the seq_len argument).
| if packed_data_path is None and packed_specs is not None: | ||
| dataset_root = getattr(cfg.dataset, "dataset_root", None) | ||
| seq_size = getattr(packed_specs, "packed_sequence_size", None) | ||
| if dataset_root is not None and seq_size is not None: | ||
| matches = sorted(Path(dataset_root).glob(f"packed/*/training_{seq_size}.npy")) | ||
| if matches: | ||
| packed_data_path = str(matches[0]) |
There was a problem hiding this comment.
Path discovery via glob is brittle.
Using sorted(Path(...).glob(...)) and taking matches[0] relies on alphabetical ordering, which may not correspond to the intended file. Consider adding a warning if multiple matches exist or making the selection criteria more explicit.
Suggested improvement
if dataset_root is not None and seq_size is not None:
matches = sorted(Path(dataset_root).glob(f"packed/*/training_{seq_size}.npy"))
if matches:
+ if len(matches) > 1:
+ import logging
+ logging.getLogger(__name__).warning(
+ f"Multiple packed data files found: {matches}. Using {matches[0]}"
+ )
packed_data_path = str(matches[0])🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/megatron/bridge/training/utils/flop_utils.py` around lines 201 - 207, The
current logic that sets packed_data_path by taking the first element of
sorted(Path(dataset_root).glob(f"packed/*/training_{seq_size}.npy")) is brittle
and may pick the wrong file; update the block that reads
cfg.dataset.dataset_root, packed_specs.packed_sequence_size and builds matches
to handle multiple hits explicitly: if multiple matches exist, either (a) choose
deterministically by a clear criterion (e.g., natural sort, newest by mtime, or
a specific directory attribute) and document that choice, or (b) emit a warning
via the project logger including the full list of matches and which path is
selected (or raise an error to force user disambiguation); ensure the final
assignment to packed_data_path and any logs include the selected path so
consumers of packed_data_path can be traced.
66414ac to
5b6e74b
Compare
5b6e74b to
f0443ae
Compare
|
/ok to test f0443ae |
|
/ok to test 945f9d0 |
945f9d0 to
8bfc1af
Compare
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com>
8bfc1af to
201730c
Compare
|
/ok to test 5f07396 |
…ps_llama3_lora_fix Signed-off-by: oliver könig <okoenig@nvidia.com>
|
/ok to test bcddf05 |
Summary by CodeRabbit