diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index ea9026688eec..6ef8efec9ed5 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -326,7 +326,12 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out): b, m, n = data_im2col.op.axis if data_im2col.op.name == "data_im2col": - n_outer, n_inner = s[data_im2col].split(n, 16) + n_size = data_im2col.shape[2] + if n_size % 16 == 0: + split_factor = 16 + else: + split_factor = 8 + n_outer, n_inner = s[data_im2col].split(n, split_factor) s[data_im2col].unroll(n_outer) s[data_im2col].vectorize(n_inner) b_m_fused = s[data_im2col].fuse(b, m)