-
-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[Feat][Spec Decode] DFlash #36847
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
benchislett
wants to merge
36
commits into
vllm-project:main
Choose a base branch
from
CentML:dflash-attempt2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,572
−106
Open
[Feat][Spec Decode] DFlash #36847
Changes from all commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
b83ff20
wip - agentic draft by claude
benchislett a4b8106
less broken wip - agent involved
benchislett 4c3c96e
bugfix - still wip with broken AR
benchislett 520323f
cleanup post merge
benchislett 44a838f
add DFlash regression test
benchislett bb63e78
update dflash test
benchislett ec8521a
fixes for latest version
benchislett cc8c1da
more reliable test
benchislett 2c978cf
cleanup
benchislett a1ede86
more cleanup and usability improvements
benchislett f83a649
add missing file
benchislett e8b1d10
optimize qwen3_dflash prepare inputs
benchislett c1fada0
slight refactor to enable easy check if causal attention is active
benchislett ca37453
more cleanup
benchislett 3f2092d
more cleanup
benchislett 9ec2eba
more cleanup
benchislett cf2514d
more cleanup
benchislett 4f4edfb
optimize prepare inputs dflash
benchislett b45ae1e
optimize dflash using customop
benchislett 7453da4
aggressive optimization
benchislett 18f343c
leverage some torch.compile
benchislett a275e6c
optimize with triton kernel
benchislett f558afa
remove customop and store context states directly into KV cache
benchislett 4e63bbd
async scheduling support
benchislett dcb236c
fix qwen3-next aux hidden states
benchislett 427f4d8
don't need to add humaneval here
benchislett ec2ee15
Apply suggestion from @benchislett
benchislett 5bf7f54
better documentation in qwen3_dflash.py
benchislett 3a88015
Remove redundant comment
benchislett c71e140
remove fc_extras
benchislett 54a8b96
remove redundant comment
benchislett a66e074
Remove redundant comment
benchislett 97f2606
test for dflash prepare inputs
benchislett d9a63c2
fix issue from rebase
benchislett d6ace6a
Merge remote-tracking branch 'upstream/main' into dflash-attempt2
benchislett e99905a
warnings for bad max_num_scheduled_tokens
benchislett File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,8 @@ | |
|
|
||
| import pytest | ||
| import torch | ||
| from datasets import load_dataset | ||
| from tqdm import tqdm | ||
|
|
||
| from tests.evals.gsm8k.gsm8k_eval import _build_gsm8k_prompts, evaluate_gsm8k_offline | ||
| from tests.utils import ( | ||
|
|
@@ -1015,19 +1017,177 @@ def some_high_acceptance_metrics() -> dict: | |
| } | ||
|
|
||
|
|
||
| def compute_acceptance_rate(metrics: list[Metric]) -> float: | ||
| def compute_acceptance_rate( | ||
| metrics: list[Metric], prev_metrics: list[Metric] | None = None | ||
| ) -> float: | ||
| name2metric = {metric.name: metric for metric in metrics} | ||
| n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value # type: ignore | ||
| n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value | ||
| if n_draft_toks == 0: | ||
| return float("nan") | ||
| n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore | ||
| n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value | ||
| if prev_metrics is not None: | ||
| prev_name2metric = {metric.name: metric for metric in prev_metrics} | ||
| n_draft_toks -= prev_name2metric["vllm:spec_decode_num_draft_tokens"].value | ||
| n_accepted_toks -= prev_name2metric[ | ||
| "vllm:spec_decode_num_accepted_tokens" | ||
| ].value | ||
| if n_draft_toks <= 0: | ||
| return float("nan") | ||
| return n_accepted_toks / n_draft_toks | ||
|
|
||
|
|
||
| def compute_acceptance_len(metrics: list[Metric]) -> float: | ||
| def compute_acceptance_len( | ||
| metrics: list[Metric], prev_metrics: list[Metric] | None = None | ||
| ) -> float: | ||
| name2metric = {metric.name: metric for metric in metrics} | ||
| n_drafts = name2metric["vllm:spec_decode_num_drafts"].value # type: ignore | ||
| n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore | ||
| n_drafts = name2metric["vllm:spec_decode_num_drafts"].value | ||
| n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value | ||
| if n_drafts == 0: | ||
| return 1 | ||
| if prev_metrics is not None: | ||
| prev_name2metric = {metric.name: metric for metric in prev_metrics} | ||
| n_drafts -= prev_name2metric["vllm:spec_decode_num_drafts"].value | ||
| n_accepted_toks -= prev_name2metric[ | ||
| "vllm:spec_decode_num_accepted_tokens" | ||
| ].value | ||
| if n_drafts <= 0: | ||
| return 1 | ||
| return 1 + (n_accepted_toks / n_drafts) | ||
|
|
||
|
|
||
| # Datasets in the format used in DFlash validations | ||
| def load_and_process_dataset(data_name: str): | ||
| if data_name == "gsm8k": | ||
| dataset = load_dataset("openai/gsm8k", "main", split="test") | ||
| prompt_fmt = ( | ||
| "{question}\nPlease reason step by step," | ||
| " and put your final answer within \\boxed{{}}." | ||
| ) | ||
| dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) | ||
| elif data_name == "mt-bench": | ||
| dataset = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train") | ||
| dataset = dataset.map(lambda x: {"turns": x["prompt"]}) | ||
| elif data_name == "humaneval": | ||
| dataset = load_dataset("openai/openai_humaneval", split="test") | ||
| prompt_fmt = ( | ||
| "Write a solution to the following problem and make sure" | ||
| " that it passes the tests:\n```python\n{prompt}\n```" | ||
| ) | ||
| dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) | ||
|
|
||
| return dataset | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def dflash_config(): | ||
| target_model = "Qwen/Qwen3-8B" | ||
| draft_model = "z-lab/Qwen3-8B-DFlash-b16" | ||
|
|
||
| return dict( | ||
| model=target_model, | ||
| trust_remote_code=True, | ||
| speculative_config={ | ||
| "method": "dflash", | ||
| "model": draft_model, | ||
| "num_speculative_tokens": 16, | ||
| "max_model_len": 32768, | ||
| }, | ||
| max_model_len=32768, | ||
| max_num_seqs=128, | ||
| gpu_memory_utilization=0.85, | ||
| enforce_eager=False, | ||
| disable_log_stats=False, | ||
| attention_config={"backend": "FLASH_ATTN"}, # Required for non-causal attention | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we get this working across the board without needing to specify this arg? We should be able to resolve this internally by querying the attention backend's supports_attn_type during selection |
||
| ) | ||
|
|
||
|
|
||
| def test_dflash_acceptance_rates(dflash_config): | ||
| """ | ||
| E2E test for DFlash (block diffusion) speculative decoding. | ||
| Runs acceptance rate validation on GSM8k, MT-Bench, and HumanEval | ||
| comparing against baseline results from the paper (Table 1). | ||
| See https://github.com/z-lab/dflash/blob/main/benchmark_sglang.py for methodology. | ||
| """ | ||
| spec_llm = LLM(**dflash_config) | ||
|
|
||
| max_prompts_per_dataset = 200 # mt-bench has 80, humaneval has 164, truncates gsm8k | ||
|
|
||
| # All scores from Table 1 in https://arxiv.org/pdf/2602.06036 | ||
| expected_acceptance_lengths = { | ||
| "mt-bench": 4.24, | ||
| "humaneval": 6.50, | ||
| "gsm8k": 6.54 * 0.95, # runs with a subset of prompts so extra wide tol here | ||
| } | ||
|
|
||
| tokenizer = spec_llm.get_tokenizer() | ||
| for dataset_name, expected_len in expected_acceptance_lengths.items(): | ||
| dataset = load_and_process_dataset(dataset_name) | ||
| prev_metrics = None | ||
| acceptance_lengths = [] | ||
| for i in tqdm( | ||
| range(min(max_prompts_per_dataset, len(dataset))), | ||
| desc=f"Processing {dataset_name}", | ||
| ): | ||
| user_content = dataset[i]["turns"][0] | ||
| prompt_text = tokenizer.apply_chat_template( | ||
| [{"role": "user", "content": user_content}], | ||
| tokenize=False, | ||
| add_generation_prompt=True, | ||
| enable_thinking=False, | ||
| ) | ||
|
|
||
| # Temp=0, MaxTokens=2048 from the paper | ||
| spec_llm.generate( | ||
| [prompt_text], | ||
| SamplingParams(temperature=0, max_tokens=2048), | ||
| use_tqdm=False, | ||
| ) | ||
| current_metrics = spec_llm.get_metrics() | ||
| acceptance_len = compute_acceptance_len(current_metrics, prev_metrics) | ||
| prev_metrics = current_metrics | ||
| acceptance_lengths.append(acceptance_len) | ||
|
|
||
| mean_acceptance_length = sum(acceptance_lengths) / len(acceptance_lengths) | ||
| expected_len = expected_len * 0.9 | ||
| print( | ||
| f"DFlash acceptance_len for {dataset_name}: {mean_acceptance_length:.2f}" | ||
| f" (expected at least {expected_len:.2f})" | ||
| ) | ||
|
|
||
| assert mean_acceptance_length >= expected_len, ( | ||
| f"DFlash acceptance_len for {dataset_name} is below expected threshold:" | ||
| f"{mean_acceptance_length:.2f} < {expected_len:.2f}" | ||
| ) | ||
|
|
||
| del spec_llm | ||
| torch.accelerator.empty_cache() | ||
| cleanup_dist_env_and_memory() | ||
|
|
||
|
|
||
| def test_dflash_correctness(dflash_config): | ||
| """ | ||
| E2E test for DFlash (block diffusion) speculative decoding. | ||
| Ensures output correctness on GSM8k, with cudagraphs and batching on. | ||
| """ | ||
| spec_llm = LLM(**dflash_config) | ||
|
|
||
| # Evaluate GSM8k accuracy (Qwen3-8B ref: ~87-92% on GSM8k) | ||
| evaluate_llm_for_gsm8k(spec_llm, expected_accuracy_threshold=0.8) | ||
|
|
||
| current_metrics = spec_llm.get_metrics() | ||
| acceptance_len = compute_acceptance_len(current_metrics) | ||
|
|
||
| # AR is thoroughly validated in test_dflash_acceptance_rates, in a manner consistent | ||
| # with the DFlash paper. However, that test measures AL per-request and thus runs | ||
| # with a batch size of 1. To ensure that AL does not collapse with large batch sizes | ||
| # we enforce a baseline on the AL over the full lm-eval-style GSM8k test. | ||
| expected_len = 3.5 # Measured is 3.9 to 4.0 | ||
| print(f"DFlash GSM8k correctness test got AL {acceptance_len}") | ||
| assert acceptance_len >= expected_len, ( | ||
| "DFlash correctness check failed with" | ||
| f" {acceptance_len=}, expected at least {expected_len}" | ||
| ) | ||
|
|
||
| del spec_llm | ||
| torch.accelerator.empty_cache() | ||
| cleanup_dist_env_and_memory() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is a specific flash attention version needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, worked for me with both FA2 and FA4. In theory any backend with non-causal support will work (for now)