Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Bot.load_extensions rework #796

Open
wants to merge 36 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
67bea2f
refactor!: rewrite `Bot.load_extensions` and corresponding utils method
shiftinv Aug 30, 2022
60e7622
feat: yield loaded extension names
shiftinv Aug 30, 2022
0c9182b
fix: catch `ValueError` raised by `resolve_name` on 3.8
shiftinv Aug 30, 2022
8b8db8f
fix: change exception types
shiftinv Aug 30, 2022
5e2ada4
feat: add `return_exceptions` parameter
shiftinv Oct 9, 2022
42ccd93
feat: discover all ext names first
shiftinv Oct 9, 2022
82300de
chore: rename `walk_extensions` to `walk_modules`
shiftinv Oct 9, 2022
9c4e784
docs: improve `walk_modules` docs
shiftinv Oct 9, 2022
f3b24db
fix: update `load_extensions` return type
shiftinv Oct 9, 2022
33819be
fix: make `walk_modules:ignore` less complex
shiftinv Oct 9, 2022
2b7d78a
feat: avoid duplicate `__path__` entries
shiftinv Oct 9, 2022
a09045e
fix(test): update walk_modules tests
shiftinv Oct 9, 2022
ea069ae
fix: un-deprecate `load_extensions(<path>)`
shiftinv Oct 9, 2022
2ff26be
docs: improve exception documentation
shiftinv Oct 9, 2022
be6122e
fix: use proper error type
shiftinv Oct 9, 2022
3610789
feat: update `test_bot.__main__`
shiftinv Oct 9, 2022
460cdf3
fix: update module not found exception
shiftinv Oct 10, 2022
3084698
feat: add `load_callback` parameter, remove iterator
shiftinv Oct 10, 2022
7690556
refactor: move to separate find_extensions method for easier customiz…
shiftinv Oct 10, 2022
08033a7
chore: use list instead of sequence
shiftinv Oct 10, 2022
2f28f20
docs: mention find_extensions for customization
shiftinv Oct 10, 2022
b402320
chore(test): move directory utils to separate file
shiftinv Oct 10, 2022
e83249c
test: add some simple tests
shiftinv Oct 10, 2022
3be2be5
fix: python 3.8 moment
shiftinv Oct 10, 2022
413a5a3
docs: add changelog entries
shiftinv Oct 10, 2022
c6a779b
revert: remove `return_exceptions`, unnecessary complexity
shiftinv Oct 11, 2022
e159c7e
chore: move to separate documentation
shiftinv Oct 11, 2022
1a3b309
fix: reinstate `utils.search_directory` for now
shiftinv Oct 11, 2022
04b0bb3
Merge branch 'master' into refactor/load-extensions
shiftinv Oct 21, 2022
7ceaa17
Merge remote-tracking branch 'upstream/master' into refactor/load-ext…
shiftinv Nov 15, 2022
bece2cf
feat: use tuple instead of list in `find_extensions`
shiftinv Nov 15, 2022
b008329
chore: invert condition to save one indent level
shiftinv Nov 15, 2022
5db6b12
Merge remote-tracking branch 'upstream/master' into refactor/load-ext…
shiftinv Jan 25, 2023
3f14848
Merge branch 'master' into refactor/load-extensions
onerandomusername Feb 22, 2023
d3a1127
Merge remote-tracking branch 'upstream/master' into refactor/load-ext…
shiftinv Apr 4, 2023
4ed28a6
chore: resolve lint issues
shiftinv Apr 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/796.deprecate.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:func:`disnake.utils.search_directory` will be removed in a future version, in favor of :func:`disnake.ext.commands.Bot.find_extensions` which is the most common usecase and is more consistent.
3 changes: 3 additions & 0 deletions changelog/796.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
|commands| Improve :func:`Bot.load_extensions <ext.commands.Bot.load_extensions>`, add :func:`Bot.find_extensions <ext.commands.Bot.find_extensions>`.
- Better support for more complex extension hierarchies.
- New ``package``, ``ignore``, and ``load_callback`` parameters.
147 changes: 140 additions & 7 deletions disnake/ext/commands/common_bot_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
Callable,
Dict,
Generic,
Iterable,
List,
Mapping,
Optional,
Sequence,
Set,
TypeVar,
Union,
Expand Down Expand Up @@ -472,7 +474,7 @@ def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str)
def _resolve_name(self, name: str, package: Optional[str]) -> str:
try:
return importlib.util.resolve_name(name, package)
except ImportError as e:
except (ValueError, ImportError) as e: # 3.8 raises ValueError instead of ImportError
raise errors.ExtensionNotFound(name) from e

