Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
9ac259b
[DRAFT] Refactor provider_bridge for Llama and Qwen models
yaoyu-33 Jan 23, 2026
ab8a5e4
refactor(bridge): Introduce MLAModelProvider for DeepSeek/Kimi MLA mo…
yaoyu-33 Jan 23, 2026
0a12eb7
refactor(bridge): Refactor Gemma bridges to use specialized providers
yaoyu-33 Jan 24, 2026
aa54966
Merge branch 'main' into feature/provider-bridge-refactor
yaoyu-33 Jan 24, 2026
4ade385
update
yaoyu-33 Jan 26, 2026
3ec0b0a
remove MEGATRON_DEFAULTS
yaoyu-33 Jan 26, 2026
c100b4a
remove testing scripts
yaoyu-33 Jan 26, 2026
3c709c5
yarn fix
yaoyu-33 Jan 26, 2026
9eb8542
clean ups
yaoyu-33 Jan 27, 2026
2cc3b35
fix
yaoyu-33 Jan 27, 2026
ca54e4f
fix unit tests
yaoyu-33 Jan 27, 2026
1ed2069
unit test fix
yaoyu-33 Jan 27, 2026
9facb3e
functional test fix
yaoyu-33 Jan 27, 2026
b753ab8
code rabbit
yaoyu-33 Jan 27, 2026
ab24acf
fix functional tests
yaoyu-33 Jan 28, 2026
b6f8e29
remove KimiK2Bridge from init
yaoyu-33 Jan 28, 2026
167055a
clean up deprecated functional tests
yaoyu-33 Jan 28, 2026
22a6321
Merge branch 'main' into feature/provider-bridge-refactor
yaoyu-33 Feb 2, 2026
f9a3231
olmoe update
yaoyu-33 Feb 3, 2026
d5b7890
Fix kv_channels calculation for OLMoE bridge
yaoyu-33 Feb 3, 2026
ada7d05
fix: always set yarn params with None defaults for MCoreGPTModel comp…
yaoyu-33 Feb 3, 2026
5f24f9b
Merge branch 'main' into feature/provider-bridge-refactor
yaoyu-33 Feb 3, 2026
057175f
lint
yaoyu-33 Feb 3, 2026
54e0971
Merge branch 'main' into feature/provider-bridge-refactor-2
yaoyu-33 Feb 4, 2026
6c6a3b9
nemotron H bridge update
yaoyu-33 Feb 4, 2026
0c94698
nemotron bridge update
yaoyu-33 Feb 4, 2026
4d827e6
nemotron bridge update
yaoyu-33 Feb 5, 2026
0cb6cb5
qwen 25 vl provider rename
yaoyu-33 Feb 5, 2026
eb9656f
remove autocast_dtype in config
yaoyu-33 Feb 5, 2026
07a4e6c
update gemma3 vl and nemotron vl
yaoyu-33 Feb 5, 2026
eb53b93
revert autocast_dtype change
yaoyu-33 Feb 5, 2026
db38206
Merge branch 'main' into feature/provider-bridge-refactor-2
yaoyu-33 Feb 5, 2026
9618f91
fix: remove generation_config from GPTModelProvider and tests
yaoyu-33 Feb 5, 2026
2b0bb06
Merge branch 'main' into feature/provider-bridge-refactor-2
yaoyu-33 Feb 5, 2026
58b2c6e
fix: remove additional generation_config tests from bridge test files
yaoyu-33 Feb 5, 2026
2310230
fix: add relu2 activation support and fix Nemotron/NemotronH tests
yaoyu-33 Feb 5, 2026
77a567b
fix: update Nemotron bridge tests for correct provider types
yaoyu-33 Feb 5, 2026
dabe9d5
remove deprecated provider test
yaoyu-33 Feb 5, 2026
36a79d6
Merge branch 'feature/provider-bridge-refactor-2' into feature/provid…
yaoyu-33 Feb 5, 2026
155f683
fix unit tests
yaoyu-33 Feb 5, 2026
1ac68fb
Merge branch 'main' into feature/provider-bridge-refactor-3
yaoyu-33 Feb 10, 2026
1144273
Merge branch 'main' into feature/provider-bridge-refactor-3
yaoyu-33 Feb 11, 2026
f022e1f
fix: correct default image_token_id for Gemma3VL from 151655 to 262144
yaoyu-33 Feb 12, 2026
c6bc339
test: fix gemma3_vl bridge test for image_token_id default
yaoyu-33 Feb 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 42 additions & 31 deletions src/megatron/bridge/models/gemma_vl/gemma3_vl_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,48 +30,59 @@
from megatron.bridge.models.hf_pretrained.vlm import PreTrainedVLM


