Skip to content

Commit 1351fde

Browse files
committed
use buffer syntax sugar
1 parent 0ced85f commit 1351fde

File tree

3 files changed

+12
-10
lines changed

3 files changed

+12
-10
lines changed

python/tvm/tir/tensor_intrin/x86.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,9 @@
2323

2424

2525
@T.prim_func
26-
def dot_product_16x4_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
27-
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
28-
B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
29-
C = T.match_buffer(c, (16,), "int32", offset_factor=1)
30-
26+
def dot_product_16x4_u8i8i32_desc(
27+
A: T.Buffer[(4,), "uint8"], B: T.Buffer[(16, 4), "int8"], C: T.Buffer[(16,), "int32"]
28+
) -> None:
3129
with T.block("root"):
3230
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
3331
T.writes(C[0:16])
@@ -41,7 +39,9 @@ def dot_product_16x4_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
4139

4240

4341
@T.prim_func
44-
def dot_product_16x4_vnni_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
42+
def dot_product_16x4_u8i8i32_vnni_impl(
43+
A: T.Buffer[(4,), "uint8"], B: T.Buffer[(16, 4), "int8"], C: T.Buffer[(16,), "int32"]
44+
) -> None:
4545
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
4646
B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
4747
C = T.match_buffer(c, (16,), "int32", offset_factor=1)
@@ -66,6 +66,8 @@ def dot_product_16x4_vnni_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
6666
)
6767

6868

69-
VNNI_INTRIN = "dot_16x4_vnni"
69+
VNNI_DOT_16x4_INTRIN = "dot_16x4_vnni"
7070

71-
TensorIntrin.register(VNNI_INTRIN, dot_product_16x4_desc, dot_product_16x4_vnni_impl)
71+
TensorIntrin.register(
72+
VNNI_DOT_16x4_INTRIN, dot_product_16x4_u8i8i32_desc, dot_product_16x4_u8i8i32_vnni_impl
73+
)

tests/python/unittest/test_meta_schedule_tune_relay.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from tvm.target.target import Target
4343
from tvm.tir.schedule import BlockRV, Schedule
4444
from tvm.tir.schedule.trace import Trace
45-
from tvm.tir.tensor_intrin.x86 import VNNI_INTRIN
45+
from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN
4646

4747

4848
logging.basicConfig()

tests/python/unittest/test_tir_schedule_tensorize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from tvm import tir, te
2323
from tvm.script import tir as T
2424
from tvm.tir.schedule.testing import verify_trace_roundtrip
25-
from tvm.tir.tensor_intrin.x86 import VNNI_INTRIN
25+
from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN
2626

2727
# fmt: off
2828
# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks

0 commit comments

Comments
 (0)