[model, refactor] refactor: Centralize provider_bridge config mapping in base class#2052
[model, refactor] refactor: Centralize provider_bridge config mapping in base class#2052
Conversation
This refactoring centralizes model-specific configurations within the provider_bridge method of each model bridge. Changes: - Add MoE-related field mappings to base class CONFIG_MAPPING: - num_experts -> num_moe_experts - num_experts_per_tok -> moe_router_topk - moe_intermediate_size -> moe_ffn_hidden_size - Refactor LlamaBridge: - Use MEGATRON_DEFAULTS and HF_DEFAULTS class attributes - Override provider_bridge only for RoPE scaling (Llama 3.1/3.2) - Refactor Qwen2Bridge: - Use MEGATRON_DEFAULTS (add_qkv_bias=True) and HF_DEFAULTS - No provider_bridge override needed - Refactor Qwen3Bridge: - Use MEGATRON_DEFAULTS (qk_layernorm=True) and HF_DEFAULTS - No provider_bridge override needed - Refactor Qwen3MoEBridge: - Use MEGATRON_DEFAULTS with MoE settings and HF_DEFAULTS - No provider_bridge override needed - Update tests to expect GPTModelProvider instead of model-specific providers - Add verification scripts for both Llama and Qwen bridges Verified on remote server: - Qwen/Qwen2-0.5B: PASS - Qwen/Qwen2-7B: PASS - Qwen/Qwen3-0.6B: PASS - Qwen/Qwen3-1.7B: PASS - Qwen/Qwen3-30B-A3B: PASS
…dels - Add MLAModelProvider as unified base for Multi-Latent Attention models - Refactor DeepSeek V2/V3 bridges to use MLAModelProvider - Refactor Kimi K2 bridge to use MLAModelProvider - Move model-specific defaults from providers to MEGATRON_DEFAULTS in bridges - Add model_type parameter to @register_bridge decorator for auto HF config - Simplify provider files to deprecated backward-compatible aliases Verified: DeepSeek-V2-Lite, DeepSeek-V2, DeepSeek-V3, Moonlight-16B, Kimi-K2
- Register GemmaModelProvider, Gemma2ModelProvider, Gemma3ModelProvider via decorator
- Add MEGATRON_DEFAULTS to Gemma/Gemma2 bridges for explicit config defaults
- Add gelu_pytorch_tanh -> fast_gelu to ACTIVATION_MAPPING in model_bridge.py
- Add verification script for Gemma provider refactoring
Verified: gemma-2b, gemma-7b, gemma-2-2b, gemma-2-9b, gemma-2-27b,
gemma-3-4b-it, gemma-3-12b-it, gemma-3-27b-it
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
|
/ok to test ca54e4f |
📝 WalkthroughWalkthroughThis PR introduces bidirectional HuggingFace-to-Megatron configuration translation on Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/megatron/bridge/models/deepseek/deepseek_provider.py (2)
85-85: Type annotation mismatch:int = Noneshould beOptional[int].The annotation
q_lora_rank: int = Noneis incorrect sinceNoneis not a validint. This should beOptional[int] = Noneorint | None = None.Proposed fix
- q_lora_rank: int = None + q_lora_rank: Optional[int] = NoneAlso update the imports at line 16:
-from typing import List, Union +from typing import List, Optional, Union
165-165: Same type annotation issue:int = Noneshould beOptional[int].Same issue as line 85 -
q_lora_rank: int = Noneshould useOptional[int].
🤖 Fix all issues with AI agents
In `@src/megatron/bridge/models/gemma/gemma3_bridge.py`:
- Around line 56-65: After you set
provider.fp16/provider.bf16/provider.params_dtype from dtype_from_hf, also set
provider.autocast_dtype to the same derived dtype so autocasting matches the VL
precision override; locate where provider.fp16/bf16/params_dtype are assigned
(using dtype_from_hf with hf_vl_config) and set provider.autocast_dtype =
self.dtype_from_hf(hf_vl_config, default=torch.float32) (or reuse the computed
params dtype) to keep autocast consistent with params.
In `@src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py`:
- Around line 83-86: When forcing BF16 in gpt_oss_bridge, clear the FP16 flag to
avoid conflicting dtype settings: in the same location where you set
provider.bf16 = True and provider.params_dtype = torch.bfloat16 (near
provider.hidden_dropout), also set provider.fp16 = False so the provider does
not have both fp16 and bf16 enabled; update the assignment in the GPT-OSS bridge
initialization (the code that modifies provider.hidden_dropout / provider.bf16 /
provider.params_dtype) to explicitly clear provider.fp16 when enabling BF16.
In `@src/megatron/bridge/models/gpt_provider.py`:
- Around line 139-145: The parameters rope_scaling, rope_scaling_factor, and
seq_len_interpolation_factor are being passed unconditionally to Megatron APIs;
to match the existing defensive pattern used for mtp_block_spec, check the
target function/class signatures (use inspect.signature where mtp_block_spec is
handled) before adding these kwargs, and only include them in the kwargs dict if
the inspected signature has those parameters; remove direct positional/keyword
assignments for rope_scaling/rope_scaling_factor/seq_len_interpolation_factor
and instead add them to the **kwargs conditionally (referencing the same
inspection logic used around mtp_block_spec).
In `@src/megatron/bridge/models/kimi/kimi_bridge.py`:
- Around line 33-75: The Kimi bridge's provider_bridge (method provider_bridge)
currently omits an explicit moe_aux_loss_coeff; add an explicit assignment
setting provider.moe_aux_loss_coeff = 1e-3 in provider_bridge (alongside the
other provider.* assignments) so the KimiK2Provider's auxiliary MoE loss
coefficient is set consistently with GLM45/Qwen3/DeepSeek V3 bridges.
In `@src/megatron/bridge/models/qwen/qwen2_bridge.py`:
- Around line 37-39: Update the class docstring in Qwen2 (which currently
mentions MEGATRON_DEFAULTS) to instead state that model-specific settings are
applied in provider_bridge; locate the docstring on the Qwen2 class in
qwen2_bridge.py (and any nearby reference to MegatronModelBridge’s
CONFIG_MAPPING/ACTIVATION_MAPPING) and replace the outdated reference to
MEGATRON_DEFAULTS with a concise note that provider_bridge handles
model-specific defaults.
In `@tests/unit_tests/models/llama/test_llama_bridge.py`:
- Around line 218-227: The test function
test_provider_bridge_rope_scaling_params currently accepts an unused fixture
parameter llama_config; remove that parameter from the test signature so it
becomes def test_provider_bridge_rope_scaling_params(self,
mock_pretrained_llama): and ensure any references to llama_config inside the
test are not present, then run the linter to confirm Ruff no longer flags the
unused argument; locate this test in
tests/unit_tests/models/llama/test_llama_bridge.py and update the function
signature accordingly (function name: test_provider_bridge_rope_scaling_params,
class uses LlamaBridge and mock_pretrained_llama).
🧹 Nitpick comments (1)
src/megatron/bridge/models/conversion/model_bridge.py (1)
244-310: Consider annotating mutable class attributes withClassVar.The
CONFIG_MAPPING,YARN_ROPE_SCALING_MAPPING, andACTIVATION_MAPPINGare mutable class attributes that should ideally be annotated withtyping.ClassVarto indicate they are class-level and not instance-level attributes.Proposed fix
Add
ClassVarto imports and annotate:from typing import ( Callable, + ClassVar, Dict, Generic, ... )- CONFIG_MAPPING = [ + CONFIG_MAPPING: ClassVar[list[tuple[str, str]]] = [ # Core architecture ... ] - YARN_ROPE_SCALING_MAPPING = [ + YARN_ROPE_SCALING_MAPPING: ClassVar[list[tuple[str, str]]] = [ ... ] - ACTIVATION_MAPPING = { + ACTIVATION_MAPPING: ClassVar[dict[str, Callable]] = { ... }
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
|
/ok to test 1ed2069 |
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
|
/ok to test ab24acf |
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
|
/ok to test b6f8e29 |
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
|
/ok to test fa03b2c |
@yaoyu-33, there was an error processing your request: See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/2/ |
|
/ok to test 167055a |
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
|
/ok to test f9a3231 |
OLMoE HF config doesn't have head_dim attribute, so kv_channels was left as None. This fix calculates it as hidden_size // num_attention_heads (2048 // 16 = 128 for OLMoE-1B-7B). This follows the pattern used by MistralBridge and NemotronHBridge.
|
/ok to test d5b7890 |
|
/ok to test ada7d05 |
|
/ok to test 5f24f9b |
|
/ok to test 057175f |
[model, refactor] refactor: Centralize provider_bridge config mapping in base class
Summary
This PR implements the
provider_bridgerefactoring proposal (docs/proposals/provider_bridge_refactor.md). It centralizes common HF to Megatron configuration mappings in theMegatronModelBridgebase class and refactors model bridges to use the new pattern with direct property assignment.This is partial work - model-specific provider classes (e.g.,
LlamaModelProvider,Qwen2ModelProvider) are NOT removed yet. That cleanup will come in a follow-up PR.Motivation
Before (scattered model-specific logic):
After (centralized mapping + direct property assignment):
Key Changes
Base Class Enhancements (
MegatronModelBridge)Added centralized mappings and helper methods:
CONFIG_MAPPINGACTIVATION_MAPPINGYARN_ROPE_SCALING_MAPPINGhf_config_to_provider_kwargs()megatron_to_hf_config()hf_to_megatron_activation()megatron_to_hf_activation()provider_bridge()CONFIG_MAPPINGPROVIDER_CLASS@register_bridge(provider=...)New:
MLAModelProviderAdded a minimal MLA (Multi-Latent Attention) provider class that combines
MLATransformerConfigwithGPTModelProvider. Used by DeepSeek V2/V3 and Kimi K2.Refactored Bridges
All refactored bridges now:
super().provider_bridge()to get a provider with common settings fromCONFIG_MAPPINGLlamaBridgeGPTModelProviderQwen2BridgeGPTModelProvideradd_qkv_bias=TrueQwen3BridgeGPTModelProviderqk_layernorm=True, no QKV biasQwen3MoEBridgeGPTModelProviderDeepSeekV2BridgeMLAModelProviderDeepSeekV3BridgeMLAModelProviderKimiK2BridgeMLAModelProviderGemmaBridgeGPTModelProviderGemma2BridgeGemma2ModelProviderGemma3BridgeGemma2ModelProviderGLM45BridgeGPTModelProviderGPTOSSBridgeGPTModelProviderSimplified Providers
DeepSeekModelProviderfamilyLlamaModelProviderNot Included in This PR (Future Work)
Llama2ModelProvider7B,Qwen2ModelProvider7BFiles Changed
Core
src/megatron/bridge/models/conversion/model_bridge.py- AddedCONFIG_MAPPING,ACTIVATION_MAPPING,YARN_ROPE_SCALING_MAPPING, helper methods, defaultprovider_bridge(),PROVIDER_CLASSsupportNew Files
src/megatron/bridge/models/mla_provider.py- NewMLAModelProviderfor MLA-based modelssrc/megatron/bridge/models/kimi/__init__.py- Kimi module initsrc/megatron/bridge/models/kimi/kimi_bridge.py- New Kimi K2 bridgeRefactored Bridges
src/megatron/bridge/models/llama/llama_bridge.pysrc/megatron/bridge/models/qwen/qwen2_bridge.pysrc/megatron/bridge/models/qwen/qwen3_bridge.pysrc/megatron/bridge/models/qwen/qwen3_moe_bridge.pysrc/megatron/bridge/models/deepseek/deepseek_v2_bridge.pysrc/megatron/bridge/models/deepseek/deepseek_v3_bridge.pysrc/megatron/bridge/models/gemma/gemma_bridge.pysrc/megatron/bridge/models/gemma/gemma2_bridge.pysrc/megatron/bridge/models/gemma/gemma3_bridge.pysrc/megatron/bridge/models/glm/glm45_bridge.pysrc/megatron/bridge/models/gpt_oss/gpt_oss_bridge.pySimplified Providers
src/megatron/bridge/models/deepseek/deepseek_provider.pysrc/megatron/bridge/models/llama/llama_provider.pyTests
tests/unit_tests/models/llama/test_llama_bridge.py- Updated to expectGPTModelProvidertests/unit_tests/models/qwen/test_qwen3_bridge.py- Updatedtests/unit_tests/models/qwen/test_qwen3_moe_bridge.py- UpdatedDesign Principles
Following
docs/proposals/provider_bridge_refactor.md:CONFIG_MAPPING- Common field mappings are handled automatically@register_bridge(provider=MLAModelProvider)for MLA modelsBreaking Changes
GPTModelProviderinstead of some model-specific providersChecklist
CONFIG_MAPPINGcovers common HF to Megatron field mappingsACTIVATION_MAPPINGcovers common activation functionsMLAModelProvideradded for MLA-based modelsLlamaBridgerefactoredQwen2BridgerefactoredQwen3BridgerefactoredQwen3MoEBridgerefactoredDeepSeekV2BridgerefactoredDeepSeekV3BridgerefactoredKimiK2Bridgeadded (new)GemmaBridgerefactoredGemma2BridgerefactoredGemma3BridgerefactoredGLM45BridgerefactoredGPTOSSBridgerefactoredRelated
docs/proposals/provider_bridge_refactor.mdSummary by CodeRabbit
New Features
Improvements
Deprecations
✏️ Tip: You can customize this high-level summary in your review settings.