Skip to content

Commit 19e3c56

Browse files
YamPengLiDarkLight1337
authored andcommitted
[Model] Add Qwen3 and Qwen3MoE (vllm-project#15289)
Signed-off-by: YamPengLi <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent 7ca9c8e commit 19e3c56

File tree

6 files changed

+893
-5
lines changed

6 files changed

+893
-5
lines changed

docs/source/models/supported_models.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,16 @@ See [this page](#generative-models) for more information on how to use generativ
478478
* `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.
479479
*
480480
* ✅︎
481+
- * `Qwen3ForCausalLM`
482+
* Qwen3
483+
* `Qwen/Qwen3-8B`, etc.
484+
* ✅︎
485+
* ✅︎
486+
- * `Qwen3MoeForCausalLM`
487+
* Qwen3MoE
488+
* `Qwen/Qwen3-MoE-15B-A2B`, etc.
489+
* ✅︎
490+
* ✅︎
481491
- * `StableLmForCausalLM`
482492
* StableLM
483493
* `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.

tests/models/registry.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,16 @@ def check_available_online(
202202
"Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct",
203203
extras={"2.5": "Qwen/Qwen2.5-7B-Instruct"}), # noqa: E501
204204
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
205+
"Qwen3ForCausalLM": _HfExamplesInfo(
206+
"Qwen/Qwen3-8B",
207+
is_available_online=False,
208+
min_transformers_version="4.51"
209+
),
210+
"Qwen3MoeForCausalLM": _HfExamplesInfo(
211+
"Qwen/Qwen3-MoE-15B-A2B",
212+
is_available_online=False,
213+
min_transformers_version="4.51"
214+
),
205215
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b",
206216
is_available_online=False),
207217
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501

vllm/model_executor/models/qwen2.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,11 @@ def forward(
263263
})
264264
class Qwen2Model(nn.Module):
265265

266-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
266+
def __init__(self,
267+
*,
268+
vllm_config: VllmConfig,
269+
prefix: str = "",
270+
decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer):
267271
super().__init__()
268272

269273
config = vllm_config.model_config.hf_config
@@ -297,12 +301,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
297301
else:
298302
self.embed_tokens = PPMissingLayer()
299303

