Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .buildkite/test_areas/lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ steps:
- vllm/lora
- tests/lora
commands:
- pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_llm_with_multi_loras.py --ignore=lora/test_olmoe_tp.py --ignore=lora/test_deepseekv2_tp.py --ignore=lora/test_gptoss_tp.py --ignore=lora/test_qwen3moe_tp.py
- pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_llm_with_multi_loras.py --ignore=lora/test_olmoe_tp.py --ignore=lora/test_deepseekv2_tp.py --ignore=lora/test_gptoss_tp.py --ignore=lora/test_qwen3moe_tp.py --ignore=lora/test_qwen35_densemoel_lora.py
parallelism: 4


Expand All @@ -30,4 +30,5 @@ steps:
- pytest -v -s -x lora/test_llama_tp.py
- pytest -v -s -x lora/test_llm_with_multi_loras.py
- pytest -v -s -x lora/test_olmoe_tp.py
- pytest -v -s -x lora/test_gptoss_tp.py
- pytest -v -s -x lora/test_gptoss_tp.py
- pytest -v -s -x lora/test_qwen35_densemoel_lora.py
5 changes: 5 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,11 @@ def whisper_lora_files():
return snapshot_download(repo_id="chengyili2005/whisper-small-mandarin-lora")


@pytest.fixture(scope="session")
def qwen35_dense_model_lora_files():
return snapshot_download(repo_id="jeeejeee/qwen35-4b-text-only-sql-lora")


