From 9bd9f68b10709cf0c06b3a085a7bd8c1e3b96e3d Mon Sep 17 00:00:00 2001 From: Hollow Man Date: Wed, 4 Mar 2026 17:03:02 +0200 Subject: [PATCH] Store hf_pretrained as properties of Megatron*Bridge classes So that downstream model bridges that need hf_pretrained configs information to build mapping_registry no longer need to override build_conversion_tasks (e.g. GLM 4.5 bridge). Signed-off-by: Hollow Man --- .../bridge/models/conversion/auto_bridge.py | 29 ++--- .../bridge/models/conversion/model_bridge.py | 10 +- .../bridge/models/glm/glm45_bridge.py | 33 ++---- .../bridge/models/glm_vl/glm_45v_bridge.py | 21 ++-- .../models/glm/test_glm45_bridge.py | 12 +- .../models/glm_vl/test_glm_45v_bridge.py | 10 +- tests/unit_tests/models/test_auto_bridge.py | 103 ++++++++++++------ .../models/test_model_bridge_lora.py | 16 ++- 8 files changed, 144 insertions(+), 90 deletions(-) diff --git a/src/megatron/bridge/models/conversion/auto_bridge.py b/src/megatron/bridge/models/conversion/auto_bridge.py index b70af56bf6..249da85395 100644 --- a/src/megatron/bridge/models/conversion/auto_bridge.py +++ b/src/megatron/bridge/models/conversion/auto_bridge.py @@ -389,9 +389,8 @@ def export_hf_weights( ... cpu=True ... )) """ - dispatch_instance = (self._causal_lm_architecture, self._get_model_instance(model)) - return model_bridge.stream_weights_megatron_to_hf( - dispatch_instance, + bridge = self._model_bridge + return bridge.stream_weights_megatron_to_hf( model, self.hf_pretrained, cpu=cpu, @@ -420,13 +419,8 @@ def export_adapter_weights( Yields: HFWeightTuple: Named tuples of (param_name, weight_tensor) for adapter parameters """ - dispatch_instance = (self._causal_lm_architecture, self._get_model_instance(model)) - return model_bridge.stream_adapter_weights_megatron_to_hf( - dispatch_instance, - model, - cpu=cpu, - show_progress=show_progress, - ) + bridge = self._model_bridge + return bridge.stream_adapter_weights_megatron_to_hf(model, cpu=cpu, show_progress=show_progress) def save_hf_adapter( self, @@ -664,9 +658,8 @@ def save_hf_weights( """ if dist.is_available() and dist.is_initialized(): dist.barrier() - dispatch_instance = (self._causal_lm_architecture, self._get_model_instance(model)) - generator = model_bridge.stream_weights_megatron_to_hf( - dispatch_instance, + bridge = self._model_bridge + generator = bridge.stream_weights_megatron_to_hf( model, self.hf_pretrained, cpu=True, @@ -1297,14 +1290,8 @@ def mla_transformer_config(self) -> MLATransformerConfig: @property def _model_bridge(self) -> "MegatronModelBridge": - hf_config = getattr(self.hf_pretrained, "hf_config", None) - if hf_config is None: - if isinstance(self.hf_pretrained, PreTrainedCausalLM): - hf_config = self.hf_pretrained.config - else: - hf_config = self.hf_pretrained - - return model_bridge.get_model_bridge(self._causal_lm_architecture, hf_config=hf_config) + # Pass the full HF context so bridge helpers can reach config-only and state-backed metadata. + return model_bridge.get_model_bridge(self._causal_lm_architecture, hf_config=self.hf_pretrained) @property def _provider_bridge_input(self) -> PreTrainedCausalLM | _ConfigOnlyPretrainedShim: diff --git a/src/megatron/bridge/models/conversion/model_bridge.py b/src/megatron/bridge/models/conversion/model_bridge.py index 795af457df..bffcfa6650 100644 --- a/src/megatron/bridge/models/conversion/model_bridge.py +++ b/src/megatron/bridge/models/conversion/model_bridge.py @@ -1307,6 +1307,9 @@ def build_conversion_tasks( if not (hasattr(hf_pretrained, "state") and hasattr(hf_pretrained.state, "source")): raise ValueError("hf_pretrained.state.source is required for weight ordering") + self.hf_pretrained = hf_pretrained + self.hf_config = hf_pretrained.config if hasattr(hf_pretrained, "config") else hf_pretrained + hf_keys: Iterable[str] = hf_pretrained.state.source.get_all_keys() mapping_registry = self.mapping_registry() @@ -1524,7 +1527,9 @@ def register_bridge_implementation( def _get_model_bridge_impl(_, hf_config=None) -> "MegatronModelBridge": bridge = bridge_class() if hf_config is not None: - bridge.hf_config = hf_config + # `hf_config` may be a raw config or a full HF wrapper; normalize both onto the bridge. + bridge.hf_pretrained = hf_config + bridge.hf_config = hf_config.config if hasattr(hf_config, "config") else hf_config return bridge @stream_weights_megatron_to_hf.impl((source, target)) @@ -1539,7 +1544,8 @@ def _megatron_to_hf_registered_impl( ) -> Iterable[HFWeightTuple]: bridge = bridge_class() - # allow bridge to access model config (config-only shims or raw configs lack .config) + # allow bridge to access model config + bridge.hf_pretrained = hf_pretrained bridge.hf_config = hf_pretrained.config if hasattr(hf_pretrained, "config") else hf_pretrained return bridge.stream_weights_megatron_to_hf( diff --git a/src/megatron/bridge/models/glm/glm45_bridge.py b/src/megatron/bridge/models/glm/glm45_bridge.py index 8dcec421f3..e49a8160bc 100644 --- a/src/megatron/bridge/models/glm/glm45_bridge.py +++ b/src/megatron/bridge/models/glm/glm45_bridge.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging from functools import partial import torch @@ -43,9 +42,6 @@ HAVE_TE = False -logger = logging.getLogger(__name__) - - @MegatronModelBridge.register_bridge(source=Glm4MoeForCausalLM, target=GPTModel, model_type="glm4_moe") class GLM45Bridge(MegatronModelBridge): """ @@ -99,14 +95,6 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider return provider - def build_conversion_tasks(self, hf_pretrained, megatron_model): - """Override to store config before mapping_registry is called.""" - # Store config on instance for use in mapping_registry - self._hf_config = hf_pretrained.config - self._hf_state_source = hf_pretrained.state.source - self._hf_keys = list(self._hf_state_source.get_all_keys()) - return super().build_conversion_tasks(hf_pretrained, megatron_model) - def mapping_registry(self) -> MegatronMappingRegistry: mapping_list = [] use_fused_experts = self._uses_fused_experts() @@ -204,11 +192,8 @@ def mapping_registry(self) -> MegatronMappingRegistry: ), ] ) - # optionally add MTP mappings - if not hasattr(self, "_hf_config"): - logger.warning("No HF config found, skipping MTP mappings.") - return MegatronMappingRegistry(*mapping_list) - hf_config = self._hf_config + # add MTP mappings + hf_config = self.hf_config num_mtp_layers = getattr(hf_config, "num_nextn_predict_layers", 0) num_transformer_layers = hf_config.num_hidden_layers for mtp_layer in range(num_mtp_layers): @@ -318,26 +303,32 @@ def mapping_registry(self) -> MegatronMappingRegistry: return MegatronMappingRegistry(*mapping_list) + def _hf_source_and_keys(self): + """Return HF state source and cached key order for expert-mapping helpers.""" + hf_source = self.hf_pretrained.state.source + if getattr(self, "_cached_hf_state_source", None) is not hf_source: + self._cached_hf_state_source = hf_source + self._cached_hf_keys = hf_source.get_all_keys() + return hf_source, self._cached_hf_keys + def _uses_fused_experts(self) -> bool: - hf_keys = getattr(self, "_hf_keys", None) + hf_source, hf_keys = self._hf_source_and_keys() if hf_keys: if any("mlp.experts.gate_up_proj" in key for key in hf_keys) or any( "mlp.experts.down_proj" in key for key in hf_keys ): return True - hf_source = getattr(self, "_hf_state_source", None) if hf_source is not None: return hf_source.has_glob("*mlp.experts.gate_up_proj*") or hf_source.has_glob("*mlp.experts.down_proj*") return False def _hf_expert_suffix(self, base_name: str) -> str: - hf_keys = getattr(self, "_hf_keys", None) or [] + hf_source, hf_keys = self._hf_source_and_keys() if any(f"{base_name}.weight" in key for key in hf_keys): return ".weight" - hf_source = getattr(self, "_hf_state_source", None) if hf_source is not None and hf_source.has_glob(f"*{base_name}.weight"): return ".weight" diff --git a/src/megatron/bridge/models/glm_vl/glm_45v_bridge.py b/src/megatron/bridge/models/glm_vl/glm_45v_bridge.py index 5f9c0ab8cf..28903b6372 100644 --- a/src/megatron/bridge/models/glm_vl/glm_45v_bridge.py +++ b/src/megatron/bridge/models/glm_vl/glm_45v_bridge.py @@ -88,13 +88,6 @@ def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> GLM45VModelProvider: ) return provider - def build_conversion_tasks(self, hf_pretrained, megatron_model): - """Override to store config before mapping_registry is called.""" - self._hf_config = hf_pretrained.config - self._hf_state_source = hf_pretrained.state.source - self._hf_keys = list(self._hf_state_source.get_all_keys()) - return super().build_conversion_tasks(hf_pretrained, megatron_model) - @classmethod def get_hf_tokenizer_kwargs(cls) -> dict: """Return HuggingFace tokenizer kwargs specific to GLM 4.5V models. @@ -211,26 +204,32 @@ def mapping_registry(self) -> MegatronMappingRegistry: ) return MegatronMappingRegistry(*mapping_list) + def _hf_source_and_keys(self): + """Return HF state source and cached key order for expert-mapping helpers.""" + hf_source = self.hf_pretrained.state.source + if getattr(self, "_cached_hf_state_source", None) is not hf_source: + self._cached_hf_state_source = hf_source + self._cached_hf_keys = hf_source.get_all_keys() + return hf_source, self._cached_hf_keys + def _uses_fused_experts(self) -> bool: - hf_keys = getattr(self, "_hf_keys", None) + hf_source, hf_keys = self._hf_source_and_keys() if hf_keys: if any("mlp.experts.gate_up_proj" in key for key in hf_keys) or any( "mlp.experts.down_proj" in key for key in hf_keys ): return True - hf_source = getattr(self, "_hf_state_source", None) if hf_source is not None: return hf_source.has_glob("*mlp.experts.gate_up_proj*") or hf_source.has_glob("*mlp.experts.down_proj*") return False def _hf_expert_suffix(self, base_name: str) -> str: - hf_keys = getattr(self, "_hf_keys", None) or [] + hf_source, hf_keys = self._hf_source_and_keys() if any(f"{base_name}.weight" in key for key in hf_keys): return ".weight" - hf_source = getattr(self, "_hf_state_source", None) if hf_source is not None and hf_source.has_glob(f"*{base_name}.weight"): return ".weight" diff --git a/tests/unit_tests/models/glm/test_glm45_bridge.py b/tests/unit_tests/models/glm/test_glm45_bridge.py index 6cd9d92cac..e047797cc0 100644 --- a/tests/unit_tests/models/glm/test_glm45_bridge.py +++ b/tests/unit_tests/models/glm/test_glm45_bridge.py @@ -122,6 +122,10 @@ def mock_pretrained_355b(self, glm45_355b_config): m = Mock(spec=PreTrainedCausalLM) m.config = cfg m.generation_config = Mock(spec=GenerationConfig) + m.state = Mock() + m.state.source = Mock() + m.state.source.get_all_keys.return_value = [] + m.state.source.has_glob.return_value = False return m @pytest.fixture @@ -135,6 +139,10 @@ def mock_pretrained_air_106b(self, glm45_air_106b_config): m = Mock(spec=PreTrainedCausalLM) m.config = cfg m.generation_config = Mock(spec=GenerationConfig) + m.state = Mock() + m.state.source = Mock() + m.state.source.get_all_keys.return_value = [] + m.state.source.has_glob.return_value = False return m def test_registration(self): @@ -198,9 +206,11 @@ def test_provider_bridge_maps_config_air_106b(self, mock_pretrained_air_106b): assert provider.bf16 is True assert provider.params_dtype == torch.bfloat16 - def test_mapping_registry_exists(self): + def test_mapping_registry_exists(self, mock_pretrained_355b): """Test that mapping registry is properly defined.""" bridge = GLM45Bridge() + bridge.hf_pretrained = mock_pretrained_355b + bridge.hf_config = mock_pretrained_355b.config registry = bridge.mapping_registry() # Verify registry has mappings diff --git a/tests/unit_tests/models/glm_vl/test_glm_45v_bridge.py b/tests/unit_tests/models/glm_vl/test_glm_45v_bridge.py index 348d94b79b..1cfa9f302a 100644 --- a/tests/unit_tests/models/glm_vl/test_glm_45v_bridge.py +++ b/tests/unit_tests/models/glm_vl/test_glm_45v_bridge.py @@ -83,13 +83,19 @@ def mock_hf_pretrained(mock_hf_config): """Create a mock HF pretrained VLM.""" pretrained = Mock(spec=PreTrainedVLM) pretrained.config = mock_hf_config + pretrained.state = Mock() + pretrained.state.source = Mock() + pretrained.state.source.get_all_keys.return_value = [] + pretrained.state.source.has_glob.return_value = False return pretrained @pytest.fixture -def glm_45v_bridge(): +def glm_45v_bridge(mock_hf_pretrained): """Create a GLM45VBridge instance.""" - return GLM45VBridge() + bridge = GLM45VBridge() + bridge.hf_pretrained = mock_hf_pretrained + return bridge class TestGLM45VBridgeInitialization: diff --git a/tests/unit_tests/models/test_auto_bridge.py b/tests/unit_tests/models/test_auto_bridge.py index a1ffa82d9d..8e437925c4 100644 --- a/tests/unit_tests/models/test_auto_bridge.py +++ b/tests/unit_tests/models/test_auto_bridge.py @@ -619,18 +619,13 @@ def test_export_hf_weights(self): mock_hf_model.config.architectures = ["LlamaForCausalLM"] mock_hf_model.config.auto_map = None - mock_megatron_model = [Mock()] - mock_megatron_model[0].module = None # No nested module + mock_megatron_model = [object()] - # Mock the export process - with patch( - "megatron.bridge.models.conversion.auto_bridge.model_bridge.stream_weights_megatron_to_hf" - ) as mock_bridge_state: - mock_weight_iter = [ - ("weight1", torch.randn(10, 10)), - ("weight2", torch.randn(5, 5)), - ] - mock_bridge_state.return_value = iter(mock_weight_iter) + with patch.object(AutoBridge, "_model_bridge", new_callable=PropertyMock) as mock_model_bridge_prop: + mock_model_bridge = Mock() + mock_weight_iter = [("weight1", torch.randn(10, 10)), ("weight2", torch.randn(5, 5))] + mock_model_bridge.stream_weights_megatron_to_hf.return_value = iter(mock_weight_iter) + mock_model_bridge_prop.return_value = mock_model_bridge with patch("megatron.bridge.models.conversion.auto_bridge.transformers") as mock_transformers: mock_arch_class = Mock() @@ -648,6 +643,48 @@ def test_export_hf_weights(self): assert weights[1][0] == "weight2" assert isinstance(weights[0][1], torch.Tensor) assert isinstance(weights[1][1], torch.Tensor) + mock_model_bridge.stream_weights_megatron_to_hf.assert_called_once_with( + mock_megatron_model, + mock_hf_model, + cpu=True, + show_progress=True, + conversion_tasks=None, + merge_adapter_weights=True, + ) + + def test_export_adapter_weights(self): + """Test exporting adapter weights from Megatron to HF format.""" + mock_hf_model = Mock(spec=PreTrainedCausalLM) + mock_hf_model.config = Mock() + mock_hf_model.config.architectures = ["LlamaForCausalLM"] + mock_hf_model.config.auto_map = None + + mock_megatron_model = [object()] + + with patch.object(AutoBridge, "_model_bridge", new_callable=PropertyMock) as mock_model_bridge_prop: + mock_model_bridge = Mock() + mock_weight_iter = [("adapter.weight", torch.randn(4, 4))] + mock_model_bridge.stream_adapter_weights_megatron_to_hf.return_value = iter(mock_weight_iter) + mock_model_bridge_prop.return_value = mock_model_bridge + + with patch("megatron.bridge.models.conversion.auto_bridge.transformers") as mock_transformers: + mock_arch_class = Mock() + mock_transformers.LlamaForCausalLM = mock_arch_class + + bridge = AutoBridge(mock_hf_model) + + with patch.object(AutoBridge, "_causal_lm_architecture", new_callable=PropertyMock) as mock_prop: + mock_prop.return_value = mock_arch_class + weights = list(bridge.export_adapter_weights(mock_megatron_model, cpu=False, show_progress=False)) + + assert len(weights) == 1 + assert weights[0][0] == "adapter.weight" + assert isinstance(weights[0][1], torch.Tensor) + mock_model_bridge.stream_adapter_weights_megatron_to_hf.assert_called_once_with( + mock_megatron_model, + cpu=False, + show_progress=False, + ) def test_get_causal_lm_architecture(self): """Test getting the CausalLM architecture class.""" @@ -1200,22 +1237,17 @@ def test_save_hf_weights_filters_quantizer_tensors(self, mock_get_rank, mock_is_ bridge.hf_pretrained = mock_hf_model with ( - patch.object( - AutoBridge, - "_causal_lm_architecture", - new_callable=PropertyMock, - return_value=Mock(), - ), - patch( - "megatron.bridge.models.conversion.auto_bridge.model_bridge.stream_weights_megatron_to_hf", - return_value=iter(weight_iter), - ), + patch.object(AutoBridge, "_model_bridge", new_callable=PropertyMock) as mock_model_bridge_prop, patch( "megatron.bridge.models.conversion.auto_bridge.is_quantized", return_value=True, ), patch("torch.save") as mock_torch_save, ): + mock_model_bridge = Mock() + mock_model_bridge.stream_weights_megatron_to_hf.return_value = iter(weight_iter) + mock_model_bridge_prop.return_value = mock_model_bridge + # Capture what save_generator receives by consuming the generator it's passed saved_pairs = [] @@ -1230,6 +1262,13 @@ def fake_save_generator(gen, *args, **kwargs): # Only the normal weight should have passed through to save_generator assert len(saved_pairs) == 1 assert saved_pairs[0][0] == "model.layers.0.self_attn.q_proj.weight" + mock_model_bridge.stream_weights_megatron_to_hf.assert_called_once_with( + mock_megatron_model, + mock_hf_model, + cpu=True, + show_progress=True, + merge_adapter_weights=True, + ) # The quantizer tensor should have been saved via torch.save sidecar mock_torch_save.assert_called_once() @@ -1265,23 +1304,25 @@ def test_save_hf_weights_no_sidecar_when_not_quantized( bridge.hf_pretrained = mock_hf_model with ( - patch.object( - AutoBridge, - "_causal_lm_architecture", - new_callable=PropertyMock, - return_value=Mock(), - ), - patch( - "megatron.bridge.models.conversion.auto_bridge.model_bridge.stream_weights_megatron_to_hf", - return_value=iter(weight_iter), - ), + patch.object(AutoBridge, "_model_bridge", new_callable=PropertyMock) as mock_model_bridge_prop, patch( "megatron.bridge.models.conversion.auto_bridge.is_quantized", return_value=False, ), patch("torch.save") as mock_torch_save, ): + mock_model_bridge = Mock() + mock_model_bridge.stream_weights_megatron_to_hf.return_value = iter(weight_iter) + mock_model_bridge_prop.return_value = mock_model_bridge + mock_source.save_generator = Mock() bridge.save_hf_weights(mock_megatron_model, "/tmp/output") + mock_model_bridge.stream_weights_megatron_to_hf.assert_called_once_with( + mock_megatron_model, + mock_hf_model, + cpu=True, + show_progress=True, + merge_adapter_weights=True, + ) mock_torch_save.assert_not_called() diff --git a/tests/unit_tests/models/test_model_bridge_lora.py b/tests/unit_tests/models/test_model_bridge_lora.py index 6641f72c4c..a5e8a830f8 100644 --- a/tests/unit_tests/models/test_model_bridge_lora.py +++ b/tests/unit_tests/models/test_model_bridge_lora.py @@ -377,6 +377,8 @@ def test_construct_adapters_names(): def test_build_adapter_conversion_tasks(monkeypatch): bridge = DummyBridge() + bridge.hf_pretrained = SimpleNamespace() + bridge.hf_config = bridge.hf_pretrained adapters_info = [ ( @@ -465,6 +467,8 @@ def megatron_to_hf(self, weight, module): def test_stream_adapter_weights_megatron_to_hf(monkeypatch): bridge = DummyBridge() + bridge.hf_pretrained = SimpleNamespace() + bridge.hf_config = bridge.hf_pretrained adapter_task = AdapterWeightConversionTask( global_base_prefix="decoder.layers.0.mlp.linear_fc1", @@ -512,7 +516,13 @@ def test_stream_adapter_weights_megatron_to_hf(monkeypatch): ) megatron_model = [SimpleNamespace(config=SimpleNamespace(num_moe_experts=0))] - weights = list(bridge.stream_adapter_weights_megatron_to_hf(megatron_model, cpu=False, show_progress=False)) + weights = list( + bridge.stream_adapter_weights_megatron_to_hf( + megatron_model, + cpu=False, + show_progress=False, + ) + ) assert len(weights) == 2 assert weights[0].param_name.endswith("lora_A.weight") assert weights[1].param_name.endswith("lora_B.weight") @@ -522,6 +532,8 @@ def test_stream_adapter_weights_megatron_to_hf(monkeypatch): def test_stream_adapter_weights_megatron_to_hf_qkv(monkeypatch): bridge = DummyBridge() + bridge.hf_pretrained = SimpleNamespace() + bridge.hf_config = bridge.hf_pretrained adapter_task = AdapterWeightConversionTask( global_base_prefix="decoder.layers.0.self_attn.linear_qkv", @@ -596,6 +608,8 @@ def test_stream_adapter_weights_megatron_to_hf_qkv(monkeypatch): def test_stream_adapter_weights_megatron_to_hf_fused_fc1(monkeypatch): bridge = DummyBridge() + bridge.hf_pretrained = SimpleNamespace() + bridge.hf_config = bridge.hf_pretrained adapter_task = AdapterWeightConversionTask( global_base_prefix="decoder.layers.0.mlp.linear_fc1",