diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index 44c868eb994f..130209e753ea 100755 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -186,8 +186,15 @@ def __init__( ) # aiter kernel related initialization + # Cap effective sequence length by actual KV cache capacity to avoid + # over-allocating the workspace buffer on memory-constrained GPUs. + # No single sequence can exceed max_total_num_tokens. + effective_max_seq_len = min( + self.max_context_len, + getattr(model_runner, "max_total_num_tokens", self.max_context_len), + ) self.max_num_partitions = ( - self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1 + effective_max_seq_len + _AITER_PARTITION_SIZE_ROCM - 1 ) // _AITER_PARTITION_SIZE_ROCM nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8