diff --git a/benchmark.sh b/benchmark.sh index b2b7f57e..49abaa76 100644 --- a/benchmark.sh +++ b/benchmark.sh @@ -39,12 +39,12 @@ from griffe.loader import GriffeLoader from griffe.extensions import load_extensions from griffe.extensions import Extension stdlib_packages = sorted([m for m in sys.stdlib_module_names if not m.startswith("_")]) -# extensions = load_extensions([ +# extensions = load_extensions( # Extension, Extension, Extension, Extension, # Extension, Extension, Extension, Extension, # Extension, Extension, Extension, Extension, # Extension, Extension, Extension, Extension, -# ]) +# ) extensions = None loader = GriffeLoader(allow_inspection=False, extensions=extensions) for package in stdlib_packages: diff --git a/docs/extensions.md b/docs/extensions.md index 0c71bbe2..390ff097 100644 --- a/docs/extensions.md +++ b/docs/extensions.md @@ -87,13 +87,11 @@ import griffe from mypackage.extensions import ThisExtension, ThisOtherExtension extensions = griffe.load_extensions( - [ - {"pydantic": {"schema": true}}, - {"scripts/exts.py:DynamicDocstrings": {"paths": ["mypkg.mymod.myobj"]}}, - "griffe_attrs", - ThisExtension(option="value"), - ThisOtherExtension, - ] + {"pydantic": {"schema": true}}, + {"scripts/exts.py:DynamicDocstrings": {"paths": ["mypkg.mymod.myobj"]}}, + "griffe_attrs", + ThisExtension(option="value"), + ThisOtherExtension, ) data = griffe.load("mypackage", extensions=extensions) diff --git a/src/griffe/cli.py b/src/griffe/cli.py index 5000c756..d0a0defd 100644 --- a/src/griffe/cli.py +++ b/src/griffe/cli.py @@ -376,7 +376,7 @@ def dump( search_paths.extend(sys.path) try: - loaded_extensions = load_extensions(extensions or ()) + loaded_extensions = load_extensions(*(extensions or ())) except ExtensionError as error: logger.exception(str(error)) # noqa: TRY401 return 1 @@ -464,7 +464,7 @@ def check( repository = get_repo_root(against_path) try: - loaded_extensions = load_extensions(extensions or ()) + loaded_extensions = load_extensions(*(extensions or ())) except ExtensionError as error: logger.exception(str(error)) # noqa: TRY401 return 1 diff --git a/src/griffe/extensions/base.py b/src/griffe/extensions/base.py index 20d6190f..bd931130 100644 --- a/src/griffe/extensions/base.py +++ b/src/griffe/extensions/base.py @@ -9,7 +9,7 @@ from importlib.util import module_from_spec, spec_from_file_location from inspect import isclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Sequence, Union +from typing import TYPE_CHECKING, Any, Dict, Sequence, Type, Union from griffe.agents.nodes import ast_children, ast_kind from griffe.enumerations import When @@ -236,6 +236,7 @@ def on_package_loaded(self, *, pkg: Module) -> None: ExtensionType = Union[VisitorExtension, InspectorExtension, Extension] +"""All the types that can be passed to `Extensions.add`.""" class Extensions: @@ -455,8 +456,13 @@ def _load_extension( return [ext(**options) for ext in extensions] +LoadableExtension = Union[str, Dict[str, Any], ExtensionType, Type[ExtensionType]] +"""All the types that can be passed to `load_extensions`.""" + + def load_extensions( - exts: Sequence[str | dict[str, Any] | ExtensionType | type[ExtensionType]] | None = None, + # TODO: Only accept LoadableExtension at some point. + *exts: LoadableExtension | Sequence[LoadableExtension], ) -> Extensions: """Load configured extensions. @@ -467,7 +473,22 @@ def load_extensions( An extensions container. """ extensions = Extensions() - for extension in exts or (): + + # TODO: Remove at some point. + all_exts: list[LoadableExtension] = [] + for ext in exts: + if isinstance(ext, (list, tuple)): + warnings.warn( + "Passing multiple extensions as a single list or tuple is deprecated. " + "Please pass them as separate arguments instead.", + DeprecationWarning, + stacklevel=2, + ) + all_exts.extend(ext) + else: + all_exts.append(ext) # type: ignore[arg-type] + + for extension in all_exts: ext = _load_extension(extension) if isinstance(ext, list): extensions.add(*ext) diff --git a/tests/test_extensions.py b/tests/test_extensions.py index 8b3e05de..1949e393 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -94,7 +94,7 @@ def test_loading_extensions(extension: str | dict[str, dict[str, Any]] | Extensi Parameters: extension: Extension specification (parametrized). """ - extensions = load_extensions([extension]) + extensions = load_extensions(extension) loaded: ExtensionTest = extensions._extensions[0] # type: ignore[assignment] # We cannot use isinstance here, # because loading from a filepath drops the parent `tests` package, @@ -115,7 +115,7 @@ class Class: cattr = 1 def method(self): ... """, - extensions=load_extensions([extension]), + extensions=load_extensions(extension), ): pass events = [