Skip to content
Merged
4 changes: 3 additions & 1 deletion python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
ForwardBatch,
ForwardMode,
PPProxyTensors,
enable_num_token_non_padded,
)
from sglang.srt.patch_torch import monkey_patch_torch_compile
from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
Expand Down Expand Up @@ -550,7 +551,8 @@ def replay_prepare(
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
self.positions[:raw_num_token].copy_(forward_batch.positions)
num_token_non_padded = len(forward_batch.input_ids)
self.num_token_non_padded[...] = num_token_non_padded
if enable_num_token_non_padded(self.model_runner.server_args):
self.num_token_non_padded[...] = num_token_non_padded
self.tbo_plugin.replay_prepare(
forward_mode=forward_batch.forward_mode,
bs=bs,
Expand Down
15 changes: 12 additions & 3 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,13 @@ def init_new(
extend_input_logprob_token_ids_gpu = (
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
)

num_token_non_padded = None
if enable_num_token_non_padded(model_runner.server_args):
num_token_non_padded = torch.tensor(
len(batch.input_ids), dtype=torch.int32
).to(device, non_blocking=True)

ret = cls(
forward_mode=batch.forward_mode,
batch_size=len(batch.seq_lens),
Expand Down Expand Up @@ -300,9 +307,7 @@ def init_new(
capture_hidden_mode=batch.capture_hidden_mode,
input_embeds=batch.input_embeds,
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
num_token_non_padded=torch.tensor(
len(batch.input_ids), dtype=torch.int32
).to(device, non_blocking=True),
num_token_non_padded=num_token_non_padded,
tbo_split_seq_index=batch.tbo_split_seq_index,
)

Expand Down Expand Up @@ -606,6 +611,10 @@ def can_run_tbo(self):
return self.tbo_split_seq_index is not None


def enable_num_token_non_padded(server_args):
return server_args.enable_ep_moe or server_args.enable_deepep_moe


class PPProxyTensors:
# adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
tensors: Dict[str, torch.Tensor]
Expand Down
Loading