Skip to content

[Feature] Add spec v2 (overlap scheduling) to DFlash speculative decoding support#20547

Open
dcw02 wants to merge 74 commits intosgl-project:mainfrom
modal-labs:dflash_v2
Open

[Feature] Add spec v2 (overlap scheduling) to DFlash speculative decoding support#20547
dcw02 wants to merge 74 commits intosgl-project:mainfrom
modal-labs:dflash_v2

Conversation

@dcw02
Copy link
Copy Markdown
Contributor

@dcw02 dcw02 commented Mar 13, 2026

Motivation

Add spec v2 path for DFlash. Should be merged after #16818

TLDR
B200, GSM8K, qwen3-8b, tp size 1, concurrency 32, max new tokens 2k, greedy decoding
9,688.26 tok/s -> 12,360.49 tok/s

Modifications

Adds v2 worker and related files

Accuracy and Benchmarks

Tested on a gcp b200 machine

Commands:

# regular v1
python benchmark/dflash/bench_dflash_gsm8k_sweep.py --tp-sizes 1 --concurrencies 32 --attention-backends trtllm_mha --speculative-draft-attention-backend fa4 --page-size 64 --skip-baseline

# overlap scheduling (spec v2)
SGLANG_ENABLE_SPEC_V2=1 SGLANG_ENABLE_DFLASH_SPEC_V2=1 SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1 python benchmark/dflash/bench_dflash_gsm8k_sweep.py --tp-sizes 1 --concurrencies 32 --attention-backends trtllm_mha --speculative-draft-attention-backend fa4 --page-size 64 --skip-baseline

v1 performance

=== DFLASH GSM8K Sweep Summary ===
target_model=Qwen/Qwen3-8B
draft_model=z-lab/Qwen3-8B-DFlash-b16
max_new_tokens=2048
sampling=temperature:0.0, top_p:1.0, top_k:1
attention_backends=trtllm_mha
speculative_draft_attention_backend=fa4
speculative_dflash_draft_window_size=None
tp_sizes=1
concurrencies=32
questions_per_concurrency_base=128
device_sm=100
skip_baseline=True

=== Backend: trtllm_mha ===

Baseline output tok/s
tp\conc   32
-------  ---
      1  N/A

Baseline accuracy
tp\conc   32
-------  ---
      1  N/A

DFLASH output tok/s
tp\conc        32
-------  --------
      1  9,688.26

DFLASH accuracy
tp\conc     32
-------  -----
      1  0.850

Speedup (DFLASH / baseline)
tp\conc   32
-------  ---
      1  N/A

DFLASH acceptance length (mean spec_accept_length)
tp\conc     32
-------  -----
      1  6.470

overlap scheduling (spec v2) performance

=== DFLASH GSM8K Sweep Summary ===
target_model=Qwen/Qwen3-8B
draft_model=z-lab/Qwen3-8B-DFlash-b16
max_new_tokens=2048
sampling=temperature:0.0, top_p:1.0, top_k:1
attention_backends=trtllm_mha
speculative_draft_attention_backend=fa4
speculative_dflash_draft_window_size=None
tp_sizes=1
concurrencies=32
questions_per_concurrency_base=128
device_sm=100
skip_baseline=True

=== Backend: trtllm_mha ===

Baseline output tok/s
tp\conc   32
-------  ---
      1  N/A

Baseline accuracy
tp\conc   32
-------  ---
      1  N/A

DFLASH output tok/s
tp\conc         32
-------  ---------
      1  12,360.49

DFLASH accuracy
tp\conc     32
-------  -----
      1  0.850

Speedup (DFLASH / baseline)
tp\conc   32
-------  ---
      1  N/A

DFLASH acceptance length (mean spec_accept_length)
tp\conc     32
-------  -----
      1  6.467

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant performance enhancement for DFlash speculative decoding by integrating a new version of overlap scheduling, referred to as spec v2. The changes involve adding specialized worker logic and data structures for DFlash, optimizing KV cache operations with fused Triton kernels, and enabling auxiliary hidden state capture in various models. This update aims to boost token generation throughput, as demonstrated by the provided benchmarks, while also laying the groundwork for more advanced speculative decoding capabilities.

