Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
9ac259b
[DRAFT] Refactor provider_bridge for Llama and Qwen models
yaoyu-33 Jan 23, 2026
ab8a5e4
refactor(bridge): Introduce MLAModelProvider for DeepSeek/Kimi MLA mo…
yaoyu-33 Jan 23, 2026
0a12eb7
refactor(bridge): Refactor Gemma bridges to use specialized providers
yaoyu-33 Jan 24, 2026
aa54966
Merge branch 'main' into feature/provider-bridge-refactor
yaoyu-33 Jan 24, 2026
4ade385
update
yaoyu-33 Jan 26, 2026
3ec0b0a
remove MEGATRON_DEFAULTS
yaoyu-33 Jan 26, 2026
c100b4a
remove testing scripts
yaoyu-33 Jan 26, 2026
3c709c5
yarn fix
yaoyu-33 Jan 26, 2026
9eb8542
clean ups
yaoyu-33 Jan 27, 2026
2cc3b35
fix
yaoyu-33 Jan 27, 2026
ca54e4f
fix unit tests
yaoyu-33 Jan 27, 2026
1ed2069
unit test fix
yaoyu-33 Jan 27, 2026
9facb3e
functional test fix
yaoyu-33 Jan 27, 2026
b753ab8
code rabbit
yaoyu-33 Jan 27, 2026
ab24acf
fix functional tests
yaoyu-33 Jan 28, 2026
b6f8e29
remove KimiK2Bridge from init
yaoyu-33 Jan 28, 2026
167055a
clean up deprecated functional tests
yaoyu-33 Jan 28, 2026
22a6321
Merge branch 'main' into feature/provider-bridge-refactor
yaoyu-33 Feb 2, 2026
f9a3231
olmoe update
yaoyu-33 Feb 3, 2026
d5b7890
Fix kv_channels calculation for OLMoE bridge
yaoyu-33 Feb 3, 2026
ada7d05
fix: always set yarn params with None defaults for MCoreGPTModel comp…
yaoyu-33 Feb 3, 2026
5f24f9b
Merge branch 'main' into feature/provider-bridge-refactor
yaoyu-33 Feb 3, 2026
057175f
lint
yaoyu-33 Feb 3, 2026
54e0971
Merge branch 'main' into feature/provider-bridge-refactor-2
yaoyu-33 Feb 4, 2026
6c6a3b9
nemotron H bridge update
yaoyu-33 Feb 4, 2026
0c94698
nemotron bridge update
yaoyu-33 Feb 4, 2026
4d827e6
nemotron bridge update
yaoyu-33 Feb 5, 2026
db38206
Merge branch 'main' into feature/provider-bridge-refactor-2
yaoyu-33 Feb 5, 2026
9618f91
fix: remove generation_config from GPTModelProvider and tests
yaoyu-33 Feb 5, 2026
2b0bb06
Merge branch 'main' into feature/provider-bridge-refactor-2
yaoyu-33 Feb 5, 2026
58b2c6e
fix: remove additional generation_config tests from bridge test files
yaoyu-33 Feb 5, 2026
2310230
fix: add relu2 activation support and fix Nemotron/NemotronH tests
yaoyu-33 Feb 5, 2026
77a567b
fix: update Nemotron bridge tests for correct provider types
yaoyu-33 Feb 5, 2026
dabe9d5
remove deprecated provider test
yaoyu-33 Feb 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions src/megatron/bridge/models/conversion/model_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import torch
import torch.nn.functional as F
from megatron.core import parallel_state
from megatron.core.activations import fast_gelu
from megatron.core.activations import fast_gelu, squared_relu
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import (
Expand Down Expand Up @@ -313,6 +313,7 @@ def mapping_registry(self) -> MegatronMappingRegistry:
"silu": F.silu,
"gelu": F.gelu,
"relu": F.relu,
"relu2": squared_relu,
"tanh": torch.tanh,
"gelu_pytorch_tanh": fast_gelu,
}
Expand Down Expand Up @@ -425,8 +426,7 @@ def provider_bridge(self, hf_pretrained: HFPreTrained) -> ModelProviderTarget:

Default implementation that:
1. Converts HF config to provider kwargs using CONFIG_MAPPING
2. Adds generation_config
3. Creates and returns a GPTModelProvider
2. Creates and returns a GPTModelProvider

