Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
10e563f
starting dflash impl
dcw02 Jan 6, 2026
289e748
fix verify mismatch
dcw02 Jan 7, 2026
f1efc03
add gsm8k bench
dcw02 Jan 7, 2026
e807216
support more backends, investigate accuracy
dcw02 Jan 7, 2026
99e140a
native sglang backend
dcw02 Jan 8, 2026
2c64b0e
remove hf backend
dcw02 Jan 8, 2026
f1a4262
dflash support flashinfer
dcw02 Jan 8, 2026
2c5b346
remove manual management of dflash kv pool
dcw02 Jan 8, 2026
6a38e63
add cuda graph
dcw02 Jan 8, 2026
40a81af
add cuda graph to draft worker
dcw02 Jan 8, 2026
510bf0c
update test
dcw02 Jan 8, 2026
c54f336
fix flashinfer backend
dcw02 Jan 9, 2026
8c8ee9c
initial radix cache support
dcw02 Jan 9, 2026
0edea3f
tp_size > 1 support
dcw02 Jan 9, 2026
f23555b
add optional dflash_config for overrides, add --speculative-dflash-bl…
dcw02 Jan 10, 2026
63c0b9a
fix OOMs with default settings
dcw02 Jan 10, 2026
9309764
clean up
dcw02 Jan 10, 2026
644ab29
clean up dflash load_weights
dcw02 Jan 10, 2026
ff6876a
attention selection logic
dcw02 Jan 10, 2026
32c3dd0
Merge remote-tracking branch 'upstream/main' into dflash
dcw02 Jan 10, 2026
d808ac9
decouple context feature count K from draft num layers
dcw02 Jan 12, 2026
e589ac1
clean up naming
dcw02 Jan 12, 2026
074efb2
performance optimizations
dcw02 Jan 12, 2026
fcc9bf7
skip Q, fused mlp
dcw02 Jan 12, 2026
a79264f
reuse buffers for decode
dcw02 Jan 12, 2026
ad5adbf
optimize greedy sampling
dcw02 Jan 12, 2026
37fc3f1
preallocate for tp>1
dcw02 Jan 12, 2026
72cbd9d
more buffers for tp>1
dcw02 Jan 12, 2026
5a577a3
dflash gsm8k benchmark sweep
dcw02 Jan 12, 2026
d968532
fix benchmark
dcw02 Jan 12, 2026
3e4177d
use device tensors for ctx_lens/draft_seq_lens, vectorize kv append a…
dcw02 Jan 20, 2026
b5b4bd6
precommit fixes
dcw02 Jan 20, 2026
ed8b16d
feat(dflash): add fused KV materialization kernel and optimize D2H
xiaomin-D Jan 20, 2026
f2a6dbc
add support for qwen3_moe
dcw02 Jan 20, 2026
117352d
support dflash_config.mask_token_id
dcw02 Feb 5, 2026
5ba316c
add llama3.1 support and fix config block_size logic
dcw02 Feb 5, 2026
5f8d0ec
Merge pull request #15 from yilian49/pr16818
dcw02 Feb 12, 2026
189f177
guards for fused path
dcw02 Feb 12, 2026
0841db6
add support for gpt oss
dcw02 Feb 12, 2026
56477b9
clean up
dcw02 Feb 12, 2026
9c0242d
Merge upstream sgl-project/main -> dflash and fix conflicts
dcw02 Feb 12, 2026
d9c68a1
add qwen3-coder-next support (mamba)
dcw02 Feb 24, 2026
7e189bd
add page size > 1 support
dcw02 Feb 24, 2026
7a739f8
non greedy
dcw02 Feb 25, 2026
8a6bec9
Merge branch 'main' of github.com:sgl-project/sglang into dflash
dcw02 Feb 25, 2026
0fe389d
rope rotation support
dcw02 Feb 26, 2026
a134f0a
clean up schedule_batch.py
dcw02 Feb 26, 2026
f62e5de
fix auto memory oom, cleanup
dcw02 Feb 28, 2026
2cc5f07
clean up
dcw02 Feb 28, 2026
760870c
Merge branch 'sgl-project:main' into dflash
dcw02 Feb 28, 2026
3b66746
Merge branch 'main' of github.com:sgl-project/sglang into dflash
dcw02 Mar 1, 2026
6913603
Merge branch 'main' of github.com:sgl-project/sglang into dflash
dcw02 Mar 2, 2026
26441b8
initial fa4 support to dflash, clean up benchmarking script
dcw02 Mar 2, 2026
74814de
clean up
dcw02 Mar 2, 2026
e493353
only run baseline once
dcw02 Mar 2, 2026
bdb2fda
Merge branch 'main' into dflash
dcw02 Mar 8, 2026
7878ec4
fix server startup timeout for autotuning
dcw02 Mar 8, 2026
5d689d7
qwen3_5 support
dcw02 Mar 8, 2026
277fe05
add draft swa
dcw02 Mar 12, 2026
a992cda
initial spec v2 dflash port
dcw02 Mar 12, 2026
c92f0c2
enabled fused kv for v2
dcw02 Mar 12, 2026
e57de00
add draft swa to overlap
dcw02 Mar 12, 2026
8017f43
add optimized non-greedy decoding
dcw02 Mar 14, 2026
7759843
avoid OOB in masked req_to_token gathers
dcw02 Mar 16, 2026
5b7ebd2
clean up dflash cuda graph runner paths
dcw02 Mar 24, 2026
9314e5e
dflash spec v2 changes for cuda graph runner changes
dcw02 Mar 29, 2026
6b6683c
clean up dflash request validation
dcw02 Mar 24, 2026
4860317
clean up stop string handling
dcw02 Mar 24, 2026
e825041
inline stop strings logic
dcw02 Mar 24, 2026
c62c37c
fix cuda IMA?
dcw02 Mar 29, 2026
fed6b3e
fix cuda IMA for bs > 1 and overlap plan streams
dcw02 Mar 31, 2026
339f25e
messy auto memory calculation for hybrid models for dflash (include s…
dcw02 Mar 31, 2026
4926ca2
update auto memory sizing
dcw02 Apr 1, 2026
e67a0d4
add dflash support for kimi k2.5
dcw02 Apr 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
802 changes: 802 additions & 0 deletions benchmark/dflash/bench_dflash_gsm8k_sweep.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ class Envs:
# Overlap Spec V2
SGLANG_ENABLE_SPEC_V2 = EnvBool(False)
SGLANG_ENABLE_OVERLAP_PLAN_STREAM = EnvBool(False)
SGLANG_ENABLE_DFLASH_SPEC_V2 = EnvBool(False)

# Spec Config
SGLANG_SPEC_ENABLE_STRICT_FILTER_CHECK = EnvBool(True)
Expand Down
30 changes: 22 additions & 8 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,8 +591,24 @@ def init_forward_metadata_capture_cuda_graph(
fast_decode_plan, decode_wrappers[i]
)
elif forward_mode.is_target_verify():
# FlashInfer's prefill wrapper decides mask mode based on whether
# `custom_mask_buf` is initialized (not whether a custom mask is provided).
# For cases like DFLASH draft (ENCODER_ONLY / non-causal) we do NOT use a
# custom mask, so we must avoid initializing `custom_mask_buf`, otherwise
# FlashInfer will treat the (zero) buffer as a real mask and block attention.
use_custom_mask = (
spec_info is not None
and getattr(spec_info, "custom_mask", None) is not None
)
prefill_wrappers = []
for i in range(self.num_wrappers):
wrapper_kwargs = {}
if use_custom_mask:
wrapper_kwargs = {
"custom_mask_buf": self.cuda_graph_custom_mask,
"mask_indptr_buf": self.cuda_graph_qk_indptr[i][: bs + 1],
}

prefill_wrappers.append(
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
Expand All @@ -603,8 +619,7 @@ def init_forward_metadata_capture_cuda_graph(
paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1],
paged_kv_indices_buf=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],
custom_mask_buf=self.cuda_graph_custom_mask,
mask_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1],
**wrapper_kwargs,
)
)
seq_lens_sum = seq_lens.sum().item()
Expand Down Expand Up @@ -777,10 +792,14 @@ def forward_extend(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)

