Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 86 additions & 32 deletions vllm/lora/layers/column_parallel_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,9 @@ def can_replace_layer(
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return (
type(source_layer) is MergedColumnParallelLinear
and len(packed_modules_list) == 2
)
return type(source_layer) is MergedColumnParallelLinear and len(
packed_modules_list
) == len(source_layer.output_sizes)


class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
Expand Down Expand Up @@ -466,18 +465,14 @@ class MergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithLo
def slice_lora_a(
self, lora_a: list[torch.Tensor | None]
) -> list[torch.Tensor | None]:
# NOTE: lora_a contains 2 subloras, and each sublora could be None.
output_shard_size = self.lora_a_stacked[0].shape[2]
output_start_idx = self.tp_rank * output_shard_size
lora_a = [
lora_a[0][output_start_idx : output_start_idx + output_shard_size, :]
if lora_a[0] is not None
else None,
lora_a[1][output_start_idx : output_start_idx + output_shard_size, :]
if lora_a[1] is not None
else None,
return [
a[output_start_idx : output_start_idx + output_shard_size, :]
if a is not None
else None
for a in lora_a
]
return lora_a

def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
Expand Down Expand Up @@ -548,21 +543,12 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
def slice_lora_a(
self, lora_a: list[torch.Tensor | None]
) -> list[torch.Tensor | None]:
# NOTE: lora_a contains 3 subloras, and each sublora could be None.
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
lora_a = [
lora_a[0][start_idx[0] : start_idx[0] + shard_size[0], :]
if lora_a[0] is not None
else None,
lora_a[1][start_idx[1] : start_idx[1] + shard_size[1], :]
if lora_a[1] is not None
else None,
lora_a[2][start_idx[2] : start_idx[2] + shard_size[2], :]
if lora_a[2] is not None
else None,
return [
a[self.tp_rank * s.shape[2] : (self.tp_rank + 1) * s.shape[2], :]
if a is not None
else None
for a, s in zip(lora_a, self.lora_a_stacked)
]
return lora_a

def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
Expand Down Expand Up @@ -613,8 +599,17 @@ def can_replace_layer(
if len(packed_modules_list) >= 3:
return True

# If packed_modules_list has exactly 2 items, let
# MergedColumnParallelLinearWithLoRA handle it
# Handle mismatch between packed modules count and output_sizes
# count. E.g. Qwen3.5 GDN layers have 2 packed modules
# (in_proj_qkv, in_proj_z) but 4 output_sizes — one HF weight
# covers multiple output slices.
if hasattr(source_layer, "output_sizes") and len(packed_modules_list) != len(
source_layer.output_sizes
):
return True

# If packed_modules_list has exactly 2 items matching output_sizes,
# let MergedColumnParallelLinearWithLoRA handle it
if len(packed_modules_list) == 2:
return False

Expand All @@ -633,10 +628,19 @@ def set_lora(
lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor | list[torch.Tensor],
):
"""Override to handle single tensor weights
that need to be split into slices."""
"""Override to handle weights that need to be split into slices.

Handles three cases:
1. Single tensor: split lora_b by output_sizes, duplicate lora_a
2. List matching n_slices: pass through directly
3. List shorter than n_slices: expand by splitting lora_b entries
that cover multiple output slices (e.g., 2 packed modules
mapping to 4 output_sizes in Qwen3.5 GDN layers)
"""
self.reset_lora(index)

output_sizes = self.base_layer.output_sizes

# Handle case where checkpoint has single tensor weights
# lora_a shape: (rank, input_size) - same for all slices, duplicate it
if isinstance(lora_a, torch.Tensor):
Expand All @@ -645,7 +649,6 @@ def set_lora(
# lora_b shape: (total_output_size, rank) -
# split along dim 0 based on output_sizes
if isinstance(lora_b, torch.Tensor):
output_sizes = self.base_layer.output_sizes
lora_b_list = []
start_idx = 0
for output_size in output_sizes:
Expand All @@ -654,5 +657,56 @@ def set_lora(
start_idx = end_idx
lora_b = lora_b_list

# Handle list shorter than n_slices: one packed module may cover
# multiple output slices. Greedily match each lora_b entry to
# consecutive output_sizes by dimension, then split accordingly.
if isinstance(lora_a, list) and len(lora_a) < self.n_slices:
expanded_a: list[torch.Tensor | None] = []
expanded_b: list[torch.Tensor | None] = []
slice_idx = 0
for i, (a_i, b_i) in enumerate(zip(lora_a, lora_b)):
if b_i is None:
# Figure out how many slices this None covers
remaining = len(lora_a) - i - 1
remaining_slices = self.n_slices - slice_idx
count = remaining_slices - remaining
expanded_a.extend([None] * count)
expanded_b.extend([None] * count)
slice_idx += count
else:
b_dim = b_i.shape[0]
# Greedily consume output_sizes until we match b_dim
consumed = 0
start_slice = slice_idx
while slice_idx < self.n_slices and consumed < b_dim:
consumed += output_sizes[slice_idx]
slice_idx += 1
if consumed != b_dim:
raise ValueError(
f"Packed LoRA B dimension {b_dim} does not match "
f"the sum of output sizes {consumed} for LoRA {i}."
)
num_covered = slice_idx - start_slice
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The greedy consumption of output_sizes to match b_dim is a good approach. However, the loop condition consumed < b_dim can result in consumed > b_dim if the output_sizes don't sum up exactly to b_dim. This would lead to an IndexError in the subsequent slicing of b_i.

To prevent this, it's safer to assert that consumed is exactly equal to b_dim after the loop. This will ensure the dimensions are valid and provide a more informative error message if they are not.

if consumed != b_dim:
    raise ValueError(
        f"Packed LoRA B dimension {b_dim} does not match "
        f"the sum of output sizes {consumed} for LoRA {i}."
    )
num_covered = slice_idx - start_slice

if num_covered == 1:
expanded_a.append(a_i)
expanded_b.append(b_i)
else:
# Split lora_b and duplicate lora_a
split_start = 0
for j in range(start_slice, slice_idx):
sz = output_sizes[j]
expanded_a.append(a_i)
expanded_b.append(b_i[split_start : split_start + sz, :])
split_start += sz

# Pad remaining slices with None (e.g. dummy LoRA warmup
# where per-slice dimensions don't span multiple output_sizes)
while len(expanded_a) < self.n_slices:
expanded_a.append(None)
expanded_b.append(None)

lora_a = expanded_a
lora_b = expanded_b

# Now call parent's set_lora which expects lists
super().set_lora(index, lora_a, lora_b)
Loading