Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import functools
import os
from typing import Any, Dict, List, Optional

Expand All @@ -20,9 +21,11 @@
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
get_device_name,
is_cpu,
is_cuda,
is_hip,
is_sm90_supported,
)

try:
Expand Down Expand Up @@ -52,6 +55,24 @@ def support_tensor_descriptor():
return _support_tensor_descriptor


# In theory, swap_ab should benefit all SM90 GPUs.
# However, since it has only been verified on H20 (not H100/H200),
# it is currently enabled only on H20.
@functools.lru_cache(maxsize=8)
def should_enable_swap_ab(
BLOCK_SIZE_M: int,
BLOCK_SIZE_N: int,
) -> bool:
device_name = get_device_name()
is_h20_device = device_name and "H20" in device_name and "H200" not in device_name
return (
is_h20_device
and is_sm90_supported()
and BLOCK_SIZE_M < 64
and BLOCK_SIZE_N >= 64
)


@triton.jit
def write_zeros_to_output(
c_ptr,
Expand Down Expand Up @@ -360,6 +381,7 @@ def fused_moe_kernel(
even_Ks: tl.constexpr,
c_sorted: tl.constexpr,
filter_expert: tl.constexpr,
swap_ab: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
Expand Down Expand Up @@ -498,7 +520,10 @@ def fused_moe_kernel(
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
if swap_ab:
accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_M), dtype=tl.float32)
else:
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

for k_start in range(0, K, BLOCK_SIZE_K):
# Load the next block of A and B, generate a mask by checking the
Expand Down Expand Up @@ -539,12 +564,17 @@ def fused_moe_kernel(
a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
)
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
if swap_ab:
a, b = tl.trans(b, (1, 0)), tl.trans(a, (1, 0))
a_scale, b_scale = b_scale, a_scale
if BLOCK_SIZE_N > group_n:
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
else:
accumulator += tl.dot(a, b) * (a_scale[:, None] * b_scale)
else:
if use_fp8_w8a8:
if swap_ab:
a, b = tl.trans(b, (1, 0)), tl.trans(a, (1, 0))
accumulator = tl.dot(a, b, acc=accumulator)
else:
accumulator += tl.dot(a, b)
Expand All @@ -556,6 +586,9 @@ def fused_moe_kernel(
if b_desc is None:
b_ptrs += BLOCK_SIZE_K * stride_bk

if swap_ab:
accumulator = tl.trans(accumulator, (1, 0))

if use_int8_w8a16:
accumulator *= b_scale
elif use_fp8_w8a8 or use_int8_w8a8:
Expand Down Expand Up @@ -615,6 +648,11 @@ def invoke_fused_moe_kernel(
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1

if use_fp8_w8a8:
swap_ab = should_enable_swap_ab(config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"])
else:
swap_ab = False

padded_size = 0
if use_fp8_w8a8:
assert B_scale is not None
Expand Down Expand Up @@ -786,6 +824,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
even_Ks=even_Ks,
c_sorted=c_sorted,
filter_expert=filter_expert,
swap_ab=swap_ab,
**config,
)

Expand Down
Loading