304+
# Use the provided decoder layer type or default to Qwen2DecoderLayer
305+
decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer
300306
self.start_layer, self.end_layer, self.layers = make_layers(
301307
config.num_hidden_layers,
302-
lambda prefix: Qwen2DecoderLayer(config=config,
303-
cache_config=cache_config,
304-
quant_config=quant_config,
305-
prefix=prefix),
308+
lambda prefix: decoder_layer_type(config=config,
309+
cache_config=cache_config,
310+
quant_config=quant_config,
311+
prefix=prefix),
306312
prefix=f"{prefix}.layers",
307313
)
308314

Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
# Copyright 2024 The Qwen team.
4+
# Copyright 2023 The vLLM team.
5+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
6+
#
7+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
8+
# and OPT implementations in this library. It has been modified from its
9+
# original forms to accommodate minor architectural differences compared
10+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
11+
#
12+
# Licensed under the Apache License, Version 2.0 (the "License");
13+
# you may not use this file except in compliance with the License.
14+
# You may obtain a copy of the License at
15+
#
16+
# http://www.apache.org/licenses/LICENSE-2.0
17+
#
18+
# Unless required by applicable law or agreed to in writing, software
19+
# distributed under the License is distributed on an "AS IS" BASIS,
20+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21+
# See the License for the specific language governing permissions and
22+
# limitations under the License.
23+
"""Inference-only Qwen3 model compatible with HuggingFace weights."""
24+
from typing import Iterable, Optional, Set, Tuple, Union
25+
26+
import torch
27+
from torch import nn
28+
from transformers import Qwen3Config
29+
30+
from vllm.attention import Attention, AttentionType
31+
from vllm.compilation.decorators import support_torch_compile
32+
from vllm.config import CacheConfig, VllmConfig
33+
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
34+
from vllm.logger import init_logger
35+
from vllm.model_executor.layers.layernorm import RMSNorm
36+
from vllm.model_executor.layers.linear import (QKVParallelLinear,
37+
RowParallelLinear)
38+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
39+
from vllm.model_executor.layers.quantization import QuantizationConfig
40+
from vllm.model_executor.layers.rotary_embedding import get_rope
41+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
42+
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
43+
from vllm.model_executor.sampling_metadata import SamplingMetadata
44+
from vllm.sequence import IntermediateTensors
45+
46+
from .interfaces import SupportsLoRA, SupportsPP
47+
from .qwen2 import Qwen2MLP as Qwen3MLP
48+
from .qwen2 import Qwen2Model
49+
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
50+
51+
logger = init_logger(__name__)
52+
53+
54+
class Qwen3Attention(nn.Module):
55+
56+
def __init__(self,
57+
hidden_size: int,
58+
num_heads: int,
59+
num_kv_heads: int,
60+
max_position: int = 4096 * 32,
61+
head_dim: Optional[int] = None,
62+
rms_norm_eps: float = 1e-06,
63+
qkv_bias: bool = False,
64+
rope_theta: float = 10000,
65+
cache_config: Optional[CacheConfig] = None,
66+
quant_config: Optional[QuantizationConfig] = None,
67+
rope_scaling: Optional[Tuple] = None,
68+
prefix: str = "",
69+
attn_type: str = AttentionType.DECODER) -> None:
70+
super().__init__()
71+
self.hidden_size = hidden_size
72+
tp_size = get_tensor_model_parallel_world_size()
73+
self.total_num_heads = num_heads
74+
assert self.total_num_heads % tp_size == 0
75+
self.num_heads = self.total_num_heads // tp_size
76+
self.total_num_kv_heads = num_kv_heads
77+
if self.total_num_kv_heads >= tp_size:
78+
# Number of KV heads is greater than TP size, so we partition
79+
# the KV heads across multiple tensor parallel GPUs.
80+
assert self.total_num_kv_heads % tp_size == 0
81+
else:
82+
# Number of KV heads is less than TP size, so we replicate
83+
# the KV heads across multiple tensor parallel GPUs.
84+
assert tp_size % self.total_num_kv_heads == 0
85+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
86+
self.head_dim = head_dim or hidden_size // self.total_num_heads
87+
self.q_size = self.num_heads * self.head_dim
88+
self.kv_size = self.num_kv_heads * self.head_dim
89+
self.scaling = self.head_dim**-0.5
90+
self.rope_theta = rope_theta
91+
92+
self.qkv_proj = QKVParallelLinear(
93+
hidden_size,
94+
self.head_dim,
95+
self.total_num_heads,
96+
self.total_num_kv_heads,
97+
bias=qkv_bias,
98+
quant_config=quant_config,
99+
prefix=f"{prefix}.qkv_proj",
100+
)
101+
self.o_proj = RowParallelLinear(
102+
self.total_num_heads * self.head_dim,
103+
hidden_size,
104+
bias=False,
105+
quant_config=quant_config,
106+
prefix=f"{prefix}.o_proj",
107+
)
108+
109+
self.rotary_emb = get_rope(
110+
self.head_dim,
111+
rotary_dim=self.head_dim,
112+
max_position=max_position,
113+
base=self.rope_theta,
114+
rope_scaling=rope_scaling,
115+
)
116+
self.attn = Attention(self.num_heads,
117+
self.head_dim,
118+
self.scaling,
119+
num_kv_heads=self.num_kv_heads,
120+
cache_config=cache_config,
121+
quant_config=quant_config,
122+
prefix=f"{prefix}.attn",
123+
attn_type=attn_type)
124+
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
125+
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
126+
127+
def forward(
128+
self,
129+
positions: torch.Tensor,
130+
hidden_states: torch.Tensor,
131+
) -> torch.Tensor:
132+
qkv, _ = self.qkv_proj(hidden_states)
133+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
134+
# Add qk-norm
135+
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
136+
self.head_dim)
137+
q_by_head = self.q_norm.forward_native(q_by_head)
138+
q = q_by_head.view(q.shape)
139+
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
140+
self.head_dim)
141+
k_by_head = self.k_norm.forward_native(k_by_head)
142+
k = k_by_head.view(k.shape)
143+
q, k = self.rotary_emb(positions, q, k)
144+
attn_output = self.attn(q, k, v)
145+
output, _ = self.o_proj(attn_output)
146+
return output
147+
148+
149+
class Qwen3DecoderLayer(nn.Module):
150+
151+
def __init__(
152+
self,
153+
config: Qwen3Config,
154+
cache_config: Optional[CacheConfig] = None,
155+
quant_config: Optional[QuantizationConfig] = None,
156+
prefix: str = "",
157+
) -> None:
158+
super().__init__()
159+
self.hidden_size = config.hidden_size
160+
# Requires transformers > 4.32.0
161+
rope_theta = getattr(config, "rope_theta", 1000000)
162+
rope_scaling = getattr(config, "rope_scaling", None)
163+
164+
# By default, Qwen3 uses causal attention as it is a decoder-only model.
165+
# You can override the HF config with `is_causal=False` to enable
166+
# bidirectional attention, which is used in some embedding models
167+
# (e.g. Alibaba-NLP/gte-Qwen3-7B-instruct)
168+
if getattr(config, "is_causal", True):
169+
attn_type = AttentionType.DECODER
170+
else:
171+
attn_type = AttentionType.ENCODER_ONLY
172+
173+
self.self_attn = Qwen3Attention(
174+
hidden_size=self.hidden_size,
175+
num_heads=config.num_attention_heads,
176+
max_position=config.max_position_embeddings,
177+
num_kv_heads=config.num_key_value_heads,
178+
rope_theta=rope_theta,
179+
rms_norm_eps=config.rms_norm_eps,
180+
qkv_bias=getattr(config, 'attention_bias', False),
181+
head_dim=getattr(config, 'head_dim', None),
182+
cache_config=cache_config,
183+
quant_config=quant_config,
184+
rope_scaling=rope_scaling,
185+
prefix=f"{prefix}.self_attn",
186+
attn_type=attn_type,
187+
)
188+
self.mlp = Qwen3MLP(
189+
hidden_size=self.hidden_size,
190+
intermediate_size=config.intermediate_size,
191+
hidden_act=config.hidden_act,
192+
quant_config=quant_config,
193+
prefix=f"{prefix}.mlp",
194+
)
195+
self.input_layernorm = RMSNorm(config.hidden_size,
196+
eps=config.rms_norm_eps)
197+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
198+
eps=config.rms_norm_eps)
199+
200+
def forward(
201+
self,
202+
positions: torch.Tensor,
203+
hidden_states: torch.Tensor,
204+
residual: Optional[torch.Tensor],
205+
) -> Tuple[torch.Tensor, torch.Tensor]:
206+
# Self Attention
207+
if residual is None:
208+
residual = hidden_states
209+
hidden_states = self.input_layernorm(hidden_states)
210+
else:
211+
hidden_states, residual = self.input_layernorm(
212+
hidden_states, residual)
213+
hidden_states = self.self_attn(
214+
positions=positions,
215+
hidden_states=hidden_states,
216+
)
217+
218+
# Fully Connected
219+
hidden_states, residual = self.post_attention_layernorm(
220+
hidden_states, residual)
221+
hidden_states = self.mlp(hidden_states)
222+
return hidden_states, residual
223+
224+
225+
ALL_DECODER_LAYER_TYPES = {
226+
"attention": Qwen3DecoderLayer,
227+
}
228+
229+
230+
@support_torch_compile(
231+
dynamic_arg_dims={
232+
"input_ids": 0,
233+
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
234+
# otherwise (seq_len, ).
235+
"positions": -1,
236+
"intermediate_tensors": 0,
237+
"inputs_embeds": 0,
238+
})
239+
class Qwen3Model(Qwen2Model):
240+
241+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
242+
super().__init__(vllm_config=vllm_config,
243+
prefix=prefix,
244+
decoder_layer_type=Qwen3DecoderLayer)
245+
246+
247+
class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
248+
packed_modules_mapping = {
249+
"qkv_proj": [
250+
"q_proj",
251+
"k_proj",
252+
"v_proj",
253+
],
254+
"gate_up_proj": [
255+
"gate_proj",
256+
"up_proj",
257+
],
258+
}
259+
260+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
261+
super().__init__()
262+
config = vllm_config.model_config.hf_config
263+
quant_config = vllm_config.quant_config
264+
lora_config = vllm_config.lora_config
265+
266+
self.config = config
267+
self.lora_config = lora_config
268+
269+
self.quant_config = quant_config
270+
self.model = Qwen3Model(vllm_config=vllm_config,
271+
prefix=maybe_prefix(prefix, "model"))
272+
273+
if get_pp_group().is_last_rank:
274+
if config.tie_word_embeddings:
275+
self.lm_head = self.model.embed_tokens
276+
else:
277+
self.lm_head = ParallelLMHead(config.vocab_size,
278+
config.hidden_size,
279+
quant_config=quant_config,
280+
prefix=maybe_prefix(
281+
prefix, "lm_head"))
282+
else:
283+
self.lm_head = PPMissingLayer()
284+
285+
self.logits_processor = LogitsProcessor(config.vocab_size)
286+
self.sampler = get_sampler()
287+
288+
self.make_empty_intermediate_tensors = (
289+
self.model.make_empty_intermediate_tensors)
290+
291+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
292+
return self.model.get_input_embeddings(input_ids)
293+
294+
def forward(
295+
self,
296+
input_ids: torch.Tensor,
297+
positions: torch.Tensor,
298+
intermediate_tensors: Optional[IntermediateTensors] = None,
299+
inputs_embeds: Optional[torch.Tensor] = None,
300+
) -> Union[torch.Tensor, IntermediateTensors]:
301+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
302+
inputs_embeds)
303+
return hidden_states
304+
305+
def compute_logits(
306+
self,
307+
hidden_states: torch.Tensor,
308+
sampling_metadata: SamplingMetadata,
309+
) -> Optional[torch.Tensor]:
310+
logits = self.logits_processor(self.lm_head, hidden_states,
311+
sampling_metadata)
312+
return logits
313+
314+
def sample(
315+
self,
316+
logits: torch.Tensor,
317+
sampling_metadata: SamplingMetadata,
318+
) -> Optional[SamplerOutput]:
319+
next_tokens = self.sampler(logits, sampling_metadata)
320+
return next_tokens
321+
322+
def load_weights(self, weights: Iterable[Tuple[str,
323+
torch.Tensor]]) -> Set[str]:
324+
loader = AutoWeightsLoader(
325+
self,
326+
skip_prefixes=(["lm_head."]
327+
if self.config.tie_word_embeddings else None),
328+
)
329+
return loader.load_weights(weights)

0 commit comments

Comments
 (0)