Skip to content

Commit e8889ae

Browse files
author
Yuanjing Shi
authored
[TVMScript] Add syntax sugar for T.handle and T.match_buffer (#9492)
1 parent b54beed commit e8889ae

File tree

7 files changed

+174
-13
lines changed

7 files changed

+174
-13
lines changed

docker/install/ubuntu_install_python_package.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,6 @@ pip3 install \
3636
pytest-xdist \
3737
requests \
3838
scipy \
39-
synr==0.5.0 \
39+
synr==0.6.0 \
4040
six \
4141
tornado

python/gen_requirements.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@
255255
("sphinx_autodoc_annotation", None),
256256
("sphinx_gallery", None),
257257
("sphinx_rtd_theme", None),
258-
("synr", "==0.5.0"),
258+
("synr", "==0.6.0"),
259259
("tensorflow", None),
260260
("tensorflow-estimator", None),
261261
("tflite", None),

python/tvm/script/parser.py

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from tvm._ffi.base import TVMError
3333
from tvm.ir import GlobalVar
3434
from tvm.ir.function import BaseFunc
35+
from tvm.tir import buffer
3536
from tvm.tir.function import PrimFunc
3637
from . import _ffi_api
3738
from . import tir
@@ -154,10 +155,10 @@ class TVMScriptParser(Transformer):
154155
ast.BuiltinOp.Not: tvm.tir.Not,
155156
}
156157

157-
def __init__(self, base_lienno, tir_namespace):
158+
def __init__(self, base_lineno, tir_namespace):
158159
self.context = None
159160

160-
self.base_lineno = base_lienno
161+
self.base_lineno = base_lineno
161162
self.current_lineno = 0
162163
self.current_col_offset = 0
163164
self.tir_namespace = tir_namespace
@@ -249,20 +250,23 @@ def parse_arg_list(self, func, node_call):
249250
func : Function
250251
The function that provides the signature
251252
252-
node_call: ast.Call
253+
node_call: Union[ast.Call, ast.TypeApply, ast.TypeCall]
253254
The AST call node that calls into the function.
254255
255256
Returns
256257
-------
257258
arg_list : list
258259
The parsed positional argument.
259260
"""
260-
assert isinstance(node_call, ast.Call)
261+
assert isinstance(node_call, (ast.Call, ast.TypeApply, ast.TypeCall))
261262
# collect arguments
262263
args = [self.transform(arg) for arg in node_call.params]
263-
kw_args = {
264-
self.transform(k): self.transform(v) for k, v in node_call.keyword_params.items()
265-
}
264+
if isinstance(node_call, ast.TypeApply):
265+
kw_args = {} # TypeApply (e.g. foo[bar]) doesn't have kwargs defined in synr
266+
else:
267+
kw_args = {
268+
self.transform(k): self.transform(v) for k, v in node_call.keyword_params.items()
269+
}
266270
# get the name and parameter list of func
267271
if isinstance(func, (Intrin, ScopeHandler, SpecialStmt)):
268272
func_name, param_list = func.signature()
@@ -276,6 +280,7 @@ def parse_arg_list(self, func, node_call):
276280
reader = CallArgumentReader(func_name, args, kw_args, self, node_call)
277281
pos_only, kwargs, varargs = param_list
278282
internal_args = list()
283+
279284
for i, arg_name in enumerate(pos_only):
280285
internal_args.append(reader.get_pos_only_arg(i + 1, arg_name))
281286
for i, arg_info in enumerate(kwargs):
@@ -439,8 +444,22 @@ def check_decorator(decorators: List[ast.Expr]) -> bool:
439444

440445
# add parameters of function
441446
for arg in node.params:
442-
arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg))
443-
self.context.update_symbol(arg.name, arg_var, node)
447+
# Note that this case is for T.match_buffer syntax sugar
448+
if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)):
449+
result = self.handle_match_buffer_type(arg.ty, arg.name)
450+
if not isinstance(result, buffer.Buffer):
451+
self.report_error(
452+
"The result type of evaluating TypeCall and TypeApply stmt"
453+
f" is wrong: {type(result)}. It should be a Buffer",
454+
node.span,
455+
)
456+
arg_name_with_handle = arg.name + "_handle"
457+
arg_var = tvm.te.var(arg_name_with_handle, tvm.ir.PrimType("handle"))
458+
self.context.func_buffer_map[arg_var] = result
459+
self.context.update_symbol(arg.name, result, node)
460+
else:
461+
arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg))
462+
self.context.update_symbol(arg.name, arg_var, node)
444463
self.context.func_params.append(arg_var)
445464

446465
if not check_decorator(node.decorators):
@@ -1110,6 +1129,30 @@ def transform_TypeConstant(self, node):
11101129
"""
11111130
return node.value
11121131

