Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,7 @@ def _verify_args(self) -> Self:
"hunyuan_v1_dense",
"afmoe",
"nemotron_h",
"kimi_k2",
]
if (
self.method in ("eagle3", "extract_hidden_states")
Expand Down
47 changes: 42 additions & 5 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import typing
from collections.abc import Callable, Iterable
from itertools import islice

import torch
from torch import nn
Expand Down Expand Up @@ -82,7 +81,13 @@
)
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec

from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP
from .interfaces import (
MixtureOfExperts,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import (
PPMissingLayer,
is_pp_missing_parameter,
Expand Down Expand Up @@ -1166,6 +1171,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()

self.aux_hidden_state_layers = tuple[int, ...]()

self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
Expand All @@ -1179,7 +1187,7 @@ def forward(
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
Expand All @@ -1205,7 +1213,13 @@ def forward(
else:
llama_4_scaling = None

for layer in islice(self.layers, self.start_layer, self.end_layer):
aux_hidden_states = []
for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]):
global_layer_idx = self.start_layer + idx
if global_layer_idx in self.aux_hidden_state_layers:
# Pre-normalization state
aux_hidden_states.append(hidden_states + residual)

hidden_states, residual = layer(
positions, hidden_states, residual, llama_4_scaling
)
Expand All @@ -1216,6 +1230,10 @@ def forward(
)

hidden_states, _ = self.norm(hidden_states, residual)

if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states

return hidden_states


Expand Down Expand Up @@ -1261,7 +1279,12 @@ def update_physical_experts_metadata(


class DeepseekV2ForCausalLM(
nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA, SupportsEagle
nn.Module,
SupportsPP,
DeepseekV2MixtureOfExperts,
SupportsLoRA,
SupportsEagle,
SupportsEagle3,
):
packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"],
Expand Down Expand Up @@ -1343,6 +1366,20 @@ def set_moe_parameters(self):
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)

def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
"""Set which layers should output auxiliary hidden states."""
self.model.aux_hidden_state_layers = layers

def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
"""Return default auxiliary layer indices: early, middle, and late layers."""
# Use config.num_hidden_layers for correct count across pipeline stages
num_layers = self.model.config.num_hidden_layers

if num_layers < 4:
return (num_layers // 2,) if num_layers > 0 else ()

return tuple(sorted({2, num_layers // 2, num_layers - 3}))

def forward(
self,
input_ids: torch.Tensor | None,
Expand Down
11 changes: 10 additions & 1 deletion vllm/model_executor/models/kimi_k25.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
CompressedTensorsConfig,
)
from vllm.model_executor.models.interfaces import (
SupportsEagle3,
SupportsMultiModal,
SupportsPP,
SupportsQuant,
Expand Down Expand Up @@ -310,7 +311,7 @@ def split_video_chunks(self, video):
dummy_inputs=KimiK25DummyInputsBuilder,
)
class KimiK25ForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant
nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant, SupportsEagle3
):
"""Kimi-K2.5 model for conditional generation.

Expand Down Expand Up @@ -456,6 +457,14 @@ def embed_multimodal(self, **kwargs: object) -> NestedTensors | None:
vision_embeddings = self._process_media_input(media_input)
return vision_embeddings

def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
"""Set which layers should output auxiliary hidden states."""
self.language_model.set_aux_hidden_state_layers(layers)

def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
"""Return default auxiliary layer indices."""
return self.language_model.get_eagle3_aux_hidden_state_layers()

def forward(
self,
input_ids: torch.Tensor,
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,6 +1301,10 @@ def load_model(self, target_model: nn.Module) -> None:
self.model.config.image_token_index = (
target_model.config.vision_config.image_token_id
)
elif self.get_model_name(target_model) == "KimiK25ForConditionalGeneration":
self.model.config.image_token_index = getattr(
target_model.config, "media_placeholder_token_id", None
)
else:
self.model.config.image_token_index = (
target_model.config.image_token_index
Expand Down