diff --git a/tests/v1/e2e/spec_decode/test_spec_decode.py b/tests/v1/e2e/spec_decode/test_spec_decode.py index 4695f6f19662..0afc3800cb38 100644 --- a/tests/v1/e2e/spec_decode/test_spec_decode.py +++ b/tests/v1/e2e/spec_decode/test_spec_decode.py @@ -7,6 +7,7 @@ import pytest import torch +from tqdm import tqdm from tests.evals.gsm8k.gsm8k_eval import _build_gsm8k_prompts, evaluate_gsm8k_offline from tests.utils import ( @@ -1015,19 +1016,178 @@ def some_high_acceptance_metrics() -> dict: } -def compute_acceptance_rate(metrics: list[Metric]) -> float: +def compute_acceptance_rate( + metrics: list[Metric], prev_metrics: list[Metric] | None = None +) -> float: name2metric = {metric.name: metric for metric in metrics} - n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value # type: ignore + n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value if n_draft_toks == 0: return float("nan") - n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore + n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value + if prev_metrics is not None: + prev_name2metric = {metric.name: metric for metric in prev_metrics} + n_draft_toks -= prev_name2metric["vllm:spec_decode_num_draft_tokens"].value + n_accepted_toks -= prev_name2metric[ + "vllm:spec_decode_num_accepted_tokens" + ].value + if n_draft_toks <= 0: + return float("nan") return n_accepted_toks / n_draft_toks -def compute_acceptance_len(metrics: list[Metric]) -> float: +def compute_acceptance_len( + metrics: list[Metric], prev_metrics: list[Metric] | None = None +) -> float: name2metric = {metric.name: metric for metric in metrics} - n_drafts = name2metric["vllm:spec_decode_num_drafts"].value # type: ignore - n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore + n_drafts = name2metric["vllm:spec_decode_num_drafts"].value + n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value if n_drafts == 0: return 1 + if prev_metrics is not None: + prev_name2metric = {metric.name: metric for metric in prev_metrics} + n_drafts -= prev_name2metric["vllm:spec_decode_num_drafts"].value + n_accepted_toks -= prev_name2metric[ + "vllm:spec_decode_num_accepted_tokens" + ].value + if n_drafts <= 0: + return 1 return 1 + (n_accepted_toks / n_drafts) + + +# Datasets in the format used in DFlash validations +def load_and_process_dataset(data_name: str): + from datasets import load_dataset + + if data_name == "gsm8k": + dataset = load_dataset("openai/gsm8k", "main", split="test") + prompt_fmt = ( + "{question}\nPlease reason step by step," + " and put your final answer within \\boxed{{}}." + ) + dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) + elif data_name == "mt-bench": + dataset = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train") + dataset = dataset.map(lambda x: {"turns": x["prompt"]}) + elif data_name == "humaneval": + dataset = load_dataset("openai/openai_humaneval", split="test") + prompt_fmt = ( + "Write a solution to the following problem and make sure" + " that it passes the tests:\n```python\n{prompt}\n```" + ) + dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) + + return dataset + + +@pytest.fixture +def dflash_config(): + target_model = "Qwen/Qwen3-8B" + draft_model = "z-lab/Qwen3-8B-DFlash-b16" + + return dict( + model=target_model, + trust_remote_code=True, + speculative_config={ + "method": "dflash", + "model": draft_model, + "num_speculative_tokens": 16, + "max_model_len": 32768, + }, + max_model_len=32768, + max_num_seqs=128, + gpu_memory_utilization=0.85, + enforce_eager=False, + disable_log_stats=False, + ) + + +def test_dflash_acceptance_rates(dflash_config): + """ + E2E test for DFlash (block diffusion) speculative decoding. + Runs acceptance rate validation on GSM8k, MT-Bench, and HumanEval + comparing against baseline results from the paper (Table 1). + See https://github.com/z-lab/dflash/blob/main/benchmark_sglang.py for methodology. + """ + spec_llm = LLM(**dflash_config) + + max_prompts_per_dataset = 200 # mt-bench has 80, humaneval has 164, truncates gsm8k + + # All scores from Table 1 in https://arxiv.org/pdf/2602.06036 + expected_acceptance_lengths = { + "mt-bench": 4.24, + "humaneval": 6.50, + "gsm8k": 6.54 * 0.95, # runs with a subset of prompts so extra wide tol here + } + + tokenizer = spec_llm.get_tokenizer() + for dataset_name, expected_len in expected_acceptance_lengths.items(): + dataset = load_and_process_dataset(dataset_name) + prev_metrics = None + acceptance_lengths = [] + for i in tqdm( + range(min(max_prompts_per_dataset, len(dataset))), + desc=f"Processing {dataset_name}", + ): + user_content = dataset[i]["turns"][0] + prompt_text = tokenizer.apply_chat_template( + [{"role": "user", "content": user_content}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + + # Temp=0, MaxTokens=2048 from the paper + spec_llm.generate( + [prompt_text], + SamplingParams(temperature=0, max_tokens=2048), + use_tqdm=False, + ) + current_metrics = spec_llm.get_metrics() + acceptance_len = compute_acceptance_len(current_metrics, prev_metrics) + prev_metrics = current_metrics + acceptance_lengths.append(acceptance_len) + + mean_acceptance_length = sum(acceptance_lengths) / len(acceptance_lengths) + expected_len = expected_len * 0.9 + print( + f"DFlash acceptance_len for {dataset_name}: {mean_acceptance_length:.2f}" + f" (expected at least {expected_len:.2f})" + ) + + assert mean_acceptance_length >= expected_len, ( + f"DFlash acceptance_len for {dataset_name} is below expected threshold:" + f"{mean_acceptance_length:.2f} < {expected_len:.2f}" + ) + + del spec_llm + torch.accelerator.empty_cache() + cleanup_dist_env_and_memory() + + +def test_dflash_correctness(dflash_config): + """ + E2E test for DFlash (block diffusion) speculative decoding. + Ensures output correctness on GSM8k, with cudagraphs and batching on. + """ + spec_llm = LLM(**dflash_config) + + # Evaluate GSM8k accuracy (Qwen3-8B ref: ~87-92% on GSM8k) + evaluate_llm_for_gsm8k(spec_llm, expected_accuracy_threshold=0.8) + + current_metrics = spec_llm.get_metrics() + acceptance_len = compute_acceptance_len(current_metrics) + + # AR is thoroughly validated in test_dflash_acceptance_rates, in a manner consistent + # with the DFlash paper. However, that test measures AL per-request and thus runs + # with a batch size of 1. To ensure that AL does not collapse with large batch sizes + # we enforce a baseline on the AL over the full lm-eval-style GSM8k test. + expected_len = 3.5 # Measured is 3.9 to 4.0 + print(f"DFlash GSM8k correctness test got AL {acceptance_len}") + assert acceptance_len >= expected_len, ( + "DFlash correctness check failed with" + f" {acceptance_len=}, expected at least {expected_len}" + ) + + del spec_llm + torch.accelerator.empty_cache() + cleanup_dist_env_and_memory() diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index fb4ea1bcecbd..9649c402e018 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -27,6 +27,7 @@ from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.platforms import current_platform from vllm.v1.attention.backends.registry import AttentionBackendEnum +from vllm.v1.spec_decode.dflash import DFlashProposer from vllm.v1.spec_decode.draft_model import DraftModelProposer from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -36,6 +37,8 @@ eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" ar_draft_model_dir = "amd/PARD-Llama-3.2-1B" # Compatible with parallel and AR drafting +dflash_target_dir = "Qwen/Qwen3-8B" +dflash_dir = "z-lab/Qwen3-8B-DFlash-b16" BLOCK_SIZE = 16 @@ -47,18 +50,29 @@ def _create_proposer( speculative_token_tree: list[tuple[int, ...]] | None = None, parallel_drafting: bool = False, ) -> EagleProposer: - model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100) - # Method-dependent setup if method == "eagle": + target_model_dir = model_dir draft_model_dir = eagle_dir elif method == "eagle3": + target_model_dir = model_dir draft_model_dir = eagle3_dir elif method == "draft_model": + target_model_dir = model_dir draft_model_dir = ar_draft_model_dir + elif method == "dflash": + target_model_dir = dflash_target_dir + draft_model_dir = dflash_dir else: raise ValueError(f"Unknown method: {method}") + model_config = ModelConfig( + model=target_model_dir, + runner="generate", + max_model_len=100, + trust_remote_code=(method == "dflash"), + ) + spec_token_tree_str = None if speculative_token_tree is not None: assert num_speculative_tokens == len(speculative_token_tree) @@ -92,7 +106,9 @@ def _create_proposer( attention_config=AttentionConfig(backend=attention_backend), ) - if "eagle" in method: + if method == "dflash": + proposer = DFlashProposer(vllm_config=vllm_config, device=device) + elif "eagle" in method: proposer = EagleProposer(vllm_config=vllm_config, device=device) else: proposer = DraftModelProposer(vllm_config=vllm_config, device=device) @@ -1152,3 +1168,134 @@ def create_deterministic_logits(token_ids, k: int): # Verify that the draft tokens match our expectations. assert torch.equal(result, expected_tokens) + + +def test_set_inputs_first_pass_dflash(): + """ + Test for DFlash set_inputs_first_pass. + + DFlash uses cross-attention: context tokens become K/V and only + query tokens (bonus + mask) are Q. This tests the DFlash-specific + input preparation where: + - Context hidden states are copied as-is + - Query input_ids are [next_token, mask, mask, ...] per request + - Positions cover context (copied) + query (last_pos + 1 + offset) + - token_indices_to_sample points to mask token positions only + - A new CommonAttentionMetadata is returned with causal=False + + Setup: + - 3 requests with query_lens [3, 2, 4] + - num_speculative_tokens = 3 + - num_query_per_req = 4 (1 bonus + 3 mask tokens) + - next_token_ids: [100, 200, 300] + + Expected output layout (query tokens only, 12 total): + Request 0 (indices 0-3): [100, mask, mask, mask] + Request 1 (indices 4-7): [200, mask, mask, mask] + Request 2 (indices 8-11): [300, mask, mask, mask] + + Expected positions layout: + Context (first 9): copied from target_positions + Query (next 12): + Request 0: last_pos=9, query=[10, 11, 12, 13] + Request 1: last_pos=7, query=[8, 9, 10, 11] + Request 2: last_pos=11, query=[12, 13, 14, 15] + """ + device = torch.device(current_platform.device_type) + + num_speculative_tokens = 3 + proposer = _create_proposer("dflash", num_speculative_tokens) + mask_token_id = proposer.parallel_drafting_token_id + + # Setup batch with 3 requests + batch_spec = BatchSpec( + seq_lens=[10, 8, 12], + query_lens=[3, 2, 4], + ) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, + block_size=BLOCK_SIZE, + device=device, + arange_block_indices=True, + ) + + # Input tensors + # Request 0: tokens [10, 11, 12] at positions [7, 8, 9] + # Request 1: tokens [20, 21] at positions [6, 7] + # Request 2: tokens [30, 31, 32, 33] at positions [8, 9, 10, 11] + target_token_ids = torch.tensor( + [10, 11, 12, 20, 21, 30, 31, 32, 33], dtype=torch.int32, device=device + ) + target_positions = torch.tensor( + [7, 8, 9, 6, 7, 8, 9, 10, 11], dtype=torch.int64, device=device + ) + target_hidden_states = torch.randn( + 9, proposer.hidden_size, dtype=proposer.dtype, device=device + ) + next_token_ids = torch.tensor([100, 200, 300], dtype=torch.int32, device=device) + + num_tokens, token_indices_to_sample, output_cad = proposer.set_inputs_first_pass( + target_token_ids=target_token_ids, + next_token_ids=next_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + token_indices_to_sample=None, + cad=common_attn_metadata, + num_rejected_tokens_gpu=None, + ) + + num_query_per_req = 1 + num_speculative_tokens # 4 + num_context = 9 + + # num_tokens is the query-only count + assert num_tokens == 3 * num_query_per_req # 12 + + # Verify input_ids (query tokens only) + # Each request: [next_token, mask, mask, mask] + M = mask_token_id + expected_input_ids = torch.tensor( + [100, M, M, M, 200, M, M, M, 300, M, M, M], + dtype=torch.int32, + device=device, + ) + assert torch.equal(proposer.input_ids[:num_tokens], expected_input_ids) + + # Verify context positions (first 9 slots): copied from target_positions + assert torch.equal(proposer.positions[:num_context], target_positions) + + # Verify query positions (next 12 slots): + # req0: last_pos=9, query=[10, 11, 12, 13] + # req1: last_pos=7, query=[8, 9, 10, 11] + # req2: last_pos=11, query=[12, 13, 14, 15] + expected_query_positions = torch.tensor( + [10, 11, 12, 13, 8, 9, 10, 11, 12, 13, 14, 15], + dtype=torch.int64, + device=device, + ) + assert torch.equal( + proposer.positions[num_context : num_context + num_tokens], + expected_query_positions, + ) + + # Verify token_indices_to_sample (mask tokens only, skip bonus at offset 0) + # req0: query indices 0-3, mask at 1,2,3 + # req1: query indices 4-7, mask at 5,6,7 + # req2: query indices 8-11, mask at 9,10,11 + expected_token_indices_to_sample = torch.tensor( + [1, 2, 3, 5, 6, 7, 9, 10, 11], dtype=torch.int32, device=device + ) + assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample) + + # Verify the new CAD has DFlash-specific properties + assert output_cad.causal is False # DFlash requires non-causal attention + assert output_cad.num_actual_tokens == num_tokens # query-only count + assert output_cad.max_query_len == num_query_per_req + + expected_query_start_loc = torch.tensor( + [0, 4, 8, 12], dtype=torch.int32, device=device + ) + assert torch.equal(output_cad.query_start_loc, expected_query_start_loc) + + # Verify hidden states (context copied as-is) + assert torch.equal(proposer.hidden_states[:num_context], target_hidden_states) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index e9dc4cac5c11..f8454b13ed9f 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -46,8 +46,11 @@ "pangu_ultra_moe_mtp", "step3p5_mtp", ] -EagleModelTypes = Literal["eagle", "eagle3", "extract_hidden_states", MTPModelTypes] NgramGPUTypes = Literal["ngram_gpu"] +DFlashModelTypes = Literal["dflash"] +EagleModelTypes = Literal[ + "eagle", "eagle3", "extract_hidden_states", MTPModelTypes, DFlashModelTypes +] SpeculativeMethod = Literal[ "ngram", "medusa", @@ -193,7 +196,11 @@ def compute_hash(self) -> str: factors: list[Any] = [] # Eagle3 and extract_hidden_states affect the computation graph because # they return intermediate hidden states in addition to the final hidden state. - uses_aux_hidden_states = self.method in ("eagle3", "extract_hidden_states") + uses_aux_hidden_states = self.method in ( + "eagle3", + "extract_hidden_states", + "dflash", + ) factors.append(uses_aux_hidden_states) # The specific layers used also affect the computation graph @@ -477,7 +484,7 @@ def __post_init__(self): ) # Automatically detect the method - if self.method in ("eagle", "eagle3"): + if self.method in ("eagle", "eagle3", "dflash"): pass # examples: # yuhuili/EAGLE-LLaMA3-Instruct-8B @@ -487,6 +494,8 @@ def __post_init__(self): self.method = "eagle" elif "eagle3" in self.draft_model_config.model.lower(): self.method = "eagle3" + elif "dflash" in self.draft_model_config.model.lower(): + self.method = "dflash" elif self.draft_model_config.hf_config.model_type == "medusa": self.method = "medusa" elif self.draft_model_config.hf_config.model_type == "mlp_speculator": @@ -519,7 +528,7 @@ def __post_init__(self): ) # Replace hf_config for EAGLE draft_model - if self.method in ("eagle", "eagle3"): + if self.method in ("eagle", "eagle3", "dflash"): from vllm.transformers_utils.configs.eagle import EAGLEConfig from vllm.transformers_utils.configs.speculators import ( SpeculatorsConfig, @@ -539,6 +548,9 @@ def __post_init__(self): self.draft_model_config.hf_config = eagle_config self.update_arch_() + if self.method == "dflash": + self.parallel_drafting = True + if self.num_speculative_tokens is not None and hasattr( self.draft_model_config.hf_config, "num_lookahead_tokens" ): @@ -794,7 +806,7 @@ def _verify_args(self) -> Self: "kimi_k25", ] if ( - self.method in ("eagle3", "extract_hidden_states") + self.method in ("eagle3", "extract_hidden_states", "dflash") and self.target_model_config and not any( supported_model in self.target_model_config.hf_text_config.model_type @@ -842,7 +854,10 @@ def max_num_new_slots_for_drafting(self) -> int: return slots_per_req def use_eagle(self) -> bool: - return self.method in ("eagle", "eagle3", "mtp") + return self.method in ("eagle", "eagle3", "mtp", "dflash") + + def use_dflash(self) -> bool: + return self.method == "dflash" def uses_draft_model(self) -> bool: return self.method == "draft_model" diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index c7449840d525..9bd4609f3b44 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1298,6 +1298,26 @@ def _set_max_num_scheduled_tokens(self): max_num_batched_tokens - scheduled_token_delta ) + if self.scheduler_config.max_num_scheduled_tokens <= 0: + raise ValueError( + "max_num_scheduled_tokens is set to" + f" {self.scheduler_config.max_num_scheduled_tokens} based on" + " the speculative decoding settings, which does not allow" + " any tokens to be scheduled. Increase max_num_batched_tokens" + " to accommodate the additional draft token slots, or decrease" + " num_speculative_tokens or max_num_seqs." + ) + if self.scheduler_config.max_num_scheduled_tokens < 8192: + logger.warning_once( + "max_num_scheduled_tokens is set to" + f" {self.scheduler_config.max_num_scheduled_tokens} based on" + " the speculative decoding settings. This may lead to suboptimal" + " performance. Consider increasing max_num_batched_tokens to" + " accommodate the additional draft token slots, or decrease" + " num_speculative_tokens or max_num_seqs.", + scope="local", + ) + max_num_scheduled_tokens = self.scheduler_config.max_num_scheduled_tokens if max_num_batched_tokens < max_num_scheduled_tokens + ( self.speculative_config.max_num_new_slots_for_drafting diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 91931f9f424f..6dec60232b1d 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -285,6 +285,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config + self.vllm_config = vllm_config self.quant_config = quant_config self.model = Qwen3Model( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") diff --git a/vllm/model_executor/models/qwen3_dflash.py b/vllm/model_executor/models/qwen3_dflash.py new file mode 100644 index 000000000000..ec2d20af5511 --- /dev/null +++ b/vllm/model_executor/models/qwen3_dflash.py @@ -0,0 +1,612 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import Qwen3Config + +from vllm import _custom_ops as ops +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.multimodal.inputs import NestedTensors +from vllm.transformers_utils.config import set_default_rope_theta +from vllm.v1.attention.backend import AttentionType + +from .qwen2 import Qwen2MLP as Qwen3MLP +from .qwen3 import Qwen3ForCausalLM +from .utils import ( + AutoWeightsLoader, + get_draft_quant_config, + maybe_prefix, + process_eagle_weight, +) + +logger = init_logger(__name__) + + +class DFlashQwen3Attention(nn.Module): + """Attention for DFlash speculative decoding. + + Context KVs are pre-inserted into the KV cache before the forward pass. + This layer handles only query tokens via standard attention. + Adapted from Qwen3Attention.""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_parameters: dict, + max_position: int = 4096 * 32, + head_dim: int | None = None, + rms_norm_eps: float = 1e-06, + attention_bias: bool = False, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ) -> None: + super().__init__() + self.layer_name = prefix + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=attention_bias, # DFlash has o_proj bias when using attention bias + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + max_position=max_position, + rope_parameters=rope_parameters, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=attn_type, + ) + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + """DFlash attention assumes that the KV cache is already populated + with the context K/V from the target model's hidden states. This forward op + computes attention for the query tokens only. + See also: precompute_and_store_context_kv""" + qkv = F.linear(hidden_states, self.qkv_proj.weight, self.qkv_proj.bias) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # Per-head RMSNorm + q_shape, k_shape = q.shape, k.shape + q = self.q_norm( + q.view(*q_shape[:-1], q_shape[-1] // self.head_dim, self.head_dim) + ).view(q_shape) + k = self.k_norm( + k.view(*k_shape[:-1], k_shape[-1] // self.head_dim, self.head_dim) + ).view(k_shape) + + q, k = self.rotary_emb(positions, q, k) + + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class DFlashQwen3DecoderLayer(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + *, + config: Qwen3Config, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + set_default_rope_theta(config, default_theta=1000000) + attn_type = AttentionType.DECODER + + self.self_attn = DFlashQwen3Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rms_norm_eps=config.rms_norm_eps, + attention_bias=getattr(config, "attention_bias", False), + head_dim=getattr(config, "head_dim", None), + cache_config=cache_config, + quant_config=quant_config, + rope_parameters=config.rope_parameters, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + ) + self.mlp = Qwen3MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is not None: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + else: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class DFlashQwen3Model(nn.Module): + def __init__( + self, + *, + vllm_config: VllmConfig, + start_layer_id: int = 0, + prefix: str = "", + ) -> None: + super().__init__() + self.config = vllm_config.speculative_config.draft_model_config.hf_config + self.vocab_size = self.config.vocab_size + self.quant_config = get_draft_quant_config(vllm_config) + + drafter_config = getattr(self.config, "eagle_config", {}) + drafter_config.update(getattr(self.config, "dflash_config", {})) + + if drafter_config is not None and "use_aux_hidden_state" in drafter_config: + self.use_aux_hidden_state = drafter_config["use_aux_hidden_state"] + else: + self.use_aux_hidden_state = True + + current_vllm_config = get_current_vllm_config() + + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) + + self.layers = nn.ModuleList( + [ + DFlashQwen3DecoderLayer( + current_vllm_config, + prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"), + config=self.config, + ) + for layer_idx in range(self.config.num_hidden_layers) + ] + ) + if self.use_aux_hidden_state: + num_features_to_use = self.config.num_hidden_layers + if "target_layer_ids" in drafter_config: + num_features_to_use = len(drafter_config["target_layer_ids"]) + elif "layer_ids" in drafter_config: + num_features_to_use = len(drafter_config["layer_ids"]) + if hasattr(self.config, "target_hidden_size"): + fc_input_size = self.config.target_hidden_size * num_features_to_use + else: + fc_input_size = self.config.hidden_size * num_features_to_use + self.fc = ReplicatedLinear( + input_size=fc_input_size, + output_size=self.config.hidden_size, + bias=False, + params_dtype=vllm_config.model_config.dtype, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "fc"), + return_bias=False, + ) + self.hidden_norm = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + ) + self.norm = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def _build_fused_kv_buffers(self) -> None: + """Build fused weight buffers for precompute_and_store_context_kv. + + Must be called after weights are loaded. Stacks the KV-projection + weights, K-norm weights, and RoPE parameters from every attention + layer so that precompute_and_store_context_kv can run one fused + GEMM for all layers at once. Also aliases the weight of the hidden_norm. + """ + layers_attn = [layer.self_attn for layer in self.layers] + attn0 = layers_attn[0] + has_bias = attn0.qkv_proj.bias is not None + + self._hidden_norm_weight = self.hidden_norm.weight.data + + # KV projection weights: [num_layers * 2 * kv_size, hidden_size] + kv_weights = [a.qkv_proj.weight[a.q_size :] for a in layers_attn] + self._fused_kv_weight = torch.cat(kv_weights, dim=0) + if has_bias: + kv_biases = [a.qkv_proj.bias[a.q_size :] for a in layers_attn] + self._fused_kv_bias: torch.Tensor | None = torch.cat(kv_biases, dim=0) + else: + self._fused_kv_bias = None + + # K-norm weights: list of [head_dim] tensors, one per layer. + self._k_norm_weights = [a.k_norm.weight.data for a in layers_attn] + + # RoPE parameters + self._rope_head_size = attn0.rotary_emb.head_size + self._rope_cos_sin_cache = attn0.rotary_emb.cos_sin_cache + self._rope_is_neox = attn0.rotary_emb.is_neox_style + # Validation that RoPE params are the same across all layers + for attn in layers_attn[1:]: + assert ( + attn.rotary_emb.head_size == self._rope_head_size + and attn.rotary_emb.is_neox_style == self._rope_is_neox + ), "All layers must have the same RoPE parameters for DFlash precomputation" + + # Layer metadata + self._num_attn_layers = len(layers_attn) + self._kv_size = attn0.kv_size + self._head_dim = attn0.head_dim + self._num_kv_heads = attn0.num_kv_heads + self._rms_norm_eps = attn0.q_norm.variance_epsilon + # Validation that all layers have the same attention config + for attn in layers_attn[1:]: + assert ( + attn.kv_size == self._kv_size + and attn.head_dim == self._head_dim + and attn.num_kv_heads == self._num_kv_heads + and attn.q_norm.variance_epsilon == self._rms_norm_eps + ), "All layers must have the same attn config for DFlash precomputation" + + # References to inner Attention layers for direct cache writes + self._attn_layers = [layer.self_attn.attn for layer in self.layers] + + def precompute_and_store_context_kv( + self, + context_states: torch.Tensor, + context_positions: torch.Tensor, + context_slot_mapping: torch.Tensor | None = None, + ) -> None: + """Precompute K/V for context states write them into each layer's KV cache. + + Input context states are projected to K/V, normed, and have RoPE applied. + Since the context shape is different than the query shape, we can't rely on the + regular forward pass to apply torch.compile and CUDA graphs to this section. + As such, this function is optimized to minimize the number of torch ops present: + we use fused vLLM kernels for RMSNorm and RoPE, fuse the GEMM into one + large projection, and avoid cloning buffers (with .contiguous()) where possible. + + When context_slot_mapping is None (e.g. during dummy_run) only + the computation runs, and no K/V is written to cache. + """ + num_ctx = context_states.shape[0] + L = self._num_attn_layers + kv = self._kv_size + hd = self._head_dim + nkv = self._num_kv_heads + + # --- Fused KV projection (one GEMM for all layers) --- + normed_context_states = torch.empty_like(context_states) + ops.rms_norm( + normed_context_states, + context_states, + self._hidden_norm_weight, + self._rms_norm_eps, + ) + all_kv_flat = F.linear( + normed_context_states, self._fused_kv_weight, self._fused_kv_bias + ) + # Single contiguous copy that separates K/V and transposes to + # layer-major layout. Result: [2, L, num_ctx, nkv, hd] contiguous. + # Indexing dim-0 gives contiguous [L, num_ctx, nkv, hd] for K and V. + all_kv = ( + all_kv_flat.view(num_ctx, L, 2, nkv, hd).permute(2, 1, 0, 3, 4).contiguous() + ) + all_k = all_kv[0] # [L, num_ctx, nkv, hd], contiguous + all_v = all_kv[1] # [L, num_ctx, nkv, hd], contiguous + + # --- Per-layer RMSNorm K (3D: [num_ctx, nkv, hd] per layer) --- + all_k_normed = torch.empty_like(all_k) + for i in range(L): + ops.rms_norm( + all_k_normed[i], + all_k[i], + self._k_norm_weights[i], + self._rms_norm_eps, + ) + + # --- Fused RoPE across all layers --- + # View as [L * num_ctx, kv] so RoPE sees one big batch (no copy). + # In-place RoPE: pass K as the "query" arg with key=None. + all_k_flat = all_k_normed.view(L * num_ctx, kv) + positions_repeated = context_positions.repeat(L) + cos_sin_cache = self._rope_cos_sin_cache + if cos_sin_cache.dtype != all_k_flat.dtype: + cos_sin_cache = cos_sin_cache.to(dtype=all_k_flat.dtype) + ops.rotary_embedding( + positions_repeated, + all_k_flat, + None, + self._rope_head_size, + cos_sin_cache, + self._rope_is_neox, + ) + + if context_slot_mapping is None: + return + + # --- Per-layer cache insert --- + all_k_final = all_k_flat.view(L, num_ctx, nkv, hd) + for i in range(L): + attn = self._attn_layers[i] + kv_cache = attn.kv_cache + attn.impl.do_kv_cache_update( + attn, + all_k_final[i], + all_v[i], + kv_cache, + context_slot_mapping, + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + if input_embeds is None: + input_embeds = self.embed_input_ids(input_ids) + + hidden_states = input_embeds + + residual = None + for layer in self.layers: + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "midlayer." in name: + name = name.replace("midlayer.", "layers.0.") + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + if "scale" in name: + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class DFlashQwen3ForCausalLM(Qwen3ForCausalLM): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + self.config = vllm_config.speculative_config.draft_model_config.hf_config + if getattr(self.config, "draft_vocab_size", None) is None: + self.config.draft_vocab_size = getattr(self.config, "vocab_size", None) + target_layer_num = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config + ) + self.config.target_layer_count = target_layer_num + self.model = DFlashQwen3Model( + vllm_config=vllm_config, + prefix="model", + start_layer_id=target_layer_num, + ) + + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.lm_head = ParallelLMHead( + self.config.draft_vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.logits_processor = LogitsProcessor( + self.config.draft_vocab_size, scale=logit_scale + ) + self.draft_id_to_target_id = None + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: NestedTensors | None = None, + is_multimodal: torch.Tensor | None = None, + ) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + return self.model(input_ids, positions, inputs_embeds) + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + logits = self.logits_processor(self.lm_head, hidden_states) + if self.draft_id_to_target_id is None: + return logits + + base = torch.arange(self.config.draft_vocab_size, device=logits.device) + targets = base + self.draft_id_to_target_id + logits_new = logits.new_full( + (logits.shape[0], self.config.vocab_size), + float("-inf"), + ) + logits_new[:, targets] = logits + return logits_new + + def precompute_and_store_context_kv( + self, + context_states: torch.Tensor, + context_positions: torch.Tensor, + context_slot_mapping: torch.Tensor | None = None, + ) -> None: + """Precompute projected + RoPE'd K/V and write to cache.""" + self.model.precompute_and_store_context_kv( + context_states, context_positions, context_slot_mapping + ) + + def combine_hidden_states( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + if not self.model.use_aux_hidden_state: + return hidden_states + needs_squeeze = hidden_states.dim() == 1 + if needs_squeeze: + hidden_states = hidden_states.unsqueeze(0) + result = self.model.fc(hidden_states) + if needs_squeeze: + result = result.squeeze(0) + return result + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + model_weights = {} + includes_draft_id_mapping = False + includes_embed_tokens = False + for name, loaded_weight in weights: + assert "mask_hidden" not in name, ( + "DFlash should use mask_token_id to embed the padding hidden state" + ) + if "t2d" in name: + continue + if "d2t" in name: + name = name.replace("d2t", "draft_id_to_target_id") + includes_draft_id_mapping = True + elif "lm_head" not in name: + name = "model." + name + if "embed_tokens" in name: + includes_embed_tokens = True + model_weights[name] = loaded_weight + process_eagle_weight(self, name) + + skip_substrs = [] + if not includes_draft_id_mapping: + skip_substrs.append("draft_id_to_target_id") + if not includes_embed_tokens: + skip_substrs.append("embed_tokens") + if not self.model.use_aux_hidden_state: + skip_substrs.append("fc.") + loader = AutoWeightsLoader( + self, + skip_prefixes=None, + skip_substrs=skip_substrs, + ) + loader.load_weights(model_weights.items()) + self.model._build_fused_kv_buffers() diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 5dfcd677b9a1..ad9c89b4fb30 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -90,6 +90,7 @@ from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from .interfaces import ( + EagleModelMixin, HasInnerState, IsHybrid, MixtureOfExperts, @@ -1339,7 +1340,7 @@ def forward( @support_torch_compile -class Qwen3NextModel(nn.Module): +class Qwen3NextModel(nn.Module, EagleModelMixin): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -1377,8 +1378,6 @@ def get_layer(prefix: str): else: self.norm = PPMissingLayer() - self.aux_hidden_state_layers: tuple[int, ...] = () - def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -1400,20 +1399,19 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - aux_hidden_states = [] + aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual) for layer_idx, layer in enumerate( islice(self.layers, self.start_layer, self.end_layer), start=self.start_layer, ): - if layer_idx in self.aux_hidden_state_layers: - aux_hidden_states.append( - hidden_states + residual if residual is not None else hidden_states - ) hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, residual=residual, ) + self._maybe_add_hidden_state( + aux_hidden_states, layer_idx + 1, hidden_states, residual + ) if not get_pp_group().is_last_rank: return IntermediateTensors( diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index c3e7edb7da4a..03cabeb1123a 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -542,6 +542,7 @@ "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"), "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"), "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), + "DFlashDraftModel": ("qwen3_dflash", "DFlashQwen3ForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index 902e335cb632..682828c0a594 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -62,9 +62,20 @@ def __init__( else f"Eagle3{arch}" for arch in self.model.architectures ] + elif method == "dflash": + assert self.model is not None, ( + "model should not be None when method is dflash" + ) + kwargs["architectures"] = [ + arch + if arch.startswith("DFlash") or arch.endswith("DFlash") + else f"DFlash{arch}" + for arch in self.model.architectures + ] else: raise ValueError( - f"Invalid method {method}. Supported methods are eagle and eagle3." + f"Invalid method {method}. Supported methods are " + "eagle, eagle3, and dflash." ) super().__init__(**kwargs) diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index cd49ea30e6f4..b6d2db382bcb 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -220,6 +220,17 @@ def is_sparse(cls) -> bool: def supports_per_head_quant_scales(cls) -> bool: return False + @classmethod + def supports_non_causal(cls) -> bool: + """Check if backend supports non-causal (bidirectional) attention + for decoder models. + + Unlike ENCODER_ONLY attention type which implies a different + execution model, this refers to non-causal attention within the + standard paged-KV-cache decoder path. + """ + return False + @classmethod def supports_attn_type(cls, attn_type: str) -> bool: """Check if backend supports a given attention type. @@ -261,6 +272,7 @@ def validate_configuration( use_per_head_quant_scales: bool, device_capability: "DeviceCapability", attn_type: str, + use_non_causal: bool = False, ) -> list[str]: invalid_reasons = [] if not cls.supports_head_size(head_size): @@ -293,6 +305,8 @@ def validate_configuration( invalid_reasons.append("compute capability not supported") if not cls.supports_attn_type(attn_type): invalid_reasons.append(f"attention type {attn_type} not supported") + if use_non_causal and not cls.supports_non_causal(): + invalid_reasons.append("non-causal attention not supported") combination_reason = cls.supports_combination( head_size, dtype, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 245995be2642..67c68c508801 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -94,6 +94,10 @@ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: def get_name() -> str: return "FLASH_ATTN" + @classmethod + def supports_non_causal(cls) -> bool: + return True + @classmethod def supports_attn_type(cls, attn_type: str) -> bool: """FlashAttention supports all attention types.""" diff --git a/vllm/v1/attention/selector.py b/vllm/v1/attention/selector.py index 40cc1027874d..ce7f8c30f907 100644 --- a/vllm/v1/attention/selector.py +++ b/vllm/v1/attention/selector.py @@ -29,6 +29,7 @@ class AttentionSelectorConfig(NamedTuple): use_mm_prefix: bool = False use_per_head_quant_scales: bool = False attn_type: str = AttentionType.DECODER + use_non_causal: bool = False def __repr__(self): return ( @@ -41,7 +42,8 @@ def __repr__(self): f"use_sparse={self.use_sparse}, " f"use_mm_prefix={self.use_mm_prefix}, " f"use_per_head_quant_scales={self.use_per_head_quant_scales}, " - f"attn_type={self.attn_type})" + f"attn_type={self.attn_type}, " + f"use_non_causal={self.use_non_causal})" ) @@ -76,6 +78,11 @@ def get_attn_backend( else: block_size = None + speculative_config = vllm_config.speculative_config + use_non_causal = ( + speculative_config is not None and speculative_config.method == "dflash" + ) + attn_selector_config = AttentionSelectorConfig( head_size=head_size, dtype=dtype, @@ -87,6 +94,7 @@ def get_attn_backend( use_mm_prefix=use_mm_prefix, use_per_head_quant_scales=use_per_head_quant_scales, attn_type=attn_type or AttentionType.DECODER, + use_non_causal=use_non_causal, ) return _cached_get_attn_backend( diff --git a/vllm/v1/spec_decode/dflash.py b/vllm/v1/spec_decode/dflash.py new file mode 100644 index 000000000000..666ec91aeea4 --- /dev/null +++ b/vllm/v1/spec_decode/dflash.py @@ -0,0 +1,282 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any + +import torch +from typing_extensions import override + +from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.triton_utils import triton +from vllm.v1.attention.backend import CommonAttentionMetadata +from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer +from vllm.v1.spec_decode.utils import copy_and_expand_dflash_inputs_kernel + +logger = init_logger(__name__) + + +class DFlashProposer(SpecDecodeBaseProposer): + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + runner=None, + ): + assert vllm_config.speculative_config is not None + assert vllm_config.speculative_config.method == "dflash" + super().__init__( + vllm_config=vllm_config, + device=device, + pass_hidden_states_to_model=True, + runner=runner, + ) + + # Only next_token_ids and mask tokens are query tokens, all other context is K/V + self.max_query_tokens = self.max_batch_size * (1 + self.num_speculative_tokens) + # Positions covers both context states + query states + self.max_positions = self.max_num_tokens + self.max_query_tokens + + # Separate context buffers to keep query buffer addresses stable for CUDA graphs + self._context_slot_mapping_buffer = torch.zeros( + self.max_num_tokens, + dtype=torch.int64, + device=device, + ) + self._slot_mapping_buffer = torch.zeros( + self.max_query_tokens, + dtype=torch.int64, + device=device, + ) + self._context_positions_buffer = torch.zeros( + self.max_num_tokens, + dtype=torch.int64, + device=device, + ) + self.positions = torch.zeros( + self.max_query_tokens, + dtype=torch.int64, + device=device, + ) + + self.arange = torch.arange( + self.max_positions + 1, device=device, dtype=torch.int32 + ) + + # For DFlash we use the input embeddings to embed the mask token + self.parallel_drafting_hidden_state_tensor = None + + @override + def _raise_if_multimodal(self): + # Override to allow multimodal inputs since DFlash supports Qwen3.5 models + # Support for multimodal inputs has not been tested. + pass + + @override + def set_inputs_first_pass( + self, + target_token_ids: torch.Tensor, + next_token_ids: torch.Tensor, + target_positions: torch.Tensor, + target_hidden_states: torch.Tensor, + token_indices_to_sample: torch.Tensor | None, + cad: CommonAttentionMetadata, + num_rejected_tokens_gpu: torch.Tensor | None, + ) -> tuple[int, torch.Tensor, CommonAttentionMetadata]: + # DFlash cross-attention: context K/V from target hidden states, + # Q from query embeddings (bonus + mask tokens). + batch_size = cad.batch_size() + num_context = target_token_ids.shape[0] + num_query_per_req = 1 + self.num_speculative_tokens + num_query_total = batch_size * num_query_per_req + + # Store for build_model_inputs_first_pass to use + self._dflash_num_context = num_context + + # We don't need to copy into a buffer here since the context preprocessing + # does not run in a CUDA graph + self._dflash_hidden_states = target_hidden_states + + token_indices_to_sample = torch.empty( + batch_size * self.num_speculative_tokens, + dtype=torch.int32, + device=self.device, + ) + + # Launch fused triton kernel for input_ids, positions, slot_mapping, + # and token_indices_to_sample + max_ctx_per_req = cad.max_query_len + max_tokens_per_req = max_ctx_per_req + num_query_per_req + BLOCK_SIZE = min(256, triton.next_power_of_2(max_tokens_per_req)) + num_blocks = triton.cdiv(max_tokens_per_req, BLOCK_SIZE) + grid = (batch_size, num_blocks) + + has_num_rejected = num_rejected_tokens_gpu is not None + copy_and_expand_dflash_inputs_kernel[grid]( + # Inputs + next_token_ids_ptr=next_token_ids, + target_positions_ptr=target_positions, + # Outputs + out_input_ids_ptr=self.input_ids, + out_context_positions_ptr=self._context_positions_buffer, + out_query_positions_ptr=self.positions, + out_context_slot_mapping_ptr=self._context_slot_mapping_buffer, + out_query_slot_mapping_ptr=self._slot_mapping_buffer, + out_token_indices_ptr=token_indices_to_sample, + # Block table + block_table_ptr=cad.block_table_tensor, + block_table_stride=cad.block_table_tensor.stride(0), + # Metadata + query_start_loc_ptr=cad.query_start_loc, + num_rejected_tokens_ptr=( + num_rejected_tokens_gpu if has_num_rejected else 0 + ), + # Scalars + parallel_drafting_token_id=self.parallel_drafting_token_id, + block_size=self.block_size, + num_query_per_req=num_query_per_req, + num_speculative_tokens=self.num_speculative_tokens, + total_input_tokens=num_context, + BLOCK_SIZE=BLOCK_SIZE, + HAS_NUM_REJECTED=has_num_rejected, + ) + + query_slot_mapping = self._slot_mapping_buffer[:num_query_total] + new_query_start_loc = self.arange[: batch_size + 1] * num_query_per_req + + # In padded mode, cad.seq_lens includes rejected tokens. Subtract + # them so attention only sees the valid prefix of context states. + effective_seq_lens = cad.seq_lens + if has_num_rejected: + effective_seq_lens = effective_seq_lens - num_rejected_tokens_gpu + + new_cad = CommonAttentionMetadata( + query_start_loc=new_query_start_loc, + seq_lens=effective_seq_lens + num_query_per_req, + query_start_loc_cpu=( + torch.from_numpy(self.token_arange_np[: batch_size + 1]).clone() + * num_query_per_req + ), + _seq_lens_cpu=None, + _num_computed_tokens_cpu=None, + num_reqs=cad.num_reqs, + num_actual_tokens=num_query_total, + max_query_len=num_query_per_req, + max_seq_len=cad.max_seq_len + num_query_per_req, + block_table_tensor=cad.block_table_tensor, + slot_mapping=query_slot_mapping, + causal=False, # Non-causal attention is required for DFlash + ) + + return num_query_total, token_indices_to_sample, new_cad + + @override + @torch.inference_mode() + def dummy_run( + self, + num_tokens: int, + use_cudagraphs: bool = True, + is_graph_capturing: bool = False, + slot_mappings: dict[str, torch.Tensor] | None = None, + ) -> None: + """ + Key differences to default dummy_run: + - Only one forward pass due to parallel drafting + - DFlash uses context states as unpadded metadata, so hidden_states will + use the unpadded num_tokens instead of num_input_tokens + - max_query_tokens is quite small, DFlash only sees spec tokens as queries + - Multimodal inputs are not currently supported + """ + num_query_tokens = min(num_tokens, self.max_query_tokens) + cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = ( + self._determine_batch_execution_and_padding( + num_query_tokens, use_cudagraphs=use_cudagraphs + ) + ) + + # Slot mapping sized to num_input_tokens (query only), matching + # the K/V tensor size from the model forward. Context KVs are + # pre-inserted separately and don't flow through the model. + if ( + self._draft_attn_layer_names + and slot_mappings is not None + and next(iter(self._draft_attn_layer_names)) in slot_mappings + ): + slot_mapping_dict = self._get_slot_mapping(num_input_tokens) + else: + slot_mapping_dict = slot_mappings or {} + + # Context and query positions use separate buffers; no copy needed. + context_positions = self._context_positions_buffer[:num_tokens] + # Context states will be passed directly to the precomputation without + # going through the buffer, since no CUDA graph is used for the precomputation. + # For the dummy run, we use the dummy buffer. + context_states = self.hidden_states[:num_tokens] + + # Run the KV projection (GEMM + norms + RoPE) for memory profiling, + self.model.precompute_and_store_context_kv(context_states, context_positions) + with set_forward_context( + None, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_runtime_mode, + slot_mapping=slot_mapping_dict, + ): + self.model( + input_ids=self.input_ids[:num_input_tokens], + positions=self._get_positions(num_input_tokens), + inputs_embeds=None, + ) + + @override + def build_model_inputs_first_pass( + self, + num_tokens: int, + num_input_tokens: int, + mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None, + ) -> tuple[dict[str, Any], int]: + # Context and query positions/slots were written to separate + # buffers by the kernel — no copy needed. + num_context = self._dflash_num_context + + # Pre-insert context KVs directly into cache + self.model.precompute_and_store_context_kv( + self._dflash_hidden_states, # Shape is already [num_context, hidden_size] + self._context_positions_buffer[:num_context], + self._context_slot_mapping_buffer[:num_context], + ) + return ( + dict( + input_ids=self.input_ids[:num_input_tokens], + positions=self._get_positions(num_input_tokens), + inputs_embeds=None, + ), + num_input_tokens, + ) + + @override + def build_per_layer_attn_metadata( + self, cad: CommonAttentionMetadata, draft_index: int = 0 + ) -> dict[str, object]: + per_layer_attention_metadata = super().build_per_layer_attn_metadata( + cad, draft_index + ) + for layer_name, attn_metadata in per_layer_attention_metadata.items(): + assert getattr(attn_metadata, "causal", None) is False, ( + f"Attention metadata for layer {layer_name} does not have" + " non-causal support, which is required for DFlash." + " Consider using a different attention backend, such as FlashAttention." + ) + return per_layer_attention_metadata + + @override + def _get_eagle3_use_aux_hidden_state_from_config(self): + use_aux_hidden_state = True + dflash_config = getattr( + self.draft_model_config.hf_config, "dflash_config", None + ) + if dflash_config is not None: + use_aux_hidden_state = dflash_config.get("use_aux_hidden_state", True) + return use_aux_hidden_state diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 4b20413ca702..1bff41bb1769 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -3,7 +3,7 @@ import ast from dataclasses import replace from importlib.util import find_spec -from typing import cast +from typing import Any, cast import numpy as np import torch @@ -23,6 +23,7 @@ from vllm.model_executor.models.deepseek_eagle3 import Eagle3DeepseekV2ForCausalLM from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.model_executor.models.qwen3_dflash import DFlashQwen3ForCausalLM from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.platforms import current_platform from vllm.triton_utils import triton @@ -83,13 +84,15 @@ def __init__( self.hidden_size = self.draft_model_config.get_hidden_size() self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size() - # Unifying eagle, draft model, and parallel drafting support + # Unifying eagle, draft model, and parallel drafting support. + # DFlash always uses parallel drafting (all tokens in one pass), + # but has an additional slot for the next_token_id (does not shift like EAGLE) self.parallel_drafting: bool = self.speculative_config.parallel_drafting self.extra_slots_per_request = ( 1 if not self.parallel_drafting else self.num_speculative_tokens ) self.net_num_new_slots_per_request = self.extra_slots_per_request - ( - 1 if self.pass_hidden_states_to_model else 0 + 1 if (self.pass_hidden_states_to_model and self.method != "dflash") else 0 ) self.needs_extra_input_slots = self.net_num_new_slots_per_request > 0 @@ -101,10 +104,14 @@ def __init__( self.speculative_config.use_local_argmax_reduction ) - max_batch_size = vllm_config.scheduler_config.max_num_seqs + self.max_batch_size = vllm_config.scheduler_config.max_num_seqs self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens self.token_arange_np = np.arange(self.max_num_tokens) + # Can be specialized by methods like DFlash to reduce the limit + self.max_query_tokens = self.max_num_tokens + self.max_positions = self.max_num_tokens + # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( @@ -146,18 +153,20 @@ def __init__( # 1D-RoPE. # See page 5 of https://arxiv.org/abs/2409.12191 self.mrope_positions = torch.zeros( - (3, self.max_num_tokens + 1), dtype=torch.int64, device=device + (3, self.max_positions + 1), dtype=torch.int64, device=device ) elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0: self.xdrope_positions = torch.zeros( - (self.uses_xdrope_dim, self.max_num_tokens + 1), + (self.uses_xdrope_dim, self.max_positions + 1), dtype=torch.int64, device=device, ) else: # RoPE need (max_num_tokens,) self.positions = torch.zeros( - self.max_num_tokens, dtype=torch.int64, device=device + self.max_positions, + dtype=torch.int64, + device=device, ) self.hidden_states = torch.zeros( (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device @@ -168,7 +177,7 @@ def __init__( # We need +1 here because the arange is used to set query_start_loc, # which has one more element than batch_size. - max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens) + max_num_slots_for_arange = max(self.max_batch_size + 1, self.max_num_tokens) self.arange = torch.arange( max_num_slots_for_arange, device=device, dtype=torch.int32 ) @@ -200,7 +209,7 @@ def __init__( ) self.backup_next_token_ids = CpuGpuBuffer( - max_batch_size, + self.max_batch_size, dtype=torch.int32, pin_memory=is_pin_memory_available(), device=device, @@ -208,7 +217,9 @@ def __init__( ) self._slot_mapping_buffer = torch.zeros( - self.max_num_tokens, dtype=torch.int64, device=device + self.max_positions, + dtype=torch.int64, + device=device, ) # Determine allowed attention backends once during initialization. @@ -275,7 +286,7 @@ def __init__( # Precompute draft position offsets in flattened tree. self.tree_draft_pos_offsets = torch.arange( 1, len(self.tree_choices) + 1, device=device, dtype=torch.int32 - ).repeat(max_batch_size, 1) + ).repeat(self.max_batch_size, 1) def _raise_if_padded_drafter_batch_disabled(self): if self.speculative_config.disable_padded_drafter_batch: @@ -305,14 +316,19 @@ def _init_parallel_drafting_params(self): # for those masked slots. model_hf_config = self.draft_model_config.hf_config - if hasattr(model_hf_config, "pard_token"): + # DFlash stores mask_token_id in dflash_config + dflash_config = getattr(model_hf_config, "dflash_config", None) + if dflash_config and "mask_token_id" in dflash_config: + self.parallel_drafting_token_id = dflash_config["mask_token_id"] + elif hasattr(model_hf_config, "pard_token"): self.parallel_drafting_token_id = model_hf_config.pard_token elif hasattr(model_hf_config, "ptd_token_id"): self.parallel_drafting_token_id = model_hf_config.ptd_token_id else: raise ValueError( "For parallel drafting, the draft model config must have " - "`pard_token` or `ptd_token_id` specified in its config.json." + "`pard_token`, `ptd_token_id`, or " + "`dflash_config.mask_token_id` specified in its config.json." ) if self.pass_hidden_states_to_model: @@ -402,9 +418,14 @@ def propose( ) -> torch.Tensor: batch_size = common_attn_metadata.batch_size() - if self.method == "eagle3": + if self.method in ("eagle3", "dflash"): assert isinstance( - self.model, (Eagle3LlamaForCausalLM, Eagle3DeepseekV2ForCausalLM) + self.model, + ( + Eagle3LlamaForCausalLM, + Eagle3DeepseekV2ForCausalLM, + DFlashQwen3ForCausalLM, + ), ) target_hidden_states = self.model.combine_hidden_states( target_hidden_states @@ -423,40 +444,17 @@ def propose( ) ) - per_layer_attn_metadata: dict[str, object] = {} - for attn_group in self.draft_attn_groups: - attn_metadata = attn_group.get_metadata_builder().build_for_drafting( - common_attn_metadata=common_attn_metadata, draft_index=0 - ) - for layer_name in attn_group.layer_names: - per_layer_attn_metadata[layer_name] = attn_metadata + per_layer_attn_metadata = self.build_per_layer_attn_metadata( + common_attn_metadata + ) cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = ( self._determine_batch_execution_and_padding(num_tokens) ) - if self.supports_mm_inputs: - mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) - - self.inputs_embeds[:num_tokens] = self.model.embed_input_ids( - self.input_ids[:num_tokens], - multimodal_embeddings=mm_embeds, - is_multimodal=is_mm_embed, - ) - - input_ids = None - inputs_embeds = self.inputs_embeds[:num_input_tokens] - else: - input_ids = self.input_ids[:num_input_tokens] - inputs_embeds = None - - model_kwargs = { - "input_ids": input_ids, - "positions": self._get_positions(num_input_tokens), - "inputs_embeds": inputs_embeds, - } - if self.pass_hidden_states_to_model: - model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens] + model_kwargs, slot_mapping_size = self.build_model_inputs_first_pass( + num_tokens, num_input_tokens, mm_embed_inputs + ) with set_forward_context( per_layer_attn_metadata, @@ -465,7 +463,7 @@ def propose( num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, slot_mapping=self._get_slot_mapping( - num_input_tokens, common_attn_metadata.slot_mapping + slot_mapping_size, common_attn_metadata.slot_mapping ), ): ret_hidden_states = self.model(**model_kwargs) @@ -488,7 +486,10 @@ def propose( positions = self.positions[token_indices_to_sample] hidden_states = hidden_states[token_indices_to_sample] - if isinstance(attn_metadata, TreeAttentionMetadata): + if any( + isinstance(attn_metadata, TreeAttentionMetadata) + for attn_metadata in per_layer_attn_metadata.values() + ): # Draft using tree attention - requires full logits for top-k logits = self.model.compute_logits(sample_hidden_states) draft_token_ids_list = self.propose_tree( @@ -504,15 +505,16 @@ def propose( draft_token_ids = self._greedy_sample(sample_hidden_states) - if self.allowed_attn_types is not None and not isinstance( - attn_metadata, self.allowed_attn_types - ): - raise ValueError( - f"Unsupported attention metadata type for speculative " - "decoding with num_speculative_tokens > 1: " - f"{type(attn_metadata)}. Supported types are: " - f"{self.allowed_attn_types}" - ) + for attn_metadata in per_layer_attn_metadata.values(): + if self.allowed_attn_types is not None and not isinstance( + attn_metadata, self.allowed_attn_types + ): + raise ValueError( + f"Unsupported attention metadata type for speculative " + "decoding with num_speculative_tokens > 1: " + f"{type(attn_metadata)}. Supported types are: " + f"{self.allowed_attn_types}" + ) # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] @@ -593,13 +595,9 @@ def propose( common_attn_metadata._num_computed_tokens_cpu += 1 # Rebuild attention metadata - for attn_group in self.draft_attn_groups: - attn_metadata = attn_group.get_metadata_builder().build_for_drafting( - common_attn_metadata=common_attn_metadata, - draft_index=token_index + 1, - ) - for layer_name in attn_group.layer_names: - per_layer_attn_metadata[layer_name] = attn_metadata + per_layer_attn_metadata = self.build_per_layer_attn_metadata( + common_attn_metadata, draft_index=token_index + 1 + ) # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids @@ -780,8 +778,51 @@ def set_inputs_first_pass( return total_num_output_tokens, token_indices_to_sample, new_cad + def build_model_inputs_first_pass( + self, + num_tokens: int, + num_input_tokens: int, + mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None, + ) -> tuple[dict[str, Any], int]: + if self.supports_mm_inputs: + mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) + + self.inputs_embeds[:num_tokens] = self.model.embed_input_ids( + self.input_ids[:num_tokens], + multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed, + ) + + input_ids = None + inputs_embeds = self.inputs_embeds[:num_input_tokens] + else: + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None + + model_kwargs = { + "input_ids": input_ids, + "positions": self._get_positions(num_input_tokens), + "inputs_embeds": inputs_embeds, + } + if self.pass_hidden_states_to_model: + model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens] + + return model_kwargs, num_input_tokens + + def build_per_layer_attn_metadata( + self, common_attn_metadata: CommonAttentionMetadata, draft_index: int = 0 + ) -> dict[str, object]: + per_layer_attn_metadata: dict[str, object] = {} + for attn_group in self.draft_attn_groups: + attn_metadata = attn_group.get_metadata_builder().build_for_drafting( + common_attn_metadata=common_attn_metadata, draft_index=draft_index + ) + for layer_name in attn_group.layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata + return per_layer_attn_metadata + def model_returns_tuple(self) -> bool: - return self.method not in ("mtp", "draft_model") + return self.method not in ("mtp", "draft_model", "dflash") def prepare_next_token_ids_cpu( self, @@ -1294,15 +1335,20 @@ def load_model(self, target_model: nn.Module) -> None: self._maybe_share_embeddings(target_language_model) self._maybe_share_lm_head(target_language_model) - if self.parallel_drafting and self.pass_hidden_states_to_model: - assert self.parallel_drafting_hidden_state_tensor is not None - self.parallel_drafting_hidden_state_tensor.copy_( - self.model.combine_hidden_states( - self.model.mask_hidden.view(3 * self.hidden_size) + if ( + self.parallel_drafting + and self.pass_hidden_states_to_model + and self.parallel_drafting_hidden_state_tensor is not None + ): + flat_mask = self.model.mask_hidden.view(-1) + if self.eagle3_use_aux_hidden_state: + # EAGLE3: mask_hidden stores all aux hidden states, + # project through combine_hidden_states + self.parallel_drafting_hidden_state_tensor.copy_( + self.model.combine_hidden_states(flat_mask) ) - if self.eagle3_use_aux_hidden_state - else self.model.mask_hidden.view(self.hidden_size) - ) + else: + self.parallel_drafting_hidden_state_tensor.copy_(flat_mask) def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None: """ @@ -1475,8 +1521,9 @@ def dummy_run( ) -> None: # FIXME: when using tree-based specdec, adjust number of forward-passes # according to the depth of the tree. + only_one_forward_pass = is_graph_capturing or self.parallel_drafting for fwd_idx in range( - self.num_speculative_tokens if not is_graph_capturing else 1 + 1 if only_one_forward_pass else self.num_speculative_tokens ): if fwd_idx <= 1: cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = ( diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index 48840967b4b8..b273b7b5515b 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -466,6 +466,114 @@ def copy_and_expand_eagle_inputs_kernel( ) +@triton.jit +def copy_and_expand_dflash_inputs_kernel( + # Inputs + next_token_ids_ptr, # [num_reqs] + target_positions_ptr, # [num_context] + # Outputs + out_input_ids_ptr, # [num_query_total] (output) + out_context_positions_ptr, # [num_context] (output) + out_query_positions_ptr, # [num_query_total] (output) + out_context_slot_mapping_ptr, # [num_context] (output) + out_query_slot_mapping_ptr, # [num_query_total] (output) + out_token_indices_ptr, # [num_reqs * num_speculative_tokens] (output) + # Block table + block_table_ptr, # [max_reqs, max_blocks] + block_table_stride, # stride of block_table dim 0 (in elements) + # Metadata + query_start_loc_ptr, # [num_reqs + 1] + num_rejected_tokens_ptr, # [num_reqs] or null (0) when not padded + # Scalars + parallel_drafting_token_id, # tl.int32 + block_size, # tl.int32 + num_query_per_req, # tl.int32 + num_speculative_tokens, # tl.int32 + total_input_tokens, # tl.int32 + BLOCK_SIZE: tl.constexpr, + HAS_NUM_REJECTED: tl.constexpr = False, +): + """ + Fused kernel for DFlash first-pass input setup. + + Per request, this kernel: + 1. Copies context positions from target_positions to + out_context_positions. + 2. Computes query positions (last_target_pos + 1 + offset) and writes + them to out_query_positions. + 3. Writes input_ids for query tokens: [next_token, mask, mask, ...]. + 4. Computes slot_mapping for context and query positions into separate + buffers via block_table lookup. + 5. Writes token_indices_to_sample for the mask (speculative) tokens. + """ + req_idx = tl.program_id(axis=0) + block_idx = tl.program_id(axis=1) + + # Load context token range for this request + ctx_start = tl.load(query_start_loc_ptr + req_idx) + ctx_end = tl.load(query_start_loc_ptr + req_idx + 1) + num_ctx = ctx_end - ctx_start + total_tokens = num_ctx + num_query_per_req + + j = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + in_bounds = j < total_tokens + is_ctx = j < num_ctx + is_query = (~is_ctx) & in_bounds + query_off = j - num_ctx # offset within query portion (0-indexed) + + # --- Positions --- + # Context: load from target_positions + ctx_pos_idx = tl.minimum(ctx_start + j, total_input_tokens - 1) + ctx_pos = tl.load(target_positions_ptr + ctx_pos_idx, mask=is_ctx, other=0) + + # Query: last_valid_pos + 1 + query_off + # In padded mode, ctx_end includes rejected tokens; use valid_ctx_end + # to find the last accepted context position. + if HAS_NUM_REJECTED: + num_rejected = tl.load(num_rejected_tokens_ptr + req_idx) + valid_ctx_end = ctx_end - num_rejected + else: + valid_ctx_end = ctx_end + last_pos = tl.load(target_positions_ptr + valid_ctx_end - 1) + query_pos = last_pos + 1 + query_off + + positions = tl.where(is_ctx, ctx_pos, query_pos) + + # Context and query positions go to separate buffers. + ctx_pos_out = ctx_start + j + tl.store(out_context_positions_ptr + ctx_pos_out, ctx_pos, mask=is_ctx) + query_out = req_idx * num_query_per_req + query_off + tl.store(out_query_positions_ptr + query_out, query_pos, mask=is_query) + + # --- Slot mapping (block_table lookup for all positions) --- + block_num = positions // block_size + # # Clamp block_number to avoid OOB when position is at max + block_num = tl.minimum(block_num, block_table_stride - 1) + block_id = tl.load( + block_table_ptr + req_idx * block_table_stride + block_num, + mask=in_bounds, + other=0, + ).to(tl.int64) + slot = block_id * block_size + (positions % block_size) + tl.store(out_context_slot_mapping_ptr + ctx_pos_out, slot, mask=is_ctx) + tl.store(out_query_slot_mapping_ptr + query_out, slot, mask=is_query) + + # --- Input IDs (query tokens only) --- + bonus_token = tl.load(next_token_ids_ptr + req_idx) + is_bonus = is_query & (query_off == 0) + input_id = tl.where(is_bonus, bonus_token, parallel_drafting_token_id) + tl.store(out_input_ids_ptr + query_out, input_id, mask=is_query) + + # --- Token indices to sample (mask tokens, skip the bonus token) --- + is_sample = is_query & (query_off > 0) + sample_out_idx = req_idx * num_speculative_tokens + (query_off - 1) + tl.store( + out_token_indices_ptr + sample_out_idx, + query_out, + mask=is_sample, + ) + + @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def update_num_computed_tokens_for_batch_change( num_computed_tokens: torch.Tensor, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index be7734487791..34b038cb5639 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -160,6 +160,7 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler +from vllm.v1.spec_decode.dflash import DFlashProposer from vllm.v1.spec_decode.draft_model import DraftModelProposer from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer @@ -515,6 +516,7 @@ def __init__( | NgramProposerGPU | SuffixDecodingProposer | EagleProposer + | DFlashProposer | DraftModelProposer | MedusaProposer | ExtractHiddenStatesProposer @@ -546,6 +548,9 @@ def __init__( self._ngram_pinned_val_buf = torch.zeros( self.max_num_reqs, dtype=torch.int32, pin_memory=True ) + elif self.speculative_config.use_dflash(): + self.drafter = DFlashProposer(self.vllm_config, self.device, self) + self.use_aux_hidden_state_outputs = True elif self.speculative_config.method == "suffix": self.drafter = SuffixDecodingProposer(self.vllm_config) elif self.speculative_config.use_eagle(): @@ -2289,7 +2294,7 @@ def _build_attn_group_metadata( cm.slot_mapping = slot_mappings[kv_cache_gid] if self.speculative_config and spec_decode_common_attn_metadata is None: - if isinstance(self.drafter, EagleProposer): + if isinstance(self.drafter, (EagleProposer, DFlashProposer)): if self.drafter.kv_cache_gid == kv_cache_gid: spec_decode_common_attn_metadata = cm else: @@ -4202,7 +4207,10 @@ def propose_draft_token_ids(sampled_token_ids): # as inputs, and does not need to wait for bookkeeping to finish. assert isinstance( self.drafter, - EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer, + EagleProposer + | DFlashProposer + | DraftModelProposer + | ExtractHiddenStatesProposer, ) sampled_token_ids = sampler_output.sampled_token_ids if input_fits_in_drafter: @@ -4589,8 +4597,14 @@ def propose_draft_token_ids( next_token_ids, valid_sampled_tokens_count ) - elif spec_config.use_eagle() or spec_config.uses_draft_model(): - assert isinstance(self.drafter, EagleProposer | DraftModelProposer) + elif ( + spec_config.use_eagle() + or spec_config.use_dflash() + or spec_config.uses_draft_model() + ): + assert isinstance( + self.drafter, EagleProposer | DFlashProposer | DraftModelProposer + ) if spec_config.disable_padded_drafter_batch: # When padded-batch is disabled, the sampled_token_ids should be @@ -4889,10 +4903,13 @@ def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None: return None hf_config = self.speculative_config.draft_model_config.hf_config - if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"): - return None - layer_ids = hf_config.eagle_aux_hidden_state_layer_ids + layer_ids = getattr(hf_config, "eagle_aux_hidden_state_layer_ids", None) + if not layer_ids: + dflash_config = getattr(hf_config, "dflash_config", None) + if dflash_config and isinstance(dflash_config, dict): + layer_ids = dflash_config.get("target_layer_ids") + if layer_ids and isinstance(layer_ids, (list, tuple)): return tuple(layer_ids) @@ -5483,7 +5500,10 @@ def _dummy_run( ): assert isinstance( self.drafter, - EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer, + EagleProposer + | DFlashProposer + | DraftModelProposer + | ExtractHiddenStatesProposer, ) assert self.speculative_config is not None # Eagle currently only supports PIECEWISE cudagraphs. @@ -6235,7 +6255,9 @@ def initialize_metadata_builders( self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model() ): - assert isinstance(self.drafter, EagleProposer | DraftModelProposer) + assert isinstance( + self.drafter, EagleProposer | DFlashProposer | DraftModelProposer + ) self.drafter.initialize_attn_backend(kv_cache_config, kernel_block_sizes) def _check_and_update_cudagraph_mode( @@ -6410,7 +6432,10 @@ def _check_and_update_cudagraph_mode( self.speculative_config.use_eagle() or self.speculative_config.uses_extract_hidden_states() ): - assert isinstance(self.drafter, EagleProposer | ExtractHiddenStatesProposer) + assert isinstance( + self.drafter, + EagleProposer | DFlashProposer | ExtractHiddenStatesProposer, + ) self.drafter.initialize_cudagraph_keys(cudagraph_mode) def calculate_reorder_batch_threshold(self) -> None: