From bec55e23c1130d68d6525cc552766c2036f6b731 Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Thu, 31 Aug 2023 09:54:10 +0000 Subject: [PATCH 1/4] [Relay][TOPI] Remove input padding for arm_cpu conv2d int8 native schedule in Legalize pass The Legalize pass was unnecessarily padding the input channels for conv2d int8 native implementations. Since the conv2d schedule itself can add padding more efficiently, I skipped padding during the pass and further optimised the schedule to deal with it. For the int8 interleaved implementation, I kept the Legalize pass padding, which transforms the input channels into a multiple of 8, and modified the schedule to ensure vectorization of the input data. I also added a test to check whether or not the Legalize pass pads the conv2d input data. --- python/tvm/topi/arm_cpu/conv2d_alter_op.py | 37 ++++++++++++- python/tvm/topi/arm_cpu/conv2d_gemm.py | 38 +++++++++++--- tests/python/relay/test_pass_legalize.py | 61 ++++++++++++++++++++-- 3 files changed, 123 insertions(+), 13 deletions(-) diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py b/python/tvm/topi/arm_cpu/conv2d_alter_op.py index 8ddb591397e4..9fb32dfdb8d9 100644 --- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py +++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py @@ -25,6 +25,7 @@ from tvm import te from tvm import relay from tvm import autotvm +from tvm.target.target import Target from ..nn import conv2d_alter_layout, conv2d_legalize from ..utils import get_const_tuple @@ -503,12 +504,44 @@ def _conv2d_legalize(attrs, inputs, arg_types): # Collect the input exprs. data, kernel = inputs + # Determine conv2d implementation + target = Target.current(allow_none=False) + _, outs = relay.backend.te_compiler.select_implementation( + relay.op.get("nn.conv2d"), + attrs, + [ + te.placeholder(data_tensor.shape, data_dtype), + te.placeholder(kernel_tensor.shape, kernel_dtype), + ], + output_tensor, + target, + ) + workload = autotvm.task.get_workload(outs) + topi_tmpl = workload[0] + # ARM vector instructions operate on the same dtype for data and kernel, we # provide those here and conv2d_alter_int8_common will convert to the # correct datatype. - if is_int8_hw_support(kernel_dtype, kernel_dtype): + if is_int8_hw_support(data_dtype, kernel_dtype): # ARM intrinsics need the datatypes of data and kernel to be the same + if ( + attrs["data_layout"] == "NHWC" + and attrs["kernel_layout"] == "HWIO" + and topi_tmpl == "conv2d_NHWC_quantized_native.arm_cpu" + ): + in_channel_vector_length = data_tensor.shape[3] + else: + in_channel_vector_length = 8 + return conv2d_alter_int8_common( - data, data_tensor, kernel, kernel_tensor, output_tensor, attrs, kernel_dtype, 8, 8 + data, + data_tensor, + kernel, + kernel_tensor, + output_tensor, + attrs, + kernel_dtype, + in_channel_vector_length, + 8, ) return None diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index ea9026688eec..fe564d2c12de 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -166,8 +166,10 @@ def compute_conv2d_gemm_without_weight_transform( pad_before = (0, 0, 0) pad_after = (0, pad_M, pad_K) - if pad_M != 0 or pad_K != 0: - A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded") + if pad_K != 0: + A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded_K") + elif pad_M != 0: + A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded_M") idxm = tvm.tir.indexmod k = te.reduce_axis((0, K_padded), "k") @@ -316,7 +318,7 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out): # Input transform A_interleaved_input = A_interleaved.op.input_tensors[0] - if A_interleaved_input.op.name == "A_padded": + if A_interleaved_input.op.name == "A_padded_K" or A_interleaved_input.op.name == "A_padded_M": s[A_interleaved_input].compute_at(s[A_interleaved], A_interleaved.op.axis[3]) s[A_interleaved_input].vectorize(A_interleaved_input.op.axis[2]) s[A_interleaved_input].compute_inline() @@ -326,7 +328,12 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out): b, m, n = data_im2col.op.axis if data_im2col.op.name == "data_im2col": - n_outer, n_inner = s[data_im2col].split(n, 16) + n_size = data_im2col.shape[2] + if n_size % 16 == 0: + split_factor = 16 + else: + split_factor = 8 + n_outer, n_inner = s[data_im2col].split(n, split_factor) s[data_im2col].unroll(n_outer) s[data_im2col].vectorize(n_inner) b_m_fused = s[data_im2col].fuse(b, m) @@ -419,7 +426,7 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): s[C].parallel(x_outer) # Input transform - if A.op.name == "A_padded": + if A.op.name == "A_padded_K" or A.op.name == "A_padded_M": padding_A = True data_im2col = A.op.input_tensors[0] else: @@ -428,12 +435,29 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): b, m, n = data_im2col.op.axis if data_im2col.op.name == "data_im2col": - n_outer, n_inner = s[data_im2col].split(n, 16) + if A.op.name == "A_padded_K": + s[data_im2col].compute_at(s[A], A.op.axis[1]) + s[A].parallel(A.op.axis[1]) + elif A.op.name == "A_padded_M": + s[data_im2col].parallel(m) + s[A].parallel(A.op.axis[1]) + else: + s[data_im2col].parallel(m) + + split_factor = 16 + n_size = data_im2col.shape[2] + if n_size % split_factor != 0: + # Split by kernel area (KH * KW) to ensure proper vectorization + ic = data_im2col.op.input_tensors[0].shape[3] + split_factor = n_size // ic + + n_outer, n_inner = s[data_im2col].split(n, split_factor) s[data_im2col].unroll(n_outer) s[data_im2col].vectorize(n_inner) - s[data_im2col].parallel(m) elif padding_A: s[data_im2col].compute_inline() + _, n_inner = s[A].split(A.op.axis[2], 16) + s[A].vectorize(n_inner) s[A].compute_at(s[C], x_inner) else: s[data_im2col].compute_at(s[C], x_inner) diff --git a/tests/python/relay/test_pass_legalize.py b/tests/python/relay/test_pass_legalize.py index 95069d29fd84..1466784394ac 100644 --- a/tests/python/relay/test_pass_legalize.py +++ b/tests/python/relay/test_pass_legalize.py @@ -16,6 +16,7 @@ # under the License. """Test legalize pass""" import numpy as np +import pytest import tvm from tvm import te @@ -178,8 +179,60 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +@pytest.mark.parametrize( + "target,exp_in_channels", + [ + ( + "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu", + 8, + ), + ( + "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod", + 3, + ), + ( + "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v8.2a,+i8mm", + 8, + ), + ( + "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon", + 8, + ), + ( + "llvm -device=arm_cpu -mtriple=armv8l-linux-gnu -mattr=+neon", + 8, + ), + ], +) +def test_conv2d_NHWC_legalize(target, exp_in_channels): + target = tvm.target.Target(target) + + dtype = "int8" + data_layout = "NHWC" + kernel_layout = "HWIO" + in_channels = 3 + out_channels = 4 + kernel_size = (1, 1) + + x = relay.var("x", shape=(1, 1, 1, in_channels), dtype=dtype) + weight = relay.var("weight", shape=(1, 1, in_channels, out_channels), dtype=dtype) + out = relay.nn.conv2d( + x, + weight, + kernel_size=kernel_size, + channels=out_channels, + data_layout=data_layout, + kernel_layout=kernel_layout, + out_dtype=dtype, + ) + + with target: + out = run_opt_pass(out, transform.Legalize()) + + act_in_channels = out.args[0].type_args[0].shape[3] + + assert act_in_channels == exp_in_channels, "Actual input channels = " + str(act_in_channels) + + if __name__ == "__main__": - test_legalize() - test_legalize_none() - test_legalize_multiple_ops() - test_legalize_multi_input() + tvm.testing.main() From 1a138a300abbf99ab7cc5f623fed84b437950418 Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Thu, 31 Aug 2023 13:26:09 +0000 Subject: [PATCH 2/4] Revert `is_int8_hw_support` argument change --- python/tvm/topi/arm_cpu/conv2d_alter_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py b/python/tvm/topi/arm_cpu/conv2d_alter_op.py index 9fb32dfdb8d9..c2ac6105953d 100644 --- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py +++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py @@ -522,7 +522,7 @@ def _conv2d_legalize(attrs, inputs, arg_types): # ARM vector instructions operate on the same dtype for data and kernel, we # provide those here and conv2d_alter_int8_common will convert to the # correct datatype. - if is_int8_hw_support(data_dtype, kernel_dtype): + if is_int8_hw_support(kernel_dtype, kernel_dtype): # ARM intrinsics need the datatypes of data and kernel to be the same if ( attrs["data_layout"] == "NHWC" From 0787b6e48946f8ff55938d7ee5762d492e86e2e3 Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Thu, 31 Aug 2023 16:22:39 +0000 Subject: [PATCH 3/4] Fix no workload bug --- python/tvm/topi/arm_cpu/conv2d_alter_op.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py b/python/tvm/topi/arm_cpu/conv2d_alter_op.py index c2ac6105953d..6c3dbb48c9f2 100644 --- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py +++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py @@ -517,7 +517,8 @@ def _conv2d_legalize(attrs, inputs, arg_types): target, ) workload = autotvm.task.get_workload(outs) - topi_tmpl = workload[0] + if workload is not None: + topi_tmpl = workload[0] # ARM vector instructions operate on the same dtype for data and kernel, we # provide those here and conv2d_alter_int8_common will convert to the From 15694edd9544f2f9a61cfd0b8f441c54b5dab47a Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Tue, 5 Sep 2023 11:08:20 +0000 Subject: [PATCH 4/4] Address code review comments --- python/tvm/topi/arm_cpu/conv2d_gemm.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index fe564d2c12de..ca6fa158da88 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -328,12 +328,7 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out): b, m, n = data_im2col.op.axis if data_im2col.op.name == "data_im2col": - n_size = data_im2col.shape[2] - if n_size % 16 == 0: - split_factor = 16 - else: - split_factor = 8 - n_outer, n_inner = s[data_im2col].split(n, split_factor) + n_outer, n_inner = s[data_im2col].split(n, 16) s[data_im2col].unroll(n_outer) s[data_im2col].vectorize(n_inner) b_m_fused = s[data_im2col].fuse(b, m) @@ -418,7 +413,8 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): b, x, y = C.op.axis (k,) = C.op.reduce_axis k_outer, k_inner = s[C].split(k, 16) - x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, y_factor=16) + y_tile_size = 16 + x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, y_factor=y_tile_size) s[C].reorder(b, x_outer, y_outer, k_outer, x_inner, y_inner, k_inner) gemm_acc = gemm_acc_nx16_int8_int8_int32(in_type, rows=1) s[C].unroll(x_inner) @@ -435,12 +431,15 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): b, m, n = data_im2col.op.axis if data_im2col.op.name == "data_im2col": + # Either only pad_K or both pad_K and pad_M applied if A.op.name == "A_padded_K": s[data_im2col].compute_at(s[A], A.op.axis[1]) s[A].parallel(A.op.axis[1]) + # Only pad_M applied elif A.op.name == "A_padded_M": s[data_im2col].parallel(m) s[A].parallel(A.op.axis[1]) + # No padding else: s[data_im2col].parallel(m) @@ -456,7 +455,7 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): s[data_im2col].vectorize(n_inner) elif padding_A: s[data_im2col].compute_inline() - _, n_inner = s[A].split(A.op.axis[2], 16) + _, n_inner = s[A].split(A.op.axis[2], y_tile_size) s[A].vectorize(n_inner) s[A].compute_at(s[C], x_inner) else: