diff --git a/docker/install/ubuntu_install_python_package.sh b/docker/install/ubuntu_install_python_package.sh index d1fa340ac37d2..fb0f596d65527 100755 --- a/docker/install/ubuntu_install_python_package.sh +++ b/docker/install/ubuntu_install_python_package.sh @@ -36,6 +36,6 @@ pip3 install \ pytest-xdist \ requests \ scipy \ - synr==0.4.1 \ + synr==0.5.0 \ six \ tornado diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 23057f7140e4b..e4a3d3d1e21b6 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -199,9 +199,10 @@ class LinkedParam : public ObjectRef { * def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None: * A = T.match_buffer(a, (m, n), "float32") * B = T.match_buffer(b, (m, n), "float32") - * - * with T.block([m, n], "") as [vi, vj]: - * B[vi, vj] = A[vi, vj] + * for i, j in T.grid(m, n): + * with T.block(): + * vi, vj = T.axis.remap("SS", [i, j]) + * B[vi, vj] = A[vi, vj] * \endcode * * Then we can make it specialized with given shapes or buffers. @@ -218,9 +219,10 @@ class LinkedParam : public ObjectRef { * def mem_copy_16_16(a: T.handle, b: T.handle) -> None: * A = T.match_buffer(a, (16, 16), "float32") * B = T.match_buffer(b, (16, 16), "float32") - * - * with T.block([16, 16], "") as [vi, vj]: - * B[vi, vj] = A[vi, vj] + * for i, j in T.grid(16, 16): + * with T.block(): + * vi, vj = T.axis.remap("SS", [i, j]) + * B[vi, vj] = A[vi, vj] * \endcode */ PrimFunc Specialize(PrimFunc func, const Map& param_map); diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 5cd860b8e9294..4f5772822d9e9 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1078,9 +1078,9 @@ class MatchBufferRegion : public ObjectRef { * \note Block's body is parameterized by iter vars. * \code * - * with T.block([extent0, extent1, ...], name) as [v0, v1, ...]: - * T.bind(v0, value0) - * T.bind(v1, value1) + * with T.block(name): + * v0 = T.axis.S(domain, value0) + * v1 = T.axis.R(domain, value1) * ... * T.reads([buffer0[start:end, ...], ...]) * T.writes([buffer1[start:end, ...], ...]) diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 017078bd7bf7c..e6b0af9773d96 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -388,7 +388,7 @@ TVM_DLL Pass ConvertBlocksToOpaque(); * \code * * for i in range(0, 16): - * with T.block([]): + * with T.block(): * B = T.alloc_buffer(16, 16) * for j in range(0, 16): * B[i, j] = A[i, j] + 1 @@ -404,7 +404,7 @@ TVM_DLL Pass ConvertBlocksToOpaque(); * \code * * for i in range(0, 16): - * with T.block([]): + * with T.block(): * B = T.alloc_buffer(1, 16) * for j in range(0, 16): * B[0, j] = A[i, j] + 1 diff --git a/python/gen_requirements.py b/python/gen_requirements.py index fa94d6a641304..cdffdd95e5467 100755 --- a/python/gen_requirements.py +++ b/python/gen_requirements.py @@ -250,7 +250,7 @@ ("sphinx_autodoc_annotation", None), ("sphinx_gallery", None), ("sphinx_rtd_theme", None), - ("synr", "==0.4.1"), + ("synr", "==0.5.0"), ("tensorflow", None), ("tensorflow-estimator", None), ("tflite", None), diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py index 75566cf6e2c5e..56d080857a7d3 100644 --- a/python/tvm/script/context_maintainer.py +++ b/python/tvm/script/context_maintainer.py @@ -22,8 +22,10 @@ import tvm from tvm.ir import Span +from tvm.ir.expr import Range from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion from tvm.runtime import Object +from tvm.tir.expr import IterVar from .tir.node import BufferSlice @@ -41,10 +43,10 @@ def example_func(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(a, (16, 16), "float32") for i, j, k in T.grid(16, 16, 16): - with T.block([16, 16, T.reduce_axis(16)], "matmul") as [vi, vj, vk]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) # iter_bindings = {vj: i, vj: j, vk: k} + with T.block("matmul"): + vi = T.axis.S(16, i) + vj = T.axis.S(16, j) + vk = T.axis.R(16, k) # iter_bindings = {vj: i, vj: j, vk: k} T.where(True) # predicate of the block_realize @@ -72,8 +74,10 @@ def example_func(a: T.handle, b: T.handle, c: T.handle) -> None: """List[Buffer]: list of T.alloc_buffer statements in the block signature""" match_buffers: List[MatchBufferRegion] = [] """List[MatchBufferRegion]: list of T.match_buffer statements in the block signature""" - iter_bindings: Mapping[Var, PrimExpr] = {} - """Mapping[Var, PrimExpr]: map of block iter var to its values""" + iter_values: List[PrimExpr] = [] + """List[PrimExpr]: list of binding values for iter vars""" + iter_vars: List[IterVar] = [] + """List[PrimExpr]: list of iter vars in the block""" reads: Optional[List[BufferSlice]] = None """Optional[List[BufferSlice]]: list of T.reads statements in the block signature, None for not-visited""" @@ -91,7 +95,8 @@ def example_func(a: T.handle, b: T.handle, c: T.handle) -> None: def __init__(self): self.alloc_buffers = [] self.match_buffers = [] - self.iter_bindings = {} + self.iter_values = [] + self.iter_vars = [] self.reads = None self.writes = None self.annotations = None @@ -112,8 +117,8 @@ class ContextMaintainer: """List[List[synr.ast.Node]]: The ast nodes insides the current scope""" block_info_stack: List[BlockInfo] = [] """List[BlockInfo]: The block info for the current block scope""" - loop_stack: List[List[Var]] = [] - """List[List[Var]]: List of loop vars inside the current block scope""" + loop_stack: Dict[Var, Range] = {} + """Dict[Var, Range]: The dict from loop var to its domain outside the block""" symbols: List[Dict[str, Union[Var, Buffer]]] = [] """List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for the current scope""" @@ -137,7 +142,7 @@ def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], No # scope context self.node_stack = [] self.block_info_stack = [] - self.loop_stack = [] + self.loop_stack = {} self.symbols = [] # function context self.func_params = [] @@ -183,8 +188,6 @@ def enter_block_scope(self, nodes: Optional[List[synr.ast.Node]] = None): The synr AST nodes in new scope """ self.enter_scope(nodes) - # Create a new loop stack for the new block - self.loop_stack.append([]) # Create a new BlockInfo for the new block self.block_info_stack.append(BlockInfo()) @@ -196,8 +199,6 @@ def exit_scope(self): def exit_block_scope(self): """Pop the inner most block scope, the function will call `exit_scope` implicitly""" self.exit_scope() - # Pop loop stack - self.loop_stack.pop() # Pop block_info self.block_info_stack.pop() diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index d5e79e8676c16..8610d91e9f075 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -377,12 +377,13 @@ def A(): """ if len(node.assignments) == 1: if not ( - isinstance(node.assignments[0].lhs, ast.Var) - and node.assignments[0].lhs.id.name == "__tvm_meta__" + len(node.assignments[0].lhs) == 1 + and isinstance(node.assignments[0].lhs[0], ast.Var) + and node.assignments[0].lhs[0].id.name == "__tvm_meta__" ): self.report_error( "The only top level assignments allowed are `__tvm_meta__ = ...`", - node.assignments[0].lhs.span, + node.assignments[0].span, ) self.init_meta( MetaUnparser().do_transform(node.assignments[0].rhs, self._diagnostic_context) @@ -526,18 +527,19 @@ def transform_Assign(self, node): return self.parse_body(node) else: value = self.transform(node.rhs) - if not isinstance(node.lhs, ast.Var): + if len(node.lhs) == 1 and not isinstance(node.lhs[0], ast.Var): # This is a little confusing because it only is true when # we have taken this branch. We might need to clarify what # exectly is allowed in Assignments in tvmscript. self.report_error( "Left hand side of assignment must be an unqualified variable", - node.lhs.span, + node.span, ) + ast_var = node.lhs[0] var = tvm.te.var( - node.lhs.id.name, - self.parse_type(node.ty, node.lhs), - span=tvm_span_from_synr(node.lhs.span), + ast_var.id.name, + self.parse_type(node.ty, ast_var), + span=tvm_span_from_synr(ast_var.span), ) self.context.update_symbol(var.name, var, node) body = self.parse_body(node) @@ -596,7 +598,7 @@ def transform_For(self, node): For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment) By now 1 pattern of For is supported: 1. for scope handler - for name in T.serial()/T.parallel()/T.vectorized()/T.unroll()/tir.range()/ + for name in T.serial()/T.parallel()/T.vectorized()/T.unroll()/range()/ T.grid()/T.thread_binding() """ @@ -892,9 +894,20 @@ def transform_Attr(self, node): namespace. """ - if isinstance(node.object, ast.Var): - if self.match_tir_namespace(node.object.id.name): - func_name = "tir." + node.field.name + def get_full_attr_name(node: ast.Attr) -> str: + reverse_field_names = [node.field.name] + while isinstance(node.object, ast.Attr): + node = node.object + reverse_field_names.append(node.field.name) + if isinstance(node.object, ast.Var): + reverse_field_names.append(node.object.id.name) + return ".".join(reversed(reverse_field_names)) + + if isinstance(node.object, (ast.Var, ast.Attr)): + full_attr_name = get_full_attr_name(node) + attr_object, fields = full_attr_name.split(".", maxsplit=1) + if self.match_tir_namespace(attr_object): + func_name = "tir." + fields res = Registry.lookup(func_name) if res is not None: return res @@ -903,9 +916,7 @@ def transform_Attr(self, node): except TVMError as e: # Check if we got an attribute error if e.args[0].find("AttributeError"): - self.report_error( - f"Unregistered function `tir.{node.field.name}`.", node.field.span - ) + self.report_error(f"Unregistered function `tir.{fields}`.", node.span) else: raise e diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index 487a71d4f0779..4750ad7626e2d 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -134,12 +134,14 @@ def enter_scope( if isinstance(node, synr.ast.With): vars = WithScopeHandler.get_optional_vars(node, context) if len(vars) != 1: - context.report_error("Unexpected number of vars", node.span) + context.report_error(f"Unexpected number of vars: 1 vs. {len(vars)}", node.span) name = vars[0].id.name var_span = vars[0].id.span elif isinstance(node, synr.ast.Assign): - name = node.lhs.id.name - var_span = node.lhs.id.span + if len(node.lhs) != 1: + context.report_error(f"Unexpected number of vars: 1 vs. {len(node.lhs)}", node.span) + name = node.lhs[0].id.name + var_span = node.lhs[0].id.span else: raise Exception("Internal Bug") @@ -247,42 +249,16 @@ def let(var, value, span): @register class Block(WithScopeHandler): - """With scope handler T.block(extents, name) as iter_vars""" + """With scope handler T.block(name)""" def __init__(self): - def block(axes=None, name_hint: str = "", span: Optional[Span] = None): + def block(name_hint: str = "", span: Optional[Span] = None): assert ( self.node and self.context and self.body ), "call 'exit_scope' before 'enter_scope'" block_info = self.context.block_info_stack[-1] - if axes is None: - axes = [] - if len(axes) != len(self.block_vars): - self.context.report_error( - "Inconsistent number of block vars, " - + f"there are {len(axes)} axes but {len(self.block_vars)} block vars. " - + "The number of block vars should match the number of axes.", - self.node.span, - ) - block_iters: List[IterVar] = [] - for i, axis in enumerate(axes): - axis = tvm.runtime.convert(axis) - if isinstance(axis, tvm.tir.PrimExpr): - block_var_dom = Range.from_min_extent(0, axis) - block_iters.append(IterVar(block_var_dom, self.block_vars[i], 0)) - elif isinstance(axis, Range): - block_iters.append(IterVar(axis, self.block_vars[i], 0)) - elif isinstance(axis, IterVar): - block_iters.append(IterVar(axis.dom, self.block_vars[i], axis.iter_type)) - else: - self.context.report_error( - "Invalid argument of T.block(), " - + f"expected PrimExpr, Range or IterVar, but got {type(axis)}", - self.node.span, - ) # create block read/write regions - reads: List[BufferRegion] = ( [buffer_slice_to_region(read) for read in block_info.reads] if block_info.reads @@ -301,7 +277,7 @@ def block(axes=None, name_hint: str = "", span: Optional[Span] = None): if region_detect_mask != 0: annotations["tir.script_parsing_detect_access"] = region_detect_mask inner = tvm.tir.Block( - block_iters, + block_info.iter_vars, reads, writes, name_hint, @@ -312,35 +288,13 @@ def block(axes=None, name_hint: str = "", span: Optional[Span] = None): annotations, span, ) - # create block var iter binding - values: List[PrimExpr] - if not block_info.iter_bindings: - values = self.context.loop_stack[-2].copy() - if len(block_iters) == 0: - # It is an opaque block without any bindings - values = [] - elif len(values) == 0: - values = [tvm.tir.const(float("nan"), dtype="float32")] * len(block_iters) - elif len(values) != len(block_iters): - self.context.report_error( - "Number of block iter var and outer loop nesting mismatch, " - + f"{len(block_iters)} block iter vars but {len(values)} loops", - self.node.span, - ) - else: - for block_var in self.block_vars: - if block_var not in block_info.iter_bindings: - self.context.report_error( - "Missing block iter var binding for " + block_var.name, - self.node.span, - ) - values = [block_info.iter_bindings[block_var] for block_var in self.block_vars] + assert len(block_info.iter_vars) == len(block_info.iter_values) predicate = ( tvm.tir.const(True, "bool") if block_info.predicate is None else block_info.predicate ) - body = tvm.tir.BlockRealize(values, predicate, inner, span) + body = tvm.tir.BlockRealize(block_info.iter_values, predicate, inner, span) return body super().__init__(func=block, concise_scope=False, def_symbol=True) @@ -358,10 +312,13 @@ def enter_scope( node, synr.ast.With ), f"BlockScopeHandler expected to work on synr.ast.With but got {type(node)}" - vars = WithScopeHandler.get_optional_vars(node, context) - self.block_vars = [tvm.te.var(var.id.name) for var in vars] - for block_var in self.block_vars: - context.update_symbol(block_var.name, block_var, node) + optional_vars = [var.id.name for var in WithScopeHandler.get_optional_vars(node, context)] + if optional_vars: + context.report_error( + f"Block expected no optional_vars (e.g., `x` in `with block() as x`), " + f"but got {optional_vars}", + node.span, + ) @register @@ -378,12 +335,38 @@ def init(span: Span = None): super().__init__(func=init, concise_scope=False, def_symbol=True) +class LoopInfo: + """Helper class for loop information""" + + loop_var: Var + begin: PrimExpr + extent: PrimExpr + kind: ForKind + thread_binding: Optional[str] + annotations: Optional[Mapping[str, Object]] + + def __init__( + self, + begin: PrimExpr, + extent: PrimExpr, + kind: ForKind, + thread_binding: Optional[str] = None, + annotations: Optional[Mapping[str, Object]] = None, + ) -> None: + self.begin = begin + self.extent = extent + self.kind = kind + self.thread_binding = thread_binding + self.annotations = annotations + + class ForScopeHandler(ScopeHandler): """Base class for all for scope handlers""" def __init__(self, func): super().__init__(func) - self.loop_vars: Optional[List[Var]] = None + self.loop_vars: List[Var] = [] + self.loop_info: List[LoopInfo] = [] def enter_scope( self, @@ -415,12 +398,23 @@ def enter_scope( span, ) + self.node = node + self.context = context + # generate loop vars self.loop_vars = [ tvm.te.var(name, dtype="int32", span=span) for name, span in zip(loop_var_names, spans) ] - for loop_var in self.loop_vars: + # collect loop infos by calling self.func + call_with_error_reporting(context.report_error, span, self.func, *arg_list) + if len(self.loop_vars) != len(self.loop_info): + self.context.report_error( + f"Inconsistent number of vars and loops, got {len(self.loop_vars)} " + + f"vs {len(self.loop_info)}", + self.node.span, + ) + for loop_var, loop_info in zip(self.loop_vars, self.loop_info): context.update_symbol(loop_var.name, loop_var, node) - context.loop_stack[-1].append(loop_var) + context.loop_stack[loop_var] = Range.from_min_extent(loop_info.begin, loop_info.extent) def exit_scope( self, @@ -430,19 +424,34 @@ def exit_scope( span: synr.ast.Span, ): assert self.loop_vars, "call 'exit_scope' before 'enter_scope'" - for _ in self.loop_vars: - context.loop_stack[-1].pop() - return super().exit_scope(node, context, arg_list, span) + for loop_var in self.loop_vars: + context.loop_stack.pop(loop_var) + # Use assert here since we have check it in `enter_scope` + assert len(self.loop_vars) == len(self.loop_info) + + body = self.body + for var, info in zip(reversed(self.loop_vars), reversed(self.loop_info)): + body = tvm.tir.For( + var, + info.begin, + info.extent, + info.kind, + body, + info.thread_binding, + info.annotations, + span=tvm_span_from_synr(span), + ) - def create_loop( + return body + + def create_loop_info( self, begin: PrimExpr, end: PrimExpr, kind: ForKind, thread_binding: Optional[str] = None, annotations: Optional[Mapping[str, Object]] = None, - span: Optional[Span] = None, - ) -> tvm.tir.For: + ) -> None: """ Helper function for creating For in TVM Script parser. @@ -471,30 +480,16 @@ def create_loop( for : For The constructed For. """ - assert ( - self.loop_vars and self.context and self.node - ), "call 'exit_scope' before 'enter_scope'" - if len(self.loop_vars) != 1: - self.context.report_error( - f"Expected exactly one loop var, but got {self.loop_vars}", self.node.span - ) + assert self.context and self.node, "call 'exit_scope' before 'enter_scope'" extent = end if begin == 0 else self.context.analyzer.simplify(end - begin) - annos: Mapping[str, Object] = {} + self.annotations: Mapping[str, Object] = {} if annotations is not None: - annos = { + self.annotations = { key: tvm.tir.StringImm(val) if isinstance(val, str) else val for key, val in annotations.items() } - return tvm.tir.For( - self.loop_vars[0], - begin, - extent, - kind, - self.body, - thread_binding=thread_binding, - annotations=annos, - span=span, - ) + + self.loop_info.append(LoopInfo(begin, extent, kind, thread_binding, annotations)) @register @@ -506,9 +501,8 @@ def serial( begin: PrimExpr, end: PrimExpr, annotations: Optional[Mapping[str, Object]] = None, - span: Optional[Span] = None, ): - return self.create_loop(begin, end, ForKind.SERIAL, annotations=annotations, span=span) + self.create_loop_info(begin, end, ForKind.SERIAL, annotations=annotations) super().__init__(serial) @@ -522,11 +516,8 @@ def parallel( begin: PrimExpr, end: PrimExpr, annotations: Optional[Mapping[str, Object]] = None, - span: Optional[Span] = None, ): - return self.create_loop( - begin, end, ForKind.PARALLEL, annotations=annotations, span=span - ) + self.create_loop_info(begin, end, ForKind.PARALLEL, annotations=annotations) super().__init__(parallel) @@ -540,11 +531,8 @@ def vectorized( begin: PrimExpr, end: PrimExpr, annotations: Optional[Mapping[str, Object]] = None, - span: Optional[Span] = None, ): - return self.create_loop( - begin, end, ForKind.VECTORIZED, annotations=annotations, span=span - ) + self.create_loop_info(begin, end, ForKind.VECTORIZED, annotations=annotations) super().__init__(vectorized) @@ -558,11 +546,8 @@ def unroll( begin: PrimExpr, end: PrimExpr, annotations: Optional[Mapping[str, Object]] = None, - span: Optional[Span] = None, ): - return self.create_loop( - begin, end, ForKind.UNROLLED, annotations=annotations, span=span - ) + self.create_loop_info(begin, end, ForKind.UNROLLED, annotations=annotations) super().__init__(unroll) @@ -577,16 +562,14 @@ def thread_binding( end: PrimExpr, thread: str, annotations: Optional[Mapping[str, Object]] = None, - span: Optional[Span] = None, ): - thread_iter_var = IterVar(None, None, IterVar.ThreadIndex, thread, span=span) - return self.create_loop( + thread_iter_var = IterVar(None, None, IterVar.ThreadIndex, thread) + self.create_loop_info( begin, end, ForKind.THREAD_BINDING, thread_binding=thread_iter_var, annotations=annotations, - span=span, ) super().__init__(thread_binding) @@ -603,12 +586,11 @@ def for_range( begin: PrimExpr, end: PrimExpr = None, annotations: Optional[Mapping[str, Object]] = None, - span: Optional[Span] = None, ): if end is None: end = begin begin = 0 - return self.create_loop(begin, end, ForKind.SERIAL, annotations=annotations, span=span) + self.create_loop_info(begin, end, ForKind.SERIAL, annotations=annotations) super().__init__(for_range) @@ -621,19 +603,8 @@ class Grid(ForScopeHandler): """For scope handler T.grid(extents)""" def __init__(self): - def grid(*extents: List[PrimExpr], span: Span): - assert ( - self.node and self.context and self.loop_vars - ), "call 'exit_scope' before 'enter_scope'" - if len(self.loop_vars) != len(extents): - self.context.report_error( - "Inconsistent number of loop vars and extents, " - + f"got {len(self.loop_vars)} vs {len(extents)}", - self.node.span, - ) - body = self.body - for loop_var, extent in zip(reversed(self.loop_vars), reversed(extents)): - body = tvm.tir.For(loop_var, 0, extent, ForKind.SERIAL, body, span=span) - return body + def grid(*extents: List[PrimExpr]): + for extent in extents: + self.create_loop_info(0, extent, ForKind.SERIAL) super().__init__(grid) diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 69cf15f493de4..de212352f3e43 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -21,17 +21,18 @@ import synr from synr import ast +from tvm.ir.expr import PrimExpr, Range import tvm.tir from tvm.runtime import Object from tvm import te from tvm.ir import Span -from tvm.tir import IntImm +from tvm.tir import IntImm, IterVar from .node import BufferSlice from .utils import buffer_slice_to_region -from ..context_maintainer import ContextMaintainer +from ..context_maintainer import BlockInfo, ContextMaintainer from ..registry import register from ..utils import ( get_param_list, @@ -132,9 +133,10 @@ def match_buffer( buffer_type="default", span=None, ): - if not isinstance(self.node, ast.Assign): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: self.context.report_error( - "match_buffer must be assigned to a buffer, e.g. A = match_buffer(...)", + "`match_buffer` must be assigned to a single buffer, " + "e.g. A = match_buffer(...)", self.node.span, ) if strides is None: @@ -143,10 +145,11 @@ def match_buffer( offset_factor = convert_to_int( offset_factor, "offset_factor", self.context.report_error, self.node.span ) + buffer_name: str = self.node.lhs[0].id.name buffer = tvm.tir.decl_buffer( shape, dtype, - self.node.lhs.id.name, + buffer_name, data, strides, elem_offset, @@ -173,7 +176,7 @@ def match_buffer( + str(type(param)), self.node.rhs.params[0].span, ) - self.context.update_symbol(self.node.lhs.id.name, buffer, self.node) + self.context.update_symbol(buffer_name, buffer, self.node) super().__init__(match_buffer, def_symbol=True) @@ -201,9 +204,9 @@ def buffer_decl( buffer_type="default", span=None, ): - if not isinstance(self.node, ast.Assign): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: self.context.report_error( - "buffer_decl must be assigned to a buffer, e.g. A = buffer_decl(...)", + "`buffer_decl` must be assigned to a single buffer, e.g. A = buffer_decl(...)", self.node.span, ) @@ -213,10 +216,11 @@ def buffer_decl( offset_factor = convert_to_int( offset_factor, "offset_factor", self.context.report_error, self.node.span ) + buffer_name: str = self.node.lhs[0].id.name buffer = tvm.tir.decl_buffer( shape, dtype, - self.node.lhs.id.name, + buffer_name, data, strides, elem_offset, @@ -226,7 +230,7 @@ def buffer_decl( buffer_type, span=span, ) - self.context.update_symbol(self.node.lhs.id.name, buffer, self.node) + self.context.update_symbol(buffer_name, buffer, self.node) return buffer super().__init__(buffer_decl, def_symbol=True) @@ -257,9 +261,10 @@ def alloc_buffer( buffer_type="default", span=None, ): - if not isinstance(self.node, ast.Assign): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: self.context.report_error( - "alloc_buffer must be assigned to a buffer, e.g. A = alloc_buffer(...)", + "`alloc_buffer` must be assigned to a single buffer, " + "e.g. A = alloc_buffer(...)", self.node.span, ) @@ -269,10 +274,11 @@ def alloc_buffer( offset_factor = convert_to_int( offset_factor, "offset_factor", self.context.report_error, self.node.span ) + buffer_name: str = self.node.lhs[0].id.name buffer = tvm.tir.decl_buffer( shape, dtype, - self.node.lhs.id.name, + buffer_name, data, strides, elem_offset, @@ -283,32 +289,11 @@ def alloc_buffer( span=span, ) self.context.current_block_scope().alloc_buffers.append(buffer) - self.context.update_symbol(self.node.lhs.id.name, buffer, self.node) + self.context.update_symbol(buffer_name, buffer, self.node) super().__init__(alloc_buffer, def_symbol=True) -@register -class BlockVarBind(SpecialStmt): - """Special function bind(block_iter, binding_value) - - Example - ------- - .. code-block:: python - - T.bind(vx, i) - """ - - def __init__(self): - def bind(iter_var, values, span=None): - block_scope = self.context.current_block_scope() - if iter_var in block_scope.iter_bindings: - self.context.report_error("Duplicate iter_var bindings of " + str(iter_var), span) - block_scope.iter_bindings[iter_var] = values - - super().__init__(bind, def_symbol=False) - - @register class BlockReads(SpecialStmt): """Special function reads([read_buffer_regions]) @@ -412,6 +397,315 @@ def block_attr(attrs: Mapping[str, Object], span: Span = None): super().__init__(block_attr, def_symbol=False) +class BlockAxis(SpecialStmt): + """Special stmt for defining a spatial block axis + axis.S(dom, iter_value) + + Example + ------- + .. code-block:: python + + vi = T.axis.S(128, i * 4 + j) + """ + + def axis( + self, + var_name: str, + dom: Union[PrimExpr, Range], + value: PrimExpr, + iter_type: int, + span: Optional[Span] = None, + ) -> None: + """ + Helper function for creating block axis + + Parameters + ---------- + var_name : str + The name_hint of var + + dom : Union[PrimExpr, Range] + The iter domain. + + value : PrimExpr + The binding value + + iter_type : int + The iteration type. + + span : Optional[Span] + The location of this for in the source code. + """ + assert self.context, "call 'exit_scope' before 'enter_scope'" + block_scope: BlockInfo = self.context.current_block_scope() + if var_name in [iter_var.var.name for iter_var in block_scope.iter_vars]: + self.context.report_error("Duplicate block axis " + var_name, self.node.span) + + block_var = tvm.tir.Var(var_name, dtype="int32") + dom = tvm.runtime.convert(dom) + if isinstance(dom, PrimExpr): + dom = tvm.ir.Range.from_min_extent(0, dom) + elif not isinstance(dom, tvm.ir.Range): + self.context.report_error( + f"Block axis domain expected PrimExpr or Range, but got {type(value)}", + self.node.span, + ) + value = tvm.runtime.convert(value) + if not isinstance(value, PrimExpr): + self.context.report_error( + f"Block axis value expected PrimExpr, but got {type(value)}", + self.node.span, + ) + iter_var = tvm.tir.IterVar(dom, block_var, iter_type) + block_scope.iter_vars.append(iter_var) + block_scope.iter_values.append(value) + self.context.update_symbol(var_name, block_var, self.node) + + +@register +class BlockAxisSpatial(BlockAxis): + """Special stmt for defining a spatial block axis + axis.spatial(dom, iter_value) + + Example + ------- + .. code-block:: python + + vi = T.axis.spatial(128, k) + """ + + def __init__(self): + def axis_spatial( + dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None + ): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: + self.context.report_error( + "`axis.spatial` must be assigned to a var, e.g. vi = axis.spatial(...)", + self.node.span, + ) + self.axis(self.node.lhs[0].id.name, dom, value, IterVar.DataPar) + + super().__init__(axis_spatial, def_symbol=True) + + def signature(self) -> Tuple[str, Tuple[list, list, Any]]: + return "tir.axis.spatial", get_param_list(self.func) + + +@register +class BlockAxisS(BlockAxis): + """The sugar special stmt for defining a spatial block axis + axis.S(dom, iter_value) + + Example + ------- + .. code-block:: python + + vi = T.axis.S(128, k) + """ + + def __init__(self): + def axis_spatial( + dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None + ): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: + self.context.report_error( + "`axis.S` must be assigned to a var, e.g. vi = axis.S(...)", + self.node.span, + ) + self.axis(self.node.lhs[0].id.name, dom, value, IterVar.DataPar) + + super().__init__(axis_spatial, def_symbol=True) + + def signature(self) -> Tuple[str, Tuple[list, list, Any]]: + return "tir.axis.S", get_param_list(self.func) + + +@register +class BlockAxisReduce(BlockAxis): + """Special stmt for defining a reduce block axis + axis.reduce(dom, iter_value) + + Example + ------- + .. code-block:: python + + vi = T.axis.reduce(128, k) + """ + + def __init__(self): + def axis_reduce( + dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None + ): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: + self.context.report_error( + "`axis.reduce` must be assigned` to a var, e.g. vi = axis.reduce(...)", + self.node.span, + ) + self.axis(self.node.lhs[0].id.name, dom, value, IterVar.CommReduce) + + super().__init__(axis_reduce, def_symbol=True) + + def signature(self) -> Tuple[str, Tuple[list, list, Any]]: + return "tir.axis.reduce", get_param_list(self.func) + + +@register +class BlockAxisR(BlockAxis): + """The sugar special stmt for defining a reduce block axis + axis.R(dom, iter_value) + + Example + ------- + .. code-block:: python + + vi = T.axis.R(128, k) + """ + + def __init__(self): + def axis_reduce( + dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None + ): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: + self.context.report_error( + "`axis.R` must be assigned to a var, e.g. vi = axis.R(...)", + self.node.span, + ) + self.axis(self.node.lhs[0].id.name, dom, value, IterVar.CommReduce) + + super().__init__(axis_reduce, def_symbol=True) + + def signature(self) -> Tuple[str, Tuple[list, list, Any]]: + return "tir.axis.R", get_param_list(self.func) + + +@register +class BlockAxisScan(BlockAxis): + """Special stmt for defining a ordered block axis + axis.scan(dom, iter_value) + + Example + ------- + .. code-block:: python + + vi = T.axis.scan(128, k) + """ + + def __init__(self): + def axis_scan( + dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None + ): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: + self.context.report_error( + "`axis.scan` must be assigned to a var, e.g. vi = axis.scan(...)", + self.node.span, + ) + self.axis(self.node.lhs[0].id.name, dom, value, IterVar.Ordered) + + super().__init__(axis_scan, def_symbol=True) + + def signature(self) -> Tuple[str, Tuple[list, list, Any]]: + return "tir.axis.scan", get_param_list(self.func) + + +@register +class BlockAxisOpaque(BlockAxis): + """Special stmt for defining a opaque block axis + axis.opaque(dom, iter_value) + + Example + ------- + .. code-block:: python + + vi = T.axis.opaque(128, k) + """ + + def __init__(self): + def axis_opaque( + dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None + ): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: + self.context.report_error( + "`axis.opaque` must be assigned to a var, e.g. vi = axis.opaque(...)", + self.node.span, + ) + self.axis(self.node.lhs[0].id.name, dom, value, IterVar.DimInfo) + + super().__init__(axis_opaque, def_symbol=True) + + def signature(self) -> Tuple[str, Tuple[list, list, Any]]: + return "tir.axis.opaque", get_param_list(self.func) + + +@register +class BlockAxisRemap(BlockAxis): + """Special stmt for remapping loops vars to block axes. + axis.remap(iter_type, iter_value) + + Note + ---- + Iter_type is a string consisting of 'S' and 'R', where 'S' means + for spatial and 'R' means for reduce. + + Example + ------- + .. code-block:: python + + vi, vj = T.axis.remap("SS", [i, j]) + """ + + def __init__(self): + def axis_remap(iter_types: str, loop_vars: List[tvm.tir.expr.Var], span: Span = None): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) >= 1: + self.context.report_error( + "`axis.remap` must be assigned to one or more vars, " + "e.g. vi, vj = axis.remap(...)", + self.node.span, + ) + var_num: int = len(self.node.lhs) + if var_num != len(iter_types): + self.context.report_error( + f"`iter_type` expected {var_num} charactor(s), " + f"but got {len(iter_types)}: {iter_types}", + span, + ) + if var_num != len(loop_vars): + self.context.report_error( + f"`iter_type` expected {var_num} loop var(s), " + f"but got {len(loop_vars)}: {loop_vars}", + span, + ) + for var, iter_ty, loop_var in zip(self.node.lhs, iter_types, loop_vars): + iter_type: int + if iter_ty == "S": + iter_type = IterVar.DataPar + elif iter_ty == "R": + iter_type = IterVar.CommReduce + else: + self.context.report_error( + f'`iter_type` only expected "S" (for spatial) or "R" (for reduce), ' + f'but got "{iter_ty}"', + span, + ) + + if not isinstance(loop_var, tvm.tir.expr.Var): + self.context.report_error( + f"Values of `axis.remap` expected single loop var, but got {loop_var}", + loop_var.span, + ) + loops = self.context.loop_stack + if loop_var not in loops: + self.context.report_error( + f"Cannot find loop var {loop_var} in loop nesting.", + span, + ) + self.axis(var.id.name, loops[loop_var], loop_var, iter_type) + + super().__init__(axis_remap, def_symbol=True) + + def signature(self) -> Tuple[str, Tuple[list, list, Any]]: + return "tir.axis.remap", get_param_list(self.func) + + @register class BlockPredicate(SpecialStmt): """Special function where(predicate) @@ -449,7 +743,12 @@ def var(dtype, span): assert isinstance( self.node, ast.Assign ), f"VarDef expected ast.Assign but got {type(self.node)}" - v = te.var(self.node.lhs.id.name, dtype, span=span) + names = [x.id.name for x in self.node.lhs] + if len(names) != 1: + self.context.report_error( + f"VarDef expected assign to only one var, but got {names}", span + ) + v = te.var(names[0], dtype, span=span) self.context.update_symbol(v.name, v, self.node) super().__init__(var, def_symbol=True) @@ -464,8 +763,13 @@ def buffer_var(dtype, storage_scope, span): assert isinstance( self.node, ast.Assign ), f"BufferVarDef expected ast.Assign but got {type(self.node)}" + names = [x.id.name for x in self.node.lhs] + if len(names) != 1: + self.context.report_error( + f"VarDef expected assign to only one var, but got {names}", span + ) ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope) - v = te.var(self.node.lhs.id.name, ptr_type, span=span) + v = te.var(names[0], ptr_type, span=span) self.context.update_symbol(v.name, v, self.node) super().__init__(buffer_var, def_symbol=True) @@ -480,7 +784,12 @@ def env_thread(env_name, span): assert isinstance( self.node, ast.Assign ), f"EnvThread expected ast.Assign but got {type(self.node)}" - v = te.var(self.node.lhs.id.name, span=span) + names = [x.id.name for x in self.node.lhs] + if len(names) != 1: + self.context.report_error( + f"VarDef expected assign to only one var, but got {names}", span + ) + v = te.var(names[0], span=span) self.context.func_var_env_dict[v] = env_name self.context.update_symbol(v.name, v, self.node) diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 681e322b20823..cb0305d49e4aa 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -467,10 +467,12 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128, T.reduce_axis(0, 128)]) as [i, j, k]: - with T.init(): - C[i, j] = 0.0 - C[i, j] += A[i, k] * B[j, k] + for i, j, k in T.grip(128, 128, 128): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] += A[vi, vk] * B[vj, vk] Returns ------- diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 6a90924912b11..b002ace0e4006 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -108,8 +108,10 @@ def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None: A = T.match_buffer(a, (m, n), "float32") B = T.match_buffer(b, (m, n), "float32") - with T.block([m, n], "") as [vi, vj]: - B[vi, vj] = A[vi, vj] + for i, j in T.grid(m, n): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] Then we can make it specialized with given shapes or buffers. @@ -129,8 +131,10 @@ def mem_copy_16_16(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") B = T.match_buffer(b, (16, 16), "float32") - with T.block([16, 16], "") as [vi, vj]: - B[vi, vj] = A[vi, vj] + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] Returns ------- diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 09a52d2e7037f..786982cf704c1 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -397,7 +397,8 @@ def before_fuse(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do fuse: @@ -419,9 +420,9 @@ def after_fuse(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) # the 2 loops are fused into 1 for i_j_fused in T.serial(0, 16384): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, tir.floordiv(i_j_fused, 128)) - T.bind(vj, T.floormod(i_j_fused, 128)) + with T.block("B"): + vi = T.axis.S(128, T.floordiv(i_j_fused, 128)) + vj = T.axis.S(128, T.floormod(i_j_fused, 128)) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -468,7 +469,8 @@ def before_split(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B") as [vi, vj]: + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do split: @@ -490,9 +492,9 @@ def after_split(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) # the original loop is split into 2 loops for i0, i1, j in T.grid(2, 64, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, ((i0*64) + i1)) - T.bind(vj, j) + with T.block("B"): + vi = T.axis.S(128, i0 * 64 + i1) + vj = T.axis.S(128, j) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -529,7 +531,8 @@ def before_reorder(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do reorder: @@ -551,9 +554,8 @@ def after_reorder(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) # Here j and i are reordered for j, i in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -586,9 +588,8 @@ def before_parallel(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do parallel: @@ -609,9 +610,8 @@ def after_parallel(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i in T.parallel(0, 128): for j in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -642,9 +642,8 @@ def before_vectorize(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do vectorize: @@ -665,9 +664,8 @@ def after_vectorize(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i in T.serial(0, 128): for j in T.vectorized(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -706,9 +704,8 @@ def before_bind(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do bind: @@ -730,9 +727,8 @@ def after_bind(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i in T.thread_binding(0, 128, thread = "blockIdx.x"): for j in T.thread_binding(0, 128, thread = "threadIdx.x"): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -758,9 +754,8 @@ def before_unroll(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do unroll: @@ -781,9 +776,8 @@ def after_unroll(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i in T.unroll(0, 128): for j in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -825,7 +819,8 @@ def before_cache_read(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and cache_read: @@ -847,10 +842,12 @@ def after_cache_read(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) A_local = T.alloc_buffer((128, 128), scope="local") for i, j in T.grid(128, 128): - with T.block([128, 128], "A_local") as [vi, vj]: + with T.block("A_local"): + vi, vj = T.axis.remap("SS", [i, j]) A_local[vi, vj] = A[vi, vj] for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A_local[vi, vj] * 2.0 """ @@ -893,7 +890,8 @@ def before_cache_write(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and cache_write: @@ -915,10 +913,12 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) B_local = T.alloc_buffer((128, 128), scope="local") for i, j in T.grid(128, 128): - with T.block([128, 128], "A_local") as [vi, vj]: + with T.block("A_local"): + vi, vj = T.axis.remap("SS", [i, j]) B_local[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = B_local[vi, vj] """ @@ -974,10 +974,14 @@ def before_compute_at(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 Create the schedule and do compute-at: @@ -1000,14 +1004,12 @@ def after_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") for i in T.serial(0, 128): for j in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for j in T.serial(0, 128): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 """ @@ -1061,10 +1063,14 @@ def before_reverse_compute_at(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 Create the schedule and do reverse-compute-at: @@ -1087,14 +1093,12 @@ def after_reverse_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") for i in T.serial(0, 128): for j in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for j in T.serial(0, 128): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 """ @@ -1135,10 +1139,14 @@ def before_inline(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 Create the schedule and do compute-inline: @@ -1156,8 +1164,10 @@ def before_inline(a: T.handle, c: T.handle) -> None: def after_inline(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 """ _ffi_api.ScheduleComputeInline(self, block) # type: ignore # pylint: disable=no-member @@ -1195,10 +1205,14 @@ def before_inline(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 Create the schedule and do reverse-compute-inline: @@ -1216,8 +1230,10 @@ def before_inline(a: T.handle, c: T.handle) -> None: def after_inline(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 """ _ffi_api.ScheduleReverseComputeInline(self, block) # type: ignore # pylint: disable=no-member @@ -1384,8 +1400,9 @@ def rfactor(self, loop: LoopRV, factor_axis: int) -> LoopRV: def before_rfactor(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128,)) - with T.block([128, T.reduce_axis(0, 128), - T.reduce_axis(0, 128)], "B") as [vii, vi, vj]: + for ii, i, j in T.grid(128, 128, 128): + with T.block("B"): + vii, vi, vj = T.axis.remap("SRR", [ii, i, j]) with T.init(): B[vii] = 0.0 B[vii] = B[vii] + A[vii, vi, vj] @@ -1408,14 +1425,18 @@ def after_rfactor(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128]) B = T.match_buffer(b, [128]) B_rf = T.alloc_buffer([128, 128]) - with T.block([128, 128, T.reduce_axis(0, 128)], "B_rf") as [vi2, vii, vi]: - with T.init(): - B_rf[vi2, vii] = 0.0 - B_rf[vi2, vii] = (B_rf[vi2, vii] + A[vii, vi, vi2]) - with T.block([128, T.reduce_axis(0, 128)], "B") as [vii_1, vi2_1]: - with T.init(): - B[vii_1] = 0.0 - B[vii_1] = (B[vii_1] + B_rf[vi2_1, vii_1]) + for i2, ii, i in T.grid(128, 128, 128): + with T.block("B_rf"): + vi2, vii, vi = T.axis.remap("SSR", [i2, ii, i]) + with T.init(): + B_rf[vi2, vii] = 0.0 + B_rf[vi2, vii] = (B_rf[vi2, vii] + A[vii, vi, vi2]) + for ii, i2 in T.grid(128, 128): + with T.block("B"): + vii, vi2 = T.axis.remap("SR", [ii, i2]) + with T.init(): + B[vii] = 0.0 + B[vii] = B[vii] + B_rf[vi2, vii] Note @@ -1483,10 +1504,14 @@ def before_storage_align(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 Create the schedule and do storage_align: @@ -1505,11 +1530,15 @@ def after_storage_align(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - T.block_attr({"buffer_dim_align": [[[0, 128, 1]]]}) - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + T.block_attr({"buffer_dim_align": [[[0, 128, 1]]]}) + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 After lowering passes, buffer B will have strides as [129, 1]. diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 1abba77a801f0..722810e9aa5bc 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -628,7 +628,7 @@ def CompactBufferAllocation(): .. code-block:: python for i in range(0, 16): - with T.block([]): + with T.block(): B = T.alloc_buffer(16, 16) for j in range(0, 16): B[i, j] = A[i, j] + 1 @@ -643,7 +643,7 @@ def CompactBufferAllocation(): .. code-block:: python for i in range(0, 16): - with T.block([]): + with T.block(): B = T.alloc_buffer(1, 16) for j in range(0, 16): B[0, j] = A[i, j] + 1 diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index fa74e56f491c4..13e4cfcd30bab 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -22,10 +22,10 @@ * \brief Printer class to print Tensor IR to python syntax script */ -#include #include #include #include +#include #include #include #include @@ -128,7 +128,17 @@ class TVMScriptPrinter : public StmtFunctor, /*! \brief the number of current node */ int current_num_; /*! \brief loop stack without annotations */ - std::vector loop_stack_; + std::vector simple_loop_stack_; + /*! \brief the maps from loop_vars to the loops */ + std::unordered_map loop_var_map_; + /*! + * \brief simple block vars remap from loop vars + * simple_remap requires: + * 1. block var iter type is kDataPar or kCommReduce + * 2. value is a single Var, which is a loop_var outside the block + * 3. The iter range is equal to loop range + */ + std::vector> block_var_remaps_; Doc VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const VarNode* op, ExprPrecedence* out_precedence) override; @@ -193,7 +203,9 @@ class TVMScriptPrinter : public StmtFunctor, Doc PrintArray(const ArrayNode* op); Doc PrintBuffer(const BufferNode* op); Doc AllocBufferDeclaration(const Buffer& buf); - Doc PrintBlockVar(const BlockNode* op); + Doc PrintBlockVar(const IterVar& iter_var, const PrimExpr& value); + Doc PrintBlockVarRemaps(); + Doc PrintBlockVars(const BlockRealizeNode* op); Doc PrintBlockAttr(const BlockRealizeNode* op); Doc PrintBlockBody(const BlockNode* op); Doc PrintBufferRegion(const BufferRegionNode* op); @@ -821,21 +833,23 @@ Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) { Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) { Doc doc; var_not_in_headers_.insert(op->loop_var.get()); + loop_var_map_[op->loop_var.get()] = GetRef(op); const auto* body = op->body.as(); bool simple_loop = op->kind == ForKind::kSerial && op->annotations.empty() && is_zero(op->min); - if (simple_loop) loop_stack_.push_back(GetRef(op)); + if (simple_loop) simple_loop_stack_.push_back(GetRef(op)); // It is a loop that can be compressed, let the loops below print it out if (simple_loop && body != nullptr) { Doc result = Print(GetRef(body)); TryDeallocVar(op->loop_var); + loop_var_map_.erase(op->loop_var.get()); return result; } // It is a loop that can not be compressed - bool print_above = !loop_stack_.empty(); + bool print_above = !simple_loop_stack_.empty(); // print loops above if needed if (print_above) { doc << PrintLoopStack(); - loop_stack_.clear(); + simple_loop_stack_.clear(); } if (!simple_loop) { // print current loop if needed @@ -847,6 +861,7 @@ Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) { doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); } TryDeallocVar(op->loop_var); + loop_var_map_.erase(op->loop_var.get()); return doc; } @@ -901,52 +916,99 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* op) { return doc; } -Doc TVMScriptPrinter::PrintBlockVar(const BlockNode* op) { +Doc TVMScriptPrinter::PrintBlockVar(const IterVar& iter_var, const PrimExpr& value) { Doc doc; - doc << "with " << tir_prefix_ << ".block(["; - std::vector block_var_docs; - for (const auto& iter_var : op->iter_vars) { - Doc block_var_doc; - if (is_zero(iter_var->dom->min) && iter_var->iter_type == kDataPar) { - block_var_doc << Print(iter_var->dom->extent); + doc << Print(iter_var->var) << " = " << tir_prefix_ << ".axis."; + switch (iter_var->iter_type) { + case kDataPar: + doc << "spatial"; + break; + case kCommReduce: + doc << "reduce"; + break; + case kOrdered: + doc << "scan"; + break; + case kOpaque: + doc << "opaque"; + break; + default: + LOG(FATAL) << "Unknown block var iter type: " << iter_var->iter_type; + break; + } + doc << "("; + const Range& dom = iter_var->dom; + if (is_zero(dom->min)) { + doc << Print(dom->extent); + } else { + doc << "(" << Print(dom->min) << ", " << Print(dom->min + dom->extent) << ")"; + } + doc << ", " << Print(value) << ")"; + return doc; +} + +Doc TVMScriptPrinter::PrintBlockVarRemaps() { + ICHECK(!block_var_remaps_.empty()); + if (block_var_remaps_.size() == 1) { + const IterVar& iter_var = block_var_remaps_[0].first; + const PrimExpr& value = block_var_remaps_[0].second; + return PrintBlockVar(iter_var, value); + } + Doc doc; + std::vector iter_vars, iter_values; + std::string iter_type; + for (const auto& pair : block_var_remaps_) { + const IterVar& iter_var = pair.first; + const PrimExpr& value = pair.second; + iter_vars.push_back(Print(iter_var->var)); + iter_values.push_back(Print(value)); + if (iter_var->iter_type == kDataPar) { + iter_type += "S"; + } else if (iter_var->iter_type == kCommReduce) { + iter_type += "R"; } else { - block_var_doc << tir_prefix_ << "."; - switch (iter_var->iter_type) { - case kDataPar: - block_var_doc << "range"; - break; - case kCommReduce: - block_var_doc << "reduce_axis"; - break; - case kOrdered: - block_var_doc << "scan_axis"; - break; - case kOpaque: - block_var_doc << "opaque_axis"; - break; - default: - LOG(FATAL) << "Unknown block var iter type: " << iter_var->iter_type; - break; - } - block_var_doc << "(" << Print(iter_var->dom->min) << ", " - << Print(iter_var->dom->min + iter_var->dom->extent) << ")"; + ICHECK(false); } - block_var_docs.push_back(block_var_doc); - } - doc << PrintSep(block_var_docs, Doc::Text(", ")) << "]"; - if (!op->name_hint.empty()) { - doc << ", " << Doc::StrLiteral(op->name_hint); } - doc << ")"; - std::vector block_var_names; - for (const auto& iter_var : op->iter_vars) { + doc << PrintSep(iter_vars, Doc::Text(", ")) << " = " << tir_prefix_ << ".axis.remap(" + << Doc::StrLiteral(iter_type) << ", [" << PrintSep(iter_values, Doc::Text(", ")) << "])"; + return doc; +} + +Doc TVMScriptPrinter::PrintBlockVars(const BlockRealizeNode* op) { + Doc doc; + const auto* block_op = op->block.as(); + ICHECK_EQ(block_op->iter_vars.size(), op->iter_values.size()); + tir::ExprDeepEqual expr_equal; + + auto is_simple_remap = [this, &expr_equal](const IterVar& iter_var, + const PrimExpr& value) -> bool { + if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) return false; + if (!value->IsInstance()) return false; + const Var& var = Downcast(value); + auto it = loop_var_map_.find(var.get()); + return it != loop_var_map_.end() && expr_equal(it->second->min, iter_var->dom->min) && + expr_equal(it->second->extent, iter_var->dom->extent); + }; + + for (size_t i = 0; i < block_op->iter_vars.size(); ++i) { + const IterVar& iter_var = block_op->iter_vars[i]; + const PrimExpr& value = op->iter_values[i]; var_not_in_headers_.insert(iter_var->var.get()); - block_var_names.push_back(Print(iter_var->var)); + if (is_simple_remap(iter_var, value)) { + block_var_remaps_.push_back(std::make_pair(iter_var, value)); + } else { + if (!block_var_remaps_.empty()) { + doc << Doc::NewLine() << PrintBlockVarRemaps(); + block_var_remaps_.clear(); + } + doc << Doc::NewLine() << PrintBlockVar(iter_var, value); + } } - if (!block_var_names.empty()) { - doc << " as [" << PrintSep(block_var_names, Doc::Text(", ")) << "]"; + if (!block_var_remaps_.empty()) { + doc << Doc::NewLine() << PrintBlockVarRemaps(); + block_var_remaps_.clear(); } - doc << ":"; return doc; } @@ -957,10 +1019,6 @@ Doc TVMScriptPrinter::PrintBlockAttr(const BlockRealizeNode* op) { if (!is_one(op->predicate)) { block_attr_doc << Doc::NewLine() << tir_prefix_ << ".where(" << Print(op->predicate) << ")"; } - for (size_t i = 0; i < block_op->iter_vars.size(); ++i) - block_attr_doc << Doc::NewLine() << tir_prefix_ << ".bind(" - << Print(block_op->iter_vars[i]->var) << ", " << Print(op->iter_values[i]) - << ")"; block_attr_doc << Doc::NewLine() << tir_prefix_ << ".reads(" << Print(block_op->reads) << ")"; block_attr_doc << Doc::NewLine() << tir_prefix_ << ".writes(" << Print(block_op->writes) << ")"; if (!block_op->annotations.empty()) { @@ -994,12 +1052,18 @@ Doc TVMScriptPrinter::PrintBlockBody(const BlockNode* op) { Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) { const auto* block_op = op->block.as(); // print block name and block vars - Doc doc = PrintBlockVar(block_op); + Doc doc; + doc << "with " << tir_prefix_ << ".block("; + if (!block_op->name_hint.empty()) { + doc << Doc::StrLiteral(block_op->name_hint); + } + doc << "):"; + Doc block_var = PrintBlockVars(op); // print predicate, binding, read/write tensor region, annotations Doc block_attr_doc = PrintBlockAttr(op); // print body Doc body = PrintBlockBody(block_op); - doc << Doc::Indent(4, block_attr_doc << Doc::NewLine() << body); + doc << Doc::Indent(4, block_var << block_attr_doc << Doc::NewLine() << body); for (const auto& iter_var : block_op->iter_vars) { TryDeallocVar(iter_var->var); } @@ -1265,11 +1329,11 @@ Doc TVMScriptPrinter::PrintLoop(const For& loop) { Doc TVMScriptPrinter::PrintLoopStack() { Doc res; - if (loop_stack_.size() == 1) { - res << PrintLoop(loop_stack_[0]); - } else if (loop_stack_.size() > 1) { + if (simple_loop_stack_.size() == 1) { + res << PrintLoop(simple_loop_stack_[0]); + } else if (simple_loop_stack_.size() > 1) { std::vector vars, extents; - for (const auto& loop : loop_stack_) { + for (const auto& loop : simple_loop_stack_) { vars.push_back(Print(loop->loop_var)); extents.push_back(Print(loop->extent)); } diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc index f265a8ae2b1b9..335fdfa215038 100644 --- a/src/tir/ir/script/script_complete.cc +++ b/src/tir/ir/script/script_complete.cc @@ -44,21 +44,11 @@ class ScriptCompleter : public StmtMutator { Map* buffer_var_map_; Stmt VisitStmt_(const BlockRealizeNode* op) override { contains_block = true; - Stmt body = StmtMutator::VisitStmt_(op); - if (!op->iter_values.empty() && !op->iter_values[0].dtype().is_int()) { - auto block_with_binding = CopyOnWrite(Downcast(body).get()); - std::vector bindings; - for (size_t i = 0; i < op->iter_values.size(); ++i) { - bindings.push_back(Var("i" + std::to_string(i))); - } - block_with_binding->iter_values = bindings; - body = BlockRealize(block_with_binding); - for (int i = op->iter_values.size() - 1; i >= 0; --i) { - body = For(Downcast(bindings[i]), op->block->iter_vars[i]->dom->min, - op->block->iter_vars[i]->dom->extent, {}, body); - } + for (const PrimExpr& value : op->iter_values) { + CHECK(value.dtype().is_int()) + << "BlockRealize iter_value expected a IntImm, but got " << value.dtype(); } - return body; + return StmtMutator::VisitStmt_(op); } Stmt VisitStmt_(const BlockNode* op) override { diff --git a/tests/python/integration/test_lower.py b/tests/python/integration/test_lower.py index 690258c2fa3b4..63733b05ab3fa 100644 --- a/tests/python/integration/test_lower.py +++ b/tests/python/integration/test_lower.py @@ -33,9 +33,8 @@ def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None: # body for blockIdx_x in T.thread_binding(0, 16, "blockIdx.x"): for blockIdx_y in T.thread_binding(0, 8, "blockIdx.y"): - with T.block([16, 8]) as [bx, by]: - T.bind(bx, blockIdx_x) - T.bind(by, blockIdx_y) + with T.block(): + bx, by = T.axis.remap("SS", [blockIdx_x, blockIdx_y]) shared_A = T.alloc_buffer([1024, 1024], "float16", scope="shared") shared_B = T.alloc_buffer([1024, 1024], "float16", scope="shared") wmma_A = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a") @@ -44,9 +43,9 @@ def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None: for ty in T.thread_binding(0, 2, "threadIdx.y"): for tz in T.thread_binding(0, 2, "threadIdx.z"): for i, j in T.grid(2, 4): - with T.block([64, 64]) as [vi, vj]: - T.bind(vi, bx * 4 + ty * 2 + i) - T.bind(vj, by * 8 + tz * 4 + j) + with T.block(): + vi = T.axis.S(64, bx * 4 + ty * 2 + i) + vj = T.axis.S(64, by * 8 + tz * 4 + j) T.reads([]) T.writes(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) C0 = T.match_buffer( @@ -74,23 +73,23 @@ def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None: for tx in T.thread_binding(0, 32, "threadIdx.x"): for i0, j0 in T.grid(1, 4): for j1 in T.vectorized(0, 4): - with T.block([1024, 1024]) as [vi, vj]: - T.bind(vi, bx * 64 + ty * 32 + tx + i0) - T.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) + with T.block(): + vi = T.axis.S(1024, bx * 64 + ty * 32 + tx + i0) + vj = T.axis.S(1024, ko * 32 + tz * 16 + j0 * 4 + j1) shared_A[vi, vj + 8] = A[vi, vj] for i0, j0 in T.grid(2, 4): for j1 in T.vectorized(0, 4): - with T.block([1024, 1024]) as [vi, vj]: - T.bind(vi, by * 128 + ty * 64 + tx * 2 + i0) - T.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) + with T.block(): + vi = T.axis.S(1024, by * 128 + ty * 64 + tx * 2 + i0) + vj = T.axis.S(1024, ko * 32 + tz * 16 + j0 * 4 + j1) shared_B[vi, vj + 8] = B[vi, vj] for ki in range(0, 2): for i in range(0, 2): - with T.block([64, 64]) as [vi, vk]: - T.bind(vi, bx * 4 + ty * 2 + i) - T.bind(vk, ko * 2 + ki) + with T.block(): + vi = T.axis.S(64, bx * 4 + ty * 2 + i) + vk = T.axis.S(64, ko * 2 + ki) T.reads( shared_A[ vi * 16 : vi * 16 + 16, @@ -142,9 +141,9 @@ def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None: ) ) for j in range(0, 4): - with T.block([64, 64]) as [vj, vk]: - T.bind(vj, by * 8 + tz * 4 + j) - T.bind(vk, ko * 2 + ki) + with T.block(): + vj = T.axis.S(64, by * 8 + tz * 4 + j) + vk = T.axis.S(64, ko * 2 + ki) T.reads( shared_B[ vj * 16 : vj * 16 + 16, @@ -196,14 +195,10 @@ def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None: ) ) for i, j in T.grid(2, 4): - with T.block([64, 64, T.reduce_axis(0, 64)]) as [ - vi, - vj, - vk, - ]: - T.bind(vi, bx * 4 + ty * 2 + i) - T.bind(vj, by * 8 + tz * 4 + j) - T.bind(vk, ko * 2 + ki) + with T.block(): + vi = T.axis.S(64, bx * 4 + ty * 2 + i) + vj = T.axis.S(64, by * 8 + tz * 4 + j) + vk = T.axis.R(64, ko * 2 + ki) T.reads( [ wmma_A[ @@ -258,9 +253,9 @@ def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None: ) ) for i, j in T.grid(2, 4): - with T.block([64, 64]) as [vi, vj]: - T.bind(vi, bx * 4 + ty * 2 + i) - T.bind(vj, by * 8 + tz * 4 + j) + with T.block(): + vi = T.axis.S(64, bx * 4 + ty * 2 + i) + vj = T.axis.S(64, by * 8 + tz * 4 + j) T.reads(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) s0 = T.var("int32") diff --git a/tests/python/unittest/test_lower_build.py b/tests/python/unittest/test_lower_build.py index 6502f0c67de62..fabf41705698f 100644 --- a/tests/python/unittest/test_lower_build.py +++ b/tests/python/unittest/test_lower_build.py @@ -41,10 +41,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block([128, 128], "init") as [vi, vj]: + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.float32(0) - for k in range(0, 128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + for k in range(128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] diff --git a/tests/python/unittest/test_meta_schedule_arg_info.py b/tests/python/unittest/test_meta_schedule_arg_info.py index 7bedea9082d14..62dcb52f74153 100644 --- a/tests/python/unittest/test_meta_schedule_arg_info.py +++ b/tests/python/unittest/test_meta_schedule_arg_info.py @@ -28,10 +28,12 @@ def Matmul(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 256), "float32") B = T.match_buffer(b, (256, 512), "float32") C = T.match_buffer(c, (128, 512), "float32") - with T.block([128, 256, T.reduce_axis(0, 512)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(128, 256, 512): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument diff --git a/tests/python/unittest/test_meta_schedule_builder.py b/tests/python/unittest/test_meta_schedule_builder.py index fa09a092c8c46..fb3fa135a9b8c 100644 --- a/tests/python/unittest/test_meta_schedule_builder.py +++ b/tests/python/unittest/test_meta_schedule_builder.py @@ -47,10 +47,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") - with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] @script.ir_module @@ -64,12 +66,16 @@ def matmul_relu( # pylint: disable=no-self-argument B = T.match_buffer(b, (1024, 1024), "float32") D = T.match_buffer(d, (1024, 1024), "float32") C = T.alloc_buffer((1024, 1024), "float32") - with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - with T.block([1024, 1024], "relu") as [vi, vj]: - D[vi, vj] = T.max(C[vi, vj], 0.0) + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(1024, 1024): + with T.block("relu"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = T.max(C[vi, vj], 0.0) @script.ir_module @@ -82,10 +88,12 @@ def batch_matmul( # pylint: disable=no-self-argument A = T.match_buffer(a, [16, 128, 128]) B = T.match_buffer(b, [16, 128, 128]) C = T.match_buffer(c, [16, 128, 128]) - with T.block([16, 128, 128, T.reduce_axis(0, 128)], "update") as [vn, vi, vj, vk]: - with T.init(): - C[vn, vi, vj] = 0.0 - C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] + for n, i, j, k in T.grid(16, 128, 128, 128): + with T.block("update"): + vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) + with T.init(): + C[vn, vi, vj] = 0.0 + C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring diff --git a/tests/python/unittest/test_meta_schedule_database.py b/tests/python/unittest/test_meta_schedule_database.py index cb39c91eaca46..121ec2fd480bb 100644 --- a/tests/python/unittest/test_meta_schedule_database.py +++ b/tests/python/unittest/test_meta_schedule_database.py @@ -41,10 +41,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") - with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] @tvm.script.ir_module @@ -56,12 +58,16 @@ def main(a: T.handle, b: T.handle, d: T.handle) -> None: # pylint: disable=no-s B = T.match_buffer(b, (16, 16), "float32") D = T.match_buffer(d, (16, 16), "float32") C = T.alloc_buffer((16, 16), "float32") - with T.block([16, 16, T.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - with T.block([16, 16], "relu") as [vi, vj]: - D[vi, vj] = T.max(C[vi, vj], 0.0) + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(16, 16): + with T.block("relu"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = T.max(C[vi, vj], 0.0) # fmt: on diff --git a/tests/python/unittest/test_meta_schedule_runner.py b/tests/python/unittest/test_meta_schedule_runner.py index 9fb1e5ef19c1c..46be12569c783 100644 --- a/tests/python/unittest/test_meta_schedule_runner.py +++ b/tests/python/unittest/test_meta_schedule_runner.py @@ -68,10 +68,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s A = T.match_buffer(a, (16, 16), "float32") B = T.match_buffer(b, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") - with T.block([16, 16, T.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] @tvm.script.ir_module @@ -83,12 +85,16 @@ def main(a: T.handle, b: T.handle, d: T.handle) -> None: # pylint: disable=no-s B = T.match_buffer(b, (16, 16), "float32") D = T.match_buffer(d, (16, 16), "float32") C = T.alloc_buffer((16, 16), "float32") - with T.block([16, 16, T.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - with T.block([16, 16], "relu") as [vi, vj]: - D[vi, vj] = T.max(C[vi, vj], 0.0) + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(16, 16): + with T.block("relu"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = T.max(C[vi, vj], 0.0) @tvm.script.ir_module @@ -99,10 +105,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s A = T.match_buffer(a, [16, 32, 32]) B = T.match_buffer(b, [16, 32, 32]) C = T.match_buffer(c, [16, 32, 32]) - with T.block([16, 32, 32, T.reduce_axis(0, 32)], "update") as [vn, vi, vj, vk]: - with T.init(): - C[vn, vi, vj] = 0.0 - C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] + for n, i, j, k in T.grid(16, 32, 32, 32): + with T.block("update"): + vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) + with T.init(): + C[vn, vi, vj] = 0.0 + C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] @tvm.script.ir_module @@ -113,8 +121,10 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s A = T.match_buffer(a, [32], "float32") B = T.match_buffer(b, [32], "float32") C = T.match_buffer(c, [32], "float32") - with T.block([32], "add") as [vi]: - C[vi] = A[vi] + B[vi] + for i in range(32): + with T.block("add"): + vi = T.axis.S(32, i) + C[vi] = A[vi] + B[vi] # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index e12871391558c..9b3ddfd7c7893 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -45,10 +45,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (32, 32), "float32") B = T.match_buffer(b, (32, 32), "float32") C = T.match_buffer(c, (32, 32), "float32") - with T.block([32, 32, T.reduce_axis(0, 32)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(32, 32, 32): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument diff --git a/tests/python/unittest/test_meta_schedule_space_generator.py b/tests/python/unittest/test_meta_schedule_space_generator.py index 39bb1acf065f8..3f7749ca9e2cf 100644 --- a/tests/python/unittest/test_meta_schedule_space_generator.py +++ b/tests/python/unittest/test_meta_schedule_space_generator.py @@ -40,10 +40,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") - with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py index a304096965433..4854aeb5f5aa9 100644 --- a/tests/python/unittest/test_meta_schedule_task_scheduler.py +++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py @@ -48,10 +48,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") - with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] @tvm.script.ir_module @@ -63,12 +65,16 @@ def main(a: T.handle, b: T.handle, d: T.handle) -> None: # pylint: disable=no-s B = T.match_buffer(b, (1024, 1024), "float32") D = T.match_buffer(d, (1024, 1024), "float32") C = T.alloc_buffer((1024, 1024), "float32") - with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - with T.block([1024, 1024], "relu") as [vi, vj]: - D[vi, vj] = T.max(C[vi, vj], 0.0) + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(1024, 1024): + with T.block("relu"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = T.max(C[vi, vj], 0.0) @tvm.script.ir_module @@ -79,10 +85,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s A = T.match_buffer(a, [16, 128, 128]) B = T.match_buffer(b, [16, 128, 128]) C = T.match_buffer(c, [16, 128, 128]) - with T.block([16, 128, 128, T.reduce_axis(0, 128)], "matmul") as [vn, vi, vj, vk]: - with T.init(): - C[vn, vi, vj] = 0.0 - C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] + for n, i, j, k in T.grid(16, 128, 128, 128): + with T.block("matmul"): + vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) + with T.init(): + C[vn, vi, vj] = 0.0 + C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks diff --git a/tests/python/unittest/test_meta_schedule_tune_context.py b/tests/python/unittest/test_meta_schedule_tune_context.py index 44bb949b925b9..01a4379e5127f 100644 --- a/tests/python/unittest/test_meta_schedule_tune_context.py +++ b/tests/python/unittest/test_meta_schedule_tune_context.py @@ -35,10 +35,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") - with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index 987898001a1b7..6b5c26d08b7b5 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -54,10 +54,12 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128, T.reduce_axis(0, 128)]) as [i, j, k]: - with T.init(): - C[i, j] = 0.0 - C[i, j] += A[i, k] * B[j, k] + for i0, j0, k0 in T.grid(128, 128, 128): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + with T.init(): + C[i, j] = 0.0 + C[i, j] += A[i, k] * B[j, k] def test_matmul(): @@ -77,10 +79,14 @@ def tir_element_wise(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128)) B = T.alloc_buffer((128, 128)) - with T.block([128, 128]) as [i, j]: - B[i, j] = A[i, j] * 2.0 - with T.block([128, 128]) as [i, j]: - C[i, j] = B[i, j] + 1.0 + for i0, j0 in T.grid(128, 128): + with T.block(): + i, j = T.axis.remap("SS", [i0, j0]) + B[i, j] = A[i, j] * 2.0 + for i0, j0 in T.grid(128, 128): + with T.block(): + i, j = T.axis.remap("SS", [i0, j0]) + C[i, j] = B[i, j] + 1.0 def test_element_wise(): @@ -125,19 +131,21 @@ def tir_conv2d(a: T.handle, w: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [16, 32, 14, 14]) Apad = T.alloc_buffer([16, 16, 16, 16]) - with T.block([16, 16, 16, 16], "Apad") as [nn, cc, yy, xx]: - Apad[nn, cc, yy, xx] = T.if_then_else( - yy >= 1 and yy - 1 < 14 and xx >= 1 and xx - 1 < 14, - A[nn, cc, yy - 1, xx - 1], - 0.0, - dtype="float32", - ) - with T.block( - [16, 32, 14, 14, T.reduce_axis(0, 16), T.reduce_axis(0, 3), T.reduce_axis(0, 3)], "B" - ) as [nn, ff, yy, xx, rc, ry, rx]: - with T.init(): - B[nn, ff, yy, xx] = 0.0 - B[nn, ff, yy, xx] += Apad[nn, rc, yy + ry, xx + rx] * W[rc, ry, rx, ff] + for n, c, y, x in T.grid(16, 16, 16, 16): + with T.block("Apad"): + nn, cc, yy, xx = T.axis.remap("SSSS", [n, c, y, x]) + Apad[nn, cc, yy, xx] = T.if_then_else( + yy >= 1 and yy - 1 < 14 and xx >= 1 and xx - 1 < 14, + A[nn, cc, yy - 1, xx - 1], + 0.0, + dtype="float32", + ) + for n, f, y, x, kc, ky, kx in T.grid(16, 32, 14, 14, 16, 3, 3): + with T.block("B"): + nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [n, f, y, x, kc, ky, kx]) + with T.init(): + B[nn, ff, yy, xx] = 0.0 + B[nn, ff, yy, xx] += Apad[nn, rc, yy + ry, xx + rx] * W[rc, ry, rx, ff] def test_conv2d(): @@ -163,9 +171,11 @@ def tir_multi_output(a0: T.handle, a1: T.handle, b0: T.handle, b1: T.handle) -> B1 = T.match_buffer(b1, (m, n)) for i0, i1 in T.grid(m, n): - with T.block([m, n], "B.v0") as [i, j]: + with T.block("B.v0"): + i, j = T.axis.remap("SS", [i0, i1]) B0[i, j] = A0[i, j] + 2.0 - with T.block([m, n], "B.v1") as [i, j]: + with T.block("B.v1"): + i, j = T.axis.remap("SS", [i0, i1]) B1[i, j] = A1[i, j] * 3.0 @@ -193,7 +203,7 @@ def tir_extern(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) # body - with T.block([], "C"): + with T.block("C"): T.reads([A[0:128, 0:128], B[0:128, 0:128]]) T.writes([C[0:128, 0:128]]) T.evaluate( @@ -251,10 +261,12 @@ def tir_reordered_matmul(c: T.handle, a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128, T.reduce_axis(0, 128)]) as [i, j, k]: - with T.init(): - C[i, j] = 0.0 - C[i, j] += A[i, k] * B[j, k] + for i0, j0, k0 in T.grid(128, 128, 128): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + with T.init(): + C[i, j] = 0.0 + C[i, j] += A[i, k] * B[j, k] def test_arg_order(): diff --git a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py index 1aae8cdd03e1a..1a0dfd09a2df2 100644 --- a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py +++ b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py @@ -25,17 +25,23 @@ def buffer_load_store_func(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128), "float32") C = T.alloc_buffer((128, 128), "float32") D = T.alloc_buffer((128, 128), "float32") - with T.block([128, 128]) as [i, j]: - A[i, j] = T.float32(0) - with T.block([32, 32, T.reduce_axis(0, 32)]) as [i, j, k]: - with T.init(): + for ii, jj in T.grid(128, 128): + with T.block(): + i, j = T.axis.remap("SS", [ii, jj]) + A[i, j] = T.float32(0) + for i0, j0, k0 in T.grid(32, 32, 32): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + with T.init(): + for ii, jj in T.grid(4, 4): + B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] for ii, jj in T.grid(4, 4): - B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] - for ii, jj in T.grid(4, 4): - for kk in range(0, 4): - B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk] - for kk in range(0, 4): - B[i * 4 + ii, j * 4 + jj] += D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk] + for kk in range(0, 4): + B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk] + for kk in range(0, 4): + B[i * 4 + ii, j * 4 + jj] += ( + D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk] + ) @T.prim_func @@ -43,7 +49,7 @@ def buffer_opaque_access(b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [16, 16], "float32") C = T.match_buffer(c, [16, 16], "float32") - with T.block([]): + with T.block(): T.reads([]) T.writes(B[0:16, 0:16]) A = T.allocate([256], "float32", "global") @@ -56,9 +62,8 @@ def buffer_opaque_access(b: T.handle, c: T.handle) -> None: T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, T.float32(0), dtype="handle")) for i, j in T.grid(16, 16): - with T.block([16, 16]) as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] @@ -72,16 +77,20 @@ def lca_is_func_root(a: T.handle) -> None: def match_buffer_func(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") B = T.match_buffer(b, (128, 128), "float32") - with T.block([8, 8], "block") as [vi, vj]: - T.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) - T.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - B0 = T.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) - B1 = T.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8)) - with T.block([16, 16], "AAA") as [i, j]: - AA = T.match_buffer(A[i, j], ()) - AA[()] = 1.0 - T.evaluate(B0.data) - T.evaluate(B1.data) + for i, j in T.grid(8, 8): + with T.block("block"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) + T.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + B0 = T.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) + B1 = T.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8)) + for ii, jj in T.grid(16, 16): + with T.block("AAA"): + vii, vjj = T.axis.remap("SS", [ii, jj]) + AA = T.match_buffer(A[vii, vjj], ()) + AA[()] = 1.0 + T.evaluate(B0.data) + T.evaluate(B1.data) def test_buffer_load_store(): diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py index e3a63c3254344..4ea35c0a2d6c4 100644 --- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py +++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py @@ -27,57 +27,65 @@ def func() -> None: B = T.alloc_buffer((128, 128), "float32") C = T.alloc_buffer((128, 128), "float32") D = T.alloc_buffer((128, 128), "float32") - with T.block([]): + with T.block(): # Need add read/write region manually to avoid triggering block access region detector T.reads([B[0, 0], C[0:16, 0:16], A[4:12, 4:12]]) T.writes([A[0:12, 0:12]]) for i, j in T.grid(8, 8): A[i, j] = B[0, 0] + C[0, 0] - with T.block([2, 2]) as [vi, vj]: - T.reads([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8], C[12:16, 12:16]]) - T.writes([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8]]) - for i, j in T.grid(4, 4): - A[vi * 4 + 4 + i, vj * 4 + 4 + j] += C[i + 12, j + 12] + for i, j in T.grid(2, 2): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8], C[12:16, 12:16]]) + T.writes([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8]]) + for i, j in T.grid(4, 4): + A[vi * 4 + 4 + i, vj * 4 + 4 + j] += C[i + 12, j + 12] T.evaluate(D.data) @T.prim_func def match_buffer_func() -> None: - with T.block([], "root"): + with T.block("root"): A = T.alloc_buffer((128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") T.reads([]) T.writes([]) # Need add read/write region manually to avoid triggering block access region detector - with T.block([8, 8], "block") as [vi, vj]: - T.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) - T.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - AA = T.match_buffer(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16)) - B0 = T.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) - B1 = T.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8)) - with T.block([16, 16], "AAA") as [i, j]: - T.reads([]) - T.writes(AA[i, j]) - AAA = T.match_buffer(AA[i, j], ()) - AAA[()] = 1.0 - T.evaluate(B0.data) - T.evaluate(B1.data) + for i, j in T.grid(8, 8): + with T.block("block"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) + T.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + AA = T.match_buffer(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16)) + B0 = T.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) + B1 = T.match_buffer( + B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8) + ) + for ii, jj in T.grid(16, 16): + with T.block("AAA"): + vii, vjj = T.axis.remap("SS", [ii, jj]) + T.reads([]) + T.writes(AA[vii, vjj]) + AAA = T.match_buffer(AA[vii, vjj], ()) + AAA[()] = 1.0 + T.evaluate(B0.data) + T.evaluate(B1.data) @T.prim_func def opaque_block_func() -> None: - with T.block([], "root"): + with T.block("root"): A = T.alloc_buffer((16, 16), "float32") B = T.alloc_buffer((16, 16), "float32") T.reads([]) T.writes([]) # Need add read/write region manually to avoid triggering block access region detector for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes([B[i, 0:16]]) for j in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, j]) T.writes(B[i, j]) B[i, j] = A[i, j] + 1.0 @@ -88,8 +96,8 @@ def opaque_access_func() -> None: A = T.alloc_buffer([1024]) B = T.alloc_buffer([1024]) for i in T.serial(0, 8): - with T.block([8]) as [v]: - T.bind(v, i) + with T.block(): + v = T.axis.S(8, i) T.reads([A[v * 128 : v * 128 + 128]]) T.writes([B[v * 128 : v * 128 + 128]]) T.evaluate( diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py index 7129275aebcdd..5ca9cf0da3c93 100644 --- a/tests/python/unittest/test_tir_lower_match_buffer.py +++ b/tests/python/unittest/test_tir_lower_match_buffer.py @@ -39,7 +39,7 @@ def buffer_load_store(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16, 16)) C = T.match_buffer(c, (16, 16)) for i, j, k in T.grid(4, 16, 8): - with T.block([]): + with T.block(): T.reads(C[i * 4 : i * 4 + 4, k * 2 : k * 2 + 2]) T.writes(A[i * 4 : i * 4 + 4, j, k * 2 : k * 2 + 2]) sub_A = T.match_buffer( @@ -55,7 +55,7 @@ def transformed_buffer_load_store(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16, 16)) C = T.match_buffer(c, (16, 16)) for i, j, k in T.grid(4, 16, 8): - with T.block([]): + with T.block(): T.reads(C[i * 4 : i * 4 + 4, k * 2 : k * 2 + 2]) T.writes(A[i * 4 : i * 4 + 4, j, k * 2 : k * 2 + 2]) for ii, kk in T.grid(4, 2): @@ -72,7 +72,7 @@ def opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (32, 64, 128)) B = T.match_buffer(b, (64, 64, 64)) for i, j, k in T.grid(2, 64, 8): - with T.block([]): + with T.block(): T.reads([]) T.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16]) sub_A = T.match_buffer( @@ -93,7 +93,7 @@ def opaque_access(a: T.handle, b: T.handle) -> None: ) ) for i, j, k in T.grid(64, 2, 8): - with T.block([]): + with T.block(): Bs_0 = T.var("int32") Bs_1 = T.var("int32") T.reads([]) @@ -122,7 +122,7 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (32, 64, 128)) B = T.match_buffer(b, (64, 64, 64)) for i, j, k in T.grid(2, 64, 8): - with T.block([]): + with T.block(): T.reads([]) T.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16]) T.evaluate( @@ -137,7 +137,7 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None: ) ) for i, j, k in T.grid(64, 2, 8): - with T.block([]): + with T.block(): T.reads([]) T.writes(B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8]) T.evaluate( @@ -157,7 +157,7 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None: def high_dim_opaque_access(a: T.handle) -> None: A = T.match_buffer(a, (16, 32, 64)) for i, j, k in T.grid(16, 2, 4): - with T.block([]): + with T.block(): As_0 = T.var("int32") As_1 = T.var("int32") T.reads([]) @@ -185,7 +185,7 @@ def high_dim_opaque_access(a: T.handle) -> None: def transformed_high_dim_opaque_access(a: T.handle) -> None: A = T.match_buffer(a, (16, 32, 64)) for i, j, k in T.grid(16, 2, 4): - with T.block([]): + with T.block(): T.reads([]) T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16]) T.evaluate( @@ -205,7 +205,7 @@ def transformed_high_dim_opaque_access(a: T.handle) -> None: def high_dim_opaque_access_with_source_strides(a: T.handle) -> None: A = T.match_buffer(a, (16, 32, 64), strides=[2576, 80, 1]) for i, j, k in T.grid(16, 2, 4): - with T.block([]): + with T.block(): As_0 = T.var("int32") As_1 = T.var("int32") T.reads([]) @@ -233,7 +233,7 @@ def high_dim_opaque_access_with_source_strides(a: T.handle) -> None: def transformed_high_dim_opaque_access_with_source_strides(a: T.handle) -> None: A = T.match_buffer(a, (16, 32, 64), strides=[2576, 80, 1]) for i, j, k in T.grid(16, 2, 4): - with T.block([]): + with T.block(): T.reads([]) T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16]) T.evaluate( @@ -254,7 +254,7 @@ def recursive_match(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (64, 64, 64)) B = T.match_buffer(b, (64, 64, 64)) for i, j, k in T.grid(64, 4, 4): - with T.block([]): + with T.block(): T.reads([]) T.writes( [ @@ -276,7 +276,7 @@ def recursive_match(a: T.handle, b: T.handle) -> None: offset_factor=1, ) for jj, kk in T.grid(4, 4): - with T.block([]): + with T.block(): T.reads([]) T.writes( [ @@ -317,7 +317,7 @@ def transformed_recursive_match(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (64, 64, 64)) B = T.match_buffer(b, (64, 64, 64)) for i, j, k in T.grid(64, 4, 4): - with T.block([]): + with T.block(): T.reads([]) T.writes( [ @@ -326,7 +326,7 @@ def transformed_recursive_match(a: T.handle, b: T.handle) -> None: ] ) for jj, kk in T.grid(4, 4): - with T.block([]): + with T.block(): T.reads([]) T.writes( [ @@ -362,7 +362,7 @@ def symbolic_match(a: T.handle, b: T.handle, n: T.int32, m: T.int32) -> None: A = T.match_buffer(a, (n * m, m)) B = T.match_buffer(b, (n * 2, m * 4)) for i in range(0, n): - with T.block([]): + with T.block(): T.reads([]) T.writes([A[i * m : i * m + n, 0:m], B[i * n : i * n + 2, 0 : m * 4]]) Bs_0 = T.var("int32") @@ -392,7 +392,7 @@ def transformed_symbolic_match(a: T.handle, b: T.handle, n: T.int32, m: T.int32) A = T.match_buffer(a, (n * m, m)) B = T.match_buffer(b, (n * 2, m * 4)) for i in range(0, n): - with T.block([]): + with T.block(): T.reads([]) T.writes([A[i * m : i * m + n, 0:m], B[i * n : i * n + 2, 0 : m * 4]]) for ii, jj in T.grid(m, m): @@ -416,7 +416,7 @@ def rank0_buffer(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (8, 8)) B = T.match_buffer(b, (8, 8)) for i, j in T.grid(8, 8): - with T.block([]): + with T.block(): T.reads([]) T.writes([A[i, j], B[i, j]]) sub_A = T.match_buffer(A[i, j], (), offset_factor=1) @@ -440,7 +440,7 @@ def transformed_rank0_buffer(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (8, 8)) B = T.match_buffer(b, (8, 8)) for i, j in T.grid(8, 8): - with T.block([]): + with T.block(): T.reads([]) T.writes([A[i, j], B[i, j]]) A[i, j] = 1 @@ -461,7 +461,7 @@ def transformed_rank0_buffer(a: T.handle, b: T.handle) -> None: def fail_match_load(a: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 8): - with T.block([]): + with T.block(): T.reads(A[i, j]) T.writes([]) sub_A = T.match_buffer(A[i, j], ()) @@ -472,7 +472,7 @@ def fail_match_load(a: T.handle) -> None: def fail_match_store(a: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 8): - with T.block([]): + with T.block(): T.reads([]) T.writes(A[i, j]) sub_A = T.match_buffer(A[i, j], ()) @@ -483,7 +483,7 @@ def fail_match_store(a: T.handle) -> None: def fail_buffer_bind(a: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 2): - with T.block([]): + with T.block(): stride = T.var("int32") sub_A = T.match_buffer( A[i, j * 4 : j * 4 + 4], (1, 4), strides=[stride, stride], offset_factor=1 @@ -496,7 +496,7 @@ def fail_buffer_bind(a: T.handle) -> None: def fail_match_func_param(a: T.handle, m: T.handle, n: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 2): - with T.block([]): + with T.block(): sub_A = T.match_buffer(A[i, j * 4 : j * 4 + 4], (1, 4), strides=[m, n], offset_factor=1) for jj in range(0, 4): sub_A[i, j * 4 + jj] = 1 diff --git a/tests/python/unittest/test_tir_schedule_block_scope.py b/tests/python/unittest/test_tir_schedule_block_scope.py index 2182c7b9f449e..ad789a0107450 100644 --- a/tests/python/unittest/test_tir_schedule_block_scope.py +++ b/tests/python/unittest/test_tir_schedule_block_scope.py @@ -32,10 +32,14 @@ def elementwise(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -44,10 +48,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block([128, 128], "init") as [vi, vj]: + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.float32(0) for k in range(0, 128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -58,9 +64,11 @@ def war_dependency(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "C") as [vi, vj]: + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index ff5b61a135ebf..853f44affe5d5 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -33,10 +33,14 @@ def elementwise(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -45,20 +49,23 @@ def access_under_scope(b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([8, 8], "scope") as [i, j]: - for x, y in T.grid(16, 16): - with T.block([128, 128], "A") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - A[vi, vj] = 1.0 - for x, y in T.grid(16, 16): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - B[vi, vj] = A[vi, vj] + 1.0 - - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] * 2.0 + for i0, j0 in T.grid(8, 8): + with T.block("scope"): + i, j = T.axis.remap("SS", [i0, j0]) + for x, y in T.grid(16, 16): + with T.block("A"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + A[vi, vj] = 1.0 + for x, y in T.grid(16, 16): + with T.block("B"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + B[vi, vj] = A[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 @T.prim_func @@ -68,76 +75,82 @@ def opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: C = T.match_buffer(c, (128, 128), dtype="float16") D = T.match_buffer(d, (128, 128), dtype="float16") - with T.block([128, 128], "load_store") as [vi, vj]: - T.reads(A[vi, vj]) - T.writes(D[vi, vj]) - D.data[vi * 128 + vj] = T.load("float16", A.data, vi * 128 + vj) - with T.block([8, 8], "opaque") as [vi, vj]: - T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.evaluate( - T.tvm_load_matrix_sync( - B.data, - 16, - 16, - 16, - vi * 8 + vj, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - A.data, - vi * 2048 + vj * 16, + for i, j in T.grid(128, 128): + with T.block("load_store"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(D[vi, vj]) + D.data[vi * 128 + vj] = T.load("float16", A.data, vi * 128 + vj) + for i, j in T.grid(8, 8): + with T.block("opaque"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.evaluate( + T.tvm_load_matrix_sync( + B.data, + 16, + 16, + 16, + vi * 8 + vj, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + A.data, + vi * 2048 + vj * 16, + 128, + 1, + dtype="handle", + ), 128, - 1, + "row_major", dtype="handle", - ), - 128, - "row_major", - dtype="handle", + ) + ) + for i, j in T.grid(8, 8): + with T.block("match_buffer"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + A0 = T.match_buffer( + A[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, + ) + C0 = T.match_buffer( + C[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, ) - ) - with T.block([8, 8], "match_buffer") as [vi, vj]: - T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - A0 = T.match_buffer( - A[ - vi * 16 : vi * 16 + 16, - vj * 16 : vj * 16 + 16, - ], - (16, 16), - "float16", - strides=[128, 1], - offset_factor=1, - ) - C0 = T.match_buffer( - C[ - vi * 16 : vi * 16 + 16, - vj * 16 : vj * 16 + 16, - ], - (16, 16), - "float16", - strides=[128, 1], - offset_factor=1, - ) - T.evaluate( - T.tvm_load_matrix_sync( - C0.data, - 16, - 16, - 16, - vi * 8 + vj, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - A0.data, - A0.elem_offset, - A0.strides[0], - 1, + T.evaluate( + T.tvm_load_matrix_sync( + C0.data, + 16, + 16, + 16, + vi * 8 + vj, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + A0.data, + A0.elem_offset, + A0.strides[0], + 1, + dtype="handle", + ), + 128, + "row_major", dtype="handle", - ), - 128, - "row_major", - dtype="handle", + ) ) - ) @T.prim_func @@ -147,15 +160,16 @@ def func_multi_consumer() -> None: C = T.alloc_buffer((128)) for i in T.grid(8): for j in T.grid(16): - with T.block([128], "A") as [vi]: - T.bind(vi, i * 16 + j) + with T.block("A"): + vi = T.axis.S(128, i * 16 + j) A[vi] = 1.0 for j in T.grid(16): - with T.block([128], "B") as [vi]: - T.bind(vi, i * 16 + j) + with T.block("B"): + vi = T.axis.S(128, i * 16 + j) B[vi] = A[vi] + 1.0 for i in T.grid(128): - with T.block([128], "C") as [vi]: + with T.block("C"): + vi = T.axis.S(128, i) C[vi] = A[vi] @@ -163,12 +177,18 @@ def func_multi_consumer() -> None: def func_multi_producer() -> None: A = T.alloc_buffer((128)) B = T.alloc_buffer((128)) - with T.block([128], "A0") as [vi]: - A[vi] = 1.0 - with T.block([128], "A1") as [vi]: - A[vi] = 2.0 - with T.block([128], "B") as [vi]: - B[vi] = A[vi] + for i in range(128): + with T.block("A0"): + vi = T.axis.S(128, i) + A[vi] = 1.0 + for i in range(128): + with T.block("A1"): + vi = T.axis.S(128, i) + A[vi] = 2.0 + for i in range(128): + with T.block("B"): + vi = T.axis.S(128, i) + B[vi] = A[vi] ########## Expected function after cache_read ########## @@ -181,14 +201,22 @@ def cache_read_elementwise(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) A_global = T.alloc_buffer((128, 128)) B_local = T.alloc_buffer((128, 128), scope="local") - with T.block([128, 128], "A_global") as [vi, vj]: - A_global[vi, vj] = A[vi, vj] - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A_global[vi, vj] * 2.0 - with T.block([128, 128], "B_local") as [vi, vj]: - B_local[vi, vj] = B[vi, vj] - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B_local[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("A_global"): + vi, vj = T.axis.remap("SS", [i, j]) + A_global[vi, vj] = A[vi, vj] + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A_global[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("B_local"): + vi, vj = T.axis.remap("SS", [i, j]) + B_local[vi, vj] = B[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B_local[vi, vj] + 1.0 @T.prim_func @@ -198,27 +226,33 @@ def cache_read_under_scope(b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128)) A_global = T.alloc_buffer((128, 128)) - with T.block([8, 8], "scope") as [i, j]: - A_local = T.alloc_buffer((128, 128), scope="local") - for x, y in T.grid(16, 16): - with T.block([128, 128], "A") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - A[vi, vj] = 1.0 - for x, y in T.grid(16, 16): - with T.block([128, 128], "A_local") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - A_local[vi, vj] = A[vi, vj] - for x, y in T.grid(16, 16): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - B[vi, vj] = A_local[vi, vj] + 1.0 - with T.block([128, 128], "A_global") as [vi, vj]: - A_global[vi, vj] = A[vi, vj] - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A_global[vi, vj] * 2.0 + for i0, j0 in T.grid(8, 8): + with T.block("scope"): + i, j = T.axis.remap("SS", [i0, j0]) + A_local = T.alloc_buffer((128, 128), scope="local") + for x, y in T.grid(16, 16): + with T.block("A"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + A[vi, vj] = 1.0 + for x, y in T.grid(16, 16): + with T.block("A_local"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + A_local[vi, vj] = A[vi, vj] + for x, y in T.grid(16, 16): + with T.block("B"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + B[vi, vj] = A_local[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("A_global"): + vi, vj = T.axis.remap("SS", [i, j]) + A_global[vi, vj] = A[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A_global[vi, vj] * 2.0 @T.prim_func @@ -229,78 +263,86 @@ def cache_read_opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) D = T.match_buffer(d, (128, 128), dtype="float16") A_global = T.alloc_buffer((128, 128), dtype="float16") - with T.block([128, 128], "A_global") as [vi, vj]: - A_global[vi, vj] = A[vi, vj] - with T.block([128, 128], "load_store") as [vi, vj]: - T.reads(A_global[vi, vj]) - T.writes(D[vi, vj]) - D.data[vi * 128 + vj] = T.load("float16", A_global.data, vi * 128 + vj) - with T.block([8, 8], "opaque") as [vi, vj]: - T.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.evaluate( - T.tvm_load_matrix_sync( - B.data, - 16, - 16, - 16, - vi * 8 + vj, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - A_global.data, - vi * 2048 + vj * 16, + for i, j in T.grid(128, 128): + with T.block("A_global"): + vi, vj = T.axis.remap("SS", [i, j]) + A_global[vi, vj] = A[vi, vj] + for i, j in T.grid(128, 128): + with T.block("load_store"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A_global[vi, vj]) + T.writes(D[vi, vj]) + D.data[vi * 128 + vj] = T.load("float16", A_global.data, vi * 128 + vj) + for i, j in T.grid(8, 8): + with T.block("opaque"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.evaluate( + T.tvm_load_matrix_sync( + B.data, + 16, + 16, + 16, + vi * 8 + vj, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + A_global.data, + vi * 2048 + vj * 16, + 128, + 1, + dtype="handle", + ), 128, - 1, + "row_major", dtype="handle", - ), - 128, - "row_major", - dtype="handle", + ) + ) + for i, j in T.grid(8, 8): + with T.block("match_buffer"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + A0 = T.match_buffer( + A_global[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, ) - ) - with T.block([8, 8], "match_buffer") as [vi, vj]: - T.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - A0 = T.match_buffer( - A_global[ - vi * 16 : vi * 16 + 16, - vj * 16 : vj * 16 + 16, - ], - (16, 16), - "float16", - strides=[128, 1], - offset_factor=1, - ) - C0 = T.match_buffer( - C[ - vi * 16 : vi * 16 + 16, - vj * 16 : vj * 16 + 16, - ], - (16, 16), - "float16", - strides=[128, 1], - offset_factor=1, - ) - T.evaluate( - T.tvm_load_matrix_sync( - C0.data, - 16, - 16, - 16, - vi * 8 + vj, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - A0.data, - A0.elem_offset, - A0.strides[0], - 1, + C0 = T.match_buffer( + C[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, + ) + T.evaluate( + T.tvm_load_matrix_sync( + C0.data, + 16, + 16, + 16, + vi * 8 + vj, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + A0.data, + A0.elem_offset, + A0.strides[0], + 1, + dtype="handle", + ), + 128, + "row_major", dtype="handle", - ), - 128, - "row_major", - dtype="handle", + ) ) - ) @T.prim_func @@ -311,20 +353,21 @@ def cache_read_multi_consumer() -> None: A_global = T.alloc_buffer((128)) for i in T.grid(8): for j in T.grid(16): - with T.block([128], "A") as [vi]: - T.bind(vi, i * 16 + j) + with T.block("A"): + vi = T.axis.S(128, i * 16 + j) A[vi] = 1.0 for j in T.grid(16): - with T.block([128], "A") as [vi]: - T.bind(vi, i * 16 + j) + with T.block("A"): + vi = T.axis.S(128, i * 16 + j) A_global[vi] = A[vi] for j in T.grid(16): - with T.block([128], "B") as [vi]: - T.bind(vi, i * 16 + j) + with T.block("B"): + vi = T.axis.S(128, i * 16 + j) B[vi] = A_global[vi] + 1.0 for i in T.grid(128): - with T.block([128], "C") as [vi]: + with T.block("C"): + vi = T.axis.S(128, i) C[vi] = A_global[vi] @@ -335,14 +378,22 @@ def continuous_cache_read(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) B_shared = T.alloc_buffer((128, 128), scope="shared") B_local = T.alloc_buffer((128, 128), scope="local") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "B_shared") as [vi, vj]: - B_shared[vi, vj] = B[vi, vj] - with T.block([128, 128], "B_local") as [vi, vj]: - B_local[vi, vj] = B_shared[vi, vj] - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B_local[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("B_shared"): + vi, vj = T.axis.remap("SS", [i, j]) + B_shared[vi, vj] = B[vi, vj] + for i, j in T.grid(128, 128): + with T.block("B_local"): + vi, vj = T.axis.remap("SS", [i, j]) + B_local[vi, vj] = B_shared[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B_local[vi, vj] + 1.0 ########## Expected function after cache_write ########## @@ -355,14 +406,22 @@ def cache_write_elementwise(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) B_global = T.alloc_buffer((128, 128), scope="local") C_local = T.alloc_buffer((128, 128)) - with T.block([128, 128], "B_global") as [vi, vj]: - B_global[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = B_global[vi, vj] - with T.block([128, 128], "C_local") as [vi, vj]: - C_local[vi, vj] = B[vi, vj] + 1.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = C_local[vi, vj] + for i, j in T.grid(128, 128): + with T.block("B_global"): + vi, vj = T.axis.remap("SS", [i, j]) + B_global[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = B_global[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C_local"): + vi, vj = T.axis.remap("SS", [i, j]) + C_local[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = C_local[vi, vj] @T.prim_func @@ -372,33 +431,39 @@ def cache_write_under_scope(b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128)) A_global = T.alloc_buffer((128, 128)) - with T.block([8, 8], "scope") as [i, j]: - A_local = T.alloc_buffer((128, 128), scope="local") - B_global = T.alloc_buffer((128, 128)) - for x, y in T.grid(16, 16): - with T.block([128, 128], "A_local") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - A_local[vi, vj] = 1.0 - for x, y in T.grid(16, 16): - with T.block([128, 128], "A") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - A_global[vi, vj] = A_local[vi, vj] - for x, y in T.grid(16, 16): - with T.block([128, 128], "B_global") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - B_global[vi, vj] = A_global[vi, vj] + 1.0 - for x, y in T.grid(16, 16): - with T.block([128, 128], "B_global") as [vi, vj]: - T.bind(vi, i * 16 + x) - T.bind(vj, j * 16 + y) - B[vi, vj] = B_global[vi, vj] - with T.block([128, 128], "A_global") as [vi, vj]: - A[vi, vj] = A_global[vi, vj] - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] * 2.0 + for i0, j0 in T.grid(8, 8): + with T.block("scope"): + i, j = T.axis.remap("SS", [i0, j0]) + A_local = T.alloc_buffer((128, 128), scope="local") + B_global = T.alloc_buffer((128, 128)) + for x, y in T.grid(16, 16): + with T.block("A_local"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + A_local[vi, vj] = 1.0 + for x, y in T.grid(16, 16): + with T.block("A"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + A_global[vi, vj] = A_local[vi, vj] + for x, y in T.grid(16, 16): + with T.block("B_global"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + B_global[vi, vj] = A_global[vi, vj] + 1.0 + for x, y in T.grid(16, 16): + with T.block("B_global"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + B[vi, vj] = B_global[vi, vj] + for i, j in T.grid(128, 128): + with T.block("A_global"): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = A_global[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 @T.prim_func @@ -411,83 +476,95 @@ def cache_write_opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle B_global = T.alloc_buffer((128, 128), dtype="float16") C_global = T.alloc_buffer((128, 128), dtype="float16") - with T.block([128, 128], "load_store") as [vi, vj]: - T.reads(A[vi, vj]) - T.writes(D_global[vi, vj]) - D_global.data[vi * 128 + vj] = T.load("float16", A.data, vi * 128 + vj) - with T.block([8, 8], "opaque") as [vi, vj]: - T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.writes(B_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.evaluate( - T.tvm_load_matrix_sync( - B_global.data, - 16, - 16, - 16, - vi * 8 + vj, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - A.data, - vi * 2048 + vj * 16, + for i, j in T.grid(128, 128): + with T.block("load_store"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(D_global[vi, vj]) + D_global.data[vi * 128 + vj] = T.load("float16", A.data, vi * 128 + vj) + for i, j in T.grid(8, 8): + with T.block("opaque"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(B_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.evaluate( + T.tvm_load_matrix_sync( + B_global.data, + 16, + 16, + 16, + vi * 8 + vj, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + A.data, + vi * 2048 + vj * 16, + 128, + 1, + dtype="handle", + ), 128, - 1, + "row_major", dtype="handle", - ), - 128, - "row_major", - dtype="handle", + ) + ) + for i, j in T.grid(8, 8): + with T.block("match_buffer"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(C_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + A0 = T.match_buffer( + A[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, ) - ) - with T.block([8, 8], "match_buffer") as [vi, vj]: - T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - T.writes(C_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - A0 = T.match_buffer( - A[ - vi * 16 : vi * 16 + 16, - vj * 16 : vj * 16 + 16, - ], - (16, 16), - "float16", - strides=[128, 1], - offset_factor=1, - ) - C0 = T.match_buffer( - C_global[ - vi * 16 : vi * 16 + 16, - vj * 16 : vj * 16 + 16, - ], - (16, 16), - "float16", - strides=[128, 1], - offset_factor=1, - ) - T.evaluate( - T.tvm_load_matrix_sync( - C0.data, - 16, - 16, - 16, - vi * 8 + vj, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - A0.data, - A0.elem_offset, - A0.strides[0], - 1, + C0 = T.match_buffer( + C_global[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, + ) + T.evaluate( + T.tvm_load_matrix_sync( + C0.data, + 16, + 16, + 16, + vi * 8 + vj, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + A0.data, + A0.elem_offset, + A0.strides[0], + 1, + dtype="handle", + ), + 128, + "row_major", dtype="handle", - ), - 128, - "row_major", - dtype="handle", + ) ) - ) - with T.block([128, 128], "D") as [vi, vj]: - D[vi, vj] = D_global[vi, vj] - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = B_global[vi, vj] - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = C_global[vi, vj] + for i, j in T.grid(128, 128): + with T.block("D"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = D_global[vi, vj] + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = B_global[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = C_global[vi, vj] @T.prim_func @@ -498,20 +575,21 @@ def cache_write_multi_consumer() -> None: A_global = T.alloc_buffer((128)) for i in T.grid(8): for j in T.grid(16): - with T.block([128], "A_global") as [vi]: - T.bind(vi, i * 16 + j) + with T.block("A_global"): + vi = T.axis.S(128, i * 16 + j) A_global[vi] = 1.0 for j in T.grid(16): - with T.block([128], "A") as [vi]: - T.bind(vi, i * 16 + j) + with T.block("A"): + vi = T.axis.S(128, i * 16 + j) A[vi] = A_global[vi] for j in T.grid(16): - with T.block([128], "B") as [vi]: - T.bind(vi, i * 16 + j) + with T.block("B"): + vi = T.axis.S(128, i * 16 + j) B[vi] = A[vi] + 1.0 for i in T.grid(128): - with T.block([128], "C") as [vi]: + with T.block("C"): + vi = T.axis.S(128, i) C[vi] = A[vi] @@ -522,14 +600,22 @@ def continuous_cache_write(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128)) B_shared = T.alloc_buffer((128, 128), scope="shared") B_local = T.alloc_buffer((128, 128), scope="local") - with T.block([128, 128], "B") as [vi, vj]: - B_local[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "B") as [vi, vj]: - B_shared[vi, vj] = B_local[vi, vj] - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = B_shared[vi, vj] - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B_local[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B_shared[vi, vj] = B_local[vi, vj] + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = B_shared[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 ########## Testcases for cache_read ########## diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py index 5235664595add..6e956e1ee6887 100644 --- a/tests/python/unittest/test_tir_schedule_compute_at.py +++ b/tests/python/unittest/test_tir_schedule_compute_at.py @@ -32,10 +32,15 @@ def two_elementwise(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -45,12 +50,13 @@ def two_elementwise_after_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") for i in range(0, 128): for ax0, ax1 in T.grid(1, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i + ax0) - T.bind(vj, ax1) + with T.block("B"): + vi = T.axis.S(128, i + ax0) + vj = T.axis.S(128, ax1) B[vi, vj] = A[vi, vj] * 2.0 for j in range(0, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -59,22 +65,26 @@ def blockized_1(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128], "float32") B = T.alloc_buffer([128, 128], "float32") C = T.match_buffer(c, [128, 128], "float32") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([8, 8], "C_outer") as [vi_o, vj_o]: - T.reads([B[ - vi_o * 16 : vi_o * 16 + 16, - vj_o * 16 : vj_o * 16 + 16, - ]]) - T.writes([C[ - vi_o * 16 : vi_o * 16 + 16, - vj_o * 16 : vj_o * 16 + 16 - ]]) - for i_i, j_i in T.grid(16, 16): - with T.block([128, 128], "C_inner") as [vi, vj]: - T.bind(vi, vi_o * 16 + i_i) - T.bind(vj, vj_o * 16 + j_i) - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(8, 8): + with T.block("C_outer"): + vi_o, vj_o = T.axis.remap("SS", [i, j]) + T.reads([B[ + vi_o * 16 : vi_o * 16 + 16, + vj_o * 16 : vj_o * 16 + 16, + ]]) + T.writes([C[ + vi_o * 16 : vi_o * 16 + 16, + vj_o * 16 : vj_o * 16 + 16 + ]]) + for i_i, j_i in T.grid(16, 16): + with T.block("C_inner"): + vi = T.axis.S(128, vi_o * 16 + i_i) + vj = T.axis.S(128, vj_o * 16 + j_i) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -84,13 +94,12 @@ def blockized_after_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], "float32") for i0_0, i1_0 in T.grid(8, 8): for ax0, ax1 in T.grid(16, 16): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i0_0 * 16 + ax0) - T.bind(vj, i1_0 * 16 + ax1) + with T.block("B"): + vi = T.axis.S(128, i0_0 * 16 + ax0) + vj = T.axis.S(128, i1_0 * 16 + ax1) B[vi, vj] = A[vi, vj] * 2.0 - with T.block([8, 8], "C_outer") as [vi_o, vj_o]: - T.bind(vi_o, i0_0) - T.bind(vj_o, i1_0) + with T.block("C_outer"): + vi_o, vj_o = T.axis.remap("SS", [i0_0, i1_0]) T.reads([B[ vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16, @@ -100,9 +109,9 @@ def blockized_after_compute_at(a: T.handle, c: T.handle) -> None: vj_o * 16 : vj_o * 16 + 16 ]]) for i0_1, i1_1 in T.grid(16, 16): - with T.block([128, 128], "C_inner") as [vi, vj]: - T.bind(vi, vi_o * 16 + i0_1) - T.bind(vj, vj_o * 16 + i1_1) + with T.block("C_inner"): + vi = T.axis.S(128, vi_o * 16 + i0_1) + vj = T.axis.S(128, vj_o * 16 + i1_1) C[vi, vj] = B[vi, vj] + 1.0 @@ -112,9 +121,8 @@ def blockized_2(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer([128, 128], "float32") C = T.match_buffer(c, [128, 128], "float32") for i_o, j_o in T.grid(8, 8): - with T.block([8, 8], "B_outer") as [vio, vjo]: - T.bind(vio, i_o) - T.bind(vjo, j_o) + with T.block("B_outer"): + vio, vjo = T.axis.remap("SS", [i_o, j_o]) T.reads([A[ vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16, @@ -124,14 +132,14 @@ def blockized_2(a: T.handle, c: T.handle) -> None: vjo * 16 : vjo * 16 + 16 ]]) for i_i, j_i in T.grid(16, 16): - with T.block([128, 128], "B_inner") as [vi, vj]: - T.bind(vi, vio * 16 + i_i) - T.bind(vj, vjo * 16 + j_i) + with T.block("B_inner"): + vi = T.axis.S(128, vio * 16 + i_i) + vj = T.axis.S(128, vjo * 16 + j_i) B[vi, vj] = A[vi, vj] * 2.0 for i_o, j_o, i_i, j_i in T.grid(4, 4, 32, 32): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i_o * 32 + i_i) - T.bind(vj, j_o * 32 + j_i) + with T.block("C"): + vi = T.axis.S(128, i_o * 32 + i_i) + vj = T.axis.S(128, j_o * 32 + j_i) C[vi, vj] = B[vi, vj] + 1.0 @@ -141,9 +149,8 @@ def blockized_2_after_reverse_compute_at(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer([128, 128], "float32") C = T.match_buffer(c, [128, 128], "float32") for i_o, j_o in T.grid(8, 8): - with T.block([8, 8], "B_outer") as [vio, vjo]: - T.bind(vio, i_o) - T.bind(vjo, j_o) + with T.block("B_outer"): + vio, vjo = T.axis.remap("SS", [i_o, j_o]) T.reads([A[ vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16, @@ -153,14 +160,14 @@ def blockized_2_after_reverse_compute_at(a: T.handle, c: T.handle) -> None: vjo * 16 : vjo * 16 + 16 ]]) for i_i, j_i in T.grid(16, 16): - with T.block([128, 128], "B_inner") as [vi, vj]: - T.bind(vi, vio * 16 + i_i) - T.bind(vj, vjo * 16 + j_i) + with T.block("B_inner"): + vi = T.axis.S(128, vio * 16 + i_i) + vj = T.axis.S(128, vjo * 16 + j_i) B[vi, vj] = A[vi, vj] * 2.0 for ax0, ax1 in T.grid(16, 16): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i_o * 16 + ax0) - T.bind(vj, j_o * 16 + ax1) + with T.block("C"): + vi = T.axis.S(128, i_o * 16 + ax0) + vj = T.axis.S(128, j_o * 16 + ax1) T.reads([B[vi, vj]]) T.writes([C[vi, vj]]) C[vi, vj] = B[vi, vj] + 1.0 @@ -173,9 +180,9 @@ def blockized_2_after_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], "float32") for i_o, j_o in T.grid(4, 4): for ax0, ax1 in T.grid(2, 2): - with T.block([8, 8], "blockized_B") as [vio, vjo]: - T.bind(vio, i_o * 2 + ax0) - T.bind(vjo, j_o * 2 + ax1) + with T.block("blockized_B"): + vio = T.axis.S(8, i_o * 2 + ax0) + vjo = T.axis.S(8, j_o * 2 + ax1) T.reads([A[ vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16, @@ -185,14 +192,14 @@ def blockized_2_after_compute_at(a: T.handle, c: T.handle) -> None: vjo * 16 : vjo * 16 + 16, ]]) for i_i, j_i in T.grid(16, 16): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, vio * 16 + i_i) - T.bind(vj, vjo * 16 + j_i) + with T.block("B"): + vi = T.axis.S(128, vio * 16 + i_i) + vj = T.axis.S(128, vjo * 16 + j_i) B[vi, vj] = A[vi, vj] * 2.0 for i_i, j_i in T.grid(32, 32): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i_o * 32 + i_i) - T.bind(vj, j_o * 32 + j_i) + with T.block("C"): + vi = T.axis.S(128, i_o * 32 + i_i) + vj = T.axis.S(128, j_o * 32 + j_i) C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -205,18 +212,28 @@ def cuda_matmul_0(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") - with T.block([2048, 2048], "A_shared") as [v0, v1]: - A_shared[v0, v1] = A[v0, v1] - with T.block([2048, 2048], "B_shared") as [v0, v1]: - B_shared[v0, v1] = B[v0, v1] - with T.block([2048, 2048], "A_shared_local") as [v0, v1]: - A_shared_local[v0, v1] = A_shared[v0, v1] - with T.block([2048, 2048], "B_shared_local") as [v0, v1]: - B_shared_local[v0, v1] = B_shared[v0, v1] - with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - with T.init(): - C_local[vi, vj] = 0.0 - C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] + for i, j in T.grid(2048, 2048): + with T.block("A_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + A_shared[v0, v1] = A[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared[v0, v1] = B[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("A_shared_local"): + v0, v1 = T.axis.remap("SS", [i, j]) + A_shared_local[v0, v1] = A_shared[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared_local"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared_local[v0, v1] = B_shared[v0, v1] + for i, j, k in T.grid(2048, 2048, 2048): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C_local[vi, vj] = 0.0 + C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for by in T.thread_binding(0, 32, thread = "blockIdx.y"): for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): for vy in T.thread_binding(0, 2, thread = "vthread.y"): @@ -224,9 +241,9 @@ def cuda_matmul_0(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis for ty in T.thread_binding(0, 8, thread = "threadIdx.y"): for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): for i, j in T.grid(4, 4): - with T.block([2048, 2048], "C_local") as [v0_4, v1_4]: - T.bind(v0_4, by * 64 + vy * 32 + ty * 4 + i) - T.bind(v1_4, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("C_local"): + v0_4 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + v1_4 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[v0_4, v1_4] = C_local[v0_4, v1_4] @@ -240,14 +257,22 @@ def cuda_matmul_0_after_compute_at(a: T.handle, b: T.handle, c: T.handle) -> Non A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") - with T.block([2048, 2048], "A_shared") as [v0, v1]: - A_shared[v0, v1] = A[v0, v1] - with T.block([2048, 2048], "B_shared") as [v0, v1]: - B_shared[v0, v1] = B[v0, v1] - with T.block([2048, 2048], "A_shared_local") as [v0, v1]: - A_shared_local[v0, v1] = A_shared[v0, v1] - with T.block([2048, 2048], "B_shared_local") as [v0, v1]: - B_shared_local[v0, v1] = B_shared[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("A_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + A_shared[v0, v1] = A[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared[v0, v1] = B[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("A_shared_local"): + v0, v1 = T.axis.remap("SS", [i, j]) + A_shared_local[v0, v1] = A_shared[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared_local"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared_local[v0, v1] = B_shared[v0, v1] for by in T.thread_binding(0, 32, thread = "blockIdx.y"): for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): for vy in T.thread_binding(0, 2, thread = "vthread.y"): @@ -255,17 +280,17 @@ def cuda_matmul_0_after_compute_at(a: T.handle, b: T.handle, c: T.handle) -> Non for ty in T.thread_binding(0, 8, thread = "threadIdx.y"): for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): for i, j, k in T.grid(4, 4, 2048): - with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) - T.bind(vk, k) + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k) with T.init(): C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): - with T.block([2048, 2048], "C_local") as [vi, vj]: - T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("C_local"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[vi, vj] = C_local[vi, vj] @@ -279,14 +304,22 @@ def cuda_matmul_1(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") - with T.block([2048, 2048], "A_shared") as [v0, v1]: - A_shared[v0, v1] = A[v0, v1] - with T.block([2048, 2048], "B_shared") as [v0, v1]: - B_shared[v0, v1] = B[v0, v1] - with T.block([2048, 2048], "A_shared_local") as [v0, v1]: - A_shared_local[v0, v1] = A_shared[v0, v1] - with T.block([2048, 2048], "B_shared_local") as [v0, v1]: - B_shared_local[v0, v1] = B_shared[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("A_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + A_shared[v0, v1] = A[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared[v0, v1] = B[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("A_shared_local"): + v0, v1 = T.axis.remap("SS", [i, j]) + A_shared_local[v0, v1] = A_shared[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared_local"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared_local[v0, v1] = B_shared[v0, v1] for by in T.thread_binding(0, 32, thread = "blockIdx.y"): for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): for vy in T.thread_binding(0, 2, thread = "vthread.y"): @@ -296,17 +329,17 @@ def cuda_matmul_1(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis for k_0 in T.serial(0, 256): for k_1 in T.unroll(0, 8): for _, i, j in T.grid(1, 4, 4): - with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) - T.bind(vk, k_0 * 8 + k_1) + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k_0 * 8 + k_1) with T.init(): C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): - with T.block([2048, 2048], "C_local") as [vi, vj]: - T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("C_local"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[vi, vj] = C_local[vi, vj] @@ -320,12 +353,18 @@ def cuda_matmul_2(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") - with T.block([2048, 2048], "A_shared") as [v0, v1]: - A_shared[v0, v1] = A[v0, v1] - with T.block([2048, 2048], "B_shared") as [v0, v1]: - B_shared[v0, v1] = B[v0, v1] - with T.block([2048, 2048], "B_shared_local") as [v0, v1]: - B_shared_local[v0, v1] = B_shared[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("A_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + A_shared[v0, v1] = A[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared[v0, v1] = B[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared_local"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared_local[v0, v1] = B_shared[v0, v1] for by in T.thread_binding(0, 32, thread = "blockIdx.y"): for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): for vy in T.thread_binding(0, 2, thread = "vthread.y"): @@ -335,22 +374,22 @@ def cuda_matmul_2(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis for k_0 in T.serial(0, 256): for k_1 in T.unroll(0, 8): for i, j in T.grid(1, 4): - with T.block([2048, 2048], "A_shared_local") as [v0, v1]: - T.bind(v0, k_0 * 8 + k_1 + i) - T.bind(v1, by * 64 + vy * 32 + ty * 4 + j) + with T.block("A_shared_local"): + v0 = T.axis.S(2048, k_0 * 8 + k_1 + i) + v1 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + j) A_shared_local[v0, v1] = A_shared[v0, v1] for _, i, j in T.grid(1, 4, 4): - with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) - T.bind(vk, k_0 * 8 + k_1) + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k_0 * 8 + k_1) with T.init(): C_local[vi, vj] = T.float32(0) C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): - with T.block([2048, 2048], "C_local") as [v0, v1]: - T.bind(v0, by * 64 + vy * 32 + ty * 4 + i) - T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("C_local"): + v0 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[v0, v1] = C_local[v0, v1] @@ -364,10 +403,14 @@ def cuda_matmul_3(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") - with T.block([2048, 2048], "A_shared") as [v0, v1]: - A_shared[v0, v1] = A[v0, v1] - with T.block([2048, 2048], "B_shared") as [v0, v1]: - B_shared[v0, v1] = B[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("A_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + A_shared[v0, v1] = A[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared[v0, v1] = B[v0, v1] for by in T.thread_binding(0, 32, thread = "blockIdx.y"): for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): for vy in T.thread_binding(0, 2, thread = "vthread.y"): @@ -377,27 +420,27 @@ def cuda_matmul_3(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis for k0 in T.serial(0, 256): for k1 in T.unroll(0, 8): for i, j in T.grid(1, 4): - with T.block([2048, 2048], "A_shared_local") as [v0, v1]: - T.bind(v0, k0 * 8 + k1 + i) - T.bind(v1, by * 64 + vy * 32 + ty * 4 + j) + with T.block("A_shared_local"): + v0 = T.axis.S(2048, k0 * 8 + k1 + i) + v1 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + j) A_shared_local[v0, v1] = A_shared[v0, v1] for i, j in T.grid(1, 4): - with T.block([2048, 2048], "B_shared_local") as [v0, v1]: - T.bind(v0, k0 * 8 + k1 + i) - T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("B_shared_local"): + v0 = T.axis.S(2048, k0 * 8 + k1 + i) + v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) B_shared_local[v0, v1] = B_shared[v0, v1] for _, i, j in T.grid(1, 4, 4): - with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) - T.bind(vk, k0 * 8 + k1) + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) with T.init(): C_local[vi, vj] = T.float32(0) C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): - with T.block([2048, 2048], "C_local") as [v0, v1]: - T.bind(v0, by * 64 + vy * 32 + ty * 4 + i) - T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("C_local"): + v0 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[v0, v1] = C_local[v0, v1] @@ -411,8 +454,10 @@ def cuda_matmul_4(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") - with T.block([2048, 2048], "B_shared") as [v0, v1]: - B_shared[v0, v1] = B[v0, v1] + for i, j in T.grid(2048, 2048): + with T.block("B_shared"): + v0, v1 = T.axis.remap("SS", [i, j]) + B_shared[v0, v1] = B[v0, v1] for by in T.thread_binding(0, 32, thread = "blockIdx.y"): for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): for vy in T.thread_binding(0, 2, thread = "vthread.y"): @@ -421,33 +466,33 @@ def cuda_matmul_4(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): for k0 in T.serial(0, 256): for i, j in T.grid(8, 64): - with T.block([2048, 2048], "A_shared") as [v0, v1]: - T.bind(v0, k0 * 8 + i) - T.bind(v1, by * 64 + j) + with T.block("A_shared"): + v0 = T.axis.S(2048, k0 * 8 + i) + v1 = T.axis.S(2048, by * 64 + j) A_shared[v0, v1] = A[v0, v1] for k1 in T.unroll(0, 8): for i, j in T.grid(1, 4): - with T.block([2048, 2048], "A_shared_local") as [v0, v1]: - T.bind(v0, k0 * 8 + k1 + i) - T.bind(v1, by * 64 + vy * 32 + ty * 4 + j) + with T.block("A_shared_local"): + v0 = T.axis.S(2048, k0 * 8 + k1 + i) + v1 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + j) A_shared_local[v0, v1] = A_shared[v0, v1] for i, j in T.grid(1, 4): - with T.block([2048, 2048], "B_shared_local") as [v0, v1]: - T.bind(v0, k0 * 8 + k1 + i) - T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("B_shared_local"): + v0 = T.axis.S(2048, k0 * 8 + k1 + i) + v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) B_shared_local[v0, v1] = B_shared[v0, v1] for _, i, j in T.grid(1, 4, 4): - with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) - T.bind(vk, k0 * 8 + k1) + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) with T.init(): C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): - with T.block([2048, 2048], "C_local") as [v0, v1]: - T.bind(v0, by * 64 + vy * 32 + ty * 4 + i) - T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("C_local"): + v0 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[v0, v1] = C_local[v0, v1] @@ -469,38 +514,38 @@ def cuda_matmul_5(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: dis for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): for k0 in T.serial(0, 256): for i, j in T.grid(8, 64): - with T.block([2048, 2048], "A_shared") as [v0, v1]: - T.bind(v0, k0 * 8 + i) - T.bind(v1, by * 64 + j) + with T.block("A_shared"): + v0 = T.axis.S(2048, k0 * 8 + i) + v1 = T.axis.S(2048, by * 64 + j) A_shared[v0, v1] = A[v0, v1] for i, j in T.grid(8, 64): - with T.block([2048, 2048], "B_shared") as [v0, v1]: - T.bind(v0, k0 * 8 + i) - T.bind(v1, bx * 64 + j) + with T.block("B_shared"): + v0 = T.axis.S(2048, k0 * 8 + i) + v1 = T.axis.S(2048, bx * 64 + j) B_shared[v0, v1] = B[v0, v1] for k1 in T.unroll(0, 8): for i, j in T.grid(1, 4): - with T.block([2048, 2048], "A_shared_local") as [v0, v1]: - T.bind(v0, k0 * 8 + k1 + i) - T.bind(v1, by * 64 + vy * 32 + ty * 4 + j) + with T.block("A_shared_local"): + v0 = T.axis.S(2048, k0 * 8 + k1 + i) + v1 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + j) A_shared_local[v0, v1] = A_shared[v0, v1] for i, j in T.grid(1, 4): - with T.block([2048, 2048], "B_shared_local") as [v0, v1]: - T.bind(v0, k0 * 8 + k1 + i) - T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("B_shared_local"): + v0 = T.axis.S(2048, k0 * 8 + k1 + i) + v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) B_shared_local[v0, v1] = B_shared[v0, v1] for _, i, j in T.grid(1, 4, 4): - with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) - T.bind(vk, k0 * 8 + k1) + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) with T.init(): C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): - with T.block([2048, 2048], "C_local") as [v0, v1]: - T.bind(v0, by * 64 + vy * 32 + ty * 4 + i) - T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + with T.block("C_local"): + v0 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[v0, v1] = C_local[v0, v1] @@ -510,12 +555,14 @@ def tiled(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer([128, 128], "float32") C = T.match_buffer(c, [128, 128], "float32") for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i_0 * 16 + i_1) - T.bind(vj, j_0 * 16 + j_1) + with T.block("B"): + vi = T.axis.S(128, i_0 * 16 + i_1) + vj = T.axis.S(128, j_0 * 16 + j_1) B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -525,14 +572,14 @@ def tiled_after_reverse_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], "float32") for i_0, j_0, i_1 in T.grid(8, 8, 16): for j_1 in T.serial(0, 16): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i_0 * 16 + i_1) - T.bind(vj, j_0 * 16 + j_1) + with T.block("B"): + vi = T.axis.S(128, i_0 * 16 + i_1) + vj = T.axis.S(128, j_0 * 16 + j_1) B[vi, vj] = A[vi, vj] * 2.0 for j_1 in T.serial(0, 16): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i_0 * 16 + i_1) - T.bind(vj, j_0 * 16 + j_1) + with T.block("C"): + vi = T.axis.S(128, i_0 * 16 + i_1) + vj = T.axis.S(128, j_0 * 16 + j_1) C[vi, vj] = B[vi, vj] + 1.0 @@ -544,17 +591,15 @@ def factorized(a: T.handle, b: T.handle) -> None: for j in T.thread_binding(0, 16, thread = "blockIdx.x"): for i_o in T.thread_binding(0, 4, thread = "threadIdx.x"): for i_i, k in T.grid(4, 16): - with T.block([16, 16, T.reduce_axis(0, 16)], "B_rf") as [vi, vj, vk]: - T.bind(vi, i_o * 4 + i_i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B_rf"): + vi = T.axis.S(16, i_o * 4 + i_i) + vj, vk = T.axis.remap("SR", [j, k]) with T.init(): B_rf_local[vi, vj] = 0.0 B_rf_local[vi, vj] = B_rf_local[vi, vj] + A[vj, vi, vk] for i, k in T.grid(16, 16): - with T.block([16, T.reduce_axis(0, 16)], "B") as [vi, vk]: - T.bind(vi, i) - T.bind(vk, k) + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + B_rf_local[vk, vi] @@ -568,17 +613,17 @@ def factorized_after_reverse_compute_at(a: T.handle, b: T.handle) -> None: for j in T.thread_binding(0, 16, thread = "blockIdx.x"): for i_o in T.thread_binding(0, 4, thread = "threadIdx.x"): for i_i, k in T.grid(4, 16): - with T.block([16, 16, T.reduce_axis(0, 16)], "B_rf") as [vi, vj, vk]: - T.bind(vi, i_o * 4 + i_i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B_rf"): + vi = T.axis.S(16, i_o * 4 + i_i) + vj = T.axis.S(16, j) + vk = T.axis.R(16, k) with T.init(): B_rf_local[vi, vj] = 0.0 B_rf_local[vi, vj] = B_rf_local[vi, vj] + A[vj, vi, vk] for k in T.serial(0, 4): - with T.block([16, T.reduce_axis(0, 16)], "B") as [vi, vk]: - T.bind(vi, j) - T.bind(vk, i_o * 4 + k) + with T.block("B"): + vi = T.axis.S(16, j) + vk = T.axis.R(16, i_o * 4 + k) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + B_rf_local[vk, vi] @@ -591,17 +636,19 @@ def fail_subtree_compact_dataflow(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") for i in range(0, 128): for j in range(0, 64): - with T.block([128, 128], "B_0") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B_0"): + vi = T.axis.S(128, i) + vj = T.axis.S(128, j) B[vi, vj] = A[vi, vj] * 2.0 for j in range(0, 64): - with T.block([128, 128], "B_1") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j + 64) + with T.block("B_1"): + vi = T.axis.S(128, i) + vj = T.axis.S(128, j + 64) B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -611,13 +658,16 @@ def fail_all_consumers_under_loop(a: T.handle, c: T.handle, d: T.handle) -> None C = T.match_buffer(c, (128, 128), "float32") D = T.match_buffer(d, (128, 128), "float32") for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block([128, 128], "C") as [vi, vj]: + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 for i, j in T.grid(128, 128): - with T.block([128, 128], "D") as [vi, vj]: + with T.block("D"): + vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = B[vi, vj] + 1.0 @@ -628,13 +678,16 @@ def fail_all_producers_under_loop(a: T.handle, d: T.handle) -> None: C = T.alloc_buffer((128, 128), "float32") D = T.match_buffer(d, (128, 128), "float32") for i, j in T.grid(128, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block([128, 128], "C") as [vi, vj]: + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = A[vi, vj] + 1.0 for i, j in T.grid(128, 128): - with T.block([128, 128], "D") as [vi, vj]: + with T.block("D"): + vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = B[vi, vj] + C[vi, vj] @@ -644,10 +697,12 @@ def read_out_of_bound(a: T.handle, c:T.handle) -> None: B = T.alloc_buffer([16], "float32") C = T.match_buffer(c, [16], "float32") for i in T.serial(0, 16): - with T.block([16], "B") as [v]: + with T.block("B"): + v = T.axis.S(16, i) B[v] = A[v] for j in T.serial(0, 16): - with T.block([16], "C") as [v]: + with T.block("C"): + v = T.axis.S(16, j) T.reads(B[v : v + 2]) C[v] = T.if_then_else(v < 15, T.max(B[v], B[v + 1]), B[v], dtype="float32") @@ -659,11 +714,11 @@ def read_out_of_bound_after_compute_at(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [16], "float32") for j in T.serial(0, 16): for i in T.serial(0, T.min(1, 15 - j) + 1): - with T.block([16], "B") as [v]: - T.bind(v, j + i) + with T.block("B"): + v = T.axis.S(16, j + i) B[v] = A[v] - with T.block([16], "C") as [v]: - T.bind(v, j) + with T.block("C"): + v = T.axis.S(16, j) T.reads([B[v : v + 2]]) C[v] = T.if_then_else(v < 15, T.max(B[v], B[v + 1]), B[v], dtype="float32") diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index f9049f6da732b..617c75b75cd91 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -31,10 +31,14 @@ def elementwise(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -43,12 +47,18 @@ def elementwise_multi_producer_consumer(a: T.handle, c: T.handle, d: T.handle) - B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) D = T.match_buffer(d, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 # B has two consumers - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 - with T.block([128, 128], "D") as [vi, vj]: - D[vi, vj] = B[vi, vj] + 2.0 + C[vi, vj] # D has two producers + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 # B has two consumers + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("D"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = B[vi, vj] + 2.0 + C[vi, vj] # D has two producers @T.prim_func @@ -56,10 +66,14 @@ def elementwise_multi_consumer_inlined(a: T.handle, c: T.handle, d: T.handle) -> A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) D = T.match_buffer(d, (128, 128)) - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] * 2.0 + 1.0 - with T.block([128, 128], "D") as [vi, vj]: - D[vi, vj] = A[vi, vj] * 2.0 + 2.0 + C[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + for i, j in T.grid(128, 128): + with T.block("D"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = A[vi, vj] * 2.0 + 2.0 + C[vi, vj] @T.prim_func @@ -67,18 +81,24 @@ def elementwise_standalone(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] + 1.0 @T.prim_func def elementwise_standalone_dce(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] + 1.0 @T.prim_func @@ -88,14 +108,12 @@ def elementwise_under_loop(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) for i in T.serial(0, 128): for j in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for j in T.serial(0, 128): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @@ -103,8 +121,10 @@ def elementwise_under_loop(a: T.handle, c: T.handle) -> None: def elementwise_inlined(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 @T.prim_func @@ -113,11 +133,15 @@ def fail_multi_reader_writer(a: T.handle, d: T.handle) -> None: B = T.alloc_buffer((128, 128)) C = T.alloc_buffer((128, 128)) D = T.match_buffer(d, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - C[vi, vj] = A[vi, vj] + 2.0 - with T.block([128, 128], "C") as [vi, vj]: - D[vi, vj] = B[vi, vj] + C[vi, vj] + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + C[vi, vj] = A[vi, vj] + 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = B[vi, vj] + C[vi, vj] @T.prim_func @@ -125,18 +149,24 @@ def elementwise_multi_reverse_loads(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = (B[vi, vj] + 1.0) * (B[vi, vj] * 2.0) + 3.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = (B[vi, vj] + 1.0) * (B[vi, vj] * 2.0) + 3.0 @T.prim_func def elementwise_multi_reverse_loads_inlined(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - C[vi, vj] = (A[vi, vj] * 2.0 + 1.0) * (A[vi, vj] * 2.0 * 2.0) + 3.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = (A[vi, vj] * 2.0 + 1.0) * (A[vi, vj] * 2.0 * 2.0) + 3.0 @T.prim_func @@ -144,12 +174,16 @@ def opaque_access_load(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - T.reads(B[0:128, 0:128]) - T.writes(C[0:128, 0:128]) - C[vi, vj] = T.load("float32", B.data, vi * 128 + vj) + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[0:128, 0:128]) + T.writes(C[0:128, 0:128]) + C[vi, vj] = T.load("float32", B.data, vi * 128 + vj) + 1.0 @T.prim_func @@ -157,13 +191,17 @@ def opaque_access_store(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - T.reads(B[0:128, 0:128]) - T.writes(C[0:128, 0:128]) - T.store(C.data, vi * 128 + vj, B[vi, vj] + 1.0) - C[vi, vj] = T.load("float32", B.data, vi * 16 + vj) + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[0:128, 0:128]) + T.writes(C[0:128, 0:128]) + T.store(C.data, vi * 128 + vj, B[vi, vj] + 1.0) + C[vi, vj] = T.load("float32", B.data, vi * 16 + vj) + 1.0 @T.prim_func @@ -171,11 +209,15 @@ def buffer_matched(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - Bb = T.match_buffer(B[vi : vi + 1, vj], (1, 1)) - C[vi, vj] = Bb[0, 0] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + Bb = T.match_buffer(B[vi : vi + 1, vj], (1, 1)) + C[vi, vj] = Bb[0, 0] + 1.0 @T.prim_func @@ -183,10 +225,13 @@ def elementwise_predicate(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): - with T.block([128, 128], "C") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) T.where(B[i, j] < 10.0) C[vi, vj] = B[vi, vj] + 1.0 @@ -196,7 +241,8 @@ def elementwise_predicate_inlined(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "C") as [vi, vj]: + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) T.where(A[i, j] * 2.0 < 10.0) C[vi, vj] = A[vi, vj] * 2.0 + 1.0 @@ -206,18 +252,24 @@ def elementwise_multi_loads(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 126], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + B[vi, vj + 1] + B[vi, vj + 2] + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + B[vi, vj + 1] + B[vi, vj + 2] @T.prim_func def elementwise_multi_loads_inlined(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 126], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] * 2.0 + A[vi, vj + 1] * 2.0 + A[vi, vj + 2] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 + A[vi, vj + 1] * 2.0 + A[vi, vj + 2] * 2.0 # pylint: enable=no-member,invalid-name,unused-variable diff --git a/tests/python/unittest/test_tir_schedule_error.py b/tests/python/unittest/test_tir_schedule_error.py index 7a9c8e01d3554..ad6a1931bb0b4 100644 --- a/tests/python/unittest/test_tir_schedule_error.py +++ b/tests/python/unittest/test_tir_schedule_error.py @@ -31,10 +31,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block([128, 128], "init") as [vi, vj]: + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.float32(0) - for k in range(0, 128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + for k in range(128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] diff --git a/tests/python/unittest/test_tir_schedule_for_kind.py b/tests/python/unittest/test_tir_schedule_for_kind.py index 60269ac01c14d..9075e93b9d45c 100644 --- a/tests/python/unittest/test_tir_schedule_for_kind.py +++ b/tests/python/unittest/test_tir_schedule_for_kind.py @@ -31,9 +31,10 @@ def element_wise(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) - - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 @T.prim_func @@ -42,9 +43,8 @@ def element_wise_parallelized(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i0 in T.parallel(0, 128): for i1 in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i0) - T.bind(vj, i1) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i0, i1]) B[vi, vj] = A[vi, vj] * 2.0 @@ -54,9 +54,8 @@ def element_wise_i_bound(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128)) for i0 in T.thread_binding(0, 128, thread="threadIdx.x"): for i1 in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i0) - T.bind(vj, i1) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i0, i1]) B[vi, vj] = A[vi, vj] * 2.0 @@ -67,14 +66,13 @@ def element_wise_compute_at_split(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) for i in T.serial(0, 128): for j0 in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j0) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j0]) B[vi, vj] = A[vi, vj] * 2.0 for j1o, j1i in T.grid(32, 4): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j1o * 4 + j1i) + with T.block("C"): + vi = T.axis.S(128, i) + vj = T.axis.S(128, j1o * 4 + j1i) C[vi, vj] = B[vi, vj] + 1.0 @@ -85,15 +83,14 @@ def element_wise_compute_at_split_vectorized(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((128, 128)) for i in T.serial(0, 128): for j0 in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j0) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j0]) B[vi, vj] = A[vi, vj] * 2.0 for j1o in T.serial(0, 32): for j1i in T.vectorized(0, 4): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j1o * 4 + j1i) + with T.block("C"): + vi = T.axis.S(128, i) + vj = T.axis.S(128, j1o * 4 + j1i) C[vi, vj] = B[vi, vj] + 1.0 @@ -102,10 +99,10 @@ def element_wise_split_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) for i, j_0, j_1 in T.grid(128, 13, 10): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): T.where(j_0 * 10 + j_1 < 128) - T.bind(vi, i) - T.bind(vj, j_0 * 10 + j_1) + vi = T.axis.S(128, i) + vj = T.axis.S(128, j_0 * 10 + j_1) B[vi, vj] = A[vi, vj] * 2.0 @@ -116,10 +113,10 @@ def element_wise_split_predicate_parallelized(a: T.handle, b: T.handle) -> None: for i in T.serial(0, 128): for j_0 in T.parallel(0, 13): for j_1 in T.serial(0, 10): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): T.where(j_0 * 10 + j_1 < 128) - T.bind(vi, i) - T.bind(vj, j_0 * 10 + j_1) + vi = T.axis.S(128, i) + vj = T.axis.S(128, j_0 * 10 + j_1) B[vi, vj] = A[vi, vj] * 2.0 @@ -129,10 +126,10 @@ def element_wise_split_predicate_vectorized(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128, 128]) for i in T.vectorized(0, 128): for j_0, j_1 in T.grid(13, 10): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): T.where(j_0 * 10 + j_1 < 128) - T.bind(vi, i) - T.bind(vj, j_0 * 10 + j_1) + vi = T.axis.S(128, i) + vj = T.axis.S(128, j_0 * 10 + j_1) B[vi, vj] = A[vi, vj] * 2.0 @@ -143,15 +140,14 @@ def element_wise_compute_at_split_j0_j1o_bound(a: T.handle, c: T.handle) -> None B = T.alloc_buffer((128, 128)) for i in T.serial(0, 128): for j0 in T.thread_binding(0, 128, thread="threadIdx.x"): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j0) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j0]) B[vi, vj] = A[vi, vj] * 2.0 for j1o in T.thread_binding(0, 32, thread="threadIdx.x"): for j1i in T.serial(0, 4): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j1o * 4 + j1i) + with T.block("C"): + vi = T.axis.S(128, i) + vj = T.axis.S(128, j1o * 4 + j1i) C[vi, vj] = B[vi, vj] + 1.0 @@ -161,10 +157,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128, T.reduce_axis(0, 128)], "C") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(128, 128, 128): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -172,10 +170,12 @@ def rowsum(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128,)) - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - with T.init(): - B[vi] = 0.0 - B[vi] = B[vi] + A[vi, vk] + for i, k in T.grid(128, 128): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] @T.prim_func @@ -184,9 +184,8 @@ def rowsum_unrolled(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for i0 in T.unroll(0, 128): for i1 in T.serial(0, 128): - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - T.bind(vi, i0) - T.bind(vk, i1) + with T.block("B"): + vi, vk = T.axis.remap("SR", [i0, i1]) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] @@ -198,9 +197,9 @@ def rowsum_not_quasi_affine(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for i, k in T.grid(128, 16): - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - T.bind(vi, i) - T.bind(vk, T.floordiv(k * k, 2)) + with T.block("B"): + vi = T.axis.S(128, i) + vk = T.axis.R(128, T.floordiv(k * k, 2)) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] @@ -211,10 +210,12 @@ def rowsum_not_compact_data_flow(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128,)) - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - with T.init(): - B[vk] = 0.0 - B[vk] = B[vk] + A[vi, vk] + for i, k in T.grid(128, 16): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + with T.init(): + B[vk] = 0.0 + B[vk] = B[vk] + A[vi, vk] @T.prim_func @@ -223,9 +224,8 @@ def rowsum_cross_thread_reduction(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for i0 in T.serial(0, 128): for i1 in T.thread_binding(0, 128, thread="threadIdx.x"): - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - T.bind(vi, i0) - T.bind(vk, i1) + with T.block("B"): + vi, vk = T.axis.remap("SR", [i0, i1]) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] @@ -235,7 +235,7 @@ def rowsum_cross_thread_reduction(a: T.handle, b: T.handle) -> None: def opaque_block(a: T.handle) -> None: A = T.match_buffer(a, (16,)) for i in T.serial(0, 15): - with T.block([], "opaque"): + with T.block("opaque"): A[i + 1] = A[i + 1] + A[i] diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_reduction.py index 8460b5cf3e66f..e158f6a026e13 100644 --- a/tests/python/unittest/test_tir_schedule_reduction.py +++ b/tests/python/unittest/test_tir_schedule_reduction.py @@ -32,18 +32,17 @@ def rowsum_blockized(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [32, 4]) A = T.match_buffer(a, [32, 4, 128]) for i0, i2_0 in T.grid(32, 16): - with T.block([32, T.reduce_axis(0, 16)], "blockized_B") as [io, ko]: - T.bind(io, i0) - T.bind(ko, i2_0) + with T.block("blockized_B"): + io, ko = T.axis.remap("SR", [i0, i2_0]) with T.init(): for i1 in T.serial(0, 4): - with T.block([4], "B_init") as [ii_init]: - T.bind(ii_init, i1) + with T.block("B_init"): + ii_init = T.axis.S(4, i1) B[io, ii_init] = 0.0 for i1_1, i2_1 in T.grid(4, 8): - with T.block([4, T.reduce_axis(0, 128)], "B") as [ii, k]: - T.bind(ii, i1_1) - T.bind(k, ko * 8 + i2_1) + with T.block("B"): + ii = T.axis.S(4, i1_1) + k = T.axis.R(128, ko * 8 + i2_1) B[io, ii] = B[io, ii] + A[io, ii, k] @@ -52,11 +51,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -65,11 +65,15 @@ def matmul_decompose0(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - with T.block([128, 128], "init") as [vi, vj]: - C[vi, vj] = 0.0 + for i, j in T.grid(128, 128): + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = 0.0 - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -78,16 +82,19 @@ def matmul_decompose1(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [32, 4], elem_offset=0, align=128, offset_factor=1) for i0 in T.serial(0, 32): - with T.block([32], "blockized_B_init") as [io]: + with T.block("blockized_B_init"): + io = T.axis.S(32, i0) for i1 in T.serial(0, 4): - with T.block([4], "B_init") as [ii]: + with T.block("B_init"): + ii = T.axis.S(4, i1) B[io, ii] = T.float32(0) for i0, i2_o in T.grid(32, 16): - with T.block([32, T.reduce_axis(0, 16)], "blockized_B_update") as [io, ko]: + with T.block("blockized_B_update"): + io, ko = T.axis.remap("SR", [i0, i2_o]) for i1, i2_i in T.grid(4, 8): - with T.block([4, T.reduce_axis(0, 128)], "B") as [ii, k]: - T.bind(ii, i1) - T.bind(k, ((ko * 8) + i2_i)) + with T.block("B"): + ii = T.axis.S(4, i1) + k = T.axis.R(128, ko * 8 + i2_i) B[io, ii] = B[io, ii] + A[io, ii, k] @@ -98,10 +105,12 @@ def matmul_decompose2(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) for i0, i1 in T.grid(128, 128): - with T.block([128, 128], "update_init") as [vi_init, vj_init]: + with T.block("update_init"): + vi_init, vj_init = T.axis.remap("SS", [i0, i1]) C[vi_init, vj_init] = T.float32(0) for i2 in T.serial(0, 128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update_update") as [vi, vj, vk]: + with T.block("update_update"): + vi, vj, vk = T.axis.remap("SSR", [i0, i1, i2]) C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk]) @@ -112,12 +121,10 @@ def matmul_decompose_fail3(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128]) for i, k, j in T.grid(128, 128, 128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -127,25 +134,21 @@ def matmul_decompose4(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body - with T.block([], "root"): + with T.block("root"): T.reads([]) T.writes([]) for i0_0 in T.serial(0, 16): for i0_1_init, i1_init in T.grid(8, 128): - with T.block([128, 128], "update_init") as [vi_init, vj_init]: - T.bind(vi_init, ((i0_0 * 8) + i0_1_init)) - T.bind(vj_init, i1_init) + with T.block("update_init"): + vi_init = T.axis.S(128, i0_0 * 8 + i0_1_init) + vj_init = T.axis.S(128, i1_init) C[vi_init, vj_init] = T.float32(0) for i0_1, i1, i2_0, i2_1 in T.grid(8, 128, 19, 7): - with T.block([128, 128, T.reduce_axis(0, 128)], "update_update") as [ - vi, - vj, - vk, - ]: + with T.block("update_update"): T.where((((i2_0 * 7) + i2_1) < 128)) - T.bind(vi, ((i0_0 * 8) + i0_1)) - T.bind(vj, i1) - T.bind(vk, ((i2_0 * 7) + i2_1)) + vi = T.axis.S(128, i0_0 * 8 + i0_1) + vj = T.axis.S(128, i1) + vk = T.axis.R(128, i2_0 * 7 + i2_1) C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk]) diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py index a60ab8dca9725..8267a369cf5de 100644 --- a/tests/python/unittest/test_tir_schedule_reorder.py +++ b/tests/python/unittest/test_tir_schedule_reorder.py @@ -30,8 +30,10 @@ def elementwise(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) - with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: - B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + for i, j, k, l in T.grid(128, 128, 128, 128): + with T.block("B"): + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @T.prim_func @@ -39,11 +41,9 @@ def elementwise_not_affine(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for i, j, k, l in T.grid(128, 128, 128, 8): - with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) - T.bind(vl, l * 16) + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + vl = T.axis.S(128, l * 16) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -53,7 +53,8 @@ def elementwise_dependent_loop(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128, 128)) for i in T.serial(0, 128): for j, k, l in T.grid(128, i, 128): - with T.block([128, 128, i, 128], "B") as [vi, vj, vk, vl]: + with T.block("B"): + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -62,8 +63,9 @@ def elementwise_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for i, j, k, l in T.grid(128, 128, 128, 128): - with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + with T.block("B"): T.where(i * 2097152 + j * 16384 + k * 128 + l < 100) + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -74,16 +76,12 @@ def elementwise_non_single_branch(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128)) for i, j in T.grid(128, 128): for k in T.serial(0, 128): - with T.block([128, 128, 128], "C") as [vi, vj, vk]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("C"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) C[vi, vj, vk] = A[vi, vj, vk] * 2.0 for k in T.serial(0, 128): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = C[vi, vj, vk] * 2.0 @@ -92,12 +90,11 @@ def elementwise_with_loops_not_same_scope(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) for i, j in T.grid(128, 128): - with T.block([128, 128], "A") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) for k in T.serial(0, 128): - with T.block([128], "B") as [vk]: - T.bind(vk, k) + with T.block("B"): + vk = T.axis.S(128, k) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -108,10 +105,9 @@ def elementwise_with_wrong_block_var_type(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) for i, j, k in T.grid(128, 128, 128): - with T.block([128, 128, T.scan_axis(0, 128)], "B") as [vi, vj, vk]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + vk = T.axis.scan(128, k) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -122,11 +118,8 @@ def elementwise_reordered(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for l, j, k, i in T.grid(128, 128, 128, 128): - with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) - T.bind(vl, l) + with T.block("B"): + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -135,11 +128,8 @@ def elementwise_reordered2(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for k, j, i, l in T.grid(128, 128, 128, 128): - with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) - T.bind(vl, l) + with T.block("B"): + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -148,12 +138,9 @@ def elementwise_reordered_with_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for l, j, k, i in T.grid(128, 128, 128, 128): - with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + with T.block("B"): T.where(i * 2097152 + j * 16384 + k * 128 + l < 100) - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) - T.bind(vl, l) + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 @@ -161,14 +148,18 @@ def elementwise_reordered_with_predicate(a: T.handle, b: T.handle) -> None: def opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16], "float32") B = T.match_buffer(b, [16, 16], "float32") - with T.block([16, 16], "A") as [vi, vj]: - T.reads([]) - T.writes([A[0:16, 0:16]]) - T.store(A.data, vi * 16 + vj, 1) - with T.block([16, 16], "B") as [vi, vj]: - T.reads([]) - T.writes([B[0:16, 0:16]]) - T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) + for i, j in T.grid(16, 16): + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads([]) + T.writes([A[0:16, 0:16]]) + T.store(A.data, vi * 16 + vj, 1) + for i, j in T.grid(16, 16): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads([]) + T.writes([B[0:16, 0:16]]) + T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) @T.prim_func @@ -176,16 +167,14 @@ def opaque_access_reorder(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16], "float32") B = T.match_buffer(b, [16, 16], "float32") for j, i in T.grid(16, 16): - with T.block([16, 16], "A") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) T.reads([]) T.writes([A[0:16, 0:16]]) T.store(A.data, vi * 16 + vj, 1) for j, i in T.grid(16, 16): - with T.block([16, 16], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) T.reads([]) T.writes([B[0:16, 0:16]]) T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py b/tests/python/unittest/test_tir_schedule_rfactor.py index 78b6a4696baa0..bd474ed342954 100644 --- a/tests/python/unittest/test_tir_schedule_rfactor.py +++ b/tests/python/unittest/test_tir_schedule_rfactor.py @@ -34,10 +34,9 @@ def transformed_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128]) for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - T.bind(vi, i0) - T.bind(vj, i1) - T.bind(vk, (((i2_outer * 32) + (i2_inner_outer * 4)) + i2_inner_inner)) + with T.block("update"): + vi, vj = T.axis.remap("SS", [i0, i1]) + vk = T.axis.R(128, i2_outer * 32 + i2_inner_outer * 4 + i2_inner_inner) T.reads([C[vi, vj], A[vi, vk], B[vj, vk]]) T.writes([C[vi, vj]]) with T.init(): @@ -53,18 +52,12 @@ def matmul_rfactor(a: T.handle, b: T.handle, c: T.handle) -> None: C_rf = T.alloc_buffer([4, 128, 128]) for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4): - with T.block([4, 128, 128, T.reduce_axis(0, 4), T.reduce_axis(0, 8)], "update_rf") as [ - vi2_inner_inner, - vi, - vj, - vi2_outer, - vi2_inner_outer, - ]: - T.bind(vi2_inner_inner, i2_inner_inner) - T.bind(vi, i0) - T.bind(vj, i1) - T.bind(vi2_outer, i2_outer) - T.bind(vi2_inner_outer, i2_inner_outer) + with T.block("update_rf"): + vi2_inner_inner = T.axis.S(4, i2_inner_inner) + vi = T.axis.S(128, i0) + vj = T.axis.S(128, i1) + vi2_outer = T.axis.R(4, i2_outer) + vi2_inner_outer = T.axis.R(8, i2_inner_outer) with T.init(): C_rf[vi2_inner_inner, vi, vj] = 0.0 C_rf[vi2_inner_inner, vi, vj] = C_rf[vi2_inner_inner, vi, vj] + ( @@ -73,14 +66,8 @@ def matmul_rfactor(a: T.handle, b: T.handle, c: T.handle) -> None: ) for i0_1, i1_1, i2_inner_inner_1 in T.grid(128, 128, 4): - with T.block([T.reduce_axis(0, 4), 128, 128], "update") as [ - vi2_inner_inner_1, - vi_1, - vj_1, - ]: - T.bind(vi2_inner_inner_1, i2_inner_inner_1) - T.bind(vi_1, i0_1) - T.bind(vj_1, i1_1) + with T.block("update"): + vi2_inner_inner_1, vi_1, vj_1 = T.axis.remap("RSS", [i2_inner_inner_1, i0_1, i1_1]) with T.init(): C[vi_1, vj_1] = 0.0 C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1] @@ -93,13 +80,17 @@ def matmul_not_stage_pipeline(a: T.handle, b: T.handle, d: T.handle) -> None: D = T.match_buffer(d, [256, 256]) C = T.alloc_buffer([256, 256]) - with T.block([128, 128, T.reduce_axis(0, 128)], "C") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(128, 128, 128): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - with T.block([256, 256], "D") as [vi, vj]: - D[vi, vj] = C[vi, vj] + for i, j in T.grid(256, 256): + with T.block("D"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = C[vi, vj] @T.prim_func @@ -108,10 +99,12 @@ def matmul_not_same_buffer_access(a: T.handle, b: T.handle, c: T.handle) -> None B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128, T.reduce_axis(0, 128)], "C") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vj, vi] = C[vj, vi] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(128, 128, 128): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vj, vi] = C[vj, vi] + A[vi, vk] * B[vk, vj] @T.prim_func @@ -122,17 +115,13 @@ def matmul_loop_multiple_children(a: T.handle, b: T.handle, c: T.handle, d: T.ha D = T.match_buffer(d, [128, 128]) for k, i, j in T.grid(128, 128, 128): - with T.block([T.reduce_axis(0, 128), 128, 128], "C") as [ck, ci, cj]: - T.bind(ck, k) - T.bind(ci, i) - T.bind(cj, j) + with T.block("C"): + ck, ci, cj = T.axis.remap("RSS", [k, i, j]) with T.init(): C[ci, cj] = 0.0 C[ci, cj] = C[ci, cj] + A[ci, ck] * B[ck, cj] - with T.block([T.reduce_axis(0, 128), 128, 128], "D") as [dk, di, dj]: - T.bind(dk, k) - T.bind(di, i) - T.bind(dj, j) + with T.block("D"): + dk, di, dj = T.axis.remap("RSS", [k, i, j]) with T.init(): D[di, dj] = 0.0 D[di, dj] = D[di, dj] + B[di, dk] * A[dk, dj] @@ -143,10 +132,12 @@ def square_sum(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [16, 256, 256]) C = T.match_buffer(c, [16]) - with T.block([16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C") as [b, i, j]: - with T.init(): - C[b] = 0.0 - C[b] = C[b] + A[b, i, j] * A[b, i, j] + for b0, i0, j0 in T.grid(16, 256, 256): + with T.block("C"): + b, i, j = T.axis.remap("SRR", [b0, i0, j0]) + with T.init(): + C[b] = 0.0 + C[b] = C[b] + A[b, i, j] * A[b, i, j] @T.prim_func @@ -156,18 +147,15 @@ def square_sum_rfactor(a: T.handle, c: T.handle) -> None: C_rf = T.alloc_buffer([16, 256]) for i0, i1, i2 in T.grid(16, 256, 256): - with T.block([256, 16, T.reduce_axis(0, 256)], "C_rf") as [vi2, b, i]: - T.bind(vi2, i2) - T.bind(b, i0) - T.bind(i, i1) + with T.block("C_rf"): + vi2, b, i = T.axis.remap("SSR", [i2, i0, i1]) with T.init(): C_rf[b, vi2] = 0.0 C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2]) for i0_1, i2_1 in T.grid(16, 256): - with T.block([T.reduce_axis(0, 256), 16], "C") as [vi2_1, b_1]: - T.bind(vi2_1, i2_1) - T.bind(b_1, i0_1) + with T.block("C"): + vi2_1, b_1 = T.axis.remap("RS", [i2_1, i0_1]) with T.init(): C[b_1] = 0.0 C[b_1] = C[b_1] + C_rf[b_1, vi2_1] @@ -180,18 +168,18 @@ def transformed_square_sum_square_root(a: T.handle, d: T.handle) -> None: C = T.alloc_buffer([16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): - with T.block([16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C") as [b, i, j]: - T.bind(b, i0) - T.bind(i, T.floordiv(i1_i2_fused_outer, 256)) - T.bind(j, T.floormod(i1_i2_fused_outer, 256)) + with T.block("C"): + b = T.axis.S(16, i0) + i = T.axis.R(256, T.floordiv(i1_i2_fused_outer, 256)) + j = T.axis.R(256, T.floormod(i1_i2_fused_outer, 256)) T.reads([C[b], A[b, i, j]]) T.writes([C[b]]) with T.init(): C[b] = 0.0 C[b] = C[b] + (A[b, i, j] * A[b, i, j]) for i0_1 in T.serial(0, 16): - with T.block([16], "D") as [b_1]: - T.bind(b_1, i0_1) + with T.block("D"): + b_1 = T.axis.S(16, i0_1) T.reads([C[b_1]]) T.writes([D[b_1]]) D[b_1] = T.sqrt(C[b_1], dtype="float32") @@ -205,31 +193,24 @@ def square_sum_square_root_rfactor(a: T.handle, d: T.handle) -> None: C_rf = T.alloc_buffer([1, 16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): - with T.block([1, 16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C_rf") as [ - vi1_i2_fused_inner, - b, - i, - j, - ]: - T.bind(vi1_i2_fused_inner, i1_i2_fused_inner) - T.bind(b, i0) - T.bind(i, T.floordiv(i1_i2_fused_outer, 256)) - T.bind(j, T.floormod(i1_i2_fused_outer, 256)) + with T.block("C_rf"): + vi1_i2_fused_inner, b = T.axis.remap("SS", [i1_i2_fused_inner, i0]) + i = T.axis.R(256, T.floordiv(i1_i2_fused_outer, 256)) + j = T.axis.R(256, T.floormod(i1_i2_fused_outer, 256)) with T.init(): C_rf[vi1_i2_fused_inner, b] = 0.0 C_rf[vi1_i2_fused_inner, b] = C_rf[vi1_i2_fused_inner, b] + (A[b, i, j] * A[b, i, j]) for i0_1, i1_i2_fused_inner_1 in T.grid(16, 1): - with T.block([T.reduce_axis(0, 1), 16], "C") as [vi1_i2_fused_inner_1, b_1]: - T.bind(vi1_i2_fused_inner_1, i1_i2_fused_inner_1) - T.bind(b_1, i0_1) + with T.block("C"): + vi1_i2_fused_inner_1, b_1 = T.axis.remap("RS", [i1_i2_fused_inner_1, i0_1]) with T.init(): C[b_1] = 0.0 C[b_1] = C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1] for i0_2 in T.serial(0, 16): - with T.block([16], "D") as [b_2]: - T.bind(b_2, i0_2) + with T.block("D"): + b_2 = T.axis.S(16, i0_2) D[b_2] = T.sqrt(C[b_2], dtype="float32") @@ -238,8 +219,10 @@ def element_wise(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 @T.prim_func @@ -247,10 +230,12 @@ def rowsum(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128,)) - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - with T.init(): - B[vi] = 0.0 - B[vi] = B[vi] + A[vi, vk] + for i, k in T.grid(128, 128): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] @T.prim_func @@ -259,9 +244,9 @@ def rowsum_not_quasi_affine(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for i, k in T.grid(128, 16): - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - T.bind(vi, i) - T.bind(vk, T.floordiv(k * k, 2)) + with T.block("B"): + vi = T.axis.S(128, i) + vk = T.axis.R(128, T.floordiv(k * k, 2)) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] @@ -272,10 +257,12 @@ def rowsum_not_dominant(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - with T.init(): - B[vi, vk] = 0.0 - B[vi, vk] = B[vi, vk] + A[vi, vk] + for i, k in T.grid(128, 128): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + with T.init(): + B[vi, vk] = 0.0 + B[vi, vk] = B[vi, vk] + A[vi, vk] @T.prim_func @@ -285,9 +272,8 @@ def rowsum_not_serial(a: T.handle, b: T.handle) -> None: for i in T.serial(0, 128): for k in T.parallel(0, 128): - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - T.bind(vi, i) - T.bind(vk, k) + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] @@ -298,10 +284,12 @@ def rowsum_wrong_reduce_pattern1(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128,)) - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - with T.init(): - B[vi] = 1.0 - B[vi] = B[vi] + A[vi, vk] + for i, k in T.grid(128, 128): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + with T.init(): + B[vi] = 1.0 + B[vi] = B[vi] + A[vi, vk] @T.prim_func @@ -309,10 +297,12 @@ def rowsum_wrong_reduce_pattern2(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128,)) - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - with T.init(): - B[vi] = 0.0 - B[vi] = B[vi] - A[vi, vk] + for i, k in T.grid(128, 128): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] - A[vi, vk] @T.prim_func @@ -321,9 +311,9 @@ def rowsum_transformed(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128,)) for io, ii_ko_fused, ki in T.grid(32, 128, 4): - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - T.bind(vi, io * 4 + T.floordiv(ii_ko_fused, 32)) - T.bind(vk, T.floormod(ii_ko_fused, 32) * 4 + ki) + with T.block("B"): + vi = T.axis.S(128, io * 4 + T.floordiv(ii_ko_fused, 32)) + vk = T.axis.R(128, T.floormod(ii_ko_fused, 32) * 4 + ki) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] @@ -334,10 +324,12 @@ def rowsum_zero_dim(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128]) B = T.match_buffer(b, []) - with T.block([T.reduce_axis(0, 128)], "B") as [k]: - with T.init(): - B[()] = 0.0 - B[()] = B[()] + A[k] + for k0 in range(128): + with T.block("B"): + k = T.axis.R(128, k0) + with T.init(): + B[()] = 0.0 + B[()] = B[()] + A[k] @T.prim_func @@ -346,15 +338,19 @@ def rowsum_zero_dim_rfactor(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, []) B_rf = T.alloc_buffer([128]) - with T.block([128], "B_rf") as [vi0]: - with T.init(): - B_rf[vi0] = 0.0 - B_rf[vi0] = B_rf[vi0] + A[vi0] + for i in range(128): + with T.block("B_rf"): + vi0 = T.axis.S(128, i) + with T.init(): + B_rf[vi0] = 0.0 + B_rf[vi0] = B_rf[vi0] + A[vi0] - with T.block([T.reduce_axis(0, 128)], "B") as [vi0_1]: - with T.init(): - B[()] = 0.0 - B[()] = B[()] + B_rf[vi0_1] + for i in range(128): + with T.block("B"): + vi0_1 = T.axis.R(128, i) + with T.init(): + B[()] = 0.0 + B[()] = B[()] + B_rf[vi0_1] @T.prim_func @@ -362,10 +358,10 @@ def rowsum_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") for i, k_0, k_1 in T.grid(128, 13, 10): - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: + with T.block("B"): T.where(k_0 * 10 + k_1 < 128) - T.bind(vi, i) - T.bind(vk, k_0 * 10 + k_1) + vi = T.axis.S(128, i) + vk = T.axis.R(128, k_0 * 10 + k_1) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] @@ -377,18 +373,15 @@ def rowsum_predicate_rfactor(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128], dtype="float32") B_rf = T.alloc_buffer([128, 13], dtype="float32") for i, k_0, k_1 in T.grid(128, 13, 10): - with T.block([13, 128, T.reduce_axis(0, 10)], "B_rf") as [vk_0, vi, vk_1]: + with T.block("B_rf"): + vk_0, vi, vk_1 = T.axis.remap("SSR", [k_0, i, k_1]) T.where(k_0 * 10 + k_1 < 128) - T.bind(vk_0, k_0) - T.bind(vi, i) - T.bind(vk_1, k_1) with T.init(): B_rf[vi, vk_0] = T.float32(0) B_rf[vi, vk_0] = B_rf[vi, vk_0] + A[vi, vk_0 * 10 + vk_1] for i, k_0 in T.grid(128, 13): - with T.block([T.reduce_axis(0, 13), 128], "B") as [vk_0, vi]: - T.bind(vk_0, k_0) - T.bind(vi, i) + with T.block("B"): + vk_0, vi = T.axis.remap("RS", [k_0, i]) with T.init(): B[vi] = T.float32(0) B[vi] = B[vi] + B_rf[vi, vk_0] @@ -405,35 +398,31 @@ def multiple_reduction_blocks(a: T.handle, f: T.handle) -> None: for i in T.serial(0, 16): for j1 in T.serial(0, 16): for k1o, k1i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "C") as [ci, cj, ck]: - T.bind(ci, i) - T.bind(cj, j1) - T.bind(ck, k1o * 4 + k1i) + with T.block("C"): + ci, cj = T.axis.remap("SS", [i, j1]) + ck = T.axis.R(16, k1o * 4 + k1i) with T.init(): C[ci, cj] = 0.0 C[ci, cj] = C[ci, cj] + A[ci, cj, ck] for k2o, k2i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "D") as [di, dj, dk]: - T.bind(di, i) - T.bind(dj, j1) - T.bind(dk, k2o * 4 + k2i) + with T.block("D"): + di, dj = T.axis.remap("SS", [i, j1]) + dk = T.axis.R(16, k2o * 4 + k2i) with T.init(): D[di, dj] = 0.0 D[di, dj] = D[di, dj] + A[di, dj, dk] + C[di, dj] for j2 in T.serial(0, 16): for k3o, k3i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "E") as [ei, ej, ek]: - T.bind(ei, i) - T.bind(ej, j2) - T.bind(ek, k3o * 4 + k3i) + with T.block("E"): + ei, ej = T.axis.remap("SS", [i, j2]) + ek = T.axis.R(16, k3o * 4 + k3i) with T.init(): E[ei, ej] = 0.0 E[ei, ej] = E[ei, ej] + A[ei, ej, ek] + D[ei, ej] for k4o, k4i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "F") as [fi, fj, fk]: - T.bind(fi, i) - T.bind(fj, j2) - T.bind(fk, k4o * 4 + k4i) + with T.block("F"): + fi, fj = T.axis.remap("SS", [i, j2]) + fk = T.axis.R(16, k4o * 4 + k4i) with T.init(): F[fi, fj] = 0.0 F[fi, fj] = F[fi, fj] + A[fi, fj, fk] + E[fi, fj] @@ -449,46 +438,38 @@ def multiple_reduction_blocks_rfactor(a: T.handle, f: T.handle) -> None: C_rf = T.alloc_buffer([16, 16, 4]) for i, j1, k1o, k1i in T.grid(16, 16, 4, 4): - with T.block([4, 16, 16, T.reduce_axis(0, 4)], "C_rf") as [vk1o, ci, cj, vk1i]: - T.bind(vk1o, k1o) - T.bind(ci, i) - T.bind(cj, j1) - T.bind(vk1i, k1i) + with T.block("C_rf"): + vk1o, ci, cj, vk1i = T.axis.remap("SSSR", [k1o, i, j1, k1i]) with T.init(): C_rf[ci, cj, vk1o] = 0.0 C_rf[ci, cj, vk1o] = C_rf[ci, cj, vk1o] + A[ci, cj, ((vk1o * 4) + vk1i)] for i_1 in T.serial(0, 16): for j1_1 in T.serial(0, 16): for k1o_1 in T.serial(0, 4): - with T.block([T.reduce_axis(0, 4), 16, 16], "C") as [vk1o_1, ci_1, cj_1]: - T.bind(vk1o_1, k1o_1) - T.bind(ci_1, i_1) - T.bind(cj_1, j1_1) + with T.block("C"): + vk1o_1, ci_1, cj_1 = T.axis.remap("RSS", [k1o_1, i_1, j1_1]) with T.init(): C[ci_1, cj_1] = 0.0 C[ci_1, cj_1] = C[ci_1, cj_1] + C_rf[ci_1, cj_1, vk1o_1] for k2o, k2i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "D") as [di, dj, dk]: - T.bind(di, i_1) - T.bind(dj, j1_1) - T.bind(dk, (k2o * 4) + k2i) + with T.block("D"): + di, dj = T.axis.remap("SS", [i_1, j1_1]) + dk = T.axis.R(16, k2o * 4 + k2i) with T.init(): D[di, dj] = 0.0 D[di, dj] = (D[di, dj] + A[di, dj, dk]) + C[di, dj] for j2 in T.serial(0, 16): for k3o, k3i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "E") as [ei, ej, ek]: - T.bind(ei, i_1) - T.bind(ej, j2) - T.bind(ek, (k3o * 4) + k3i) + with T.block("E"): + ei, ej = T.axis.remap("SS", [i_1, j2]) + ek = T.axis.R(16, k3o * 4 + k3i) with T.init(): E[ei, ej] = 0.0 E[ei, ej] = (E[ei, ej] + A[ei, ej, ek]) + D[ei, ej] for k4o, k4i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "F") as [fi, fj, fk]: - T.bind(fi, i_1) - T.bind(fj, j2) - T.bind(fk, (k4o * 4) + k4i) + with T.block("F"): + fi, fj = T.axis.remap("SS", [i_1, j2]) + fk = T.axis.R(16, k4o * 4 + k4i) with T.init(): F[fi, fj] = 0.0 F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj] diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py index c93c7ca63aa88..fbf0a6a5bd78d 100644 --- a/tests/python/unittest/test_tir_schedule_sampling.py +++ b/tests/python/unittest/test_tir_schedule_sampling.py @@ -32,8 +32,10 @@ def elementwise(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + for i, j, k in T.grid(128, 128, 128): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 # pylint: enable=no-member,invalid-name,unused-variable diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py index 29cfe8cadfb35..d2365c39c9cb2 100644 --- a/tests/python/unittest/test_tir_schedule_split_fuse.py +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -30,8 +30,10 @@ def elementwise(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + for i, j, k in T.grid(128, 128, 128): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @T.prim_func @@ -40,7 +42,10 @@ def elementwise_dependent_loops(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128)) for i in T.serial(0, 128): for j, k in T.grid(i, 128): - with T.block([128, i, 128], "B") as [vi, vj, vk]: + with T.block("B"): + vi = T.axis.S(128, i) + vj = T.axis.S(i, j) + vk = T.axis.S(128, k) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -49,7 +54,8 @@ def elementwise_symbolic(a: T.handle, b: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (128, 128, n)) B = T.match_buffer(b, (128, 128, n)) for i, j, k in T.grid(128, 128, n): - with T.block([128, 128, n], "B") as [vi, vj, vk]: + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -58,10 +64,10 @@ def elementwise_symbolic_fused(a: T.handle, b: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (128, 128, n)) B = T.match_buffer(b, (128, 128, n)) for i_j_k_fused in T.serial(0, (n * 16384)): - with T.block([128, 128, n], "B") as [vi, vj, vk]: - T.bind(vi, T.floordiv(i_j_k_fused, (n * 128))) - T.bind(vj, T.floormod(T.floordiv(i_j_k_fused, n), 128)) - T.bind(vk, T.floormod(i_j_k_fused, n)) + with T.block("B"): + vi = T.axis.S(128, T.floordiv(i_j_k_fused, n * 128)) + vj = T.axis.S(128, T.floormod(T.floordiv(i_j_k_fused, n), 128)) + vk = T.axis.S(n, T.floormod(i_j_k_fused, n)) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -72,11 +78,10 @@ def elementwise_symbolic_split(a: T.handle, b: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (128, 128, n)) B = T.match_buffer(b, (128, 128, n)) for i, j, k0, k1 in T.grid(128, 128, 10, T.floordiv((n + 9), 10)): - with T.block([128, 128, n], "B") as [vi, vj, vk]: + with T.block("B"): T.where((((k0 * T.floordiv((n + 9), 10)) + k1) < n)) - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, ((k0 * T.floordiv((n + 9), 10)) + k1)) + vi, vj = T.axis.remap("SS", [i, j]) + vk = T.axis.S(n, k0 * T.floordiv(n + 9, 10) + k1) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -89,10 +94,12 @@ def elementwise_with_seq(a: T.handle, b: T.handle) -> None: C = T.alloc_buffer((128, 128, 128)) for i, j in T.grid(128, 128): for k in T.serial(0, 128): - with T.block([128, 128, 128], "C") as [vi, vj, vk]: + with T.block("C"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) C[vi, vj, vk] = A[vi, vj, vk] * 2.0 for k in T.serial(0, 128): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = C[vi, vj, vk] * 2.0 @@ -102,10 +109,8 @@ def elementwise_with_anno(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128)) for i, j in T.grid(128, 128): for k in T.serial(0, 128, annotations={"useless_annotation": True}): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -117,10 +122,8 @@ def elementwise_with_thread_binding(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128)) for i, j in T.grid(128, 128): for k in T.thread_binding(0, 128, thread="threadIdx.x"): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -132,10 +135,8 @@ def elementwise_with_starting_point(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (128, 128, 128)) for i, j in T.grid(128, 128): for k in T.serial(10, 128): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -146,13 +147,11 @@ def elementwise_with_opaque_block(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) for i, j, k in T.grid(128, 128, 128): - with T.block([], "opaque"): + with T.block("opaque"): T.reads([A[i, j, k]]) T.writes([B[i, j, k]]) - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, i) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -163,10 +162,10 @@ def elementwise_fused(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) for fused in T.serial(0, 2097152): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, T.floordiv(fused, 16384)) - T.bind(vj, T.floormod(T.floordiv(fused, 128), 128)) - T.bind(vk, T.floormod(fused, 128)) + with T.block("B"): + vi = T.axis.S(128, T.floordiv(fused, 16384)) + vj = T.axis.S(128, T.floormod(T.floordiv(fused, 128), 128)) + vk = T.axis.S(128, T.floormod(fused, 128)) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -177,10 +176,10 @@ def elementwise_split_case0(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128]) B = T.match_buffer(b, [128, 128, 128]) for i1, i2, i3, j1, j2, k1, k2 in T.grid(2, 1, 64, 4, 32, 16, 8): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, ((i1 * 64) + i3)) - T.bind(vj, ((j1 * 32) + j2)) - T.bind(vk, ((k1 * 8) + k2)) + with T.block("B"): + vi = T.axis.S(128, i1 * 64 + i3) + vj = T.axis.S(128, j1 * 32 + j2) + vk = T.axis.S(128, k1 * 8 + k2) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -191,10 +190,10 @@ def elementwise_split_case1(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128]) B = T.match_buffer(b, [128, 128, 128]) for i1, i2, i3, j1, j2, j3, k1, k2, k3 in T.grid(2, 1, 64, 2, 1, 64, 2, 1, 64): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, i1 * 64 + i3) - T.bind(vj, j1 * 64 + j3) - T.bind(vk, k1 * 64 + k3) + with T.block("B"): + vi = T.axis.S(128, i1 * 64 + i3) + vj = T.axis.S(128, j1 * 64 + j3) + vk = T.axis.S(128, k1 * 64 + k3) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -205,16 +204,11 @@ def elementwise_split_with_predicate(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128, 128, 128]) A = T.match_buffer(a, [128, 128, 128]) for i0, i1, i2, j0, j1, k0, k1 in T.grid(1000, 2, 3, 1, 129, 3, 43): - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.where( - ( - ((((((i0 * 2) + i1) * 3) + i2) < 128) and (((j0 * 129) + j1) < 128)) - and (((k0 * 43) + k1) < 128) - ) - ) - T.bind(vi, (((i0 * 6) + (i1 * 3)) + i2)) - T.bind(vj, j1) - T.bind(vk, ((k0 * 43) + k1)) + with T.block("B"): + T.where((i0 * 2 + i1) * 3 + i2 < 128 and j0 * 129 + j1 < 128 and k0 * 43 + k1 < 128) + vi = T.axis.S(128, i0 * 6 + i1 * 3 + i2) + vj = T.axis.S(128, j1) + vk = T.axis.S(128, k0 * 43 + k1) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -225,7 +219,7 @@ def elementwise_fuse_with_opaque_block(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128, 128, 128]) A = T.match_buffer(a, [128, 128, 128]) for i_j_k_fused in T.serial(0, 2097152): - with T.block([], "opaque"): + with T.block("opaque"): T.reads( [ A[ @@ -244,10 +238,10 @@ def elementwise_fuse_with_opaque_block(a: T.handle, b: T.handle) -> None: ] ] ) - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, T.floordiv(i_j_k_fused, 16384)) - T.bind(vj, T.floormod(T.floordiv(i_j_k_fused, 128), 128)) - T.bind(vk, T.floormod(i_j_k_fused, 128)) + with T.block("B"): + vi = T.axis.S(128, T.floordiv(i_j_k_fused, 16384)) + vj = T.axis.S(128, T.floormod(T.floordiv(i_j_k_fused, 128), 128)) + vk = T.axis.S(128, T.floormod(i_j_k_fused, 128)) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -259,13 +253,12 @@ def elementwise_split_with_opaque_block(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128]) for i0, i1, j, k in T.grid(8, 16, 128, 128): - with T.block([], "opaque"): + with T.block("opaque"): T.reads([A[i0 * 16 + i1, j, k]]) T.writes([B[i0 * 16 + i1, j, k]]) - with T.block([128, 128, 128], "B") as [vi, vj, vk]: - T.bind(vi, i0 * 16 + i1) - T.bind(vj, j) - T.bind(vk, k) + with T.block("B"): + vi = T.axis.S(128, i0 * 16 + i1) + vj, vk = T.axis.remap("SS", [j, k]) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -275,14 +268,18 @@ def elementwise_split_with_opaque_block(a: T.handle, b: T.handle) -> None: def opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16], "float32") B = T.match_buffer(b, [16, 16], "float32") - with T.block([16, 16], "A") as [vi, vj]: - T.reads([]) - T.writes([A[0:16, 0:16]]) - T.store(A.data, vi * 16 + vj, 1) - with T.block([16, 16], "B") as [vi, vj]: - T.reads([]) - T.writes([B[0:16, 0:16]]) - T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) + for i, j in T.grid(16, 16): + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads([]) + T.writes([A[0:16, 0:16]]) + T.store(A.data, vi * 16 + vj, 1) + for i, j in T.grid(16, 16): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads([]) + T.writes([B[0:16, 0:16]]) + T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) @T.prim_func @@ -290,16 +287,16 @@ def opaque_access_fused(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16]) B = T.match_buffer(b, [16, 16]) for i_j_fused in T.serial(0, 256): - with T.block([16, 16], "A") as [vi, vj]: - T.bind(vi, T.floordiv(i_j_fused, 16)) - T.bind(vj, T.floormod(i_j_fused, 16)) + with T.block("A"): + vi = T.axis.S(16, T.floordiv(i_j_fused, 16)) + vj = T.axis.S(16, T.floormod(i_j_fused, 16)) T.reads([]) T.writes([A[0:16, 0:16]]) T.store(A.data, ((vi * 16) + vj), 1, 1) for i_j_fused in T.serial(0, 256): - with T.block([16, 16], "B") as [vi, vj]: - T.bind(vi, T.floordiv(i_j_fused, 16)) - T.bind(vj, T.floormod(i_j_fused, 16)) + with T.block("B"): + vi = T.axis.S(16, T.floordiv(i_j_fused, 16)) + vj = T.axis.S(16, T.floormod(i_j_fused, 16)) T.reads([]) T.writes([B[0:16, 0:16]]) T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, ((vi * 16) + vj), dtype="handle")) @@ -310,16 +307,16 @@ def opaque_access_split(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (16, 16)) B = T.match_buffer(b, (16, 16)) for i, j0, j1 in T.grid(16, 4, 4): - with T.block([16, 16], "A") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, ((j0 * 4) + j1)) + with T.block("A"): + vi = T.axis.S(16, i) + vj = T.axis.S(16, j0 * 4 + j1) T.reads([]) T.writes([A[0:16, 0:16]]) T.store(A.data, ((vi * 16) + vj), 1, 1) for i, j0, j1 in T.grid(16, 4, 4): - with T.block([16, 16], "B") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, ((j0 * 4) + j1)) + with T.block("B"): + vi = T.axis.S(16, i) + vj = T.axis.S(16, j0 * 4 + j1) T.reads([]) T.writes([B[0:16, 0:16]]) T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, ((vi * 16) + vj), dtype="handle")) @@ -331,9 +328,9 @@ def elementwise_not_affine(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, (127, 128)) for i in T.serial(0, 4): for j, k in T.grid(T.min(31, 126 - i * 32) + 1, 128): - with T.block([127, 128], "B") as [vi, vj]: - T.bind(vi, i * 32 + j) - T.bind(vj, k) + with T.block("B"): + vi = T.axis.S(127, i * 32 + j) + vj = T.axis.S(128, k) B[vi, vj] = A[vi, vj] @@ -343,12 +340,12 @@ def elementwise_not_affine_fused(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [127, 128]) for i in T.grid(4): for j_k_fused in T.serial(0, T.min(31, 126 - i * 32) * 128 + 128): - with T.block([127, 128], "B") as [vi, vj]: - T.bind( - vi, + with T.block("B"): + vi = T.axis.S( + 127, i * 32 + T.floormod(T.floordiv(j_k_fused, 128), T.min(31, 126 - i * 32) + 1), ) - T.bind(vj, T.floormod(j_k_fused, 128)) + vj = T.axis.S(128, T.floormod(j_k_fused, 128)) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) B[vi, vj] = A[vi, vj] diff --git a/tests/python/unittest/test_tir_schedule_state.py b/tests/python/unittest/test_tir_schedule_state.py index 94e1b4a6b3959..bc62fa1ba950d 100644 --- a/tests/python/unittest/test_tir_schedule_state.py +++ b/tests/python/unittest/test_tir_schedule_state.py @@ -32,10 +32,14 @@ def elementwise(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -44,10 +48,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block([128, 128], "init") as [vi, vj]: + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.float32(0) for k in range(0, 128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -55,22 +61,28 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: def block_in_opaque_block(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") B = T.match_buffer(b, (128, 128), "float32") - with T.block([128], "B") as vi: - T.reads([A[0:128, 0:128]]) - T.writes([B[0:128, 0:128]]) - B[vi, 0] = A[vi, 0] - if A[vi, 0] == 0.0: - with T.block([], "C"): - T.reads([A[0:128, 0:128]]) - T.writes([B[0:128, 0:128]]) - with T.block([128], "D") as vj: - B[vi, vj] = A[vi, vj] * 3.0 - else: - with T.block([], "E"): - T.reads([A[0:128, 0:128]]) - T.writes([B[0:128, 0:128]]) - with T.block([128], "F") as vj: - B[vi, vj] = A[vi, vj] * 2.0 + for i in range(128): + with T.block("B"): + vi = T.axis.S(128, i) + T.reads([A[0:128, 0:128]]) + T.writes([B[0:128, 0:128]]) + B[vi, 0] = A[vi, 0] + if A[vi, 0] == 0.0: + with T.block("C"): + T.reads([A[0:128, 0:128]]) + T.writes([B[0:128, 0:128]]) + for j in range(128): + with T.block("D"): + vj = T.axis.S(128, j) + B[vi, vj] = A[vi, vj] * 3.0 + else: + with T.block("E"): + T.reads([A[0:128, 0:128]]) + T.writes([B[0:128, 0:128]]) + for j in range(128): + with T.block("F"): + vj = T.axis.S(128, j) + B[vi, vj] = A[vi, vj] * 2.0 # pylint: enable=no-member,invalid-name,unused-variable diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py b/tests/python/unittest/test_tir_schedule_state_cached_flags.py index e2b39ce7c2895..e3bd000c2e705 100644 --- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py +++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py @@ -32,10 +32,14 @@ def elementwise(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -44,10 +48,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block([128, 128], "init") as [vi, vj]: + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = 0.0 for k in range(0, 128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -55,22 +61,28 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: def block_in_opaque_block(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") B = T.match_buffer(b, (128, 128), "float32") - with T.block([128], "B") as vi: - T.reads([A[0:128, 0:128]]) - T.writes([B[0:128, 0:128]]) - B[vi, 0] = A[vi, 0] - if A[vi, 0] == 0.0: - with T.block([], "C"): - T.reads([A[0:128, 0:128]]) - T.writes([B[0:128, 0:128]]) - with T.block([128], "D") as vj: - B[vi, vj] = A[vi, vj] * 3.0 - else: - with T.block([], "E"): - T.reads([A[0:128, 0:128]]) - T.writes([B[0:128, 0:128]]) - with T.block([128], "F") as vj: - B[vi, vj] = A[vi, vj] * 2.0 + for i in range(128): + with T.block("B"): + vi = T.axis.S(128, i) + T.reads([A[0:128, 0:128]]) + T.writes([B[0:128, 0:128]]) + B[vi, 0] = A[vi, 0] + if A[vi, 0] == 0.0: + with T.block("C"): + T.reads([A[0:128, 0:128]]) + T.writes([B[0:128, 0:128]]) + for j in range(128): + with T.block("D"): + vj = T.axis.S(128, j) + B[vi, vj] = A[vi, vj] * 3.0 + else: + with T.block("E"): + T.reads([A[0:128, 0:128]]) + T.writes([B[0:128, 0:128]]) + for j in range(128): + with T.block("F"): + vj = T.axis.S(128, j) + B[vi, vj] = A[vi, vj] * 2.0 @T.prim_func @@ -78,10 +90,14 @@ def write_after_read(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 @T.prim_func @@ -90,9 +106,11 @@ def loop_carried_dependency(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128,)) C = T.match_buffer(c, (128,)) for i in range(0, 128): - with T.block([128], "B") as vi: + with T.block("B"): + vi = T.axis.S(128, i) B[vi] = A[vi] * 2.0 - with T.block([128], "C") as vi: + with T.block("C"): + vi = T.axis.S(128, i) C[vi] = T.if_then_else(vi >= 1, B[vi - 1] + 1.0, 0.0, dtype="float32") @@ -101,14 +119,17 @@ def concatenate_multi_producer(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128,)) B = T.match_buffer(b, (128,)) for i in range(0, 64): - with T.block([64], "A_0") as vi: + with T.block("A_0"): + vi = T.axis.S(64, i) A[vi] = vi + 1 for i in range(0, 64): - with T.block([64], "A_1") as vi: - T.bind(vi, i + 64) + with T.block("A_1"): + vi = T.axis.S(64, i + 64) A[vi] = vi + 2 - with T.block([128], "B") as vi: - B[vi] = A[vi] * 2.0 + for i in range(0, 128): + with T.block("B"): + vi = T.axis.S(128, i) + B[vi] = A[vi] * 2.0 @T.prim_func @@ -116,14 +137,17 @@ def concatenate_multi_producer_uncovered(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128,)) B = T.match_buffer(b, (128,)) for i in range(0, 63): - with T.block([63], "A_0") as vi: + with T.block("A_0"): + vi = T.axis.S(63, i) A[vi] = vi + 1 for i in range(0, 64): - with T.block([64], "A_1") as vi: - T.bind(vi, i + 64) + with T.block("A_1"): + vi = T.axis.S(64, i + 64) A[vi] = vi + 2 - with T.block([128], "B") as vi: - B[vi] = A[vi] * 2.0 + for i in range(0, 128): + with T.block("B"): + vi = T.axis.S(128, i) + B[vi] = A[vi] * 2.0 @T.prim_func @@ -132,9 +156,11 @@ def lca_at_loop(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, (128,)) C = T.match_buffer(c, (128,)) for i in range(0, 128): - with T.block([128], "B") as vi: + with T.block("B"): + vi = T.axis.S(128, i) B[vi] = A[vi] * 2.0 - with T.block([128], "C") as vi: + with T.block("C"): + vi = T.axis.S(128, i) C[vi] = B[vi] + 1.0 @@ -143,18 +169,20 @@ def multi_producer_consumer(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128,)) B = T.match_buffer(b, (128,)) for i in range(0, 64): - with T.block([64], "A_0") as vi: + with T.block("A_0"): + vi = T.axis.S(64, i) A[vi] = vi + 1 for i in range(0, 64): - with T.block([64], "A_1") as vi: - T.bind(vi, i + 64) + with T.block("A_1"): + vi = T.axis.S(64, i + 64) A[vi] = vi + 2 for i in range(0, 64): - with T.block([64], "B_0") as vi: + with T.block("B_0"): + vi = T.axis.S(64, i) B[vi] = A[vi] + 2.0 for i in range(0, 64): - with T.block([64], "B_1") as vi: - T.bind(vi, i + 64) + with T.block("B_1"): + vi = T.axis.S(64, i + 64) B[vi] = A[vi] + 3.0 @@ -164,12 +192,14 @@ def elementwise_affine_producer(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") for i, j, k, l in T.grid(16, 2, 32, 16): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i * 8 + j * 4 + k // 8) - T.bind(vj, k % 8 * 16 + l) + with T.block("B"): + vi = T.axis.S(128, i * 8 + j * 4 + k // 8) + vj = T.axis.S(128, k % 8 * 16 + l) B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -177,13 +207,19 @@ def elementwise_subblock(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") - with T.block([32, 32], "B") as [vi, vj]: - T.reads([A[vi * 4 : vi * 4 + 4, vj * 4 : vj * 4 + 4]]) - T.writes([B[vi * 4 : vi * 4 + 4, vj * 4 : vj * 4 + 4]]) - with T.block([4, 4], "B_sub") as [vi_i, vj_i]: - B[vi * 4 + vi_i, vj * 4 + vj_i] = A[vi * 4 + vi_i, vj * 4 + vj_i] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(32, 32): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads([A[vi * 4 : vi * 4 + 4, vj * 4 : vj * 4 + 4]]) + T.writes([B[vi * 4 : vi * 4 + 4, vj * 4 : vj * 4 + 4]]) + for ii, jj in T.grid(4, 4): + with T.block("B_sub"): + vi_i, vj_i = T.axis.remap("SS", [ii, jj]) + B[vi * 4 + vi_i, vj * 4 + vj_i] = A[vi * 4 + vi_i, vj * 4 + vj_i] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -191,13 +227,19 @@ def elementwise_subblock_uncovered(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") - with T.block([32, 32], "B") as [vi, vj]: - T.reads([A[vi * 4 : vi * 4 + 2, vj * 4 : vj * 4 + 2]]) - T.writes([B[vi * 4 : vi * 4 + 2, vj * 4 : vj * 4 + 2]]) - with T.block([2, 2], "B_sub") as [vi_i, vj_i]: - B[vi * 4 + vi_i, vj * 4 + vj_i] = A[vi * 4 + vi_i, vj * 4 + vj_i] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(32, 32): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads([A[vi * 4 : vi * 4 + 2, vj * 4 : vj * 4 + 2]]) + T.writes([B[vi * 4 : vi * 4 + 2, vj * 4 : vj * 4 + 2]]) + for ii, jj in T.grid(2, 2): + with T.block("B_sub"): + vi_i, vj_i = T.axis.remap("SS", [ii, jj]) + B[vi * 4 + vi_i, vj * 4 + vj_i] = A[vi * 4 + vi_i, vj * 4 + vj_i] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -207,10 +249,12 @@ def bound_to_thread(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer([128, 128], scope="shared") for i in T.thread_binding(0, 128, thread="threadIdx.x"): for j in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for j in T.serial(0, 128): - with T.block([128, 128], "C") as [vi, vj]: + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) C[vj, vi] = B[vj, vi] + 1.0 @@ -222,14 +266,14 @@ def equal_ranked_threads(a: T.handle, c: T.handle) -> None: for i_o in T.thread_binding(0, 16, thread="threadIdx.x"): for i_i in T.thread_binding(0, 8, thread="threadIdx.y"): for j in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i_o * 8 + i_i) - T.bind(vj, j) + with T.block("B"): + vi = T.axis.S(128, i_o * 8 + i_i) + vj = T.axis.S(128, j) B[vi, vj] = A[vi, vj] * 2.0 for j in T.serial(0, 128): - with T.block([128, 128], "C") as [vi, vj]: - T.bind(vi, i_o * 8 + i_i) - T.bind(vj, j) + with T.block("C"): + vi = T.axis.S(128, i_o * 8 + i_i) + vj = T.axis.S(128, j) C[vj, vi] = B[vj, vi] + 1.0 @@ -241,10 +285,12 @@ def warp_memory(a: T.handle, c: T.handle) -> None: for i_o in T.thread_binding(0, 4, thread="threadIdx.y"): for i_i in T.thread_binding(0, 32, thread="threadIdx.x"): for j in T.serial(0, 128): - with T.block([4, 32, 128], "B") as [warp_id, lane_id, vj]: + with T.block("B"): + warp_id, lane_id, vj = T.axis.remap("SSS", [i_o, i_i, j]) B[vj, warp_id, lane_id] = A[warp_id * 32 + lane_id, vj] * 2.0 for j in T.serial(0, 128): - with T.block([4, 32, 128], "C") as [warp_id, lane_id, vj]: + with T.block("C"): + warp_id, lane_id, vj = T.axis.remap("SSS", [i_o, i_i, j]) C[warp_id * 32 + lane_id, vj] = B[vj, warp_id, lane_id] + 1.0 @@ -256,11 +302,15 @@ def warp_memory_negative(a: T.handle, c: T.handle) -> None: for i_o in T.thread_binding(0, 4, thread="threadIdx.y"): for i_i in T.thread_binding(0, 32, thread="threadIdx.x"): for j in T.serial(0, 128): - with T.block([4, 32, 128], "B") as [warp_id, lane_id, vj]: + with T.block("B"): + warp_id, lane_id, vj = T.axis.remap("SSS", [i_o, i_i, j]) B[vj, warp_id, lane_id] = A[warp_id * 32 + lane_id, vj] * 2.0 for i_o_prime in T.thread_binding(0, 4, thread="threadIdx.y"): for j in T.serial(0, 128): - with T.block([4, 32, 4, 128], "C") as [_warp_id, lane_id, warp_id, vj]: + with T.block("C"): + _warp_id, warp_id, lane_id, vj = T.axis.remap( + "SSSS", [i_o, i_i, i_o_prime, j] + ) C[warp_id * 32 + lane_id, vj] = B[vj, warp_id, lane_id] + 1.0 diff --git a/tests/python/unittest/test_tir_schedule_storage_align.py b/tests/python/unittest/test_tir_schedule_storage_align.py index 7d0e91f70e609..3b699fd8f1b2d 100644 --- a/tests/python/unittest/test_tir_schedule_storage_align.py +++ b/tests/python/unittest/test_tir_schedule_storage_align.py @@ -29,22 +29,20 @@ def element_wise(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body - with T.block([], "root"): + with T.block("root"): T.reads([]) T.writes([]) B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) for i0 in T.serial(0, 128): for ax1 in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i0) - T.bind(vj, ax1) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i0, ax1]) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) B[vi, vj] = (A[vi, vj]*T.float32(2)) for i1 in T.serial(0, 128): - with T.block([128, 128], "C") as [vi_1, vj_1]: - T.bind(vi_1, i0) - T.bind(vj_1, i1) + with T.block("C"): + vi_1, vj_1 = T.axis.remap("SS", [i0, i1]) T.reads([B[vi_1, vj_1]]) T.writes([C[vi_1, vj_1]]) C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1)) @@ -55,23 +53,21 @@ def element_wise_storage_align(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body - with T.block([], "root"): + with T.block("root"): T.reads([]) T.writes([]) B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) for i0 in T.serial(0, 128): for ax1 in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: - T.bind(vi, i0) - T.bind(vj, ax1) + with T.block("B"): + vi, vj = T.axis.remap("SS", [i0, ax1]) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) T.block_attr({"buffer_dim_align":[[0, 0, 128, 127]]}) B[vi, vj] = (A[vi, vj]*T.float32(2)) for i1 in T.serial(0, 128): - with T.block([128, 128], "C") as [vi_1, vj_1]: - T.bind(vi_1, i0) - T.bind(vj_1, i1) + with T.block("C"): + vi_1, vj_1 = T.axis.remap("SS", [i0, i1]) T.reads([B[vi_1, vj_1]]) T.writes([C[vi_1, vj_1]]) C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1)) @@ -82,23 +78,21 @@ def element_wise_invalid_annotation(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body - with T.block([], "root"): + with T.block("root"): T.reads([]) T.writes([]) B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) for i0 in T.serial(0, 128): for ax1 in T.serial(0, 128): - with T.block([128, 128], "B") as [vi, vj]: + with T.block("B"): T.block_attr({"buffer_dim_align": [0]}) - T.bind(vi, i0) - T.bind(vj, ax1) + vi, vj = T.axis.remap("SS", [i0, ax1]) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) B[vi, vj] = (A[vi, vj]*T.float32(2)) for i1 in T.serial(0, 128): - with T.block([128, 128], "C") as [vi_1, vj_1]: - T.bind(vi_1, i0) - T.bind(vj_1, i1) + with T.block("C"): + vi_1, vj_1 = T.axis.remap("SS", [i0, i1]) T.reads([B[vi_1, vj_1]]) T.writes([C[vi_1, vj_1]]) C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1)) diff --git a/tests/python/unittest/test_tir_schedule_trace.py b/tests/python/unittest/test_tir_schedule_trace.py index 36e05c6b51701..f1c97c57b2ff0 100644 --- a/tests/python/unittest/test_tir_schedule_trace.py +++ b/tests/python/unittest/test_tir_schedule_trace.py @@ -32,18 +32,24 @@ def elementwise(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func def elementwise_inlined(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 # pylint: enable=no-member,invalid-name,unused-variable diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py index 185d229b44e14..440d0ab67a50a 100644 --- a/tests/python/unittest/test_tir_schedule_utilities.py +++ b/tests/python/unittest/test_tir_schedule_utilities.py @@ -34,10 +34,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block([128, 128], "init") as [vi, vj]: + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.float32(0) for k in range(0, 128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] diff --git a/tests/python/unittest/test_tir_specialize.py b/tests/python/unittest/test_tir_specialize.py index 86dc5dffed9f1..72666a89ebcb0 100644 --- a/tests/python/unittest/test_tir_specialize.py +++ b/tests/python/unittest/test_tir_specialize.py @@ -27,10 +27,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle, n: T.int32) -> None: B = T.match_buffer(b, [m, n]) C = T.match_buffer(c, [m, m]) - with T.block([m, m, T.reduce_axis(0, n)], "update") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(m, m, n): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -39,10 +41,12 @@ def matmul_128(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -52,10 +56,12 @@ def matmul_m_128(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [m, 128]) C = T.match_buffer(c, [m, m]) - with T.block([m, m, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(m, m, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -66,10 +72,12 @@ def matmul_m_8x(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [m, x * 8]) C = T.match_buffer(c, [m, m]) - with T.block([m, m, T.reduce_axis(0, x * 8)], "update") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(m, m, x * 8): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -81,11 +89,15 @@ def element_wise(a: T.handle, c: T.handle) -> None: B = T.alloc_buffer((m, n), "float32") - with T.block([m, n], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(m, n): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 - with T.block([m, n], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(m, n): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -94,11 +106,15 @@ def element_wise_128_64(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 64), "float32") B = T.alloc_buffer((128, 64), "float32") - with T.block([128, 64], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 64): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, 64], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, 64): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -108,11 +124,15 @@ def element_wise_128_n(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, n), "float32") B = T.alloc_buffer((128, n), "float32") - with T.block([128, n], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, n): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 - with T.block([128, n], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for i, j in T.grid(128, n): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func @@ -120,8 +140,10 @@ def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32, p: T.int32, q: T. A = T.match_buffer(a, (m, n), "float32", strides=[p, 1], elem_offset=q) B = T.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=q) - with T.block([m, n], "") as [vi, vj]: - B[vi, vj] = A[vi, vj] + for i, j in T.grid(m, n): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] @T.prim_func @@ -129,8 +151,10 @@ def mem_copy_16_16_8_4(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32", strides=[8, 1], elem_offset=4) B = T.match_buffer(b, (16, 16), "float32", strides=[8, 1], elem_offset=4) - with T.block([16, 16], "") as [vi, vj]: - B[vi, vj] = A[vi, vj] + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] @T.prim_func @@ -138,8 +162,10 @@ def mem_copy_m_n_p_n(a: T.handle, b: T.handle, m: T.int32, n: T.int32, p: T.int3 A = T.match_buffer(a, (m, n), "float32", strides=[p, 1], elem_offset=n) B = T.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=n) - with T.block([m, n], "") as [vi, vj]: - B[vi, vj] = A[vi, vj] + for i, j in T.grid(m, n): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] @T.prim_func @@ -147,8 +173,10 @@ def param_in_arith_exprs(a: T.handle, b: T.handle) -> None: n = T.var("int32") A = T.match_buffer(a, [n // 8, 8], "int32") B = T.match_buffer(b, [n], "int32") - with T.block([n - 1], "") as [vi]: - B[vi] = A[vi // 8, vi % 8] + (n + 1) * 42 + for i in range(n - 1): + with T.block(): + vi = T.axis.S(n - 1, i) + B[vi] = A[vi // 8, vi % 8] + (n + 1) * 42 @T.prim_func @@ -156,8 +184,10 @@ def param_in_arith_exprs_n_16(a: T.handle, b: T.handle) -> None: n = T.var("int32") A = T.match_buffer(a, [2, 8], "int32") B = T.match_buffer(b, [16], "int32") - with T.block([15], "") as [vi]: - B[vi] = A[vi // 8, vi % 8] + 714 + for i in range(15): + with T.block(): + vi = T.axis.S(15, i) + B[vi] = A[vi // 8, vi % 8] + 714 def test_specialize_nothing(): diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index 0cfc724e41de2..7d3115428f5ad 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -32,17 +32,17 @@ def elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((16, 16), "float32") for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(A[i, j]) T.writes(B[i, j]) B[i, j] = A[i, j] + 1.0 for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(B[i, j]) T.writes(C[i, j]) C[i, j] = B[i, j] * 2.0 @@ -53,7 +53,7 @@ def compacted_elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((1, 16), "float32") @@ -74,7 +74,7 @@ def unschedulable_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((16, 16), "float32") @@ -89,11 +89,11 @@ def param_buffer_access_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (20, 20), "float32") B = T.match_buffer(c, (20, 20), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(B[i, 0:16]) for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(A[i, j]) T.writes(B[i, j]) B[i, j] = A[i, j] + 1.0 @@ -106,17 +106,17 @@ def shared_mem_func(a: T.handle, c: T.handle) -> None: for i0 in T.thread_binding(0, 2, thread="blockIdx.x"): for i1 in T.thread_binding(0, 2, thread="vthread"): for i2 in T.thread_binding(0, 4, thread="threadIdx.x"): - with T.block([]): + with T.block(): T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) B = T.alloc_buffer((16, 16), "float32", scope="shared") for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(A[i0 * 8 + i1 * 4 + i2, j]) T.writes(B[i0 * 8 + i1 * 4 + i2, j]) B[i0 * 8 + i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(B[i0 * 8 + i1 * 4 + i2, j]) T.writes(C[i0 * 8 + i1 * 4 + i2, j]) C[i0 * 8 + i1 * 4 + i2, j] = B[i0 * 8 + i1 * 4 + i2, j] * 2.0 @@ -129,17 +129,17 @@ def compacted_shared_mem_func(a: T.handle, c: T.handle) -> None: for i0 in T.thread_binding(0, 2, thread="blockIdx.x"): for i1 in T.thread_binding(0, 2, thread="vthread"): for i2 in T.thread_binding(0, 4, thread="threadIdx.x"): - with T.block([]): + with T.block(): T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) B = T.alloc_buffer((8, 16), "float32", scope="shared") for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(A[i0 * 8 + i1 * 4 + i2, j]) T.writes(B[i1 * 4 + i2, j]) B[i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(B[i1 * 4 + i2, j]) T.writes(C[i0 * 8 + i1 * 4 + i2, j]) C[i0 * 8 + i1 * 4 + i2, j] = B[i1 * 4 + i2, j] * 2.0 @@ -152,17 +152,17 @@ def warp_mem_func(a: T.handle, c: T.handle) -> None: for i0 in T.thread_binding(0, 2, thread="blockIdx.x"): for i1 in T.thread_binding(0, 2, thread="vthread"): for i2 in T.thread_binding(0, 4, thread="threadIdx.x"): - with T.block([]): + with T.block(): T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) B = T.alloc_buffer((16, 16), "float32", scope="warp") for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(A[i0 * 8 + i1 * 4 + i2, j]) T.writes(B[i0 * 8 + i1 * 4 + i2, j]) B[i0 * 8 + i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(B[i0 * 8 + i1 * 4 + i2, j]) T.writes(C[i0 * 8 + i1 * 4 + i2, j]) C[i0 * 8 + i1 * 4 + i2, j] = B[i0 * 8 + i1 * 4 + i2, j] * 2.0 @@ -175,17 +175,17 @@ def compacted_warp_mem_func(a: T.handle, c: T.handle) -> None: for i0 in T.thread_binding(0, 2, thread="blockIdx.x"): for i1 in T.thread_binding(0, 2, thread="vthread"): for i2 in T.thread_binding(0, 4, thread="threadIdx.x"): - with T.block([]): + with T.block(): T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) B = T.alloc_buffer((4, 16), "float32", scope="warp") for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(A[i0 * 8 + i1 * 4 + i2, j]) T.writes(B[i2, j]) B[i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(B[i2, j]) T.writes(C[i0 * 8 + i1 * 4 + i2, j]) C[i0 * 8 + i1 * 4 + i2, j] = B[i2, j] * 2.0 @@ -196,17 +196,17 @@ def symbolic_func(a: T.handle, c: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (n * 8,), "float32") C = T.match_buffer(c, (n * 8,), "float32") for i in range(0, n): - with T.block([]): + with T.block(): T.reads(A[i * 8 : i * 8 + 8]) T.writes(C[i * 8 : i * 8 + 8]) B = T.alloc_buffer((n * 8,), "float32") for j in range(0, 8): - with T.block([]) as []: + with T.block() as []: T.reads(A[i * 8 + j]) T.writes(B[i * 8 + j]) B[i * 8 + j] = A[i * 8 + j] + 1.0 for j in range(0, 8): - with T.block([]) as []: + with T.block() as []: T.reads(B[i * 8 + j]) T.writes(C[i * 8 + j]) C[i * 8 + j] = B[i * 8 + j] * 2.0 @@ -217,17 +217,17 @@ def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (n * 8,), "float32") C = T.match_buffer(c, (n * 8,), "float32") for i in range(0, n): - with T.block([]): + with T.block(): T.reads(A[i * 8 : i * 8 + 8]) T.writes(C[i * 8 : i * 8 + 8]) B = T.alloc_buffer((8,), "float32") for j in range(0, 8): - with T.block([]) as []: + with T.block() as []: T.reads(A[i * 8 + j]) T.writes(B[j]) B[j] = A[i * 8 + j] + 1.0 for j in range(0, 8): - with T.block([]) as []: + with T.block() as []: T.reads(B[j]) T.writes(C[i * 8 + j]) C[i * 8 + j] = B[j] * 2.0 @@ -238,12 +238,12 @@ def complex_func(a: T.handle, c: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (8, 8), "float32") C = T.match_buffer(c, (8, 8), "float32") for i in range(0, 8): - with T.block([]): + with T.block(): T.reads(A[0, 8]) T.writes(C[0, 8]) B = T.alloc_buffer((8, 8), "float32") for j in range(0, 4): - with T.block([]) as []: + with T.block() as []: D = T.alloc_buffer((8, 8), "float32") T.reads(A[i, j]) T.writes(B[i, j]) @@ -252,12 +252,12 @@ def complex_func(a: T.handle, c: T.handle, n: T.int32) -> None: for k in range(2, 4): T.store(B.data, j, A[i, j] + D[k, j]) for j in range(3, 5): - with T.block([]) as []: + with T.block() as []: T.reads(B[i, j]) T.writes(C[i, j]) C[i, j] = B[i, j] for j in range(6, 8): - with T.block([]) as []: + with T.block() as []: T.reads(B[i, j]) T.writes(C[i, j]) C[i, j] = B[i, j] @@ -268,12 +268,12 @@ def compacted_complex_func(a: T.handle, c: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (8, 8), "float32") C = T.match_buffer(c, (8, 8), "float32") for i in range(0, 8): - with T.block([]): + with T.block(): T.reads(A[0, 8]) T.writes(C[0, 8]) B = T.alloc_buffer((1, 8), "float32") for j in range(0, 4): - with T.block([]) as []: + with T.block() as []: D = T.alloc_buffer((6, 1), "float32") T.reads(A[i, j]) T.writes(B[0, j]) @@ -282,12 +282,12 @@ def compacted_complex_func(a: T.handle, c: T.handle, n: T.int32) -> None: for k in range(2, 4): T.store(B.data, j, A[i, j] + D[k - 2, 0]) for j in range(3, 5): - with T.block([]) as []: + with T.block() as []: T.reads(B[0, j]) T.writes(C[i, j]) C[i, j] = B[0, j] for j in range(6, 8): - with T.block([]) as []: + with T.block() as []: T.reads(B[0, j]) T.writes(C[i, j]) C[i, j] = B[0, j] @@ -298,19 +298,19 @@ def match_buffer_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16)) C = T.match_buffer(c, (16, 16)) for i in range(0, 16): - with T.block([]): + with T.block(): A0 = T.match_buffer(A[i, 0:16], (16)) C0 = T.match_buffer(C[i, 0:16], (16)) B = T.alloc_buffer((16, 16)) - with T.block([]): + with T.block(): B0 = T.match_buffer(B[i, 0:16], (16)) for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: A1 = T.match_buffer(A0[j], ()) B1 = T.match_buffer(B0[j], ()) B1[()] = A1[()] + 1.0 for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: C1 = T.match_buffer(C0[j], ()) B2 = T.match_buffer(B[i, j], ()) C1[()] = B2[()] * 2.0 @@ -321,19 +321,19 @@ def compacted_match_buffer_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16)) C = T.match_buffer(c, (16, 16)) for i in range(0, 16): - with T.block([]): + with T.block(): A0 = T.match_buffer(A[i, 0:16], (16)) C0 = T.match_buffer(C[i, 0:16], (16)) B = T.alloc_buffer((1, 16)) - with T.block([]): + with T.block(): B0 = T.match_buffer(B[0, 0:16], (16)) for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: A1 = T.match_buffer(A0[j], ()) B1 = T.match_buffer(B0[j], ()) B1[()] = A1[()] + 1.0 for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: C1 = T.match_buffer(C0[j], ()) B2 = T.match_buffer(B[0, j], ()) C1[()] = B2[()] * 2.0 @@ -344,18 +344,18 @@ def storage_align_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((16, 16), "float32") for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(A[i, j]) T.writes(B[i, j]) T.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]}) B[i, j] = A[i, j] + 1.0 for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads(B[i, j]) T.writes(C[i, j]) C[i, j] = B[i, j] * 2.0 @@ -366,7 +366,7 @@ def compacted_storage_align_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((1, 16), strides=(31, 1), dtypes="float32") diff --git a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py index 287a30916520c..ee323a64c50f8 100644 --- a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py +++ b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py @@ -32,19 +32,19 @@ def elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((16, 16), "float32") for j in range(0, 16): - with T.block([16, 16]) as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block(): + vi = T.axis.S(16, i) + vj = T.axis.S(16, j) B[vi, vj] = A[vi, vj] + 1.0 for j in range(0, 16): - with T.block([16, 16]) as [vi, vj]: - T.bind(vi, i) - T.bind(vj, j) + with T.block(): + vi = T.axis.S(16, i) + vj = T.axis.S(16, j) C[vi, vj] = B[vi, vj] * 2.0 @@ -53,7 +53,7 @@ def substituted_elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer([16, 16], "float32") diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index 21c896c7bb7ef..eed82ebb91187 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -32,7 +32,7 @@ def compacted_elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with T.block([]): + with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer([1, 16], "float32", scope="global") @@ -67,7 +67,7 @@ def compacted_gpu_func(a: T.handle, c: T.handle) -> None: for i0 in T.thread_binding(0, 4, thread="blockIdx.x"): for i1 in T.thread_binding(0, 2, thread="threadIdx.x"): for i2 in T.thread_binding(0, 2, thread="vthread"): - with T.block([]): + with T.block(): T.reads(A[i0 * 4 + i1 * 2 + i2, 0:16]) T.writes(C[i0 * 4 + i1 * 2 + i2, 0:16]) B = T.alloc_buffer([1, 16], "float32", scope="local") @@ -108,17 +108,17 @@ def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> C = T.match_buffer(c, (n, m), "float32") for i in range(0, n): - with T.block([]): + with T.block(): T.reads(A[i, m]) T.writes(C[i, m]) B = T.alloc_buffer((m,), "float32", scope="global") for j in range(0, m): - with T.block([]) as []: + with T.block() as []: T.reads(A[i, j]) T.writes(B[j]) B[j] = A[i, j] + 1.0 for j in range(0, m): - with T.block([]) as []: + with T.block() as []: T.reads(B[j]) T.writes(C[i, j]) C[i, j] = B[j] * 2.0 @@ -143,7 +143,7 @@ def compacted_predicate_func(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (32), "float32") for i, j in T.grid(5, 7): - with T.block([]) as []: + with T.block() as []: T.reads(A[i * 7 + j]) T.writes(C[i * 7 + j]) T.where(i * 7 + j < 32) @@ -166,7 +166,7 @@ def compacted_unit_loop_func(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (32), "float32") for x, y, z in T.grid(4, 1, 8): - with T.block([]) as []: + with T.block() as []: T.reads(A[x * 8 + y * 8 + z]) T.writes(C[x * 8 + y * 8 + z]) C[x * 8 + y * 8 + z] = A[x * 8 + y * 8 + z] + 1.0 @@ -187,7 +187,7 @@ def compacted_multi_alloc_func(a: T.handle, d: T.handle) -> None: D = T.match_buffer(d, (32), "float32") for i in range(0, 32): - with T.block([]) as []: + with T.block() as []: T.reads(A[i]) T.writes(D[i]) B = T.alloc_buffer((32,), scope="global") @@ -215,7 +215,7 @@ def compacted_strided_buffer_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i0 in range(0, 4): - with T.block([]): + with T.block(): T.reads(A[i0 * 4 : i0 * 4 + 4, 0:16]) T.writes(C[i0 * 4 : i0 * 4 + 4, 0:16]) B = T.alloc_buffer([4, 16], "float32", strides=[17, 1], scope="global") diff --git a/tests/python/unittest/test_tir_transform_lower_init_block.py b/tests/python/unittest/test_tir_transform_lower_init_block.py index c1c4fb3d2e8fa..a4fd9404eee4f 100644 --- a/tests/python/unittest/test_tir_transform_lower_init_block.py +++ b/tests/python/unittest/test_tir_transform_lower_init_block.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import tir, te +from tvm import te from tvm.script import tir as T # pylint: disable=no-self-argument @@ -28,10 +28,13 @@ def main(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [64, 64, 64]) B = T.match_buffer(b, [64]) - with T.block([64, T.reduce_axis(0, 64), T.reduce_axis(32, 64)]) as [i, j, k]: - with T.init(): - B[i] = T.float32(0) - B[i] += A[i, j, k] + for i0, j0 in T.grid(64, 64): + for k0 in T.serial(32, 64): + with T.block(): + i, j, k = T.axis.remap("SRR", [i0, j0, k0]) + with T.init(): + B[i] = T.float32(0) + B[i] += A[i, j, k] @tvm.script.ir_module @@ -41,10 +44,13 @@ def main(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [64, 64, 64]) B = T.match_buffer(b, [64]) - with T.block([64, T.reduce_axis(0, 64), T.reduce_axis(32, 64)]) as [i, j, k]: - if (j == 0) and (k == 32): - B[i] = T.float32(0) - B[i] += A[i, j, k] + for i0, j0 in T.grid(64, 64): + for k0 in T.serial(32, 64): + with T.block(): + i, j, k = T.axis.remap("SRR", [i0, j0, k0]) + if (j == 0) and (k == 32): + B[i] = T.float32(0) + B[i] += A[i, j, k] @tvm.script.ir_module @@ -54,12 +60,15 @@ def main(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [64, 64, 64]) B = T.match_buffer(b, [64]) - with T.block([64, T.reduce_axis(0, 64), T.reduce_axis(32, 64)]) as [i, j, k]: - BB = T.match_buffer(B[i], ()) - AA = T.match_buffer(A[i, 0:64, 0:64], (64, 64)) - with T.init(): - BB[()] = T.float32(0) - BB[()] += AA[j, k] + for i0, j0 in T.grid(64, 64): + for k0 in T.serial(32, 64): + with T.block(): + i, j, k = T.axis.remap("SRR", [i0, j0, k0]) + BB = T.match_buffer(B[i], ()) + AA = T.match_buffer(A[i, 0:64, 0:64], (64, 64)) + with T.init(): + BB[()] = T.float32(0) + BB[()] += AA[j, k] @tvm.script.ir_module @@ -69,17 +78,21 @@ def main(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [64, 64, 64]) B = T.match_buffer(b, [64]) - with T.block([64, T.reduce_axis(0, 64), T.reduce_axis(32, 64)]) as [i, j, k]: - BB = T.match_buffer(B[i], ()) - AA = T.match_buffer(A[i, 0:64, 0:64], (64, 64)) - if (j == 0) and (k == 32): - BB[()] = T.float32(0) - BB[()] += AA[j, k] + for i0, j0 in T.grid(64, 64): + for k0 in T.serial(32, 64): + with T.block(): + i, j, k = T.axis.remap("SRR", [i0, j0, k0]) + BB = T.match_buffer(B[i], ()) + AA = T.match_buffer(A[i, 0:64, 0:64], (64, 64)) + if (j == 0) and (k == 32): + BB[()] = T.float32(0) + BB[()] += AA[j, k] def test_lower_reduction(): origin_mod = WithInit mod = tvm.tir.transform.LowerInitBlock()(origin_mod) + print(mod.script()) tvm.ir.assert_structural_equal(mod, WithBranch, True) diff --git a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py index e55555305a09c..c22f5f82ee10f 100644 --- a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import tir, te +from tvm import te from tvm.script import tir as T @@ -31,12 +31,14 @@ def element_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16)) C = T.match_buffer(c, (16, 16)) B = T.alloc_buffer((16, 16)) - for i_0 in range(0, 16): - for j_0 in range(0, 16): - with T.block([16, 16]) as [i, j]: + for i0 in range(0, 16): + for j0 in range(0, 16): + with T.block(): + i, j = T.axis.remap("SS", [i0, j0]) B[i, j] = A[i, j] + 1.0 - for j_0 in range(0, 16): - with T.block([16, 16]) as [i, j]: + for j0 in range(0, 16): + with T.block(): + i, j = T.axis.remap("SS", [i0, j0]) C[i, j] = B[i, j] * 2.0 @@ -46,95 +48,112 @@ def transformed_element_func(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [16, 16]) for i_0 in range(0, 16): - with T.block([]): + with T.block(): T.reads([A[i_0, 0:16]]) T.writes([C[i_0, 0:16]]) B = T.alloc_buffer([16, 16]) for j_0 in T.serial(0, 16): - with T.block([16, 16], "") as [i, j]: - T.bind(i, i_0) - T.bind(j, j_0) + with T.block(): + i, j = T.axis.remap("SS", [i_0, j_0]) B[i, j] = A[i, j] + 1.0 for j_0 in T.serial(0, 16): - with T.block([16, 16], "") as [i, j]: - T.bind(i, i_0) - T.bind(j, j_0) + with T.block(): + i, j = T.axis.remap("SS", [i_0, j_0]) C[i, j] = B[i, j] * 2.0 @T.prim_func def original_func() -> None: A = T.alloc_buffer((128, 128), "float32") - with T.block([128, 128]) as [i, j]: - A[i, j] = T.float32(0) - with T.block([32, 32, T.reduce_axis(0, 32)]) as [i, j, k]: - B = T.alloc_buffer((128, 128), "float32") - C = T.alloc_buffer((128, 128), "float32") - D = T.alloc_buffer((128, 128), "float32") - if k == 0: + for i0, j0 in T.grid(128, 128): + with T.block(): + i, j = T.axis.remap("SS", [i0, j0]) + A[i, j] = T.float32(0) + for i0, j0, k0 in T.grid(32, 32, 32): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + B = T.alloc_buffer((128, 128), "float32") + C = T.alloc_buffer((128, 128), "float32") + D = T.alloc_buffer((128, 128), "float32") + if k == 0: + for ii, jj in T.grid(4, 4): + B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] for ii, jj in T.grid(4, 4): - B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] - for ii, jj in T.grid(4, 4): - for kk in range(0, 4): - B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk] - for kk in range(0, 4): - B[i * 4 + ii, j * 4 + jj] += D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk] + for kk in range(0, 4): + B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk] + for kk in range(0, 4): + B[i * 4 + ii, j * 4 + jj] += ( + D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk] + ) @T.prim_func def transformed_func() -> None: A = T.alloc_buffer([128, 128]) - with T.block([128, 128], "") as [i, j]: - A[i, j] = T.float32(0) - with T.block([32, 32, T.reduce_axis(0, 32)], "") as [i, j, k]: - B = T.alloc_buffer([128, 128]) - if k == 0: + for i0, j0 in T.grid(128, 128): + with T.block(): + i, j = T.axis.remap("SS", [i0, j0]) + A[i, j] = T.float32(0) + for i0, j0, k0 in T.grid(32, 32, 32): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + B = T.alloc_buffer([128, 128]) + if k == 0: + for ii, jj in T.grid(4, 4): + B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] for ii, jj in T.grid(4, 4): - B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] - for ii, jj in T.grid(4, 4): - with T.block([], ""): - T.reads([B[((i * 4) + ii), ((j * 4) + jj)]]) - T.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) - C = T.alloc_buffer([128, 128]) - for kk in T.serial(0, 4): - B[((i * 4) + ii), ((j * 4) + jj)] = ( - B[((i * 4) + ii), ((j * 4) + jj)] + C[((i * 4) + ii), ((k * 4) + kk)] - ) - for kk in T.serial(0, 4): - with T.block([], ""): - T.reads( - [ - B[((i * 4) + ii), ((j * 4) + jj)], - C[((i * 4) + ii), ((k * 4) + kk)], - ] - ) - T.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) - D = T.alloc_buffer([128, 128]) - B[((i * 4) + ii), ((j * 4) + jj)] = B[((i * 4) + ii), ((j * 4) + jj)] + ( - D[((j * 4) + jj), ((k * 4) + kk)] * C[((i * 4) + ii), ((k * 4) + kk)] + with T.block(""): + T.reads([B[((i * 4) + ii), ((j * 4) + jj)]]) + T.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) + C = T.alloc_buffer([128, 128]) + for kk in T.serial(0, 4): + B[((i * 4) + ii), ((j * 4) + jj)] = ( + B[((i * 4) + ii), ((j * 4) + jj)] + C[((i * 4) + ii), ((k * 4) + kk)] ) + for kk in T.serial(0, 4): + with T.block(""): + T.reads( + [ + B[((i * 4) + ii), ((j * 4) + jj)], + C[((i * 4) + ii), ((k * 4) + kk)], + ] + ) + T.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) + D = T.alloc_buffer([128, 128]) + B[((i * 4) + ii), ((j * 4) + jj)] = B[ + ((i * 4) + ii), ((j * 4) + jj) + ] + ( + D[((j * 4) + jj), ((k * 4) + kk)] + * C[((i * 4) + ii), ((k * 4) + kk)] + ) @T.prim_func def match_buffer_func() -> None: C = T.alloc_buffer((128, 128)) - with T.block([128]) as [vi]: - C0 = T.match_buffer(C[vi, 0:128], (128)) - with T.block([128]) as [jj]: - C1 = T.match_buffer(C0[jj], ()) - C1[()] = 0 + for i in range(128): + with T.block(): + vi = T.axis.S(128, i) + C0 = T.match_buffer(C[vi, 0:128], (128)) + for j in range(128): + with T.block(): + jj = T.axis.S(128, j) + C1 = T.match_buffer(C0[jj], ()) + C1[()] = 0 @T.prim_func def transformed_match_buffer_func() -> None: for i in range(0, 128): - with T.block([128]) as [vi]: - T.bind(vi, i) + with T.block(): + vi = T.axis.S(128, i) C = T.alloc_buffer((128, 128)) C0 = T.match_buffer(C[vi, 0:128], (128)) - with T.block([128]) as [jj]: - C1 = T.match_buffer(C0[jj], ()) - C1[()] = 0 + for j in range(128): + with T.block(): + jj = T.axis.S(128, j) + C1 = T.match_buffer(C0[jj], ()) + C1[()] = 0 @T.prim_func @@ -143,9 +162,10 @@ def opaque_access(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [1024]) A_cache = T.alloc_buffer([1024]) for i in T.serial(0, 8): - with T.block([8]) as [vi]: - with T.block([8]) as [v]: - T.bind(v, vi) + with T.block(): + vi = T.axis.S(8, i) + with T.block(): + v = T.axis.S(8, vi) T.reads([A[(v * 128) : ((v * 128) + 128)]]) T.writes([A_cache[(v * 128) : ((v * 128) + 128)]]) T.evaluate( @@ -161,8 +181,8 @@ def opaque_access(a: T.handle, b: T.handle) -> None: ) ) for j in T.serial(0, 128): - with T.block([1024]) as [v]: - T.bind(v, ((vi * 128) + j)) + with T.block(): + v = T.axis.S(1024, vi * 128 + j) T.reads([A_cache[v]]) T.writes([B[v]]) B[v] = A_cache[v] @@ -173,12 +193,13 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [1024]) B = T.match_buffer(b, [1024]) for i in T.serial(0, 8): - with T.block([8]) as [vi]: + with T.block(): + vi = T.axis.S(8, i) T.reads(A[vi * 128 : vi * 128 + 128]) T.writes(B[vi * 128 : vi * 128 + 128]) A_cache = T.alloc_buffer([1024]) - with T.block([8]) as [v]: - T.bind(v, vi) + with T.block(): + v = T.axis.S(8, vi) T.reads([A[v * 128 : v * 128 + 128]]) T.writes([A_cache[v * 128 : v * 128 + 128]]) T.evaluate( @@ -187,8 +208,8 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None: ) ) for j in T.serial(0, 128): - with T.block([1024]) as [v]: - T.bind(v, ((vi * 128) + j)) + with T.block(): + v = T.axis.S(1024, vi * 128 + j) T.reads([A_cache[v]]) T.writes([B[v]]) B[v] = A_cache[v] diff --git a/tests/python/unittest/test_tvmscript_complete.py b/tests/python/unittest/test_tvmscript_complete.py index 7c521db21bb84..d3b4b23acdb66 100644 --- a/tests/python/unittest/test_tvmscript_complete.py +++ b/tests/python/unittest/test_tvmscript_complete.py @@ -26,10 +26,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = T.float32(0) - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -39,12 +41,14 @@ def matmul_original(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128]) for i, j in T.grid(32, 32): - with T.block([32, 32], "init") as [vi, vj]: + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) for ii, jj in T.grid(4, 4): C[vi * 4 + ii, vj * 4 + jj] = T.float32(0) for k in range(0, 32): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) for ii, jj, kk in T.grid(4, 4, 4): C[vi * 4 + ii, vj * 4 + jj] = ( C[vi * 4 + ii, vj * 4 + jj] @@ -58,12 +62,15 @@ def elementwise_with_root(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - with T.block([]) as []: - with T.block([128, 128]) as [vi, vj]: - B[vi, vj] = A[vi, vj] + T.float32(1) - - with T.block([128, 128]) as [vi, vj]: - C[vi, vj] = B[vi, vj] + T.float32(1) + with T.block() as []: + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.float32(1) + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + T.float32(1) def func_with_opaque_block(a: T.handle, b: T.handle, c: T.handle) -> None: @@ -71,12 +78,13 @@ def func_with_opaque_block(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - with T.block([]) as []: - with T.block([]) as []: + with T.block() as []: + with T.block() as []: B[0, 0] = A[0, 0] + T.float32(1) - - with T.block([128, 128]) as [vi, vj]: - C[vi, vj] = B[vi, vj] + T.float32(1) + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + T.float32(1) @T.prim_func @@ -85,14 +93,18 @@ def func_with_part_access_region(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - with T.block([]) as []: - with T.block([128, 128]) as [vi, vj]: - T.reads(A[vi, vj]) - B[vi, vj] = A[vi, vj] + T.float32(1) + with T.block() as []: + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + B[vi, vj] = A[vi, vj] + T.float32(1) - with T.block([128, 128]) as [vi, vj]: - T.writes(C[vi, vj]) - C[vi, vj] = B[vi, vj] + T.float32(1) + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + T.float32(1) def test_complete_matmul(): @@ -181,22 +193,23 @@ def func_with_bufferslice_indices(data: T.handle, index: T.handle) -> None: index_buf = T.match_buffer(index, (1,), "int32") out_buf = T.alloc_buffer((16, 16), "float32") - with T.block([16, 16]) as [vi, vj]: - out_buf[vi, vj] = data_buf[vi, index_buf[0]] + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + out_buf[vi, vj] = data_buf[vi, index_buf[0]] @T.prim_func def expected_bufferslice_indices(data: T.handle, index: T.handle) -> None: index_buf = T.match_buffer(index, [1], dtype="int32", elem_offset=0, align=128, offset_factor=1) data_buf = T.match_buffer(data, [16, 16], elem_offset=0, align=128, offset_factor=1) - with T.block([], "root"): + with T.block("root"): T.reads([]) T.writes([]) out_buf = T.alloc_buffer([16, 16], elem_offset=0, align=128, offset_factor=1) for i0, i1 in T.grid(16, 16): - with T.block([16, 16], "") as [vi, vj]: - T.bind(vi, i0) - T.bind(vj, i1) + with T.block(): + vi, vj = T.axis.remap("SS", [i0, i1]) T.reads([data_buf[vi, 0:16], index_buf[0]]) T.writes([out_buf[vi, vj]]) out_buf[vi, vj] = data_buf[vi, index_buf[0]] @@ -208,22 +221,23 @@ def func_with_recursive_bufferslice_indices(data: T.handle, index: T.handle) -> index_buf = T.match_buffer(index, (1,), "int32") out_buf = T.alloc_buffer((16, 16), "float32") - with T.block([16, 16]) as [vi, vj]: - out_buf[vi, vj] = data_buf[index_buf[index_buf[0]], index_buf[0]] + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + out_buf[vi, vj] = data_buf[index_buf[index_buf[0]], index_buf[0]] @T.prim_func def expected_recursive_bufferslice_indices(data: T.handle, index: T.handle) -> None: index_buf = T.match_buffer(index, [1], dtype="int32", elem_offset=0, align=128, offset_factor=1) data_buf = T.match_buffer(data, [16, 16], elem_offset=0, align=128, offset_factor=1) - with T.block([], "root"): + with T.block("root"): T.reads([]) T.writes([]) out_buf = T.alloc_buffer([16, 16], elem_offset=0, align=128, offset_factor=1) for i0, i1 in T.grid(16, 16): - with T.block([16, 16], "") as [vi, vj]: - T.bind(vi, i0) - T.bind(vj, i1) + with T.block(): + vi, vj = T.axis.remap("SS", [i0, i1]) T.reads([data_buf[0:16, 0:16], index_buf[0]]) T.writes([out_buf[vi, vj]]) out_buf[vi, vj] = data_buf[index_buf[index_buf[0]], index_buf[0]] @@ -240,11 +254,11 @@ def test_complete_buffer_indices(): def match_buffer_func(a: T.handle) -> None: A = T.match_buffer(a, (16, 16)) for i in range(0, 16): - with T.block([]): + with T.block(): A0 = T.match_buffer(A[i, 0:16], (16)) - with T.block([]): + with T.block(): for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: A1 = T.match_buffer(A0[j], ()) A1[()] = 1.0 @@ -253,15 +267,15 @@ def match_buffer_func(a: T.handle) -> None: def expected_match_buffer_func(a: T.handle) -> None: A = T.match_buffer(a, (16, 16)) for i in range(0, 16): - with T.block([]): + with T.block(): T.reads([]) T.writes(A[i, 0:16]) A0 = T.match_buffer(A[i, 0:16], (16)) - with T.block([]): + with T.block(): T.reads([]) T.writes(A0[0:16]) for j in range(0, 16): - with T.block([]) as []: + with T.block() as []: T.reads([]) T.writes(A0[j]) A1 = T.match_buffer(A0[j], ()) diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 99a22636b9272..80c37229f5193 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -155,33 +155,83 @@ def test_allocate_with_buffers(): check_error(allocate_with_buffers, 2) -def inconsistent_binding() -> None: - with T.block([128, 128]) as [vi]: # error +def inconsistent_binding_value() -> None: + for i, j in T.grid(16, 16): + vi, vj = T.axis.remap("SS", [i]) # error + T.evaluate(1.0) + + +def inconsistent_binding_type() -> None: + for i, j in T.grid(16, 16): + vi, vj = T.axis.remap("S", [i, j]) # error T.evaluate(1.0) def test_inconsistent_binding(): - check_error(inconsistent_binding, 2) + check_error(inconsistent_binding_value, 3) + check_error(inconsistent_binding_type, 3) + + +def error_remap_type() -> None: + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("TT", [i, j]) # error + T.evaluate(1.0) + + +def error_remap_value() -> None: + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i + j, j]) # error + T.evaluate(1.0) + + +def test_error_remap_args(): + check_error(error_remap_type, 4) + check_error(error_remap_value, 4) def invalid_block_axes(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") - with T.block([A]) as [vi]: # error - T.evaluate(1.0) + for i, j in T.grid(16, 16): + with T.block(): + vi = T.axis.S(i, A) # error + T.evaluate(1.0) def test_invalid_block_axes(): - check_error(invalid_block_axes, 3) + check_error(invalid_block_axes, 5) -def miss_block_bind() -> None: - with T.block([16, 16]) as [vi, vj]: # error - T.bind(vi, 1) - T.evaluate(1.0) +def duplicate_block_axes() -> None: + for i, j in T.grid(16, 16): + with T.block(): + vi = T.axis.S(16, i) + vi = T.axis.S(16, j) # error + T.evaluate(1.0) + + +def duplicate_block_axes_remap() -> None: + for i, j in T.grid(16, 16): + with T.block(): + vi, vi = T.axis.remap("SS", [i, j]) # error + T.evaluate(1.0) + + +def test_duplicate_block_axes(): + check_error(duplicate_block_axes, 5) + check_error(duplicate_block_axes_remap, 4) + + +def miss_block_bind_value() -> None: + for i, j in T.grid(128, 128): + with T.block(): + vi = T.axis.S(i) # error + T.evaluate(1.0) def test_miss_block_bind(): - check_error(miss_block_bind, 2) + check_error(miss_block_bind_value, 4) def invalid_loop_var() -> None: @@ -203,74 +253,99 @@ def test_inconsistent_grid(): def invalid_match_buffer_region() -> None: - with T.block([16, 16]) as [vi, vj]: - A = T.match_buffer(vi) # error - T.evaluate(1.0) + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + A = T.match_buffer(vi) # error + T.evaluate(1.0) def test_invalid_match_buffer_region(): - check_error(invalid_match_buffer_region, 3) + check_error(invalid_match_buffer_region, 5) def duplicate_buffer() -> None: A = T.alloc_buffer((128, 128), "float32") - with T.block([16, 16]) as [vi, vj]: - A = T.alloc_buffer((128, 128), "float32") # error - T.evaluate(1.0) + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + A = T.alloc_buffer((128, 128), "float32") # error + T.evaluate(1.0) def test_duplicate_buffer(): - check_error(duplicate_buffer, 4) + check_error(duplicate_buffer, 6) def duplicate_reads() -> None: A = T.alloc_buffer((128, 128), "float32") - with T.block([16, 16]) as [vi, vj]: - T.reads(A[0:8, 0:8]) - T.reads(A[0:16, 0:16]) # error - T.evaluate(1.0) + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[0:8, 0:8]) + T.reads(A[0:16, 0:16]) # error + T.evaluate(1.0) def duplicate_writes() -> None: A = T.alloc_buffer((128, 128), "float32") - with T.block([16, 16]) as [vi, vj]: - T.writes(A[0:8, 0:8]) - T.writes(A[0:16, 0:16]) # error - T.evaluate(1.0) + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.writes(A[0:8, 0:8]) + T.writes(A[0:16, 0:16]) # error + T.evaluate(1.0) def duplicate_predicate() -> None: - with T.block([16, 16]) as [vi, vj]: - T.where(1) - T.where(0) # error + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.where(1) + T.where(0) # error def duplicate_annotations() -> None: - with T.block([16, 16]) as [vi, vj]: - T.block_attr({}) - T.block_attr({}) # error + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.block_attr({}) + T.block_attr({}) # error def duplicate_init() -> None: - with T.block([16, 16]) as [vi, vj]: - with T.init(): - T.evaluate(1.0) - with T.init(): # error + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + with T.init(): + T.evaluate(1.0) + with T.init(): # error + T.evaluate(1.0) + + +def duplicate_axes() -> None: + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + vi = T.axis.S(i, 16) # error T.evaluate(1.0) def test_duplicate_block_signature(): - check_error(duplicate_reads, 5) - check_error(duplicate_writes, 5) - check_error(duplicate_predicate, 4) - check_error(duplicate_annotations, 4) - check_error(duplicate_init, 5) + check_error(duplicate_reads, 7) + check_error(duplicate_writes, 7) + check_error(duplicate_predicate, 6) + check_error(duplicate_annotations, 6) + check_error(duplicate_init, 7) + check_error(duplicate_axes, 5) def opaque_access_during_complete(a: T.handle) -> None: # error A = T.match_buffer(a, (16, 16), "float32") - with T.block([16, 16]) as [vi, vj]: - T.evaluate(T.load("float32", A.data, vi * 16 + vj)) + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.evaluate(T.load("float32", A.data, vi * 16 + vj)) def test_opaque_access_during_complete(): @@ -279,55 +354,65 @@ def test_opaque_access_during_complete(): def convert_slice_to_bufferload() -> None: A = T.alloc_buffer((128, 128), "float32") - with T.block([16, 16]) as [vi, vj]: - A[vi, vj] = A[vi : vi + 2, vj] + 1 # error + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = A[vi : vi + 2, vj] + 1 # error def test_convert_slice_to_bufferload(): - check_error(convert_slice_to_bufferload, 4) + check_error(convert_slice_to_bufferload, 6) def error_index_type() -> None: A = T.alloc_buffer((128, 128), "float32") - with T.block([16, 16]) as [vi, vj]: - A[vi, vj] = A[vi, 0.0] + 1 # error + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = A[vi, 0.0] + 1 # error def error_bufferslice_index_type() -> None: A = T.alloc_buffer((1,), "float32") B = T.alloc_buffer((16, 16), "float32") C = T.alloc_buffer((16, 16), "float32") - with T.block([16, 16]) as [vi, vj]: - C[vi, vj] = B[vi, A[0]] # error + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, A[0]] # error def test_error_index_type(): - check_error(error_index_type, 4) - check_error(error_bufferslice_index_type, 6) + check_error(error_index_type, 6) + check_error(error_bufferslice_index_type, 8) def error_index_with_stop() -> None: A = T.alloc_buffer((128, 128), "float32") - with T.block([16, 16]) as [vi, vj]: - A[vi, vj] = A[vi, 1:10] + 1 # error + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = A[vi, 1:10] + 1 # error def error_bufferslice_index_with_stop() -> None: A = T.alloc_buffer((1,), "int32") B = T.alloc_buffer((16, 16), "float32") C = T.alloc_buffer((16, 16), "float32") - with T.block([16, 16]) as [vi, vj]: - C[vi, vj] = B[vi, A[0:1]] # error + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, A[0:1]] # error def test_error_index_with_stop_slice(): - check_error(error_index_with_stop, 4) - check_error(error_bufferslice_index_with_stop, 6) + check_error(error_index_with_stop, 6) + check_error(error_bufferslice_index_with_stop, 8) def mismatch_args() -> None: A = T.alloc_buffer((128, 128), "float32") - with T.block([16, 16]) as [vi, vj]: + with T.block(): T.reads(A[0, 0], A[1, 1]) # error T.evaluate(1.0) @@ -338,8 +423,7 @@ def test_mismatch_args(): def special_stmt_except() -> None: A = T.alloc_buffer("(128, 128)", "float32") # error - with T.block([16, 16]) as [vi, vj]: - T.evaluate(1.0) + T.evaluate(1.0) def scope_handler_except() -> None: @@ -368,7 +452,7 @@ def test_tvm_exception_catch(): def buffer_shape_mismatch(a: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 2): - with T.block([]): + with T.block(): T.reads([]) T.writes([A[i, j * 4 : j * 4 + 4]]) sub_A = T.match_buffer( @@ -383,7 +467,7 @@ def test_match_buffer_shape_mismatch(): def high_dim_store() -> None: - with T.block([], "root"): + with T.block("root"): B = T.allocate([256], "float32", "global") for i, j in T.grid(16, 16): B[i, j] = 1.0 # error: Store is only allowed with one index @@ -393,6 +477,15 @@ def test_high_dim_store(): check_error(high_dim_store, 5) +def block_has_option_vars() -> None: + with T.block("root") as x: # error: block does not support option_vars + T.evaluate(0.0) + + +def test_block_has_option_vars(): + check_error(block_has_option_vars, 2) + + def check_error(func, rel_lineno): # Override the default renderer to accumulate errors errors = [] @@ -416,5 +509,7 @@ def render(e): ), f"Expected error to be on line {rel_lineno}, but it was on {d.span.line - 1}" +# TODO(Siyuan): block iter errors. + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tvmscript_ops.py b/tests/python/unittest/test_tvmscript_ops.py index c55fd7b692823..6203277c42e0d 100644 --- a/tests/python/unittest/test_tvmscript_ops.py +++ b/tests/python/unittest/test_tvmscript_ops.py @@ -37,22 +37,25 @@ def get_valid_counts( out_buf = T.match_buffer(out, (1, 2500, 6), "float32") out_indices_buf = T.match_buffer(out_indices, (1, 2500), "int32") - with T.block([1], "init") as [vi]: + with T.block("init"): + vi = T.axis.S(1, 0) valid_count_buf[vi] = T.int32(0) - with T.block([2500], "update") as [vj]: - T.reads([data_buf[vi, vj, 6]]) - T.writes([valid_count_buf[vi], out_indices_buf[vi, vj], out_buf[vi, vj, 6]]) - if (data_buf[vi, vj, score_index] > score_threshold) and ( - (id_index < 0) or (data_buf[vi, vj, id_index] >= T.float32(0)) - ): - for k in T.serial(0, 6): - out_buf[vi, valid_count_buf[vi], k] = data_buf[vi, vj, k] - out_indices_buf[vi, valid_count_buf[vi]] = vj - valid_count_buf[vi] = valid_count_buf[vi] + 1 - if vj >= valid_count_buf[vi]: - for k in T.serial(0, 6): - out_buf[vi, vj, k] = T.float32(-1) - out_indices_buf[vi, vj] = T.int32(-1) + for j in range(2500): + with T.block("update"): + vj = T.axis.S(2500, j) + T.reads([data_buf[vi, vj, 6]]) + T.writes([valid_count_buf[vi], out_indices_buf[vi, vj], out_buf[vi, vj, 6]]) + if (data_buf[vi, vj, score_index] > score_threshold) and ( + (id_index < 0) or (data_buf[vi, vj, id_index] >= T.float32(0)) + ): + for k in T.serial(0, 6): + out_buf[vi, valid_count_buf[vi], k] = data_buf[vi, vj, k] + out_indices_buf[vi, valid_count_buf[vi]] = vj + valid_count_buf[vi] = valid_count_buf[vi] + 1 + if vj >= valid_count_buf[vi]: + for k in T.serial(0, 6): + out_buf[vi, vj, k] = T.float32(-1) + out_indices_buf[vi, vj] = T.int32(-1) def _check_get_valid_counts_with_numpy(f, dshape, score_threshold, id_index, score_index): diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 8058b96b024d1..7c54cdc85f821 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -2672,10 +2672,12 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = T.float32(0) - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func @@ -2685,11 +2687,13 @@ def matmul_original(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128]) for i, j in T.grid(128, 128): - with T.block([128, 128], "init") as [vi, vj]: + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.float32(0) for k in range(128): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -2699,11 +2703,14 @@ def element_wise(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * T.float32(2) - - with T.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + T.float32(1) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * T.float32(2) + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + T.float32(1) @T.prim_func @@ -2712,9 +2719,9 @@ def predicate(b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (16, 16), "float32") for i, jo, ji in T.grid(16, 4, 5): - with T.block([16, 16], "update") as [vi, vj]: - T.bind(vi, i) - T.bind(vj, jo * 4 + ji) + with T.block("update"): + vi = T.axis.S(16, i) + vj = T.axis.S(16, jo * 4 + ji) T.where(jo * 4 + ji < 16) C[vi, vj] = B[vi, vj] + T.float32(1) @@ -2807,12 +2814,16 @@ def match_buffer_region(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (16, 16, 16), "float32") B = T.match_buffer(b, (1), "float32") - with T.block([16, 4]) as [vi, vj]: - C = T.match_buffer(A[0:16, vi, vj * 4 : vj * 4 + 4], (16, 1, 4)) - with T.block([4]) as [vii]: - D = T.match_buffer(C[vii * 4 : vii * 4 + 4, 0, 0:4], (4, 1, 4)) - for i, j in T.grid(4, 4): - B[0] += D[i, 0, j] + for i, j in T.grid(16, 4): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + C = T.match_buffer(A[0:16, vi, vj * 4 : vj * 4 + 4], (16, 1, 4)) + for ii in range(4): + with T.block(): + vii = T.axis.S(4, ii) + D = T.match_buffer(C[vii * 4 : vii * 4 + 4, 0, 0:4], (4, 1, 4)) + for i, j in T.grid(4, 4): + B[0] += D[i, 0, j] def test_match_buffer_region(): @@ -2844,8 +2855,8 @@ def block_elements(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") B = T.match_buffer(b, (1, 1), "float32") - with T.block([1], "update") as [vi]: - T.bind(vi, 0) + with T.block("update"): + vi = T.axis.S(1, 0) T.where(True) T.reads(A[0:16, 0:16]) T.writes(B[0, 0]) @@ -2879,11 +2890,11 @@ def opaque_block(a: T.handle, b: T.handle) -> None: for i in range(16): for j in range(16): - with T.block([]): + with T.block(): T.reads([]) T.writes(A[i, j]) A[i, j] = T.float32(0) - with T.block([]): + with T.block(): T.reads([A[i, 0:16]]) T.writes([B[i, 0:16]]) for j in range(16): @@ -2927,7 +2938,7 @@ def rank0_block(a: T.handle) -> None: B = T.alloc_buffer((), "float32") T.store(B.data, 0, T.load("float32", A.data, 0)) - with T.block([], "update") as []: + with T.block("update") as []: T.reads([A[()]]) T.writes([B[()]]) for i in range(1): @@ -2969,8 +2980,10 @@ def test_minmax(): def abs(a: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") - with T.block([128, 128], "A") as [vi, vj]: - A[vi, vj] = T.abs(A[vi, vj]) + for i, j in T.grid(128, 128): + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = T.abs(A[vi, vj]) def test_abs(): @@ -3011,15 +3024,13 @@ def test_simplify_bracket(): @T.prim_func def var_with_same_name(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") - with T.block([16, 16]) as [vi, vj]: - A[vi, vj] = 0 - with T.block([16, 16]) as [vi, vj]: - A[vi, vj] = 0 for i, j in T.grid(16, 16): - with T.block([16, 16]) as [vi, vj]: + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) A[vi, vj] = 0 for i, j in T.grid(16, 16): - with T.block([16, 16]) as [vi, vj]: + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) A[vi, vj] = 0 @@ -3029,14 +3040,10 @@ def test_same_name_var(): rt_func = tvm.script.from_source(out_str) tvm.ir.assert_structural_equal(func, rt_func) - assert out_str.count("with T.block([16, 16]) as [vi, vj]") == 4 + assert out_str.count('vi, vj = T.axis.remap("SS", [i, j])') == 2 assert out_str.find("vi_") == -1 assert out_str.find("vj_") == -1 - assert out_str.count("for i0, i1 in T.grid(16, 16)") == 2 - assert out_str.find("i0_") == -1 - assert out_str.find("i1_") == -1 - assert out_str.count("for i, j in T.grid(16, 16)") == 2 assert out_str.find("i_") == -1 assert out_str.find("i_") == -1 @@ -3047,11 +3054,13 @@ def while_loop(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (16,), "float32") B = T.match_buffer(b, (16,), "float32") i = T.alloc_buffer((), "int32", scope="local") - with T.block([16]) as [vi]: - B[vi] = 0 - while i[()] < 10: - for j in range(16): - B[j] += A[j] + for ii in range(16): + with T.block(): + vi = T.axis.S(16, ii) + B[vi] = 0 + while i[()] < 10: + for j in range(16): + B[j] += A[j] def test_while_loop(): diff --git a/tests/scripts/task_ci_setup.sh b/tests/scripts/task_ci_setup.sh index 7138effe395a4..dfd2a32165f1c 100755 --- a/tests/scripts/task_ci_setup.sh +++ b/tests/scripts/task_ci_setup.sh @@ -30,7 +30,7 @@ set -o pipefail # echo "Addtiional setup in" ${CI_IMAGE_NAME} -python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.4.1 +python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.5.0 # Rebuild standalone_crt in build/ tree. This file is not currently archived by pack_lib() in # Jenkinsfile. We expect config.cmake to be present from pack_lib().