From bd996e340610b7149d8e1a54fe6fdc328828c9d7 Mon Sep 17 00:00:00 2001 From: lightzhan-intellif Date: Sun, 18 Dec 2022 07:30:15 +0000 Subject: [PATCH 1/2] fix var capturing order error. --- python/tvm/script/parser/core/utils.py | 2 +- .../unittest/test_tvmscript_regression.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/python/tvm/script/parser/core/utils.py b/python/tvm/script/parser/core/utils.py index a304afddbe55..453ac18b382b 100644 --- a/python/tvm/script/parser/core/utils.py +++ b/python/tvm/script/parser/core/utils.py @@ -37,8 +37,8 @@ def inspect_function_capture(func: Callable) -> Dict[str, Any]: The function variables map with non-local or global variables. """ captured = { - **inspect.getclosurevars(func).nonlocals, **func.__globals__, # type: ignore + **inspect.getclosurevars(func).nonlocals, } return captured diff --git a/tests/python/unittest/test_tvmscript_regression.py b/tests/python/unittest/test_tvmscript_regression.py index 05c1665ea2a1..f66990b9cffa 100644 --- a/tests/python/unittest/test_tvmscript_regression.py +++ b/tests/python/unittest/test_tvmscript_regression.py @@ -58,7 +58,24 @@ def func_ref(): tvm.ir.assert_structural_equal(test_case, func_ref) +def test_var_capturing_order(): + b = 2 + + @T.prim_func + def test_case(): + k: T.int32 = b + + @T.prim_func + def func_ref(): + k: T.int32 = 2 + T.evaluate(0) + + tvm.ir.assert_structural_equal(test_case, func_ref) + + if __name__ == "__main__": + b = 1 a = numpy.zeros((10, 10), dtype="int8") test_multi_element_array_in_outmost_namespace() test_different_dtype_assignment_to_var() + test_var_capturing_order() From beedcb9f82f8af3ae65d8b56da06659166de3e4e Mon Sep 17 00:00:00 2001 From: lightzhan-intellif Date: Sun, 18 Dec 2022 08:49:49 +0000 Subject: [PATCH 2/2] Make it easy to understand. --- tests/python/unittest/test_tvmscript_regression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tvmscript_regression.py b/tests/python/unittest/test_tvmscript_regression.py index f66990b9cffa..d063c0fcab7f 100644 --- a/tests/python/unittest/test_tvmscript_regression.py +++ b/tests/python/unittest/test_tvmscript_regression.py @@ -74,8 +74,8 @@ def func_ref(): if __name__ == "__main__": - b = 1 a = numpy.zeros((10, 10), dtype="int8") test_multi_element_array_in_outmost_namespace() test_different_dtype_assignment_to_var() + b = 1 test_var_capturing_order()