Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
b83ff20
wip - agentic draft by claude
benchislett Mar 3, 2026
a4b8106
less broken wip - agent involved
benchislett Mar 3, 2026
4c3c96e
bugfix - still wip with broken AR
benchislett Mar 3, 2026
520323f
cleanup post merge
benchislett Mar 11, 2026
44a838f
add DFlash regression test
benchislett Mar 11, 2026
bb63e78
update dflash test
benchislett Mar 11, 2026
ec8521a
fixes for latest version
benchislett Mar 11, 2026
cc8c1da
more reliable test
benchislett Mar 11, 2026
2c978cf
cleanup
benchislett Mar 12, 2026
a1ede86
more cleanup and usability improvements
benchislett Mar 12, 2026
f83a649
add missing file
benchislett Mar 12, 2026
e8b1d10
optimize qwen3_dflash prepare inputs
benchislett Mar 12, 2026
c1fada0
slight refactor to enable easy check if causal attention is active
benchislett Mar 12, 2026
ca37453
more cleanup
benchislett Mar 12, 2026
3f2092d
more cleanup
benchislett Mar 12, 2026
9ec2eba
more cleanup
benchislett Mar 12, 2026
cf2514d
more cleanup
benchislett Mar 12, 2026
4f4edfb
optimize prepare inputs dflash
benchislett Mar 12, 2026
b45ae1e
optimize dflash using customop
benchislett Mar 13, 2026
7453da4
aggressive optimization
benchislett Mar 13, 2026
18f343c
leverage some torch.compile
benchislett Mar 13, 2026
a275e6c
optimize with triton kernel
benchislett Mar 13, 2026
f558afa
remove customop and store context states directly into KV cache
benchislett Mar 18, 2026
4e63bbd
async scheduling support
benchislett Mar 18, 2026
dcb236c
fix qwen3-next aux hidden states
benchislett Mar 18, 2026
427f4d8
don't need to add humaneval here
benchislett Mar 19, 2026
ec2ee15
Apply suggestion from @benchislett
benchislett Mar 19, 2026
5bf7f54
better documentation in qwen3_dflash.py
benchislett Mar 19, 2026
3a88015
Remove redundant comment
benchislett Mar 19, 2026
c71e140
remove fc_extras
benchislett Mar 19, 2026
54a8b96
remove redundant comment
benchislett Mar 19, 2026
a66e074
Remove redundant comment
benchislett Mar 19, 2026
97f2606
test for dflash prepare inputs
benchislett Mar 19, 2026
d9a63c2
fix issue from rebase
benchislett Mar 19, 2026
d6ace6a
Merge remote-tracking branch 'upstream/main' into dflash-attempt2
benchislett Mar 19, 2026
e99905a
warnings for bad max_num_scheduled_tokens
benchislett Mar 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 166 additions & 6 deletions tests/v1/e2e/spec_decode/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import pytest
import torch
from datasets import load_dataset
from tqdm import tqdm

from tests.evals.gsm8k.gsm8k_eval import _build_gsm8k_prompts, evaluate_gsm8k_offline
from tests.utils import (
Expand Down Expand Up @@ -1015,19 +1017,177 @@ def some_high_acceptance_metrics() -> dict:
}


def compute_acceptance_rate(metrics: list[Metric]) -> float:
def compute_acceptance_rate(
metrics: list[Metric], prev_metrics: list[Metric] | None = None
) -> float:
name2metric = {metric.name: metric for metric in metrics}
n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value # type: ignore
n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value
if n_draft_toks == 0:
return float("nan")
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value
if prev_metrics is not None:
prev_name2metric = {metric.name: metric for metric in prev_metrics}
n_draft_toks -= prev_name2metric["vllm:spec_decode_num_draft_tokens"].value
n_accepted_toks -= prev_name2metric[
"vllm:spec_decode_num_accepted_tokens"
].value
if n_draft_toks <= 0:
return float("nan")
return n_accepted_toks / n_draft_toks


def compute_acceptance_len(metrics: list[Metric]) -> float:
def compute_acceptance_len(
metrics: list[Metric], prev_metrics: list[Metric] | None = None
) -> float:
name2metric = {metric.name: metric for metric in metrics}
n_drafts = name2metric["vllm:spec_decode_num_drafts"].value # type: ignore
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore
n_drafts = name2metric["vllm:spec_decode_num_drafts"].value
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value
if n_drafts == 0:
return 1
if prev_metrics is not None:
prev_name2metric = {metric.name: metric for metric in prev_metrics}
n_drafts -= prev_name2metric["vllm:spec_decode_num_drafts"].value
n_accepted_toks -= prev_name2metric[
"vllm:spec_decode_num_accepted_tokens"
].value
if n_drafts <= 0:
return 1
return 1 + (n_accepted_toks / n_drafts)


# Datasets in the format used in DFlash validations
def load_and_process_dataset(data_name: str):
if data_name == "gsm8k":
dataset = load_dataset("openai/gsm8k", "main", split="test")
prompt_fmt = (
"{question}\nPlease reason step by step,"
" and put your final answer within \\boxed{{}}."
)
dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]})
elif data_name == "mt-bench":
dataset = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train")
dataset = dataset.map(lambda x: {"turns": x["prompt"]})
elif data_name == "humaneval":
dataset = load_dataset("openai/openai_humaneval", split="test")
prompt_fmt = (
"Write a solution to the following problem and make sure"
" that it passes the tests:\n```python\n{prompt}\n```"
)
dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]})

return dataset


@pytest.fixture
def dflash_config():
target_model = "Qwen/Qwen3-8B"
draft_model = "z-lab/Qwen3-8B-DFlash-b16"

return dict(
model=target_model,
trust_remote_code=True,
speculative_config={
"method": "dflash",
"model": draft_model,
"num_speculative_tokens": 16,
"max_model_len": 32768,
},
max_model_len=32768,
max_num_seqs=128,
gpu_memory_utilization=0.85,
enforce_eager=False,
disable_log_stats=False,
attention_config={"backend": "FLASH_ATTN"}, # Required for non-causal attention
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is a specific flash attention version needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, worked for me with both FA2 and FA4. In theory any backend with non-causal support will work (for now)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get this working across the board without needing to specify this arg? We should be able to resolve this internally by querying the attention backend's supports_attn_type during selection

)


def test_dflash_acceptance_rates(dflash_config):
"""
E2E test for DFlash (block diffusion) speculative decoding.
Runs acceptance rate validation on GSM8k, MT-Bench, and HumanEval
comparing against baseline results from the paper (Table 1).
See https://github.com/z-lab/dflash/blob/main/benchmark_sglang.py for methodology.
"""
spec_llm = LLM(**dflash_config)

max_prompts_per_dataset = 200 # mt-bench has 80, humaneval has 164, truncates gsm8k

# All scores from Table 1 in https://arxiv.org/pdf/2602.06036
expected_acceptance_lengths = {
"mt-bench": 4.24,
"humaneval": 6.50,
"gsm8k": 6.54 * 0.95, # runs with a subset of prompts so extra wide tol here
}

tokenizer = spec_llm.get_tokenizer()
for dataset_name, expected_len in expected_acceptance_lengths.items():
dataset = load_and_process_dataset(dataset_name)
prev_metrics = None
acceptance_lengths = []
for i in tqdm(
range(min(max_prompts_per_dataset, len(dataset))),
desc=f"Processing {dataset_name}",
):
user_content = dataset[i]["turns"][0]
prompt_text = tokenizer.apply_chat_template(
[{"role": "user", "content": user_content}],
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)

# Temp=0, MaxTokens=2048 from the paper
spec_llm.generate(
[prompt_text],
SamplingParams(temperature=0, max_tokens=2048),
use_tqdm=False,
)
current_metrics = spec_llm.get_metrics()
acceptance_len = compute_acceptance_len(current_metrics, prev_metrics)
prev_metrics = current_metrics
acceptance_lengths.append(acceptance_len)

mean_acceptance_length = sum(acceptance_lengths) / len(acceptance_lengths)
expected_len = expected_len * 0.9
print(
f"DFlash acceptance_len for {dataset_name}: {mean_acceptance_length:.2f}"
f" (expected at least {expected_len:.2f})"
)