def load_extension(self, name: str, *, package: Optional[str] = None) -> None:
Expand Down Expand Up @@ -623,18 +625,149 @@ def reload_extension(self, name: str, *, package: Optional[str] = None) -> None:
sys.modules.update(modules)
raise

def load_extensions(self, path: str) -> None:
"""Loads all extensions in a directory.
def find_extensions(
self,
root_module: str,
*,
package: Optional[str] = None,
ignore: Optional[Union[Iterable[str], Callable[[str], bool]]] = None,
) -> Sequence[str]:
"""Finds all extensions in a given module, also traversing into sub-packages.

See :ref:`ext_commands_extensions_load` for details on how packages are found.

.. versionadded:: 2.7

.. note::
This imports all *packages* (not all modules) in the given path(s)
to access the ``__path__`` attribute for finding submodules,
unless they are filtered by the ``ignore`` parameter.

Parameters
----------
root_module: :class:`str`
The module/package name to search in, for example ``cogs.admin``.
Also supports paths in the current working directory.
package: Optional[:class:`str`]
The package name to resolve relative imports with.
This is required when ``root_module`` is a relative module name, e.g ``.cogs.admin``.
Defaults to ``None``.
ignore: Optional[Union[Iterable[:class:`str`], Callable[[:class:`str`], :class:`bool`]]]
An iterable of module names to ignore, or a callable that's used for ignoring
modules (where the callable returning ``True`` results in the module being ignored).
Defaults to ``None``, i.e. no modules are ignored.

If it's an iterable, module names that start with any of the given strings will be ignored.

Raises
------
ExtensionError
The given root module could not be found,
or the name of the root module could not be resolved using the provided ``package`` parameter.
ValueError
``root_module`` is a path and outside of the cwd.
TypeError
The ``ignore`` parameter is of an invalid type.
ImportError
A package couldn't be imported.

Returns
-------
Sequence[:class:`str`]
The list of full extension names.
"""
if "/" in root_module or "\\" in root_module:
path = os.path.relpath(root_module)
if ".." in path:
raise ValueError(
"Paths outside the cwd are not supported. Try using the module name instead."
)
root_module = path.replace(os.sep, ".")

# `find_spec` already calls `resolve_name`, but we want our custom error handling here
root_module = self._resolve_name(root_module, package)

if not (spec := importlib.util.find_spec(root_module)):
raise errors.ExtensionError(
f"Unable to find root module '{root_module}'", name=root_module
)

if not (paths := spec.submodule_search_locations):
raise errors.ExtensionError(
f"Module '{root_module}' is not a package", name=root_module
)

return tuple(disnake.utils._walk_modules(paths, prefix=f"{spec.name}.", ignore=ignore))

def load_extensions(
self,
root_module: str,
*,
package: Optional[str] = None,
ignore: Optional[Union[Iterable[str], Callable[[str], bool]]] = None,
load_callback: Optional[Callable[[str], None]] = None,
) -> Union[List[str], List[Union[str, errors.ExtensionError]]]:
"""Loads all extensions in a given module, also traversing into sub-packages.

See :func:`find_extensions` for details.

.. versionadded:: 2.4

.. versionchanged:: 2.7
Now accepts a module name instead of a filesystem path.
Improved package traversal, adding support for more complex extensions
with ``__init__.py`` files.
Also added ``package``, ``ignore``, and ``load_callback`` parameters.

.. note::
For further customization, you may use :func:`find_extensions`:

.. code-block:: python3

for extension_name in bot.find_extensions(...):
... # custom logic
bot.load_extension(extension_name)

Parameters
----------
path: :class:`str`
The path to search for extensions
root_module: :class:`str`
See :func:`find_extensions`.
package: Optional[:class:`str`]
See :func:`find_extensions`.
ignore: Optional[Union[Iterable[:class:`str`], Callable[[:class:`str`], :class:`bool`]]]
See :func:`find_extensions`.
load_callback: Optional[Callable[[:class:`str`], None]]
A callback that gets invoked with the extension name when each extension gets loaded.

Raises
------
ExtensionError
The given root module could not be found,
or the name of the root module could not be resolved using the provided ``package`` parameter.
Other extension-related errors may also be raised
as this method calls :func:`load_extension` on all found extensions.
See :func:`load_extension` for further details on raised exceptions.
ValueError
``root_module`` is a path and outside of the cwd.
TypeError
The ``ignore`` parameter is of an invalid type.
ImportError
A package (not module) couldn't be imported.

Returns
-------
List[:class:`str`]
The list of module names that have been loaded.
"""
for extension in disnake.utils.search_directory(path):
self.load_extension(extension)
ret: List[str] = []

for ext_name in self.find_extensions(root_module, package=package, ignore=ignore):
self.load_extension(ext_name)
ret.append(ext_name)
if load_callback:
load_callback(ext_name)

return ret

@property
def extensions(self) -> Mapping[str, types.ModuleType]:
Expand Down
47 changes: 47 additions & 0 deletions disnake/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import asyncio
import datetime
import functools
import importlib
import json
import os
import pkgutil
Expand Down Expand Up @@ -1275,9 +1276,12 @@ def format_dt(dt: Union[datetime.datetime, float], /, style: TimestampStyle = "f
return f"<t:{int(dt)}:{style}>"


@deprecated("disnake.ext.commands.Bot.find_extensions")
def search_directory(path: str) -> Iterator[str]:
"""Walk through a directory and yield all modules.

.. deprecated:: 2.7

Parameters
----------
path: :class:`str`
Expand Down Expand Up @@ -1311,6 +1315,49 @@ def search_directory(path: str) -> Iterator[str]:
yield prefix + name


# this is similar to pkgutil.walk_packages, but with a few modifications
def _walk_modules(
paths: Iterable[str],
prefix: str = "",
ignore: Optional[Union[Iterable[str], Callable[[str], bool]]] = None,
) -> Iterator[str]:
if isinstance(ignore, str):
raise TypeError("`ignore` must be an iterable of strings or a callable")

if isinstance(ignore, Iterable):
ignore_tup = tuple(ignore)
ignore = lambda path: path.startswith(ignore_tup)
# else, it's already a callable or None

seen: Set[str] = set()

for _, name, ispkg in pkgutil.iter_modules(paths, prefix):
if ignore and ignore(name):
continue

if not ispkg:
yield name
continue

# it's a package here
mod = importlib.import_module(name)

# if this module is a package but also has a `setup` function,
# yield it and don't look for other files in this module
if hasattr(mod, "setup"):
yield name
continue

sub_paths: List[str] = []
for p in mod.__path__ or []:
if p not in seen:
seen.add(p)
sub_paths.append(p)

if sub_paths:
yield from _walk_modules(sub_paths, prefix=f"{name}.", ignore=ignore)


