Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/megatron/bridge/models/conversion/auto_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,10 @@ def export_hf_weights(
gathering of distributed tensors and format conversion. It's useful for
streaming weight export or custom processing. All ranks get full tensors.

If the model contains LoRA adapters, they will be automatically merged
into the base weights before export. This ensures the exported model
contains the full fine-tuned weights.

Args:
model: Megatron model instance or list of instances
cpu: Whether to move tensors to CPU before yielding
Expand Down Expand Up @@ -363,6 +367,10 @@ def save_hf_pretrained(
and weights to a directory that can be loaded with HuggingFace's
from_pretrained methods.

If the model contains LoRA adapters, they will be automatically merged
into the base weights before saving. This ensures the saved model
contains the full fine-tuned weights.

If the original model was loaded with trust_remote_code=True, any custom
modeling files (e.g., modeling_*.py, configuration_*.py) will be preserved
to ensure the saved model can be loaded properly.
Expand Down Expand Up @@ -409,6 +417,10 @@ def save_hf_weights(self, model: list[MegatronModelT], path: str | Path, show_pr
to handle large models efficiently without requiring all weights in memory
at once.

If the model contains LoRA adapters, they will be automatically merged
into the base weights before saving. This ensures the saved weights
contain the full fine-tuned parameters.

The weights are gathered from distributed ranks and saved in the standard
HuggingFace sharded format when the model is large.

Expand Down
116 changes: 112 additions & 4 deletions src/megatron/bridge/models/conversion/model_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,13 @@ def _megatron_local_name_to_global(

def _update_expert_number(param_name: str, param_type: str) -> str:
"""Update expert number from local to global for weight or bias parameters."""
local_expert_number = int(param_name.split(f".{param_type}")[-1])
suffix = param_name.split(f".{param_type}")[-1]
# Check if suffix contains a valid expert number
# (this can be missing from PEFT adapters weight)
if not suffix or not suffix.isdigit():
# No expert number suffix, return original param_name
return param_name
local_expert_number = int(suffix)
global_expert_number = num_experts_per_rank * ep_group.rank() + local_expert_number
return param_name.replace(
f".{param_type}{local_expert_number}",
Expand Down Expand Up @@ -313,6 +319,8 @@ def _megatron_global_param_names_all_pp_ranks(
global_param_name = _megatron_local_name_to_global(
models_list, model_config, local_param_name, vp_stage
)
if self._is_adapter_param_name(global_param_name):
continue
global_param_names.append(global_param_name)

gathered_global_param_names = [None] * pp_group.size()
Expand Down Expand Up @@ -625,6 +633,9 @@ def stream_weights_megatron_to_hf(
task, converted_weights_dict
) # dict will be none except for one expert;
# All ranks get the full tensor

converted_weights_dict = self._merge_lora_adapter_weights(task, megatron_model, converted_weights_dict)

for hf_name, tensor in converted_weights_dict.items():
final_tensor = tensor.cpu() if cpu else tensor

Expand All @@ -648,6 +659,91 @@ def stream_weights_megatron_to_hf(
# Regular case - yield the tensor normally
yield HFWeightTuple(hf_name, final_tensor)

def _merge_lora_adapter_weights(
self,
task: WeightConversionTask,
megatron_model: List[MegatronModel],
converted_weights_dict: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
"""Merge LoRA adapter weights back into the base tensor for HF export."""

if not converted_weights_dict:
return converted_weights_dict

if not task.param_name.endswith(".to_wrap.weight") or task.megatron_module is None:
return converted_weights_dict

# Get the LoRALinear wrapper by navigating up from to_wrap.weight
parent_name = task.param_name[: -len(".to_wrap.weight")]
try:
lora_module, _ = get_module_and_param_from_name(megatron_model, parent_name, task.vp_stage)
except ValueError:
return converted_weights_dict

adapter = getattr(lora_module, "adapter", None)
to_wrap = getattr(lora_module, "to_wrap", None)
if adapter is None or to_wrap is None:
return converted_weights_dict

required_attrs = ("linear_in", "linear_out", "alpha", "dim")
if not all(hasattr(adapter, attr) for attr in required_attrs):
return converted_weights_dict

from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear

from megatron.bridge.models.conversion.param_mapping import ColumnParallelMapping, RowParallelMapping
from megatron.bridge.peft.lora import LoRAMerge
from megatron.bridge.peft.utils import HAVE_TE, TECL, TERL

mapping_class = None
if to_wrap is not None:
# Determine which mapping to use based on the base layer's parallel type
if (HAVE_TE and any(isinstance(to_wrap, te_cls) for te_cls in TECL)) or isinstance(
to_wrap, ColumnParallelLinear
):
# Base layer is ColumnParallel, so use ColumnParallelMapping for linear_in
mapping_class = ColumnParallelMapping
elif (HAVE_TE and any(isinstance(to_wrap, te_cls) for te_cls in TERL)) or isinstance(
to_wrap, RowParallelLinear
):
# Base layer is RowParallel, so use RowParallelMapping for linear_in
mapping_class = RowParallelMapping

# Gather LoRA adapter weights using the determined mapping class
if mapping_class is not None:
# Gather linear_in weights
linear_in_name = parent_name + ".adapter.linear_in.weight"
linear_in_mapping = mapping_class(linear_in_name, linear_in_name)
linear_in_dict = linear_in_mapping.megatron_to_hf(adapter.linear_in.weight, adapter.linear_in)
linear_in_weight = linear_in_dict.get(linear_in_name) if linear_in_dict else None
else:
# Non-parallel case: use weights directly
linear_in_weight = adapter.linear_in.weight

# Always no parallel for linear_out
linear_out_weight = getattr(adapter.linear_out, "weight", None)
if linear_in_weight is None or linear_out_weight is None:
return converted_weights_dict

alpha = adapter.alpha
dim = adapter.dim
merger = LoRAMerge()

# All ranks get the gathered weights, so we can merge on all ranks
for hf_name, base_weight in list(converted_weights_dict.items()):
# Merge LoRA weights for each converted weight in the dict
base_device = base_weight.device
merged_weight = merger.merge(
base_weight,
linear_out_weight.to(base_device),
linear_in_weight.to(base_device),
alpha,
dim,
)
converted_weights_dict[hf_name] = merged_weight

return converted_weights_dict

def dtype_from_hf(self, config, default=None):
"""Extract torch dtype from a HuggingFace config.

Expand Down Expand Up @@ -810,6 +906,15 @@ def _broadcast_shared_embeddings(self, megatron_model: Union[MegatronModel, List
if hasattr(unwrapped_model, "output_layer"):
unwrapped_model.output_layer.weight.data.copy_(embd_weights)

def _get_lora_unwrapped_name(self, megatron_param: str) -> str:
"""Remove .to_wrap from LoRA parameter names."""
return megatron_param.replace(".to_wrap.", ".")

def _is_adapter_param_name(self, param_name: str) -> bool:
"""Return True if the parameter only belongs to a PEFT adapter."""

return ".adapter." in param_name

def build_conversion_tasks(
self,
hf_pretrained: HFPreTrained,
Expand Down Expand Up @@ -848,7 +953,7 @@ def build_conversion_tasks(
for vp_stage, model in enumerate(megatron_model):
# persistent buffers are part of the model's state_dict, but not the named_parameters, so we must include them here separately
for local_name, _ in itertools.chain(model.named_parameters(), persistent_buffers(model)):
if "_extra_state" in local_name:
if "_extra_state" in local_name or self._is_adapter_param_name(local_name):
continue

local_name = self._unwrap_name(local_name)
Expand All @@ -858,7 +963,7 @@ def build_conversion_tasks(
print_rank_0(f"WARNING: {global_name} not in global_names_index_dict")
continue
global_name_idx = global_names_index_dict[global_name]
mapping = mapping_registry.megatron_to_hf_lookup(global_name)
mapping = mapping_registry.megatron_to_hf_lookup(self._get_lora_unwrapped_name(global_name))

if not mapping:
logger.warning(f"WARNING: No mapping found for megatron_param: {global_name}")
Expand Down Expand Up @@ -897,8 +1002,11 @@ def build_conversion_tasks(

# Fill the remaining ones for pp communications
for idx, global_name in enumerate(sorted_global_param_names_all_pp_ranks):
mapping = mapping_registry.megatron_to_hf_lookup(global_name)
if tasks[idx] is None:
mapping = mapping_registry.megatron_to_hf_lookup(self._get_lora_unwrapped_name(global_name))
# Skip tasks with no mapping found
if mapping is None:
continue
# This is an exception here we pass in global name
# we are not using global_name to extract module and weights
# only use it for param mapping auto dispatch checks
Expand Down
16 changes: 8 additions & 8 deletions src/megatron/bridge/models/conversion/param_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,8 +1361,8 @@ def megatron_to_hf(
config = self.broadcast_obj_from_pp_rank(None, "qkv_config")
else:
config = self._get_config(megatron_module)
# create shallow copy and remove non-picklable objects with max depth=2
config = remove_non_pickleables(config, max_depth=2)
# create shallow copy and remove non-picklable objects with max depth=3
config = remove_non_pickleables(config, max_depth=3)
config = self.broadcast_obj_from_pp_rank(config, "qkv_config")

# Delegate TP/PP gathering.
Expand Down Expand Up @@ -1469,8 +1469,8 @@ def megatron_to_hf(
config = self.broadcast_obj_from_pp_rank(None)
else:
config = self._get_config(megatron_module)
# create shallow copy and remove non-picklable objects with max depth=2
config = remove_non_pickleables(config, max_depth=2)
# create shallow copy and remove non-picklable objects with max depth=3
config = remove_non_pickleables(config, max_depth=3)
config = self.broadcast_obj_from_pp_rank(config)

d_inner_local = (config.mamba_num_heads * config.mamba_head_dim) // self.tp_size
Expand Down Expand Up @@ -1577,8 +1577,8 @@ def megatron_to_hf(
config = self.broadcast_obj_from_pp_rank(None)
else:
config = self._get_config(megatron_module)
# create shallow copy and remove non-picklable objects with max depth=2
config = remove_non_pickleables(config, max_depth=2)
# create shallow copy and remove non-picklable objects with max depth=3
config = remove_non_pickleables(config, max_depth=3)
config = self.broadcast_obj_from_pp_rank(config)

d_inner_local = (config.mamba_num_heads * config.mamba_head_dim) // self.tp_size
Expand Down Expand Up @@ -1661,8 +1661,8 @@ def megatron_to_hf(
config = self.broadcast_obj_from_pp_rank(None)
else:
config = self._get_config(megatron_module)
# create shallow copy and remove non-picklable objects with max depth=2
config = remove_non_pickleables(config, max_depth=2)
# create shallow copy and remove non-picklable objects with max depth=3
config = remove_non_pickleables(config, max_depth=3)
config = self.broadcast_obj_from_pp_rank(config)

# Delegate TP/PP gathering.
Expand Down
4 changes: 2 additions & 2 deletions src/megatron/bridge/models/conversion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def try_get_param(parts):
raise ValueError(f"Parameter '{param_name}' not found in model at VP stage {vp_stage}")


def remove_non_pickleables(obj, max_depth: int = 2, current_depth: int = 0):
def remove_non_pickleables(obj, max_depth: int = 3, current_depth: int = 0):
"""Remove non-pickleable objects from a configuration object recursively.

This utility function identifies and removes objects that cannot be pickled for
Expand All @@ -168,7 +168,7 @@ def remove_non_pickleables(obj, max_depth: int = 2, current_depth: int = 0):

Args:
obj: The object to clean
max_depth: Maximum recursion depth (default: 2)
max_depth: Maximum recursion depth (default: 3)
current_depth: Current recursion depth (internal use)

Returns:
Expand Down
38 changes: 31 additions & 7 deletions src/megatron/bridge/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,30 @@ class LoRAMerge(PEFT):
Implements the LoRA weight merge for parameter-efficient fine-tuning.
"""

def merge(
self,
base_weight: torch.Tensor,
linear_out: torch.Tensor,
linear_in: torch.Tensor,
alpha: int,
dim: int,
) -> torch.Tensor:
"""
Merges the LoRA adapter weights with the base model weights.

Args:
base_weight (torch.Tensor): The base model weights.
linear_out (torch.Tensor): LoRA's B matrix.
linear_in (torch.Tensor): LoRA's A matrix.
alpha (int): Weighting factor for the low-rank projection.
dim (int): Dimension of the low-rank projection space.

Returns:
torch.Tensor: The merged weights.
"""
lora_weight = alpha / dim * (linear_out @ linear_in)
return base_weight + lora_weight

@torch.no_grad()
def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optional[str] = None) -> nn.Module:
"""
Expand All @@ -176,13 +200,13 @@ def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optio
if not isinstance(module, LoRALinear):
return module
logging.info(f"merging {(prefix if prefix else '') + '.' + (name if name else '')}")
base_weight = module.to_wrap.weight
lora_weight = (
module.adapter.alpha
/ module.adapter.dim
* module.adapter.linear_out.weight.to(base_weight.device)
@ module.adapter.linear_in.weight.to(base_weight.device)
base_device = module.to_wrap.weight.device
merged_weight = self.merge(
module.to_wrap.weight,
module.adapter.linear_out.weight.to(base_device),
module.adapter.linear_in.weight.to(base_device),
module.adapter.alpha,
module.adapter.dim,
)
merged_weight = base_weight + lora_weight
module.to_wrap.weight.data = merged_weight
return module
Loading