Skip to content

[Feature] Spec V2 DFlash Support#23000

Open
dcw02 wants to merge 23 commits into
mainfrom
dcw02/dflash-spec-v2
Open

[Feature] Spec V2 DFlash Support#23000
dcw02 wants to merge 23 commits into
mainfrom
dcw02/dflash-spec-v2

Conversation

@dcw02
Copy link
Copy Markdown
Collaborator

@dcw02 dcw02 commented Apr 16, 2026

Motivation

Add spec v2 to DFlash

Benchmarks

Run on gcp b200:8 node, using a gsm8k sweep script, qwen3-8b target, z-lab/Qwen3-8B-DFlash-b16 draft model, trtllm_mha target attention, fa4 draft attention, piecewise cuda graphs on.

v1 performance

DFLASH output tok/s
tp\conc       1         32
-------  ------  ---------
      1  845.98  11,405.85

DFLASH accuracy
tp\conc      1     32
-------  -----  -----
      1  0.852  0.844

DFLASH acceptance length (mean spec_accept_length)
tp\conc      1     32
-------  -----  -----
      1  6.345  6.487

v2 performance

DFLASH output tok/s
tp\conc         1         32
-------  --------  ---------
      1  1,161.88  15,326.81

DFLASH accuracy
tp\conc      1     32
-------  -----  -----
      1  0.852  0.844

DFLASH acceptance length (mean spec_accept_length)
tp\conc      1     32
-------  -----  -----
      1  6.352  6.482

this spec v2 version also brings in some extra optimizations compared to #20547 which brought bs1 performance from 900 -> 1161 tok/s and bs32 from 12,300 -> 15,326 tok/s.

Benchmarking is done with this script using the command SGLANG_ENABLE_SPEC_V2=1 SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1 python benchmark/dflash/bench_dflash_gsm8k_sweep.py --skip-baseline --tp-sizes 1 --concurrencies 1,32 --attention-backends trtllm_mha --speculative-draft-attention-backend fa4 on 1xB200

i removed mamba memory calculations to add later once i figure out the best way to do that

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@dcw02
Copy link
Copy Markdown
Collaborator Author

dcw02 commented Apr 16, 2026

/rerun-test test/registered/spec/dflash/test_dflash.py

@github-actions
Copy link
Copy Markdown
Contributor

1-gpu-5090 (1 test): View workflow run

cd test/ && python3 registered/spec/dflash/test_dflash.py

@ggg-s
Copy link
Copy Markdown

ggg-s commented Apr 16, 2026

What optimizations were made on top of PR #20547? PCG?

@dcw02
Copy link
Copy Markdown
Collaborator Author

dcw02 commented Apr 16, 2026

What optimizations were made on top of PR #20547? PCG?

I rewrote the fused kv helper, added some new triton ops, removed some syncs, etc. PCG already exists, I did not add it.

@dcw02
Copy link
Copy Markdown
Collaborator Author

dcw02 commented Apr 24, 2026

I am investigating accept length degradations for both v1 and v2 paths in this PR but not in #20547

@dcw02
Copy link
Copy Markdown
Collaborator Author

dcw02 commented Apr 25, 2026

the accept length degradation issue has been fixed, it was a rope config handling issue when transformers version got bumped

@dcw02
Copy link
Copy Markdown
Collaborator Author

dcw02 commented Apr 25, 2026

so i realized we can carry reserved kv allocation metadata through overlap draft state and let next-step prep use the prepared allocation watermark, we could get rid of a scheduling bubble that helps low concurrency a lot. for correctness, scheduler output processing applies the request watermark monotonically later.

prior v2 baseline:

DFLASH output tok/s
tp\conc       1         32
-------  ------  ---------
      1  975.04  12,956.91

decoupling:

DFLASH output tok/s
tp\conc         1         32
-------  --------  ---------
      1  1,073.64  12,995.40

@dcw02
Copy link
Copy Markdown
Collaborator Author

dcw02 commented Apr 27, 2026

@tugot17 yes, we can merge it after spec v2 dflash is merged. thanks for your contribution!

@liusy58
Copy link
Copy Markdown
Collaborator