Highlights

  • DFLASH Speculative Decoding v2: Implemented DFLASH speculative decoding with overlap scheduling (spec v2) to significantly improve token generation throughput.
  • Performance Improvement: Achieved a notable speedup from 9,688.26 tok/s to 12,360.49 tok/s on a B200 machine with Qwen3-8B and concurrency 32, representing a 27.6% increase.
  • New DFlash Worker and Components: Introduced dedicated DFlash worker implementations (DFlashWorker and DFlashWorkerV2) and new data structures (DFlashDraftInput, DFlashVerifyInput, DFlashDraftInputV2) to manage the speculative decoding process.
  • Auxiliary Hidden State Capture: Added support for capturing auxiliary hidden states in target models (e.g., GPT-OSS, Llama, Qwen3, Qwen3.5, Qwen3-MoE, Qwen3-Next, Qwen3-VL) required for DFlash context feature projection.
  • Fused KV Materialization: Integrated a Triton kernel for fused KV materialization, optimizing the process of projecting and storing Key/Value states in the draft model's cache.
  • Configuration and Restrictions: Added new server arguments (--speculative-dflash-block-size, --speculative-dflash-draft-window-size) and enforced restrictions for DFlash spec v2, such as supporting only greedy decoding and disallowing logprobs, hidden states, and grammar constraints in phase 1.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmark/dflash/bench_dflash_gsm8k_sweep.py
    • Added a new benchmark script to evaluate DFlash performance on the GSM8K dataset, including support for speculative decoding v2.
  • python/sglang/srt/environ.py
    • Added a new environment variable, SGLANG_ENABLE_DFLASH_SPEC_V2, to control the experimental DFlash spec v2 overlap scheduling.
  • python/sglang/srt/layers/attention/flashinfer_backend.py
    • Modified the FlashInfer attention backend to correctly handle custom masks for DFlash, ensuring proper non-causal attention for draft blocks.
  • python/sglang/srt/managers/schedule_batch.py
    • Updated the maybe_wait_verify_done method to support synchronization for DFlash spec v2 overlap scheduling.
  • python/sglang/srt/managers/scheduler.py
    • Implemented new validation checks and restrictions for DFlash speculative decoding requests, particularly for spec v2, to ensure compatibility with supported features like greedy decoding.
  • python/sglang/srt/model_executor/cuda_graph_runner.py
    • Extended CUDA graph capture logic to include DFlash, allowing for efficient execution of DFlash draft and verify steps, and integrated auxiliary hidden state capture.
  • python/sglang/srt/model_executor/model_runner.py
    • Integrated DFlash-specific configuration parsing, enabled auxiliary hidden state capture for DFlash, and adjusted KV cache scaling to accommodate combined target and draft KV pools.
  • python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py
    • Modified KV cache profiling to account for the combined memory footprint of target and DFlash draft KV pools, ensuring accurate memory allocation.
  • python/sglang/srt/models/dflash.py
    • Added a new DFlash draft model implementation, including DFlashAttention, DFlashMLP, and DFlashDecoderLayer, designed for efficient speculative decoding.
  • python/sglang/srt/models/gpt_oss.py
    • Added get_input_embeddings and set_dflash_layers_to_capture methods to support DFlash auxiliary hidden state capture in GPT-OSS models.
  • python/sglang/srt/models/llama.py
    • Added the set_dflash_layers_to_capture method to enable DFlash auxiliary hidden state capture in Llama models.
  • python/sglang/srt/models/qwen3.py
    • Added the set_dflash_layers_to_capture method to enable DFlash auxiliary hidden state capture in Qwen3 models.
  • python/sglang/srt/models/qwen3_5.py
    • Modified the Qwen3.5 model to support DFlash auxiliary hidden state capture and added the set_dflash_layers_to_capture method.
  • python/sglang/srt/models/qwen3_moe.py
    • Added the set_dflash_layers_to_capture method to enable DFlash auxiliary hidden state capture in Qwen3-MoE models.
  • python/sglang/srt/models/qwen3_next.py
    • Added get_input_embeddings and set_dflash_layers_to_capture methods to support DFlash auxiliary hidden state capture in Qwen3-Next models.
  • python/sglang/srt/models/qwen3_vl.py
    • Modified the Qwen3-VL model to support DFlash auxiliary hidden state capture and added the set_dflash_layers_to_capture method.
  • python/sglang/srt/server_args.py
    • Added new server arguments, --speculative-dflash-block-size and --speculative-dflash-draft-window-size, and updated DFlash configuration logic.
  • python/sglang/srt/speculative/dflash_info.py
    • Added new data structures, DFlashDraftInput and DFlashVerifyInput, to manage DFlash speculative decoding state for non-overlap scheduling.
  • python/sglang/srt/speculative/dflash_info_v2.py
    • Added new data structures, DFlashDraftInputV2, specifically designed for DFlash speculative decoding with overlap scheduling (spec v2).
  • python/sglang/srt/speculative/dflash_utils.py
    • Added utility functions for DFlash configuration parsing, KV cell size scaling, and accept length computation, including support for non-greedy sampling verification.
  • python/sglang/srt/speculative/dflash_worker.py
    • Added the DFlash speculative decoding worker for non-overlap scheduling, handling draft generation and target verification steps.
  • python/sglang/srt/speculative/dflash_worker_v2.py
    • Added the DFlash speculative decoding worker for overlap scheduling (spec v2), designed for improved performance by overlapping draft and verify phases.
  • python/sglang/srt/speculative/spec_info.py
    • Updated the SpeculativeAlgorithm enum to include DFLASH and extended the supports_spec_v2 logic to conditionally enable DFlash spec v2.
  • python/sglang/srt/speculative/triton_ops/init.py
    • Added FusedKVMaterializeHelper to the Triton operations module for DFlash optimizations.
  • python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py
    • Added a Triton kernel for fused KV materialization, combining KV projection, RMSNorm, and RoPE operations for DFlash performance.