assert mean_acceptance_length >= expected_len, (
f"DFlash acceptance_len for {dataset_name} is below expected threshold:"
f"{mean_acceptance_length:.2f} < {expected_len:.2f}"
)

del spec_llm
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()


def test_dflash_correctness(dflash_config):
"""
E2E test for DFlash (block diffusion) speculative decoding.
Ensures output correctness on GSM8k, with cudagraphs and batching on.
"""
spec_llm = LLM(**dflash_config)

# Evaluate GSM8k accuracy (Qwen3-8B ref: ~87-92% on GSM8k)
evaluate_llm_for_gsm8k(spec_llm, expected_accuracy_threshold=0.8)

current_metrics = spec_llm.get_metrics()
acceptance_len = compute_acceptance_len(current_metrics)

# AR is thoroughly validated in test_dflash_acceptance_rates, in a manner consistent
# with the DFlash paper. However, that test measures AL per-request and thus runs
# with a batch size of 1. To ensure that AL does not collapse with large batch sizes
# we enforce a baseline on the AL over the full lm-eval-style GSM8k test.
expected_len = 3.5 # Measured is 3.9 to 4.0
print(f"DFlash GSM8k correctness test got AL {acceptance_len}")
assert acceptance_len >= expected_len, (
"DFlash correctness check failed with"
f" {acceptance_len=}, expected at least {expected_len}"
)

del spec_llm
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
153 changes: 150 additions & 3 deletions tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.spec_decode.dflash import DFlashProposer
from vllm.v1.spec_decode.draft_model import DraftModelProposer
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
Expand All @@ -36,6 +37,8 @@
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
ar_draft_model_dir = "amd/PARD-Llama-3.2-1B" # Compatible with parallel and AR drafting
dflash_target_dir = "Qwen/Qwen3-8B"
dflash_dir = "z-lab/Qwen3-8B-DFlash-b16"

BLOCK_SIZE = 16

Expand All @@ -47,18 +50,29 @@ def _create_proposer(
speculative_token_tree: list[tuple[int, ...]] | None = None,
parallel_drafting: bool = False,
) -> EagleProposer:
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)

# Method-dependent setup
if method == "eagle":
target_model_dir = model_dir
draft_model_dir = eagle_dir
elif method == "eagle3":
target_model_dir = model_dir
draft_model_dir = eagle3_dir
elif method == "draft_model":
target_model_dir = model_dir
draft_model_dir = ar_draft_model_dir
elif method == "dflash":
target_model_dir = dflash_target_dir
draft_model_dir = dflash_dir
else:
raise ValueError(f"Unknown method: {method}")

model_config = ModelConfig(
model=target_model_dir,
runner="generate",
max_model_len=100,
trust_remote_code=(method == "dflash"),
)

spec_token_tree_str = None
if speculative_token_tree is not None:
assert num_speculative_tokens == len(speculative_token_tree)
Expand Down Expand Up @@ -92,7 +106,9 @@ def _create_proposer(
attention_config=AttentionConfig(backend=attention_backend),
)

if "eagle" in method:
if method == "dflash":
proposer = DFlashProposer(vllm_config=vllm_config, device=device)
elif "eagle" in method:
proposer = EagleProposer(vllm_config=vllm_config, device=device)
else:
proposer = DraftModelProposer(vllm_config=vllm_config, device=device)
Expand Down Expand Up @@ -1152,3 +1168,134 @@ def create_deterministic_logits(token_ids, k: int):

# Verify that the draft tokens match our expectations.
assert torch.equal(result, expected_tokens)


def test_set_inputs_first_pass_dflash():
"""
Test for DFlash set_inputs_first_pass.
DFlash uses cross-attention: context tokens become K/V and only
query tokens (bonus + mask) are Q. This tests the DFlash-specific
input preparation where:
- Context hidden states are copied as-is
- Query input_ids are [next_token, mask, mask, ...] per request
- Positions cover context (copied) + query (last_pos + 1 + offset)
- token_indices_to_sample points to mask token positions only
- A new CommonAttentionMetadata is returned with causal=False
Setup:
- 3 requests with query_lens [3, 2, 4]
- num_speculative_tokens = 3
- num_query_per_req = 4 (1 bonus + 3 mask tokens)
- next_token_ids: [100, 200, 300]
Expected output layout (query tokens only, 12 total):
Request 0 (indices 0-3): [100, mask, mask, mask]
Request 1 (indices 4-7): [200, mask, mask, mask]
Request 2 (indices 8-11): [300, mask, mask, mask]
Expected positions layout:
Context (first 9): copied from target_positions
Query (next 12):
Request 0: last_pos=9, query=[10, 11, 12, 13]
Request 1: last_pos=7, query=[8, 9, 10, 11]
Request 2: last_pos=11, query=[12, 13, 14, 15]
"""
device = torch.device(current_platform.device_type)

num_speculative_tokens = 3
proposer = _create_proposer("dflash", num_speculative_tokens)
mask_token_id = proposer.parallel_drafting_token_id

# Setup batch with 3 requests
batch_spec = BatchSpec(
seq_lens=[10, 8, 12],
query_lens=[3, 2, 4],
)

common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=BLOCK_SIZE,
device=device,
arange_block_indices=True,
)

# Input tensors
# Request 0: tokens [10, 11, 12] at positions [7, 8, 9]
# Request 1: tokens [20, 21] at positions [6, 7]
# Request 2: tokens [30, 31, 32, 33] at positions [8, 9, 10, 11]
target_token_ids = torch.tensor(
[10, 11, 12, 20, 21, 30, 31, 32, 33], dtype=torch.int32, device=device
)
target_positions = torch.tensor(
[7, 8, 9, 6, 7, 8, 9, 10, 11], dtype=torch.int64, device=device
)
target_hidden_states = torch.randn(
9, proposer.hidden_size, dtype=proposer.dtype, device=device
)
next_token_ids = torch.tensor([100, 200, 300], dtype=torch.int32, device=device)

num_tokens, token_indices_to_sample, output_cad = proposer.set_inputs_first_pass(
target_token_ids=target_token_ids,
next_token_ids=next_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
token_indices_to_sample=None,
cad=common_attn_metadata,
num_rejected_tokens_gpu=None,
)

num_query_per_req = 1 + num_speculative_tokens # 4
num_context = 9

# num_tokens is the query-only count
assert num_tokens == 3 * num_query_per_req # 12

# Verify input_ids (query tokens only)
# Each request: [next_token, mask, mask, mask]
M = mask_token_id
expected_input_ids = torch.tensor(
[100, M, M, M, 200, M, M, M, 300, M, M, M],
dtype=torch.int32,
device=device,
)
assert torch.equal(proposer.input_ids[:num_tokens], expected_input_ids)

# Verify context positions (first 9 slots): copied from target_positions
assert torch.equal(proposer.positions[:num_context], target_positions)

# Verify query positions (next 12 slots):
# req0: last_pos=9, query=[10, 11, 12, 13]
# req1: last_pos=7, query=[8, 9, 10, 11]
# req2: last_pos=11, query=[12, 13, 14, 15]
expected_query_positions = torch.tensor(
[10, 11, 12, 13, 8, 9, 10, 11, 12, 13, 14, 15],
dtype=torch.int64,
device=device,
)
assert torch.equal(
proposer.positions[num_context : num_context + num_tokens],
expected_query_positions,
)

# Verify token_indices_to_sample (mask tokens only, skip bonus at offset 0)
# req0: query indices 0-3, mask at 1,2,3
# req1: query indices 4-7, mask at 5,6,7
# req2: query indices 8-11, mask at 9,10,11
expected_token_indices_to_sample = torch.tensor(
[1, 2, 3, 5, 6, 7, 9, 10, 11], dtype=torch.int32, device=device
)
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)

# Verify the new CAD has DFlash-specific properties
assert output_cad.causal is False # DFlash requires non-causal attention
assert output_cad.num_actual_tokens == num_tokens # query-only count
assert output_cad.max_query_len == num_query_per_req

expected_query_start_loc = torch.tensor(
[0, 4, 8, 12], dtype=torch.int32, device=device
)
assert torch.equal(output_cad.query_start_loc, expected_query_start_loc)

# Verify hidden states (context copied as-is)
assert torch.equal(proposer.hidden_states[:num_context], target_hidden_states)
Loading
Loading