diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 1fd483da3752..cd686d0c8d06 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -772,8 +772,6 @@ def replay_prepare( require_gathered_buffer=self.require_gathered_buffer, num_tokens_per_bs=self.num_tokens_per_bs, nsa_enable_prefill_cp=self.nsa_enable_prefill_cp, - attn_tp_rank=self.attn_tp_rank, - attn_tp_size=self.attn_tp_size, enable_num_token_non_padded_flag=enable_num_token_non_padded( self.model_runner.server_args ), diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index d4ed92d0ec2d..0b3d74b78332 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -47,6 +47,7 @@ from sglang.srt.layers.dp_attention import ( DpPaddingMode, get_attention_dp_rank, + get_attention_tp_rank, get_attention_tp_size, set_dp_buffer_len, set_is_extend_in_batch, @@ -202,6 +203,26 @@ def __lt__(self, other): return self.value < other.value +def compute_local_num_token_non_padded( + global_num_token_non_padded: torch.Tensor | int, + num_tokens_per_dp: int, +) -> torch.Tensor: + """Compute local non-padded token count for this attention-TP rank. + + Converts a global count (across all TP ranks) to a local count for this rank. + The "global" scope is within the current DP rank; DP is handled via num_tokens_per_dp. + """ + attn_tp_rank = get_attention_tp_rank() + attn_tp_size = get_attention_tp_size() + tokens_per_rank = num_tokens_per_dp // attn_tp_size + + return torch.clamp( + global_num_token_non_padded - tokens_per_rank * attn_tp_rank, + 0, + tokens_per_rank, + ) + + @dataclass class ForwardBatch: """Store all inputs of a forward pass.""" @@ -503,6 +524,27 @@ def init_new( return ret + def adjust_num_token_non_padded_for_attn_tp(self, server_args) -> None: + """Make num_token_non_padded local to this attention-TP rank.""" + from sglang.srt.utils.common import require_mlp_tp_gather + + dp_rank = get_attention_dp_rank() + + if require_mlp_tp_gather(server_args): + num_tokens_per_dp = self.global_num_tokens_gpu[dp_rank] + else: + num_tokens_per_dp = self.global_num_tokens_gpu[0] + + self.num_token_non_padded = compute_local_num_token_non_padded( + global_num_token_non_padded=self.num_token_non_padded, + num_tokens_per_dp=num_tokens_per_dp, + ) + + self.num_token_non_padded_cpu = compute_local_num_token_non_padded( + global_num_token_non_padded=self.num_token_non_padded_cpu, + num_tokens_per_dp=num_tokens_per_dp, + ) + def merge_mm_inputs(self) -> Optional[MultimodalInputs]: """ Merge all multimodal inputs in the batch into a single MultiModalInputs object. diff --git a/python/sglang/srt/model_executor/input_buffers.py b/python/sglang/srt/model_executor/input_buffers.py index ec8d23476a5a..b070028e6b74 100644 --- a/python/sglang/srt/model_executor/input_buffers.py +++ b/python/sglang/srt/model_executor/input_buffers.py @@ -5,7 +5,11 @@ import torch -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors +from sglang.srt.model_executor.forward_batch_info import ( + ForwardBatch, + PPProxyTensors, + compute_local_num_token_non_padded, +) @dataclass @@ -124,8 +128,6 @@ def populate_from_forward_batch( require_gathered_buffer: bool, num_tokens_per_bs: int, nsa_enable_prefill_cp: bool, - attn_tp_rank: int, - attn_tp_size: int, enable_num_token_non_padded_flag: bool, pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> Optional[torch.Tensor]: @@ -158,17 +160,15 @@ def populate_from_forward_batch( self.global_num_tokens_for_logprob_gpu.fill_(bs * num_tokens_per_bs) if enable_num_token_non_padded_flag: - num_token_non_padded = forward_batch.num_token_non_padded if require_gathered_buffer and not nsa_enable_prefill_cp: - tokens_per_rank = bs // attn_tp_size * num_tokens_per_bs - num_local_token_non_padded = torch.clamp( - num_token_non_padded - tokens_per_rank * attn_tp_rank, - min=0, - max=tokens_per_rank, + num_tokens_per_dp = bs * num_tokens_per_bs + local = compute_local_num_token_non_padded( + global_num_token_non_padded=forward_batch.num_token_non_padded, + num_tokens_per_dp=num_tokens_per_dp, ) - self.num_token_non_padded.copy_(num_local_token_non_padded) + self.num_token_non_padded.copy_(local) else: - self.num_token_non_padded.copy_(num_token_non_padded) + self.num_token_non_padded.copy_(forward_batch.num_token_non_padded) # Pipeline-parallel proxy tensors. if pp_proxy_tensors is not None and self.pp_proxy_tensors is not None: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b8b1b96272fd..1bd2ae43b866 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -84,6 +84,7 @@ ATTENTION_BACKENDS, attn_backend_wrapper, ) +from sglang.srt.layers.attention.nsa.utils import is_nsa_enable_prefill_cp from sglang.srt.layers.attention.tbo_backend import TboAttnBackend from sglang.srt.layers.dp_attention import ( DpPaddingMode, @@ -2697,6 +2698,17 @@ def _forward_raw( else: forward_batch.prepare_attn_tp_scatter_input(self) + # Normalize num_token_non_padded to be local to this attention TP rank if needed. + if ( + forward_batch.num_token_non_padded is not None + and forward_batch.global_num_tokens_gpu is not None + and require_gathered_buffer + and not is_nsa_enable_prefill_cp() + ): + forward_batch.adjust_num_token_non_padded_for_attn_tp( + server_args=self.server_args, + ) + if forward_batch.forward_mode.is_decode(): ret = self.forward_decode( forward_batch,