2424from tvm .script import tir as T
2525import tvm .contrib .nnpack
2626from tvm .tir .schedule .analysis import has_block
27- from tvm .topi .arm_cpu .matmul import _get_transpose_interleave_intrin_name
2827
2928from ..utils import traverse_inline , get_const_tuple
3029from .. import nn
@@ -776,10 +775,6 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
776775 get_transpose_interleave_intrin_name ,
777776 )
778777
779- transpose_interleave_intrin_name = _get_transpose_interleave_intrin_name (
780- in_dtype , out_dtype
781- )
782-
783778 # Interleave the padded im2col matrix utilizing the matrix tile
784779 interleave_t_A_block = sch .cache_read (gemm_block , 0 , "global" )
785780 sch .transform_layout (interleave_t_A_block , ("write" , 0 ), lambda b , m , k : (b , k , m ))
@@ -788,7 +783,9 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
788783 ko , ki = sch .split (k , factors = (None , tile_K ), disable_predication = True )
789784 sch .parallel (b )
790785 sch .reorder (b , ko , mo , ki , mi )
791- sch .tensorize (ki , get_transpose_interleave_intrin_name (in_dtype , out_dtype , M_padded , K_padded ))
786+ sch .tensorize (
787+ ki , get_transpose_interleave_intrin_name (in_dtype , out_dtype , M_padded , K_padded )
788+ )
792789
793790 # Interleave the padded weights matrix utilizing the matrix tile
794791 if in_dtype == "float16" :
@@ -798,7 +795,9 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
798795 ko , ki = sch .split (k , factors = (None , tile_K ), disable_predication = True )
799796 no , ni = sch .split (n , factors = (None , tile_N ), disable_predication = True )
800797 sch .reorder (ko , no , ki , ni )
801- sch .tensorize (ki , get_transpose_interleave_intrin_name (in_dtype , out_dtype , M_padded , K_padded ))
798+ sch .tensorize (
799+ ki , get_transpose_interleave_intrin_name (in_dtype , out_dtype , M_padded , K_padded )
800+ )
802801
803802 # Split and reorder the loops of the GeMM for tensorization
804803 b , m , n , k = sch .get_loops (gemm_block )
@@ -821,7 +820,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
821820 )
822821 tvm .tir .TensorIntrin .register (
823822 sme_gemm_interleaved_intrin_name ,
824- * get_sme_gemm_interleaved_mopa_2svlx2svl_intrin (K_padded , in_dtype ),
823+ * get_sme_gemm_interleaved_mopa_2svlx2svl_intrin (M_padded , K_padded , in_dtype ),
825824 override = True ,
826825 )
827826 sch .tensorize (mi , sme_gemm_interleaved_intrin_name )
0 commit comments