Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
3bedab6
init
Chen-0210 Jan 25, 2026
04b91bc
clean
Chen-0210 Jan 26, 2026
0e96e36
clean
Chen-0210 Jan 26, 2026
644da7a
clean
Chen-0210 Jan 27, 2026
e2c0d09
fix
Chen-0210 Jan 30, 2026
e9f1f27
Merge branch 'main' into fix-qwen3-next
Chen-0210 Jan 30, 2026
eeb629d
drop zero tensor
Chen-0210 Jan 30, 2026
c7b5770
fix
Chen-0210 Feb 2, 2026
4ed1dfd
Merge branch 'main' into fix-qwen3-next
Chen-0210 Feb 2, 2026
31873eb
add test
Chen-0210 Feb 2, 2026
454dc93
fix
Chen-0210 Feb 2, 2026
76c1527
Merge branch 'main' into fix-qwen3-next
Kangyan-Zhou Feb 6, 2026
01fd7ab
Merge branch 'main' into fix-qwen3-next
Chen-0210 Feb 7, 2026
c8c5ac5
Merge branch 'main' into fix-qwen3-next
Chen-0210 Feb 14, 2026
b4adfb9
fix
Chen-0210 Feb 14, 2026
a52e5ea
Merge origin/main into fix-qwen3-next
Chen-0210 Mar 27, 2026
dc1d6e8
Drop qwen3_next split-op changes
Chen-0210 Mar 27, 2026
3cbaa79
Align qwen3_next formatting with main
Chen-0210 Mar 27, 2026
2f624af
Preserve mamba tracking and empty output allocation
Chen-0210 Mar 27, 2026
04959bf
Merge branch 'main' into fix-qwen3-next
Chen-0210 Mar 27, 2026
d9f8f60
Fix linear attention output copy slice
Chen-0210 Mar 27, 2026
74fe3c7
Merge remote-tracking branch 'origin/main' into codex/rebase-fix-qwen…
Chen-0210 Mar 30, 2026
811d798
Use non-padded cache loc in piecewise replay
Chen-0210 Mar 30, 2026
f2e5b8f
Apply pre-commit formatting
Chen-0210 Mar 30, 2026
22ecabd
Merge branch 'main' into fix-qwen3-next
Chen-0210 Mar 30, 2026
945db4f
Merge branch 'main' into fix-qwen3-next
ispobock Apr 1, 2026
0cb68c4
Merge branch 'main' into fix-qwen3-next
Chen-0210 Apr 7, 2026
8896c2f
Remove static qo_indptr padding in piecewise replay
Chen-0210 Apr 7, 2026
bf5b440
Merge branch 'main' into fix-qwen3-next
Chen-0210 Apr 7, 2026
f8b4a2e
Merge branch 'main' into fix-qwen3-next
Chen-0210 Apr 7, 2026
05d71d4
Slice cache loc in unified attention split ops
Chen-0210 Apr 8, 2026
58241d3
Remove static out_cache_loc model whitelist
Chen-0210 Apr 8, 2026
68fc842
Simplify piecewise cache loc replay
Chen-0210 Apr 9, 2026
ff22072
Drop piecewise cache loc branching and test overrides
Chen-0210 Apr 9, 2026
1107938
Merge branch 'main' into fix-qwen3-next
Chen-0210 Apr 9, 2026
71db919
Preserve forward batch identity in unified attention
Chen-0210 Apr 9, 2026
d5486e4
Merge remote-tracking branch 'chen0210/fix-qwen3-next' into codex/fix…
Chen-0210 Apr 9, 2026
8563107
Simplify unified attention cache loc restore
Chen-0210 Apr 9, 2026
2f84232
Merge branch 'main' into fix-qwen3-next
Chen-0210 Apr 9, 2026
9427dd7
Merge branch 'main' into fix-qwen3-next
Chen-0210 Apr 10, 2026
00c2192
Slice SWA cache loc in unified attention
Chen-0210 Apr 10, 2026
c2dd538
Merge remote-tracking branch 'chen0210/fix-qwen3-next' into codex/fix…
Chen-0210 Apr 10, 2026
53bec38
Merge remote-tracking branch 'origin/main' into fix-qwen3-next
Chen-0210 Apr 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions python/sglang/srt/compilation/piecewise_context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def __init__(self):
self.quant_config = None
self.moe_layers = None
self.moe_fusions = None
self.num_tokens: Optional[int] = None

