diff --git a/docs/source/changes.md b/docs/source/changes.md index 3f7ce9b6..c2220e59 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -26,6 +26,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and - {pull}`579` fixes an interaction with `--pdb` and `--trace` and task that return. The debugging modes swallowed the return and `None` was returned. Closes {issue}`574`. - {pull}`581` simplifies the code for tracebacks and unpublishes some utility functions. +- {pull}`585` implements two more import modes, `prepend` and `append`. ## 0.4.6 - 2024-03-13 diff --git a/src/_pytask/build.py b/src/_pytask/build.py index ef3fd857..80063063 100644 --- a/src/_pytask/build.py +++ b/src/_pytask/build.py @@ -27,6 +27,7 @@ from _pytask.exceptions import ResolvingDependenciesError from _pytask.outcomes import ExitCode from _pytask.path import HashPathCache +from _pytask.path import ImportMode from _pytask.pluginmanager import get_plugin_manager from _pytask.pluginmanager import hookimpl from _pytask.pluginmanager import storage @@ -79,6 +80,8 @@ def build( # noqa: C901, PLR0912, PLR0913 expression: str = "", force: bool = False, ignore: Iterable[str] = (), + import_mode: Literal["prepend", "append", "importlib"] + | ImportMode = ImportMode.importlib, marker_expression: str = "", max_failures: float = float("inf"), n_entries_in_table: int = 15, @@ -132,6 +135,8 @@ def build( # noqa: C901, PLR0912, PLR0913 ignore A pattern to ignore files or directories. Refer to ``pathlib.Path.match`` for more info. + import_mode + The import mode used to import task modules. marker_expression Same as ``-m`` on the command line. Select tasks via marker expressions. max_failures @@ -192,6 +197,7 @@ def build( # noqa: C901, PLR0912, PLR0913 "expression": expression, "force": force, "ignore": ignore, + "import_mode": import_mode, "marker_expression": marker_expression, "max_failures": max_failures, "n_entries_in_table": n_entries_in_table, diff --git a/src/_pytask/collect.py b/src/_pytask/collect.py index 1b1f9a2e..21a623b2 100644 --- a/src/_pytask/collect.py +++ b/src/_pytask/collect.py @@ -206,7 +206,9 @@ def pytask_collect_file( ) -> list[CollectionReport] | None: """Collect a file.""" if any(path.match(pattern) for pattern in session.config["task_files"]): - mod = import_path(path, session.config["root"]) + mod = import_path( + path, root=session.config["root"], mode=session.config["import_mode"] + ) collected_reports = [] for name, obj in inspect.getmembers(mod): diff --git a/src/_pytask/config.py b/src/_pytask/config.py index 86423ac3..4f8eaa8c 100644 --- a/src/_pytask/config.py +++ b/src/_pytask/config.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING from typing import Any +from _pytask.path import ImportMode from _pytask.pluginmanager import hookimpl from _pytask.shared import parse_markers from _pytask.shared import parse_paths @@ -110,6 +111,8 @@ def pytask_parse_config(config: dict[str, Any]) -> None: config["pm"].trace.root.setwriter(print) config["pm"].enable_tracing() + config["import_mode"] = ImportMode(config["import_mode"]) + @hookimpl def pytask_post_parse(config: dict[str, Any]) -> None: diff --git a/src/_pytask/parameters.py b/src/_pytask/parameters.py index c3e32f0b..ce18475f 100644 --- a/src/_pytask/parameters.py +++ b/src/_pytask/parameters.py @@ -13,7 +13,9 @@ from sqlalchemy.engine import make_url from sqlalchemy.exc import ArgumentError +from _pytask.click import EnumChoice from _pytask.config_utils import set_defaults_from_config +from _pytask.path import ImportMode from _pytask.path import import_path from _pytask.pluginmanager import hookimpl from _pytask.pluginmanager import register_hook_impls_from_modules @@ -138,7 +140,7 @@ def _hook_module_callback( "Please provide a valid path." ) raise click.BadParameter(msg) - module = import_path(path, ctx.params["root"]) + module = import_path(path, root=ctx.params["root"]) parsed_modules.append(module.__name__) else: spec = importlib.util.find_spec(module_name) @@ -177,12 +179,19 @@ def pytask_add_hooks(pm: PluginManager) -> None: callback=_hook_module_callback, ) +_IMPORT_MODE_OPTION = click.Option( + ["--import-mode"], + type=EnumChoice(ImportMode), + default=ImportMode.importlib, + help="Choose the import mode.", +) + @hookimpl(trylast=True) def pytask_extend_command_line_interface(cli: click.Group) -> None: """Register general markers.""" for command in ("build", "clean", "collect", "dag", "profile"): - cli.commands[command].params.extend((_DATABASE_URL_OPTION,)) + cli.commands[command].params.extend((_DATABASE_URL_OPTION, _IMPORT_MODE_OPTION)) for command in ("build", "clean", "collect", "dag", "markers", "profile"): cli.commands[command].params.extend( (_CONFIG_OPTION, _HOOK_MODULE_OPTION, _PATH_ARGUMENT) diff --git a/src/_pytask/path.py b/src/_pytask/path.py index 05027092..289c26da 100644 --- a/src/_pytask/path.py +++ b/src/_pytask/path.py @@ -5,8 +5,10 @@ import contextlib import functools import importlib.util +import itertools import os import sys +from enum import Enum from pathlib import Path from types import ModuleType from typing import Sequence @@ -15,6 +17,7 @@ from _pytask.cache import Cache __all__ = [ + "ImportMode", "find_case_sensitive_path", "find_closest_ancestor", "find_common_ancestor", @@ -122,87 +125,356 @@ def find_case_sensitive_path(path: Path, platform: str) -> Path: return path.resolve() if platform == "win32" else path -def import_path(path: Path, root: Path) -> ModuleType: +######################################################################################## +# The following code is copied from pytest and adapted to be used in pytask. + + +class ImportMode(Enum): + """Possible values for `mode` parameter of `import_path`.""" + + prepend = "prepend" + append = "append" + importlib = "importlib" + + +class ImportPathMismatchError(ImportError): + """Raised on import_path() if there is a mismatch of __file__'s. + + This can happen when `import_path` is called multiple times with different filenames + that has the same basename but reside in packages (for example "/tests1/test_foo.py" + and "/tests2/test_foo.py"). + """ + + +def import_path( # noqa: C901, PLR0912, PLR0915 + path: str | os.PathLike[str], + *, + mode: str | ImportMode = ImportMode.importlib, + root: Path, + consider_namespace_packages: bool = False, +) -> ModuleType: """Import and return a module from the given path. - The function is taken from pytest when the import mode is set to ``importlib``. It - pytest's recommended import mode for new projects although the default is set to - ``prepend``. More discussion and information can be found in :issue:`373`. + The module can be a file (a module) or a directory (a package). + + :param path: + Path to the file to import. + + :param mode: + Controls the underlying import mechanism that will be used: + + * ImportMode.prepend: the directory containing the module (or package, taking + `__init__.py` files into account) will be put at the *start* of `sys.path` + before being imported with `importlib.import_module`. + + * ImportMode.append: same as `prepend`, but the directory will be appended to + the end of `sys.path`, if not already in `sys.path`. + * ImportMode.importlib: uses more fine control mechanisms provided by + `importlib` to import the module, which avoids having to muck with `sys.path` + at all. It effectively allows having same-named test modules in different + places. + + :param root: + Used as an anchor when mode == ImportMode.importlib to obtain a unique name for + the module being imported so it can safely be stored into ``sys.modules``. + + :param consider_namespace_packages: + If True, consider namespace packages when resolving module names. + + :raises ImportPathMismatchError: + If after importing the given `path` and the module `__file__` are different. + Only raised in `prepend` and `append` modes. """ - module_name = _module_name_from_path(path, root) - with contextlib.suppress(KeyError): - return sys.modules[module_name] + path = Path(path) + mode = ImportMode(mode) + + if not path.exists(): + raise ImportError(path) + + if mode is ImportMode.importlib: + # Try to import this module using the standard import mechanisms, but + # without touching sys.path. + try: + pkg_root, module_name = resolve_pkg_root_and_module_name( + path, consider_namespace_packages=consider_namespace_packages + ) + except CouldNotResolvePathError: + pass + else: + # If the given module name is already in sys.modules, do not import it + # again. + with contextlib.suppress(KeyError): + return sys.modules[module_name] + + mod = _import_module_using_spec( + module_name, path, pkg_root, insert_modules=False + ) + if mod is not None: + return mod + + # Could not import the module with the current sys.path, so we fall back + # to importing the file as a single module, not being a part of a package. + module_name = _module_name_from_path(path, root) + with contextlib.suppress(KeyError): + return sys.modules[module_name] + + mod = _import_module_using_spec( + module_name, path, path.parent, insert_modules=True + ) + if mod is None: + msg = f"Can't find module {module_name} at location {path}" + raise ImportError(msg) + return mod + + try: + pkg_root, module_name = resolve_pkg_root_and_module_name( + path, consider_namespace_packages=consider_namespace_packages + ) + except CouldNotResolvePathError: + pkg_root, module_name = path.parent, path.stem + + # Change sys.path permanently: restoring it at the end of this function would cause + # surprising problems because of delayed imports: for example, a conftest.py file + # imported by this function might have local imports, which would fail at runtime if + # we restored sys.path. + if mode is ImportMode.append: + if str(pkg_root) not in sys.path: + sys.path.append(str(pkg_root)) + elif mode is ImportMode.prepend: + if str(pkg_root) != sys.path[0]: + sys.path.insert(0, str(pkg_root)) + else: + msg = f"Unknown mode: {mode}" + raise ValueError(msg) + + importlib.import_module(module_name) + + mod = sys.modules[module_name] + if path.name == "__init__.py": + return mod - spec = importlib.util.spec_from_file_location(module_name, str(path)) + ignore = os.environ.get("PY_IGNORE_IMPORTMISMATCH", "") + if ignore != "1": + module_file = mod.__file__ + if module_file is None: + raise ImportPathMismatchError(module_name, module_file, path) - if spec is None: - msg = f"Can't find module {module_name!r} at location {path}." - raise ImportError(msg) + if module_file.endswith((".pyc", ".pyo")): + module_file = module_file[:-1] + if module_file.endswith(os.sep + "__init__.py"): + module_file = module_file[: -(len(os.sep + "__init__.py"))] + + try: + is_same = _is_same(str(path), module_file) + except FileNotFoundError: + is_same = False + + if not is_same: + raise ImportPathMismatchError(module_name, module_file, path) - mod = importlib.util.module_from_spec(spec) - sys.modules[module_name] = mod - spec.loader.exec_module(mod) # type: ignore[union-attr] - _insert_missing_modules(sys.modules, module_name) return mod -def _module_name_from_path(path: Path, root: Path) -> str: - """Return a dotted module name based on the given path, anchored on root. +def _import_module_using_spec( + module_name: str, module_path: Path, module_location: Path, *, insert_modules: bool +) -> ModuleType | None: + """Import a module using the spec. + + Try to import a module by its canonical name, path to the .py file, and its parent + location. + + :param insert_modules: + If True, will call insert_missing_modules to create empty intermediate modules + for made-up module names (when importing test files not reachable from + sys.path). Note: we can probably drop insert_missing_modules altogether: instead + of generating module names such as "src.tests.test_foo", which require + intermediate empty modules, we might just as well generate unique module names + like "src_tests_test_foo". + """ + # Checking with sys.meta_path first in case one of its hooks can import this module, + # such as our own assertion-rewrite hook. + for meta_importer in sys.meta_path: + spec = meta_importer.find_spec(module_name, [str(module_location)]) + if spec is not None: + break + else: + spec = importlib.util.spec_from_file_location(module_name, str(module_path)) + if spec is not None: + mod = importlib.util.module_from_spec(spec) + sys.modules[module_name] = mod + spec.loader.exec_module(mod) # type: ignore[union-attr] + if insert_modules: + _insert_missing_modules(sys.modules, module_name) + return mod + + return None + + +# Implement a special _is_same function on Windows which returns True if the two +# filenames compare equal, to circumvent os.path.samefile returning False for mounts in +# UNC (#7678). +if sys.platform.startswith("win"): + + def _is_same(f1: str, f2: str) -> bool: + return Path(f1) == Path(f2) or os.path.samefile(f1, f2) # noqa: PTH121 + +else: - For example: path="projects/src/project/task_foo.py" and root="/projects", the - resulting module name will be "src.project.task_foo". + def _is_same(f1: str, f2: str) -> bool: + return os.path.samefile(f1, f2) # noqa: PTH121 + +def _module_name_from_path(path: Path, root: Path) -> str: + """ + Return a dotted module name based on the given path, anchored on root. + + For example: path="projects/src/tests/test_foo.py" and root="/projects", the + resulting module name will be "src.tests.test_foo". """ path = path.with_suffix("") try: relative_path = path.relative_to(root) except ValueError: - # If we can't get a relative path to root, use the full path, except for the - # first part ("d:\\" or "/" depending on the platform, for example). + # If we can't get a relative path to root, use the full path, except + # for the first part ("d:\\" or "/" depending on the platform, for example). path_parts = path.parts[1:] else: # Use the parts for the relative path to the root path. path_parts = relative_path.parts - # Module name for packages do not contain the __init__ file, unless the - # `__init__.py` file is at the root. + # Module name for packages do not contain the __init__ file, unless + # the `__init__.py` file is at the root. if len(path_parts) >= 2 and path_parts[-1] == "__init__": # noqa: PLR2004 path_parts = path_parts[:-1] + # Module names cannot contain ".", normalize them to "_". This prevents a directory + # having a "." in the name (".env.310" for example) causing extra intermediate + # modules. Also, important to replace "." at the start of paths, as those are + # considered relative imports. + path_parts = tuple(x.replace(".", "_") for x in path_parts) + return ".".join(path_parts) def _insert_missing_modules(modules: dict[str, ModuleType], module_name: str) -> None: - """Insert missing modules when importing modules with :func:`import_path`. + """Insert missing modules in sys.modules. - When we want to import a module as ``src.project.task_foo`` for example, we need to - create empty modules ``src`` and ``src.project`` after inserting - ``src.project.task_foo``, otherwise ``src.project.task_foo`` is not importable by - ``__import__``. + Used by ``import_path`` to create intermediate modules when using mode=importlib. + When we want to import a module as "src.tests.test_foo" for example, we need + to create empty modules "src" and "src.tests" after inserting "src.tests.test_foo", + otherwise "src.tests.test_foo" is not importable by ``__import__``. """ module_parts = module_name.split(".") + child_module: ModuleType | None = None + module: ModuleType | None = None + child_name: str = "" while module_name: if module_name not in modules: try: - # If sys.meta_path is empty, calling import_module will issue a warning - # and raise ModuleNotFoundError. To avoid the warning, we check - # sys.meta_path explicitly and raise the error ourselves to fall back to - # creating a dummy module. + # If sys.meta_path is empty, calling import_module will issue + # a warning and raise ModuleNotFoundError. To avoid the + # warning, we check sys.meta_path explicitly and raise the error + # ourselves to fall back to creating a dummy module. if not sys.meta_path: raise ModuleNotFoundError - importlib.import_module(module_name) + module = importlib.import_module(module_name) except ModuleNotFoundError: module = ModuleType( module_name, - doc="Empty module created by pytask.", + doc="Empty module created by pytest's importmode=importlib.", ) - modules[module_name] = module + else: + module = modules[module_name] + # Add child attribute to the parent that can reference the child modules. + if child_module and not hasattr(module, child_name): + setattr(module, child_name, child_module) + modules[module_name] = module + # Keep track of the child module while moving up the tree. + child_module, child_name = module, module_name.rpartition(".")[-1] module_parts.pop(-1) module_name = ".".join(module_parts) +def resolve_package_path(path: Path) -> Path | None: + """Resolve the package path. + + Return the Python package path by looking for the last directory upwards which still + contains an __init__.py. + + Returns None if it can not be determined. + """ + result = None + for parent in itertools.chain((path,), path.parents): + if parent.is_dir(): + if not (parent / "__init__.py").is_file(): + break + if not parent.name.isidentifier(): + break + result = parent + return result + + +def resolve_pkg_root_and_module_name( + path: Path, *, consider_namespace_packages: bool = False +) -> tuple[Path, str]: + """Resolve package root and module name. + + Return the path to the directory of the root package that contains the given Python + file, and its module name: + + .. codeblock:: + src/ + app/ + __init__.py + core/ + __init__.py + models.py + + Passing the full path to `models.py` will yield Path("src") and "app.core.models". + + If consider_namespace_packages is True, then we additionally check upwards in the + hierarchy until we find a directory that is reachable from sys.path, which marks it + as a namespace package: + + https://packaging.python.org/en/latest/guides/packaging-namespace-packages + + Raises CouldNotResolvePathError if the given path does not belong to a package + (missing any __init__.py files). + """ + pkg_path = resolve_package_path(path) + if pkg_path is not None: + pkg_root = pkg_path.parent + # https://packaging.python.org/en/latest/guides/packaging-namespace-packages/ + if consider_namespace_packages: + # Go upwards in the hierarchy, if we find a parent path included + # in sys.path, it means the package found by resolve_package_path() + # actually belongs to a namespace package. + for parent in pkg_root.parents: + # If any of the parent paths has a __init__.py, it means it is not + # a namespace package (see the docs linked above). + if (parent / "__init__.py").is_file(): + break + if str(parent) in sys.path: + # Point the pkg_root to the root of the namespace package. + pkg_root = parent + break + + names = list(path.with_suffix("").relative_to(pkg_root).parts) + if names[-1] == "__init__": + names.pop() + module_name = ".".join(names) + return pkg_root, module_name + + msg = f"Could not resolve for {path}" + raise CouldNotResolvePathError(msg) + + +class CouldNotResolvePathError(Exception): + """Custom exception raised by resolve_pkg_root_and_module_name.""" + + def shorten_path(path: Path, paths: Sequence[Path]) -> str: """Shorten a path. @@ -223,6 +495,9 @@ def shorten_path(path: Path, paths: Sequence[Path]) -> str: return relative_to(path, ancestor).as_posix() +######################################################################################## + + HashPathCache = Cache() diff --git a/tests/conftest.py b/tests/conftest.py index cf1f65d8..4b823187 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,7 +29,7 @@ def _remove_variable_info_from_output(data: str, path: Any) -> str: # noqa: ARG # Remove dynamic versions. index_root = next(i for i, line in enumerate(lines) if line.startswith("Root:")) - new_info_line = " ".join(lines[1:index_root]) + new_info_line = "".join(lines[1:index_root]) for platform in ("linux", "win32", "darwin"): new_info_line = new_info_line.replace(platform, "") pattern = re.compile(version.VERSION_PATTERN, flags=re.IGNORECASE | re.VERBOSE) diff --git a/tests/test_path.py b/tests/test_path.py index e7df4389..e8cf48d4 100644 --- a/tests/test_path.py +++ b/tests/test_path.py @@ -10,12 +10,16 @@ from typing import Any import pytest +from _pytask.path import CouldNotResolvePathError +from _pytask.path import ImportMode from _pytask.path import _insert_missing_modules from _pytask.path import _module_name_from_path from _pytask.path import find_case_sensitive_path from _pytask.path import find_closest_ancestor from _pytask.path import find_common_ancestor from _pytask.path import relative_to +from _pytask.path import resolve_pkg_root_and_module_name +from pytask import cli from pytask.path import import_path @@ -178,183 +182,377 @@ def test_no_meta_path_found( import_path(simple_module, root=tmp_path) -@pytest.mark.unit() -def test_importmode_importlib_with_dataclass(tmp_path: Path) -> None: +@pytest.fixture(params=[True, False]) +def ns_param(request: pytest.FixtureRequest) -> bool: """ - Ensure that importlib mode works with a module containing dataclasses (#373, - pytest#7856). + Simple parametrized fixture for tests which call import_path() with + consider_namespace_packages using True and False. """ - fn = tmp_path.joinpath("_src/project/task_dataclass.py") - fn.parent.mkdir(parents=True) - fn.write_text( - textwrap.dedent( - """ - from dataclasses import dataclass - - @dataclass - class Data: - value: str - """ - ) - ) - - module = import_path(fn, root=tmp_path) - Data: Any = module.Data # noqa: N806 - data = Data(value="foo") - assert data.value == "foo" - assert data.__module__ == "_src.project.task_dataclass" + return bool(request.param) @pytest.mark.unit() -def test_importmode_importlib_with_pickle(tmp_path: Path) -> None: - """Ensure that importlib mode works with pickle (#373, pytest#7859).""" - fn = tmp_path.joinpath("_src/project/task_pickle.py") - fn.parent.mkdir(parents=True) - fn.write_text( - textwrap.dedent( - """ - import pickle - - def _action(): - return 42 - - def round_trip(): - s = pickle.dumps(_action) - return pickle.loads(s) - """ +class TestImportLibMode: + def test_importmode_importlib_with_dataclass( + self, tmp_path: Path, ns_param: bool + ) -> None: + """ + Ensure that importlib mode works with a module containing dataclasses (#373, + pytest#7856). + """ + fn = tmp_path.joinpath("_src/project/task_dataclass.py") + fn.parent.mkdir(parents=True) + fn.write_text( + textwrap.dedent( + """ + from dataclasses import dataclass + + @dataclass + class Data: + value: str + """ + ) ) - ) - module = import_path(fn, root=tmp_path) - round_trip = module.round_trip - action = round_trip() - assert action() == 42 + module = import_path(fn, mode="importlib", root=tmp_path) + Data: Any = module.Data # noqa: N806 + data = Data(value="foo") + assert data.value == "foo" + assert data.__module__ == "_src.project.task_dataclass" - -@pytest.mark.unit() -def test_importmode_importlib_with_pickle_separate_modules(tmp_path: Path) -> None: - """ - Ensure that importlib mode works can load pickles that look similar but are - defined in separate modules. - """ - fn1 = tmp_path.joinpath("_src/m1/project/task.py") - fn1.parent.mkdir(parents=True) - fn1.write_text( - textwrap.dedent( - """ - import dataclasses - import pickle - - @dataclasses.dataclass - class Data: - x: int = 42 - """ + # Ensure we do not import the same module again (pytest#11475). + module2 = import_path( + fn, mode="importlib", root=tmp_path, consider_namespace_packages=ns_param ) - ) - - fn2 = tmp_path.joinpath("_src/m2/project/task.py") - fn2.parent.mkdir(parents=True) - fn2.write_text( - textwrap.dedent( - """ - import dataclasses - import pickle - - @dataclasses.dataclass - class Data: - x: str = "" - """ + assert module is module2 + + def test_importmode_importlib_with_pickle( + self, tmp_path: Path, ns_param: bool + ) -> None: + """Ensure that importlib mode works with pickle (#373, pytest#7859).""" + fn = tmp_path.joinpath("_src/project/task_pickle.py") + fn.parent.mkdir(parents=True) + fn.write_text( + textwrap.dedent( + """ + import pickle + + def _action(): + return 42 + + def round_trip(): + s = pickle.dumps(_action) + return pickle.loads(s) + """ + ) ) - ) - - import pickle - - def round_trip(obj): - s = pickle.dumps(obj) - return pickle.loads(s) # noqa: S301 - - module = import_path(fn1, root=tmp_path) - Data1 = module.Data # noqa: N806 - - module = import_path(fn2, root=tmp_path) - Data2 = module.Data # noqa: N806 - - assert round_trip(Data1(20)) == Data1(20) - assert round_trip(Data2("hello")) == Data2("hello") - assert Data1.__module__ == "_src.m1.project.task" - assert Data2.__module__ == "_src.m2.project.task" + module = import_path( + fn, mode="importlib", root=tmp_path, consider_namespace_packages=ns_param + ) + round_trip = module.round_trip + action = round_trip() + assert action() == 42 + + def test_importmode_importlib_with_pickle_separate_modules( + self, tmp_path: Path, ns_param: bool + ) -> None: + """ + Ensure that importlib mode works can load pickles that look similar but are + defined in separate modules. + """ + fn1 = tmp_path.joinpath("_src/m1/project/task.py") + fn1.parent.mkdir(parents=True) + fn1.write_text( + textwrap.dedent( + """ + import dataclasses + import pickle + + @dataclasses.dataclass + class Data: + x: int = 42 + """ + ) + ) -@pytest.mark.unit() -def test_module_name_from_path(tmp_path: Path) -> None: - result = _module_name_from_path(tmp_path / "src/project/task_foo.py", tmp_path) - assert result == "src.project.task_foo" + fn2 = tmp_path.joinpath("_src/m2/project/task.py") + fn2.parent.mkdir(parents=True) + fn2.write_text( + textwrap.dedent( + """ + import dataclasses + import pickle + + @dataclasses.dataclass + class Data: + x: str = "" + """ + ) + ) - # Path is not relative to root dir: use the full path to obtain the module name. - result = _module_name_from_path(Path("/home/foo/task_foo.py"), Path("/bar")) - assert result == "home.foo.task_foo" + import pickle - # Importing __init__.py files should return the package as module name. - result = _module_name_from_path(tmp_path / "src/app/__init__.py", tmp_path) - assert result == "src.app" + def round_trip(obj): + s = pickle.dumps(obj) + return pickle.loads(s) # noqa: S301 + module = import_path( + fn1, mode="importlib", root=tmp_path, consider_namespace_packages=ns_param + ) + Data1 = module.Data # noqa: N806 -@pytest.mark.unit() -def test_insert_missing_modules( - monkeypatch: pytest.MonkeyPatch, tmp_path: Path -) -> None: - monkeypatch.chdir(tmp_path) - # Use 'xxx' and 'xxy' as parent names as they are unlikely to exist and - # don't end up being imported. - modules = {"xxx.project.foo": ModuleType("xxx.project.foo")} - _insert_missing_modules(modules, "xxx.project.foo") - assert sorted(modules) == ["xxx", "xxx.project", "xxx.project.foo"] + module = import_path( + fn2, mode="importlib", root=tmp_path, consider_namespace_packages=ns_param + ) + Data2 = module.Data # noqa: N806 + + assert round_trip(Data1(20)) == Data1(20) + assert round_trip(Data2("hello")) == Data2("hello") + assert Data1.__module__ == "_src.m1.project.task" + assert Data2.__module__ == "_src.m2.project.task" + + def test_module_name_from_path(self, tmp_path: Path) -> None: + result = _module_name_from_path(tmp_path / "src/project/task_foo.py", tmp_path) + assert result == "src.project.task_foo" + + # Path is not relative to root dir: use the full path to obtain the module name. + result = _module_name_from_path(Path("/home/foo/task_foo.py"), Path("/bar")) + assert result == "home.foo.task_foo" + + # Importing __init__.py files should return the package as module name. + result = _module_name_from_path(tmp_path / "src/app/__init__.py", tmp_path) + assert result == "src.app" + + # Unless __init__.py file is at the root, in which case we cannot have an empty + # module name. + result = _module_name_from_path(tmp_path / "__init__.py", tmp_path) + assert result == "__init__" + + # Modules which start with "." are considered relative and will not be imported + # unless part of a package, so we replace it with a "_" when generating the fake + # module name. + result = _module_name_from_path(tmp_path / ".env/tasks/task_foo.py", tmp_path) + assert result == "_env.tasks.task_foo" + + # We want to avoid generating extra intermediate modules if some directory just + # happens to contain a "." in the name. + result = _module_name_from_path( + tmp_path / ".env.310/tasks/task_foo.py", tmp_path + ) + assert result == "_env_310.tasks.task_foo" + + def test_resolve_pkg_root_and_module_name( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + # Create a directory structure first without __init__.py files. + (tmp_path / "src/app/core").mkdir(parents=True) + models_py = tmp_path / "src/app/core/models.py" + models_py.touch() + with pytest.raises(CouldNotResolvePathError): + _ = resolve_pkg_root_and_module_name(models_py) + + # Create the __init__.py files, it should now resolve to a proper module name. + (tmp_path / "src/app/__init__.py").touch() + (tmp_path / "src/app/core/__init__.py").touch() + assert resolve_pkg_root_and_module_name( + models_py, consider_namespace_packages=True + ) == (tmp_path / "src", "app.core.models") + + # If we add tmp_path to sys.path, src becomes a namespace package. + monkeypatch.syspath_prepend(tmp_path) + assert resolve_pkg_root_and_module_name( + models_py, consider_namespace_packages=True + ) == (tmp_path, "src.app.core.models") + assert resolve_pkg_root_and_module_name( + models_py, consider_namespace_packages=False + ) == (tmp_path / "src", "app.core.models") + + def test_insert_missing_modules( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + monkeypatch.chdir(tmp_path) + # Use 'xxx' and 'xxy' as parent names as they are unlikely to exist and + # don't end up being imported. + modules = {"xxx.project.foo": ModuleType("xxx.project.foo")} + _insert_missing_modules(modules, "xxx.project.foo") + assert sorted(modules) == ["xxx", "xxx.project", "xxx.project.foo"] + + mod = ModuleType("mod", doc="My Module") + modules = {"xxy": mod} + _insert_missing_modules(modules, "xxy") + assert modules == {"xxy": mod} + + modules = {} + _insert_missing_modules(modules, "") + assert not modules + + def test_parent_contains_child_module_attribute( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ): + monkeypatch.chdir(tmp_path) + # Use 'xxx' and 'xxy' as parent names as they are unlikely to exist and + # don't end up being imported. + modules = {"xxx.tasks.foo": ModuleType("xxx.tasks.foo")} + _insert_missing_modules(modules, "xxx.tasks.foo") + assert sorted(modules) == ["xxx", "xxx.tasks", "xxx.tasks.foo"] + assert modules["xxx"].tasks is modules["xxx.tasks"] + assert modules["xxx.tasks"].foo is modules["xxx.tasks.foo"] + + def test_importlib_package( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path, ns_param: bool + ) -> None: + """ + Importing a package using --importmode=importlib should not import the + package's __init__.py file more than once (#11306). + """ + monkeypatch.chdir(tmp_path) + monkeypatch.syspath_prepend(tmp_path) + + package_name = "importlib_import_package" + tmp_path.joinpath(package_name).mkdir() + init = tmp_path.joinpath(f"{package_name}/__init__.py") + init.write_text( + textwrap.dedent( + """ + from .singleton import Singleton + instance = Singleton() + """ + ), + encoding="ascii", + ) + singleton = tmp_path.joinpath(f"{package_name}/singleton.py") + singleton.write_text( + textwrap.dedent( + """ + class Singleton: + INSTANCES = [] + def __init__(self) -> None: + self.INSTANCES.append(self) + if len(self.INSTANCES) > 1: + raise RuntimeError("Already initialized") + """ + ), + encoding="ascii", + ) - mod = ModuleType("mod", doc="My Module") - modules = {"xxy": mod} - _insert_missing_modules(modules, "xxy") - assert modules == {"xxy": mod} + mod = import_path( + init, + root=tmp_path, + mode=ImportMode.importlib, + consider_namespace_packages=ns_param, + ) + assert len(mod.instance.INSTANCES) == 1 + # Ensure we do not import the same module again (#11475). + mod2 = import_path( + init, + root=tmp_path, + mode=ImportMode.importlib, + consider_namespace_packages=ns_param, + ) + assert mod is mod2 + + def test_importlib_root_is_package(self, runner, tmp_path: Path) -> None: + """ + Regression for importing a `__init__`.py file that is at the root (#11417). + """ + tmp_path.joinpath("__init__.py").touch() + tmp_path.joinpath("task_example.py").write_text( + "def task_my_task(): assert True" + ) + result = runner.invoke(cli, ["--import-mode=importlib", tmp_path.as_posix()]) + assert "1 Succeeded" in result.output + + +class TestNamespacePackages: + """Test import_path support when importing from properly namespace packages.""" + + def setup_directories( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> tuple[Path, Path]: + # Set up a namespace package "com.company", containing + # two subpackages, "app" and "calc". + (tmp_path / "src/dist1/com/company/app/core").mkdir(parents=True) + (tmp_path / "src/dist1/com/company/app/__init__.py").touch() + (tmp_path / "src/dist1/com/company/app/core/__init__.py").touch() + models_py = tmp_path / "src/dist1/com/company/app/core/models.py" + models_py.touch() + + (tmp_path / "src/dist2/com/company/calc/algo").mkdir(parents=True) + (tmp_path / "src/dist2/com/company/calc/__init__.py").touch() + (tmp_path / "src/dist2/com/company/calc/algo/__init__.py").touch() + algorithms_py = tmp_path / "src/dist2/com/company/calc/algo/algorithms.py" + algorithms_py.touch() + + monkeypatch.syspath_prepend(tmp_path / "src/dist1") + monkeypatch.syspath_prepend(tmp_path / "src/dist2") + return models_py, algorithms_py + + @pytest.mark.parametrize("import_mode", ImportMode) + def test_resolve_pkg_root_and_module_name_ns_multiple_levels( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch, import_mode: ImportMode + ) -> None: + models_py, algorithms_py = self.setup_directories(tmp_path, monkeypatch) + + pkg_root, module_name = resolve_pkg_root_and_module_name( + models_py, consider_namespace_packages=True + ) + assert (pkg_root, module_name) == ( + tmp_path / "src/dist1", + "com.company.app.core.models", + ) - modules = {} - _insert_missing_modules(modules, "") - assert not modules + mod = import_path( + models_py, mode=import_mode, root=tmp_path, consider_namespace_packages=True + ) + assert mod.__name__ == "com.company.app.core.models" + assert mod.__file__ == str(models_py) + # Ensure we do not import the same module again (#11475). + mod2 = import_path( + models_py, mode=import_mode, root=tmp_path, consider_namespace_packages=True + ) + assert mod is mod2 -@pytest.mark.unit() -def test_importlib_package(monkeypatch: pytest.MonkeyPatch, tmp_path: Path): - """ - Importing a package using --importmode=importlib should not import the - package's __init__.py file more than once (#11306). - """ - monkeypatch.chdir(tmp_path) - monkeypatch.syspath_prepend(tmp_path) - - package_name = "importlib_import_package" - tmp_path.joinpath(package_name).mkdir() - init = tmp_path.joinpath(f"{package_name}/__init__.py") - init.write_text( - textwrap.dedent( - """ - from .singleton import Singleton - instance = Singleton() - """ - ), - encoding="ascii", - ) - singleton = tmp_path.joinpath(f"{package_name}/singleton.py") - singleton.write_text( - textwrap.dedent( - """ - class Singleton: - INSTANCES = [] - def __init__(self) -> None: - self.INSTANCES.append(self) - if len(self.INSTANCES) > 1: - raise RuntimeError("Already initialized") - """ - ), - encoding="ascii", - ) + pkg_root, module_name = resolve_pkg_root_and_module_name( + algorithms_py, consider_namespace_packages=True + ) + assert (pkg_root, module_name) == ( + tmp_path / "src/dist2", + "com.company.calc.algo.algorithms", + ) - mod = import_path(init, root=tmp_path) - assert len(mod.instance.INSTANCES) == 1 + mod = import_path( + algorithms_py, + mode=import_mode, + root=tmp_path, + consider_namespace_packages=True, + ) + assert mod.__name__ == "com.company.calc.algo.algorithms" + assert mod.__file__ == str(algorithms_py) + + # Ensure we do not import the same module again (#11475). + mod2 = import_path( + algorithms_py, + mode=import_mode, + root=tmp_path, + consider_namespace_packages=True, + ) + assert mod is mod2 + + def test_incorrect_namespace_package( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + models_py, algorithms_py = self.setup_directories(tmp_path, monkeypatch) + # Namespace packages must not have an __init__.py at any of its + # directories; if it does, we then fall back to importing just the + # part of the package containing the __init__.py files. + (tmp_path / "src/dist1/com/__init__.py").touch() + + pkg_root, module_name = resolve_pkg_root_and_module_name( + models_py, consider_namespace_packages=True + ) + assert (pkg_root, module_name) == ( + tmp_path / "src/dist1/com/company", + "app.core.models", + )