causal = (
not layer.is_cross_attention
and layer.attn_type != AttentionType.ENCODER_ONLY
)
o = prefill_wrapper_paged.forward(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=not layer.is_cross_attention,
causal=causal,
sm_scale=layer.scaling,
# Disable sliding window attention for multi-item scoring:
# - Sliding window could cut across item boundaries, breaking semantic coherence
Expand Down Expand Up @@ -832,11 +851,6 @@ def forward_extend(
)

else:
if not self.is_dllm_model:
# TODO: design a better interface
# For other models, use causal attention for the ragged part as previously
causal = True

o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
Expand Down
18 changes: 14 additions & 4 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2036,10 +2036,20 @@ def prepare_for_decode(self):
)

def maybe_wait_verify_done(self):
if self.is_spec_v2:
draft_input: EagleDraftInput = self.spec_info
if draft_input.verify_done is not None:
draft_input.verify_done.synchronize()
if not self.is_spec_v2:
return

draft_input: EagleDraftInput = self.spec_info
verify_done = getattr(draft_input, "verify_done", None)
if verify_done is None:
return

if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get():
torch.get_device_module(self.device).current_stream().wait_event(
verify_done
)
else:
verify_done.synchronize()

def filter_batch(
self,
Expand Down
28 changes: 28 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,27 @@ def copy_to_cpu(self):
self.copy_done.record()


def validate_dflash_request(req: Req, enable_overlap: bool) -> Optional[str]:
if req.return_logprob:
return "DFLASH speculative decoding does not support return_logprob yet."

if enable_overlap and req.return_hidden_states:
return "DFLASH speculative decoding does not support return_hidden_states yet."

if (
req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None
or req.sampling_params.ebnf is not None
or req.sampling_params.structural_tag is not None
):
return (
"DFLASH speculative decoding does not support "
"grammar-constrained decoding yet."
)

return None


class Scheduler(
SchedulerOutputProcessorMixin,
SchedulerUpdateWeightsMixin,
Expand Down Expand Up @@ -1633,6 +1654,13 @@ def handle_generate_request(
self._add_request_to_queue(req)
return

if self.spec_algorithm.is_dflash():
error_msg = validate_dflash_request(req, self.enable_overlap)
if error_msg is not None:
req.set_finish_with_abort(error_msg)
self.init_req_max_new_tokens(req)
self._add_request_to_queue(req)
return
# Handle multimodal inputs
if recv_req.mm_inputs is not None:
image_inputs = self._get_multimodal_inputs(recv_req.mm_inputs)
Expand Down
92 changes: 79 additions & 13 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,18 +472,15 @@ def __init__(self, model_runner: ModelRunner):
self.capture_forward_mode = ForwardMode.DECODE
self.capture_hidden_mode = CaptureHiddenMode.NULL
self.num_tokens_per_bs = 1
if (
model_runner.spec_algorithm.is_eagle()
or model_runner.spec_algorithm.is_standalone()
or model_runner.spec_algorithm.is_ngram()
):
if model_runner.spec_algorithm.is_speculative():
if self.model_runner.is_draft_worker:
raise RuntimeError("This should not happen")
else:
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
self.num_tokens_per_bs = (
self.model_runner.server_args.speculative_num_draft_tokens
)
# DFLASH draft workers reuse this runner for TARGET_VERIFY mode.
if not self.model_runner.spec_algorithm.is_dflash():
raise RuntimeError("This should not happen")
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
self.num_tokens_per_bs = (
self.model_runner.server_args.speculative_num_draft_tokens
)
elif self.is_dllm:
self.capture_forward_mode = ForwardMode.DLLM_EXTEND
self.num_tokens_per_bs = self.dllm_config.block_size
Expand Down Expand Up @@ -560,6 +557,18 @@ def __init__(self, model_runner: ModelRunner):
and model_runner.eagle_use_aux_hidden_state
):
self.model_runner.model.set_eagle3_layers_to_capture()
if (
model_runner.spec_algorithm.is_dflash()
and model_runner.dflash_use_aux_hidden_state
):
if not hasattr(self.model_runner.model, "set_dflash_layers_to_capture"):
raise ValueError(
f"Model {self.model_runner.model.__class__.__name__} does not implement set_dflash_layers_to_capture, "
"which is required for DFLASH aux hidden capture."
)
self.model_runner.model.set_dflash_layers_to_capture(
self.model_runner.dflash_target_layer_ids
)

# Capture
try:
Expand All @@ -585,6 +594,7 @@ def can_run(self, forward_batch: ForwardBatch):
max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
or self.model_runner.spec_algorithm.is_standalone()
or self.model_runner.spec_algorithm.is_dflash()
else max(forward_batch.global_num_tokens_cpu)
)
else:
Expand Down Expand Up @@ -912,6 +922,12 @@ def run_once():
kwargs["pp_proxy_tensors"] = PPProxyTensors(
{k: v.clone() for k, v in pp_proxy_tensors.tensors.items()}
)
if (
self.model_runner.spec_algorithm.is_dflash()
and self.model_runner.is_draft_worker
and "input_embeds" in inspect.signature(forward).parameters
):
kwargs["input_embeds"] = buffers.input_embeds[:num_tokens]

logits_output_or_pp_proxy_tensors = forward(
input_ids,
Expand Down Expand Up @@ -988,6 +1004,7 @@ def replay_prepare(
max_num_tokens / self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
or self.model_runner.spec_algorithm.is_standalone()
or self.model_runner.spec_algorithm.is_dflash()
else max_num_tokens
)
index = bisect.bisect_left(self.capture_bs, max_batch_size)
Expand All @@ -1009,6 +1026,13 @@ def replay_prepare(
),
pp_proxy_tensors=pp_proxy_tensors,
)
if (
self.model_runner.spec_algorithm.is_dflash()
and self.model_runner.is_draft_worker
and forward_batch.input_embeds is not None
):
buffers.input_embeds[:raw_num_token].copy_(forward_batch.input_embeds)
# Padded tokens aren't read, so skip zeroing them.
if self.enable_two_batch_overlap:
self.tbo_plugin.replay_prepare(
forward_mode=self.capture_forward_mode,
Expand Down Expand Up @@ -1054,6 +1078,14 @@ def replay(
# In speculative decoding, these two fields are still needed.
self.buffers.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
self.buffers.positions[: self.raw_num_token].copy_(forward_batch.positions)
if (
self.model_runner.spec_algorithm.is_dflash()
and self.model_runner.is_draft_worker
and forward_batch.input_embeds is not None
):
self.buffers.input_embeds[: self.raw_num_token].copy_(
forward_batch.input_embeds
)

# Replay
if self.enable_pdmux:
Expand All @@ -1066,10 +1098,18 @@ def replay(
if isinstance(output, LogitsProcessorOutput):
if self.is_dllm:
next_token_logits = None
full_logits = output.full_logits[: self.raw_num_token]
full_logits = (
output.full_logits[: self.raw_num_token]
if output.full_logits is not None
else None
)
else:
full_logits = None
next_token_logits = output.next_token_logits[: self.raw_num_token]
next_token_logits = (
output.next_token_logits[: self.raw_num_token]
if output.next_token_logits is not None
else None
)

return LogitsProcessorOutput(
next_token_logits=next_token_logits,
Expand Down Expand Up @@ -1111,6 +1151,32 @@ def get_spec_info(self, num_tokens: int):
seq_lens_sum=None,
seq_lens_cpu=None,
)
elif self.model_runner.spec_algorithm.is_dflash():
from sglang.srt.speculative.dflash_info import DFlashVerifyInput
from sglang.srt.speculative.dflash_utils import (
resolve_dflash_verify_mask_policy,
)

# Avoid enabling custom-mask modes during graph capture for backends that
# can express DFLASH verify via their built-in causal path.
_, build_custom_mask = resolve_dflash_verify_mask_policy(
self.model_runner.attn_backend
)
spec_info = DFlashVerifyInput(
draft_token=None,
positions=None,
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
custom_mask=(
None
if (self.model_runner.is_draft_worker or not build_custom_mask)
else self.buffers.custom_mask
),
capture_hidden_mode=(
CaptureHiddenMode.NULL
if self.model_runner.is_draft_worker
else CaptureHiddenMode.FULL
),
)

elif self.model_runner.spec_algorithm.is_ngram():
from sglang.srt.speculative.ngram_info import NgramVerifyInput
Expand Down
Loading
Loading