Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 3 additions & 23 deletions python/sglang/srt/layers/rocm_linear_utils.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,17 @@
import torch
from aiter.ops.triton.fused_kv_cache import fused_qk_rope_cat_and_cache_mla
from aiter.ops.triton.fused_qk_concat import fused_qk_rope_cat
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
from aiter.ops.triton.gemm_a16w16_atomic import gemm_a16w16_atomic

from sglang.srt.utils import BumpAllocator
from aiter.tuned_gemm import tgemm

__all__ = ["fused_qk_rope_cat", "fused_qk_rope_cat_and_cache_mla"]


def aiter_dsv3_router_gemm(
hidden_states: torch.Tensor,
weight: torch.Tensor,
gemm_output_zero_allocator: BumpAllocator = None,
):
M = hidden_states.shape[0]
N = weight.shape[0]
y = None

if M <= 256:
# TODO (cagri): convert to bfloat16 as part of another kernel to save time
# for now it is also coupled with zero allocator.
if gemm_output_zero_allocator != None:
y = gemm_output_zero_allocator.allocate(M * N).view(M, N)
else:
y = torch.zeros((M, N), dtype=torch.float32, device=hidden_states.device)

if y is not None:
logits = gemm_a16w16_atomic(hidden_states, weight, y=y).to(hidden_states.dtype)
else:
logits = gemm_a16w16(hidden_states, weight)

return logits
"""Use aiter tuned GEMM dispatcher (tgemm.mm) to automatically select the GEMM kernel."""
return tgemm.mm(hidden_states, weight, otype=hidden_states.dtype)


def get_dsv3_gemm_output_zero_allocator_size(
Expand Down
14 changes: 5 additions & 9 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,11 @@
use_intel_amx_backend,
)

if _use_aiter:
from sglang.srt.layers.rocm_linear_utils import aiter_dsv3_router_gemm

if _use_aiter_gfx95:
from sglang.srt.layers.rocm_linear_utils import (
aiter_dsv3_router_gemm,
get_dsv3_gemm_output_zero_allocator_size,
)

Expand Down Expand Up @@ -327,14 +329,8 @@ def forward(
logits = dsv3_router_gemm(
hidden_states, self.weight, out_dtype=torch.float32
)
elif (
_use_aiter_gfx95
and hidden_states.shape[0] <= 256
and self.weight.shape[0] <= 256
):
logits = aiter_dsv3_router_gemm(
hidden_states, self.weight, gemm_output_zero_allocator
)
elif _use_aiter:
logits = aiter_dsv3_router_gemm(hidden_states, self.weight)
else:
logits = F.linear(hidden_states, self.weight, None)

Expand Down
Loading