def set_forward_batch(self, forward_batch: ForwardBatch):
self.forward_batch = forward_batch
Expand Down Expand Up @@ -105,7 +104,6 @@ def set_forward_context(
quant_config: Any,
moe_layers: List[Any],
moe_fusions: List[Any],
num_tokens: Optional[int] = None,
):
global _forward_context
_forward_context = ForwardContext()
Expand All @@ -114,7 +112,6 @@ def set_forward_context(
_forward_context.set_quant_config(quant_config)
_forward_context.set_moe_layers(moe_layers)
_forward_context.set_moe_fusions(moe_fusions)
_forward_context.num_tokens = num_tokens
try:
yield
finally:
Expand Down
52 changes: 3 additions & 49 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@
import torch

from sglang.kernel_api_logging import debug_kernel_api
from sglang.srt.compilation.piecewise_context_manager import (
get_forward_context,
is_in_piecewise_cuda_graph,
)
from sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph
from sglang.srt.dllm.config import DllmConfig
from sglang.srt.environ import envs
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
Expand Down Expand Up @@ -150,8 +147,6 @@ def __init__(
self.max_context_len = model_runner.model_config.context_len
self.skip_prefill = skip_prefill
self.is_multimodal = model_runner.model_config.is_multimodal
self.page_size = model_runner.page_size

assert not (
model_runner.sliding_window_size is not None
and model_runner.model_config.is_encoder_decoder
Expand Down Expand Up @@ -1210,8 +1205,6 @@ def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBacken
self.q_data_type = model_runner.dtype
self.sliding_window_size = model_runner.sliding_window_size
self.attn_backend = attn_backend
self.page_size = attn_backend.page_size

# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
Expand Down Expand Up @@ -1400,13 +1393,8 @@ def call_begin_forward(
# Normal extend
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
# Reserve extra space in kv_indices for a potential piecewise CUDA graph
# dummy request (see below). Worst case: static_num_tokens extra pages.
fwd_ctx = get_forward_context()
pcg_num_tokens = fwd_ctx.num_tokens if fwd_ctx is not None else None
extra_kv = pcg_num_tokens if pcg_num_tokens is not None else 0
kv_indices = torch.empty(
paged_kernel_lens_sum + extra_kv + 256,
paged_kernel_lens_sum + 256,
dtype=torch.int32,
device=req_pool_indices.device,
)
Expand All @@ -1422,39 +1410,6 @@ def call_begin_forward(
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]

# Piecewise CUDA graph padding: input_ids are padded to static_num_tokens,
# so q.shape[0] == static_num_tokens but qo_indptr[-1] == actual tokens.
# Append a dummy request for the padding tokens so that
# qo_indptr[-1] == static_num_tokens, satisfying flashinfer's shape check
# without corrupting the causal masks of real requests.
# The dummy request's KV indices all point to slot 0 (a scratch location);
# its attention output is discarded via the [:raw_num_tokens] slice in replay.
bs_eff = bs
# extend_num_tokens is a Python int (== sum of seq_lens - prefix_lens),
# and paged_kernel_lens_sum is also a Python int (== kv_indptr[-1]),
# so this block requires no CPU-GPU synchronisation.
actual_qo_tokens = (
fwd_ctx.forward_batch.extend_num_tokens if fwd_ctx is not None else None
)
if (
pcg_num_tokens is not None
and actual_qo_tokens is not None
and pcg_num_tokens > actual_qo_tokens
):
pad_tokens = pcg_num_tokens - actual_qo_tokens
num_dummy_pages = (pad_tokens + self.page_size - 1) // self.page_size
kv_start = (
paged_kernel_lens_sum # equals kv_indptr[-1], no .item() needed
)
kv_indices[kv_start : kv_start + num_dummy_pages] = 0
qo_indptr = torch.cat(
[qo_indptr, qo_indptr.new_tensor([pcg_num_tokens])]
)
kv_indptr = torch.cat(
[kv_indptr, kv_indptr.new_tensor([kv_start + num_dummy_pages])]
)
bs_eff = bs + 1

custom_mask = None
else:
assert isinstance(spec_info, SpecInput)
Expand All @@ -1466,7 +1421,6 @@ def call_begin_forward(
self.req_to_token,
)
)
bs_eff = bs

# extend part
if use_ragged:
Expand Down Expand Up @@ -1508,7 +1462,7 @@ def call_begin_forward(
qo_indptr,
kv_indptr,
kv_indices,
self.kv_last_page_len[:bs_eff],
self.kv_last_page_len[:bs],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
Expand Down
40 changes: 33 additions & 7 deletions python/sglang/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,21 +153,47 @@ def unified_attention_with_output(
forward_batch = context.forward_batch
attention_layers = context.attention_layers
attention_layer = attention_layers[layer_id]
real_num_tokens = forward_batch.num_token_non_padded_cpu

query = query[:real_num_tokens]
key = key[:real_num_tokens]
value = value[:real_num_tokens]

kwargs = {}
if q_rope is not None:
kwargs["q_rope"] = q_rope
kwargs["q_rope"] = q_rope[:real_num_tokens]
if k_rope is not None:
kwargs["k_rope"] = k_rope
kwargs["k_rope"] = k_rope[:real_num_tokens]
if sinks is not None:
kwargs["sinks"] = sinks

original_out_cache_loc = forward_batch.out_cache_loc
original_out_cache_loc_swa = forward_batch.out_cache_loc_swa
token_to_kv_pool = forward_batch.token_to_kv_pool
original_swa_loc = getattr(token_to_kv_pool, "swa_loc", None)
# Keep the original ForwardBatch object and only narrow cache locations for
# this backend call so model/backend state is still written to the same batch.
forward_batch.out_cache_loc = original_out_cache_loc[:real_num_tokens]
if original_out_cache_loc_swa is not None:
forward_batch.out_cache_loc_swa = original_out_cache_loc_swa[:real_num_tokens]
if hasattr(token_to_kv_pool, "set_swa_loc"):
token_to_kv_pool.set_swa_loc(forward_batch.out_cache_loc_swa)

ret = forward_batch.attn_backend.forward(
query, key, value, attention_layer, forward_batch, save_kv_cache, **kwargs
query,
key,
value,
attention_layer,
forward_batch,
save_kv_cache,
**kwargs,
)
assert (
output.numel() == ret.numel()
), f"Output tensor element mismatch: {output.numel()} != {ret.numel()}"
forward_batch.out_cache_loc = original_out_cache_loc
forward_batch.out_cache_loc_swa = original_out_cache_loc_swa
if original_out_cache_loc_swa is not None and hasattr(
token_to_kv_pool, "set_swa_loc"
):
token_to_kv_pool.set_swa_loc(original_swa_loc)

output.view(ret.shape).copy_(ret)
output[:real_num_tokens].view(ret.shape).copy_(ret)
return
31 changes: 23 additions & 8 deletions python/sglang/srt/layers/radix_linear_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,33 @@ def unified_linear_attention_with_output(
forward_batch = context.forward_batch
attention_layers = context.attention_layers
attention_layer = attention_layers[layer_id]
real_num_tokens = forward_batch.num_token_non_padded_cpu

original_out_cache_loc = forward_batch.out_cache_loc
original_out_cache_loc_swa = forward_batch.out_cache_loc_swa
token_to_kv_pool = forward_batch.token_to_kv_pool
original_swa_loc = getattr(token_to_kv_pool, "swa_loc", None)
# Keep the original ForwardBatch object and only narrow cache locations for
# this backend call so model/backend state is still written to the same batch.
forward_batch.out_cache_loc = original_out_cache_loc[:real_num_tokens]
if original_out_cache_loc_swa is not None:
forward_batch.out_cache_loc_swa = original_out_cache_loc_swa[:real_num_tokens]
if hasattr(token_to_kv_pool, "set_swa_loc"):
token_to_kv_pool.set_swa_loc(forward_batch.out_cache_loc_swa)

ret = forward_batch.attn_backend.forward(
layer=attention_layer,
forward_batch=forward_batch,
mixed_qkv=mixed_qkv,
a=a,
b=b,
mixed_qkv=mixed_qkv[:real_num_tokens],
a=a[:real_num_tokens],
b=b[:real_num_tokens],
)
forward_batch.out_cache_loc = original_out_cache_loc
forward_batch.out_cache_loc_swa = original_out_cache_loc_swa
if original_out_cache_loc_swa is not None and hasattr(
token_to_kv_pool, "set_swa_loc"
):
token_to_kv_pool.set_swa_loc(original_swa_loc)

assert (
output.numel() == ret.numel()
), f"Output tensor element mismatch: {output.numel()} != {ret.numel()}"

output.view(ret.shape).copy_(ret)
output[:, :real_num_tokens].copy_(ret)
return
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def warmup_compile(self, num_tokens: int):
spec_info=None,
capture_hidden_mode=CaptureHiddenMode.NULL,
num_token_non_padded=None,
num_token_non_padded_cpu=num_tokens,
global_forward_mode=ForwardMode.EXTEND,
lora_ids=None,
)
Expand Down Expand Up @@ -547,6 +548,7 @@ def capture_one_batch_size(self, num_tokens: int):
spec_info=None,
capture_hidden_mode=CaptureHiddenMode.NULL,
num_token_non_padded=None,
num_token_non_padded_cpu=num_tokens,
global_forward_mode=ForwardMode.EXTEND,
lora_ids=None,
)
Expand Down Expand Up @@ -736,6 +738,7 @@ def replay_prepare(
spec_info=forward_batch.spec_info,
capture_hidden_mode=forward_batch.capture_hidden_mode,
num_token_non_padded=forward_batch.num_token_non_padded,
num_token_non_padded_cpu=forward_batch.num_token_non_padded_cpu,
global_forward_mode=pcg_global_forward_mode,
lora_ids=forward_batch.lora_ids,
sampling_info=forward_batch.sampling_info,
Expand All @@ -757,13 +760,7 @@ def replay(
forward_batch: ForwardBatch,
**kwargs,
) -> Union[LogitsProcessorOutput, PPProxyTensors, EmbeddingPoolerOutput]:
num_tokens = len(forward_batch.input_ids)
index = bisect.bisect_left(self.capture_num_tokens, num_tokens)
static_num_tokens = self.capture_num_tokens[index]
with enable_piecewise_cuda_graph():
# Prepare static buffers first so set_forward_context can carry num_tokens
# into call_begin_forward (via ForwardContext.num_tokens), eliminating the
# need for a separate global and allowing pre-calculation of dummy-page count.
static_forward_batch = self.replay_prepare(forward_batch, **kwargs)
# Replay
with set_forward_context(
Expand All @@ -772,7 +769,6 @@ def replay(
self.quant_config,
self.moe_layers,
self.moe_fusions,
num_tokens=static_num_tokens,
):
# Due to the dispatch kernel for MLA model, we init the metadata with original forward_batch
self.model_runner.attn_backend.init_forward_metadata(forward_batch)
Expand Down
3 changes: 0 additions & 3 deletions test/registered/scheduler/test_priority_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ def setUpClass(cls):
"--max-queued-requests", # Enforce max queued request number is 3
"3",
"--enable-priority-scheduling", # Enable priority scheduling
# Disable PCG to avoid padding in flashinfer backend. Ref: https://github.com/sgl-project/sglang/pull/21452
"--disable-piecewise-cuda-graph",
),
return_stdout_stderr=(cls.stdout, cls.stderr),
)
Expand Down Expand Up @@ -249,7 +247,6 @@ def setUpClass(cls):
"--max-queued-requests", # Enforce max queued request number is 3
"3",
"--enable-priority-scheduling", # Enable priority scheduling
"--disable-piecewise-cuda-graph",
),
return_stdout_stderr=(cls.stdout, cls.stderr),
)
Expand Down
2 changes: 0 additions & 2 deletions test/registered/scheduler/test_scheduler_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,6 @@ def setUpClass(cls):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--max-running-requests=1",
# Disable PCG to avoid padding in flashinfer backend. Ref: https://github.com/sgl-project/sglang/pull/21452
"--disable-piecewise-cuda-graph",
],
)

Expand Down
Loading