Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
10e563f
starting dflash impl
dcw02 Jan 6, 2026
289e748
fix verify mismatch
dcw02 Jan 7, 2026
f1efc03
add gsm8k bench
dcw02 Jan 7, 2026
e807216
support more backends, investigate accuracy
dcw02 Jan 7, 2026
99e140a
native sglang backend
dcw02 Jan 8, 2026
2c64b0e
remove hf backend
dcw02 Jan 8, 2026
f1a4262
dflash support flashinfer
dcw02 Jan 8, 2026
2c5b346
remove manual management of dflash kv pool
dcw02 Jan 8, 2026
6a38e63
add cuda graph
dcw02 Jan 8, 2026
40a81af
add cuda graph to draft worker
dcw02 Jan 8, 2026
510bf0c
update test
dcw02 Jan 8, 2026
c54f336
fix flashinfer backend
dcw02 Jan 9, 2026
8c8ee9c
initial radix cache support
dcw02 Jan 9, 2026
0edea3f
tp_size > 1 support
dcw02 Jan 9, 2026
f23555b
add optional dflash_config for overrides, add --speculative-dflash-bl…
dcw02 Jan 10, 2026
63c0b9a
fix OOMs with default settings
dcw02 Jan 10, 2026
9309764
clean up
dcw02 Jan 10, 2026
644ab29
clean up dflash load_weights
dcw02 Jan 10, 2026
ff6876a
attention selection logic
dcw02 Jan 10, 2026
32c3dd0
Merge remote-tracking branch 'upstream/main' into dflash
dcw02 Jan 10, 2026
d808ac9
decouple context feature count K from draft num layers
dcw02 Jan 12, 2026
e589ac1
clean up naming
dcw02 Jan 12, 2026
074efb2
performance optimizations
dcw02 Jan 12, 2026
fcc9bf7
skip Q, fused mlp
dcw02 Jan 12, 2026
a79264f
reuse buffers for decode
dcw02 Jan 12, 2026
ad5adbf
optimize greedy sampling
dcw02 Jan 12, 2026
37fc3f1
preallocate for tp>1
dcw02 Jan 12, 2026
72cbd9d
more buffers for tp>1
dcw02 Jan 12, 2026
5a577a3
dflash gsm8k benchmark sweep
dcw02 Jan 12, 2026
d968532
fix benchmark
dcw02 Jan 12, 2026
3e4177d
use device tensors for ctx_lens/draft_seq_lens, vectorize kv append a…
dcw02 Jan 20, 2026
b5b4bd6
precommit fixes
dcw02 Jan 20, 2026
ed8b16d
feat(dflash): add fused KV materialization kernel and optimize D2H
xiaomin-D Jan 20, 2026
f2a6dbc
add support for qwen3_moe
dcw02 Jan 20, 2026
117352d
support dflash_config.mask_token_id
dcw02 Feb 5, 2026
5ba316c
add llama3.1 support and fix config block_size logic
dcw02 Feb 5, 2026
3988d4e
[feat] dflash for vlm-qwen3
EanWang211123 Feb 6, 2026
5f8d0ec
Merge pull request #15 from yilian49/pr16818
dcw02 Feb 12, 2026
189f177
guards for fused path
dcw02 Feb 12, 2026
0841db6
add support for gpt oss
dcw02 Feb 12, 2026
56477b9
clean up
dcw02 Feb 12, 2026
9c0242d
Merge upstream sgl-project/main -> dflash and fix conflicts
dcw02 Feb 12, 2026
d9c68a1
add qwen3-coder-next support (mamba)
dcw02 Feb 24, 2026
7e189bd
add page size > 1 support
dcw02 Feb 24, 2026
7a739f8
non greedy
dcw02 Feb 25, 2026
8a6bec9
Merge branch 'main' of github.com:sgl-project/sglang into dflash
dcw02 Feb 25, 2026
d21afd5
Merge branch 'pr-dflash-16818' into vlm-dflash-test
EanWang211123 Feb 26, 2026
b046451
Merge branch 'main' into vlm-dflash-test
EanWang211123 Feb 26, 2026
72cd001
“restore global server_args after DFlash worker initialization to pre…
EanWang211123 Feb 26, 2026
71a1879
Merge branch 'vlm-dflash-test' of https://github.com/EanWang211123/sg…
EanWang211123 Feb 26, 2026
a1fab19
Merge branch 'main' into vlm-dflash-test
EanWang211123 Feb 26, 2026
5b3a56e
Merge branch 'vlm-dflash-test' of https://github.com/EanWang211123/sg…
EanWang211123 Feb 26, 2026
5557c2e
Merge branch 'main' into vlm-dflash-test
EanWang211123 Feb 26, 2026
6c3e70e
Merge branch 'main' into vlm-dflash-test
EanWang211123 Feb 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
784 changes: 784 additions & 0 deletions benchmark/dflash/bench_dflash_gsm8k_sweep.py

Large diffs are not rendered by default.

576 changes: 576 additions & 0 deletions benchmark/dflash/bench_dflash_mmstar.py

Large diffs are not rendered by default.

30 changes: 22 additions & 8 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,8 +588,24 @@ def init_forward_metadata_capture_cuda_graph(
fast_decode_plan, decode_wrappers[i]
)
elif forward_mode.is_target_verify():
# FlashInfer's prefill wrapper decides mask mode based on whether
# `custom_mask_buf` is initialized (not whether a custom mask is provided).
# For cases like DFLASH draft (ENCODER_ONLY / non-causal) we do NOT use a
# custom mask, so we must avoid initializing `custom_mask_buf`, otherwise
# FlashInfer will treat the (zero) buffer as a real mask and block attention.
use_custom_mask = (
spec_info is not None
and getattr(spec_info, "custom_mask", None) is not None
)
prefill_wrappers = []
for i in range(self.num_wrappers):
wrapper_kwargs = {}
if use_custom_mask:
wrapper_kwargs = {
"custom_mask_buf": self.cuda_graph_custom_mask,
"mask_indptr_buf": self.cuda_graph_qk_indptr[i][: bs + 1],
}

prefill_wrappers.append(
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
Expand All @@ -600,8 +616,7 @@ def init_forward_metadata_capture_cuda_graph(
paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1],
paged_kv_indices_buf=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],
custom_mask_buf=self.cuda_graph_custom_mask,
mask_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1],
**wrapper_kwargs,
)
)
seq_lens_sum = seq_lens.sum().item()
Expand Down Expand Up @@ -774,10 +789,14 @@ def forward_extend(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)

causal = (
not layer.is_cross_attention
and layer.attn_type != AttentionType.ENCODER_ONLY
)
o = prefill_wrapper_paged.forward(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=not layer.is_cross_attention,
causal=causal,
sm_scale=layer.scaling,
# Disable sliding window attention for multi-item scoring:
# - Sliding window could cut across item boundaries, breaking semantic coherence
Expand Down Expand Up @@ -829,11 +848,6 @@ def forward_extend(
)

else:
if not self.is_dllm_model:
# TODO: design a better interface
# For other models, use causal attention for the ragged part as previously
causal = True

o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2055,7 +2055,11 @@ def filter_batch(
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
self.out_cache_loc = None
self.seq_lens_sum = self.seq_lens.sum().item()
# Use CPU copy to avoid GPU sync.
if self.seq_lens_cpu is not None:
self.seq_lens_sum = int(self.seq_lens_cpu.sum().item())
else:
self.seq_lens_sum = int(self.seq_lens.sum().item())

if self.output_ids is not None:
self.output_ids = self.output_ids[keep_indices_device]
Expand Down
20 changes: 20 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1561,6 +1561,26 @@ def handle_generate_request(
self._add_request_to_queue(req)
return

if self.spec_algorithm.is_dflash() and req.return_logprob:
req.set_finish_with_abort(
"DFLASH speculative decoding does not support return_logprob yet."
)
self.init_req_max_new_tokens(req)
self._add_request_to_queue(req)
return
if self.spec_algorithm.is_dflash() and (
req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None
or req.sampling_params.ebnf is not None
or req.sampling_params.structural_tag is not None
):
req.set_finish_with_abort(
"DFLASH speculative decoding does not support grammar-constrained decoding yet."
)
self.init_req_max_new_tokens(req)
self._add_request_to_queue(req)
return

# Handle multimodal inputs
if recv_req.mm_inputs is not None:
image_inputs = self._get_multimodal_inputs(recv_req.mm_inputs)
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/mem_cache/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,9 @@ def release_kv_cache(req: Req, tree_cache: BasePrefixCache, is_insert: bool = Tr
req.mamba_pool_idx.unsqueeze(-1)
)
req.mamba_pool_idx = None
# DFLASH tracks per-request draft progress on Req.
if hasattr(req, "dflash_draft_seq_len"):
req.dflash_draft_seq_len = 0
return

tree_cache.cache_finished_req(req, is_insert=is_insert)
Expand Down Expand Up @@ -506,6 +509,8 @@ def release_kv_cache(req: Req, tree_cache: BasePrefixCache, is_insert: bool = Tr
), "mamba state is freed while the tree cache does not manage mamba states"
tree_cache.req_to_token_pool.free_mamba_cache(req)
tree_cache.req_to_token_pool.free(req)
if hasattr(req, "dflash_draft_seq_len"):
req.dflash_draft_seq_len = 0


def available_and_evictable_str(tree_cache: BasePrefixCache) -> str:
Expand Down
79 changes: 73 additions & 6 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,14 +476,18 @@ def __init__(self, model_runner: ModelRunner):
model_runner.spec_algorithm.is_eagle()
or model_runner.spec_algorithm.is_standalone()
or model_runner.spec_algorithm.is_ngram()
or model_runner.spec_algorithm.is_dflash()
):
if self.model_runner.is_draft_worker:
raise RuntimeError("This should not happen")
else:
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
self.num_tokens_per_bs = (
self.model_runner.server_args.speculative_num_draft_tokens
)
# EAGLE/standalone/ngram draft workers use separate cuda-graph runners; do not
# capture TARGET_VERIFY graphs here. DFLASH draft uses a fixed-size block and
# reuses TARGET_VERIFY graphs for performance.
if not self.model_runner.spec_algorithm.is_dflash():
raise RuntimeError("This should not happen")
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
self.num_tokens_per_bs = (
self.model_runner.server_args.speculative_num_draft_tokens
)
elif self.is_dllm:
self.capture_forward_mode = ForwardMode.DLLM_EXTEND
self.num_tokens_per_bs = self.dllm_config.block_size
Expand Down Expand Up @@ -560,6 +564,18 @@ def __init__(self, model_runner: ModelRunner):
and model_runner.eagle_use_aux_hidden_state
):
self.model_runner.model.set_eagle3_layers_to_capture()
if (
model_runner.spec_algorithm.is_dflash()
and model_runner.dflash_use_aux_hidden_state
):
if not hasattr(self.model_runner.model, "set_dflash_layers_to_capture"):
raise ValueError(
f"Model {self.model_runner.model.__class__.__name__} does not implement set_dflash_layers_to_capture, "
"which is required for DFLASH aux hidden capture."
)
self.model_runner.model.set_dflash_layers_to_capture(
self.model_runner.dflash_target_layer_ids
)

# Capture
try:
Expand Down Expand Up @@ -912,6 +928,12 @@ def run_once():
kwargs["pp_proxy_tensors"] = PPProxyTensors(
{k: v.clone() for k, v in pp_proxy_tensors.tensors.items()}
)
if (
self.model_runner.spec_algorithm.is_dflash()
and self.model_runner.is_draft_worker
and "input_embeds" in inspect.signature(forward).parameters
):
kwargs["input_embeds"] = buffers.input_embeds[:num_tokens]

logits_output_or_pp_proxy_tensors = forward(
input_ids,
Expand Down Expand Up @@ -1009,6 +1031,13 @@ def replay_prepare(
),
pp_proxy_tensors=pp_proxy_tensors,
)
if (
self.model_runner.spec_algorithm.is_dflash()
and self.model_runner.is_draft_worker
and forward_batch.input_embeds is not None
):
buffers.input_embeds[:raw_num_token].copy_(forward_batch.input_embeds)
# Padded tokens aren't read, so skip zeroing them.
if self.enable_two_batch_overlap:
self.tbo_plugin.replay_prepare(
forward_mode=self.capture_forward_mode,
Expand Down Expand Up @@ -1054,6 +1083,14 @@ def replay(
# In speculative decoding, these two fields are still needed.
self.buffers.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
self.buffers.positions[: self.raw_num_token].copy_(forward_batch.positions)
if (
self.model_runner.spec_algorithm.is_dflash()
and self.model_runner.is_draft_worker
and forward_batch.input_embeds is not None
):
self.buffers.input_embeds[: self.raw_num_token].copy_(
forward_batch.input_embeds
)

# Replay
if self.enable_pdmux:
Expand All @@ -1063,6 +1100,8 @@ def replay(
self.graphs[graph_key].replay()
output = self.output_buffers[graph_key]

if isinstance(output, torch.Tensor):
return output[: self.raw_num_token]
if isinstance(output, LogitsProcessorOutput):
if self.is_dllm:
next_token_logits = None
Expand Down Expand Up @@ -1111,6 +1150,34 @@ def get_spec_info(self, num_tokens: int):
seq_lens_sum=None,
seq_lens_cpu=None,
)
elif self.model_runner.spec_algorithm.is_dflash():
from sglang.srt.speculative.dflash_info import DFlashVerifyInput

backend_name = type(self.model_runner.attn_backend).__name__
# Avoid enabling custom-mask modes during graph capture for backends that
# can express DFLASH verify via their built-in causal path.
skip_custom_mask = backend_name in {
"FlashInferAttnBackend",
"FlashInferMLAAttnBackend",
"FlashAttentionBackend",
"TRTLLMHAAttnBackend",
"TRTLLMMLABackend",
}
spec_info = DFlashVerifyInput(
draft_token=None,
positions=None,
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
custom_mask=(
None
if (self.model_runner.is_draft_worker or skip_custom_mask)
else self.buffers.custom_mask
),
capture_hidden_mode=(
CaptureHiddenMode.NULL
if self.model_runner.is_draft_worker
else CaptureHiddenMode.FULL
),
)

elif self.model_runner.spec_algorithm.is_ngram():
from sglang.srt.speculative.ngram_info import NgramVerifyInput
Expand Down
Loading
Loading