Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions vllm/model_executor/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors

from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
Expand Down Expand Up @@ -422,6 +422,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
# Track layers for auxiliary hidden state outputs (EAGLE3)
self.aux_hidden_state_layers: tuple[int, ...] = ()

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
Expand All @@ -432,7 +434,9 @@ def forward(
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
) -> Union[
torch.Tensor, IntermediateTensors, tuple[torch.Tensor, list[torch.Tensor]]
]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
Expand All @@ -443,13 +447,29 @@ def forward(
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in islice(self.layers, self.start_layer, self.end_layer):

aux_hidden_states = []
for layer_idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer),
start=self.start_layer,
):
# Collect auxiliary hidden states if specified
if layer_idx in self.aux_hidden_state_layers:
aux_hidden_state = (
hidden_states + residual if residual is not None else hidden_states
)
aux_hidden_states.append(aux_hidden_state)
hidden_states, residual = layer(positions, hidden_states, residual)

if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)

# Return auxiliary hidden states if collected
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states

def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
Expand Down Expand Up @@ -606,7 +626,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
return loaded_params


class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExperts):
class Qwen3MoeForCausalLM(
nn.Module, SupportsPP, SupportsLoRA, SupportsEagle3, MixtureOfExperts
):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -702,6 +724,13 @@ def update_physical_experts_metadata(
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()

def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
Comment thread
rahul-tuli marked this conversation as resolved.
self.model.aux_hidden_state_layers = layers

def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

Expand Down