Skip to content

[Feat][Spec Decode] DFlash#36847

Open
benchislett wants to merge 36 commits intovllm-project:mainfrom
CentML:dflash-attempt2
Open

[Feat][Spec Decode] DFlash#36847
benchislett wants to merge 36 commits intovllm-project:mainfrom
CentML:dflash-attempt2

Conversation

@benchislett
Copy link
Collaborator

@benchislett benchislett commented Mar 12, 2026

Purpose

Overview

DFlash works much like P-EAGLE (see #32887), but with a major architectural change: it uses bidirectional attention between the query tokens (the last sampled token from the base model plus a bunch of placeholder mask tokens) and the context states, which are the target model's hidden states from the prefill or accepted tokens.

To implement this, I introduce an extra operation that lives outside of the main model execution, which populates the KV cache with the context states directly. Though not exposed to the standard set of torch.compile and CUDA graph optimizations, handling the context in this way allows us to use async scheduling (by writing the full set of context states to the cache and then using seq_lens_gpu to ignore the rejected ones), as well as enabling (piecewise) CUDA graphs for the main forward pass over the query tokens.

(Solved) Because the attention is bidirectional and structured in this way, we cannot allow rejected tokens to stay in the batch as their states would corrupt the rest of the calculation, unlike in standard causal attention where we can simply omit them and sample from an earlier position. Therefore, "disable_padded_drafter_batch" is required and will be enabled when using DFlash, disabling async scheduling as a consequence. At this time I cannot think of a good way around this problem.

Additionally, a small selection of kernel backends actually support non-causal attention, and a smaller set additionally include support for gpt-oss style "sinks". Flash Attention with Qwen3-8B is used as a test for this implementation, as Triton Attention and FlashInfer (TRTLLM) attention both do not support non-causality. A follow-up work here would be to allow different attention backends for the drafter and the target model. This is out-of-scope for this PR, but would enable a broader set of compatibility for DFlash models.

(Solved) Finally, this new architecture requires extra logic to handle the new input shapes and attention metadata. It is not as simple as P-EAGLE in which we can share code with other modes of speculation, since the sizes and contents of the input tensors are all different. One compatibility component in which I am not absolutely confident is the CUDA graph support: even in piecewise mode, I am not sure how the "padded" query tokens interact with the unpadded context: is it safe to have operations on the context slice of Q inside the attention op, or must it be moved into a custom op to avoid issues? It has not yet been functional to enable torch.compile for the DFlash drafter, likely for this reason.

Implementation Details

The model is implemented in qwen3_dflash.py and the core speculation logic is in dflash.py. Similarly to draft model speculative decoding, DFlash has a unique speculation paradigm that requires some refactoring of eagle.py in order to support cleanly.

Specifically, build_model_inputs_first_pass and build_per_layer_attn_metadata have been introduced in eagle.py so that dflash.py can override them.

Initial implementations inlined the logic for DFlash into eagle.py, but similarly to #24322 the added branching and complexity would (in my opinion) lead to a fairly cluttered EAGLE implementation. In this PR I have tried to separate the concerns to a reasonable degree, so that maintenance of the existing EAGLE pathway is not burdened.

Testing

I add unit tests for both correctness (GSM8k) and acceptance rate checks for Qwen3-8B with DFlash. I am able to reproduce almost exactly the acceptance rate values reported in the DFlash paper, and have included this in the test suite with an exact reproduction of their evaluation setup. This should be a valuable resource and an easily extensible tool to new DFlash models as they continue to evolve and expand the DFlash model family.

In local tests on 1xB200, both Qwen3-8B and Qwen3.5-9B pass the test suite using both FA2 and FA4 on 1xB200.

Usage

vllm serve Qwen/Qwen3-8B \
    --speculative-config '{"method": "dflash", "model": "z-lab/Qwen3-8B-DFlash-b16", "num_speculative_tokens": 15}' \
    --attention-backend flash_attn \
    --max-num-batched-tokens 32768

Benchmarking

The latest DFlash implementation on this branch is optimized for low-latency performance. I measured the following speedups using the DFlash methodology (https://github.com/z-lab/dflash)

Datasets covered are Alpaca, GSM8k, HumanEval, Math500, MBPP, MT-Bench. Number of spec tokens for DFlash is 15.

Alpaca

Output tok/s
conc  baseline    dflash  speedup
----  --------  --------  -------
  32  5,268.64  6,210.65    1.179
  16  3,079.46  5,015.65    1.629
   8  1,581.42  3,339.16    2.111
   1    209.46    590.48    2.819

GSM8k

Output tok/s
conc  baseline    dflash  speedup
----  --------  --------  -------
  32  5,261.70  8,433.19    1.603
  16  3,079.63  6,534.04    2.122
   8  1,596.48  4,332.12    2.714
   1    211.35    741.56    3.509

HumanEval

Output tok/s
conc  baseline     dflash  speedup
----  --------  ---------  -------
  32  5,002.62  10,257.62    2.050
  16  3,043.47   8,131.09    2.672
   8  1,583.48   5,433.65    3.431
   1    209.16     969.07    4.633

Math500

Output tok/s
conc  baseline    dflash  speedup
----  --------  --------  -------
  32  5,267.27  9,836.42    1.867
  16  3,087.45  7,668.98    2.484
   8  1,598.57  5,096.36    3.188
   1    211.31    849.22    4.019

MBPP

Output tok/s
conc  baseline    dflash  speedup
----  --------  --------  -------
  32  5,120.32  8,289.34    1.619
  16  3,099.78  6,677.64    2.154
   8  1,584.55  4,487.97    2.832
   1    209.56    810.25    3.866

MT-Bench

Output tok/s
conc  baseline    dflash  speedup
----  --------  --------  -------
  32  5,217.68  6,787.14    1.301
  16  3,082.95  5,297.45    1.718
   8  1,591.69  3,536.00    2.222
   1    210.29    627.01    2.982

@mergify
Copy link

mergify bot commented Mar 12, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @benchislett.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 12, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for DFlash speculative decoding, a new method that leverages bidirectional attention. The implementation is comprehensive, touching configuration, model definition, the speculative proposer, and the core model runner. Key changes include:

  • A new DFlashProposer and qwen3_dflash model to implement the DFlash architecture.
  • Refactoring of eagle.py to better support different speculative decoding methods.
  • Configuration updates for auto-detection and setup of DFlash.
  • Extensive end-to-end tests for both correctness and acceptance rate, which is great to see.

The code is well-structured, and the refactoring improves extensibility. My main concern, which you've also highlighted in the PR description and code comments, is the compatibility with CUDA graphs due to the torch.cat operation on potentially differently padded tensors. This is a critical point for performance and stability, and I've left a comment with a suggestion on how to address it.

@benchislett
Copy link
Collaborator Author

Update on the graphs, I have a local workaround that I'm working on cleaning up. The solution is to put the context states in the forward_context and access them via CustomOp similar to unified_attention_with_output. Then all the ordinary logic can go in the main graph.

@mergify
Copy link

mergify bot commented Mar 13, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @benchislett.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 13, 2026
@jianc99
Copy link

jianc99 commented Mar 18, 2026

Thanks @benchislett for implementing DFlash in vLLM. We’ve just released DFlash checkpoints for Qwen3.5-4B, 9B, 27B, and 35B-A3B. They perform very well and are faster than MTP. Feel free to try them out. We’ll continue adding support for more models, and I’m really looking forward to seeing DFlash run in vLLM!

benchislett and others added 21 commits March 19, 2026 15:58
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@mergify
Copy link

mergify bot commented Mar 19, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @benchislett.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 19, 2026
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@benchislett
Copy link
Collaborator Author

benchislett commented Mar 19, 2026

(force-pushed to fix DCO. manually checked all test cases still pass)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@mergify mergify bot removed the needs-rebase label Mar 19, 2026
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@benchislett
Copy link
Collaborator Author

benchislett commented Mar 19, 2026

@mgoin

Regarding this issue:
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
It is a side-effect of how we currently handle parallel drafting and max_num_batched_tokens. It can be resolved by increasing max_num_batched_tokens to 32768.

It will now raise a clear error on startup as of e99905a.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

new-model Requests to new models qwen Related to Qwen models speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants