Skip to content

Commit f88c31e

Browse files
committed
add VNNI unittest
1 parent 6cc8009 commit f88c31e

File tree

1 file changed

+38
-2
lines changed

1 file changed

+38
-2
lines changed

tests/python/unittest/test_tir_schedule_tensorize.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919
import pytest
2020
import tvm
2121
import tvm.testing
22-
from tvm import tir
22+
from tvm import tir, te
2323
from tvm.script import tir as T
2424
from tvm.tir.schedule.testing import verify_trace_roundtrip
25+
from tvm.tir.tensor_intrin.vnni import INTRIN_NAME as VNNI_INTRIN
2526

2627
# fmt: off
2728
# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
@@ -531,5 +532,40 @@ def test_tensorize_with_annotation():
531532
verify_trace_roundtrip(sch=s, mod=func)
532533

533534

535+
def test_tensorize_vnni():
536+
n, m, k = 128, 128, 128
537+
X = te.placeholder((m, k), name="X", dtype="uint8")
538+
packed_W = te.placeholder((n // 16, k // 4, 16, 4), name="packedW", dtype="int8")
539+
540+
ak = te.reduce_axis((0, k), name="k")
541+
matmul = te.compute(
542+
(m, n),
543+
lambda i, j: te.sum(
544+
X[i, ak].astype("int32")
545+
* packed_W[
546+
tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4
547+
].astype("int32"),
548+
axis=ak,
549+
),
550+
name="compute",
551+
)
552+
553+
func = te.create_prim_func([X, packed_W, matmul])
554+
555+
sch = tir.Schedule(func, debug_mask="all")
556+
block = sch.get_block("compute")
557+
_, j, k = sch.get_loops(block)
558+
559+
_, ji = sch.split(j, factors=[None, 16])
560+
ko, ki = sch.split(k, factors=[None, 4])
561+
sch.reorder(ko, ji, ki)
562+
563+
sch.decompose_reduction(block, ko)
564+
sch.tensorize(ji, VNNI_INTRIN)
565+
566+
verify_trace_roundtrip(sch=sch, mod=func)
567+
568+
534569
if __name__ == "__main__":
535-
sys.exit(pytest.main([__file__] + sys.argv[1:]))
570+
# sys.exit(pytest.main([__file__] + sys.argv[1:]))
571+
test_tensorize_vnni()

0 commit comments

Comments
 (0)