@MegatronModelBridge.register_bridge(source=Gemma3ForConditionalGeneration, target=Gemma3VLModel)
@MegatronModelBridge.register_bridge(
source=Gemma3ForConditionalGeneration,
target=Gemma3VLModel,
provider=Gemma3VLModelProvider,
model_type="gemma3_vl",
)
class Gemma3VLBridge(MegatronModelBridge):
"""
Megatron Bridge for Gemma3 VL.

This bridge handles the conversion between HuggingFace Gemma3ForConditionalGeneration
and Megatron-Core Gemma3VLModel formats, including weight mappings and
configuration translation for vision-language models.

Example:
>>> from megatron.bridge import AutoBridge
>>> bridge = AutoBridge.from_hf_pretrained("google/gemma-3-4b-it")
>>> provider = bridge.to_megatron_provider()
"""

def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> Gemma3VLModelProvider:
hf_config = hf_pretrained.config
text_config = hf_config.text_config
vision_config = hf_config.vision_config

provider = Gemma3VLModelProvider(
# Text configuration
init_method_std=text_config.initializer_range,
hidden_size=text_config.hidden_size,
ffn_hidden_size=text_config.intermediate_size,
kv_channels=text_config.head_dim,
seq_length=text_config.max_position_embeddings,
num_attention_heads=text_config.num_attention_heads,
num_layers=text_config.num_hidden_layers,
num_query_groups=text_config.num_key_value_heads,
window_size=text_config.sliding_window,
rotary_base=(text_config.rope_local_base_freq, text_config.rope_theta),
layernorm_epsilon=text_config.rms_norm_eps,
vocab_size=text_config.vocab_size,
softmax_scale=1.0 / math.sqrt(text_config.query_pre_attn_scalar),
rope_scaling_factor=text_config.rope_scaling["factor"] if text_config.rope_scaling else 1.0,
# Vision configuration
vision_config=vision_config,
mm_tokens_per_image=hf_config.mm_tokens_per_image,
# VL-specific token IDs
bos_token_id=getattr(hf_config, "bos_token_id", 0),
eos_token_id=getattr(hf_config, "eos_token_id", 1),
vision_start_token_id=getattr(hf_config, "vision_start_token_id", 255999),
vision_end_token_id=getattr(hf_config, "vision_end_token_id", 256000),
image_token_id=getattr(hf_config, "image_token_id", 151655),
# Precision configuration
fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16),
bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16),
params_dtype=self.dtype_from_hf(hf_config, default=torch.float32),
)
# Use base class helper for common config conversion
provider_kwargs = self.hf_config_to_provider_kwargs(text_config)
provider = Gemma3VLModelProvider(**provider_kwargs)

# Gemma3-specific features not in CONFIG_MAPPING
provider.window_size = text_config.sliding_window
provider.rotary_base = (text_config.rope_local_base_freq, text_config.rope_theta)
provider.softmax_scale = 1.0 / math.sqrt(text_config.query_pre_attn_scalar)
provider.rope_scaling_factor = text_config.rope_scaling["factor"] if text_config.rope_scaling else 1.0

# Override dtype and vocab settings to match baseline
provider.bf16 = True
provider.params_dtype = torch.bfloat16
provider.autocast_dtype = torch.bfloat16
provider.make_vocab_size_divisible_by = 128

# Vision configuration
provider.vision_config = vision_config
provider.mm_tokens_per_image = hf_config.mm_tokens_per_image

# VL-specific token IDs
provider.bos_token_id = getattr(hf_config, "bos_token_id", 0)
provider.eos_token_id = getattr(hf_config, "eos_token_id", 1)
provider.vision_start_token_id = getattr(hf_config, "vision_start_token_id", 255999)
provider.vision_end_token_id = getattr(hf_config, "vision_end_token_id", 256000)
provider.image_token_id = getattr(hf_config, "image_token_id", 262144)

# Vision projector configuration
provider.vision_projector_config.input_size = vision_config.hidden_size
provider.vision_projector_config.hidden_size = text_config.hidden_size

Expand Down
46 changes: 26 additions & 20 deletions src/megatron/bridge/models/nemotron_vl/nemotron_vl_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from megatron.core.activations import squared_relu

from megatron.bridge.models import ColumnParallelMapping, RowParallelMapping
from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry
Expand All @@ -29,7 +29,12 @@
from megatron.bridge.models.nemotron_vl.nemotron_vl_provider import NemotronNano12Bv2VLModelProvider


@MegatronModelBridge.register_bridge(source="NemotronH_Nano_VL_V2", target=NemotronVLModel)
@MegatronModelBridge.register_bridge(
source="NemotronH_Nano_VL_V2",
target=NemotronVLModel,
provider=NemotronNano12Bv2VLModelProvider,
model_type="nemotron_vl",
)
class NemotronVLBridge(MegatronModelBridge):
"""Conversion utilities between HF Nemotron-VL and Megatron-Core format."""

Expand All @@ -39,25 +44,26 @@ class NemotronVLBridge(MegatronModelBridge):

def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> NemotronNano12Bv2VLModelProvider: # type: ignore[override]
hf_config = hf_pretrained.config
llm_config = hf_config.llm_config

# Use base class helper for common config mapping
provider_kwargs = self.hf_config_to_provider_kwargs(llm_config)

# Handle vocab size divisibility
provider_kwargs["make_vocab_size_divisible_by"] = self.make_vocab_size_divisible_by(llm_config.vocab_size)

provider = NemotronNano12Bv2VLModelProvider(**provider_kwargs)

# Nemotron VL-specific settings
# Note: Most defaults come from the provider class hierarchy (NemotronNano12Bv2Provider)
provider.scatter_embedding_sequence_parallel = False
provider.attention_softmax_in_fp32 = True

# Override fields that should use NemotronH provider's specialized defaults
# instead of HF config values
provider.activation_func = squared_relu # Nemotron uses squared_relu, not HF's hidden_act
provider.autocast_dtype = None # Not set in original code

provider = NemotronNano12Bv2VLModelProvider(
num_layers=hf_config.llm_config.num_hidden_layers,
hidden_size=hf_config.llm_config.hidden_size,
ffn_hidden_size=hf_config.llm_config.intermediate_size,
num_attention_heads=hf_config.llm_config.num_attention_heads,
num_query_groups=getattr(
hf_config.llm_config, "num_key_value_heads", hf_config.llm_config.num_attention_heads // 2
),
init_method_std=hf_config.llm_config.initializer_range,
layernorm_epsilon=getattr(hf_config.llm_config, "layer_norm_epsilon", 1e-5),
make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(hf_config.llm_config.vocab_size),
share_embeddings_and_output_weights=getattr(hf_config.llm_config, "tie_word_embeddings", False),
vocab_size=hf_config.llm_config.vocab_size,
seq_length=hf_config.llm_config.max_position_embeddings,
fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16),
bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16),
params_dtype=self.dtype_from_hf(hf_config, default=torch.float32),
)
return provider

# ------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion src/megatron/bridge/models/qwen_vl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
Qwen3VLMoEModelProvider,
)
from megatron.bridge.models.qwen_vl.qwen25_vl_bridge import Qwen25VLBridge
from megatron.bridge.models.qwen_vl.qwen_vl_provider import (
from megatron.bridge.models.qwen_vl.qwen25_vl_provider import (
Qwen25VLModelProvider,
)

Expand Down
59 changes: 28 additions & 31 deletions src/megatron/bridge/models/qwen_vl/qwen25_vl_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from transformers import Qwen2_5_VLForConditionalGeneration

from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry
Expand All @@ -25,10 +24,15 @@
)
from megatron.bridge.models.hf_pretrained.vlm import PreTrainedVLM
from megatron.bridge.models.qwen_vl.modeling_qwen25_vl import Qwen25VLModel
from megatron.bridge.models.qwen_vl.qwen_vl_provider import Qwen25VLModelProvider
from megatron.bridge.models.qwen_vl.qwen25_vl_provider import Qwen25VLModelProvider


