diff --git a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py index f4ba8281d63c..17834529259f 100644 --- a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py @@ -42,6 +42,7 @@ from sglang.srt.distributed.parallel_state import graph_capture from sglang.srt.layers.dp_attention import ( DpPaddingMode, + get_attention_cp_size, get_attention_tp_rank, get_attention_tp_size, set_dp_buffer_len, @@ -58,7 +59,12 @@ PPProxyTensors, ) from sglang.srt.model_executor.input_buffers import ForwardInputBuffers -from sglang.srt.utils import get_available_gpu_memory, is_npu, log_info_on_rank0 +from sglang.srt.utils import ( + get_available_gpu_memory, + is_npu, + log_info_on_rank0, + require_gathered_buffer, +) # Suppress Dynamo warning about tracing through lru_cache-wrapped functions (e.g., is_arch_support_pdl). warnings.filterwarnings("ignore", message=".*lru_cache.*", module="torch._dynamo") @@ -189,6 +195,21 @@ def __init__(self, model_runner: ModelRunner): # Batch sizes to capture self.capture_num_tokens = self.compile_config.get_capture_sizes() + # When the layer communicator scatters/gathers across the attention TP + # group (e.g. with --moe-dense-tp-size 1), the model's reduce_scatter + # requires the token count to be divisible by attn_tp_size * attn_cp_size. + # Drop captures that would violate this (mirrors the filter used by + # the regular CUDA graph runner in get_batch_sizes_to_capture). + if require_gathered_buffer(self.model_runner.server_args): + mul_base = self.attn_tp_size + attn_cp_size = get_attention_cp_size() + if mul_base % attn_cp_size != 0: + mul_base *= attn_cp_size + filtered = [n for n in self.capture_num_tokens if n % mul_base == 0] + assert ( + len(filtered) > 0 + ), f"No piecewise CUDA graph capture sizes are multiples of {mul_base}" + self.capture_num_tokens = filtered log_info_on_rank0( logger, f"Capture cuda graph num tokens {self.capture_num_tokens}" )