Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{path::Path, str::FromStr, sync::Arc, time::SystemTime};
use super::ShardedRepodata;
use crate::{
fetch::CacheAction,
gateway::sharded_subdir::decode_zst_bytes_async,
gateway::{error::SubdirNotFoundError, sharded_subdir::decode_zst_bytes_async},
reporter::{DownloadReporter, ResponseReporterExt},
utils::url_to_cache_filename,
GatewayError, Reporter,
Expand All @@ -12,8 +12,9 @@ use async_fd_lock::{LockWrite, RwLockWriteGuard};
use bytes::Bytes;
use fs_err::tokio as tokio_fs;
use futures::{future::OptionFuture, TryFutureExt};
use http::{HeaderMap, Method, Uri};
use http::{HeaderMap, Method, StatusCode, Uri};
use http_cache_semantics::{AfterResponse, BeforeRequest, CachePolicy, RequestLike};
use rattler_conda_types::Channel;
use rattler_networking::LazyClient;
use rattler_redaction::Redact;
use reqwest::Response;
Expand All @@ -27,6 +28,20 @@ use url::Url;

const REPODATA_SHARDS_FILENAME: &str = "repodata_shards.msgpack.zst";

/// Creates a `SubdirNotFoundError` for when sharded repodata is not available.
fn create_subdir_not_found_error(channel_base_url: &Url) -> GatewayError {
GatewayError::SubdirNotFoundError(Box::new(SubdirNotFoundError {
channel: Channel::from_url(channel_base_url.clone()),
subdir: channel_base_url
.path_segments()
.and_then(|mut s| s.next_back())
.unwrap_or("unknown")
.to_string(),
source: std::io::Error::new(std::io::ErrorKind::NotFound, "sharded repodata not found")
.into(),
}))
}

// Fetches the shard index from the url or read it from the cache.
pub async fn fetch_index(
client: LazyClient,
Expand Down Expand Up @@ -149,6 +164,12 @@ pub async fn fetch_index(
// Try reading the cached file
if cache_action != CacheAction::NoCache {
if let Ok(cache_header) = read_cached_index(&mut cache_reader).await {
// Check if the cache indicates the resource was not found (404)
if cache_header.not_found {
tracing::debug!("cached 404 for sharded index at {channel_base_url}");
return Err(create_subdir_not_found_error(channel_base_url));
}

// If we are in cache-only mode we can't fetch the index from the server
if cache_action == CacheAction::ForceCacheOnly {
if let Ok(shard_index) = read_shard_index_from_reader(&mut cache_reader).await {
Expand Down Expand Up @@ -207,6 +228,34 @@ pub async fn fetch_index(
.map(|r| (r, r.on_download_start(&shards_url)));
let response = client.client().execute(request).await?;

// Check if the resource was not found (404)
if response.status() == StatusCode::NOT_FOUND {
tracing::debug!(
"sharded index not found (404) at {channel_base_url}, caching this result"
);

// Cache the 404 response
let policy = CachePolicy::new(&canonical_request, &response);
write_not_found_cache(cache_reader.into_inner().inner_mut(), policy)
.await
.map_err(|e| {
GatewayError::IoError(
format!(
"failed to write 404 cache for shard index to {}",
cache_path.display()
),
e,
)
})?;

if let Some((reporter, index)) = download_reporter {
reporter.on_download_complete(response.url(), index);
}

// Return SubdirNotFoundError to trigger fallback
return Err(create_subdir_not_found_error(channel_base_url));
}

match cache_header.policy.after_response(
&state_request,
&response,
Expand Down Expand Up @@ -297,6 +346,28 @@ pub async fn fetch_index(
)
.await?;

// Check if the resource was not found (404)
if response.status() == StatusCode::NOT_FOUND {
tracing::debug!("sharded index not found (404) at {channel_base_url}, caching this result");

// Cache the 404 response
let policy = CachePolicy::new(&canonical_request, &response);
write_not_found_cache(cache_reader.into_inner().inner_mut(), policy)
.await
.map_err(|e| {
GatewayError::IoError(
format!(
"failed to write 404 cache for shard index to {}",
cache_path.display()
),
e,
)
})?;

// Return SubdirNotFoundError to trigger fallback
return Err(create_subdir_not_found_error(channel_base_url));
}

let policy = CachePolicy::new(&canonical_request, &response);
from_response(
cache_reader.into_inner(),
Expand All @@ -312,14 +383,14 @@ pub async fn fetch_index(
/// Magic number that identifies the cache file format.
const MAGIC_NUMBER: &[u8] = b"SHARD-CACHE-V1";

/// Writes the shard index cache to disk.
pub async fn write_shard_index_cache(
/// Writes cache data to disk with the given header and optional body.
async fn write_cache(
cache_file: &mut File,
policy: CachePolicy,
decoded_bytes: Bytes,
cache_header: CacheHeader,
body: Option<&[u8]>,
) -> std::io::Result<()> {
let cache_header =
rmp_serde::encode::to_vec(&CacheHeader { policy }).expect("failed to encode cache header");
let encoded_header =
rmp_serde::encode::to_vec(&cache_header).expect("failed to encode cache header");

// Move to the start of the file
cache_file.rewind().await?;
Expand All @@ -328,10 +399,15 @@ pub async fn write_shard_index_cache(
let mut writer = BufWriter::new(cache_file);
writer.write_all(MAGIC_NUMBER).await?;
writer
.write_all(&(cache_header.len() as u32).to_le_bytes())
.write_all(&(encoded_header.len() as u32).to_le_bytes())
.await?;
writer.write_all(&cache_header).await?;
writer.write_all(decoded_bytes.as_ref()).await?;
writer.write_all(&encoded_header).await?;

// Write body if present
if let Some(body_bytes) = body {
writer.write_all(body_bytes).await?;
}

writer.flush().await?;

// Truncate the file to the correct size
Expand All @@ -342,6 +418,36 @@ pub async fn write_shard_index_cache(
Ok(())
}

/// Writes the shard index cache to disk.
pub async fn write_shard_index_cache(
cache_file: &mut File,
policy: CachePolicy,
decoded_bytes: Bytes,
) -> std::io::Result<()> {
write_cache(
cache_file,
CacheHeader {
policy,
not_found: false,
},
Some(decoded_bytes.as_ref()),
)
.await
}

/// Writes a 404 (not found) marker to the cache file.
async fn write_not_found_cache(cache_file: &mut File, policy: CachePolicy) -> std::io::Result<()> {
write_cache(
cache_file,
CacheHeader {
policy,
not_found: true,
},
None,
)
.await
}

/// Read the shard index from a reader and deserialize it.
pub async fn read_shard_index_from_reader<R: AsyncRead + Unpin>(
reader: &mut BufReader<R>,
Expand All @@ -366,6 +472,9 @@ pub async fn read_shard_index_from_reader<R: AsyncRead + Unpin>(
#[derive(Clone, Debug, Serialize, Deserialize)]
struct CacheHeader {
pub policy: CachePolicy,
/// Indicates whether the resource was not found (404) on the remote.
#[serde(default)]
pub not_found: bool,
}

/// Try reading the cache file from disk.
Expand Down