liusy58 commented Apr 27, 2026

@dcw02 Thank you for your reply. Can we chat on slack?

@tugot17
Copy link
Copy Markdown
Contributor

tugot17 commented Apr 27, 2026

@dcw02
I added the LFM changes, but if it will be easier to add it after the DFLash is merge to main in the first place than let's wait

https://github.com/sgl-project/sglang/pull/23847/changes

@liusy58
Copy link
Copy Markdown
Collaborator

liusy58 commented Apr 28, 2026

@dcw02 Could you please resolve these merge conflicts?

@dcw02
Copy link
Copy Markdown
Collaborator Author

dcw02 commented Apr 28, 2026

@liusy58 fixed merged conflicts

@dcw02 dcw02 requested a review from kpham-sgl as a code owner May 1, 2026 04:56
@dcw02 dcw02 force-pushed the dcw02/dflash-spec-v2 branch from 8ae7dd3 to 9893ef8 Compare May 1, 2026 05:01
@dcw02
Copy link
Copy Markdown
Collaborator Author

dcw02 commented May 1, 2026

I will put up separate PRs for the draft swa layers and gemma 4 support so they can be merged in first for v1

yahya010 added a commit to abdelfattah-lab/sglang that referenced this pull request May 4, 2026
…roject#23000

Cherry-picked the two files needed for smcsd's DFlash direct-load path:
- python/sglang/srt/models/dflash.py (DFlashDraftModel + DFlashDecoderLayer)
- python/sglang/srt/speculative/dflash_utils.py (helpers used by the model)

Copied from sglang upstream PR refs/pull/23000/head, which is the canonical
implementation of DFlash speculative decoding referenced by checkpoints like
z-lab/Qwen3.6-27B-DFlash. Adding the model class to our branch lets smcsd's
_init_dflash_direct load DFlash drafts directly via sglang's class registry
instead of transformers' trust_remote_code (which would 404 on dflash.py).

The other DFlash files in PR sgl-project#23000 (dflash_worker, dflash_info,
dflash_accept_bonus, etc.) are sglang-side speculative decoding scaffolding
not used by smcsd's SMC-DFlash worker.
@ggg-s
Copy link
Copy Markdown

ggg-s commented May 7, 2026

hi @dcw02 Is the current PR compatible with DFLASH + FlashInfer + mixed batches?

@dcw02
Copy link
Copy Markdown
Collaborator Author

dcw02 commented May 7, 2026

hi @dcw02 Is the current PR compatible with DFLASH + FlashInfer + mixed batches?

I haven't tested that myself so I'm unsure

@ggg-s
Copy link
Copy Markdown

ggg-s commented May 8, 2026

hi @dcw02 Can the current PCG be used?

class TestDFlashServerSpecV2(TestDFlashServerBase):
spec_v2 = True

@unittest.skip
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: why do we need to skip this?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tested and it passes, so I re-enabled it. I think prior to the merge commit f39d86d4 there was a bug with PCG in flashinfer backend that caused it to fail.

@@ -26,6 +28,8 @@ class TestDFlashServerBase(CustomTestCase, MatchedStopMixin, GSM8KMixin):
attention_backend = "flashinfer"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: Does dflash only support flashinfer?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DFlash supports fa3, fa4, flashinfer, and triton as speculative draft attention backends. For best performance on b200s you can mix trtllm_mha target attention backend with fa4 draft attention backend. trtllm_mha attention isn't supported for DFlash draft since it requires non-causal full_attention/ENCODER_ONLY.

@@ -110,6 +97,23 @@ def _lazy_init_buf(self, draft_input: EagleDraftInput):
device=self.device,
)

if self.spec_algo.is_dflash():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I prefer adding a more general function (something like need_topk) instead of checking whether it's dflash here. What do you think?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes agreed, I changed it to SpeculativeAlgorithm.need_topk() instead of special casing it to DFlash

Comment thread python/sglang/srt/server_args.py Outdated
logger.warning(
"Overlap scheduler is disabled when using DFLASH speculative decoding (spec v2 is not supported yet)."
)
if envs.SGLANG_ENABLE_SPEC_V2.get():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spec v2 is opened by default. the logic here may need to be changed

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes let me do a merge from main and I will update the logic here

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok updated the logic now that spec v2 is default

