Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
10e563f
starting dflash impl
dcw02 Jan 6, 2026
289e748
fix verify mismatch
dcw02 Jan 7, 2026
f1efc03
add gsm8k bench
dcw02 Jan 7, 2026
e807216
support more backends, investigate accuracy
dcw02 Jan 7, 2026
99e140a
native sglang backend
dcw02 Jan 8, 2026
2c64b0e
remove hf backend
dcw02 Jan 8, 2026
f1a4262
dflash support flashinfer
dcw02 Jan 8, 2026
2c5b346
remove manual management of dflash kv pool
dcw02 Jan 8, 2026
6a38e63
add cuda graph
dcw02 Jan 8, 2026
40a81af
add cuda graph to draft worker
dcw02 Jan 8, 2026
510bf0c
update test
dcw02 Jan 8, 2026
c54f336
fix flashinfer backend
dcw02 Jan 9, 2026
8c8ee9c
initial radix cache support
dcw02 Jan 9, 2026
0edea3f
tp_size > 1 support
dcw02 Jan 9, 2026
f23555b
add optional dflash_config for overrides, add --speculative-dflash-bl…
dcw02 Jan 10, 2026
63c0b9a
fix OOMs with default settings
dcw02 Jan 10, 2026
9309764
clean up
dcw02 Jan 10, 2026
644ab29
clean up dflash load_weights
dcw02 Jan 10, 2026
ff6876a
attention selection logic
dcw02 Jan 10, 2026
32c3dd0
Merge remote-tracking branch 'upstream/main' into dflash
dcw02 Jan 10, 2026
d808ac9
decouple context feature count K from draft num layers
dcw02 Jan 12, 2026
e589ac1
clean up naming
dcw02 Jan 12, 2026
074efb2
performance optimizations
dcw02 Jan 12, 2026
fcc9bf7
skip Q, fused mlp
dcw02 Jan 12, 2026
a79264f
reuse buffers for decode
dcw02 Jan 12, 2026
ad5adbf
optimize greedy sampling
dcw02 Jan 12, 2026
37fc3f1
preallocate for tp>1
dcw02 Jan 12, 2026
72cbd9d
more buffers for tp>1
dcw02 Jan 12, 2026
5a577a3
dflash gsm8k benchmark sweep
dcw02 Jan 12, 2026
d968532
fix benchmark
dcw02 Jan 12, 2026
3e4177d
use device tensors for ctx_lens/draft_seq_lens, vectorize kv append a…
dcw02 Jan 20, 2026
b5b4bd6
precommit fixes
dcw02 Jan 20, 2026
ed8b16d
feat(dflash): add fused KV materialization kernel and optimize D2H
xiaomin-D Jan 20, 2026
f2a6dbc
add support for qwen3_moe
dcw02 Jan 20, 2026
117352d
support dflash_config.mask_token_id
dcw02 Feb 5, 2026
5ba316c
add llama3.1 support and fix config block_size logic
dcw02 Feb 5, 2026
5f8d0ec
Merge pull request #15 from yilian49/pr16818
dcw02 Feb 12, 2026
189f177
guards for fused path
dcw02 Feb 12, 2026
0841db6
add support for gpt oss
dcw02 Feb 12, 2026
56477b9
clean up
dcw02 Feb 12, 2026
9c0242d
Merge upstream sgl-project/main -> dflash and fix conflicts
dcw02 Feb 12, 2026
d9c68a1
add qwen3-coder-next support (mamba)
dcw02 Feb 24, 2026
7e189bd
add page size > 1 support
dcw02 Feb 24, 2026
7a739f8
non greedy
dcw02 Feb 25, 2026
8a6bec9
Merge branch 'main' of github.com:sgl-project/sglang into dflash
dcw02 Feb 25, 2026
0fe389d
rope rotation support
dcw02 Feb 26, 2026
a134f0a
clean up schedule_batch.py
dcw02 Feb 26, 2026
f62e5de
fix auto memory oom, cleanup
dcw02 Feb 28, 2026
2cc5f07
clean up
dcw02 Feb 28, 2026
760870c
Merge branch 'sgl-project:main' into dflash
dcw02 Feb 28, 2026
3b66746
Merge branch 'main' of github.com:sgl-project/sglang into dflash
dcw02 Mar 1, 2026
6913603
Merge branch 'main' of github.com:sgl-project/sglang into dflash
dcw02 Mar 2, 2026
26441b8
initial fa4 support to dflash, clean up benchmarking script
dcw02 Mar 2, 2026
74814de
clean up
dcw02 Mar 2, 2026
e493353
only run baseline once
dcw02 Mar 2, 2026
619b59c
[feat] add qwen3-5 dflash support
EanWang211123 Mar 3, 2026
a4b1e1e
[fix] fix offset of layer-ids
EanWang211123 Mar 3, 2026
5106aa5
Merge branch 'main' into dflash-qwen3_5
EanWang211123 Mar 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
720 changes: 720 additions & 0 deletions benchmark/dflash/bench_dflash_gsm8k_sweep.py

Large diffs are not rendered by default.

30 changes: 22 additions & 8 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,8 +591,24 @@ def init_forward_metadata_capture_cuda_graph(
fast_decode_plan, decode_wrappers[i]
)
elif forward_mode.is_target_verify():
# FlashInfer's prefill wrapper decides mask mode based on whether
# `custom_mask_buf` is initialized (not whether a custom mask is provided).
# For cases like DFLASH draft (ENCODER_ONLY / non-causal) we do NOT use a
# custom mask, so we must avoid initializing `custom_mask_buf`, otherwise
# FlashInfer will treat the (zero) buffer as a real mask and block attention.
use_custom_mask = (
spec_info is not None
and getattr(spec_info, "custom_mask", None) is not None
)
prefill_wrappers = []
for i in range(self.num_wrappers):
wrapper_kwargs = {}
if use_custom_mask:
wrapper_kwargs = {
"custom_mask_buf": self.cuda_graph_custom_mask,
"mask_indptr_buf": self.cuda_graph_qk_indptr[i][: bs + 1],
}

prefill_wrappers.append(
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
Expand All @@ -603,8 +619,7 @@ def init_forward_metadata_capture_cuda_graph(
paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1],
paged_kv_indices_buf=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],
custom_mask_buf=self.cuda_graph_custom_mask,
mask_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1],
**wrapper_kwargs,
)
)
seq_lens_sum = seq_lens.sum().item()
Expand Down Expand Up @@ -777,10 +792,14 @@ def forward_extend(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)

causal = (
not layer.is_cross_attention
and layer.attn_type != AttentionType.ENCODER_ONLY
)
o = prefill_wrapper_paged.forward(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=not layer.is_cross_attention,
causal=causal,
sm_scale=layer.scaling,
# Disable sliding window attention for multi-item scoring:
# - Sliding window could cut across item boundaries, breaking semantic coherence
Expand Down Expand Up @@ -832,11 +851,6 @@ def forward_extend(
)

else:
if not self.is_dllm_model:
# TODO: design a better interface
# For other models, use causal attention for the ragged part as previously
causal = True

o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
Expand Down
20 changes: 20 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1633,6 +1633,26 @@ def handle_generate_request(
self._add_request_to_queue(req)
return

if self.spec_algorithm.is_dflash() and req.return_logprob:
req.set_finish_with_abort(
"DFLASH speculative decoding does not support return_logprob yet."
)
self.init_req_max_new_tokens(req)
self._add_request_to_queue(req)
return
if self.spec_algorithm.is_dflash() and (
req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None
or req.sampling_params.ebnf is not None
or req.sampling_params.structural_tag is not None
):
req.set_finish_with_abort(
"DFLASH speculative decoding does not support grammar-constrained decoding yet."
)
self.init_req_max_new_tokens(req)
self._add_request_to_queue(req)
return
Comment on lines +1636 to +1654
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The two if blocks for checking DFLASH unsupported features are repetitive. They can be combined to reduce code duplication and improve readability.

Suggested change
if self.spec_algorithm.is_dflash() and req.return_logprob:
req.set_finish_with_abort(
"DFLASH speculative decoding does not support return_logprob yet."
)
self.init_req_max_new_tokens(req)
self._add_request_to_queue(req)
return
if self.spec_algorithm.is_dflash() and (
req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None
or req.sampling_params.ebnf is not None
or req.sampling_params.structural_tag is not None
):
req.set_finish_with_abort(
"DFLASH speculative decoding does not support grammar-constrained decoding yet."
)
self.init_req_max_new_tokens(req)
self._add_request_to_queue(req)
return
if self.spec_algorithm.is_dflash():
unsupported_reason = None
if req.return_logprob:
unsupported_reason = "return_logprob"
elif (
req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None
or req.sampling_params.ebnf is not None
or req.sampling_params.structural_tag is not None
):
unsupported_reason = "grammar-constrained decoding"
if unsupported_reason:
req.set_finish_with_abort(
f"DFLASH speculative decoding does not support {unsupported_reason} yet."
)
self.init_req_max_new_tokens(req)
self._add_request_to_queue(req)
return


# Handle multimodal inputs
if recv_req.mm_inputs is not None:
image_inputs = self._get_multimodal_inputs(recv_req.mm_inputs)
Expand Down
79 changes: 73 additions & 6 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,14 +509,18 @@ def __init__(self, model_runner: ModelRunner):
model_runner.spec_algorithm.is_eagle()
or model_runner.spec_algorithm.is_standalone()
or model_runner.spec_algorithm.is_ngram()
or model_runner.spec_algorithm.is_dflash()
):
if self.model_runner.is_draft_worker:
raise RuntimeError("This should not happen")
else:
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
self.num_tokens_per_bs = (
self.model_runner.server_args.speculative_num_draft_tokens
)
# EAGLE/standalone/ngram draft workers use separate cuda-graph runners; do not
# capture TARGET_VERIFY graphs here. DFLASH draft uses a fixed-size block and
# reuses TARGET_VERIFY graphs for performance.
if not self.model_runner.spec_algorithm.is_dflash():
raise RuntimeError("This should not happen")
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
self.num_tokens_per_bs = (
self.model_runner.server_args.speculative_num_draft_tokens
)
elif self.is_dllm:
self.capture_forward_mode = ForwardMode.DLLM_EXTEND
self.num_tokens_per_bs = self.dllm_config.block_size
Expand Down Expand Up @@ -593,6 +597,18 @@ def __init__(self, model_runner: ModelRunner):
and model_runner.eagle_use_aux_hidden_state
):
self.model_runner.model.set_eagle3_layers_to_capture()
if (
model_runner.spec_algorithm.is_dflash()
and model_runner.dflash_use_aux_hidden_state
):
if not hasattr(self.model_runner.model, "set_dflash_layers_to_capture"):
raise ValueError(
f"Model {self.model_runner.model.__class__.__name__} does not implement set_dflash_layers_to_capture, "
"which is required for DFLASH aux hidden capture."
)
self.model_runner.model.set_dflash_layers_to_capture(
self.model_runner.dflash_target_layer_ids
)

# Capture
try:
Expand All @@ -618,6 +634,7 @@ def can_run(self, forward_batch: ForwardBatch):
max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
or self.model_runner.spec_algorithm.is_standalone()
or self.model_runner.spec_algorithm.is_dflash()
else max(forward_batch.global_num_tokens_cpu)
)
else:
Expand Down Expand Up @@ -945,6 +962,12 @@ def run_once():
kwargs["pp_proxy_tensors"] = PPProxyTensors(
{k: v.clone() for k, v in pp_proxy_tensors.tensors.items()}
)
if (
self.model_runner.spec_algorithm.is_dflash()
and self.model_runner.is_draft_worker
and "input_embeds" in inspect.signature(forward).parameters
):
kwargs["input_embeds"] = buffers.input_embeds[:num_tokens]

logits_output_or_pp_proxy_tensors = forward(
input_ids,
Expand Down Expand Up @@ -1021,6 +1044,7 @@ def replay_prepare(
max_num_tokens / self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
or self.model_runner.spec_algorithm.is_standalone()
or self.model_runner.spec_algorithm.is_dflash()
else max_num_tokens
)
index = bisect.bisect_left(self.capture_bs, max_batch_size)
Expand All @@ -1042,6 +1066,13 @@ def replay_prepare(
),
pp_proxy_tensors=pp_proxy_tensors,
)
if (
self.model_runner.spec_algorithm.is_dflash()
and self.model_runner.is_draft_worker
and forward_batch.input_embeds is not None
):
buffers.input_embeds[:raw_num_token].copy_(forward_batch.input_embeds)
# Padded tokens aren't read, so skip zeroing them.
if self.enable_two_batch_overlap:
self.tbo_plugin.replay_prepare(
forward_mode=self.capture_forward_mode,
Expand Down Expand Up @@ -1087,6 +1118,14 @@ def replay(
# In speculative decoding, these two fields are still needed.
self.buffers.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
self.buffers.positions[: self.raw_num_token].copy_(forward_batch.positions)
if (
self.model_runner.spec_algorithm.is_dflash()
and self.model_runner.is_draft_worker
and forward_batch.input_embeds is not None
):
self.buffers.input_embeds[: self.raw_num_token].copy_(
forward_batch.input_embeds
)

# Replay
if self.enable_pdmux:
Expand All @@ -1096,6 +1135,8 @@ def replay(
self.graphs[graph_key].replay()
output = self.output_buffers[graph_key]

if isinstance(output, torch.Tensor):
return output[: self.raw_num_token]
if isinstance(output, LogitsProcessorOutput):
if self.is_dllm:
next_token_logits = None
Expand Down Expand Up @@ -1144,6 +1185,32 @@ def get_spec_info(self, num_tokens: int):
seq_lens_sum=None,
seq_lens_cpu=None,
)
elif self.model_runner.spec_algorithm.is_dflash():
from sglang.srt.speculative.dflash_info import DFlashVerifyInput
from sglang.srt.speculative.dflash_utils import (
resolve_dflash_verify_mask_policy,
)

# Avoid enabling custom-mask modes during graph capture for backends that
# can express DFLASH verify via their built-in causal path.
_, build_custom_mask = resolve_dflash_verify_mask_policy(
self.model_runner.attn_backend
)
spec_info = DFlashVerifyInput(
draft_token=None,
positions=None,
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
custom_mask=(
None
if (self.model_runner.is_draft_worker or not build_custom_mask)
else self.buffers.custom_mask
),
capture_hidden_mode=(
CaptureHiddenMode.NULL
if self.model_runner.is_draft_worker
else CaptureHiddenMode.FULL
),
)

elif self.model_runner.spec_algorithm.is_ngram():
from sglang.srt.speculative.ngram_info import NgramVerifyInput
Expand Down
86 changes: 82 additions & 4 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@
get_global_server_args,
set_global_server_args_for_scheduler,
)
from sglang.srt.speculative.dflash_utils import (
parse_dflash_draft_config,
)
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
MultiprocessingSerializer,
Expand Down Expand Up @@ -345,6 +348,9 @@ def __init__(
self.remote_instance_transfer_engine_weight_info = None
# auxiliary hidden capture mode. TODO: expose this to server args?
self.eagle_use_aux_hidden_state = False
self.dflash_use_aux_hidden_state = False
self.dflash_target_layer_ids = None
self.dflash_draft_num_layers = None
if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
# load draft config
draft_model_config = ModelConfig.from_server_args(
Expand All @@ -370,6 +376,51 @@ def __init__(
# if there is no aux layer, set to None
self.eagle_aux_hidden_state_layer_ids = None

if self.spec_algorithm.is_dflash() and not self.is_draft_worker:
# Select target layers to capture for building DFlash context features.
draft_model_config = ModelConfig.from_server_args(
server_args,
model_path=(server_args.speculative_draft_model_path),
model_revision=server_args.speculative_draft_model_revision,
is_draft_model=True,
)
dflash_draft_config = parse_dflash_draft_config(
draft_hf_config=draft_model_config.hf_config
)
draft_num_layers = dflash_draft_config.require_num_layers()
trained_target_layers = dflash_draft_config.num_target_layers

target_num_layers = getattr(
self.model_config.hf_config, "num_hidden_layers", None
)
if target_num_layers is None:
# VLM (e.g. Qwen3.5) has num_hidden_layers in text_config, not top-level
target_num_layers = self.model_config.num_hidden_layers
if target_num_layers is None:
raise ValueError(
"DFLASH requires target num_hidden_layers in config. "
f"Got target={target_num_layers}."
)
target_num_layers = int(target_num_layers)

if (
trained_target_layers is not None
and trained_target_layers != target_num_layers
):
logger.warning(
"DFLASH draft config num_target_layers=%s differs from runtime target num_hidden_layers=%s; "
"selecting capture layers based on the runtime target model.",
trained_target_layers,
target_num_layers,
)

self.dflash_use_aux_hidden_state = True
self.dflash_draft_num_layers = int(draft_num_layers)
self.dflash_target_layer_ids = dflash_draft_config.resolve_target_layer_ids(
target_num_layers=int(target_num_layers),
draft_num_layers=int(draft_num_layers),
)

# Apply the rank zero filter to logger
if server_args.show_time_cost:
enable_show_time_cost()
Expand Down Expand Up @@ -636,6 +687,14 @@ def initialize(self, pre_model_load_memory: float):
self.eagle_aux_hidden_state_layer_ids
)

if self.dflash_use_aux_hidden_state:
if not hasattr(self.model, "set_dflash_layers_to_capture"):
raise ValueError(
f"Model {self.model.__class__.__name__} does not implement set_dflash_layers_to_capture, "
"which is required for DFLASH."
)
self.model.set_dflash_layers_to_capture(self.dflash_target_layer_ids)

# Initialize piecewise CUDA graph
self.init_piecewise_cuda_graphs()

Expand Down Expand Up @@ -1847,6 +1906,7 @@ def _should_run_flashinfer_autotune(self) -> bool:
self.spec_algorithm.is_eagle()
or self.spec_algorithm.is_standalone()
or self.spec_algorithm.is_ngram()
or self.spec_algorithm.is_dflash()
):
return not self.is_draft_worker

Expand Down Expand Up @@ -1881,12 +1941,13 @@ def _dummy_run(self, batch_size: int, run_ctx=None):
self.spec_algorithm.is_eagle()
or self.spec_algorithm.is_standalone()
or self.spec_algorithm.is_ngram()
or self.spec_algorithm.is_dflash()
):
if self.is_draft_worker:
raise RuntimeError("This should not happen")
else:
capture_forward_mode = ForwardMode.TARGET_VERIFY
num_tokens_per_bs = self.server_args.speculative_num_draft_tokens
if not self.spec_algorithm.is_dflash():
raise RuntimeError("This should not happen")
capture_forward_mode = ForwardMode.TARGET_VERIFY
num_tokens_per_bs = self.server_args.speculative_num_draft_tokens

if self.server_args.enable_return_hidden_states:
capture_hidden_mode = CaptureHiddenMode.FULL
Expand All @@ -1906,6 +1967,8 @@ def _dummy_run(self, batch_size: int, run_ctx=None):

if self.eagle_use_aux_hidden_state:
self.model.set_eagle3_layers_to_capture()
if self.dflash_use_aux_hidden_state:
self.model.set_dflash_layers_to_capture(self.dflash_target_layer_ids)

require_mlp_tp_gather_ = require_mlp_tp_gather(self.server_args)
if require_gathered_buffer(self.server_args):
Expand Down Expand Up @@ -2015,6 +2078,21 @@ def get_spec_info():
seq_lens_sum=None,
seq_lens_cpu=None,
)
elif self.spec_algorithm.is_dflash():
from sglang.srt.speculative.dflash_info import DFlashVerifyInput

# Dummy warmup only needs shape metadata; avoid forcing custom-mask mode.
spec_info = DFlashVerifyInput(
draft_token=None,
positions=None,
draft_token_num=self.server_args.speculative_num_draft_tokens,
custom_mask=None,
capture_hidden_mode=(
CaptureHiddenMode.NULL
if self.is_draft_worker
else CaptureHiddenMode.FULL
),
)

elif self.spec_algorithm.is_ngram():
from sglang.srt.speculative.ngram_info import NgramVerifyInput
Expand Down
Loading
Loading