Subclasses should override this to add model-specific configuration
by calling super().provider_bridge() then setting properties directly
Expand All @@ -449,9 +449,6 @@ def provider_bridge(self, hf_pretrained: HFPreTrained) -> ModelProviderTarget:
yarn_params = provider_kwargs.pop("_yarn_params", None)
mla_rope_params = provider_kwargs.pop("_mla_rope_params", None)

# Add generation config
provider_kwargs["generation_config"] = hf_pretrained.generation_config

# Use specified provider class, defaulting to GPTModelProvider
provider_class = self.PROVIDER_CLASS if self.PROVIDER_CLASS is not None else GPTModelProvider
provider = provider_class(**provider_kwargs)
Expand Down
1 change: 0 additions & 1 deletion src/megatron/bridge/models/deepseek/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def get_common_configs(hf_pretrained: PreTrainedCausalLM) -> dict:

# Ensure MLA is enabled
configs["multi_latent_attention"] = True
configs["generation_config"] = hf_pretrained.generation_config
configs["vocab_size"] = hf_config.vocab_size
configs["rotary_base"] = hf_config.rope_theta
configs["init_method_std"] = hf_config.initializer_range
Expand Down
2 changes: 0 additions & 2 deletions src/megatron/bridge/models/gpt_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,6 @@ class GPTModelProvider(TransformerConfig, ModelProviderMixin[MCoreGPTModel]):
hf_model_id: str | None = None
"""Optional HuggingFace model identifier associated with this provider."""

generation_config: Optional[Any] = None

# This represents the unpadded vocab size
# The padded vocab size is automatically calculated in the provide() method.
vocab_size: Optional[int] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> Llama31ModelProv
fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16),
bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16),
params_dtype=self.dtype_from_hf(hf_config, default=torch.float32),
generation_config=hf_pretrained.generation_config,
vocab_size=hf_config.vocab_size,
)

Expand Down
1 change: 0 additions & 1 deletion src/megatron/bridge/models/mistral/mistral_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> MistralModelProv
fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16),
bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16),
params_dtype=self.dtype_from_hf(hf_config, default=torch.float32),
generation_config=hf_pretrained.generation_config,
vocab_size=hf_config.vocab_size,
kv_channels=getattr(hf_config, "head_dim", None),
)
Expand Down
58 changes: 33 additions & 25 deletions src/megatron/bridge/models/nemotron/nemotron_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,21 @@
AutoMapping,
QKVMapping,
)
from megatron.bridge.models.gpt_provider import GPTModelProvider
from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM
from megatron.bridge.models.nemotron.nemotron_provider import NemotronModelProvider


@MegatronModelBridge.register_bridge(source=NemotronForCausalLM, target=GPTModel)
def squared_relu(x):
"""Squared ReLU activation function."""
return torch.pow(torch.nn.functional.relu(x), 2)


@MegatronModelBridge.register_bridge(
source=NemotronForCausalLM,
target=GPTModel,
provider=GPTModelProvider,
model_type="nemotron",
)
class NemotronBridge(MegatronModelBridge):
"""
Megatron Bridge for Nemotron Causal LM.
Expand All @@ -39,30 +49,28 @@ class NemotronBridge(MegatronModelBridge):
>>> provider = bridge.to_megatron_provider()
"""

def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> NemotronModelProvider:
hf_config = hf_pretrained.config

provider = NemotronModelProvider(
num_layers=hf_config.num_hidden_layers,
hidden_size=hf_config.hidden_size,
ffn_hidden_size=hf_config.intermediate_size,
num_attention_heads=hf_config.num_attention_heads,
init_method_std=hf_config.initializer_range,
layernorm_epsilon=hf_config.norm_eps,
num_query_groups=hf_config.num_key_value_heads,
seq_length=hf_config.max_position_embeddings,
rotary_base=hf_config.rope_theta,
rotary_percent=hf_config.partial_rotary_factor,
kv_channels=getattr(hf_config, "head_dim", None),
make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(hf_config.vocab_size),
share_embeddings_and_output_weights=getattr(hf_config, "tie_word_embeddings", False),
fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16),
bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16),
params_dtype=self.dtype_from_hf(hf_config, default=torch.float32),
generation_config=hf_pretrained.generation_config,
vocab_size=hf_config.vocab_size,
)
CONFIG_MAPPING = MegatronModelBridge.CONFIG_MAPPING + [
# Nemotron uses norm_eps instead of rms_norm_eps
("norm_eps", "layernorm_epsilon"),
]

