diff --git a/crates/rattler_cache/Cargo.toml b/crates/rattler_cache/Cargo.toml index 86f14e1ec8..a59f4dd54d 100644 --- a/crates/rattler_cache/Cargo.toml +++ b/crates/rattler_cache/Cargo.toml @@ -23,6 +23,7 @@ rattler_digest = { version = "1.0.6", path = "../rattler_digest", default-featur rattler_networking = { version = "0.22.4", path = "../rattler_networking", default-features = false } rattler_package_streaming = { version = "0.22.28", path = "../rattler_package_streaming", default-features = false, features = ["reqwest"] } reqwest.workspace = true +tempfile.workspace = true tokio = { workspace = true, features = ["macros"] } tracing.workspace = true url.workspace = true @@ -32,6 +33,7 @@ digest.workspace = true fs4 = { workspace = true, features = ["fs-err3-tokio", "tokio"] } simple_spawn_blocking = { version = "1.0.0", path = "../simple_spawn_blocking", features = ["tokio"] } rayon = { workspace = true } +serde_json = { workspace = true } [dev-dependencies] assert_matches.workspace = true @@ -39,7 +41,6 @@ axum.workspace = true bytes.workspace = true futures.workspace = true rstest.workspace = true -tempfile.workspace = true tokio-stream.workspace = true tower-http = { workspace = true, features = ["fs"] } tools = { path = "../tools" } diff --git a/crates/rattler_cache/src/consts.rs b/crates/rattler_cache/src/consts.rs index d635d262e1..a72f4b25ee 100644 --- a/crates/rattler_cache/src/consts.rs +++ b/crates/rattler_cache/src/consts.rs @@ -1,4 +1,5 @@ /// The location in the main cache folder where the conda package cache is stored. pub const PACKAGE_CACHE_DIR: &str = "pkgs"; +pub const RUN_EXPORTS_CACHE_DIR: &str = "run_exports"; /// The location in the main cache folder where the repodata cache is stored. pub const REPODATA_CACHE_DIR: &str = "repodata"; diff --git a/crates/rattler_cache/src/lib.rs b/crates/rattler_cache/src/lib.rs index 753b867a10..1f932e301b 100644 --- a/crates/rattler_cache/src/lib.rs +++ b/crates/rattler_cache/src/lib.rs @@ -1,11 +1,12 @@ use std::path::PathBuf; pub mod package_cache; +pub mod run_exports_cache; pub mod validation; mod consts; -pub use consts::{PACKAGE_CACHE_DIR, REPODATA_CACHE_DIR}; +pub use consts::{PACKAGE_CACHE_DIR, REPODATA_CACHE_DIR, RUN_EXPORTS_CACHE_DIR}; /// Determines the default cache directory for rattler. /// It first checks the environment variable `RATTLER_CACHE_DIR`. diff --git a/crates/rattler_cache/src/run_exports_cache/cache_key.rs b/crates/rattler_cache/src/run_exports_cache/cache_key.rs new file mode 100644 index 0000000000..a5383160ae --- /dev/null +++ b/crates/rattler_cache/src/run_exports_cache/cache_key.rs @@ -0,0 +1,86 @@ +use rattler_conda_types::{package::ArchiveIdentifier, PackageRecord}; +use rattler_digest::{Md5Hash, Sha256Hash}; +use std::fmt::{Display, Formatter}; + +/// Provides a unique identifier for packages in the cache. +#[derive(Debug, Hash, Clone, Eq, PartialEq)] +pub struct CacheKey { + pub(crate) name: String, + pub(crate) version: String, + pub(crate) build_string: String, + pub(crate) sha256: Option, + pub(crate) md5: Option, + pub(crate) extension: String, +} + +impl CacheKey { + /// Potentially adds a sha256 hash of the archive. + pub fn with_opt_sha256(mut self, sha256: Option) -> Self { + self.sha256 = sha256; + self + } + + /// Potentially adds a md5 hash of the archive. + pub fn with_opt_md5(mut self, md5: Option) -> Self { + self.md5 = md5; + self + } +} + +impl CacheKey { + /// Return the sha256 hash of the package if it is known. + pub fn sha256(&self) -> Option { + self.sha256 + } + + /// Return the md5 hash of the package if it is known. + pub fn md5(&self) -> Option { + self.md5 + } + + /// Return the sha256 hash string of the package if it is known. + pub fn sha256_str(&self) -> String { + self.sha256() + .map(|hash| format!("{hash:x}")) + .unwrap_or_default() + } + + /// Try to create a new cache key from a package record and a filename. + pub fn create(record: &PackageRecord, filename: &str) -> Result { + let archive_identifier = ArchiveIdentifier::try_from_filename(filename) + .ok_or_else(|| CacheKeyError::InvalidArchiveIdentifier(filename.to_string()))?; + + Ok(Self { + name: record.name.as_normalized().to_string(), + version: record.version.to_string(), + build_string: record.build.clone(), + sha256: record.sha256, + md5: record.md5, + extension: archive_identifier.archive_type.extension().to_string(), + }) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum CacheKeyError { + #[error("could not identify the archive type from the name: {0}")] + InvalidArchiveIdentifier(String), +} + +impl Display for CacheKey { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + // we need to use either sha256 or md5 hash to display the key + // if both are none, we ignore them + let display_key = match (self.sha256(), self.md5()) { + (Some(sha256), _) => format!("-{sha256:x}"), + (_, Some(md5)) => format!("-{md5:x}"), + _ => "".to_string(), + }; + + write!( + f, + "{}-{}-{}{}{}", + &self.name, &self.version, &self.build_string, display_key, self.extension + ) + } +} diff --git a/crates/rattler_cache/src/run_exports_cache/download.rs b/crates/rattler_cache/src/run_exports_cache/download.rs new file mode 100644 index 0000000000..e9f8d4a619 --- /dev/null +++ b/crates/rattler_cache/src/run_exports_cache/download.rs @@ -0,0 +1,66 @@ +use std::sync::Arc; + +use ::tokio::io::{AsyncSeekExt, AsyncWriteExt}; +use fs_err::tokio; +use futures::StreamExt; +use rattler_package_streaming::DownloadReporter; +use tempfile::NamedTempFile; +use url::Url; + +/// Download the contents of the archive from the specified remote location +/// and store it in a temporary file. +pub(crate) async fn download( + client: reqwest_middleware::ClientWithMiddleware, + url: Url, + suffix: &str, + reporter: Option>, +) -> Result { + let temp_file = NamedTempFile::with_suffix(suffix)?; + + // Send the request for the file + let response = client.get(url.clone()).send().await?.error_for_status()?; + + if let Some(reporter) = &reporter { + reporter.on_download_start(); + } + + let total_bytes = response.content_length(); + let (tmp_file_handle, tmp_path) = temp_file.into_parts(); + // Convert the named temp file into a tokio file + let mut file = tokio::File::from_std(fs_err::File::from_parts(tmp_file_handle, &tmp_path)); + + let mut stream = response.bytes_stream(); + + let mut bytes_received = 0; + while let Some(chunk_result) = stream.next().await { + let chunk = chunk_result?; + + if let Some(reporter) = &reporter { + bytes_received += chunk.len() as u64; + reporter.on_download_progress(bytes_received, total_bytes); + } + file.write_all(&chunk).await?; + } + + file.flush().await?; + + file.rewind().await?; + + let file_handle = file.into_parts().0.into_std().await; + + Ok(NamedTempFile::from_parts(file_handle, tmp_path)) +} + +/// An error that can occur when downloading an archive. +#[derive(thiserror::Error, Debug)] +#[allow(missing_docs)] +pub enum DownloadError { + #[error("an io error occurred: {0}")] + Io(#[from] std::io::Error), + + #[error(transparent)] + ReqwestMiddleware(#[from] ::reqwest_middleware::Error), + + #[error(transparent)] + Reqwest(#[from] ::reqwest::Error), +} diff --git a/crates/rattler_cache/src/run_exports_cache/mod.rs b/crates/rattler_cache/src/run_exports_cache/mod.rs new file mode 100644 index 0000000000..627b0a7975 --- /dev/null +++ b/crates/rattler_cache/src/run_exports_cache/mod.rs @@ -0,0 +1,684 @@ +//! This module provides functionality to cache extracted Conda packages. See +//! [`RunExportsCache`]. + +use std::{ + fmt::Debug, + future::Future, + io::Seek, + path::{Path, PathBuf}, + sync::Arc, + time::{Duration, SystemTime}, +}; + +use dashmap::DashMap; +use download::DownloadError; +use fs_err::tokio as tokio_fs; +use parking_lot::Mutex; +use rattler_conda_types::package::{PackageFile, RunExportsJson}; +use rattler_networking::retry_policies::{DoNotRetryPolicy, RetryDecision, RetryPolicy}; +use rattler_package_streaming::{DownloadReporter, ExtractError}; +use tempfile::{NamedTempFile, PersistError}; +use tracing::instrument; +use url::Url; + +mod cache_key; +mod download; + +pub use cache_key::{CacheKey, CacheKeyError}; + +use crate::package_cache::CacheReporter; + +/// A [`RunExportsCache`] manages a cache of `run_exports.json` +/// +/// The store does not provide an implementation to get the data into the store. +/// Instead, this is left up to the user when the `run_exports.json` is requested. If the +/// `run_exports.json` is found in the cache it is returned immediately. However, if the +/// cache is missing a user defined function is called to populate the cache. This +/// separates the corners between caching and fetching of the content. +#[derive(Clone)] +pub struct RunExportsCache { + inner: Arc, +} + +/// A cache entry that contains the path to the package and the `run_exports.json` +#[derive(Clone, Debug)] +pub struct CacheEntry { + /// The `run_exports.json` of the package. + pub(crate) run_exports: Option, + /// The path to the file on disk. + pub(crate) path: PathBuf, +} + +impl CacheEntry { + /// Create a new cache entry. + pub(crate) fn new(run_exports: Option, path: PathBuf) -> Self { + Self { run_exports, path } + } + + /// Returns the `run_exports.json` of the package. + pub fn run_exports(&self) -> Option { + self.run_exports.clone() + } + + /// Returns the path to the file on disk. + pub fn path(&self) -> &Path { + &self.path + } +} + +#[derive(Default)] +struct RunExportsCacheInner { + path: PathBuf, + run_exports: DashMap>>>, +} + +/// A key that defines the actual location of the package in the cache. +#[derive(Debug, Hash, Clone, Eq, PartialEq)] +pub struct BucketKey { + name: String, + version: String, + build_string: String, + sha256_string: String, +} + +impl From for BucketKey { + fn from(key: CacheKey) -> Self { + Self { + name: key.name.clone(), + version: key.version.clone(), + build_string: key.build_string.clone(), + sha256_string: key.sha256_str(), + } + } +} + +impl RunExportsCache { + /// Constructs a new [`RunExportsCache`] located at the specified path. + pub fn new(path: impl Into) -> Self { + Self { + inner: Arc::new(RunExportsCacheInner { + path: path.into(), + run_exports: DashMap::default(), + }), + } + } + + /// Returns the directory that contains the specified package. + /// + /// If the package was previously successfully fetched and stored in the + /// cache the directory containing the data is returned immediately. If + /// the package was not previously fetch the filesystem is checked to + /// see if a directory with valid package content exists. Otherwise, the + /// user provided `fetch` function is called to populate the cache. + /// + /// If the package is already being fetched by another task/thread the + /// request is coalesced. No duplicate fetch is performed. + pub async fn get_or_fetch( + &self, + cache_key: &CacheKey, + fetch: F, + ) -> Result + where + F: (Fn() -> Fut) + Send + 'static, + Fut: Future, E>> + Send + 'static, + E: std::error::Error + Send + Sync + 'static, + { + let cache_path = self.inner.path.join(cache_key.to_string()); + let cache_entry = self + .inner + .run_exports + .entry(cache_key.clone().into()) + .or_default() + .clone(); + + // Acquire the entry. From this point on we can be sure that only one task is + // accessing the cache entry. + let mut entry = cache_entry.lock().await; + + // Check if the cache entry is already stored in the cache. + if let Some(run_exports) = entry.as_ref() { + return Ok(run_exports.clone()); + } + + // Otherwise, defer to populate method to fill our cache. + let run_exports_file = fetch() + .await + .map_err(|e| RunExportsCacheError::Fetch(Arc::new(e)))?; + + if let Some(parent_dir) = cache_path.parent() { + if !parent_dir.exists() { + tokio_fs::create_dir_all(parent_dir).await?; + } + } + + let run_exports = if let Some(file) = run_exports_file { + file.persist(&cache_path)?; + + let run_exports_str = tokio_fs::read_to_string(&cache_path).await?; + Some(RunExportsJson::from_str(&run_exports_str)?) + } else { + None + }; + + let cache_entry = CacheEntry::new(run_exports, cache_path); + + entry.replace(cache_entry.clone()); + + Ok(cache_entry) + } + + /// Returns the directory that contains the specified package. + /// + /// This is a convenience wrapper around `get_or_fetch` which fetches the + /// package from the given URL if the package could not be found in the + /// cache. + pub async fn get_or_fetch_from_url( + &self, + cache_key: &CacheKey, + url: Url, + client: reqwest_middleware::ClientWithMiddleware, + reporter: Option>, + ) -> Result { + self.get_or_fetch_from_url_with_retry(cache_key, url, client, DoNotRetryPolicy, reporter) + .await + } + + /// Returns the directory that contains the specified package. + /// + /// This is a convenience wrapper around `get_or_fetch` which fetches the + /// package from the given URL if the package could not be found in the + /// cache. + /// + /// This function assumes that the `client` is already configured with a + /// retry middleware that will retry any request that fails. This function + /// uses the passed in `retry_policy` if, after the request has been sent + /// and the response is successful, streaming of the package data fails + /// and the whole request must be retried. + #[instrument(skip_all, fields(url=%url))] + pub async fn get_or_fetch_from_url_with_retry( + &self, + cache_key: &CacheKey, + url: Url, + client: reqwest_middleware::ClientWithMiddleware, + retry_policy: impl RetryPolicy + Send + 'static + Clone, + reporter: Option>, + ) -> Result { + let request_start = SystemTime::now(); + // Convert into cache key + let download_reporter = reporter.clone(); + + let extension = cache_key.extension.clone(); + // Get or fetch the package, using the specified fetch function + self.get_or_fetch(cache_key, move || { + + #[derive(Debug, thiserror::Error)] + enum FetchError{ + #[error(transparent)] + Download(#[from] DownloadError), + + #[error(transparent)] + Extract(#[from] ExtractError), + + #[error(transparent)] + Io(#[from] std::io::Error), + + } + + let url = url.clone(); + let client = client.clone(); + let retry_policy = retry_policy.clone(); + let download_reporter = download_reporter.clone(); + let extension = extension.clone(); + + async move { + let mut current_try = 0; + // Retry until the retry policy says to stop + loop { + current_try += 1; + tracing::debug!("downloading {}", &url); + // Extract the package + let result = crate::run_exports_cache::download::download( + client.clone(), + url.clone(), + &extension, + download_reporter.clone().map(|reporter| Arc::new(PassthroughReporter { + reporter, + index: Mutex::new(None), + }) as Arc::), + ) + .await; + + // Extract any potential error + let err = match result { + Ok(result) => { + let temp_file = NamedTempFile::new()?; + // Clone the file handler to be able to pass it to the blocking task + let mut file_handler = temp_file.as_file().try_clone()?; + // now extract run_exports.json from the archive without unpacking + let result = simple_spawn_blocking::tokio::run_blocking_task(move || { + rattler_package_streaming::seek::extract_package_file::(result.as_file(), result.path(), &mut file_handler)?; + file_handler.rewind()?; + Ok(()) + }).await; + + match result { + Ok(()) => { + return Ok(Some(temp_file)); + }, + Err(err) => { + if matches!(err, ExtractError::MissingComponent) { + return Ok(None); + } + return Err(FetchError::Extract(err)); + + } + } + }, + Err(err) => FetchError::Download(err), + }; + + // Only retry on io errors. We assume that the user has + // middleware installed that handles connection retries. + if !matches!(&err, FetchError::Download(_)) { + return Err(err); + } + + // Determine whether to retry based on the retry policy + let execute_after = match retry_policy.should_retry(request_start, current_try) { + RetryDecision::Retry { execute_after } => execute_after, + RetryDecision::DoNotRetry => return Err(err), + }; + let duration = execute_after.duration_since(SystemTime::now()).unwrap_or(Duration::ZERO); + + // Wait for a second to let the remote service restore itself. This increases the + // chance of success. + tracing::warn!( + "failed to download and extract {} {}. Retry #{}, Sleeping {:?} until the next attempt...", + &url, + // destination.display(), + err, + current_try, + duration + ); + tokio::time::sleep(duration).await; + } + } + }) + .await + } +} + +/// An error that might be returned from one of the caching function of the +/// [`RunExportsCache`]. +#[derive(Debug, thiserror::Error)] +pub enum RunExportsCacheError { + /// An error occurred while fetching the package. + #[error(transparent)] + Fetch(#[from] Arc), + + /// A locking error occurred + #[error("{0}")] + Lock(String, #[source] std::io::Error), + + /// An IO error occurred + #[error("{0}")] + Io(#[from] std::io::Error), + + /// An error occurred while persisting the temp file + #[error("{0}")] + Persist(#[from] PersistError), + + /// An error occured when extracting `run_exports` from archive + #[error(transparent)] + Extract(#[from] ExtractError), + + /// An error occured when serializing `run_exports` + #[error(transparent)] + Serialize(#[from] serde_json::Error), + + /// The operation was cancelled + #[error("operation was cancelled")] + Cancelled, +} + +struct PassthroughReporter { + reporter: Arc, + index: Mutex>, +} + +impl DownloadReporter for PassthroughReporter { + fn on_download_start(&self) { + let index = self.reporter.on_download_start(); + assert!( + self.index.lock().replace(index).is_none(), + "on_download_start was called multiple times" + ); + } + + fn on_download_progress(&self, bytes_downloaded: u64, total_bytes: Option) { + let index = self.index.lock().expect("on_download_start was not called"); + self.reporter + .on_download_progress(index, bytes_downloaded, total_bytes); + } + + fn on_download_complete(&self) { + let index = self + .index + .lock() + .take() + .expect("on_download_start was not called"); + self.reporter.on_download_completed(index); + } +} + +#[cfg(test)] +mod test { + use std::{future::IntoFuture, net::SocketAddr, str::FromStr, sync::Arc}; + + use assert_matches::assert_matches; + use axum::{ + body::Body, + extract::State, + http::{Request, StatusCode}, + middleware, + middleware::Next, + response::{Redirect, Response}, + routing::get, + Router, + }; + + use rattler_conda_types::{PackageName, PackageRecord, Version}; + use rattler_digest::{parse_digest_from_hex, Sha256}; + use rattler_networking::retry_policies::{DoNotRetryPolicy, ExponentialBackoffBuilder}; + use reqwest::Client; + use reqwest_middleware::{ClientBuilder, ClientWithMiddleware}; + use reqwest_retry::RetryTransientMiddleware; + use tempfile::tempdir; + use tokio::sync::Mutex; + + use url::Url; + + use crate::run_exports_cache::CacheKey; + + use super::RunExportsCache; + + #[tokio::test] + pub async fn test_run_exports_cache_when_empty() { + // This archive does not contain a run_exports.json + // so we expect the cache to return None + let package_url = Url::parse("https://conda.anaconda.org/robostack/linux-64/ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2").unwrap(); + + let cache_dir = tempdir().unwrap().into_path(); + + let cache = RunExportsCache::new(&cache_dir); + + let mut pkg_record = PackageRecord::new( + PackageName::from_str("ros-noetic-rosbridge-suite").unwrap(), + Version::from_str("0.11.14").unwrap(), + "py39h6fdeb60_14".to_string(), + ); + pkg_record.sha256 = Some( + parse_digest_from_hex::( + "4dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc8", + ) + .unwrap(), + ); + + let cache_key = CacheKey::create( + &pkg_record, + "ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2", + ) + .unwrap(); + + // Get the package to the cache + let cached_run_exports = cache + .get_or_fetch_from_url( + &cache_key, + package_url.clone(), + ClientWithMiddleware::from(Client::new()), + None, + ) + .await + .unwrap(); + + assert!(cached_run_exports.run_exports.is_none()); + } + + #[tokio::test] + pub async fn test_run_exports_cache_when_present() { + // This archive contains a run_exports.json + // so we expect the cache to return it + let package_url = + Url::parse("https://repo.prefix.dev/conda-forge/linux-64/zlib-1.3.1-hb9d3cd8_2.conda") + .unwrap(); + + let cache_dir = tempdir().unwrap().into_path(); + + let cache = RunExportsCache::new(&cache_dir); + + let pkg_record = PackageRecord::new( + PackageName::from_str("zlib").unwrap(), + Version::from_str("1.3.1").unwrap(), + "hb9d3cd8_2".to_string(), + ); + + let cache_key = CacheKey::create(&pkg_record, "zlib-1.3.1-hb9d3cd8_2.conda").unwrap(); + + // Get the package to the cache + let cached_run_exports = cache + .get_or_fetch_from_url( + &cache_key, + package_url.clone(), + ClientWithMiddleware::from(Client::new()), + None, + ) + .await + .unwrap(); + + assert!(cached_run_exports.run_exports.is_some()); + } + + /// A helper middleware function that fails the first two requests. + async fn fail_the_first_two_requests( + State(count): State>>, + req: Request, + next: Next, + ) -> Result { + let count = { + let mut count = count.lock().await; + *count += 1; + *count + }; + + println!("Running middleware for request #{count} for {}", req.uri()); + if count <= 2 { + println!("Discarding request!"); + return Err(StatusCode::INTERNAL_SERVER_ERROR); + } + + // requires the http crate to get the header name + Ok(next.run(req).await) + } + + enum Middleware { + FailTheFirstTwoRequests, + } + + async fn redirect_to_prefix( + axum::extract::Path((channel, subdir, file)): axum::extract::Path<(String, String, String)>, + ) -> Redirect { + Redirect::permanent(&format!("https://prefix.dev/{channel}/{subdir}/{file}")) + } + + async fn test_flaky_package_cache( + archive_name: &str, + package_record: &PackageRecord, + middleware: Middleware, + ) { + // Construct a service that serves raw files from the test directory + // build our application with a route + let router = Router::new() + // `GET /` goes to `root` + .route("/{channel}/{subdir}/{file}", get(redirect_to_prefix)); + + // Construct a router that returns data from the static dir but fails the first + // try. + let request_count = Arc::new(Mutex::new(0)); + + let router = match middleware { + Middleware::FailTheFirstTwoRequests => router.layer(middleware::from_fn_with_state( + request_count.clone(), + fail_the_first_two_requests, + )), + }; + + // Construct the server that will listen on localhost but with a *random port*. + // The random port is very important because it enables creating + // multiple instances at the same time. We need this to be able to run + // tests in parallel. + let addr = SocketAddr::new([127, 0, 0, 1].into(), 0); + let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let service = router.into_make_service(); + tokio::spawn(axum::serve(listener, service).into_future()); + + let packages_dir = tempdir().unwrap(); + let cache = RunExportsCache::new(packages_dir.path()); + + let server_url = Url::parse(&format!("http://localhost:{}", addr.port())).unwrap(); + + let client = ClientBuilder::new(Client::default()).build(); + + let cache_key = CacheKey::create(package_record, archive_name).unwrap(); + + // Do the first request without + let result = cache + .get_or_fetch_from_url_with_retry( + &cache_key, + server_url.join(archive_name).unwrap(), + client.clone(), + DoNotRetryPolicy, + None, + ) + .await; + + // First request without retry policy should fail + assert_matches!(result, Err(_)); + { + let request_count_lock = request_count.lock().await; + assert_eq!(*request_count_lock, 1, "Expected there to be 1 request"); + } + + let retry_policy = ExponentialBackoffBuilder::default().build_with_max_retries(3); + let client = ClientBuilder::from_client(client) + .with(RetryTransientMiddleware::new_with_policy(retry_policy)) + .build(); + + // The second one should fail after the 2nd try + let result = cache + .get_or_fetch_from_url_with_retry( + &cache_key, + server_url.join(archive_name).unwrap(), + client, + retry_policy, + None, + ) + .await; + + assert!(result.is_ok()); + { + let request_count_lock = request_count.lock().await; + assert_eq!(*request_count_lock, 3, "Expected there to be 3 requests"); + } + } + + #[tokio::test] + async fn test_flaky() { + let tar_bz2 = "conda-forge/win-64/conda-22.9.0-py310h5588dad_2.tar.bz2"; + let conda = "conda-forge/win-64/conda-22.11.1-py38haa244fe_1.conda"; + + let tar_record = PackageRecord::new( + PackageName::from_str("conda").unwrap(), + Version::from_str("22.9.0").unwrap(), + "py310h5588dad_2".to_string(), + ); + + let conda_record = PackageRecord::new( + PackageName::from_str("conda").unwrap(), + Version::from_str("22.11.1").unwrap(), + "py38haa244fe_1".to_string(), + ); + + test_flaky_package_cache(tar_bz2, &tar_record, Middleware::FailTheFirstTwoRequests).await; + test_flaky_package_cache(conda, &conda_record, Middleware::FailTheFirstTwoRequests).await; + } + + #[tokio::test] + // Test if packages with different sha's are replaced even though they share the + // same BucketKey. + pub async fn test_package_cache_key_with_sha() { + let package_url = Url::parse("https://conda.anaconda.org/robostack/linux-64/ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2").unwrap(); + + let mut pkg_record = PackageRecord::new( + PackageName::from_str("ros-noetic-rosbridge-suite").unwrap(), + Version::from_str("0.11.14").unwrap(), + "py39h6fdeb60_14".to_string(), + ); + pkg_record.sha256 = Some( + parse_digest_from_hex::( + "4dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc8", + ) + .unwrap(), + ); + + // Create a temporary directory to store the packages + let packages_dir = tempdir().unwrap(); + let cache = RunExportsCache::new(packages_dir.path()); + + let cache_key = CacheKey::create( + &pkg_record, + "ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2", + ) + .unwrap(); + + // Get the package to the cache + let first_cache_path = cache + .get_or_fetch_from_url( + &cache_key, + package_url.clone(), + ClientWithMiddleware::from(Client::new()), + None, + ) + .await + .unwrap(); + + // Change the sha256 of the package + // And expect the package to be replaced + let new_sha = parse_digest_from_hex::( + "5dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc9", + ) + .unwrap(); + pkg_record.sha256 = Some(new_sha); + + let cache_key = CacheKey::create( + &pkg_record, + "ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2", + ) + .unwrap(); + + // Get the package again + // and verify that the package was replaced + let second_package_cache = cache + .get_or_fetch_from_url( + &cache_key, + package_url.clone(), + ClientWithMiddleware::from(Client::new()), + None, + ) + .await + .unwrap(); + + assert_ne!(first_cache_path.path(), second_package_cache.path()); + } +} diff --git a/crates/rattler_package_streaming/Cargo.toml b/crates/rattler_package_streaming/Cargo.toml index 7e934aa3f9..9d06ef3619 100644 --- a/crates/rattler_package_streaming/Cargo.toml +++ b/crates/rattler_package_streaming/Cargo.toml @@ -22,6 +22,7 @@ rattler_networking = { path = "../rattler_networking", version = "0.22.4", defau rattler_redaction = { version = "0.1.6", path = "../rattler_redaction", features = ["reqwest", "reqwest-middleware"] } reqwest = { workspace = true, features = ["stream"], optional = true } reqwest-middleware = { workspace = true, optional = true } +simple_spawn_blocking = { version = "1.0.0", path = "../simple_spawn_blocking", features = ["tokio"] } serde_json = { workspace = true } tar = { workspace = true } tempfile = { workspace = true } diff --git a/crates/rattler_package_streaming/src/lib.rs b/crates/rattler_package_streaming/src/lib.rs index e64fafefdd..eee0c0ef7c 100644 --- a/crates/rattler_package_streaming/src/lib.rs +++ b/crates/rattler_package_streaming/src/lib.rs @@ -2,6 +2,7 @@ //! This crate provides the ability to extract a Conda package archive or specific parts of it. +use simple_spawn_blocking::Cancelled; use std::path::PathBuf; use zip::result::ZipError; @@ -62,6 +63,12 @@ impl From for ExtractError { } } +impl From for ExtractError { + fn from(_value: Cancelled) -> Self { + Self::Cancelled + } +} + #[cfg(feature = "reqwest")] impl From<::reqwest_middleware::Error> for ExtractError { fn from(err: ::reqwest_middleware::Error) -> Self { diff --git a/crates/rattler_package_streaming/src/seek.rs b/crates/rattler_package_streaming/src/seek.rs index 7af6f984c8..6a1bd3d46a 100644 --- a/crates/rattler_package_streaming/src/seek.rs +++ b/crates/rattler_package_streaming/src/seek.rs @@ -6,6 +6,7 @@ use crate::ExtractError; use rattler_conda_types::package::ArchiveType; use rattler_conda_types::package::PackageFile; use std::fs::File; +use std::io::Write; use std::{ io::{Read, Seek, SeekFrom}, path::Path, @@ -84,6 +85,26 @@ fn get_file_from_archive( Err(ExtractError::MissingComponent) } +/// Read a package file content from archive based on the path +fn read_package_file_content<'a>( + file: impl Read + Seek + 'a, + path: impl AsRef, + package_path: impl AsRef, +) -> Result, ExtractError> { + match ArchiveType::try_from(&path).ok_or(ExtractError::UnsupportedArchiveType)? { + ArchiveType::TarBz2 => { + let mut archive = stream_tar_bz2(file); + let buf = get_file_from_archive(&mut archive, package_path.as_ref())?; + Ok(buf) + } + ArchiveType::Conda => { + let mut info_archive = stream_conda_info(file).unwrap(); + let buf = get_file_from_archive(&mut info_archive, package_path.as_ref())?; + Ok(buf) + } + } +} + /// Read a package file from archive /// Note: If you want to extract multiple `info/*` files then this will be slightly /// slower than manually iterating over the archive entries with @@ -100,19 +121,23 @@ fn get_file_from_archive( pub fn read_package_file(path: impl AsRef) -> Result { // stream extract the file from a package let file = File::open(&path)?; + let content = read_package_file_content(&file, &path, P::package_path())?; - match ArchiveType::try_from(&path).ok_or(ExtractError::UnsupportedArchiveType)? { - ArchiveType::TarBz2 => { - let mut archive = stream_tar_bz2(file); - let buf = get_file_from_archive(&mut archive, P::package_path())?; - P::from_str(&String::from_utf8_lossy(&buf)) - .map_err(|e| ExtractError::ArchiveMemberParseError(P::package_path().to_owned(), e)) - } - ArchiveType::Conda => { - let mut info_archive = stream_conda_info(file).unwrap(); - let buf = get_file_from_archive(&mut info_archive, P::package_path())?; - P::from_str(&String::from_utf8_lossy(&buf)) - .map_err(|e| ExtractError::ArchiveMemberParseError(P::package_path().to_owned(), e)) - } - } + P::from_str(&String::from_utf8_lossy(&content)) + .map_err(|e| ExtractError::ArchiveMemberParseError(P::package_path().to_owned(), e)) +} + +/// Get a [`PackageFile`] from temporary archive and extract it to a writer +pub fn extract_package_file<'a, P: PackageFile>( + reader: impl Read + Seek + 'a, + location: &Path, + writer: &mut impl Write, +) -> Result<(), ExtractError> { + let content = read_package_file_content(reader, location, P::package_path())?; + + writer.write_all(&content)?; + + writer.flush()?; + + Ok(()) } diff --git a/crates/rattler_repodata_gateway/src/fetch/mod.rs b/crates/rattler_repodata_gateway/src/fetch/mod.rs index 7e09585f94..43bda98fa5 100644 --- a/crates/rattler_repodata_gateway/src/fetch/mod.rs +++ b/crates/rattler_repodata_gateway/src/fetch/mod.rs @@ -725,7 +725,10 @@ async fn stream_and_decode_to_file( // Clone the file handle and create a hashing writer so we can compute a hash // while the content is being written to disk. let file = tokio_fs::File::from_std(fs_err::File::from_parts( - temp_file.as_file().try_clone().unwrap(), + temp_file + .as_file() + .try_clone() + .map_err(FetchRepoDataError::IoError)?, temp_file.path(), )); let mut hashing_file_writer = HashingWriter::<_, Blake2b256>::new(file); diff --git a/crates/simple_spawn_blocking/src/lib.rs b/crates/simple_spawn_blocking/src/lib.rs index d82f4737fe..753ed5c408 100644 --- a/crates/simple_spawn_blocking/src/lib.rs +++ b/crates/simple_spawn_blocking/src/lib.rs @@ -1,4 +1,4 @@ -//! A simpel crate that makes it more ergonomic to spawn blocking tasks and +//! A simple crate that makes it more ergonomic to spawn blocking tasks and //! await their completion. #[cfg(feature = "tokio")]