diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py index 91f5512453fb..dcda757e2c3c 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py @@ -41,7 +41,7 @@ def get_binary_elementwise_params( ------- SerialBinaryElementwise The parameters needed to construct a binary elementwise operator. - output_pointer : tvm.tir.Var + output_buffer : tvm.tir.Buffer The output pointer of the binary elementwise operation. replace_pointer : tvm.tir.Var The output pointer of the DMA write operation, which is to replace @@ -56,17 +56,17 @@ def get_binary_elementwise_params( _, _, _, _, _, inner = get_outer_loops(body, "NHWC") # loads = [input, input, LUT, LUT] loads = get_loads(inner) - input_pointer = loads[0].buffer.data - input_pointer1 = loads[1].buffer.data + input_buffer = loads[0].buffer + input_buffer1 = loads[1].buffer if reversed_operands: - input_pointer, input_pointer1 = input_pointer1, input_pointer - output_pointer = inner.buffer.data + input_buffer, input_buffer1 = input_buffer1, input_buffer + output_buffer = inner.buffer # Get feature map info - serial_ifm, _ = get_ifm_params(input_pointer, producers_consumers, stmt) - serial_ifm2, _ = get_ifm_params(input_pointer1, producers_consumers, stmt) - serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params( - output_pointer, producers_consumers, stmt + serial_ifm, _ = get_ifm_params(input_buffer, producers_consumers, stmt) + serial_ifm2, _ = get_ifm_params(input_buffer1, producers_consumers, stmt) + serial_ofm, serial_block_config, replace_buffer, is_allocator = get_ofm_params( + output_buffer, producers_consumers, stmt ) # Get activation info serial_activation = SerialActivation( @@ -87,7 +87,7 @@ def get_binary_elementwise_params( block_config=serial_block_config, rescale_config=rescale_config, ), - output_pointer, - replace_pointer, + output_buffer, + replace_buffer, is_allocator, ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py index 2358e5a221bb..caf0ef5b95b3 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py @@ -40,9 +40,9 @@ def get_conv2d_params(stmt, producers_consumers): ------- Serial2DConvolution The parameters needed to construct a 2D convolution. - output_pointer : tvm.tir.Var + output_buffer : tvm.tir.Buffer The output pointer of the convolution operation. - replace_pointer : tvm.tir.Var + replace_buffer : tvm.tir.Buffer The output pointer of the DMA write operation, which is to replace the convolution output pointer. is_allocator : bool @@ -60,12 +60,12 @@ def get_conv2d_params(stmt, producers_consumers): loads = get_loads(rc.body) # stores = [output] stores = get_stores(rc.body) - input_pointer = loads[1].buffer.data - output_pointer = stores[0].buffer.data + input_buffer = loads[1].buffer + output_buffer = stores[0].buffer # Get feature map info - serial_ifm, serial_padding = get_ifm_params(input_pointer, producers_consumers, stmt) - serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params( - output_pointer, producers_consumers, stmt + serial_ifm, serial_padding = get_ifm_params(input_buffer, producers_consumers, stmt) + serial_ofm, serial_block_config, replace_buffer, is_allocator = get_ofm_params( + output_buffer, producers_consumers, stmt ) # Get kernel info serial_kernel = SerialKernel( @@ -157,7 +157,7 @@ def get_conv2d_params(stmt, producers_consumers): upscale=attrs["upscale"], block_config=serial_block_config, ), - output_pointer, - replace_pointer, + output_buffer, + replace_buffer, is_allocator, ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py index 5878c2a7e09c..44af7e7647c6 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py @@ -47,11 +47,11 @@ def get_depthwise_conv2d_params( ------- Serial2DDepthwise The parameters needed to construct a 2D depthwise. - output_pointer : tvm.tir.Var - The output pointer of the convolution operation. - replace_pointer : tvm.tir.Var - The output pointer of the DMA write operation, which is to replace - the convolution output pointer. + output_buffer : tvm.tir.Buffer + The output buffer of the convolution operation. + replace_buffer : tvm.tir.Buffer + The output buffer of the DMA write operation, which is to replace + the convolution output buffer. is_allocator : bool Whether this operator allocates its output. @@ -64,12 +64,12 @@ def get_depthwise_conv2d_params( loads = get_loads(rw.body) # stores = [output] stores = get_stores(rw.body) - input_pointer = loads[1].buffer.data - output_pointer = stores[0].buffer.data + input_buffer = loads[1].buffer + output_buffer = stores[0].buffer # Get feature map info - serial_ifm, serial_padding = get_ifm_params(input_pointer, producers_consumers, stmt) - serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params( - output_pointer, producers_consumers, stmt + serial_ifm, serial_padding = get_ifm_params(input_buffer, producers_consumers, stmt) + serial_ofm, serial_block_config, replace_buffer, is_allocator = get_ofm_params( + output_buffer, producers_consumers, stmt ) # Get kernel info serial_kernel = SerialKernel( @@ -113,7 +113,7 @@ def get_depthwise_conv2d_params( upscale="NONE", block_config=serial_block_config, ), - output_pointer, - replace_pointer, + output_buffer, + replace_buffer, is_allocator, ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py index 82485db65866..8d819cc318c8 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py @@ -34,21 +34,21 @@ def get_pad_params(stmt): ------- pad : SerialPadding The serializable padding. - input_pointer : tvm.tir.Var - The pointer consumed by the operation. - output_pointer : tvm.tir.Var - The pointer produced by the operation. + input_buffer: tvm.tir.Buffer + The buffer consumed by the operation. + output_buffer: tvm.tir.Buffer + The buffer produced by the operation. """ _, body = get_op_attrs(stmt) n, h, w, c, _, inner = get_outer_loops(body, "NHWC") - output_pointer = inner.buffer.data + output_buffer = inner.buffer pad = SerialPadding(top=0, left=0, bottom=0, right=0) if isinstance(inner.value, tvm.tir.Call): - input_pointer = inner.value.args[1].buffer.data + input_buffer = inner.value.args[1].buffer else: - input_pointer = inner.value.buffer.data - return pad, input_pointer, output_pointer + input_buffer = inner.value.buffer + return pad, input_buffer, output_buffer padded_shape = [n.extent, h.extent, w.extent, c.extent] @@ -72,8 +72,8 @@ def _visit(expr): tvm.tir.stmt_functor.post_order_visit(cond, _visit) return ( pad, - input_pointer, - output_pointer, + input_buffer, + output_buffer, ) @@ -87,19 +87,19 @@ def get_upscale_params(stmt): Returns ------- - input_pointer : tvm.tir.Var - The pointer consumed by the operation. - output_pointer : tvm.tir.Var - The pointer produced by the operation. + input_buffer: tvm.tir.Buffer + The buffer consumed by the operation. + output_buffer: tvm.tir.Buffer + The buffer produced by the operation. """ _, body = get_op_attrs(stmt) _, _, _, _, _, inner = get_outer_loops(body, "NHWC") if isinstance(inner.value, tvm.tir.Call): - input_pointer = inner.value.args[1].buffer.data + input_buffer = inner.value.args[1].buffer else: - input_pointer = inner.value.buffer.data - output_pointer = inner.buffer.data - return (input_pointer, output_pointer) + input_buffer = inner.value.buffer + output_buffer = inner.buffer + return (input_buffer, output_buffer) def get_convert_to_nhwc_params(stmt): @@ -114,10 +114,10 @@ def get_convert_to_nhwc_params(stmt): ------- int The true number of channels. - input_pointer : tvm.tir.Var - The pointer consumed by the operation. - output_pointer : tvm.tir.Var - The pointer produced by the operation. + input_buffer: tvm.tir.Buffer + The buffer consumed by the operation. + output_buffer: tvm.tir.Buffer + The buffer produced by the operation. """ attrs, body = get_op_attrs(stmt) @@ -127,12 +127,12 @@ def get_convert_to_nhwc_params(stmt): # compute that is deemed uneccesary isn't removed by TVM. if attrs["layout"] == "NHCWB16": inner = inner.body - input_pointer = inner.value.b.buffer.data + input_buffer = inner.value.b.buffer else: - input_pointer = inner.value.buffer.data + input_buffer = inner.value.buffer - output_pointer = inner.buffer.data - return c.extent, input_pointer, output_pointer + output_buffer = inner.buffer + return c.extent, input_buffer, output_buffer def get_convert_to_nhcwb16_params(stmt): @@ -147,24 +147,24 @@ def get_convert_to_nhcwb16_params(stmt): ------- out_channels : int The true number of channels. - input_pointer : tvm.tir.Var - The pointer consumed by the operation. - output_pointer : tvm.tir.Var - The pointer produced by the operation. + input_buffer : tvm.tir.Buffer + The buffer consumed by the operation. + output_buffer : tvm.tir.Buffer + The buffer produced by the operation. """ attrs, body = get_op_attrs(stmt) _, _, _, c, b, inner = get_outer_loops(body, attrs["layout"]) - output_pointer = inner.buffer.data + output_buffer = inner.buffer if isinstance(inner.value, tvm.tir.Call): cond = inner.value.args[0] out_channels = cond.b.value - input_pointer = inner.value.args[1].buffer.data + input_buffer = inner.value.args[1].buffer else: - input_pointer = inner.value.buffer.data + input_buffer = inner.value.buffer out_channels = c.extent * b.extent if attrs["layout"] == "NHCWB16" else c.extent - return out_channels, input_pointer, output_pointer + return out_channels, input_buffer, output_buffer class Tiles(NamedTuple): @@ -298,16 +298,16 @@ def get_read_params(stmt): ------- SerialFeatureMap The serializable feature map. - input_pointer : tvm.tir.Var - The pointer consumed by the operation. - output_pointer : tvm.tir.Var - The pointer produced by the operation. + input_buffer: tvm.tir.Buffer + The buffer consumed by the operation. + output_buffer: tvm.tir.Buffer + The buffer produced by the operation. """ attrs, body = get_op_attrs(stmt) _, h, w, c, _, inner = get_outer_loops(body, attrs["layout"]) - input_pointer = inner.value.buffer.data - output_pointer = inner.buffer.data + input_buffer = inner.value.buffer + output_buffer = inner.buffer # Needed for stride calculation, can replace with # inner.value.buffer.strides in future. @@ -337,8 +337,8 @@ def get_read_params(stmt): stride_w=strides[1], stride_c=strides[2], ), - input_pointer, - output_pointer, + input_buffer, + output_buffer, ) @@ -354,16 +354,16 @@ def get_write_params(stmt): ------- SerialFeatureMap The serializable feature map. - input_pointer : tvm.tir.Var - The pointer consumed by the operation. - output_pointer : tvm.tir.Var - The pointer produced by the operation. + input_buffer: tvm.tir.Buffer + The buffer consumed by the operation. + output_buffer: tvm.tir.Buffer + The buffer produced by the operation. """ attrs, body = get_op_attrs(stmt) _, h, w, c, _, inner = get_outer_loops(body, attrs["layout"]) - input_pointer = inner.value.buffer.data - output_pointer = inner.buffer.data + input_buffer = inner.value.buffer.data + output_buffer = inner.buffer # Needed for stride calculation, can replace with # inner.value.buffer.strides in future. @@ -402,18 +402,18 @@ def get_write_params(stmt): stride_c=strides[2], ), block_config, - input_pointer, - output_pointer, + input_buffer, + output_buffer, ) -def get_ifm_params(pointer, producers_consumers, stmt): +def get_ifm_params(buffer, producers_consumers, stmt): """Get the parameters associated with the DMA capabilities for an IFM. Parameters ---------- - pointer : tvm.tir.Var - The pointer that the IFM DMA pipeline produces. + buffer: tvm.tir.Buffer + The buffer that the IFM DMA pipeline produces. producers_consumers: ProducersConsumers It associates pointers with the loop nest that produces their values and with the loop nest that consumes their values. @@ -426,13 +426,13 @@ def get_ifm_params(pointer, producers_consumers, stmt): The serializable padding. """ - pad = producers_consumers.get_producer(pointer, stmt) - serial_padding, input_pointer, _ = get_pad_params(pad) - upscale = producers_consumers.get_producer(input_pointer, pad) - input_pointer, _ = get_upscale_params(upscale) - convert_to_nhwc = producers_consumers.get_producer(input_pointer, upscale) - in_channels, input_pointer, _ = get_convert_to_nhwc_params(convert_to_nhwc) - read = producers_consumers.get_producer(input_pointer, convert_to_nhwc) + pad = producers_consumers.get_producer(buffer, stmt) + serial_padding, input_buffer, _ = get_pad_params(pad) + upscale = producers_consumers.get_producer(input_buffer, pad) + input_buffer, _ = get_upscale_params(upscale) + convert_to_nhwc = producers_consumers.get_producer(input_buffer, upscale) + in_channels, input_buffer, _ = get_convert_to_nhwc_params(convert_to_nhwc) + read = producers_consumers.get_producer(input_buffer, convert_to_nhwc) serial_ifm, _, _ = get_read_params(read) serial_ifm.channels = in_channels @@ -479,13 +479,13 @@ def _get_buffer_var(stmt): return serial_ifm, serial_padding -def get_ofm_params(pointer, producers_consumers, stmt): +def get_ofm_params(buffer, producers_consumers, stmt): """Get the parameters associated with the DMA capabilities for an OFM. Parameters ---------- - pointer : tvm.tir.Var - The pointer that the OFM DMA pipeline consumes. + buffer: tvm.tir.Buffer + The buffer that the OFM DMA pipeline consumes. producers_consumers: ProducersConsumers It associates pointers with the loop nest that produces their values and with the loop nest that consumes their values. @@ -496,20 +496,20 @@ def get_ofm_params(pointer, producers_consumers, stmt): The serializable OFM. serial_block_config : SerialBlockConfig The serializable block config. - output_pointer : tvm.tir.Var - The pointer that the OFM DMA pipeline produces. + output_buffer: tvm.tir.Buffer + The buffer that the OFM DMA pipeline produces. is_allocator : bool Whether this operator allocates its output. """ - convert_to_nhcwb16 = producers_consumers.get_consumer(pointer, stmt) - out_channels, _, output_pointer = get_convert_to_nhcwb16_params(convert_to_nhcwb16) - write = producers_consumers.get_consumer(output_pointer, convert_to_nhcwb16) - serial_ofm, serial_block_config, _, output_pointer = get_write_params(write) + convert_to_nhcwb16 = producers_consumers.get_consumer(buffer, stmt) + out_channels, _, output_buffer = get_convert_to_nhcwb16_params(convert_to_nhcwb16) + write = producers_consumers.get_consumer(output_buffer, convert_to_nhcwb16) + serial_ofm, serial_block_config, _, output_buffer = get_write_params(write) is_allocator = True - producer = producers_consumers.get_producer(output_pointer, write) + producer = producers_consumers.get_producer(output_buffer, write) if producer is None or producer != write: is_allocator = False serial_ofm.channels = out_channels - return serial_ofm, serial_block_config, output_pointer, is_allocator + return serial_ofm, serial_block_config, output_buffer, is_allocator diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/identity.py b/python/tvm/relay/backend/contrib/ethosu/tir/identity.py index 9610c8dd3cdc..a92d2d399c31 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/identity.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/identity.py @@ -30,7 +30,9 @@ from .producers_consumers import ProducersConsumers -def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatureMap, tvm.tir.Var]: +def _get_feature_map( + stmt: tvm.tir.AttrStmt, fm_type: str +) -> Tuple[SerialFeatureMap, tvm.tir.Buffer]: """Get the feature map parameters from a loop nest of any shape (as long there are at most 4 nested loops). @@ -45,8 +47,8 @@ def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatur ------- SerialFeatureMap The serializable feature map. - output_pointer : tvm.tir.Var - The pointer produced by the operation. + output_buffer : tvm.tir.Buffer + The buffer produced by the operation. """ assert fm_type in ("ifm", "ofm") @@ -96,14 +98,14 @@ def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatur stride_c=strides[2] if len(strides) > 2 else 1, ) - output_pointer = inner.buffer.data + output_buffer = inner.buffer - return serial_feature_map, output_pointer + return serial_feature_map, output_buffer def get_identity_params( stmt: tvm.tir.AttrStmt, producers_consumers: ProducersConsumers -) -> Tuple[SerialPooling, tvm.tir.Var, tvm.tir.Var]: +) -> Tuple[SerialPooling, tvm.tir.Buffer, tvm.tir.Buffer]: """Get the parameters necessary to construct a call_extern for an identity pooling. Parameters @@ -118,11 +120,11 @@ def get_identity_params( ------- SerialPooling The parameters needed to construct a 2D pooling. - output_pointer : tvm.tir.Var - The output pointer of the pooling operation. - replace_pointer : tvm.tir.Var - The output pointer of the DMA write operation, which is to replace - the pooling output pointer. + output_buffer : tvm.tir.Buffer + The output buffer of the pooling operation. + replace_buffer : tvm.tir.Buffer + The output buffer of the DMA write operation, which is to replace + the pooling output buffer. is_allocator : bool Whether this operator allocates its output. @@ -136,19 +138,19 @@ def get_identity_params( # loads = [input, LUT, LUT] loads = get_loads(store) - input_pointer = loads[0].buffer.data - output_pointer = store.buffer.data + input_buffer = loads[0].buffer + output_buffer = store.buffer - read = producers_consumers.get_producer(input_pointer, stmt) - write = producers_consumers.get_consumer(output_pointer, stmt) + read = producers_consumers.get_producer(input_buffer, stmt) + write = producers_consumers.get_consumer(output_buffer, stmt) serial_ifm, _ = _get_feature_map(read, "ifm") - serial_ofm, write_output_pointer = _get_feature_map(write, "ofm") + serial_ofm, write_output_buffer = _get_feature_map(write, "ofm") - replace_pointer = write_output_pointer + replace_buffer = write_output_buffer is_allocator = True - producer = producers_consumers.get_producer(write_output_pointer, write) + producer = producers_consumers.get_producer(write_output_buffer, write) if producer is None or producer != write: is_allocator = False @@ -169,7 +171,7 @@ def get_identity_params( rounding_mode=attrs["rounding_mode"], block_config=SerialBlockConfig(0, 0, 0), ), - output_pointer, - replace_pointer, + output_buffer, + replace_buffer, is_allocator, ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 9636f2044733..2a242944789f 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -26,6 +26,7 @@ from tvm.relay.backend.contrib.ethosu import vela_api from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator as tirtocs +from .utils import collect_buffer_map from .convolution import get_conv2d_params from .depthwise import get_depthwise_conv2d_params from .pooling import get_pooling_params @@ -72,9 +73,10 @@ def ReplaceOperators(): "ethosu_unary_elementwise": get_unary_elementwise_params, } producers_consumers = ProducersConsumers() - replace_output_pointer = {} + pointer_to_buffer = {} + replace_output_buffer = {} pointer_to_extents = {} - replaced_pointers = [] + replaced_buffers = [] ReplaceInfo = namedtuple("ReplaceInfo", ["pointer", "reallocate"]) @@ -90,28 +92,37 @@ def _resolve_pointers(stmt): Additionally, it determines the extent (size/shape) of each pointer which is required for the _replace_pointers pass which runs later.""" - loads = [] def _get_loads(stmt): - if isinstance(stmt, tvm.tir.BufferLoad): - loads.append(stmt.buffer.data) + loads = [] + + def _visit(stmt): + if isinstance(stmt, tvm.tir.BufferLoad): + loads.append(stmt.buffer) + + tvm.tir.stmt_functor.post_order_visit(stmt, _visit) + return loads - buffer_var = None + def _get_output_buffer(stmt): + output_buffer = None - def _get_buffer_var(stmt): - if isinstance(stmt, tvm.tir.BufferStore): - nonlocal buffer_var - buffer_var = stmt.buffer.data + def _visit(stmt): + if isinstance(stmt, tvm.tir.BufferStore): + nonlocal output_buffer + output_buffer = stmt.buffer + + tvm.tir.stmt_functor.post_order_visit(stmt, _visit) + return output_buffer if isinstance(stmt, tvm.tir.AttrStmt): if stmt.attr_key == "pragma_op": - tvm.tir.stmt_functor.post_order_visit(stmt, _get_buffer_var) - producers_consumers.add_producer(buffer_var, stmt) + output_buffer = _get_output_buffer(stmt) + producers_consumers.add_producer(output_buffer, stmt) - tvm.tir.stmt_functor.post_order_visit(stmt, _get_loads) - for load_pointer in loads: - if load_pointer != buffer_var: - producers_consumers.add_consumer(load_pointer, stmt) + loads = _get_loads(stmt) + for load_buffer in loads: + if load_buffer != output_buffer: + producers_consumers.add_consumer(load_buffer, stmt) def _replace_operator(stmt): """Replace operators with call_externs, having derived the parameters @@ -135,23 +146,28 @@ def _replace_operator(stmt): if stmt.attr_key == "pragma_op" and op_name in op_map: # Get the parameters for the extern call param_func = op_map[op_name] - info, output_pointer, replace_pointer, is_allocator = param_func( + info, output_buffer, replace_buffer, is_allocator = param_func( stmt, producers_consumers ) - if replace_pointer is not None: + if replace_buffer is not None: # Allocate pointer only once - if replace_pointer in replaced_pointers: + if replace_buffer in replaced_buffers: is_allocator = False - replace_output_pointer[output_pointer] = ReplaceInfo( - replace_pointer, is_allocator - ) - replaced_pointers.append(replace_pointer) + replace_output_buffer[output_buffer] = ReplaceInfo(replace_buffer, is_allocator) + replaced_buffers.append(replace_buffer) # Make the extern call irb = tvm.tir.ir_builder.create() irb.emit(tvm.tir.call_extern("handle", op_name, *info)) return irb.get() return None + def _remove_buffer_decl(stmt, buffer_var): + def _mutator(stmt): + if isinstance(stmt, tvm.tir.DeclBuffer) and stmt.buffer.data == buffer_var: + return stmt.body + + return tvm.tir.stmt_functor.ir_transform(stmt, None, _mutator, ["tir.DeclBuffer"]) + def _remove_no_compile(stmt): """Certain operators are marked as 'no compile' operators. This means they should be removed from the IR as they are compiled as part of other operators. @@ -170,33 +186,41 @@ def _remove_no_compile(stmt): if isinstance(stmt, tvm.tir.Allocate): # Remove allocates - producer = producers_consumers.get_last_producer(stmt.buffer_var) + buffer = pointer_to_buffer[stmt.buffer_var] + producer = producers_consumers.get_last_producer(buffer) if producer: if producer.attr_key == "pragma_op" and producer.value.value not in op_map: - return stmt.body + return _remove_buffer_decl(stmt.body, stmt.buffer_var) return None def _replace_pointers(stmt): if isinstance(stmt, tvm.tir.Allocate): # If the allocate allocates a pointer that needs replacing - if stmt.buffer_var in replace_output_pointer: - replace_pointer, reallocate = replace_output_pointer[stmt.buffer_var] - if not reallocate: - return stmt.body - # Otherwise, rewrite the allocation statement with the new pointer - # and the new extent - replace_type = replace_pointer.type_annotation.element_type.dtype - replace_extents = pointer_to_extents[replace_pointer] - return tvm.tir.Allocate( - replace_pointer, replace_type, replace_extents, stmt.condition, stmt.body - ) + buffer = pointer_to_buffer[stmt.buffer_var] + if buffer in replace_output_buffer: + replace_buffer, reallocate = replace_output_buffer[buffer] + if reallocate: + # Otherwise, rewrite the allocation statement with the new pointer + # and the new extent + replace_pointer = replace_buffer.data + replace_type = replace_pointer.type_annotation.element_type.dtype + replace_extents = pointer_to_extents[replace_pointer] + return tvm.tir.Allocate( + replace_pointer, replace_type, replace_extents, stmt.condition, stmt.body + ) + else: + return _remove_buffer_decl(stmt.body, stmt.buffer_var) return None - def _remove_buffer_decl(stmt): + def _replace_buffers(stmt): if isinstance(stmt, tvm.tir.DeclBuffer): - if stmt.buffer.data in replace_output_pointer: - return stmt.body + if stmt.buffer in replace_output_buffer: + replace_buffer, reallocate = replace_output_buffer[stmt.buffer] + if reallocate: + return tvm.tir.DeclBuffer(replace_buffer, stmt.body) + else: + return stmt.body def _post_transform(stmt): # Replace operators with call_externs @@ -205,19 +229,21 @@ def _post_transform(stmt): result = result or _remove_no_compile(stmt) # Replace necessary pointers that were removed in the previous step result = result or _replace_pointers(stmt) - # Replace BufferDecl, since only the tir.Var data pointer is - # still used, and not the tir.Buffer - result = result or _remove_buffer_decl(stmt) + # Replace BufferDecl, since only the tir.Var data pointer may + # have been replaced or removed + result = result or _replace_buffers(stmt) return result def _ftransform(f, mod, ctx): + nonlocal pointer_to_buffer + pointer_to_buffer = collect_buffer_map(f.body) tvm.tir.stmt_functor.post_order_visit(f.body, _find_pointer_to_extent) tvm.tir.stmt_functor.post_order_visit(f.body, _resolve_pointers) producers_consumers.add_allocate_variables(pointer_to_extents.keys()) return f.with_body( tvm.tir.stmt_functor.ir_transform( - f.body, None, _post_transform, ["tir.AttrStmt", "tir.Allocate"] + f.body, None, _post_transform, ["tir.AttrStmt", "tir.Allocate", "tir.DeclBuffer"] ) ) @@ -491,6 +517,7 @@ def _visit(stmt): def transform_stmt( stmt, buf_remap, + remove_decl_buffer, var_remap, pointer_to_buffer, new_buffer_var_to_const, @@ -557,6 +584,13 @@ def _visit_rewrite(stmt): offset = new_buffer_to_split_idx[new_buffer] return tvm.tir.BufferLoad(buf_remap[stmt.buffer], [offset], stmt.span) + # Update or remove DeclBuffer + if isinstance(stmt, tvm.tir.DeclBuffer): + if stmt.buffer in remove_decl_buffer: + return stmt.body + elif stmt.buffer in buf_remap: + return tvm.tir.DeclBuffer(buf_remap[stmt.buffer], stmt.body) + if isinstance(stmt, tvm.tir.AttrStmt): node_pointer = stmt.node if node_pointer in var_remap: @@ -574,7 +608,7 @@ def _visit_rewrite(stmt): stmt, None, _visit_rewrite, - ["tir.Call", "tir.Allocate", "tir.BufferLoad", "tir.AttrStmt"], + ["tir.Call", "tir.Allocate", "tir.BufferLoad", "tir.AttrStmt", "tir.DeclBuffer"], ) def _collect_parameter_buffer_aliases(prim_func): @@ -610,17 +644,22 @@ def _ftransform(f, mod, ctx): # Step 2: Generate variable/buffer remaps, based on the # collected information. buf_remap = {} + remove_decl_buffer = set() new_buffer_var_to_const = {} new_buffer_to_split_idx = {} def define_remap(old_buf, new_buf): try: old_buffers = param_buffer_var_usage[old_buf.data] + is_param = True except KeyError: old_buffers = [old_buf] + is_param = False for old_buffer in old_buffers: buf_remap[old_buffer] = new_buf + if is_param: + remove_decl_buffer.add(old_buffer) # Any encoded buffers must be replaced for info in buffer_information["constant_buffer_replacements"]: @@ -666,6 +705,7 @@ def define_remap(old_buf, new_buf): new_body = transform_stmt( f.body, buf_remap, + remove_decl_buffer, var_remap, pointer_to_buffer, new_buffer_var_to_const, diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py b/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py index 069930475df9..646e5bad7223 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py @@ -26,7 +26,7 @@ def get_pooling_params( stmt: tvm.tir.AttrStmt, producers_consumers: ProducersConsumers -) -> Tuple[SerialPooling, tvm.tir.Var, tvm.tir.Var]: +) -> Tuple[SerialPooling, tvm.tir.Buffer, tvm.tir.Buffer]: """Get the parameters necessary to construct a call_extern for a pooling. Parameters @@ -41,9 +41,9 @@ def get_pooling_params( ------- SerialPooling The parameters needed to construct a 2D convolution. - output_pointer : tvm.tir.Var + output_buffer : tvm.tir.Buffer The output pointer of the convolution operation. - replace_pointer : tvm.tir.Var + replace_buffer : tvm.tir.Buffer The output pointer of the DMA write operation, which is to replace the convolution output pointer. is_allocator : bool @@ -57,12 +57,12 @@ def get_pooling_params( loads = get_loads(rw.body) # stores = [output] stores = get_stores(rw.body) - input_pointer = loads[1].buffer.data - output_pointer = stores[0].buffer.data + input_buffer = loads[1].buffer + output_buffer = stores[0].buffer # Get feature map info - serial_ifm, serial_padding = get_ifm_params(input_pointer, producers_consumers, stmt) - serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params( - output_pointer, producers_consumers, stmt + serial_ifm, serial_padding = get_ifm_params(input_buffer, producers_consumers, stmt) + serial_ofm, serial_block_config, replace_buffer, is_allocator = get_ofm_params( + output_buffer, producers_consumers, stmt ) # Get kernel info serial_kernel = SerialKernel( @@ -90,7 +90,7 @@ def get_pooling_params( upscale=attrs["upscale"], block_config=serial_block_config, ), - output_pointer, - replace_pointer, + output_buffer, + replace_buffer, is_allocator, ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/producers_consumers.py b/python/tvm/relay/backend/contrib/ethosu/tir/producers_consumers.py index 39cbf701649f..6985ed28383f 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/producers_consumers.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/producers_consumers.py @@ -27,52 +27,52 @@ class ProducersConsumers: def __init__(self) -> None: self.indices: dict[tvm.tir.AttrStmt, int] = {} - self.producers: list[(tvm.tir.AttrStmt, tvm.tir.expr.Var)] = [] - self.consumers: list[(tvm.tir.AttrStmt, list[tvm.tir.expr.Var])] = [] + self.producers: list[(tvm.tir.AttrStmt, tvm.tir.expr.Buffer)] = [] + self.consumers: list[(tvm.tir.AttrStmt, list[tvm.tir.expr.Buffer])] = [] self.allocate_variables: Optional[KeysView] = None - def add_producer(self, var: tvm.tir.expr.Var, attr: tvm.tir.AttrStmt) -> None: + def add_producer(self, buf: tvm.tir.Buffer, attr: tvm.tir.AttrStmt) -> None: """Add the attribute statement attr as producer of the variable var.""" self.indices[attr] = len(self.producers) - self.producers.append((attr, var)) + self.producers.append((attr, buf)) def get_producer( - self, var: tvm.tir.expr.Var, attr: tvm.tir.AttrStmt + self, buf: tvm.tir.Buffer, attr: tvm.tir.AttrStmt ) -> Optional[tvm.tir.AttrStmt]: """Get the last attribute statement which produces the variable var when the current attribute statement is attr.""" - if var not in self.allocate_variables: + if buf.data not in self.allocate_variables: return None index = self.indices[attr] for i in list(reversed(range(index + 1))): - if self.producers[i][1] == var: + if self.producers[i][1] == buf: return self.producers[i][0] return None - def get_last_producer(self, var: tvm.tir.expr.Var) -> Optional[tvm.tir.AttrStmt]: + def get_last_producer(self, buf: tvm.tir.Buffer) -> Optional[tvm.tir.AttrStmt]: """Get the last attribute statement which produces the variable var.""" - return self.get_producer(var, self.producers[-1][0]) + return self.get_producer(buf, self.producers[-1][0]) def add_allocate_variables(self, allocate_variables: KeysView) -> None: """Add the allocated variables.""" self.allocate_variables = allocate_variables - def add_consumer(self, var: tvm.tir.expr.Var, attr: tvm.tir.AttrStmt) -> None: + def add_consumer(self, buf: tvm.tir.Buffer, attr: tvm.tir.AttrStmt) -> None: """Add the attribute statement attr as consumer of the variable var.""" index = self.indices[attr] if index < len(self.consumers): - self.consumers[index][1].append(var) + self.consumers[index][1].append(buf) else: - self.consumers.append((attr, [var])) + self.consumers.append((attr, [buf])) def get_consumer( - self, var: tvm.tir.expr.Var, attr: tvm.tir.AttrStmt + self, buf: tvm.tir.Buffer, attr: tvm.tir.AttrStmt ) -> Optional[tvm.tir.AttrStmt]: """Get the first attribute statement which consumes the variable var when the current attribute statement is attr.""" index = self.indices[attr] for i in range(index, len(self.consumers)): - if var in self.consumers[i][1]: + if buf in self.consumers[i][1]: return self.consumers[i][0] return None diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/transform.py b/python/tvm/relay/backend/contrib/ethosu/tir/transform.py index 272318066b3f..c5cf413110e6 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/transform.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/transform.py @@ -36,9 +36,9 @@ def get_copy_params(stmt, producers_consumers): ------- SerialCopy The parameters needed to construct a copy. - tvm.tir.Var - The output pointer of the copy operation. - replace_pointer : tvm.tir.Var + tvm.tir.Buffer + The output buffer of the copy operation. + replace_buffer : tvm.tir.Var The output pointer of the DMA write operation, which is to replace the convolution output pointer. is_allocator : bool @@ -56,7 +56,7 @@ def get_copy_params(stmt, producers_consumers): length=length, write_address=tvm.tir.expr.BufferLoad(write_store.buffer, write_base), ), - write_store.buffer.data, + write_store.buffer, None, True, ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py index cd5d71d74b84..8ec0fa952990 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py @@ -37,9 +37,9 @@ def get_unary_elementwise_params(stmt, producers_consumers): ------- SerialUnaryElementwise The parameters needed to construct a unary elementwise operator. - output_pointer : tvm.tir.Var + output_buffer : tvm.tir.Buffer The output pointer of the unary elementwise operation. - replace_pointer : tvm.tir.Var + replace_buffer : tvm.tir.Buffer The output pointer of the DMA write operation, which is to replace the unary elementwise output pointer. is_allocator : bool @@ -48,18 +48,18 @@ def get_unary_elementwise_params(stmt, producers_consumers): attrs, body = get_op_attrs(stmt) _, _, _, _, _, inner = get_outer_loops(body, "NHWC") - input_pointer = None + input_buffer = None if isinstance(inner.value, tir.expr.Select): # ABS - input_pointer = inner.value.condition.b.buffer.data + input_buffer = inner.value.condition.b.buffer if isinstance(inner.value, tir.expr.Sub): # CLZ - input_pointer = inner.value.b.args[0].buffer.data - output_pointer = inner.buffer.data + input_buffer = inner.value.b.args[0].buffer + output_buffer = inner.buffer # Get feature map info - serial_ifm, _ = get_ifm_params(input_pointer, producers_consumers, stmt) - serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params( - output_pointer, producers_consumers, stmt + serial_ifm, _ = get_ifm_params(input_buffer, producers_consumers, stmt) + serial_ofm, serial_block_config, replace_buffer, is_allocator = get_ofm_params( + output_buffer, producers_consumers, stmt ) # Get activation info serial_activation = SerialActivation( @@ -74,7 +74,7 @@ def get_unary_elementwise_params(stmt, producers_consumers): rounding_mode=attrs["rounding_mode"], block_config=serial_block_config, ), - output_pointer, - replace_pointer, + output_buffer, + replace_buffer, is_allocator, ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py index 396735a07c4c..485539a0c6d2 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py @@ -178,7 +178,7 @@ def collect_buffer_map(stmt): buffer_map = {} def _visit(node): - if isinstance(node, (tvm.tir.BufferLoad, tvm.tir.BufferStore)): + if isinstance(node, (tvm.tir.DeclBuffer, tvm.tir.BufferLoad, tvm.tir.BufferStore)): buf = node.buffer if buf.data not in buffer_map: buffer_map[buf.data] = buf diff --git a/src/tir/contrib/ethosu/passes.cc b/src/tir/contrib/ethosu/passes.cc index 0c0d47571c4a..6699a27dab5b 100644 --- a/src/tir/contrib/ethosu/passes.cc +++ b/src/tir/contrib/ethosu/passes.cc @@ -179,33 +179,48 @@ class HoistAllocatesMutator : public StmtExprMutator { HoistAllocatesMutator() {} PrimFunc operator()(PrimFunc main_func) { - Stmt new_main_func_body = SeqStmt::Flatten(this->VisitStmt(main_func->body)); - - // Write all allocates that were removed in reverse order - for (auto it = allocates_.rbegin(); it != allocates_.rend(); it++) { - Allocate current_alloc = *it; - if (it != allocates_.rbegin()) { - new_main_func_body = SeqStmt::Flatten(new_main_func_body); - } - new_main_func_body = - Allocate(current_alloc->buffer_var, current_alloc->dtype, current_alloc->extents, - current_alloc->condition, new_main_func_body, current_alloc->annotations, - current_alloc->span); + Stmt body = main_func->body; + while (auto opt = body.as()) { + auto node = opt.value(); + body = node->body; + node.CopyOnWrite()->body = Evaluate(0); + init_nest_.push_back(node); } - PrimFunc new_main_func = PrimFunc(main_func->params, new_main_func_body, main_func->ret_type, - main_func->buffer_map, main_func->attrs); - return new_main_func; + body = this->VisitStmt(body); + body = SeqStmt::Flatten(body); + + body = MergeNest(init_nest_, body); + + main_func.CopyOnWrite()->body = body; + return main_func; } private: Stmt VisitStmt_(const AllocateNode* op) override { - allocates_.push_back(GetRef(op)); - return VisitStmt(op->body); + auto body = op->body; + auto node = GetRef(op); + node.CopyOnWrite()->body = Evaluate(0); + init_nest_.push_back(node); + + while (true) { + auto decl_ptr = body.as(); + + if (decl_ptr && decl_ptr->buffer->data.same_as(op->buffer_var)) { + auto decl = GetRef(decl_ptr); + body = decl->body; + decl.CopyOnWrite()->body = Evaluate(0); + init_nest_.push_back(decl); + } else { + break; + } + } + + return VisitStmt(body); } /*! A stack to store allocates as they are visited. */ - std::vector allocates_; + std::vector init_nest_; }; /*! @@ -416,21 +431,6 @@ tvm::transform::Pass CopyComputeReordering(Optional max_copy_movements, TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering") .set_body_typed(CopyComputeReordering); -/*! - * \brief This mutator removes all allocates. - */ -class RemoveAllocatesMutator : public StmtExprMutator { - public: - PrimFunc operator()(PrimFunc main_func) { - auto prim_func_node{main_func.CopyOnWrite()}; - prim_func_node->body = this->VisitStmt(main_func->body); - return GetRef(prim_func_node); - } - - private: - Stmt VisitStmt_(const AllocateNode* op) override { return VisitStmt(op->body); } -}; - /*! * \brief This extractor collects information used by the MergeConstantsMutator */ @@ -438,8 +438,8 @@ class MergeConstantsInfoExtractor : public StmtExprVisitor { public: class Info { public: - /*! A stack to store allocates as they are visited. */ - std::vector allocates{}; + /*! A stack to store Allocate/DeclBuffer as they are visited. */ + std::vector init_nest{}; /*! A list that contains in the i-th position the write buffer of the i-th statement * if that statement is a copy to a buffer with global scope */ @@ -467,12 +467,28 @@ class MergeConstantsInfoExtractor : public StmtExprVisitor { Info _info{}; void VisitStmt_(const AllocateNode* op) override { - _info.allocates.push_back(GetRef(op)); - VisitStmt(op->body); + auto alloc = GetRef(op); + alloc.CopyOnWrite()->body = Evaluate(0); + _info.init_nest.push_back(alloc); + + Stmt body = op->body; + while (true) { + auto decl = body.as(); + if (decl && decl->buffer->data.same_as(op->buffer_var)) { + auto node = GetRef(decl); + node.CopyOnWrite()->body = Evaluate(0); + _info.init_nest.push_back(node); + body = decl->body; + } else { + break; + } + } + + VisitStmt(body); } void VisitStmt_(const SeqStmtNode* op) override { - std::vector seq_stmt = FlattenUnwrap(GetRef(op)).seq; + auto seq_stmt = FlattenUnwrap(GetRef(op)).seq; if (seq_stmt.size() <= 1) { StmtExprVisitor::VisitStmt_(op); @@ -590,53 +606,73 @@ class MergeConstantsMutator : public StmtExprMutator { Stmt RewritePrimFuncBody(Stmt body) { std::unordered_map var_to_allocate{}; + std::vector init_nest; + while (auto opt = body.as()) { + auto node = opt.value(); + body = node->body; + node.CopyOnWrite()->body = Evaluate(0); + init_nest.push_back(node); + } + // Rewrite old allocates - std::unordered_set buffer_vars{GetVarsForWrittenCopyBuffers()}; - for (auto it{_info.allocates.rbegin()}; it != _info.allocates.rend(); ++it) { - Allocate alloc{*it}; - var_to_allocate[alloc->buffer_var.get()] = alloc; - if (buffer_vars.count(alloc->buffer_var.as()) == 0) { - body = Allocate(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->condition, body, - alloc->annotations, alloc->span); + auto buffer_vars = GetVarsForWrittenCopyBuffers(); + for (auto init_stmt : _info.init_nest) { + if (auto opt = init_stmt.as()) { + auto alloc = opt.value(); + var_to_allocate[alloc->buffer_var.get()] = alloc; + } + + auto var = [&]() -> Var { + if (auto opt = init_stmt.as()) { + return opt.value()->buffer_var; + } else if (auto opt = init_stmt.as()) { + return opt.value()->buffer->data; + } else { + LOG(FATAL) << "Expected Allocate or DeclBuffer, but found " << init_stmt->GetTypeKey(); + } + }(); + if (!buffer_vars.count(var.get())) { + init_nest.push_back(init_stmt); } } // Rewrite new allocates - for (auto it{_info.copy_write_buffers.rbegin()}; it != _info.copy_write_buffers.rend(); ++it) { - if (Optional buffer_opt = *it) { - Buffer old_write_buffer{buffer_opt.value()}; - int new_buffer_index{ - _info.old_to_new_write_buffer[old_write_buffer.as()].first}; + for (auto buffer_opt : _info.copy_write_buffers) { + if (buffer_opt) { + Buffer old_write_buffer = buffer_opt.value(); + int new_buffer_index = _info.old_to_new_write_buffer[old_write_buffer.get()].first; // Check if the allocate has already been created if (new_buffers.count(new_buffer_index) == 0) { - BufferNode* new_buffer{old_write_buffer.CopyOnWrite()}; - new_buffer->shape = {_info.new_buffers_length[new_buffer_index]}; + Buffer new_buffer = old_write_buffer; - new_buffers[new_buffer_index] = GetRef(new_buffer); + new_buffer.CopyOnWrite()->shape = {_info.new_buffers_length[new_buffer_index]}; + new_buffers[new_buffer_index] = new_buffer; Allocate old_allocate{var_to_allocate[old_write_buffer->data.get()]}; - body = Allocate(new_buffer->data, new_buffer->dtype, new_buffer->shape, tir::const_true(), - body, old_allocate->annotations, old_allocate->span); + init_nest.push_back(Allocate(new_buffer->data, new_buffer->dtype, new_buffer->shape, + tir::const_true(), Evaluate(0), old_allocate->annotations, + old_allocate->span)); + init_nest.push_back(DeclBuffer(new_buffer, Evaluate(0))); } } } // Rewrite operators - return this->VisitStmt(body); + return MergeNest(init_nest, this->VisitStmt(body)); } - Stmt VisitStmt_(const AllocateNode* op) override { - auto allocate{CopyOnWrite(op)}; - allocate->body = this->VisitStmt(op->body); - return Stmt(allocate); - } + Stmt VisitStmt_(const AllocateNode* op) override { return this->VisitStmt(op->body); } + + Stmt VisitStmt_(const DeclBufferNode* op) override { return this->VisitStmt(op->body); } Stmt VisitStmt_(const SeqStmtNode* op) override { - std::vector seq_stmt = FlattenUnwrap(GetRef(op)).seq; + auto [seq_stmt, rewrap_nest] = FlattenUnwrap(GetRef(op)); - if (seq_stmt.size() <= 1) { - return StmtExprMutator::VisitStmt_(op); + if (seq_stmt.size() == 0) { + return Evaluate(0); + } else if (seq_stmt.size() == 1) { + return MergeNest(rewrap_nest, VisitStmt(seq_stmt[0])); } Array new_seq{}; @@ -669,7 +705,7 @@ class MergeConstantsMutator : public StmtExprMutator { } } } - return SeqStmt::Flatten(new_seq); + return MergeNest(rewrap_nest, SeqStmt::Flatten(new_seq)); } /*! Returns the variables of the buffers written by copies */ @@ -947,7 +983,6 @@ tvm::transform::Pass MergeConstants() { ICHECK(const_dict) << "Expected a ethos-u.const_dict attribute"; MergeConstantsInfoExtractor::Info info{MergeConstantsInfoExtractor()(f)}; - f = RemoveAllocatesMutator()(f); return MergeConstantsMutator(info)(f, const_dict.value()); }; return tvm::tir::transform::CreatePrimFuncPass(pass_func, 0, "tir.contrib.ethos-u.MergeConstants", diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index c04e12b8395e..0d8cc8553c95 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -41,13 +41,29 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { static PrimFunc Flatten(PrimFunc func) { arith::Analyzer ana; auto pass = BufferFlattener(&ana); - auto writer = func.CopyOnWrite(); pass.MarkBufferMapShapes(func); - writer->body = pass.VisitStmt(func->body); + auto body = pass.VisitStmt(func->body); + // The buffers in func->buffer_map are deliberately left // unflattened, as they are used for validation of user-provided // arguments. The flattened buffers used in the updated // function body alias the argument buffers. + for (size_t i = func->params.size(); i > 0; i--) { + auto handle = func->params[i - 1]; + if (auto opt = func->buffer_map.Get(handle)) { + auto old_buf = opt.value(); + if (pass.buffers_used_.count(old_buf)) { + auto new_buf = pass.GetFlattenedBuffer(old_buf); + if (!old_buf.same_as(new_buf)) { + body = DeclBuffer(new_buf, std::move(body)); + } + } + } + } + + if (!body.same_as(func->body)) { + func.CopyOnWrite()->body = std::move(body); + } return func; } @@ -153,11 +169,14 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { } Stmt VisitStmt_(const DeclBufferNode* op) final { - // TODO(rfc-70): Update the DeclBuffer node instead of - // stripping it out. Stripping it out in the current - // implementation as not all lowering passes support - // DeclBuffer. - return VisitStmt(op->body); + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + + auto new_buf = GetFlattenedBuffer(node->buffer); + if (!node->buffer.same_as(new_buf)) { + node.CopyOnWrite()->buffer = new_buf; + } + + return std::move(node); } Buffer GetFlattenedBuffer(Buffer buf) { @@ -166,16 +185,23 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { return it->second; } auto flattened = buf.GetFlattenedBuffer(); - auto writer = flattened.CopyOnWrite(); // TODO(Lunderberg): Move the handling of boolean into a // dedicated pass. if (flattened->dtype == DataType::Bool()) { - writer->dtype = DataType::Int(8); + flattened.CopyOnWrite()->dtype = DataType::Int(8); } // canonicalize shape - for (size_t i = 0; i < flattened->shape.size(); ++i) { - writer->shape.Set(i, analyzer_->canonical_simplify(flattened->shape[i])); + bool shape_is_changed = false; + Array new_shape; + for (const auto& dim : flattened->shape) { + auto new_dim = analyzer_->canonical_simplify(dim); + shape_is_changed = shape_is_changed || !StructuralEqual()(dim, new_dim); + new_shape.push_back(new_dim); + } + + if (shape_is_changed) { + flattened.CopyOnWrite()->shape = std::move(new_shape); } buffer_remap_[buf] = flattened; @@ -226,6 +252,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { template Node VisitBufferAccess(Node node) { ICHECK(node->buffer.defined()); + buffers_used_.insert(node->buffer); auto flattened_indices = GetSimplifiedElemOffset(node->buffer, node->indices); Buffer flattened_buffer = GetFlattenedBuffer(node->buffer); @@ -264,6 +291,8 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { /*! \brief Map of buffers being remapped. */ std::unordered_map buffer_remap_; + std::unordered_set buffers_used_; + /*! \brief The updated external buffer map. */ Map updated_extern_buffer_map_; }; diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 9c1244838173..f24f14542969 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -1340,12 +1340,30 @@ class StorageFlattener : public StmtExprMutator { auto pass = StorageFlattener(func->buffer_map, cache_line_size, create_bound_attributes, &bound_analyzer); - auto fptr = func.CopyOnWrite(); - fptr->body = pass(std::move(fptr->body)); + Stmt body = pass(func->body); + + for (size_t i = func->params.size(); i > 0; i--) { + auto handle = func->params[i - 1]; + if (auto opt = func->buffer_map.Get(handle)) { + auto old_buf = opt.value(); + if (pass.buf_map_.count(old_buf)) { + auto new_buf = pass.GetBufferEntry(old_buf).flattened_buffer; + if (!old_buf.same_as(new_buf)) { + body = DeclBuffer(new_buf, std::move(body)); + } + } + } + } + // The buffers in func->buffer_map are deliberately left // unflattened, as they are used for validation of user-provided // arguments. The flattened buffers used in the updated // function body alias the argument buffers. + + if (!body.same_as(func->body)) { + func.CopyOnWrite()->body = body; + } + return func; }; return transform::CreatePrimFuncPass(pass_func, 0, "tir.StorageFlattener", {}); @@ -1542,9 +1560,10 @@ class StorageFlattener : public StmtExprMutator { buffer_var_defines_.erase(op->buffer->data.get()); buf_map_[key].in_scope = false; - Stmt ret = - Allocate(e.flattened_buffer->data, e.flattened_buffer->dtype, e.flattened_buffer->shape, - make_const(DataType::Bool(e.flattened_buffer->dtype.lanes()), true), body); + Stmt ret = body; + ret = DeclBuffer(e.flattened_buffer, body); + ret = Allocate(e.flattened_buffer->data, e.flattened_buffer->dtype, e.flattened_buffer->shape, + make_const(DataType::Bool(e.flattened_buffer->dtype.lanes()), true), ret); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { ret = AttrStmt(e.buffer->data, tir::attr::buffer_bound, diff --git a/tests/python/contrib/test_ethosu/cascader/test_integration.py b/tests/python/contrib/test_ethosu/cascader/test_integration.py index 14cc8fbc61cf..1eb3c3b87aab 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_integration.py +++ b/tests/python/contrib/test_ethosu/cascader/test_integration.py @@ -109,7 +109,10 @@ def test_single_conv_compute_cycles_hint(): for single convolution. """ primfunc = _compile_model(_create_single_conv2d()) - ops = primfunc.body.body.seq + body = primfunc + while not isinstance(body, tvm.tir.SeqStmt): + body = body.body + ops = body.seq compute_cycles_hints = [2944, 320] for op, compute_cycle_hint in zip(ops, compute_cycles_hints): assert op.attr_key == "pragma_compute_cycles_hint" @@ -122,7 +125,10 @@ def test_double_conv_compute_cycles_hint(): for double convolution. """ primfunc = _compile_model(_create_double_conv2d()) - ops = primfunc.body.body.body.body.seq + body = primfunc + while not isinstance(body, tvm.tir.SeqStmt): + body = body.body + ops = body.seq compute_cycles_hints = [2944, 1408, 320, 240] for op, compute_cycle_hint in zip(ops, compute_cycles_hints): assert op.attr_key == "pragma_compute_cycles_hint" @@ -135,7 +141,10 @@ def test_scalar_add_compute_cycles_hint(): for add with scalar values. """ primfunc = _compile_model(_create_scalar_add()) - ops = primfunc.body.body.seq + body = primfunc + while not isinstance(body, tvm.tir.SeqStmt): + body = body.body + ops = body.seq compute_cycles_hints = [16, 24] for op, compute_cycle_hint in zip(ops, compute_cycles_hints): diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 8c35a43e47e9..481fe066d6e5 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -40,8 +40,8 @@ class WeightStreamOnlyU55: def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_write: T.Buffer((1, 16, 16, 8), "int8")) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - placeholder = T.Buffer([8192], "int8", data=input_placeholder.data) - ethosu_write = T.Buffer([2048], "int8", data=input_ethosu_write.data) + placeholder = T.decl_buffer([8192], "int8", data=input_placeholder.data) + ethosu_write = T.decl_buffer([2048], "int8", data=input_ethosu_write.data) buffer1 = T.Buffer([160], "uint8") buffer3 = T.Buffer([144], "uint8") buffer5 = T.Buffer([144], "uint8") @@ -49,10 +49,10 @@ def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_writ buffer8 = T.Buffer([32], "uint8") # body p1_data = T.allocate([160], "uint8", "global", annotations={"disable_lower_builtin":True}) - p1 = T.Buffer([160], "uint8", data=p1_data) + p1 = T.decl_buffer([160], "uint8", data=p1_data) + buffer9 = T.decl_buffer([144], "uint8", data=p1_data) p2_data = T.allocate([144], "uint8", "global", annotations={"disable_lower_builtin":True}) - p2 = T.Buffer([144], "uint8", data=p2_data) - buffer9 = T.Buffer([144], "uint8", data=p1.data) + p2 = T.decl_buffer([144], "uint8", data=p2_data) T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 160, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 144, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, T.int8(-1), T.int8(-1), 12, p1[128], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -69,27 +69,38 @@ class WeightStreamOnlyU65: @T.prim_func def main(ifm: T.Buffer((1, 16, 16, 32), "int8"), ethosu_write: T.Buffer((1, 16, 16, 8), "int8")): T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)}) + + ifm_1 = T.decl_buffer((8192,), "int8", data=ifm.data) + ethosu_write_1 = T.decl_buffer((2048,), "int8", data=ethosu_write.data) + p2_global_6 = T.allocate([192], "uint8", "global", annotations={"disable_lower_builtin": T.bool(True)}) + p2_global_3 = T.decl_buffer((192,), "uint8", data=p2_global_6) + p2_global_4 = T.allocate([192], "uint8", "global", annotations={"disable_lower_builtin": T.bool(True)}) + p2_global_4_1 = T.decl_buffer((192,), "uint8", data=p2_global_4) + p2_global_5 = T.allocate([208], "uint8", "global", annotations={"disable_lower_builtin": T.bool(True)}) + p2_global_5_1 = T.decl_buffer((208,), "uint8", data=p2_global_5) + buffer_encoded = T.Buffer((192,), "uint8") - p2_global_3 = T.Buffer((192,), "uint8", data=p2_global_6) - T.call_extern("handle", "ethosu_copy", buffer_encoded[0], 192, p2_global_3[0]) buffer_encoded_1 = T.Buffer((192,), "uint8") - p2_global_4_1 = T.Buffer((192,), "uint8", data=p2_global_4) - T.call_extern("handle", "ethosu_copy", buffer_encoded_1[0], 192, p2_global_4_1[0]) buffer_encoded_2 = T.Buffer((208,), "uint8") - p2_global_5_1 = T.Buffer((208,), "uint8", data=p2_global_5) + buffer_encoded_3 = T.Buffer((192,), "uint8") + + + T.call_extern("handle", "ethosu_copy", buffer_encoded[0], 192, p2_global_3[0]) + + T.call_extern("handle", "ethosu_copy", buffer_encoded_1[0], 192, p2_global_4_1[0]) T.call_extern("handle", "ethosu_copy", buffer_encoded_2[0], 208, p2_global_5_1[0]) - ifm_1 = T.Buffer((8192,), "int8", data=ifm.data) - ethosu_write_1 = T.Buffer((2048,), "int8", data=ethosu_write.data) + + T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, ifm_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2_global_3[0], 80, p2_global_3[80], 80, 12, p2_global_3[160], 16, p2_global_3[176], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) - buffer_encoded_3 = T.Buffer((192,), "uint8") - p2_global_6_1 = T.Buffer((192,), "uint8", data=p2_global_6) - T.call_extern("handle", "ethosu_copy", buffer_encoded_3[0], 192, p2_global_6_1[0]) + + + T.call_extern("handle", "ethosu_copy", buffer_encoded_3[0], 192, p2_global_3[0]) T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, ifm_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write_1[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2_global_4_1[0], 80, p2_global_4_1[80], 80, 12, p2_global_4_1[160], 16, p2_global_4_1[176], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, ifm_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write_1[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2_global_5_1[0], 96, p2_global_5_1[96], 80, 12, p2_global_5_1[176], 16, p2_global_5_1[192], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) - T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, ifm_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write_1[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2_global_6_1[0], 80, p2_global_6_1[80], 80, 12, p2_global_6_1[160], 16, p2_global_6_1[176], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) + T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, ifm_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write_1[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2_global_3[0], 80, p2_global_3[80], 80, 12, p2_global_3[160], 16, p2_global_3[176], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) # fmt: on @@ -155,16 +166,18 @@ def _get_func(): class RereadWeightsU55: @T.prim_func def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_write: T.Buffer((1, 16, 16, 8), "int8")) -> None: - # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer1 = T.Buffer([384], "uint8") - placeholder = T.Buffer([8192], "int8", data=input_placeholder.data) - ethosu_write = T.Buffer([2048], "int8", data=input_ethosu_write.data) - # body + placeholder = T.decl_buffer([8192], "int8", data=input_placeholder.data) + ethosu_write = T.decl_buffer([2048], "int8", data=input_ethosu_write.data) + p1_data = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin":True}) - p1 = T.Buffer([384], "uint8", data=p1_data) + p1 = T.decl_buffer([384], "uint8", data=p1_data) + p2_data = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin":True}) - p2 = T.Buffer([384], "uint8", data=p2_data) + p2 = T.decl_buffer([384], "uint8", data=p2_data) + + buffer1 = T.Buffer([384], "uint8") + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 384, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 384, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 304, T.int8(-1), T.int8(-1), 12, p1[304], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -176,17 +189,19 @@ def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_writ class RereadWeightsU65: @T.prim_func def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_write: T.Buffer((1, 16, 16, 8), "int8")) -> None: - # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - # buffer definition - placeholder = T.Buffer([8192], dtype="int8", data=input_placeholder.data) - ethosu_write = T.Buffer([2048], dtype="int8", data=input_ethosu_write.data) - placeholder_encoded_1 = T.Buffer([464], "uint8") - # body + + placeholder = T.decl_buffer([8192], dtype="int8", data=input_placeholder.data) + ethosu_write = T.decl_buffer([2048], dtype="int8", data=input_ethosu_write.data) + p1_data = T.allocate([464], "uint8", "global", annotations={"disable_lower_builtin":True}) - p1 = T.Buffer([464], "uint8", data=p1_data) + p1 = T.decl_buffer([464], "uint8", data=p1_data) + p2_data = T.allocate([464], "uint8", "global", annotations={"disable_lower_builtin":True}) - p2 = T.Buffer([464], "uint8", data=p2_data) + p2 = T.decl_buffer([464], "uint8", data=p2_data) + + placeholder_encoded_1 = T.Buffer([464], "uint8") + T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1[0], 464, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1[0], 464, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -256,17 +271,19 @@ def _get_func(): class DirectReadOnlyU55: @T.prim_func def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_write: T.Buffer((1, 16, 16, 8), "int8")) -> None: - # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + + placeholder = T.decl_buffer([8192], "int8", data=input_placeholder.data) + ethosu_write = T.decl_buffer([2048], "int8", data=input_ethosu_write.data) + + ethosu_write_1_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) + ethosu_write_1 = T.decl_buffer([4096], "int8", data=ethosu_write_1_data) + buffer = T.Buffer([592], "uint8") buffer_1 = T.Buffer([160], "uint8") buffer_2 = T.Buffer([160], "uint8") buffer_3 = T.Buffer([80], "uint8") - placeholder = T.Buffer([8192], "int8", data=input_placeholder.data) - ethosu_write = T.Buffer([2048], "int8", data=input_ethosu_write.data) - # body - ethosu_write_1_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) - ethosu_write_1 = T.Buffer([4096], "int8", data=ethosu_write_1_data) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer[0], 592, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, buffer_2[0], 160, T.int8(-1), T.int8(-1), 12, buffer_3[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -276,18 +293,19 @@ def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_writ class DirectReadOnlyU65: @T.prim_func def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_write: T.Buffer((1, 16, 16, 8), "int8")) -> None: - # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - # buffer definition + + placeholder = T.decl_buffer([8192], dtype="int8", data=input_placeholder.data) + ethosu_write = T.decl_buffer([2048], dtype="int8", data=input_ethosu_write.data) + + ethosu_write_2_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) + ethosu_write_2 = T.decl_buffer([4096], "int8", data=ethosu_write_2_data) + placeholder_encoded = T.Buffer([608], dtype="uint8") placeholder_encoded_1 = T.Buffer([160], dtype="uint8") placeholder_encoded_2 = T.Buffer([208], dtype="uint8") placeholder_encoded_3 = T.Buffer([96], dtype="uint8") - placeholder = T.Buffer([8192], dtype="int8", data=input_placeholder.data) - ethosu_write = T.Buffer([2048], dtype="int8", data=input_ethosu_write.data) - # body - ethosu_write_2_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) - ethosu_write_2 = T.Buffer([4096], "int8", data=ethosu_write_2_data) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_encoded[0], 304, placeholder_encoded[304], 304, 12, placeholder_encoded_1[0], 80, placeholder_encoded_1[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_encoded_2[0], 112, placeholder_encoded_2[112], 96, 12, placeholder_encoded_3[0], 48, placeholder_encoded_3[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) # fmt: on @@ -354,23 +372,28 @@ def _get_func(): class MixedReadU55: @T.prim_func def main(input_ifm: T.Buffer((1,16,16,32), "int8"), input_ethosu_write: T.Buffer((1,16,16,8), "int8")) -> None: - # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + + ifm = T.decl_buffer([8192], "int8", data=input_ifm.data) + ethosu_write = T.decl_buffer([2048], "int8", data=input_ethosu_write.data) + + p1_data = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True}) + p1 = T.decl_buffer([112], "uint8", data=p1_data) + + p3_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) + p3 = T.decl_buffer([4096], "int8", data=p3_data) + + p2_data = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True}) + p2 = T.decl_buffer([112], "uint8", data=p2_data) + buffer1 = T.Buffer([112], "uint8") buffer3 = T.Buffer([112], "uint8") buffer5 = T.Buffer([112], "uint8") buffer7 = T.Buffer([112], "uint8") buffer9 = T.Buffer([592], "uint8") buffer10 = T.Buffer([160], "uint8") - ifm = T.Buffer([8192], "int8", data=input_ifm.data) - ethosu_write = T.Buffer([2048], "int8", data=input_ethosu_write.data) - # body - p1_data = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True}) - p1 = T.Buffer([112], "uint8", data=p1_data) - p3_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) - p3 = T.Buffer([4096], "int8", data=p3_data) - p2_data = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True}) - p2 = T.Buffer([112], "uint8", data=p2_data) + + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 112, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, ifm[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer9[0], 592, T.int8(-1), T.int8(-1), 12, buffer10[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 112, p2[0], dtype="handle")) @@ -388,29 +411,35 @@ class MixedReadU65: @T.prim_func def main(ifm: T.Buffer((1, 16, 16, 32), "int8"), ethosu_write: T.Buffer((1, 16, 16, 8), "int8")): T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)}) + + ifm_1 = T.decl_buffer((8192,), "int8", data=ifm.data) + ethosu_write_3 = T.decl_buffer((2048,), "int8", data=ethosu_write.data) + p5_global = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin": T.bool(True)}) + p5_global_3 = T.decl_buffer((128,), "uint8", data=p5_global) + p5_global_1 = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin": T.bool(True)}) + p5_global_4 = T.decl_buffer((128,), "uint8", data=p5_global_1) + ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin": T.bool(True)}) + ethosu_write_2 = T.decl_buffer((4096,), "int8", data=ethosu_write_1) + p5_global_2 = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin": T.bool(True)}) + p5_global_5 = T.decl_buffer((128,), "uint8", data=p5_global_2) + + p1_encoded = T.Buffer((608,), "uint8") + p2_encoded = T.Buffer((160,), "uint8") buffer_encoded = T.Buffer((128,), "uint8") - p5_global_3 = T.Buffer((128,), "uint8", data=p5_global) - T.call_extern("handle", "ethosu_copy", buffer_encoded[0], 128, p5_global_3[0]) buffer_encoded_1 = T.Buffer((128,), "uint8") - p5_global_4 = T.Buffer((128,), "uint8", data=p5_global_1) + buffer_encoded_2 = T.Buffer((128,), "uint8") + buffer_encoded_3 = T.Buffer((128,), "uint8") + + T.call_extern("handle", "ethosu_copy", buffer_encoded[0], 128, p5_global_3[0]) T.call_extern("handle", "ethosu_copy", buffer_encoded_1[0], 128, p5_global_4[0]) - ifm_1 = T.Buffer((8192,), "int8", data=ifm.data) - ethosu_write_2 = T.Buffer((4096,), "int8", data=ethosu_write_1) - p1_encoded = T.Buffer((608,), "uint8") - p2_encoded = T.Buffer((160,), "uint8") T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, ifm_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, p1_encoded[0], 304, p1_encoded[304], 304, 12, p2_encoded[0], 80, p2_encoded[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) - buffer_encoded_2 = T.Buffer((128,), "uint8") - p5_global_5 = T.Buffer((128,), "uint8", data=p5_global_2) T.call_extern("handle", "ethosu_copy", buffer_encoded_2[0], 128, p5_global_5[0]) - ethosu_write_3 = T.Buffer((2048,), "int8", data=ethosu_write.data) T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write_3[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p5_global_3[0], 48, p5_global_3[48], 48, 12, p5_global_3[96], 16, p5_global_3[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) - buffer_encoded_3 = T.Buffer((128,), "uint8") - p5_global_6 = T.Buffer((128,), "uint8", data=p5_global) - T.call_extern("handle", "ethosu_copy", buffer_encoded_3[0], 128, p5_global_6[0]) + T.call_extern("handle", "ethosu_copy", buffer_encoded_3[0], 128, p5_global_3[0]) T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write_3[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p5_global_4[0], 48, p5_global_4[48], 48, 12, p5_global_4[96], 16, p5_global_4[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write_3[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p5_global_5[0], 48, p5_global_5[48], 48, 12, p5_global_5[96], 16, p5_global_5[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write_3[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p5_global_6[0], 48, p5_global_6[48], 48, 12, p5_global_6[96], 16, p5_global_6[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) @@ -511,7 +540,10 @@ def get_graph(): # Check tile address for the scalar constant input hasn't been # overwritten. - extern_calls = tir_mod["main"].body.body.body.body + extern_calls = tir_mod["main"] + while not isinstance(extern_calls, tvm.tir.SeqStmt): + extern_calls = extern_calls.body + binary_elementwise = extern_calls[-1].value args = binary_elementwise.args diff --git a/tests/python/contrib/test_ethosu/test_identity_optimizer.py b/tests/python/contrib/test_ethosu/test_identity_optimizer.py index 3ae58dfc81ba..e5f288482703 100644 --- a/tests/python/contrib/test_ethosu/test_identity_optimizer.py +++ b/tests/python/contrib/test_ethosu/test_identity_optimizer.py @@ -311,7 +311,10 @@ def get_graph(): # Check for hints in the TIR prim func that the identity optimization pass # has ran. There should not be an identity in the prim func. - assert prim_func.body.value.args[0] == "ethosu_pooling" + body = prim_func + while not isinstance(body, tvm.tir.Evaluate): + body = body.body + assert body.value.args[0] == "ethosu_pooling" def test_same_output(): diff --git a/tests/python/contrib/test_ethosu/test_layout_optimizer.py b/tests/python/contrib/test_ethosu/test_layout_optimizer.py index 69d549acbb3b..355e302be956 100644 --- a/tests/python/contrib/test_ethosu/test_layout_optimizer.py +++ b/tests/python/contrib/test_ethosu/test_layout_optimizer.py @@ -794,7 +794,7 @@ def get_graph(): prim_func = mod[external_gv_name] # Check for hints in the TIR prim func that the layout optimization pass has ran - ops = prim_func.body.body.seq + ops = prim_func.body.body.body.body.body.seq max_pool1, max_pool2 = ops assert str(max_pool1.value.args[31]) == '"NHCWB16"' diff --git a/tests/python/contrib/test_ethosu/test_remove_concatenates.py b/tests/python/contrib/test_ethosu/test_remove_concatenates.py index 58cf5f72d7c0..6040a705ff14 100644 --- a/tests/python/contrib/test_ethosu/test_remove_concatenates.py +++ b/tests/python/contrib/test_ethosu/test_remove_concatenates.py @@ -36,9 +36,9 @@ def main(input_placeholder: T.Buffer((1,8,12,16), "int8"), input_placeholder_1: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - placeholder = T.Buffer(1536, dtype="int8", data=input_placeholder.data) - placeholder_1 = T.Buffer(1280, dtype="int8", data=input_placeholder_1.data) - T_concat = T.Buffer(4096, dtype="int8", data=input_T_concat.data) + placeholder = T.decl_buffer(1536, dtype="int8", data=input_placeholder.data) + placeholder_1 = T.decl_buffer(1280, dtype="int8", data=input_placeholder_1.data) + T_concat = T.decl_buffer(4096, dtype="int8", data=input_T_concat.data) buffer = T.Buffer([2992], "uint8") buffer_1 = T.Buffer([160], "uint8") @@ -50,7 +50,7 @@ def main(input_placeholder: T.Buffer((1,8,12,16), "int8"), input_placeholder_1: buffer_7 = T.Buffer([160], "uint8") # body T_concat_1_data = T.allocate([2816], "int8", "global", annotations={"disable_lower_builtin":True}) - T_concat_1 = T.Buffer([2816], "int8", data=T_concat_1_data) + T_concat_1 = T.decl_buffer([2816], "int8", data=T_concat_1_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, placeholder_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 160, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T_concat_1[192], 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 2992, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, T_concat_1[192], 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T_concat[352], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 2992, T.int8(-1), T.int8(-1), 12, buffer_3[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 12, 16, 8, 0, 12, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 192, 16, 1, "int8", 8, 12, 16, 8, 0, 12, T_concat_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, buffer_4[0], 2992, T.int8(-1), T.int8(-1), 12, buffer_5[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index a8aa4043293f..bdca9590f8a0 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -369,17 +369,17 @@ def _visit(stmt): class Conv2dDoubleCascade1: @T.prim_func def main(input_placeholder_5: T.Buffer((1, 8, 8, 3), "int8"), input_ethosu_write_1: T.Buffer((1, 8, 8, 8), "int8")) -> None: - # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.Buffer([304], "uint8") + buffer = T.Buffer([304], "uint8") buffer_1 = T.Buffer([80], "uint8") buffer_2 = T.Buffer([320], "uint8") buffer_3 = T.Buffer([160], "uint8") - placeholder_5 = T.Buffer([192], 'int8', data=input_placeholder_5.data) - ethosu_write_1 = T.Buffer([512], 'int8', data=input_ethosu_write_1.data) - # body + + placeholder_5 = T.decl_buffer([192], 'int8', data=input_placeholder_5.data) + ethosu_write_1 = T.decl_buffer([512], 'int8', data=input_ethosu_write_1.data) + ethosu_write_2_data = T.allocate([1024], "int8", "global", annotations={"disable_lower_builtin": True}) - ethosu_write_2 = T.Buffer([1024], "int8", data=ethosu_write_2_data) + ethosu_write_2 = T.decl_buffer([1024], "int8", data=ethosu_write_2_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, buffer_3[0], 160, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, buffer[0], 304, T.int8(-1), T.int8(-1), 12, buffer_1[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, placeholder_5[12], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, buffer_3[0], 160, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -393,15 +393,18 @@ class Conv2dDoubleCascade2: def main(input_placeholder_5: T.Buffer((1, 8, 8, 3), "int8"), input_ethosu_write_1: T.Buffer((1, 8, 8, 8), "int8")) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + placeholder_5 = T.decl_buffer([192], 'int8', data=input_placeholder_5.data) + ethosu_write_1 = T.decl_buffer([512], 'int8', data=input_ethosu_write_1.data) + + ethosu_write_2_data = T.allocate([1536], "int8", "global", annotations={"disable_lower_builtin": True}) + ethosu_write_2 = T.decl_buffer([1536], "int8", data=ethosu_write_2_data) + buffer = T.Buffer([80], "uint8") buffer_1 = T.Buffer([320], "uint8") buffer_2 = T.Buffer([1312], "uint8") buffer_3 = T.Buffer([2608], "uint8") - placeholder_5 = T.Buffer([192], 'int8', data=input_placeholder_5.data) - ethosu_write_1 = T.Buffer([512], 'int8', data=input_ethosu_write_1.data) # body - ethosu_write_2_data = T.allocate([1536], "int8", "global", annotations={"disable_lower_builtin": True}) - ethosu_write_2 = T.Buffer([1536], "int8", data=ethosu_write_2_data) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 1312, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, buffer_3[0], 2608, T.int8(-1), T.int8(-1), 12, buffer[0], 80, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[48], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 1312, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -419,12 +422,12 @@ def main(input_placeholder_5: T.Buffer((1, 16, 16, 3), "int8"), input_ethosu_wri buffer_1 = T.Buffer([80], "uint8") buffer_2 = T.Buffer([320], "uint8") buffer_3 = T.Buffer([880], "uint8") - placeholder_5 = T.Buffer([768], 'int8', data=input_placeholder_5.data) - ethosu_write_1 = T.Buffer([640], 'int8', data=input_ethosu_write_1.data) + placeholder_5 = T.decl_buffer([768], 'int8', data=input_placeholder_5.data) + ethosu_write_1 = T.decl_buffer([640], 'int8', data=input_ethosu_write_1.data) # body ethosu_write_2_data = T.allocate([2560], "int8", "global", annotations={"disable_lower_builtin": True}) - ethosu_write_2 = T.Buffer([2560], "int8", data=ethosu_write_2_data) + ethosu_write_2 = T.decl_buffer([2560], "int8", data=ethosu_write_2_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, ethosu_write_2[512], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, buffer_3[0], 880, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 32, 8, 0, 8, ethosu_write_2[512], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, buffer[0], 1744, T.int8(-1), T.int8(-1), 12, buffer_1[0], 80, T.int8(-1), T.int8(-1), 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 12, 16, 3, 12, 0, 16, placeholder_5[192], 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 10, 8, 32, 10, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, buffer_3[0], 880, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -444,11 +447,11 @@ def main(input_placeholder_5: T.Buffer((1, 8, 1, 8, 16), "int8"), input_ethosu_w buffer_1 = T.Buffer([352], "uint8") buffer_2 = T.Buffer([272], "uint8") buffer_3 = T.Buffer([11040], "uint8") - placeholder_5 = T.Buffer([1024], 'int8', data=input_placeholder_5.data) - ethosu_write_1 = T.Buffer([2048], 'int8', data=input_ethosu_write_1.data) + placeholder_5 = T.decl_buffer([1024], 'int8', data=input_placeholder_5.data) + ethosu_write_1 = T.decl_buffer([2048], 'int8', data=input_ethosu_write_1.data) # body ethosu_write_2_data = T.allocate([2304], "int8", "global", annotations={"disable_lower_builtin": True}) - ethosu_write_2 = T.Buffer((2304,), "int8", data=ethosu_write_2_data) + ethosu_write_2 = T.decl_buffer((2304,), "int8", data=ethosu_write_2_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[384], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, buffer[0], 1456, T.int8(-1), T.int8(-1), 12, buffer_1[0], 352, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[384], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, buffer_3[0], 11040, T.int8(-1), T.int8(-1), 12, buffer_2[0], 272, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[256], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, buffer[0], 1456, T.int8(-1), T.int8(-1), 12, buffer_1[0], 352, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -466,11 +469,11 @@ def main(input_placeholder: T.Buffer((1, 8, 8, 3), "int8"), input_ethosu_write: buffer_1 = T.Buffer([320], "uint8") buffer_2 = T.Buffer([304], "uint8") buffer_3 = T.Buffer([80], "uint8") - placeholder = T.Buffer([192], 'int8', data=input_placeholder.data) - ethosu_write = T.Buffer([8192], 'int8', data=input_ethosu_write.data) + placeholder = T.decl_buffer([192], 'int8', data=input_placeholder.data) + ethosu_write = T.decl_buffer([8192], 'int8', data=input_ethosu_write.data) # body ethosu_write_1_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) - ethosu_write_1 = T.Buffer([4096], "int8", data=ethosu_write_1_data) + ethosu_write_1 = T.decl_buffer([4096], "int8", data=ethosu_write_1_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, buffer[0], 160, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 32, 8, 16, 0, 32, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 8, 1, 1, 1, 1, 1, 1, 1, buffer_2[0], 304, T.int8(-1), T.int8(-1), 12, buffer_3[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, placeholder[96], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, buffer[0], 160, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", 0, 0, 0, dtype="handle")) @@ -488,11 +491,11 @@ def main(input_placeholder: T.Buffer((1, 8, 1, 8, 16), "int8"), input_ethosu_wri buffer_1 = T.Buffer([352], "uint8") buffer_2 = T.Buffer([11040], "uint8") buffer_3 = T.Buffer([272], "uint8") - placeholder = T.Buffer([1024], 'int8', data=input_placeholder.data) - ethosu_write = T.Buffer([32768], 'int8', data=input_ethosu_write.data) + placeholder = T.decl_buffer([1024], 'int8', data=input_placeholder.data) + ethosu_write = T.decl_buffer([32768], 'int8', data=input_ethosu_write.data) # body ethosu_write_1_data = T.allocate([12288], "int8", "global", annotations={"disable_lower_builtin":True}) - ethosu_write_1 = T.Buffer([12288], "int8", data=ethosu_write_1_data) + ethosu_write_1 = T.decl_buffer([12288], "int8", data=ethosu_write_1_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 3, 8, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 16, 16, 35, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 768, 16, 256, 3, 3, 1, 1, 1, 1, buffer[0], 1456, T.int8(-1), T.int8(-1), 12, buffer_1[0], 352, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 35, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 768, 16, 256, "int8", 32, 32, 26, 32, 0, 32, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 1024, 16, 512, 3, 3, 1, 1, 1, 1, buffer_2[0], 11040, T.int8(-1), T.int8(-1), 12, buffer_3[0], 272, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", 0, 0, 0, dtype="handle")) # fmt: on @@ -636,6 +639,7 @@ def _get_func( config = { "enable_cascader": True, } + with tvm.transform.PassContext(opt_level=3, config={"relay.ext.ethos-u.options": config}): func = _get_func(*params[:-1]) mod, _ = _lower_to_tir(func, cascader=total_cascader(params[-1])) @@ -654,8 +658,8 @@ def main(input_placeholder_3: T.Buffer((1, 10, 12, 8), "int8"), input_ethosu_wri T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.Buffer([848], "uint8") buffer_1 = T.Buffer([160], "uint8") - placeholder_3 = T.Buffer([960], 'int8', data=input_placeholder_3.data) - ethosu_write_1 = T.Buffer([1024], 'int8', data=input_ethosu_write_1.data) + placeholder_3 = T.decl_buffer([960], 'int8', data=input_placeholder_3.data) + ethosu_write_1 = T.decl_buffer([1024], 'int8', data=input_ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 0, 8, placeholder_3[120], 0, 0, 0, T.float32(0.5), 10, "NHWC", 96, 8, 1, "int8", 8, 8, 16, 8, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 848, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -669,8 +673,8 @@ def main(input_placeholder_3: T.Buffer((1, 7, 9, 5), "int8"), input_ethosu_write T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.Buffer([160], "uint8") buffer_1 = T.Buffer([656], "uint8") - placeholder_3 = T.Buffer([315], 'int8', data=input_placeholder_3.data) - ethosu_write_1 = T.Buffer([240], 'int8', data=input_ethosu_write_1.data) + placeholder_3 = T.decl_buffer([315], 'int8', data=input_placeholder_3.data) + ethosu_write_1 = T.decl_buffer([240], 'int8', data=input_ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 3, 5, 3, 3, 0, 5, placeholder_3[146], 0, 0, 0, T.float32(0.5), 10, "NHWC", 45, 5, 1, "int8", 3, 5, 16, 3, 0, 5, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 80, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 656, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) # fmt: on @@ -713,8 +717,8 @@ def main(input_placeholder_3: T.Buffer((4, 6, 8, 1), "int8"), input_ethosu_write T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.Buffer([160], "uint8") buffer_1 = T.Buffer([848], "uint8") - placeholder_3 = T.Buffer([192], 'int8', data=input_placeholder_3.data) - ethosu_write_1 = T.Buffer([768], 'int8', data=input_ethosu_write_1.data) + placeholder_3 = T.decl_buffer([192], 'int8', data=input_placeholder_3.data) + ethosu_write_1 = T.decl_buffer([768], 'int8', data=input_ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -729,8 +733,8 @@ def main(input_placeholder_3: T.Buffer((1, 24, 8), "int8"), input_ethosu_write_1 T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.Buffer([160], "uint8") buffer_1 = T.Buffer([848], "uint8") - placeholder_3 = T.Buffer([192], 'int8', data=input_placeholder_3.data) - ethosu_write_1 = T.Buffer([768], 'int8', data=input_ethosu_write_1.data) + placeholder_3 = T.decl_buffer([192], 'int8', data=input_placeholder_3.data) + ethosu_write_1 = T.decl_buffer([768], 'int8', data=input_ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -745,8 +749,8 @@ def main(input_placeholder_3: T.Buffer((192, 1), "int8"), input_ethosu_write_1: T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.Buffer([160], "uint8") buffer_1 = T.Buffer([848], "uint8") - placeholder_3 = T.Buffer([192], 'int8', data=input_placeholder_3.data) - ethosu_write_1 = T.Buffer([768], 'int8', data=input_ethosu_write_1.data) + placeholder_3 = T.decl_buffer([192], 'int8', data=input_placeholder_3.data) + ethosu_write_1 = T.decl_buffer([768], 'int8', data=input_ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) @@ -761,7 +765,7 @@ def main(placeholder_3: T.Buffer((192,), "int8"), input_ethosu_write_1: T.Buffer T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) buffer = T.Buffer([160], "uint8") buffer_1 = T.Buffer([848], "uint8") - ethosu_write_1 = T.Buffer([768], 'int8', data=input_ethosu_write_1.data) + ethosu_write_1 = T.decl_buffer([768], 'int8', data=input_ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index ff343517352d..c85cf892fe50 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -36,14 +36,16 @@ class ReferenceModule: @T.prim_func def main(input_placeholder_3: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_write_1: T.Buffer((1, 16, 16, 8), "int8")) -> None: - # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer_1 = T.Buffer([384], "uint8") - placeholder_3 = T.Buffer([8192], dtype="int8", data=input_placeholder_3.data) - ethosu_write_1 = T.Buffer([2048], dtype="int8", data=input_ethosu_write_1.data) - # body + + placeholder_3 = T.decl_buffer([8192], dtype="int8", data=input_placeholder_3.data) + ethosu_write_1 = T.decl_buffer([2048], dtype="int8", data=input_ethosu_write_1.data) + placeholder_global_data = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin": True}) - placeholder_global = T.Buffer([384], "uint8", data=placeholder_global_data) + placeholder_global = T.decl_buffer([384], "uint8", data=placeholder_global_data) + + buffer_1 = T.Buffer([384], "uint8") + T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 384, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, T.int8(-1), T.int8(-1), 12, placeholder_global[304], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) # fmt: on @@ -80,17 +82,19 @@ def _get_func(): class WeightStream: @T.prim_func def main(input_placeholder_5: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_write_1: T.Buffer((1, 16, 16, 16), "int8")) -> None: - # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.Buffer([528], "uint8") - buffer_2 = T.Buffer([336], "uint8") - placeholder_5 = T.Buffer([8192], dtype="int8", data=input_placeholder_5.data) - ethosu_write_1 = T.Buffer([4096], dtype="int8", data=input_ethosu_write_1.data) - # body + + placeholder_5 = T.decl_buffer([8192], dtype="int8", data=input_placeholder_5.data) + ethosu_write_1 = T.decl_buffer([4096], dtype="int8", data=input_ethosu_write_1.data) + placeholder_d_global_data = T.allocate([528], "uint8", "global", annotations={"disable_lower_builtin": True}) - placeholder_d_global = T.Buffer([528], "uint8", data=placeholder_d_global_data) + placeholder_d_global = T.decl_buffer([528], "uint8", data=placeholder_d_global_data) placeholder_d_global_1_data = T.allocate([336], "uint8", "global", annotations={"disable_lower_builtin": True}) - placeholder_d_global_1 = T.Buffer([336], "uint8", data=placeholder_d_global_1_data) + placeholder_d_global_1 = T.decl_buffer([336], "uint8", data=placeholder_d_global_1_data) + + buffer = T.Buffer([528], "uint8") + buffer_2 = T.Buffer([336], "uint8") + T.evaluate(T.call_extern("ethosu_copy", buffer[0], 528, placeholder_d_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 336, placeholder_d_global_1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_d_global[0], 416, T.int8(-1), T.int8(-1), 12, placeholder_d_global[416], 112, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index 0b6f4a2629b7..2f7bd91df141 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -218,18 +218,22 @@ class DiamondGraphTir: @T.prim_func def main(input_placeholder: T.Buffer((1, 56, 56, 96), "int8"), input_ethosu_write: T.Buffer((1, 56, 56, 24), "int8")) -> None: T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - placeholder = T.Buffer([301056], dtype='int8', data=input_placeholder.data) - ethosu_write = T.Buffer([75264], dtype='int8', data=input_ethosu_write.data) - buffer1 = T.Buffer([2848], "uint8") - buffer3 = T.Buffer([976], "uint8") + + placeholder = T.decl_buffer([301056], dtype='int8', data=input_placeholder.data) + ethosu_write = T.decl_buffer([75264], dtype='int8', data=input_ethosu_write.data) + p1_data = T.allocate([2848], "uint8", "global", annotations={"disable_lower_builtin":True}) - p1 = T.Buffer([2848], "uint8", data=p1_data) + p1 = T.decl_buffer([2848], "uint8", data=p1_data) p2_data = T.allocate([976], "uint8", "global", annotations={"disable_lower_builtin":True}) - p2 = T.Buffer([976], "uint8", data=p2_data) + p2 = T.decl_buffer([976], "uint8", data=p2_data) p5_data = T.allocate([75264], "int8", "global", annotations={"disable_lower_builtin":True}) - p5 = T.Buffer([75264], "int8", data=p5_data) + p5 = T.decl_buffer([75264], "int8", data=p5_data) p6_data = T.allocate([75264], "int8", "global", annotations={"disable_lower_builtin":True}) - p6 = T.Buffer([75264], "int8", data=p6_data) + p6 = T.decl_buffer([75264], "int8", data=p6_data) + + buffer1 = T.Buffer([2848], "uint8") + buffer3 = T.Buffer([976], "uint8") + T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 2848, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 976, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 96, 56, 0, 56, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 5376, 96, 1, "int8", 56, 56, 24, 56, 0, 56, p5[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, p1[0], 2608, T.int8(-1), T.int8(-1), 12, p1[2608], 240, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) diff --git a/tests/python/te/test_te_build_lower.py b/tests/python/te/test_te_build_lower.py index 50d5119b43a0..6da7a2df3563 100644 --- a/tests/python/te/test_te_build_lower.py +++ b/tests/python/te/test_te_build_lower.py @@ -56,7 +56,7 @@ def test_split_uneven_unique_likely(): sch = te.create_schedule(c.op) xo, xi = sch[c].split(x, 5) stmt = tvm.lower(sch, [a, b, c])["main"].body - assert isinstance(stmt.body.body, tvm.tir.stmt.IfThenElse) + assert isinstance(stmt.body.body.body.body.body, tvm.tir.stmt.IfThenElse) if __name__ == "__main__": diff --git a/tests/python/te/test_te_hybrid_script.py b/tests/python/te/test_te_hybrid_script.py index d6b11785a4a3..61821adb841f 100644 --- a/tests/python/te/test_te_hybrid_script.py +++ b/tests/python/te/test_te_hybrid_script.py @@ -756,6 +756,8 @@ def outer_product(a, b): sch[c].vectorize(ji) sch[c].reorder(ii, io, joo, joi, ji) ir = tvm.lower(sch, [a, b, c])["main"].body + assert isinstance(ir, tvm.tir.DeclBuffer) + ir = ir.body assert isinstance(ir, tvm.tir.AttrStmt) ir = ir.body assert isinstance(ir, tvm.tir.For) @@ -777,6 +779,8 @@ def outer_product(a, b): sch = te.create_schedule(c.op) sch[c].fuse(c.op.axis[0], c.op.axis[1]) ir = tvm.lower(sch, [a, b, c])["main"].body + assert isinstance(ir, tvm.tir.DeclBuffer) + ir = ir.body assert isinstance(ir, tvm.tir.AttrStmt) ir = ir.body assert isinstance(ir, tvm.tir.For) diff --git a/tests/python/te/test_te_schedule.py b/tests/python/te/test_te_schedule.py index ed224883478e..e2811765045b 100644 --- a/tests/python/te/test_te_schedule.py +++ b/tests/python/te/test_te_schedule.py @@ -324,7 +324,7 @@ def test_legalize_invalid_attach(): s[A].compute_at(s[B], B.op.axis[1]) s[B].fuse(B.op.axis[0], B.op.axis[1]) stmt = tvm.lower(s, [A, B], simple_mode=True)["main"].body - assert isinstance(stmt, tvm.tir.stmt.For) + assert isinstance(stmt.body.body, tvm.tir.stmt.For) def test_compute_at(): diff --git a/tests/python/tir-base/test_lower_build.py b/tests/python/tir-base/test_lower_build.py index e94a4f09ec56..3a68d6ff935c 100644 --- a/tests/python/tir-base/test_lower_build.py +++ b/tests/python/tir-base/test_lower_build.py @@ -60,9 +60,9 @@ def main( ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "from_legacy_te_schedule": True, "tir.noalias": True}) - A_flat = T.Buffer([16384], data=A.data) - B_flat = T.Buffer([16384], data=B.data) - C_flat = T.Buffer([16384], data=C.data) + A_flat = T.decl_buffer(16384, data=A.data) + B_flat = T.decl_buffer(16384, data=B.data) + C_flat = T.decl_buffer(16384, data=C.data) # body for x, y in T.grid(128, 128): C_flat[x * 128 + y] = 0.0 @@ -82,9 +82,9 @@ def main( ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A_flat = T.Buffer([16384], data=A.data) - B_flat = T.Buffer([16384], data=B.data) - C_flat = T.Buffer([16384], data=C.data) + A_flat = T.decl_buffer(16384, data=A.data) + B_flat = T.decl_buffer(16384, data=B.data) + C_flat = T.decl_buffer(16384, data=C.data) # body for x, y in T.grid(128, 128): C_flat[x * 128 + y] = 0.0 @@ -144,7 +144,4 @@ def test_lower_build_lowered_module(): if __name__ == "__main__": - test_lower_build_te_schedule() - test_lower_build_tir_func() - test_lower_build_tir_module() - test_lower_build_lowered_module() + tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_flatten_buffer.py b/tests/python/tir-transform/test_tir_transform_flatten_buffer.py index 20f91b639497..a7965e4db423 100644 --- a/tests/python/tir-transform/test_tir_transform_flatten_buffer.py +++ b/tests/python/tir-transform/test_tir_transform_flatten_buffer.py @@ -41,42 +41,10 @@ def before(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")): C[i, j] = B_new[0, j] * 2.0 def expected(input_A: T.Buffer((16, 16), "float32"), input_C: T.Buffer((16, 16), "float32")): - A = T.Buffer(256, dtype="float32", data=input_A.data) - C = T.Buffer(256, dtype="float32", data=input_C.data) + A = T.decl_buffer(256, dtype="float32", data=input_A.data) + C = T.decl_buffer(256, dtype="float32", data=input_C.data) for i in T.serial(0, 16): - B_new_data = T.allocate([16], "float32", scope="global") - B_new = T.Buffer([16], "float32", scope="global", data=B_new_data) - for j in T.serial(0, 16): - B_new[j] = A[((i * 16) + j)] + 1.0 - for j in T.serial(0, 16): - C[((i * 16) + j)] = B_new[j] * 2.0 - - -class TestElementwiseWithoutDeclBuffer(BaseCompare): - """2-d buffers are flattened to 1-d - - Like TestElementwise, but the TIR doesn't have the DeclBuffer - node. The T.Buffer declaration applies only during the - parsing the TVMScript, and doesn't occur in the TIR itself. In - this case, the allocation should be assumed to be targeting flat - memory, and should be flattened to a 1-d allocation. - """ - - def before(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")): - for i in T.serial(0, 16): - B_new_data = T.allocate([1, 16], "float32", "global") - B_new = T.Buffer([1, 16], "float32", data=B_new_data) - for j in T.serial(0, 16): - B_new[0, j] = A[i, j] + 1.0 - for j in T.serial(0, 16): - C[i, j] = B_new[0, j] * 2.0 - - def expected(input_A: T.Buffer((16, 16), "float32"), input_C: T.Buffer((16, 16), "float32")): - A = T.Buffer(256, dtype="float32", data=input_A.data) - C = T.Buffer(256, dtype="float32", data=input_C.data) - for i in T.serial(0, 16): - B_new_data = T.allocate([16], "float32", "global") - B_new = T.Buffer(16, "float32", data=B_new_data) + B_new = T.decl_buffer(16, "float32", scope="global") for j in T.serial(0, 16): B_new[j] = A[((i * 16) + j)] + 1.0 for j in T.serial(0, 16): @@ -101,8 +69,8 @@ def before(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")): C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0 def expected(input_A: T.Buffer((16, 16), "float32"), input_C: T.Buffer((16, 16), "float32")): - A = T.Buffer(256, dtype="float32", data=input_A.data) - C = T.Buffer(256, dtype="float32", data=input_C.data) + A = T.decl_buffer(256, dtype="float32", data=input_A.data) + C = T.decl_buffer(256, dtype="float32", data=input_C.data) i0 = T.env_thread("blockIdx.x") i1 = T.env_thread("threadIdx.x") @@ -111,8 +79,7 @@ def expected(input_A: T.Buffer((16, 16), "float32"), input_C: T.Buffer((16, 16), T.launch_thread(i0, 4) T.launch_thread(i1, 2) T.launch_thread(i2, 2) - B_data = T.allocate([16], "float32", scope="local") - B = T.Buffer([16], "float32", scope="local", data=B_data) + B = T.decl_buffer(16, "float32", scope="local") for j in range(0, 16): B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0 for j in range(0, 16): @@ -136,12 +103,11 @@ def before(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: def expected(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: input_A = T.match_buffer(a, (n, m), "float32") input_C = T.match_buffer(c, (n, m), "float32") - A = T.Buffer(n * m, "float32", data=input_A.data) - C = T.Buffer(n * m, "float32", data=input_C.data) + A = T.decl_buffer(n * m, "float32", data=input_A.data) + C = T.decl_buffer(n * m, "float32", data=input_C.data) for i in range(0, n): - B_data = T.allocate([m], "float32", scope="global") - B = T.Buffer([m], "float32", scope="global", data=B_data) + B = T.decl_buffer(m, "float32", scope="global") for j in range(0, m): B[j] = A[i * m + j] + 1.0 for j in range(0, m): @@ -161,8 +127,8 @@ def before(a: T.handle, b: T.handle, n: T.int32) -> None: def expected(a: T.handle, b: T.handle, n: T.int32) -> None: input_A = T.match_buffer(a, (32, n, n), "float32") input_B = T.match_buffer(b, (32, n, n), "float32") - A = T.Buffer(n * n * 32, "float32", data=input_A.data) - B = T.Buffer(n * n * 32, "float32", data=input_B.data) + A = T.decl_buffer(n * n * 32, "float32", data=input_A.data) + B = T.decl_buffer(n * n * 32, "float32", data=input_B.data) for i in range(0, n * n * 32): B[i] = A[i] @@ -185,8 +151,8 @@ def before(a: T.handle, b: T.handle, n: T.int32) -> None: def expected(a: T.handle, b: T.handle, n: T.int32) -> None: input_A = T.match_buffer(a, (32, n, n), "float32") input_B = T.match_buffer(b, (32, n, n), "float32") - A = T.Buffer(n * n * 32, "float32", data=input_A.data) - B = T.Buffer(n * n * 32, "float32", data=input_B.data) + A = T.decl_buffer(n * n * 32, "float32", data=input_A.data) + B = T.decl_buffer(n * n * 32, "float32", data=input_B.data) for bx, tx in T.grid((n * n + 1) // 2, 64): if bx * 64 + tx < n * n * 32: @@ -205,14 +171,12 @@ def before(A: T.Buffer((4, 32), "float32"), D: T.Buffer((4, 32), "float32")): D[i, j] = C[i, j] * 2.0 def expected(input_A: T.Buffer((4, 32), "float32"), input_D: T.Buffer((4, 32), "float32")): - A = T.Buffer(128, "float32", data=input_A.data) - D = T.Buffer(128, "float32", data=input_D.data) + A = T.decl_buffer(128, "float32", data=input_A.data) + D = T.decl_buffer(128, "float32", data=input_D.data) for i, j in T.grid(4, 32): - B_data = T.allocate([128], "float32", scope="global") - B = T.Buffer([128], "float32", scope="global", data=B_data) - C_data = T.allocate([128], "float32", scope="global") - C = T.Buffer([128], "float32", scope="global", data=C_data) + B = T.decl_buffer(128, "float32", scope="global") + C = T.decl_buffer(128, "float32", scope="global") B[i * 32 + j] = A[i * 32 + j] + 1.0 C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j] D[i * 32 + j] = C[i * 32 + j] * 2.0 @@ -231,11 +195,10 @@ def before(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")): C[i0 * 4 + i1, j] = B_1[i1, j] * 2.0 def expected(input_A: T.Buffer((16, 16), "float32"), input_C: T.Buffer((16, 16), "float32")): - A = T.Buffer(256, dtype="float32", data=input_A.data) - C = T.Buffer(256, dtype="float32", data=input_C.data) + A = T.decl_buffer(256, dtype="float32", data=input_A.data) + C = T.decl_buffer(256, dtype="float32", data=input_C.data) for i0 in T.serial(0, 4): - B_new_data = T.allocate([68], "float32", scope="global") - B_new = T.Buffer([68], "float32", scope="global", data=B_new_data) + B_new = T.decl_buffer(68, "float32", scope="global") for i1 in T.serial(0, 4): for j in T.serial(0, 16): B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0 @@ -252,8 +215,8 @@ def before(A: T.Buffer(10, "bool"), B: T.Buffer(10, "bool")) -> None: B[i0] = A[i0] def expected(input_A: T.Buffer(10, "bool"), input_B: T.Buffer(10, "bool")) -> None: - A = T.Buffer(10, dtype="int8", data=input_A.data) - B = T.Buffer(10, dtype="int8", data=input_B.data) + A = T.decl_buffer(10, dtype="int8", data=input_A.data) + B = T.decl_buffer(10, dtype="int8", data=input_B.data) # body for i0 in T.serial(10): B[i0] = T.cast(T.cast(A[i0], "bool"), "int8") @@ -329,8 +292,7 @@ def before(): T.evaluate(A[i0, i1, i2, i3, i4, i5]) def expected(): - A_data = T.allocate([30, 1001], dtype="float32", scope="global") - A = T.Buffer([30, 1001], dtype="float32", scope="global", axis_separators=[1], data=A_data) + A = T.decl_buffer([30, 1001], axis_separators=[1], dtype="float32", scope="global") for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13): T.evaluate(A[i0 * 15 + i1 * 5 + i2, i3 * 143 + i4 * 13 + i5]) diff --git a/tests/python/tir-transform/test_tir_transform_loop_partition.py b/tests/python/tir-transform/test_tir_transform_loop_partition.py index 2b3f73e24f88..9fa0b7b13d14 100644 --- a/tests/python/tir-transform/test_tir_transform_loop_partition.py +++ b/tests/python/tir-transform/test_tir_transform_loop_partition.py @@ -17,7 +17,7 @@ import pytest import tvm import tvm.testing -from tvm import te +from tvm import te, tir from tvm.ir.module import IRModule from tvm.script import tir as T import numpy @@ -182,7 +182,11 @@ def test_vectorize(): s[C].bind(tx, te.thread_axis("threadIdx.x")) s[C].vectorize(x) stmt = tvm.lower(s, [A, B], name="main")["main"] - body = stmt.body.body.body.body + + body = stmt + while not isinstance(body, tir.IfThenElse): + body = body.body + assert x.var.name not in str(body.condition) assert any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp))) @@ -233,7 +237,11 @@ def test_thread_axis2(): s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) stmt = tvm.lower(s, [A, B], name="main")["main"] - for_body = stmt.body.body.body.body[0] + + while not isinstance(stmt, tir.SeqStmt): + stmt = stmt.body + + for_body = stmt[0] assert "threadIdx" not in str(for_body.extent) @@ -567,7 +575,7 @@ def test_explicit_partition_hint(): mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.LoopPartition()(mod) mod = tvm.tir.transform.Simplify()(mod) - assert tvm.ir.structural_equal(mod["main"], partitioned_concat) + tvm.ir.assert_structural_equal(mod["main"], partitioned_concat) def partition_from_scheduled_tir(prim_func, pass_cfg): @@ -629,7 +637,7 @@ def test_condition_mutually_exclusive(): mod = partition_from_scheduled_tir( concat_func_3, {"tir.LoopPartition": {"partition_const_loop": True}} ) - assert tvm.ir.structural_equal( + tvm.ir.assert_structural_equal( mod["main"], partitioned_concat_3.with_attr("global_symbol", "main") ) @@ -681,7 +689,7 @@ def partitioned_main( mod = tvm.tir.transform.UnrollLoop()(mod) mod = tvm.tir.transform.RemoveNoOp()(mod) mod = tvm.tir.transform.Simplify()(mod) - assert tvm.ir.structural_equal(mod["main"], partitioned_main.with_attr("global_symbol", "main")) + tvm.ir.assert_structural_equal(mod["main"], partitioned_main.with_attr("global_symbol", "main")) def test_loop_partition_recursive_unroll_hint(): @@ -712,32 +720,28 @@ def main(): @T.prim_func def partitioned_main(): - placeholder_0_dm = T.allocate([16384], "int8", "global") - placeholder_0_dm_1 = T.Buffer([16384], dtype="int8", data=placeholder_0_dm) + placeholder_0_dm = T.decl_buffer([16384], "int8") for i3_0 in T.unroll(2): for i2_0 in T.unroll(2): - pad_temp = T.allocate([4096], "int8", "global") - pad_temp_1 = T.Buffer([4096], dtype="int8", data=pad_temp) + pad_temp = T.decl_buffer([4096], "int8") for ax0, ax1, ax2 in T.grid(16, 16, 16): if 6 <= i2_0 * 4 + ax0 and 6 <= i3_0 * 4 + ax1: - pad_temp_1[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[ + pad_temp[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm[ i2_0 * 2048 + ax0 * 512 + i3_0 * 64 + ax1 * 16 + ax2 ] for i2_0 in T.unroll(2): - pad_temp_2 = T.allocate([4096], "int8", "global") - pad_temp_3 = T.Buffer([4096], dtype="int8", data=pad_temp_2) + pad_temp_2 = T.decl_buffer([4096], "int8") for ax0, ax1, ax2 in T.grid(16, 16, 16): if 6 <= i2_0 * 4 + ax0: - pad_temp_3[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[ + pad_temp_2[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm[ i2_0 * 2048 + ax0 * 512 + ax1 * 16 + ax2 + 128 ] for i3_0 in T.unroll(2): for i2_0 in T.unroll(2): - pad_temp_4 = T.allocate([4096], "int8", "global") - pad_temp_5 = T.Buffer([4096], dtype="int8", data=pad_temp_4) + pad_temp_4 = T.decl_buffer([4096], "int8") for ax0, ax1, ax2 in T.grid(16, 16, 16): if 6 <= i2_0 * 4 + ax0 and i3_0 * 4 + ax1 < 14: - pad_temp_5[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[ + pad_temp_4[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm[ i2_0 * 2048 + ax0 * 512 + i3_0 * 64 + ax1 * 16 + ax2 + 192 ] @@ -750,7 +754,7 @@ def partitioned_main(): } }, ) - assert tvm.ir.structural_equal(mod["main"], partitioned_main.with_attr("global_symbol", "main")) + tvm.ir.assert_structural_equal(mod["main"], partitioned_main.with_attr("global_symbol", "main")) def test_loop_partition_keep_loop_annotations(): @@ -784,7 +788,7 @@ def after(A: T.Buffer(160, "int32"), B: T.Buffer(160, "int32")) -> None: } }, ) - assert tvm.ir.structural_equal(mod["main"], after.with_attr("global_symbol", "main")) + tvm.ir.assert_structural_equal(mod["main"], after.with_attr("global_symbol", "main")) def test_loop_partition_with_unit_loop_in_condition(): @@ -832,7 +836,7 @@ def after( } }, ) - assert tvm.ir.structural_equal(mod["main"], after.with_attr("global_symbol", "main")) + tvm.ir.assert_structural_equal(mod["main"], after.with_attr("global_symbol", "main")) @T.prim_func diff --git a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py b/tests/python/tir-transform/test_tir_transform_narrow_datatype.py index c03dd7a5291d..e2641a65f287 100644 --- a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py +++ b/tests/python/tir-transform/test_tir_transform_narrow_datatype.py @@ -170,6 +170,8 @@ def check(m, target_bits, target_dtype): B = te.compute((), lambda *idx: te.sum(A[k], axis=k), name="B") s = te.create_schedule(B.op) stmt = lower_sch(s, [A, B], target_bits) + while isinstance(stmt, tvm.tir.DeclBuffer): + stmt = stmt.body assert stmt[1].loop_var.dtype == target_dtype # i32 -> i32 @@ -221,6 +223,8 @@ def check(shapex, shapey, target_bits, target_dtype): func = mod["main"] z = engine.lower(func, "llvm") stmt = lower_sch(z.schedule, tuple(z.inputs) + tuple(z.outputs), 32) + while isinstance(stmt, tvm.tir.DeclBuffer): + stmt = stmt.body # outer loop assert stmt.loop_var.dtype == target_dtype # inner loop @@ -262,7 +266,7 @@ def check(shape, index, target_bits, target_dtype): func = mod["main"] z = engine.lower(func, "llvm") stmt = lower_sch(z.schedule, tuple(z.inputs) + tuple(z.outputs), 32) - assert stmt.value.indices[0].dtype == target_dtype + assert stmt.body.body.value.indices[0].dtype == target_dtype check( (const(2**16, "int64"), const(2**15 + 1, "int64")), diff --git a/tests/python/tir-transform/test_tir_transform_storage_flatten.py b/tests/python/tir-transform/test_tir_transform_storage_flatten.py index 8ddfbb5adfd3..d3adea149fb9 100644 --- a/tests/python/tir-transform/test_tir_transform_storage_flatten.py +++ b/tests/python/tir-transform/test_tir_transform_storage_flatten.py @@ -53,7 +53,7 @@ def test_flatten_prefetch(): mod = tvm.transform.Sequential( [tvm.tir.transform.StorageFlatten(64), tvm.tir.transform.Simplify()] )(mod) - stmt = mod["main"].body + stmt = mod["main"].body.body assert stmt.extent.value == 2 assert isinstance(stmt.body, tvm.tir.For) assert stmt.body.extent.value == 2 @@ -80,7 +80,7 @@ def test_flatten_storage_align(): [tvm.tir.transform.StorageFlatten(64), tvm.tir.transform.Simplify()] )(mod) - stmt = mod["main"].body + stmt = mod["main"].body.body.body assert stmt.extents[0].value == 17 * 8 @@ -114,9 +114,9 @@ def main(A_param: T.handle, C_param: T.handle): ] )(mod) - stmt = mod["main"].body - assert isinstance(stmt.body, tvm.tir.Allocate) - assert list(stmt.body.extents) == [8] + stmt = mod["main"].body.body.body.body + assert isinstance(stmt, tvm.tir.Allocate) + assert list(stmt.extents) == [8] mod = tvm.tir.transform.ThreadSync("shared")(mod) f = mod["main"]