Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from sglang.srt.distributed.parallel_state import graph_capture
from sglang.srt.layers.dp_attention import (
DpPaddingMode,
get_attention_cp_size,
get_attention_tp_rank,
get_attention_tp_size,
set_dp_buffer_len,
Expand All @@ -58,7 +59,12 @@
PPProxyTensors,
)
from sglang.srt.model_executor.input_buffers import ForwardInputBuffers
from sglang.srt.utils import get_available_gpu_memory, is_npu, log_info_on_rank0
from sglang.srt.utils import (
get_available_gpu_memory,
is_npu,
log_info_on_rank0,
require_gathered_buffer,
)

# Suppress Dynamo warning about tracing through lru_cache-wrapped functions (e.g., is_arch_support_pdl).
warnings.filterwarnings("ignore", message=".*lru_cache.*", module="torch._dynamo")
Expand Down Expand Up @@ -189,6 +195,21 @@ def __init__(self, model_runner: ModelRunner):

# Batch sizes to capture
self.capture_num_tokens = self.compile_config.get_capture_sizes()
# When the layer communicator scatters/gathers across the attention TP
# group (e.g. with --moe-dense-tp-size 1), the model's reduce_scatter
# requires the token count to be divisible by attn_tp_size * attn_cp_size.
# Drop captures that would violate this (mirrors the filter used by
# the regular CUDA graph runner in get_batch_sizes_to_capture).
if require_gathered_buffer(self.model_runner.server_args):
mul_base = self.attn_tp_size
attn_cp_size = get_attention_cp_size()
if mul_base % attn_cp_size != 0:
mul_base *= attn_cp_size
filtered = [n for n in self.capture_num_tokens if n % mul_base == 0]
assert (
len(filtered) > 0
), f"No piecewise CUDA graph capture sizes are multiples of {mul_base}"
self.capture_num_tokens = filtered
log_info_on_rank0(
logger, f"Capture cuda graph num tokens {self.capture_num_tokens}"
)
Expand Down
Loading