Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions src/megatron/bridge/models/conversion/model_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,17 +358,20 @@ def hf_config_to_provider_kwargs(self, hf_config) -> dict:
# Map config fields using CONFIG_MAPPING
# Supports dot notation for nested dict access (e.g., "rope_scaling.factor")
for hf_name, megatron_name in self.CONFIG_MAPPING:
has_value = False
value = None
if "." in hf_name:
# Nested dict access: "parent.child" -> getattr(config, parent).get(child)
parts = hf_name.split(".", 1)
parent = getattr(hf_config, parts[0], None)
if parent is not None and isinstance(parent, dict):
value = parent.get(parts[1])
else:
value = None
if parts[1] in parent:
value = parent[parts[1]]
has_value = True
else:
value = getattr(hf_config, hf_name, None)
if value is not None:
has_value = hasattr(hf_config, hf_name)
if has_value:
provider_kwargs[megatron_name] = value

# Extract rotary_base via compat function (handles both legacy rope_theta
Expand Down Expand Up @@ -505,8 +508,9 @@ def megatron_to_hf_config(cls, provider) -> dict:
# Map config fields using CONFIG_MAPPING (reverse direction)
# Supports dot notation for nested dict building (e.g., "rope_scaling.factor")
for hf_name, megatron_name in cls.CONFIG_MAPPING:
has_value = hasattr(provider, megatron_name)
value = getattr(provider, megatron_name, None)
if value is not None:
if has_value:
if "." in hf_name:
# Nested dict: "parent.child" -> hf_config["parent"]["child"] = value
parts = hf_name.split(".", 1)
Expand All @@ -524,8 +528,9 @@ def megatron_to_hf_config(cls, provider) -> dict:
hf_config["rope_scaling"]["rope_type"] = "yarn"

for hf_key, megatron_key in cls.YARN_ROPE_SCALING_MAPPING:
has_value = hasattr(provider, megatron_key)
value = getattr(provider, megatron_key, None)
if value is not None:
if has_value:
hf_config["rope_scaling"][hf_key] = value

yarn_correction_range_round_to_int = getattr(provider, "yarn_correction_range_round_to_int", None)
Expand Down
95 changes: 95 additions & 0 deletions tests/unit_tests/models/deepseek/test_deepseek_bridges.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,74 @@ def test_provider_bridge_maps_config(self, mock_pretrained_v2):
assert provider.bf16 is True
assert provider.params_dtype == torch.bfloat16

def test_hf_config_to_provider_kwargs_preserves_none_q_lora_rank(self, mock_pretrained_v2):
mock_pretrained_v2.config.q_lora_rank = None
bridge = DeepSeekV2Bridge()

provider_kwargs = bridge.hf_config_to_provider_kwargs(mock_pretrained_v2.config)

assert "q_lora_rank" in provider_kwargs
assert provider_kwargs["q_lora_rank"] is None

def test_provider_bridge_preserves_none_q_lora_rank(self, mock_pretrained_v2):
mock_pretrained_v2.config.q_lora_rank = None
bridge = DeepSeekV2Bridge()

provider = bridge.provider_bridge(mock_pretrained_v2)

assert provider.q_lora_rank is None

def test_megatron_to_hf_config_preserves_none_q_lora_rank(self, mock_pretrained_v2):
mock_pretrained_v2.config.q_lora_rank = None
bridge = DeepSeekV2Bridge()
provider = bridge.provider_bridge(mock_pretrained_v2)

hf_config = bridge.megatron_to_hf_config(provider)

assert "q_lora_rank" in hf_config
assert hf_config["q_lora_rank"] is None

def test_hf_config_to_provider_kwargs_nested_dot_notation(self, mock_pretrained_v2):
"""Test that dot-notation CONFIG_MAPPING reads nested dict values (including None)."""
bridge = DeepSeekV2Bridge()
# Patch CONFIG_MAPPING with a dot-notation entry pointing into rope_scaling dict
original = bridge.CONFIG_MAPPING
bridge.CONFIG_MAPPING = list(original) + [("rope_scaling.factor", "yarn_rotary_scaling_factor")]
mock_pretrained_v2.config.rope_scaling = {"factor": 40, "type": "yarn"}

kwargs = bridge.hf_config_to_provider_kwargs(mock_pretrained_v2.config)

bridge.CONFIG_MAPPING = original
assert kwargs.get("yarn_rotary_scaling_factor") == 40

def test_hf_config_to_provider_kwargs_nested_dot_notation_none_value(self, mock_pretrained_v2):
"""Test that dot-notation CONFIG_MAPPING preserves None values from nested dicts."""
bridge = DeepSeekV2Bridge()
original = bridge.CONFIG_MAPPING
bridge.CONFIG_MAPPING = list(original) + [("rope_scaling.factor", "yarn_rotary_scaling_factor")]
mock_pretrained_v2.config.rope_scaling = {"factor": None, "type": "yarn"}

