From 236f1244f2a2b8b59f5641bae493bb5db33603af Mon Sep 17 00:00:00 2001 From: Michael <13900043+michaelzhang-ai@users.noreply.github.com> Date: Wed, 18 Mar 2026 18:35:17 -0500 Subject: [PATCH] fix(aiter): workspace buffer OOM and tuned GEMM torchao compatibility Two fixes for aiter backend failures surfaced by PR #20392: 1. aiter_backend.py: Cap max_num_partitions by min(max_context_len, max_total_num_tokens). The workspace buffer was sized for the model's theoretical max context (e.g. 131K = 512 partitions = 16 GiB) when the KV cache only held 25K tokens (100 partitions = 3 GiB), causing OOM on memory-constrained CI GPUs. 2. unquant.py: Add aiter tgemm.mm fast path for unquantized linear ops, guarded by type(layer.weight.data) is torch.Tensor. Torchao-quantized weights (AffineQuantizedTensor) fail the strict type() check and fall through to F.linear, preventing NotImplementedError on gemm_a16w16. --- python/sglang/srt/layers/attention/aiter_backend.py | 9 ++++++++- python/sglang/srt/layers/quantization/unquant.py | 4 ++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index 44c868eb994f..130209e753ea 100755 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -186,8 +186,15 @@ def __init__( ) # aiter kernel related initialization + # Cap effective sequence length by actual KV cache capacity to avoid + # over-allocating the workspace buffer on memory-constrained GPUs. + # No single sequence can exceed max_total_num_tokens. + effective_max_seq_len = min( + self.max_context_len, + getattr(model_runner, "max_total_num_tokens", self.max_context_len), + ) self.max_num_partitions = ( - self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1 + effective_max_seq_len + _AITER_PARTITION_SIZE_ROCM - 1 ) // _AITER_PARTITION_SIZE_ROCM nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8 diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index a1edcf5b4d44..0b5a982f2118 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -52,6 +52,7 @@ from aiter import ActivationType from aiter.fused_moe import fused_moe from aiter.ops.shuffle import shuffle_weight + from aiter.tuned_gemm import tgemm if _is_npu: from sglang.srt.hardware_backend.npu.utils import npu_format_cast @@ -150,6 +151,9 @@ def apply( output = output.view(x_shapes[0], x_shapes[1], -1) return output + if _use_aiter and type(layer.weight.data) is torch.Tensor: + return tgemm.mm(x, layer.weight, bias, otype=x.dtype) + return F.linear(x, layer.weight, bias)