chore: Update RL to use megatron-bridge tot#1358
Conversation
📝 WalkthroughWalkthroughAdds a pre-finalize step to Megatron community import. Introduces a CustomFloat16Module and mixed_precision_wrapper logic in Megatron policy worker, including FSDP detection, vocab padding, and stricter config assertions. Updates refit_verifier to explicitly pass TP/PP/EP to vLLM, set temperatures, add train_iters, and fix boolean types. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Caller
participant CommunityImport as Community Import
participant Provider as Megatron Provider
Caller->>CommunityImport: import_model_from_hf_name(...)
CommunityImport->>Provider: finalize()
Note over Provider: Prepares global state before parallel init
CommunityImport->>Provider: initialize_model_parallel(...)
Provider-->>Caller: Initialized model
sequenceDiagram
autonumber
participant Worker as PolicyWorker
participant Tokenizer
participant Vocab as VocabUtil
participant Wrapper as MixedPrecisionWrapper
participant Model as MegatronModel
participant Router as MoE Routers
participant FSDP as torch_FSDP (optional)
Worker->>Tokenizer: load tokenizer
Worker->>Vocab: calculate_padded_vocab_size(vocab_size)
Vocab-->>Worker: final_padded_vocab_size
Worker->>Wrapper: select wrapper (Float16 / CustomFloat16 / None)
alt FSDP available
Worker->>FSDP: set HAVE_FSDP2 flag
end
Worker->>Model: get_model(..., mixed_precision_wrapper=Wrapper, vocab_size=asserted)
opt Using CustomFloat16
Worker->>Wrapper: re_enable_float32_expert_bias()
Wrapper->>Router: _maintain_float32_expert_bias()
end
Note over Worker,Model: Same flow applied to reference model (pre/post load)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Pre-merge checks and finishing touches❌ Failed checks (2 warnings, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
nemo_rl/models/policy/megatron_policy_worker.py (1)
2026-2060: CustomFloat16Module correctly maintains float32 MoE router bias.The new
CustomFloat16Moduleclass properly extendsFloat16Moduleto ensure MoE router expert bias stays in float32 for numerical stability. There_enable_float32_expert_bias()method correctly:
- Handles VLM models by unwrapping
language_model- Walks decoder layers to find routers
- Invokes
_maintain_float32_expert_bias()when availableConsider adding defensive checks for robustness:
def re_enable_float32_expert_bias(self) -> None: """Ensure MoE router expert bias stays in float32 for numerical stability. Walks the wrapped module to find MoE routers and invokes the `_maintain_float32_expert_bias()` helper which recreates or casts the expert bias tensors to float32 as required by Megatron-LM. """ module = self.module # Handle VLM models where language model is nested if hasattr(module, "language_model"): module = module.language_model - if hasattr(module, "decoder") and hasattr(module.decoder, "layers"): + # Only process if the model has the expected decoder structure + if not (hasattr(module, "decoder") and hasattr(module.decoder, "layers")): + return + for layer in module.decoder.layers: - for layer in module.decoder.layers: - mlp = getattr(layer, "mlp", None) - router = getattr(mlp, "router", None) if mlp is not None else None - if router is not None and hasattr( - router, "_maintain_float32_expert_bias" - ): - router._maintain_float32_expert_bias() + mlp = getattr(layer, "mlp", None) + router = getattr(mlp, "router", None) if mlp is not None else None + if router is not None and hasattr(router, "_maintain_float32_expert_bias"): + router._maintain_float32_expert_bias()nemo_rl/models/megatron/community_import.py (1)
72-72: Update docstring to document finalize() call
Add thatmodel_provider.finalize()runs deferred post-init logic, validates the provider/config, and must be called after config modifications and beforeinitialize_model_parallel().
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
nemo_rl/models/megatron/community_import.py(1 hunks)nemo_rl/models/policy/megatron_policy_worker.py(9 hunks)tools/refit_verifier.py(4 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts
Files:
nemo_rl/models/megatron/community_import.pytools/refit_verifier.pynemo_rl/models/policy/megatron_policy_worker.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)
Files:
nemo_rl/models/megatron/community_import.pynemo_rl/models/policy/megatron_policy_worker.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Lint check
- GitHub Check: Post submodule check comment / Comment on PR
- GitHub Check: Post automodel integration comment / Comment on PR
🔇 Additional comments (9)
tools/refit_verifier.py (3)
161-161: LGTM! Consistent temperature configuration.Adding
temperature: 1.0to the Megatron generation config ensures consistency with the vLLM configuration (line 271), which is important for accurate logprob comparison in this verification script.
216-217: LGTM! Required Megatron config additions.Adding
train_iters: 1andbias_activation_fusion: Falsealigns with new Megatron configuration requirements. The assertion at line 651 inmegatron_policy_worker.pyconfirms thattrain_itersis now mandatory.
277-279: LGTM! Explicit parallelism configuration.Explicitly passing
tensor_parallel_size,pipeline_parallel_size, andexpert_parallel_sizeto vLLM config (instead of computing them) improves clarity and aligns with the updatedVllmGenerationexpectations noted in the comment at line 262.nemo_rl/models/policy/megatron_policy_worker.py (6)
134-139: LGTM! Proper FSDP2 feature detection.The try/except block correctly detects FSDP2 availability without requiring a hard dependency. The
ImportError-specific exception handling follows best practices and is used appropriately at lines 316 and 771 to conditionally adjust checkpoint loading behavior.
249-249: Stricter vocab_size validation.The assertion now requires
vocab_sizeto be explicitly specified in the model config, replacing any previous fallback behavior. This is a breaking change that ensures explicit configuration.Ensure all configuration files and documentation specify
vocab_sizeexplicitly.
814-818: LGTM! Explicit padded vocab size calculation.Calculating
final_padded_vocab_sizeusing the importedcalculate_padded_vocab_sizeutility ensures correct vocab padding for tensor parallelism. The calculated value is used at line 1466 for inference configuration.
254-276: Clarify precedence when both freeze_moe_router and defer_fp32_logits are enabled.The logic sets
mixed_precision_wrapper = CustomFloat16Modulewhenfreeze_moe_routeris true (lines 271-272), then overrides it toNonewhendefer_fp32_logitsis enabled (lines 275-276). This meansdefer_fp32_logitstakes precedence.Verify that the precedence is intentional. If both options can be enabled simultaneously, consider adding a comment or assertion to clarify the expected behavior:
# If deferring fp32 logits, disable mixed-precision wrapper entirely # This takes precedence over freeze_moe_router which also sets the wrapper if policy_cfg["megatron_cfg"].get("defer_fp32_logits", None): mixed_precision_wrapper = None
745-758: LGTM! Consistent wrapper configuration for reference model.The reference model uses the same mixed precision wrapper selection logic as the main model (lines 254-276), ensuring consistency. The
ref_mixed_precision_wrapperis correctly passed toget_modelat line 758.
93-93: LGTM! Required import for CustomFloat16Module.The
TransformerConfigimport is necessary for the newCustomFloat16Moduleclass definition at line 2038.
|
@ZhiyuLi-Nvidia can you double check the fp32 expert bias change in this PR? |
@yaoyu-33 could you help me run a test on moonshotai/Moonlight-16B-A3B-Instruct for verification. I think we are good to go if the experiment is successful. |
❌ Submodule Fast-Forward Check FailedCheck based on commit: 0f93ad0 (PR #1358 from ❌ Submodules that need attention:Megatron-Bridge: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
❌ Submodule Fast-Forward Check FailedCheck based on commit: 371e458 (PR #1358 from ✅ Submodules that are properly updated:Megatron-LM: ✅ PR branch is ahead of main branch (fast-forward) ❌ Submodules that need attention:Megatron-Bridge: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
❌ Submodule Fast-Forward Check FailedCheck based on commit: fc33e3b (PR #1358 from ✅ Submodules that are properly updated:Megatron-LM: ✅ PR branch is ahead of main branch (fast-forward) ❌ Submodules that need attention:Megatron-Bridge: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
❌ Submodule Fast-Forward Check FailedCheck based on commit: e9a3e46 (PR #1358 from ✅ Submodules that are properly updated:Megatron-LM: ✅ PR branch is ahead of main branch (fast-forward) ❌ Submodules that need attention:Megatron-Bridge: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
❌ Submodule Fast-Forward Check FailedCheck based on commit: 752a7cf (PR #1358 from ✅ Submodules that are properly updated:Megatron-LM: ✅ PR branch is ahead of main branch (fast-forward) ❌ Submodules that need attention:Megatron-Bridge: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
❌ Submodule Fast-Forward Check FailedCheck based on commit: 8675993 (PR #1358 from ✅ Submodules that are properly updated:Megatron-LM: ✅ PR branch is ahead of main branch (fast-forward) ❌ Submodules that need attention:Megatron-Bridge: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
❌ Submodule Fast-Forward Check FailedCheck based on commit: bb2809a (PR #1358 from ✅ Submodules that are properly updated:Megatron-LM: ✅ PR branch is ahead of main branch (fast-forward) ❌ Submodules that need attention:Megatron-Bridge: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
❌ Submodule Fast-Forward Check FailedCheck based on commit: e7851c7 (PR #1358 from ✅ Submodules that are properly updated:Megatron-LM: ✅ PR branch is ahead of main branch (fast-forward) ❌ Submodules that need attention:Megatron-Bridge: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
ZhiyuLi-Nvidia
left a comment
There was a problem hiding this comment.
Thank you @yaoyu-33. LGTM!
Sync up offline.
Correct convergence/logprob error in moonlight model should verify the effectiveness of re_enable_float32_expert_bias.
❌ Submodule Fast-Forward Check FailedCheck based on commit: 8c7d7f1 (PR #1358 from ✅ Submodules that are properly updated:Megatron-LM: ✅ PR branch is ahead of main branch (fast-forward) ❌ Submodules that need attention:Megatron-Bridge: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
❌ Submodule Fast-Forward Check FailedCheck based on commit: ba4f889 (PR #1358 from ✅ Submodules that are properly updated:Megatron-LM: ✅ PR branch is ahead of main branch (fast-forward) ❌ Submodules that need attention:Megatron-Bridge: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
❌ Submodule Fast-Forward Check FailedCheck based on commit: f9eb786 (PR #1358 from ✅ Submodules that are properly updated:Megatron-LM: ✅ PR branch is ahead of main branch (fast-forward) ❌ Submodules that need attention:Megatron-Bridge: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
❌ Submodule Fast-Forward Check FailedCheck based on commit: 87451fe (PR #1358 from ✅ Submodules that are properly updated:Megatron-LM: ✅ PR branch is ahead of main branch (fast-forward) ❌ Submodules that need attention:Megatron-Bridge: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
87451fe to
d7a3e40
Compare
❌ Submodule Fast-Forward Check FailedCheck based on commit: d7a3e40 (PR #1358 from ✅ Submodules that are properly updated:Megatron-LM: ✅ PR branch is ahead of main branch (fast-forward) ❌ Submodules that need attention:Megatron-Bridge: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
❌ Submodule Fast-Forward Check FailedCheck based on commit: 3674c3f (PR #1358 from ✅ Submodules that are properly updated:Megatron-LM: ✅ PR branch is ahead of main branch (fast-forward) ❌ Submodules that need attention:Megatron-Bridge: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: Terry Kong <terryk@nvidia.com> fix more stuff Signed-off-by: Terry Kong <terryk@nvidia.com> fix Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: Terry Kong <terryk@nvidia.com>
8797154 to
535f66c
Compare
❌ Submodule Fast-Forward Check FailedCheck based on commit: 8797154 (PR #1358 from ✅ Submodules that are properly updated:Megatron-LM: ✅ PR branch is ahead of main branch (fast-forward) ❌ Submodules that need attention:Megatron-Bridge: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
❌ Submodule Fast-Forward Check FailedCheck based on commit: 535f66c (PR #1358 from ✅ Submodules that are properly updated:Megatron-LM: ✅ PR branch is ahead of main branch (fast-forward) ❌ Submodules that need attention:Megatron-Bridge: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
Signed-off-by: Terry Kong <terryk@nvidia.com>
❌ Submodule Fast-Forward Check FailedCheck based on commit: 4ee29a2 (PR #1358 from ✅ Submodules that are properly updated:Megatron-LM: ✅ PR branch is ahead of main branch (fast-forward) ❌ Submodules that need attention:Megatron-Bridge: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
Signed-off-by: Terry Kong <terryk@nvidia.com> Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com> Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> Co-authored-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: Terry Kong <terryk@nvidia.com> Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com> Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> Co-authored-by: Terry Kong <terryk@nvidia.com> Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
What does this PR do ?
As title
Tests
The 70b failure is a slight memory bump that exists in main.
The dpo failure
dpo-llama3.1-8b-instruct-4n8g-megatron.v2is due thenum_workerschange and caused the shuffling order to change since the default is 1 instead of 0.Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
New Features
Bug Fixes
Refactor
Chores