def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider:
"""Convert HuggingFace Nemotron config to GPTModelProvider."""
# Use base class for common config conversion
provider = super().provider_bridge(hf_pretrained)

provider.normalization = "LayerNorm"
provider.activation_func = squared_relu
provider.position_embedding_type = "rope"
provider.add_bias_linear = False
provider.hidden_dropout = 0.0
provider.attention_dropout = 0.0
provider.masked_softmax_fusion = True
provider.persist_layer_norm = True
provider.bias_dropout_add_fusion = False
provider.layernorm_zero_centered_gamma = True
provider.cross_entropy_loss_fusion = True
provider.apply_rope_fusion = True
return provider

def mapping_registry(self) -> MegatronMappingRegistry:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import copy
from dataclasses import dataclass
from typing import Any, Optional

from megatron.core.activations import fast_gelu
from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec
Expand All @@ -39,7 +38,6 @@ class NemotronNano12Bv2VLModelProvider(NemotronNano12Bv2Provider):

vision_model_type: str = "radio"
language_model_type: str = "nemotron5-hybrid-12b"
generation_config: Optional[Any] = None

# Freeze knobs useful for transfer-learning scenarios
freeze_language_model: bool = False
Expand Down
1 change: 1 addition & 0 deletions src/megatron/bridge/models/nemotronh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@


__all__ = [
"NemotronHBridge",
"NemotronHModelProvider",
"NemotronHModelProvider4B",
"NemotronHModelProvider8B",
Expand Down
94 changes: 48 additions & 46 deletions src/megatron/bridge/models/nemotronh/nemotron_h_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging

import torch
from megatron.core.activations import squared_relu
from megatron.core.models.mamba import MambaModel

from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry
Expand All @@ -28,13 +28,18 @@
RowParallelMapping,
)
from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM
from megatron.bridge.models.nemotronh.nemotron_h_provider import NemotronHModelProvider
from megatron.bridge.models.mamba.mamba_provider import MambaModelProvider


logger = logging.getLogger(__name__)


@MegatronModelBridge.register_bridge(source="NemotronHForCausalLM", target=MambaModel)
@MegatronModelBridge.register_bridge(
source="NemotronHForCausalLM",
target=MambaModel,
provider=MambaModelProvider,
model_type="nemotron_h",
)
class NemotronHBridge(MegatronModelBridge):
"""
Megatron Bridge for Nemotron-H Causal LM.
Expand All @@ -49,52 +54,49 @@ class NemotronHBridge(MegatronModelBridge):
>>> provider = bridge.to_megatron_provider()
"""

def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> NemotronHModelProvider:
# Extend CONFIG_MAPPING with Nemotron-H/Mamba-specific fields
CONFIG_MAPPING = MegatronModelBridge.CONFIG_MAPPING + [
# Mamba-specific fields
("mamba_head_dim", "mamba_head_dim"),
("mamba_num_heads", "mamba_num_heads"),
("n_groups", "mamba_num_groups"),
("ssm_state_size", "mamba_state_dim"),
("hybrid_override_pattern", "hybrid_override_pattern"),
("residual_in_fp32", "fp32_residual_connection"),
("use_bias", "add_bias_linear"),
("layer_norm_epsilon", "layernorm_epsilon"),
# MoE-specific fields (already in base but with different HF names)
("moe_shared_expert_intermediate_size", "moe_shared_expert_intermediate_size"),
]

def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> MambaModelProvider:
"""Convert HuggingFace Nemotron-H config to MambaModelProvider."""
# Use base class for common config conversion
provider = super().provider_bridge(hf_pretrained)
hf_config = hf_pretrained.config

configs = {}
# MoE configurations
if hasattr(hf_config, "n_routed_experts") and hf_config.n_routed_experts > 0:
configs.update(
{
"num_moe_experts": hf_config.n_routed_experts,
"moe_ffn_hidden_size": hf_config.moe_intermediate_size,
"moe_shared_expert_intermediate_size": hf_config.moe_shared_expert_intermediate_size,
"moe_router_topk": hf_config.num_experts_per_tok,
"moe_router_num_groups": hf_config.n_group,
"moe_router_group_topk": hf_config.topk_group,
"moe_router_topk_scaling_factor": hf_config.routed_scaling_factor,
}
)
# Nemotron-H specific defaults
provider.activation_func = squared_relu
provider.masked_softmax_fusion = True
provider.apply_query_key_layer_scaling = False
provider.persist_layer_norm = True
provider.attention_softmax_in_fp32 = False
provider.first_last_layers_bf16 = True
provider.is_hybrid_model = True

