diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py index 8ce947d43a96..797332fef5dc 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py @@ -1,5 +1,6 @@ from __future__ import annotations +import functools import os from typing import Any, Dict, List, Optional @@ -20,9 +21,11 @@ from sglang.srt.utils import ( cpu_has_amx_support, get_bool_env_var, + get_device_name, is_cpu, is_cuda, is_hip, + is_sm90_supported, ) try: @@ -52,6 +55,24 @@ def support_tensor_descriptor(): return _support_tensor_descriptor +# In theory, swap_ab should benefit all SM90 GPUs. +# However, since it has only been verified on H20 (not H100/H200), +# it is currently enabled only on H20. +@functools.lru_cache(maxsize=8) +def should_enable_swap_ab( + BLOCK_SIZE_M: int, + BLOCK_SIZE_N: int, +) -> bool: + device_name = get_device_name() + is_h20_device = device_name and "H20" in device_name and "H200" not in device_name + return ( + is_h20_device + and is_sm90_supported() + and BLOCK_SIZE_M < 64 + and BLOCK_SIZE_N >= 64 + ) + + @triton.jit def write_zeros_to_output( c_ptr, @@ -360,6 +381,7 @@ def fused_moe_kernel( even_Ks: tl.constexpr, c_sorted: tl.constexpr, filter_expert: tl.constexpr, + swap_ab: tl.constexpr, ): """ Implements the fused computation for a Mixture of Experts (MOE) using @@ -498,7 +520,10 @@ def fused_moe_kernel( # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block # of fp32 values for higher accuracy. # `accumulator` will be converted back to fp16 after the loop. - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + if swap_ab: + accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_M), dtype=tl.float32) + else: + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k_start in range(0, K, BLOCK_SIZE_K): # Load the next block of A and B, generate a mask by checking the @@ -539,12 +564,17 @@ def fused_moe_kernel( a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 ) b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + if swap_ab: + a, b = tl.trans(b, (1, 0)), tl.trans(a, (1, 0)) + a_scale, b_scale = b_scale, a_scale if BLOCK_SIZE_N > group_n: accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] else: accumulator += tl.dot(a, b) * (a_scale[:, None] * b_scale) else: if use_fp8_w8a8: + if swap_ab: + a, b = tl.trans(b, (1, 0)), tl.trans(a, (1, 0)) accumulator = tl.dot(a, b, acc=accumulator) else: accumulator += tl.dot(a, b) @@ -556,6 +586,9 @@ def fused_moe_kernel( if b_desc is None: b_ptrs += BLOCK_SIZE_K * stride_bk + if swap_ab: + accumulator = tl.trans(accumulator, (1, 0)) + if use_int8_w8a16: accumulator *= b_scale elif use_fp8_w8a8 or use_int8_w8a8: @@ -615,6 +648,11 @@ def invoke_fused_moe_kernel( assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 + if use_fp8_w8a8: + swap_ab = should_enable_swap_ab(config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"]) + else: + swap_ab = False + padded_size = 0 if use_fp8_w8a8: assert B_scale is not None @@ -786,6 +824,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): even_Ks=even_Ks, c_sorted=c_sorted, filter_expert=filter_expert, + swap_ab=swap_ab, **config, )