-
Notifications
You must be signed in to change notification settings - Fork 239
cp: Fix LLAMA3 LoRa TFLOPs Formula (2416) into r0.3.0
#2533
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -14,14 +14,17 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch.nn.functional as F | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from megatron.bridge.peft.lora import LoRA | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from megatron.bridge.training.config import ConfigContainer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from megatron.bridge.utils.vocab_utils import calculate_padded_vocab_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def num_floating_point_operations(cfg: ConfigContainer, batch_size: int = 1): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Return the number of floating point operations""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # If the model provider has a custom TFLOPS calculation method, use it. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if hasattr(cfg.model, "_get_num_floating_point_operations"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| peft = getattr(cfg, "peft", None) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| is_lora = isinstance(peft, LoRA) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # If the model provider has a custom TFLOPS calculation method, use it (non-LoRA only). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not is_lora and hasattr(cfg.model, "_get_num_floating_point_operations"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return cfg.model._get_num_floating_point_operations(batch_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def calculate_layer_counts(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -183,6 +186,37 @@ def transformer_flops(): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_query_groups = ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cfg.model.num_attention_heads if cfg.model.num_query_groups is None else cfg.model.num_query_groups | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if is_lora: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _LORA_SEQ_STATS = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 4096: (842603, 4096), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 2048: (488991, 2030), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| seq_len = cfg.model.seq_length | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if seq_len not in _LORA_SEQ_STATS: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"No LoRA stats for seq_length={seq_len}. Add it to _LORA_SEQ_STATS.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| avg_seqlen2, avg_tokens = _LORA_SEQ_STATS[seq_len] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hs = cfg.model.hidden_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| n_layers = cfg.model.num_layers | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| n_heads = cfg.model.num_attention_heads | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ffn_hs = cfg.model.ffn_hidden_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vocab_size = cfg.model.vocab_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_flops_frozen = ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| avg_tokens | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * n_layers | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * hs**2 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 12 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| + 12 * num_query_groups / n_heads | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| + 18 * ffn_hs / hs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| + 6 * vocab_size / (n_layers * hs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+204
to
+215
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use padded vocab size in LoRA FLOPs math for consistency. Line 204 uses Proposed fix- vocab_size = cfg.model.vocab_size
+ vocab_size = calculate_padded_vocab_size(
+ cfg.model.vocab_size,
+ cfg.model.make_vocab_size_divisible_by,
+ cfg.model.tensor_model_parallel_size,
+ logging_enabled=False,
+ )📝 Committable suggestion
Suggested change
🧰 Tools🪛 GitHub Actions: CICD NeMo[error] 207-214: pre-commit ruff-format hook failed: 1 file reformatted. Run 'pre-commit run --all-files' or commit again to apply formatting changes. 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_flops_unfrozen = n_layers * hs**2 * (12 * avg_seqlen2 / hs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return batch_size * (model_flops_frozen * (2.0 / 3.0) + model_flops_unfrozen) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+190
to
+219
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please run formatting before merge ( CI already reports this file was reformatted by the pre-commit hook; please run As per coding guidelines: "Use ruff for linting and formatting Python code". 🧰 Tools🪛 GitHub Actions: CICD NeMo[error] 207-214: pre-commit ruff-format hook failed: 1 file reformatted. Run 'pre-commit run --all-files' or commit again to apply formatting changes. 🪛 Ruff (0.15.2)[warning] 197-197: Avoid specifying long messages outside the exception class (TRY003) 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # MoE. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if cfg.model.num_moe_experts is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Every Transformer MLP is dense. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid hard-failing on unsupported LoRA
seq_lengthvalues.Line 196–197 currently raises for any
seq_lengthnot present in_LORA_SEQ_STATS, which can stop otherwise valid training runs just for metrics computation. Prefer a graceful fallback to the standard transformer path when stats are missing.Proposed fix
🧰 Tools
🪛 Ruff (0.15.2)
[warning] 197-197: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents