Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 49 additions & 11 deletions vllm/lora/layers/column_parallel_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ def __init__(
# There are two LoRA layers
# the output_sizes in MergedColumnParallelLinear is not sharded by tp
# we need to divide it by the tp_size to get correct slices size
output_sizes = self.base_layer.output_sizes
self.output_sizes = self.base_layer.output_sizes
self.output_slices = tuple(
divide(output_size, self.tp_size) for output_size in output_sizes
divide(output_size, self.tp_size) for output_size in self.output_sizes
)
self.n_slices = len(self.output_slices)
self.output_ids = (self.tp_rank,) * self.n_slices
Expand Down Expand Up @@ -253,6 +253,42 @@ def slice_lora_b(
]
return sliced_lora_b

def expand_packed_lora(
self,
lora_a: list[torch.Tensor],
lora_b: list[torch.Tensor],
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""
Expand packed adapter groups when they don't match n_slices.
E.g. in_proj_qkv (covers Q+K+V) + in_proj_z
"""
expanded_a: list[torch.Tensor] = []
expanded_b: list[torch.Tensor] = []
start_idx = 0
for a_i, b_i in zip(lora_a, lora_b):
# Determine which output slices this b_i covers.
b_rows, cu_rows, covered = b_i.shape[0], 0, 0
for i in range(start_idx, self.n_slices):
cu_rows += self.output_sizes[i]
if cu_rows == b_rows:
covered = i - start_idx + 1
break
else:
raise ValueError(
f"Cannot determine how to split lora_b with {b_rows} rows "
f"into {self.n_slices} slices with output sizes "
f"{self.output_sizes} starting from index {start_idx}."
)
# Split b_i into per-slice tensors and replicate a_i for each.
start = 0
for j in range(covered):
size = self.output_sizes[start_idx + j]
expanded_b.append(b_i[start : start + size, :])
expanded_a.append(a_i)
start += size
start_idx += covered
return expanded_a, expanded_b

def set_lora(
self,
index: int,
Expand All @@ -261,6 +297,12 @@ def set_lora(
):
self.reset_lora(index)

# Expand packed adapter groups when they don't match n_slices.
# E.g. in_proj_qkv (covers Q+K+V) + in_proj_z as 2 groups for a
# 4-slice layer: split b_qkv by output_sizes and replicate a_qkv.
if isinstance(lora_b, list) and len(lora_b) != self.n_slices:
lora_a, lora_b = self.expand_packed_lora(lora_a, lora_b)

if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
Expand Down Expand Up @@ -467,18 +509,14 @@ class MergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithLo
def slice_lora_a(
self, lora_a: list[torch.Tensor | None]
) -> list[torch.Tensor | None]:
# NOTE: lora_a contains 2 subloras, and each sublora could be None.
output_shard_size = self.lora_a_stacked[0].shape[2]
output_start_idx = self.tp_rank * output_shard_size
lora_a = [
lora_a[0][output_start_idx : output_start_idx + output_shard_size, :]
if lora_a[0] is not None
else None,
lora_a[1][output_start_idx : output_start_idx + output_shard_size, :]
if lora_a[1] is not None
else None,
return [
lora_a_i[output_start_idx : output_start_idx + output_shard_size, :]
if (lora_a_i := lora_a[i]) is not None
else None
for i in range(len(lora_a))
]
return lora_a

def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
Expand Down
5 changes: 5 additions & 0 deletions vllm/lora/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,12 @@ def create_dummy_lora(
else:
parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]]
n_slices = getattr(module, "n_slices", len(replacements))
subloras: list[LoRALayerWeights | None] = []
# HACK: overrides replacements for qkvz = qkv + z case.
# Any better methods to handle this case?
if n_slices != len(replacements):
replacements = [f"slice_{i}" for i in range(n_slices)]
Comment on lines +559 to +562
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of a 'HACK' comment here is concerning as it suggests the solution is not robust and could lead to future maintenance issues. Code with 'HACK' comments is often difficult to understand and easy to break.

If this logic is indeed the correct and necessary approach for handling dummy LoRA creation for packed modules like in_proj_qkvz, please replace the 'HACK' comment with a more detailed explanation. The explanation should clarify:

  1. Why there's a mismatch between n_slices and len(replacements).
  2. Why generating generic slice_i names is the appropriate solution for creating dummy LoRAs in this scenario.
  3. How this interacts with the loading of real LoRA weights.

A clear explanation will improve code maintainability and prevent future confusion.

Alternatively, if a more robust, less 'hacky' solution is possible (perhaps by making the relationship between packed modules and slices more explicit in the model configuration), that would be preferable.

for i, r in enumerate(replacements):
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name + "." + r,
Expand Down
104 changes: 13 additions & 91 deletions vllm/model_executor/models/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
GemmaRMSNorm as Qwen3_5RMSNorm,
)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Expand Down Expand Up @@ -131,40 +130,6 @@ def fix_query_key_value_ordering(
"Qwen3.5 Series dont need to fix query key value ordering"
)

def __init__(
self,
config: Qwen3_5Config,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
create_in_proj_qkvz = vllm_config.lora_config is None
super().__init__(
config,
vllm_config=vllm_config,
prefix=prefix,
create_in_proj_qkvz=create_in_proj_qkvz,
)
if vllm_config.lora_config is not None:
# Separate in_proj_qkv (Q,K,V) and in_proj_z for LoRA compatibility.
# Use MergedColumnParallelLinear for in_proj_qkv because GDN can have
# linear_num_key_heads != linear_num_value_heads (e.g. 16 vs 32), so
# output sizes [key_dim, key_dim, value_dim] are not representable
# with a single QKVParallelLinear (which ties K and V head counts).
self.in_proj_qkv = MergedColumnParallelLinear(
input_size=self.hidden_size,
output_sizes=[self.key_dim, self.key_dim, self.value_dim],
bias=False,
quant_config=vllm_config.quant_config,
prefix=f"{prefix}.in_proj_qkv",
)
self.in_proj_z = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.value_dim,
bias=False,
quant_config=vllm_config.quant_config,
prefix=f"{prefix}.in_proj_z",
)

def create_qkvz_proj(
self,
hidden_size: int,
Expand Down Expand Up @@ -215,21 +180,15 @@ def forward(
# ============================================================
# Part 1: Input Projection
# ============================================================
if hasattr(self, "in_proj_qkv"):
# LoRA path: separate in_proj_qkv and in_proj_z
mixed_qkv, _ = self.in_proj_qkv(hidden_states)
ba, _ = self.in_proj_ba(hidden_states)
z, _ = self.in_proj_z(hidden_states)
else:
mixed_qkvz, ba = torch.ops.vllm.gdn_in_proj(
hidden_states,
sum(self.in_proj_qkvz.output_sizes) // self.tp_size,
sum(self.in_proj_ba.output_sizes) // self.tp_size,
self.prefix,
)
qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size
z_size = self.value_dim // self.tp_size
mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
mixed_qkvz, ba = torch.ops.vllm.gdn_in_proj(
hidden_states,
sum(self.in_proj_qkvz.output_sizes) // self.tp_size,
sum(self.in_proj_ba.output_sizes) // self.tp_size,
self.prefix,
)
qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size
z_size = self.value_dim // self.tp_size
mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
z = z.reshape(z.size(0), -1, self.head_v_dim)
b, a = ba.chunk(2, dim=-1)

Expand Down Expand Up @@ -368,7 +327,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.num_redundant_experts = eplb_config.num_redundant_experts

self.config = config
self.enable_lora = vllm_config.lora_config is not None

self.vocab_size = config.vocab_size

Expand Down Expand Up @@ -427,6 +385,9 @@ def load_fused_expert_weights(
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
# GDN
("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)),
("in_proj_qkvz", "in_proj_z", 3),
# self attention
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
Expand All @@ -438,21 +399,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
("in_proj_ba", "in_proj_a", 1),
]

if self.enable_lora:
stacked_params_mapping.extend(
[
("in_proj_qkv", "in_proj_qkv", (0, 1, 2)),
("in_proj_z", "in_proj_z", 0),
]
)
else:
stacked_params_mapping.extend(
[
("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)),
("in_proj_qkvz", "in_proj_z", 3),
]
)

params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
Expand Down Expand Up @@ -500,10 +446,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
continue
param = params_dict[name]
weight_loader = param.weight_loader
if param_name == "in_proj_z" and self.enable_lora:
weight_loader(param, loaded_weight)
else:
weight_loader(param, loaded_weight, shard_id)
weight_loader(param, loaded_weight, shard_id)
break
else:
is_expert_weight = False
Expand Down Expand Up @@ -633,15 +576,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)

# When LoRA is enabled, GDN uses separate in_proj_qkv and in_proj_z
# instead of merged in_proj_qkvz; pack mapping must match.
if vllm_config.lora_config:
base = getattr(Qwen3_5ForCausalLMBase, "packed_modules_mapping", {})
self.packed_modules_mapping = {k: list(v) for k, v in base.items()}
self.packed_modules_mapping.pop("in_proj_qkvz", None)
self.packed_modules_mapping["in_proj_qkv"] = ["in_proj_qkv"]
self.packed_modules_mapping["in_proj_z"] = ["in_proj_z"]

if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
Expand Down Expand Up @@ -734,7 +668,6 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
# protocols have not __init__ method, so we need to use nn.Module.__init__
nn.Module.__init__(self)
self.update_packed_mapping(enable_lora=vllm_config.lora_config is not None)
config: Qwen3_5Config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
Expand Down Expand Up @@ -762,16 +695,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
self.language_model.make_empty_intermediate_tensors
)

def update_packed_mapping(self, enable_lora: bool):
# When LoRA is enabled, GDN uses separate in_proj_qkv and in_proj_z
if enable_lora:
base = getattr(
Qwen3_5ForConditionalGeneration, "packed_modules_mapping", {}
)
self.packed_modules_mapping = {k: list(v) for k, v in base.items()}
self.packed_modules_mapping.pop("in_proj_qkvz", None)
self.packed_modules_mapping["in_proj_qkv"] = ["in_proj_qkv"]

def embed_input_ids(
self,
input_ids: torch.Tensor,
Expand Down Expand Up @@ -958,7 +881,6 @@ class Qwen3_5MoeForConditionalGeneration(
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
# protocols have not __init__ method, so we need to use nn.Module.__init__
nn.Module.__init__(self)
self.update_packed_mapping(enable_lora=vllm_config.lora_config is not None)
config: Qwen3_5MoeConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
Expand Down
16 changes: 7 additions & 9 deletions vllm/model_executor/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,6 @@ def __init__(
config: Qwen3NextConfig,
vllm_config: VllmConfig,
prefix: str = "",
create_in_proj_qkvz: bool = True,
) -> None:
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
Expand Down Expand Up @@ -454,14 +453,13 @@ def __init__(
# we need to create qkvz_proj adaptively here.
# When create_in_proj_qkvz is False (e.g. LoRA enabled in Qwen3.5),
# the subclass creates in_proj_qkv and in_proj_z separately.
if create_in_proj_qkvz:
self.in_proj_qkvz = self.create_qkvz_proj(
hidden_size=self.hidden_size,
key_dim=self.key_dim,
value_dim=self.value_dim,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_qkvz",
)
self.in_proj_qkvz = self.create_qkvz_proj(
hidden_size=self.hidden_size,
key_dim=self.key_dim,
value_dim=self.value_dim,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_qkvz",
)
# ba_proj doesn't support blockwise fp8 quantization.
# Qwen3-Next and Qwen3.5 have different in_proj_ba checkpoint
# layouts, so we use a factory method to create the projection.
Expand Down
Loading