diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index bd874d4c2..bfee1d2e3 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -3,10 +3,13 @@ from tilelang import tvm as tvm from tilelang.language import ptx_arrive_barrier, evaluate from tilelang.language.kernel import get_thread_bindings, get_block_extents +from tilelang.utils.target import check_hip_availability from tvm import tir from typing import Union, Any from tvm.tir import PrimExpr, Var, Call +_IS_HIP_AVAILABLE = check_hip_availability() + def create_list_of_mbarrier(*args: Any) -> Call: """ @@ -295,7 +298,10 @@ def shfl_xor(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, Returns: tir.Call: A handle to the shuffle operation """ - return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset) + if _IS_HIP_AVAILABLE: + return tir.call_extern(value.dtype, "__shfl_xor", value, offset) + else: + return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset) def shfl_down(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]): @@ -305,7 +311,10 @@ def shfl_down(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr value: Optional[int, PrimExpr] The value to shuffle """ - return tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset) + if _IS_HIP_AVAILABLE: + return tir.call_extern(value.dtype, "__shfl_down", value, offset) + else: + return tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset) def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]): @@ -315,7 +324,10 @@ def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, value: Optional[int, PrimExpr] The value to shuffle """ - return tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset) + if _IS_HIP_AVAILABLE: + return tir.call_extern(value.dtype, "__shfl_up", value, offset) + else: + return tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset) def sync_threads():