Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 76 additions & 9 deletions vllm/lora/layers/column_parallel_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,9 @@ def can_replace_layer(
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return (
type(source_layer) is MergedColumnParallelLinear
and len(packed_modules_list) == 2
)
return type(source_layer) is MergedColumnParallelLinear and len(
packed_modules_list
) == len(source_layer.output_sizes)


class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
Expand Down Expand Up @@ -613,8 +612,17 @@ def can_replace_layer(
if len(packed_modules_list) >= 3:
return True

# If packed_modules_list has exactly 2 items, let
# MergedColumnParallelLinearWithLoRA handle it
# Handle mismatch between packed modules count and output_sizes
# count. E.g. Qwen3.5 GDN layers have 2 packed modules
# (in_proj_qkv, in_proj_z) but 4 output_sizes — one HF weight
# covers multiple output slices.
if hasattr(source_layer, "output_sizes") and len(packed_modules_list) != len(
source_layer.output_sizes
):
return True

# If packed_modules_list has exactly 2 items matching output_sizes,
# let MergedColumnParallelLinearWithLoRA handle it
if len(packed_modules_list) == 2:
return False

Expand All @@ -633,10 +641,19 @@ def set_lora(
lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor | list[torch.Tensor],
):
"""Override to handle single tensor weights
that need to be split into slices."""
"""Override to handle weights that need to be split into slices.

Handles three cases:
1. Single tensor: split lora_b by output_sizes, duplicate lora_a
2. List matching n_slices: pass through directly
3. List shorter than n_slices: expand by splitting lora_b entries
that cover multiple output slices (e.g., 2 packed modules
mapping to 4 output_sizes in Qwen3.5 GDN layers)
"""
self.reset_lora(index)

output_sizes = self.base_layer.output_sizes

# Handle case where checkpoint has single tensor weights
# lora_a shape: (rank, input_size) - same for all slices, duplicate it
if isinstance(lora_a, torch.Tensor):
Expand All @@ -645,7 +662,6 @@ def set_lora(
# lora_b shape: (total_output_size, rank) -
# split along dim 0 based on output_sizes
if isinstance(lora_b, torch.Tensor):
output_sizes = self.base_layer.output_sizes
lora_b_list = []
start_idx = 0
for output_size in output_sizes:
Expand All @@ -654,5 +670,56 @@ def set_lora(
start_idx = end_idx
lora_b = lora_b_list

# Handle list shorter than n_slices: one packed module may cover
# multiple output slices. Greedily match each lora_b entry to
# consecutive output_sizes by dimension, then split accordingly.
if isinstance(lora_a, list) and len(lora_a) < self.n_slices:
expanded_a: list[torch.Tensor | None] = []
expanded_b: list[torch.Tensor | None] = []
slice_idx = 0
for i, (a_i, b_i) in enumerate(zip(lora_a, lora_b)):
if b_i is None:
# Figure out how many slices this None covers
remaining = len(lora_a) - i - 1
remaining_slices = self.n_slices - slice_idx
count = remaining_slices - remaining
expanded_a.extend([None] * count)
expanded_b.extend([None] * count)
slice_idx += count
else:
b_dim = b_i.shape[0]
# Greedily consume output_sizes until we match b_dim
consumed = 0
start_slice = slice_idx
while slice_idx < self.n_slices and consumed < b_dim:
consumed += output_sizes[slice_idx]
slice_idx += 1
if consumed != b_dim:
raise ValueError(
f"Packed LoRA B dimension {b_dim} does not match "
f"the sum of output sizes {consumed} for LoRA {i}."
)
num_covered = slice_idx - start_slice
if num_covered == 1:
expanded_a.append(a_i)
expanded_b.append(b_i)
else:
# Split lora_b and duplicate lora_a
split_start = 0
for j in range(start_slice, slice_idx):
sz = output_sizes[j]
expanded_a.append(a_i)
expanded_b.append(b_i[split_start : split_start + sz, :])
split_start += sz

# Pad remaining slices with None (e.g. dummy LoRA warmup
# where per-slice dimensions don't span multiple output_sizes)
while len(expanded_a) < self.n_slices:
expanded_a.append(None)
expanded_b.append(None)

lora_a = expanded_a
lora_b = expanded_b

# Now call parent's set_lora which expects lists
super().set_lora(index, lora_a, lora_b)
Loading