diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py index ee81b7939cb0..5afba6c01b9e 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py @@ -1,6 +1,5 @@ from __future__ import annotations -import functools import os from typing import Any, Dict, List, Optional @@ -21,11 +20,9 @@ from sglang.srt.utils import ( cpu_has_amx_support, get_bool_env_var, - get_device_name, is_cpu, is_cuda, is_hip, - is_sm90_supported, ) try: @@ -55,24 +52,6 @@ def support_tensor_descriptor(): return _support_tensor_descriptor -# In theory, swap_ab should benefit all SM90 GPUs. -# However, since it has only been verified on H20 (not H100/H200), -# it is currently enabled only on H20. -@functools.lru_cache(maxsize=8) -def should_enable_swap_ab( - BLOCK_SIZE_M: int, - BLOCK_SIZE_N: int, -) -> bool: - device_name = get_device_name() - is_h20_device = device_name and "H20" in device_name and "H200" not in device_name - return ( - is_h20_device - and is_sm90_supported() - and BLOCK_SIZE_M < 64 - and BLOCK_SIZE_N >= 64 - ) - - @triton.jit def write_zeros_to_output( c_ptr, @@ -381,7 +360,6 @@ def fused_moe_kernel( even_Ks: tl.constexpr, c_sorted: tl.constexpr, filter_expert: tl.constexpr, - swap_ab: tl.constexpr, ): """ Implements the fused computation for a Mixture of Experts (MOE) using @@ -520,10 +498,7 @@ def fused_moe_kernel( # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block # of fp32 values for higher accuracy. # `accumulator` will be converted back to fp16 after the loop. - if swap_ab: - accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_M), dtype=tl.float32) - else: - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k_start in range(0, K, BLOCK_SIZE_K): # Load the next block of A and B, generate a mask by checking the @@ -564,17 +539,12 @@ def fused_moe_kernel( a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 ) b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) - if swap_ab: - a, b = tl.trans(b, (1, 0)), tl.trans(a, (1, 0)) - a_scale, b_scale = b_scale, a_scale if BLOCK_SIZE_N > group_n: accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] else: accumulator += tl.dot(a, b) * (a_scale[:, None] * b_scale) else: if use_fp8_w8a8: - if swap_ab: - a, b = tl.trans(b, (1, 0)), tl.trans(a, (1, 0)) accumulator = tl.dot(a, b, acc=accumulator) else: accumulator += tl.dot(a, b) @@ -586,9 +556,6 @@ def fused_moe_kernel( if b_desc is None: b_ptrs += BLOCK_SIZE_K * stride_bk - if swap_ab: - accumulator = tl.trans(accumulator, (1, 0)) - if use_int8_w8a16: accumulator *= b_scale elif use_fp8_w8a8 or use_int8_w8a8: @@ -648,11 +615,6 @@ def invoke_fused_moe_kernel( assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 - if use_fp8_w8a8: - swap_ab = should_enable_swap_ab(config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"]) - else: - swap_ab = False - padded_size = 0 if use_fp8_w8a8: assert B_scale is not None @@ -824,7 +786,6 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): even_Ks=even_Ks, c_sorted=c_sorted, filter_expert=filter_expert, - swap_ab=swap_ab, **config, )