diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index 44c868eb994f..130209e753ea 100755 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -186,8 +186,15 @@ def __init__( ) # aiter kernel related initialization + # Cap effective sequence length by actual KV cache capacity to avoid + # over-allocating the workspace buffer on memory-constrained GPUs. + # No single sequence can exceed max_total_num_tokens. + effective_max_seq_len = min( + self.max_context_len, + getattr(model_runner, "max_total_num_tokens", self.max_context_len), + ) self.max_num_partitions = ( - self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1 + effective_max_seq_len + _AITER_PARTITION_SIZE_ROCM - 1 ) // _AITER_PARTITION_SIZE_ROCM nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8 diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index a1edcf5b4d44..0b5a982f2118 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -52,6 +52,7 @@ from aiter import ActivationType from aiter.fused_moe import fused_moe from aiter.ops.shuffle import shuffle_weight + from aiter.tuned_gemm import tgemm if _is_npu: from sglang.srt.hardware_backend.npu.utils import npu_format_cast @@ -150,6 +151,9 @@ def apply( output = output.view(x_shapes[0], x_shapes[1], -1) return output + if _use_aiter and type(layer.weight.data) is torch.Tensor: + return tgemm.mm(x, layer.weight, bias, otype=x.dtype) + return F.linear(x, layer.weight, bias)