2
2
3
3
import builtins
4
4
import re
5
+ import sys
5
6
import textwrap
6
7
import types
7
8
from ast import AST
14
15
from ast import Module
15
16
from ast import NodeTransformer
16
17
from ast import alias
18
+ from ast import copy_location
19
+ from ast import fix_missing_locations
17
20
from ast import unparse
18
21
from typing import TYPE_CHECKING
19
22
from typing import Any
@@ -57,18 +60,34 @@ def wrapper(*vargs, **kwargs):
57
60
symbols .update (kwargs )
58
61
59
62
class Transformer (NodeTransformer ):
60
- def visit_FunctionDef (self , node ) -> AST :
61
- name = symbols .get (node .name , self )
62
- if name is self :
63
+ def visit_FunctionDef (self , node : ast .FunctionDef ) -> AST :
64
+ if node .name not in symbols :
63
65
return self .generic_visit (node )
64
66
65
- return FunctionDef (
66
- name = name ,
67
- args = node .args ,
68
- body = list (map (self .visit , node .body )),
69
- decorator_list = getattr (node , "decorator_list" , []),
70
- lineno = None ,
71
- )
67
+ name = symbols [node .name ]
68
+ assert isinstance (name , str )
69
+ body : list [ast .stmt ] = [self .visit (stmt ) for stmt in node .body ]
70
+ if sys .version_info >= (3 , 12 ):
71
+ # mypy complains if type_params is missing
72
+ funcdef = FunctionDef (
73
+ name = name ,
74
+ args = node .args ,
75
+ body = body ,
76
+ decorator_list = node .decorator_list ,
77
+ returns = node .returns ,
78
+ type_params = node .type_params ,
79
+ )
80
+ else :
81
+ funcdef = FunctionDef (
82
+ name = name ,
83
+ args = node .args ,
84
+ body = body ,
85
+ decorator_list = node .decorator_list ,
86
+ returns = node .returns ,
87
+ )
88
+ copy_location (funcdef , node )
89
+ fix_missing_locations (funcdef )
90
+ return funcdef
72
91
73
92
def visit_Name (self , node : ast .Name ) -> AST :
74
93
value = symbols .get (node .id , self )
@@ -170,15 +189,17 @@ def require(self, value: type[Any] | Hashable) -> ast.Name:
170
189
def visit_Module (self , module : Module ) -> AST :
171
190
assert isinstance (module , Module )
172
191
module = super ().generic_visit (module ) # type: ignore[assignment]
173
- preamble : list [AST ] = []
192
+ preamble : list [ast . stmt ] = []
174
193
175
194
for name , node in self .defines .items ():
176
- assignment = Assign (targets = [store (name )], value = node , lineno = None )
195
+ assignment = Assign (targets = [store (name )], value = node )
196
+ copy_location (assignment , node )
197
+ fix_missing_locations (assignment )
177
198
preamble .append (self .visit (assignment ))
178
199
179
- imports : list [AST ] = []
200
+ imports : list [ast . stmt ] = []
180
201
for value , node in self .imports .items ():
181
- stmt : AST
202
+ stmt : ast . stmt
182
203
183
204
if isinstance (value , types .ModuleType ):
184
205
stmt = Import (
@@ -198,7 +219,7 @@ def visit_Module(self, module: Module) -> AST:
198
219
199
220
imports .append (stmt )
200
221
201
- return Module (imports + preamble + module .body , () )
222
+ return Module (imports + preamble + module .body , [] )
202
223
203
224
def visit_Comment (self , node : Comment ) -> AST :
204
225
self .comments .append (node .text )
0 commit comments