Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition in artifact cache #8517

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 20 additions & 43 deletions src/poetry/installation/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import contextlib
import csv
import functools
import itertools
import json
import threading
Expand All @@ -15,7 +16,6 @@

from cleo.io.null_io import NullIO
from poetry.core.packages.utils.link import Link
from requests.utils import atomic_open

from poetry.installation.chef import Chef
from poetry.installation.chef import ChefBuildError
Expand All @@ -28,8 +28,8 @@
from poetry.puzzle.exceptions import SolverProblemError
from poetry.utils._compat import decode
from poetry.utils.authenticator import Authenticator
from poetry.utils.cache import ArtifactCache
from poetry.utils.env import EnvCommandError
from poetry.utils.helpers import Downloader
from poetry.utils.helpers import get_file_hash
from poetry.utils.helpers import pluralize
from poetry.utils.helpers import remove_directory
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(
else:
self._max_workers = 1

self._artifact_cache = ArtifactCache(cache_dir=config.artifacts_cache_directory)
self._artifact_cache = pool.artifact_cache
self._authenticator = Authenticator(
config, self._io, disable_cache=disable_cache, pool_size=self._max_workers
)
Expand Down Expand Up @@ -748,23 +748,11 @@ def _download(self, operation: Install | Update) -> Path:
def _download_link(self, operation: Install | Update, link: Link) -> Path:
package = operation.package

output_dir = self._artifact_cache.get_cache_directory_for_link(link)
# Try to get cached original package for the link provided
# Get original package for the link provided
download_func = functools.partial(self._download_archive, operation)
original_archive = self._artifact_cache.get_cached_archive_for_link(
link, strict=True
link, strict=True, download_func=download_func
)
if original_archive is None:
# No cached original distributions was found, so we download and prepare it
try:
original_archive = self._download_archive(operation, link)
except BaseException:
cache_directory = self._artifact_cache.get_cache_directory_for_link(
link
)
cached_file = cache_directory.joinpath(link.filename)
cached_file.unlink(missing_ok=True)

raise

# Get potential higher prioritized cached archive, otherwise it will fall back
# to the original archive.
Expand All @@ -790,7 +778,7 @@ def _download_link(self, operation: Install | Update, link: Link) -> Path:
)
self._write(operation, message)

archive = self._chef.prepare(archive, output_dir=output_dir)
archive = self._chef.prepare(archive, output_dir=original_archive.parent)

# Use the original archive to provide the correct hash.
self._populate_hashes_dict(original_archive, package)
Expand All @@ -815,11 +803,15 @@ def _validate_archive_hash(archive: Path, package: Package) -> str:

return archive_hash

def _download_archive(self, operation: Install | Update, link: Link) -> Path:
response = self._authenticator.request(
"get", link.url, stream=True, io=self._sections.get(id(operation), self._io)
)
wheel_size = response.headers.get("content-length")
def _download_archive(
self,
operation: Install | Update,
url: str,
dest: Path,
) -> None:
downloader = Downloader(url, dest, self._authenticator)
wheel_size = downloader.total_size