1132+
def transform_TypeTuple(self, node):
1133+
"""Tuple value visitor for types.
1134+
1135+
Mostly used in `transform_TypeCall` and `transform_TypeApply`.
1136+
"""
1137+
return [self.transform(value) for value in node.values]
1138+
1139+
def handle_match_buffer_type(self, node, buffer_name):
1140+
"""special function to handle syntax sugar for match buffer.
1141+
1142+
This method is for buffer declarations in the function parameters.
1143+
"""
1144+
func = self.transform(node.func_name)
1145+
assert isinstance(func, SpecialStmt)
1146+
1147+
# parse args and kwargs for TypeCall and TypeApply
1148+
arg_list = self.parse_arg_list(func, node)
1149+
# Note that the third element in arg_list would always be the 'name'
1150+
# TODO: This index is hardcoded as a workaround. Better to make it programmatic
1151+
if arg_list[2] is None:
1152+
arg_list[2] = buffer_name
1153+
buf = func.handle(node, self.context, arg_list, node.func_name.span)
1154+
return buf
1155+
11131156
def transform_Return(self, node):
11141157
self.report_error(
11151158
"TVM script does not support return statements. Instead the last statement in any "

python/tvm/script/tir/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,6 @@
1818

1919
# Type system
2020
from .ty import int8, int16, int32, int64, float16, float32, float64
21-
from .ty import boolean, handle, Ptr, Tuple
21+
from .ty import boolean, handle, Ptr, Tuple, Buffer
2222

2323
from .prim_func import prim_func

python/tvm/script/tir/ty.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"""
2222
# pylint: disable=invalid-name
2323
import tvm
24+
from .special_stmt import SpecialStmt, convert_to_int
2425

2526

2627
class TypeGeneric: # pylint: disable=too-few-public-methods
@@ -67,6 +68,75 @@ def __getitem__(self, vtypes):
6768
return ConcreteType(tvm.ir.TupleType([vtype.evaluate() for vtype in vtypes]))
6869

6970

71+
class GenericBufferType(SpecialStmt): # pylint: disable=too-few-public-methods, abstract-method
72+
"""TVM script typing class for uniform Type objects"""
73+
74+
def __init__(self, vtype):
75+
def match_buffer_syntax_sugar(
76+
shape,
77+
dtype: str = "float32",
78+
name: str = None,
79+
data=None,
80+
strides=None,
81+
elem_offset=None,
82+
scope="global",
83+
align=-1,
84+
offset_factor=0,
85+
buffer_type="default",
86+
span=None,
87+
):
88+
if strides is None:
89+
strides = []
90+
align = convert_to_int(align, "align", self.context.report_error, self.node.span)
91+
offset_factor = convert_to_int(
92+
offset_factor, "offset_factor", self.context.report_error, self.node.span
93+
)
94+
buffer = tvm.tir.decl_buffer(
95+
shape,
96+
dtype,
97+
name,
98+
data,
99+
strides,
100+
elem_offset,
101+
scope,
102+
align,
103+
offset_factor,
104+
buffer_type,
105+
span=span,
106+
)
107+
return buffer
108+
109+
self.type = vtype
110+
super().__init__(match_buffer_syntax_sugar, def_symbol=True)
111+
112+
def __call__(
113+
self,
114+
shape,
115+
dtype="float32",
116+
*,
117+
name: str = None,
118+
data=None,
119+
strides=None,
120+
elem_offset=None,
121+
scope="global",
122+
align=-1,
123+
offset_factor=0,
124+
buffer_type="default",
125+
span=None,
126+
):
127+
"""
128+
This function is for Buffer(...) syntax sugar.
129+
"""
130+
pass # pylint: disable=unnecessary-pass
131+
132+
def __getitem__(self, args):
133+
"""
134+
This function is for Buffer[...] syntax sugar
135+
Note that args is the list of all arguments
136+
"""
137+
pass # pylint: disable=unnecessary-pass
138+
139+
70140
int8 = ConcreteType("int8")
71141
int16 = ConcreteType("int16")
72142
int32 = ConcreteType("int32")
@@ -78,3 +148,6 @@ def __getitem__(self, vtypes):
78148
handle = ConcreteType("handle")
79149
Ptr = GenericPtrType()
80150
Tuple = GenericTupleType()
151+
# we don't have 'buffer' type on the cpp side
152+
# thus 'handle' is used here for convenience's sake
153+
Buffer = GenericBufferType("handle")

tests/python/unittest/test_tvmscript_syntax_sugar.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,5 +101,50 @@ def test_syntax_sugar_fail():
101101
check_error(loop_syntax_sugar_fail, 3)
102102

103103

104+
# match buffer - use kwargs
105+
@T.prim_func
106+
def elementwise_handle(
107+
a: T.handle,
108+
b: T.handle,
109+
) -> None:
110+
A = T.match_buffer(a, (128, 128, 128, 128))
111+
B = T.match_buffer(b, (128, 128, 128, 128))
112+
for i, j, k, l in T.grid(128, 128, 128, 128):
113+
with T.block("B"):
114+
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
115+
B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0
116+
117+
118+
# match buffer - use buffer with kwargs
119+
@T.prim_func
120+
def elementwise_buffer_kwargs(
121+
a: T.Buffer(shape=(128, 128, 128, 128), dtype="float32", elem_offset=None),
122+
b: T.Buffer(shape=(128, 128, 128, 128), dtype="float32", elem_offset=None),
123+
) -> None:
124+
for i, j, k, l in T.grid(128, 128, 128, 128):
125+
with T.block("B"):
126+
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
127+
b[vi, vj, vk, vl] = a[vi, vj, vk, vl] * 2.0
128+
129+
130+
# match buffer - use buffer without kwargs
131+
@T.prim_func
132+
def elementwise_buffer_no_kwargs(
133+
a: T.Buffer[(128, 128, 128, 128), "float32"],
134+
b: T.Buffer[(128, 128, 128, 128), "float32"],
135+
) -> None:
136+
for i, j, k, l in T.grid(128, 128, 128, 128):
137+
with T.block("B"):
138+
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
139+
b[vi, vj, vk, vl] = a[vi, vj, vk, vl] * 2.0
140+
141+
142+
def test_match_buffer_syntax_sugar():
143+
# with kwargs
144+
assert_structural_equal(elementwise_handle, elementwise_buffer_kwargs)
145+
# without kwargs
146+
assert_structural_equal(elementwise_handle, elementwise_buffer_no_kwargs)
147+
148+
104149
if __name__ == "__main__":
105150
sys.exit(pytest.main([__file__] + sys.argv[1:]))

tests/scripts/task_ci_setup.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ set -o pipefail
3030
#
3131
echo "Addtiional setup in" ${CI_IMAGE_NAME}
3232

33-
python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.5.0
33+
python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.6.0
3434

3535
# Rebuild standalone_crt in build/ tree. This file is not currently archived by pack_lib() in
3636
# Jenkinsfile. We expect config.cmake to be present from pack_lib().

0 commit comments

Comments
 (0)