From fe2525a6ca8a8cb108ece7a9757b4d178cd25c33 Mon Sep 17 00:00:00 2001 From: yaoyu-33 Date: Thu, 8 Jan 2026 17:38:55 -0800 Subject: [PATCH] Attach hf config to auto bridge --- src/megatron/bridge/models/conversion/auto_bridge.py | 9 ++++++++- src/megatron/bridge/models/conversion/model_bridge.py | 6 ++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/megatron/bridge/models/conversion/auto_bridge.py b/src/megatron/bridge/models/conversion/auto_bridge.py index 9dc684e5ba..123019e693 100644 --- a/src/megatron/bridge/models/conversion/auto_bridge.py +++ b/src/megatron/bridge/models/conversion/auto_bridge.py @@ -919,7 +919,14 @@ def mla_transformer_config(self) -> MLATransformerConfig: @property def _model_bridge(self) -> "MegatronModelBridge": - return model_bridge.get_model_bridge(self._causal_lm_architecture) + hf_config = getattr(self.hf_pretrained, "hf_config", None) + if hf_config is None: + if isinstance(self.hf_pretrained, PreTrainedCausalLM): + hf_config = self.hf_pretrained.config + else: + hf_config = self.hf_pretrained + + return model_bridge.get_model_bridge(self._causal_lm_architecture, hf_config=hf_config) @property def _provider_bridge_input(self) -> PreTrainedCausalLM | _ConfigOnlyPretrainedShim: diff --git a/src/megatron/bridge/models/conversion/model_bridge.py b/src/megatron/bridge/models/conversion/model_bridge.py index 6889c8c15f..92c551db61 100644 --- a/src/megatron/bridge/models/conversion/model_bridge.py +++ b/src/megatron/bridge/models/conversion/model_bridge.py @@ -1486,7 +1486,7 @@ def is_tensor_parallel(param) -> bool: # Core dispatch functions @dispatch -def get_model_bridge(hf_architecture) -> "MegatronModelBridge": +def get_model_bridge(hf_architecture, hf_config=None) -> "MegatronModelBridge": """Get the appropriate model bridge for a given HuggingFace architecture.""" ... @@ -1522,8 +1522,10 @@ def register_bridge_implementation( bridge_class_name = bridge_class.__name__ @get_model_bridge.impl(source) - def _get_model_bridge_impl(_) -> "MegatronModelBridge": + def _get_model_bridge_impl(_, hf_config=None) -> "MegatronModelBridge": bridge = bridge_class() + if hf_config is not None: + bridge.hf_config = hf_config return bridge @stream_weights_megatron_to_hf.impl((source, target))