Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/python-packages/api-stub-generator/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ Fixed issue where types would appear wrapped in "Optional" even though
they do not accept `None`.
Fixed issue where, in some cases, string literal default values would not appear wrapped
in quotes.
Added support for @overloads decorators.
Fixed issue where decorators with parameters would not appear in APIView.

## Version 0.2.11 (2022-04-06)
Added __main__ to execute as module
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import astroid
import inspect
from ._base_node import get_qualified_name
from ._argtype import ArgType


class AstroidArgumentParser:

def __init__(self, node: astroid.Arguments, namespace: str, func_node):
if not isinstance(node, astroid.Arguments):
raise TypeError("Can only pass in an astroid Arguments node.")
self._namespace = namespace
self._node = node
self._parent = func_node
self.args = {}
self.kwargs = {}
self.posargs = {}
self.varargs = None
self._parse_args()
self._parse_kwargs()
self._parse_posonly_args()
self._parse_varargs()

def _default_value(self, name):
try:
return self._node.default_value(name).as_string()
except astroid.NoDefault:
return inspect.Parameter.empty

def _argtype(self, name, idx, annotations, type_comments):
if annotations:
argtype = annotations[idx]
elif type_comments:
argtype = type_comments[idx]
else:
argtype = None
return get_qualified_name(argtype, self._namespace) if argtype else None

def _parse_args(self):
for (idx, arg) in enumerate(self._node.args):
name = arg.name
argtype = self._argtype(name, idx, self._node.annotations, self._node.type_comment_args)
default = self._default_value(name)
self.args[name] = ArgType(name, argtype=argtype, default=default, keyword=None, func_node=self._parent)

def _parse_kwargs(self):
for (idx, arg) in enumerate(self._node.kwonlyargs):
name = arg.name
argtype = self._argtype(name, idx, self._node.kwonlyargs_annotations, self._node.type_comment_kwonlyargs)
default = self._default_value(name)
self.kwargs[name] = ArgType(name, argtype=argtype, default=default, keyword="keyword", func_node=self._parent)
if self._node.kwarg:
kwarg_name = self._node.kwarg
if self._node.kwargannotation:
kwarg_type = self._node.kwargannotation.as_string()
else:
kwarg_type = None
# This wonky logic matches the existing code
arg = ArgType(kwarg_name, argtype=kwarg_type, default=inspect.Parameter.empty, keyword="keyword", func_node=self._parent)
arg.argname = f"**{kwarg_name}"
self.args[arg.argname] = arg

def _parse_posonly_args(self):
for (idx, arg) in enumerate(self._node.posonlyargs):
name = arg.name
argtype = self._argtype(name, idx, self._node.posonlyargs_annotations, self._node.type_comment_posonlyargs)
default = self._default_value(name)
self.posargs[name] = ArgType(name, argtype=argtype, default=default, keyword=None, func_node=self._parent)

def _parse_varargs(self):
if self._node.vararg:
name = self._node.vararg
if self._node.varargannotation:
argtype = self._node.varargannotation.as_string()
else:
argtype = None
arg = ArgType(name, argtype=argtype, default=inspect.Parameter.empty, keyword=None, func_node=self._parent)
arg.argname = f"*{name}"
self.args[name] = arg
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import astroid
from inspect import Parameter
import re

Expand Down Expand Up @@ -45,13 +46,19 @@ def generate_tokens(self, apiview):
c.generate_tokens(apiview)
apiview.end_group()


def get_qualified_name(obj, namespace):
def get_qualified_name(obj, namespace: str) -> str:
"""Generate and return fully qualified name of object with module name for internal types.
If module name is not available for the object then it will return name
:param: obj
Parameter object of type class, function or enum
"""
module_name = getattr(obj, "__module__", "")

if module_name.startswith("astroid"):
return obj.as_string()
elif module_name == "types":
return str(obj)

if obj is Parameter.empty:
return None

Expand All @@ -61,10 +68,6 @@ def get_qualified_name(obj, namespace):
elif hasattr(obj, "__qualname__"):
name = getattr(obj, "__qualname__")

module_name = ""
if hasattr(obj, "__module__"):
module_name = getattr(obj, "__module__")

wrap_optional = False
args = []
# newer versions of Python extract inner types into __args__
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import astroid
import logging
import inspect
from enum import Enum
import operator
from typing import List

from ._base_node import NodeEntityBase, get_qualified_name
from ._function_node import FunctionNode
Expand Down Expand Up @@ -100,9 +102,7 @@ def _should_include_function(self, func_obj):
# Method or Function member should only be included if it is defined in same package.
# So this check will filter any methods defined in parent class if parent class is in non-azure package
# for e.g. as_dict method in msrest
if not (
inspect.ismethod(func_obj) or inspect.isfunction(func_obj)
) or inspect.isbuiltin(func_obj):
if not (inspect.ismethod(func_obj) or inspect.isfunction(func_obj)):
return False
if hasattr(func_obj, "__module__"):
function_module = getattr(func_obj, "__module__")
Expand Down Expand Up @@ -134,6 +134,29 @@ def _handle_class_variable(self, child_obj, name, *, type_string=None, value=Non
)
)

""" Uses AST parsing to look for @overload decorated functions
because inspect cannot see these. Note that this will not
find overloads for module-level functions.
"""
def _parse_overloads(self) -> List[FunctionNode]:
overload_nodes = []
try:
class_node = astroid.parse(inspect.getsource(self.obj)).body[0]
except:
return []
functions = [x for x in class_node.body if isinstance(x, astroid.FunctionDef)]
for func in functions:
if not func.decorators:
continue
for node in func.decorators.nodes:
try:
if node.name == "overload":
overload_node = FunctionNode(self.namespace, self, node=func)
overload_nodes.append(overload_node)
except AttributeError:
continue
return overload_nodes

def _inspect(self):
# Inspect current class and it's members recursively
logging.debug("Inspecting class {}".format(self.full_name))
Expand All @@ -156,16 +179,20 @@ def _inspect(self):
members = inspect.getmembers(self.obj)
else:
members = inspect.getmembers(self.obj)

overloads = self._parse_overloads()
for name, child_obj in members:
if inspect.isbuiltin(child_obj):
continue
elif self._should_include_function(child_obj):
# Include dunder and public methods
if not name.startswith("_") or name.startswith("__"):
try:
self.child_nodes.append(
FunctionNode(self.namespace, self, child_obj)
)
func_node = FunctionNode(self.namespace, self, obj=child_obj)
func_overloads = [x for x in overloads if x.name == func_node.name]
for overload in func_overloads:
self.child_nodes.append(overload)
self.child_nodes.append(func_node)
except OSError:
# Don't create entries for things that don't have source
pass
Expand All @@ -188,14 +215,27 @@ def _inspect(self):

if self.is_enum and isinstance(child_obj, self.obj):
self.child_nodes.append(
EnumNode(name=name, namespace=self.namespace, parent_node=self, obj=child_obj)
EnumNode(
name=name,
namespace=self.namespace,
parent_node=self,
obj=child_obj
)
)
elif inspect.isclass(child_obj):
self.child_nodes.append(
ClassNode(
name=child_obj.name,
namespace=self.namespace,
parent_node=self,
obj=child_obj,
pkg_root_namespace=self.pkg_root_namespace
)
)
elif isinstance(child_obj, property):
if not name.startswith("_"):
# Add instance properties
self.child_nodes.append(
PropertyNode(self.namespace, self, name, child_obj)
)
self.child_nodes.append(PropertyNode(self.namespace, self, name, child_obj))
else:
self._handle_class_variable(child_obj, name, value=str(child_obj))

Expand Down
Loading