diff --git a/python/sglang/srt/layers/moe/moe_runner/triton.py b/python/sglang/srt/layers/moe/moe_runner/triton.py index ebb6ba1b4a45..cdf3e9a471f3 100644 --- a/python/sglang/srt/layers/moe/moe_runner/triton.py +++ b/python/sglang/srt/layers/moe/moe_runner/triton.py @@ -32,23 +32,25 @@ _is_cuda = is_cuda() _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() -_use_aiter = bool(int(os.getenv("SGLANG_MOE_USE_AITER", "0"))) +_use_aiter = bool(int(os.getenv("SGLANG_USE_AITER", "0"))) _MOE_PADDING_SIZE = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0 -if _is_cuda: +if _is_cuda or _is_hip: from sgl_kernel import gelu_and_mul, silu_and_mul + + if _is_hip: + if _use_aiter: + try: + from aiter import moe_sum + except ImportError: + raise ImportError( + "aiter is required when SGLANG_USE_AITER is set to True" + ) + else: + from vllm import _custom_ops as vllm_ops # moe_sum elif _is_cpu and _is_cpu_amx_available: pass -elif _is_hip: - from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul - - if _use_aiter: - try: - from aiter import moe_sum - except ImportError: - raise ImportError("aiter is required when SGLANG_USE_AITER is set to True") - if _is_cuda or _is_hip: from sgl_kernel import ( # noqa: F401 @@ -206,7 +208,7 @@ def run( gemm1_alpha, gemm1_limit, ) - elif _is_cuda: + elif _is_cuda or _is_hip: silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) else: vllm_ops.silu_and_mul( @@ -215,7 +217,7 @@ def run( elif activation == "gelu": assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu" assert gemm1_limit is None, "gemm1_limit is not supported for gelu" - if _is_cuda: + if _is_cuda or _is_hip: gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) else: vllm_ops.gelu_and_mul(