Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
10e563f
starting dflash impl
dcw02 Jan 6, 2026
289e748
fix verify mismatch
dcw02 Jan 7, 2026
f1efc03
add gsm8k bench
dcw02 Jan 7, 2026
e807216
support more backends, investigate accuracy
dcw02 Jan 7, 2026
99e140a
native sglang backend
dcw02 Jan 8, 2026
2c64b0e
remove hf backend
dcw02 Jan 8, 2026
f1a4262
dflash support flashinfer
dcw02 Jan 8, 2026
2c5b346
remove manual management of dflash kv pool
dcw02 Jan 8, 2026
6a38e63
add cuda graph
dcw02 Jan 8, 2026
40a81af
add cuda graph to draft worker
dcw02 Jan 8, 2026
510bf0c
update test
dcw02 Jan 8, 2026
c54f336
fix flashinfer backend
dcw02 Jan 9, 2026
8c8ee9c
initial radix cache support
dcw02 Jan 9, 2026
0edea3f
tp_size > 1 support
dcw02 Jan 9, 2026
f23555b
add optional dflash_config for overrides, add --speculative-dflash-bl…
dcw02 Jan 10, 2026
63c0b9a
fix OOMs with default settings
dcw02 Jan 10, 2026
9309764
clean up
dcw02 Jan 10, 2026
644ab29
clean up dflash load_weights
dcw02 Jan 10, 2026
ff6876a
attention selection logic
dcw02 Jan 10, 2026
32c3dd0
Merge remote-tracking branch 'upstream/main' into dflash
dcw02 Jan 10, 2026
d808ac9
decouple context feature count K from draft num layers
dcw02 Jan 12, 2026
e589ac1
clean up naming
dcw02 Jan 12, 2026
074efb2
performance optimizations
dcw02 Jan 12, 2026
fcc9bf7
skip Q, fused mlp
dcw02 Jan 12, 2026
a79264f
reuse buffers for decode
dcw02 Jan 12, 2026
ad5adbf
optimize greedy sampling
dcw02 Jan 12, 2026
37fc3f1
preallocate for tp>1
dcw02 Jan 12, 2026
72cbd9d
more buffers for tp>1
dcw02 Jan 12, 2026
5a577a3
dflash gsm8k benchmark sweep
dcw02 Jan 12, 2026
d968532
fix benchmark
dcw02 Jan 12, 2026
3e4177d
use device tensors for ctx_lens/draft_seq_lens, vectorize kv append a…
dcw02 Jan 20, 2026
b5b4bd6
precommit fixes
dcw02 Jan 20, 2026
ed8b16d
feat(dflash): add fused KV materialization kernel and optimize D2H
xiaomin-D Jan 20, 2026
f2a6dbc
add support for qwen3_moe
dcw02 Jan 20, 2026
117352d
support dflash_config.mask_token_id
dcw02 Feb 5, 2026
5ba316c
add llama3.1 support and fix config block_size logic
dcw02 Feb 5, 2026
5f8d0ec
Merge pull request #15 from yilian49/pr16818
dcw02 Feb 12, 2026
189f177
guards for fused path
dcw02 Feb 12, 2026
0841db6
add support for gpt oss
dcw02 Feb 12, 2026
56477b9
clean up
dcw02 Feb 12, 2026
9c0242d
Merge upstream sgl-project/main -> dflash and fix conflicts
dcw02 Feb 12, 2026
d9c68a1
add qwen3-coder-next support (mamba)
dcw02 Feb 24, 2026
7e189bd
add page size > 1 support
dcw02 Feb 24, 2026
7a739f8
non greedy
dcw02 Feb 25, 2026
8a6bec9
Merge branch 'main' of github.com:sgl-project/sglang into dflash
dcw02 Feb 25, 2026
0fe389d
rope rotation support
dcw02 Feb 26, 2026
a134f0a
clean up schedule_batch.py
dcw02 Feb 26, 2026
f62e5de
fix auto memory oom, cleanup
dcw02 Feb 28, 2026
2cc5f07
clean up
dcw02 Feb 28, 2026
760870c
Merge branch 'sgl-project:main' into dflash
dcw02 Feb 28, 2026
3b66746
Merge branch 'main' of github.com:sgl-project/sglang into dflash
dcw02 Mar 1, 2026
6913603
Merge branch 'main' of github.com:sgl-project/sglang into dflash
dcw02 Mar 2, 2026
26441b8
initial fa4 support to dflash, clean up benchmarking script
dcw02 Mar 2, 2026
74814de
clean up
dcw02 Mar 2, 2026
e493353
only run baseline once
dcw02 Mar 2, 2026
bdb2fda
Merge branch 'main' into dflash
dcw02 Mar 8, 2026
7878ec4
fix server startup timeout for autotuning
dcw02 Mar 8, 2026
5d689d7
qwen3_5 support
dcw02 Mar 8, 2026
277fe05
add draft swa
dcw02 Mar 12, 2026
e5ef869
avoid OOB in masked req_to_token gathers
dcw02 Mar 16, 2026
8063b65
clean up model support
dcw02 Mar 24, 2026
fd3aa66
clean up benchmarking script
dcw02 Mar 24, 2026
180d2a8
clean up dflash cuda graph runner paths
dcw02 Mar 24, 2026
ab97a91
clean up dflash request validation
dcw02 Mar 24, 2026
6231c4b
clean up stop string handling
dcw02 Mar 24, 2026
2a48abf
inline stop strings logic
dcw02 Mar 24, 2026
bea986b
dflash tests
dcw02 Mar 25, 2026
5e234f8
lazy import dflash_utils in model_runner
hnyls2002 Apr 7, 2026
aff48a2
lazy import dflash_utils in kv_cache_mixin
hnyls2002 Apr 7, 2026
0f0132b
migrate dflash infer_a to infer_b; rename beta to b; delete engine test
hnyls2002 Apr 7, 2026
dab5401
use GSM8KMixin; rename to test_dflash.py
hnyls2002 Apr 7, 2026
fca00c7
delete test_dflash_basic; add accept_length_thres; remove unused fixture
hnyls2002 Apr 7, 2026
d40a6a4
Merge branch 'main' into dflash
hnyls2002 Apr 7, 2026
085bbb7
fix suite name: stage-b-test-1-gpu-small
hnyls2002 Apr 7, 2026
80a238a
fix GSM8KMixin import path
hnyls2002 Apr 7, 2026
2acfdde
Merge branch 'main' into dflash
hnyls2002 Apr 7, 2026
97017f2
pass memory_pool_config to draft worker
hnyls2002 Apr 7, 2026
2e2e349
Merge remote-tracking branch 'modal-projects/dflash' into dflash
hnyls2002 Apr 7, 2026
213bf69
set gsm8k accept_length_thres to 2.8
hnyls2002 Apr 7, 2026
2b833bb
gsm8k thresholds: accuracy 0.75, accept_length 2.8
hnyls2002 Apr 7, 2026
16c61a6
reduce CI time: drop page64; radix test only on page256 with 50 nodes
hnyls2002 Apr 7, 2026
5fcae0e
Merge branch 'main' into dflash
hnyls2002 Apr 7, 2026
ef360dd
increase max_running_requests to 128
hnyls2002 Apr 7, 2026
2a3f304
Merge remote-tracking branch 'modal-projects/dflash' into dflash
hnyls2002 Apr 7, 2026
a4a2504
add mem-fraction-static 0.65 for 5090
hnyls2002 Apr 7, 2026
f937b28
revert to max_running 64; est_time 300
hnyls2002 Apr 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,8 +596,24 @@ def init_forward_metadata_capture_cuda_graph(
fast_decode_plan, decode_wrappers[i]
)
elif forward_mode.is_target_verify():
# FlashInfer's prefill wrapper decides mask mode based on whether
# `custom_mask_buf` is initialized (not whether a custom mask is provided).
# For cases like DFLASH draft (ENCODER_ONLY / non-causal) we do NOT use a
# custom mask, so we must avoid initializing `custom_mask_buf`, otherwise
# FlashInfer will treat the (zero) buffer as a real mask and block attention.
use_custom_mask = (
spec_info is not None
and getattr(spec_info, "custom_mask", None) is not None
)
prefill_wrappers = []
for i in range(self.num_wrappers):
wrapper_kwargs = {}
if use_custom_mask:
wrapper_kwargs = {
"custom_mask_buf": self.cuda_graph_custom_mask,
"mask_indptr_buf": self.cuda_graph_qk_indptr[i][: bs + 1],
}

prefill_wrappers.append(
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
Expand All @@ -608,8 +624,7 @@ def init_forward_metadata_capture_cuda_graph(
paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1],
paged_kv_indices_buf=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],
custom_mask_buf=self.cuda_graph_custom_mask,
mask_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1],
**wrapper_kwargs,
)
)
seq_lens_sum = seq_lens.sum().item()
Expand Down Expand Up @@ -783,10 +798,14 @@ def forward_extend(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)

causal = (
not layer.is_cross_attention
and layer.attn_type != AttentionType.ENCODER_ONLY
)
o = prefill_wrapper_paged.forward(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=not layer.is_cross_attention,
causal=causal,
sm_scale=layer.scaling,
# Disable sliding window attention for multi-item scoring:
# - Sliding window could cut across item boundaries, breaking semantic coherence
Expand Down Expand Up @@ -838,11 +857,6 @@ def forward_extend(
)

else:
if not self.is_dllm_model:
# TODO: design a better interface
# For other models, use causal attention for the ragged part as previously
causal = True

o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
Expand Down
26 changes: 26 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,24 @@ def copy_to_cpu(self):
self.copy_done.record()


def validate_dflash_request(req: Req) -> Optional[str]:
if req.return_logprob:
return "DFLASH speculative decoding does not support return_logprob yet."

if (
req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None
or req.sampling_params.ebnf is not None
or req.sampling_params.structural_tag is not None
):
return (
"DFLASH speculative decoding does not support "
"grammar-constrained decoding yet."
)

return None


class Scheduler(
SchedulerOutputProcessorMixin,
SchedulerUpdateWeightsMixin,
Expand Down Expand Up @@ -1861,6 +1879,14 @@ def handle_generate_request(
self._add_request_to_queue(req)
return

if self.spec_algorithm.is_dflash():
error_msg = validate_dflash_request(req)
if error_msg is not None:
req.set_finish_with_abort(error_msg)
self.init_req_max_new_tokens(req)
self._add_request_to_queue(req)
return

# Handle multimodal inputs
if recv_req.mm_inputs is not None:
image_inputs = self._get_multimodal_inputs(recv_req.mm_inputs)
Expand Down
92 changes: 79 additions & 13 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,18 +547,15 @@ def __init__(self, model_runner: ModelRunner):
self.capture_forward_mode = ForwardMode.DECODE
self.capture_hidden_mode = CaptureHiddenMode.NULL
self.num_tokens_per_bs = 1
if (
model_runner.spec_algorithm.is_eagle()
or model_runner.spec_algorithm.is_standalone()
or model_runner.spec_algorithm.is_ngram()
):
if model_runner.spec_algorithm.is_speculative():
if self.model_runner.is_draft_worker:
raise RuntimeError("This should not happen")
else:
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
self.num_tokens_per_bs = (
self.model_runner.server_args.speculative_num_draft_tokens
)
# DFLASH draft workers reuse this runner for TARGET_VERIFY mode.
if not self.model_runner.spec_algorithm.is_dflash():
raise RuntimeError("This should not happen")
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
self.num_tokens_per_bs = (
self.model_runner.server_args.speculative_num_draft_tokens
)
elif self.is_dllm:
self.capture_forward_mode = ForwardMode.DLLM_EXTEND
self.num_tokens_per_bs = self.dllm_config.block_size
Expand Down Expand Up @@ -646,6 +643,18 @@ def __init__(self, model_runner: ModelRunner):
and model_runner.eagle_use_aux_hidden_state
):
self.model_runner.model.set_eagle3_layers_to_capture()
if (
model_runner.spec_algorithm.is_dflash()
and model_runner.dflash_use_aux_hidden_state
):
if not hasattr(self.model_runner.model, "set_dflash_layers_to_capture"):
raise ValueError(
f"Model {self.model_runner.model.__class__.__name__} does not implement set_dflash_layers_to_capture, "
"which is required for DFLASH aux hidden capture."
)
self.model_runner.model.set_dflash_layers_to_capture(
self.model_runner.dflash_target_layer_ids
)

# Capture
try:
Expand All @@ -671,6 +680,7 @@ def can_run(self, forward_batch: ForwardBatch):
max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
or self.model_runner.spec_algorithm.is_standalone()
or self.model_runner.spec_algorithm.is_dflash()
else max(forward_batch.global_num_tokens_cpu)
)
else:
Expand Down Expand Up @@ -1007,6 +1017,12 @@ def run_once():
kwargs["pp_proxy_tensors"] = PPProxyTensors(
{k: v.clone() for k, v in pp_proxy_tensors.tensors.items()}
)
if (
self.model_runner.spec_algorithm.is_dflash()
and self.model_runner.is_draft_worker
and "input_embeds" in inspect.signature(forward).parameters
):
kwargs["input_embeds"] = buffers.input_embeds[:num_tokens]

logits_output_or_pp_proxy_tensors = forward(
input_ids,
Expand Down Expand Up @@ -1083,6 +1099,7 @@ def replay_prepare(
max_num_tokens / self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
or self.model_runner.spec_algorithm.is_standalone()
or self.model_runner.spec_algorithm.is_dflash()
else max_num_tokens
)
index = bisect.bisect_left(self.capture_bs, max_batch_size)
Expand All @@ -1104,6 +1121,13 @@ def replay_prepare(
),
pp_proxy_tensors=pp_proxy_tensors,
)
if (
self.model_runner.spec_algorithm.is_dflash()
and self.model_runner.is_draft_worker
and forward_batch.input_embeds is not None
):
buffers.input_embeds[:raw_num_token].copy_(forward_batch.input_embeds)
# Padded tokens aren't read, so skip zeroing them.
if self.enable_two_batch_overlap:
self.tbo_plugin.replay_prepare(
forward_mode=self.capture_forward_mode,
Expand Down Expand Up @@ -1152,6 +1176,14 @@ def replay(
# In speculative decoding, these two fields are still needed.
self.buffers.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
self.buffers.positions[: self.raw_num_token].copy_(forward_batch.positions)
if (
self.model_runner.spec_algorithm.is_dflash()
and self.model_runner.is_draft_worker
and forward_batch.input_embeds is not None
):
self.buffers.input_embeds[: self.raw_num_token].copy_(
forward_batch.input_embeds
)

# Replay
if self.enable_pdmux:
Expand All @@ -1164,10 +1196,18 @@ def replay(
if isinstance(output, LogitsProcessorOutput):
if self.is_dllm:
next_token_logits = None
full_logits = output.full_logits[: self.raw_num_token]
full_logits = (
output.full_logits[: self.raw_num_token]
if output.full_logits is not None
else None
)
else:
full_logits = None
next_token_logits = output.next_token_logits[: self.raw_num_token]
next_token_logits = (
output.next_token_logits[: self.raw_num_token]
if output.next_token_logits is not None
else None
)

return LogitsProcessorOutput(
next_token_logits=next_token_logits,
Expand Down Expand Up @@ -1209,6 +1249,32 @@ def get_spec_info(self, num_tokens: int):
seq_lens_sum=None,
seq_lens_cpu=None,
)
elif self.model_runner.spec_algorithm.is_dflash():
from sglang.srt.speculative.dflash_info import DFlashVerifyInput
from sglang.srt.speculative.dflash_utils import (
resolve_dflash_verify_mask_policy,
)

# Avoid enabling custom-mask modes during graph capture for backends that
# can express DFLASH verify via their built-in causal path.
_, build_custom_mask = resolve_dflash_verify_mask_policy(
self.model_runner.attn_backend
)
spec_info = DFlashVerifyInput(
draft_token=None,
positions=None,
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
custom_mask=(
None
if (self.model_runner.is_draft_worker or not build_custom_mask)
else self.buffers.custom_mask
),
capture_hidden_mode=(
CaptureHiddenMode.NULL
if self.model_runner.is_draft_worker
else CaptureHiddenMode.FULL
),
)

elif self.model_runner.spec_algorithm.is_ngram():
from sglang.srt.speculative.ngram_info import NgramVerifyInput
Expand Down
Loading
Loading