Skip to content

Commit

Permalink
feat: Pass down agent to extension hooks
Browse files Browse the repository at this point in the history
Some extensions need to know the user-selected docstring parser and options. The docstring parser and options can only be accessed through the agent.

Other extensions who create objects might want to trigger the rest of the extensions, which can again only be done if we have a reference to the agent.

Issue-312: #312
  • Loading branch information
pawamoy committed Aug 11, 2024
1 parent 87525f9 commit 71acb01
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 56 deletions.
42 changes: 21 additions & 21 deletions src/_griffe/agents/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,8 @@ def inspect_module(self, node: ObjectNode) -> None:
Parameters:
node: The node to inspect.
"""
self.extensions.call("on_node", node=node)
self.extensions.call("on_module_node", node=node)
self.extensions.call("on_node", node=node, agent=self)
self.extensions.call("on_module_node", node=node, agent=self)
self.current = module = Module(
name=self.module_name,
filepath=self.filepath,
Expand All @@ -305,20 +305,20 @@ def inspect_module(self, node: ObjectNode) -> None:
lines_collection=self.lines_collection,
modules_collection=self.modules_collection,
)
self.extensions.call("on_instance", node=node, obj=module)
self.extensions.call("on_module_instance", node=node, mod=module)
self.extensions.call("on_instance", node=node, obj=module, agent=self)
self.extensions.call("on_module_instance", node=node, mod=module, agent=self)
self.generic_inspect(node)
self.extensions.call("on_members", node=node, obj=module)
self.extensions.call("on_module_members", node=node, mod=module)
self.extensions.call("on_members", node=node, obj=module, agent=self)
self.extensions.call("on_module_members", node=node, mod=module, agent=self)

def inspect_class(self, node: ObjectNode) -> None:
"""Inspect a class.
Parameters:
node: The node to inspect.
"""
self.extensions.call("on_node", node=node)
self.extensions.call("on_class_node", node=node)
self.extensions.call("on_node", node=node, agent=self)
self.extensions.call("on_class_node", node=node, agent=self)

bases = []
for base in node.obj.__bases__:
Expand All @@ -336,11 +336,11 @@ def inspect_class(self, node: ObjectNode) -> None:
)
self.current.set_member(node.name, class_)
self.current = class_
self.extensions.call("on_instance", node=node, obj=class_)
self.extensions.call("on_class_instance", node=node, cls=class_)
self.extensions.call("on_instance", node=node, obj=class_, agent=self)
self.extensions.call("on_class_instance", node=node, cls=class_, agent=self)
self.generic_inspect(node)
self.extensions.call("on_members", node=node, obj=class_)
self.extensions.call("on_class_members", node=node, cls=class_)
self.extensions.call("on_members", node=node, obj=class_, agent=self)
self.extensions.call("on_class_members", node=node, cls=class_, agent=self)
self.current = self.current.parent # type: ignore[assignment]

def inspect_staticmethod(self, node: ObjectNode) -> None:
Expand Down Expand Up @@ -430,8 +430,8 @@ def handle_function(self, node: ObjectNode, labels: set | None = None) -> None:
node: The node to inspect.
labels: Labels to add to the data object.
"""
self.extensions.call("on_node", node=node)
self.extensions.call("on_function_node", node=node)
self.extensions.call("on_node", node=node, agent=self)
self.extensions.call("on_function_node", node=node, agent=self)

try:
signature = getsignature(node.obj)
Expand Down Expand Up @@ -475,11 +475,11 @@ def handle_function(self, node: ObjectNode, labels: set | None = None) -> None:
)
obj.labels |= labels
self.current.set_member(node.name, obj)
self.extensions.call("on_instance", node=node, obj=obj)
self.extensions.call("on_instance", node=node, obj=obj, agent=self)
if obj.is_attribute:
self.extensions.call("on_attribute_instance", node=node, attr=obj)
self.extensions.call("on_attribute_instance", node=node, attr=obj, agent=self)
else:
self.extensions.call("on_function_instance", node=node, func=obj)
self.extensions.call("on_function_instance", node=node, func=obj, agent=self)

def inspect_attribute(self, node: ObjectNode) -> None:
"""Inspect an attribute.
Expand All @@ -496,8 +496,8 @@ def handle_attribute(self, node: ObjectNode, annotation: str | Expr | None = Non
node: The node to inspect.
annotation: A potential annotation.
"""
self.extensions.call("on_node", node=node)
self.extensions.call("on_attribute_node", node=node)
self.extensions.call("on_node", node=node, agent=self)
self.extensions.call("on_attribute_node", node=node, agent=self)

