From 01599fc1519431aff93c4d56f775dc996ce5df1f Mon Sep 17 00:00:00 2001 From: Hollow Man Date: Thu, 13 Nov 2025 23:19:12 +0200 Subject: [PATCH 1/4] feat: Support PEFT weight mapping and merge LoRA adapters when export to hf Signed-off-by: Hollow Man --- .../bridge/models/conversion/auto_bridge.py | 12 ++ .../bridge/models/conversion/model_bridge.py | 110 +++++++++++++++++- .../bridge/models/conversion/param_mapping.py | 16 +-- .../bridge/models/conversion/utils.py | 4 +- src/megatron/bridge/peft/lora.py | 38 ++++-- .../models/test_model_bridge_lora.py | 105 +++++++++++++++++ 6 files changed, 265 insertions(+), 20 deletions(-) create mode 100644 tests/unit_tests/models/test_model_bridge_lora.py diff --git a/src/megatron/bridge/models/conversion/auto_bridge.py b/src/megatron/bridge/models/conversion/auto_bridge.py index 8ebcfe4849..2da1e7f4e9 100644 --- a/src/megatron/bridge/models/conversion/auto_bridge.py +++ b/src/megatron/bridge/models/conversion/auto_bridge.py @@ -314,6 +314,10 @@ def export_hf_weights( gathering of distributed tensors and format conversion. It's useful for streaming weight export or custom processing. All ranks get full tensors. + If the model contains LoRA adapters, they will be automatically merged + into the base weights before export. This ensures the exported model + contains the full fine-tuned weights. + Args: model: Megatron model instance or list of instances cpu: Whether to move tensors to CPU before yielding @@ -363,6 +367,10 @@ def save_hf_pretrained( and weights to a directory that can be loaded with HuggingFace's from_pretrained methods. + If the model contains LoRA adapters, they will be automatically merged + into the base weights before saving. This ensures the saved model + contains the full fine-tuned weights. + If the original model was loaded with trust_remote_code=True, any custom modeling files (e.g., modeling_*.py, configuration_*.py) will be preserved to ensure the saved model can be loaded properly. @@ -409,6 +417,10 @@ def save_hf_weights(self, model: list[MegatronModelT], path: str | Path, show_pr to handle large models efficiently without requiring all weights in memory at once. + If the model contains LoRA adapters, they will be automatically merged + into the base weights before saving. This ensures the saved weights + contain the full fine-tuned parameters. + The weights are gathered from distributed ranks and saved in the standard HuggingFace sharded format when the model is large. diff --git a/src/megatron/bridge/models/conversion/model_bridge.py b/src/megatron/bridge/models/conversion/model_bridge.py index 5cdedda69a..b9aa848abc 100644 --- a/src/megatron/bridge/models/conversion/model_bridge.py +++ b/src/megatron/bridge/models/conversion/model_bridge.py @@ -141,7 +141,13 @@ def _megatron_local_name_to_global( def _update_expert_number(param_name: str, param_type: str) -> str: """Update expert number from local to global for weight or bias parameters.""" - local_expert_number = int(param_name.split(f".{param_type}")[-1]) + suffix = param_name.split(f".{param_type}")[-1] + # Check if suffix contains a valid expert number + # (this can be missing from PEFT adapters weight) + if not suffix or not suffix.isdigit(): + # No expert number suffix, return original param_name + return param_name + local_expert_number = int(suffix) global_expert_number = num_experts_per_rank * ep_group.rank() + local_expert_number return param_name.replace( f".{param_type}{local_expert_number}", @@ -313,6 +319,8 @@ def _megatron_global_param_names_all_pp_ranks( global_param_name = _megatron_local_name_to_global( models_list, model_config, local_param_name, vp_stage ) + if self._is_adapter_param_name(global_param_name): + continue global_param_names.append(global_param_name) gathered_global_param_names = [None] * pp_group.size() @@ -625,6 +633,9 @@ def stream_weights_megatron_to_hf( task, converted_weights_dict ) # dict will be none except for one expert; # All ranks get the full tensor + + converted_weights_dict = self._merge_lora_adapter_weights(task, megatron_model, converted_weights_dict) + for hf_name, tensor in converted_weights_dict.items(): final_tensor = tensor.cpu() if cpu else tensor @@ -648,6 +659,86 @@ def stream_weights_megatron_to_hf( # Regular case - yield the tensor normally yield HFWeightTuple(hf_name, final_tensor) + def _merge_lora_adapter_weights( + self, + task: WeightConversionTask, + megatron_model: List[MegatronModel], + converted_weights_dict: Dict[str, torch.Tensor], + ) -> Dict[str, torch.Tensor]: + """Merge LoRA adapter weights back into the base tensor for HF export.""" + + if not converted_weights_dict: + return converted_weights_dict + + if not task.param_name.endswith(".to_wrap.weight") or task.megatron_module is None: + return converted_weights_dict + + # Get the LoRALinear wrapper by navigating up from to_wrap.weight + parent_name = task.param_name[: -len(".to_wrap.weight")] + try: + lora_module, _ = get_module_and_param_from_name(megatron_model, parent_name, task.vp_stage) + except ValueError: + return converted_weights_dict + + adapter = getattr(lora_module, "adapter", None) + to_wrap = getattr(lora_module, "to_wrap", None) + if adapter is None or to_wrap is None: + return converted_weights_dict + + required_attrs = ("linear_in", "linear_out", "alpha", "dim") + if not all(hasattr(adapter, attr) for attr in required_attrs): + return converted_weights_dict + + from megatron.bridge.peft.lora import LoRAMerge + from megatron.bridge.models.conversion.param_mapping import ColumnParallelMapping, RowParallelMapping + from megatron.bridge.peft.utils import HAVE_TE, TECL, TERL + from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear + + mapping_class = None + if to_wrap is not None: + # Determine which mapping to use based on the base layer's parallel type + if (HAVE_TE and any(isinstance(to_wrap, te_cls) for te_cls in TECL)) or isinstance(to_wrap, ColumnParallelLinear): + # Base layer is ColumnParallel, so use ColumnParallelMapping for linear_in + mapping_class = ColumnParallelMapping + elif (HAVE_TE and any(isinstance(to_wrap, te_cls) for te_cls in TERL)) or isinstance(to_wrap, RowParallelLinear): + # Base layer is RowParallel, so use RowParallelMapping for linear_in + mapping_class = RowParallelMapping + + # Gather LoRA adapter weights using the determined mapping class + if mapping_class is not None: + # Gather linear_in weights + linear_in_name = parent_name + ".adapter.linear_in.weight" + linear_in_mapping = mapping_class(linear_in_name, linear_in_name) + linear_in_dict = linear_in_mapping.megatron_to_hf(adapter.linear_in.weight, adapter.linear_in) + linear_in_weight = linear_in_dict.get(linear_in_name) if linear_in_dict else None + else: + # Non-parallel case: use weights directly + linear_in_weight = adapter.linear_in.weight + + # Always no parallel for linear_out + linear_out_weight = getattr(adapter.linear_out, "weight", None) + if linear_in_weight is None or linear_out_weight is None: + return converted_weights_dict + + alpha = adapter.alpha + dim = adapter.dim + merger = LoRAMerge() + + # All ranks get the gathered weights, so we can merge on all ranks + for hf_name, base_weight in list(converted_weights_dict.items()): + # Merge LoRA weights for each converted weight in the dict + base_device = base_weight.device + merged_weight = merger.merge( + base_weight, + linear_out_weight.to(base_device), + linear_in_weight.to(base_device), + alpha, + dim, + ) + converted_weights_dict[hf_name] = merged_weight + + return converted_weights_dict + def dtype_from_hf(self, config, default=None): """Extract torch dtype from a HuggingFace config. @@ -810,6 +901,16 @@ def _broadcast_shared_embeddings(self, megatron_model: Union[MegatronModel, List if hasattr(unwrapped_model, "output_layer"): unwrapped_model.output_layer.weight.data.copy_(embd_weights) + def _get_lora_unwrapped_name(self, megatron_param: str) -> str: + """Remove .to_wrap from LoRA parameter names. + """ + return megatron_param.replace(".to_wrap.", ".") + + def _is_adapter_param_name(self, param_name: str) -> bool: + """Return True if the parameter only belongs to a PEFT adapter.""" + + return ".adapter." in param_name + def build_conversion_tasks( self, hf_pretrained: HFPreTrained, @@ -858,7 +959,7 @@ def build_conversion_tasks( print_rank_0(f"WARNING: {global_name} not in global_names_index_dict") continue global_name_idx = global_names_index_dict[global_name] - mapping = mapping_registry.megatron_to_hf_lookup(global_name) + mapping = mapping_registry.megatron_to_hf_lookup(self._get_lora_unwrapped_name(global_name)) if not mapping: logger.warning(f"WARNING: No mapping found for megatron_param: {global_name}") @@ -897,8 +998,11 @@ def build_conversion_tasks( # Fill the remaining ones for pp communications for idx, global_name in enumerate(sorted_global_param_names_all_pp_ranks): - mapping = mapping_registry.megatron_to_hf_lookup(global_name) if tasks[idx] is None: + mapping = mapping_registry.megatron_to_hf_lookup(self._get_lora_unwrapped_name(global_name)) + # Skip tasks with no mapping found + if mapping is None: + continue # This is an exception here we pass in global name # we are not using global_name to extract module and weights # only use it for param mapping auto dispatch checks diff --git a/src/megatron/bridge/models/conversion/param_mapping.py b/src/megatron/bridge/models/conversion/param_mapping.py index b23b7f0f55..61325ae0be 100644 --- a/src/megatron/bridge/models/conversion/param_mapping.py +++ b/src/megatron/bridge/models/conversion/param_mapping.py @@ -1361,8 +1361,8 @@ def megatron_to_hf( config = self.broadcast_obj_from_pp_rank(None, "qkv_config") else: config = self._get_config(megatron_module) - # create shallow copy and remove non-picklable objects with max depth=2 - config = remove_non_pickleables(config, max_depth=2) + # create shallow copy and remove non-picklable objects with max depth=3 + config = remove_non_pickleables(config, max_depth=3) config = self.broadcast_obj_from_pp_rank(config, "qkv_config") # Delegate TP/PP gathering. @@ -1469,8 +1469,8 @@ def megatron_to_hf( config = self.broadcast_obj_from_pp_rank(None) else: config = self._get_config(megatron_module) - # create shallow copy and remove non-picklable objects with max depth=2 - config = remove_non_pickleables(config, max_depth=2) + # create shallow copy and remove non-picklable objects with max depth=3 + config = remove_non_pickleables(config, max_depth=3) config = self.broadcast_obj_from_pp_rank(config) d_inner_local = (config.mamba_num_heads * config.mamba_head_dim) // self.tp_size @@ -1577,8 +1577,8 @@ def megatron_to_hf( config = self.broadcast_obj_from_pp_rank(None) else: config = self._get_config(megatron_module) - # create shallow copy and remove non-picklable objects with max depth=2 - config = remove_non_pickleables(config, max_depth=2) + # create shallow copy and remove non-picklable objects with max depth=3 + config = remove_non_pickleables(config, max_depth=3) config = self.broadcast_obj_from_pp_rank(config) d_inner_local = (config.mamba_num_heads * config.mamba_head_dim) // self.tp_size @@ -1661,8 +1661,8 @@ def megatron_to_hf( config = self.broadcast_obj_from_pp_rank(None) else: config = self._get_config(megatron_module) - # create shallow copy and remove non-picklable objects with max depth=2 - config = remove_non_pickleables(config, max_depth=2) + # create shallow copy and remove non-picklable objects with max depth=3 + config = remove_non_pickleables(config, max_depth=3) config = self.broadcast_obj_from_pp_rank(config) # Delegate TP/PP gathering. diff --git a/src/megatron/bridge/models/conversion/utils.py b/src/megatron/bridge/models/conversion/utils.py index c2140f1586..5a66e71983 100644 --- a/src/megatron/bridge/models/conversion/utils.py +++ b/src/megatron/bridge/models/conversion/utils.py @@ -159,7 +159,7 @@ def try_get_param(parts): raise ValueError(f"Parameter '{param_name}' not found in model at VP stage {vp_stage}") -def remove_non_pickleables(obj, max_depth: int = 2, current_depth: int = 0): +def remove_non_pickleables(obj, max_depth: int = 3, current_depth: int = 0): """Remove non-pickleable objects from a configuration object recursively. This utility function identifies and removes objects that cannot be pickled for @@ -168,7 +168,7 @@ def remove_non_pickleables(obj, max_depth: int = 2, current_depth: int = 0): Args: obj: The object to clean - max_depth: Maximum recursion depth (default: 2) + max_depth: Maximum recursion depth (default: 3) current_depth: Current recursion depth (internal use) Returns: diff --git a/src/megatron/bridge/peft/lora.py b/src/megatron/bridge/peft/lora.py index 3a8b3496fd..ab42113ebc 100644 --- a/src/megatron/bridge/peft/lora.py +++ b/src/megatron/bridge/peft/lora.py @@ -159,6 +159,30 @@ class LoRAMerge(PEFT): Implements the LoRA weight merge for parameter-efficient fine-tuning. """ + def merge( + self, + base_weight: torch.Tensor, + linear_out: torch.Tensor, + linear_in: torch.Tensor, + alpha: int, + dim: int, + ) -> torch.Tensor: + """ + Merges the LoRA adapter weights with the base model weights. + + Args: + base_weight (torch.Tensor): The base model weights. + linear_out (torch.Tensor): LoRA's B matrix. + linear_in (torch.Tensor): LoRA's A matrix. + alpha (int): Weighting factor for the low-rank projection. + dim (int): Dimension of the low-rank projection space. + + Returns: + torch.Tensor: The merged weights. + """ + lora_weight = alpha / dim * (linear_out @ linear_in) + return base_weight + lora_weight + @torch.no_grad() def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optional[str] = None) -> nn.Module: """ @@ -176,13 +200,13 @@ def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optio if not isinstance(module, LoRALinear): return module logging.info(f"merging {(prefix if prefix else '') + '.' + (name if name else '')}") - base_weight = module.to_wrap.weight - lora_weight = ( - module.adapter.alpha - / module.adapter.dim - * module.adapter.linear_out.weight.to(base_weight.device) - @ module.adapter.linear_in.weight.to(base_weight.device) + base_device = module.to_wrap.weight.device + merged_weight = self.merge( + module.to_wrap.weight, + module.adapter.linear_out.weight.to(base_device), + module.adapter.linear_in.weight.to(base_device), + module.adapter.alpha, + module.adapter.dim, ) - merged_weight = base_weight + lora_weight module.to_wrap.weight.data = merged_weight return module diff --git a/tests/unit_tests/models/test_model_bridge_lora.py b/tests/unit_tests/models/test_model_bridge_lora.py new file mode 100644 index 0000000000..8ee2d09733 --- /dev/null +++ b/tests/unit_tests/models/test_model_bridge_lora.py @@ -0,0 +1,105 @@ +from types import SimpleNamespace +from unittest.mock import Mock + +import torch + +from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge, WeightConversionTask + + +class DummyBridge(MegatronModelBridge): + def provider_bridge(self, hf_pretrained): # pragma: no cover - not used in tests + return None + + def mapping_registry(self): # pragma: no cover - not used in tests + return MegatronMappingRegistry() + + +def _make_lora_module(alpha=8, dim=4): + linear_in = SimpleNamespace(weight=torch.eye(dim)) + linear_out = SimpleNamespace(weight=torch.eye(dim)) + adapter = SimpleNamespace(linear_in=linear_in, linear_out=linear_out, alpha=alpha, dim=dim) + base_linear = torch.nn.Linear(dim, dim, bias=False) + lora_module = SimpleNamespace(adapter=adapter, to_wrap=base_linear) + return lora_module + + +def test_merge_lora_adapter_weights_merges(monkeypatch): + bridge = DummyBridge() + base_weight = torch.zeros(4, 4) + converted = {"hf.weight": base_weight.clone()} + task = WeightConversionTask( + param_name="decoder.layers.0.mlp.linear_fc1.to_wrap.weight", + mapping=Mock(), + megatron_module=Mock(), + vp_stage=0, + ) + + lora_module = _make_lora_module(alpha=4, dim=4) + monkeypatch.setattr( + "megatron.bridge.models.conversion.model_bridge.get_module_and_param_from_name", + lambda *_, **__: (lora_module, None), + ) + monkeypatch.setattr("megatron.bridge.models.conversion.model_bridge.print_rank_0", lambda *_, **__: None) + monkeypatch.setattr("megatron.bridge.peft.utils.HAVE_TE", False) + + updated = bridge._merge_lora_adapter_weights(task, [Mock()], converted) + expected = base_weight + torch.eye(4) + torch.testing.assert_close(updated["hf.weight"], expected) + + +def test_merge_lora_adapter_weights_noop_without_adapter(monkeypatch): + bridge = DummyBridge() + converted = {"hf.weight": torch.ones(2, 2)} + task = WeightConversionTask( + param_name="decoder.layers.0.mlp.linear_fc1.weight", + mapping=Mock(), + megatron_module=Mock(), + ) + + updated = bridge._merge_lora_adapter_weights(task, [Mock()], converted) + torch.testing.assert_close(updated["hf.weight"], converted["hf.weight"]) + + +def test_global_param_names_skip_adapter(monkeypatch): + bridge = DummyBridge() + + class DummyGroup: + def size(self): + return 1 + + fake_param = torch.nn.Parameter(torch.zeros(1, 1)) + + class FakeModel: + def __init__(self): + self.config = SimpleNamespace() + + def named_parameters(self): + return [ + ("decoder.layers.0.mlp.adapter.linear_in.weight", fake_param), + ("decoder.layers.0.mlp.linear_fc1.to_wrap.weight", fake_param), + ] + + monkeypatch.setattr( + "megatron.bridge.models.conversion.model_bridge.parallel_state.get_pipeline_model_parallel_group", + lambda: DummyGroup(), + ) + monkeypatch.setattr( + "megatron.bridge.models.conversion.model_bridge.persistent_buffers", + lambda *_: [], + ) + monkeypatch.setattr( + "megatron.bridge.models.conversion.model_bridge._megatron_local_name_to_global", + lambda *_args, **_kwargs: _args[2], + ) + monkeypatch.setattr( + "megatron.bridge.models.conversion.model_bridge.unwrap_model", + lambda models: models if isinstance(models, list) else [models], + ) + monkeypatch.setattr( + "torch.distributed.all_gather_object", + lambda output, obj, group=None: output.__setitem__(0, obj), + ) + + names = bridge._megatron_global_param_names_all_pp_ranks([FakeModel()]) + assert names == ["decoder.layers.0.mlp.linear_fc1.to_wrap.weight"] From 8191df9d8cc2b44e9c8dab73dee3fff1830bff46 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Thu, 13 Nov 2025 21:51:27 -0800 Subject: [PATCH 2/4] lint Signed-off-by: Chen Cui --- .../bridge/models/conversion/model_bridge.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/megatron/bridge/models/conversion/model_bridge.py b/src/megatron/bridge/models/conversion/model_bridge.py index b9aa848abc..1a7285c5a3 100644 --- a/src/megatron/bridge/models/conversion/model_bridge.py +++ b/src/megatron/bridge/models/conversion/model_bridge.py @@ -689,18 +689,23 @@ def _merge_lora_adapter_weights( if not all(hasattr(adapter, attr) for attr in required_attrs): return converted_weights_dict - from megatron.bridge.peft.lora import LoRAMerge + from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear + from megatron.bridge.models.conversion.param_mapping import ColumnParallelMapping, RowParallelMapping + from megatron.bridge.peft.lora import LoRAMerge from megatron.bridge.peft.utils import HAVE_TE, TECL, TERL - from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear mapping_class = None if to_wrap is not None: # Determine which mapping to use based on the base layer's parallel type - if (HAVE_TE and any(isinstance(to_wrap, te_cls) for te_cls in TECL)) or isinstance(to_wrap, ColumnParallelLinear): + if (HAVE_TE and any(isinstance(to_wrap, te_cls) for te_cls in TECL)) or isinstance( + to_wrap, ColumnParallelLinear + ): # Base layer is ColumnParallel, so use ColumnParallelMapping for linear_in mapping_class = ColumnParallelMapping - elif (HAVE_TE and any(isinstance(to_wrap, te_cls) for te_cls in TERL)) or isinstance(to_wrap, RowParallelLinear): + elif (HAVE_TE and any(isinstance(to_wrap, te_cls) for te_cls in TERL)) or isinstance( + to_wrap, RowParallelLinear + ): # Base layer is RowParallel, so use RowParallelMapping for linear_in mapping_class = RowParallelMapping @@ -902,8 +907,7 @@ def _broadcast_shared_embeddings(self, megatron_model: Union[MegatronModel, List unwrapped_model.output_layer.weight.data.copy_(embd_weights) def _get_lora_unwrapped_name(self, megatron_param: str) -> str: - """Remove .to_wrap from LoRA parameter names. - """ + """Remove .to_wrap from LoRA parameter names.""" return megatron_param.replace(".to_wrap.", ".") def _is_adapter_param_name(self, param_name: str) -> bool: From 33eeac6c100d9cbb8b26940a9e6af4a144f79284 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Thu, 13 Nov 2025 21:51:44 -0800 Subject: [PATCH 3/4] copyright Signed-off-by: Chen Cui --- tests/unit_tests/models/test_model_bridge_lora.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/unit_tests/models/test_model_bridge_lora.py b/tests/unit_tests/models/test_model_bridge_lora.py index 8ee2d09733..bbcc5e59a0 100644 --- a/tests/unit_tests/models/test_model_bridge_lora.py +++ b/tests/unit_tests/models/test_model_bridge_lora.py @@ -1,3 +1,17 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from types import SimpleNamespace from unittest.mock import Mock From b2bb00a0e01112c2738b1865ca6e7cb65ae2f5c4 Mon Sep 17 00:00:00 2001 From: Hollow Man Date: Fri, 14 Nov 2025 14:58:26 +0200 Subject: [PATCH 4/4] Supress WARNING: xxx.adapter.xxx not in global_names_index_dict Signed-off-by: Hollow Man --- src/megatron/bridge/models/conversion/model_bridge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/megatron/bridge/models/conversion/model_bridge.py b/src/megatron/bridge/models/conversion/model_bridge.py index 1a7285c5a3..5f9d95a6d4 100644 --- a/src/megatron/bridge/models/conversion/model_bridge.py +++ b/src/megatron/bridge/models/conversion/model_bridge.py @@ -953,7 +953,7 @@ def build_conversion_tasks( for vp_stage, model in enumerate(megatron_model): # persistent buffers are part of the model's state_dict, but not the named_parameters, so we must include them here separately for local_name, _ in itertools.chain(model.named_parameters(), persistent_buffers(model)): - if "_extra_state" in local_name: + if "_extra_state" in local_name or self._is_adapter_param_name(local_name): continue local_name = self._unwrap_name(local_name)