From 37b2a2e4f99e7ad0cbffd9b2e792b2cb0c1ba850 Mon Sep 17 00:00:00 2001 From: Hollow Man Date: Sun, 7 Dec 2025 02:25:32 +0200 Subject: [PATCH 1/3] [LoRA] Fix LoRA merge and support CanonicalLoRA merge Previous LoRA merge contains several errors: - It didn't handle fused QKV/gate up correctly - The handling of PP gathering is problematic Signed-off-by: Hollow Man --- .../bridge/models/conversion/model_bridge.py | 475 +++++++++++++++--- .../bridge/models/conversion/param_mapping.py | 9 +- .../models/test_model_bridge_lora.py | 339 +++++++++++-- 3 files changed, 714 insertions(+), 109 deletions(-) diff --git a/src/megatron/bridge/models/conversion/model_bridge.py b/src/megatron/bridge/models/conversion/model_bridge.py index e191142bc8..69de7f5a99 100644 --- a/src/megatron/bridge/models/conversion/model_bridge.py +++ b/src/megatron/bridge/models/conversion/model_bridge.py @@ -17,6 +17,7 @@ import itertools import logging import re +from collections import defaultdict from dataclasses import dataclass from typing import ( Callable, @@ -44,7 +45,13 @@ from transformers.modeling_utils import PreTrainedModel from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry -from megatron.bridge.models.conversion.param_mapping import MegatronParamMapping +from megatron.bridge.models.conversion.param_mapping import ( + ColumnParallelMapping, + MegatronParamMapping, + ReplicatedMapping, + RowParallelMapping, + split_qkv_weights, +) from megatron.bridge.models.conversion.utils import ( extract_sort_key, get_module_and_param_from_name, @@ -52,6 +59,9 @@ ) from megatron.bridge.models.decorators.dispatch import dispatch from megatron.bridge.models.model_provider import ModelProviderMixin +from megatron.bridge.peft.canonical_lora import ModuleDict +from megatron.bridge.peft.lora import LoRAMerge +from megatron.bridge.peft.utils import get_adapter_attributes_from_linear from megatron.bridge.utils.common_utils import print_rank_0 @@ -63,6 +73,14 @@ MegatronModel = TypeVar("MegatronModel", bound=MegatronModule) _BridgeImplClass = TypeVar("_BridgeImplClass", bound="MegatronModelBridge") +ADAPTER_NAME_MAP = { + ".q_proj.weight": "adapter_q", + ".k_proj.weight": "adapter_k", + ".v_proj.weight": "adapter_v", + ".gate_proj.weight": "adapter_gate", + ".up_proj.weight": "adapter_up", +} + class MegatronWeightTuple(NamedTuple): """Tuple representing a Megatron model weight with its metadata.""" @@ -91,6 +109,7 @@ class WeightConversionTask(Generic[MappingT]): Attributes: param_name (str): *unwrapped, local* parameter name (no ``module.`` prefixes). + global_param_name (str): *unwrapped, global* parameter name (no ``module.`` prefixes). mapping (MappingT): Concrete :pyclass:`MegatronParamMapping` instance responsible for weight transformation and distribution. @@ -104,6 +123,7 @@ class WeightConversionTask(Generic[MappingT]): """ param_name: str + global_param_name: str mapping: MappingT pp_rank: Optional[int] = None vp_stage: Optional[int] = None @@ -111,6 +131,35 @@ class WeightConversionTask(Generic[MappingT]): param_weight: Optional[torch.Tensor] = None +@dataclass(frozen=True) +class AdapterWeightConversionTask: + """Task describing an adapter's LoRA weights for conversion or merging. + + The task reuses :class:`WeightConversionTask` to gather the adapter's + linear_in/linear_out weights (if they are tensor-parallel) and carries the + adapter metadata required by the merge step. + """ + + global_base_prefix: str + adapter_key: Optional[str] # For canonical LoRA only + alpha: int + dim: int + linear_in_task: WeightConversionTask + linear_out_task: WeightConversionTask + + +@dataclass(frozen=True) +class AdapterWeight: + """Materialized adapter weights ready for merge.""" + + global_base_prefix: str + adapter_key: Optional[str] # For canonical LoRA only + alpha: int + dim: int + linear_in_weight: torch.Tensor + linear_out_weight: torch.Tensor + + def _megatron_local_name_to_global( models: MegatronModule | List[MegatronModule], config: TransformerConfig, @@ -136,19 +185,14 @@ def _megatron_local_name_to_global( # EP ep_group = parallel_state.get_expert_model_parallel_group() - if ".mlp.experts.linear_fc" in param_name and get_pg_size(ep_group) > 1: + # For now adapters are not sharded across EP ranks + if ".mlp.experts.linear_fc" in param_name and get_pg_size(ep_group) > 1 and not ".adapter." in param_name: num_experts = config.num_moe_experts num_experts_per_rank = num_experts // ep_group.size() def _update_expert_number(param_name: str, param_type: str) -> str: """Update expert number from local to global for weight or bias parameters.""" - suffix = param_name.split(f".{param_type}")[-1] - # Check if suffix contains a valid expert number - # (this can be missing from PEFT adapters weight) - if not suffix or not suffix.isdigit(): - # No expert number suffix, return original param_name - return param_name - local_expert_number = int(suffix) + local_expert_number = int(param_name.split(f".{param_type}")[-1]) global_expert_number = num_experts_per_rank * ep_group.rank() + local_expert_number return param_name.replace( f".{param_type}{local_expert_number}", @@ -295,6 +339,17 @@ def mapping_registry(self): """ raise NotImplementedError("Subclass must implement mapping_registry method") + def _get_adapter_wrap_module( + self, local_base_prefix: str, megatron_model: Union[MegatronModel, List[MegatronModel]], vp_stage: int + ) -> tuple[Optional[torch.nn.Module], Optional[torch.nn.Module]]: + """Get the adapter and to_wrap modules from the parent name.""" + lora_module, _ = get_module_and_param_from_name(megatron_model, local_base_prefix, vp_stage) + adapter = getattr(lora_module, "adapter", None) + if adapter is None: + # For CanonicalLoRA module + lora_module, _ = get_module_and_param_from_name(megatron_model, local_base_prefix + ".to_wrap", vp_stage) + return getattr(lora_module, "adapter", None), getattr(lora_module, "to_wrap", None) + def _megatron_global_param_names_all_pp_ranks( self, megatron_model: Union[MegatronModel, List[MegatronModel]] ) -> List[str]: @@ -341,6 +396,207 @@ def _megatron_global_param_names_all_pp_ranks( return self._cached_param_names + def _megatron_global_adapters_info_all_pp_ranks( + self, megatron_model: Union[MegatronModel, List[MegatronModel]] + ) -> List[tuple[str, str, bool, bool, int, int, int, int]]: + """Get all adapters' information tuple: + (global_base_name, local_base_prefix, input_is_parallel, base_linear_is_parallel, alpha, dim, pp_rank, vp_stage) + across all pipeline parallel ranks.""" + # Cache the result after first call + if hasattr(self, "_cached_param_objects_adapter"): + return self._cached_param_objects_adapter + + # Compute the result + pp_group = parallel_state.get_pipeline_model_parallel_group() + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + model_config = unwrap_model(megatron_model)[0].config + global_param_objects = [] + + # Ensure megatron_model is a list for consistent handling + models_list = megatron_model if isinstance(megatron_model, list) else [megatron_model] + + for vp_stage, model in enumerate(models_list): + # persistent buffers are part of the model's state_dict, but not the named_parameters, so we must include them here separately + for local_param_name, _ in itertools.chain(model.named_parameters(), persistent_buffers(model)): + if "_extra_state" in local_param_name: + continue + local_param_name = self._unwrap_name(local_param_name) + global_param_name = _megatron_local_name_to_global( + models_list, model_config, local_param_name, vp_stage + ) + is_adapter_param = self._is_adapter_param_name(global_param_name) + # only collect linear_in.weight for deduplication + if not is_adapter_param or not global_param_name.endswith(".linear_in.weight"): + continue + local_base_prefix = local_param_name.partition(".adapter.")[0] + global_base_name = global_param_name[: -len(".linear_in.weight")] + adapter, to_wrap = self._get_adapter_wrap_module(local_base_prefix, models_list, vp_stage) + if isinstance(adapter, ModuleDict): + adapter_name = local_param_name.removeprefix(local_base_prefix + ".adapter.").split(".")[0] + adapter = adapter[adapter_name] + input_is_parallel, _, _, _, base_linear_is_parallel = get_adapter_attributes_from_linear(to_wrap) + global_param_objects.append( + ( + global_base_name, + local_base_prefix, + input_is_parallel, + base_linear_is_parallel, + adapter.alpha, + adapter.dim, + pp_rank, + vp_stage, + ) + ) + + gathered_global_param_objects = [None] * pp_group.size() + torch.distributed.all_gather_object(gathered_global_param_objects, global_param_objects, group=pp_group) + + # flatten the list, sort it and remove duplicates + # the order matters here, casually re-order will cause a hang. + flattened_names = list(set(sum(gathered_global_param_objects, []))) + + # the order cannot be changed, this sync for all ranks for conversion + # change this might cause a hang + gathered_global_param_objects = sorted(flattened_names, key=lambda x: extract_sort_key(x[0])) + + self._cached_param_objects_adapter = gathered_global_param_objects + + return gathered_global_param_objects + + def _construct_adapters_names(self, prefix: str, adapter_key: Optional[str]) -> tuple[str, str]: + linear_in_name, linear_out_name = prefix + ".adapter", prefix + ".adapter" + if adapter_key is not None: + linear_in_name += f".{adapter_key}" + linear_out_name += f".{adapter_key}" + linear_in_name += ".linear_in.weight" + linear_out_name += ".linear_out.weight" + return linear_in_name, linear_out_name + + def build_adapter_conversion_tasks( + self, megatron_model: Union[MegatronModel, List[MegatronModel]] + ) -> Dict[str, List[AdapterWeightConversionTask]]: + """Construct adapter merge tasks keyed by their base parameter. + + The returned dict is keyed by the *global* LoRA-wrapped parameter name + (e.g., ``decoder.layers.0.mlp.linear_fc1.to_wrap.weight``). Each value + contains the adapter tasks (canonical or regular) that should be + merged into that base weight. + """ + + models_list = megatron_model if isinstance(megatron_model, list) else [megatron_model] + + adapters_info = self._megatron_global_adapters_info_all_pp_ranks(models_list) + tasks_by_base: Dict[str, List[AdapterWeightConversionTask]] = defaultdict(list) + + for ( + global_base_name, + local_base_prefix, + input_is_parallel, + base_linear_is_parallel, + alpha, + dim, + pp_rank, + vp_stage, + ) in adapters_info: + # global_base_name example: decoder.layers.0.mlp.linear_fc1.adapter.adapter_q + global_base_prefix, _, adapter_suffix = global_base_name.partition(".adapter") + + adapter_key = None + if adapter_suffix: + key_token = adapter_suffix.split(".")[-1] + if key_token.startswith("adapter_"): + adapter_key = key_token + + global_linear_in_name, global_linear_out_name = self._construct_adapters_names( + global_base_prefix, adapter_key + ) + # In case the adapter doesn't exist locally, we use the global names + local_linear_in_name, local_linear_out_name = global_linear_in_name, global_linear_out_name + linear_in_module, linear_in_weight = None, None + linear_out_module, linear_out_weight = None, None + if parallel_state.get_pipeline_model_parallel_rank() == pp_rank: + adapter, _ = self._get_adapter_wrap_module(local_base_prefix, models_list, vp_stage) + if isinstance(adapter, ModuleDict): + adapter = adapter[adapter_key] + linear_in_module, linear_in_weight = adapter.linear_in, adapter.linear_in.weight + linear_out_module, linear_out_weight = adapter.linear_out, adapter.linear_out.weight + local_linear_in_name, local_linear_out_name = self._construct_adapters_names( + local_base_prefix, adapter_key + ) + + # Pick mapping strategies based on base layer parallelism + if base_linear_is_parallel: + linear_in_mapping_cls = RowParallelMapping if input_is_parallel else ColumnParallelMapping + linear_out_mapping_cls = ColumnParallelMapping + else: + linear_in_mapping_cls = ReplicatedMapping + linear_out_mapping_cls = ReplicatedMapping + + linear_in_task = WeightConversionTask( + param_name=local_linear_in_name, + global_param_name=global_linear_in_name, + # TODO: use some actual HF param name mapping + mapping=linear_in_mapping_cls(local_linear_in_name, local_linear_out_name), + pp_rank=pp_rank, + vp_stage=vp_stage, + megatron_module=linear_in_module, + param_weight=linear_in_weight, + ) + + linear_out_task = WeightConversionTask( + param_name=local_linear_out_name, + global_param_name=global_linear_out_name, + # TODO: use some actual HF param name mapping + mapping=linear_out_mapping_cls(local_linear_out_name, local_linear_out_name), + pp_rank=pp_rank, + vp_stage=vp_stage, + megatron_module=linear_out_module, + param_weight=linear_out_weight, + ) + + tasks_by_base[global_base_prefix].append( + AdapterWeightConversionTask( + global_base_prefix=global_base_prefix, + adapter_key=adapter_key, + alpha=alpha, + dim=dim, + linear_in_task=linear_in_task, + linear_out_task=linear_out_task, + ) + ) + + return tasks_by_base + + def materialize_adapter_weights(self, adapter_tasks: List[AdapterWeightConversionTask]) -> List[AdapterWeight]: + """Run adapter merge tasks to gather full adapter weights.""" + + materialized: List[AdapterWeight] = [] + for adapter_task in adapter_tasks: + mapping = adapter_task.linear_in_task.mapping + linear_in_dict = mapping.megatron_to_hf( + adapter_task.linear_in_task.param_weight, adapter_task.linear_in_task.megatron_module + ) + linear_in_weight = next(iter(linear_in_dict.values())) + + mapping = adapter_task.linear_out_task.mapping + linear_out_dict = mapping.megatron_to_hf( + adapter_task.linear_out_task.param_weight, adapter_task.linear_out_task.megatron_module + ) + linear_out_weight = next(iter(linear_out_dict.values())) + + materialized.append( + AdapterWeight( + global_base_prefix=adapter_task.global_base_prefix, + adapter_key=adapter_task.adapter_key, + alpha=adapter_task.alpha, + dim=adapter_task.dim, + linear_in_weight=linear_in_weight, + linear_out_weight=linear_out_weight, + ) + ) + + return materialized + def _with_progress_tracking(self, tasks, description: str, show_progress: bool = True): """Helper method to wrap an iterable with progress tracking. @@ -646,6 +902,9 @@ def stream_weights_megatron_to_hf( if conversion_tasks is None: conversion_tasks = self.build_conversion_tasks(hf_pretrained, megatron_model) + # Collect adapter conversion tasks + adapter_tasks_by_base = self.build_adapter_conversion_tasks(megatron_model) + megatron_to_hf_tasks = conversion_tasks unwrapped_model = unwrap_model(megatron_model)[0] model_config = unwrapped_model.config @@ -657,7 +916,18 @@ def stream_weights_megatron_to_hf( ) # dict will be none except for one expert; # All ranks get the full tensor - converted_weights_dict = self._merge_lora_adapter_weights(task, megatron_model, converted_weights_dict) + adapter_tasks = None + if "to_wrap.weight" in task.global_param_name: + task_global_base_prefix, _, _ = task.global_param_name.partition(".to_wrap.weight") + adapter_tasks = adapter_tasks_by_base.get(task_global_base_prefix) + if adapter_tasks: + adapter_weights = self.materialize_adapter_weights(adapter_tasks) + # Merge LoRA adapter weights back into the base tensor for HF export + converted_weights_dict = self._merge_lora_adapter_weights( + megatron_model, + converted_weights_dict, + adapter_weights, + ) for hf_name, tensor in converted_weights_dict.items(): final_tensor = tensor.cpu() if cpu else tensor @@ -684,84 +954,137 @@ def stream_weights_megatron_to_hf( def _merge_lora_adapter_weights( self, - task: WeightConversionTask, megatron_model: List[MegatronModel], converted_weights_dict: Dict[str, torch.Tensor], + adapter_weights: List[AdapterWeight], ) -> Dict[str, torch.Tensor]: """Merge LoRA adapter weights back into the base tensor for HF export.""" - if not converted_weights_dict: - return converted_weights_dict + # CanonicalLoRA case when adapter_keys are provided via adapter_weights + if len(adapter_weights) > 1 and all( + w.adapter_key in ADAPTER_NAME_MAP.values() for w in adapter_weights if w.adapter_key + ): + return self._merge_canonical_adapter_from_weights(converted_weights_dict, adapter_weights) + + assert len(adapter_weights) == 1, "Expected a single adapter weight for standard LoRA merging" + + adapter_weight = adapter_weights[0] + alpha, dim = adapter_weight.alpha, adapter_weight.dim + linear_in_weight, linear_out_weight = adapter_weight.linear_in_weight, adapter_weight.linear_out_weight + + # Check if this is a fused layer that gets split into multiple projections + # For fused FC1: splits into gate_proj and up_proj (2 parts) + # For fused QKV: splits into q_proj, k_proj, v_proj (3 parts, interleaved) + base_weight_shape = next(iter(converted_weights_dict.values())).shape + weight_names = converted_weights_dict.keys() + is_fused_fc1 = ( + len(weight_names) % 2 == 0 + and all("gate_proj" in name or "up_proj" in name for name in weight_names) + and linear_out_weight.shape[0] == 2 * base_weight_shape[0] + ) + is_fused_qkv = len(weight_names) == 3 and all( + "q_proj" in name or "k_proj" in name or "v_proj" in name for name in weight_names + ) - if not task.param_name.endswith(".to_wrap.weight") or task.megatron_module is None: - return converted_weights_dict + # For QKV, split using the same interleaving logic as the base weight + if is_fused_qkv: + # Use the same interleaving pattern as split_qkv_weights + q_out, k_out, v_out = split_qkv_weights(megatron_model[0].config, linear_out_weight) + qkv_linear_out_weights = { + "q_proj": q_out, + "k_proj": k_out, + "v_proj": v_out, + } + else: + qkv_linear_out_weights = None - # Get the LoRALinear wrapper by navigating up from to_wrap.weight - parent_name = task.param_name[: -len(".to_wrap.weight")] - try: - lora_module, _ = get_module_and_param_from_name(megatron_model, parent_name, task.vp_stage) - except ValueError: - return converted_weights_dict + # All ranks get the gathered weights, so we can merge on all ranks + for hf_name, base_weight in list(converted_weights_dict.items()): + # For fused layers, split linear_out_weight based on which projection we're merging + current_linear_out_weight = linear_out_weight + if is_fused_fc1: + split_size = linear_out_weight.shape[0] // 2 + if "gate_proj" in hf_name: + # FC1: first half for gate_proj + current_linear_out_weight = linear_out_weight[:split_size, :] + elif "up_proj" in hf_name: + # FC1: second half for up_proj + current_linear_out_weight = linear_out_weight[split_size:, :] + else: + raise ValueError(f"Unknown weight name: {hf_name}") + elif is_fused_qkv and qkv_linear_out_weights is not None: + # QKV: Use properly split weights based on interleaving pattern + if "q_proj" in hf_name: + current_linear_out_weight = qkv_linear_out_weights["q_proj"] + elif "k_proj" in hf_name: + current_linear_out_weight = qkv_linear_out_weights["k_proj"] + elif "v_proj" in hf_name: + current_linear_out_weight = qkv_linear_out_weights["v_proj"] + else: + raise ValueError(f"Unknown weight name: {hf_name}") - adapter = getattr(lora_module, "adapter", None) - to_wrap = getattr(lora_module, "to_wrap", None) - if adapter is None or to_wrap is None: - return converted_weights_dict - - required_attrs = ("linear_in", "linear_out", "alpha", "dim") - if not all(hasattr(adapter, attr) for attr in required_attrs): - return converted_weights_dict - - from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear - - from megatron.bridge.models.conversion.param_mapping import ColumnParallelMapping, RowParallelMapping - from megatron.bridge.peft.lora import LoRAMerge - from megatron.bridge.peft.utils import HAVE_TE, TECL, TERL - - mapping_class = None - if to_wrap is not None: - # Determine which mapping to use based on the base layer's parallel type - if (HAVE_TE and any(isinstance(to_wrap, te_cls) for te_cls in TECL)) or isinstance( - to_wrap, ColumnParallelLinear - ): - # Base layer is ColumnParallel, so use ColumnParallelMapping for linear_in - mapping_class = ColumnParallelMapping - elif (HAVE_TE and any(isinstance(to_wrap, te_cls) for te_cls in TERL)) or isinstance( - to_wrap, RowParallelLinear - ): - # Base layer is RowParallel, so use RowParallelMapping for linear_in - mapping_class = RowParallelMapping - - # Gather LoRA adapter weights using the determined mapping class - if mapping_class is not None: - # Gather linear_in weights - linear_in_name = parent_name + ".adapter.linear_in.weight" - linear_in_mapping = mapping_class(linear_in_name, linear_in_name) - linear_in_dict = linear_in_mapping.megatron_to_hf(adapter.linear_in.weight, adapter.linear_in) - linear_in_weight = linear_in_dict.get(linear_in_name) if linear_in_dict else None - else: - # Non-parallel case: use weights directly - linear_in_weight = adapter.linear_in.weight + # Merge LoRA weights for each converted weight in the dict + merged_weight = self._merge_single_adapter_weight( + base_weight, alpha, dim, linear_in_weight, current_linear_out_weight + ) + converted_weights_dict[hf_name] = merged_weight + + return converted_weights_dict + + def _merge_single_adapter_weight( + self, + base_weight: torch.Tensor, + alpha: int, + dim: int, + linear_in_weight: torch.Tensor, + linear_out_weight: torch.Tensor, + ) -> torch.Tensor: + """Merge a single adapter's weights with base weight. - # Always no parallel for linear_out - linear_out_weight = getattr(adapter.linear_out, "weight", None) - if linear_in_weight is None or linear_out_weight is None: - return converted_weights_dict + Args: + base_weight: Base weight tensor to merge with + alpha: Alpha value for the adapter + dim: Dimension of the adapter + linear_in_weight: Gathered linear_in weight + linear_out_weight: linear_out weight - alpha = adapter.alpha - dim = adapter.dim + Returns: + Merged weight tensor + """ merger = LoRAMerge() + base_device = base_weight.device + return merger.merge( + base_weight, + linear_out_weight.to(base_device), + linear_in_weight.to(base_device), + alpha, + dim, + ) - # All ranks get the gathered weights, so we can merge on all ranks - for hf_name, base_weight in list(converted_weights_dict.items()): - # Merge LoRA weights for each converted weight in the dict - base_device = base_weight.device - merged_weight = merger.merge( + def _merge_canonical_adapter_from_weights( + self, + converted_weights_dict: Dict[str, torch.Tensor], + adapter_weights: List[AdapterWeight], + ) -> Dict[str, torch.Tensor]: + """Merge CanonicalLoRA adapters using pre-materialized adapter weights.""" + adapter_lookup = {aw.adapter_key: aw for aw in adapter_weights} + + for hf_name, base_weight in converted_weights_dict.items(): + target_adapter = None + for suffix, adapter_key in ADAPTER_NAME_MAP.items(): + if hf_name.endswith(suffix): + target_adapter = adapter_lookup.get(adapter_key) + break + + if target_adapter is None: + raise ValueError(f"Adapter name mapping not found for {hf_name}") + + merged_weight = self._merge_single_adapter_weight( base_weight, - linear_out_weight.to(base_device), - linear_in_weight.to(base_device), - alpha, - dim, + target_adapter.alpha, + target_adapter.dim, + target_adapter.linear_in_weight, + target_adapter.linear_out_weight, ) converted_weights_dict[hf_name] = merged_weight @@ -1027,6 +1350,7 @@ def build_conversion_tasks( pp_rank=pp_rank, vp_stage=vp_stage, param_name=local_name, + global_param_name=global_name, megatron_module=local_module, param_weight=local_weights, mapping=mapping, @@ -1046,6 +1370,7 @@ def build_conversion_tasks( pp_rank=pp_rank, vp_stage=None, param_name=global_name, + global_param_name=global_name, megatron_module=None, param_weight=None, mapping=mapping, diff --git a/src/megatron/bridge/models/conversion/param_mapping.py b/src/megatron/bridge/models/conversion/param_mapping.py index c958822834..84adb91b89 100644 --- a/src/megatron/bridge/models/conversion/param_mapping.py +++ b/src/megatron/bridge/models/conversion/param_mapping.py @@ -171,6 +171,11 @@ def is_expert(self) -> bool: """ return ".mlp.experts.linear_fc" in self.megatron_param or ".mlp.experts.local_experts." in self.megatron_param + @property + def is_adapter(self) -> bool: + """Check if this mapping is for an adapter parameter.""" + return ".adapter." in self.megatron_param + def _resolve_names(self, captures: Tuple[str, ...]) -> Tuple[str, Union[str, Dict[str, str]]]: """Resolve wildcard patterns with captured values. @@ -852,7 +857,7 @@ def megatron_to_hf( gathered = self.gather_from_tp_ranks(megatron_weights) full_weights = torch.cat(gathered, dim=0) - if self.is_expert: + if self.is_expert and not self.is_adapter: return self.gather_from_ep_ranks(full_weights, megatron_module, self.hf_param) return {str(self.hf_param): full_weights} @@ -945,7 +950,7 @@ def megatron_to_hf( gathered = self.gather_from_tp_ranks(megatron_weights) full_weights = torch.cat(gathered, dim=1) - if self.is_expert: + if self.is_expert and not self.is_adapter: return self.gather_from_ep_ranks(full_weights, megatron_module, self.hf_param) return {str(self.hf_param): full_weights} diff --git a/tests/unit_tests/models/test_model_bridge_lora.py b/tests/unit_tests/models/test_model_bridge_lora.py index bbcc5e59a0..fde492c6c1 100644 --- a/tests/unit_tests/models/test_model_bridge_lora.py +++ b/tests/unit_tests/models/test_model_bridge_lora.py @@ -18,7 +18,13 @@ import torch from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry -from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge, WeightConversionTask +from megatron.bridge.models.conversion.model_bridge import ( + AdapterWeight, + AdapterWeightConversionTask, + MegatronModelBridge, + WeightConversionTask, +) +from megatron.bridge.models.conversion.param_mapping import ColumnParallelMapping, merge_qkv_weights class DummyBridge(MegatronModelBridge): @@ -29,50 +35,133 @@ def mapping_registry(self): # pragma: no cover - not used in tests return MegatronMappingRegistry() -def _make_lora_module(alpha=8, dim=4): - linear_in = SimpleNamespace(weight=torch.eye(dim)) - linear_out = SimpleNamespace(weight=torch.eye(dim)) - adapter = SimpleNamespace(linear_in=linear_in, linear_out=linear_out, alpha=alpha, dim=dim) - base_linear = torch.nn.Linear(dim, dim, bias=False) - lora_module = SimpleNamespace(adapter=adapter, to_wrap=base_linear) - return lora_module - - def test_merge_lora_adapter_weights_merges(monkeypatch): bridge = DummyBridge() base_weight = torch.zeros(4, 4) converted = {"hf.weight": base_weight.clone()} - task = WeightConversionTask( - param_name="decoder.layers.0.mlp.linear_fc1.to_wrap.weight", - mapping=Mock(), - megatron_module=Mock(), - vp_stage=0, + adapter_weight = AdapterWeight( + global_base_prefix="decoder.layers.0.mlp.linear_fc1", + adapter_key=None, + alpha=4, + dim=4, + linear_in_weight=torch.eye(4), + linear_out_weight=torch.eye(4), ) - lora_module = _make_lora_module(alpha=4, dim=4) - monkeypatch.setattr( - "megatron.bridge.models.conversion.model_bridge.get_module_and_param_from_name", - lambda *_, **__: (lora_module, None), - ) - monkeypatch.setattr("megatron.bridge.models.conversion.model_bridge.print_rank_0", lambda *_, **__: None) - monkeypatch.setattr("megatron.bridge.peft.utils.HAVE_TE", False) - - updated = bridge._merge_lora_adapter_weights(task, [Mock()], converted) + updated = bridge._merge_lora_adapter_weights([Mock(config=SimpleNamespace())], converted, [adapter_weight]) expected = base_weight + torch.eye(4) torch.testing.assert_close(updated["hf.weight"], expected) -def test_merge_lora_adapter_weights_noop_without_adapter(monkeypatch): +def test_merge_single_adapter_weight_matches_loramerge(): bridge = DummyBridge() - converted = {"hf.weight": torch.ones(2, 2)} - task = WeightConversionTask( - param_name="decoder.layers.0.mlp.linear_fc1.weight", - mapping=Mock(), - megatron_module=Mock(), + base = torch.zeros(2, 2) + linear_in = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + linear_out = torch.tensor([[5.0, 6.0], [7.0, 8.0]]) + + merged = bridge._merge_single_adapter_weight( + base, alpha=2, dim=2, linear_in_weight=linear_in, linear_out_weight=linear_out ) + expected = base + 2 / 2 * (linear_out @ linear_in) + torch.testing.assert_close(merged, expected) - updated = bridge._merge_lora_adapter_weights(task, [Mock()], converted) - torch.testing.assert_close(updated["hf.weight"], converted["hf.weight"]) + +def test_merge_lora_adapter_weights_fused_fc1(monkeypatch): + bridge = DummyBridge() + base = torch.zeros(4, 4) + converted = { + "decoder.layers.0.mlp.gate_proj.weight": base.clone(), + "decoder.layers.0.mlp.up_proj.weight": base.clone(), + } + + linear_out = torch.cat([torch.eye(4), 2 * torch.eye(4)], dim=0) + adapter_weight = AdapterWeight( + global_base_prefix="decoder.layers.0.mlp.linear_fc1", + adapter_key=None, + alpha=1, + dim=1, + linear_in_weight=torch.eye(4), + linear_out_weight=linear_out, + ) + + updated = bridge._merge_lora_adapter_weights([Mock(config=SimpleNamespace())], converted, [adapter_weight]) + torch.testing.assert_close(updated["decoder.layers.0.mlp.gate_proj.weight"], torch.eye(4)) + torch.testing.assert_close(updated["decoder.layers.0.mlp.up_proj.weight"], 2 * torch.eye(4)) + + +def test_merge_lora_adapter_weights_qkv_split(monkeypatch): + bridge = DummyBridge() + config = SimpleNamespace( + num_attention_heads=2, + num_query_groups=1, + kv_channels=None, + hidden_size=4, + attention_output_gate=False, + ) + megatron_model = [SimpleNamespace(config=config)] + converted = { + "q_proj.weight": torch.zeros(4, 4), + "k_proj.weight": torch.zeros(2, 4), + "v_proj.weight": torch.zeros(2, 4), + } + + q_weight = torch.eye(4) + k_weight = torch.ones(2, 4) + v_weight = torch.full((2, 4), 2.0) + linear_out = merge_qkv_weights(config, q_weight, k_weight, v_weight) + + adapter_weight = AdapterWeight( + global_base_prefix="decoder.layers.0.self_attention.linear_qkv", + adapter_key=None, + alpha=4, + dim=4, + linear_in_weight=torch.eye(4), + linear_out_weight=linear_out, + ) + + updated = bridge._merge_lora_adapter_weights(megatron_model, converted, [adapter_weight]) + torch.testing.assert_close(updated["q_proj.weight"], q_weight) + torch.testing.assert_close(updated["k_proj.weight"], k_weight) + torch.testing.assert_close(updated["v_proj.weight"], v_weight) + + +def test_merge_canonical_adapter_from_weights(monkeypatch): + bridge = DummyBridge() + converted = { + "decoder.layers.0.self_attn.q_proj.weight": torch.zeros(2, 2), + "decoder.layers.0.self_attn.k_proj.weight": torch.zeros(1, 2), + "decoder.layers.0.self_attn.v_proj.weight": torch.zeros(1, 2), + } + + adapter_q = AdapterWeight( + global_base_prefix="decoder.layers.0.self_attn.linear_qkv", + adapter_key="adapter_q", + alpha=1, + dim=1, + linear_in_weight=torch.eye(2), + linear_out_weight=torch.ones(2, 2), + ) + adapter_k = AdapterWeight( + global_base_prefix="decoder.layers.0.self_attn.linear_qkv", + adapter_key="adapter_k", + alpha=1, + dim=1, + linear_in_weight=torch.eye(2), + linear_out_weight=2 * torch.ones(1, 2), + ) + adapter_v = AdapterWeight( + global_base_prefix="decoder.layers.0.self_attn.linear_qkv", + adapter_key="adapter_v", + alpha=1, + dim=1, + linear_in_weight=torch.eye(2), + linear_out_weight=3 * torch.ones(1, 2), + ) + + updated = bridge._merge_canonical_adapter_from_weights(converted, [adapter_q, adapter_k, adapter_v]) + torch.testing.assert_close(updated["decoder.layers.0.self_attn.q_proj.weight"], torch.ones(2, 2)) + torch.testing.assert_close(updated["decoder.layers.0.self_attn.k_proj.weight"], 2 * torch.ones(1, 2)) + torch.testing.assert_close(updated["decoder.layers.0.self_attn.v_proj.weight"], 3 * torch.ones(1, 2)) def test_global_param_names_skip_adapter(monkeypatch): @@ -117,3 +206,189 @@ def named_parameters(self): names = bridge._megatron_global_param_names_all_pp_ranks([FakeModel()]) assert names == ["decoder.layers.0.mlp.linear_fc1.to_wrap.weight"] + + +def test_megatron_global_adapters_info_all_pp_ranks(monkeypatch): + bridge = DummyBridge() + + class DummyGroup: + def size(self): + return 1 + + class FakeAdapter: + def __init__(self): + self.linear_in = SimpleNamespace(weight=torch.ones(2, 2)) + self.linear_out = SimpleNamespace(weight=torch.ones(2, 2)) + self.alpha = 8 + self.dim = 2 + + class FakeModel: + def __init__(self): + self.config = SimpleNamespace() + param = torch.nn.Parameter(torch.zeros(2, 2)) + self._params = [ + ("decoder.layers.0.mlp.linear_fc1.adapter.linear_in.weight", param), + ("decoder.layers.0.mlp.linear_fc1.adapter.linear_out.weight", param), + ] + + def named_parameters(self): + return self._params + + monkeypatch.setattr( + "megatron.bridge.models.conversion.model_bridge.parallel_state.get_pipeline_model_parallel_group", + lambda: DummyGroup(), + ) + monkeypatch.setattr( + "megatron.bridge.models.conversion.model_bridge.parallel_state.get_pipeline_model_parallel_rank", + lambda: 0, + ) + monkeypatch.setattr( + "megatron.bridge.models.conversion.model_bridge.persistent_buffers", + lambda *_: [], + ) + monkeypatch.setattr( + "megatron.bridge.models.conversion.model_bridge._megatron_local_name_to_global", + lambda *_args, **_kwargs: _args[2], + ) + monkeypatch.setattr( + "megatron.bridge.models.conversion.model_bridge.unwrap_model", + lambda models: models if isinstance(models, list) else [models], + ) + monkeypatch.setattr( + "megatron.bridge.models.conversion.model_bridge.get_adapter_attributes_from_linear", + lambda *_args, **_kwargs: (True, None, None, None, False), + ) + monkeypatch.setattr( + "torch.distributed.all_gather_object", + lambda output, obj, group=None: output.__setitem__(0, obj), + ) + + adapter = FakeAdapter() + monkeypatch.setattr(bridge, "_get_adapter_wrap_module", lambda *_: (adapter, Mock())) + + info = bridge._megatron_global_adapters_info_all_pp_ranks([FakeModel()]) + assert len(info) == 1 + ( + global_base_name, + local_base_prefix, + input_is_parallel, + base_linear_is_parallel, + alpha, + dim, + pp_rank, + vp_stage, + ) = info[0] + assert global_base_name == "decoder.layers.0.mlp.linear_fc1.adapter" + assert local_base_prefix == "decoder.layers.0.mlp.linear_fc1" + assert input_is_parallel is True and base_linear_is_parallel is False + assert alpha == 8 and dim == 2 and pp_rank == 0 and vp_stage == 0 + + +def test_construct_adapters_names(): + bridge = DummyBridge() + linear_in, linear_out = bridge._construct_adapters_names("decoder.layers.0.mlp.linear_fc1", None) + assert linear_in == "decoder.layers.0.mlp.linear_fc1.adapter.linear_in.weight" + assert linear_out == "decoder.layers.0.mlp.linear_fc1.adapter.linear_out.weight" + + linear_in_k, linear_out_k = bridge._construct_adapters_names("decoder.layers.0.attn.q_proj", "adapter_q") + assert linear_in_k.endswith("adapter_q.linear_in.weight") + assert linear_out_k.endswith("adapter_q.linear_out.weight") + + +def test_build_adapter_conversion_tasks(monkeypatch): + bridge = DummyBridge() + + adapters_info = [ + ( + "decoder.layers.0.mlp.linear_fc1.adapter", + "decoder.layers.0.mlp.linear_fc1", + False, + False, + 4, + 8, + 0, + 0, + ) + ] + + adapter = SimpleNamespace( + linear_in=SimpleNamespace(weight=torch.ones(2, 2)), + linear_out=SimpleNamespace(weight=torch.ones(2, 2)), + alpha=4, + dim=8, + ) + + monkeypatch.setattr(bridge, "_megatron_global_adapters_info_all_pp_ranks", lambda *_: adapters_info) + monkeypatch.setattr(bridge, "_get_adapter_wrap_module", lambda *_: (adapter, Mock())) + monkeypatch.setattr( + "megatron.bridge.models.conversion.model_bridge.parallel_state.get_pipeline_model_parallel_rank", + lambda: 0, + ) + + tasks_by_base = bridge.build_adapter_conversion_tasks([Mock()]) + assert "decoder.layers.0.mlp.linear_fc1" in tasks_by_base + tasks = tasks_by_base["decoder.layers.0.mlp.linear_fc1"] + assert len(tasks) == 1 + task = tasks[0] + assert task.adapter_key is None + assert task.linear_in_task.param_weight.shape == torch.Size([2, 2]) + assert task.linear_out_task.param_weight.shape == torch.Size([2, 2]) + + +def test_materialize_adapter_weights(monkeypatch): + bridge = DummyBridge() + + class DummyMapping: + def __init__(self, payload): + self.payload = payload + + def megatron_to_hf(self, weight, module): + return {"hf": self.payload} + + adapter_tasks = [ + AdapterWeightConversionTask( + global_base_prefix="decoder.layers.0.mlp.linear_fc1", + adapter_key=None, + alpha=2, + dim=4, + linear_in_task=WeightConversionTask( + param_name="in_name", + global_param_name="in_name", + mapping=DummyMapping(torch.ones(2, 2)), + megatron_module=None, + param_weight=None, + ), + linear_out_task=WeightConversionTask( + param_name="out_name", + global_param_name="out_name", + mapping=DummyMapping(2 * torch.ones(2, 2)), + megatron_module=None, + param_weight=None, + ), + ) + ] + + materials = bridge.materialize_adapter_weights(adapter_tasks) + assert len(materials) == 1 + assert torch.all(materials[0].linear_in_weight == torch.ones(2, 2)) + assert torch.all(materials[0].linear_out_weight == 2 * torch.ones(2, 2)) + + +def test_column_parallel_mapping_skips_ep_gather_for_adapters(monkeypatch): + mapping = ColumnParallelMapping( + "decoder.layers.0.mlp.experts.linear_fc1.adapter.linear_in.weight", + "hf_param", + ) + + # Avoid distributed calls + monkeypatch.setattr(ColumnParallelMapping, "broadcast_from_pp_rank", lambda self, tensor, cache_key=None: tensor) + monkeypatch.setattr(ColumnParallelMapping, "gather_from_tp_ranks", lambda self, tensor: [tensor]) + monkeypatch.setattr(ColumnParallelMapping, "tp_size", property(lambda self: 1)) + + def _raise(*args, **kwargs): + raise AssertionError("gather_from_ep_ranks should not be called for adapters") + + monkeypatch.setattr(ColumnParallelMapping, "gather_from_ep_ranks", _raise) + + result = mapping.megatron_to_hf(torch.ones(2, 2), None) + torch.testing.assert_close(result["hf_param"], torch.ones(2, 2)) From 21c086e916e9463e52fef82ee4f616c825dc082d Mon Sep 17 00:00:00 2001 From: Hollow Man Date: Sun, 7 Dec 2025 13:05:39 +0200 Subject: [PATCH 2/3] update docstrings and abstraction Signed-off-by: Hollow Man --- .../bridge/models/conversion/model_bridge.py | 53 +++++++++++++++---- 1 file changed, 43 insertions(+), 10 deletions(-) diff --git a/src/megatron/bridge/models/conversion/model_bridge.py b/src/megatron/bridge/models/conversion/model_bridge.py index 69de7f5a99..6cd657242d 100644 --- a/src/megatron/bridge/models/conversion/model_bridge.py +++ b/src/megatron/bridge/models/conversion/model_bridge.py @@ -74,6 +74,7 @@ _BridgeImplClass = TypeVar("_BridgeImplClass", bound="MegatronModelBridge") ADAPTER_NAME_MAP = { + # Map HF weight suffixes (keys) to CanonicalLoRA adapter keys (values) ".q_proj.weight": "adapter_q", ".k_proj.weight": "adapter_k", ".v_proj.weight": "adapter_v", @@ -156,8 +157,8 @@ class AdapterWeight: adapter_key: Optional[str] # For canonical LoRA only alpha: int dim: int - linear_in_weight: torch.Tensor - linear_out_weight: torch.Tensor + linear_in_weight: MegatronWeightTuple + linear_out_weight: MegatronWeightTuple def _megatron_local_name_to_global( @@ -342,7 +343,17 @@ def mapping_registry(self): def _get_adapter_wrap_module( self, local_base_prefix: str, megatron_model: Union[MegatronModel, List[MegatronModel]], vp_stage: int ) -> tuple[Optional[torch.nn.Module], Optional[torch.nn.Module]]: - """Get the adapter and to_wrap modules from the parent name.""" + """Locate the adapter wrapper and its underlying module. + + Args: + local_base_prefix: Module prefix without the ``.adapter`` suffix (e.g. ``decoder.layers.0.mlp.linear_fc1``). + megatron_model: Single model or list of models indexed by virtual pipeline stage. + vp_stage: Virtual pipeline stage corresponding to the provided prefix. + + Returns: + A tuple ``(adapter, to_wrap)`` where ``adapter`` is the LoRA wrapper (or ``None`` if absent) + and ``to_wrap`` is the base linear module being wrapped. + """ lora_module, _ = get_module_and_param_from_name(megatron_model, local_base_prefix, vp_stage) adapter = getattr(lora_module, "adapter", None) if adapter is None: @@ -464,6 +475,17 @@ def _megatron_global_adapters_info_all_pp_ranks( return gathered_global_param_objects def _construct_adapters_names(self, prefix: str, adapter_key: Optional[str]) -> tuple[str, str]: + """Build linear_in/linear_out parameter names for an adapter. + + Args: + prefix: Base module prefix without any adapter suffix (global or local, depending on caller). + adapter_key: Optional adapter identifier used by CanonicalLoRA (e.g. ``adapter_q``). ``None`` for + standard single-adapter LoRA modules. + + Returns: + Tuple ``(linear_in_name, linear_out_name)`` containing the parameter names for the adapter's + input and output projection weights. + """ linear_in_name, linear_out_name = prefix + ".adapter", prefix + ".adapter" if adapter_key is not None: linear_in_name += f".{adapter_key}" @@ -576,13 +598,13 @@ def materialize_adapter_weights(self, adapter_tasks: List[AdapterWeightConversio linear_in_dict = mapping.megatron_to_hf( adapter_task.linear_in_task.param_weight, adapter_task.linear_in_task.megatron_module ) - linear_in_weight = next(iter(linear_in_dict.values())) + linear_in_tensor = next(iter(linear_in_dict.values())) mapping = adapter_task.linear_out_task.mapping linear_out_dict = mapping.megatron_to_hf( adapter_task.linear_out_task.param_weight, adapter_task.linear_out_task.megatron_module ) - linear_out_weight = next(iter(linear_out_dict.values())) + linear_out_tensor = next(iter(linear_out_dict.values())) materialized.append( AdapterWeight( @@ -590,8 +612,16 @@ def materialize_adapter_weights(self, adapter_tasks: List[AdapterWeightConversio adapter_key=adapter_task.adapter_key, alpha=adapter_task.alpha, dim=adapter_task.dim, - linear_in_weight=linear_in_weight, - linear_out_weight=linear_out_weight, + linear_in_weight=MegatronWeightTuple( + adapter_task.linear_in_task.param_name, + linear_in_tensor, + adapter_task.linear_in_task.vp_stage, + ), + linear_out_weight=MegatronWeightTuple( + adapter_task.linear_out_task.param_name, + linear_out_tensor, + adapter_task.linear_out_task.vp_stage, + ), ) ) @@ -970,7 +1000,10 @@ def _merge_lora_adapter_weights( adapter_weight = adapter_weights[0] alpha, dim = adapter_weight.alpha, adapter_weight.dim - linear_in_weight, linear_out_weight = adapter_weight.linear_in_weight, adapter_weight.linear_out_weight + linear_in_weight, linear_out_weight = ( + adapter_weight.linear_in_weight.weight, + adapter_weight.linear_out_weight.weight, + ) # Check if this is a fused layer that gets split into multiple projections # For fused FC1: splits into gate_proj and up_proj (2 parts) @@ -1083,8 +1116,8 @@ def _merge_canonical_adapter_from_weights( base_weight, target_adapter.alpha, target_adapter.dim, - target_adapter.linear_in_weight, - target_adapter.linear_out_weight, + target_adapter.linear_in_weight.weight, + target_adapter.linear_out_weight.weight, ) converted_weights_dict[hf_name] = merged_weight From 19c4c5e7a1f8714ce33585e06112a62f6cd3c573 Mon Sep 17 00:00:00 2001 From: Hollow Man Date: Sun, 7 Dec 2025 13:42:48 +0200 Subject: [PATCH 3/3] fix test cases Signed-off-by: Hollow Man --- .../models/test_model_bridge_lora.py | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/tests/unit_tests/models/test_model_bridge_lora.py b/tests/unit_tests/models/test_model_bridge_lora.py index fde492c6c1..b7b6e881b5 100644 --- a/tests/unit_tests/models/test_model_bridge_lora.py +++ b/tests/unit_tests/models/test_model_bridge_lora.py @@ -22,6 +22,7 @@ AdapterWeight, AdapterWeightConversionTask, MegatronModelBridge, + MegatronWeightTuple, WeightConversionTask, ) from megatron.bridge.models.conversion.param_mapping import ColumnParallelMapping, merge_qkv_weights @@ -44,8 +45,8 @@ def test_merge_lora_adapter_weights_merges(monkeypatch): adapter_key=None, alpha=4, dim=4, - linear_in_weight=torch.eye(4), - linear_out_weight=torch.eye(4), + linear_in_weight=MegatronWeightTuple("in", torch.eye(4), vp_stage=0), + linear_out_weight=MegatronWeightTuple("out", torch.eye(4), vp_stage=0), ) updated = bridge._merge_lora_adapter_weights([Mock(config=SimpleNamespace())], converted, [adapter_weight]) @@ -80,8 +81,8 @@ def test_merge_lora_adapter_weights_fused_fc1(monkeypatch): adapter_key=None, alpha=1, dim=1, - linear_in_weight=torch.eye(4), - linear_out_weight=linear_out, + linear_in_weight=MegatronWeightTuple("in", torch.eye(4), vp_stage=0), + linear_out_weight=MegatronWeightTuple("out", linear_out, vp_stage=0), ) updated = bridge._merge_lora_adapter_weights([Mock(config=SimpleNamespace())], converted, [adapter_weight]) @@ -115,8 +116,8 @@ def test_merge_lora_adapter_weights_qkv_split(monkeypatch): adapter_key=None, alpha=4, dim=4, - linear_in_weight=torch.eye(4), - linear_out_weight=linear_out, + linear_in_weight=MegatronWeightTuple("in", torch.eye(4), vp_stage=0), + linear_out_weight=MegatronWeightTuple("out", linear_out, vp_stage=0), ) updated = bridge._merge_lora_adapter_weights(megatron_model, converted, [adapter_weight]) @@ -138,24 +139,24 @@ def test_merge_canonical_adapter_from_weights(monkeypatch): adapter_key="adapter_q", alpha=1, dim=1, - linear_in_weight=torch.eye(2), - linear_out_weight=torch.ones(2, 2), + linear_in_weight=MegatronWeightTuple("in_q", torch.eye(2), vp_stage=0), + linear_out_weight=MegatronWeightTuple("out_q", torch.ones(2, 2), vp_stage=0), ) adapter_k = AdapterWeight( global_base_prefix="decoder.layers.0.self_attn.linear_qkv", adapter_key="adapter_k", alpha=1, dim=1, - linear_in_weight=torch.eye(2), - linear_out_weight=2 * torch.ones(1, 2), + linear_in_weight=MegatronWeightTuple("in_k", torch.eye(2), vp_stage=0), + linear_out_weight=MegatronWeightTuple("out_k", 2 * torch.ones(1, 2), vp_stage=0), ) adapter_v = AdapterWeight( global_base_prefix="decoder.layers.0.self_attn.linear_qkv", adapter_key="adapter_v", alpha=1, dim=1, - linear_in_weight=torch.eye(2), - linear_out_weight=3 * torch.ones(1, 2), + linear_in_weight=MegatronWeightTuple("in_v", torch.eye(2), vp_stage=0), + linear_out_weight=MegatronWeightTuple("out_v", 3 * torch.ones(1, 2), vp_stage=0), ) updated = bridge._merge_canonical_adapter_from_weights(converted, [adapter_q, adapter_k, adapter_v]) @@ -370,8 +371,8 @@ def megatron_to_hf(self, weight, module): materials = bridge.materialize_adapter_weights(adapter_tasks) assert len(materials) == 1 - assert torch.all(materials[0].linear_in_weight == torch.ones(2, 2)) - assert torch.all(materials[0].linear_out_weight == 2 * torch.ones(2, 2)) + assert torch.all(materials[0].linear_in_weight.weight == torch.ones(2, 2)) + assert torch.all(materials[0].linear_out_weight.weight == 2 * torch.ones(2, 2)) def test_column_parallel_mapping_skips_ep_gather_for_adapters(monkeypatch):