def as_valid_locale(locale: str) -> Optional[str]:
"""Converts the provided locale name to a name that is valid for use with the API,
for example by returning ``en-US`` for ``en_US``.
Expand Down
57 changes: 57 additions & 0 deletions docs/ext/commands/extensions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,60 @@ Although rare, sometimes an extension needs to clean-up or know when it's being

def teardown(bot):
print('I am being unloaded!')

.. _ext_commands_extensions_load:

Loading multiple extensions
-----------------------------

Commonly, you might have a package/folder that contains several modules.
Instead of manually loading them one by one, you can use :meth:`.Bot.load_extensions` to load the entire package in one sweep.

Consider the following directory structure:

.. code-block::

my_bot/
├─── cogs/
│ ├─── admin.py
│ ├─── fun.py
│ └─── other_complex_thing/
│ ├─── __init__.py (contains setup)
│ ├─── data.py
│ └─── models.py
└─── main.py

Now, you could call :meth:`.Bot.load_extension` separately on ``cogs.admin``, ``cogs.fun``, and ``cogs.other_complex_thing``;
however, if you add a new extension, you'd need to once again load it separately.

Instead, you can use ``bot.load_extensions("my_bot.cogs")`` (or ``.load_extensions(".cogs", package=__package__)``)
to load all of them automatically, without any extra work required.

Customization
+++++++++++++++

To adjust the loading process, for example to handle exceptions that may occur, use :meth:`.Bot.find_extensions`.
This is also what :meth:`.Bot.load_extensions` uses internally.

As an example, one could load extensions like this:

.. code-block:: python3

for extension in bot.find_extensions("my_bot.cogs"):
try:
bot.load_extension(extension)
except commands.ExtensionError as e:
logger.warning(f"Failed to load extension {extension}: {e}")

Discovery
+++++++++++

:meth:`.Bot.find_extensions` (and by extension, :meth:`.Bot.load_extensions`) discover modules/extensions
similar to :func:`py:pkgutil.walk_packages`; the given root package name is resolved,
and submodules/-packages are iterated through recursively.

If a package has a ``setup`` function (similar to ``my_bot.cogs.other_complex_thing`` above),
it won't be traversed further, i.e. ``data.py`` and ``models.py`` in the example won't be considered separate extensions.

Namespace packages (see `PEP 420 <https://peps.python.org/pep-0420/>`__) are not supported (other than
the provided root package), meaning every subpackage must have an ``__init__.py`` file.
11 changes: 5 additions & 6 deletions test_bot/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import asyncio
import logging
import os
import sys
import traceback

Expand Down Expand Up @@ -51,10 +50,6 @@ async def on_ready(self) -> None:
)
# fmt: on

def add_cog(self, cog: commands.Cog, *, override: bool = False) -> None:
logger.info("Loading cog %s", cog.qualified_name)
return super().add_cog(cog, override=override)

async def on_command_error(self, ctx: commands.Context, error: commands.CommandError) -> None:
msg = f"Command `{ctx.command}` failed due to `{error}`"
logger.error(msg, exc_info=True)
Expand Down Expand Up @@ -126,5 +121,9 @@ async def on_message_command_error(

if __name__ == "__main__":
bot = TestBot()
bot.load_extensions(os.path.join(__package__, Config.cogs_folder))
bot.load_extensions(
".cogs",
package=__package__,
load_callback=lambda e: logger.info("Loaded extension %s.", e),
)
bot.run(Config.token)
1 change: 1 addition & 0 deletions tests/ext/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: MIT
1 change: 1 addition & 0 deletions tests/ext/commands/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: MIT
45 changes: 45 additions & 0 deletions tests/ext/commands/test_common_bot_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SPDX-License-Identifier: MIT

import asyncio
from pathlib import Path
from typing import Iterator
from unittest import mock

import pytest

from disnake.ext.commands import errors
from disnake.ext.commands.common_bot_base import CommonBotBase

from ... import helpers


class TestExtensions:
@pytest.fixture
def module_root(self, tmpdir: Path) -> Iterator[str]:
with helpers.chdir_module(tmpdir):
yield str(tmpdir)

@pytest.fixture
def bot(self):
with mock.patch.object(asyncio, "get_event_loop", mock.Mock()), mock.patch.object(
CommonBotBase, "_fill_owners", mock.Mock()
):
bot = CommonBotBase()
return bot

def test_find_path_invalid(self, bot: CommonBotBase) -> None:
with pytest.raises(ValueError, match=r"Paths outside the cwd are not supported"):
bot.find_extensions("../../etc/passwd")

def test_find(self, bot: CommonBotBase, module_root: str) -> None:
helpers.create_dirs(module_root, {"test_cogs": {"__init__.py": "", "admin.py": ""}})

assert bot.find_extensions("test_cogs")

with pytest.raises(errors.ExtensionError, match=r"Unable to find root module 'other_cogs'"):
bot.find_extensions("other_cogs")

with pytest.raises(
errors.ExtensionError, match=r"Module 'test_cogs.admin' is not a package"
):
bot.find_extensions(".admin", package="test_cogs")
Loading