kwargs = bridge.hf_config_to_provider_kwargs(mock_pretrained_v2.config)

bridge.CONFIG_MAPPING = original
assert "yarn_rotary_scaling_factor" in kwargs
assert kwargs["yarn_rotary_scaling_factor"] is None

def test_megatron_to_hf_config_yarn_none_value(self, mock_pretrained_v2):
"""Test that YARN_ROPE_SCALING_MAPPING preserves None values on provider."""
bridge = DeepSeekV2Bridge()
provider = bridge.provider_bridge(mock_pretrained_v2)
# Ensure YARN rope_scaling block is emitted
provider.yarn_rotary_scaling_factor = 40
# Set a YARN key to None — should still appear in hf_config["rope_scaling"]
provider.yarn_mscale = None

hf_config = bridge.megatron_to_hf_config(provider)

assert "rope_scaling" in hf_config
assert "mscale" in hf_config["rope_scaling"]
assert hf_config["rope_scaling"]["mscale"] is None


class TestDeepSeekV3Bridge:
"""Test cases for DeepSeekV3Bridge."""
Expand Down Expand Up @@ -222,6 +290,33 @@ def test_provider_bridge_maps_config(self, mock_pretrained_v3):
assert provider.bf16 is True
assert provider.params_dtype == torch.bfloat16

def test_hf_config_to_provider_kwargs_preserves_none_q_lora_rank(self, mock_pretrained_v3):
mock_pretrained_v3.config.q_lora_rank = None
bridge = DeepSeekV3Bridge()

provider_kwargs = bridge.hf_config_to_provider_kwargs(mock_pretrained_v3.config)

assert "q_lora_rank" in provider_kwargs
assert provider_kwargs["q_lora_rank"] is None

def test_provider_bridge_preserves_none_q_lora_rank(self, mock_pretrained_v3):
mock_pretrained_v3.config.q_lora_rank = None
bridge = DeepSeekV3Bridge()

provider = bridge.provider_bridge(mock_pretrained_v3)

assert provider.q_lora_rank is None

def test_megatron_to_hf_config_preserves_none_q_lora_rank(self, mock_pretrained_v3):
mock_pretrained_v3.config.q_lora_rank = None
bridge = DeepSeekV3Bridge()
provider = bridge.provider_bridge(mock_pretrained_v3)

hf_config = bridge.megatron_to_hf_config(provider)

assert "q_lora_rank" in hf_config
assert hf_config["q_lora_rank"] is None

def test_export_injects_inv_freq_for_layer(self, mock_pretrained_v3):
bridge = DeepSeekV3Bridge()
bridge.hf_config = mock_pretrained_v3.config
Expand Down
23 changes: 7 additions & 16 deletions tests/unit_tests/models/gemma_vl/test_gemma3_vl_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
@pytest.fixture
def mock_text_config():
"""Create a mock text config for Gemma3 VL."""
config = Mock()
# Use spec=[] so hasattr() only returns True for explicitly-set attributes,
# matching real HF config behaviour (Gemma3 text config has no MLA fields
# like q_lora_rank, so they must not appear in the provider kwargs).
config = Mock(spec=[])
config.num_hidden_layers = 28
config.hidden_size = 2560
config.intermediate_size = 15360
Expand All @@ -46,13 +49,7 @@ def mock_text_config():
config.rope_scaling = None
config.rope_parameters = None
config.hidden_act = "gelu_pytorch_tanh"
# Set MLA-specific fields to None (these are auto-mapped in CONFIG_MAPPING)
config.q_lora_rank = None
config.kv_lora_rank = None
config.qk_nope_head_dim = None
config.qk_rope_head_dim = None
config.v_head_dim = None
config.num_nextn_predict_layers = None
config.torch_dtype = "bfloat16"
return config


Expand Down Expand Up @@ -367,7 +364,7 @@ def test_provider_bridge_with_minimal_config(self, gemma3_vl_bridge):
minimal_config = Mock()

# Create minimal text config
text_config = Mock()
text_config = Mock(spec=[])
text_config.num_hidden_layers = 18
text_config.hidden_size = 2048
text_config.intermediate_size = 8192
Expand All @@ -385,13 +382,7 @@ def test_provider_bridge_with_minimal_config(self, gemma3_vl_bridge):
text_config.rope_scaling = None
text_config.rope_parameters = None
text_config.hidden_act = "gelu_pytorch_tanh"
# Set MLA-specific fields to None
text_config.q_lora_rank = None
text_config.kv_lora_rank = None
text_config.qk_nope_head_dim = None
text_config.qk_rope_head_dim = None
text_config.v_head_dim = None
text_config.num_nextn_predict_layers = None
text_config.torch_dtype = "bfloat16"

# Create minimal vision config
vision_config = SiglipVisionConfig()
Expand Down
7 changes: 4 additions & 3 deletions tests/unit_tests/models/llama/test_llama_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from megatron.bridge.models import AutoBridge
from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge
from megatron.bridge.models.conversion.transformers_compat import rope_theta_from_hf
from megatron.bridge.models.gpt_provider import GPTModelProvider
from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM
from megatron.bridge.models.llama.llama_bridge import LlamaBridge
Expand Down Expand Up @@ -155,7 +156,7 @@ def test_provider_bridge_architecture_mapping(self, mock_pretrained_llama, llama
assert result.num_attention_heads == llama_config.num_attention_heads
assert result.num_query_groups == llama_config.num_key_value_heads
assert result.seq_length == llama_config.max_position_embeddings
assert result.rotary_base == llama_config.rope_parameters["rope_theta"]
assert result.rotary_base == rope_theta_from_hf(llama_config)
assert result.vocab_size == llama_config.vocab_size
assert result.layernorm_epsilon == llama_config.rms_norm_eps
assert result.init_method_std == llama_config.initializer_range
Expand Down Expand Up @@ -225,7 +226,7 @@ def test_provider_bridge_rope_scaling_params(self, mock_pretrained_llama):
assert result.rope_scaling is True
assert result.rope_scaling_factor == 32.0
# Check position embedding
assert result.rotary_base == mock_pretrained_llama.config.rope_parameters["rope_theta"]
assert result.rotary_base == rope_theta_from_hf(mock_pretrained_llama.config)

def test_provider_bridge_embedding_sharing(self, llama_config):
"""Test embedding sharing configuration."""
Expand Down Expand Up @@ -508,7 +509,7 @@ def test_roundtrip_hf_to_megatron_to_hf(self):
assert result_hf_config["num_key_value_heads"] == hf_config_dict["num_key_value_heads"]
assert result_hf_config["vocab_size"] == hf_config_dict["vocab_size"]
assert result_hf_config["max_position_embeddings"] == hf_config_dict["max_position_embeddings"]
assert result_hf_config["rope_theta"] == hf_config_dict["rope_parameters"]["rope_theta"]
assert result_hf_config["rope_theta"] == rope_theta_from_hf(config)
assert result_hf_config["rms_norm_eps"] == hf_config_dict["rms_norm_eps"]
assert result_hf_config["tie_word_embeddings"] == hf_config_dict["tie_word_embeddings"]
# Check new mappings are preserved
Expand Down
5 changes: 3 additions & 2 deletions tests/unit_tests/models/mistral/test_mistral_model_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from megatron.bridge.models import AutoBridge
from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge
from megatron.bridge.models.conversion.transformers_compat import rope_theta_from_hf
from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM
from megatron.bridge.models.mistral.mistral_bridge import MistralBridge
from megatron.bridge.models.mistral.mistral_provider import MistralModelProvider
Expand Down Expand Up @@ -101,7 +102,7 @@ def test_provider_bridge_basic(self, mock_pretrained_mistral, mistral_config):
assert result.hidden_size == mistral_config.hidden_size
assert result.num_attention_heads == mistral_config.num_attention_heads
assert result.seq_length == mistral_config.max_position_embeddings
assert result.rotary_base == mistral_config.rope_parameters["rope_theta"]
assert result.rotary_base == rope_theta_from_hf(mistral_config)

def test_provider_bridge_vocabulary(self, mock_pretrained_mistral, mistral_config):
"""Test vocabulary size mapping."""
Expand Down Expand Up @@ -149,7 +150,7 @@ def test_provider_bridge_position_embedding(self, mock_pretrained_mistral, mistr
result = bridge.provider_bridge(mock_pretrained_mistral)

# Check position embedding
assert result.rotary_base == mistral_config.rope_parameters["rope_theta"]
assert result.rotary_base == rope_theta_from_hf(mistral_config)

def test_provider_bridge_mistral_specific_features(self, mock_pretrained_mistral):
"""Test Mistral-specific features."""
Expand Down
13 changes: 5 additions & 8 deletions tests/unit_tests/models/nemotron_vl/test_nemotron_vl_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@

@pytest.fixture
def mock_llm_config():
cfg = Mock()
# Use spec=[] so hasattr() only returns True for explicitly-set attributes,
# matching real HF config behaviour (Nemotron config has no MLA fields
# like q_lora_rank, so they must not appear in the provider kwargs).
cfg = Mock(spec=[])
cfg.num_hidden_layers = 28
cfg.hidden_size = 5120
cfg.intermediate_size = 20480
Expand All @@ -38,14 +41,8 @@ def mock_llm_config():
cfg.vocab_size = 262144
cfg.max_position_embeddings = 131072
cfg.hidden_act = "relu2"
# Set MLA-specific fields to None (these are auto-mapped in CONFIG_MAPPING)
cfg.q_lora_rank = None
cfg.kv_lora_rank = None
cfg.qk_nope_head_dim = None
cfg.qk_rope_head_dim = None
cfg.v_head_dim = None
cfg.num_nextn_predict_layers = None
cfg.rope_scaling = None
cfg.torch_dtype = "bfloat16"
return cfg


Expand Down
5 changes: 3 additions & 2 deletions tests/unit_tests/models/qwen/test_qwen3_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from megatron.bridge.models import AutoBridge
from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge
from megatron.bridge.models.conversion.transformers_compat import rope_theta_from_hf
from megatron.bridge.models.gpt_provider import GPTModelProvider
from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM
from megatron.bridge.models.qwen.qwen3_bridge import Qwen3Bridge
Expand Down Expand Up @@ -100,7 +101,7 @@ def test_provider_bridge_basic(self, mock_pretrained_qwen3, qwen3_config):
assert result.hidden_size == qwen3_config.hidden_size
assert result.num_attention_heads == qwen3_config.num_attention_heads
assert result.seq_length == qwen3_config.max_position_embeddings
assert result.rotary_base == qwen3_config.rope_parameters["rope_theta"]
assert result.rotary_base == rope_theta_from_hf(qwen3_config)

def test_provider_bridge_vocabulary(self, mock_pretrained_qwen3, qwen3_config):
"""Test vocabulary size mapping."""
Expand Down Expand Up @@ -148,7 +149,7 @@ def test_provider_bridge_position_embedding(self, mock_pretrained_qwen3, qwen3_c
result = bridge.provider_bridge(mock_pretrained_qwen3)

# Check position embedding
assert result.rotary_base == qwen3_config.rope_parameters["rope_theta"]
assert result.rotary_base == rope_theta_from_hf(qwen3_config)

def test_provider_bridge_qwen3_specific_features(self, mock_pretrained_qwen3):
"""Test Qwen3-specific features."""
Expand Down
20 changes: 5 additions & 15 deletions tests/unit_tests/models/qwen_vl/test_qwen25_vl_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
@pytest.fixture
def mock_text_config():
"""Create a mock text config for Qwen2.5-VL."""
text_config = Mock()
text_config = Mock(spec=[])
text_config.num_hidden_layers = 32
text_config.hidden_size = 4096
text_config.intermediate_size = 11008
Expand All @@ -40,15 +40,10 @@ def mock_text_config():
text_config.rope_theta = 1000000.0
text_config.tie_word_embeddings = False
text_config.hidden_act = "silu"
text_config.q_lora_rank = None
text_config.kv_lora_rank = None
text_config.qk_nope_head_dim = None
text_config.qk_rope_head_dim = None
text_config.v_head_dim = None
text_config.num_nextn_predict_layers = None
text_config.rope_scaling = None
text_config.bos_token_id = 151643
text_config.eos_token_id = 151645
text_config.torch_dtype = "bfloat16"
return text_config


Expand Down Expand Up @@ -318,10 +313,10 @@ class TestQwen25VLBridgeEdgeCases:
def test_provider_bridge_with_minimal_config(self, qwen25_vl_bridge):
"""Test provider_bridge with minimal HF config."""
minimal_pretrained = Mock(spec=PreTrainedVLM)
minimal_config = Mock()
minimal_config = Mock(spec=[])

# Text config with required fields
text_config = Mock()
text_config = Mock(spec=[])
text_config.num_hidden_layers = 24
text_config.hidden_size = 2048
text_config.intermediate_size = 5504
Expand All @@ -333,13 +328,8 @@ def test_provider_bridge_with_minimal_config(self, qwen25_vl_bridge):
text_config.max_position_embeddings = 4096
text_config.rope_theta = 1000000.0
text_config.hidden_act = "silu"
text_config.q_lora_rank = None
text_config.kv_lora_rank = None
text_config.qk_nope_head_dim = None
text_config.qk_rope_head_dim = None
text_config.v_head_dim = None
text_config.num_nextn_predict_layers = None
text_config.rope_scaling = None
text_config.torch_dtype = "bfloat16"

minimal_config.text_config = text_config
minimal_config.vision_config = Qwen2_5_VLVisionConfig()
Expand Down
Loading