Skip to content

Commit 6161a8d

Browse files
lightzhan-intelliflightzhan-intellif
andauthored
[BugFix][TVMScript]fix var capturing order error (#13640)
This PR try to fix the following bug: ```python def test_var_capturing_order(): b = 2 @T.prim_func def test_case(): k: T.int32 = b if __name__ == "__main__": b = 1 ``` In the prim func `test_case`, the vaule of b should be 2, rather than 1. The parser wrongly uses global vars to shadow the value of nonlocal vars, which should be reversed. Co-authored-by: lightzhan-intellif <[email protected]>
1 parent ddb006e commit 6161a8d

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

python/tvm/script/parser/core/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ def inspect_function_capture(func: Callable) -> Dict[str, Any]:
3737
The function variables map with non-local or global variables.
3838
"""
3939
captured = {
40-
**inspect.getclosurevars(func).nonlocals,
4140
**func.__globals__, # type: ignore
41+
**inspect.getclosurevars(func).nonlocals,
4242
}
4343
return captured
4444

tests/python/unittest/test_tvmscript_regression.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,24 @@ def func_ref():
5858
tvm.ir.assert_structural_equal(test_case, func_ref)
5959

6060

61+
def test_var_capturing_order():
62+
b = 2
63+
64+
@T.prim_func
65+
def test_case():
66+
k: T.int32 = b
67+
68+
@T.prim_func
69+
def func_ref():
70+
k: T.int32 = 2
71+
T.evaluate(0)
72+
73+
tvm.ir.assert_structural_equal(test_case, func_ref)
74+
75+
6176
if __name__ == "__main__":
6277
a = numpy.zeros((10, 10), dtype="int8")
6378
test_multi_element_array_in_outmost_namespace()
6479
test_different_dtype_assignment_to_var()
80+
b = 1
81+
test_var_capturing_order()

0 commit comments

Comments
 (0)