diff --git a/unsloth/kernels/moe/grouped_gemm/interface.py b/unsloth/kernels/moe/grouped_gemm/interface.py index 554e5fcc03..5588458973 100644 --- a/unsloth/kernels/moe/grouped_gemm/interface.py +++ b/unsloth/kernels/moe/grouped_gemm/interface.py @@ -39,10 +39,10 @@ # Precompute TMA support to avoid graph breaks # TMA requires both: -# 1. GPU capability >= 9 (Hopper+) +# 1. NVIDIA GPU with capability >= 9 (Hopper+) # 2. Triton version with TMA API (make_tensor_descriptor or _experimental_make_tensor_descriptor) def _check_tma_support(): - if DEVICE_TYPE == "xpu": + if DEVICE_TYPE in ("xpu", "hip"): return False import triton.language as tl