diff --git a/pybind11_stubgen/__init__.py b/pybind11_stubgen/__init__.py index ba7d85a6..78d0743f 100644 --- a/pybind11_stubgen/__init__.py +++ b/pybind11_stubgen/__init__.py @@ -85,6 +85,68 @@ def __init__(self, name, args='*args, **kwargs', rtype='None', validate=True): for invalid_default in invalid_defaults: logger.log(lvl, " {}".format(invalid_default)) + if USE_BOOST_PYTHON: + if args: + find_optional_args = re.findall('\[(.*?)\]$', args) + optional_args = None + if find_optional_args: + optional_args = find_optional_args[0] + if optional_args: + nominal_args = args.replace("[" + optional_args + "]","") + else: + nominal_args = args + + num_nominal_args = 0 + if nominal_args: + nominal_args = nominal_args.split(",") + num_nominal_args = len(nominal_args) + + num_optional_args = 0 + if optional_args: + optional_args = optional_args.split("[,") + num_optional_args = len(optional_args) + if num_optional_args > 1: + optional_args[-1] = re.sub(']'*(num_optional_args-1)+'$', '', optional_args[-1]) # Replace at the end + new_args = "" + + if nominal_args: + for k,arg in enumerate(nominal_args): + type_name = re.findall('\((.*?)\)', arg)[0] + arg_name = arg.split(")")[1] + arg_name = arg_name.replace(' ','_') + + new_args += arg_name + ": " + type_name + if k < num_nominal_args-1: + new_args += ", " + + if num_optional_args > 0 and num_nominal_args > 0: + new_args += ", " + + if optional_args and True: + for k,arg in enumerate(optional_args): + # Check for default value + split_arg_equal = arg.split('=',maxsplit=1) + main_arg = split_arg_equal[0] + type_name = re.findall('\((.*?)\)', main_arg)[0] + + arg_name = main_arg.split(")")[1] + arg_name = arg_name.replace(' ','_') + new_args += arg_name + ": " + type_name + optional_value = None + if len(split_arg_equal) > 1: + optional_value = split_arg_equal[1] + new_args += " = " + optional_value + + if k < num_optional_args-1: + new_args += ", " + + new_args = new_args.replace(" ,", ",") + self.args = new_args + args = new_args + + rtype = rtype.split(" :")[0] + self.rtype = rtype + function_def_str = "def {sig.name}({sig.args}) -> {sig.rtype}: ...".format(sig=self) try: ast.parse(function_def_str) @@ -189,6 +251,9 @@ def replace_typing_types(match): return "typing." + match.group('type').capitalize() +# If true, parse BOOST_PYTHON signature +USE_BOOST_PYTHON = False + class StubsGenerator(object): INDENT = " " * 4 @@ -696,12 +761,18 @@ def parse(self): logger.debug("Skip '%s' module while parsing '%s' " % (m.module.__name__, self.module.__name__)) elif inspect.isbuiltin(member) or inspect.isfunction(member): self.free_functions.append(FreeFunctionStubsGenerator(name, member, self.module.__name__)) + elif type(member) is type: + logger.debug("Skip '%s' type while parsing '%s' " % (name, self.module.__name__)) + pass elif inspect.isclass(member): if member.__module__ == self.module.__name__: if member.__name__ not in self.class_name_blacklist: self.classes.append(ClassStubsGenerator(member)) else: self.imported_classes[name] = member + importlib.import_module(member.__module__) + self.classes.append(ClassStubsGenerator(member)) + self.classes[-1].parse() elif name == "__doc__": self.doc_string = member elif name not in self.attributes_blacklist: @@ -762,15 +833,20 @@ def to_lines(self): # type: () -> List[str] "import typing" ] - for name, class_ in self.imported_classes.items(): - class_name = getattr(class_, "__qualname__", class_.__name__) - if name == class_name: - suffix = "" - else: - suffix = " as {}".format(name) - result += [ - 'from {} import {}{}'.format(class_.__module__, class_name, suffix) - ] + globals_ = {} + exec("from {} import *".format(self.module.__name__), globals_) + + result += [""] + all_ = set(globals_.keys()) - {"__builtins__"} + result.append("__all__ = [\n " + ",\n ".join(map(lambda s: '"%s"' % s, sorted(all_))) + "\n]\n") + + for x in itertools.chain(self.classes, + self.free_functions): + result.extend(x.to_lines()) + result += [""] + + for x in itertools.chain(self.attributes): + result.extend(x.to_lines()) # import used packages used_modules = sorted(self.get_involved_modules_names()) @@ -787,15 +863,7 @@ def to_lines(self): # type: () -> List[str] # add space between imports and rest of module result += [""] - globals_ = {} - exec("from {} import *".format(self.module.__name__), globals_) - all_ = set(globals_.keys()) - {"__builtins__"} - result.append("__all__ = [\n " + ",\n ".join(map(lambda s: '"%s"' % s, sorted(all_))) + "\n]\n\n") - for x in itertools.chain(self.classes, - self.free_functions, - self.attributes): - result.extend(x.to_lines()) result.append("") # Newline at EOF return result @@ -809,7 +877,7 @@ def write(self): os.mkdir(self.short_name + self.stub_suffix) with DirectoryWalkerGuard(self.short_name + self.stub_suffix): - with open("__init__.pyi", "w") as init_pyi: + with open("__init__.pyi", "w", encoding="utf-8") as init_pyi: init_pyi.write("\n".join(self.to_lines())) for m in self.submodules: m.write() @@ -869,6 +937,7 @@ def main(args=None): help="Render `numpy.ndarray` without (non-standardized) bracket-enclosed type and shape info") parser.add_argument("module_names", nargs="+", metavar="MODULE_NAME", type=str, help="modules names") parser.add_argument("--log-level", default="INFO", help="Set output log level") + parser.add_argument("--boost-python", action="store_true") sys_args = parser.parse_args(args or sys.argv[1:]) @@ -880,6 +949,10 @@ def main(args=None): global BARE_NUPMY_NDARRAY BARE_NUPMY_NDARRAY = True + if sys_args.boost_python: + global USE_BOOST_PYTHON + USE_BOOST_PYTHON = True + if 'all' in sys_args.ignore_invalid: FunctionSignature.ignore_invalid_signature = True FunctionSignature.ignore_invalid_defaultarg = True