Skip to content

Commit d705401

Browse files
author
shingjan
committed
addr cmts
1 parent ed9ba73 commit d705401

File tree

2 files changed

+40
-10
lines changed

2 files changed

+40
-10
lines changed

python/tvm/script/tir/ty.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,20 @@ def __init__(self, vtype):
5555
def evaluate(self):
5656
return tvm.ir.PrimType(self.type)
5757

58-
def __call__(self, shape, dtype, elem_offset):
59-
pass
58+
def __call__(
59+
self,
60+
shape,
61+
dtype="float32",
62+
data=None,
63+
strides=None,
64+
elem_offset=None,
65+
scope="global",
66+
align=-1,
67+
offset_factor=0,
68+
buffer_type="default",
69+
span=None,
70+
):
71+
self.name = "match_buffer"
6072

6173

6274
class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method

tests/python/unittest/test_tvmscript_syntax_sugar.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,35 +18,53 @@
1818
import sys
1919

2020
import pytest
21+
import tvm
2122
from tvm.script import tir as T
2223

23-
# match buffer - use kwargs
24+
# match buffer - no syntax sugar
2425
@T.prim_func
25-
def elementwise(
26+
def elementwise_handle(
27+
a: T.handle,
28+
b: T.handle,
29+
) -> None:
30+
A = T.match_buffer(a, (128, 128, 128, 128))
31+
B = T.match_buffer(b, (128, 128, 128, 128))
32+
for i, j, k, l in T.grid(128, 128, 128, 128):
33+
with T.block("B"):
34+
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
35+
B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0
36+
37+
38+
# match buffer - use buffer with kwargs
39+
@T.prim_func
40+
def elementwise_buffer_kwargs(
2641
a: T.Buffer(shape=(128, 128, 128, 128), dtype="float32", elem_offset=1),
2742
b: T.Buffer(shape=(128, 128, 128, 128), dtype="float32", elem_offset=2),
2843
) -> None:
29-
# A = T.match_buffer(a, (128, 128, 128, 128))
30-
# B = T.match_buffer(b, (128, 128, 128, 128))
3144
for i, j, k, l in T.grid(128, 128, 128, 128):
3245
with T.block("B"):
3346
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
3447
b[vi, vj, vk, vl] = a[vi, vj, vk, vl] * 2.0
3548

3649

37-
# match buffer - no kwargs
50+
# match buffer - use buffer without kwargs
3851
@T.prim_func
39-
def elementwise(
52+
def elementwise_buffer_no_kwargs(
4053
a: T.Buffer[(128, 128, 128, 128), "float32"],
4154
b: T.Buffer[(128, 128, 128, 128), "float32"],
4255
) -> None:
43-
# A = T.match_buffer(a, (128, 128, 128, 128))
44-
# B = T.match_buffer(b, (128, 128, 128, 128))
4556
for i, j, k, l in T.grid(128, 128, 128, 128):
4657
with T.block("B"):
4758
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
4859
b[vi, vj, vk, vl] = a[vi, vj, vk, vl] * 2.0
4960

5061

62+
def test_match_buffer_syntax_sugar():
63+
# with kwargs
64+
tvm.ir.assert_structural_equal(elementwise_handle, elementwise_buffer_kwargs)
65+
# without kwargs
66+
tvm.ir.assert_structural_equal(elementwise_handle, elementwise_buffer_no_kwargs)
67+
68+
5169
if __name__ == "__main__":
5270
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 commit comments

Comments
 (0)