Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 8 additions & 21 deletions src/megatron/bridge/models/conversion/auto_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,8 @@ def export_hf_weights(
... cpu=True
... ))
"""
dispatch_instance = (self._causal_lm_architecture, self._get_model_instance(model))
return model_bridge.stream_weights_megatron_to_hf(
dispatch_instance,
bridge = self._model_bridge
return bridge.stream_weights_megatron_to_hf(
model,
self.hf_pretrained,
cpu=cpu,
Expand Down Expand Up @@ -420,13 +419,8 @@ def export_adapter_weights(
Yields:
HFWeightTuple: Named tuples of (param_name, weight_tensor) for adapter parameters
"""
dispatch_instance = (self._causal_lm_architecture, self._get_model_instance(model))
return model_bridge.stream_adapter_weights_megatron_to_hf(
dispatch_instance,
model,
cpu=cpu,
show_progress=show_progress,
)
bridge = self._model_bridge
return bridge.stream_adapter_weights_megatron_to_hf(model, cpu=cpu, show_progress=show_progress)

def save_hf_adapter(
self,
Expand Down Expand Up @@ -664,9 +658,8 @@ def save_hf_weights(
"""
if dist.is_available() and dist.is_initialized():
dist.barrier()
dispatch_instance = (self._causal_lm_architecture, self._get_model_instance(model))
generator = model_bridge.stream_weights_megatron_to_hf(
dispatch_instance,
bridge = self._model_bridge
generator = bridge.stream_weights_megatron_to_hf(
model,
self.hf_pretrained,
cpu=True,
Expand Down Expand Up @@ -1297,14 +1290,8 @@ def mla_transformer_config(self) -> MLATransformerConfig:

@property
def _model_bridge(self) -> "MegatronModelBridge":
hf_config = getattr(self.hf_pretrained, "hf_config", None)
if hf_config is None:
if isinstance(self.hf_pretrained, PreTrainedCausalLM):
hf_config = self.hf_pretrained.config
else:
hf_config = self.hf_pretrained

return model_bridge.get_model_bridge(self._causal_lm_architecture, hf_config=hf_config)
# Pass the full HF context so bridge helpers can reach config-only and state-backed metadata.
return model_bridge.get_model_bridge(self._causal_lm_architecture, hf_config=self.hf_pretrained)

@property
def _provider_bridge_input(self) -> PreTrainedCausalLM | _ConfigOnlyPretrainedShim:
Expand Down
10 changes: 8 additions & 2 deletions src/megatron/bridge/models/conversion/model_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,9 @@ def build_conversion_tasks(
if not (hasattr(hf_pretrained, "state") and hasattr(hf_pretrained.state, "source")):
raise ValueError("hf_pretrained.state.source is required for weight ordering")

self.hf_pretrained = hf_pretrained
self.hf_config = hf_pretrained.config if hasattr(hf_pretrained, "config") else hf_pretrained

hf_keys: Iterable[str] = hf_pretrained.state.source.get_all_keys()

mapping_registry = self.mapping_registry()
Expand Down Expand Up @@ -1524,7 +1527,9 @@ def register_bridge_implementation(
def _get_model_bridge_impl(_, hf_config=None) -> "MegatronModelBridge":
bridge = bridge_class()
if hf_config is not None:
bridge.hf_config = hf_config
# `hf_config` may be a raw config or a full HF wrapper; normalize both onto the bridge.
bridge.hf_pretrained = hf_config
bridge.hf_config = hf_config.config if hasattr(hf_config, "config") else hf_config
return bridge

@stream_weights_megatron_to_hf.impl((source, target))
Expand All @@ -1539,7 +1544,8 @@ def _megatron_to_hf_registered_impl(
) -> Iterable[HFWeightTuple]:
bridge = bridge_class()

# allow bridge to access model config (config-only shims or raw configs lack .config)
# allow bridge to access model config
bridge.hf_pretrained = hf_pretrained
bridge.hf_config = hf_pretrained.config if hasattr(hf_pretrained, "config") else hf_pretrained

return bridge.stream_weights_megatron_to_hf(
Expand Down
33 changes: 12 additions & 21 deletions src/megatron/bridge/models/glm/glm45_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from functools import partial

import torch
Expand Down Expand Up @@ -43,9 +42,6 @@
HAVE_TE = False


logger = logging.getLogger(__name__)


@MegatronModelBridge.register_bridge(source=Glm4MoeForCausalLM, target=GPTModel, model_type="glm4_moe")
class GLM45Bridge(MegatronModelBridge):
"""
Expand Down Expand Up @@ -99,14 +95,6 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider

return provider

def build_conversion_tasks(self, hf_pretrained, megatron_model):
"""Override to store config before mapping_registry is called."""
# Store config on instance for use in mapping_registry
self._hf_config = hf_pretrained.config
self._hf_state_source = hf_pretrained.state.source
self._hf_keys = list(self._hf_state_source.get_all_keys())
return super().build_conversion_tasks(hf_pretrained, megatron_model)

def mapping_registry(self) -> MegatronMappingRegistry:
mapping_list = []
use_fused_experts = self._uses_fused_experts()
Expand Down Expand Up @@ -204,11 +192,8 @@ def mapping_registry(self) -> MegatronMappingRegistry:
),
]
)
# optionally add MTP mappings
if not hasattr(self, "_hf_config"):
logger.warning("No HF config found, skipping MTP mappings.")
return MegatronMappingRegistry(*mapping_list)
hf_config = self._hf_config
# add MTP mappings
hf_config = self.hf_config
num_mtp_layers = getattr(hf_config, "num_nextn_predict_layers", 0)
num_transformer_layers = hf_config.num_hidden_layers
for mtp_layer in range(num_mtp_layers):
Expand Down Expand Up @@ -318,26 +303,32 @@ def mapping_registry(self) -> MegatronMappingRegistry:

return MegatronMappingRegistry(*mapping_list)

def _hf_source_and_keys(self):
"""Return HF state source and cached key order for expert-mapping helpers."""
hf_source = self.hf_pretrained.state.source
if getattr(self, "_cached_hf_state_source", None) is not hf_source:
self._cached_hf_state_source = hf_source
self._cached_hf_keys = hf_source.get_all_keys()
return hf_source, self._cached_hf_keys

def _uses_fused_experts(self) -> bool:
hf_keys = getattr(self, "_hf_keys", None)
hf_source, hf_keys = self._hf_source_and_keys()
if hf_keys:
if any("mlp.experts.gate_up_proj" in key for key in hf_keys) or any(
"mlp.experts.down_proj" in key for key in hf_keys
):
return True

hf_source = getattr(self, "_hf_state_source", None)
if hf_source is not None:
return hf_source.has_glob("*mlp.experts.gate_up_proj*") or hf_source.has_glob("*mlp.experts.down_proj*")

return False

def _hf_expert_suffix(self, base_name: str) -> str:
hf_keys = getattr(self, "_hf_keys", None) or []
hf_source, hf_keys = self._hf_source_and_keys()
if any(f"{base_name}.weight" in key for key in hf_keys):
return ".weight"

hf_source = getattr(self, "_hf_state_source", None)
if hf_source is not None and hf_source.has_glob(f"*{base_name}.weight"):
return ".weight"

Expand Down
21 changes: 10 additions & 11 deletions src/megatron/bridge/models/glm_vl/glm_45v_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,6 @@ def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> GLM45VModelProvider:
)
return provider

def build_conversion_tasks(self, hf_pretrained, megatron_model):
"""Override to store config before mapping_registry is called."""
self._hf_config = hf_pretrained.config
self._hf_state_source = hf_pretrained.state.source
self._hf_keys = list(self._hf_state_source.get_all_keys())
return super().build_conversion_tasks(hf_pretrained, megatron_model)

@classmethod
def get_hf_tokenizer_kwargs(cls) -> dict:
"""Return HuggingFace tokenizer kwargs specific to GLM 4.5V models.
Expand Down Expand Up @@ -211,26 +204,32 @@ def mapping_registry(self) -> MegatronMappingRegistry:
)
return MegatronMappingRegistry(*mapping_list)

def _hf_source_and_keys(self):
"""Return HF state source and cached key order for expert-mapping helpers."""
hf_source = self.hf_pretrained.state.source
if getattr(self, "_cached_hf_state_source", None) is not hf_source:
self._cached_hf_state_source = hf_source
self._cached_hf_keys = hf_source.get_all_keys()
return hf_source, self._cached_hf_keys

def _uses_fused_experts(self) -> bool:
hf_keys = getattr(self, "_hf_keys", None)
hf_source, hf_keys = self._hf_source_and_keys()
if hf_keys:
if any("mlp.experts.gate_up_proj" in key for key in hf_keys) or any(
"mlp.experts.down_proj" in key for key in hf_keys
):
return True

hf_source = getattr(self, "_hf_state_source", None)
if hf_source is not None:
return hf_source.has_glob("*mlp.experts.gate_up_proj*") or hf_source.has_glob("*mlp.experts.down_proj*")

return False

def _hf_expert_suffix(self, base_name: str) -> str:
hf_keys = getattr(self, "_hf_keys", None) or []
hf_source, hf_keys = self._hf_source_and_keys()
if any(f"{base_name}.weight" in key for key in hf_keys):
return ".weight"

hf_source = getattr(self, "_hf_state_source", None)
if hf_source is not None and hf_source.has_glob(f"*{base_name}.weight"):
return ".weight"

Expand Down
12 changes: 11 additions & 1 deletion tests/unit_tests/models/glm/test_glm45_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ def mock_pretrained_355b(self, glm45_355b_config):
m = Mock(spec=PreTrainedCausalLM)
m.config = cfg
m.generation_config = Mock(spec=GenerationConfig)
m.state = Mock()
m.state.source = Mock()
m.state.source.get_all_keys.return_value = []
m.state.source.has_glob.return_value = False
return m

@pytest.fixture
Expand All @@ -135,6 +139,10 @@ def mock_pretrained_air_106b(self, glm45_air_106b_config):
m = Mock(spec=PreTrainedCausalLM)
m.config = cfg
m.generation_config = Mock(spec=GenerationConfig)
m.state = Mock()
m.state.source = Mock()
m.state.source.get_all_keys.return_value = []
m.state.source.has_glob.return_value = False
return m

def test_registration(self):
Expand Down Expand Up @@ -198,9 +206,11 @@ def test_provider_bridge_maps_config_air_106b(self, mock_pretrained_air_106b):
assert provider.bf16 is True
assert provider.params_dtype == torch.bfloat16

def test_mapping_registry_exists(self):
def test_mapping_registry_exists(self, mock_pretrained_355b):
"""Test that mapping registry is properly defined."""
bridge = GLM45Bridge()
bridge.hf_pretrained = mock_pretrained_355b
bridge.hf_config = mock_pretrained_355b.config
registry = bridge.mapping_registry()

# Verify registry has mappings
Expand Down
10 changes: 8 additions & 2 deletions tests/unit_tests/models/glm_vl/test_glm_45v_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,19 @@ def mock_hf_pretrained(mock_hf_config):
"""Create a mock HF pretrained VLM."""
pretrained = Mock(spec=PreTrainedVLM)
pretrained.config = mock_hf_config
pretrained.state = Mock()
pretrained.state.source = Mock()
pretrained.state.source.get_all_keys.return_value = []
pretrained.state.source.has_glob.return_value = False
return pretrained


@pytest.fixture
def glm_45v_bridge():
def glm_45v_bridge(mock_hf_pretrained):
"""Create a GLM45VBridge instance."""
return GLM45VBridge()
bridge = GLM45VBridge()
bridge.hf_pretrained = mock_hf_pretrained
return bridge


class TestGLM45VBridgeInitialization:
Expand Down
Loading
Loading