From 97f2d12dec560f38eeaa52539baed53ebdba00ab Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Thu, 3 Aug 2023 16:24:46 -0700 Subject: [PATCH 1/3] [TVMScript] Allow use of Python builtins in script The builtins are already supported by `eval` (they are automatically injected in the global scope), but they are not recognized by the evaluator's checks. When the evaluator sees doc.Name, it looks it up in the current `var_table`, and flags an error if it's not there. Make the evaluator also consult the current builtins before erroring out. --- python/tvm/script/parser/core/evaluator.py | 21 ++++++++++++++++--- .../unittest/test_tvmscript_parser_tir.py | 17 +++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py index 939b7e82ce61..26e9d091bfb8 100644 --- a/python/tvm/script/parser/core/evaluator.py +++ b/python/tvm/script/parser/core/evaluator.py @@ -58,6 +58,17 @@ } +def _get_builtin_or_none(name: str): + builtins = globals().get("__builtins__") + if not builtins: + return None + if not isinstance(builtins, dict) and hasattr(builtins, "__dict__"): + builtins = builtins.__dict__ + if isinstance(builtins, dict): + return builtins.get(name) + return None + + class ExprEvaluator: """Expression evaluator for TVMScript parser. @@ -106,9 +117,13 @@ def eval(parser: "Parser", value_table: Dict[str, Any], node: doc.AST) -> Any: self = ExprEvaluator(parser, value_table) result = self._visit(node) # pylint: disable=protected-access if isinstance(result, doc.Name): - if result.id not in self.value_table: + if result.id in self.value_table: + return self.value_table[result.id] + else: + builtin = _get_builtin_or_none(result.id) + if builtin: + return builtin raise ParserError(result, f"Undefined variable: {result.id}") - return self.value_table[result.id] if isinstance(result, doc.Constant): return result.value raise TypeError(f"Unexpected result type: {type(result)}") @@ -202,7 +217,7 @@ def _visit(self, node: doc.AST) -> Any: return tuple(self._visit(n) for n in node) assert isinstance(node, doc.AST) if isinstance(node, doc.Name): - if node.id not in self.value_table: + if node.id not in self.value_table and not _get_builtin_or_none(node.id): raise ParserError(node, f"Undefined variable: {node.id}") return node if isinstance( diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py b/tests/python/unittest/test_tvmscript_parser_tir.py index ef02df497b7b..a7b6e53e6fe4 100644 --- a/tests/python/unittest/test_tvmscript_parser_tir.py +++ b/tests/python/unittest/test_tvmscript_parser_tir.py @@ -308,5 +308,22 @@ def expected(A: T.Buffer((), "int32"), B: T.Buffer((), "int32")): tvm.ir.assert_structural_equal(func_with_empty_tuple, expected) +def test_tir_builtin_expression(): + dims = (128, 128) + + @T.prim_func(private=True) + def with_builtin(a: T.handle) -> None: + A = T.match_buffer(a, [len(dims), *dims], "int32") + for i, j, k in T.grid(*A.shape): + A[i, j, k] = T.int32(1 + len(A.shape)) + + @T.prim_func(private=True) + def evaluated(A: T.Buffer((2, 128, 128), "int32")): + for i, j, k in T.grid(2, 128, 128): + A[i, j, k] = 4 + + tvm.ir.assert_structural_equal(with_builtin, evaluated) + + if __name__ == "__main__": tvm.testing.main() From 6953efb8332edfa23eae946021ba8d40212f47c5 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Sun, 6 Aug 2023 09:03:34 -0700 Subject: [PATCH 2/3] Remove test that is no longer invalid Buitin functions in T.prim_func's body are evaluated when the actual PrimFunc is created, so this test will print 'a' before generating `T.evaluate(0)` (i.e. equivalent of an empty statement) as the body. --- tests/python/unittest/test_tvmscript_error_report.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index f902ebb41183..279785fdca51 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -85,14 +85,6 @@ def undefined_buffer(a: T.handle) -> None: check_error(undefined_buffer, 5) -def test_unsupported_stmt(): - def unsupported_stmt(a: T.int32) -> None: - if a > 0: - print("I love tvm") # error - - check_error(unsupported_stmt, 3) - - def test_unsupported_function_call(): def unsupported_function_call(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") From e41922bbd2b08a7dc913e134bc232e2d92947fa8 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Sun, 6 Aug 2023 15:31:19 -0700 Subject: [PATCH 3/3] Restart build