Conversation
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Code Review
This pull request introduces support for DFlash speculative decoding, a new method that leverages bidirectional attention. The implementation is comprehensive, touching configuration, model definition, the speculative proposer, and the core model runner. Key changes include:
- A new
DFlashProposerandqwen3_dflashmodel to implement the DFlash architecture. - Refactoring of
eagle.pyto better support different speculative decoding methods. - Configuration updates for auto-detection and setup of DFlash.
- Extensive end-to-end tests for both correctness and acceptance rate, which is great to see.
The code is well-structured, and the refactoring improves extensibility. My main concern, which you've also highlighted in the PR description and code comments, is the compatibility with CUDA graphs due to the torch.cat operation on potentially differently padded tensors. This is a critical point for performance and stability, and I've left a comment with a suggestion on how to address it.
|
Update on the graphs, I have a local workaround that I'm working on cleaning up. The solution is to put the context states in the forward_context and access them via CustomOp similar to unified_attention_with_output. Then all the ordinary logic can go in the main graph. |
|
This pull request has merge conflicts that must be resolved before it can be |
|
Thanks @benchislett for implementing DFlash in vLLM. We’ve just released DFlash checkpoints for Qwen3.5-4B, 9B, 27B, and 35B-A3B. They perform very well and are faster than MTP. Feel free to try them out. We’ll continue adding support for more models, and I’m really looking forward to seeing DFlash run in vLLM! |
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
27773c5 to
d9a63c2
Compare
|
(force-pushed to fix DCO. manually checked all test cases still pass) |
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
|
Regarding this issue: It will now raise a clear error on startup as of e99905a. |
Purpose
Overview
DFlash works much like P-EAGLE (see #32887), but with a major architectural change: it uses bidirectional attention between the query tokens (the last sampled token from the base model plus a bunch of placeholder mask tokens) and the context states, which are the target model's hidden states from the prefill or accepted tokens.
To implement this, I introduce an extra operation that lives outside of the main model execution, which populates the KV cache with the context states directly. Though not exposed to the standard set of torch.compile and CUDA graph optimizations, handling the context in this way allows us to use async scheduling (by writing the full set of context states to the cache and then using seq_lens_gpu to ignore the rejected ones), as well as enabling (piecewise) CUDA graphs for the main forward pass over the query tokens.
(Solved)
Because the attention is bidirectional and structured in this way, we cannot allow rejected tokens to stay in the batch as their states would corrupt the rest of the calculation, unlike in standard causal attention where we can simply omit them and sample from an earlier position. Therefore, "disable_padded_drafter_batch" is required and will be enabled when using DFlash, disabling async scheduling as a consequence. At this time I cannot think of a good way around this problem.Additionally, a small selection of kernel backends actually support non-causal attention, and a smaller set additionally include support for gpt-oss style "sinks". Flash Attention with Qwen3-8B is used as a test for this implementation, as Triton Attention and FlashInfer (TRTLLM) attention both do not support non-causality. A follow-up work here would be to allow different attention backends for the drafter and the target model. This is out-of-scope for this PR, but would enable a broader set of compatibility for DFlash models.
(Solved)
Finally, this new architecture requires extra logic to handle the new input shapes and attention metadata. It is not as simple as P-EAGLE in which we can share code with other modes of speculation, since the sizes and contents of the input tensors are all different. One compatibility component in which I am not absolutely confident is the CUDA graph support: even in piecewise mode, I am not sure how the "padded" query tokens interact with the unpadded context: is it safe to have operations on the context slice of Q inside the attention op, or must it be moved into a custom op to avoid issues? It has not yet been functional to enable torch.compile for the DFlash drafter, likely for this reason.Implementation Details
The model is implemented in
qwen3_dflash.pyand the core speculation logic is indflash.py. Similarly to draft model speculative decoding, DFlash has a unique speculation paradigm that requires some refactoring ofeagle.pyin order to support cleanly.Specifically,
build_model_inputs_first_passandbuild_per_layer_attn_metadatahave been introduced ineagle.pyso thatdflash.pycan override them.Initial implementations inlined the logic for DFlash into
eagle.py, but similarly to #24322 the added branching and complexity would (in my opinion) lead to a fairly cluttered EAGLE implementation. In this PR I have tried to separate the concerns to a reasonable degree, so that maintenance of the existing EAGLE pathway is not burdened.Testing
I add unit tests for both correctness (GSM8k) and acceptance rate checks for Qwen3-8B with DFlash. I am able to reproduce almost exactly the acceptance rate values reported in the DFlash paper, and have included this in the test suite with an exact reproduction of their evaluation setup. This should be a valuable resource and an easily extensible tool to new DFlash models as they continue to evolve and expand the DFlash model family.
In local tests on 1xB200, both Qwen3-8B and Qwen3.5-9B pass the test suite using both FA2 and FA4 on 1xB200.
Usage
Benchmarking
The latest DFlash implementation on this branch is optimized for low-latency performance. I measured the following speedups using the DFlash methodology (https://github.com/z-lab/dflash)
Datasets covered are Alpaca, GSM8k, HumanEval, Math500, MBPP, MT-Bench. Number of spec tokens for DFlash is 15.
Alpaca
GSM8k
HumanEval
Math500
MBPP
MT-Bench