Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
different python versions. Synr also provides an error handling context that we
use for error reporting.
"""
# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return
# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return, broad-except
import types
import json
import operator
import inspect
Expand Down Expand Up @@ -543,7 +544,7 @@ def transform_Assign(self, node):
AST abstract grammar:
Assign(expr* targets, expr value, string? type_comment)

By now 3 patterns of Assign is supported:
By now 5 patterns of Assign is supported:
1. special stmts with return value
1.1 Buffer = T.match_buffer()/T.buffer_decl()
1.2 Var = T.var()
Expand All @@ -552,6 +553,9 @@ def transform_Assign(self, node):
3. (Store) Var[PrimExpr] = PrimExpr
4. with scope handlers with concise scoping and var def
4.1 var = T.allocate()
5. A call to a pure python function, consuming and producing TVMScript values.
The outputs are inlined into the following body (no variable is created).
x, y = f(...)
"""

if isinstance(node.rhs, ast.Call):
Expand All @@ -577,6 +581,35 @@ def transform_Assign(self, node):
arg_list = self.parse_arg_list(func, node.rhs)
func.handle(node, self.context, arg_list, node.rhs.func_name.span)
return self.parse_body(node)
elif isinstance(func, types.FunctionType):
# Pattern 5
args = [self.transform(arg) for arg in node.rhs.params]
try:
out = func(*args)
except Exception as e:
self.report_error(
"Error occured when invoking the function "
+ func.__name__
+ ": \n"
+ str(e),
node.rhs.span,
)

if len(node.lhs) == 1 and not isinstance(out, list):
out = [out]

assert len(out) == len(node.lhs)

for var, value in zip(node.lhs, out):
self.context.update_symbol(var.id.name, value, node)

body = self.parse_body(node)

for var, value in zip(node.lhs, out):
self.context.remove_symbol(var.id.name)

return body

if isinstance(node.rhs, (ast.Call, ast.Constant)):
# Pattern 4 of let binding
value = self.transform(node.rhs)
Expand Down Expand Up @@ -606,15 +639,18 @@ def transform_Assign(self, node):
return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))

self.report_error(
"""Assignments should be either
"""Assignments should be one of:
1. A "special statement" with return value
1.1 Buffer = T.match_buffer()/T.buffer_decl()
1.2 Var = T.var()
1.3 Var = T.env_thread()
2. A store into a buffer: Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr
3. A store into a variable: Var[PrimExpr] = PrimExpr
4. A with scope handler with concise scoping and var def
4.1 var = T.allocate()""",
4.1 var = T.allocate()
5. The right-hand side being a call to a pure python function, consuming and
producing TVMScript values.
x, y = f(...)""",
node.span,
)

Expand Down
81 changes: 81 additions & 0 deletions tests/python/unittest/test_tvmscript_syntax_sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,5 +265,86 @@ def constant_binds_wrapped():
assert_structural_equal(constant_binds, constant_binds_wrapped)


def test_func_call():
def shared_16x16_to_ldmatrix_32x8_layout(i, j):
thread_id = (i % 8) * 4 + (j % 8) // 2
return thread_id, (j // 8) * 4 + (i // 8) * 2 + (j % 2)

@T.prim_func
def mma_sync_m16n16k16_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")

with T.block("root"):
T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8])
T.writes(C[0:32, 0:8])
for i, j, k in T.grid(16, 16, 16):
with T.block("C"):
i, j, k = T.axis.remap("SSR", [i, j, k])
thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j)
thread_id_A, local_id_A = shared_16x16_to_ldmatrix_32x8_layout(i, k)
thread_id_B, local_id_B = shared_16x16_to_ldmatrix_32x8_layout(k, j)

T.reads(
C[thread_id_C, local_id_C],
A[thread_id_A, local_id_A],
B[thread_id_B, local_id_B],
)
T.writes(C[thread_id_C, local_id_C])

C[thread_id_C, local_id_C] += (
A[thread_id_A, local_id_A] * B[thread_id_B, local_id_B]
)

@T.prim_func
def mma_sync_m16n16k16_desc_manual(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")

with T.block("root"):
T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8])
T.writes(C[0:32, 0:8])
for i, j, k in T.grid(16, 16, 16):
with T.block("C"):
i, j, k = T.axis.remap("SSR", [i, j, k])
T.reads(
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2],
A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2],
B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2],
)
T.writes(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2])
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = (
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2]
+ A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2]
* B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2]
)

assert_structural_equal(mma_sync_m16n16k16_desc, mma_sync_m16n16k16_desc_manual)

# The following is an example of an error message from calling an invalid function

# error: Error occured when invoking the function sqrt:
# loop of ufunc does not support argument 0 of type Var which has no callable sqrt method
# --> test_tvmscript_syntax_sugar.py:334:19
# |
# 334 | ind = sqrt(i)
# | ^^^^^^^
# note: run with `TVM_BACKTRACE=1` environment variable to display a backtrace.

# Uncomment to see the error above.
# def sqrt(x):
# import numpy as np
# return np.sqrt(x)

# @T.prim_func
# def loop(a: T.handle) -> None:
# A = T.match_buffer(a, (128,))
# for i in T.serial(128):
# ind = sqrt(i)
# A[i] = A[ind]


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))