Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,8 +772,6 @@ def replay_prepare(
require_gathered_buffer=self.require_gathered_buffer,
num_tokens_per_bs=self.num_tokens_per_bs,
nsa_enable_prefill_cp=self.nsa_enable_prefill_cp,
attn_tp_rank=self.attn_tp_rank,
attn_tp_size=self.attn_tp_size,
enable_num_token_non_padded_flag=enable_num_token_non_padded(
self.model_runner.server_args
),
Expand Down
42 changes: 42 additions & 0 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from sglang.srt.layers.dp_attention import (
DpPaddingMode,
get_attention_dp_rank,
get_attention_tp_rank,
get_attention_tp_size,
set_dp_buffer_len,
set_is_extend_in_batch,
Expand Down Expand Up @@ -202,6 +203,26 @@ def __lt__(self, other):
return self.value < other.value


def compute_local_num_token_non_padded(
global_num_token_non_padded: torch.Tensor | int,
num_tokens_per_dp: int,
) -> torch.Tensor:
"""Compute local non-padded token count for this attention-TP rank.

Converts a global count (across all TP ranks) to a local count for this rank.
The "global" scope is within the current DP rank; DP is handled via num_tokens_per_dp.
"""
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
tokens_per_rank = num_tokens_per_dp // attn_tp_size

return torch.clamp(
global_num_token_non_padded - tokens_per_rank * attn_tp_rank,
0,
tokens_per_rank,
)


@dataclass
class ForwardBatch:
"""Store all inputs of a forward pass."""
Expand Down Expand Up @@ -503,6 +524,27 @@ def init_new(

return ret

def adjust_num_token_non_padded_for_attn_tp(self, server_args) -> None:
"""Make num_token_non_padded local to this attention-TP rank."""
from sglang.srt.utils.common import require_mlp_tp_gather

dp_rank = get_attention_dp_rank()

if require_mlp_tp_gather(server_args):
num_tokens_per_dp = self.global_num_tokens_gpu[dp_rank]
else:
num_tokens_per_dp = self.global_num_tokens_gpu[0]

self.num_token_non_padded = compute_local_num_token_non_padded(
global_num_token_non_padded=self.num_token_non_padded,
num_tokens_per_dp=num_tokens_per_dp,
)

self.num_token_non_padded_cpu = compute_local_num_token_non_padded(
global_num_token_non_padded=self.num_token_non_padded_cpu,
num_tokens_per_dp=num_tokens_per_dp,
)

def merge_mm_inputs(self) -> Optional[MultimodalInputs]:
"""
Merge all multimodal inputs in the batch into a single MultiModalInputs object.
Expand Down
22 changes: 11 additions & 11 deletions python/sglang/srt/model_executor/input_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

import torch

from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
PPProxyTensors,
compute_local_num_token_non_padded,
)


@dataclass
Expand Down Expand Up @@ -124,8 +128,6 @@ def populate_from_forward_batch(
require_gathered_buffer: bool,
num_tokens_per_bs: int,
nsa_enable_prefill_cp: bool,
attn_tp_rank: int,
attn_tp_size: int,
enable_num_token_non_padded_flag: bool,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Optional[torch.Tensor]:
Expand Down Expand Up @@ -158,17 +160,15 @@ def populate_from_forward_batch(
self.global_num_tokens_for_logprob_gpu.fill_(bs * num_tokens_per_bs)

if enable_num_token_non_padded_flag:
num_token_non_padded = forward_batch.num_token_non_padded
if require_gathered_buffer and not nsa_enable_prefill_cp:
tokens_per_rank = bs // attn_tp_size * num_tokens_per_bs
num_local_token_non_padded = torch.clamp(
num_token_non_padded - tokens_per_rank * attn_tp_rank,
min=0,
max=tokens_per_rank,
num_tokens_per_dp = bs * num_tokens_per_bs
local = compute_local_num_token_non_padded(
global_num_token_non_padded=forward_batch.num_token_non_padded,
num_tokens_per_dp=num_tokens_per_dp,
)
self.num_token_non_padded.copy_(num_local_token_non_padded)
self.num_token_non_padded.copy_(local)
else:
self.num_token_non_padded.copy_(num_token_non_padded)
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)

# Pipeline-parallel proxy tensors.
if pp_proxy_tensors is not None and self.pp_proxy_tensors is not None:
Expand Down
12 changes: 12 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
ATTENTION_BACKENDS,
attn_backend_wrapper,
)
from sglang.srt.layers.attention.nsa.utils import is_nsa_enable_prefill_cp
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
from sglang.srt.layers.dp_attention import (
DpPaddingMode,
Expand Down Expand Up @@ -2697,6 +2698,17 @@ def _forward_raw(
else:
forward_batch.prepare_attn_tp_scatter_input(self)

# Normalize num_token_non_padded to be local to this attention TP rank if needed.
if (
forward_batch.num_token_non_padded is not None
and forward_batch.global_num_tokens_gpu is not None
and require_gathered_buffer
and not is_nsa_enable_prefill_cp()
):
forward_batch.adjust_num_token_non_padded_for_attn_tp(
server_args=self.server_args,
)

if forward_batch.forward_mode.is_decode():
ret = self.forward_decode(
forward_batch,
Expand Down
Loading