Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions vllm/model_executor/models/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
IsHybrid,
MixtureOfExperts,
MultiModalEmbeddings,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
_require_is_multimodal,
Expand Down Expand Up @@ -353,6 +354,8 @@ def get_layer(prefix: str):
else:
self.norm = PPMissingLayer()

self.aux_hidden_state_layers: tuple[int, ...] = ()

def load_fused_expert_weights(
self,
name: str,
Expand Down Expand Up @@ -536,6 +539,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
class Qwen3_5ForCausalLMBase(
nn.Module,
HasInnerState,
SupportsEagle3,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add SupportsEagle. It's not currently used everywhere but I'm trying to get all the models to have both for consistency, at least for now. See #36063

SupportsLoRA,
SupportsPP,
):
Expand Down Expand Up @@ -592,6 +596,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)

def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers

def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)

def forward(
self,
input_ids: torch.Tensor,
Expand Down
16 changes: 14 additions & 2 deletions vllm/model_executor/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,6 +1148,8 @@ def get_layer(prefix: str):
else:
self.norm = PPMissingLayer()

self.aux_hidden_state_layers: tuple[int, ...] = ()

def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)

Expand All @@ -1157,7 +1159,7 @@ def forward(
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
Expand All @@ -1169,7 +1171,15 @@ def forward(
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]

for layer in islice(self.layers, self.start_layer, self.end_layer):
aux_hidden_states = []
for layer_idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer),
start=self.start_layer,
):
if layer_idx in self.aux_hidden_state_layers:
aux_hidden_states.append(
hidden_states + residual if residual is not None else hidden_states
)
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
Expand All @@ -1181,6 +1191,8 @@ def forward(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
if aux_hidden_states:
return hidden_states, aux_hidden_states
return hidden_states

def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
Expand Down