Skip to content

Commit

Permalink
fix: Fix extensions loader
Browse files Browse the repository at this point in the history
  • Loading branch information
pawamoy committed Nov 28, 2021
1 parent 75a8a8b commit 78fb70b
Showing 1 changed file with 41 additions and 27 deletions.
68 changes: 41 additions & 27 deletions src/griffe/agents/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from __future__ import annotations

import ast
import enum
from collections import defaultdict
from inspect import isclass
from types import ModuleType
from typing import TYPE_CHECKING, Any, Sequence, Type, Union

from griffe.agents.base import BaseInspector, BaseVisitor
from griffe.agents.nodes import ObjectNode

if TYPE_CHECKING:
from griffe.agents.inspector import Inspector
Expand Down Expand Up @@ -214,7 +215,44 @@ def after_inspection(self) -> list[InspectorExtension]:
return self._inspectors[When.after_all]


builtin_extensions: dict[str, ModuleType] = {}
builtin_extensions: set[str] = {
"hybrid",
}


def load_extension(extension: str | dict[str, Any] | Extension | Type[Extension]) -> Extension:
"""Load a configured extension.
Parameters:
extension: An extension, with potential configuration options.
Returns:
An extension instance.
"""
if isinstance(extension, (VisitorExtension, InspectorExtension)):
return extension

if isclass(extension) and issubclass(extension, (VisitorExtension, InspectorExtension)): # type: ignore[arg-type]
return extension() # type: ignore[operator]

if isinstance(extension, dict):
import_path, options = next(iter(extension.items()))

else: # we consider it's a string
import_path = str(extension)
options = {}

if import_path in builtin_extensions:
import_path = f"griffe.agents.extensions.{import_path}"

module = __import__(import_path, level=0)
ext_path = import_path.split(".")[1:]
ext_module = module
for part in ext_path:
ext_module = getattr(ext_module, part)

# TODO: handle AttributeError
return ext_module.Extension(**options)


def load_extensions(exts: Sequence[str | dict[str, Any] | Extension | Type[Extension]]) -> Extensions: # noqa: WPS231
Expand All @@ -227,30 +265,6 @@ def load_extensions(exts: Sequence[str | dict[str, Any] | Extension | Type[Exten
An extensions container.
"""
extensions = Extensions()

for extension in exts:
if isinstance(extension, (str, dict)):
if isinstance(extension, str):
if extension in builtin_extensions:
ext_module = builtin_extensions[extension]
else:
ext_module = __import__(extension)
options = {}
elif isinstance(extension, dict):
import_path = next(extension.keys()) # type: ignore[arg-type,call-overload]
if import_path in builtin_extensions:
ext_module = builtin_extensions[import_path]
else:
ext_module = __import__(import_path)
options = next(extension.values()) # type: ignore[arg-type,call-overload]

# TODO: handle AttributeError
extensions.add(ext_module.Extension(**options)) # type: ignore[attr-defined]

elif isinstance(extension, (VisitorExtension, InspectorExtension)):
extensions.add(extension)

elif isclass(extension) and issubclass(extension, (VisitorExtension, InspectorExtension)):
extensions.add(extension())

extensions.add(load_extension(extension))
return extensions

0 comments on commit 78fb70b

Please sign in to comment.