[Feature] Spec V2 DFlash Support#23000
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
/rerun-test test/registered/spec/dflash/test_dflash.py |
|
✅ |
|
What optimizations were made on top of PR #20547? PCG? |
I rewrote the fused kv helper, added some new triton ops, removed some syncs, etc. PCG already exists, I did not add it. |
|
I am investigating accept length degradations for both v1 and v2 paths in this PR but not in #20547 |
|
the accept length degradation issue has been fixed, it was a rope config handling issue when transformers version got bumped |
|
so i realized we can carry reserved kv allocation metadata through overlap draft state and let next-step prep use the prepared allocation watermark, we could get rid of a scheduling bubble that helps low concurrency a lot. for correctness, scheduler output processing applies the request watermark monotonically later. prior v2 baseline: decoupling: |
|
@tugot17 yes, we can merge it after spec v2 dflash is merged. thanks for your contribution! |
|
@dcw02 Thank you for your reply. Can we chat on slack? |
|
@dcw02 https://github.com/sgl-project/sglang/pull/23847/changes |
|
@dcw02 Could you please resolve these merge conflicts? |
|
@liusy58 fixed merged conflicts |
8ae7dd3 to
9893ef8
Compare
|
I will put up separate PRs for the draft swa layers and gemma 4 support so they can be merged in first for v1 |
…roject#23000 Cherry-picked the two files needed for smcsd's DFlash direct-load path: - python/sglang/srt/models/dflash.py (DFlashDraftModel + DFlashDecoderLayer) - python/sglang/srt/speculative/dflash_utils.py (helpers used by the model) Copied from sglang upstream PR refs/pull/23000/head, which is the canonical implementation of DFlash speculative decoding referenced by checkpoints like z-lab/Qwen3.6-27B-DFlash. Adding the model class to our branch lets smcsd's _init_dflash_direct load DFlash drafts directly via sglang's class registry instead of transformers' trust_remote_code (which would 404 on dflash.py). The other DFlash files in PR sgl-project#23000 (dflash_worker, dflash_info, dflash_accept_bonus, etc.) are sglang-side speculative decoding scaffolding not used by smcsd's SMC-DFlash worker.
|
hi @dcw02 Is the current PR compatible with DFLASH + FlashInfer + mixed batches? |
I haven't tested that myself so I'm unsure |
|
hi @dcw02 Can the current PCG be used? |
| class TestDFlashServerSpecV2(TestDFlashServerBase): | ||
| spec_v2 = True | ||
|
|
||
| @unittest.skip |
There was a problem hiding this comment.
qq: why do we need to skip this?
There was a problem hiding this comment.
I just tested and it passes, so I re-enabled it. I think prior to the merge commit f39d86d4 there was a bug with PCG in flashinfer backend that caused it to fail.
| @@ -26,6 +28,8 @@ class TestDFlashServerBase(CustomTestCase, MatchedStopMixin, GSM8KMixin): | |||
| attention_backend = "flashinfer" | |||
There was a problem hiding this comment.
qq: Does dflash only support flashinfer?
There was a problem hiding this comment.
DFlash supports fa3, fa4, flashinfer, and triton as speculative draft attention backends. For best performance on b200s you can mix trtllm_mha target attention backend with fa4 draft attention backend. trtllm_mha attention isn't supported for DFlash draft since it requires non-causal full_attention/ENCODER_ONLY.
| @@ -110,6 +97,23 @@ def _lazy_init_buf(self, draft_input: EagleDraftInput): | |||
| device=self.device, | |||
| ) | |||
|
|
|||
| if self.spec_algo.is_dflash(): | |||
There was a problem hiding this comment.
nit: I prefer adding a more general function (something like need_topk) instead of checking whether it's dflash here. What do you think?
There was a problem hiding this comment.
yes agreed, I changed it to SpeculativeAlgorithm.need_topk() instead of special casing it to DFlash
| logger.warning( | ||
| "Overlap scheduler is disabled when using DFLASH speculative decoding (spec v2 is not supported yet)." | ||
| ) | ||
| if envs.SGLANG_ENABLE_SPEC_V2.get(): |
There was a problem hiding this comment.
spec v2 is opened by default. the logic here may need to be changed
There was a problem hiding this comment.
yes let me do a merge from main and I will update the logic here
There was a problem hiding this comment.
ok updated the logic now that spec v2 is default
|
FYI: I tried this PR's Results (warm, temperature=0, short prompts):
One small thing that wasn't covered by this PR: Gemma 4 ties # in gemma4_causal.py
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbeddingShardIndices,
)
def _ensure_dflash_shard_indices(lm_head, vocab_size: int) -> None:
if getattr(lm_head, \"shard_indices\", None) is not None:
return
lm_head.shard_indices = VocabParallelEmbeddingShardIndices(
padded_org_vocab_start_index=0,
padded_org_vocab_end_index=vocab_size,
padded_added_vocab_start_index=vocab_size,
padded_added_vocab_end_index=vocab_size,
org_vocab_start_index=0,
org_vocab_end_index=vocab_size,
added_vocab_start_index=vocab_size,
added_vocab_end_index=vocab_size,
)
# call after the lm_head assignment in __init__ of
# Gemma4ForCausalLM and Gemma4ForConditionalGeneration:
_ensure_dflash_shard_indices(self.lm_head, vocab_size)Not asking for changes to this PR (its scope is V2 / overlap scheduling). |
|
@dssugar feel free to put up a PR to add gemma 4 support to v1! |
…fore target verify
|
I added a DFlash only prefill refill heuristic for online scheduling, I think it can be removed when mixed prefill decode is fully implemented. Previously, the scheduler would admit new prefill as soon as a single running request finished, at max concurrency this produced many one request prefill batches near full decode occupancy that massively reduced throughput. The heuristic now waits until a small target number of running request slots are free before admitting prefill work to refill together. I ran extensive sweeps and ablations using many target / draft model combinations and set a good default target of 2/3/4/4 for max-running 8/16/32/64. An env override SGLANG_DFLASH_PREFILL_REFILL_TARGET remains for benchmarking/restoring old immediate-refill behavior. This could be more generalized to other spec methods but I'm keeping it DFlash specific since all the benchmarking evidence and tuning is for DFlash's refill cadence. Other spec algorithms are unchanged.
|
|
@dcw02 any progess on this? |
Motivation
Add spec v2 to DFlash
Benchmarks
Run on gcp b200:8 node, using a gsm8k sweep script, qwen3-8b target, z-lab/Qwen3-8B-DFlash-b16 draft model, trtllm_mha target attention, fa4 draft attention, piecewise cuda graphs on.
v1 performancev2 performancethis spec v2 version also brings in some extra optimizations compared to #20547 which brought bs1 performance from
900 -> 1161 tok/sand bs32 from12,300 -> 15,326 tok/s.Benchmarking is done with this script using the command
SGLANG_ENABLE_SPEC_V2=1 SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1 python benchmark/dflash/bench_dflash_gsm8k_sweep.py --skip-baseline --tp-sizes 1 --concurrencies 1,32 --attention-backends trtllm_mha --speculative-draft-attention-backend fa4on 1xB200i removed mamba memory calculations to add later once i figure out the best way to do that