diff --git a/src/megatron/bridge/models/conversion/auto_bridge.py b/src/megatron/bridge/models/conversion/auto_bridge.py index 0251fb9ee3..d47bc43ed5 100644 --- a/src/megatron/bridge/models/conversion/auto_bridge.py +++ b/src/megatron/bridge/models/conversion/auto_bridge.py @@ -1008,7 +1008,14 @@ def mla_transformer_config(self) -> MLATransformerConfig: @property def _model_bridge(self) -> "MegatronModelBridge": - return model_bridge.get_model_bridge(self._causal_lm_architecture) + hf_config = getattr(self.hf_pretrained, "hf_config", None) + if hf_config is None: + if isinstance(self.hf_pretrained, PreTrainedCausalLM): + hf_config = self.hf_pretrained.config + else: + hf_config = self.hf_pretrained + + return model_bridge.get_model_bridge(self._causal_lm_architecture, hf_config=hf_config) @property def _provider_bridge_input(self) -> PreTrainedCausalLM | _ConfigOnlyPretrainedShim: diff --git a/src/megatron/bridge/models/conversion/model_bridge.py b/src/megatron/bridge/models/conversion/model_bridge.py index fa3ccafaf5..b8c92e6bf8 100644 --- a/src/megatron/bridge/models/conversion/model_bridge.py +++ b/src/megatron/bridge/models/conversion/model_bridge.py @@ -1060,7 +1060,7 @@ def is_tensor_parallel(param) -> bool: # Core dispatch functions @dispatch -def get_model_bridge(hf_architecture) -> "MegatronModelBridge": +def get_model_bridge(hf_architecture, hf_config=None) -> "MegatronModelBridge": """Get the appropriate model bridge for a given HuggingFace architecture.""" ... @@ -1108,8 +1108,10 @@ def register_bridge_implementation( bridge_class_name = bridge_class.__name__ @get_model_bridge.impl(source) - def _get_model_bridge_impl(_) -> "MegatronModelBridge": + def _get_model_bridge_impl(_, hf_config=None) -> "MegatronModelBridge": bridge = bridge_class() + if hf_config is not None: + bridge.hf_config = hf_config return bridge @stream_weights_megatron_to_hf.impl((source, target))