diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index bf533bf14e55..f0eabd973e00 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -618,7 +618,14 @@ def _verify_args(self) -> Self: f"{self.disable_by_batch_size=}" ) - eagle3_target_supported = ["llama", "qwen", "minicpm", "gpt_oss"] + eagle3_target_supported = [ + "llama", + "qwen", + "minicpm", + "gpt_oss", + "deepseek_v2", + "kimi_k2", + ] if ( self.method == "eagle3" and self.target_model_config diff --git a/vllm/model_executor/models/deepseek_eagle3.py b/vllm/model_executor/models/deepseek_eagle3.py new file mode 100644 index 000000000000..640ba89914b2 --- /dev/null +++ b/vllm/model_executor/models/deepseek_eagle3.py @@ -0,0 +1,419 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Eagle3 speculative decoding model for DeepseekV2/V3 with MLP (no MoE).""" + +import copy +from collections.abc import Iterable + +import torch +import torch.nn as nn +from transformers import DeepseekV2Config, DeepseekV3Config + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.model_executor.models.deepseek_v2 import ( + DeepseekV2ForCausalLM, + DeepseekV2MLAAttention, + DeepseekV2MLP, +) +from vllm.multimodal.inputs import NestedTensors + +from .utils import ( + AutoWeightsLoader, + get_draft_quant_config, + maybe_prefix, + process_eagle_weight, +) + +logger = init_logger(__name__) + + +class DeepseekV2Eagle3DecoderLayer(nn.Module): + """ + Eagle3 decoder layer for Deepseek that: + 1. Always uses MLP (not MoE) + 2. First layer accepts concatenated embeds + hidden_states + """ + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str, + config: DeepseekV2Config | DeepseekV3Config | None = None, + layer_idx: int = 0, + ) -> None: + super().__init__() + + if config is None: + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = get_draft_quant_config(vllm_config) + + self.hidden_size = config.hidden_size + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + + self.layer_idx = layer_idx + + # MLA attention parameters + qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0) + qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0) + v_head_dim = getattr(config, "v_head_dim", 0) + kv_lora_rank = getattr(config, "kv_lora_rank", 0) + config = copy.copy(config) + if rope_scaling: + rope_params = rope_scaling.copy() + rope_params["rope_type"] = "deepseek_yarn" + else: + rope_params = {"rope_type": "default"} + config.rope_parameters = rope_params + self.self_attn = DeepseekV2MLAAttention( + vllm_config=vllm_config, + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + v_head_dim=v_head_dim, + q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None, + kv_lora_rank=kv_lora_rank, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + input_size=2 * self.hidden_size if layer_idx == 0 else self.hidden_size, + ) + + # Always use MLP (not MoE) for Eagle3 + self.mlp = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + if getattr(config, "norm_before_residual", False): + self._residual_norm = self._norm_before_residual + else: + self._residual_norm = self._norm_after_residual + + def _norm_before_residual( + self, hidden_states: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + hidden_states = self.hidden_norm(hidden_states) + residual = hidden_states + return hidden_states, residual + + def _norm_after_residual( + self, hidden_states: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + residual = hidden_states + hidden_states = self.hidden_norm(hidden_states) + return hidden_states, residual + + def forward( + self, + positions: torch.Tensor, + embeds: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.layer_idx == 0: + # First layer: concatenate embeds with hidden_states + embeds = self.input_layernorm(embeds) + hidden_states, residual = self._residual_norm(hidden_states=hidden_states) + hidden_states = torch.cat([embeds, hidden_states], dim=-1) + else: + # Subsequent layers: process hidden_states and residuals only + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + # Self Attention + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + llama_4_scaling=None, + ) + + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + + # Fully Connected (MLP, not MoE) + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +@support_torch_compile +class DeepseekV2Eagle3Model(nn.Module): + def __init__( + self, + *, + vllm_config: VllmConfig, + start_layer_id: int = 0, + prefix: str = "", + ) -> None: + super().__init__() + self.config = vllm_config.speculative_config.draft_model_config.hf_config + self.vocab_size = self.config.vocab_size + + # Get drafter's quantization config + self.quant_config = get_draft_quant_config(vllm_config) + + current_vllm_config = get_current_vllm_config() + + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) + + self.layers = nn.ModuleList( + [ + DeepseekV2Eagle3DecoderLayer( + current_vllm_config, + prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"), + config=self.config, + layer_idx=layer_idx, + ) + for layer_idx in range(self.config.num_hidden_layers) + ] + ) + + # fc layer for combining auxiliary hidden states (3x hidden size input) + if hasattr(self.config, "target_hidden_size"): + fc_input_size = self.config.target_hidden_size * 3 + else: + fc_input_size = self.config.hidden_size * 3 + + self.fc = ReplicatedLinear( + input_size=fc_input_size, + output_size=self.config.hidden_size, + bias=False, + params_dtype=vllm_config.model_config.dtype, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "fc"), + return_bias=False, + ) + + self.norm = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_embeds: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if input_embeds is None: + input_embeds = self.embed_input_ids(input_ids) + assert hidden_states.shape[-1] == input_embeds.shape[-1] + + residual = None + for layer in self.layers: + hidden_states, residual = layer( + positions=positions, + embeds=input_embeds, + hidden_states=hidden_states, + residual=residual, + ) + hidden_states, hidden_prenorm = self.norm(hidden_states, residual) + return hidden_states, hidden_prenorm + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + (".fused_qkv_a_proj", ".q_a_proj", 0), + (".fused_qkv_a_proj", ".kv_a_proj_with_mqa", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + if "midlayer." in name: + name = name.replace("midlayer.", "layers.0.") + + # Handle kv cache quantization scales + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + # Remapping the name FP8 kv-scale + if "scale" in name: + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class Eagle3DeepseekV2ForCausalLM(DeepseekV2ForCausalLM): + """Eagle3 speculative decoding model for DeepseekV2/V3.""" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + self.config = vllm_config.speculative_config.draft_model_config.hf_config + + # Ensure draft_vocab_size is set + if getattr(self.config, "draft_vocab_size", None) is None: + base_vocab_size = getattr(self.config, "vocab_size", None) + self.config.draft_vocab_size = base_vocab_size + + target_layer_num = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config + ) + + # Store target layer count in draft config + self.config.target_layer_count = target_layer_num + + self.model = DeepseekV2Eagle3Model( + vllm_config=vllm_config, prefix="model", start_layer_id=target_layer_num + ) + + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.lm_head = ParallelLMHead( + self.config.draft_vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.logits_processor = LogitsProcessor( + self.config.draft_vocab_size, scale=logit_scale + ) + self.draft_id_to_target_id = nn.Parameter( + torch.zeros(self.config.draft_vocab_size, dtype=torch.long), + requires_grad=False, + ) + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: NestedTensors | None = None, + is_multimodal: torch.Tensor | None = None, + ) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.model(input_ids, positions, hidden_states, inputs_embeds) + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + logits = self.logits_processor(self.lm_head, hidden_states) + if self.draft_id_to_target_id is None: + assert logits.shape[1] == self.config.vocab_size, ( + "Expected logits to have shape " + f"(*, {self.config.vocab_size}), but got {logits.shape}" + ) + return logits + + base = torch.arange(self.config.draft_vocab_size, device=logits.device) + targets = base + self.draft_id_to_target_id + logits_new = logits.new_full( + ( + logits.shape[0], + self.config.vocab_size, + ), + float("-inf"), + ) + logits_new[:, targets] = logits + return logits_new + + def combine_hidden_states( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + # Combine multiple auxiliary hidden states returned by Eagle3 + return self.model.fc(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + model_weights = {} + includes_draft_id_mapping = False + includes_embed_tokens = False + + for name, loaded_weight in weights: + if "t2d" in name: + continue + if "d2t" in name: + name = name.replace("d2t", "draft_id_to_target_id") + includes_draft_id_mapping = True + elif "lm_head" not in name: + name = "model." + name + if "embed_tokens" in name: + includes_embed_tokens = True + model_weights[name] = loaded_weight + process_eagle_weight(self, name) + + skip_substrs = [] + if not includes_draft_id_mapping: + skip_substrs.append("draft_id_to_target_id") + if not includes_embed_tokens: + skip_substrs.append("embed_tokens") + + loader = AutoWeightsLoader( + self, + skip_prefixes=None, + skip_substrs=skip_substrs, + ) + loader.load_weights(model_weights.items()) + + +# Aliases for compatibility +Eagle3DeepseekV3ForCausalLM = Eagle3DeepseekV2ForCausalLM diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index b22cdb6d6c80..981719d7bbfc 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -85,7 +85,13 @@ from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec from vllm.v1.worker.workspace import current_workspace_manager -from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP +from .interfaces import ( + MixtureOfExperts, + SupportsEagle, + SupportsEagle3, + SupportsLoRA, + SupportsPP, +) from .utils import ( PPMissingLayer, is_pp_missing_parameter, @@ -946,6 +952,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", topk_indices_buffer: torch.Tensor | None = None, + input_size: int | None = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -965,9 +972,13 @@ def __init__( self.scaling = self.qk_head_dim**-0.5 self.max_position_embeddings = max_position_embeddings + # Use input_size for projection input dimensions if provided, + # otherwise default to hidden_size (used in Eagle3 Deepseek with MLA) + proj_input_size = input_size if input_size is not None else self.hidden_size + if self.q_lora_rank is not None: self.fused_qkv_a_proj = MergedColumnParallelLinear( - self.hidden_size, + proj_input_size, [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], bias=False, quant_config=quant_config, @@ -976,7 +987,7 @@ def __init__( ) else: self.kv_a_proj_with_mqa = ReplicatedLinear( - self.hidden_size, + proj_input_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False, quant_config=quant_config, @@ -994,7 +1005,7 @@ def __init__( ) else: self.q_proj = ColumnParallelLinear( - self.hidden_size, + proj_input_size, self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, @@ -1290,6 +1301,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ["hidden_states", "residual"], config.hidden_size ) + self.aux_hidden_state_layers = tuple[int, ...]() + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -1325,7 +1338,13 @@ def forward( else: llama_4_scaling = None - for layer in islice(self.layers, self.start_layer, self.end_layer): + aux_hidden_states = [] + for idx, layer in enumerate( + islice(self.layers, self.start_layer, self.end_layer), + start=self.start_layer, + ): + if idx in self.aux_hidden_state_layers: + aux_hidden_states.append(hidden_states + residual) hidden_states, residual = layer( positions, hidden_states, residual, llama_4_scaling ) @@ -1336,6 +1355,8 @@ def forward( ) hidden_states, _ = self.norm(hidden_states, residual) + if len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states return hidden_states @@ -1381,7 +1402,12 @@ def update_physical_experts_metadata( class DeepseekV2ForCausalLM( - nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA, SupportsEagle + nn.Module, + SupportsPP, + DeepseekV2MixtureOfExperts, + SupportsLoRA, + SupportsEagle, + SupportsEagle3, ): packed_modules_mapping = { "gate_up_proj": ["gate_proj", "up_proj"], @@ -1460,6 +1486,13 @@ def set_moe_parameters(self): self.extract_moe_parameters(example_moe) + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index e0e346fcd878..f709a2105795 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -444,6 +444,8 @@ "mistral_large_3_eagle", "EagleMistralLarge3ForCausalLM", ), + "Eagle3DeepseekV2ForCausalLM": ("deepseek_eagle3", "Eagle3DeepseekV2ForCausalLM"), + "Eagle3DeepseekV3ForCausalLM": ("deepseek_eagle3", "Eagle3DeepseekV2ForCausalLM"), "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"), diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 16cb2bea3e32..35b24e31589c 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -82,6 +82,7 @@ def __getitem__(self, key): flex_olmo="FlexOlmoConfig", hunyuan_vl="HunYuanVLConfig", isaac="IsaacConfig", + kimi_k2="DeepseekV3Config", # Kimi K2 uses same architecture as DeepSeek V3 kimi_linear="KimiLinearConfig", kimi_vl="KimiVLConfig", RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 66697132b365..f6494546d2d2 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -21,6 +21,7 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal +from vllm.model_executor.models.deepseek_eagle3 import Eagle3DeepseekV2ForCausalLM from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.multimodal import MULTIMODAL_REGISTRY @@ -245,7 +246,9 @@ def propose( last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 if self.method == "eagle3": - assert isinstance(self.model, Eagle3LlamaForCausalLM) + assert isinstance( + self.model, (Eagle3LlamaForCausalLM, Eagle3DeepseekV2ForCausalLM) + ) target_hidden_states = self.model.combine_hidden_states( target_hidden_states )