@@ -207,65 +207,6 @@ def _callback(op):
207207 return s
208208
209209
210- def dense_vnni_compute (X , packedW , bias = None ):
211- """Compute for uint8 x int8 -> int32 dense"""
212- m , k = X .shape
213- n_o , _ , n_i , _ = packedW .shape
214- ak = te .reduce_axis ((0 , k ), name = "k" )
215-
216- C = te .compute (
217- (m , n_o * n_i ),
218- lambda i , j : te .sum (
219- X [i , ak ].astype ("int32" )
220- * packedW [tvm .tir .indexdiv (j , 16 ), tvm .tir .indexdiv (ak , 4 ), j % 16 , ak % 4 ].astype (
221- "int32"
222- ),
223- axis = ak ,
224- ),
225- tag = "dense_vnni" ,
226- )
227-
228- if bias is not None :
229- C = te .compute (C .shape , lambda i , j : C [i , j ] + bias [j ], tag = tag .BROADCAST )
230-
231- return C
232-
233-
234- def dense_vnni_schedule (s , C , O ):
235- """Schedule dense compute using VNNI vpdpbusd instruction"""
236- # C: The output of GEMM
237- # O: The output of the fused op
238- if C != O :
239- a_y , a_x = O .op .axis
240- a_yo , a_yi = s [O ].split (a_y , factor = 32 )
241- a_xo , a_xi = s [O ].split (a_x , factor = 16 )
242-
243- s [O ].reorder (a_yo , a_xo , a_yi , a_xi )
244- fused = s [O ].fuse (a_yo , a_xo )
245- s [O ].vectorize (a_xi )
246- s [O ].parallel (fused )
247-
248- s [C ].compute_at (s [O ], a_yi )
249-
250- a_y , a_x = C .op .axis
251- (a_k ,) = C .op .reduce_axis
252-
253- a_ko , a_ki = s [C ].split (a_k , factor = 4 )
254- a_yo , a_yi = s [C ].split (a_y , factor = 32 )
255- a_xo , a_xi = s [C ].split (a_x , factor = 16 )
256-
257- s [C ].reorder (a_yo , a_xo , a_yi , a_ko , a_xi , a_ki )
258-
259- pc = dot_16x1x16_uint8_int8_int32_cascadelake ()
260- s [C ].tensorize (a_xi , pc )
261-
262- if C == O :
263- fused = s [O ].fuse (a_yo , a_xo )
264- s [O ].parallel (fused )
265-
266- return s
267-
268-
269210@autotvm .register_topi_compute ("dense_pack.x86" )
270211def dense_pack (cfg , data , weight , bias = None , out_dtype = None ):
271212 """Compute dense with transformed weight."""
@@ -275,10 +216,6 @@ def dense_pack(cfg, data, weight, bias=None, out_dtype=None):
275216 if len (weight .shape ) == 3 :
276217 N , _ , packw_bn = get_const_tuple (weight .shape ) # out_dim
277218 N = N * packw_bn
278- elif len (weight .shape ) == 4 :
279- N , K , n_inner , k_inner = get_const_tuple (weight .shape ) # out_dim
280- assert n_inner == 16 and k_inner == 4
281- return dense_vnni_compute (data , weight , bias )
282219 else :
283220 N , _ = get_const_tuple (weight .shape ) # out_dim
284221 # create tuning space
@@ -336,15 +273,106 @@ def schedule_dense_pack(cfg, outs):
336273 s = te .create_schedule ([x .op for x in outs ])
337274
338275 def _callback (op ):
339- if "dense_vnni" in op .tag :
340- dense_vnni_schedule (s , op .output (0 ), outs [0 ])
341276 if "dense_pack" in op .tag :
342277 _schedule_dense_pack_template (cfg , s , op .output (0 ), outs [0 ])
343278
344279 traverse_inline (s , outs [0 ].op , _callback )
345280 return s
346281
347282
283+ def dense_vnni_compute (cfg , X , packed_w , bias = None ):
284+ """Compute for uint8 x int8 -> int32 dense"""
285+ m , k = X .shape
286+ n_o , _ , n_i , _ = packed_w .shape
287+ ak = te .reduce_axis ((0 , k ), name = "k" )
288+
289+ C = te .compute (
290+ (m , n_o * n_i ),
291+ lambda i , j : te .sum (
292+ X [i , ak ].astype ("int32" )
293+ * packed_w [tvm .tir .indexdiv (j , 16 ), tvm .tir .indexdiv (ak , 4 ), j % 16 , ak % 4 ].astype (
294+ "int32"
295+ ),
296+ axis = ak ,
297+ ),
298+ tag = "dense_vnni" ,
299+ )
300+
301+ if bias is not None :
302+ C = te .compute (C .shape , lambda i , j : C [i , j ] + bias [j ], tag = tag .BROADCAST )
303+
304+ a_y , _ = C .op .axis
305+ cfg .define_split ("tile_y" , a_y , num_outputs = 2 )
306+
307+ return C
308+
309+
310+ def dense_vnni_schedule (cfg , s , C , O ):
311+ """Schedule dense compute using VNNI vpdpbusd instruction"""
312+ # C: The output of GEMM
313+ # O: The output of the fused op
314+ def split_y (out ):
315+ default_y_split_factor = 32
316+ a_y = out .op .axis [0 ]
317+
318+ if cfg .is_fallback :
319+ return s [out ].split (a_y , factor = default_y_split_factor )
320+ else :
321+ return cfg ["tile_y" ].apply (s , out , a_y )
322+
323+ (a_k ,) = C .op .reduce_axis
324+
325+ a_yo , a_yi = split_y (C )
326+ a_xo , a_xi = s [C ].split (C .op .axis [1 ], factor = 16 )
327+ a_ko , a_ki = s [C ].split (a_k , factor = 4 )
328+
329+ s [C ].reorder (a_yo , a_xo , a_yi , a_ko , a_xi , a_ki )
330+
331+ pc = dot_16x1x16_uint8_int8_int32_cascadelake ()
332+ s [C ].tensorize (a_xi , pc )
333+
334+ if C == O :
335+ fused = s [O ].fuse (a_yo , a_xo )
336+ s [O ].parallel (fused )
337+ else :
338+ a_yo , a_yi = split_y (O )
339+ a_xo , a_xi = s [O ].split (O .op .axis [1 ], factor = 16 )
340+
341+ s [O ].reorder (a_yo , a_xo , a_yi , a_xi )
342+ fused = s [O ].fuse (a_yo , a_xo )
343+ s [O ].vectorize (a_xi )
344+ s [O ].parallel (fused )
345+
346+ s [C ].compute_at (s [O ], a_yi )
347+
348+ return s
349+
350+
351+ @autotvm .register_topi_compute ("dense_vnni.x86" )
352+ def dense_vnni (cfg , data , weight , bias = None , out_dtype = None ):
353+ """Compute for uint8 x int8 -> int32 dense"""
354+ if out_dtype is None :
355+ out_dtype = data .dtype
356+ assert len (weight .shape ) == 4
357+ assert data .dtype == "uint8" and weight .dtype == "int8"
358+ _ , _ , n_inner , k_inner = get_const_tuple (weight .shape ) # out_dim
359+ assert n_inner == 16 and k_inner == 4
360+ return dense_vnni_compute (cfg , data , weight , bias )
361+
362+
363+ @autotvm .register_topi_schedule ("dense_vnni.x86" )
364+ def schedule_dense_vnni (cfg , outs ):
365+ """Create a schedule for dense_vnni"""
366+ s = te .create_schedule ([x .op for x in outs ])
367+
368+ def _callback (op ):
369+ if "dense_vnni" in op .tag :
370+ dense_vnni_schedule (cfg , s , op .output (0 ), outs [0 ])
371+
372+ traverse_inline (s , outs [0 ].op , _callback )
373+ return s
374+
375+
348376def matmul_blas_common (cfg , tensor_a , tensor_b , bias , out_dtype , transpose_a , transpose_b , lib ):
349377 """Compute matmul/dense using a BLAS library"""
350378 M , K = get_const_tuple (tensor_a .shape )
0 commit comments