diff --git a/python/sglang/srt/compilation/piecewise_context_manager.py b/python/sglang/srt/compilation/piecewise_context_manager.py index a49e9ad47a37..20a08a9972b9 100644 --- a/python/sglang/srt/compilation/piecewise_context_manager.py +++ b/python/sglang/srt/compilation/piecewise_context_manager.py @@ -71,7 +71,6 @@ def __init__(self): self.quant_config = None self.moe_layers = None self.moe_fusions = None - self.num_tokens: Optional[int] = None def set_forward_batch(self, forward_batch: ForwardBatch): self.forward_batch = forward_batch @@ -105,7 +104,6 @@ def set_forward_context( quant_config: Any, moe_layers: List[Any], moe_fusions: List[Any], - num_tokens: Optional[int] = None, ): global _forward_context _forward_context = ForwardContext() @@ -114,7 +112,6 @@ def set_forward_context( _forward_context.set_quant_config(quant_config) _forward_context.set_moe_layers(moe_layers) _forward_context.set_moe_fusions(moe_fusions) - _forward_context.num_tokens = num_tokens try: yield finally: diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index c1e2ea4fcdab..c8128058ea34 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -17,10 +17,7 @@ import torch from sglang.kernel_api_logging import debug_kernel_api -from sglang.srt.compilation.piecewise_context_manager import ( - get_forward_context, - is_in_piecewise_cuda_graph, -) +from sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph from sglang.srt.dllm.config import DllmConfig from sglang.srt.environ import envs from sglang.srt.layers.attention.base_attn_backend import AttentionBackend @@ -150,8 +147,6 @@ def __init__( self.max_context_len = model_runner.model_config.context_len self.skip_prefill = skip_prefill self.is_multimodal = model_runner.model_config.is_multimodal - self.page_size = model_runner.page_size - assert not ( model_runner.sliding_window_size is not None and model_runner.model_config.is_encoder_decoder @@ -1210,8 +1205,6 @@ def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBacken self.q_data_type = model_runner.dtype self.sliding_window_size = model_runner.sliding_window_size self.attn_backend = attn_backend - self.page_size = attn_backend.page_size - # Buffers and wrappers self.kv_indptr = attn_backend.kv_indptr self.kv_last_page_len = attn_backend.kv_last_page_len @@ -1400,13 +1393,8 @@ def call_begin_forward( # Normal extend kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] - # Reserve extra space in kv_indices for a potential piecewise CUDA graph - # dummy request (see below). Worst case: static_num_tokens extra pages. - fwd_ctx = get_forward_context() - pcg_num_tokens = fwd_ctx.num_tokens if fwd_ctx is not None else None - extra_kv = pcg_num_tokens if pcg_num_tokens is not None else 0 kv_indices = torch.empty( - paged_kernel_lens_sum + extra_kv + 256, + paged_kernel_lens_sum + 256, dtype=torch.int32, device=req_pool_indices.device, ) @@ -1422,39 +1410,6 @@ def call_begin_forward( qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) qo_indptr = qo_indptr[: bs + 1] - # Piecewise CUDA graph padding: input_ids are padded to static_num_tokens, - # so q.shape[0] == static_num_tokens but qo_indptr[-1] == actual tokens. - # Append a dummy request for the padding tokens so that - # qo_indptr[-1] == static_num_tokens, satisfying flashinfer's shape check - # without corrupting the causal masks of real requests. - # The dummy request's KV indices all point to slot 0 (a scratch location); - # its attention output is discarded via the [:raw_num_tokens] slice in replay. - bs_eff = bs - # extend_num_tokens is a Python int (== sum of seq_lens - prefix_lens), - # and paged_kernel_lens_sum is also a Python int (== kv_indptr[-1]), - # so this block requires no CPU-GPU synchronisation. - actual_qo_tokens = ( - fwd_ctx.forward_batch.extend_num_tokens if fwd_ctx is not None else None - ) - if ( - pcg_num_tokens is not None - and actual_qo_tokens is not None - and pcg_num_tokens > actual_qo_tokens - ): - pad_tokens = pcg_num_tokens - actual_qo_tokens - num_dummy_pages = (pad_tokens + self.page_size - 1) // self.page_size - kv_start = ( - paged_kernel_lens_sum # equals kv_indptr[-1], no .item() needed - ) - kv_indices[kv_start : kv_start + num_dummy_pages] = 0 - qo_indptr = torch.cat( - [qo_indptr, qo_indptr.new_tensor([pcg_num_tokens])] - ) - kv_indptr = torch.cat( - [kv_indptr, kv_indptr.new_tensor([kv_start + num_dummy_pages])] - ) - bs_eff = bs + 1 - custom_mask = None else: assert isinstance(spec_info, SpecInput) @@ -1466,7 +1421,6 @@ def call_begin_forward( self.req_to_token, ) ) - bs_eff = bs # extend part if use_ragged: @@ -1508,7 +1462,7 @@ def call_begin_forward( qo_indptr, kv_indptr, kv_indices, - self.kv_last_page_len[:bs_eff], + self.kv_last_page_len[:bs], self.num_qo_heads, self.num_kv_heads, self.head_dim, diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 449bc867067e..b16ac22c78b3 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -153,21 +153,47 @@ def unified_attention_with_output( forward_batch = context.forward_batch attention_layers = context.attention_layers attention_layer = attention_layers[layer_id] + real_num_tokens = forward_batch.num_token_non_padded_cpu + + query = query[:real_num_tokens] + key = key[:real_num_tokens] + value = value[:real_num_tokens] kwargs = {} if q_rope is not None: - kwargs["q_rope"] = q_rope + kwargs["q_rope"] = q_rope[:real_num_tokens] if k_rope is not None: - kwargs["k_rope"] = k_rope + kwargs["k_rope"] = k_rope[:real_num_tokens] if sinks is not None: kwargs["sinks"] = sinks + original_out_cache_loc = forward_batch.out_cache_loc + original_out_cache_loc_swa = forward_batch.out_cache_loc_swa + token_to_kv_pool = forward_batch.token_to_kv_pool + original_swa_loc = getattr(token_to_kv_pool, "swa_loc", None) + # Keep the original ForwardBatch object and only narrow cache locations for + # this backend call so model/backend state is still written to the same batch. + forward_batch.out_cache_loc = original_out_cache_loc[:real_num_tokens] + if original_out_cache_loc_swa is not None: + forward_batch.out_cache_loc_swa = original_out_cache_loc_swa[:real_num_tokens] + if hasattr(token_to_kv_pool, "set_swa_loc"): + token_to_kv_pool.set_swa_loc(forward_batch.out_cache_loc_swa) + ret = forward_batch.attn_backend.forward( - query, key, value, attention_layer, forward_batch, save_kv_cache, **kwargs + query, + key, + value, + attention_layer, + forward_batch, + save_kv_cache, + **kwargs, ) - assert ( - output.numel() == ret.numel() - ), f"Output tensor element mismatch: {output.numel()} != {ret.numel()}" + forward_batch.out_cache_loc = original_out_cache_loc + forward_batch.out_cache_loc_swa = original_out_cache_loc_swa + if original_out_cache_loc_swa is not None and hasattr( + token_to_kv_pool, "set_swa_loc" + ): + token_to_kv_pool.set_swa_loc(original_swa_loc) - output.view(ret.shape).copy_(ret) + output[:real_num_tokens].view(ret.shape).copy_(ret) return diff --git a/python/sglang/srt/layers/radix_linear_attention.py b/python/sglang/srt/layers/radix_linear_attention.py index 89821e774460..1e860d5f7c5d 100644 --- a/python/sglang/srt/layers/radix_linear_attention.py +++ b/python/sglang/srt/layers/radix_linear_attention.py @@ -117,18 +117,33 @@ def unified_linear_attention_with_output( forward_batch = context.forward_batch attention_layers = context.attention_layers attention_layer = attention_layers[layer_id] + real_num_tokens = forward_batch.num_token_non_padded_cpu + + original_out_cache_loc = forward_batch.out_cache_loc + original_out_cache_loc_swa = forward_batch.out_cache_loc_swa + token_to_kv_pool = forward_batch.token_to_kv_pool + original_swa_loc = getattr(token_to_kv_pool, "swa_loc", None) + # Keep the original ForwardBatch object and only narrow cache locations for + # this backend call so model/backend state is still written to the same batch. + forward_batch.out_cache_loc = original_out_cache_loc[:real_num_tokens] + if original_out_cache_loc_swa is not None: + forward_batch.out_cache_loc_swa = original_out_cache_loc_swa[:real_num_tokens] + if hasattr(token_to_kv_pool, "set_swa_loc"): + token_to_kv_pool.set_swa_loc(forward_batch.out_cache_loc_swa) ret = forward_batch.attn_backend.forward( layer=attention_layer, forward_batch=forward_batch, - mixed_qkv=mixed_qkv, - a=a, - b=b, + mixed_qkv=mixed_qkv[:real_num_tokens], + a=a[:real_num_tokens], + b=b[:real_num_tokens], ) + forward_batch.out_cache_loc = original_out_cache_loc + forward_batch.out_cache_loc_swa = original_out_cache_loc_swa + if original_out_cache_loc_swa is not None and hasattr( + token_to_kv_pool, "set_swa_loc" + ): + token_to_kv_pool.set_swa_loc(original_swa_loc) - assert ( - output.numel() == ret.numel() - ), f"Output tensor element mismatch: {output.numel()} != {ret.numel()}" - - output.view(ret.shape).copy_(ret) + output[:, :real_num_tokens].copy_(ret) return diff --git a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py index efec70dc3e11..932a15e71a01 100644 --- a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py @@ -387,6 +387,7 @@ def warmup_compile(self, num_tokens: int): spec_info=None, capture_hidden_mode=CaptureHiddenMode.NULL, num_token_non_padded=None, + num_token_non_padded_cpu=num_tokens, global_forward_mode=ForwardMode.EXTEND, lora_ids=None, ) @@ -547,6 +548,7 @@ def capture_one_batch_size(self, num_tokens: int): spec_info=None, capture_hidden_mode=CaptureHiddenMode.NULL, num_token_non_padded=None, + num_token_non_padded_cpu=num_tokens, global_forward_mode=ForwardMode.EXTEND, lora_ids=None, ) @@ -736,6 +738,7 @@ def replay_prepare( spec_info=forward_batch.spec_info, capture_hidden_mode=forward_batch.capture_hidden_mode, num_token_non_padded=forward_batch.num_token_non_padded, + num_token_non_padded_cpu=forward_batch.num_token_non_padded_cpu, global_forward_mode=pcg_global_forward_mode, lora_ids=forward_batch.lora_ids, sampling_info=forward_batch.sampling_info, @@ -757,13 +760,7 @@ def replay( forward_batch: ForwardBatch, **kwargs, ) -> Union[LogitsProcessorOutput, PPProxyTensors, EmbeddingPoolerOutput]: - num_tokens = len(forward_batch.input_ids) - index = bisect.bisect_left(self.capture_num_tokens, num_tokens) - static_num_tokens = self.capture_num_tokens[index] with enable_piecewise_cuda_graph(): - # Prepare static buffers first so set_forward_context can carry num_tokens - # into call_begin_forward (via ForwardContext.num_tokens), eliminating the - # need for a separate global and allowing pre-calculation of dummy-page count. static_forward_batch = self.replay_prepare(forward_batch, **kwargs) # Replay with set_forward_context( @@ -772,7 +769,6 @@ def replay( self.quant_config, self.moe_layers, self.moe_fusions, - num_tokens=static_num_tokens, ): # Due to the dispatch kernel for MLA model, we init the metadata with original forward_batch self.model_runner.attn_backend.init_forward_metadata(forward_batch) diff --git a/test/registered/scheduler/test_priority_scheduling.py b/test/registered/scheduler/test_priority_scheduling.py index c1244596ccc4..d9328c85749b 100644 --- a/test/registered/scheduler/test_priority_scheduling.py +++ b/test/registered/scheduler/test_priority_scheduling.py @@ -41,8 +41,6 @@ def setUpClass(cls): "--max-queued-requests", # Enforce max queued request number is 3 "3", "--enable-priority-scheduling", # Enable priority scheduling - # Disable PCG to avoid padding in flashinfer backend. Ref: https://github.com/sgl-project/sglang/pull/21452 - "--disable-piecewise-cuda-graph", ), return_stdout_stderr=(cls.stdout, cls.stderr), ) @@ -249,7 +247,6 @@ def setUpClass(cls): "--max-queued-requests", # Enforce max queued request number is 3 "3", "--enable-priority-scheduling", # Enable priority scheduling - "--disable-piecewise-cuda-graph", ), return_stdout_stderr=(cls.stdout, cls.stderr), ) diff --git a/test/registered/scheduler/test_scheduler_control.py b/test/registered/scheduler/test_scheduler_control.py index 2042ff00febf..3e6d967d2328 100644 --- a/test/registered/scheduler/test_scheduler_control.py +++ b/test/registered/scheduler/test_scheduler_control.py @@ -318,8 +318,6 @@ def setUpClass(cls): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--max-running-requests=1", - # Disable PCG to avoid padding in flashinfer backend. Ref: https://github.com/sgl-project/sglang/pull/21452 - "--disable-piecewise-cuda-graph", ], )