# TODO: to improve
parent = self.current
Expand Down Expand Up @@ -533,8 +533,8 @@ def handle_attribute(self, node: ObjectNode, annotation: str | Expr | None = Non

if node.name == "__all__":
parent.exports = set(node.obj)
self.extensions.call("on_instance", node=node, obj=attribute)
self.extensions.call("on_attribute_instance", node=node, attr=attribute)
self.extensions.call("on_instance", node=node, obj=attribute, agent=self)
self.extensions.call("on_attribute_instance", node=node, attr=attribute, agent=self)


_kind_map = {
Expand Down
44 changes: 22 additions & 22 deletions src/_griffe/agents/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ def visit_module(self, node: ast.Module) -> None:
Parameters:
node: The node to visit.
"""
self.extensions.call("on_node", node=node)
self.extensions.call("on_module_node", node=node)
self.extensions.call("on_node", node=node, agent=self)
self.extensions.call("on_module_node", node=node, agent=self)
self.current = module = Module(
name=self.module_name,
filepath=self.filepath,
Expand All @@ -245,20 +245,20 @@ def visit_module(self, node: ast.Module) -> None:
lines_collection=self.lines_collection,
modules_collection=self.modules_collection,
)
self.extensions.call("on_instance", node=node, obj=module)
self.extensions.call("on_module_instance", node=node, mod=module)
self.extensions.call("on_instance", node=node, obj=module, agent=self)
self.extensions.call("on_module_instance", node=node, mod=module, agent=self)
self.generic_visit(node)
self.extensions.call("on_members", node=node, obj=module)
self.extensions.call("on_module_members", node=node, mod=module)
self.extensions.call("on_members", node=node, obj=module, agent=self)
self.extensions.call("on_module_members", node=node, mod=module, agent=self)

def visit_classdef(self, node: ast.ClassDef) -> None:
"""Visit a class definition node.
Parameters:
node: The node to visit.
"""
self.extensions.call("on_node", node=node)
self.extensions.call("on_class_node", node=node)
self.extensions.call("on_node", node=node, agent=self)
self.extensions.call("on_class_node", node=node, agent=self)

# handle decorators
decorators = []
Expand Down Expand Up @@ -293,11 +293,11 @@ def visit_classdef(self, node: ast.ClassDef) -> None:
class_.labels |= self.decorators_to_labels(decorators)
self.current.set_member(node.name, class_)
self.current = class_
self.extensions.call("on_instance", node=node, obj=class_)
self.extensions.call("on_class_instance", node=node, cls=class_)
self.extensions.call("on_instance", node=node, obj=class_, agent=self)
self.extensions.call("on_class_instance", node=node, cls=class_, agent=self)
self.generic_visit(node)
self.extensions.call("on_members", node=node, obj=class_)
self.extensions.call("on_class_members", node=node, cls=class_)
self.extensions.call("on_members", node=node, obj=class_, agent=self)
self.extensions.call("on_class_members", node=node, cls=class_, agent=self)
self.current = self.current.parent # type: ignore[assignment]

def decorators_to_labels(self, decorators: list[Decorator]) -> set[str]:
Expand Down Expand Up @@ -349,8 +349,8 @@ def handle_function(self, node: ast.AsyncFunctionDef | ast.FunctionDef, labels:
node: The node to visit.
labels: Labels to add to the data object.
"""
self.extensions.call("on_node", node=node)
self.extensions.call("on_function_node", node=node)
self.extensions.call("on_node", node=node, agent=self)
self.extensions.call("on_function_node", node=node, agent=self)

labels = labels or set()

Expand Down Expand Up @@ -387,8 +387,8 @@ def handle_function(self, node: ast.AsyncFunctionDef | ast.FunctionDef, labels:
)
attribute.labels |= labels
self.current.set_member(node.name, attribute)
self.extensions.call("on_instance", node=node, obj=attribute)
self.extensions.call("on_attribute_instance", node=node, attr=attribute)
self.extensions.call("on_instance", node=node, obj=attribute, agent=self)
self.extensions.call("on_attribute_instance", node=node, attr=attribute, agent=self)
return

# handle parameters
Expand Down Expand Up @@ -438,8 +438,8 @@ def handle_function(self, node: ast.AsyncFunctionDef | ast.FunctionDef, labels:

function.labels |= labels

self.extensions.call("on_instance", node=node, obj=function)
self.extensions.call("on_function_instance", node=node, func=function)
self.extensions.call("on_instance", node=node, obj=function, agent=self)
self.extensions.call("on_function_instance", node=node, func=function, agent=self)
if self.current.kind is Kind.CLASS and function.name == "__init__":
self.current = function # type: ignore[assignment] # temporary assign a function
self.generic_visit(node)
Expand Down Expand Up @@ -529,8 +529,8 @@ def handle_attribute(
node: The node to visit.
annotation: A potential annotation.
"""
self.extensions.call("on_node", node=node)
self.extensions.call("on_attribute_node", node=node)
self.extensions.call("on_node", node=node, agent=self)
self.extensions.call("on_attribute_node", node=node, agent=self)
parent = self.current
labels = set()

Expand Down Expand Up @@ -618,8 +618,8 @@ def handle_attribute(
name if isinstance(name, str) else ExprName(name.name, parent=name.parent)
for name in safe_get__all__(node, self.current) # type: ignore[arg-type]
]
self.extensions.call("on_instance", node=node, obj=attribute)
self.extensions.call("on_attribute_instance", node=node, attr=attribute)
self.extensions.call("on_instance", node=node, obj=attribute, agent=self)
self.extensions.call("on_attribute_instance", node=node, attr=attribute, agent=self)

def visit_assign(self, node: ast.Assign) -> None:
"""Visit an assignment node.
Expand Down
75 changes: 62 additions & 13 deletions src/_griffe/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,98 +135,147 @@ def generic_inspect(self, node: ObjectNode) -> None:
if not child.alias_target_path:
self.inspect(child)

def on_node(self, *, node: ast.AST | ObjectNode, **kwargs: Any) -> None:
def on_node(self, *, node: ast.AST | ObjectNode, agent: Visitor | Inspector, **kwargs: Any) -> None:
"""Run when visiting a new node during static/dynamic analysis.
Parameters:
node: The currently visited node.
"""

def on_instance(self, *, node: ast.AST | ObjectNode, obj: Object, **kwargs: Any) -> None:
def on_instance(
self,
*,
node: ast.AST | ObjectNode,
obj: Object,
agent: Visitor | Inspector,
**kwargs: Any,
) -> None:
"""Run when an Object has been created.
Parameters:
node: The currently visited node.
obj: The object instance.
"""

def on_members(self, *, node: ast.AST | ObjectNode, obj: Object, **kwargs: Any) -> None:
def on_members(self, *, node: ast.AST | ObjectNode, obj: Object, agent: Visitor | Inspector, **kwargs: Any) -> None:
"""Run when members of an Object have been loaded.
Parameters:
node: The currently visited node.
obj: The object instance.
"""

def on_module_node(self, *, node: ast.AST | ObjectNode, **kwargs: Any) -> None:
def on_module_node(self, *, node: ast.AST | ObjectNode, agent: Visitor | Inspector, **kwargs: Any) -> None:
"""Run when visiting a new module node during static/dynamic analysis.
Parameters:
node: The currently visited node.
"""

def on_module_instance(self, *, node: ast.AST | ObjectNode, mod: Module, **kwargs: Any) -> None:
def on_module_instance(
self,
*,
node: ast.AST | ObjectNode,
mod: Module,
agent: Visitor | Inspector,
**kwargs: Any,
) -> None:
"""Run when a Module has been created.
Parameters:
node: The currently visited node.
mod: The module instance.
"""

def on_module_members(self, *, node: ast.AST | ObjectNode, mod: Module, **kwargs: Any) -> None:
def on_module_members(
self,
*,
node: ast.AST | ObjectNode,
mod: Module,
agent: Visitor | Inspector,
**kwargs: Any,
) -> None:
"""Run when members of a Module have been loaded.
Parameters:
node: The currently visited node.
mod: The module instance.
"""

def on_class_node(self, *, node: ast.AST | ObjectNode, **kwargs: Any) -> None:
def on_class_node(self, *, node: ast.AST | ObjectNode, agent: Visitor | Inspector, **kwargs: Any) -> None:
"""Run when visiting a new class node during static/dynamic analysis.
Parameters:
node: The currently visited node.
"""

def on_class_instance(self, *, node: ast.AST | ObjectNode, cls: Class, **kwargs: Any) -> None:
def on_class_instance(
self,
*,
node: ast.AST | ObjectNode,
cls: Class,
agent: Visitor | Inspector,
**kwargs: Any,
) -> None:
"""Run when a Class has been created.
Parameters:
node: The currently visited node.
cls: The class instance.
"""

def on_class_members(self, *, node: ast.AST | ObjectNode, cls: Class, **kwargs: Any) -> None:
def on_class_members(
self,
*,
node: ast.AST | ObjectNode,
cls: Class,
agent: Visitor | Inspector,
**kwargs: Any,
) -> None:
"""Run when members of a Class have been loaded.
Parameters:
node: The currently visited node.
cls: The class instance.
"""

def on_function_node(self, *, node: ast.AST | ObjectNode, **kwargs: Any) -> None:
def on_function_node(self, *, node: ast.AST | ObjectNode, agent: Visitor | Inspector, **kwargs: Any) -> None:
"""Run when visiting a new function node during static/dynamic analysis.
Parameters:
node: The currently visited node.
"""

def on_function_instance(self, *, node: ast.AST | ObjectNode, func: Function, **kwargs: Any) -> None:
def on_function_instance(
self,
*,
node: ast.AST | ObjectNode,
func: Function,
agent: Visitor | Inspector,
**kwargs: Any,
) -> None:
"""Run when a Function has been created.
Parameters:
node: The currently visited node.
func: The function instance.
"""

def on_attribute_node(self, *, node: ast.AST | ObjectNode, **kwargs: Any) -> None:
def on_attribute_node(self, *, node: ast.AST | ObjectNode, agent: Visitor | Inspector, **kwargs: Any) -> None:
"""Run when visiting a new attribute node during static/dynamic analysis.
Parameters:
node: The currently visited node.
"""

def on_attribute_instance(self, *, node: ast.AST | ObjectNode, attr: Attribute, **kwargs: Any) -> None:
def on_attribute_instance(
self,
*,
node: ast.AST | ObjectNode,
attr: Attribute,
agent: Visitor | Inspector,
**kwargs: Any,
) -> None:
"""Run when an Attribute has been created.
Parameters:
Expand Down

0 comments on commit 71acb01

Please sign in to comment.