From 2ea88c020481e78060c90d8307a4f6a68047eaa2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20Mazzucotelli?= Date: Fri, 5 Nov 2021 21:40:34 +0100 Subject: [PATCH] refactor: Improve visitor getters --- src/griffe/visitor.py | 252 +++++++++++++++++++++++++++++++----------- 1 file changed, 188 insertions(+), 64 deletions(-) diff --git a/src/griffe/visitor.py b/src/griffe/visitor.py index 0c90cd1a..2f581674 100644 --- a/src/griffe/visitor.py +++ b/src/griffe/visitor.py @@ -9,28 +9,45 @@ from __future__ import annotations import inspect -from ast import ( - AST, - AnnAssign, - Assign, - Attribute, - BinOp, - BitOr, - Call, - Constant, - Dict, - Expr, - FormattedValue, - Index, - JoinedStr, - List, - Name, - PyCF_ONLY_AST, - Str, - Subscript, - Tuple, - keyword, -) +from ast import AST as Node +from ast import And as NodeAnd +from ast import AnnAssign as NodeAnnAssign +from ast import Assign as NodeAssign +from ast import Attribute as NodeAttribute +from ast import BinOp as NodeBinOp +from ast import BitOr as NodeBitOr +from ast import BoolOp as NodeBoolOp +from ast import Call as NodeCall +from ast import Compare as NodeCompare +from ast import Constant as NodeConstant +from ast import Dict as NodeDict +from ast import DictComp as NodeDictComp +from ast import Expr as NodeExpr +from ast import FormattedValue as NodeFormattedValue +from ast import GeneratorExp as NodeGeneratorExp +from ast import IfExp as NodeIfExp +from ast import Index as NodeIndex +from ast import JoinedStr as NodeJoinedStr +from ast import Lambda as NodeLambda +from ast import List as NodeList +from ast import ListComp as NodeListComp +from ast import Mult as NodeMult +from ast import Name as NodeName +from ast import Not as NodeNot +from ast import NotEq as NodeNotEq +from ast import Or as NodeOr +from ast import PyCF_ONLY_AST +from ast import Set as NodeSet +from ast import Slice as NodeSlice +from ast import Starred as NodeStarred +from ast import Str as NodeStr +from ast import Subscript as NodeSubscript +from ast import Tuple as NodeTuple +from ast import UAdd as NodeUAdd +from ast import UnaryOp as NodeUnaryOp +from ast import USub as NodeUSub +from ast import comprehension as NodeComprehension +from ast import keyword as NodeKeyword from itertools import zip_longest from pathlib import Path @@ -66,15 +83,15 @@ def visit( # ========================================================== # docstrings def _get_docstring(node): - if isinstance(node, Expr): + if isinstance(node, NodeExpr): doc = node.value - elif node.body and isinstance(node.body[0], Expr): + elif node.body and isinstance(node.body[0], NodeExpr): doc = node.body[0].value else: return None - if isinstance(doc, Constant) and isinstance(doc.value, str): + if isinstance(doc, NodeConstant) and isinstance(doc.value, str): return Docstring(doc.value, doc.lineno, doc.end_lineno) - if isinstance(doc, Str): + if isinstance(doc, NodeStr): return Docstring(doc.s, doc.lineno, doc.end_lineno) return None @@ -82,10 +99,13 @@ def _get_docstring(node): # ========================================================== # base classes def _get_base_class_name(node): - if isinstance(node, Name): + if isinstance(node, NodeName): return node.id - if isinstance(node, Attribute): + if isinstance(node, NodeAttribute): return f"{_get_base_class_name(node.value)}.{node.attr}" + # TODO: resolve subscript + if isinstance(node, NodeSubscript): + return f"{_get_base_class_name(node.value)}[{_get_base_class_name(node.slice)}]" # ========================================================== @@ -103,7 +123,7 @@ def _get_attribute_annotation(node): def _get_binop_annotation(node): - if isinstance(node.op, BitOr): + if isinstance(node.op, NodeBitOr): return f"{_get_annotation(node.left)} | {_get_annotation(node.right)}" @@ -124,14 +144,14 @@ def _get_list_annotation(node): _node_annotation_map = { - Name: _get_name_annotation, - Constant: _get_constant_annotation, - Attribute: _get_attribute_annotation, - BinOp: _get_binop_annotation, - Subscript: _get_subscript_annotation, - Index: _get_index_annotation, - Tuple: _get_tuple_annotation, - List: _get_list_annotation, + NodeName: _get_name_annotation, + NodeConstant: _get_constant_annotation, + NodeAttribute: _get_attribute_annotation, + NodeBinOp: _get_binop_annotation, + NodeSubscript: _get_subscript_annotation, + NodeIndex: _get_index_annotation, + NodeTuple: _get_tuple_annotation, + NodeList: _get_list_annotation, } @@ -154,8 +174,31 @@ def _get_attribute_value(node): def _get_binop_value(node): - if isinstance(node.op, BitOr): - return f"{_get_value(node.left)} | {_get_value(node.right)}" + return f"{_get_value(node.left)} {_get_value(node.op)} {_get_value(node.right)}" + + +def _get_bitor_value(node): + return "|" + + +def _get_mult_value(node): + return "*" + + +def _get_unaryop_value(node): + if isinstance(node.op, NodeUSub): + return f"-{_get_value(node.operand)}" + if isinstance(node.op, NodeUAdd): + return f"+{_get_value(node.operand)}" + if isinstance(node.op, NodeNot): + return f"not {_get_value(node.operand)}" + + +def _get_slice_value(node): + value = f"{_get_value(node.lower) if node.lower else ''}:{_get_value(node.upper) if node.upper else ''}" + if node.step: + value = f"{value}:{_get_value(node.step)}" + return value def _get_subscript_value(node): @@ -166,6 +209,10 @@ def _get_index_value(node): return _get_value(node.value) +def _get_lambda_value(node): + return f"lambda {_get_value(node.args)}: {_get_value(node.body)}" + + def _get_list_value(node): return "[" + ", ".join(_get_value(el) for el in node.elts) + "]" @@ -183,10 +230,18 @@ def _get_dict_value(node): return "{" + ", ".join(f"{_get_value(key)}: {_get_value(value)}" for key, value in pairs) + "}" +def _get_set_value(node): + return "{" + ", ".join(_get_value(el) for el in node.elts) + "}" + + def _get_ellipsis_value(node): return "..." +def _get_starred_value(node): + return _get_value(node.value) + + def _get_formatted_value(node): return f"{{{_get_value(node.value)}}}" @@ -195,6 +250,59 @@ def _get_joinedstr_value(node): return "".join(_get_value(value) for value in node.values) +def _get_boolop_value(node): + if isinstance(node.op, NodeOr): + return " or ".join(_get_value(value) for value in node.values) + if isinstance(node.op, NodeAnd): + return " and ".join(_get_value(value) for value in node.values) + + +def _get_compare_value(node): + left = _get_value(node.left) + ops = [_get_value(op) for op in node.ops] + comparators = [_get_value(comparator) for comparator in node.comparators] + return f"{left} " + " ".join(f"{op} {comp}" for op, comp in zip(ops, comparators)) + + +def _get_noteq_value(node): + return "!=" + + +def _get_generatorexp_value(node): + element = _get_value(node.elt) + generators = [_get_value(gen) for gen in node.generators] + return f"{element} " + " ".join(generators) + + +def _get_listcomp_value(node): + element = _get_value(node.elt) + generators = [_get_value(gen) for gen in node.generators] + return f"[{element} " + " ".join(generators) + "]" + + +def _get_dictcomp_value(node): + key = _get_value(node.key) + value = _get_value(node.value) + generators = [_get_value(gen) for gen in node.generators] + return f"{{{key}: {value} " + " ".join(generators) + "}" + + +def _get_comprehension_value(node): + target = _get_value(node.target) + iterable = _get_value(node.iter) + conditions = [_get_value(condition) for condition in node.ifs] + value = f"for {target} in {iterable}" + if conditions: + value = f"{value} if " + " if ".join(conditions) + if node.is_async: + value = f"async {value}" + return value + + +def _get_ifexp_value(node): + return f"{_get_value(node.body)} if {_get_value(node.test)} else {_get_value(node.orelse)}" + + def _get_call_value(node): posargs = ", ".join(_get_value(arg) for arg in node.args) kwargs = ", ".join(_get_value(kwarg) for kwarg in node.keywords) @@ -211,19 +319,34 @@ def _get_call_value(node): _node_value_map = { type(None): lambda _: repr(None), - Name: _get_name_value, - Constant: _get_constant_value, - Attribute: _get_attribute_value, - BinOp: _get_binop_value, - Subscript: _get_subscript_value, - Index: _get_index_value, - List: _get_list_value, - Tuple: _get_tuple_value, - keyword: _get_keyword_value, - Dict: _get_dict_value, - FormattedValue: _get_formatted_value, - JoinedStr: _get_joinedstr_value, - Call: _get_call_value, + NodeName: _get_name_value, + NodeConstant: _get_constant_value, + NodeAttribute: _get_attribute_value, + NodeBinOp: _get_binop_value, + NodeUnaryOp: _get_unaryop_value, + NodeSubscript: _get_subscript_value, + NodeIndex: _get_index_value, + NodeList: _get_list_value, + NodeTuple: _get_tuple_value, + NodeKeyword: _get_keyword_value, + NodeDict: _get_dict_value, + NodeSet: _get_set_value, + NodeFormattedValue: _get_formatted_value, + NodeJoinedStr: _get_joinedstr_value, + NodeCall: _get_call_value, + NodeSlice: _get_slice_value, + NodeBoolOp: _get_boolop_value, + NodeGeneratorExp: _get_generatorexp_value, + NodeComprehension: _get_comprehension_value, + NodeCompare: _get_compare_value, + NodeNotEq: _get_noteq_value, + NodeBitOr: _get_bitor_value, + NodeMult: _get_mult_value, + NodeListComp: _get_listcomp_value, + NodeLambda: _get_lambda_value, + NodeDictComp: _get_dictcomp_value, + NodeStarred: _get_starred_value, + NodeIfExp: _get_ifexp_value, } @@ -234,7 +357,7 @@ def _get_value(node): # ========================================================== # names def _get_attribute_name(node): - return f"{node.attr}.{_get_names(node.value)}" + return f"{_get_names(node.value)}.{node.attr}" def _get_name_name(node): @@ -242,18 +365,18 @@ def _get_name_name(node): def _get_assign_names(node): - return [_get_names(target) for target in node.targets] + return [name for name in [_get_names(target) for target in node.targets] if name] def _get_annassign_names(node): - return [_get_names(node.target)] + return [name for name in _get_names(node.target) if name] _node_names_map = { - Assign: _get_assign_names, - AnnAssign: _get_annassign_names, - Name: _get_name_name, - Attribute: _get_attribute_name, + NodeAssign: _get_assign_names, + NodeAnnAssign: _get_annassign_names, + NodeName: _get_name_name, + NodeAttribute: _get_attribute_name, } @@ -270,9 +393,9 @@ def _get_instance_names(node): def _get_parameter_default(node, filepath): if node is None: return None - if isinstance(node, Constant): + if isinstance(node, NodeConstant): return repr(node.value) - if isinstance(node, Name): + if isinstance(node, NodeName): return node.id if node.lineno == node.end_lineno: return lines_collection[filepath][node.lineno - 1][node.col_offset : node.end_col_offset] @@ -296,11 +419,12 @@ def __init__( self.code: str = code self.extensions: Extensions = extensions.instantiate(self) # self.scope = defaultdict(dict) + self.root: Node | None = None self.parent: Module | None = parent self.current: Module | Class | Function = None # type: ignore self.in_decorator: bool = False - def _visit(self, node: AST, parent: AST | None = None) -> None: + def _visit(self, node: Node, parent: Node | None = None) -> None: node.parent = parent # type: ignore self._run_specific_or_generic(node) @@ -311,14 +435,14 @@ def get_module(self) -> Module: self.visit(top_node) return self.current.module # type: ignore # there's always a module after the visit - def visit(self, node: AST, parent: AST | None = None) -> None: + def visit(self, node: Node, parent: Node | None = None) -> None: for start_visitor in self.extensions.when_visit_starts: start_visitor.visit(node, parent) super().visit(node, parent) for stop_visitor in self.extensions.when_visit_stops: stop_visitor.visit(node, parent) - def generic_visit(self, node: AST) -> None: # noqa: WPS231 + def generic_visit(self, node: Node) -> None: # noqa: WPS231 for start_visitor in self.extensions.when_children_visit_starts: start_visitor.visit(node) super().generic_visit(node)