@MegatronModelBridge.register_bridge(source=Qwen2_5_VLForConditionalGeneration, target=Qwen25VLModel)
@MegatronModelBridge.register_bridge(
source=Qwen2_5_VLForConditionalGeneration,
target=Qwen25VLModel,
provider=Qwen25VLModelProvider,
model_type="qwen2_5_vl",
)
class Qwen25VLBridge(MegatronModelBridge):
"""
Megatron Bridge for Qwen2.5-VL Conditional Generation.
Expand All @@ -45,35 +49,28 @@ class Qwen25VLBridge(MegatronModelBridge):

def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> Qwen25VLModelProvider:
hf_config = hf_pretrained.config
text_config = hf_config # Qwen2.5-VL has text config fields directly on main config

provider = Qwen25VLModelProvider(
num_layers=hf_config.num_hidden_layers,
hidden_size=hf_config.hidden_size,
ffn_hidden_size=hf_config.intermediate_size,
num_attention_heads=hf_config.num_attention_heads,
num_query_groups=hf_config.num_key_value_heads,
init_method_std=hf_config.initializer_range,
layernorm_epsilon=hf_config.rms_norm_eps,
gated_linear_unit=True,
make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(hf_config.vocab_size),
rotary_base=hf_config.rope_theta,
share_embeddings_and_output_weights=getattr(hf_config, "tie_word_embeddings", False),
vocab_size=hf_config.vocab_size,
seq_length=hf_config.max_position_embeddings,
fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16),
bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16),
params_dtype=self.dtype_from_hf(hf_config, default=torch.float32),
add_qkv_bias=True, # Qwen2 has bias in QKV projections
vision_config=hf_config.vision_config,
# VL-specific token IDs
bos_token_id=getattr(hf_config, "bos_token_id", 151643),
eos_token_id=getattr(hf_config, "eos_token_id", 151645),
vision_start_token_id=getattr(hf_config, "vision_start_token_id", 151652),
vision_end_token_id=getattr(hf_config, "vision_end_token_id", 151653),
vision_token_id=getattr(hf_config, "vision_token_id", 151654),
image_token_id=getattr(hf_config, "image_token_id", 151655),
video_token_id=getattr(hf_config, "video_token_id", 151656),
)
provider_kwargs = self.hf_config_to_provider_kwargs(text_config)
provider = Qwen25VLModelProvider(**provider_kwargs)

# Qwen2-specific settings
provider.normalization = "RMSNorm"
provider.gated_linear_unit = True
provider.add_qkv_bias = True
provider.add_bias_linear = False
provider.hidden_dropout = 0.0

# VL-specific overrides
provider.position_embedding_type = "mrope"
provider.vision_config = hf_config.vision_config
provider.bos_token_id = getattr(hf_config, "bos_token_id", 151643)
provider.eos_token_id = getattr(hf_config, "eos_token_id", 151645)
provider.vision_start_token_id = getattr(hf_config, "vision_start_token_id", 151652)
provider.vision_end_token_id = getattr(hf_config, "vision_end_token_id", 151653)
provider.vision_token_id = getattr(hf_config, "vision_token_id", 151654)
provider.image_token_id = getattr(hf_config, "image_token_id", 151655)
provider.video_token_id = getattr(hf_config, "video_token_id", 151656)

return provider

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,16 @@
from megatron.core.models.gpt import GPTModel as MCoreGPTModel
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionConfig

from megatron.bridge.models import (
Qwen2ModelProvider,
)

from .modeling_qwen25_vl import Qwen25VLModel


# =============================================================================
# Qwen 2.5 VL Model Providers
# =============================================================================
from megatron.bridge.models.gpt_provider import GPTModelProvider
from megatron.bridge.models.qwen_vl.modeling_qwen25_vl import Qwen25VLModel


@dataclass
class Qwen25VLModelProvider(Qwen2ModelProvider):
class Qwen25VLModelProvider(GPTModelProvider):
"""
Base model provider for Qwen 2.5 VL Models.
"""

# Language configuration inherited from Qwen25ModelProvider3B
# VL models shouldn't scatter embeddings across sequence parallel regions because
# the vision embeddings are going to be inserted into the language embeddings.
scatter_embedding_sequence_parallel: bool = False
Expand Down Expand Up @@ -74,4 +65,4 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> Qwen25V
return model

def provide_language_model(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGPTModel:
return super().provide(pre_process=pre_process, post_process=post_process, vp_stage=vp_stage)
return GPTModelProvider.provide(self, pre_process=pre_process, post_process=post_process, vp_stage=vp_stage)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't we use super() here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can use super, just want to be more explicit here what's being called. Going to change in Mannu's refactor anyways.

Loading
Loading