diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 360f1c32f03b..3f78b70e78db 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -46,7 +46,9 @@ "pangu_ultra_moe_mtp", "step3p5_mtp", ] -EagleModelTypes = Literal["eagle", "eagle3", "extract_hidden_states", MTPModelTypes] +DFlashModelTypes = Literal["dflash"] +EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes, DFlashModelTypes] + NgramGPUTypes = Literal["ngram_gpu"] SpeculativeMethod = Literal[ "ngram", @@ -196,7 +198,11 @@ def compute_hash(self) -> str: factors: list[Any] = [] # Eagle3 and extract_hidden_states affect the computation graph because # they return intermediate hidden states in addition to the final hidden state. - uses_aux_hidden_states = self.method in ("eagle3", "extract_hidden_states") + uses_aux_hidden_states = self.method in ( + "eagle3", + "extract_hidden_states", + "dflash", + ) factors.append(uses_aux_hidden_states) # The specific layers used also affect the computation graph @@ -480,7 +486,7 @@ def __post_init__(self): ) # Automatically detect the method - if self.method in ("eagle", "eagle3"): + if self.method in ("eagle", "eagle3", "dflash"): pass # examples: # yuhuili/EAGLE-LLaMA3-Instruct-8B @@ -490,6 +496,8 @@ def __post_init__(self): self.method = "eagle" elif "eagle3" in self.draft_model_config.model.lower(): self.method = "eagle3" + elif "dflash" in self.draft_model_config.model.lower(): + self.method = "dflash" elif self.draft_model_config.hf_config.model_type == "medusa": self.method = "medusa" elif self.draft_model_config.hf_config.model_type == "mlp_speculator": @@ -795,7 +803,7 @@ def _verify_args(self) -> Self: "kimi_k25", ] if ( - self.method in ("eagle3", "extract_hidden_states") + self.method in ("eagle3", "extract_hidden_states", "dflash") and self.target_model_config and not any( supported_model in self.target_model_config.hf_text_config.model_type @@ -843,7 +851,7 @@ def max_num_new_slots_for_drafting(self) -> int: return slots_per_req def use_eagle(self) -> bool: - return self.method in ("eagle", "eagle3", "mtp") + return self.method in ("eagle", "eagle3", "mtp", "dflash") def uses_draft_model(self) -> bool: return self.method == "draft_model" diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 266ad5477b33..19863d0ebac6 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -310,9 +310,25 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.model.aux_hidden_state_layers = layers - def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + def build_target_layer_ids(self, num_target_layers: int, num_draft_layers: int): + if num_draft_layers == 1: + return [num_target_layers // 2] + start = 1 + end = num_target_layers - 3 + span = end - start + target_layer_ids = [ + int(round(start + (i * span) / (num_draft_layers - 1))) + for i in range(num_draft_layers) + ] + return target_layer_ids + + def get_eagle3_aux_hidden_state_layers(self, method) -> tuple[int, ...]: num_layers = len(self.model.layers) - return (2, num_layers // 2, num_layers - 3) + if method == "dflash": + return_layers = self.build_target_layer_ids(num_layers, 5) + else: + return_layers = [2, num_layers // 2, num_layers - 3] + return tuple(return_layers) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) diff --git a/vllm/model_executor/models/qwen3_dflash.py b/vllm/model_executor/models/qwen3_dflash.py new file mode 100644 index 000000000000..c1ac3edd9884 --- /dev/null +++ b/vllm/model_executor/models/qwen3_dflash.py @@ -0,0 +1,520 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Any + +import torch +from torch import nn +from transformers import Qwen3Config + +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.multimodal.inputs import NestedTensors +from vllm.transformers_utils.config import set_default_rope_theta +from vllm.v1.attention.backend import AttentionType + +from .qwen2 import Qwen2MLP as Qwen3MLP +from .qwen3 import Qwen3ForCausalLM +from .utils import ( + AutoWeightsLoader, + extract_layer_index, + get_draft_quant_config, + maybe_prefix, + process_eagle_weight, +) + +logger = init_logger(__name__) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_safe(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_len = q.size(-2) + q_embed = (q * cos[..., -q_len:, :]) + (rotate_half(q) * sin[..., -q_len:, :]) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen3Attention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_parameters: dict, + max_position: int = 4096 * 32, + head_dim: int | None = None, + rms_norm_eps: float = 1e-06, + qkv_bias: bool = False, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + dual_chunk_attention_config: dict[str, Any] | None = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.dual_chunk_attention_config = dual_chunk_attention_config + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + max_position=max_position, + rope_parameters=rope_parameters, + dual_chunk_attention_config=dual_chunk_attention_config, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=attn_type, + **{ + "layer_idx": extract_layer_index(prefix), + "dual_chunk_attention_config": dual_chunk_attention_config, + } + if dual_chunk_attention_config + else {}, + ) + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + context_states: torch.Tensor, + ) -> torch.Tensor: + assert context_states is not None + + num_context = context_states.shape[0] + + concat_states = torch.cat([context_states, hidden_states], dim=0) + qkv, _ = self.qkv_proj(concat_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) + q = self.q_norm(q_by_head).view(q.shape) + k = self.k_norm(k_by_head).view(k.shape) + + q, k = self.rotary_emb(positions, q, k) + + # Remove context states from Q + q = q[num_context:] + + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Qwen3DecoderLayer(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + *, + config: Qwen3Config, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + layer_idx: int = 0, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + set_default_rope_theta(config, default_theta=1000000) + dual_chunk_attention_config = getattr( + config, "dual_chunk_attention_config", None + ) + attn_type = AttentionType.DECODER + + self.self_attn = Qwen3Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, "attention_bias", False), + head_dim=getattr(config, "head_dim", None), + cache_config=cache_config, + quant_config=quant_config, + rope_parameters=config.rope_parameters, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + dual_chunk_attention_config=dual_chunk_attention_config, + ) + self.mlp = Qwen3MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + context_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is not None: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + else: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + context_states=context_states, + ) + + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class Qwen3Model(nn.Module): + def __init__( + self, + *, + vllm_config: VllmConfig, + start_layer_id: int = 0, + prefix: str = "", + ) -> None: + super().__init__() + self.config = vllm_config.speculative_config.draft_model_config.hf_config + self.vocab_size = self.config.vocab_size + + # Get drafter's quantization config + self.quant_config = get_draft_quant_config(vllm_config) + + drafter_config = getattr(self.config, "eagle_config", {}) + drafter_config.update(getattr(self.config, "dflash_config", {})) + + if drafter_config is not None and "use_aux_hidden_state" in drafter_config: + self.use_aux_hidden_state = drafter_config["use_aux_hidden_state"] + else: + self.use_aux_hidden_state = True + + current_vllm_config = get_current_vllm_config() + + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) + + self.layers = nn.ModuleList( + [ + Qwen3DecoderLayer( + current_vllm_config, + prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"), + config=self.config, + layer_idx=layer_idx, + ) + for layer_idx in range(self.config.num_hidden_layers) + ] + ) + if self.use_aux_hidden_state: + num_features_to_use = self.config.num_hidden_layers + if "layer_ids" in drafter_config: + num_features_to_use = len(drafter_config["layer_ids"]) + if hasattr(self.config, "target_hidden_size"): + fc_input_size = self.config.target_hidden_size * num_features_to_use + else: + fc_input_size = self.config.hidden_size * num_features_to_use + self.fc = ReplicatedLinear( + input_size=fc_input_size, + output_size=self.config.hidden_size, + bias=False, + params_dtype=vllm_config.model_config.dtype, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "fc"), + return_bias=False, + ) + self.hidden_norm = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + ) + self.norm = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_embeds: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if input_embeds is None: + input_embeds = self.embed_input_ids(input_ids) + assert hidden_states.shape[-1] == input_embeds.shape[-1] + + context_states = hidden_states + hidden_states = input_embeds + + residual = None + for layer in self.layers: + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + context_states=context_states, + residual=residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "midlayer." in name: + name = name.replace("midlayer.", "layers.0.") + # Handle kv cache quantization scales + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + # Remapping the name FP8 kv-scale + if "scale" in name: + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class DFlashQwen3ForCausalLM(Qwen3ForCausalLM): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + self.config = vllm_config.speculative_config.draft_model_config.hf_config + # Ensure draft_vocab_size is set + # default to the base vocab size when absent + if getattr(self.config, "draft_vocab_size", None) is None: + base_vocab_size = getattr(self.config, "vocab_size", None) + self.config.draft_vocab_size = base_vocab_size + target_layer_num = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config + ) + + # Store target layer count in draft config for + # proper layer_types indexing in draft models + self.config.target_layer_count = target_layer_num + self.model = Qwen3Model( + vllm_config=vllm_config, prefix="model", start_layer_id=target_layer_num + ) + + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.lm_head = ParallelLMHead( + self.config.draft_vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.logits_processor = LogitsProcessor( + self.config.draft_vocab_size, scale=logit_scale + ) + # self.draft_id_to_target_id = nn.Parameter( + # torch.zeros(self.config.draft_vocab_size, dtype=torch.long), + # requires_grad=False, + # ) + self.draft_id_to_target_id = None + + self.use_parallel_drafting = vllm_config.speculative_config.parallel_drafting + + if self.use_parallel_drafting: + self.register_buffer( + "mask_hidden", + torch.zeros( + 1, + (5 if self.model.use_aux_hidden_state else 1) + * self.config.hidden_size, + ), + persistent=False, + ) + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: NestedTensors | None = None, + is_multimodal: torch.Tensor | None = None, + ) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.model(input_ids, positions, hidden_states, inputs_embeds) + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + logits = self.logits_processor(self.lm_head, hidden_states) + if self.draft_id_to_target_id is None: + assert logits.shape[1] == self.config.vocab_size, ( + "Expected logits to have shape " + f"(*, {self.config.vocab_size}), but got {logits.shape}" + ) + return logits + + base = torch.arange(self.config.draft_vocab_size, device=logits.device) + targets = base + self.draft_id_to_target_id + logits_new = logits.new_full( + ( + logits.shape[0], + self.config.vocab_size, + ), + float("-inf"), + ) + logits_new[:, targets] = logits + return logits_new + + def combine_hidden_states( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + if not self.model.use_aux_hidden_state: + return hidden_states + # combine multiple auxiliary hidden states returned by eagle3 + if hidden_states.dim() == 1: + return self.model.hidden_norm( + self.model.fc(hidden_states).view(-1, 1) + ).view(-1) + return self.model.hidden_norm(self.model.fc(hidden_states)) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + model_weights = {} + includes_draft_id_mapping = False + includes_embed_tokens = False + for name, loaded_weight in weights: + if "t2d" in name: + continue + if "d2t" in name: + name = name.replace("d2t", "draft_id_to_target_id") + includes_draft_id_mapping = True + elif "lm_head" not in name: + name = "model." + name + if "embed_tokens" in name: + includes_embed_tokens = True + model_weights[name] = loaded_weight + process_eagle_weight(self, name) + + skip_substrs = [] + if not includes_draft_id_mapping: + skip_substrs.append("draft_id_to_target_id") + if not includes_embed_tokens: + skip_substrs.append("embed_tokens") + if not self.model.use_aux_hidden_state: + skip_substrs.append("fc.") + loader = AutoWeightsLoader( + self, + skip_prefixes=None, + skip_substrs=skip_substrs, + ) + loader.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index d5d3bd265bee..16395725c674 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -544,6 +544,7 @@ "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"), "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), + "DFlashDraftModel": ("qwen3_dflash", "DFlashQwen3ForCausalLM"), "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index 902e335cb632..5955ffcfed52 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -41,8 +41,8 @@ def __init__( # Eagle model name should follow naming convention of # LlamaForCausalLM -> EagleLlamaForCausalLM - # LlamaForCausalLM -> Eagle3LlamaForCausalLM - # LlamaForCausalLMEagle3 -> LlamaForCausalLMEagle3 + # LlamaForCausalLM -> LlamaForCausalLM + # LlamaForCausalLM -> LlamaForCausalLM if method == "eagle": assert self.model is not None, ( "model should not be None when method is eagle" @@ -62,6 +62,16 @@ def __init__( else f"Eagle3{arch}" for arch in self.model.architectures ] + elif method == "dflash": + assert self.model is not None, ( + "model should not be None when method is dflash" + ) + kwargs["architectures"] = [ + arch + if arch.startswith("DFlash") or arch.endswith("DFlash") + else f"DFlash{arch}" + for arch in self.model.architectures + ] else: raise ValueError( f"Invalid method {method}. Supported methods are eagle and eagle3." diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 445bb403b4b3..1b95d2e3d677 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -23,6 +23,7 @@ from vllm.model_executor.models.deepseek_eagle3 import Eagle3DeepseekV2ForCausalLM from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.model_executor.models.qwen3_dflash import DFlashQwen3ForCausalLM from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.platforms import current_platform from vllm.triton_utils import triton @@ -147,7 +148,7 @@ def __init__( # 1D-RoPE. # See page 5 of https://arxiv.org/abs/2409.12191 self.mrope_positions = torch.zeros( - (3, self.max_num_tokens + 1), dtype=torch.int64, device=device + (3, 2 * (self.max_num_tokens + 1)), dtype=torch.int64, device=device ) elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0: self.xdrope_positions = torch.zeros( @@ -158,11 +159,22 @@ def __init__( else: # RoPE need (max_num_tokens,) self.positions = torch.zeros( - self.max_num_tokens, dtype=torch.int64, device=device + 2 * self.max_num_tokens, dtype=torch.int64, device=device ) self.hidden_states = torch.zeros( (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device ) + if self.method == "dflash": + # --- DFlash scratch state (kept across () call) --- + self._dflash_ctx_len: int = 0 + self.dflash_mask_token_id: int = 151669 + self._dflash_kv_len: int = 0 + self._dflash_num_query_tokens: int = 0 + + # Cached query_start_loc buffers for (batch_size, num_query_tokens) layout + self._dflash_query_start_loc_buffer: torch.Tensor | None = None + self._dflash_query_start_loc_cpu_buffer: torch.Tensor | None = None + self._dflash_query_offsets: torch.Tensor | None = None # Will be set when we initialize the attention backend self.block_size: int = -1 @@ -359,6 +371,12 @@ def _get_slot_mapping( view = self._slot_mapping_buffer[:num_tokens] return {name: view for name in self._draft_attn_layer_names} + def _get_dflash_slot_mapping( + self, + slot_mapping: torch.Tensor, + ) -> dict[str, torch.Tensor]: + return {name: slot_mapping for name in self._draft_attn_layer_names} + def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None: """Initialize cudagraph dispatcher keys for eagle. @@ -403,9 +421,14 @@ def propose( ) -> torch.Tensor: batch_size = common_attn_metadata.batch_size() - if self.method == "eagle3": + if self.method in ("eagle3", "dflash"): assert isinstance( - self.model, (Eagle3LlamaForCausalLM, Eagle3DeepseekV2ForCausalLM) + self.model, + ( + Eagle3LlamaForCausalLM, + Eagle3DeepseekV2ForCausalLM, + DFlashQwen3ForCausalLM, + ), ) target_hidden_states = self.model.combine_hidden_states( target_hidden_states @@ -453,13 +476,29 @@ def propose( input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None + if self.method == "dflash": + positions_len = self._dflash_kv_len + hidden_states_len = self._dflash_ctx_len + else: + positions_len = num_input_tokens + hidden_states_len = num_input_tokens + model_kwargs = { "input_ids": input_ids, - "positions": self._get_positions(num_input_tokens), + "positions": self._get_positions(positions_len), "inputs_embeds": inputs_embeds, } if self.pass_hidden_states_to_model: - model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens] + model_kwargs["hidden_states"] = self.hidden_states[:hidden_states_len] + + if self.method == "dflash": + forward_slot_mapping = self._get_dflash_slot_mapping( + common_attn_metadata.slot_mapping + ) + else: + forward_slot_mapping = self._get_slot_mapping( + num_input_tokens, common_attn_metadata.slot_mapping + ) with set_forward_context( per_layer_attn_metadata, @@ -467,9 +506,7 @@ def propose( num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, - slot_mapping=self._get_slot_mapping( - num_input_tokens, common_attn_metadata.slot_mapping - ), + slot_mapping=forward_slot_mapping, ): ret_hidden_states = self.model(**model_kwargs) if not self.model_returns_tuple(): @@ -481,7 +518,11 @@ def propose( sample_hidden_states = last_hidden_states[token_indices_to_sample] # Early exit if there is only one draft token to be generated. - if self.num_speculative_tokens == 1 or self.parallel_drafting: + if ( + self.num_speculative_tokens == 1 + or self.parallel_drafting + or self.method == "dflash" + ): draft_token_ids = self._greedy_sample(sample_hidden_states) return draft_token_ids.view(-1, self.num_speculative_tokens) @@ -648,6 +689,141 @@ def propose( draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids + def set_dflash_first_pass( + self, + target_token_ids: torch.Tensor, + next_token_ids: torch.Tensor, + target_positions: torch.Tensor, + target_hidden_states: torch.Tensor, + token_indices_to_sample: torch.Tensor | None, + cad: CommonAttentionMetadata, + num_rejected_tokens_gpu: torch.Tensor | None, + ) -> tuple[int, torch.Tensor, CommonAttentionMetadata]: + if self.dflash_mask_token_id is None: + raise ValueError("DFlash requires mask_token_id.") + + if target_positions.dim() != 1: + target_positions = target_positions[0] + + batch_size = cad.batch_size() + device = target_hidden_states.device + num_query_tokens = 1 + self.num_speculative_tokens + num_query_tokens_total = batch_size * num_query_tokens + + num_context_tokens = target_hidden_states.shape[0] + num_kv_tokens = num_context_tokens + num_query_tokens_total + + self._dflash_ctx_len = num_context_tokens + self._dflash_kv_len = num_kv_tokens + self._dflash_num_query_tokens = num_query_tokens + self._dflash_num_query_tokens_total = num_query_tokens_total + + self.input_ids[:num_query_tokens_total].fill_(self.dflash_mask_token_id) + self.input_ids[:num_query_tokens_total:num_query_tokens] = next_token_ids + last_positions = cad.seq_lens.to(torch.long) - 1 + + query_offsets = self._dflash_query_offsets + if query_offsets is None or ( + query_offsets.device != device + or query_offsets.shape[1] != num_query_tokens + or query_offsets.dtype != target_positions.dtype + ): + self._dflash_query_offsets = torch.arange( + num_query_tokens, + device=device, + dtype=target_positions.dtype, + ).view(1, -1) + query_offsets = self._dflash_query_offsets + + assert query_offsets is not None + query_positions = last_positions.view(-1, 1) + 1 + query_offsets + query_positions_flat = query_positions.reshape(-1) + + self.positions[:num_context_tokens] = target_positions[:num_context_tokens] + self.positions[num_context_tokens:num_kv_tokens] = query_positions_flat + + self.hidden_states[:num_context_tokens] = target_hidden_states + + token_indices_to_sample = ( + torch.arange( + num_query_tokens_total, + device=device, + dtype=torch.int32, + ) + .view(batch_size, num_query_tokens)[:, 1:] + .reshape(-1) + ) + + block_size = self.block_size + + block_table_tensor = getattr(cad, "block_table_tensor", None) + if block_table_tensor is None: + raise RuntimeError( + "DFlash requires block_table_tensor in CommonAttentionMetadata." + ) + + block_numbers_bt = (query_positions // block_size).to(torch.long) + block_ids = block_table_tensor.gather(dim=1, index=block_numbers_bt) + query_slot_mapping = ( + block_ids * block_size + (query_positions % block_size) + ).reshape(-1) + + ctx_slot_mapping = cad.slot_mapping[:num_context_tokens] + full_slot_mapping = torch.cat([ctx_slot_mapping, query_slot_mapping], dim=0) + + query_start_loc_buffer = self._dflash_query_start_loc_buffer + if query_start_loc_buffer is None or ( + query_start_loc_buffer.shape[0] < batch_size + 1 + or query_start_loc_buffer.device != device + ): + self._dflash_query_start_loc_buffer = torch.empty( + batch_size + 1, + device=device, + dtype=torch.int32, + ) + query_start_loc_buffer = self._dflash_query_start_loc_buffer + + assert query_start_loc_buffer is not None + qsl = query_start_loc_buffer[: batch_size + 1] + qsl.copy_(torch.arange(batch_size + 1, device=device, dtype=torch.int32)) + qsl.mul_(num_query_tokens) + + query_start_loc_cpu_buffer = self._dflash_query_start_loc_cpu_buffer + if query_start_loc_cpu_buffer is None or ( + query_start_loc_cpu_buffer.shape[0] < batch_size + 1 + ): + self._dflash_query_start_loc_cpu_buffer = torch.empty( + batch_size + 1, + dtype=torch.int32, + pin_memory=is_pin_memory_available(), + ) + query_start_loc_cpu_buffer = self._dflash_query_start_loc_cpu_buffer + + assert query_start_loc_cpu_buffer is not None + qsl_cpu = query_start_loc_cpu_buffer[: batch_size + 1] + qsl_cpu.copy_( + torch.arange(batch_size + 1, dtype=torch.int32).mul_(num_query_tokens) + ) + + new_cad = replace( + cad, + slot_mapping=full_slot_mapping, + num_actual_tokens=num_query_tokens_total, + max_query_len=num_query_tokens, + query_start_loc=qsl, + query_start_loc_cpu=qsl_cpu, + max_seq_len=min( + int(cad.max_seq_len + num_query_tokens), self.max_model_len + ), + seq_lens=(cad.seq_lens + num_query_tokens), + causal=False, + ) + + new_cad._seq_lens_cpu = None + new_cad._num_computed_tokens_cpu = None + + return num_query_tokens_total, token_indices_to_sample, new_cad + def set_inputs_first_pass( self, target_token_ids: torch.Tensor, @@ -658,6 +834,16 @@ def set_inputs_first_pass( cad: CommonAttentionMetadata, num_rejected_tokens_gpu: torch.Tensor | None, ) -> tuple[int, torch.Tensor, CommonAttentionMetadata]: + if self.method == "dflash": + return self.set_dflash_first_pass( + target_token_ids=target_token_ids, + next_token_ids=next_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + token_indices_to_sample=token_indices_to_sample, + cad=cad, + num_rejected_tokens_gpu=num_rejected_tokens_gpu, + ) if not self.needs_extra_input_slots: # Default EAGLE pathway: no reshaping of input tensors needed. # Simply rotate the input ids and leave the positions unchanged, @@ -784,7 +970,7 @@ def set_inputs_first_pass( return total_num_output_tokens, token_indices_to_sample, new_cad def model_returns_tuple(self) -> bool: - return self.method not in ("mtp", "draft_model") + return self.method not in ("mtp", "draft_model", "dflash") def prepare_next_token_ids_cpu( self, @@ -1499,6 +1685,8 @@ def dummy_run( else: slot_mapping_dict = slot_mappings or {} + is_dflash = self.method == "dflash" + with set_forward_context( None, self.vllm_config, @@ -1514,14 +1702,33 @@ def dummy_run( input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None - kwargs = dict( - input_ids=input_ids, - positions=self._get_positions(num_input_tokens), - inputs_embeds=inputs_embeds, - ) - if self.pass_hidden_states_to_model: - kwargs["hidden_states"] = self.hidden_states[:num_input_tokens] - self.model(**kwargs) + if is_dflash: + ctx_len = 0 + q_len = num_input_tokens + kv_len = q_len # positions length == concat_states length + if input_ids is not None: + self.input_ids[:q_len].fill_(1) + input_ids = self.input_ids[:q_len] + + kwargs = dict( + input_ids=input_ids, + positions=self._get_positions(kv_len), + inputs_embeds=inputs_embeds, + ) + if self.pass_hidden_states_to_model: + kwargs["hidden_states"] = self.hidden_states[:ctx_len] + + self.model(**kwargs) + + else: + kwargs = dict( + input_ids=input_ids, + positions=self._get_positions(num_input_tokens), + inputs_embeds=inputs_embeds, + ) + if self.pass_hidden_states_to_model: + kwargs["hidden_states"] = self.hidden_states[:num_input_tokens] + self.model(**kwargs) def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool: """ @@ -1530,13 +1737,17 @@ def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool: They might indicate this by setting "use_aux_hidden_state" to False inside the "eagle_config" dict of their hf_config. """ - if self.method != "eagle3": + if self.method not in ("eagle3", "dflash"): return False - # Assume that eagle3 heads use aux hidden states by default use_aux_hidden_state = True eagle_config = getattr(self.draft_model_config.hf_config, "eagle_config", None) + dflash_config = getattr( + self.draft_model_config.hf_config, "dflash_config", None + ) if eagle_config is not None: use_aux_hidden_state = eagle_config.get("use_aux_hidden_state", True) + if dflash_config is not None: + use_aux_hidden_state = dflash_config.get("use_aux_hidden_state", True) return use_aux_hidden_state def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b53bd71a1cd1..262a9fed58d7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -544,7 +544,7 @@ def __init__( self.drafter = SuffixDecodingProposer(self.vllm_config) elif self.speculative_config.use_eagle(): self.drafter = EagleProposer(self.vllm_config, self.device, self) - if self.speculative_config.method == "eagle3": + if self.speculative_config.method in ("eagle3", "dflash"): self.use_aux_hidden_state_outputs = ( self.drafter.eagle3_use_aux_hidden_state ) @@ -4565,7 +4565,13 @@ def load_model(self, load_dummy_weights: bool = False) -> None: aux_layers, ) else: - aux_layers = self.model.get_eagle3_aux_hidden_state_layers() + speculative_config = self.speculative_config + assert speculative_config is not None + assert speculative_config.method is not None + + aux_layers = self.model.get_eagle3_aux_hidden_state_layers( + method=speculative_config.method + ) self.model.set_aux_hidden_state_layers(aux_layers) time_after_load = time.perf_counter()