|
19 | 19 | import pytest |
20 | 20 | import tvm |
21 | 21 | import tvm.testing |
22 | | -from tvm import tir |
| 22 | +from tvm import tir, te |
23 | 23 | from tvm.script import tir as T |
24 | 24 | from tvm.tir.schedule.testing import verify_trace_roundtrip |
| 25 | +from tvm.tir.tensor_intrin.vnni import INTRIN_NAME as VNNI_INTRIN |
25 | 26 |
|
26 | 27 | # fmt: off |
27 | 28 | # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks |
@@ -531,5 +532,40 @@ def test_tensorize_with_annotation(): |
531 | 532 | verify_trace_roundtrip(sch=s, mod=func) |
532 | 533 |
|
533 | 534 |
|
| 535 | +def test_tensorize_vnni(): |
| 536 | + n, m, k = 128, 128, 128 |
| 537 | + X = te.placeholder((m, k), name="X", dtype="uint8") |
| 538 | + packed_W = te.placeholder((n // 16, k // 4, 16, 4), name="packedW", dtype="int8") |
| 539 | + |
| 540 | + ak = te.reduce_axis((0, k), name="k") |
| 541 | + matmul = te.compute( |
| 542 | + (m, n), |
| 543 | + lambda i, j: te.sum( |
| 544 | + X[i, ak].astype("int32") |
| 545 | + * packed_W[ |
| 546 | + tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4 |
| 547 | + ].astype("int32"), |
| 548 | + axis=ak, |
| 549 | + ), |
| 550 | + name="compute", |
| 551 | + ) |
| 552 | + |
| 553 | + func = te.create_prim_func([X, packed_W, matmul]) |
| 554 | + |
| 555 | + sch = tir.Schedule(func, debug_mask="all") |
| 556 | + block = sch.get_block("compute") |
| 557 | + _, j, k = sch.get_loops(block) |
| 558 | + |
| 559 | + _, ji = sch.split(j, factors=[None, 16]) |
| 560 | + ko, ki = sch.split(k, factors=[None, 4]) |
| 561 | + sch.reorder(ko, ji, ki) |
| 562 | + |
| 563 | + sch.decompose_reduction(block, ko) |
| 564 | + sch.tensorize(ji, VNNI_INTRIN) |
| 565 | + |
| 566 | + verify_trace_roundtrip(sch=sch, mod=func) |
| 567 | + |
| 568 | + |
534 | 569 | if __name__ == "__main__": |
535 | | - sys.exit(pytest.main([__file__] + sys.argv[1:])) |
| 570 | + # sys.exit(pytest.main([__file__] + sys.argv[1:])) |
| 571 | + test_tensorize_vnni() |
0 commit comments