diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py b/python/tvm/topi/arm_cpu/conv2d_alter_op.py index 8ddb591397e4..6c3dbb48c9f2 100644 --- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py +++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py @@ -25,6 +25,7 @@ from tvm import te from tvm import relay from tvm import autotvm +from tvm.target.target import Target from ..nn import conv2d_alter_layout, conv2d_legalize from ..utils import get_const_tuple @@ -503,12 +504,45 @@ def _conv2d_legalize(attrs, inputs, arg_types): # Collect the input exprs. data, kernel = inputs + # Determine conv2d implementation + target = Target.current(allow_none=False) + _, outs = relay.backend.te_compiler.select_implementation( + relay.op.get("nn.conv2d"), + attrs, + [ + te.placeholder(data_tensor.shape, data_dtype), + te.placeholder(kernel_tensor.shape, kernel_dtype), + ], + output_tensor, + target, + ) + workload = autotvm.task.get_workload(outs) + if workload is not None: + topi_tmpl = workload[0] + # ARM vector instructions operate on the same dtype for data and kernel, we # provide those here and conv2d_alter_int8_common will convert to the # correct datatype. if is_int8_hw_support(kernel_dtype, kernel_dtype): # ARM intrinsics need the datatypes of data and kernel to be the same + if ( + attrs["data_layout"] == "NHWC" + and attrs["kernel_layout"] == "HWIO" + and topi_tmpl == "conv2d_NHWC_quantized_native.arm_cpu" + ): + in_channel_vector_length = data_tensor.shape[3] + else: + in_channel_vector_length = 8 + return conv2d_alter_int8_common( - data, data_tensor, kernel, kernel_tensor, output_tensor, attrs, kernel_dtype, 8, 8 + data, + data_tensor, + kernel, + kernel_tensor, + output_tensor, + attrs, + kernel_dtype, + in_channel_vector_length, + 8, ) return None diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index ea9026688eec..ca6fa158da88 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -166,8 +166,10 @@ def compute_conv2d_gemm_without_weight_transform( pad_before = (0, 0, 0) pad_after = (0, pad_M, pad_K) - if pad_M != 0 or pad_K != 0: - A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded") + if pad_K != 0: + A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded_K") + elif pad_M != 0: + A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded_M") idxm = tvm.tir.indexmod k = te.reduce_axis((0, K_padded), "k") @@ -316,7 +318,7 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out): # Input transform A_interleaved_input = A_interleaved.op.input_tensors[0] - if A_interleaved_input.op.name == "A_padded": + if A_interleaved_input.op.name == "A_padded_K" or A_interleaved_input.op.name == "A_padded_M": s[A_interleaved_input].compute_at(s[A_interleaved], A_interleaved.op.axis[3]) s[A_interleaved_input].vectorize(A_interleaved_input.op.axis[2]) s[A_interleaved_input].compute_inline() @@ -411,7 +413,8 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): b, x, y = C.op.axis (k,) = C.op.reduce_axis k_outer, k_inner = s[C].split(k, 16) - x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, y_factor=16) + y_tile_size = 16 + x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, y_factor=y_tile_size) s[C].reorder(b, x_outer, y_outer, k_outer, x_inner, y_inner, k_inner) gemm_acc = gemm_acc_nx16_int8_int8_int32(in_type, rows=1) s[C].unroll(x_inner) @@ -419,7 +422,7 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): s[C].parallel(x_outer) # Input transform - if A.op.name == "A_padded": + if A.op.name == "A_padded_K" or A.op.name == "A_padded_M": padding_A = True data_im2col = A.op.input_tensors[0] else: @@ -428,12 +431,32 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): b, m, n = data_im2col.op.axis if data_im2col.op.name == "data_im2col": - n_outer, n_inner = s[data_im2col].split(n, 16) + # Either only pad_K or both pad_K and pad_M applied + if A.op.name == "A_padded_K": + s[data_im2col].compute_at(s[A], A.op.axis[1]) + s[A].parallel(A.op.axis[1]) + # Only pad_M applied + elif A.op.name == "A_padded_M": + s[data_im2col].parallel(m) + s[A].parallel(A.op.axis[1]) + # No padding + else: + s[data_im2col].parallel(m) + + split_factor = 16 + n_size = data_im2col.shape[2] + if n_size % split_factor != 0: + # Split by kernel area (KH * KW) to ensure proper vectorization + ic = data_im2col.op.input_tensors[0].shape[3] + split_factor = n_size // ic + + n_outer, n_inner = s[data_im2col].split(n, split_factor) s[data_im2col].unroll(n_outer) s[data_im2col].vectorize(n_inner) - s[data_im2col].parallel(m) elif padding_A: s[data_im2col].compute_inline() + _, n_inner = s[A].split(A.op.axis[2], y_tile_size) + s[A].vectorize(n_inner) s[A].compute_at(s[C], x_inner) else: s[data_im2col].compute_at(s[C], x_inner) diff --git a/tests/python/relay/test_pass_legalize.py b/tests/python/relay/test_pass_legalize.py index 95069d29fd84..1466784394ac 100644 --- a/tests/python/relay/test_pass_legalize.py +++ b/tests/python/relay/test_pass_legalize.py @@ -16,6 +16,7 @@ # under the License. """Test legalize pass""" import numpy as np +import pytest import tvm from tvm import te @@ -178,8 +179,60 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +@pytest.mark.parametrize( + "target,exp_in_channels", + [ + ( + "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu", + 8, + ), + ( + "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod", + 3, + ), + ( + "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v8.2a,+i8mm", + 8, + ), + ( + "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon", + 8, + ), + ( + "llvm -device=arm_cpu -mtriple=armv8l-linux-gnu -mattr=+neon", + 8, + ), + ], +) +def test_conv2d_NHWC_legalize(target, exp_in_channels): + target = tvm.target.Target(target) + + dtype = "int8" + data_layout = "NHWC" + kernel_layout = "HWIO" + in_channels = 3 + out_channels = 4 + kernel_size = (1, 1) + + x = relay.var("x", shape=(1, 1, 1, in_channels), dtype=dtype) + weight = relay.var("weight", shape=(1, 1, in_channels, out_channels), dtype=dtype) + out = relay.nn.conv2d( + x, + weight, + kernel_size=kernel_size, + channels=out_channels, + data_layout=data_layout, + kernel_layout=kernel_layout, + out_dtype=dtype, + ) + + with target: + out = run_opt_pass(out, transform.Legalize()) + + act_in_channels = out.args[0].type_args[0].shape[3] + + assert act_in_channels == exp_in_channels, "Actual input channels = " + str(act_in_channels) + + if __name__ == "__main__": - test_legalize() - test_legalize_none() - test_legalize_multiple_ops() - test_legalize_multi_input() + tvm.testing.main()