diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py index 91f5512453fb..dcda757e2c3c 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py @@ -41,7 +41,7 @@ def get_binary_elementwise_params( ------- SerialBinaryElementwise The parameters needed to construct a binary elementwise operator. - output_pointer : tvm.tir.Var + output_buffer : tvm.tir.Buffer The output pointer of the binary elementwise operation. replace_pointer : tvm.tir.Var The output pointer of the DMA write operation, which is to replace @@ -56,17 +56,17 @@ def get_binary_elementwise_params( _, _, _, _, _, inner = get_outer_loops(body, "NHWC") # loads = [input, input, LUT, LUT] loads = get_loads(inner) - input_pointer = loads[0].buffer.data - input_pointer1 = loads[1].buffer.data + input_buffer = loads[0].buffer + input_buffer1 = loads[1].buffer if reversed_operands: - input_pointer, input_pointer1 = input_pointer1, input_pointer - output_pointer = inner.buffer.data + input_buffer, input_buffer1 = input_buffer1, input_buffer + output_buffer = inner.buffer # Get feature map info - serial_ifm, _ = get_ifm_params(input_pointer, producers_consumers, stmt) - serial_ifm2, _ = get_ifm_params(input_pointer1, producers_consumers, stmt) - serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params( - output_pointer, producers_consumers, stmt + serial_ifm, _ = get_ifm_params(input_buffer, producers_consumers, stmt) + serial_ifm2, _ = get_ifm_params(input_buffer1, producers_consumers, stmt) + serial_ofm, serial_block_config, replace_buffer, is_allocator = get_ofm_params( + output_buffer, producers_consumers, stmt ) # Get activation info serial_activation = SerialActivation( @@ -87,7 +87,7 @@ def get_binary_elementwise_params( block_config=serial_block_config, rescale_config=rescale_config, ), - output_pointer, - replace_pointer, + output_buffer, + replace_buffer, is_allocator, ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py index 2358e5a221bb..caf0ef5b95b3 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py @@ -40,9 +40,9 @@ def get_conv2d_params(stmt, producers_consumers): ------- Serial2DConvolution The parameters needed to construct a 2D convolution. - output_pointer : tvm.tir.Var + output_buffer : tvm.tir.Buffer The output pointer of the convolution operation. - replace_pointer : tvm.tir.Var + replace_buffer : tvm.tir.Buffer The output pointer of the DMA write operation, which is to replace the convolution output pointer. is_allocator : bool @@ -60,12 +60,12 @@ def get_conv2d_params(stmt, producers_consumers): loads = get_loads(rc.body) # stores = [output] stores = get_stores(rc.body) - input_pointer = loads[1].buffer.data - output_pointer = stores[0].buffer.data + input_buffer = loads[1].buffer + output_buffer = stores[0].buffer # Get feature map info - serial_ifm, serial_padding = get_ifm_params(input_pointer, producers_consumers, stmt) - serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params( - output_pointer, producers_consumers, stmt + serial_ifm, serial_padding = get_ifm_params(input_buffer, producers_consumers, stmt) + serial_ofm, serial_block_config, replace_buffer, is_allocator = get_ofm_params( + output_buffer, producers_consumers, stmt ) # Get kernel info serial_kernel = SerialKernel( @@ -157,7 +157,7 @@ def get_conv2d_params(stmt, producers_consumers): upscale=attrs["upscale"], block_config=serial_block_config, ), - output_pointer, - replace_pointer, + output_buffer, + replace_buffer, is_allocator, ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py index 5878c2a7e09c..44af7e7647c6 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py @@ -47,11 +47,11 @@ def get_depthwise_conv2d_params( ------- Serial2DDepthwise The parameters needed to construct a 2D depthwise. - output_pointer : tvm.tir.Var - The output pointer of the convolution operation. - replace_pointer : tvm.tir.Var - The output pointer of the DMA write operation, which is to replace - the convolution output pointer. + output_buffer : tvm.tir.Buffer + The output buffer of the convolution operation. + replace_buffer : tvm.tir.Buffer + The output buffer of the DMA write operation, which is to replace + the convolution output buffer. is_allocator : bool Whether this operator allocates its output. @@ -64,12 +64,12 @@ def get_depthwise_conv2d_params( loads = get_loads(rw.body) # stores = [output] stores = get_stores(rw.body) - input_pointer = loads[1].buffer.data - output_pointer = stores[0].buffer.data + input_buffer = loads[1].buffer + output_buffer = stores[0].buffer # Get feature map info - serial_ifm, serial_padding = get_ifm_params(input_pointer, producers_consumers, stmt) - serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params( - output_pointer, producers_consumers, stmt + serial_ifm, serial_padding = get_ifm_params(input_buffer, producers_consumers, stmt) + serial_ofm, serial_block_config, replace_buffer, is_allocator = get_ofm_params( + output_buffer, producers_consumers, stmt ) # Get kernel info serial_kernel = SerialKernel( @@ -113,7 +113,7 @@ def get_depthwise_conv2d_params( upscale="NONE", block_config=serial_block_config, ), - output_pointer, - replace_pointer, + output_buffer, + replace_buffer, is_allocator, ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py index 82485db65866..8d819cc318c8 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py @@ -34,21 +34,21 @@ def get_pad_params(stmt): ------- pad : SerialPadding The serializable padding. - input_pointer : tvm.tir.Var - The pointer consumed by the operation. - output_pointer : tvm.tir.Var - The pointer produced by the operation. + input_buffer: tvm.tir.Buffer + The buffer consumed by the operation. + output_buffer: tvm.tir.Buffer + The buffer produced by the operation. """ _, body = get_op_attrs(stmt) n, h, w, c, _, inner = get_outer_loops(body, "NHWC") - output_pointer = inner.buffer.data + output_buffer = inner.buffer pad = SerialPadding(top=0, left=0, bottom=0, right=0) if isinstance(inner.value, tvm.tir.Call): - input_pointer = inner.value.args[1].buffer.data + input_buffer = inner.value.args[1].buffer else: - input_pointer = inner.value.buffer.data - return pad, input_pointer, output_pointer + input_buffer = inner.value.buffer + return pad, input_buffer, output_buffer padded_shape = [n.extent, h.extent, w.extent, c.extent] @@ -72,8 +72,8 @@ def _visit(expr): tvm.tir.stmt_functor.post_order_visit(cond, _visit) return ( pad, - input_pointer, - output_pointer, + input_buffer, + output_buffer, ) @@ -87,19 +87,19 @@ def get_upscale_params(stmt): Returns ------- - input_pointer : tvm.tir.Var - The pointer consumed by the operation. - output_pointer : tvm.tir.Var - The pointer produced by the operation. + input_buffer: tvm.tir.Buffer + The buffer consumed by the operation. + output_buffer: tvm.tir.Buffer + The buffer produced by the operation. """ _, body = get_op_attrs(stmt) _, _, _, _, _, inner = get_outer_loops(body, "NHWC") if isinstance(inner.value, tvm.tir.Call): - input_pointer = inner.value.args[1].buffer.data + input_buffer = inner.value.args[1].buffer else: - input_pointer = inner.value.buffer.data - output_pointer = inner.buffer.data - return (input_pointer, output_pointer) + input_buffer = inner.value.buffer + output_buffer = inner.buffer + return (input_buffer, output_buffer) def get_convert_to_nhwc_params(stmt): @@ -114,10 +114,10 @@ def get_convert_to_nhwc_params(stmt): ------- int The true number of channels. - input_pointer : tvm.tir.Var - The pointer consumed by the operation. - output_pointer : tvm.tir.Var - The pointer produced by the operation. + input_buffer: tvm.tir.Buffer + The buffer consumed by the operation. + output_buffer: tvm.tir.Buffer + The buffer produced by the operation. """ attrs, body = get_op_attrs(stmt) @@ -127,12 +127,12 @@ def get_convert_to_nhwc_params(stmt): # compute that is deemed uneccesary isn't removed by TVM. if attrs["layout"] == "NHCWB16": inner = inner.body - input_pointer = inner.value.b.buffer.data + input_buffer = inner.value.b.buffer else: - input_pointer = inner.value.buffer.data + input_buffer = inner.value.buffer - output_pointer = inner.buffer.data - return c.extent, input_pointer, output_pointer + output_buffer = inner.buffer + return c.extent, input_buffer, output_buffer def get_convert_to_nhcwb16_params(stmt): @@ -147,24 +147,24 @@ def get_convert_to_nhcwb16_params(stmt): ------- out_channels : int The true number of channels. - input_pointer : tvm.tir.Var - The pointer consumed by the operation. - output_pointer : tvm.tir.Var - The pointer produced by the operation. + input_buffer : tvm.tir.Buffer + The buffer consumed by the operation. + output_buffer : tvm.tir.Buffer + The buffer produced by the operation. """ attrs, body = get_op_attrs(stmt) _, _, _, c, b, inner = get_outer_loops(body, attrs["layout"]) - output_pointer = inner.buffer.data + output_buffer = inner.buffer if isinstance(inner.value, tvm.tir.Call): cond = inner.value.args[0] out_channels = cond.b.value - input_pointer = inner.value.args[1].buffer.data + input_buffer = inner.value.args[1].buffer else: - input_pointer = inner.value.buffer.data + input_buffer = inner.value.buffer out_channels = c.extent * b.extent if attrs["layout"] == "NHCWB16" else c.extent - return out_channels, input_pointer, output_pointer + return out_channels, input_buffer, output_buffer class Tiles(NamedTuple): @@ -298,16 +298,16 @@ def get_read_params(stmt): ------- SerialFeatureMap The serializable feature map. - input_pointer : tvm.tir.Var - The pointer consumed by the operation. - output_pointer : tvm.tir.Var - The pointer produced by the operation. + input_buffer: tvm.tir.Buffer + The buffer consumed by the operation. + output_buffer: tvm.tir.Buffer + The buffer produced by the operation. """ attrs, body = get_op_attrs(stmt) _, h, w, c, _, inner = get_outer_loops(body, attrs["layout"]) - input_pointer = inner.value.buffer.data - output_pointer = inner.buffer.data + input_buffer = inner.value.buffer + output_buffer = inner.buffer # Needed for stride calculation, can replace with # inner.value.buffer.strides in future. @@ -337,8 +337,8 @@ def get_read_params(stmt): stride_w=strides[1], stride_c=strides[2], ), - input_pointer, - output_pointer, + input_buffer, + output_buffer, ) @@ -354,16 +354,16 @@ def get_write_params(stmt): ------- SerialFeatureMap The serializable feature map. - input_pointer : tvm.tir.Var - The pointer consumed by the operation. - output_pointer : tvm.tir.Var - The pointer produced by the operation. + input_buffer: tvm.tir.Buffer + The buffer consumed by the operation. + output_buffer: tvm.tir.Buffer + The buffer produced by the operation. """ attrs, body = get_op_attrs(stmt) _, h, w, c, _, inner = get_outer_loops(body, attrs["layout"]) - input_pointer = inner.value.buffer.data - output_pointer = inner.buffer.data + input_buffer = inner.value.buffer.data + output_buffer = inner.buffer # Needed for stride calculation, can replace with # inner.value.buffer.strides in future. @@ -402,18 +402,18 @@ def get_write_params(stmt): stride_c=strides[2], ), block_config, - input_pointer, - output_pointer, + input_buffer, + output_buffer, ) -def get_ifm_params(pointer, producers_consumers, stmt): +def get_ifm_params(buffer, producers_consumers, stmt): """Get the parameters associated with the DMA capabilities for an IFM. Parameters ---------- - pointer : tvm.tir.Var - The pointer that the IFM DMA pipeline produces. + buffer: tvm.tir.Buffer + The buffer that the IFM DMA pipeline produces. producers_consumers: ProducersConsumers It associates pointers with the loop nest that produces their values and with the loop nest that consumes their values. @@ -426,13 +426,13 @@ def get_ifm_params(pointer, producers_consumers, stmt): The serializable padding. """ - pad = producers_consumers.get_producer(pointer, stmt) - serial_padding, input_pointer, _ = get_pad_params(pad) - upscale = producers_consumers.get_producer(input_pointer, pad) - input_pointer, _ = get_upscale_params(upscale) - convert_to_nhwc = producers_consumers.get_producer(input_pointer, upscale) - in_channels, input_pointer, _ = get_convert_to_nhwc_params(convert_to_nhwc) - read = producers_consumers.get_producer(input_pointer, convert_to_nhwc) + pad = producers_consumers.get_producer(buffer, stmt) + serial_padding, input_buffer, _ = get_pad_params(pad) + upscale = producers_consumers.get_producer(input_buffer, pad) + input_buffer, _ = get_upscale_params(upscale) + convert_to_nhwc = producers_consumers.get_producer(input_buffer, upscale) + in_channels, input_buffer, _ = get_convert_to_nhwc_params(convert_to_nhwc) + read = producers_consumers.get_producer(input_buffer, convert_to_nhwc) serial_ifm, _, _ = get_read_params(read) serial_ifm.channels = in_channels @@ -479,13 +479,13 @@ def _get_buffer_var(stmt): return serial_ifm, serial_padding -def get_ofm_params(pointer, producers_consumers, stmt): +def get_ofm_params(buffer, producers_consumers, stmt): """Get the parameters associated with the DMA capabilities for an OFM. Parameters ---------- - pointer : tvm.tir.Var - The pointer that the OFM DMA pipeline consumes. + buffer: tvm.tir.Buffer + The buffer that the OFM DMA pipeline consumes. producers_consumers: ProducersConsumers It associates pointers with the loop nest that produces their values and with the loop nest that consumes their values. @@ -496,20 +496,20 @@ def get_ofm_params(pointer, producers_consumers, stmt): The serializable OFM. serial_block_config : SerialBlockConfig The serializable block config. - output_pointer : tvm.tir.Var - The pointer that the OFM DMA pipeline produces. + output_buffer: tvm.tir.Buffer + The buffer that the OFM DMA pipeline produces. is_allocator : bool Whether this operator allocates its output. """ - convert_to_nhcwb16 = producers_consumers.get_consumer(pointer, stmt) - out_channels, _, output_pointer = get_convert_to_nhcwb16_params(convert_to_nhcwb16) - write = producers_consumers.get_consumer(output_pointer, convert_to_nhcwb16) - serial_ofm, serial_block_config, _, output_pointer = get_write_params(write) + convert_to_nhcwb16 = producers_consumers.get_consumer(buffer, stmt) + out_channels, _, output_buffer = get_convert_to_nhcwb16_params(convert_to_nhcwb16) + write = producers_consumers.get_consumer(output_buffer, convert_to_nhcwb16) + serial_ofm, serial_block_config, _, output_buffer = get_write_params(write) is_allocator = True - producer = producers_consumers.get_producer(output_pointer, write) + producer = producers_consumers.get_producer(output_buffer, write) if producer is None or producer != write: is_allocator = False serial_ofm.channels = out_channels - return serial_ofm, serial_block_config, output_pointer, is_allocator + return serial_ofm, serial_block_config, output_buffer, is_allocator diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/identity.py b/python/tvm/relay/backend/contrib/ethosu/tir/identity.py index 43ae52b3bae7..b6c965909ae5 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/identity.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/identity.py @@ -30,7 +30,9 @@ from .producers_consumers import ProducersConsumers -def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatureMap, tvm.tir.Var]: +def _get_feature_map( + stmt: tvm.tir.AttrStmt, fm_type: str +) -> Tuple[SerialFeatureMap, tvm.tir.Buffer]: """Get the feature map parameters from a loop nest of any shape (as long there are at most 4 nested loops). @@ -45,8 +47,8 @@ def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatur ------- SerialFeatureMap The serializable feature map. - output_pointer : tvm.tir.Var - The pointer produced by the operation. + output_buffer : tvm.tir.Buffer + The buffer produced by the operation. """ assert fm_type in ("ifm", "ofm") @@ -96,14 +98,14 @@ def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatur stride_c=strides[2] if len(strides) > 2 else 1, ) - output_pointer = inner.buffer.data + output_buffer = inner.buffer - return serial_feature_map, output_pointer + return serial_feature_map, output_buffer def get_identity_params( stmt: tvm.tir.AttrStmt, producers_consumers: ProducersConsumers -) -> Tuple[SerialPooling, tvm.tir.Var, tvm.tir.Var]: +) -> Tuple[SerialPooling, tvm.tir.Buffer, tvm.tir.Buffer]: """Get the parameters necessary to construct a call_extern for an identity pooling. Parameters @@ -118,11 +120,11 @@ def get_identity_params( ------- SerialPooling The parameters needed to construct a 2D pooling. - output_pointer : tvm.tir.Var - The output pointer of the pooling operation. - replace_pointer : tvm.tir.Var - The output pointer of the DMA write operation, which is to replace - the pooling output pointer. + output_buffer : tvm.tir.Buffer + The output buffer of the pooling operation. + replace_buffer : tvm.tir.Buffer + The output buffer of the DMA write operation, which is to replace + the pooling output buffer. is_allocator : bool Whether this operator allocates its output. @@ -136,19 +138,19 @@ def get_identity_params( # loads = [input, LUT, LUT] loads = get_loads(store) - input_pointer = loads[0].buffer.data - output_pointer = store.buffer.data + input_buffer = loads[0].buffer + output_buffer = store.buffer - read = producers_consumers.get_producer(input_pointer, stmt) - write = producers_consumers.get_consumer(output_pointer, stmt) + read = producers_consumers.get_producer(input_buffer, stmt) + write = producers_consumers.get_consumer(output_buffer, stmt) serial_ifm, _ = _get_feature_map(read, "ifm") - serial_ofm, write_output_pointer = _get_feature_map(write, "ofm") + serial_ofm, write_output_buffer = _get_feature_map(write, "ofm") - replace_pointer = write_output_pointer + replace_buffer = write_output_buffer is_allocator = True - producer = producers_consumers.get_producer(write_output_pointer, write) + producer = producers_consumers.get_producer(write_output_buffer, write) if producer is None or producer != write: is_allocator = False @@ -169,7 +171,7 @@ def get_identity_params( rounding_mode="TFL", block_config=SerialBlockConfig(0, 0, 0), ), - output_pointer, - replace_pointer, + output_buffer, + replace_buffer, is_allocator, ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 9636f2044733..2a242944789f 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -26,6 +26,7 @@ from tvm.relay.backend.contrib.ethosu import vela_api from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator as tirtocs +from .utils import collect_buffer_map from .convolution import get_conv2d_params from .depthwise import get_depthwise_conv2d_params from .pooling import get_pooling_params @@ -72,9 +73,10 @@ def ReplaceOperators(): "ethosu_unary_elementwise": get_unary_elementwise_params, } producers_consumers = ProducersConsumers() - replace_output_pointer = {} + pointer_to_buffer = {} + replace_output_buffer = {} pointer_to_extents = {} - replaced_pointers = [] + replaced_buffers = [] ReplaceInfo = namedtuple("ReplaceInfo", ["pointer", "reallocate"]) @@ -90,28 +92,37 @@ def _resolve_pointers(stmt): Additionally, it determines the extent (size/shape) of each pointer which is required for the _replace_pointers pass which runs later.""" - loads = [] def _get_loads(stmt): - if isinstance(stmt, tvm.tir.BufferLoad): - loads.append(stmt.buffer.data) + loads = [] + + def _visit(stmt): + if isinstance(stmt, tvm.tir.BufferLoad): + loads.append(stmt.buffer) + + tvm.tir.stmt_functor.post_order_visit(stmt, _visit) + return loads - buffer_var = None + def _get_output_buffer(stmt): + output_buffer = None - def _get_buffer_var(stmt): - if isinstance(stmt, tvm.tir.BufferStore): - nonlocal buffer_var - buffer_var = stmt.buffer.data + def _visit(stmt): + if isinstance(stmt, tvm.tir.BufferStore): + nonlocal output_buffer + output_buffer = stmt.buffer + + tvm.tir.stmt_functor.post_order_visit(stmt, _visit) + return output_buffer if isinstance(stmt, tvm.tir.AttrStmt): if stmt.attr_key == "pragma_op": - tvm.tir.stmt_functor.post_order_visit(stmt, _get_buffer_var) - producers_consumers.add_producer(buffer_var, stmt) + output_buffer = _get_output_buffer(stmt) + producers_consumers.add_producer(output_buffer, stmt) - tvm.tir.stmt_functor.post_order_visit(stmt, _get_loads) - for load_pointer in loads: - if load_pointer != buffer_var: - producers_consumers.add_consumer(load_pointer, stmt) + loads = _get_loads(stmt) + for load_buffer in loads: + if load_buffer != output_buffer: + producers_consumers.add_consumer(load_buffer, stmt) def _replace_operator(stmt): """Replace operators with call_externs, having derived the parameters @@ -135,23 +146,28 @@ def _replace_operator(stmt): if stmt.attr_key == "pragma_op" and op_name in op_map: # Get the parameters for the extern call param_func = op_map[op_name] - info, output_pointer, replace_pointer, is_allocator = param_func( + info, output_buffer, replace_buffer, is_allocator = param_func( stmt, producers_consumers ) - if replace_pointer is not None: + if replace_buffer is not None: # Allocate pointer only once - if replace_pointer in replaced_pointers: + if replace_buffer in replaced_buffers: is_allocator = False - replace_output_pointer[output_pointer] = ReplaceInfo( - replace_pointer, is_allocator - ) - replaced_pointers.append(replace_pointer) + replace_output_buffer[output_buffer] = ReplaceInfo(replace_buffer, is_allocator) + replaced_buffers.append(replace_buffer) # Make the extern call irb = tvm.tir.ir_builder.create() irb.emit(tvm.tir.call_extern("handle", op_name, *info)) return irb.get() return None + def _remove_buffer_decl(stmt, buffer_var): + def _mutator(stmt): + if isinstance(stmt, tvm.tir.DeclBuffer) and stmt.buffer.data == buffer_var: + return stmt.body + + return tvm.tir.stmt_functor.ir_transform(stmt, None, _mutator, ["tir.DeclBuffer"]) + def _remove_no_compile(stmt): """Certain operators are marked as 'no compile' operators. This means they should be removed from the IR as they are compiled as part of other operators. @@ -170,33 +186,41 @@ def _remove_no_compile(stmt): if isinstance(stmt, tvm.tir.Allocate): # Remove allocates - producer = producers_consumers.get_last_producer(stmt.buffer_var) + buffer = pointer_to_buffer[stmt.buffer_var] + producer = producers_consumers.get_last_producer(buffer) if producer: if producer.attr_key == "pragma_op" and producer.value.value not in op_map: - return stmt.body + return _remove_buffer_decl(stmt.body, stmt.buffer_var) return None def _replace_pointers(stmt): if isinstance(stmt, tvm.tir.Allocate): # If the allocate allocates a pointer that needs replacing - if stmt.buffer_var in replace_output_pointer: - replace_pointer, reallocate = replace_output_pointer[stmt.buffer_var] - if not reallocate: - return stmt.body - # Otherwise, rewrite the allocation statement with the new pointer - # and the new extent - replace_type = replace_pointer.type_annotation.element_type.dtype - replace_extents = pointer_to_extents[replace_pointer] - return tvm.tir.Allocate( - replace_pointer, replace_type, replace_extents, stmt.condition, stmt.body - ) + buffer = pointer_to_buffer[stmt.buffer_var] + if buffer in replace_output_buffer: + replace_buffer, reallocate = replace_output_buffer[buffer] + if reallocate: + # Otherwise, rewrite the allocation statement with the new pointer + # and the new extent + replace_pointer = replace_buffer.data + replace_type = replace_pointer.type_annotation.element_type.dtype + replace_extents = pointer_to_extents[replace_pointer] + return tvm.tir.Allocate( + replace_pointer, replace_type, replace_extents, stmt.condition, stmt.body + ) + else: + return _remove_buffer_decl(stmt.body, stmt.buffer_var) return None - def _remove_buffer_decl(stmt): + def _replace_buffers(stmt): if isinstance(stmt, tvm.tir.DeclBuffer): - if stmt.buffer.data in replace_output_pointer: - return stmt.body + if stmt.buffer in replace_output_buffer: + replace_buffer, reallocate = replace_output_buffer[stmt.buffer] + if reallocate: + return tvm.tir.DeclBuffer(replace_buffer, stmt.body) + else: + return stmt.body def _post_transform(stmt): # Replace operators with call_externs @@ -205,19 +229,21 @@ def _post_transform(stmt): result = result or _remove_no_compile(stmt) # Replace necessary pointers that were removed in the previous step result = result or _replace_pointers(stmt) - # Replace BufferDecl, since only the tir.Var data pointer is - # still used, and not the tir.Buffer - result = result or _remove_buffer_decl(stmt) + # Replace BufferDecl, since only the tir.Var data pointer may + # have been replaced or removed + result = result or _replace_buffers(stmt) return result def _ftransform(f, mod, ctx): + nonlocal pointer_to_buffer + pointer_to_buffer = collect_buffer_map(f.body) tvm.tir.stmt_functor.post_order_visit(f.body, _find_pointer_to_extent) tvm.tir.stmt_functor.post_order_visit(f.body, _resolve_pointers) producers_consumers.add_allocate_variables(pointer_to_extents.keys()) return f.with_body( tvm.tir.stmt_functor.ir_transform( - f.body, None, _post_transform, ["tir.AttrStmt", "tir.Allocate"] + f.body, None, _post_transform, ["tir.AttrStmt", "tir.Allocate", "tir.DeclBuffer"] ) ) @@ -491,6 +517,7 @@ def _visit(stmt): def transform_stmt( stmt, buf_remap, + remove_decl_buffer, var_remap, pointer_to_buffer, new_buffer_var_to_const, @@ -557,6 +584,13 @@ def _visit_rewrite(stmt): offset = new_buffer_to_split_idx[new_buffer] return tvm.tir.BufferLoad(buf_remap[stmt.buffer], [offset], stmt.span) + # Update or remove DeclBuffer + if isinstance(stmt, tvm.tir.DeclBuffer): + if stmt.buffer in remove_decl_buffer: + return stmt.body + elif stmt.buffer in buf_remap: + return tvm.tir.DeclBuffer(buf_remap[stmt.buffer], stmt.body) + if isinstance(stmt, tvm.tir.AttrStmt): node_pointer = stmt.node if node_pointer in var_remap: @@ -574,7 +608,7 @@ def _visit_rewrite(stmt): stmt, None, _visit_rewrite, - ["tir.Call", "tir.Allocate", "tir.BufferLoad", "tir.AttrStmt"], + ["tir.Call", "tir.Allocate", "tir.BufferLoad", "tir.AttrStmt", "tir.DeclBuffer"], ) def _collect_parameter_buffer_aliases(prim_func): @@ -610,17 +644,22 @@ def _ftransform(f, mod, ctx): # Step 2: Generate variable/buffer remaps, based on the # collected information. buf_remap = {} + remove_decl_buffer = set() new_buffer_var_to_const = {} new_buffer_to_split_idx = {} def define_remap(old_buf, new_buf): try: old_buffers = param_buffer_var_usage[old_buf.data] + is_param = True except KeyError: old_buffers = [old_buf] + is_param = False for old_buffer in old_buffers: buf_remap[old_buffer] = new_buf + if is_param: + remove_decl_buffer.add(old_buffer) # Any encoded buffers must be replaced for info in buffer_information["constant_buffer_replacements"]: @@ -666,6 +705,7 @@ def define_remap(old_buf, new_buf): new_body = transform_stmt( f.body, buf_remap, + remove_decl_buffer, var_remap, pointer_to_buffer, new_buffer_var_to_const, diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py b/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py index 069930475df9..646e5bad7223 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py @@ -26,7 +26,7 @@ def get_pooling_params( stmt: tvm.tir.AttrStmt, producers_consumers: ProducersConsumers -) -> Tuple[SerialPooling, tvm.tir.Var, tvm.tir.Var]: +) -> Tuple[SerialPooling, tvm.tir.Buffer, tvm.tir.Buffer]: """Get the parameters necessary to construct a call_extern for a pooling. Parameters @@ -41,9 +41,9 @@ def get_pooling_params( ------- SerialPooling The parameters needed to construct a 2D convolution. - output_pointer : tvm.tir.Var + output_buffer : tvm.tir.Buffer The output pointer of the convolution operation. - replace_pointer : tvm.tir.Var + replace_buffer : tvm.tir.Buffer The output pointer of the DMA write operation, which is to replace the convolution output pointer. is_allocator : bool @@ -57,12 +57,12 @@ def get_pooling_params( loads = get_loads(rw.body) # stores = [output] stores = get_stores(rw.body) - input_pointer = loads[1].buffer.data - output_pointer = stores[0].buffer.data + input_buffer = loads[1].buffer + output_buffer = stores[0].buffer # Get feature map info - serial_ifm, serial_padding = get_ifm_params(input_pointer, producers_consumers, stmt) - serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params( - output_pointer, producers_consumers, stmt + serial_ifm, serial_padding = get_ifm_params(input_buffer, producers_consumers, stmt) + serial_ofm, serial_block_config, replace_buffer, is_allocator = get_ofm_params( + output_buffer, producers_consumers, stmt ) # Get kernel info serial_kernel = SerialKernel( @@ -90,7 +90,7 @@ def get_pooling_params( upscale=attrs["upscale"], block_config=serial_block_config, ), - output_pointer, - replace_pointer, + output_buffer, + replace_buffer, is_allocator, ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/producers_consumers.py b/python/tvm/relay/backend/contrib/ethosu/tir/producers_consumers.py index 39cbf701649f..6985ed28383f 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/producers_consumers.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/producers_consumers.py @@ -27,52 +27,52 @@ class ProducersConsumers: def __init__(self) -> None: self.indices: dict[tvm.tir.AttrStmt, int] = {} - self.producers: list[(tvm.tir.AttrStmt, tvm.tir.expr.Var)] = [] - self.consumers: list[(tvm.tir.AttrStmt, list[tvm.tir.expr.Var])] = [] + self.producers: list[(tvm.tir.AttrStmt, tvm.tir.expr.Buffer)] = [] + self.consumers: list[(tvm.tir.AttrStmt, list[tvm.tir.expr.Buffer])] = [] self.allocate_variables: Optional[KeysView] = None - def add_producer(self, var: tvm.tir.expr.Var, attr: tvm.tir.AttrStmt) -> None: + def add_producer(self, buf: tvm.tir.Buffer, attr: tvm.tir.AttrStmt) -> None: """Add the attribute statement attr as producer of the variable var.""" self.indices[attr] = len(self.producers) - self.producers.append((attr, var)) + self.producers.append((attr, buf)) def get_producer( - self, var: tvm.tir.expr.Var, attr: tvm.tir.AttrStmt + self, buf: tvm.tir.Buffer, attr: tvm.tir.AttrStmt ) -> Optional[tvm.tir.AttrStmt]: """Get the last attribute statement which produces the variable var when the current attribute statement is attr.""" - if var not in self.allocate_variables: + if buf.data not in self.allocate_variables: return None index = self.indices[attr] for i in list(reversed(range(index + 1))): - if self.producers[i][1] == var: + if self.producers[i][1] == buf: return self.producers[i][0] return None - def get_last_producer(self, var: tvm.tir.expr.Var) -> Optional[tvm.tir.AttrStmt]: + def get_last_producer(self, buf: tvm.tir.Buffer) -> Optional[tvm.tir.AttrStmt]: """Get the last attribute statement which produces the variable var.""" - return self.get_producer(var, self.producers[-1][0]) + return self.get_producer(buf, self.producers[-1][0]) def add_allocate_variables(self, allocate_variables: KeysView) -> None: """Add the allocated variables.""" self.allocate_variables = allocate_variables - def add_consumer(self, var: tvm.tir.expr.Var, attr: tvm.tir.AttrStmt) -> None: + def add_consumer(self, buf: tvm.tir.Buffer, attr: tvm.tir.AttrStmt) -> None: """Add the attribute statement attr as consumer of the variable var.""" index = self.indices[attr] if index < len(self.consumers): - self.consumers[index][1].append(var) + self.consumers[index][1].append(buf) else: - self.consumers.append((attr, [var])) + self.consumers.append((attr, [buf])) def get_consumer( - self, var: tvm.tir.expr.Var, attr: tvm.tir.AttrStmt + self, buf: tvm.tir.Buffer, attr: tvm.tir.AttrStmt ) -> Optional[tvm.tir.AttrStmt]: """Get the first attribute statement which consumes the variable var when the current attribute statement is attr.""" index = self.indices[attr] for i in range(index, len(self.consumers)): - if var in self.consumers[i][1]: + if buf in self.consumers[i][1]: return self.consumers[i][0] return None diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/transform.py b/python/tvm/relay/backend/contrib/ethosu/tir/transform.py index 272318066b3f..c5cf413110e6 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/transform.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/transform.py @@ -36,9 +36,9 @@ def get_copy_params(stmt, producers_consumers): ------- SerialCopy The parameters needed to construct a copy. - tvm.tir.Var - The output pointer of the copy operation. - replace_pointer : tvm.tir.Var + tvm.tir.Buffer + The output buffer of the copy operation. + replace_buffer : tvm.tir.Var The output pointer of the DMA write operation, which is to replace the convolution output pointer. is_allocator : bool @@ -56,7 +56,7 @@ def get_copy_params(stmt, producers_consumers): length=length, write_address=tvm.tir.expr.BufferLoad(write_store.buffer, write_base), ), - write_store.buffer.data, + write_store.buffer, None, True, ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py index cd5d71d74b84..8ec0fa952990 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py @@ -37,9 +37,9 @@ def get_unary_elementwise_params(stmt, producers_consumers): ------- SerialUnaryElementwise The parameters needed to construct a unary elementwise operator. - output_pointer : tvm.tir.Var + output_buffer : tvm.tir.Buffer The output pointer of the unary elementwise operation. - replace_pointer : tvm.tir.Var + replace_buffer : tvm.tir.Buffer The output pointer of the DMA write operation, which is to replace the unary elementwise output pointer. is_allocator : bool @@ -48,18 +48,18 @@ def get_unary_elementwise_params(stmt, producers_consumers): attrs, body = get_op_attrs(stmt) _, _, _, _, _, inner = get_outer_loops(body, "NHWC") - input_pointer = None + input_buffer = None if isinstance(inner.value, tir.expr.Select): # ABS - input_pointer = inner.value.condition.b.buffer.data + input_buffer = inner.value.condition.b.buffer if isinstance(inner.value, tir.expr.Sub): # CLZ - input_pointer = inner.value.b.args[0].buffer.data - output_pointer = inner.buffer.data + input_buffer = inner.value.b.args[0].buffer + output_buffer = inner.buffer # Get feature map info - serial_ifm, _ = get_ifm_params(input_pointer, producers_consumers, stmt) - serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params( - output_pointer, producers_consumers, stmt + serial_ifm, _ = get_ifm_params(input_buffer, producers_consumers, stmt) + serial_ofm, serial_block_config, replace_buffer, is_allocator = get_ofm_params( + output_buffer, producers_consumers, stmt ) # Get activation info serial_activation = SerialActivation( @@ -74,7 +74,7 @@ def get_unary_elementwise_params(stmt, producers_consumers): rounding_mode=attrs["rounding_mode"], block_config=serial_block_config, ), - output_pointer, - replace_pointer, + output_buffer, + replace_buffer, is_allocator, ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py index 396735a07c4c..485539a0c6d2 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py @@ -178,7 +178,7 @@ def collect_buffer_map(stmt): buffer_map = {} def _visit(node): - if isinstance(node, (tvm.tir.BufferLoad, tvm.tir.BufferStore)): + if isinstance(node, (tvm.tir.DeclBuffer, tvm.tir.BufferLoad, tvm.tir.BufferStore)): buf = node.buffer if buf.data not in buffer_map: buffer_map[buf.data] = buf diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 50de995a9145..4fb0ed2861fc 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -448,6 +448,7 @@ def allocate(self, dtype, shape, name="buf", axis_separators=None, scope=""): buffer_var = buffer.data self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x)) + self.emit(lambda x: _stmt.DeclBuffer(buffer, x)) return BufferVar(self, buffer, dtype) def pointer(self, content_type, name="ptr", scope=""): diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index 52e38c7ba2d8..72389b5b1b02 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -223,9 +223,11 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, Stmt body = SeqStmt::Flatten(reduce_body, assign_body); for (size_t idx = size; idx != 0; --idx) { const auto& res_buffer = res_buffers[idx - 1]; + body = DeclBuffer(res_buffer, body); body = Allocate(res_buffer->data, res_buffer->dtype, res_buffer->shape, const_true(), body); if (!normal_red.empty()) { const auto& normal_res_buffer = normal_res_buffers[idx - 1]; + body = DeclBuffer(normal_res_buffer, body); body = Allocate(normal_res_buffer->data, normal_res_buffer->dtype, normal_res_buffer->shape, const_true(), body); } diff --git a/src/tir/analysis/verify_well_formed.cc b/src/tir/analysis/verify_well_formed.cc index 898183533ccd..ff3ce6b58aed 100644 --- a/src/tir/analysis/verify_well_formed.cc +++ b/src/tir/analysis/verify_well_formed.cc @@ -26,26 +26,25 @@ #include #include +#include + #include "../ir/functor_common.h" #include "tvm/ir/module.h" namespace tvm { namespace tir { -/*! \brief Verify all Expr inside the block does not contain: - * 1. loop vars outside the current block. - * 2. block vars of parent blocks. - */ -class BlockVarAccessVerifier : public StmtExprVisitor { +template +class BaseVerifier : public StmtExprVisitor { public: static bool Verify(const PrimFunc& func, bool assert_mode) { - BlockVarAccessVerifier verifier(assert_mode); + Derived verifier(func, assert_mode); verifier(func->body); return !verifier.has_error_; } - private: - explicit BlockVarAccessVerifier(bool assert_mode) : assert_mode_(assert_mode) {} + protected: + explicit BaseVerifier(const PrimFunc&, bool assert_mode) : assert_mode_(assert_mode) {} void VisitStmt(const Stmt& stmt) final { if (!has_error_) { @@ -59,6 +58,20 @@ class BlockVarAccessVerifier : public StmtExprVisitor { } } + /*! \brief Whether it's in assert mode. */ + bool assert_mode_; + /*! \brief Whether there is error. */ + bool has_error_{false}; +}; + +/*! \brief Verify all Expr inside the block does not contain: + * 1. loop vars outside the current block. + * 2. block vars of parent blocks. + */ +class BlockVarAccessVerifier : public BaseVerifier { + private: + using BaseVerifier::BaseVerifier; + void VisitExpr_(const VarNode* op) final { auto it = loop_vars_.find(op); if (it != loop_vars_.end() && it->second < block_stack_.size()) { @@ -127,18 +140,153 @@ class BlockVarAccessVerifier : public StmtExprVisitor { private: /*! \brief The map from outside loop vars to its corresponding block level. */ std::unordered_map loop_vars_; - /*! \brief Whether it's in assert mode. */ - bool assert_mode_; /*! \brief Current nested block stack level. */ std::vector block_stack_; - /*! \brief Whether there is error. */ - bool has_error_{false}; +}; + +class UndefinedBufferAccessVerifier : public BaseVerifier { + private: + using Parent = BaseVerifier; + friend class BaseVerifier; + + UndefinedBufferAccessVerifier(const PrimFunc& func, bool assert_mode) + : BaseVerifier(func, assert_mode) { + for (const Var& param : func->params) { + if (auto opt = func->buffer_map.Get(param)) { + global_defs_.emplace_back(this, opt.value(), NullOpt); + } + } + } + + // Buffer definition sites + void VisitStmt_(const BufferRealizeNode* op) final { + Context context(this, op->buffer, GetRef(op)); + Parent::VisitStmt_(op); + } + void VisitStmt_(const DeclBufferNode* op) final { + Context context(this, op->buffer, GetRef(op)); + Parent::VisitStmt_(op); + } + void VisitStmt_(const BlockNode* op) final { + std::vector context; + for (const auto& buf : op->alloc_buffers) { + context.emplace_back(this, buf, GetRef(op)); + } + for (const auto& match : op->match_buffers) { + context.emplace_back(this, match->buffer, GetRef(op)); + } + Parent::VisitStmt_(op); + } + + // Buffer usage sites + void VisitExpr_(const BufferLoadNode* op) final { + AssertDefined(op->buffer, op); + Parent::VisitExpr_(op); + } + void VisitStmt_(const BufferStoreNode* op) final { + AssertDefined(op->buffer, op); + Parent::VisitStmt_(op); + } + + // AttrStmt, which may be either usage or definition, depending on + // the attribute. + void VisitStmt_(const AttrStmtNode* op) final { + std::vector context; + + if (op->attr_key == attr::buffer_bind_scope) { + Array arr = Downcast>(op->node); + ICHECK_EQ(arr.size(), 2U); + Buffer source = Downcast(arr[0]); + Buffer target = Downcast(arr[1]); + AssertDefined(target, op); + context.emplace_back(this, source, GetRef(op)); + } else if (auto node = op->node.as()) { + AssertDefined(node.value(), op); + } + Parent::VisitStmt_(op); + } + + // A context manager for scoped buffer definitions + struct Context { + Context(UndefinedBufferAccessVerifier* self, const Buffer& buffer, Optional definition) + : self_(self), buffer_(buffer) { + if (auto it = self_->definition_site_.find(buffer_); it != self_->definition_site_.end()) { + Optional prev = (*it).second; + self_->has_error_ = true; + if (self_->assert_mode_) { + auto& fatal = LOG(FATAL); + fatal << "Buffer " << buffer << " was defined multiple times. " + << "The first definition occurred "; + if (prev) { + fatal << " in the " << prev->GetTypeKey() << ", " << prev << ". "; + } else { + fatal << " in the PrimFunc's buffer_map. "; + } + fatal << "The second definition occurred "; + if (definition) { + fatal << " in the " << definition->GetTypeKey() << ", " << definition << "."; + } else { + fatal << " in the PrimFunc's buffer_map."; + } + } + } + self_->definition_site_.Set(buffer_, definition); + } + ~Context() { + if (self_) { + self_->definition_site_.erase(buffer_); + } + } + Context& operator=(const Context&) = delete; + Context(const Context&) = delete; + Context& operator=(Context&& other) { + swap(std::move(other)); + return *this; + } + Context(Context&& other) { swap(std::move(other)); } + + void swap(Context&& other) { + std::swap(self_, other.self_); + std::swap(buffer_, other.buffer_); + } + + UndefinedBufferAccessVerifier* self_{nullptr}; + Buffer buffer_; + }; + + void AssertDefined(const Buffer& buffer, const Object* usage) { + auto it = definition_site_.find(buffer); + if (it == definition_site_.end()) { + has_error_ = true; + if (assert_mode_) { + Array defined_bufs; + for (const auto& [buf, definition] : definition_site_) { + defined_bufs.push_back(buf); + } + LOG(FATAL) << "Buffer " << buffer << "@" << buffer.get() << " was accessed as part of " + << GetRef(usage) << ", without a definition. " + << "At this location, buffers " << defined_bufs << " are defined"; + } + } + } + + // A lookup table for currently-defined buffers. The + // `Optional` contains either the location at which the buffer + // is defined, or NullOpt if the buffer was defined in the + // `buffer_map`. + Map> definition_site_; + + // A container for buffer definitions in the `buffer_map`. + std::vector global_defs_; }; bool VerifyWellFormed(const PrimFunc& func, bool assert_mode) { if (!BlockVarAccessVerifier::Verify(func, assert_mode)) { return false; } + if (!UndefinedBufferAccessVerifier::Verify(func, assert_mode)) { + return false; + } // TODO(Siyuan): add more checks here. return true; } diff --git a/src/tir/contrib/ethosu/passes.cc b/src/tir/contrib/ethosu/passes.cc index fba506fba1c9..06344cc91592 100644 --- a/src/tir/contrib/ethosu/passes.cc +++ b/src/tir/contrib/ethosu/passes.cc @@ -179,33 +179,48 @@ class HoistAllocatesMutator : public StmtExprMutator { HoistAllocatesMutator() {} PrimFunc operator()(PrimFunc main_func) { - Stmt new_main_func_body = SeqStmt::Flatten(this->VisitStmt(main_func->body)); - - // Write all allocates that were removed in reverse order - for (auto it = allocates_.rbegin(); it != allocates_.rend(); it++) { - Allocate current_alloc = *it; - if (it != allocates_.rbegin()) { - new_main_func_body = SeqStmt::Flatten(new_main_func_body); - } - new_main_func_body = - Allocate(current_alloc->buffer_var, current_alloc->dtype, current_alloc->extents, - current_alloc->condition, new_main_func_body, current_alloc->annotations, - current_alloc->span); + Stmt body = main_func->body; + while (auto opt = body.as()) { + auto node = opt.value(); + body = node->body; + node.CopyOnWrite()->body = Evaluate(0); + init_nest_.push_back(node); } - PrimFunc new_main_func = PrimFunc(main_func->params, new_main_func_body, main_func->ret_type, - main_func->buffer_map, main_func->attrs); - return new_main_func; + body = this->VisitStmt(body); + body = SeqStmt::Flatten(body); + + body = MergeNest(init_nest_, body); + + main_func.CopyOnWrite()->body = body; + return main_func; } private: Stmt VisitStmt_(const AllocateNode* op) override { - allocates_.push_back(GetRef(op)); - return VisitStmt(op->body); + auto body = op->body; + auto node = GetRef(op); + node.CopyOnWrite()->body = Evaluate(0); + init_nest_.push_back(node); + + while (true) { + auto decl_ptr = body.as(); + + if (decl_ptr && decl_ptr->buffer->data.same_as(op->buffer_var)) { + auto decl = GetRef(decl_ptr); + body = decl->body; + decl.CopyOnWrite()->body = Evaluate(0); + init_nest_.push_back(decl); + } else { + break; + } + } + + return VisitStmt(body); } /*! A stack to store allocates as they are visited. */ - std::vector allocates_; + std::vector init_nest_; }; /*! @@ -416,21 +431,6 @@ tvm::transform::Pass CopyComputeReordering(Optional max_copy_movements, TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering") .set_body_typed(CopyComputeReordering); -/*! - * \brief This mutator removes all allocates. - */ -class RemoveAllocatesMutator : public StmtExprMutator { - public: - PrimFunc operator()(PrimFunc main_func) { - auto prim_func_node{main_func.CopyOnWrite()}; - prim_func_node->body = this->VisitStmt(main_func->body); - return GetRef(prim_func_node); - } - - private: - Stmt VisitStmt_(const AllocateNode* op) override { return VisitStmt(op->body); } -}; - /*! * \brief This extractor collects information used by the MergeConstantsMutator */ @@ -438,8 +438,8 @@ class MergeConstantsInfoExtractor : public StmtExprVisitor { public: class Info { public: - /*! A stack to store allocates as they are visited. */ - std::vector allocates{}; + /*! A stack to store Allocate/DeclBuffer as they are visited. */ + std::vector init_nest{}; /*! A list that contains in the i-th position the write buffer of the i-th statement * if that statement is a copy to a buffer with global scope */ @@ -467,12 +467,28 @@ class MergeConstantsInfoExtractor : public StmtExprVisitor { Info _info{}; void VisitStmt_(const AllocateNode* op) override { - _info.allocates.push_back(GetRef(op)); - VisitStmt(op->body); + auto alloc = GetRef(op); + alloc.CopyOnWrite()->body = Evaluate(0); + _info.init_nest.push_back(alloc); + + Stmt body = op->body; + while (true) { + auto decl = body.as(); + if (decl && decl->buffer->data.same_as(op->buffer_var)) { + auto node = GetRef(decl); + node.CopyOnWrite()->body = Evaluate(0); + _info.init_nest.push_back(node); + body = decl->body; + } else { + break; + } + } + + VisitStmt(body); } void VisitStmt_(const SeqStmtNode* op) override { - std::vector seq_stmt = FlattenUnwrap(GetRef(op)).seq; + auto seq_stmt = FlattenUnwrap(GetRef(op)).seq; if (seq_stmt.size() <= 1) { StmtExprVisitor::VisitStmt_(op); @@ -590,53 +606,73 @@ class MergeConstantsMutator : public StmtExprMutator { Stmt RewritePrimFuncBody(Stmt body) { std::unordered_map var_to_allocate{}; + std::vector init_nest; + while (auto opt = body.as()) { + auto node = opt.value(); + body = node->body; + node.CopyOnWrite()->body = Evaluate(0); + init_nest.push_back(node); + } + // Rewrite old allocates - std::unordered_set buffer_vars{GetVarsForWrittenCopyBuffers()}; - for (auto it{_info.allocates.rbegin()}; it != _info.allocates.rend(); ++it) { - Allocate alloc{*it}; - var_to_allocate[alloc->buffer_var.get()] = alloc; - if (buffer_vars.count(alloc->buffer_var.as()) == 0) { - body = Allocate(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->condition, body, - alloc->annotations, alloc->span); + auto buffer_vars = GetVarsForWrittenCopyBuffers(); + for (auto init_stmt : _info.init_nest) { + if (auto opt = init_stmt.as()) { + auto alloc = opt.value(); + var_to_allocate[alloc->buffer_var.get()] = alloc; + } + + auto var = [&]() -> Var { + if (auto opt = init_stmt.as()) { + return opt.value()->buffer_var; + } else if (auto opt = init_stmt.as()) { + return opt.value()->buffer->data; + } else { + LOG(FATAL) << "Expected Allocate or DeclBuffer, but found " << init_stmt->GetTypeKey(); + } + }(); + if (!buffer_vars.count(var.get())) { + init_nest.push_back(init_stmt); } } // Rewrite new allocates - for (auto it{_info.copy_write_buffers.rbegin()}; it != _info.copy_write_buffers.rend(); ++it) { - if (Optional buffer_opt = *it) { - Buffer old_write_buffer{buffer_opt.value()}; - int new_buffer_index{ - _info.old_to_new_write_buffer[old_write_buffer.as()].first}; + for (auto buffer_opt : _info.copy_write_buffers) { + if (buffer_opt) { + Buffer old_write_buffer = buffer_opt.value(); + int new_buffer_index = _info.old_to_new_write_buffer[old_write_buffer.get()].first; // Check if the allocate has already been created if (new_buffers.count(new_buffer_index) == 0) { - BufferNode* new_buffer{old_write_buffer.CopyOnWrite()}; - new_buffer->shape = {_info.new_buffers_length[new_buffer_index]}; + Buffer new_buffer = old_write_buffer; - new_buffers[new_buffer_index] = GetRef(new_buffer); + new_buffer.CopyOnWrite()->shape = {_info.new_buffers_length[new_buffer_index]}; + new_buffers[new_buffer_index] = new_buffer; Allocate old_allocate{var_to_allocate[old_write_buffer->data.get()]}; - body = Allocate(new_buffer->data, new_buffer->dtype, new_buffer->shape, tir::const_true(), - body, old_allocate->annotations, old_allocate->span); + init_nest.push_back(Allocate(new_buffer->data, new_buffer->dtype, new_buffer->shape, + tir::const_true(), Evaluate(0), old_allocate->annotations, + old_allocate->span)); + init_nest.push_back(DeclBuffer(new_buffer, Evaluate(0))); } } } // Rewrite operators - return this->VisitStmt(body); + return MergeNest(init_nest, this->VisitStmt(body)); } - Stmt VisitStmt_(const AllocateNode* op) override { - auto allocate{CopyOnWrite(op)}; - allocate->body = this->VisitStmt(op->body); - return Stmt(allocate); - } + Stmt VisitStmt_(const AllocateNode* op) override { return this->VisitStmt(op->body); } + + Stmt VisitStmt_(const DeclBufferNode* op) override { return this->VisitStmt(op->body); } Stmt VisitStmt_(const SeqStmtNode* op) override { - std::vector seq_stmt = FlattenUnwrap(GetRef(op)).seq; + auto [seq_stmt, rewrap_nest] = FlattenUnwrap(GetRef(op)); - if (seq_stmt.size() <= 1) { - return StmtExprMutator::VisitStmt_(op); + if (seq_stmt.size() == 0) { + return Evaluate(0); + } else if (seq_stmt.size() == 1) { + return MergeNest(rewrap_nest, VisitStmt(seq_stmt[0])); } Array new_seq{}; @@ -669,7 +705,7 @@ class MergeConstantsMutator : public StmtExprMutator { } } } - return SeqStmt::Flatten(new_seq); + return MergeNest(rewrap_nest, SeqStmt::Flatten(new_seq)); } /*! Returns the variables of the buffers written by copies */ @@ -947,7 +983,6 @@ tvm::transform::Pass MergeConstants() { ICHECK(const_dict) << "Expected a ethos-u.const_dict attribute"; MergeConstantsInfoExtractor::Info info{MergeConstantsInfoExtractor()(f)}; - f = RemoveAllocatesMutator()(f); return MergeConstantsMutator(info)(f, const_dict.value()); }; return tvm::tir::transform::CreatePrimFuncPass(pass_func, 0, "tir.contrib.ethos-u.MergeConstants", diff --git a/src/tir/transforms/bind_params.cc b/src/tir/transforms/bind_params.cc index 0b71b2e8fa34..509756b3847c 100644 --- a/src/tir/transforms/bind_params.cc +++ b/src/tir/transforms/bind_params.cc @@ -46,88 +46,104 @@ namespace tir { class ParamsCollector : public StmtExprVisitor { public: - explicit ParamsCollector(const Map& constant_map) - : constant_map_(constant_map) {} - std::vector CollectParams(tir::Stmt body) { - this->VisitStmt(body); - return constant_list_; + static std::vector Collect(const Map& constant_map, + tir::Stmt body) { + auto pass = ParamsCollector(constant_map); + pass.VisitStmt(body); + return pass.used_constants_; } - void VisitExpr_(const BufferLoadNode* ln) { - if (constant_map_.find(ln->buffer->data) != constant_map_.end()) { - auto it = std::find(constant_list_.begin(), constant_list_.end(), ln->buffer->data.get()); - if (it == constant_list_.end()) { - constant_list_.push_back(ln->buffer->data.get()); - } + private: + explicit ParamsCollector(const Map& constant_map) { + for (const auto& it : constant_map) { + const auto& buf = it.first; + unused_constants_[buf->data.get()] = buf; } - StmtExprVisitor::VisitExpr_(ln); } - void VisitExpr_(const CallNode* cn) { - if (cn->op.same_as(builtin::tvm_access_ptr())) { - ICHECK_EQ(cn->args.size(), 5U); - const Var& var = Downcast(cn->args[1]); - const VarNode* buffer = cn->args[1].as(); - auto it = constant_map_.find(var); - if (it != constant_map_.end()) { - auto it = std::find(constant_list_.begin(), constant_list_.end(), buffer); - if (it == constant_list_.end()) { - constant_list_.push_back(buffer); - } - } + void VisitExpr_(const BufferLoadNode* node) { + HandleAccess(node->buffer->data.get()); + StmtExprVisitor::VisitExpr_(node); + } + + void VisitExpr_(const VarNode* node) { + HandleAccess(node); + StmtExprVisitor::VisitExpr_(node); + } + + void HandleAccess(const VarNode* buffer_var) { + auto it = unused_constants_.find(buffer_var); + if (it != unused_constants_.end()) { + used_constants_.push_back(it->second); + unused_constants_.erase(it); } - StmtExprVisitor::VisitExpr_(cn); } private: - std::vector constant_list_; - Map constant_map_; + std::vector used_constants_; + std::unordered_map unused_constants_; }; -PrimFunc BindParams(PrimFunc f, const Array& constants) { - Map constant_map; - +PrimFunc BindParams(PrimFunc func, const Array& constants) { // Remove constants from the primfunc signature - size_t num_constants = constants.size(); - size_t start = f->params.size() - num_constants; + size_t first_constant = func->params.size() - constants.size(); Array params; - for (unsigned i = 0; i < start; i++) { - params.push_back(f->params[i]); + Map buffer_map; + Map constant_map; + for (unsigned i = 0; i < func->params.size(); i++) { + Var param = func->params[i]; + if (i < first_constant) { + params.push_back(func->params[i]); + if (auto opt = func->buffer_map.Get(param)) { + buffer_map.Set(param, opt.value()); + } + } else { + auto opt = func->buffer_map.Get(param); + ICHECK(opt.defined()) << "Attempted to bind constant NDArray to parameter " << param + << ", but " << param << " is a scalar parameter"; + constant_map.Set(opt.value(), constants[i - first_constant]); + } } - auto* n = f.CopyOnWrite(); - for (unsigned i = start; i < f->params.size(); i++) { - tir::Var p = n->params[i]; - tir::Var b = n->buffer_map[p]->data; - n->buffer_map.erase(p); - constant_map.Set(b, constants[i - start]); + auto constant_list = ParamsCollector::Collect(constant_map, func->body); + + // Unwrap the root BlockRealize/Block, if present + Stmt body = func->body; + if (auto* block_realize = body.as()) { + body = block_realize->block->body; } - n->params = params; - auto constant_list = ParamsCollector(constant_map).CollectParams(n->body); // Allocate constants within the primfunc - for (auto i : constant_list) { - auto var = GetRef(i); - int ndim = constant_map[var]->ndim; - Array extents; + for (auto buf : constant_list) { + const auto& ndarray = constant_map[buf]; + int ndim = ndarray->ndim; + Array extents; for (int i = 0; i < ndim; i++) { - int shape = constant_map[var]->shape[i]; + int shape = ndarray->shape[i]; extents.push_back(make_const(DataType::Int(32), shape)); } - DataType dtype = DataType(constant_map[var]->dtype); - - if (n->body->IsInstance()) { - auto* block_realize = n->body.as(); - auto block = block_realize->block; - block.CopyOnWrite()->body = - tir::AllocateConst(var, dtype, extents, constant_map[var], block->body); - n->body = BlockRealize(block_realize->iter_values, block_realize->predicate, block); - } else { - n->body = tir::AllocateConst(var, dtype, extents, constant_map[var], n->body); - } + + DataType dtype = DataType(ndarray->dtype); + body = tir::DeclBuffer(buf, body); + body = tir::AllocateConst(buf->data, dtype, extents, ndarray, body); + } + + // Re-wrap the root BlockRealize/Block, if present + if (auto opt = func->body.as()) { + auto block_realize = opt.value(); + auto block = block_realize->block; + block.CopyOnWrite()->body = body; + block_realize.CopyOnWrite()->block = block; + body = block_realize; } - return f; + + auto* write_ptr = func.CopyOnWrite(); + write_ptr->params = params; + write_ptr->buffer_map = buffer_map; + write_ptr->body = body; + + return func; } namespace transform { diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index c04e12b8395e..0d8cc8553c95 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -41,13 +41,29 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { static PrimFunc Flatten(PrimFunc func) { arith::Analyzer ana; auto pass = BufferFlattener(&ana); - auto writer = func.CopyOnWrite(); pass.MarkBufferMapShapes(func); - writer->body = pass.VisitStmt(func->body); + auto body = pass.VisitStmt(func->body); + // The buffers in func->buffer_map are deliberately left // unflattened, as they are used for validation of user-provided // arguments. The flattened buffers used in the updated // function body alias the argument buffers. + for (size_t i = func->params.size(); i > 0; i--) { + auto handle = func->params[i - 1]; + if (auto opt = func->buffer_map.Get(handle)) { + auto old_buf = opt.value(); + if (pass.buffers_used_.count(old_buf)) { + auto new_buf = pass.GetFlattenedBuffer(old_buf); + if (!old_buf.same_as(new_buf)) { + body = DeclBuffer(new_buf, std::move(body)); + } + } + } + } + + if (!body.same_as(func->body)) { + func.CopyOnWrite()->body = std::move(body); + } return func; } @@ -153,11 +169,14 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { } Stmt VisitStmt_(const DeclBufferNode* op) final { - // TODO(rfc-70): Update the DeclBuffer node instead of - // stripping it out. Stripping it out in the current - // implementation as not all lowering passes support - // DeclBuffer. - return VisitStmt(op->body); + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + + auto new_buf = GetFlattenedBuffer(node->buffer); + if (!node->buffer.same_as(new_buf)) { + node.CopyOnWrite()->buffer = new_buf; + } + + return std::move(node); } Buffer GetFlattenedBuffer(Buffer buf) { @@ -166,16 +185,23 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { return it->second; } auto flattened = buf.GetFlattenedBuffer(); - auto writer = flattened.CopyOnWrite(); // TODO(Lunderberg): Move the handling of boolean into a // dedicated pass. if (flattened->dtype == DataType::Bool()) { - writer->dtype = DataType::Int(8); + flattened.CopyOnWrite()->dtype = DataType::Int(8); } // canonicalize shape - for (size_t i = 0; i < flattened->shape.size(); ++i) { - writer->shape.Set(i, analyzer_->canonical_simplify(flattened->shape[i])); + bool shape_is_changed = false; + Array new_shape; + for (const auto& dim : flattened->shape) { + auto new_dim = analyzer_->canonical_simplify(dim); + shape_is_changed = shape_is_changed || !StructuralEqual()(dim, new_dim); + new_shape.push_back(new_dim); + } + + if (shape_is_changed) { + flattened.CopyOnWrite()->shape = std::move(new_shape); } buffer_remap_[buf] = flattened; @@ -226,6 +252,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { template Node VisitBufferAccess(Node node) { ICHECK(node->buffer.defined()); + buffers_used_.insert(node->buffer); auto flattened_indices = GetSimplifiedElemOffset(node->buffer, node->indices); Buffer flattened_buffer = GetFlattenedBuffer(node->buffer); @@ -264,6 +291,8 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { /*! \brief Map of buffers being remapped. */ std::unordered_map buffer_remap_; + std::unordered_set buffers_used_; + /*! \brief The updated external buffer map. */ Map updated_extern_buffer_map_; }; diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index abc288f0eb24..1b2e8e9db04a 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -429,6 +429,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // Fix all local allocations as all statements are built. Stmt body = SeqStmt::Flatten(seq); for (Buffer buf : new_alloc_bufs) { + body = DeclBuffer(buf, body); body = Allocate(buf->data, buf->dtype, buf->shape, const_true(buf->dtype.lanes()), body); } diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 9b1dbf1a6618..b9fc056f1962 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -56,7 +56,7 @@ class HostDeviceSplitter : public StmtMutator { private: Stmt SplitDeviceFunc(Stmt body, Target device_target) { - Array params = [&]() { + auto [params, buffers_to_declare] = [&]() -> std::tuple, Array> { VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/false); use_def(body); @@ -71,7 +71,7 @@ class HostDeviceSplitter : public StmtMutator { }; return sort_key(a) < sort_key(b); }); - return params; + return {params, use_def.undefined_buffers_}; }(); // CodeGenCPU is used for some device-side targets, such as @@ -91,12 +91,15 @@ class HostDeviceSplitter : public StmtMutator { kernel_ret_type = VoidType(); } - GlobalVar kernel_symbol_global = var_supply_(); + for (Buffer buf : buffers_to_declare) { + body = DeclBuffer(buf, std::move(body)); + } PrimFunc device_func(params, body, kernel_ret_type); device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target}, {tir::attr::kNoAlias, Bool(true)}, {tir::attr::kIsGlobalFunc, Bool(true)}}); + GlobalVar kernel_symbol_global = var_supply_(); (*device_mod_)->Add(kernel_symbol_global, device_func); Array args = params.Map([](const Var& var) -> PrimExpr { return var; }); diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 9c1244838173..f24f14542969 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -1340,12 +1340,30 @@ class StorageFlattener : public StmtExprMutator { auto pass = StorageFlattener(func->buffer_map, cache_line_size, create_bound_attributes, &bound_analyzer); - auto fptr = func.CopyOnWrite(); - fptr->body = pass(std::move(fptr->body)); + Stmt body = pass(func->body); + + for (size_t i = func->params.size(); i > 0; i--) { + auto handle = func->params[i - 1]; + if (auto opt = func->buffer_map.Get(handle)) { + auto old_buf = opt.value(); + if (pass.buf_map_.count(old_buf)) { + auto new_buf = pass.GetBufferEntry(old_buf).flattened_buffer; + if (!old_buf.same_as(new_buf)) { + body = DeclBuffer(new_buf, std::move(body)); + } + } + } + } + // The buffers in func->buffer_map are deliberately left // unflattened, as they are used for validation of user-provided // arguments. The flattened buffers used in the updated // function body alias the argument buffers. + + if (!body.same_as(func->body)) { + func.CopyOnWrite()->body = body; + } + return func; }; return transform::CreatePrimFuncPass(pass_func, 0, "tir.StorageFlattener", {}); @@ -1542,9 +1560,10 @@ class StorageFlattener : public StmtExprMutator { buffer_var_defines_.erase(op->buffer->data.get()); buf_map_[key].in_scope = false; - Stmt ret = - Allocate(e.flattened_buffer->data, e.flattened_buffer->dtype, e.flattened_buffer->shape, - make_const(DataType::Bool(e.flattened_buffer->dtype.lanes()), true), body); + Stmt ret = body; + ret = DeclBuffer(e.flattened_buffer, body); + ret = Allocate(e.flattened_buffer->data, e.flattened_buffer->dtype, e.flattened_buffer->shape, + make_const(DataType::Bool(e.flattened_buffer->dtype.lanes()), true), ret); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { ret = AttrStmt(e.buffer->data, tir::attr::buffer_bound, diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 3ecd0f64bb44..5f9e4704e933 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -381,6 +381,17 @@ class StoragePlanRewriter : public StmtExprMutator { Stmt Rewrite(Stmt stmt, bool detect_inplace) { detect_inplace_ = detect_inplace; + + // Keep DeclBuffer of flattend argument buffers at the top of the + // function. + std::vector init_nest; + while (auto opt = stmt.as()) { + auto node = opt.value(); + stmt = node->body; + node.CopyOnWrite()->body = Evaluate(0); + init_nest.push_back(node); + } + // plan the rewrite LinearAccessPatternFinder finder; finder(stmt); @@ -390,43 +401,42 @@ class StoragePlanRewriter : public StmtExprMutator { this->PrepareNewAlloc(); // start rewrite stmt = operator()(std::move(stmt)); + if (attach_map_.count(nullptr)) { - return MakeAttach(attach_map_.at(nullptr), stmt); + stmt = MakeAttach(attach_map_.at(nullptr), stmt); } + stmt = MergeNest(init_nest, stmt); return stmt; } template Node VisitBufferAccess(Node node) { - auto it = alloc_map_.find(node->buffer->data.get()); - if (it != alloc_map_.end()) { - Buffer buf = RemapBuffer(node->buffer, it->second->alloc_var); + if (Buffer buf = RemapBuffer(node->buffer); !buf.same_as(node->buffer)) { + node.CopyOnWrite()->buffer = buf; + } + if (auto it = alloc_map_.find(node->buffer->data.get()); it != alloc_map_.end()) { Array indices = node->indices; indices.Set(indices.size() - 1, RemapIndex(node->buffer->dtype, indices[indices.size() - 1], it->second)); - auto writer = node.CopyOnWrite(); - writer->buffer = buf; - writer->indices = indices; + node.CopyOnWrite()->indices = indices; } return node; } - Buffer RemapBuffer(Buffer buf, Var new_backing_array) { + Buffer RemapBuffer(Buffer buf) { auto key = buf.get(); - auto it = buffer_remap_.find(key); - if (it != buffer_remap_.end()) { - ICHECK_EQ(it->second->data.get(), new_backing_array.get()) - << "Cannot remap buffer " << buf->name << " to use backing array " - << new_backing_array->name_hint << ", previously used backing array " - << it->second->data->name_hint; + + if (auto it = buffer_remap_.find(key); it != buffer_remap_.end()) { return it->second; } - Buffer remapped = Buffer(new_backing_array, buf->dtype, buf->shape, buf->strides, - buf->elem_offset, new_backing_array->name_hint, buf->data_alignment, - buf->offset_factor, buf->buffer_type, buf->axis_separators, buf->span); + Buffer remapped = buf; + if (auto it = alloc_map_.find(buf->data.get()); it != alloc_map_.end()) { + remapped.CopyOnWrite()->data = it->second->alloc_var; + } + buffer_remap_[key] = remapped; return remapped; } @@ -521,8 +531,7 @@ class StoragePlanRewriter : public StmtExprMutator { } auto node = Downcast(StmtExprMutator::VisitStmt_(op)); - if (auto it = alloc_map_.find(op->buffer->data.get()); it != alloc_map_.end()) { - Buffer buf = RemapBuffer(op->buffer, it->second->alloc_var); + if (Buffer buf = RemapBuffer(op->buffer); !buf.same_as(node->buffer)) { node.CopyOnWrite()->buffer = buf; } return std::move(node); @@ -659,11 +668,26 @@ class StoragePlanRewriter : public StmtExprMutator { // simply use the original allocation. e->alloc_nest.push_back(Allocate(e->alloc_var, alloc_type, e->allocs[0]->extents, e->allocs[0]->condition, Evaluate(0))); - if (auto ptr = e->allocs[0]->body.as()) { - e->alloc_nest.push_back( - DeclBuffer(RemapBuffer(ptr->buffer, e->alloc_var), Evaluate(0))); + // The first allocation may have a DeclBuffer that should be + // hoisted alongside the Allocate. + if (auto ptr = e->allocs[0]->body.as(); + ptr && ptr->buffer->data.same_as(e->alloc_var)) { + auto remapped = RemapBuffer(ptr->buffer); + e->alloc_nest.push_back(DeclBuffer(remapped, Evaluate(0))); hoisted_buffer_decls_.insert(ptr->buffer.get()); + // Remaining allocations may have a DeclBuffer that should + // be removed, with all occurrences of that buffer + // replaced. + for (size_t i = 1; i < e->allocs.size(); i++) { + const auto* alloc = e->allocs[i]; + if (auto ptr = alloc->body.as(); + ptr && ptr->buffer->data.same_as(alloc->buffer_var)) { + hoisted_buffer_decls_.insert(ptr->buffer.get()); + buffer_remap_[ptr->buffer.get()] = remapped; + } + } } + if (IsSpecialTaggedMemory(e->scope)) { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); if (info.defined()) { @@ -712,6 +736,19 @@ class StoragePlanRewriter : public StmtExprMutator { combo_size = analyzer_.Simplify(combo_size); e->alloc_nest.push_back( Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(), Evaluate(0))); + + // Any buffers immediately within the Allocate should be + // hoisted to the Allocate's new location. + for (const auto& alloc : e->allocs) { + Stmt body = alloc->body; + auto decl = body.as(); + if (decl && decl->buffer->data.same_as(alloc->buffer_var)) { + e->alloc_nest.push_back(DeclBuffer(RemapBuffer(decl->buffer), Evaluate(0))); + hoisted_buffer_decls_.insert(decl->buffer.get()); + body = decl->body; + } + } + if (IsSpecialTaggedMemory(e->scope)) { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); if (info.defined()) { diff --git a/tests/python/contrib/test_ethosu/cascader/test_integration.py b/tests/python/contrib/test_ethosu/cascader/test_integration.py index 14cc8fbc61cf..1eb3c3b87aab 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_integration.py +++ b/tests/python/contrib/test_ethosu/cascader/test_integration.py @@ -109,7 +109,10 @@ def test_single_conv_compute_cycles_hint(): for single convolution. """ primfunc = _compile_model(_create_single_conv2d()) - ops = primfunc.body.body.seq + body = primfunc + while not isinstance(body, tvm.tir.SeqStmt): + body = body.body + ops = body.seq compute_cycles_hints = [2944, 320] for op, compute_cycle_hint in zip(ops, compute_cycles_hints): assert op.attr_key == "pragma_compute_cycles_hint" @@ -122,7 +125,10 @@ def test_double_conv_compute_cycles_hint(): for double convolution. """ primfunc = _compile_model(_create_double_conv2d()) - ops = primfunc.body.body.body.body.seq + body = primfunc + while not isinstance(body, tvm.tir.SeqStmt): + body = body.body + ops = body.seq compute_cycles_hints = [2944, 1408, 320, 240] for op, compute_cycle_hint in zip(ops, compute_cycles_hints): assert op.attr_key == "pragma_compute_cycles_hint" @@ -135,7 +141,10 @@ def test_scalar_add_compute_cycles_hint(): for add with scalar values. """ primfunc = _compile_model(_create_scalar_add()) - ops = primfunc.body.body.seq + body = primfunc + while not isinstance(body, tvm.tir.SeqStmt): + body = body.body + ops = body.seq compute_cycles_hints = [16, 24] for op, compute_cycle_hint in zip(ops, compute_cycles_hints): diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 4341f367f0e1..86c8763b9de6 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -39,8 +39,8 @@ class WeightStreamOnlyU55: def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_write: T.Buffer((1, 16, 16, 8), "int8")) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - placeholder = T.Buffer([8192], "int8", data=input_placeholder.data) - ethosu_write = T.Buffer([2048], "int8", data=input_ethosu_write.data) + placeholder = T.decl_buffer([8192], "int8", data=input_placeholder.data) + ethosu_write = T.decl_buffer([2048], "int8", data=input_ethosu_write.data) buffer1 = T.Buffer([160], "uint8") buffer3 = T.Buffer([144], "uint8") buffer5 = T.Buffer([144], "uint8") @@ -48,10 +48,10 @@ def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_writ buffer8 = T.Buffer([32], "uint8") # body p1_data = T.allocate([160], "uint8", "global", annotations={"disable_lower_builtin":True}) - p1 = T.Buffer([160], "uint8", data=p1_data) + p1 = T.decl_buffer([160], "uint8", data=p1_data) + buffer9 = T.decl_buffer([144], "uint8", data=p1_data) p2_data = T.allocate([144], "uint8", "global", annotations={"disable_lower_builtin":True}) - p2 = T.Buffer([144], "uint8", data=p2_data) - buffer9 = T.Buffer([144], "uint8", data=p1.data) + p2 = T.decl_buffer([144], "uint8", data=p2_data) T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 160, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 144, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, T.int8(-1), T.int8(-1), 12, p1[128], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -68,27 +68,38 @@ class WeightStreamOnlyU65: @T.prim_func def main(ifm: T.Buffer((1, 16, 16, 32), "int8"), ethosu_write: T.Buffer((1, 16, 16, 8), "int8")): T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)}) + + ifm_1 = T.decl_buffer((8192,), "int8", data=ifm.data) + ethosu_write_1 = T.decl_buffer((2048,), "int8", data=ethosu_write.data) + p2_global_6 = T.allocate([192], "uint8", "global", annotations={"disable_lower_builtin": T.bool(True)}) + p2_global_3 = T.decl_buffer((192,), "uint8", data=p2_global_6) + p2_global_4 = T.allocate([192], "uint8", "global", annotations={"disable_lower_builtin": T.bool(True)}) + p2_global_4_1 = T.decl_buffer((192,), "uint8", data=p2_global_4) + p2_global_5 = T.allocate([208], "uint8", "global", annotations={"disable_lower_builtin": T.bool(True)}) + p2_global_5_1 = T.decl_buffer((208,), "uint8", data=p2_global_5) + buffer_encoded = T.Buffer((192,), "uint8") - p2_global_3 = T.Buffer((192,), "uint8", data=p2_global_6) - T.call_extern("handle", "ethosu_copy", buffer_encoded[0], 192, p2_global_3[0]) buffer_encoded_1 = T.Buffer((192,), "uint8") - p2_global_4_1 = T.Buffer((192,), "uint8", data=p2_global_4) - T.call_extern("handle", "ethosu_copy", buffer_encoded_1[0], 192, p2_global_4_1[0]) buffer_encoded_2 = T.Buffer((208,), "uint8") - p2_global_5_1 = T.Buffer((208,), "uint8", data=p2_global_5) + buffer_encoded_3 = T.Buffer((192,), "uint8") + + + T.call_extern("handle", "ethosu_copy", buffer_encoded[0], 192, p2_global_3[0]) + + T.call_extern("handle", "ethosu_copy", buffer_encoded_1[0], 192, p2_global_4_1[0]) T.call_extern("handle", "ethosu_copy", buffer_encoded_2[0], 208, p2_global_5_1[0]) - ifm_1 = T.Buffer((8192,), "int8", data=ifm.data) - ethosu_write_1 = T.Buffer((2048,), "int8", data=ethosu_write.data) + + T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, ifm_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2_global_3[0], 80, p2_global_3[80], 80, 12, p2_global_3[160], 16, p2_global_3[176], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) - buffer_encoded_3 = T.Buffer((192,), "uint8") - p2_global_6_1 = T.Buffer((192,), "uint8", data=p2_global_6) - T.call_extern("handle", "ethosu_copy", buffer_encoded_3[0], 192, p2_global_6_1[0]) + + + T.call_extern("handle", "ethosu_copy", buffer_encoded_3[0], 192, p2_global_3[0]) T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, ifm_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write_1[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2_global_4_1[0], 80, p2_global_4_1[80], 80, 12, p2_global_4_1[160], 16, p2_global_4_1[176], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, ifm_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write_1[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2_global_5_1[0], 96, p2_global_5_1[96], 80, 12, p2_global_5_1[176], 16, p2_global_5_1[192], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) - T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, ifm_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write_1[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2_global_6_1[0], 80, p2_global_6_1[80], 80, 12, p2_global_6_1[160], 16, p2_global_6_1[176], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) + T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, ifm_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write_1[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2_global_3[0], 80, p2_global_3[80], 80, 12, p2_global_3[160], 16, p2_global_3[176], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) __tvm_meta__ = None # fmt: on @@ -154,16 +165,18 @@ def _get_func(): class RereadWeightsU55: @T.prim_func def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_write: T.Buffer((1, 16, 16, 8), "int8")) -> None: - # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer1 = T.Buffer([384], "uint8") - placeholder = T.Buffer([8192], "int8", data=input_placeholder.data) - ethosu_write = T.Buffer([2048], "int8", data=input_ethosu_write.data) - # body + placeholder = T.decl_buffer([8192], "int8", data=input_placeholder.data) + ethosu_write = T.decl_buffer([2048], "int8", data=input_ethosu_write.data) + p1_data = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin":True}) - p1 = T.Buffer([384], "uint8", data=p1_data) + p1 = T.decl_buffer([384], "uint8", data=p1_data) + p2_data = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin":True}) - p2 = T.Buffer([384], "uint8", data=p2_data) + p2 = T.decl_buffer([384], "uint8", data=p2_data) + + buffer1 = T.Buffer([384], "uint8") + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 384, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 384, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 304, T.int8(-1), T.int8(-1), 12, p1[304], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -175,17 +188,19 @@ def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_writ class RereadWeightsU65: @T.prim_func def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_write: T.Buffer((1, 16, 16, 8), "int8")) -> None: - # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - # buffer definition - placeholder = T.Buffer([8192], dtype="int8", data=input_placeholder.data) - ethosu_write = T.Buffer([2048], dtype="int8", data=input_ethosu_write.data) - placeholder_encoded_1 = T.Buffer([464], "uint8") - # body + + placeholder = T.decl_buffer([8192], dtype="int8", data=input_placeholder.data) + ethosu_write = T.decl_buffer([2048], dtype="int8", data=input_ethosu_write.data) + p1_data = T.allocate([464], "uint8", "global", annotations={"disable_lower_builtin":True}) - p1 = T.Buffer([464], "uint8", data=p1_data) + p1 = T.decl_buffer([464], "uint8", data=p1_data) + p2_data = T.allocate([464], "uint8", "global", annotations={"disable_lower_builtin":True}) - p2 = T.Buffer([464], "uint8", data=p2_data) + p2 = T.decl_buffer([464], "uint8", data=p2_data) + + placeholder_encoded_1 = T.Buffer([464], "uint8") + T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1[0], 464, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1[0], 464, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -256,17 +271,19 @@ def _get_func(): class DirectReadOnlyU55: @T.prim_func def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_write: T.Buffer((1, 16, 16, 8), "int8")) -> None: - # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + + placeholder = T.decl_buffer([8192], "int8", data=input_placeholder.data) + ethosu_write = T.decl_buffer([2048], "int8", data=input_ethosu_write.data) + + ethosu_write_1_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) + ethosu_write_1 = T.decl_buffer([4096], "int8", data=ethosu_write_1_data) + buffer = T.Buffer([592], "uint8") buffer_1 = T.Buffer([160], "uint8") buffer_2 = T.Buffer([160], "uint8") buffer_3 = T.Buffer([80], "uint8") - placeholder = T.Buffer([8192], "int8", data=input_placeholder.data) - ethosu_write = T.Buffer([2048], "int8", data=input_ethosu_write.data) - # body - ethosu_write_1_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) - ethosu_write_1 = T.Buffer([4096], "int8", data=ethosu_write_1_data) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer[0], 592, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, buffer_2[0], 160, T.int8(-1), T.int8(-1), 12, buffer_3[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -276,18 +293,19 @@ def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_writ class DirectReadOnlyU65: @T.prim_func def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_write: T.Buffer((1, 16, 16, 8), "int8")) -> None: - # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - # buffer definition + + placeholder = T.decl_buffer([8192], dtype="int8", data=input_placeholder.data) + ethosu_write = T.decl_buffer([2048], dtype="int8", data=input_ethosu_write.data) + + ethosu_write_2_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) + ethosu_write_2 = T.decl_buffer([4096], "int8", data=ethosu_write_2_data) + placeholder_encoded = T.Buffer([608], dtype="uint8") placeholder_encoded_1 = T.Buffer([160], dtype="uint8") placeholder_encoded_2 = T.Buffer([208], dtype="uint8") placeholder_encoded_3 = T.Buffer([96], dtype="uint8") - placeholder = T.Buffer([8192], dtype="int8", data=input_placeholder.data) - ethosu_write = T.Buffer([2048], dtype="int8", data=input_ethosu_write.data) - # body - ethosu_write_2_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) - ethosu_write_2 = T.Buffer([4096], "int8", data=ethosu_write_2_data) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_encoded[0], 304, placeholder_encoded[304], 304, 12, placeholder_encoded_1[0], 80, placeholder_encoded_1[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_encoded_2[0], 112, placeholder_encoded_2[112], 96, 12, placeholder_encoded_3[0], 48, placeholder_encoded_3[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -354,23 +372,28 @@ def _get_func(): class MixedReadU55: @T.prim_func def main(input_ifm: T.Buffer((1,16,16,32), "int8"), input_ethosu_write: T.Buffer((1,16,16,8), "int8")) -> None: - # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + + ifm = T.decl_buffer([8192], "int8", data=input_ifm.data) + ethosu_write = T.decl_buffer([2048], "int8", data=input_ethosu_write.data) + + p1_data = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True}) + p1 = T.decl_buffer([112], "uint8", data=p1_data) + + p3_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) + p3 = T.decl_buffer([4096], "int8", data=p3_data) + + p2_data = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True}) + p2 = T.decl_buffer([112], "uint8", data=p2_data) + buffer1 = T.Buffer([112], "uint8") buffer3 = T.Buffer([112], "uint8") buffer5 = T.Buffer([112], "uint8") buffer7 = T.Buffer([112], "uint8") buffer9 = T.Buffer([592], "uint8") buffer10 = T.Buffer([160], "uint8") - ifm = T.Buffer([8192], "int8", data=input_ifm.data) - ethosu_write = T.Buffer([2048], "int8", data=input_ethosu_write.data) - # body - p1_data = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True}) - p1 = T.Buffer([112], "uint8", data=p1_data) - p3_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) - p3 = T.Buffer([4096], "int8", data=p3_data) - p2_data = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True}) - p2 = T.Buffer([112], "uint8", data=p2_data) + + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 112, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, ifm[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer9[0], 592, T.int8(-1), T.int8(-1), 12, buffer10[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 112, p2[0], dtype="handle")) @@ -388,32 +411,38 @@ class MixedReadU65: @T.prim_func def main(ifm: T.Buffer((1, 16, 16, 32), "int8"), ethosu_write: T.Buffer((1, 16, 16, 8), "int8")): T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)}) + + ifm_1 = T.decl_buffer((8192,), "int8", data=ifm.data) + ethosu_write_3 = T.decl_buffer((2048,), "int8", data=ethosu_write.data) + p5_global = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin": T.bool(True)}) + p5_global_3 = T.decl_buffer((128,), "uint8", data=p5_global) + p5_global_1 = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin": T.bool(True)}) + p5_global_4 = T.decl_buffer((128,), "uint8", data=p5_global_1) + ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin": T.bool(True)}) + ethosu_write_2 = T.decl_buffer((4096,), "int8", data=ethosu_write_1) + p5_global_2 = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin": T.bool(True)}) + p5_global_5 = T.decl_buffer((128,), "uint8", data=p5_global_2) + + p1_encoded = T.Buffer((608,), "uint8") + p2_encoded = T.Buffer((160,), "uint8") buffer_encoded = T.Buffer((128,), "uint8") - p5_global_3 = T.Buffer((128,), "uint8", data=p5_global) - T.call_extern("handle", "ethosu_copy", buffer_encoded[0], 128, p5_global_3[0]) buffer_encoded_1 = T.Buffer((128,), "uint8") - p5_global_4 = T.Buffer((128,), "uint8", data=p5_global_1) + buffer_encoded_2 = T.Buffer((128,), "uint8") + buffer_encoded_3 = T.Buffer((128,), "uint8") + + T.call_extern("handle", "ethosu_copy", buffer_encoded[0], 128, p5_global_3[0]) T.call_extern("handle", "ethosu_copy", buffer_encoded_1[0], 128, p5_global_4[0]) - ifm_1 = T.Buffer((8192,), "int8", data=ifm.data) - ethosu_write_2 = T.Buffer((4096,), "int8", data=ethosu_write_1) - p1_encoded = T.Buffer((608,), "uint8") - p2_encoded = T.Buffer((160,), "uint8") T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, ifm_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, p1_encoded[0], 304, p1_encoded[304], 304, 12, p2_encoded[0], 80, p2_encoded[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) - buffer_encoded_2 = T.Buffer((128,), "uint8") - p5_global_5 = T.Buffer((128,), "uint8", data=p5_global_2) T.call_extern("handle", "ethosu_copy", buffer_encoded_2[0], 128, p5_global_5[0]) - ethosu_write_3 = T.Buffer((2048,), "int8", data=ethosu_write.data) T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write_3[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p5_global_3[0], 48, p5_global_3[48], 48, 12, p5_global_3[96], 16, p5_global_3[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) - buffer_encoded_3 = T.Buffer((128,), "uint8") - p5_global_6 = T.Buffer((128,), "uint8", data=p5_global) - T.call_extern("handle", "ethosu_copy", buffer_encoded_3[0], 128, p5_global_6[0]) + T.call_extern("handle", "ethosu_copy", buffer_encoded_3[0], 128, p5_global_3[0]) T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write_3[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p5_global_4[0], 48, p5_global_4[48], 48, 12, p5_global_4[96], 16, p5_global_4[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write_3[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p5_global_5[0], 48, p5_global_5[48], 48, 12, p5_global_5[96], 16, p5_global_5[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) - T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write_3[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p5_global_6[0], 48, p5_global_6[48], 48, 12, p5_global_6[96], 16, p5_global_6[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) + T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write_3[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p5_global_3[0], 48, p5_global_3[48], 48, 12, p5_global_3[96], 16, p5_global_3[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) __tvm_meta__ = None # fmt: on @@ -512,7 +541,10 @@ def get_graph(): # Check tile address for the scalar constant input hasn't been # overwritten. - extern_calls = tir_mod["main"].body.body.body.body + extern_calls = tir_mod["main"] + while not isinstance(extern_calls, tvm.tir.SeqStmt): + extern_calls = extern_calls.body + binary_elementwise = extern_calls[-1].value args = binary_elementwise.args diff --git a/tests/python/contrib/test_ethosu/test_identity_optimizer.py b/tests/python/contrib/test_ethosu/test_identity_optimizer.py index 3ae58dfc81ba..e5f288482703 100644 --- a/tests/python/contrib/test_ethosu/test_identity_optimizer.py +++ b/tests/python/contrib/test_ethosu/test_identity_optimizer.py @@ -311,7 +311,10 @@ def get_graph(): # Check for hints in the TIR prim func that the identity optimization pass # has ran. There should not be an identity in the prim func. - assert prim_func.body.value.args[0] == "ethosu_pooling" + body = prim_func + while not isinstance(body, tvm.tir.Evaluate): + body = body.body + assert body.value.args[0] == "ethosu_pooling" def test_same_output(): diff --git a/tests/python/contrib/test_ethosu/test_layout_optimizer.py b/tests/python/contrib/test_ethosu/test_layout_optimizer.py index 69d549acbb3b..355e302be956 100644 --- a/tests/python/contrib/test_ethosu/test_layout_optimizer.py +++ b/tests/python/contrib/test_ethosu/test_layout_optimizer.py @@ -794,7 +794,7 @@ def get_graph(): prim_func = mod[external_gv_name] # Check for hints in the TIR prim func that the layout optimization pass has ran - ops = prim_func.body.body.seq + ops = prim_func.body.body.body.body.body.seq max_pool1, max_pool2 = ops assert str(max_pool1.value.args[31]) == '"NHCWB16"' diff --git a/tests/python/contrib/test_ethosu/test_remove_concatenates.py b/tests/python/contrib/test_ethosu/test_remove_concatenates.py index ef034930d7bc..d719b67b9009 100644 --- a/tests/python/contrib/test_ethosu/test_remove_concatenates.py +++ b/tests/python/contrib/test_ethosu/test_remove_concatenates.py @@ -35,9 +35,9 @@ def main(input_placeholder: T.Buffer((1,8,12,16), "int8"), input_placeholder_1: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - placeholder = T.Buffer(1536, dtype="int8", data=input_placeholder.data) - placeholder_1 = T.Buffer(1280, dtype="int8", data=input_placeholder_1.data) - T_concat = T.Buffer(4096, dtype="int8", data=input_T_concat.data) + placeholder = T.decl_buffer(1536, dtype="int8", data=input_placeholder.data) + placeholder_1 = T.decl_buffer(1280, dtype="int8", data=input_placeholder_1.data) + T_concat = T.decl_buffer(4096, dtype="int8", data=input_T_concat.data) buffer = T.Buffer([2992], "uint8") buffer_1 = T.Buffer([160], "uint8") @@ -49,7 +49,7 @@ def main(input_placeholder: T.Buffer((1,8,12,16), "int8"), input_placeholder_1: buffer_7 = T.Buffer([160], "uint8") # body T_concat_1_data = T.allocate([2816], "int8", "global", annotations={"disable_lower_builtin":True}) - T_concat_1 = T.Buffer([2816], "int8", data=T_concat_1_data) + T_concat_1 = T.decl_buffer([2816], "int8", data=T_concat_1_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, placeholder_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 160, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T_concat_1[192], 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 2992, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, T_concat_1[192], 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T_concat[352], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 2992, T.int8(-1), T.int8(-1), 12, buffer_3[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 12, 16, 8, 0, 12, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 192, 16, 1, "int8", 8, 12, 16, 8, 0, 12, T_concat_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, buffer_4[0], 2992, T.int8(-1), T.int8(-1), 12, buffer_5[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index 32d1303e124e..80cf012fb92f 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -368,17 +368,17 @@ def _visit(stmt): class Conv2dDoubleCascade1: @T.prim_func def main(input_placeholder_5: T.Buffer((1, 8, 8, 3), "int8"), input_ethosu_write_1: T.Buffer((1, 8, 8, 8), "int8")) -> None: - # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.Buffer([304], "uint8") + buffer = T.Buffer([304], "uint8") buffer_1 = T.Buffer([80], "uint8") buffer_2 = T.Buffer([320], "uint8") buffer_3 = T.Buffer([160], "uint8") - placeholder_5 = T.Buffer([192], 'int8', data=input_placeholder_5.data) - ethosu_write_1 = T.Buffer([512], 'int8', data=input_ethosu_write_1.data) - # body + + placeholder_5 = T.decl_buffer([192], 'int8', data=input_placeholder_5.data) + ethosu_write_1 = T.decl_buffer([512], 'int8', data=input_ethosu_write_1.data) + ethosu_write_2_data = T.allocate([1024], "int8", "global", annotations={"disable_lower_builtin": True}) - ethosu_write_2 = T.Buffer([1024], "int8", data=ethosu_write_2_data) + ethosu_write_2 = T.decl_buffer([1024], "int8", data=ethosu_write_2_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, buffer_3[0], 160, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, buffer[0], 304, T.int8(-1), T.int8(-1), 12, buffer_1[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, placeholder_5[12], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, buffer_3[0], 160, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -392,15 +392,18 @@ class Conv2dDoubleCascade2: def main(input_placeholder_5: T.Buffer((1, 8, 8, 3), "int8"), input_ethosu_write_1: T.Buffer((1, 8, 8, 8), "int8")) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + placeholder_5 = T.decl_buffer([192], 'int8', data=input_placeholder_5.data) + ethosu_write_1 = T.decl_buffer([512], 'int8', data=input_ethosu_write_1.data) + + ethosu_write_2_data = T.allocate([1536], "int8", "global", annotations={"disable_lower_builtin": True}) + ethosu_write_2 = T.decl_buffer([1536], "int8", data=ethosu_write_2_data) + buffer = T.Buffer([80], "uint8") buffer_1 = T.Buffer([320], "uint8") buffer_2 = T.Buffer([1312], "uint8") buffer_3 = T.Buffer([2608], "uint8") - placeholder_5 = T.Buffer([192], 'int8', data=input_placeholder_5.data) - ethosu_write_1 = T.Buffer([512], 'int8', data=input_ethosu_write_1.data) # body - ethosu_write_2_data = T.allocate([1536], "int8", "global", annotations={"disable_lower_builtin": True}) - ethosu_write_2 = T.Buffer([1536], "int8", data=ethosu_write_2_data) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 1312, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, buffer_3[0], 2608, T.int8(-1), T.int8(-1), 12, buffer[0], 80, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[48], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 1312, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -418,12 +421,12 @@ def main(input_placeholder_5: T.Buffer((1, 16, 16, 3), "int8"), input_ethosu_wri buffer_1 = T.Buffer([80], "uint8") buffer_2 = T.Buffer([320], "uint8") buffer_3 = T.Buffer([880], "uint8") - placeholder_5 = T.Buffer([768], 'int8', data=input_placeholder_5.data) - ethosu_write_1 = T.Buffer([640], 'int8', data=input_ethosu_write_1.data) + placeholder_5 = T.decl_buffer([768], 'int8', data=input_placeholder_5.data) + ethosu_write_1 = T.decl_buffer([640], 'int8', data=input_ethosu_write_1.data) # body ethosu_write_2_data = T.allocate([2560], "int8", "global", annotations={"disable_lower_builtin": True}) - ethosu_write_2 = T.Buffer([2560], "int8", data=ethosu_write_2_data) + ethosu_write_2 = T.decl_buffer([2560], "int8", data=ethosu_write_2_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, ethosu_write_2[512], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, buffer_3[0], 880, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 32, 8, 0, 8, ethosu_write_2[512], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, buffer[0], 1744, T.int8(-1), T.int8(-1), 12, buffer_1[0], 80, T.int8(-1), T.int8(-1), 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 12, 16, 3, 12, 0, 16, placeholder_5[192], 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 10, 8, 32, 10, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, buffer_3[0], 880, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -443,11 +446,11 @@ def main(input_placeholder_5: T.Buffer((1, 8, 1, 8, 16), "int8"), input_ethosu_w buffer_1 = T.Buffer([352], "uint8") buffer_2 = T.Buffer([272], "uint8") buffer_3 = T.Buffer([11040], "uint8") - placeholder_5 = T.Buffer([1024], 'int8', data=input_placeholder_5.data) - ethosu_write_1 = T.Buffer([2048], 'int8', data=input_ethosu_write_1.data) + placeholder_5 = T.decl_buffer([1024], 'int8', data=input_placeholder_5.data) + ethosu_write_1 = T.decl_buffer([2048], 'int8', data=input_ethosu_write_1.data) # body ethosu_write_2_data = T.allocate([2304], "int8", "global", annotations={"disable_lower_builtin": True}) - ethosu_write_2 = T.Buffer((2304,), "int8", data=ethosu_write_2_data) + ethosu_write_2 = T.decl_buffer((2304,), "int8", data=ethosu_write_2_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[384], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, buffer[0], 1456, T.int8(-1), T.int8(-1), 12, buffer_1[0], 352, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[384], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, buffer_3[0], 11040, T.int8(-1), T.int8(-1), 12, buffer_2[0], 272, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[256], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, buffer[0], 1456, T.int8(-1), T.int8(-1), 12, buffer_1[0], 352, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -465,11 +468,11 @@ def main(input_placeholder: T.Buffer((1, 8, 8, 3), "int8"), input_ethosu_write: buffer_1 = T.Buffer([320], "uint8") buffer_2 = T.Buffer([304], "uint8") buffer_3 = T.Buffer([80], "uint8") - placeholder = T.Buffer([192], 'int8', data=input_placeholder.data) - ethosu_write = T.Buffer([8192], 'int8', data=input_ethosu_write.data) + placeholder = T.decl_buffer([192], 'int8', data=input_placeholder.data) + ethosu_write = T.decl_buffer([8192], 'int8', data=input_ethosu_write.data) # body ethosu_write_1_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) - ethosu_write_1 = T.Buffer([4096], "int8", data=ethosu_write_1_data) + ethosu_write_1 = T.decl_buffer([4096], "int8", data=ethosu_write_1_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, buffer[0], 160, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 32, 8, 16, 0, 32, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 8, 1, 1, 1, 1, 1, 1, 1, buffer_2[0], 304, T.int8(-1), T.int8(-1), 12, buffer_3[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, placeholder[96], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, buffer[0], 160, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", 0, 0, 0, dtype="handle")) @@ -487,11 +490,11 @@ def main(input_placeholder: T.Buffer((1, 8, 1, 8, 16), "int8"), input_ethosu_wri buffer_1 = T.Buffer([352], "uint8") buffer_2 = T.Buffer([11040], "uint8") buffer_3 = T.Buffer([272], "uint8") - placeholder = T.Buffer([1024], 'int8', data=input_placeholder.data) - ethosu_write = T.Buffer([32768], 'int8', data=input_ethosu_write.data) + placeholder = T.decl_buffer([1024], 'int8', data=input_placeholder.data) + ethosu_write = T.decl_buffer([32768], 'int8', data=input_ethosu_write.data) # body ethosu_write_1_data = T.allocate([12288], "int8", "global", annotations={"disable_lower_builtin":True}) - ethosu_write_1 = T.Buffer([12288], "int8", data=ethosu_write_1_data) + ethosu_write_1 = T.decl_buffer([12288], "int8", data=ethosu_write_1_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 3, 8, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 16, 16, 35, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 768, 16, 256, 3, 3, 1, 1, 1, 1, buffer[0], 1456, T.int8(-1), T.int8(-1), 12, buffer_1[0], 352, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 35, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 768, 16, 256, "int8", 32, 32, 26, 32, 0, 32, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 1024, 16, 512, 3, 3, 1, 1, 1, 1, buffer_2[0], 11040, T.int8(-1), T.int8(-1), 12, buffer_3[0], 272, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -636,6 +639,7 @@ def _get_func( config = { "enable_cascader": True, } + with tvm.transform.PassContext(opt_level=3, config={"relay.ext.ethos-u.options": config}): func = _get_func(*params[:-1]) mod, _ = _lower_to_tir(func, cascader=total_cascader(params[-1])) @@ -653,8 +657,8 @@ def main(input_placeholder_3: T.Buffer((1, 10, 12, 8), "int8"), input_ethosu_wri T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.Buffer([848], "uint8") buffer_1 = T.Buffer([160], "uint8") - placeholder_3 = T.Buffer([960], 'int8', data=input_placeholder_3.data) - ethosu_write_1 = T.Buffer([1024], 'int8', data=input_ethosu_write_1.data) + placeholder_3 = T.decl_buffer([960], 'int8', data=input_placeholder_3.data) + ethosu_write_1 = T.decl_buffer([1024], 'int8', data=input_ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 0, 8, placeholder_3[120], 0, 0, 0, T.float32(0.5), 10, "NHWC", 96, 8, 1, "int8", 8, 8, 16, 8, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 848, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -668,8 +672,8 @@ def main(input_placeholder_3: T.Buffer((1, 7, 9, 5), "int8"), input_ethosu_write T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.Buffer([160], "uint8") buffer_1 = T.Buffer([656], "uint8") - placeholder_3 = T.Buffer([315], 'int8', data=input_placeholder_3.data) - ethosu_write_1 = T.Buffer([240], 'int8', data=input_ethosu_write_1.data) + placeholder_3 = T.decl_buffer([315], 'int8', data=input_placeholder_3.data) + ethosu_write_1 = T.decl_buffer([240], 'int8', data=input_ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 3, 5, 3, 3, 0, 5, placeholder_3[146], 0, 0, 0, T.float32(0.5), 10, "NHWC", 45, 5, 1, "int8", 3, 5, 16, 3, 0, 5, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 80, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 656, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -712,8 +716,8 @@ def main(input_placeholder_3: T.Buffer((4, 6, 8, 1), "int8"), input_ethosu_write T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.Buffer([160], "uint8") buffer_1 = T.Buffer([848], "uint8") - placeholder_3 = T.Buffer([192], 'int8', data=input_placeholder_3.data) - ethosu_write_1 = T.Buffer([768], 'int8', data=input_ethosu_write_1.data) + placeholder_3 = T.decl_buffer([192], 'int8', data=input_placeholder_3.data) + ethosu_write_1 = T.decl_buffer([768], 'int8', data=input_ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -728,8 +732,8 @@ def main(input_placeholder_3: T.Buffer((1, 24, 8), "int8"), input_ethosu_write_1 T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.Buffer([160], "uint8") buffer_1 = T.Buffer([848], "uint8") - placeholder_3 = T.Buffer([192], 'int8', data=input_placeholder_3.data) - ethosu_write_1 = T.Buffer([768], 'int8', data=input_ethosu_write_1.data) + placeholder_3 = T.decl_buffer([192], 'int8', data=input_placeholder_3.data) + ethosu_write_1 = T.decl_buffer([768], 'int8', data=input_ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -744,8 +748,8 @@ def main(input_placeholder_3: T.Buffer((192, 1), "int8"), input_ethosu_write_1: T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.Buffer([160], "uint8") buffer_1 = T.Buffer([848], "uint8") - placeholder_3 = T.Buffer([192], 'int8', data=input_placeholder_3.data) - ethosu_write_1 = T.Buffer([768], 'int8', data=input_ethosu_write_1.data) + placeholder_3 = T.decl_buffer([192], 'int8', data=input_placeholder_3.data) + ethosu_write_1 = T.decl_buffer([768], 'int8', data=input_ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -760,7 +764,7 @@ def main(placeholder_3: T.Buffer((192,), "int8"), input_ethosu_write_1: T.Buffer T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.Buffer([160], "uint8") buffer_1 = T.Buffer([848], "uint8") - ethosu_write_1 = T.Buffer([768], 'int8', data=input_ethosu_write_1.data) + ethosu_write_1 = T.decl_buffer([768], 'int8', data=input_ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index 94763c5d3fbf..d4f35c23f76e 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -35,14 +35,16 @@ class ReferenceModule: @T.prim_func def main(input_placeholder_3: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_write_1: T.Buffer((1, 16, 16, 8), "int8")) -> None: - # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer_1 = T.Buffer([384], "uint8") - placeholder_3 = T.Buffer([8192], dtype="int8", data=input_placeholder_3.data) - ethosu_write_1 = T.Buffer([2048], dtype="int8", data=input_ethosu_write_1.data) - # body + + placeholder_3 = T.decl_buffer([8192], dtype="int8", data=input_placeholder_3.data) + ethosu_write_1 = T.decl_buffer([2048], dtype="int8", data=input_ethosu_write_1.data) + placeholder_global_data = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin": True}) - placeholder_global = T.Buffer([384], "uint8", data=placeholder_global_data) + placeholder_global = T.decl_buffer([384], "uint8", data=placeholder_global_data) + + buffer_1 = T.Buffer([384], "uint8") + T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 384, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, T.int8(-1), T.int8(-1), 12, placeholder_global[304], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -79,17 +81,19 @@ def _get_func(): class WeightStream: @T.prim_func def main(input_placeholder_5: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_write_1: T.Buffer((1, 16, 16, 16), "int8")) -> None: - # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.Buffer([528], "uint8") - buffer_2 = T.Buffer([336], "uint8") - placeholder_5 = T.Buffer([8192], dtype="int8", data=input_placeholder_5.data) - ethosu_write_1 = T.Buffer([4096], dtype="int8", data=input_ethosu_write_1.data) - # body + + placeholder_5 = T.decl_buffer([8192], dtype="int8", data=input_placeholder_5.data) + ethosu_write_1 = T.decl_buffer([4096], dtype="int8", data=input_ethosu_write_1.data) + placeholder_d_global_data = T.allocate([528], "uint8", "global", annotations={"disable_lower_builtin": True}) - placeholder_d_global = T.Buffer([528], "uint8", data=placeholder_d_global_data) + placeholder_d_global = T.decl_buffer([528], "uint8", data=placeholder_d_global_data) placeholder_d_global_1_data = T.allocate([336], "uint8", "global", annotations={"disable_lower_builtin": True}) - placeholder_d_global_1 = T.Buffer([336], "uint8", data=placeholder_d_global_1_data) + placeholder_d_global_1 = T.decl_buffer([336], "uint8", data=placeholder_d_global_1_data) + + buffer = T.Buffer([528], "uint8") + buffer_2 = T.Buffer([336], "uint8") + T.evaluate(T.call_extern("ethosu_copy", buffer[0], 528, placeholder_d_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 336, placeholder_d_global_1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_d_global[0], 416, T.int8(-1), T.int8(-1), 12, placeholder_d_global[416], 112, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index 1edd840b0b0e..e5e200b4cc7c 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -198,18 +198,22 @@ class DiamondGraphTir: @T.prim_func def main(input_placeholder: T.Buffer((1, 56, 56, 96), "int8"), input_ethosu_write: T.Buffer((1, 56, 56, 24), "int8")) -> None: T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - placeholder = T.Buffer([301056], dtype='int8', data=input_placeholder.data) - ethosu_write = T.Buffer([75264], dtype='int8', data=input_ethosu_write.data) - buffer1 = T.Buffer([2848], "uint8") - buffer3 = T.Buffer([976], "uint8") + + placeholder = T.decl_buffer([301056], dtype='int8', data=input_placeholder.data) + ethosu_write = T.decl_buffer([75264], dtype='int8', data=input_ethosu_write.data) + p1_data = T.allocate([2848], "uint8", "global", annotations={"disable_lower_builtin":True}) - p1 = T.Buffer([2848], "uint8", data=p1_data) + p1 = T.decl_buffer([2848], "uint8", data=p1_data) p2_data = T.allocate([976], "uint8", "global", annotations={"disable_lower_builtin":True}) - p2 = T.Buffer([976], "uint8", data=p2_data) + p2 = T.decl_buffer([976], "uint8", data=p2_data) p5_data = T.allocate([75264], "int8", "global", annotations={"disable_lower_builtin":True}) - p5 = T.Buffer([75264], "int8", data=p5_data) + p5 = T.decl_buffer([75264], "int8", data=p5_data) p6_data = T.allocate([75264], "int8", "global", annotations={"disable_lower_builtin":True}) - p6 = T.Buffer([75264], "int8", data=p6_data) + p6 = T.decl_buffer([75264], "int8", data=p6_data) + + buffer1 = T.Buffer([2848], "uint8") + buffer3 = T.Buffer([976], "uint8") + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 2848, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 976, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 96, 56, 0, 56, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 5376, 96, 1, "int8", 56, 56, 24, 56, 0, 56, p5[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, p1[0], 2608, T.int8(-1), T.int8(-1), 12, p1[2608], 240, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) diff --git a/tests/python/unittest/test_lower_build.py b/tests/python/unittest/test_lower_build.py index e94a4f09ec56..9882b5885b05 100644 --- a/tests/python/unittest/test_lower_build.py +++ b/tests/python/unittest/test_lower_build.py @@ -60,9 +60,9 @@ def main( ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "from_legacy_te_schedule": True, "tir.noalias": True}) - A_flat = T.Buffer([16384], data=A.data) - B_flat = T.Buffer([16384], data=B.data) - C_flat = T.Buffer([16384], data=C.data) + A_flat = T.decl_buffer([16384], data=A.data) + B_flat = T.decl_buffer([16384], data=B.data) + C_flat = T.decl_buffer([16384], data=C.data) # body for x, y in T.grid(128, 128): C_flat[x * 128 + y] = 0.0 @@ -82,9 +82,9 @@ def main( ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A_flat = T.Buffer([16384], data=A.data) - B_flat = T.Buffer([16384], data=B.data) - C_flat = T.Buffer([16384], data=C.data) + A_flat = T.decl_buffer([16384], data=A.data) + B_flat = T.decl_buffer([16384], data=B.data) + C_flat = T.decl_buffer([16384], data=C.data) # body for x, y in T.grid(128, 128): C_flat[x * 128 + y] = 0.0 diff --git a/tests/python/unittest/test_te_build_lower.py b/tests/python/unittest/test_te_build_lower.py index 50d5119b43a0..6da7a2df3563 100644 --- a/tests/python/unittest/test_te_build_lower.py +++ b/tests/python/unittest/test_te_build_lower.py @@ -56,7 +56,7 @@ def test_split_uneven_unique_likely(): sch = te.create_schedule(c.op) xo, xi = sch[c].split(x, 5) stmt = tvm.lower(sch, [a, b, c])["main"].body - assert isinstance(stmt.body.body, tvm.tir.stmt.IfThenElse) + assert isinstance(stmt.body.body.body.body.body, tvm.tir.stmt.IfThenElse) if __name__ == "__main__": diff --git a/tests/python/unittest/test_te_hybrid_script.py b/tests/python/unittest/test_te_hybrid_script.py index d6b11785a4a3..61821adb841f 100644 --- a/tests/python/unittest/test_te_hybrid_script.py +++ b/tests/python/unittest/test_te_hybrid_script.py @@ -756,6 +756,8 @@ def outer_product(a, b): sch[c].vectorize(ji) sch[c].reorder(ii, io, joo, joi, ji) ir = tvm.lower(sch, [a, b, c])["main"].body + assert isinstance(ir, tvm.tir.DeclBuffer) + ir = ir.body assert isinstance(ir, tvm.tir.AttrStmt) ir = ir.body assert isinstance(ir, tvm.tir.For) @@ -777,6 +779,8 @@ def outer_product(a, b): sch = te.create_schedule(c.op) sch[c].fuse(c.op.axis[0], c.op.axis[1]) ir = tvm.lower(sch, [a, b, c])["main"].body + assert isinstance(ir, tvm.tir.DeclBuffer) + ir = ir.body assert isinstance(ir, tvm.tir.AttrStmt) ir = ir.body assert isinstance(ir, tvm.tir.For) diff --git a/tests/python/unittest/test_te_schedule.py b/tests/python/unittest/test_te_schedule.py index ed224883478e..e2811765045b 100644 --- a/tests/python/unittest/test_te_schedule.py +++ b/tests/python/unittest/test_te_schedule.py @@ -324,7 +324,7 @@ def test_legalize_invalid_attach(): s[A].compute_at(s[B], B.op.axis[1]) s[B].fuse(B.op.axis[0], B.op.axis[1]) stmt = tvm.lower(s, [A, B], simple_mode=True)["main"].body - assert isinstance(stmt, tvm.tir.stmt.For) + assert isinstance(stmt.body.body, tvm.tir.stmt.For) def test_compute_at(): diff --git a/tests/python/unittest/test_tir_analysis_verify_well_formed.py b/tests/python/unittest/test_tir_analysis_verify_well_formed.py index 4f88cc8be1e1..608fb7e3b992 100644 --- a/tests/python/unittest/test_tir_analysis_verify_well_formed.py +++ b/tests/python/unittest/test_tir_analysis_verify_well_formed.py @@ -14,9 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest + import tvm import tvm.testing from tvm.script import tir as T +from tvm.ir.instrument import pass_instrument def test_pass_simple(): @@ -54,5 +57,118 @@ def element_wise( assert not tvm.tir.analysis.verify_well_formed(element_wise, assert_mode=False) +def test_pass_buffer_usage(): + @T.prim_func + def func(a: T.handle): + # Buffer declaration as part of buffer_map + A = T.match_buffer(a, 128, "float32") + + # Buffer declaration as part of BlockNode + B = T.alloc_buffer(128, "float32") + for i in range(128): + B[i] = A[i] * 2.0 + + # Buffer declaration in a DeclBuffer node + c_data = T.allocate([128], "float32") + C = T.decl_buffer(128, "float32", data=c_data) + for i in range(128): + C[i] = B[i] * 2.0 + + assert tvm.tir.analysis.verify_well_formed(func) + + +def test_fail_implicit_buffer_alias(): + @T.prim_func + def func(A: T.Buffer([128, 128], "float32")): + # Aliased buffer usage without declaration. The `T.Buffer` in + # TVMScript does not actually make any TIR node, and does not + # count as a TIR declaration. + Alias = T.Buffer(128 * 128, "float32", data=A.data) + T.evaluate(Alias[0]) + + assert not tvm.tir.analysis.verify_well_formed(func, assert_mode=False) + + +def test_pass_explicit_buffer_alias(): + @T.prim_func + def func(A: T.Buffer([128, 128], "float32")): + # Aliased buffer usage with declaration. + Alias = T.decl_buffer(128 * 128, "float32", data=A.data) + T.evaluate(Alias[0]) + + assert tvm.tir.analysis.verify_well_formed(func) + + +def matmul(): + @T.prim_func + def func(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + return func + + +def launch_env_thread(): + @T.prim_func + def func(inputs: T.Buffer((64, 2, 4), "float32")) -> None: + bx = T.launch_thread("blockIdx.x", 64) + for i, j in T.grid(2, 4): + T.evaluate(inputs[bx, i, j]) + + return func + + +def copy_using_env_thread(): + shape = (64, 2, 4) + + @T.prim_func + def func(A: T.Buffer(shape), B: T.Buffer(shape)): + blocks, M, N = T.meta_var(shape) + + bx = T.launch_thread("blockIdx.x", blocks) + for i, j in T.grid(M, N): + B[bx, i, j] = A[bx, i, j] + + return func + + +@pass_instrument +class InstrumentWellFormed: + def run_after_pass(self, mod, info): + for func in mod.functions.values(): + tvm.tir.analysis.verify_well_formed(func) + + +@pytest.mark.parametrize( + "generator,target", + [ + (matmul, "llvm"), + pytest.param( + launch_env_thread, + "cuda", + marks=tvm.testing.Feature("cuda").marks(support_required="compile-only"), + ), + pytest.param( + copy_using_env_thread, + "cuda", + marks=tvm.testing.Feature("cuda").marks(support_required="compile-only"), + ), + ], +) +def test_well_formed_all_lowering_steps(generator, target): + func = generator() + + with tvm.transform.PassContext(instruments=[InstrumentWellFormed()]): + tvm.build(func, target=target) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 8a39337575a7..63d310aaa897 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -33,6 +33,8 @@ def test_for(): body = ib.get() assert isinstance(body, tvm.tir.Allocate) body = body.body + assert isinstance(body, tvm.tir.DeclBuffer) + body = body.body assert isinstance(body, tvm.tir.For) body = body.body assert isinstance(body, tvm.tir.SeqStmt) diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index 840a18ae6aea..ef7ce4a3d0e6 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -1543,21 +1543,14 @@ def test_cache_write_fail_invalid_storage_scope(use_block_name): sch.cache_write(block_b, 0, "test_scope") -@pytest.mark.parametrize("use_decl_buffer", [True, False]) -def test_cache_write_allocate_const(use_decl_buffer): - def apply_decl_buffer(*args, **kwargs): - if use_decl_buffer: - return T.decl_buffer(*args, **kwargs) - else: - return T.Buffer(*args, **kwargs) - +def test_cache_write_allocate_const(): @T.prim_func def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float16")): B = T.alloc_buffer([128, 128], dtype="float32") const1 = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8]) - const1_buf = apply_decl_buffer([8], dtype="float32", data=const1) + const1_buf = T.decl_buffer([8], dtype="float32", data=const1) const2 = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8]) - const2_buf = apply_decl_buffer([8], dtype="float32", data=const2) + const2_buf = T.decl_buffer([8], dtype="float32", data=const2) for i, j in T.grid(128, 128): for x in range(8): with T.block("B"): @@ -1578,9 +1571,9 @@ def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float1 A_global = T.alloc_buffer([128, 128], dtype="float32") C_global = T.alloc_buffer([128, 128], dtype="float16") const1 = T.allocate_const([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8]) - const1_buf = apply_decl_buffer([8], dtype="float32", data=const1) + const1_buf = T.decl_buffer([8], dtype="float32", data=const1) const2 = T.allocate_const([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8]) - const2_buf = apply_decl_buffer([8], dtype="float32", data=const2) + const2_buf = T.decl_buffer([8], dtype="float32", data=const2) for ax0, ax1 in T.grid(128, 128): with T.block("A_global"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py index 963d9586bcaa..e1934b1ac6e2 100644 --- a/tests/python/unittest/test_tir_schedule_compute_at.py +++ b/tests/python/unittest/test_tir_schedule_compute_at.py @@ -1745,21 +1745,14 @@ def after(A: T.Buffer((1, 3, 5, 5, 16), "float32"), C: T.Buffer((1, 6, 5, 5, 8), verify_trace_roundtrip(sch=sch, mod=before) -@pytest.mark.parametrize("use_decl_buffer", [True, False]) @pytest.mark.parametrize("use_reverse_compute_at", [True, False]) -def test_compute_at_allocate_const(use_decl_buffer, use_reverse_compute_at): - def apply_decl_buffer(*args, **kwargs): - if use_decl_buffer: - return T.decl_buffer(*args, **kwargs) - else: - return T.Buffer(*args, **kwargs) - +def test_compute_at_allocate_const(use_reverse_compute_at): @T.prim_func def before(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32")): B = T.alloc_buffer([4]) offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32", extents=[4]) - offset = apply_decl_buffer([4], data=offset_ptr) + offset = T.decl_buffer([4], data=offset_ptr) for i in range(4): with T.block("compute_B"): vi = T.axis.remap("S", [i]) @@ -1775,7 +1768,7 @@ def expected(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32")) B = T.alloc_buffer([4]) offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32", extents=[4]) - offset = apply_decl_buffer([4], data=offset_ptr) + offset = T.decl_buffer([4], data=offset_ptr) for i in range(4): with T.block("compute_B"): vi = T.axis.remap("S", [i]) @@ -1802,20 +1795,13 @@ def expected(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32")) verify_trace_roundtrip(sch=sch, mod=before) -@pytest.mark.parametrize("use_decl_buffer", [True, False]) -def test_compute_inline_allocate_const(use_decl_buffer): - def apply_decl_buffer(*args, **kwargs): - if use_decl_buffer: - return T.decl_buffer(*args, **kwargs) - else: - return T.Buffer(*args, **kwargs) - +def test_compute_inline_allocate_const(): @T.prim_func def before(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32")): B = T.alloc_buffer([4]) offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32", extents=[4]) - offset = apply_decl_buffer([4], data=offset_ptr) + offset = T.decl_buffer([4], data=offset_ptr) for i in range(4): with T.block("compute_B"): vi = T.axis.remap("S", [i]) @@ -1829,7 +1815,7 @@ def before(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32")): @T.prim_func def expected(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32")): offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32", extents=[4]) - offset = apply_decl_buffer([4], data=offset_ptr) + offset = T.decl_buffer([4], data=offset_ptr) for i, j in T.grid(4, 256): with T.block("compute_C"): vi, vj = T.axis.remap("SS", [i, j]) diff --git a/tests/python/unittest/test_tir_transform_coproc_sync.py b/tests/python/unittest/test_tir_transform_coproc_sync.py index 7dacd8e046cc..2d45118f39f2 100644 --- a/tests/python/unittest/test_tir_transform_coproc_sync.py +++ b/tests/python/unittest/test_tir_transform_coproc_sync.py @@ -51,7 +51,7 @@ def meminfo_cache(): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body - body = stmt.body.body + body = stmt.body.body.body blist = tvm.tir.stmt_list(body) assert blist[1].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_read_barrier")) @@ -112,7 +112,7 @@ def __check_list(tvm_array, py_list): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body - slist = tvm.tir.stmt_list(stmt[0].body) + slist = tvm.tir.stmt_list(stmt[0].body.body) push_st = slist[2] slist = tvm.tir.stmt_list(slist[-1]) pop_st = slist[0].body[0] diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index 20f91b639497..a7965e4db423 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -41,42 +41,10 @@ def before(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")): C[i, j] = B_new[0, j] * 2.0 def expected(input_A: T.Buffer((16, 16), "float32"), input_C: T.Buffer((16, 16), "float32")): - A = T.Buffer(256, dtype="float32", data=input_A.data) - C = T.Buffer(256, dtype="float32", data=input_C.data) + A = T.decl_buffer(256, dtype="float32", data=input_A.data) + C = T.decl_buffer(256, dtype="float32", data=input_C.data) for i in T.serial(0, 16): - B_new_data = T.allocate([16], "float32", scope="global") - B_new = T.Buffer([16], "float32", scope="global", data=B_new_data) - for j in T.serial(0, 16): - B_new[j] = A[((i * 16) + j)] + 1.0 - for j in T.serial(0, 16): - C[((i * 16) + j)] = B_new[j] * 2.0 - - -class TestElementwiseWithoutDeclBuffer(BaseCompare): - """2-d buffers are flattened to 1-d - - Like TestElementwise, but the TIR doesn't have the DeclBuffer - node. The T.Buffer declaration applies only during the - parsing the TVMScript, and doesn't occur in the TIR itself. In - this case, the allocation should be assumed to be targeting flat - memory, and should be flattened to a 1-d allocation. - """ - - def before(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")): - for i in T.serial(0, 16): - B_new_data = T.allocate([1, 16], "float32", "global") - B_new = T.Buffer([1, 16], "float32", data=B_new_data) - for j in T.serial(0, 16): - B_new[0, j] = A[i, j] + 1.0 - for j in T.serial(0, 16): - C[i, j] = B_new[0, j] * 2.0 - - def expected(input_A: T.Buffer((16, 16), "float32"), input_C: T.Buffer((16, 16), "float32")): - A = T.Buffer(256, dtype="float32", data=input_A.data) - C = T.Buffer(256, dtype="float32", data=input_C.data) - for i in T.serial(0, 16): - B_new_data = T.allocate([16], "float32", "global") - B_new = T.Buffer(16, "float32", data=B_new_data) + B_new = T.decl_buffer(16, "float32", scope="global") for j in T.serial(0, 16): B_new[j] = A[((i * 16) + j)] + 1.0 for j in T.serial(0, 16): @@ -101,8 +69,8 @@ def before(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")): C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0 def expected(input_A: T.Buffer((16, 16), "float32"), input_C: T.Buffer((16, 16), "float32")): - A = T.Buffer(256, dtype="float32", data=input_A.data) - C = T.Buffer(256, dtype="float32", data=input_C.data) + A = T.decl_buffer(256, dtype="float32", data=input_A.data) + C = T.decl_buffer(256, dtype="float32", data=input_C.data) i0 = T.env_thread("blockIdx.x") i1 = T.env_thread("threadIdx.x") @@ -111,8 +79,7 @@ def expected(input_A: T.Buffer((16, 16), "float32"), input_C: T.Buffer((16, 16), T.launch_thread(i0, 4) T.launch_thread(i1, 2) T.launch_thread(i2, 2) - B_data = T.allocate([16], "float32", scope="local") - B = T.Buffer([16], "float32", scope="local", data=B_data) + B = T.decl_buffer(16, "float32", scope="local") for j in range(0, 16): B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0 for j in range(0, 16): @@ -136,12 +103,11 @@ def before(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: def expected(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: input_A = T.match_buffer(a, (n, m), "float32") input_C = T.match_buffer(c, (n, m), "float32") - A = T.Buffer(n * m, "float32", data=input_A.data) - C = T.Buffer(n * m, "float32", data=input_C.data) + A = T.decl_buffer(n * m, "float32", data=input_A.data) + C = T.decl_buffer(n * m, "float32", data=input_C.data) for i in range(0, n): - B_data = T.allocate([m], "float32", scope="global") - B = T.Buffer([m], "float32", scope="global", data=B_data) + B = T.decl_buffer(m, "float32", scope="global") for j in range(0, m): B[j] = A[i * m + j] + 1.0 for j in range(0, m): @@ -161,8 +127,8 @@ def before(a: T.handle, b: T.handle, n: T.int32) -> None: def expected(a: T.handle, b: T.handle, n: T.int32) -> None: input_A = T.match_buffer(a, (32, n, n), "float32") input_B = T.match_buffer(b, (32, n, n), "float32") - A = T.Buffer(n * n * 32, "float32", data=input_A.data) - B = T.Buffer(n * n * 32, "float32", data=input_B.data) + A = T.decl_buffer(n * n * 32, "float32", data=input_A.data) + B = T.decl_buffer(n * n * 32, "float32", data=input_B.data) for i in range(0, n * n * 32): B[i] = A[i] @@ -185,8 +151,8 @@ def before(a: T.handle, b: T.handle, n: T.int32) -> None: def expected(a: T.handle, b: T.handle, n: T.int32) -> None: input_A = T.match_buffer(a, (32, n, n), "float32") input_B = T.match_buffer(b, (32, n, n), "float32") - A = T.Buffer(n * n * 32, "float32", data=input_A.data) - B = T.Buffer(n * n * 32, "float32", data=input_B.data) + A = T.decl_buffer(n * n * 32, "float32", data=input_A.data) + B = T.decl_buffer(n * n * 32, "float32", data=input_B.data) for bx, tx in T.grid((n * n + 1) // 2, 64): if bx * 64 + tx < n * n * 32: @@ -205,14 +171,12 @@ def before(A: T.Buffer((4, 32), "float32"), D: T.Buffer((4, 32), "float32")): D[i, j] = C[i, j] * 2.0 def expected(input_A: T.Buffer((4, 32), "float32"), input_D: T.Buffer((4, 32), "float32")): - A = T.Buffer(128, "float32", data=input_A.data) - D = T.Buffer(128, "float32", data=input_D.data) + A = T.decl_buffer(128, "float32", data=input_A.data) + D = T.decl_buffer(128, "float32", data=input_D.data) for i, j in T.grid(4, 32): - B_data = T.allocate([128], "float32", scope="global") - B = T.Buffer([128], "float32", scope="global", data=B_data) - C_data = T.allocate([128], "float32", scope="global") - C = T.Buffer([128], "float32", scope="global", data=C_data) + B = T.decl_buffer(128, "float32", scope="global") + C = T.decl_buffer(128, "float32", scope="global") B[i * 32 + j] = A[i * 32 + j] + 1.0 C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j] D[i * 32 + j] = C[i * 32 + j] * 2.0 @@ -231,11 +195,10 @@ def before(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")): C[i0 * 4 + i1, j] = B_1[i1, j] * 2.0 def expected(input_A: T.Buffer((16, 16), "float32"), input_C: T.Buffer((16, 16), "float32")): - A = T.Buffer(256, dtype="float32", data=input_A.data) - C = T.Buffer(256, dtype="float32", data=input_C.data) + A = T.decl_buffer(256, dtype="float32", data=input_A.data) + C = T.decl_buffer(256, dtype="float32", data=input_C.data) for i0 in T.serial(0, 4): - B_new_data = T.allocate([68], "float32", scope="global") - B_new = T.Buffer([68], "float32", scope="global", data=B_new_data) + B_new = T.decl_buffer(68, "float32", scope="global") for i1 in T.serial(0, 4): for j in T.serial(0, 16): B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0 @@ -252,8 +215,8 @@ def before(A: T.Buffer(10, "bool"), B: T.Buffer(10, "bool")) -> None: B[i0] = A[i0] def expected(input_A: T.Buffer(10, "bool"), input_B: T.Buffer(10, "bool")) -> None: - A = T.Buffer(10, dtype="int8", data=input_A.data) - B = T.Buffer(10, dtype="int8", data=input_B.data) + A = T.decl_buffer(10, dtype="int8", data=input_A.data) + B = T.decl_buffer(10, dtype="int8", data=input_B.data) # body for i0 in T.serial(10): B[i0] = T.cast(T.cast(A[i0], "bool"), "int8") @@ -329,8 +292,7 @@ def before(): T.evaluate(A[i0, i1, i2, i3, i4, i5]) def expected(): - A_data = T.allocate([30, 1001], dtype="float32", scope="global") - A = T.Buffer([30, 1001], dtype="float32", scope="global", axis_separators=[1], data=A_data) + A = T.decl_buffer([30, 1001], axis_separators=[1], dtype="float32", scope="global") for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13): T.evaluate(A[i0 * 15 + i1 * 5 + i2, i3 * 143 + i4 * 13 + i5]) diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py index 593f9447d44c..46ce18fbb813 100644 --- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py +++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py @@ -108,8 +108,12 @@ def get_vthread(name): ) )["main"] - assert list(stmt.body.body.extents) == [A_expected_alloc] - assert list(stmt.body.body.body.body.extents) == [C_expected_alloc] + A_alloc = stmt.body.body + assert A_alloc.buffer_var.name == "A" + assert list(A_alloc.extents) == [A_expected_alloc] + C_alloc = A_alloc.body.body.body.body + assert C_alloc.buffer_var.name == "C" + assert list(C_alloc.extents) == [C_expected_alloc] def test_vthread_if_then_else(): @@ -132,8 +136,8 @@ def test_vthread_if_then_else(): tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt).with_attr("global_symbol", "main")) )["main"] - assert stmt.body.body.body[0].else_case != None - assert stmt.body.body.body[1].else_case == None + assert stmt.body.body.body.body[0].else_case != None + assert stmt.body.body.body.body[1].else_case == None def test_vthread_simplified(): diff --git a/tests/python/unittest/test_tir_transform_lift_attr_scope.py b/tests/python/unittest/test_tir_transform_lift_attr_scope.py index 65e317dfbcb8..9af5bf7e64a4 100644 --- a/tests/python/unittest/test_tir_transform_lift_attr_scope.py +++ b/tests/python/unittest/test_tir_transform_lift_attr_scope.py @@ -38,9 +38,9 @@ def test_coproc_lift(): body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) - body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"] + func = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"] - assert body.body.body.node == cp + assert func.body.body.node == cp # only able to lift to the common pattern of the last two fors. ib = tvm.tir.ir_builder.create() @@ -58,10 +58,10 @@ def test_coproc_lift(): body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) - body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"] + func = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"] - assert body.body.body.body[1].node == cp - assert len(body.body.body.body) == 2 + assert func.body.body.body.body[1].node == cp + assert len(func.body.body.body.body) == 2 if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index aa11ae5a5f7b..61b4f5a7bff8 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -16,7 +16,7 @@ # under the License. import tvm import tvm.testing -from tvm import te +from tvm import te, tir from tvm.ir.module import IRModule from tvm.script import tir as T import numpy @@ -181,7 +181,11 @@ def test_vectorize(): s[C].bind(tx, te.thread_axis("threadIdx.x")) s[C].vectorize(x) stmt = tvm.lower(s, [A, B], name="main")["main"] - body = stmt.body.body.body.body + + body = stmt + while not isinstance(body, tir.IfThenElse): + body = body.body + assert x.var.name not in str(body.condition) assert any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp))) @@ -232,7 +236,11 @@ def test_thread_axis2(): s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) stmt = tvm.lower(s, [A, B], name="main")["main"] - for_body = stmt.body.body.body.body[0] + + while not isinstance(stmt, tir.SeqStmt): + stmt = stmt.body + + for_body = stmt[0] assert "threadIdx" not in str(for_body.extent) @@ -566,7 +574,7 @@ def test_explicit_partition_hint(): mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.LoopPartition()(mod) mod = tvm.tir.transform.Simplify()(mod) - assert tvm.ir.structural_equal(mod["main"], partitioned_concat) + tvm.ir.assert_structural_equal(mod["main"], partitioned_concat) def partition_from_scheduled_tir(prim_func, pass_cfg): @@ -628,7 +636,7 @@ def test_condition_mutually_exclusive(): mod = partition_from_scheduled_tir( concat_func_3, {"tir.LoopPartition": {"partition_const_loop": True}} ) - assert tvm.ir.structural_equal( + tvm.ir.assert_structural_equal( mod["main"], partitioned_concat_3.with_attr("global_symbol", "main") ) @@ -680,7 +688,7 @@ def partitioned_main( mod = tvm.tir.transform.UnrollLoop()(mod) mod = tvm.tir.transform.RemoveNoOp()(mod) mod = tvm.tir.transform.Simplify()(mod) - assert tvm.ir.structural_equal(mod["main"], partitioned_main.with_attr("global_symbol", "main")) + tvm.ir.assert_structural_equal(mod["main"], partitioned_main.with_attr("global_symbol", "main")) def test_loop_partition_recursive_unroll_hint(): @@ -711,32 +719,28 @@ def main(): @T.prim_func def partitioned_main(): - placeholder_0_dm = T.allocate([16384], "int8", "global") - placeholder_0_dm_1 = T.Buffer([16384], dtype="int8", data=placeholder_0_dm) + placeholder_0_dm = T.decl_buffer([16384], "int8") for i3_0 in T.unroll(2): for i2_0 in T.unroll(2): - pad_temp = T.allocate([4096], "int8", "global") - pad_temp_1 = T.Buffer([4096], dtype="int8", data=pad_temp) + pad_temp = T.decl_buffer([4096], "int8") for ax0, ax1, ax2 in T.grid(16, 16, 16): if 6 <= i2_0 * 4 + ax0 and 6 <= i3_0 * 4 + ax1: - pad_temp_1[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[ + pad_temp[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm[ i2_0 * 2048 + ax0 * 512 + i3_0 * 64 + ax1 * 16 + ax2 ] for i2_0 in T.unroll(2): - pad_temp_2 = T.allocate([4096], "int8", "global") - pad_temp_3 = T.Buffer([4096], dtype="int8", data=pad_temp_2) + pad_temp_2 = T.decl_buffer([4096], "int8") for ax0, ax1, ax2 in T.grid(16, 16, 16): if 6 <= i2_0 * 4 + ax0: - pad_temp_3[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[ + pad_temp_2[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm[ i2_0 * 2048 + ax0 * 512 + ax1 * 16 + ax2 + 128 ] for i3_0 in T.unroll(2): for i2_0 in T.unroll(2): - pad_temp_4 = T.allocate([4096], "int8", "global") - pad_temp_5 = T.Buffer([4096], dtype="int8", data=pad_temp_4) + pad_temp_4 = T.decl_buffer([4096], "int8") for ax0, ax1, ax2 in T.grid(16, 16, 16): if 6 <= i2_0 * 4 + ax0 and i3_0 * 4 + ax1 < 14: - pad_temp_5[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[ + pad_temp_4[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm[ i2_0 * 2048 + ax0 * 512 + i3_0 * 64 + ax1 * 16 + ax2 + 192 ] @@ -749,7 +753,7 @@ def partitioned_main(): } }, ) - assert tvm.ir.structural_equal(mod["main"], partitioned_main.with_attr("global_symbol", "main")) + tvm.ir.assert_structural_equal(mod["main"], partitioned_main.with_attr("global_symbol", "main")) def test_loop_partition_keep_loop_annotations(): @@ -783,7 +787,7 @@ def after(A: T.Buffer(160, "int32"), B: T.Buffer(160, "int32")) -> None: } }, ) - assert tvm.ir.structural_equal(mod["main"], after.with_attr("global_symbol", "main")) + tvm.ir.assert_structural_equal(mod["main"], after.with_attr("global_symbol", "main")) def test_loop_partition_with_unit_loop_in_condition(): @@ -831,7 +835,7 @@ def after( } }, ) - assert tvm.ir.structural_equal(mod["main"], after.with_attr("global_symbol", "main")) + tvm.ir.assert_structural_equal(mod["main"], after.with_attr("global_symbol", "main")) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py index f797d35d47ca..08796871c89d 100644 --- a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py +++ b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py @@ -70,10 +70,10 @@ def expected(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")): T.reinterpret("handle", T.uint64(0)), ): mask_data = T.allocate([1], "uint32", "local") - mask = T.Buffer(1, "uint32", data=mask_data, scope="local") + mask = T.decl_buffer(1, "uint32", data=mask_data, scope="local") t0_data = T.allocate([1], "float32", "local") - t0 = T.Buffer(1, data=t0_data, scope="local") + t0 = T.decl_buffer(1, data=t0_data, scope="local") reduce[0] = A_flat[0] mask[0] = T.tvm_warp_activemask() @@ -96,7 +96,7 @@ def expected(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")): class TestBasicWithDeclBuffer(BaseCompare): def before(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) - A_flat = T.Buffer(4096, data=A.data) + A_flat = T.decl_buffer(4096, data=A.data) for i in range(128): threadIdx_x = T.launch_thread("threadIdx.x", 32) @@ -120,7 +120,7 @@ def before(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")): def expected(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) - A_flat = T.Buffer(4096, data=A.data) + A_flat = T.decl_buffer(4096, data=A.data) for i in range(128): threadIdx_x = T.launch_thread("threadIdx.x", 32) @@ -133,10 +133,10 @@ def expected(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")): T.reinterpret("handle", T.uint64(0)), ): mask_data = T.allocate([1], "uint32", "local") - mask = T.Buffer(1, "uint32", data=mask_data, scope="local") + mask = T.decl_buffer(1, "uint32", data=mask_data, scope="local") t0_data = T.allocate([1], "float32", "local") - t0 = T.Buffer(1, data=t0_data, scope="local") + t0 = T.decl_buffer(1, data=t0_data, scope="local") reduce[0] = A_flat[0] mask[0] = T.tvm_warp_activemask() @@ -159,16 +159,16 @@ def expected(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")): class TestReduceSummation(BaseCompare): def before(A: T.Buffer((128, 128), "float32"), B: T.Buffer(128, "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) - A_flat = T.Buffer((16384,), data=A.data) + A_flat = T.decl_buffer(16384, data=A.data) for i in range(128): threadIdx_x = T.launch_thread("threadIdx.x", 32) normal_reduce_data = T.allocate([1], "float32", "local") - normal_reduce = T.Buffer(1, data=normal_reduce_data, scope="local") + normal_reduce = T.decl_buffer(1, data=normal_reduce_data, scope="local") reduce_data = T.allocate([1], "float32", "local") - reduce = T.Buffer(1, data=reduce_data, scope="local") + reduce = T.decl_buffer(1, data=reduce_data, scope="local") normal_reduce[0] = T.float32(0) @@ -192,16 +192,16 @@ def before(A: T.Buffer((128, 128), "float32"), B: T.Buffer(128, "float32")): def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer(128, "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) - A_flat = T.Buffer(16384, data=A.data) + A_flat = T.decl_buffer(16384, data=A.data) for i in range(128): threadIdx_x = T.launch_thread("threadIdx.x", 32) normal_reduce_data = T.allocate([1], "float32", "local") - normal_reduce = T.Buffer(1, data=normal_reduce_data, scope="local") + normal_reduce = T.decl_buffer(1, data=normal_reduce_data, scope="local") reduce_data = T.allocate([1], "float32", "local") - reduce = T.Buffer(1, data=reduce_data, scope="local") + reduce = T.decl_buffer(1, data=reduce_data, scope="local") normal_reduce[0] = T.float32(0) for ko in range(4): @@ -212,10 +212,10 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer(128, "float32")): T.reinterpret("handle", T.uint64(0)), ): mask_data = T.allocate([1], "uint32", "local") - mask = T.Buffer(1, "uint32", data=mask_data, scope="local") + mask = T.decl_buffer(1, "uint32", data=mask_data, scope="local") t0_data = T.allocate([1], "float32", "local") - t0 = T.Buffer(1, data=t0_data, scope="local") + t0 = T.decl_buffer(1, data=t0_data, scope="local") reduce[0] = normal_reduce[0] mask[0] = T.tvm_warp_activemask() diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index c7e90d4e7dc9..99ccc5556585 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -18,7 +18,7 @@ import pytest import tvm import tvm.testing -from tvm import te +from tvm import te, tir from tvm.contrib.nvcc import have_fp16 @@ -55,9 +55,13 @@ def test_lower_warp_memory_local_scope(): mod = _run_passes(mod) fdevice = mod["f_kernel"] - allocate = fdevice.body.body + + allocate = fdevice + while not isinstance(allocate, tir.Allocate): + allocate = allocate.body + assert allocate.buffer_var.type_annotation.storage_scope == "local" - assert fdevice.body.body.extents[0].value == 2 + assert allocate.extents[0].value == 2 @tvm.testing.requires_cuda diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index c03dd7a5291d..e2641a65f287 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -170,6 +170,8 @@ def check(m, target_bits, target_dtype): B = te.compute((), lambda *idx: te.sum(A[k], axis=k), name="B") s = te.create_schedule(B.op) stmt = lower_sch(s, [A, B], target_bits) + while isinstance(stmt, tvm.tir.DeclBuffer): + stmt = stmt.body assert stmt[1].loop_var.dtype == target_dtype # i32 -> i32 @@ -221,6 +223,8 @@ def check(shapex, shapey, target_bits, target_dtype): func = mod["main"] z = engine.lower(func, "llvm") stmt = lower_sch(z.schedule, tuple(z.inputs) + tuple(z.outputs), 32) + while isinstance(stmt, tvm.tir.DeclBuffer): + stmt = stmt.body # outer loop assert stmt.loop_var.dtype == target_dtype # inner loop @@ -262,7 +266,7 @@ def check(shape, index, target_bits, target_dtype): func = mod["main"] z = engine.lower(func, "llvm") stmt = lower_sch(z.schedule, tuple(z.inputs) + tuple(z.outputs), 32) - assert stmt.value.indices[0].dtype == target_dtype + assert stmt.body.body.value.indices[0].dtype == target_dtype check( (const(2**16, "int64"), const(2**15 + 1, "int64")), diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index f09645462366..4fd260d32da8 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -53,7 +53,7 @@ def test_flatten_prefetch(): mod = tvm.transform.Sequential( [tvm.tir.transform.StorageFlatten(64), tvm.tir.transform.Simplify()] )(mod) - stmt = mod["main"].body + stmt = mod["main"].body.body assert stmt.extent.value == 2 assert isinstance(stmt.body, tvm.tir.For) assert stmt.body.extent.value == 2 @@ -80,7 +80,7 @@ def test_flatten_storage_align(): [tvm.tir.transform.StorageFlatten(64), tvm.tir.transform.Simplify()] )(mod) - stmt = mod["main"].body + stmt = mod["main"].body.body.body assert stmt.extents[0].value == 17 * 8 @@ -114,9 +114,9 @@ def main(A_param: T.handle, C_param: T.handle): ] )(mod) - stmt = mod["main"].body - assert isinstance(stmt.body, tvm.tir.Allocate) - assert list(stmt.body.extents) == [8] + stmt = mod["main"].body.body.body.body + assert isinstance(stmt, tvm.tir.Allocate) + assert list(stmt.extents) == [8] mod = tvm.tir.transform.ThreadSync("shared")(mod) f = mod["main"] diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index 197e81818ee3..880a9caf0dfc 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -414,7 +414,9 @@ def get_mod(kind="serial"): # } # } assert isinstance(body.body.body, tvm.tir.Allocate) # j - assert isinstance(body.body.body.body, tvm.tir.Allocate) # A + assert isinstance(body.body.body.body, tvm.tir.DeclBuffer) + assert isinstance(body.body.body.body.body, tvm.tir.Allocate) # A + assert isinstance(body.body.body.body.body.body, tvm.tir.DeclBuffer) mod = get_mod(kind="serial") # for (i, 0, n) { @@ -438,7 +440,9 @@ def get_mod(kind="serial"): # } # } assert isinstance(body.body, tvm.tir.Allocate) # j - assert isinstance(body.body.body, tvm.tir.Allocate) # A + assert isinstance(body.body.body, tvm.tir.DeclBuffer) + assert isinstance(body.body.body.body, tvm.tir.Allocate) # A + assert isinstance(body.body.body.body.body, tvm.tir.DeclBuffer) def test_inplace_rule2(scope_tb="local_TB2", max_bits=1024 * 1024 * 1024): diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index 571927dffe6e..2cfc65aae069 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -57,7 +57,7 @@ def test_thread_storage_sync(): func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None) mod = run_passes(func) f = mod["test_kernel"] - body_list = tvm.tir.stmt_list(f.body.body.body) + body_list = tvm.tir.stmt_list(f.body.body.body.body.body.body) assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync"))