diff --git a/src/griffe/cli.py b/src/griffe/cli.py index 41b66ecd..d798d8e8 100644 --- a/src/griffe/cli.py +++ b/src/griffe/cli.py @@ -14,6 +14,7 @@ from __future__ import annotations import argparse +import asyncio import json import logging import sys @@ -22,7 +23,7 @@ from griffe.encoders import Encoder from griffe.extended_ast import extend_ast from griffe.extensions import Extensions -from griffe.loader import GriffeLoader +from griffe.loader import AsyncGriffeLoader, GriffeLoader from griffe.logger import get_logger logger = get_logger(__name__) @@ -36,6 +37,34 @@ def _print_data(data, output_file): print(data, file=fd) +async def _load_packages_async(packages, extensions, search_paths): + loader = AsyncGriffeLoader(extensions=extensions) + loaded = {} + for package in packages: + logger.info(f"Loading package {package}") + try: + module = await loader.load_module(package, search_paths=search_paths) + except ModuleNotFoundError: + logger.error(f"Could not find package {package}") + else: + loaded[module.name] = module + return loaded + + +def _load_packages(packages, extensions, search_paths): + loader = GriffeLoader(extensions=extensions) + loaded = {} + for package in packages: + logger.info(f"Loading package {package}") + try: + module = loader.load_module(package, search_paths=search_paths) + except ModuleNotFoundError: + logger.error(f"Could not find package {package}") + else: + loaded[module.name] = module + return loaded + + def get_parser() -> argparse.ArgumentParser: """ Return the program argument parser. @@ -43,19 +72,26 @@ def get_parser() -> argparse.ArgumentParser: Returns: The argument parser for the program. """ - parser = argparse.ArgumentParser(prog="griffe") + parser = argparse.ArgumentParser(prog="griffe", add_help=False) parser.add_argument( - "-s", - "--search", - action="append", - type=Path, - help="Paths to search packages into.", + "-A", + "--async-loader", + action="store_true", + help="Whether to read files on disk asynchronously. " + "Very large projects with many files will be processed faster. " + "Small projects with a few files will not see any speed up.", ) parser.add_argument( "-a", - "--append-search", + "--append-sys-path", action="store_true", - help="Whether to append sys.path to specified search paths.", + help="Whether to append sys.path to search paths specified with -s.", + ) + parser.add_argument( + "-h", + "--help", + action="help", + help="Show this help message and exit.", ) parser.add_argument( "-o", @@ -63,11 +99,18 @@ def get_parser() -> argparse.ArgumentParser: default=sys.stdout, help="Output file. Supports templating to output each package in its own file, with {{package}}.", ) + parser.add_argument( + "-s", + "--search", + action="append", + type=Path, + help="Paths to search packages into.", + ) parser.add_argument("packages", metavar="PACKAGE", nargs="+", help="Packages to find and parse.") return parser -def main(args: list[str] | None = None) -> int: +def main(args: list[str] | None = None) -> int: # noqa: WPS231 """ Run the main program. @@ -82,7 +125,7 @@ def main(args: list[str] | None = None) -> int: parser = get_parser() opts: argparse.Namespace = parser.parse_args(args) # type: ignore - logging.basicConfig(format="%(levelname)-10s %(message)s", level=logging.INFO) # noqa: WPS323 + logging.basicConfig(format="%(levelname)-10s %(message)s", level=logging.WARNING) # noqa: WPS323 output = opts.output @@ -91,25 +134,18 @@ def main(args: list[str] | None = None) -> int: per_package_output = True search = opts.search - if opts.append_search: + if opts.append_sys_path: search.extend(sys.path) extend_ast() - extensions = Extensions() - loader = GriffeLoader(extensions=extensions) - packages = {} - success = True - for package in opts.packages: - logger.info(f"Loading package {package}") - try: - module = loader.load_module(package, search_paths=search) - except ModuleNotFoundError: - logger.error(f"Could not find package {package}") - success = False - else: - packages[module.name] = module + if opts.async_loader: + loop = asyncio.get_event_loop() + coroutine = _load_packages_async(opts.packages, extensions=extensions, search_paths=search) + packages = loop.run_until_complete(coroutine) + else: + packages = _load_packages(opts.packages, extensions=extensions, search_paths=search) if per_package_output: for package_name, data in packages.items(): @@ -119,4 +155,4 @@ def main(args: list[str] | None = None) -> int: serialized = json.dumps(packages, cls=Encoder, indent=2, full=True) _print_data(serialized, output) - return 0 if success else 1 + return 0 if len(packages) == len(opts.packages) else 1 diff --git a/src/griffe/loader.py b/src/griffe/loader.py index 276c454a..263aaea4 100644 --- a/src/griffe/loader.py +++ b/src/griffe/loader.py @@ -12,9 +12,10 @@ from __future__ import annotations +import asyncio import sys from pathlib import Path -from typing import Iterator +from typing import Iterator, Tuple from griffe.collections import lines_collection from griffe.dataclasses import Module @@ -22,16 +23,31 @@ from griffe.logger import get_logger from griffe.visitor import visit +NamePartsType = Tuple[str, ...] +NamePartsAndPathType = Tuple[NamePartsType, Path] + logger = get_logger(__name__) -class GriffeLoader: - """The griffe loader, allowing to load data from modules. +async def _read_async(path): + async with aopen(path) as fd: + return await fd.read() - Attributes: - extensions: The extensions to use. - """ +async def _fallback_read_async(path): + logger.warning("aiofiles is not installed, fallback to blocking read") + return path.read_text() + + +try: + from aiofiles import open as aopen # type: ignore +except ModuleNotFoundError: + read_async = _fallback_read_async +else: + read_async = _read_async + + +class _BaseGriffeLoader: def __init__(self, extensions: Extensions | None = None) -> None: """Initialize the loader. @@ -40,35 +56,51 @@ def __init__(self, extensions: Extensions | None = None) -> None: """ self.extensions = extensions or Extensions() + def _module_name_and_path( + self, + module: str | Path, + search_paths: list[str | Path] | None = None, + ) -> tuple[str, Path]: + if isinstance(module, Path): + # programatically passed a Path, try only that + module_name, module_path = module_name_path(module) + else: + # passed a string (from CLI or Python code), try both + try: + module_name, module_path = module_name_path(Path(module)) + except FileNotFoundError: + module_name = module + module_path = find_module(module_name, search_paths=search_paths) + return module_name, module_path + + +class GriffeLoader(_BaseGriffeLoader): + """The Griffe loader, allowing to load data from modules. + + Attributes: + extensions: The extensions to use. + """ + def load_module( self, module: str | Path, - recursive: bool = True, + submodules: bool = True, search_paths: list[str | Path] | None = None, ) -> Module: """Load a module. Arguments: module: The module name or path. - recursive: Whether to recurse on the submodules. + submodules: Whether to recurse on the submodules. search_paths: The paths to search into. Returns: A module. """ - if isinstance(module, Path): - # programatically passed a Path, try only that - module_name, module_path = module_name_path(module) - else: - # passed a string (from CLI or Python code), try both - try: - module_name, module_path = module_name_path(Path(module)) - except FileNotFoundError: - module_name = module - module_path = find_module(module_name, search_paths=search_paths) - return self._load_module_path(module_name, module_path, recursive=recursive) + module_name, module_path = self._module_name_and_path(module, search_paths) + return self._load_module_path(module_name, module_path, submodules=submodules) - def _load_module_path(self, module_name, module_path, recursive=True): + def _load_module_path(self, module_name: str, module_path: Path, submodules: bool = True) -> Module: logger.debug(f"Loading path {module_path}") code = module_path.read_text() lines_collection[module_path] = code.splitlines(keepends=False) @@ -78,19 +110,83 @@ def _load_module_path(self, module_name, module_path, recursive=True): code=code, extensions=self.extensions, ) - if recursive: - for subparts, subpath in sorted(iter_submodules(module_path), key=_module_depth): - parent_parts = subparts[:-1] - try: - member_parent = module[parent_parts] - except KeyError: - logger.debug(f"Skipping (not importable) {subpath}") - continue - member_parent[subparts[-1]] = self._load_module_path(subparts[-1], subpath, recursive=False) + if submodules: + self._load_submodules(module) return module + def _load_submodules(self, module: Module) -> None: + for subparts, subpath in sorted(iter_submodules(module.filepath), key=_module_depth): + self._load_submodule(module, subparts, subpath) + + def _load_submodule(self, module: Module, subparts: NamePartsType, subpath: Path) -> None: + parent_parts = subparts[:-1] + try: + member_parent = module[parent_parts] + except KeyError: + logger.debug(f"Skipping (not importable) {subpath}") + else: + member_parent[subparts[-1]] = self._load_module_path(subparts[-1], subpath, submodules=False) + + +class AsyncGriffeLoader(_BaseGriffeLoader): + """The asynchronous Griffe loader, allowing to load data from modules. + + Attributes: + extensions: The extensions to use. + """ + + async def load_module( + self, + module: str | Path, + submodules: bool = True, + search_paths: list[str | Path] | None = None, + ) -> Module: + """Load a module. + + Arguments: + module: The module name or path. + submodules: Whether to recurse on the submodules. + search_paths: The paths to search into. + + Returns: + A module. + """ + module_name, module_path = self._module_name_and_path(module, search_paths) + return await self._load_module_path(module_name, module_path, submodules=submodules) + + async def _load_module_path(self, module_name: str, module_path: Path, submodules: bool = True) -> Module: + logger.debug(f"Loading path {module_path}") + code = await read_async(module_path) + lines_collection[module_path] = code.splitlines(keepends=False) + module = visit( + module_name, + filepath=module_path, + code=code, + extensions=self.extensions, + ) + if submodules: + await self._load_submodules(module) + return module + + async def _load_submodules(self, module: Module) -> None: + await asyncio.gather( + *[ + self._load_submodule(module, subparts, subpath) + for subparts, subpath in sorted(iter_submodules(module.filepath), key=_module_depth) + ] + ) + + async def _load_submodule(self, module: Module, subparts: NamePartsType, subpath: Path) -> None: + parent_parts = subparts[:-1] + try: + member_parent = module[parent_parts] + except KeyError: + logger.debug(f"Skipping (not importable) {subpath}") + else: + member_parent[subparts[-1]] = await self._load_module_path(subparts[-1], subpath, submodules=False) + -def _module_depth(name_parts_and_path): +def _module_depth(name_parts_and_path: NamePartsAndPathType) -> int: return len(name_parts_and_path[0]) @@ -117,12 +213,15 @@ def module_name_path(path: Path) -> tuple[str, Path]: raise FileNotFoundError if path.exists(): if path.stem == "__init__": + if path.parent.is_absolute(): + return path.parent.name, path return path.parent.resolve().name, path return path.stem, path raise FileNotFoundError # credits to @NiklasRosenstein and the docspec project +# TODO: possible optimization by caching elements of search directories def find_module(module_name: str, search_paths: list[str | Path] | None = None) -> Path: """Find a module in a given list of paths or in `sys.path`. @@ -156,7 +255,7 @@ def find_module(module_name: str, search_paths: list[str | Path] | None = None) raise ModuleNotFoundError(module_name) -def iter_submodules(path) -> Iterator[tuple[list[str], Path]]: # noqa: WPS234 +def iter_submodules(path: Path) -> Iterator[NamePartsAndPathType]: # noqa: WPS234 """Iterate on a module's submodules, if any. Arguments: