Skip to content

Commit a2651c6

Browse files
authored
Merge pull request vllm-project#20 from nebius/feature/qwen3_moe_0.10.2
Added Qwen MoE
2 parents 21f44bc + 3ae424c commit a2651c6

File tree

3 files changed

+19
-14
lines changed

3 files changed

+19
-14
lines changed

vllm/model_executor/models/qwen2.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,8 +336,6 @@ def __init__(self,
336336

337337
self.aux_hidden_state_layers: tuple[int] = tuple()
338338

339-
self.aux_hidden_state_layers = tuple[int, ...]()
340-
341339
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
342340
return self.embed_tokens(input_ids)
343341

vllm/model_executor/models/qwen3.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
304304
self.make_empty_intermediate_tensors = (
305305
self.model.make_empty_intermediate_tensors)
306306

307-
self.aux_hidden_state_layers: tuple[int] = tuple()
308-
309307
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
310308
self.model.aux_hidden_state_layers = layers
311309

@@ -316,13 +314,6 @@ def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
316314
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
317315
return self.model.get_input_embeddings(input_ids)
318316

319-
def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
320-
self.model.aux_hidden_state_layers = layers
321-
322-
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
323-
num_layers = len(self.model.layers)
324-
return (2, num_layers // 2, num_layers - 3)
325-
326317
def forward(
327318
self,
328319
input_ids: torch.Tensor,

vllm/model_executor/models/qwen3_moe.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
from vllm.model_executor.sampling_metadata import SamplingMetadata
5858
from vllm.sequence import IntermediateTensors
5959

60-
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
60+
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP, SupportsEagle3
6161
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
6262
is_pp_missing_parameter,
6363
make_empty_intermediate_tensors_factory, make_layers,
@@ -408,6 +408,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
408408
make_empty_intermediate_tensors_factory(
409409
["hidden_states", "residual"], config.hidden_size))
410410

411+
self.aux_hidden_state_layers: tuple[int] = tuple()
412+
411413
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
412414
return self.embed_tokens(input_ids)
413415

@@ -428,14 +430,21 @@ def forward(
428430
assert intermediate_tensors is not None
429431
hidden_states = intermediate_tensors["hidden_states"]
430432
residual = intermediate_tensors["residual"]
431-
for layer in islice(self.layers, self.start_layer, self.end_layer):
433+
434+
aux_hidden_states = []
435+
for layer_idx, layer in enumerate(islice(self.layers, self.start_layer, self.end_layer)):
436+
if layer_idx in self.aux_hidden_state_layers:
437+
aux_hidden_states.append(hidden_states + residual)
432438
hidden_states, residual = layer(positions, hidden_states, residual)
439+
433440
if not get_pp_group().is_last_rank:
434441
return IntermediateTensors({
435442
"hidden_states": hidden_states,
436443
"residual": residual
437444
})
438445
hidden_states, _ = self.norm(hidden_states, residual)
446+
if len(aux_hidden_states) > 0:
447+
return hidden_states, aux_hidden_states
439448
return hidden_states
440449

441450
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
@@ -579,7 +588,7 @@ def load_weights(self, weights: Iterable[tuple[str,
579588
return loaded_params
580589

581590

582-
class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
591+
class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, SupportsEagle3,
583592
MixtureOfExperts):
584593
packed_modules_mapping = {
585594
"qkv_proj": [
@@ -674,6 +683,13 @@ def update_physical_experts_metadata(
674683

675684
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
676685
return self.model.get_input_embeddings(input_ids)
686+
687+
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
688+
self.model.aux_hidden_state_layers = layers
689+
690+
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
691+
num_layers = len(self.model.layers)
692+
return (2, num_layers // 2, num_layers - 3)
677693

678694
def forward(
679695
self,

0 commit comments

Comments
 (0)