Skip to content

Commit a4be2ed

Browse files
authored
[TVMScript] Support inlined function call as a sugar (#11324)
* [TVMScript] Support function call to help construct AST * add test * update test * more comment * fix for avoiding Buffer.vload(...) case * update parse error msg * wrap func call with try / catch, emit error msg * silence pylint
1 parent b5e1fdd commit a4be2ed

File tree

2 files changed

+121
-4
lines changed

2 files changed

+121
-4
lines changed

python/tvm/script/parser.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
different python versions. Synr also provides an error handling context that we
2121
use for error reporting.
2222
"""
23-
# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return
23+
# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return, broad-except
24+
import types
2425
import json
2526
import operator
2627
import inspect
@@ -543,7 +544,7 @@ def transform_Assign(self, node):
543544
AST abstract grammar:
544545
Assign(expr* targets, expr value, string? type_comment)
545546
546-
By now 3 patterns of Assign is supported:
547+
By now 5 patterns of Assign is supported:
547548
1. special stmts with return value
548549
1.1 Buffer = T.match_buffer()/T.buffer_decl()
549550
1.2 Var = T.var()
@@ -552,6 +553,9 @@ def transform_Assign(self, node):
552553
3. (Store) Var[PrimExpr] = PrimExpr
553554
4. with scope handlers with concise scoping and var def
554555
4.1 var = T.allocate()
556+
5. A call to a pure python function, consuming and producing TVMScript values.
557+
The outputs are inlined into the following body (no variable is created).
558+
x, y = f(...)
555559
"""
556560

557561
if isinstance(node.rhs, ast.Call):
@@ -577,6 +581,35 @@ def transform_Assign(self, node):
577581
arg_list = self.parse_arg_list(func, node.rhs)
578582
func.handle(node, self.context, arg_list, node.rhs.func_name.span)
579583
return self.parse_body(node)
584+
elif isinstance(func, types.FunctionType):
585+
# Pattern 5
586+
args = [self.transform(arg) for arg in node.rhs.params]
587+
try:
588+
out = func(*args)
589+
except Exception as e:
590+
self.report_error(
591+
"Error occured when invoking the function "
592+
+ func.__name__
593+
+ ": \n"
594+
+ str(e),
595+
node.rhs.span,
596+
)
597+
598+
if len(node.lhs) == 1 and not isinstance(out, list):
599+
out = [out]
600+
601+
assert len(out) == len(node.lhs)
602+
603+
for var, value in zip(node.lhs, out):
604+
self.context.update_symbol(var.id.name, value, node)
605+
606+
body = self.parse_body(node)
607+
608+
for var, value in zip(node.lhs, out):
609+
self.context.remove_symbol(var.id.name)
610+
611+
return body
612+
580613
if isinstance(node.rhs, (ast.Call, ast.Constant)):
581614
# Pattern 4 of let binding
582615
value = self.transform(node.rhs)
@@ -606,15 +639,18 @@ def transform_Assign(self, node):
606639
return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))
607640

608641
self.report_error(
609-
"""Assignments should be either
642+
"""Assignments should be one of:
610643
1. A "special statement" with return value
611644
1.1 Buffer = T.match_buffer()/T.buffer_decl()
612645
1.2 Var = T.var()
613646
1.3 Var = T.env_thread()
614647
2. A store into a buffer: Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr
615648
3. A store into a variable: Var[PrimExpr] = PrimExpr
616649
4. A with scope handler with concise scoping and var def
617-
4.1 var = T.allocate()""",
650+
4.1 var = T.allocate()
651+
5. The right-hand side being a call to a pure python function, consuming and
652+
producing TVMScript values.
653+
x, y = f(...)""",
618654
node.span,
619655
)
620656

tests/python/unittest/test_tvmscript_syntax_sugar.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,5 +265,86 @@ def constant_binds_wrapped():
265265
assert_structural_equal(constant_binds, constant_binds_wrapped)
266266

267267

268+
def test_func_call():
269+
def shared_16x16_to_ldmatrix_32x8_layout(i, j):
270+
thread_id = (i % 8) * 4 + (j % 8) // 2
271+
return thread_id, (j // 8) * 4 + (i // 8) * 2 + (j % 2)
272+
273+
@T.prim_func
274+
def mma_sync_m16n16k16_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
275+
A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
276+
B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
277+
C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
278+
279+
with T.block("root"):
280+
T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8])
281+
T.writes(C[0:32, 0:8])
282+
for i, j, k in T.grid(16, 16, 16):
283+
with T.block("C"):
284+
i, j, k = T.axis.remap("SSR", [i, j, k])
285+
thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j)
286+
thread_id_A, local_id_A = shared_16x16_to_ldmatrix_32x8_layout(i, k)
287+
thread_id_B, local_id_B = shared_16x16_to_ldmatrix_32x8_layout(k, j)
288+
289+
T.reads(
290+
C[thread_id_C, local_id_C],
291+
A[thread_id_A, local_id_A],
292+
B[thread_id_B, local_id_B],
293+
)
294+
T.writes(C[thread_id_C, local_id_C])
295+
296+
C[thread_id_C, local_id_C] += (
297+
A[thread_id_A, local_id_A] * B[thread_id_B, local_id_B]
298+
)
299+
300+
@T.prim_func
301+
def mma_sync_m16n16k16_desc_manual(a: T.handle, b: T.handle, c: T.handle) -> None:
302+
A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
303+
B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
304+
C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
305+
306+
with T.block("root"):
307+
T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8])
308+
T.writes(C[0:32, 0:8])
309+
for i, j, k in T.grid(16, 16, 16):
310+
with T.block("C"):
311+
i, j, k = T.axis.remap("SSR", [i, j, k])
312+
T.reads(
313+
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2],
314+
A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2],
315+
B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2],
316+
)
317+
T.writes(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2])
318+
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = (
319+
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2]
320+
+ A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2]
321+
* B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2]
322+
)
323+
324+
assert_structural_equal(mma_sync_m16n16k16_desc, mma_sync_m16n16k16_desc_manual)
325+
326+
# The following is an example of an error message from calling an invalid function
327+
328+
# error: Error occured when invoking the function sqrt:
329+
# loop of ufunc does not support argument 0 of type Var which has no callable sqrt method
330+
# --> test_tvmscript_syntax_sugar.py:334:19
331+
# |
332+
# 334 | ind = sqrt(i)
333+
# | ^^^^^^^
334+
# note: run with `TVM_BACKTRACE=1` environment variable to display a backtrace.
335+
336+
# Uncomment to see the error above.
337+
# def sqrt(x):
338+
# import numpy as np
339+
# return np.sqrt(x)
340+
341+
# @T.prim_func
342+
# def loop(a: T.handle) -> None:
343+
# A = T.match_buffer(a, (128,))
344+
# for i in T.serial(128):
345+
# ind = sqrt(i)
346+
# A[i] = A[ind]
347+
348+
268349
if __name__ == "__main__":
269350
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 commit comments

Comments
 (0)