Skip to content

Commit f54864e

Browse files
chriskuehlradoering
authored andcommitted
perf: don't clear the entire dependency cache when backtracking (python-poetry#7950)
1 parent fba1309 commit f54864e

File tree

2 files changed

+83
-28
lines changed

2 files changed

+83
-28
lines changed

Diff for: src/poetry/mixology/version_solver.py

+36-20
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,20 @@ class DependencyCache:
3939
"""
4040

4141
def __init__(self, provider: Provider) -> None:
42-
self.provider = provider
43-
self.cache: dict[
44-
tuple[str, str | None, str | None, str | None, str | None],
45-
list[DependencyPackage],
46-
] = {}
42+
self._provider = provider
43+
self._cache: dict[
44+
int,
45+
dict[
46+
tuple[str, str | None, str | None, str | None, str | None],
47+
list[DependencyPackage],
48+
],
49+
] = collections.defaultdict(dict)
4750

4851
self.search_for = functools.lru_cache(maxsize=128)(self._search_for)
4952

50-
def _search_for(self, dependency: Dependency) -> list[DependencyPackage]:
53+
def _search_for(
54+
self, dependency: Dependency, level: int
55+
) -> list[DependencyPackage]:
5156
key = (
5257
dependency.complete_name,
5358
dependency.source_type,
@@ -56,12 +61,17 @@ def _search_for(self, dependency: Dependency) -> list[DependencyPackage]:
5661
dependency.source_subdirectory,
5762
)
5863

59-
packages = self.cache.get(key)
60-
61-
if packages:
62-
packages = [
63-
p for p in packages if dependency.constraint.allows(p.package.version)
64-
]
64+
for check_level in range(level, -1, -1):
65+
packages = self._cache[check_level].get(key)
66+
if packages is not None:
67+
packages = [
68+
p
69+
for p in packages
70+
if dependency.constraint.allows(p.package.version)
71+
]
72+
break
73+
else:
74+
packages = None
6575

6676
# provider.search_for() normally does not include pre-release packages
6777
# (unless requested), but will include them if there are no other
@@ -71,14 +81,14 @@ def _search_for(self, dependency: Dependency) -> list[DependencyPackage]:
7181
# nothing, we need to call provider.search_for() again as it may return
7282
# additional results this time.
7383
if not packages:
74-
packages = self.provider.search_for(dependency)
75-
76-
self.cache[key] = packages
84+
packages = self._provider.search_for(dependency)
7785

86+
self._cache[level][key] = packages
7887
return packages
7988

80-
def clear(self) -> None:
81-
self.cache.clear()
89+
def clear_level(self, level: int) -> None:
90+
self.search_for.cache_clear()
91+
self._cache.pop(level, None)
8292

8393

8494
class VersionSolver:
@@ -318,9 +328,9 @@ def _resolve_conflict(self, incompatibility: Incompatibility) -> Incompatibility
318328
self._solution.decision_level, previous_satisfier_level, -1
319329
):
320330
self._contradicted_incompatibilities.pop(level, None)
331+
self._dependency_cache.clear_level(level)
321332

322333
self._solution.backtrack(previous_satisfier_level)
323-
self._dependency_cache.clear()
324334
if new_incompatibility:
325335
self._add_incompatibility(incompatibility)
326336

@@ -418,7 +428,11 @@ def _get_min(dependency: Dependency) -> tuple[bool, int, int]:
418428
if locked:
419429
return is_specific_marker, Preference.LOCKED, 1
420430

421-
num_packages = len(self._dependency_cache.search_for(dependency))
431+
num_packages = len(
432+
self._dependency_cache.search_for(
433+
dependency, self._solution.decision_level
434+
)
435+
)
422436

423437
if num_packages < 2:
424438
preference = Preference.NO_CHOICE
@@ -435,7 +449,9 @@ def _get_min(dependency: Dependency) -> tuple[bool, int, int]:
435449

436450
locked = self._provider.get_locked(dependency)
437451
if locked is None:
438-
packages = self._dependency_cache.search_for(dependency)
452+
packages = self._dependency_cache.search_for(
453+
dependency, self._solution.decision_level
454+
)
439455
package = next(iter(packages), None)
440456

441457
if package is None:

Diff for: tests/mixology/version_solver/test_dependency_cache.py

+47-8
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from copy import deepcopy
44
from typing import TYPE_CHECKING
5+
from unittest import mock
56

67
from poetry.factory import Factory
78
from poetry.mixology.version_solver import DependencyCache
@@ -32,14 +33,14 @@ def test_solver_dependency_cache_respects_source_type(
3233
cache.search_for.cache_clear()
3334

3435
# ensure cache was never hit for both calls
35-
cache.search_for(dependency_pypi)
36-
cache.search_for(dependency_git)
36+
cache.search_for(dependency_pypi, 0)
37+
cache.search_for(dependency_git, 0)
3738
assert not cache.search_for.cache_info().hits
3839

3940
# increase test coverage by searching for copies
4041
# (when searching for the exact same object, __eq__ is never called)
41-
packages_pypi = cache.search_for(deepcopy(dependency_pypi))
42-
packages_git = cache.search_for(deepcopy(dependency_git))
42+
packages_pypi = cache.search_for(deepcopy(dependency_pypi), 0)
43+
packages_git = cache.search_for(deepcopy(dependency_git), 0)
4344

4445
assert cache.search_for.cache_info().hits == 2
4546
assert cache.search_for.cache_info().currsize == 2
@@ -60,6 +61,44 @@ def test_solver_dependency_cache_respects_source_type(
6061
assert package_git.package.source_resolved_reference == MOCK_DEFAULT_GIT_REVISION
6162

6263

64+
def test_solver_dependency_cache_pulls_from_prior_level_cache(
65+
root: ProjectPackage, provider: Provider, repo: Repository
66+
) -> None:
67+
dependency_pypi = Factory.create_dependency("demo", ">=0.1.0")
68+
root.add_dependency(dependency_pypi)
69+
add_to_repo(repo, "demo", "1.0.0")
70+
71+
wrapped_provider = mock.Mock(wraps=provider)
72+
cache = DependencyCache(wrapped_provider)
73+
cache.search_for.cache_clear()
74+
75+
# On first call, provider.search_for() should be called and the level-0
76+
# cache populated.
77+
cache.search_for(dependency_pypi, 0)
78+
assert len(wrapped_provider.search_for.mock_calls) == 1
79+
assert ("demo", None, None, None, None) in cache._cache[0]
80+
assert cache.search_for.cache_info().hits == 0
81+
assert cache.search_for.cache_info().misses == 1
82+
83+
# On second call at level 1, provider.search_for() should not be called
84+
# again and the level-1 cache should be populated from the level-0 cache.
85+
cache.search_for(dependency_pypi, 1)
86+
assert len(wrapped_provider.search_for.mock_calls) == 1
87+
assert ("demo", None, None, None, None) in cache._cache[1]
88+
assert cache._cache[0] == cache._cache[1]
89+
assert cache.search_for.cache_info().hits == 0
90+
assert cache.search_for.cache_info().misses == 2
91+
92+
# Clearing the level 1 cache should invalidate the lru_cache on
93+
# cache.search_for and wipe out the level 1 cache while preserving the
94+
# level 0 cache.
95+
cache.clear_level(1)
96+
assert set(cache._cache.keys()) == {0}
97+
assert ("demo", None, None, None, None) in cache._cache[0]
98+
assert cache.search_for.cache_info().hits == 0
99+
assert cache.search_for.cache_info().misses == 0
100+
101+
63102
def test_solver_dependency_cache_respects_subdirectories(
64103
root: ProjectPackage, provider: Provider, repo: Repository
65104
) -> None:
@@ -87,14 +126,14 @@ def test_solver_dependency_cache_respects_subdirectories(
87126
cache.search_for.cache_clear()
88127

89128
# ensure cache was never hit for both calls
90-
cache.search_for(dependency_one)
91-
cache.search_for(dependency_one_copy)
129+
cache.search_for(dependency_one, 0)
130+
cache.search_for(dependency_one_copy, 0)
92131
assert not cache.search_for.cache_info().hits
93132

94133
# increase test coverage by searching for copies
95134
# (when searching for the exact same object, __eq__ is never called)
96-
packages_one = cache.search_for(deepcopy(dependency_one))
97-
packages_one_copy = cache.search_for(deepcopy(dependency_one_copy))
135+
packages_one = cache.search_for(deepcopy(dependency_one), 0)
136+
packages_one_copy = cache.search_for(deepcopy(dependency_one_copy), 0)
98137

99138
assert cache.search_for.cache_info().hits == 2
100139
assert cache.search_for.cache_info().currsize == 2

0 commit comments

Comments
 (0)