Skip to content

Commit 4dc47df

Browse files
allow constant value let binding in script (#11115)
1 parent 4330c21 commit 4dc47df

File tree

2 files changed

+41
-24
lines changed

2 files changed

+41
-24
lines changed

python/tvm/script/parser.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -574,32 +574,33 @@ def transform_Assign(self, node):
574574
arg_list = self.parse_arg_list(func, node.rhs)
575575
func.handle(node, self.context, arg_list, node.rhs.func_name.span)
576576
return self.parse_body(node)
577-
else:
578-
value = self.transform(node.rhs)
579-
if len(node.lhs) == 1 and not isinstance(node.lhs[0], ast.Var):
580-
# This is a little confusing because it only is true when
581-
# we have taken this branch. We might need to clarify what
582-
# exectly is allowed in Assignments in tvmscript.
583-
self.report_error(
584-
"Left hand side of assignment must be an unqualified variable",
585-
node.span,
586-
)
587-
ast_var = node.lhs[0]
577+
if isinstance(node.rhs, (ast.Call, ast.Constant)):
578+
# Pattern 4 of let binding
579+
value = self.transform(node.rhs)
580+
if len(node.lhs) == 1 and not isinstance(node.lhs[0], ast.Var):
581+
# This is a little confusing because it only is true when
582+
# we have taken this branch. We might need to clarify what
583+
# exectly is allowed in Assignments in tvmscript.
584+
self.report_error(
585+
"Left hand side of assignment must be an unqualified variable",
586+
node.span,
587+
)
588+
ast_var = node.lhs[0]
588589

589-
if node.ty is None and hasattr(value, "dtype"):
590-
var_ty = value.dtype
591-
else:
592-
var_ty = self.parse_type(node.ty, ast_var)
590+
if node.ty is None and hasattr(value, "dtype"):
591+
var_ty = value.dtype
592+
else:
593+
var_ty = self.parse_type(node.ty, ast_var)
593594

594-
var = tvm.te.var(
595-
ast_var.id.name,
596-
var_ty,
597-
span=tvm_span_from_synr(ast_var.span),
598-
)
599-
self.context.update_symbol(var.name, var, node)
600-
body = self.parse_body(node)
601-
self.context.remove_symbol(var.name)
602-
return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))
595+
var = tvm.te.var(
596+
ast_var.id.name,
597+
var_ty,
598+
span=tvm_span_from_synr(ast_var.span),
599+
)
600+
self.context.update_symbol(var.name, var, node)
601+
body = self.parse_body(node)
602+
self.context.remove_symbol(var.name)
603+
return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))
603604

604605
self.report_error(
605606
"""Assignments should be either

tests/python/unittest/test_tvmscript_syntax_sugar.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,5 +249,21 @@ def func_without_type_annotation(A: T.Buffer[(1,), "int32"]):
249249
T.evaluate(x)
250250

251251

252+
def test_letstmt_bind_with_constant():
253+
@T.prim_func
254+
def constant_binds():
255+
x = 1
256+
y = 42.0
257+
T.evaluate(T.cast(x, "float32") + y)
258+
259+
@T.prim_func
260+
def constant_binds_wrapped():
261+
x = T.int32(1)
262+
y = T.float32(42.0)
263+
T.evaluate(T.cast(x, "float32") + y)
264+
265+
assert_structural_equal(constant_binds, constant_binds_wrapped)
266+
267+
252268
if __name__ == "__main__":
253269
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 commit comments

Comments
 (0)