@pytest.fixture
def reset_default_device():
"""
Expand Down
132 changes: 132 additions & 0 deletions tests/lora/test_qwen35_densemoel_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from transformers import AutoTokenizer

import vllm
import vllm.config
from vllm.lora.request import LoRARequest

from ..utils import create_new_process_for_each_test, multi_gpu_test

MODEL_PATH = "Qwen/Qwen3.5-4B"

PROMPT_TEMPLATE = """Write a SQL query for the given database.\nSchema:\nTables:\n - stadium(Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average)\n - singer(Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male)\n - concert(concert_ID, concert_Name, Theme, Stadium_ID, Year)\n - singer_in_concert(concert_ID, Singer_ID)\n\nQuestion:\n{query}""" # noqa: E501

EXPECTED_LORA_OUTPUT = [
"SELECT count(*) FROM singer",
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'",
"SELECT name FROM stadium WHERE stadium_id NOT IN (SELECT stadium_id FROM concert)",
]


tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)


def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
prompts = [
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
PROMPT_TEMPLATE.format(
query=(
"What is the average, minimum, and maximum "
"age of all singers from France?"
)
),
PROMPT_TEMPLATE.format(
query=("What are the names of the stadiums without any concerts?")
),
]
input_templates = []
for prmpt in prompts:
messages = [{"role": "user", "content": prmpt}]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False, # disable thinking
)
input_templates.append(prompt)
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=512)
outputs = llm.generate(
input_templates,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
)

generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts


@create_new_process_for_each_test()
def test_qwen35_dense_model_lora(qwen35_dense_model_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=512,
enable_lora=True,
max_loras=2,
max_num_seqs=16,
max_lora_rank=8,
trust_remote_code=True,
)

output1 = do_sample(llm, qwen35_dense_model_lora_files, lora_id=1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
output2 = do_sample(llm, qwen35_dense_model_lora_files, lora_id=2)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]


@multi_gpu_test(num_gpus=4)
def test_qwen35_dense_model_lora_tp4(qwen35_dense_model_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=2,
max_lora_rank=8,
max_num_seqs=16,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=False,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)

output1 = do_sample(llm, qwen35_dense_model_lora_files, lora_id=1)
print(output1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
output2 = do_sample(llm, qwen35_dense_model_lora_files, lora_id=2)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]


@multi_gpu_test(num_gpus=4)
def test_qwen35_dense_model_lora_tp4_fully_sharded_loras(qwen35_dense_model_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=512,
enable_lora=True,
max_loras=2,
max_lora_rank=8,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=True,
gpu_memory_utilization=0.8,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)
output1 = do_sample(llm, qwen35_dense_model_lora_files, lora_id=1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
output2 = do_sample(llm, qwen35_dense_model_lora_files, lora_id=2)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
123 changes: 100 additions & 23 deletions vllm/model_executor/models/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,18 @@
from torch import nn

from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
VllmConfig,
)
from vllm.config import VllmConfig
from vllm.distributed import (
get_pp_group,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import (
GemmaRMSNorm as Qwen3_5RMSNorm,
)
from vllm.model_executor.layers.linear import MergedColumnParallelLinear
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
Expand Down Expand Up @@ -130,6 +131,40 @@ def fix_query_key_value_ordering(
"Qwen3.5 Series dont need to fix query key value ordering"
)

def __init__(
self,
config: Qwen3_5Config,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
create_in_proj_qkvz = vllm_config.lora_config is None
super().__init__(
config,
vllm_config=vllm_config,
prefix=prefix,
create_in_proj_qkvz=create_in_proj_qkvz,
)
if vllm_config.lora_config is not None:
# Separate in_proj_qkv (Q,K,V) and in_proj_z for LoRA compatibility.
# Use MergedColumnParallelLinear for in_proj_qkv because GDN can have
# linear_num_key_heads != linear_num_value_heads (e.g. 16 vs 32), so
# output sizes [key_dim, key_dim, value_dim] are not representable
# with a single QKVParallelLinear (which ties K and V head counts).
self.in_proj_qkv = MergedColumnParallelLinear(
input_size=self.hidden_size,
output_sizes=[self.key_dim, self.key_dim, self.value_dim],
bias=False,
quant_config=vllm_config.quant_config,
prefix=f"{prefix}.in_proj_qkv",
)
self.in_proj_z = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.value_dim,
bias=False,
quant_config=vllm_config.quant_config,
prefix=f"{prefix}.in_proj_z",
)

def create_qkvz_proj(
self,
hidden_size: int,
Expand Down Expand Up @@ -180,15 +215,21 @@ def forward(
# ============================================================
# Part 1: Input Projection
# ============================================================
mixed_qkvz, ba = torch.ops.vllm.gdn_in_proj(
hidden_states,
sum(self.in_proj_qkvz.output_sizes) // self.tp_size,
sum(self.in_proj_ba.output_sizes) // self.tp_size,
self.prefix,
)
qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size
z_size = self.value_dim // self.tp_size
mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
if hasattr(self, "in_proj_qkv"):
# LoRA path: separate in_proj_qkv and in_proj_z
mixed_qkv, _ = self.in_proj_qkv(hidden_states)
ba, _ = self.in_proj_ba(hidden_states)
z, _ = self.in_proj_z(hidden_states)
else:
mixed_qkvz, ba = torch.ops.vllm.gdn_in_proj(
hidden_states,
sum(self.in_proj_qkvz.output_sizes) // self.tp_size,
sum(self.in_proj_ba.output_sizes) // self.tp_size,
self.prefix,
)
qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size
z_size = self.value_dim // self.tp_size
mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
z = z.reshape(z.size(0), -1, self.head_v_dim)
b, a = ba.chunk(2, dim=-1)

Expand Down Expand Up @@ -240,18 +281,14 @@ def __init__(
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
speculative_config = vllm_config.speculative_config

self.layer_type = layer_type
self.layer_idx = extract_layer_index(prefix)

if self.layer_type == "linear_attention":
self.linear_attn = Qwen3_5GatedDeltaNet(
config,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
speculative_config=speculative_config,
config=config,
vllm_config=vllm_config,
prefix=f"{prefix}.linear_attn",
)
elif self.layer_type == "full_attention":
Expand Down Expand Up @@ -331,6 +368,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.num_redundant_experts = eplb_config.num_redundant_experts

self.config = config
self.enable_lora = vllm_config.lora_config is not None

self.vocab_size = config.vocab_size

Expand Down Expand Up @@ -396,13 +434,25 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
# mlp
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
# GDN
("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)),
("in_proj_qkvz", "in_proj_z", 3),
("in_proj_ba", "in_proj_b", 0),
("in_proj_ba", "in_proj_a", 1),
]

if self.enable_lora:
stacked_params_mapping.extend(
[
("in_proj_qkv", "in_proj_qkv", (0, 1, 2)),
("in_proj_z", "in_proj_z", 0),
]
)
else:
stacked_params_mapping.extend(
[
("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)),
("in_proj_qkvz", "in_proj_z", 3),
]
)

params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
Expand Down Expand Up @@ -450,7 +500,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
if param_name == "in_proj_z" and self.enable_lora:
weight_loader(param, loaded_weight)
else:
weight_loader(param, loaded_weight, shard_id)
break
else:
is_expert_weight = False
Expand Down Expand Up @@ -580,6 +633,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)

# When LoRA is enabled, GDN uses separate in_proj_qkv and in_proj_z
# instead of merged in_proj_qkvz; pack mapping must match.
if vllm_config.lora_config:
base = getattr(Qwen3_5ForCausalLMBase, "packed_modules_mapping", {})
self.packed_modules_mapping = {k: list(v) for k, v in base.items()}
self.packed_modules_mapping.pop("in_proj_qkvz", None)
self.packed_modules_mapping["in_proj_qkv"] = ["in_proj_qkv"]
self.packed_modules_mapping["in_proj_z"] = ["in_proj_z"]

if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
Expand Down Expand Up @@ -672,6 +734,7 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
# protocols have not __init__ method, so we need to use nn.Module.__init__
nn.Module.__init__(self)
self.update_packed_mapping(enable_lora=vllm_config.lora_config is not None)
config: Qwen3_5Config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
Expand Down Expand Up @@ -699,6 +762,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
self.language_model.make_empty_intermediate_tensors
)

def update_packed_mapping(self, enable_lora: bool):
# When LoRA is enabled, GDN uses separate in_proj_qkv and in_proj_z
if enable_lora:
base = getattr(
Qwen3_5ForConditionalGeneration, "packed_modules_mapping", {}
)
self.packed_modules_mapping = {k: list(v) for k, v in base.items()}
self.packed_modules_mapping.pop("in_proj_qkvz", None)
self.packed_modules_mapping["in_proj_qkv"] = ["in_proj_qkv"]

def embed_input_ids(
self,
input_ids: torch.Tensor,
Expand Down Expand Up @@ -879,9 +952,13 @@ def set_moe_parameters(self):
class Qwen3_5MoeForConditionalGeneration(
Qwen3_5ForConditionalGeneration, Qwen3_5_MoeMixtureOfExperts
):
# For MoE LoRA weights loading
is_3d_moe_weight: bool = True

def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
# protocols have not __init__ method, so we need to use nn.Module.__init__
nn.Module.__init__(self)
self.update_packed_mapping(enable_lora=vllm_config.lora_config is not None)
config: Qwen3_5MoeConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
Expand Down
Loading
Loading