Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/megatron/bridge/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
LlamaNemotronHeterogeneousProvider,
)
from megatron.bridge.models.mamba.mamba_provider import MambaModelProvider
from megatron.bridge.models.mimo.mimo_bridge import MimoBridge
from megatron.bridge.models.ministral3 import (
Ministral3Bridge,
Ministral3Model,
Expand Down Expand Up @@ -312,6 +313,7 @@
"NemotronNano12Bv2Provider",
"Nemotron3NanoProvider",
"MambaModelProvider",
"MimoBridge",
# Nemotron Models
"NemotronBridge",
"NemotronModelProvider",
Expand Down
155 changes: 155 additions & 0 deletions src/megatron/bridge/models/mimo/mimo_bridge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Mapping

import torch
from megatron.core.models.gpt.gpt_model import GPTModel

from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry
from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge, WeightConversionTask
from megatron.bridge.models.conversion.param_mapping import (
AutoMapping,
GatedMLPMapping,
QKVMapping,
)
from megatron.bridge.models.qwen.qwen2_bridge import Qwen2Bridge


@MegatronModelBridge.register_bridge(source="MiMoForCausalLM", target=GPTModel, model_type="mimo")
class MimoBridge(Qwen2Bridge):
"""Megatron Bridge for MiMo Causal LM."""

def provider_bridge(self, hf_pretrained):
provider = super().provider_bridge(hf_pretrained)
hf_config = hf_pretrained.config

# MiMo follows Qwen2 attention behavior and adds MTP on top.
provider.qk_layernorm = False
provider.add_qkv_bias = True

num_mtp_layers = getattr(hf_config, "num_nextn_predict_layers", 0)
if num_mtp_layers > 0:
provider.mtp_num_layers = num_mtp_layers
provider.mtp_loss_scaling_factor = 0.1

return provider

def mapping_registry(self) -> MegatronMappingRegistry:
mapping_list = list(super().mapping_registry().mappings)

mapping_list.extend(
[
AutoMapping(
megatron_param="mtp.layers.*.enorm.weight",
hf_param="model.mtp_layers.*.token_layernorm.weight",
),
AutoMapping(
megatron_param="mtp.layers.*.hnorm.weight",
hf_param="model.mtp_layers.*.hidden_layernorm.weight",
),
AutoMapping(
megatron_param="mtp.layers.*.eh_proj.weight",
hf_param="model.mtp_layers.*.input_proj.weight",
),
AutoMapping(
megatron_param="mtp.layers.*.final_layernorm.weight",
hf_param="model.mtp_layers.*.final_layernorm.weight",
),
]
)

# Support both naming conventions: Megatron-Core may expose MTP layers as
# either "transformer_layer" or "mtp_model_layer" depending on configuration
for layer_prefix in ("transformer_layer", "mtp_model_layer"):
layer_path = f"mtp.layers.*.{layer_prefix}"
mapping_list.extend(
[
AutoMapping(
megatron_param=f"{layer_path}.self_attention.linear_qkv.layer_norm_weight",
hf_param="model.mtp_layers.*.input_layernorm.weight",
),
AutoMapping(
megatron_param=f"{layer_path}.mlp.linear_fc1.layer_norm_weight",
hf_param="model.mtp_layers.*.post_attention_layernorm.weight",
),
AutoMapping(
megatron_param=f"{layer_path}.self_attention.linear_proj.weight",
hf_param="model.mtp_layers.*.self_attn.o_proj.weight",
),
AutoMapping(
megatron_param=f"{layer_path}.mlp.linear_fc2.weight",
hf_param="model.mtp_layers.*.mlp.down_proj.weight",
),
QKVMapping(
megatron_param=f"{layer_path}.self_attention.linear_qkv.weight",
q="model.mtp_layers.*.self_attn.q_proj.weight",
k="model.mtp_layers.*.self_attn.k_proj.weight",
v="model.mtp_layers.*.self_attn.v_proj.weight",
),
QKVMapping(
megatron_param=f"{layer_path}.self_attention.linear_qkv.bias",
q="model.mtp_layers.*.self_attn.q_proj.bias",
k="model.mtp_layers.*.self_attn.k_proj.bias",
v="model.mtp_layers.*.self_attn.v_proj.bias",
),
GatedMLPMapping(
megatron_param=f"{layer_path}.mlp.linear_fc1.weight",
gate="model.mtp_layers.*.mlp.gate_proj.weight",
up="model.mtp_layers.*.mlp.up_proj.weight",
),
]
)

return MegatronMappingRegistry(*mapping_list)

@staticmethod
def _swap_input_proj_halves(weight: torch.Tensor) -> torch.Tensor:
if weight.ndim < 2:
raise ValueError(
f"Expected tensor with at least 2 dimensions for input_proj weight, got shape {weight.shape}"
)
if weight.shape[1] % 2 != 0:
raise ValueError(f"Expected even dimension at dim=1 for input_proj weight, got shape {weight.shape}")
first_half, second_half = weight.chunk(2, dim=1)
return torch.cat((second_half, first_half), dim=1)

def maybe_modify_loaded_hf_weight(
self, hf_param: str | dict[str, str], hf_state_dict: Mapping[str, torch.Tensor]
) -> torch.Tensor:
hf_weights = super().maybe_modify_loaded_hf_weight(hf_param, hf_state_dict)
if isinstance(hf_param, str) and hf_param.endswith(".input_proj.weight"):
return self._swap_input_proj_halves(hf_weights)
return hf_weights

def maybe_modify_converted_hf_weight(
self,
task: WeightConversionTask,
converted_weights_dict: dict[str, torch.Tensor],
hf_state_dict: Mapping[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
converted_weights_dict = super().maybe_modify_converted_hf_weight(
task,
converted_weights_dict,
hf_state_dict,
)

if not task.global_param_name.endswith(".eh_proj.weight"):
return converted_weights_dict

for hf_name, weight in list(converted_weights_dict.items()):
if hf_name.endswith(".input_proj.weight"):
converted_weights_dict[hf_name] = self._swap_input_proj_halves(weight)

return converted_weights_dict
128 changes: 128 additions & 0 deletions tests/unit_tests/models/mimo/test_mimo_bridge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import Mock

import pytest
import torch
from transformers import GenerationConfig

from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge, WeightConversionTask
from megatron.bridge.models.gpt_provider import GPTModelProvider
from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM
from megatron.bridge.models.mimo.mimo_bridge import MimoBridge


class TestMimoBridge:
"""Test cases for MimoBridge."""

@pytest.fixture
def mimo_config(self):
return {
"architectures": ["MiMoForCausalLM"],
"attention_bias": True,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 11008,
"max_position_embeddings": 32768,
"model_type": "mimo",
"num_attention_heads": 32,
"num_hidden_layers": 36,
"num_key_value_heads": 8,
"num_nextn_predict_layers": 1,
"rms_norm_eps": 1e-05,
"rope_theta": 640000.0,
"tie_word_embeddings": False,
"torch_dtype": "bfloat16",
"vocab_size": 151680,
}

@pytest.fixture
def mock_pretrained_mimo(self, mimo_config):
cfg = Mock(spec=list(mimo_config.keys()))
for key, value in mimo_config.items():
setattr(cfg, key, value)

model = Mock(spec=PreTrainedCausalLM)
model.config = cfg
model.generation_config = Mock(spec=GenerationConfig)
return model

def test_registration(self):
assert issubclass(MimoBridge, MegatronModelBridge)

def test_provider_bridge_maps_mtp_config(self, mock_pretrained_mimo):
bridge = MimoBridge()
provider = bridge.provider_bridge(mock_pretrained_mimo)

assert isinstance(provider, GPTModelProvider)
assert provider.hidden_size == mock_pretrained_mimo.config.hidden_size
assert provider.num_attention_heads == mock_pretrained_mimo.config.num_attention_heads
assert provider.ffn_hidden_size == mock_pretrained_mimo.config.intermediate_size
assert provider.vocab_size == mock_pretrained_mimo.config.vocab_size
assert provider.qk_layernorm is False
assert provider.add_qkv_bias is True
assert provider.mtp_num_layers == mock_pretrained_mimo.config.num_nextn_predict_layers
assert provider.mtp_loss_scaling_factor == 0.1
assert provider.bf16 is True
assert provider.params_dtype == torch.bfloat16

def test_mapping_registry_includes_mtp_paths(self):
bridge = MimoBridge()
registry = bridge.mapping_registry()

mapping = registry.megatron_to_hf_lookup("mtp.layers.0.eh_proj.weight")
assert mapping is not None
assert mapping.hf_param == "model.mtp_layers.0.input_proj.weight"

transformer_mapping = registry.megatron_to_hf_lookup(
"mtp.layers.0.transformer_layer.self_attention.linear_qkv.weight"
)
assert transformer_mapping is not None
assert transformer_mapping.hf_param["q"] == "model.mtp_layers.0.self_attn.q_proj.weight"

mtp_model_mapping = registry.megatron_to_hf_lookup(
"mtp.layers.0.mtp_model_layer.self_attention.linear_qkv.weight"
)
assert mtp_model_mapping is not None
assert mtp_model_mapping.hf_param["q"] == "model.mtp_layers.0.self_attn.q_proj.weight"

def test_mtp_input_proj_swap_on_hf_load(self):
bridge = MimoBridge()
weight = torch.arange(24, dtype=torch.float32).reshape(3, 8)
hf_key = "model.mtp_layers.0.input_proj.weight"

modified = bridge.maybe_modify_loaded_hf_weight(hf_key, {hf_key: weight})

expected = torch.cat((weight[:, 4:], weight[:, :4]), dim=1)
assert torch.equal(modified, expected)

def test_mtp_input_proj_swap_on_hf_export(self):
bridge = MimoBridge()
weight = torch.arange(24, dtype=torch.float32).reshape(3, 8)

task = WeightConversionTask(
param_name="mtp.layers.0.eh_proj.weight",
global_param_name="mtp.layers.0.eh_proj.weight",
mapping=Mock(),
)
converted = {"model.mtp_layers.0.input_proj.weight": weight}

modified = bridge.maybe_modify_converted_hf_weight(task, dict(converted), {})

expected = torch.cat((weight[:, 4:], weight[:, :4]), dim=1)
assert torch.equal(modified["model.mtp_layers.0.input_proj.weight"], expected)
Loading