diff --git a/examples/conversion/adapter/stream_adapter_weights.py b/examples/conversion/adapter/stream_adapter_weights.py index 4b7444c2fb..f6665830da 100644 --- a/examples/conversion/adapter/stream_adapter_weights.py +++ b/examples/conversion/adapter/stream_adapter_weights.py @@ -51,6 +51,7 @@ from __future__ import annotations import argparse +import math import os from contextlib import contextmanager from pathlib import Path @@ -162,14 +163,19 @@ def configure_device(device_index: int = 0) -> torch.device: def calculate_required_world_size(args: argparse.Namespace) -> int: - """Compute the model-parallel product used to validate distributed setup.""" + """Compute the minimum world size compatible with the requested parallelism. - return ( - args.tensor_model_parallel_size - * args.pipeline_model_parallel_size - * args.expert_model_parallel_size - * args.expert_tensor_parallel_size + Megatron requires WORLD_SIZE to be divisible by both the dense TP/PP domain + and the expert ETP/EP/PP domain. Those domains reuse the same global ranks, + so the minimum compatible world size is their least common multiple instead + of the raw product of tp, pp, ep, and etp. + """ + + dense_model_parallel_size = args.tensor_model_parallel_size * args.pipeline_model_parallel_size + expert_model_parallel_size = ( + args.expert_tensor_parallel_size * args.expert_model_parallel_size * args.pipeline_model_parallel_size ) + return math.lcm(dense_model_parallel_size, expert_model_parallel_size) @contextmanager @@ -189,7 +195,7 @@ def distributed_context( raise RuntimeError( f"Requested world_size={required_world_size} from model-parallel settings " f"(tp={tp}, pp={pp}, ep={ep}, etp={etp}), but initialized world_size={world_size}. " - "Launch with torchrun --nproc_per_node equal to the product." + f"Launch with torchrun --nproc_per_node={required_world_size}." ) yield world_size return @@ -200,7 +206,7 @@ def distributed_context( if required_world_size > 1 and "WORLD_SIZE" not in os.environ: raise RuntimeError( "Distributed world size is greater than 1 but WORLD_SIZE is not set. " - "Launch with torchrun --nproc_per_node equal to the requested world size." + f"Launch with torchrun --nproc_per_node={required_world_size}." ) if "MASTER_ADDR" in os.environ and "MASTER_PORT" in os.environ: @@ -223,7 +229,7 @@ def distributed_context( raise RuntimeError( f"Requested world_size={required_world_size} from model-parallel settings " f"(tp={tp}, pp={pp}, ep={ep}, etp={etp}), but initialized world_size={world_size}. " - "Launch with torchrun --nproc_per_node equal to the product." + f"Launch with torchrun --nproc_per_node={required_world_size}." ) yield world_size finally: @@ -274,7 +280,7 @@ def stream_and_collect_adapters( ) for weight_name, tensor in generator: - adapter_state[weight_name] = tensor + adapter_state[weight_name] = tensor.clone() print_rank_0(f"Collected adapter tensor: {weight_name} with shape {tuple(tensor.shape)}") if not adapter_state: @@ -286,9 +292,7 @@ def stream_and_collect_adapters( def _normalize_base_weight_name(param_name: str) -> str: """Remove the 'base_layer' suffix emitted when merge_adapter_weights=False.""" - if param_name.endswith("base_layer.weight"): - return param_name[: -len("base_layer.weight")] + "weight" - return param_name + return param_name.replace(".base_layer.", ".") def collect_hf_state_dict( @@ -327,10 +331,14 @@ def merge_hf_lora_adapters( for name, tensor in adapter_state.items(): if name.endswith(".lora_A.weight"): - base_name = name[: -len(".lora_A.weight")] + ".weight" + base_name = name[: -len(".lora_A.weight")] + if base_name not in base_state and f"{base_name}.weight" in base_state: + base_name = f"{base_name}.weight" grouped.setdefault(base_name, {})["A"] = tensor elif name.endswith(".lora_B.weight"): - base_name = name[: -len(".lora_B.weight")] + ".weight" + base_name = name[: -len(".lora_B.weight")] + if base_name not in base_state and f"{base_name}.weight" in base_state: + base_name = f"{base_name}.weight" grouped.setdefault(base_name, {})["B"] = tensor scale = alpha / float(dim) @@ -395,7 +403,7 @@ def main() -> None: f"🧮 Model-parallel settings: tp={args.tensor_model_parallel_size}, " f"pp={args.pipeline_model_parallel_size}, " f"ep={args.expert_model_parallel_size}, etp={args.expert_tensor_parallel_size}. " - f"Expected world_size={required_world_size}." + f"Minimum example world_size={required_world_size}." ) print_rank_0(f"🔧 Loading Hugging Face model {args.hf_model_id} with bfloat16 weights...") diff --git a/src/megatron/bridge/models/conversion/model_bridge.py b/src/megatron/bridge/models/conversion/model_bridge.py index 795af457df..8b8410bc8b 100644 --- a/src/megatron/bridge/models/conversion/model_bridge.py +++ b/src/megatron/bridge/models/conversion/model_bridge.py @@ -1097,7 +1097,11 @@ def stream_weights_megatron_to_hf( final_tensor = tensor.cpu() if cpu else tensor if not merge_adapter_weights and "to_wrap.weight" in task.global_param_name: - hf_name = hf_name[: -len("weight")] + "base_layer.weight" + suffix_pos = hf_name.rfind(".") + if suffix_pos == -1: + hf_name = hf_name + ".base_layer" + else: + hf_name = hf_name[:suffix_pos] + ".base_layer" + hf_name[suffix_pos:] # Handle tied embeddings case # TODO(yuya): fix this hard coded naming diff --git a/src/megatron/bridge/models/conversion/param_mapping.py b/src/megatron/bridge/models/conversion/param_mapping.py index 877a72ecaa..c6d315a262 100644 --- a/src/megatron/bridge/models/conversion/param_mapping.py +++ b/src/megatron/bridge/models/conversion/param_mapping.py @@ -2676,12 +2676,26 @@ def merge_gdn_linear_weights( return in_proj -def split_gdn_linear_weights(provider: TransformerConfig, in_proj: torch.Tensor, tp_size: int = 1) -> torch.Tensor: - """Split GDN linear weights into QKVZ and BA.""" +def split_gdn_linear_weights( + provider: TransformerConfig, + in_proj: torch.Tensor, + tp_size: int = 1, + feature_dim: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Split GDN linear weights into QKVZ and BA. + + Args: + provider: Transformer config with GDN dimensions. + in_proj: Packed in-proj tensor. + tp_size: Tensor-parallel world size used for packing layout. + feature_dim: Trailing tensor dimension used for reshape/split. + Defaults to ``provider.hidden_size`` for base weights, but LoRA + paths can pass the adapter rank here. + """ assert tp_size >= 1, f"tp_size must be greater than 0, but got {tp_size=}" - hidden_size = provider.hidden_size + feature_dim = provider.hidden_size if feature_dim is None else feature_dim qk_head_dim = provider.linear_key_head_dim v_head_dim = provider.linear_value_head_dim num_qk_heads = provider.linear_num_key_heads @@ -2690,7 +2704,7 @@ def split_gdn_linear_weights(provider: TransformerConfig, in_proj: torch.Tensor, qk_dim_local_tp = qk_head_dim * num_qk_heads_local_tp v_dim_local_tp = v_head_dim * num_v_heads_local_tp - in_proj = in_proj.reshape(tp_size, -1, hidden_size) + in_proj = in_proj.reshape(tp_size, -1, feature_dim) q, k, v, z, b, a = torch.split( in_proj, [ @@ -2704,12 +2718,12 @@ def split_gdn_linear_weights(provider: TransformerConfig, in_proj: torch.Tensor, dim=1, ) - q, k, v, z, b, a = [weight.reshape(num_qk_heads, -1, hidden_size) for weight in [q, k, v, z, b, a]] + q, k, v, z, b, a = [weight.reshape(num_qk_heads, -1, feature_dim) for weight in [q, k, v, z, b, a]] qkvz = torch.cat([q, k, v, z], dim=1) ba = torch.cat([b, a], dim=1) - qkvz = qkvz.reshape(-1, hidden_size) - ba = ba.reshape(-1, hidden_size) + qkvz = qkvz.reshape(-1, feature_dim) + ba = ba.reshape(-1, feature_dim) assert qkvz.numel() + ba.numel() == in_proj.numel(), ( f"QKVZBA weights are not correctly split, {qkvz.numel()=}, {ba.numel()=}, {in_proj.numel()=}" @@ -2782,6 +2796,7 @@ def _split_gdn_grouped_to_separate( config: TransformerConfig, qkvz: torch.Tensor, ba: torch.Tensor, + feature_dim: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Convert head-grouped ``qkvz`` and ``ba`` tensors (as produced by :func:`split_gdn_linear_weights`) back into four flat tensors. @@ -2789,7 +2804,7 @@ def _split_gdn_grouped_to_separate( Returns: Tuple of (qkv, z, b, a) where each tensor has a flat per-component layout. """ - hidden_size = config.hidden_size + feature_dim = config.hidden_size if feature_dim is None else feature_dim qk_head_dim = config.linear_key_head_dim v_head_dim = config.linear_value_head_dim num_qk_heads = config.linear_num_key_heads @@ -2798,31 +2813,31 @@ def _split_gdn_grouped_to_separate( expected_qkvz_dim0 = num_qk_heads * (qk_head_dim * 2 + v_per_group * v_head_dim * 2) expected_ba_dim0 = num_qk_heads * v_per_group * 2 - if qkvz.ndim != 2 or qkvz.shape[0] != expected_qkvz_dim0 or qkvz.shape[1] != hidden_size: + if qkvz.ndim != 2 or qkvz.shape[0] != expected_qkvz_dim0 or qkvz.shape[1] != feature_dim: raise ValueError( - f"qkvz shape mismatch: expected ({expected_qkvz_dim0}, {hidden_size}), got {tuple(qkvz.shape)}" + f"qkvz shape mismatch: expected ({expected_qkvz_dim0}, {feature_dim}), got {tuple(qkvz.shape)}" ) - if ba.ndim != 2 or ba.shape[0] != expected_ba_dim0 or ba.shape[1] != hidden_size: - raise ValueError(f"ba shape mismatch: expected ({expected_ba_dim0}, {hidden_size}), got {tuple(ba.shape)}") + if ba.ndim != 2 or ba.shape[0] != expected_ba_dim0 or ba.shape[1] != feature_dim: + raise ValueError(f"ba shape mismatch: expected ({expected_ba_dim0}, {feature_dim}), got {tuple(ba.shape)}") # --- Split grouped QKVZ --- - qkvz_g = qkvz.reshape(num_qk_heads, -1, hidden_size) + qkvz_g = qkvz.reshape(num_qk_heads, -1, feature_dim) q_g, k_g, v_g, z_g = torch.split( qkvz_g, [qk_head_dim, qk_head_dim, v_per_group * v_head_dim, v_per_group * v_head_dim], dim=1, ) - q_flat = q_g.reshape(-1, hidden_size) - k_flat = k_g.reshape(-1, hidden_size) - v_flat = v_g.reshape(-1, hidden_size) - z_flat = z_g.reshape(-1, hidden_size) + q_flat = q_g.reshape(-1, feature_dim) + k_flat = k_g.reshape(-1, feature_dim) + v_flat = v_g.reshape(-1, feature_dim) + z_flat = z_g.reshape(-1, feature_dim) qkv = torch.cat([q_flat, k_flat, v_flat], dim=0) # --- Split grouped BA --- - ba_g = ba.reshape(num_qk_heads, -1, hidden_size) + ba_g = ba.reshape(num_qk_heads, -1, feature_dim) b_g, a_g = torch.split(ba_g, [v_per_group, v_per_group], dim=1) - b_flat = b_g.reshape(-1, hidden_size) - a_flat = a_g.reshape(-1, hidden_size) + b_flat = b_g.reshape(-1, feature_dim) + a_flat = a_g.reshape(-1, feature_dim) return qkv, z_flat, b_flat, a_flat diff --git a/src/megatron/bridge/models/conversion/peft_bridge.py b/src/megatron/bridge/models/conversion/peft_bridge.py index 61e8d440c4..b9108ccc44 100644 --- a/src/megatron/bridge/models/conversion/peft_bridge.py +++ b/src/megatron/bridge/models/conversion/peft_bridge.py @@ -30,6 +30,8 @@ ColumnParallelMapping, ReplicatedMapping, RowParallelMapping, + _split_gdn_grouped_to_separate, + split_gdn_linear_weights, split_qkv_weights, ) from megatron.bridge.models.conversion.utils import ( @@ -67,6 +69,8 @@ ".linear_out.weight": ".lora_B.weight", } +GDN_IN_PROJ_KEYS = ("in_proj_qkv", "in_proj_z", "in_proj_b", "in_proj_a") + @dataclass(frozen=True) class AdapterWeightConversionTask: @@ -97,7 +101,7 @@ def _select_hf_base_param_name(base_mapping, adapter_key: Optional[str], expecte hf_param = base_mapping.hf_param if isinstance(hf_param, str): - return hf_param if hf_param.endswith(expected_suffix) else None + return hf_param if hf_param.endswith(expected_suffix) or expected_suffix == ".weight" else None if isinstance(hf_param, dict): if adapter_key: @@ -109,7 +113,7 @@ def _select_hf_base_param_name(base_mapping, adapter_key: Optional[str], expecte # For fused qkv/gate_up case, we just need a placeholder here value = next(iter(hf_param.values())) - return value if value.endswith(expected_suffix) else None + return value if value.endswith(expected_suffix) or expected_suffix == ".weight" else None return None @@ -169,10 +173,16 @@ def _resolve_hf_adapter_param_name( # Strip expert layers numbering base_suffix = base_suffix.rstrip(digits) hf_base_name = _select_hf_base_param_name(base_mapping, adapter_key, base_suffix) - if hf_base_name is None or not hf_base_name.endswith(base_suffix): + if hf_base_name is None: return None - return hf_base_name[: -len(base_suffix)] + hf_suffix + if hf_base_name.endswith(base_suffix): + return hf_base_name[: -len(base_suffix)] + hf_suffix + + # Some HF base names (e.g., Qwen3.5 MoE expert gate_up_proj / down_proj) + # don't include a trailing ".weight". Allow LoRA suffix to be appended directly. + if base_suffix == ".weight": + return hf_base_name + hf_suffix def _get_base_hf_param_names_for_adapter( self, @@ -203,14 +213,15 @@ def _get_base_hf_param_names_for_adapter( def _make_lora_param_name(self, base_name: str, megatron_adapter_suffix: str) -> Optional[str]: """Translate a base HF weight name into its LoRA-specific counterpart.""" - if not base_name.endswith(".weight"): - return None - hf_suffix = MEGATRON_TO_HF_LORA_SUFFIX.get(megatron_adapter_suffix) if hf_suffix is None: return None - return base_name[: -len(".weight")] + hf_suffix + if base_name.endswith(".weight"): + return base_name[: -len(".weight")] + hf_suffix + + # Some HF base names (e.g., Qwen3.5 MoE expert gate_up_proj) omit ".weight". + return base_name + hf_suffix def _is_fused_qkv(self, hf_weight_names: Iterable[str]) -> bool: """Check whether the provided HF names correspond to a fused QKV weight.""" @@ -223,6 +234,16 @@ def _is_fused_qkv(self, hf_weight_names: Iterable[str]) -> bool: discovered = {token for name in names for token in required if token in name} return discovered == required + def _is_gdn_in_proj_split(self, hf_weight_names: Iterable[str]) -> bool: + """Check whether the provided HF names correspond to split GDN in_proj weights.""" + + names = list(hf_weight_names) + if len(names) != 4: + return False + required = set(GDN_IN_PROJ_KEYS) + discovered = {token for name in names for token in required if token in name} + return discovered == required and all("linear_attn" in name for name in names) + def _is_fused_fc1_gate_up( self, base_hf_weight_names: Iterable[str], @@ -261,6 +282,14 @@ def _infer_qkv_projection_from_name(self, hf_name: str) -> Optional[str]: return "v_proj" return None + def _infer_gdn_in_proj_projection_from_name(self, hf_name: str) -> Optional[str]: + """Return in_proj_qkv/z/b/a identifier based on the HF name.""" + + for projection_key in GDN_IN_PROJ_KEYS: + if projection_key in hf_name: + return projection_key + return None + def _infer_hf_expert_idx(self, hf_name: str) -> Optional[int]: """Return the expert index embedded in an HF MoE weight name.""" @@ -280,6 +309,120 @@ def _split_qkv_linear_out_weight( q_out, k_out, v_out = split_qkv_weights(model.config, linear_out_weight) return {"q_proj": q_out, "k_proj": k_out, "v_proj": v_out} + def _split_gdn_in_proj_linear_out_weight( + self, + megatron_model: Union[MegatronModel, List[MegatronModel]], + linear_out_weight: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + """Split a fused LoRA linear_out tensor for GDN in_proj adapters.""" + + model = megatron_model[0] if isinstance(megatron_model, list) else megatron_model + tp_size = parallel_state.get_tensor_model_parallel_world_size() + feature_dim = linear_out_weight.shape[1] + qkvz, ba = split_gdn_linear_weights( + model.config, + linear_out_weight, + tp_size=tp_size, + feature_dim=feature_dim, + ) + qkv, z, b, a = _split_gdn_grouped_to_separate(model.config, qkvz, ba, feature_dim=feature_dim) + return {"in_proj_qkv": qkv, "in_proj_z": z, "in_proj_b": b, "in_proj_a": a} + + def _build_lora_hf_names(self, base_hf_weight_names: List[str]) -> tuple[List[str], List[str]]: + """Build LoRA A/B names for a list of HF base parameter names.""" + + linear_in_hf_names = [ + self._make_lora_param_name(base_name, ".linear_in.weight") for base_name in base_hf_weight_names + ] + linear_out_hf_names = [ + self._make_lora_param_name(base_name, ".linear_out.weight") for base_name in base_hf_weight_names + ] + return linear_in_hf_names, linear_out_hf_names + + def _collect_packed_expert_adapter_tensors( + self, + linear_in_tensor: torch.Tensor, + linear_out_tensor: torch.Tensor, + expert_linear_in_gathered: Optional[List[torch.Tensor]], + expert_linear_out_gathered: Optional[List[torch.Tensor]], + num_moe_experts: int, + ) -> tuple[List[torch.Tensor], List[torch.Tensor]]: + """Collect one LoRA A/B tensor per expert for grouped expert exports.""" + + per_expert_linear_in: List[torch.Tensor] = [] + per_expert_linear_out: List[torch.Tensor] = [] + if linear_in_tensor.ndim > 2 or linear_out_tensor.ndim > 2: + # Already carries local expert dim; concatenate across EP ranks if needed. + linear_in_all = ( + torch.cat(expert_linear_in_gathered, dim=0) + if expert_linear_in_gathered is not None + else linear_in_tensor + ) + linear_out_all = ( + torch.cat(expert_linear_out_gathered, dim=0) + if expert_linear_out_gathered is not None + else linear_out_tensor + ) + per_expert_linear_in = list(linear_in_all) + per_expert_linear_out = list(linear_out_all) + return per_expert_linear_in, per_expert_linear_out + + for expert_idx in range(num_moe_experts): + per_expert_linear_in.append( + self._select_expert_adapter_weight( + linear_in_tensor, + expert_linear_in_gathered, + expert_idx, + num_moe_experts, + ) + ) + per_expert_linear_out.append( + self._select_expert_adapter_weight( + linear_out_tensor, + expert_linear_out_gathered, + expert_idx, + num_moe_experts, + ) + ) + return per_expert_linear_in, per_expert_linear_out + + def _build_packed_expert_linear_out_by_base( + self, + megatron_model: List[MegatronModel], + base_hf_weight_names: List[str], + per_expert_linear_out: List[torch.Tensor], + is_expert: bool, + ) -> Dict[str, torch.Tensor]: + """Build per-base stacked LoRA-B tensors for packed grouped-expert export.""" + + if not per_expert_linear_out: + return {} + + # Handle fused adapters (qkv/gate_up/gdn in_proj) by splitting per-expert then stacking. + per_base_linear_out = self._get_fused_adapter_linear_out_slices( + megatron_model, + base_hf_weight_names, + per_expert_linear_out[0], + is_expert=is_expert, + ) + if per_base_linear_out is None: + stacked = torch.stack(per_expert_linear_out, dim=0) + return {base_name: stacked for base_name in base_hf_weight_names} + + per_base_stacks: Dict[str, List[torch.Tensor]] = {name: [] for name in base_hf_weight_names} + for expert_out in per_expert_linear_out: + per_base = self._get_fused_adapter_linear_out_slices( + megatron_model, + base_hf_weight_names, + expert_out, + is_expert=is_expert, + ) + assert per_base is not None, "Expected fused adapter split for expert LoRA" + for base_name in base_hf_weight_names: + per_base_stacks[base_name].append(per_base[base_name]) + + return {base_name: torch.stack(parts, dim=0) for base_name, parts in per_base_stacks.items()} + def _split_fused_fc1_linear_out_weight( self, linear_out_weight: torch.Tensor, @@ -640,6 +783,61 @@ def stream_adapter_weights_megatron_to_hf( if is_grouped_expert: base_suffixes = [f".weight{expert_num}" for expert_num in range(num_moe_experts)] + # If the HF base names don't include experts.N, emit packed expert weights + # (stacked along dim 0) once per HF name instead of duplicating per expert. + packed_expert = False + base_hf_weight_names: List[str] = [] + if is_grouped_expert and base_suffixes: + base_hf_weight_names = self._get_base_hf_param_names_for_adapter( + mapping_registry, + adapter_task.global_base_prefix, + adapter_task.adapter_key, + base_suffixes[0], + ) + if base_hf_weight_names and not any( + re.search(r"experts\.(\d+)", name) for name in base_hf_weight_names + ): + packed_expert = True + + if packed_expert: + linear_in_hf_names, linear_out_hf_names = self._build_lora_hf_names(base_hf_weight_names) + per_expert_linear_in, per_expert_linear_out = self._collect_packed_expert_adapter_tensors( + linear_in_tensor, + linear_out_tensor, + expert_linear_in_gathered, + expert_linear_out_gathered, + num_moe_experts, + ) + + if not per_expert_linear_in or not per_expert_linear_out: + raise ValueError( + f"Expected to find per-expert adapter weights for grouped expert " + f"linear layer but none found, global_base_prefix={adapter_task.global_base_prefix}" + ) + linear_in_stacked = torch.stack(per_expert_linear_in, dim=0) + if cpu: + linear_in_stacked = linear_in_stacked.cpu() + + if adapter_task.adapter_key is None: + linear_out_by_base = self._build_packed_expert_linear_out_by_base( + megatron_model, + base_hf_weight_names, + per_expert_linear_out, + is_expert=is_expert_linear(adapter_task.global_base_prefix), + ) + else: + shared_linear_out = torch.stack(per_expert_linear_out, dim=0) + linear_out_by_base = {base_name: shared_linear_out for base_name in base_hf_weight_names} + + for index, base_name in enumerate(base_hf_weight_names): + linear_out_stacked = linear_out_by_base[base_name] + if cpu: + linear_out_stacked = linear_out_stacked.cpu() + yield HFWeightTuple(linear_in_hf_names[index], linear_in_stacked) + yield HFWeightTuple(linear_out_hf_names[index], linear_out_stacked) + + continue + for base_suffix in base_suffixes: current_linear_in_tensor = linear_in_tensor current_linear_out_tensor = linear_out_tensor @@ -668,11 +866,7 @@ def stream_adapter_weights_megatron_to_hf( adapter_task.adapter_key, base_suffix, ) - linear_in_hf_names = [] - linear_out_hf_names = [] - for base_name in base_hf_weight_names: - linear_in_hf_names.append(self._make_lora_param_name(base_name, ".linear_in.weight")) - linear_out_hf_names.append(self._make_lora_param_name(base_name, ".linear_out.weight")) + linear_in_hf_names, linear_out_hf_names = self._build_lora_hf_names(base_hf_weight_names) if adapter_task.adapter_key is None: # Handle fused adapters (e.g., gate/up or q/k/v) by splitting the fused tensor # into per-base slices keyed by the HF weight names. @@ -719,6 +913,16 @@ def _get_fused_adapter_linear_out_slices( per_base[base_name] = qkv_linear_out_weights[projection_key] return per_base + if self._is_gdn_in_proj_split(base_hf_weight_names): + gdn_linear_out_weights = self._split_gdn_in_proj_linear_out_weight(megatron_model, linear_out_tensor) + per_base = {} + for base_name in base_hf_weight_names: + projection_key = self._infer_gdn_in_proj_projection_from_name(base_name) + if projection_key is None: + raise ValueError(f"Unknown GDN in_proj base weight name: {base_name}") + per_base[base_name] = gdn_linear_out_weights[projection_key] + return per_base + is_fused_fc1 = self._is_fused_fc1_gate_up(base_hf_weight_names, linear_out_tensor) if is_fused_fc1: gate_weight, up_weight = self._split_fused_fc1_linear_out_weight( @@ -745,6 +949,10 @@ def _merge_lora_adapter_weights( ) -> Dict[str, torch.Tensor]: """Merge LoRA adapter weights back into the base tensor for HF export.""" + if not converted_weights_dict: + # Nothing to merge on this rank (e.g., non-owning PP rank or filtered mapping). + return converted_weights_dict + if len(adapter_weights) > 1 and all( w.adapter_key in ADAPTER_NAME_MAP.values() for w in adapter_weights if w.adapter_key ): @@ -765,8 +973,41 @@ def _merge_lora_adapter_weights( expert_linear_in_gathered = self._gather_expert_adapter_weight(linear_in_weight) expert_linear_out_gathered = self._gather_expert_adapter_weight(linear_out_weight) - base_weight_shape = next(iter(converted_weights_dict.values())).shape + base_weight = next(iter(converted_weights_dict.values())) + base_weight_shape = base_weight.shape weight_names = converted_weights_dict.keys() + if self._is_gdn_in_proj_split(weight_names): + # GDN in_proj LoRA is defined on the fused Megatron tensor; split it into + # the four HF tensors (qkv/z/b/a) before merging. + config = unwrap_model(megatron_model)[0].config + hidden_size = config.hidden_size + qk_dim = config.linear_key_head_dim * config.linear_num_key_heads + v_dim = config.linear_value_head_dim * config.linear_num_value_heads + num_v_heads = config.linear_num_value_heads + fused_dim0 = 2 * qk_dim + 2 * v_dim + 2 * num_v_heads + + base_device = base_weight.device + linear_out_on_base = ( + linear_out_weight if linear_out_weight.device == base_device else linear_out_weight.to(base_device) + ) + linear_in_on_base = ( + linear_in_weight if linear_in_weight.device == base_device else linear_in_weight.to(base_device) + ) + dummy_base = torch.zeros((fused_dim0, hidden_size), device=base_device, dtype=base_weight.dtype) + lora_weight = LoRAMerge().merge(dummy_base, linear_out_on_base, linear_in_on_base, alpha, dim) + + tp_size = parallel_state.get_tensor_model_parallel_world_size() + qkvz, ba = split_gdn_linear_weights(config, lora_weight, tp_size=tp_size) + qkv, z, b, a = _split_gdn_grouped_to_separate(config, qkvz, ba) + gdn_slices = {"in_proj_qkv": qkv, "in_proj_z": z, "in_proj_b": b, "in_proj_a": a} + + for hf_name, base_tensor in list(converted_weights_dict.items()): + projection_key = self._infer_gdn_in_proj_projection_from_name(hf_name) + if projection_key is None: + raise ValueError(f"Unknown GDN in_proj weight name: {hf_name}") + converted_weights_dict[hf_name] = base_tensor + gdn_slices[projection_key] + + return converted_weights_dict is_fused_fc1 = self._is_fused_fc1_gate_up(weight_names, linear_out_weight, base_weight_shape) is_fused_qkv = self._is_fused_qkv(weight_names) and not is_expert qkv_linear_out_weights = ( @@ -784,18 +1025,19 @@ def _merge_lora_adapter_weights( current_linear_out_weight = linear_out_weight if is_grouped_expert: expert_idx = self._infer_hf_expert_idx(hf_name) - current_linear_in_weight = self._select_expert_adapter_weight( - linear_in_weight, - expert_linear_in_gathered, - expert_idx, - num_moe_experts, - ) - current_linear_out_weight = self._select_expert_adapter_weight( - linear_out_weight, - expert_linear_out_gathered, - expert_idx, - num_moe_experts, - ) + if expert_idx is not None: + current_linear_in_weight = self._select_expert_adapter_weight( + linear_in_weight, + expert_linear_in_gathered, + expert_idx, + num_moe_experts, + ) + current_linear_out_weight = self._select_expert_adapter_weight( + linear_out_weight, + expert_linear_out_gathered, + expert_idx, + num_moe_experts, + ) if is_fused_fc1: if is_expert: fc1_gate_weight, fc1_up_weight = self._split_fused_fc1_linear_out_weight( @@ -890,18 +1132,19 @@ def _merge_canonical_adapter_from_weights( linear_out_weight = target_adapter.linear_out_weight.weight if is_grouped_expert: expert_idx = self._infer_hf_expert_idx(hf_name) - linear_in_weight = self._select_expert_adapter_weight( - linear_in_weight, - expert_linear_in_gathered.get(target_adapter_key), - expert_idx, - num_moe_experts, - ) - linear_out_weight = self._select_expert_adapter_weight( - linear_out_weight, - expert_linear_out_gathered.get(target_adapter_key), - expert_idx, - num_moe_experts, - ) + if expert_idx is not None: + linear_in_weight = self._select_expert_adapter_weight( + linear_in_weight, + expert_linear_in_gathered.get(target_adapter_key), + expert_idx, + num_moe_experts, + ) + linear_out_weight = self._select_expert_adapter_weight( + linear_out_weight, + expert_linear_out_gathered.get(target_adapter_key), + expert_idx, + num_moe_experts, + ) merged_weight = self._merge_single_adapter_weight( base_weight, diff --git a/src/megatron/bridge/peft/adapter_wrapper.py b/src/megatron/bridge/peft/adapter_wrapper.py index c84c9f6907..fcc828e21f 100644 --- a/src/megatron/bridge/peft/adapter_wrapper.py +++ b/src/megatron/bridge/peft/adapter_wrapper.py @@ -204,7 +204,7 @@ def sharded_state_dict( The combined sharded state dictionary. """ adapter_sharded_state_dict_kwargs = {} - if isinstance(self.adapter, ParallelLinearAdapter) and "in_proj" in self.adapter.base_linear_name: + if isinstance(self.adapter, ParallelLinearAdapter) and "mixer.in_proj" in self.adapter.base_linear_name: adapter_sharded_state_dict_kwargs["mamba_dim_info"] = _compute_mamba_dim_info(self.to_wrap) sharded_state_dict = {} diff --git a/src/megatron/bridge/peft/canonical_lora.py b/src/megatron/bridge/peft/canonical_lora.py index b832ab1672..1f99a69d2d 100644 --- a/src/megatron/bridge/peft/canonical_lora.py +++ b/src/megatron/bridge/peft/canonical_lora.py @@ -31,6 +31,12 @@ logger = logging.getLogger(__name__) +def _should_treat_linear_fc1_as_unfused(full_name: str) -> bool: + """Return True when CanonicalLoRA should keep linear_fc1 as a single adapter.""" + + return full_name.startswith("vision_model.") or full_name.endswith(".mlp.experts.linear_fc1") + + class ModuleDict(nn.ModuleDict): """ nn.ModuleDict with a sharded_state_dict implementation for checkpointing @@ -302,6 +308,11 @@ def transform(self, m: nn.Module, name: Optional[str] = None, prefix: Optional[s base_linear_is_parallel=attrs.base_linear_is_parallel, ) + if name == "linear_fc1" and _should_treat_linear_fc1_as_unfused(full_name): + logger.info(f"Adding lora to: {full_name} (treating unsupported canonical linear_fc1 as unfused)") + adapter = ParallelLinearAdapter(attrs.in_features, attrs.out_features, **adapter_kwargs) + return LoRALinear(m, adapter) + canonical_submodules = self.canonical_mapping[match] logger.info(f"Adding lora to: {full_name} ({canonical_submodules})") if name == "linear_qkv": diff --git a/tests/unit_tests/models/test_model_bridge_lora.py b/tests/unit_tests/models/test_model_bridge_lora.py index 6641f72c4c..03287065ca 100644 --- a/tests/unit_tests/models/test_model_bridge_lora.py +++ b/tests/unit_tests/models/test_model_bridge_lora.py @@ -15,6 +15,7 @@ from types import SimpleNamespace from unittest.mock import Mock +import pytest import torch from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry @@ -24,7 +25,13 @@ MegatronWeightTuple, WeightConversionTask, ) -from megatron.bridge.models.conversion.param_mapping import AutoMapping, ColumnParallelMapping, merge_qkv_weights +from megatron.bridge.models.conversion.param_mapping import ( + AutoMapping, + ColumnParallelMapping, + _fuse_gdn_separate_to_grouped, + merge_gdn_linear_weights, + merge_qkv_weights, +) from megatron.bridge.models.conversion.peft_bridge import AdapterWeight from megatron.bridge.peft.utils import AdapterAttributes @@ -37,6 +44,30 @@ def mapping_registry(self): # pragma: no cover - not used in tests return MegatronMappingRegistry() +@pytest.fixture(autouse=True) +def _patch_parallel_state(monkeypatch): + monkeypatch.setattr( + "megatron.bridge.models.conversion.peft_bridge.parallel_state.get_tensor_model_parallel_world_size", + lambda: 1, + ) + monkeypatch.setattr( + "megatron.bridge.models.conversion.peft_bridge.parallel_state.get_expert_tensor_parallel_world_size", + lambda: 1, + ) + monkeypatch.setattr( + "megatron.bridge.models.conversion.peft_bridge.parallel_state.get_expert_model_parallel_world_size", + lambda: 1, + ) + monkeypatch.setattr( + "megatron.bridge.peft.lora.parallel_state.get_tensor_model_parallel_world_size", + lambda: 1, + ) + monkeypatch.setattr( + "megatron.bridge.peft.lora.parallel_state.get_tensor_model_parallel_group", + lambda: None, + ) + + def test_merge_lora_adapter_weights_merges(monkeypatch): bridge = DummyBridge() base_weight = torch.zeros(4, 4) @@ -59,6 +90,25 @@ def test_merge_lora_adapter_weights_merges(monkeypatch): torch.testing.assert_close(updated["hf.weight"], expected) +def test_merge_lora_adapter_weights_empty_returns_empty(): + bridge = DummyBridge() + adapter_weight = AdapterWeight( + global_base_prefix="decoder.layers.0.mlp.linear_fc1", + adapter_key=None, + alpha=4, + dim=4, + linear_in_weight=MegatronWeightTuple("in", torch.eye(4), vp_stage=0), + linear_out_weight=MegatronWeightTuple("out", torch.eye(4), vp_stage=0), + ) + + updated = bridge._merge_lora_adapter_weights( + [SimpleNamespace(config=SimpleNamespace(num_moe_experts=0))], + {}, + [adapter_weight], + ) + assert updated == {} + + def test_merge_single_adapter_weight_matches_loramerge(): bridge = DummyBridge() base = torch.zeros(2, 2) @@ -179,6 +229,62 @@ def test_merge_lora_adapter_weights_qkv_split(monkeypatch): torch.testing.assert_close(updated["v_proj.weight"], v_weight) +def test_merge_lora_adapter_weights_grouped_expert_missing_expert_idx(monkeypatch): + bridge = DummyBridge() + base = torch.zeros(2, 2) + converted = {"model.layers.0.mlp.experts.down_proj.weight": base.clone()} + + adapter_weight = AdapterWeight( + global_base_prefix="decoder.layers.0.mlp.experts.linear_fc2", + adapter_key=None, + alpha=2, + dim=2, + linear_in_weight=MegatronWeightTuple("in", torch.eye(2), vp_stage=0), + linear_out_weight=MegatronWeightTuple("out", 2 * torch.eye(2), vp_stage=0), + ) + + monkeypatch.setattr( + "megatron.bridge.models.conversion.peft_bridge.parallel_state.get_expert_model_parallel_world_size", + lambda: 1, + ) + + updated = bridge._merge_lora_adapter_weights( + [SimpleNamespace(config=SimpleNamespace(num_moe_experts=2))], + converted, + [adapter_weight], + ) + + torch.testing.assert_close(updated["model.layers.0.mlp.experts.down_proj.weight"], 2 * torch.eye(2)) + + +def test_merge_lora_adapter_weights_grouped_expert_gate_up_proj_unfused(monkeypatch): + bridge = DummyBridge() + base = torch.zeros(2, 2) + converted = {"model.language_model.layers.0.mlp.experts.gate_up_proj": base.clone()} + + adapter_weight = AdapterWeight( + global_base_prefix="language_model.decoder.layers.0.mlp.experts.linear_fc1", + adapter_key=None, + alpha=2, + dim=2, + linear_in_weight=MegatronWeightTuple("in", torch.eye(2), vp_stage=0), + linear_out_weight=MegatronWeightTuple("out", 2 * torch.eye(2), vp_stage=0), + ) + + monkeypatch.setattr( + "megatron.bridge.models.conversion.peft_bridge.parallel_state.get_expert_model_parallel_world_size", + lambda: 1, + ) + + updated = bridge._merge_lora_adapter_weights( + [SimpleNamespace(config=SimpleNamespace(num_moe_experts=2))], + converted, + [adapter_weight], + ) + + torch.testing.assert_close(updated["model.language_model.layers.0.mlp.experts.gate_up_proj"], 2 * torch.eye(2)) + + def test_merge_canonical_adapter_from_weights(monkeypatch): bridge = DummyBridge() converted = { @@ -375,6 +481,32 @@ def test_construct_adapters_names(): assert linear_out_k.endswith("adapter_q.linear_out.weight") +def test_make_lora_param_name_without_weight_suffix(): + bridge = DummyBridge() + base_name = "model.layers.0.mlp.experts.down_proj" + assert bridge._make_lora_param_name(base_name, ".linear_in.weight") == base_name + ".lora_A.weight" + assert bridge._make_lora_param_name(base_name, ".linear_out.weight") == base_name + ".lora_B.weight" + + +def test_resolve_hf_adapter_param_name_without_weight_suffix(): + bridge = DummyBridge() + registry = MegatronMappingRegistry( + AutoMapping( + megatron_param="decoder.layers.*.mlp.linear_fc1.weight", + hf_param="model.layers.*.mlp.experts.gate_up_proj", + ) + ) + + name = bridge._resolve_hf_adapter_param_name( + registry, + "decoder.layers.0.mlp.linear_fc1", + ".linear_in.weight", + ".weight", + None, + ) + assert name == "model.layers.0.mlp.experts.gate_up_proj.lora_A.weight" + + def test_build_adapter_conversion_tasks(monkeypatch): bridge = DummyBridge() @@ -666,12 +798,215 @@ def test_stream_adapter_weights_megatron_to_hf_fused_fc1(monkeypatch): torch.testing.assert_close(weights[3].weight, torch.full((2, 2), 2.0)) -def test_stream_weights_megatron_to_hf_skips_merge_when_disabled(monkeypatch): +def test_stream_adapter_weights_megatron_to_hf_packed_expert_stacks(monkeypatch): + bridge = DummyBridge() + + adapter_task = AdapterWeightConversionTask( + global_base_prefix="decoder.layers.0.mlp.experts.linear_fc2", + adapter_key=None, + alpha=2, + dim=4, + linear_in_task=WeightConversionTask( + param_name="local_in", + global_param_name="decoder.layers.0.mlp.experts.linear_fc2.adapter.linear_in.weight", + mapping=Mock(), + ), + linear_out_task=WeightConversionTask( + param_name="local_out", + global_param_name="decoder.layers.0.mlp.experts.linear_fc2.adapter.linear_out.weight", + mapping=Mock(), + ), + ) + + linear_in = torch.stack( + [torch.full((2, 2), 1.0), torch.full((2, 2), 2.0)], + dim=0, + ) + linear_out = torch.stack( + [torch.full((2, 2), 3.0), torch.full((2, 2), 4.0)], + dim=0, + ) + adapter_weight = AdapterWeight( + global_base_prefix="decoder.layers.0.mlp.experts.linear_fc2", + adapter_key=None, + alpha=2, + dim=4, + linear_in_weight=MegatronWeightTuple("local_in", linear_in, vp_stage=0), + linear_out_weight=MegatronWeightTuple("local_out", linear_out, vp_stage=0), + ) + + monkeypatch.setattr( + bridge, + "build_adapter_conversion_tasks", + lambda *_: {"decoder.layers.0.mlp.experts.linear_fc2": [adapter_task]}, + ) + monkeypatch.setattr(bridge, "materialize_adapter_weights", lambda *_: [adapter_weight]) + monkeypatch.setattr( + bridge, + "_get_base_hf_param_names_for_adapter", + lambda *_args, **_kwargs: ["model.layers.0.mlp.experts.down_proj"], + ) + monkeypatch.setattr( + "megatron.bridge.models.conversion.peft_bridge.parallel_state.get_expert_model_parallel_world_size", + lambda: 1, + ) + + weights = list( + bridge.stream_adapter_weights_megatron_to_hf( + [SimpleNamespace(config=SimpleNamespace(num_moe_experts=2))], + cpu=False, + show_progress=False, + ) + ) + + assert len(weights) == 2 + assert weights[0].param_name.endswith("down_proj.lora_A.weight") + assert weights[1].param_name.endswith("down_proj.lora_B.weight") + torch.testing.assert_close(weights[0].weight, linear_in) + torch.testing.assert_close(weights[1].weight, linear_out) + + +def test_split_gdn_in_proj_linear_out_weight_roundtrip(monkeypatch): + bridge = DummyBridge() + config = SimpleNamespace( + hidden_size=4, + linear_key_head_dim=1, + linear_num_key_heads=2, + linear_value_head_dim=1, + linear_num_value_heads=2, + ) + + qkv = torch.arange(24, dtype=torch.float32).reshape(6, 4) + z = torch.full((2, 4), 100.0) + b = torch.full((2, 4), 200.0) + a = torch.full((2, 4), 300.0) + + qkvz, ba = _fuse_gdn_separate_to_grouped(config, qkv, z, b, a) + linear_out = merge_gdn_linear_weights(config, qkvz, ba, tp_size=1) + + monkeypatch.setattr( + "megatron.bridge.models.conversion.peft_bridge.parallel_state.get_tensor_model_parallel_world_size", + lambda: 1, + ) + + split = bridge._split_gdn_in_proj_linear_out_weight( + [SimpleNamespace(config=config)], + linear_out, + ) + + torch.testing.assert_close(split["in_proj_qkv"], qkv) + torch.testing.assert_close(split["in_proj_z"], z) + torch.testing.assert_close(split["in_proj_b"], b) + torch.testing.assert_close(split["in_proj_a"], a) + + +def test_get_fused_adapter_linear_out_slices_gdn_mapping(monkeypatch): + bridge = DummyBridge() + base_hf_weight_names = [ + "model.layers.0.linear_attn.in_proj_qkv.weight", + "model.layers.0.linear_attn.in_proj_z.weight", + "model.layers.0.linear_attn.in_proj_b.weight", + "model.layers.0.linear_attn.in_proj_a.weight", + ] + gdn_slices = { + "in_proj_qkv": torch.full((6, 2), 1.0), + "in_proj_z": torch.full((2, 2), 2.0), + "in_proj_b": torch.full((2, 2), 3.0), + "in_proj_a": torch.full((2, 2), 4.0), + } + + monkeypatch.setattr( + bridge, + "_split_gdn_in_proj_linear_out_weight", + lambda *_args, **_kwargs: gdn_slices, + ) + + per_base = bridge._get_fused_adapter_linear_out_slices( + [SimpleNamespace(config=SimpleNamespace())], + base_hf_weight_names, + torch.zeros(1, 1), + is_expert=False, + ) + + assert per_base is not None + torch.testing.assert_close(per_base["model.layers.0.linear_attn.in_proj_qkv.weight"], gdn_slices["in_proj_qkv"]) + torch.testing.assert_close(per_base["model.layers.0.linear_attn.in_proj_z.weight"], gdn_slices["in_proj_z"]) + torch.testing.assert_close(per_base["model.layers.0.linear_attn.in_proj_b.weight"], gdn_slices["in_proj_b"]) + torch.testing.assert_close(per_base["model.layers.0.linear_attn.in_proj_a.weight"], gdn_slices["in_proj_a"]) + assert bridge._infer_gdn_in_proj_projection_from_name("foo.bar.in_proj_z.weight") == "in_proj_z" + assert bridge._infer_gdn_in_proj_projection_from_name("foo.bar.unknown.weight") is None + + +def test_merge_lora_adapter_weights_gdn_in_proj_split(monkeypatch): + bridge = DummyBridge() + config = SimpleNamespace( + hidden_size=4, + linear_key_head_dim=1, + linear_num_key_heads=2, + linear_value_head_dim=1, + linear_num_value_heads=2, + num_moe_experts=0, + ) + megatron_model = [SimpleNamespace(config=config)] + + qkv_delta = torch.arange(24, dtype=torch.float32).reshape(6, 4) + z_delta = torch.full((2, 4), 10.0) + b_delta = torch.full((2, 4), 20.0) + a_delta = torch.full((2, 4), 30.0) + + base_qkv = torch.full((6, 4), -1.0) + base_z = torch.full((2, 4), -2.0) + base_b = torch.full((2, 4), -3.0) + base_a = torch.full((2, 4), -4.0) + converted = { + "model.layers.0.linear_attn.in_proj_qkv.weight": base_qkv.clone(), + "model.layers.0.linear_attn.in_proj_z.weight": base_z.clone(), + "model.layers.0.linear_attn.in_proj_b.weight": base_b.clone(), + "model.layers.0.linear_attn.in_proj_a.weight": base_a.clone(), + } + + qkvz, ba = _fuse_gdn_separate_to_grouped(config, qkv_delta, z_delta, b_delta, a_delta) + linear_out = merge_gdn_linear_weights(config, qkvz, ba, tp_size=1) + adapter_weight = AdapterWeight( + global_base_prefix="decoder.layers.0.self_attention.linear_qkv", + adapter_key=None, + alpha=4, + dim=4, + linear_in_weight=MegatronWeightTuple("in", torch.eye(4), vp_stage=0), + linear_out_weight=MegatronWeightTuple("out", linear_out, vp_stage=0), + ) + + monkeypatch.setattr( + "megatron.bridge.models.conversion.peft_bridge.parallel_state.get_tensor_model_parallel_world_size", + lambda: 1, + ) + + updated = bridge._merge_lora_adapter_weights(megatron_model, converted, [adapter_weight]) + + torch.testing.assert_close( + updated["model.layers.0.linear_attn.in_proj_qkv.weight"], + base_qkv + qkv_delta, + ) + torch.testing.assert_close( + updated["model.layers.0.linear_attn.in_proj_z.weight"], + base_z + z_delta, + ) + torch.testing.assert_close( + updated["model.layers.0.linear_attn.in_proj_b.weight"], + base_b + b_delta, + ) + torch.testing.assert_close( + updated["model.layers.0.linear_attn.in_proj_a.weight"], + base_a + a_delta, + ) + + +def _stream_weights_with_merge_disabled(monkeypatch, converted_name: str): bridge = DummyBridge() class DummyMapping: def megatron_to_hf(self, weight, module): - return {"hf.weight": torch.ones(1)} + return {converted_name: torch.ones(1)} task = WeightConversionTask( param_name="decoder.layers.0.mlp.linear_fc1.to_wrap.weight", @@ -719,8 +1054,22 @@ def _raise_on_build(*_args, **_kwargs): ) ) + return weights + + +@pytest.mark.parametrize( + ("converted_name", "expected_name"), + [ + ("hf.weight", "hf.base_layer.weight"), + ("hf.tensor", "hf.base_layer.tensor"), + ("hf", "hf.base_layer"), + ], +) +def test_stream_weights_megatron_to_hf_skips_merge_when_disabled(monkeypatch, converted_name, expected_name): + weights = _stream_weights_with_merge_disabled(monkeypatch, converted_name) + assert len(weights) == 1 - assert weights[0].param_name in {"hf.weight", "hf.base_layer.weight"} + assert weights[0].param_name == expected_name def test_column_parallel_mapping_skips_ep_gather_for_adapters(monkeypatch): diff --git a/tests/unit_tests/peft/test_adapter_wrapper.py b/tests/unit_tests/peft/test_adapter_wrapper.py index 3c2a180bb3..1ffd70fd65 100644 --- a/tests/unit_tests/peft/test_adapter_wrapper.py +++ b/tests/unit_tests/peft/test_adapter_wrapper.py @@ -19,7 +19,7 @@ modules with adapters in Parameter-Efficient Fine-Tuning scenarios. """ -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest import torch @@ -88,6 +88,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return output +class MockParallelLinearAdapter(nn.Module): + """Minimal ParallelLinearAdapter stand-in for sharded_state_dict tests.""" + + def __init__(self, base_linear_name: str): + super().__init__() + self.base_linear_name = base_linear_name + + class TestAdapterWrapper: """Test the AdapterWrapper base class.""" @@ -235,6 +243,46 @@ def test_sharded_state_dict(self, mock_linear_simple, simple_adapter): mock_linear_simple.sharded_state_dict.assert_called_once_with("test_", (), None) simple_adapter.sharded_state_dict.assert_called_once_with("test_adapter.", (), None) + def test_sharded_state_dict_skips_mamba_metadata_for_non_mixer_in_proj(self, mock_linear_simple): + """Test that non-Mamba in_proj adapters do not request Mamba metadata.""" + mock_linear_simple.sharded_state_dict = Mock(return_value={"linear_shard": "value1"}) + adapter = MockParallelLinearAdapter("decoder.layers.0.self_attention.in_proj") + adapter.sharded_state_dict = Mock(return_value={"adapter_shard": "value2"}) + + with ( + patch("megatron.bridge.peft.adapter_wrapper.ParallelLinearAdapter", MockParallelLinearAdapter), + patch( + "megatron.bridge.peft.adapter_wrapper._compute_mamba_dim_info", return_value={"dummy": 1} + ) as mock_dim, + ): + wrapper = ConcreteAdapterWrapper(mock_linear_simple, adapter) + result = wrapper.sharded_state_dict(prefix="test_") + + assert "linear_shard" in result + assert "adapter_shard" in result + mock_dim.assert_not_called() + adapter.sharded_state_dict.assert_called_once_with("test_adapter.", (), None) + + def test_sharded_state_dict_adds_mamba_metadata_for_mixer_in_proj(self, mock_linear_simple): + """Test that Mamba mixer.in_proj adapters request Mamba metadata.""" + mock_linear_simple.sharded_state_dict = Mock(return_value={"linear_shard": "value1"}) + adapter = MockParallelLinearAdapter("decoder.layers.0.mixer.in_proj") + adapter.sharded_state_dict = Mock(return_value={"adapter_shard": "value2"}) + + with ( + patch("megatron.bridge.peft.adapter_wrapper.ParallelLinearAdapter", MockParallelLinearAdapter), + patch( + "megatron.bridge.peft.adapter_wrapper._compute_mamba_dim_info", return_value={"dummy": 1} + ) as mock_dim, + ): + wrapper = ConcreteAdapterWrapper(mock_linear_simple, adapter) + result = wrapper.sharded_state_dict(prefix="test_") + + assert "linear_shard" in result + assert "adapter_shard" in result + mock_dim.assert_called_once_with(mock_linear_simple) + adapter.sharded_state_dict.assert_called_once_with("test_adapter.", (), None, mamba_dim_info={"dummy": 1}) + def test_forward_integration(self, mock_linear_simple, simple_adapter): """Test full forward pass integration.""" wrapper = ConcreteAdapterWrapper(mock_linear_simple, simple_adapter) diff --git a/tests/unit_tests/peft/test_canonical_lora.py b/tests/unit_tests/peft/test_canonical_lora.py index 5b0ec0f2fc..cb34016516 100644 --- a/tests/unit_tests/peft/test_canonical_lora.py +++ b/tests/unit_tests/peft/test_canonical_lora.py @@ -84,6 +84,37 @@ def __init__(self): self.linear_fc2 = MockMegatronLinear(2048, 512) +class VisionLanguageMegatronStyleModel(nn.Module): + """Model with both language and vision linear_fc1 modules.""" + + def __init__(self): + super().__init__() + self.language_model = nn.Module() + self.language_model.linear_fc1 = MockMegatronLinear(512, 2048) + + self.vision_model = nn.Module() + self.vision_model.merger = nn.Module() + self.vision_model.merger.linear_fc1 = MockMegatronLinear(512, 512) + + +class MoEMegatronStyleModel(nn.Module): + """Model with dense, expert, and shared-expert linear_fc1 modules.""" + + def __init__(self): + super().__init__() + self.language_model = nn.Module() + self.language_model.decoder = nn.Module() + self.language_model.decoder.layers = nn.ModuleList([nn.Module()]) + + layer = self.language_model.decoder.layers[0] + layer.mlp = nn.Module() + layer.mlp.linear_fc1 = MockMegatronLinear(512, 2048) + layer.mlp.experts = nn.Module() + layer.mlp.experts.linear_fc1 = MockMegatronLinear(512, 2048) + layer.mlp.shared_experts = nn.Module() + layer.mlp.shared_experts.linear_fc1 = MockMegatronLinear(512, 2048) + + class NestedModel(nn.Module): """Model with nested structure for testing pattern matching.""" @@ -273,6 +304,62 @@ def mock_get_attrs(module, is_expert=False): assert isinstance(transformed_model.linear_proj, MockMegatronLinear) assert isinstance(transformed_model.linear_fc2, MockMegatronLinear) + def test_canonical_lora_treats_visual_linear_fc1_as_unfused(self): + """Vision-side linear_fc1 should keep a single unfused LoRA adapter.""" + model = VisionLanguageMegatronStyleModel() + lora = CanonicalLoRA(target_modules=["linear_fc1_up", "linear_fc1_gate"]) + + def mock_get_attrs(module, is_expert=False): + return AdapterAttributes( + input_is_parallel=False, + in_features=module.in_features, + out_features=module.out_features, + disable_tensor_parallel_comm=False, + disable_sequence_parallel_comm=True, + base_linear_is_parallel=True, + ) + + with patch( + "megatron.bridge.peft.canonical_lora.get_adapter_attributes_from_linear", side_effect=mock_get_attrs + ): + with patch("megatron.bridge.peft.canonical_lora.ParallelLinearAdapter") as mock_adapter: + mock_adapter.return_value = nn.Linear(1, 1) + + transformed_model = lora(model, training=True) + + assert isinstance(transformed_model.language_model.linear_fc1, LoRALinearSplitFC1UpGate) + assert isinstance(transformed_model.vision_model.merger.linear_fc1, LoRALinear) + assert not isinstance(transformed_model.vision_model.merger.linear_fc1, LoRALinearSplitFC1UpGate) + + def test_canonical_lora_treats_moe_expert_linear_fc1_as_unfused(self): + """Grouped expert linear_fc1 should keep a single unfused LoRA adapter.""" + model = MoEMegatronStyleModel() + lora = CanonicalLoRA(target_modules=["linear_fc1_up", "linear_fc1_gate"]) + + def mock_get_attrs(module, is_expert=False): + return AdapterAttributes( + input_is_parallel=False, + in_features=module.in_features, + out_features=module.out_features, + disable_tensor_parallel_comm=False, + disable_sequence_parallel_comm=True, + base_linear_is_parallel=True, + ) + + with patch( + "megatron.bridge.peft.canonical_lora.get_adapter_attributes_from_linear", side_effect=mock_get_attrs + ): + with patch("megatron.bridge.peft.canonical_lora.ParallelLinearAdapter") as mock_adapter: + mock_adapter.return_value = nn.Linear(1, 1) + + transformed_model = lora(model, training=True) + + layer = transformed_model.language_model.decoder.layers[0] + assert isinstance(layer.mlp.linear_fc1, LoRALinearSplitFC1UpGate) + assert isinstance(layer.mlp.experts.linear_fc1, LoRALinear) + assert not isinstance(layer.mlp.experts.linear_fc1, LoRALinearSplitFC1UpGate) + assert isinstance(layer.mlp.shared_experts.linear_fc1, LoRALinearSplitFC1UpGate) + def test_canonical_lora_transform_nested_model(self): """Test CanonicalLoRA transformation on nested model structures.""" model = NestedModel()