@@ -173,6 +173,9 @@ def leave(self, node, key, parent, path, ancestors):
173173 # Provide special return values as attributes
174174 BREAK , SKIP , REMOVE , IDLE = BREAK , SKIP , REMOVE , IDLE
175175
176+ def __init__ (self ):
177+ self ._visit_fns = {}
178+
176179 def __init_subclass__ (cls ) -> None :
177180 """Verify that all defined handlers are valid."""
178181 super ().__init_subclass__ ()
@@ -197,11 +200,12 @@ def __init_subclass__(cls) -> None:
197200
198201 def get_visit_fn (self , kind : str , is_leaving : bool = False ) -> Callable :
199202 """Get the visit function for the given node kind and direction."""
200- method = "leave" if is_leaving else "enter"
201- visit_fn = getattr (self , f"{ method } _{ kind } " , None )
202- if not visit_fn :
203- visit_fn = getattr (self , method , None )
204- return visit_fn
203+ key = (kind , is_leaving )
204+ if key not in self ._visit_fns :
205+ method = "leave" if is_leaving else "enter"
206+ fn = getattr (self , f"{ method } _{ kind } " , None )
207+ self ._visit_fns [key ] = fn or getattr (self , method , None )
208+ return self ._visit_fns [key ]
205209
206210
207211class Stack (NamedTuple ):
@@ -367,14 +371,22 @@ class ParallelVisitor(Visitor):
367371
368372 def __init__ (self , visitors : Collection [Visitor ]):
369373 """Create a new visitor from the given list of parallel visitors."""
374+ super ().__init__ ()
370375 self .visitors = visitors
371376 self .skipping : List [Any ] = [None ] * len (visitors )
377+ self ._enter_visit_fns = {}
378+ self ._leave_visit_fns = {}
372379
373380 def enter (self , node : Node , * args : Any ) -> Optional [VisitorAction ]:
381+ visit_fns = self ._enter_visit_fns .get (node .kind )
382+ if visit_fns is None :
383+ visit_fns = [v .get_visit_fn (node .kind ) for v in self .visitors ]
384+ self ._enter_visit_fns [node .kind ] = visit_fns
385+
374386 skipping = self .skipping
375387 for i , visitor in enumerate (self .visitors ):
376388 if not skipping [i ]:
377- fn = visitor . get_visit_fn ( node . kind )
389+ fn = visit_fns [ i ]
378390 if fn :
379391 result = fn (node , * args )
380392 if result is SKIP or result is False :
@@ -386,10 +398,15 @@ def enter(self, node: Node, *args: Any) -> Optional[VisitorAction]:
386398 return None
387399
388400 def leave (self , node : Node , * args : Any ) -> Optional [VisitorAction ]:
401+ visit_fns = self ._leave_visit_fns .get (node .kind )
402+ if visit_fns is None :
403+ visit_fns = [v .get_visit_fn (node .kind , is_leaving = True ) for v in self .visitors ]
404+ self ._leave_visit_fns [node .kind ] = visit_fns
405+
389406 skipping = self .skipping
390407 for i , visitor in enumerate (self .visitors ):
391408 if not skipping [i ]:
392- fn = visitor . get_visit_fn ( node . kind , is_leaving = True )
409+ fn = visit_fns [ i ]
393410 if fn :
394411 result = fn (node , * args )
395412 if result is BREAK or result is True :
0 commit comments