[Feature] Add DFLASH speculative decoding support#16818
[Feature] Add DFLASH speculative decoding support#16818dcw02 wants to merge 67 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @dcw02, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates DFLASH speculative decoding into the system, aiming to significantly boost the generation throughput of large language models. It achieves this by introducing a native DFLASH draft model, adapting existing attention mechanisms and the core model execution pipeline to support DFLASH's unique verification process, and providing robust configuration options and performance benchmarks. The changes enable faster token generation while maintaining output quality. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
The code is currently being cleaned up but I'll mark the PR as ready and edit the summary when it's ready. |
There was a problem hiding this comment.
Code Review
This pull request introduces support for DFLASH speculative decoding, a significant new feature. The changes are extensive, including a new DFLASH draft model, a worker for its execution, and integration into the existing speculative decoding framework. The implementation appears solid and correctly follows the DFLASH algorithm, with necessary modifications to attention backends, the CUDA graph runner, and server arguments. I've identified a minor logic issue in the weight loading mechanism for the new DFLASH model and have suggested a refactoring to improve clarity and correctness. The rest of the changes are well-implemented.
|
will this support NVFP4? |
|
Can dflash support enabling dp-attention simultaneously in the future? |
yes that can be done. it exists for EAGLE3, I would have to take a look how it's implemented |
Cherry-pick from sgl-project#16818 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…project#20547) Cherry-pick from sgl-project#20547, resolved conflicts with PR sgl-project#16818. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Cherry-pick from sgl-project#16818 onto v0.5.9 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
| @@ -0,0 +1,749 @@ | |||
| """DFLASH vs baseline GSM8K sweep. | |||
There was a problem hiding this comment.
This should be added in the following PRs. Instead of adding standalone benchmark scripts, please add CI (UT) first. We can discuss how to integrate the benchmark/eval scripts of dflash into SGLang later.
There was a problem hiding this comment.
cleaned up and removed for now
There was a problem hiding this comment.
ok also added dflash tests, modeled after the EAGLE ones: a basic MMLU + accept-length test, an infer-a correctness file with radix/page-size variants, and an infer-beta file for stop conditions, radix attention, GSM8K, and paged mode.
| self._add_request_to_queue(req) | ||
| return | ||
|
|
||
| if self.spec_algorithm.is_dflash() and req.return_logprob: |
There was a problem hiding this comment.
I think the logprob support can be quite easy after the #21048 is done. cc @Qiaolin-Yu
| if self.spec_algorithm.is_dflash() and req.return_logprob: | ||
| req.set_finish_with_abort( | ||
| "DFLASH speculative decoding does not support return_logprob yet." | ||
| ) | ||
| self.init_req_max_new_tokens(req) | ||
| self._add_request_to_queue(req) | ||
| return | ||
| if self.spec_algorithm.is_dflash() and ( | ||
| req.sampling_params.json_schema is not None | ||
| or req.sampling_params.regex is not None | ||
| or req.sampling_params.ebnf is not None | ||
| or req.sampling_params.structural_tag is not None | ||
| ): | ||
| req.set_finish_with_abort( | ||
| "DFLASH speculative decoding does not support grammar-constrained decoding yet." | ||
| ) | ||
| self.init_req_max_new_tokens(req) | ||
| self._add_request_to_queue(req) | ||
| return |
There was a problem hiding this comment.
Move this to an dflash compatiblity checker helper int scheduler.py
There was a problem hiding this comment.
refactored it to the style of validate_input_length
| @@ -832,11 +851,6 @@ def forward_extend( | |||
| ) | |||
|
|
|||
| else: | |||
There was a problem hiding this comment.
cc @ClawSeven. I think this is a nicer fix for the general dllm forward.
| ] | ||
|
|
||
| appended = 0 | ||
| if ( |
There was a problem hiding this comment.
The dflash (spec v1)'s implementation about the stop strs are quite messy, please try to add some UT to verify your implementations. All the related tests (kits) can be found in eagle UTs.
There was a problem hiding this comment.
i found no performance benefit from this implementation, so i simplified to be more similar with eagle3/other spec methods
| model_runner.spec_algorithm.is_eagle() | ||
| or model_runner.spec_algorithm.is_standalone() | ||
| or model_runner.spec_algorithm.is_ngram() | ||
| or model_runner.spec_algorithm.is_dflash() |
There was a problem hiding this comment.
You can add a new method called: is_spec or something simliar instead of flat calling.
There was a problem hiding this comment.
i added a is_speculative() function to python/sglang/srt/speculative/spec_info.py and used that in places that made sense. left is_none() and other individual helpers (is_eagle(), etc) in place.
| ) | ||
| # EAGLE/standalone/ngram draft workers use separate cuda-graph runners; do not | ||
| # capture TARGET_VERIFY graphs here. DFLASH draft uses a fixed-size block and | ||
| # reuses TARGET_VERIFY graphs for performance. |
There was a problem hiding this comment.
I think this only refers to "reuses the TARGET_VERIFY mode" instead of the real graph instances.
| self.graphs[graph_key].replay() | ||
| output = self.output_buffers[graph_key] | ||
|
|
||
| if isinstance(output, torch.Tensor): |
There was a problem hiding this comment.
Why not put the output tensor also in the LogitsProcessorOutput?
There was a problem hiding this comment.
good catch, this was implementation debt. i removed the DFlash specific raw tensor path, and restored the typed output path.
| register_cuda_ci(est_time=561, suite="stage-b-test-large-1-gpu") | ||
|
|
||
|
|
||
| class TestDFlashEngine(CustomTestCase): |
There was a problem hiding this comment.
No engine test would be needed for dflash. The prev eagle_infer_a eagle_infer_b is duplicated and will be merged and simplified.
There was a problem hiding this comment.
For dflash, each engine start would cost one file (or one test class).
DFlash Speculative Decoding Support
This PR adds support for Dflash speculative decoding:
Overview
New Files
python/sglang/srt/models/dflash.pyDFlashAttention: Non-causal attention (AttentionType.ENCODER_ONLY) with per-head Q/K normalizationDFlashDraftModel: No embedding/LM head (uses target model's). Projects concatenated target-layer features viafc+hidden_normkv_proj_only()method skips Q computation when materializing context tokens into draft KV cachedflash_config(target_layer_ids, block_size, mask_token)python/sglang/srt/speculative/dflash_worker.pyreq_to_token_pooland allocator (EAGLE3-style)TARGET_VERIFYmodeall_gatherpython/sglang/srt/speculative/dflash_info.pyDFlashDraftInput: Per-batch state tracking verified tokens, target hidden features, draft cache lengthsDFlashVerifyInput: Verify-forward inputs with custom attention masks, positions, draft tokensverify(): Greedy verification computing accept lengths, committing tokens, updating cachespython/sglang/srt/speculative/dflash_utils.pybuild_target_layer_ids(): Select evenly-spaced target layers for context features (mirrors reference impl)compute_dflash_accept_len_and_bonus(): Accept length calculation (accepts while draft == target)dflash_configresolutionbenchmark/dflash/bench_dflash_gsm8k_sweep.pyModified Files
python/sglang/srt/server_args.py--speculative-dflash-block-sizeargumentDFLASH, createsDFlashWorkerinstancepython/sglang/srt/speculative/spec_info.pySpeculativeAlgorithm.DFLASHenum variantDFlashWorkerpython/sglang/srt/managers/schedule_batch.pyDFlashDraftInputandDFlashVerifyInputasspec_infopython/sglang/srt/model_executor/model_runner.pyCaptureHiddenMode.FULLsupport for capturing intermediate layer features during verifyKey Features
TARGET_VERIFYmodeLimitations
Testing Setup
Hardware:
p5en.48xlarge(H200) andp6-b200.48xlarge(B200)Models:
Qwen/Qwen3-8Bz-lab/Qwen3-8B-DFlash-b16Configuration:
flashinfer(H200 and B200),fa3(H200 only)Workload:
temperature=0.0,top_p=1.0,top_k=1Accuracy Testing
Test configuration:
Note on numerical differences:
Dflash uses prefill kernels for multi-token verification (
TARGET_VERIFYmode), while baseline tests use decode kernels for single-token generation. These different kernel implementations can produce small numerical differences but this does not affect overall accuracy or generation quality.Summary:
H200 Accuracy Results (fa3 backend)
Baseline accuracy
Dflash accuracy
Average accept length (tokens per verify step)
H200 Accuracy Results (flashinfer backend)
Baseline accuracy
Dflash accuracy
Average accept length (tokens per verify step)
B200 Accuracy Results (flashinfer backend)
Baseline accuracy
Dflash accuracy
Average accept length (tokens per verify step)
Benchmarks
Test configuration:
Summary:
H200 Results (fa3 backend)
Baseline throughput (tok/s)
Dflash throughput (tok/s)
Speedup (Dflash / baseline)
H200 Results (flashinfer backend)
Baseline throughput (tok/s)
Dflash throughput (tok/s)
Speedup (Dflash / baseline)
B200 Results (flashinfer backend)
Baseline throughput (tok/s)
Dflash throughput (tok/s)
Speedup (Dflash / baseline)
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci