From 36de6d70fd609d238a80751223f6cb75e314729e Mon Sep 17 00:00:00 2001 From: Michael <13900043+michaelzhang-ai@users.noreply.github.com> Date: Wed, 18 Mar 2026 18:21:57 -0500 Subject: [PATCH] fix(aiter): cap workspace buffer partitions by KV cache capacity to prevent OOM The aiter backend computes max_num_partitions from max_context_len (e.g. 131K for Llama 3.1), which can produce a workspace buffer exceeding available GPU memory on constrained setups. Since no single sequence can exceed max_total_num_tokens (the actual KV cache capacity), use that as an upper bound to right-size the allocation. --- python/sglang/srt/layers/attention/aiter_backend.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index 44c868eb994f..130209e753ea 100755 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -186,8 +186,15 @@ def __init__( ) # aiter kernel related initialization + # Cap effective sequence length by actual KV cache capacity to avoid + # over-allocating the workspace buffer on memory-constrained GPUs. + # No single sequence can exceed max_total_num_tokens. + effective_max_seq_len = min( + self.max_context_len, + getattr(model_runner, "max_total_num_tokens", self.max_context_len), + ) self.max_num_partitions = ( - self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1 + effective_max_seq_len + _AITER_PARTITION_SIZE_ROCM - 1 ) // _AITER_PARTITION_SIZE_ROCM nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8