Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 91 additions & 18 deletions pybind11_stubgen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,68 @@ def __init__(self, name, args='*args, **kwargs', rtype='None', validate=True):
for invalid_default in invalid_defaults:
logger.log(lvl, " {}".format(invalid_default))

if USE_BOOST_PYTHON:
if args:
find_optional_args = re.findall('\[(.*?)\]$', args)
optional_args = None
if find_optional_args:
optional_args = find_optional_args[0]
if optional_args:
nominal_args = args.replace("[" + optional_args + "]","")
else:
nominal_args = args

num_nominal_args = 0
if nominal_args:
nominal_args = nominal_args.split(",")
num_nominal_args = len(nominal_args)

num_optional_args = 0
if optional_args:
optional_args = optional_args.split("[,")
num_optional_args = len(optional_args)
if num_optional_args > 1:
optional_args[-1] = re.sub(']'*(num_optional_args-1)+'$', '', optional_args[-1]) # Replace at the end
new_args = ""

if nominal_args:
for k,arg in enumerate(nominal_args):
type_name = re.findall('\((.*?)\)', arg)[0]
arg_name = arg.split(")")[1]
arg_name = arg_name.replace(' ','_')

new_args += arg_name + ": " + type_name
if k < num_nominal_args-1:
new_args += ", "

if num_optional_args > 0 and num_nominal_args > 0:
new_args += ", "

if optional_args and True:
for k,arg in enumerate(optional_args):
# Check for default value
split_arg_equal = arg.split('=',maxsplit=1)
main_arg = split_arg_equal[0]
type_name = re.findall('\((.*?)\)', main_arg)[0]

arg_name = main_arg.split(")")[1]
arg_name = arg_name.replace(' ','_')
new_args += arg_name + ": " + type_name
optional_value = None
if len(split_arg_equal) > 1:
optional_value = split_arg_equal[1]
new_args += " = " + optional_value

if k < num_optional_args-1:
new_args += ", "

new_args = new_args.replace(" ,", ",")
self.args = new_args
args = new_args

rtype = rtype.split(" :")[0]
self.rtype = rtype

function_def_str = "def {sig.name}({sig.args}) -> {sig.rtype}: ...".format(sig=self)
try:
ast.parse(function_def_str)
Expand Down Expand Up @@ -189,6 +251,9 @@ def replace_typing_types(match):
return "typing." + match.group('type').capitalize()


# If true, parse BOOST_PYTHON signature
USE_BOOST_PYTHON = False

class StubsGenerator(object):
INDENT = " " * 4

Expand Down Expand Up @@ -696,12 +761,18 @@ def parse(self):
logger.debug("Skip '%s' module while parsing '%s' " % (m.module.__name__, self.module.__name__))
elif inspect.isbuiltin(member) or inspect.isfunction(member):
self.free_functions.append(FreeFunctionStubsGenerator(name, member, self.module.__name__))
elif type(member) is type:
logger.debug("Skip '%s' type while parsing '%s' " % (name, self.module.__name__))
pass
elif inspect.isclass(member):
if member.__module__ == self.module.__name__:
if member.__name__ not in self.class_name_blacklist:
self.classes.append(ClassStubsGenerator(member))
else:
self.imported_classes[name] = member
importlib.import_module(member.__module__)
self.classes.append(ClassStubsGenerator(member))
self.classes[-1].parse()
elif name == "__doc__":
self.doc_string = member
elif name not in self.attributes_blacklist:
Expand Down Expand Up @@ -762,15 +833,20 @@ def to_lines(self): # type: () -> List[str]
"import typing"
]

for name, class_ in self.imported_classes.items():
class_name = getattr(class_, "__qualname__", class_.__name__)
if name == class_name:
suffix = ""
else:
suffix = " as {}".format(name)
result += [
'from {} import {}{}'.format(class_.__module__, class_name, suffix)
]
globals_ = {}
exec("from {} import *".format(self.module.__name__), globals_)

result += [""]
all_ = set(globals_.keys()) - {"__builtins__"}
result.append("__all__ = [\n " + ",\n ".join(map(lambda s: '"%s"' % s, sorted(all_))) + "\n]\n")

for x in itertools.chain(self.classes,
self.free_functions):
result.extend(x.to_lines())
result += [""]

for x in itertools.chain(self.attributes):
result.extend(x.to_lines())

# import used packages
used_modules = sorted(self.get_involved_modules_names())
Expand All @@ -787,15 +863,7 @@ def to_lines(self): # type: () -> List[str]
# add space between imports and rest of module
result += [""]

globals_ = {}
exec("from {} import *".format(self.module.__name__), globals_)
all_ = set(globals_.keys()) - {"__builtins__"}
result.append("__all__ = [\n " + ",\n ".join(map(lambda s: '"%s"' % s, sorted(all_))) + "\n]\n\n")

for x in itertools.chain(self.classes,
self.free_functions,
self.attributes):
result.extend(x.to_lines())
result.append("") # Newline at EOF
return result

Expand All @@ -809,7 +877,7 @@ def write(self):
os.mkdir(self.short_name + self.stub_suffix)

with DirectoryWalkerGuard(self.short_name + self.stub_suffix):
with open("__init__.pyi", "w") as init_pyi:
with open("__init__.pyi", "w", encoding="utf-8") as init_pyi:
init_pyi.write("\n".join(self.to_lines()))
for m in self.submodules:
m.write()
Expand Down Expand Up @@ -869,6 +937,7 @@ def main(args=None):
help="Render `numpy.ndarray` without (non-standardized) bracket-enclosed type and shape info")
parser.add_argument("module_names", nargs="+", metavar="MODULE_NAME", type=str, help="modules names")
parser.add_argument("--log-level", default="INFO", help="Set output log level")
parser.add_argument("--boost-python", action="store_true")

sys_args = parser.parse_args(args or sys.argv[1:])

Expand All @@ -880,6 +949,10 @@ def main(args=None):
global BARE_NUPMY_NDARRAY
BARE_NUPMY_NDARRAY = True

if sys_args.boost_python:
global USE_BOOST_PYTHON
USE_BOOST_PYTHON = True

if 'all' in sys_args.ignore_invalid:
FunctionSignature.ignore_invalid_signature = True
FunctionSignature.ignore_invalid_defaultarg = True
Expand Down