operation_message = self.get_operation_message(operation)
message = (
f" <fg=blue;options=bold>•</> {operation_message}: <info>Downloading...</>"
Expand All @@ -841,30 +833,15 @@ def _download_archive(self, operation: Install | Update, link: Link) -> Path:
self._sections[id(operation)].clear()
progress.start()

done = 0
archive = (
self._artifact_cache.get_cache_directory_for_link(link) / link.filename
)
archive.parent.mkdir(parents=True, exist_ok=True)
with atomic_open(archive) as f:
for chunk in response.iter_content(chunk_size=4096):
if not chunk:
break

done += len(chunk)

if progress:
with self._lock:
progress.set_progress(done)

f.write(chunk)
for fetched_size in downloader.download_with_progress(chunk_size=4096):
if progress:
with self._lock:
progress.set_progress(fetched_size)

if progress:
with self._lock:
progress.finish()

return archive

def _should_write_operation(self, operation: Operation) -> bool:
return (
not operation.skipped or self._dry_run or self._verbose or not self._enabled
Expand Down
11 changes: 3 additions & 8 deletions src/poetry/packages/direct_origin.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,9 @@ def get_package_from_directory(cls, directory: Path) -> Package:

def get_package_from_url(self, url: str) -> Package:
link = Link(url)
artifact = self._artifact_cache.get_cached_archive_for_link(link, strict=True)

if not artifact:
artifact = (
self._artifact_cache.get_cache_directory_for_link(link) / link.filename
)
artifact.parent.mkdir(parents=True, exist_ok=True)
download_file(url, artifact)
artifact = self._artifact_cache.get_cached_archive_for_link(
link, strict=True, download_func=download_file
)

package = self.get_package_from_file(artifact)
package.files = [
Expand Down
47 changes: 45 additions & 2 deletions src/poetry/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
import json
import logging
import shutil
import threading
import time

from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from typing import Generic
from typing import TypeVar
from typing import overload

from poetry.utils._compat import decode
from poetry.utils._compat import encode
Expand Down Expand Up @@ -187,6 +190,9 @@ def _deserialize(self, data_raw: bytes) -> CacheItem[T]:
class ArtifactCache:
def __init__(self, *, cache_dir: Path) -> None:
self._cache_dir = cache_dir
self._archive_locks: defaultdict[Path, threading.Lock] = defaultdict(
threading.Lock
)

def get_cache_directory_for_link(self, link: Link) -> Path:
key_parts = {"url": link.url_without_fragment}
Expand Down Expand Up @@ -218,18 +224,54 @@ def get_cache_directory_for_git(

return self._get_directory_from_hash(key_parts)

@overload
def get_cached_archive_for_link(
self,
link: Link,
*,
strict: bool,
env: Env | None = ...,
download_func: Callable[[str, Path], None],
) -> Path: ...

@overload
def get_cached_archive_for_link(
self,
link: Link,
*,
strict: bool,
env: Env | None = ...,
download_func: None = ...,
) -> Path | None: ...

def get_cached_archive_for_link(
self,
link: Link,
*,
strict: bool,
env: Env | None = None,
download_func: Callable[[str, Path], None] | None = None,
) -> Path | None:
cache_dir = self.get_cache_directory_for_link(link)

return self._get_cached_archive(
cached_archive = self._get_cached_archive(
cache_dir, strict=strict, filename=link.filename, env=env
)
if cached_archive is None and strict and download_func is not None:
cached_archive = cache_dir / link.filename
with self._archive_locks[cached_archive]:
# Check again if the archive exists (under the lock) to avoid
# duplicate downloads because it may have already been downloaded
# by another thread in the meantime
if not cached_archive.exists():
cache_dir.mkdir(parents=True, exist_ok=True)
try:
download_func(link.url, cached_archive)
except BaseException:
cached_archive.unlink(missing_ok=True)
raise

return cached_archive

def get_cached_archive_for_git(
self, url: str, reference: str, subdirectory: str | None, env: Env
Expand All @@ -246,8 +288,9 @@ def _get_cached_archive(
filename: str | None = None,
env: Env | None = None,
) -> Path | None:
# implication "not strict -> env must not be None"
assert strict or env is not None
# implication "strict -> filename should not be None"
# implication "strict -> filename must not be None"
assert not strict or filename is not None

archives = self._get_cached_archives(cache_dir)
Expand Down
64 changes: 43 additions & 21 deletions src/poetry/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@

from collections.abc import Mapping
from contextlib import contextmanager
from contextlib import suppress
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from typing import overload

import requests

from requests.utils import atomic_open

from poetry.utils.constants import REQUESTS_TIMEOUT
Expand Down Expand Up @@ -100,43 +104,61 @@ def download_file(
session: Authenticator | Session | None = None,
chunk_size: int = 1024,
) -> None:
import requests

from poetry.puzzle.provider import Indicator

get = requests.get if not session else session.get

response = get(url, stream=True, timeout=REQUESTS_TIMEOUT)
response.raise_for_status()
downloader = Downloader(url, dest, session)

set_indicator = False
with Indicator.context() as update_context:
update_context(f"Downloading {url}")

if "Content-Length" in response.headers:
try:
total_size = int(response.headers["Content-Length"])
except ValueError:
total_size = 0

total_size = downloader.total_size
if total_size > 0:
fetched_size = 0
last_percent = 0

# if less than 1MB, we simply show that we're downloading
# but skip the updating
set_indicator = total_size > 1024 * 1024

with atomic_open(dest) as f:
for chunk in response.iter_content(chunk_size=chunk_size):
for fetched_size in downloader.download_with_progress(chunk_size):
if set_indicator:
percent = (fetched_size * 100) // total_size
if percent > last_percent:
last_percent = percent
update_context(f"Downloading {url} {percent:3}%")


class Downloader:
def __init__(
self,
url: str,
dest: Path,
session: Authenticator | Session | None = None,
):
self._dest = dest

get = requests.get if not session else session.get

self._response = get(url, stream=True, timeout=REQUESTS_TIMEOUT)
self._response.raise_for_status()

@cached_property
def total_size(self) -> int:
total_size = 0
if "Content-Length" in self._response.headers:
with suppress(ValueError):
total_size = int(self._response.headers["Content-Length"])
return total_size

def download_with_progress(self, chunk_size: int = 1024) -> Iterator[int]:
fetched_size = 0
with atomic_open(self._dest) as f:
for chunk in self._response.iter_content(chunk_size=chunk_size):
if chunk:
f.write(chunk)

if set_indicator:
fetched_size += len(chunk)
percent = (fetched_size * 100) // total_size
if percent > last_percent:
last_percent = percent
update_context(f"Downloading {url} {percent:3}%")
fetched_size += len(chunk)
yield fetched_size


def get_package_version_display_string(
Expand Down
Loading