Skip to content

Commit

Permalink
fix race condition to avoid downloading the same artifact in multiple…
Browse files Browse the repository at this point in the history
… threads and trying to store it in the same location of the artifact cache
  • Loading branch information
radoering committed Oct 7, 2023
1 parent 5720160 commit 50e520a
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 12 deletions.
22 changes: 16 additions & 6 deletions src/poetry/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import json
import logging
import shutil
import threading
import time

from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
Expand Down Expand Up @@ -188,6 +190,9 @@ def _deserialize(self, data_raw: bytes) -> CacheItem[T]:
class ArtifactCache:
def __init__(self, *, cache_dir: Path) -> None:
self._cache_dir = cache_dir
self._archive_locks: defaultdict[Path, threading.Lock] = defaultdict(
threading.Lock
)

def get_cache_directory_for_link(self, link: Link) -> Path:
key_parts = {"url": link.url_without_fragment}
Expand Down Expand Up @@ -253,13 +258,18 @@ def get_cached_archive_for_link(
cache_dir, strict=strict, filename=link.filename, env=env
)
if cached_archive is None and strict and download_func is not None:
cache_dir.mkdir(parents=True, exist_ok=True)
cached_archive = cache_dir / link.filename
try:
download_func(link.url, cached_archive)
except BaseException:
cached_archive.unlink(missing_ok=True)
raise
with self._archive_locks[cached_archive]:
# Check again if the archive exists (under the lock) to avoid
# duplicate downloads because it may have already been downloaded
# by another thread in the meantime
if not cached_archive.exists():
cache_dir.mkdir(parents=True, exist_ok=True)
try:
download_func(link.url, cached_archive)
except BaseException:
cached_archive.unlink(missing_ok=True)
raise

return cached_archive

Expand Down
14 changes: 8 additions & 6 deletions tests/installation/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,14 +582,16 @@ def test_executor_should_delete_incomplete_downloads(
pool: RepositoryPool,
mock_file_downloads: None,
env: MockEnv,
fixture_dir: FixtureDirGetter,
) -> None:
fixture = fixture_dir("distributions") / "demo-0.1.0-py2.py3-none-any.whl"
destination_fixture = tmp_path / "tomlkit-0.5.3-py2.py3-none-any.whl"
shutil.copyfile(str(fixture), str(destination_fixture))
cached_archive = tmp_path / "tomlkit-0.5.3-py2.py3-none-any.whl"

def download_fail(*_: Any) -> None:
cached_archive.touch() # broken archive
raise Exception("Download error")

mocker.patch(
"poetry.installation.executor.Executor._download_archive",
side_effect=Exception("Download error"),
side_effect=download_fail,
)
mocker.patch(
"poetry.utils.cache.ArtifactCache._get_cached_archive",
Expand All @@ -607,7 +609,7 @@ def test_executor_should_delete_incomplete_downloads(
with pytest.raises(Exception, match="Download error"):
executor._download(Install(Package("tomlkit", "0.5.3")))

assert not destination_fixture.exists()
assert not cached_archive.exists()


def verify_installed_distribution(
Expand Down
37 changes: 37 additions & 0 deletions tests/utils/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import concurrent.futures
import shutil
import traceback

from pathlib import Path
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -322,6 +324,41 @@ def test_get_found_cached_archive_for_link(
assert Path(cached) == archive


def test_get_cached_archive_for_link_no_race_condition(
tmp_path: Path, mocker: MockerFixture
) -> None:
cache = ArtifactCache(cache_dir=tmp_path)
link = Link("https://files.python-poetry.org/demo-0.1.0.tar.gz")

def replace_file(_: str, dest: Path) -> None:
dest.unlink(missing_ok=True)
# write some data (so it takes a while) to provoke possible race conditions
dest.write_text("a" * 2**20)

download_mock = mocker.Mock(side_effect=replace_file)

with concurrent.futures.ThreadPoolExecutor() as executor:
tasks = []
for _ in range(4):
tasks.append(
executor.submit(
cache.get_cached_archive_for_link,
link,
strict=True,
download_func=download_mock,
)
)
concurrent.futures.wait(tasks)
results = set()
for task in tasks:
try:
results.add(task.result())
except Exception:
pytest.fail(traceback.format_exc())
assert results == {cache.get_cache_directory_for_link(link) / link.filename}
download_mock.assert_called_once()


def test_get_cached_archive_for_git() -> None:
"""Smoke test that checks that no assertion is raised."""
cache = ArtifactCache(cache_dir=Path())
Expand Down

0 comments on commit 50e520a

Please sign in to comment.