From a941ae9f521cb48b22e5d5c2674c193f5971545e Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Wed, 6 Sep 2023 16:04:37 +0000 Subject: [PATCH] [TOPI] Ensure vectorization of input padding in `arm_cpu` int8 conv2d interleaved schedule When padding the input data, the int8 conv2d interleaved schedule tries to split the `data_im2col` cols axis by a factor of 16 in order to then vectorize over those splits. However, the size of the axis is `n_size = KH x KW x IC` and the `Legalize` pass only pads the number of input channels up to a multiple of 8. Therefore, `n_size` is only guaranteed to be a multiple of 8, not 16. I modified the schedule to check whether a split factor of 16 is appropriate, otherwise use 8 instead, in order to ensure vectorization is performed in all cases. --- python/tvm/topi/arm_cpu/conv2d_gemm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index ea9026688eec..6ef8efec9ed5 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -326,7 +326,12 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out): b, m, n = data_im2col.op.axis if data_im2col.op.name == "data_im2col": - n_outer, n_inner = s[data_im2col].split(n, 16) + n_size = data_im2col.shape[2] + if n_size % 16 == 0: + split_factor = 16 + else: + split_factor = 8 + n_outer, n_inner = s[data_im2col].split(n, split_factor) s[data_im2col].unroll(n_outer) s[data_im2col].vectorize(n_inner) b_m_fused = s[data_im2col].fuse(b, m)