@@ -126,6 +126,13 @@ class VisitorActionEnum(Enum):
126126}
127127
128128
129+ class EnterLeaveVisitor (NamedTuple ):
130+ """Visitor with functions for entering and leaving."""
131+
132+ enter : Optional [Callable [..., Optional [VisitorAction ]]]
133+ leave : Optional [Callable [..., Optional [VisitorAction ]]]
134+
135+
129136class Visitor :
130137 """Visitor that walks through an AST.
131138
@@ -170,6 +177,8 @@ def leave(self, node, key, parent, path, ancestors):
170177 # Provide special return values as attributes
171178 BREAK , SKIP , REMOVE , IDLE = BREAK , SKIP , REMOVE , IDLE
172179
180+ enter_leave_map : Dict [str , EnterLeaveVisitor ]
181+
173182 def __init_subclass__ (cls ) -> None :
174183 """Verify that all defined handlers are valid."""
175184 super ().__init_subclass__ ()
@@ -191,13 +200,34 @@ def __init_subclass__(cls) -> None:
191200 ):
192201 raise TypeError (f"Invalid AST node kind: { kind } ." )
193202
194- def get_visit_fn (self , kind : str , is_leaving : bool = False ) -> Callable :
195- """Get the visit function for the given node kind and direction."""
196- method = "leave" if is_leaving else "enter"
197- visit_fn = getattr (self , f"{ method } _{ kind } " , None )
198- if not visit_fn :
199- visit_fn = getattr (self , method , None )
200- return visit_fn
203+ def __init__ (self ) -> None :
204+ self .enter_leave_map = {}
205+
206+ def get_enter_leave_for_kind (self , kind : str ) -> EnterLeaveVisitor :
207+ """Given a node kind, return the EnterLeaveVisitor for that kind."""
208+ try :
209+ return self .enter_leave_map [kind ]
210+ except KeyError :
211+ enter_fn = getattr (self , f"enter_{ kind } " , None )
212+ if not enter_fn :
213+ enter_fn = getattr (self , "enter" , None )
214+ leave_fn = getattr (self , f"leave_{ kind } " , None )
215+ if not leave_fn :
216+ leave_fn = getattr (self , "leave" , None )
217+ enter_leave = EnterLeaveVisitor (enter_fn , leave_fn )
218+ self .enter_leave_map [kind ] = enter_leave
219+ return enter_leave
220+
221+ def get_visit_fn (
222+ self , kind : str , is_leaving : bool = False
223+ ) -> Optional [Callable [..., Optional [VisitorAction ]]]:
224+ """Get the visit function for the given node kind and direction.
225+
226+ .. deprecated:: 3.2
227+ Please use ``get_enter_leave_for_kind`` instead. Will be removed in v3.3.
228+ """
229+ enter_leave = self .get_enter_leave_for_kind (kind )
230+ return enter_leave .leave if is_leaving else enter_leave .enter
201231
202232
203233class Stack (NamedTuple ):
@@ -237,6 +267,7 @@ def visit(
237267 raise TypeError (f"Not an AST Visitor: { inspect (visitor )} ." )
238268 if visitor_keys is None :
239269 visitor_keys = QUERY_DOCUMENT_KEYS
270+
240271 stack : Any = None
241272 in_array = isinstance (root , list )
242273 keys : Tuple [Node , ...] = (root ,)
@@ -299,7 +330,8 @@ def visit(
299330 else :
300331 if not isinstance (node , Node ):
301332 raise TypeError (f"Invalid AST Node: { inspect (node )} ." )
302- visit_fn = visitor .get_visit_fn (node .kind , is_leaving )
333+ enter_leave = visitor .get_enter_leave_for_kind (node .kind )
334+ visit_fn = enter_leave .leave if is_leaving else enter_leave .enter
303335 if visit_fn :
304336 result = visit_fn (node , key , parent , path , ancestors )
305337
@@ -357,39 +389,63 @@ class ParallelVisitor(Visitor):
357389
358390 def __init__ (self , visitors : Collection [Visitor ]):
359391 """Create a new visitor from the given list of parallel visitors."""
392+ super ().__init__ ()
360393 self .visitors = visitors
361394 self .skipping : List [Any ] = [None ] * len (visitors )
362395
363- def enter (self , node : Node , * args : Any ) -> Optional [VisitorAction ]:
364- skipping = self .skipping
365- for i , visitor in enumerate (self .visitors ):
366- if not skipping [i ]:
367- fn = visitor .get_visit_fn (node .kind )
368- if fn :
369- result = fn (node , * args )
370- if result is SKIP or result is False :
371- skipping [i ] = node
372- elif result is BREAK or result is True :
373- skipping [i ] = BREAK
374- elif result is not None :
375- return result
376- return None
377-
378- def leave (self , node : Node , * args : Any ) -> Optional [VisitorAction ]:
379- skipping = self .skipping
380- for i , visitor in enumerate (self .visitors ):
381- if not skipping [i ]:
382- fn = visitor .get_visit_fn (node .kind , is_leaving = True )
383- if fn :
384- result = fn (node , * args )
385- if result is BREAK or result is True :
386- skipping [i ] = BREAK
387- elif (
388- result is not None
389- and result is not SKIP
390- and result is not False
391- ):
392- return result
393- elif skipping [i ] is node :
394- skipping [i ] = None
395- return None
396+ def get_enter_leave_for_kind (self , kind : str ) -> EnterLeaveVisitor :
397+ """Given a node kind, return the EnterLeaveVisitor for that kind."""
398+ try :
399+ return self .enter_leave_map [kind ]
400+ except KeyError :
401+ has_visitor = False
402+ enter_list : List [Optional [Callable [..., Optional [VisitorAction ]]]] = []
403+ leave_list : List [Optional [Callable [..., Optional [VisitorAction ]]]] = []
404+ for visitor in self .visitors :
405+ enter , leave = visitor .get_enter_leave_for_kind (kind )
406+ if not has_visitor and (enter or leave ):
407+ has_visitor = True
408+ enter_list .append (enter )
409+ leave_list .append (leave )
410+
411+ if has_visitor :
412+
413+ def enter (node : Node , * args : Any ) -> Optional [VisitorAction ]:
414+ skipping = self .skipping
415+ for i , fn in enumerate (enter_list ):
416+ if not skipping [i ]:
417+ if fn :
418+ result = fn (node , * args )
419+ if result is SKIP or result is False :
420+ skipping [i ] = node
421+ elif result is BREAK or result is True :
422+ skipping [i ] = BREAK
423+ elif result is not None :
424+ return result
425+ return None
426+
427+ def leave (node : Node , * args : Any ) -> Optional [VisitorAction ]:
428+ skipping = self .skipping
429+ for i , fn in enumerate (leave_list ):
430+ if not skipping [i ]:
431+ if fn :
432+ result = fn (node , * args )
433+ if result is BREAK or result is True :
434+ skipping [i ] = BREAK
435+ elif (
436+ result is not None
437+ and result is not SKIP
438+ and result is not False
439+ ):
440+ return result
441+ elif skipping [i ] is node :
442+ skipping [i ] = None
443+ return None
444+
445+ else :
446+
447+ enter = leave = None
448+
449+ enter_leave = EnterLeaveVisitor (enter , leave )
450+ self .enter_leave_map [kind ] = enter_leave
451+ return enter_leave
0 commit comments