Skip to content

[Feature] Add DFLASH speculative decoding support#16818

Closed
dcw02 wants to merge 67 commits intosgl-project:mainfrom
modal-labs:dflash
Closed

[Feature] Add DFLASH speculative decoding support#16818
dcw02 wants to merge 67 commits intosgl-project:mainfrom
modal-labs:dflash

Conversation

@dcw02
Copy link
Copy Markdown
Contributor

@dcw02 dcw02 commented Jan 9, 2026

DFlash Speculative Decoding Support

This PR adds support for Dflash speculative decoding:

Overview

New Files

python/sglang/srt/models/dflash.py

  • Draft model implementation using SGLang primitives (RadixAttention, RMSNorm, rotary embeddings)
  • DFlashAttention: Non-causal attention (AttentionType.ENCODER_ONLY) with per-head Q/K normalization
  • DFlashDraftModel: No embedding/LM head (uses target model's). Projects concatenated target-layer features via fc + hidden_norm
  • Optimized kv_proj_only() method skips Q computation when materializing context tokens into draft KV cache
  • Configurable via dflash_config (target_layer_ids, block_size, mask_token)

python/sglang/srt/speculative/dflash_worker.py

  • Main worker coordinating draft and target models (spec-v1 only)
  • Separate draft KV cache pool, shared req_to_token_pool and allocator (EAGLE3-style)
  • Draft generation: fills block with mask token embeddings, runs draft model with TARGET_VERIFY mode
  • TP-safe greedy sampling over vocab-parallel LM head with chunking and cross-rank all_gather
  • Materializes target hidden states into draft KV cache (applies projection, K/V norms, RoPE, writes to pool)
  • Preallocated buffers for draft blocks (CUDA graph compatible)

python/sglang/srt/speculative/dflash_info.py

  • DFlashDraftInput: Per-batch state tracking verified tokens, target hidden features, draft cache lengths
  • DFlashVerifyInput: Verify-forward inputs with custom attention masks, positions, draft tokens
  • Implements batch filtering/merging, KV allocation for verify blocks
  • verify(): Greedy verification computing accept lengths, committing tokens, updating caches

python/sglang/srt/speculative/dflash_utils.py

  • build_target_layer_ids(): Select evenly-spaced target layers for context features (mirrors reference impl)
  • compute_dflash_accept_len_and_bonus(): Accept length calculation (accepts while draft == target)
  • Config helpers for dflash_config resolution

benchmark/dflash/bench_dflash_gsm8k_sweep.py

  • GSM8K evaluation with block size sweeps and acceptance rate tracking

Modified Files

python/sglang/srt/server_args.py

  • Added --speculative-dflash-block-size argument
  • Algorithm enum supports DFLASH, creates DFlashWorker instance
  • Reserved memory estimation for draft model

python/sglang/srt/speculative/spec_info.py

  • Added SpeculativeAlgorithm.DFLASH enum variant
  • Worker creation logic routes to DFlashWorker

python/sglang/srt/managers/schedule_batch.py

  • Batch state supports DFlashDraftInput and DFlashVerifyInput as spec_info

python/sglang/srt/model_executor/model_runner.py

  • CaptureHiddenMode.FULL support for capturing intermediate layer features during verify

Key Features

  • TP Support: Draft model and greedy sampling work across multiple TP ranks
  • Radix Cache: Draft KV materialized before radix updates for consistency
  • CUDA Graphs: Fixed-size verify forwards with TARGET_VERIFY mode
  • Multiple Backends: Draft model supports flashinfer, fa3
  • Memory Efficient: EAGLE3-style shared allocator, temporary block allocations freed after drafting

Limitations

  • Spec-v1 only (no overlap scheduling)
  • Greedy verification only
  • No grammar/logprobs support
  • Page size must be 1

Testing Setup

Hardware:

  • AWS EC2 instances: p5en.48xlarge (H200) and p6-b200.48xlarge (B200)
  • NVIDIA driver: 580.95.05
  • CUDA toolkit: 13.1

Models:

  • Target: Qwen/Qwen3-8B
  • Draft: z-lab/Qwen3-8B-DFlash-b16

Configuration:

  • TP sizes: 1, 2, 4, 8 (Note: Qwen3-8B is a small model; TP > 1 is primarily to demonstrate support rather than a recommended configuration. The draft worker also runs with tensor parallelism.)
  • Concurrencies: 1, 2, 4, 8, 16, 32
  • Attention backends: flashinfer (H200 and B200), fa3 (H200 only)

Workload:

  • Dataset: GSM8K
  • Max new tokens: 2048 per request
  • Prompt style: chat (zero-shot, not fewshot QA)
  • Sampling parameters: temperature=0.0, top_p=1.0, top_k=1

Accuracy Testing

Test configuration:

  • Number of samples: 128 for all TP size / concurrency pairs
  • Metrics: GSM8K answer accuracy, average accept length per verify step

Note on numerical differences:

Dflash uses prefill kernels for multi-token verification (TARGET_VERIFY mode), while baseline tests use decode kernels for single-token generation. These different kernel implementations can produce small numerical differences but this does not affect overall accuracy or generation quality.

Summary:

  • Baseline and Dflash achieve comparable accuracy (82-87% on GSM8K) across all configurations
  • Accuracy differences between baseline and Dflash are within normal variance (~±2%)
  • Accept length is highly consistent at ~6.3-6.4 tokens per verify step across all hardware, backends, TP sizes, and concurrency levels
  • With block size of 16, this represents an acceptance rate of ~40% of draft tokens
  • No systematic accuracy degradation observed with different configurations or attention backends

H200 Accuracy Results (fa3 backend)

Baseline accuracy

TP \ Concurrency 1 2 4 8 16 32
1 0.852 0.852 0.852 0.844 0.859 0.852
2 0.844 0.859 0.859 0.852 0.867 0.844
4 0.836 0.859 0.836 0.828 0.852 0.852
8 0.820 0.820 0.820 0.852 0.836 0.828

Dflash accuracy

TP \ Concurrency 1 2 4 8 16 32
1 0.852 0.852 0.844 0.852 0.859 0.836
2 0.844 0.859 0.836 0.859 0.852 0.859
4 0.836 0.836 0.852 0.867 0.828 0.867
8 0.820 0.820 0.852 0.844 0.859 0.859

Average accept length (tokens per verify step)

TP \ Concurrency 1 2 4 8 16 32
1 6.36 6.34 6.34 6.36 6.29 6.38
2 6.36 6.40 6.36 6.44 6.41 6.43
4 6.36 6.36 6.36 6.31 6.31 6.37
8 6.33 6.31 6.29 6.36 6.29 6.30

H200 Accuracy Results (flashinfer backend)

Baseline accuracy

TP \ Concurrency 1 2 4 8 16 32
1 0.844 0.844 0.844 0.844 0.852 0.852
2 0.852 0.852 0.836 0.836 0.852 0.852
4 0.836 0.836 0.836 0.836 0.844 0.852
8 0.836 0.836 0.836 0.844 0.836 0.852

Dflash accuracy

TP \ Concurrency 1 2 4 8 16 32
1 0.844 0.844 0.852 0.844 0.852 0.836
2 0.859 0.812 0.859 0.844 0.844 0.852
4 0.852 0.852 0.859 0.852 0.852 0.836
8 0.852 0.836 0.828 0.844 0.852 0.852

Average accept length (tokens per verify step)

TP \ Concurrency 1 2 4 8 16 32
1 6.37 6.37 6.37 6.32 6.33 6.30
2 6.31 6.31 6.33 6.32 6.35 6.31
4 6.32 6.37 6.37 6.37 6.32 6.38
8 6.29 6.28 6.35 6.30 6.34 6.39

B200 Accuracy Results (flashinfer backend)

Baseline accuracy

TP \ Concurrency 1 2 4 8 16 32
1 0.859 0.859 0.859 0.859 0.859 0.836
2 0.875 0.875 0.875 0.875 0.875 0.875
4 0.859 0.859 0.859 0.852 0.859 0.859
8 0.844 0.844 0.844 0.844 0.859 0.844

Dflash accuracy

TP \ Concurrency 1 2 4 8 16 32
1 0.844 0.844 0.844 0.844 0.844 0.852
2 0.852 0.859 0.859 0.867 0.859 0.859
4 0.852 0.852 0.859 0.852 0.852 0.852
8 0.836 0.836 0.836 0.836 0.844 0.836

Average accept length (tokens per verify step)

TP \ Concurrency 1 2 4 8 16 32
1 6.34 6.35 6.34 6.34 6.35 6.35
2 6.34 6.28 6.31 6.31 6.30 6.32
4 6.30 6.35 6.38 6.38 6.36 6.37
8 6.38 6.33 6.36 6.35 6.36 6.36

Benchmarks

Test configuration:

  • Number of samples: 128, 256, 512, 1024, 1024, 1024 for concurrencies 1, 2, 4, 8, 16, 32 respectively
  • Speedup metric: time-based

Summary:

  • Peak speedup: 3.73x (H200, fa3, TP=1, concurrency=1)
  • Peak throughput: 8,335 tok/s (B200, flashinfer, TP=1, concurrency=32)
  • Average accept length: ~6.3-6.5 tokens per verify step across all configurations

H200 Results (fa3 backend)

Baseline throughput (tok/s)

TP \ Concurrency 1 2 4 8 16 32
1 183 355 698 1,312 2,525 4,384
2 249 476 920 1,727 3,088 5,030
4 322 600 1,157 2,133 3,654 5,878
8 354 657 1,237 2,224 3,743 5,575

Dflash throughput (tok/s)

TP \ Concurrency 1 2 4 8 16 32
1 681 1,309 2,381 4,094 5,957 7,318
2 732 1,371 2,497 4,119 5,964 7,474
4 765 1,423 2,521 4,211 6,132 8,188
8 746 1,345 2,438 4,032 6,068 8,260

Speedup (Dflash / baseline)

TP \ Concurrency 1 2 4 8 16 32
1 3.73x 3.69x 3.41x 3.12x 2.36x 1.67x
2 2.94x 2.88x 2.71x 2.39x 1.93x 1.49x
4 2.37x 2.37x 2.18x 1.97x 1.68x 1.39x
8 2.10x 2.05x 1.97x 1.81x 1.62x 1.48x

H200 Results (flashinfer backend)

Baseline throughput (tok/s)

TP \ Concurrency 1 2 4 8 16 32
1 182 359 700 1,298 2,436 4,109
2 256 486 930 1,719 2,996 4,870
4 339 627 1,184 2,126 3,583 5,387
8 388 724 1,352 2,395 3,927 5,606

Dflash throughput (tok/s)

TP \ Concurrency 1 2 4 8 16 32
1 672 1,244 2,238 3,698 5,283 6,297
2 701 1,305 2,339 3,825 5,432 6,921
4 720 1,333 2,352 3,883 5,604 7,426
8 715 1,276 2,289 3,745 5,562 7,504

Speedup (Dflash / baseline)

TP \ Concurrency 1 2 4 8 16 32
1 3.70x 3.46x 3.20x 2.85x 2.17x 1.53x
2 2.74x 2.69x 2.52x 2.23x 1.81x 1.42x
4 2.12x 2.13x 1.99x 1.83x 1.56x 1.38x
8 1.84x 1.76x 1.69x 1.56x 1.42x 1.34x

B200 Results (flashinfer backend)

Baseline throughput (tok/s)

TP \ Concurrency 1 2 4 8 16 32
1 211 416 814 1,555 2,929 5,027
2 248 491 953 1,767 3,302 5,194
4 323 609 1,137 2,119 3,764 5,963
8 348 679 1,277 2,331 4,041 4,975

Dflash throughput (tok/s)

TP \ Concurrency 1 2 4 8 16 32
1 774 1,422 2,562 4,335 6,366 8,335
2 803 1,469 2,577 4,173 5,947 7,777
4 825 1,505 2,468 4,061 5,699 7,626
8 810 1,293 2,337 3,839 5,684 7,426

Speedup (Dflash / baseline)

TP \ Concurrency 1 2 4 8 16 32
1 3.67x 3.42x 3.15x 2.79x 2.17x 1.66x
2 3.23x 2.99x 2.71x 2.36x 1.80x 1.50x
4 2.55x 2.47x 2.17x 1.92x 1.51x 1.28x
8 2.33x 1.91x 1.83x 1.65x 1.41x 1.49x

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@github-actions github-actions bot added the blackwell SM100/SM120 label Jan 9, 2026
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @dcw02, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates DFLASH speculative decoding into the system, aiming to significantly boost the generation throughput of large language models. It achieves this by introducing a native DFLASH draft model, adapting existing attention mechanisms and the core model execution pipeline to support DFLASH's unique verification process, and providing robust configuration options and performance benchmarks. The changes enable faster token generation while maintaining output quality.

Highlights

  • DFLASH Speculative Decoding Integration: Introduced comprehensive support for DFLASH speculative decoding, a technique designed to accelerate large language model inference by using a smaller, faster draft model to propose tokens that are then verified by the larger target model.
  • Native DFLASH Draft Model Implementation: Implemented the DFlashDraftModel and associated components (DFlashAttention, DFlashMLP) directly within SGLang, allowing for efficient native execution of the draft model without token embeddings or an LM head, leveraging the target model's components.
  • Attention Backend and CUDA Graph Enhancements: Modified FlashInfer and TRTLLM attention backends to correctly handle DFLASH's specific non-causal masking requirements for draft blocks and ensure proper CUDA graph capture for performance optimization, including auxiliary hidden state capture.
  • Scheduler and Model Runner Adaptations: Updated the scheduler to manage DFLASH-specific KV cache release and draft state, and integrated DFLASH into the ModelRunner and CudaGraphRunner for auxiliary hidden state capture and efficient execution flow.
  • Command-Line Argument and Configuration: Added a new --speculative-algorithm DFLASH option, along with specific argument handling for DFLASH, including validation for distributed attention, pipeline parallelism, and automatic inference of speculative_num_draft_tokens from the draft model's configuration.
  • Performance Benchmarking and Testing: Included new manual tests for DFLASH correctness against a target-only baseline and a GSM8K benchmark to measure speedup and acceptance rates, demonstrating significant throughput improvements (e.g., 3.729x speedup on h200 fa3).

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@dcw02
Copy link
Copy Markdown
Contributor Author

dcw02 commented Jan 9, 2026

The code is currently being cleaned up but I'll mark the PR as ready and edit the summary when it's ready.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for DFLASH speculative decoding, a significant new feature. The changes are extensive, including a new DFLASH draft model, a worker for its execution, and integration into the existing speculative decoding framework. The implementation appears solid and correctly follows the DFLASH algorithm, with necessary modifications to attention backends, the CUDA graph runner, and server arguments. I've identified a minor logic issue in the weight loading mechanism for the new DFLASH model and have suggested a refactoring to improve clarity and correctness. The rest of the changes are well-implemented.

@nutriarch
Copy link
Copy Markdown

will this support NVFP4?

@dcw02 dcw02 marked this pull request as ready for review January 10, 2026 07:17
@yuyangxie96
Copy link
Copy Markdown

Can dflash support enabling dp-attention simultaneously in the future?

@dcw02
Copy link
Copy Markdown
Contributor Author

dcw02 commented Mar 3, 2026

Can dflash support enabling dp-attention simultaneously in the future?

yes that can be done. it exists for EAGLE3, I would have to take a look how it's implemented

harryjing pushed a commit to harryjing/sglang that referenced this pull request Mar 19, 2026
Cherry-pick from sgl-project#16818

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
harryjing pushed a commit to harryjing/sglang that referenced this pull request Mar 19, 2026
…project#20547)

Cherry-pick from sgl-project#20547, resolved conflicts with PR sgl-project#16818.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
harryjing pushed a commit to harryjing/sglang that referenced this pull request Mar 19, 2026
Cherry-pick from sgl-project#16818 onto v0.5.9

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@@ -0,0 +1,749 @@
"""DFLASH vs baseline GSM8K sweep.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be added in the following PRs. Instead of adding standalone benchmark scripts, please add CI (UT) first. We can discuss how to integrate the benchmark/eval scripts of dflash into SGLang later.

Ref: https://github.com/sgl-project/sglang/blob/main/.claude/skills/write-sglang-test/SKILL.md

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleaned up and removed for now

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok also added dflash tests, modeled after the EAGLE ones: a basic MMLU + accept-length test, an infer-a correctness file with radix/page-size variants, and an infer-beta file for stop conditions, radix attention, GSM8K, and paged mode.

self._add_request_to_queue(req)
return

if self.spec_algorithm.is_dflash() and req.return_logprob:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the logprob support can be quite easy after the #21048 is done. cc @Qiaolin-Yu

Comment on lines +1636 to +1654
if self.spec_algorithm.is_dflash() and req.return_logprob:
req.set_finish_with_abort(
"DFLASH speculative decoding does not support return_logprob yet."
)
self.init_req_max_new_tokens(req)
self._add_request_to_queue(req)
return
if self.spec_algorithm.is_dflash() and (
req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None
or req.sampling_params.ebnf is not None
or req.sampling_params.structural_tag is not None
):
req.set_finish_with_abort(
"DFLASH speculative decoding does not support grammar-constrained decoding yet."
)
self.init_req_max_new_tokens(req)
self._add_request_to_queue(req)
return
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this to an dflash compatiblity checker helper int scheduler.py

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactored it to the style of validate_input_length

@@ -832,11 +851,6 @@ def forward_extend(
)

else:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ClawSeven. I think this is a nicer fix for the general dllm forward.

]

appended = 0
if (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dflash (spec v1)'s implementation about the stop strs are quite messy, please try to add some UT to verify your implementations. All the related tests (kits) can be found in eagle UTs.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i found no performance benefit from this implementation, so i simplified to be more similar with eagle3/other spec methods

Comment on lines +476 to +479
model_runner.spec_algorithm.is_eagle()
or model_runner.spec_algorithm.is_standalone()
or model_runner.spec_algorithm.is_ngram()
or model_runner.spec_algorithm.is_dflash()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add a new method called: is_spec or something simliar instead of flat calling.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i added a is_speculative() function to python/sglang/srt/speculative/spec_info.py and used that in places that made sense. left is_none() and other individual helpers (is_eagle(), etc) in place.

)
# EAGLE/standalone/ngram draft workers use separate cuda-graph runners; do not
# capture TARGET_VERIFY graphs here. DFLASH draft uses a fixed-size block and
# reuses TARGET_VERIFY graphs for performance.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this only refers to "reuses the TARGET_VERIFY mode" instead of the real graph instances.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed the comment

self.graphs[graph_key].replay()
output = self.output_buffers[graph_key]

if isinstance(output, torch.Tensor):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not put the output tensor also in the LogitsProcessorOutput?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, this was implementation debt. i removed the DFlash specific raw tensor path, and restored the typed output path.

@hnyls2002
Copy link
Copy Markdown
Collaborator

@dcw02 @gongy This is a great PR! Thanks for that.

register_cuda_ci(est_time=561, suite="stage-b-test-large-1-gpu")


class TestDFlashEngine(CustomTestCase):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No engine test would be needed for dflash. The prev eagle_infer_a eagle_infer_b is duplicated and will be merged and simplified.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For dflash, each engine start would cost one file (or one test class).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

blackwell SM100/SM120 npu

Projects

None yet

Development

Successfully merging this pull request may close these issues.