1818# pylint: disable=unused-argument, redefined-builtin
1919"""GEMM Convolution schedule on ARM"""
2020import tvm
21+ from tvm .target import Target
2122from tvm import te
2223from tvm .topi import nn
2324from tvm .autotvm .task .space import AnnotateEntity , ReorderEntity , OtherOptionEntity
2930 gemm_acc_nx16_int8_int8_int32 ,
3031 gemm_acc_2x2_int8_int8_int32 ,
3132)
32- from .arm_utils import is_aarch64_arm , is_dotprod_available , is_mmla_available
3333
3434
35- def configure_knobs (cfg , M , K ):
35+ def configure_knobs (cfg , M , K , target ):
3636 """Configure auto-tuning knobs for the interleaved strategy"""
3737
3838 x , y = cfg .axis (M // 4 ), cfg .axis (K // 16 )
@@ -48,7 +48,7 @@ def configure_knobs(cfg, M, K):
4848 cfg ["reorder_gemm" ] = ReorderEntity ([0 , 1 ])
4949 cfg ["A_interleaved_unroll_vec" ] = AnnotateEntity (["unroll" , "vec" ])
5050
51- if not is_dotprod_available () :
51+ if not target . features . has_dotprod :
5252 cfg .define_knob ("gemm_quantized_unroll" , [True , False ])
5353 if cfg .is_fallback :
5454 cfg ["gemm_quantized_unroll" ] = OtherOptionEntity (False )
@@ -133,12 +133,13 @@ def compute_conv2d_gemm_without_weight_transform(
133133 # - Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h
134134 # In order to have more information
135135 #
136- if is_mmla_available ():
136+ target = Target .current (allow_none = False )
137+ if target .features .has_matmul_i8 :
137138 # If smmla/ummla is enabled, we are loading 8 rows from A. Each row
138139 # will contain 8 elements
139140 tile_rows_A = 8
140141 tile_cols_A = 8
141- elif is_dotprod_available () and interleave_A :
142+ elif target . features . has_dotprod and interleave_A :
142143 # If dot product has been enabled, and we are interleaving A
143144 # tile size should be 8x4
144145 tile_rows_A = 8
@@ -173,15 +174,16 @@ def compute_conv2d_gemm_without_weight_transform(
173174
174175 if interleave_A :
175176 # Configuration space
176- configure_knobs (cfg , M_padded , K_padded )
177+ configure_knobs (cfg , M_padded , K_padded , target )
177178
178179 # Pack the input data
179180 A_interleaved = te .compute (
180181 (batches , M_padded // tile_rows_A , K_padded // tile_cols_A , tile_rows_A , tile_cols_A ),
181182 lambda b , x , y , z , w : A [b , z + tile_rows_A * x , w + tile_cols_A * y ],
182183 name = "A_interleaved" ,
183184 )
184- if is_mmla_available ():
185+ target = Target .current (allow_none = False )
186+ if target .features .has_matmul_i8 :
185187 # Execute GEMM. In the case of mmla, we need to enforce the tiling
186188 # from the compute. This is because mmla is doing a tiled computation
187189 # as well. So we have a big 8x12 tile, with small 2x2 sub-tiles
@@ -323,7 +325,8 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out):
323325 k = C_interleaved .op .reduce_axis [0 ]
324326 _ , M , N = C .shape
325327 if in_type in ["int8" , "uint8" ]:
326- if is_mmla_available ():
328+ target = Target .current (allow_none = False )
329+ if target .features .has_matmul_i8 :
327330 gemm_acc = gemm_acc_2x2_int8_int8_int32 (in_type )
328331 xi_inner , yi_inner = C_interleaved .op .axis [- 2 :]
329332 k_outer , k_inner = s [C_interleaved ].split (k , 8 )
@@ -333,7 +336,7 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out):
333336 s [C_interleaved ].tensorize (xi_inner , gemm_acc )
334337 s [C_interleaved ].unroll (xi )
335338 s [C_interleaved ].unroll (yi )
336- elif is_dotprod_available () :
339+ elif target . features . has_dotprod :
337340 gemm_acc = gemm_acc_4x4_int8_int8_int32 (in_type )
338341 xi_outer , yi_outer , xi_inner , yi_inner = s [C_interleaved ].tile (
339342 xi , yi , x_factor = 8 , y_factor = 4
@@ -354,7 +357,7 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out):
354357 s [C_interleaved ].tensorize (xi_inner_inner , gemm_acc )
355358 s [C_interleaved ].unroll (xi_inner_outer )
356359
357- elif is_aarch64_arm () :
360+ elif target . features . has_asimd :
358361 s [C_interleaved ].reorder (yi , xi )
359362 K = A_interleaved_input .shape [2 ]
360363 assert in_type in ["int8" , "uint8" ], "Only int8 and uint8 gemm are supported"
0 commit comments