From cbe122798ac34ca79a8571a40d6d0efe261ae3d0 Mon Sep 17 00:00:00 2001 From: hallerite Date: Wed, 11 Mar 2026 20:25:09 +0000 Subject: [PATCH 1/5] fix qwen3.5 lora slicing Signed-off-by: hallerite --- vllm/lora/layers/column_parallel_linear.py | 2 +- vllm/model_executor/models/qwen3_5.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index eaed6e2265cd..8f5f0f0464b9 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -285,7 +285,7 @@ def can_replace_layer( ) -> bool: return ( type(source_layer) is MergedColumnParallelLinear - and len(packed_modules_list) == 2 + and len(packed_modules_list) == len(source_layer.output_sizes) ) diff --git a/vllm/model_executor/models/qwen3_5.py b/vllm/model_executor/models/qwen3_5.py index 85f455101e3e..be7c852314f6 100644 --- a/vllm/model_executor/models/qwen3_5.py +++ b/vllm/model_executor/models/qwen3_5.py @@ -528,8 +528,9 @@ class Qwen3_5ForCausalLMBase( "v_proj", ], "gate_up_proj": ["gate_proj", "up_proj"], - # GDN fused projections. - "in_proj_qkvz": ["in_proj_qkv", "in_proj_z"], + # GDN fused projections — 4 packed modules to match 4 output_sizes + # in create_qkvz_proj for correct per-slice TP sharding with LoRA. + "in_proj_qkvz": ["in_proj_q", "in_proj_k", "in_proj_v", "in_proj_z"], "in_proj_ba": ["in_proj_b", "in_proj_a"], } @@ -632,7 +633,7 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid) supports_multimodal_pruning = False packed_modules_mapping = Qwen3VLForConditionalGeneration.packed_modules_mapping | { - "in_proj_qkvz": ["in_proj_qkv", "in_proj_z"], + "in_proj_qkvz": ["in_proj_q", "in_proj_k", "in_proj_v", "in_proj_z"], "in_proj_ba": ["in_proj_b", "in_proj_a"], } From b9284d956ff130872466b812062f0367bfc23304 Mon Sep 17 00:00:00 2001 From: hallerite Date: Wed, 11 Mar 2026 22:13:26 +0000 Subject: [PATCH 2/5] fix pre-commit Signed-off-by: hallerite --- vllm/lora/layers/column_parallel_linear.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index 8f5f0f0464b9..6697a2b60199 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -283,10 +283,9 @@ def can_replace_layer( packed_modules_list: list, model_config: PretrainedConfig | None = None, ) -> bool: - return ( - type(source_layer) is MergedColumnParallelLinear - and len(packed_modules_list) == len(source_layer.output_sizes) - ) + return type(source_layer) is MergedColumnParallelLinear and len( + packed_modules_list + ) == len(source_layer.output_sizes) class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): From eb7439fcd65473528a2213e72bcd3f7f989d4162 Mon Sep 17 00:00:00 2001 From: hallerite Date: Fri, 13 Mar 2026 21:12:29 +0000 Subject: [PATCH 3/5] Fix Qwen3.5 LoRA + bitsandbytes compatibility Revert packed_modules_mapping to real HF weight names (in_proj_qkv, in_proj_z) to fix bitsandbytes quant state stacking, and extend MergedColumnParallelLinearVariableSliceWithLoRA to handle the mismatch between packed module count (2) and output_sizes count (4) in GDN layers. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: hallerite --- vllm/lora/layers/column_parallel_linear.py | 75 ++++++++++++++++++++-- vllm/model_executor/models/qwen3_5.py | 7 +- 2 files changed, 73 insertions(+), 9 deletions(-) diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index 6697a2b60199..84758be8a11f 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -612,8 +612,18 @@ def can_replace_layer( if len(packed_modules_list) >= 3: return True - # If packed_modules_list has exactly 2 items, let - # MergedColumnParallelLinearWithLoRA handle it + # Handle mismatch between packed modules count and output_sizes + # count. E.g. Qwen3.5 GDN layers have 2 packed modules + # (in_proj_qkv, in_proj_z) but 4 output_sizes — one HF weight + # covers multiple output slices. + if ( + hasattr(source_layer, "output_sizes") + and len(packed_modules_list) != len(source_layer.output_sizes) + ): + return True + + # If packed_modules_list has exactly 2 items matching output_sizes, + # let MergedColumnParallelLinearWithLoRA handle it if len(packed_modules_list) == 2: return False @@ -632,10 +642,19 @@ def set_lora( lora_a: torch.Tensor | list[torch.Tensor], lora_b: torch.Tensor | list[torch.Tensor], ): - """Override to handle single tensor weights - that need to be split into slices.""" + """Override to handle weights that need to be split into slices. + + Handles three cases: + 1. Single tensor: split lora_b by output_sizes, duplicate lora_a + 2. List matching n_slices: pass through directly + 3. List shorter than n_slices: expand by splitting lora_b entries + that cover multiple output slices (e.g., 2 packed modules + mapping to 4 output_sizes in Qwen3.5 GDN layers) + """ self.reset_lora(index) + output_sizes = self.base_layer.output_sizes + # Handle case where checkpoint has single tensor weights # lora_a shape: (rank, input_size) - same for all slices, duplicate it if isinstance(lora_a, torch.Tensor): @@ -644,7 +663,6 @@ def set_lora( # lora_b shape: (total_output_size, rank) - # split along dim 0 based on output_sizes if isinstance(lora_b, torch.Tensor): - output_sizes = self.base_layer.output_sizes lora_b_list = [] start_idx = 0 for output_size in output_sizes: @@ -653,5 +671,52 @@ def set_lora( start_idx = end_idx lora_b = lora_b_list + # Handle list shorter than n_slices: one packed module may cover + # multiple output slices. Greedily match each lora_b entry to + # consecutive output_sizes by dimension, then split accordingly. + if isinstance(lora_a, list) and len(lora_a) < self.n_slices: + expanded_a: list[torch.Tensor | None] = [] + expanded_b: list[torch.Tensor | None] = [] + slice_idx = 0 + for i, (a_i, b_i) in enumerate(zip(lora_a, lora_b)): + if b_i is None: + # Figure out how many slices this None covers + remaining = len(lora_a) - i - 1 + remaining_slices = self.n_slices - slice_idx + count = remaining_slices - remaining + expanded_a.extend([None] * count) + expanded_b.extend([None] * count) + slice_idx += count + else: + b_dim = b_i.shape[0] + # Greedily consume output_sizes until we match b_dim + consumed = 0 + start_slice = slice_idx + while slice_idx < self.n_slices and consumed < b_dim: + consumed += output_sizes[slice_idx] + slice_idx += 1 + num_covered = slice_idx - start_slice + if num_covered == 1: + expanded_a.append(a_i) + expanded_b.append(b_i) + else: + # Split lora_b and duplicate lora_a + split_start = 0 + for j in range(start_slice, slice_idx): + sz = output_sizes[j] + expanded_a.append(a_i) + expanded_b.append( + b_i[split_start:split_start + sz, :]) + split_start += sz + + # Pad remaining slices with None (e.g. dummy LoRA warmup + # where per-slice dimensions don't span multiple output_sizes) + while len(expanded_a) < self.n_slices: + expanded_a.append(None) + expanded_b.append(None) + + lora_a = expanded_a + lora_b = expanded_b + # Now call parent's set_lora which expects lists super().set_lora(index, lora_a, lora_b) diff --git a/vllm/model_executor/models/qwen3_5.py b/vllm/model_executor/models/qwen3_5.py index be7c852314f6..85f455101e3e 100644 --- a/vllm/model_executor/models/qwen3_5.py +++ b/vllm/model_executor/models/qwen3_5.py @@ -528,9 +528,8 @@ class Qwen3_5ForCausalLMBase( "v_proj", ], "gate_up_proj": ["gate_proj", "up_proj"], - # GDN fused projections — 4 packed modules to match 4 output_sizes - # in create_qkvz_proj for correct per-slice TP sharding with LoRA. - "in_proj_qkvz": ["in_proj_q", "in_proj_k", "in_proj_v", "in_proj_z"], + # GDN fused projections. + "in_proj_qkvz": ["in_proj_qkv", "in_proj_z"], "in_proj_ba": ["in_proj_b", "in_proj_a"], } @@ -633,7 +632,7 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid) supports_multimodal_pruning = False packed_modules_mapping = Qwen3VLForConditionalGeneration.packed_modules_mapping | { - "in_proj_qkvz": ["in_proj_q", "in_proj_k", "in_proj_v", "in_proj_z"], + "in_proj_qkvz": ["in_proj_qkv", "in_proj_z"], "in_proj_ba": ["in_proj_b", "in_proj_a"], } From 6005bc56e47e48e3bd84759026feb0e221ffc27d Mon Sep 17 00:00:00 2001 From: hallerite Date: Fri, 13 Mar 2026 22:25:23 +0000 Subject: [PATCH 4/5] run precommit Signed-off-by: hallerite --- vllm/lora/layers/column_parallel_linear.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index 84758be8a11f..2ced3f5f42ba 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -616,9 +616,8 @@ def can_replace_layer( # count. E.g. Qwen3.5 GDN layers have 2 packed modules # (in_proj_qkv, in_proj_z) but 4 output_sizes — one HF weight # covers multiple output slices. - if ( - hasattr(source_layer, "output_sizes") - and len(packed_modules_list) != len(source_layer.output_sizes) + if hasattr(source_layer, "output_sizes") and len(packed_modules_list) != len( + source_layer.output_sizes ): return True @@ -705,8 +704,7 @@ def set_lora( for j in range(start_slice, slice_idx): sz = output_sizes[j] expanded_a.append(a_i) - expanded_b.append( - b_i[split_start:split_start + sz, :]) + expanded_b.append(b_i[split_start : split_start + sz, :]) split_start += sz # Pad remaining slices with None (e.g. dummy LoRA warmup From 8d6e14eedfc965a5710c4924ace207edeb7040dd Mon Sep 17 00:00:00 2001 From: hallerite Date: Fri, 13 Mar 2026 23:23:16 +0000 Subject: [PATCH 5/5] validate greedy output_sizes consumption in VariableSlice.set_lora Assert that consumed dimensions exactly match lora_b's shape after greedily matching output_sizes. Prevents silent data corruption if dimensions don't align. Signed-off-by: hallerite Co-Authored-By: Claude Opus 4.6 (1M context) --- vllm/lora/layers/column_parallel_linear.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index 2ced3f5f42ba..dd6646262a4e 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -694,6 +694,11 @@ def set_lora( while slice_idx < self.n_slices and consumed < b_dim: consumed += output_sizes[slice_idx] slice_idx += 1 + if consumed != b_dim: + raise ValueError( + f"Packed LoRA B dimension {b_dim} does not match " + f"the sum of output sizes {consumed} for LoRA {i}." + ) num_covered = slice_idx - start_slice if num_covered == 1: expanded_a.append(a_i)