diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py index 939b7e82ce61..26e9d091bfb8 100644 --- a/python/tvm/script/parser/core/evaluator.py +++ b/python/tvm/script/parser/core/evaluator.py @@ -58,6 +58,17 @@ } +def _get_builtin_or_none(name: str): + builtins = globals().get("__builtins__") + if not builtins: + return None + if not isinstance(builtins, dict) and hasattr(builtins, "__dict__"): + builtins = builtins.__dict__ + if isinstance(builtins, dict): + return builtins.get(name) + return None + + class ExprEvaluator: """Expression evaluator for TVMScript parser. @@ -106,9 +117,13 @@ def eval(parser: "Parser", value_table: Dict[str, Any], node: doc.AST) -> Any: self = ExprEvaluator(parser, value_table) result = self._visit(node) # pylint: disable=protected-access if isinstance(result, doc.Name): - if result.id not in self.value_table: + if result.id in self.value_table: + return self.value_table[result.id] + else: + builtin = _get_builtin_or_none(result.id) + if builtin: + return builtin raise ParserError(result, f"Undefined variable: {result.id}") - return self.value_table[result.id] if isinstance(result, doc.Constant): return result.value raise TypeError(f"Unexpected result type: {type(result)}") @@ -202,7 +217,7 @@ def _visit(self, node: doc.AST) -> Any: return tuple(self._visit(n) for n in node) assert isinstance(node, doc.AST) if isinstance(node, doc.Name): - if node.id not in self.value_table: + if node.id not in self.value_table and not _get_builtin_or_none(node.id): raise ParserError(node, f"Undefined variable: {node.id}") return node if isinstance( diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index f902ebb41183..279785fdca51 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -85,14 +85,6 @@ def undefined_buffer(a: T.handle) -> None: check_error(undefined_buffer, 5) -def test_unsupported_stmt(): - def unsupported_stmt(a: T.int32) -> None: - if a > 0: - print("I love tvm") # error - - check_error(unsupported_stmt, 3) - - def test_unsupported_function_call(): def unsupported_function_call(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py b/tests/python/unittest/test_tvmscript_parser_tir.py index ef02df497b7b..a7b6e53e6fe4 100644 --- a/tests/python/unittest/test_tvmscript_parser_tir.py +++ b/tests/python/unittest/test_tvmscript_parser_tir.py @@ -308,5 +308,22 @@ def expected(A: T.Buffer((), "int32"), B: T.Buffer((), "int32")): tvm.ir.assert_structural_equal(func_with_empty_tuple, expected) +def test_tir_builtin_expression(): + dims = (128, 128) + + @T.prim_func(private=True) + def with_builtin(a: T.handle) -> None: + A = T.match_buffer(a, [len(dims), *dims], "int32") + for i, j, k in T.grid(*A.shape): + A[i, j, k] = T.int32(1 + len(A.shape)) + + @T.prim_func(private=True) + def evaluated(A: T.Buffer((2, 128, 128), "int32")): + for i, j, k in T.grid(2, 128, 128): + A[i, j, k] = 4 + + tvm.ir.assert_structural_equal(with_builtin, evaluated) + + if __name__ == "__main__": tvm.testing.main()