return NemotronHModelProvider(
num_layers=hf_config.num_hidden_layers,
hidden_size=hf_config.hidden_size,
ffn_hidden_size=hf_config.intermediate_size,
add_bias_linear=hf_config.use_bias,
num_attention_heads=hf_config.num_attention_heads,
num_query_groups=hf_config.num_key_value_heads,
kv_channels=getattr(hf_config, "head_dim", None) or getattr(hf_config, "attention_head_dim", None),
init_method_std=hf_config.initializer_range,
layernorm_epsilon=hf_config.layer_norm_epsilon,
make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(hf_config.vocab_size),
vocab_size=hf_config.vocab_size,
share_embeddings_and_output_weights=getattr(hf_config, "tie_word_embeddings", False),
seq_length=hf_config.max_position_embeddings,
fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16),
bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16),
fp32_residual_connection=hf_config.residual_in_fp32,
params_dtype=self.dtype_from_hf(hf_config, default=torch.float32),
attention_dropout=hf_config.attention_dropout,
hidden_dropout=hf_config.hidden_dropout,
hybrid_override_pattern=hf_config.hybrid_override_pattern,
mamba_head_dim=hf_config.mamba_head_dim,
mamba_num_heads=hf_config.mamba_num_heads,
mamba_num_groups=hf_config.n_groups,
mamba_state_dim=hf_config.ssm_state_size,
add_qkv_bias=hf_config.attention_bias,
**configs,
)
# MoE-specific defaults (only if MoE is enabled)
if hasattr(hf_config, "n_routed_experts") and hf_config.n_routed_experts > 0:
provider.moe_aux_loss_coeff = 0.0001
provider.moe_router_score_function = "sigmoid"
provider.moe_router_enable_expert_bias = True
provider.moe_router_load_balancing_type = "seq_aux_loss"
provider.moe_router_dtype = "fp32"
provider.moe_grouped_gemm = True
provider.moe_token_dispatcher_type = "alltoall"
provider.moe_permute_fusion = True
provider.moe_shared_expert_overlap = True

return provider

def mapping_registry(self) -> MegatronMappingRegistry:
# Return MegatronMappingRegistry containing parameter mappings from Megatron to HF format
Expand Down
1 change: 0 additions & 1 deletion src/megatron/bridge/models/qwen/qwen3_next_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> Qwen3NextModelPr
fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16),
bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16),
params_dtype=self.dtype_from_hf(hf_config, default=torch.float32),
generation_config=hf_pretrained.generation_config,
qk_layernorm=True, # Qwen3 MoE uses QK layernorm
moe_grouped_gemm=True,
kv_channels=hf_config.head_dim,
Expand Down
1 change: 0 additions & 1 deletion src/megatron/bridge/models/qwen_vl/qwen25_vl_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> Qwen25VLModelProvider
fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16),
bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16),
params_dtype=self.dtype_from_hf(hf_config, default=torch.float32),
generation_config=hf_pretrained.generation_config,
add_qkv_bias=True, # Qwen2 has bias in QKV projections
vision_config=hf_config.vision_config,
# VL-specific token IDs
Expand Down
2 changes: 0 additions & 2 deletions src/megatron/bridge/models/qwen_vl/qwen3_vl_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> Qwen3VLModelProvider:
fp16=(model_dtype == torch.float16),
bf16=(model_dtype == torch.bfloat16),
params_dtype=model_dtype,
generation_config=hf_pretrained.generation_config,
# Qwen3 specific parameters
add_qkv_bias=text_config.attention_bias, # Qwen3 can have bias in QKV
qk_layernorm=True, # Qwen3 uses QK layernorm
Expand Down Expand Up @@ -263,7 +262,6 @@ def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> Qwen3VLMoEModelProvid
fp16=(model_dtype == torch.float16),
bf16=(model_dtype == torch.bfloat16),
params_dtype=model_dtype,
generation_config=hf_pretrained.generation_config,
# Qwen3 specific parameters
add_qkv_bias=text_config.attention_bias, # Qwen3 can have bias in QKV
qk_layernorm=True, # Qwen3 uses QK layernorm
Expand Down
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this test removed?

This file was deleted.

Loading