From a5719ab43b77b119d55b425e01669fd2c40896d0 Mon Sep 17 00:00:00 2001 From: huanglei Date: Wed, 4 Mar 2026 14:17:03 +0800 Subject: [PATCH 1/3] [feat] Kimi K2/DeepSeek Support eagle3 --- vllm/config/speculative.py | 1 + vllm/model_executor/models/deepseek_v2.py | 58 +++++++++++++++++++++-- vllm/model_executor/models/kimi_k25.py | 23 ++++++++- vllm/v1/spec_decode/eagle.py | 4 ++ 4 files changed, 81 insertions(+), 5 deletions(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index a950ba531ad2..3cbb298037d2 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -774,6 +774,7 @@ def _verify_args(self) -> Self: "hunyuan_v1_dense", "afmoe", "nemotron_h", + "kimi_k2", ] if ( self.method in ("eagle3", "extract_hidden_states") diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 5dd883f222e5..c1ff000e9863 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -82,7 +82,13 @@ ) from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec -from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP +from .interfaces import ( + MixtureOfExperts, + SupportsEagle, + SupportsEagle3, + SupportsLoRA, + SupportsPP, +) from .utils import ( PPMissingLayer, is_pp_missing_parameter, @@ -1166,6 +1172,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.norm = PPMissingLayer() + + # Eagle3 support: track which layers should output auxiliary hidden states + self.aux_hidden_state_layers = tuple[int, ...]() + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size ) @@ -1179,7 +1189,7 @@ def forward( positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None, inputs_embeds: torch.Tensor | None = None, - ) -> torch.Tensor | IntermediateTensors: + ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -1205,7 +1215,16 @@ def forward( else: llama_4_scaling = None - for layer in islice(self.layers, self.start_layer, self.end_layer): + # Eagle3 support: collect auxiliary hidden states from specified layers + aux_hidden_states = [] + for idx, layer in enumerate( + islice(self.layers, self.start_layer, self.end_layer) + ): + # Save hidden states before layer processing if this layer is marked + if idx in self.aux_hidden_state_layers: + # Store pre-normalization state (hidden_states + residual) + aux_hidden_states.append(hidden_states + residual) + hidden_states, residual = layer( positions, hidden_states, residual, llama_4_scaling ) @@ -1216,6 +1235,11 @@ def forward( ) hidden_states, _ = self.norm(hidden_states, residual) + + # Eagle3 support: return auxiliary hidden states if collected + if len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states + return hidden_states @@ -1261,7 +1285,7 @@ def update_physical_experts_metadata( class DeepseekV2ForCausalLM( - nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA, SupportsEagle + nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA, SupportsEagle, SupportsEagle3 ): packed_modules_mapping = { "gate_up_proj": ["gate_proj", "up_proj"], @@ -1343,6 +1367,32 @@ def set_moe_parameters(self): def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + """Set which layers should output auxiliary hidden states for Eagle3. + + Args: + layers: Tuple of layer indices that should output auxiliary hidden states. + """ + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + """Get default auxiliary layer indices for DeepseekV2. + + Returns default layer selection based on model size. These defaults + can be overridden by speculative config if the draft model specifies + eagle_aux_hidden_state_layer_ids in its config. + + Returns: + Tuple of layer indices (typically: early, middle, late layers) + """ + num_layers = len(self.model.layers) + # Select 3 representative layers: early, middle, and late + return ( + 2, # Early layer (captures low-level features) + num_layers // 2, # Middle layer (captures mid-level semantics) + num_layers - 3 # Late layer (captures high-level semantics) + ) + def forward( self, input_ids: torch.Tensor | None, diff --git a/vllm/model_executor/models/kimi_k25.py b/vllm/model_executor/models/kimi_k25.py index 248339337fa9..b086b89b7e5d 100644 --- a/vllm/model_executor/models/kimi_k25.py +++ b/vllm/model_executor/models/kimi_k25.py @@ -28,6 +28,7 @@ CompressedTensorsConfig, ) from vllm.model_executor.models.interfaces import ( + SupportsEagle3, SupportsMultiModal, SupportsPP, SupportsQuant, @@ -310,7 +311,7 @@ def split_video_chunks(self, video): dummy_inputs=KimiK25DummyInputsBuilder, ) class KimiK25ForConditionalGeneration( - nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant + nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant, SupportsEagle3 ): """Kimi-K2.5 model for conditional generation. @@ -456,6 +457,26 @@ def embed_multimodal(self, **kwargs: object) -> NestedTensors | None: vision_embeddings = self._process_media_input(media_input) return vision_embeddings + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + """Set which layers should output auxiliary hidden states for Eagle3. + + Delegates to the underlying language model (DeepseekV2ForCausalLM). + + Args: + layers: Tuple of layer indices that should output auxiliary hidden states. + """ + self.language_model.set_aux_hidden_state_layers(layers) + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + """Get default auxiliary layer indices for Kimi K2.5. + + Delegates to the underlying language model (DeepseekV2ForCausalLM). + + Returns: + Tuple of layer indices (typically: early, middle, late layers) + """ + return self.language_model.get_eagle3_aux_hidden_state_layers() + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index ca58c441f46d..4f960a2a0d18 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1301,6 +1301,10 @@ def load_model(self, target_model: nn.Module) -> None: self.model.config.image_token_index = ( target_model.config.vision_config.image_token_id ) + elif self.get_model_name(target_model) == "KimiK25ForConditionalGeneration": + self.model.config.image_token_index = getattr( + target_model.config, "media_placeholder_token_id", None + ) else: self.model.config.image_token_index = ( target_model.config.image_token_index From beb2594638683af6cce36ca96621d4bc4f6dc0a4 Mon Sep 17 00:00:00 2001 From: huanglei Date: Wed, 4 Mar 2026 15:40:45 +0800 Subject: [PATCH 2/3] fix(eagle3): Fix layer indexing for pipeline parallelism --- vllm/model_executor/models/deepseek_v2.py | 39 ++++++++++++++++------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index c1ff000e9863..4f6061ac9333 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -26,7 +26,6 @@ import typing from collections.abc import Callable, Iterable -from itertools import islice import torch from torch import nn @@ -1217,11 +1216,11 @@ def forward( # Eagle3 support: collect auxiliary hidden states from specified layers aux_hidden_states = [] - for idx, layer in enumerate( - islice(self.layers, self.start_layer, self.end_layer) - ): + for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]): + # Calculate global layer index for pipeline parallelism + global_layer_idx = self.start_layer + idx # Save hidden states before layer processing if this layer is marked - if idx in self.aux_hidden_state_layers: + if global_layer_idx in self.aux_hidden_state_layers: # Store pre-normalization state (hidden_states + residual) aux_hidden_states.append(hidden_states + residual) @@ -1285,7 +1284,12 @@ def update_physical_experts_metadata( class DeepseekV2ForCausalLM( - nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA, SupportsEagle, SupportsEagle3 + nn.Module, + SupportsPP, + DeepseekV2MixtureOfExperts, + SupportsLoRA, + SupportsEagle, + SupportsEagle3, ): packed_modules_mapping = { "gate_up_proj": ["gate_proj", "up_proj"], @@ -1385,12 +1389,25 @@ def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: Returns: Tuple of layer indices (typically: early, middle, late layers) """ - num_layers = len(self.model.layers) + # Use total number of layers from config, not len(layers) which + # only reflects the current pipeline parallel rank + num_layers = self.model.config.num_hidden_layers + + # Handle small models gracefully + if num_layers < 4: + # For very small models, return middle layer only (or empty if no layers) + return (num_layers // 2,) if num_layers > 0 else () + # Select 3 representative layers: early, middle, and late - return ( - 2, # Early layer (captures low-level features) - num_layers // 2, # Middle layer (captures mid-level semantics) - num_layers - 3 # Late layer (captures high-level semantics) + # Use set to avoid duplicates in edge cases, then sort + return tuple( + sorted( + { + 2, # Early layer (captures low-level features) + num_layers // 2, # Middle layer (captures mid-level semantics) + num_layers - 3, # Late layer (captures high-level semantics) + } + ) ) def forward( From 84b87827dcf92569bf1d23ceb521c9eecb7d004e Mon Sep 17 00:00:00 2001 From: huanglei Date: Fri, 6 Mar 2026 10:42:19 +0800 Subject: [PATCH 3/3] refactor: Trim unnecessary comments in Eagle3 implementation Simplify docstrings and remove redundant comments that duplicate what the code already expresses. Keep only essential technical notes that explain non-obvious implementation details. Co-Authored-By: Claude Sonnet 4.5 --- vllm/model_executor/models/deepseek_v2.py | 40 +++-------------------- vllm/model_executor/models/kimi_k25.py | 16 ++------- 2 files changed, 7 insertions(+), 49 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 4f6061ac9333..0fe31e994073 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -1172,7 +1172,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.norm = PPMissingLayer() - # Eagle3 support: track which layers should output auxiliary hidden states self.aux_hidden_state_layers = tuple[int, ...]() self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( @@ -1214,14 +1213,11 @@ def forward( else: llama_4_scaling = None - # Eagle3 support: collect auxiliary hidden states from specified layers aux_hidden_states = [] for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]): - # Calculate global layer index for pipeline parallelism global_layer_idx = self.start_layer + idx - # Save hidden states before layer processing if this layer is marked if global_layer_idx in self.aux_hidden_state_layers: - # Store pre-normalization state (hidden_states + residual) + # Pre-normalization state aux_hidden_states.append(hidden_states + residual) hidden_states, residual = layer( @@ -1235,7 +1231,6 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) - # Eagle3 support: return auxiliary hidden states if collected if len(aux_hidden_states) > 0: return hidden_states, aux_hidden_states @@ -1372,43 +1367,18 @@ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: - """Set which layers should output auxiliary hidden states for Eagle3. - - Args: - layers: Tuple of layer indices that should output auxiliary hidden states. - """ + """Set which layers should output auxiliary hidden states.""" self.model.aux_hidden_state_layers = layers def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: - """Get default auxiliary layer indices for DeepseekV2. - - Returns default layer selection based on model size. These defaults - can be overridden by speculative config if the draft model specifies - eagle_aux_hidden_state_layer_ids in its config. - - Returns: - Tuple of layer indices (typically: early, middle, late layers) - """ - # Use total number of layers from config, not len(layers) which - # only reflects the current pipeline parallel rank + """Return default auxiliary layer indices: early, middle, and late layers.""" + # Use config.num_hidden_layers for correct count across pipeline stages num_layers = self.model.config.num_hidden_layers - # Handle small models gracefully if num_layers < 4: - # For very small models, return middle layer only (or empty if no layers) return (num_layers // 2,) if num_layers > 0 else () - # Select 3 representative layers: early, middle, and late - # Use set to avoid duplicates in edge cases, then sort - return tuple( - sorted( - { - 2, # Early layer (captures low-level features) - num_layers // 2, # Middle layer (captures mid-level semantics) - num_layers - 3, # Late layer (captures high-level semantics) - } - ) - ) + return tuple(sorted({2, num_layers // 2, num_layers - 3})) def forward( self, diff --git a/vllm/model_executor/models/kimi_k25.py b/vllm/model_executor/models/kimi_k25.py index b086b89b7e5d..e021908a65cc 100644 --- a/vllm/model_executor/models/kimi_k25.py +++ b/vllm/model_executor/models/kimi_k25.py @@ -458,23 +458,11 @@ def embed_multimodal(self, **kwargs: object) -> NestedTensors | None: return vision_embeddings def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: - """Set which layers should output auxiliary hidden states for Eagle3. - - Delegates to the underlying language model (DeepseekV2ForCausalLM). - - Args: - layers: Tuple of layer indices that should output auxiliary hidden states. - """ + """Set which layers should output auxiliary hidden states.""" self.language_model.set_aux_hidden_state_layers(layers) def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: - """Get default auxiliary layer indices for Kimi K2.5. - - Delegates to the underlying language model (DeepseekV2ForCausalLM). - - Returns: - Tuple of layer indices (typically: early, middle, late layers) - """ + """Return default auxiliary layer indices.""" return self.language_model.get_eagle3_aux_hidden_state_layers() def forward(