diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index eaed6e2265cd..228d1c5fff24 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -283,10 +283,9 @@ def can_replace_layer( packed_modules_list: list, model_config: PretrainedConfig | None = None, ) -> bool: - return ( - type(source_layer) is MergedColumnParallelLinear - and len(packed_modules_list) == 2 - ) + return type(source_layer) is MergedColumnParallelLinear and len( + packed_modules_list + ) == len(source_layer.output_sizes) class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): @@ -466,18 +465,14 @@ class MergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithLo def slice_lora_a( self, lora_a: list[torch.Tensor | None] ) -> list[torch.Tensor | None]: - # NOTE: lora_a contains 2 subloras, and each sublora could be None. output_shard_size = self.lora_a_stacked[0].shape[2] output_start_idx = self.tp_rank * output_shard_size - lora_a = [ - lora_a[0][output_start_idx : output_start_idx + output_shard_size, :] - if lora_a[0] is not None - else None, - lora_a[1][output_start_idx : output_start_idx + output_shard_size, :] - if lora_a[1] is not None - else None, + return [ + a[output_start_idx : output_start_idx + output_shard_size, :] + if a is not None + else None + for a in lora_a ] - return lora_a def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor: return _mcp_apply(x, bias, self) @@ -548,21 +543,12 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA): def slice_lora_a( self, lora_a: list[torch.Tensor | None] ) -> list[torch.Tensor | None]: - # NOTE: lora_a contains 3 subloras, and each sublora could be None. - shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] - start_idx = [self.tp_rank * shard_size[i] for i in range(3)] - lora_a = [ - lora_a[0][start_idx[0] : start_idx[0] + shard_size[0], :] - if lora_a[0] is not None - else None, - lora_a[1][start_idx[1] : start_idx[1] + shard_size[1], :] - if lora_a[1] is not None - else None, - lora_a[2][start_idx[2] : start_idx[2] + shard_size[2], :] - if lora_a[2] is not None - else None, + return [ + a[self.tp_rank * s.shape[2] : (self.tp_rank + 1) * s.shape[2], :] + if a is not None + else None + for a, s in zip(lora_a, self.lora_a_stacked) ] - return lora_a def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor: return _mcp_apply(x, bias, self) @@ -613,8 +599,17 @@ def can_replace_layer( if len(packed_modules_list) >= 3: return True - # If packed_modules_list has exactly 2 items, let - # MergedColumnParallelLinearWithLoRA handle it + # Handle mismatch between packed modules count and output_sizes + # count. E.g. Qwen3.5 GDN layers have 2 packed modules + # (in_proj_qkv, in_proj_z) but 4 output_sizes — one HF weight + # covers multiple output slices. + if hasattr(source_layer, "output_sizes") and len(packed_modules_list) != len( + source_layer.output_sizes + ): + return True + + # If packed_modules_list has exactly 2 items matching output_sizes, + # let MergedColumnParallelLinearWithLoRA handle it if len(packed_modules_list) == 2: return False @@ -633,10 +628,19 @@ def set_lora( lora_a: torch.Tensor | list[torch.Tensor], lora_b: torch.Tensor | list[torch.Tensor], ): - """Override to handle single tensor weights - that need to be split into slices.""" + """Override to handle weights that need to be split into slices. + + Handles three cases: + 1. Single tensor: split lora_b by output_sizes, duplicate lora_a + 2. List matching n_slices: pass through directly + 3. List shorter than n_slices: expand by splitting lora_b entries + that cover multiple output slices (e.g., 2 packed modules + mapping to 4 output_sizes in Qwen3.5 GDN layers) + """ self.reset_lora(index) + output_sizes = self.base_layer.output_sizes + # Handle case where checkpoint has single tensor weights # lora_a shape: (rank, input_size) - same for all slices, duplicate it if isinstance(lora_a, torch.Tensor): @@ -645,7 +649,6 @@ def set_lora( # lora_b shape: (total_output_size, rank) - # split along dim 0 based on output_sizes if isinstance(lora_b, torch.Tensor): - output_sizes = self.base_layer.output_sizes lora_b_list = [] start_idx = 0 for output_size in output_sizes: @@ -654,5 +657,56 @@ def set_lora( start_idx = end_idx lora_b = lora_b_list + # Handle list shorter than n_slices: one packed module may cover + # multiple output slices. Greedily match each lora_b entry to + # consecutive output_sizes by dimension, then split accordingly. + if isinstance(lora_a, list) and len(lora_a) < self.n_slices: + expanded_a: list[torch.Tensor | None] = [] + expanded_b: list[torch.Tensor | None] = [] + slice_idx = 0 + for i, (a_i, b_i) in enumerate(zip(lora_a, lora_b)): + if b_i is None: + # Figure out how many slices this None covers + remaining = len(lora_a) - i - 1 + remaining_slices = self.n_slices - slice_idx + count = remaining_slices - remaining + expanded_a.extend([None] * count) + expanded_b.extend([None] * count) + slice_idx += count + else: + b_dim = b_i.shape[0] + # Greedily consume output_sizes until we match b_dim + consumed = 0 + start_slice = slice_idx + while slice_idx < self.n_slices and consumed < b_dim: + consumed += output_sizes[slice_idx] + slice_idx += 1 + if consumed != b_dim: + raise ValueError( + f"Packed LoRA B dimension {b_dim} does not match " + f"the sum of output sizes {consumed} for LoRA {i}." + ) + num_covered = slice_idx - start_slice + if num_covered == 1: + expanded_a.append(a_i) + expanded_b.append(b_i) + else: + # Split lora_b and duplicate lora_a + split_start = 0 + for j in range(start_slice, slice_idx): + sz = output_sizes[j] + expanded_a.append(a_i) + expanded_b.append(b_i[split_start : split_start + sz, :]) + split_start += sz + + # Pad remaining slices with None (e.g. dummy LoRA warmup + # where per-slice dimensions don't span multiple output_sizes) + while len(expanded_a) < self.n_slices: + expanded_a.append(None) + expanded_b.append(None) + + lora_a = expanded_a + lora_b = expanded_b + # Now call parent's set_lora which expects lists super().set_lora(index, lora_a, lora_b)