diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index eaed6e2265cd..dd6646262a4e 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -283,10 +283,9 @@ def can_replace_layer( packed_modules_list: list, model_config: PretrainedConfig | None = None, ) -> bool: - return ( - type(source_layer) is MergedColumnParallelLinear - and len(packed_modules_list) == 2 - ) + return type(source_layer) is MergedColumnParallelLinear and len( + packed_modules_list + ) == len(source_layer.output_sizes) class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): @@ -613,8 +612,17 @@ def can_replace_layer( if len(packed_modules_list) >= 3: return True - # If packed_modules_list has exactly 2 items, let - # MergedColumnParallelLinearWithLoRA handle it + # Handle mismatch between packed modules count and output_sizes + # count. E.g. Qwen3.5 GDN layers have 2 packed modules + # (in_proj_qkv, in_proj_z) but 4 output_sizes — one HF weight + # covers multiple output slices. + if hasattr(source_layer, "output_sizes") and len(packed_modules_list) != len( + source_layer.output_sizes + ): + return True + + # If packed_modules_list has exactly 2 items matching output_sizes, + # let MergedColumnParallelLinearWithLoRA handle it if len(packed_modules_list) == 2: return False @@ -633,10 +641,19 @@ def set_lora( lora_a: torch.Tensor | list[torch.Tensor], lora_b: torch.Tensor | list[torch.Tensor], ): - """Override to handle single tensor weights - that need to be split into slices.""" + """Override to handle weights that need to be split into slices. + + Handles three cases: + 1. Single tensor: split lora_b by output_sizes, duplicate lora_a + 2. List matching n_slices: pass through directly + 3. List shorter than n_slices: expand by splitting lora_b entries + that cover multiple output slices (e.g., 2 packed modules + mapping to 4 output_sizes in Qwen3.5 GDN layers) + """ self.reset_lora(index) + output_sizes = self.base_layer.output_sizes + # Handle case where checkpoint has single tensor weights # lora_a shape: (rank, input_size) - same for all slices, duplicate it if isinstance(lora_a, torch.Tensor): @@ -645,7 +662,6 @@ def set_lora( # lora_b shape: (total_output_size, rank) - # split along dim 0 based on output_sizes if isinstance(lora_b, torch.Tensor): - output_sizes = self.base_layer.output_sizes lora_b_list = [] start_idx = 0 for output_size in output_sizes: @@ -654,5 +670,56 @@ def set_lora( start_idx = end_idx lora_b = lora_b_list + # Handle list shorter than n_slices: one packed module may cover + # multiple output slices. Greedily match each lora_b entry to + # consecutive output_sizes by dimension, then split accordingly. + if isinstance(lora_a, list) and len(lora_a) < self.n_slices: + expanded_a: list[torch.Tensor | None] = [] + expanded_b: list[torch.Tensor | None] = [] + slice_idx = 0 + for i, (a_i, b_i) in enumerate(zip(lora_a, lora_b)): + if b_i is None: + # Figure out how many slices this None covers + remaining = len(lora_a) - i - 1 + remaining_slices = self.n_slices - slice_idx + count = remaining_slices - remaining + expanded_a.extend([None] * count) + expanded_b.extend([None] * count) + slice_idx += count + else: + b_dim = b_i.shape[0] + # Greedily consume output_sizes until we match b_dim + consumed = 0 + start_slice = slice_idx + while slice_idx < self.n_slices and consumed < b_dim: + consumed += output_sizes[slice_idx] + slice_idx += 1 + if consumed != b_dim: + raise ValueError( + f"Packed LoRA B dimension {b_dim} does not match " + f"the sum of output sizes {consumed} for LoRA {i}." + ) + num_covered = slice_idx - start_slice + if num_covered == 1: + expanded_a.append(a_i) + expanded_b.append(b_i) + else: + # Split lora_b and duplicate lora_a + split_start = 0 + for j in range(start_slice, slice_idx): + sz = output_sizes[j] + expanded_a.append(a_i) + expanded_b.append(b_i[split_start : split_start + sz, :]) + split_start += sz + + # Pad remaining slices with None (e.g. dummy LoRA warmup + # where per-slice dimensions don't span multiple output_sizes) + while len(expanded_a) < self.n_slices: + expanded_a.append(None) + expanded_b.append(None) + + lora_a = expanded_a + lora_b = expanded_b + # Now call parent's set_lora which expects lists super().set_lora(index, lora_a, lora_b)