diff --git a/python/sglang/srt/layers/rocm_linear_utils.py b/python/sglang/srt/layers/rocm_linear_utils.py index 6c8a6a367e54..ae31553c95e9 100644 --- a/python/sglang/srt/layers/rocm_linear_utils.py +++ b/python/sglang/srt/layers/rocm_linear_utils.py @@ -1,10 +1,7 @@ import torch from aiter.ops.triton.fused_kv_cache import fused_qk_rope_cat_and_cache_mla from aiter.ops.triton.fused_qk_concat import fused_qk_rope_cat -from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 -from aiter.ops.triton.gemm_a16w16_atomic import gemm_a16w16_atomic - -from sglang.srt.utils import BumpAllocator +from aiter.tuned_gemm import tgemm __all__ = ["fused_qk_rope_cat", "fused_qk_rope_cat_and_cache_mla"] @@ -12,26 +9,9 @@ def aiter_dsv3_router_gemm( hidden_states: torch.Tensor, weight: torch.Tensor, - gemm_output_zero_allocator: BumpAllocator = None, ): - M = hidden_states.shape[0] - N = weight.shape[0] - y = None - - if M <= 256: - # TODO (cagri): convert to bfloat16 as part of another kernel to save time - # for now it is also coupled with zero allocator. - if gemm_output_zero_allocator != None: - y = gemm_output_zero_allocator.allocate(M * N).view(M, N) - else: - y = torch.zeros((M, N), dtype=torch.float32, device=hidden_states.device) - - if y is not None: - logits = gemm_a16w16_atomic(hidden_states, weight, y=y).to(hidden_states.dtype) - else: - logits = gemm_a16w16(hidden_states, weight) - - return logits + """Use aiter tuned GEMM dispatcher (tgemm.mm) to automatically select the GEMM kernel.""" + return tgemm.mm(hidden_states, weight, otype=hidden_states.dtype) def get_dsv3_gemm_output_zero_allocator_size( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 0a2aa07e00b6..6fae78248777 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -153,9 +153,11 @@ use_intel_amx_backend, ) +if _use_aiter: + from sglang.srt.layers.rocm_linear_utils import aiter_dsv3_router_gemm + if _use_aiter_gfx95: from sglang.srt.layers.rocm_linear_utils import ( - aiter_dsv3_router_gemm, get_dsv3_gemm_output_zero_allocator_size, ) @@ -327,14 +329,8 @@ def forward( logits = dsv3_router_gemm( hidden_states, self.weight, out_dtype=torch.float32 ) - elif ( - _use_aiter_gfx95 - and hidden_states.shape[0] <= 256 - and self.weight.shape[0] <= 256 - ): - logits = aiter_dsv3_router_gemm( - hidden_states, self.weight, gemm_output_zero_allocator - ) + elif _use_aiter: + logits = aiter_dsv3_router_gemm(hidden_states, self.weight) else: logits = F.linear(hidden_states, self.weight, None)