Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions python/sglang/srt/layers/moe/moe_runner/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,25 @@
_is_cuda = is_cuda()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_use_aiter = bool(int(os.getenv("SGLANG_MOE_USE_AITER", "0")))
_use_aiter = bool(int(os.getenv("SGLANG_USE_AITER", "0")))
_MOE_PADDING_SIZE = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0


if _is_cuda:
if _is_cuda or _is_hip:
from sgl_kernel import gelu_and_mul, silu_and_mul

if _is_hip:
if _use_aiter:
try:
from aiter import moe_sum
except ImportError:
raise ImportError(
"aiter is required when SGLANG_USE_AITER is set to True"
)
else:
from vllm import _custom_ops as vllm_ops # moe_sum
elif _is_cpu and _is_cpu_amx_available:
pass
elif _is_hip:
from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul

if _use_aiter:
try:
from aiter import moe_sum
except ImportError:
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")


if _is_cuda or _is_hip:
from sgl_kernel import ( # noqa: F401
Expand Down Expand Up @@ -206,7 +208,7 @@ def run(
gemm1_alpha,
gemm1_limit,
)
elif _is_cuda:
elif _is_cuda or _is_hip:
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else:
vllm_ops.silu_and_mul(
Expand All @@ -215,7 +217,7 @@ def run(
elif activation == "gelu":
assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu"
assert gemm1_limit is None, "gemm1_limit is not supported for gelu"
if _is_cuda:
if _is_cuda or _is_hip:
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else:
vllm_ops.gelu_and_mul(
Expand Down
Loading