diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index d65f9adea86f..45350c5a65c7 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -29,7 +29,8 @@ import numpy as np # type: ignore from tvm import tir -from tvm.ir import Range, Type +from tvm import ir +from tvm.ir import Type from tvm.ir.base import deprecated from tvm.runtime import String, convert, ndarray from tvm.target import Target @@ -496,7 +497,7 @@ def alloc_buffer( ) -def _as_range(dom: Union[Range, List[PrimExpr]]) -> Range: +def _as_range(dom: Union[ir.Range, List[PrimExpr]]) -> ir.Range: """The range constructor. Parameters @@ -509,13 +510,13 @@ def _as_range(dom: Union[Range, List[PrimExpr]]) -> Range: res : Range The Range. """ - if isinstance(dom, Range): + if isinstance(dom, ir.Range): return dom if isinstance(dom, (list, tuple)): - return Range(dom[0], dom[1]) + return ir.Range(dom[0], dom[1]) if hasattr(dom, "dtype"): - return Range(IntImm(dom.dtype, 0), dom) - return Range(0, dom) + return ir.Range(IntImm(dom.dtype, 0), dom) + return ir.Range(0, dom) class axis: # pylint: disable=invalid-name @@ -523,7 +524,7 @@ class axis: # pylint: disable=invalid-name @staticmethod def spatial( - dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], + dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32", ) -> Var: @@ -551,7 +552,7 @@ def spatial( @staticmethod def reduce( - dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], + dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32", ) -> Var: @@ -579,7 +580,7 @@ def reduce( @staticmethod def scan( - dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], + dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32", ) -> Var: @@ -607,7 +608,7 @@ def scan( @staticmethod def opaque( - dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], + dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32", ) -> Var: @@ -1288,7 +1289,7 @@ def buffer_store( def prefetch( buffer: Buffer, # pylint: disable=redefined-outer-name - bounds: List[Range], + bounds: List[ir.Range], ) -> None: """The prefetch hint for a buffer. @@ -1579,7 +1580,7 @@ def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: # pylint: disable=redefined-buil return _ffi_api.max(a, b) # type: ignore[attr-defined] # pylint: disable=no-member -def iter_var(v: Union[Var, str], dom: Range, iter_type: str, thread_tag: str) -> IterVar: +def iter_var(v: Union[Var, str], dom: ir.Range, iter_type: str, thread_tag: str) -> IterVar: """The iteration variable. Parameters @@ -1666,6 +1667,21 @@ def target(target_config: Union[Dict, str]) -> Target: return Target(target_config) +def Range(begin: PrimExpr, end: PrimExpr) -> ir.Range: # pylint: disable=invalid-name + """ + Create a Range object. + + Parameters + ---------- + begin : PrimExpr + The begin value of the range. + + end : Optional[PrimExpr] + The end value of the range. + """ + return ir.Range(begin, end) + + class meta_var: # pylint: disable=invalid-name """A meta variable used in TVMScript metaprogramming. It means that the value of the variable does not appear in the final TIR, but only stays in the parser. @@ -1782,6 +1798,11 @@ def wrapped(*args, **kwargs): tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync) tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment) tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync) +tvm_storage_sync = _tir_op.tvm_storage_sync +tvm_warp_shuffle = _tir_op.tvm_warp_shuffle +tvm_warp_shuffle_up = _tir_op.tvm_warp_shuffle_up +tvm_warp_shuffle_down = _tir_op.tvm_warp_shuffle_down +tvm_warp_activemask = _tir_op.tvm_warp_activemask ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group) ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group) assume = _op_wrapper(_tir_op.assume) @@ -2042,6 +2063,11 @@ def wrapped(*args, **kwargs): "tvm_bmma_sync", "tvm_fill_fragment", "tvm_store_matrix_sync", + "tvm_storage_sync", + "tvm_warp_shuffle", + "tvm_warp_shuffle_up", + "tvm_warp_shuffle_down", + "tvm_warp_activemask", "ptx_mma", "ptx_mma_sp", "ptx_ldmatrix", @@ -2109,4 +2135,5 @@ def wrapped(*args, **kwargs): "Let", "IterVar", "CommReducer", + "Range", ] diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 0a9c4fdfaa52..0fe460c085d7 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -569,7 +569,8 @@ def lookup_param(param_name, span=None): def tvm_thread_allreduce(*freduce_args): - """ + """Perform allreduce inside threadblock. + Parameters ---------- freduce_args : Expr @@ -583,6 +584,111 @@ def tvm_thread_allreduce(*freduce_args): return call_intrin("handle", "tir.tvm_thread_allreduce", *freduce_args) +def tvm_storage_sync(storage_scope): + """Perform synchronization in specified scope. + + Parameters + ---------- + storage_scope : str + The storage scope to perform synchronization. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("handle", "tir.tvm_storage_sync", storage_scope) + + +def tvm_warp_shuffle(mask, value, warp_id, width, warp_size): + """Exchange value between threads inside a warp. + + Parameters + ---------- + mask : PrimExpr + The warp mask indicates active threads inside warp. + value : PrimExpr + The value to exchange. + warp_id : PrimExpr + The source lane index to fetch value. + width : PrimExpr + The width of sub-sections to perform warp shuffle. + warp_size : PrimExpr + The warp size. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin(value.dtype, "tir.tvm_warp_shuffle", mask, value, warp_id, width, warp_size) + + +def tvm_warp_shuffle_up(mask, value, offset, width, warp_size): + """Copy value from a lane with lower (by offset) index relative to caller. + + Parameters + ---------- + mask : PrimExpr + The warp mask indicates active threads inside warp. + value : PrimExpr + The value to exchange. + offset : PrimExpr + The difference between source lane index and destination lane index: + `offset = dst_lane_idx - src_lane_idx` + width : PrimExpr + The width of sub-sections to perform warp shuffle. + warp_size : PrimExpr + The warp size. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + value.dtype, "tir.tvm_warp_shuffle_up", mask, value, offset, width, warp_size + ) + + +def tvm_warp_shuffle_down(mask, value, offset, width, warp_size): + """Copy value from a lane with higher (by offset) index relative to caller. + + Parameters + ---------- + mask : PrimExpr + The warp mask indicates active threads inside warp. + value : PrimExpr + The value to exchange. + offset : PrimExpr + The difference between source lane index and destination lane index: + `offset = src_lane_idx - dst_lane_idx` + width : PrimExpr + The width of sub-sections to perform warp shuffle. + warp_size : PrimExpr + The warp size. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + value.dtype, "tir.tvm_warp_shuffle_down", mask, value, offset, width, warp_size + ) + + +def tvm_warp_activemask(): + """Return a 32-bit mask indicates currently active threads in a calling warp. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("uint32", "tir.tvm_warp_activemask") + + def type_annotation(dtype): """Create a type annotation expression diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index c956f3bb02b9..6f07b6a75aeb 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3623,6 +3623,60 @@ def main(A: T.handle, B: T.handle): return main +def tvm_shfl_builtins(): + @T.prim_func + def func( + A: T.handle("float32"), + B: T.handle("float32"), + C: T.handle("float32"), + ): + blockIdx_x = T.launch_thread("blockIdx.x", 1) + threadIdx_x = T.launch_thread("threadIdx.x", 32) + A_warp = T.allocate([1], "float32", "local") + B_warp = T.allocate([1], "float32", "local") + red_buf0 = T.allocate([1], "float32", "local") + A_warp_1 = T.Buffer((32,), data=A_warp, scope="local") + A_1 = T.Buffer((32,), data=A) + A_warp_1[0] = A_1[threadIdx_x] + B_warp_1 = T.Buffer((32,), data=B_warp, scope="local") + T.tvm_storage_sync("warp") + B_warp_1[0] = T.tvm_warp_shuffle( + T.tvm_warp_activemask(), A_warp_1[0], threadIdx_x % 4 * 8 + threadIdx_x // 4, 32, 32 + ) + T.float32(1) + red_buf0_1 = T.Buffer((1,), data=red_buf0, scope="local") + with T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + mask = T.allocate([1], "uint32", "local") + t0 = T.allocate([1], "float32", "local") + red_buf0_1[0] = A_warp_1[0] + mask_1 = T.Buffer((1,), "uint32", data=mask, scope="local") + mask_1[0] = T.tvm_warp_activemask() + t0_1 = T.Buffer((1,), data=t0, scope="local") + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 16, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 8, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 4, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 2, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 1, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + red_buf0_1[0] = T.tvm_warp_shuffle(mask_1[0], red_buf0_1[0], 0, 32, 32) + # NOTE(Zihao): test tvm_warp_shuffle_up + red_buf0_1[0] = T.tvm_warp_shuffle_up(mask_1[0], red_buf0_1[0], 0, 32, 32) + if threadIdx_x == 0: + C_1 = T.Buffer((1,), data=C) + C_1[0] = red_buf0_1[0] + B_1 = T.Buffer((32,), data=B) + B_1[threadIdx_x] = B_warp_1[0] + + return func + + ir_generator = tvm.testing.parameter( launch_env_thread, opt_gemm_normalize, @@ -3686,6 +3740,7 @@ def main(A: T.handle, B: T.handle): let_stmt_value, string_stride, merge_shape_var_def, + tvm_shfl_builtins, )