diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index a1edcf5b4d44..0b5a982f2118 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -52,6 +52,7 @@ from aiter import ActivationType from aiter.fused_moe import fused_moe from aiter.ops.shuffle import shuffle_weight + from aiter.tuned_gemm import tgemm if _is_npu: from sglang.srt.hardware_backend.npu.utils import npu_format_cast @@ -150,6 +151,9 @@ def apply( output = output.view(x_shapes[0], x_shapes[1], -1) return output + if _use_aiter and type(layer.weight.data) is torch.Tensor: + return tgemm.mm(x, layer.weight, bias, otype=x.dtype) + return F.linear(x, layer.weight, bias)