Skip to content

Commit accd6cd

Browse files
committed
make it tunable
1 parent babe4ac commit accd6cd

File tree

2 files changed

+114
-70
lines changed

2 files changed

+114
-70
lines changed

python/tvm/relay/op/strategy/x86.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -530,11 +530,27 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):
530530
def dense_pack_strategy_cpu(attrs, inputs, out_type, target):
531531
"""dense_pack x86 strategy"""
532532
strategy = _op.OpStrategy()
533-
strategy.add_implementation(
534-
wrap_compute_dense(topi.x86.dense_pack),
535-
wrap_topi_schedule(topi.x86.schedule_dense_pack),
536-
name="dense_pack.x86",
537-
)
533+
534+
if (
535+
inputs[0].dtype == "uint8"
536+
and inputs[1].dtype == "int8"
537+
and out_type.dtype == "int32"
538+
and attrs["weight_layout"] == "NC16n4c"
539+
):
540+
strategy.add_implementation(
541+
wrap_compute_dense(topi.x86.dense_vnni),
542+
wrap_topi_schedule(topi.x86.schedule_dense_vnni),
543+
name="dense_vnni.x86",
544+
plevel=12,
545+
)
546+
else:
547+
strategy.add_implementation(
548+
wrap_compute_dense(topi.x86.dense_pack),
549+
wrap_topi_schedule(topi.x86.schedule_dense_pack),
550+
name="dense_pack.x86",
551+
plevel=10,
552+
)
553+
538554
return strategy
539555

540556

python/tvm/topi/x86/dense.py

Lines changed: 93 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -207,65 +207,6 @@ def _callback(op):
207207
return s
208208

209209

210-
def dense_vnni_compute(X, packedW, bias=None):
211-
"""Compute for uint8 x int8 -> int32 dense"""
212-
m, k = X.shape
213-
n_o, _, n_i, _ = packedW.shape
214-
ak = te.reduce_axis((0, k), name="k")
215-
216-
C = te.compute(
217-
(m, n_o * n_i),
218-
lambda i, j: te.sum(
219-
X[i, ak].astype("int32")
220-
* packedW[tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4].astype(
221-
"int32"
222-
),
223-
axis=ak,
224-
),
225-
tag="dense_vnni",
226-
)
227-
228-
if bias is not None:
229-
C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j], tag=tag.BROADCAST)
230-
231-
return C
232-
233-
234-
def dense_vnni_schedule(s, C, O):
235-
"""Schedule dense compute using VNNI vpdpbusd instruction"""
236-
# C: The output of GEMM
237-
# O: The output of the fused op
238-
if C != O:
239-
a_y, a_x = O.op.axis
240-
a_yo, a_yi = s[O].split(a_y, factor=32)
241-
a_xo, a_xi = s[O].split(a_x, factor=16)
242-
243-
s[O].reorder(a_yo, a_xo, a_yi, a_xi)
244-
fused = s[O].fuse(a_yo, a_xo)
245-
s[O].vectorize(a_xi)
246-
s[O].parallel(fused)
247-
248-
s[C].compute_at(s[O], a_yi)
249-
250-
a_y, a_x = C.op.axis
251-
(a_k,) = C.op.reduce_axis
252-
253-
a_ko, a_ki = s[C].split(a_k, factor=4)
254-
a_yo, a_yi = s[C].split(a_y, factor=32)
255-
a_xo, a_xi = s[C].split(a_x, factor=16)
256-
257-
s[C].reorder(a_yo, a_xo, a_yi, a_ko, a_xi, a_ki)
258-
259-
pc = dot_16x1x16_uint8_int8_int32_cascadelake()
260-
s[C].tensorize(a_xi, pc)
261-
262-
if C == O:
263-
fused = s[O].fuse(a_yo, a_xo)
264-
s[O].parallel(fused)
265-
266-
return s
267-
268-
269210
@autotvm.register_topi_compute("dense_pack.x86")
270211
def dense_pack(cfg, data, weight, bias=None, out_dtype=None):
271212
"""Compute dense with transformed weight."""
@@ -275,10 +216,6 @@ def dense_pack(cfg, data, weight, bias=None, out_dtype=None):
275216
if len(weight.shape) == 3:
276217
N, _, packw_bn = get_const_tuple(weight.shape) # out_dim
277218
N = N * packw_bn
278-
elif len(weight.shape) == 4:
279-
N, K, n_inner, k_inner = get_const_tuple(weight.shape) # out_dim
280-
assert n_inner == 16 and k_inner == 4
281-
return dense_vnni_compute(data, weight, bias)
282219
else:
283220
N, _ = get_const_tuple(weight.shape) # out_dim
284221
# create tuning space
@@ -336,15 +273,106 @@ def schedule_dense_pack(cfg, outs):
336273
s = te.create_schedule([x.op for x in outs])
337274

338275
def _callback(op):
339-
if "dense_vnni" in op.tag:
340-
dense_vnni_schedule(s, op.output(0), outs[0])
341276
if "dense_pack" in op.tag:
342277
_schedule_dense_pack_template(cfg, s, op.output(0), outs[0])
343278

344279
traverse_inline(s, outs[0].op, _callback)
345280
return s
346281

347282

283+
def dense_vnni_compute(cfg, X, packed_w, bias=None):
284+
"""Compute for uint8 x int8 -> int32 dense"""
285+
m, k = X.shape
286+
n_o, _, n_i, _ = packed_w.shape
287+
ak = te.reduce_axis((0, k), name="k")
288+
289+
C = te.compute(
290+
(m, n_o * n_i),
291+
lambda i, j: te.sum(
292+
X[i, ak].astype("int32")
293+
* packed_w[tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4].astype(
294+
"int32"
295+
),
296+
axis=ak,
297+
),
298+
tag="dense_vnni",
299+
)
300+
301+
if bias is not None:
302+
C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j], tag=tag.BROADCAST)
303+
304+
a_y, _ = C.op.axis
305+
cfg.define_split("tile_y", a_y, num_outputs=2)
306+
307+
return C
308+
309+
310+
def dense_vnni_schedule(cfg, s, C, O):
311+
"""Schedule dense compute using VNNI vpdpbusd instruction"""
312+
# C: The output of GEMM
313+
# O: The output of the fused op
314+
def split_y(out):
315+
default_y_split_factor = 32
316+
a_y = out.op.axis[0]
317+
318+
if cfg.is_fallback:
319+
return s[out].split(a_y, factor=default_y_split_factor)
320+
else:
321+
return cfg["tile_y"].apply(s, out, a_y)
322+
323+
(a_k,) = C.op.reduce_axis
324+
325+
a_yo, a_yi = split_y(C)
326+
a_xo, a_xi = s[C].split(C.op.axis[1], factor=16)
327+
a_ko, a_ki = s[C].split(a_k, factor=4)
328+
329+
s[C].reorder(a_yo, a_xo, a_yi, a_ko, a_xi, a_ki)
330+
331+
pc = dot_16x1x16_uint8_int8_int32_cascadelake()
332+
s[C].tensorize(a_xi, pc)
333+
334+
if C == O:
335+
fused = s[O].fuse(a_yo, a_xo)
336+
s[O].parallel(fused)
337+
else:
338+
a_yo, a_yi = split_y(O)
339+
a_xo, a_xi = s[O].split(O.op.axis[1], factor=16)
340+
341+
s[O].reorder(a_yo, a_xo, a_yi, a_xi)
342+
fused = s[O].fuse(a_yo, a_xo)
343+
s[O].vectorize(a_xi)
344+
s[O].parallel(fused)
345+
346+
s[C].compute_at(s[O], a_yi)
347+
348+
return s
349+
350+
351+
@autotvm.register_topi_compute("dense_vnni.x86")
352+
def dense_vnni(cfg, data, weight, bias=None, out_dtype=None):
353+
"""Compute for uint8 x int8 -> int32 dense"""
354+
if out_dtype is None:
355+
out_dtype = data.dtype
356+
assert len(weight.shape) == 4
357+
assert data.dtype == "uint8" and weight.dtype == "int8"
358+
_, _, n_inner, k_inner = get_const_tuple(weight.shape) # out_dim
359+
assert n_inner == 16 and k_inner == 4
360+
return dense_vnni_compute(cfg, data, weight, bias)
361+
362+
363+
@autotvm.register_topi_schedule("dense_vnni.x86")
364+
def schedule_dense_vnni(cfg, outs):
365+
"""Create a schedule for dense_vnni"""
366+
s = te.create_schedule([x.op for x in outs])
367+
368+
def _callback(op):
369+
if "dense_vnni" in op.tag:
370+
dense_vnni_schedule(cfg, s, op.output(0), outs[0])
371+
372+
traverse_inline(s, outs[0].op, _callback)
373+
return s
374+
375+
348376
def matmul_blas_common(cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b, lib):
349377
"""Compute matmul/dense using a BLAS library"""
350378
M, K = get_const_tuple(tensor_a.shape)

0 commit comments

Comments
 (0)