Activity
  • The pull request introduces DFlash speculative decoding with overlap scheduling (spec v2).
  • Benchmarks provided in the description show a significant performance improvement from 9,688.26 tok/s to 12,360.49 tok/s.
  • The author has included a checklist, with items for code formatting, unit tests, documentation, and code style remaining unchecked.
  • No explicit comments or reviews have been made on the pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@dcw02
Copy link
Copy Markdown
Contributor Author

dcw02 commented Mar 13, 2026

Currently the only missing thing compared to v1 is non-greedy decoding support, it is being worked on.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces DFLASH speculative decoding, adding a specialized DFlashDraftModel and updating core components like ModelRunner, CudaGraphRunner, and Scheduler to support its specific requirements, including auxiliary hidden state capture and handling DFLASH-specific server arguments. New data structures (DFlashDraftInput, DFlashVerifyInput) and worker implementations (DFlashWorker, DFlashWorkerV2) manage the drafting and verification process, with optimizations like fused KV materialization. A new benchmark script is also included. An improvement opportunity exists in the scheduler to refactor duplicated logic for aborting requests with unsupported DFLASH features into a helper method for better maintainability.

Comment on lines +1636 to +1684
if self.spec_algorithm.is_dflash() and req.return_logprob:
req.set_finish_with_abort(
"DFLASH speculative decoding does not support return_logprob yet."
)
self.init_req_max_new_tokens(req)
self._add_request_to_queue(req)
return
if (
self.spec_algorithm.is_dflash()
and self.enable_overlap
and req.return_hidden_states
):
req.set_finish_with_abort(
"DFLASH spec-v2 phase 1 does not support return_hidden_states yet."
)
self.init_req_max_new_tokens(req)
self._add_request_to_queue(req)
return
if self.spec_algorithm.is_dflash() and (
req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None
or req.sampling_params.ebnf is not None
or req.sampling_params.structural_tag is not None
):
req.set_finish_with_abort(
"DFLASH speculative decoding does not support grammar-constrained decoding yet."
)
self.init_req_max_new_tokens(req)
self._add_request_to_queue(req)
return
if (
self.spec_algorithm.is_dflash()
and self.enable_overlap
and (
req.sampling_params.top_k > 1
or req.sampling_params.frequency_penalty != 0.0
or req.sampling_params.presence_penalty != 0.0
or req.sampling_params.repetition_penalty != 1.0
or req.sampling_params.logit_bias is not None
or req.custom_logit_processor is not None
)
):
req.set_finish_with_abort(
"DFLASH spec-v2 phase 1 only supports plain greedy decoding yet. "
"Non-greedy sampling, penalties, logit_bias, and custom logit processors are not enabled."
)
self.init_req_max_new_tokens(req)
self._add_request_to_queue(req)
return
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for aborting requests with unsupported DFLASH features is duplicated across several if blocks. This can be refactored into a helper method to reduce code repetition and improve maintainability.

For example, you could create a helper like this:

def _abort_dflash_request(self, req: Req, message: str):
    req.set_finish_with_abort(message)
    self.init_req_max_new_tokens(req)
    self._add_request_to_queue(req)

Then you can simplify the checks:

if self.spec_algorithm.is_dflash():
    if req.return_logprob:
        self._abort_dflash_request(req, "DFLASH speculative decoding does not support return_logprob yet.")
        return
    if self.enable_overlap and req.return_hidden_states:
        self._abort_dflash_request(req, "DFLASH spec-v2 phase 1 does not support return_hidden_states yet.")
        return
    # ... and so on

harryjing pushed a commit to harryjing/sglang that referenced this pull request Mar 19, 2026
…project#20547)

Cherry-pick from sgl-project#20547, resolved conflicts with PR sgl-project#16818.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
harryjing pushed a commit to harryjing/sglang that referenced this pull request Mar 19, 2026
…project#20547)

Cherry-pick from sgl-project#20547 onto v0.5.9, resolved conflicts.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@dcw02 dcw02 requested review from ch-wan and fzyzcjy as code owners April 7, 2026 23:31
@ggg-s
Copy link
Copy Markdown

ggg-s commented Apr 9, 2026

@dcw02 Does it currently support PCG?

@dcw02
Copy link
Copy Markdown
Contributor Author

dcw02 commented Apr 9, 2026

@dcw02 Does it currently support PCG?

i've enabled it without issues with --enforce-piecewise-cuda-graph

@dcw02
Copy link
Copy Markdown
Contributor Author

dcw02 commented Apr 9, 2026

i'm closing this PR and reopening it soon, from another branch. have some extra improvements

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants