Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@
set_graph_pool_id,
)
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size
from sglang.srt.layers.dp_attention import (
DPPaddingMode,
get_attention_tp_rank,
get_attention_tp_size,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.torchao_utils import save_gemlite_cache
from sglang.srt.model_executor.forward_batch_info import (
Expand Down Expand Up @@ -255,6 +259,9 @@ def __init__(self, model_runner: ModelRunner):
self.dp_size = model_runner.server_args.dp_size
self.pp_size = model_runner.server_args.pp_size

self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()

# Batch sizes to capture
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
rank0_log(f"Capture cuda graph bs {self.capture_bs}")
Expand Down Expand Up @@ -749,7 +756,17 @@ def replay_prepare(
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
if enable_num_token_non_padded(self.model_runner.server_args):
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
num_token_non_padded = forward_batch.num_token_non_padded
if self.require_gathered_buffer:
tokens_per_rank = bs // self.attn_tp_size * self.num_tokens_per_bs
num_local_token_non_padded = torch.clamp(
num_token_non_padded - tokens_per_rank * self.attn_tp_rank,
min=0,
max=tokens_per_rank,
)
self.num_token_non_padded.copy_(num_local_token_non_padded)
else:
self.num_token_non_padded.copy_(num_token_non_padded)
if self.enable_two_batch_overlap:
self.tbo_plugin.replay_prepare(
forward_mode=self.capture_forward_mode,
Expand Down
Loading