diff --git a/src/poetry/mixology/version_solver.py b/src/poetry/mixology/version_solver.py index 4999a35fd9b..abfafd49806 100644 --- a/src/poetry/mixology/version_solver.py +++ b/src/poetry/mixology/version_solver.py @@ -5,6 +5,8 @@ import time from typing import TYPE_CHECKING +from typing import Optional +from typing import Tuple from poetry.core.packages.dependency import Dependency @@ -29,6 +31,11 @@ _conflict = object() +DependencyCacheKey = Tuple[ + str, Optional[str], Optional[str], Optional[str], Optional[str] +] + + class DependencyCache: """ A cache of the valid dependencies. @@ -40,18 +47,46 @@ class DependencyCache: def __init__(self, provider: Provider) -> None: self.provider = provider - self.cache: dict[ - int, - dict[ - tuple[str, str | None, str | None, str | None, str | None], - list[DependencyPackage], - ], - ] = collections.defaultdict(dict) - self.search_for = functools.lru_cache(maxsize=128)(self._search_for) + # self.cache maps a package name to a stack of cached package lists, + # ordered by the decision level which added them to the cache. This is + # done so that when backtracking we can maintain cache entries from + # previous decision levels, while clearing cache entries from only the + # rolled back levels. + # + # In order to maintain the integrity of the cache, `clear_level()` + # needs to be called in descending order as decision levels are + # backtracked so that the correct items can be popped from the stack. + self.cache: dict[DependencyCacheKey, list[list[DependencyPackage]]] = ( + collections.defaultdict(list) + ) + self.cached_dependencies_by_level: dict[int, list[DependencyCacheKey]] = ( + collections.defaultdict(list) + ) + + self._search_for_cached = functools.lru_cache(maxsize=128)(self._search_for) def _search_for( - self, dependency: Dependency, level: int + self, + dependency: Dependency, + key: DependencyCacheKey, + ) -> list[DependencyPackage]: + cache_entries = self.cache.get(key) + if cache_entries: + packages = [ + p + for p in cache_entries[-1] + if dependency.constraint.allows(p.package.version) + ] + else: + packages = self.provider.search_for(dependency) + + return packages + + def search_for( + self, + dependency: Dependency, + decision_level: int, ) -> list[DependencyPackage]: key = ( dependency.complete_name, @@ -60,26 +95,17 @@ def _search_for( dependency.source_reference, dependency.source_subdirectory, ) - - for check_level in range(level, -1, -1): - packages = self.cache[check_level].get(key) - if packages is not None: - packages = [ - p - for p in packages - if dependency.constraint.allows(p.package.version) - ] - break - else: - packages = self.provider.search_for(dependency) - - self.cache[level][key] = packages - + packages = self._search_for_cached(dependency, key) + if not self.cache[key] or self.cache[key][-1] is not packages: + self.cache[key].append(packages) + self.cached_dependencies_by_level[decision_level].append(key) return packages def clear_level(self, level: int) -> None: - self.search_for.cache_clear() - self.cache.pop(level, None) + if level in self.cached_dependencies_by_level: + self._search_for_cached.cache_clear() + for key in self.cached_dependencies_by_level.pop(level): + self.cache[key].pop() class VersionSolver: @@ -318,9 +344,10 @@ def _resolve_conflict(self, incompatibility: Incompatibility) -> Incompatibility for level in range( self._solution.decision_level, previous_satisfier_level, -1 ): - self._contradicted_incompatibilities.difference_update( - self._contradicted_incompatibilities_by_level.pop(level, set()), - ) + if level in self._contradicted_incompatibilities_by_level: + self._contradicted_incompatibilities.difference_update( + self._contradicted_incompatibilities_by_level.pop(level), + ) self._dependency_cache.clear_level(level) self._solution.backtrack(previous_satisfier_level) diff --git a/tests/mixology/version_solver/test_dependency_cache.py b/tests/mixology/version_solver/test_dependency_cache.py index 5a3d3e7cbe3..3e20e762250 100644 --- a/tests/mixology/version_solver/test_dependency_cache.py +++ b/tests/mixology/version_solver/test_dependency_cache.py @@ -30,20 +30,20 @@ def test_solver_dependency_cache_respects_source_type( add_to_repo(repo, "demo", "1.0.0") cache = DependencyCache(provider) - cache.search_for.cache_clear() + cache._search_for_cached.cache_clear() # ensure cache was never hit for both calls cache.search_for(dependency_pypi, 0) cache.search_for(dependency_git, 0) - assert not cache.search_for.cache_info().hits + assert not cache._search_for_cached.cache_info().hits # increase test coverage by searching for copies # (when searching for the exact same object, __eq__ is never called) packages_pypi = cache.search_for(deepcopy(dependency_pypi), 0) packages_git = cache.search_for(deepcopy(dependency_git), 0) - assert cache.search_for.cache_info().hits == 2 - assert cache.search_for.cache_info().currsize == 2 + assert cache._search_for_cached.cache_info().hits == 2 + assert cache._search_for_cached.cache_info().currsize == 2 assert len(packages_pypi) == len(packages_git) == 1 assert packages_pypi != packages_git @@ -65,38 +65,59 @@ def test_solver_dependency_cache_pulls_from_prior_level_cache( root: ProjectPackage, provider: Provider, repo: Repository ) -> None: dependency_pypi = Factory.create_dependency("demo", ">=0.1.0") + dependency_pypi_constrained = Factory.create_dependency("demo", ">=0.1.0,<2.0.0") root.add_dependency(dependency_pypi) + root.add_dependency(dependency_pypi_constrained) add_to_repo(repo, "demo", "1.0.0") wrapped_provider = mock.Mock(wraps=provider) cache = DependencyCache(wrapped_provider) - cache.search_for.cache_clear() + cache._search_for_cached.cache_clear() - # On first call, provider.search_for() should be called and the level-0 - # cache populated. + # On first call, provider.search_for() should be called and the cache + # populated. cache.search_for(dependency_pypi, 0) assert len(wrapped_provider.search_for.mock_calls) == 1 - assert ("demo", None, None, None, None) in cache.cache[0] - assert cache.search_for.cache_info().hits == 0 - assert cache.search_for.cache_info().misses == 1 - - # On second call at level 1, provider.search_for() should not be called - # again and the level-1 cache should be populated from the level-0 cache. + assert ("demo", None, None, None, None) in cache.cache + assert ("demo", None, None, None, None) in cache.cached_dependencies_by_level[0] + assert cache._search_for_cached.cache_info().hits == 0 + assert cache._search_for_cached.cache_info().misses == 1 + + # On second call at level 1, neither provider.search_for() nor + # cache._search_for_cached() should have been called again, and the cache + # should remain the same. cache.search_for(dependency_pypi, 1) assert len(wrapped_provider.search_for.mock_calls) == 1 - assert ("demo", None, None, None, None) in cache.cache[1] - assert cache.cache[0] == cache.cache[1] - assert cache.search_for.cache_info().hits == 0 - assert cache.search_for.cache_info().misses == 2 - - # Clearing the level 1 cache should invalidate the lru_cache on - # cache.search_for and wipe out the level 1 cache while preserving the + assert ("demo", None, None, None, None) in cache.cache + assert ("demo", None, None, None, None) in cache.cached_dependencies_by_level[0] + assert set(cache.cached_dependencies_by_level.keys()) == {0} + assert cache._search_for_cached.cache_info().hits == 1 + assert cache._search_for_cached.cache_info().misses == 1 + + # On third call at level 2 with an updated constraint for the `demo` + # package should not call provider.search_for(), but should call + # cache._search_for_cached() and update the cache. + cache.search_for(dependency_pypi_constrained, 2) + assert len(wrapped_provider.search_for.mock_calls) == 1 + assert ("demo", None, None, None, None) in cache.cache + assert ("demo", None, None, None, None) in cache.cached_dependencies_by_level[0] + assert ("demo", None, None, None, None) in cache.cached_dependencies_by_level[2] + assert set(cache.cached_dependencies_by_level.keys()) == {0, 2} + assert cache._search_for_cached.cache_info().hits == 1 + assert cache._search_for_cached.cache_info().misses == 2 + + # Clearing the level 2 and level 1 caches should invalidate the lru_cache + # on cache.search_for and wipe out the level 2 cache while preserving the # level 0 cache. + cache.clear_level(2) cache.clear_level(1) - assert set(cache.cache.keys()) == {0} - assert ("demo", None, None, None, None) in cache.cache[0] - assert cache.search_for.cache_info().hits == 0 - assert cache.search_for.cache_info().misses == 0 + cache.search_for(dependency_pypi, 0) + assert len(wrapped_provider.search_for.mock_calls) == 1 + assert ("demo", None, None, None, None) in cache.cache + assert ("demo", None, None, None, None) in cache.cached_dependencies_by_level[0] + assert set(cache.cached_dependencies_by_level.keys()) == {0} + assert cache._search_for_cached.cache_info().hits == 0 + assert cache._search_for_cached.cache_info().misses == 1 def test_solver_dependency_cache_respects_subdirectories( @@ -123,20 +144,20 @@ def test_solver_dependency_cache_respects_subdirectories( root.add_dependency(dependency_one_copy) cache = DependencyCache(provider) - cache.search_for.cache_clear() + cache._search_for_cached.cache_clear() # ensure cache was never hit for both calls cache.search_for(dependency_one, 0) cache.search_for(dependency_one_copy, 0) - assert not cache.search_for.cache_info().hits + assert not cache._search_for_cached.cache_info().hits # increase test coverage by searching for copies # (when searching for the exact same object, __eq__ is never called) packages_one = cache.search_for(deepcopy(dependency_one), 0) packages_one_copy = cache.search_for(deepcopy(dependency_one_copy), 0) - assert cache.search_for.cache_info().hits == 2 - assert cache.search_for.cache_info().currsize == 2 + assert cache._search_for_cached.cache_info().hits == 2 + assert cache._search_for_cached.cache_info().currsize == 2 assert len(packages_one) == len(packages_one_copy) == 1