@dssugar
Copy link
Copy Markdown

dssugar commented May 11, 2026

FYI: I tried this PR's gemma4_causal.py / gemma4_mm.py changes against
dense Gemma 4 31B (RedHatAI/gemma-4-31B-it-NVFP4) with z-lab/gemma-4-31B-it-DFlash
as the drafter on a single RTX 5090 (sm120) with attention_backend=triton
(fa4 / trtllm_mha are not available on this device). v1 DFlash path.

Results (warm, temperature=0, short prompts):

  • code 100w: 158.4 tok/s (vs MTP baseline 83.8 = 1.89x; vs the device's
    vLLM main + MTP at 164 = 97%)
  • haiku: 92.9 tok/s, jp: 64.0 tok/s
  • server-log accept length (code peak): 4.47, accept rate 0.23
  • jp accept length: 1.80, rate 0.05 (predictable: JP is harder to draft)

One small thing that wasn't covered by this PR: Gemma 4 ties lm_head to
embed_tokens (a plain nn.Embedding subclass, not VocabParallelEmbedding),
so dflash_worker._prepare_for_speculative_decoding rejects it at
hasattr(lm_head, "shard_indices"). I worked around it by setattr-ing a
trivial VocabParallelEmbeddingShardIndices (tp=1, num_added=0) onto
lm_head, which lets the fast path (tp_size == 1 and num_added == 0) match
without touching the TP / added-vocab branches. Approximate diff:

# in gemma4_causal.py
from sglang.srt.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbeddingShardIndices,
)

def _ensure_dflash_shard_indices(lm_head, vocab_size: int) -> None:
    if getattr(lm_head, \"shard_indices\", None) is not None:
        return
    lm_head.shard_indices = VocabParallelEmbeddingShardIndices(
        padded_org_vocab_start_index=0,
        padded_org_vocab_end_index=vocab_size,
        padded_added_vocab_start_index=vocab_size,
        padded_added_vocab_end_index=vocab_size,
        org_vocab_start_index=0,
        org_vocab_end_index=vocab_size,
        added_vocab_start_index=vocab_size,
        added_vocab_end_index=vocab_size,
    )

# call after the lm_head assignment in __init__ of
# Gemma4ForCausalLM and Gemma4ForConditionalGeneration:
_ensure_dflash_shard_indices(self.lm_head, vocab_size)

Not asking for changes to this PR (its scope is V2 / overlap scheduling).
Just sharing in case someone hits the same wall on a tied-embedding model.
I've only verified short/temp=0 generation so far — haven't checked for the
gibberish loop reported on vLLM #41262 (TP=2). Thanks for the work on V2!

@dcw02
Copy link
Copy Markdown
Collaborator Author

dcw02 commented May 11, 2026

@dssugar feel free to put up a PR to add gemma 4 support to v1!

@dcw02
Copy link
Copy Markdown
Collaborator Author

dcw02 commented May 13, 2026

I added a DFlash only prefill refill heuristic for online scheduling, I think it can be removed when mixed prefill decode is fully implemented. Previously, the scheduler would admit new prefill as soon as a single running request finished, at max concurrency this produced many one request prefill batches near full decode occupancy that massively reduced throughput. The heuristic now waits until a small target number of running request slots are free before admitting prefill work to refill together.

I ran extensive sweeps and ablations using many target / draft model combinations and set a good default target of 2/3/4/4 for max-running 8/16/32/64. An env override SGLANG_DFLASH_PREFILL_REFILL_TARGET remains for benchmarking/restoring old immediate-refill behavior.

This could be more generalized to other spec methods but I'm keeping it DFlash specific since all the benchmarking evidence and tuning is for DFlash's refill cadence. Other spec algorithms are unchanged.

Qwen/Qwen3-8B gsm8k concurrency 32 performance: 13,879.59 tok/s -> 15,326.81 tok/s

@tugot17
Copy link
Copy Markdown
Contributor

tugot17 commented May 20, 2026

@dcw02 any progess on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants