Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
70bdac0
feat(spec): add dflash spec v2
dcw02 Apr 16, 2026
9e87ef4
remove benchmark sweep
dcw02 Apr 16, 2026
c0a329d
remove dflash spec v2 specific env
dcw02 Apr 16, 2026
de0372a
clean up
dcw02 Apr 16, 2026
a722fee
remove mamba memory calculations
dcw02 Apr 16, 2026
ad6f0bb
update test for spec v2 and overlap plan streams
dcw02 Apr 16, 2026
7ea9fbd
fix dflash rope config for transformers v5
dcw02 Apr 25, 2026
6465189
small cleanup
dcw02 Apr 25, 2026
ce09806
decouple dflash v2 next step planning from lagging host metadata
dcw02 Apr 25, 2026
89a4a26
draft swa layer support
dcw02 Apr 26, 2026
2b0f324
fix dflash swa flashinfer
dcw02 Apr 26, 2026
f39d86d
Merge branch 'main' of github.com:sgl-project/sglang into dcw02/dflas…
dcw02 Apr 28, 2026
9893ef8
gemma 4 support
dcw02 May 1, 2026
fe8ceef
clean up dead methods from gemma 4 dflash support
dcw02 May 11, 2026
bf853ca
re-enable greedy determinism test
dcw02 May 11, 2026
d89e71e
spec algo need_topk() for future map instead of special casing dflash
dcw02 May 11, 2026
ebb2526
Merge remote-tracking branch 'origin/main' into dcw02/dflash-spec-v2
dcw02 May 12, 2026
c07177a
clean up gemma 4 changes
dcw02 May 12, 2026
35cbfea
update dflash server_args.py for spec v2 default
dcw02 May 12, 2026
f5ba4cb
simplify dflash server_args
dcw02 May 12, 2026
5192b04
enable pcg for dflash
dcw02 May 12, 2026
55a7980
remove useless verify prep side stream that was immediately joined be…
dcw02 May 13, 2026
93758bb
DFlash prefill refill heuristic
dcw02 May 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 39 additions & 17 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.mem_cache.swa_memory_pool import SWATokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
from sglang.srt.utils import (
get_int_env_var,
is_flashinfer_available,
Expand Down Expand Up @@ -606,9 +606,7 @@ def init_forward_metadata_capture_cuda_graph(
elif forward_mode.is_target_verify():
# FlashInfer's prefill wrapper decides mask mode based on whether
# `custom_mask_buf` is initialized (not whether a custom mask is provided).
# For cases like DFLASH draft (ENCODER_ONLY / non-causal) we do NOT use a
# custom mask, so we must avoid initializing `custom_mask_buf`, otherwise
# FlashInfer will treat the (zero) buffer as a real mask and block attention.
# DFlash relies on layer causal/window metadata instead of a custom mask.
use_custom_mask = (
spec_info is not None
and getattr(spec_info, "custom_mask", None) is not None
Expand Down Expand Up @@ -1241,7 +1239,7 @@ def update(
seq_lens: torch.Tensor,
seq_lens_cpu: Optional[torch.Tensor],
seq_lens_sum: int,
prefix_lens: torch.Tensor,
prefix_lens: Optional[torch.Tensor],
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
Expand All @@ -1259,7 +1257,7 @@ def update_single_wrapper(
seq_lens: torch.Tensor,
seq_lens_cpu: Optional[torch.Tensor],
seq_lens_sum: int,
prefix_lens: torch.Tensor,
prefix_lens: Optional[torch.Tensor],
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
Expand All @@ -1269,6 +1267,7 @@ def update_single_wrapper(
cross_attention_custom_mask: Optional[torch.Tensor] = None,
):
if use_ragged:
assert prefix_lens is not None
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
# and forward_batch.extend_seq_lens_cpu
paged_kernel_lens = prefix_lens
Expand Down Expand Up @@ -1300,7 +1299,7 @@ def update_sliding_window(
seq_lens: torch.Tensor,
seq_lens_cpu: Optional[torch.Tensor],
seq_lens_sum: int,
prefix_lens: torch.Tensor,
prefix_lens: Optional[torch.Tensor],
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
Expand All @@ -1309,12 +1308,23 @@ def update_sliding_window(
multi_item_params: Optional[MultiItemScoringParams] = None,
cross_attention_custom_mask: Optional[torch.Tensor] = None,
):
if prefix_lens is None:
accept_length = getattr(spec_info, "accept_length", None)
prefix_lens = (
seq_lens
if accept_length is None
else seq_lens
- accept_length[: seq_lens.shape[0]].to(
device=seq_lens.device, dtype=seq_lens.dtype
)
)
window_size = seq_lens.new_tensor(self.sliding_window_size)
for wrapper_id in range(2):
if wrapper_id == 0:
# window attention use paged only
paged_kernel_lens = torch.minimum(
seq_lens,
torch.tensor(self.sliding_window_size) + seq_lens - prefix_lens,
window_size + seq_lens - prefix_lens,
)
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
else:
Expand Down Expand Up @@ -1350,7 +1360,7 @@ def update_cross_attention(
seq_lens: torch.Tensor,
seq_lens_cpu: Optional[torch.Tensor],
seq_lens_sum: int,
prefix_lens: torch.Tensor,
prefix_lens: Optional[torch.Tensor],
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
Expand Down Expand Up @@ -1398,7 +1408,7 @@ def call_begin_forward(
paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
seq_lens: torch.Tensor,
prefix_lens: torch.Tensor,
prefix_lens: Optional[torch.Tensor],
kv_start_idx: torch.Tensor,
kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor,
Expand All @@ -1411,6 +1421,7 @@ def call_begin_forward(
):
bs = len(seq_lens)
if spec_info is None:
assert prefix_lens is not None
assert len(seq_lens) == len(req_pool_indices)
# Normal extend
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
Expand All @@ -1435,14 +1446,25 @@ def call_begin_forward(
custom_mask = cross_attention_custom_mask
else:
assert isinstance(spec_info, SpecInput)
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
req_pool_indices,
paged_kernel_lens,
paged_kernel_lens_sum,
self.req_to_token,
if spec_info.spec_input_type == SpecInputType.DFLASH_VERIFY:
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
req_pool_indices,
paged_kernel_lens,
paged_kernel_lens_sum,
self.req_to_token,
kv_start_idx=kv_start_idx,
)
)
else:
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
req_pool_indices,
paged_kernel_lens,
paged_kernel_lens_sum,
self.req_to_token,
)
)
)

# extend part
if use_ragged:
Expand Down
89 changes: 55 additions & 34 deletions python/sglang/srt/managers/overlap_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,41 +83,54 @@ def __init__(
def _lazy_init_buf(self, draft_input: EagleDraftInput):
self.buf_initialized = True

# Get a reference for each tensor
topk_p0 = draft_input.topk_p[0]
topk_index0 = draft_input.topk_index[0]
bonus_token0 = draft_input.bonus_tokens[0]
self.need_verified_id = getattr(draft_input, "verified_id", None) is not None
self.need_bonus_tokens = getattr(draft_input, "bonus_tokens", None) is not None

if self.need_verified_id:
verified_id0 = draft_input.verified_id[0]
self.verified_id_buf = torch.empty(
(self.future_buffer_len, *verified_id0.shape),
dtype=verified_id0.dtype,
device=self.device,
)
if self.need_bonus_tokens:
bonus_token0 = draft_input.bonus_tokens[0]
self.bonus_tokens_buf = torch.empty(
(self.future_buffer_len, *bonus_token0.shape),
dtype=bonus_token0.dtype,
device=self.device,
)
new_seq_lens0 = draft_input.new_seq_lens[0]

self.topk_p_buf = torch.empty(
(self.future_buffer_len, *topk_p0.shape),
dtype=topk_p0.dtype,
device=self.device,
)
self.topk_index_buf = torch.empty(
(self.future_buffer_len, *topk_index0.shape),
dtype=topk_index0.dtype,
device=self.device,
)
self.bonus_tokens_buf = torch.empty(
(self.future_buffer_len, *bonus_token0.shape),
dtype=bonus_token0.dtype,
device=self.device,
)
self.new_seq_lens_buf = torch.empty(
(self.future_buffer_len, *new_seq_lens0.shape),
dtype=new_seq_lens0.dtype,
device=self.device,
)

if spec_need_hidden_states():
hidden_states0 = draft_input.hidden_states[0]
self.hidden_states_buf = torch.empty(
(self.future_buffer_len, *hidden_states0.shape),
dtype=hidden_states0.dtype,
if self.spec_algo.need_topk():
# Get a reference for each tensor
topk_p0 = draft_input.topk_p[0]
topk_index0 = draft_input.topk_index[0]
self.topk_p_buf = torch.empty(
(self.future_buffer_len, *topk_p0.shape),
dtype=topk_p0.dtype,
device=self.device,
)
self.topk_index_buf = torch.empty(
(self.future_buffer_len, *topk_index0.shape),
dtype=topk_index0.dtype,
device=self.device,
)

if spec_need_hidden_states():
hidden_states0 = draft_input.hidden_states[0]
self.hidden_states_buf = torch.empty(
(self.future_buffer_len, *hidden_states0.shape),
dtype=hidden_states0.dtype,
device=self.device,
)

def alloc_future_indices(self, bs: int) -> FutureIndices:
"""Update the circular buffer pointer and allocate future indices."""
cur_future_ct = self.future_ct
Expand All @@ -144,12 +157,16 @@ def resolve_future(self, model_worker_batch: ModelWorkerBatch):
# caching allocator (torch GC) could reclaim the memory before
# the GPU finishes reading it.
indices.record_stream(torch.get_device_module(self.device).current_stream())
draft_input.topk_p = self.topk_p_buf[indices]
draft_input.topk_index = self.topk_index_buf[indices]
draft_input.bonus_tokens = self.bonus_tokens_buf[indices]
if self.need_verified_id:
draft_input.verified_id = self.verified_id_buf[indices]
if self.need_bonus_tokens:
draft_input.bonus_tokens = self.bonus_tokens_buf[indices]
draft_input.new_seq_lens = self.new_seq_lens_buf[indices]
if spec_need_hidden_states():
draft_input.hidden_states = self.hidden_states_buf[indices]
if self.spec_algo.need_topk():
draft_input.topk_p = self.topk_p_buf[indices]
draft_input.topk_index = self.topk_index_buf[indices]
if spec_need_hidden_states():
draft_input.hidden_states = self.hidden_states_buf[indices]

def is_empty_slice(self, s: slice) -> bool:
start, stop, step = s.indices(self.future_buffer_len)
Expand Down Expand Up @@ -179,9 +196,13 @@ def store_to_map_for_new_batch(
if not self.buf_initialized:
self._lazy_init_buf(draft_input)

self.topk_p_buf[intv] = draft_input.topk_p
self.topk_index_buf[intv] = draft_input.topk_index
self.bonus_tokens_buf[intv] = draft_input.bonus_tokens
if self.need_verified_id:
self.verified_id_buf[intv] = draft_input.verified_id
if self.need_bonus_tokens:
self.bonus_tokens_buf[intv] = draft_input.bonus_tokens
self.new_seq_lens_buf[intv] = draft_input.new_seq_lens
if spec_need_hidden_states():
self.hidden_states_buf[intv] = draft_input.hidden_states
if self.spec_algo.need_topk():
self.topk_p_buf[intv] = draft_input.topk_p
self.topk_index_buf[intv] = draft_input.topk_index
if spec_need_hidden_states():
self.hidden_states_buf[intv] = draft_input.hidden_states
18 changes: 14 additions & 4 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2388,10 +2388,20 @@ def prepare_for_decode(self):
)

def maybe_wait_verify_done(self):
if self.is_spec_v2:
draft_input: EagleDraftInput = self.spec_info
if draft_input.verify_done is not None:
draft_input.verify_done.synchronize()
if not self.is_spec_v2:
return

draft_input: EagleDraftInput = self.spec_info
verify_done = getattr(draft_input, "verify_done", None)
if verify_done is None:
return

if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get():
torch.get_device_module(self.device).current_stream().wait_event(
verify_done
)
else:
verify_done.synchronize()

def filter_batch(
self,
Expand Down
31 changes: 28 additions & 3 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,10 @@
from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
from sglang.srt.session.session_controller import SessionController
from sglang.srt.session.streaming_session import StreamingSession
from sglang.srt.speculative.dflash_utils import (
resolve_dflash_prefill_refill_target,
should_delay_dflash_prefill_for_batching,
)
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
DynamicGradMode,
Expand Down Expand Up @@ -303,10 +307,12 @@ def copy_to_cpu(self):
self.copy_done.record()


def validate_dflash_request(req: Req) -> Optional[str]:
def validate_dflash_request(req: Req, enable_overlap: bool) -> Optional[str]:
if req.return_logprob:
return "DFLASH speculative decoding does not support return_logprob yet."

if enable_overlap and req.return_hidden_states:
return "DFLASH speculative decoding does not support return_hidden_states yet."
if (
req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None
Expand Down Expand Up @@ -726,6 +732,11 @@ def init_model_worker(self):
_,
_,
) = self.tp_worker.get_worker_info()
self.dflash_prefill_refill_target = (
resolve_dflash_prefill_refill_target(self.max_running_requests)
if self.spec_algorithm.is_dflash()
else 1
)
if not get_global_server_args().pp_max_micro_batch_size:
get_global_server_args().pp_max_micro_batch_size = max(
self.max_running_requests // self.pp_size, 1
Expand Down Expand Up @@ -2085,13 +2096,12 @@ def handle_generate_request(
return

if self.spec_algorithm.is_dflash():
error_msg = validate_dflash_request(req)
error_msg = validate_dflash_request(req, self.enable_overlap)
if error_msg is not None:
req.set_finish_with_abort(error_msg)
self.init_req_max_new_tokens(req)
self._add_request_to_queue(req)
return

# Handle multimodal inputs
if recv_req.mm_inputs is not None:
image_inputs = self._get_multimodal_inputs(recv_req.mm_inputs)
Expand Down Expand Up @@ -2580,6 +2590,19 @@ def get_num_allocatable_reqs(self, running_bs):
res = min(res, self.req_to_token_pool.available_size())
return res

def _should_delay_dflash_prefill_for_batching(self, running_bs: int) -> bool:
if not self.spec_algorithm.is_dflash():
return False
if running_bs <= 0 or self.chunked_req is not None:
return False

return should_delay_dflash_prefill_for_batching(
running_bs=running_bs,
num_allocatable_reqs=self.get_num_allocatable_reqs(running_bs),
max_running_requests=self.max_running_requests,
prefill_refill_target=self.dflash_prefill_refill_target,
)

def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
prefill_delayer_single_pass = None
if self.prefill_delayer:
Expand Down Expand Up @@ -2620,6 +2643,8 @@ def _get_new_batch_prefill_raw(
return None

running_bs = len(self.running_batch.reqs)
if self._should_delay_dflash_prefill_for_batching(running_bs):
return None

# Ignore the check if self.chunked_req is not None.
# In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
Expand Down
Loading
Loading