Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ pub enum AsyncHttpRangeReaderError {
MemoryMapError(#[source] Arc<std::io::Error>),

/// Error from `http-content-range`
#[error("Invalid Content-Range header: {0}")]
#[error("invalid Content-Range header: {0}")]
ContentRangeParser(String),

/// The server returned fewer or more bytes than the range request asked for
#[error("expected {expected} bytes from range response, got {actual}")]
ContentLengthMismatch { expected: u64, actual: u64 },
Comment thread
konstin marked this conversation as resolved.
Outdated
}

impl From<std::io::Error> for AsyncHttpRangeReaderError {
Expand Down
184 changes: 170 additions & 14 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ struct Inner {
streamer_state_rx: WatchStream<StreamerState>,

/// A channel sender to send range requests to the background task
///
/// Contract: All ranges sent must be inside the range of the memory map
request_tx: tokio::sync::mpsc::Sender<Range<u64>>,

/// An optional object to reserve a slot in the `request_tx` sender. When in the process of
Expand Down Expand Up @@ -196,7 +198,7 @@ impl AsyncHttpRangeReader {
}

/// Initialize the reader from [`AsyncHttpRangeReader::initial_tail_request`] (or a user
/// provided response that also has a range of bytes from the end as body)
/// provided range response)
pub async fn from_tail_response(
Comment thread
konstin marked this conversation as resolved.
client: impl Into<reqwest_middleware::ClientWithMiddleware>,
tail_request_response: Response,
Expand All @@ -212,10 +214,11 @@ impl AsyncHttpRangeReader {
.ok_or(AsyncHttpRangeReaderError::ContentRangeMissing)?
.to_str()
.map_err(|_err| AsyncHttpRangeReaderError::ContentRangeMissing)?;
// The parser ensures finish < complete_length
let content_range = ContentRange::parse(content_range_header).ok_or_else(|| {
AsyncHttpRangeReaderError::ContentRangeParser(content_range_header.to_string())
})?;
let (start, finish, complete_length) = match content_range {
let (start, end_inclusive, complete_length) = match content_range {
ContentRange::Bytes(ContentRangeBytes {
first_byte,
last_byte,
Expand All @@ -236,8 +239,7 @@ impl AsyncHttpRangeReader {
let memory_map_slice =
unsafe { std::slice::from_raw_parts(memory_map.as_ptr(), memory_map.len()) };

let requested_range =
SparseRange::from_range(complete_length - (finish - start)..complete_length);
let requested_range = SparseRange::from_range(start..end_inclusive + 1);

// adding more than 2 entries to the channel would block the sender. I assumed two would
// suffice because I would want to 1) prefetch a certain range and 2) read stuff via the
Expand All @@ -249,7 +251,7 @@ impl AsyncHttpRangeReader {
client,
url,
extra_headers,
Some((tail_request_response, start)),
Some((tail_request_response, start, end_inclusive + 1)),
memory_map,
state_tx,
request_rx,
Expand All @@ -259,7 +261,7 @@ impl AsyncHttpRangeReader {
let mut streamer_state = StreamerState::default();
streamer_state
.requested_ranges
.push(complete_length - (finish - start)..complete_length);
.push(start..end_inclusive + 1);

let reader = Self {
len: memory_map_slice.len() as u64,
Expand Down Expand Up @@ -416,23 +418,22 @@ async fn run_streamer(
client: reqwest_middleware::ClientWithMiddleware,
url: Url,
extra_headers: HeaderMap,
initial_tail_response: Option<(Response, u64)>,
initial_tail_response: Option<(Response, u64, u64)>,
mut memory_map: MmapMut,
mut state_tx: Sender<StreamerState>,
mut request_rx: tokio::sync::mpsc::Receiver<Range<u64>>,
) {
let mut state = StreamerState::default();

if let Some((response, response_start)) = initial_tail_response {
if let Some((response, start, end_exclusive)) = initial_tail_response {
// Add the initial range to the state
state
.requested_ranges
.push(response_start..memory_map.len() as u64);
state.requested_ranges.push(start..memory_map.len() as u64);
Comment thread
konstin marked this conversation as resolved.
Outdated

// Stream the initial data in memory
if !stream_response(
response,
response_start,
start,
end_exclusive,
&mut memory_map,
&mut state_tx,
&mut state,
Expand Down Expand Up @@ -497,6 +498,7 @@ async fn run_streamer(
if !stream_response(
response,
*range.start(),
*range.end() + 1,
&mut memory_map,
&mut state_tx,
&mut state,
Expand All @@ -512,13 +514,25 @@ async fn run_streamer(
/// Streams the data from the specified response to the memory map updating progress in between.
/// Returns `true` if everything went fine, `false` if anything went wrong. The error state, if any,
/// is stored in `state_tx` so the "frontend" will consume it.
///
/// The response must return bytes for the range of precisely `start..end_exclusive`.
async fn stream_response(
tail_request_response: Response,
mut offset: u64,
start: u64,
end_exclusive: u64,
memory_map: &mut MmapMut,
state_tx: &mut Sender<StreamerState>,
state: &mut StreamerState,
) -> bool {
// Enforce request channel contract
assert!(
(end_exclusive as usize) <= memory_map.len(),
"end is outside of memory map {} > {}",
end_exclusive,
memory_map.len()
);

let mut offset = start;
let mut byte_stream = tail_request_response.bytes_stream();
while let Some(bytes) = byte_stream.next().await {
let bytes = match bytes {
Expand All @@ -534,7 +548,17 @@ async fn stream_response(
let byte_range = offset..offset + bytes.len() as u64;

// Update the offset
offset = byte_range.end;
offset += bytes.len() as u64;
Comment on lines -537 to +610

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I liked the previous version of this line better :)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh 😅


// Prevent the server from sending more bytes than advertised in a response
if offset > end_exclusive {
state.error = Some(AsyncHttpRangeReaderError::ContentLengthMismatch {
expected: end_exclusive - start,
actual: offset - start,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't quite true. The actual content length if the response is chunked and there's no content length header is unknowable. And regardless, if there is a content length, that's what the actual content length would be.

});
let _ = state_tx.send(state.clone());
return false;
}

// Copy the data from the stream to memory
memory_map[byte_range.start as usize..byte_range.end as usize]
Expand All @@ -551,6 +575,16 @@ async fn stream_response(
}
}

// Prevent the server from sending less bytes than advertised in a response
if offset != end_exclusive {
state.error = Some(AsyncHttpRangeReaderError::ContentLengthMismatch {
expected: end_exclusive - start,
actual: offset - start,
});
let _ = state_tx.send(state.clone());
return false;
}

true
}

Expand Down Expand Up @@ -658,7 +692,12 @@ mod test {
use crate::static_directory_server::StaticDirectoryServer;
use assert_matches::assert_matches;
use async_zip::tokio::read::seek::ZipFileReader;
use axum::body::Body;
use axum::extract::Request;
use axum::response::IntoResponse;
use futures::AsyncReadExt;
use reqwest::header;
use reqwest::Method;
use reqwest::{Client, StatusCode};
use rstest::*;
use std::path::Path;
Expand Down Expand Up @@ -854,4 +893,121 @@ mod test {
err, AsyncHttpRangeReaderError::HttpError(err) if err.status() == Some(StatusCode::NOT_FOUND)
);
}

/// Spawn a server where the HEAD response reports `head_size` bytes, and range requests always
/// claim to be `pretend_size` bytes, while actually serving `actual_size`.
async fn spawn_mismatch_server(
head_content_length: usize,
pretend_size: usize,
actual_size: usize,
) -> Url {
let app =
axum::Router::new().fallback(async move |request: Request| match *request.method() {
Method::HEAD => {
let headers = [
(header::CONTENT_LENGTH, head_content_length.to_string()),
(header::ACCEPT_RANGES, "bytes".to_string()),
];
(StatusCode::OK, headers).into_response()
}
Method::GET => {
let range_header = request
.headers()
.get(header::RANGE)
.unwrap()
.to_str()
.unwrap()
.to_string();

let range_spec = range_header.strip_prefix("bytes=").unwrap();
let (start_str, _end_str) = range_spec.split_once('-').unwrap();
let start = start_str.parse::<usize>().unwrap();
// The end is inclusive
let end = start + pretend_size - 1;

axum::response::Response::builder()
.status(StatusCode::PARTIAL_CONTENT)
// Note that the client ignores this value currently, it only checks the
// actual size
.header(
header::CONTENT_RANGE,
format!("bytes {start}-{end}/{head_content_length}"),
)
.body(Body::from(vec![1u8; actual_size]))
.unwrap()
.into_response()
}
_ => StatusCode::METHOD_NOT_ALLOWED.into_response(),
});

let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app.into_make_service())
.await
.unwrap();
});

Url::parse(&format!("http://localhost:{}/file", local_addr.port())).unwrap()
}

/// HEAD says 512 bytes, but range responses return 1024 bytes — overflows
/// the memory map.
#[tokio::test]
async fn test_content_length_response_beyond_content_length() {
let cases = [
// Baseline
(512, 512, 512, true),
// The requested and declared length is 512, while the actual content is 1024
(512, 512, 1024, false),
// The declared total length is 512, but it says and sends a range of 1024
(512, 1024, 1024, false),
// We ignore the response range end header is lying, we're getting the 512 we ordered
(512, 1024, 512, true),
// Baseline
(1024, 512, 512, true),
// We requested 512, but we're getting 1024
(1024, 512, 1024, false),
// We requested 512, but we're getting 1024
(1024, 1024, 1024, false),
// We ignore the response range end header is lying, we're getting the 512 we ordered
(1024, 1024, 512, true),
];
for (head_content_length, range_header_length, range_actual_length, is_ok) in cases {
let url = spawn_mismatch_server(
head_content_length,
range_header_length,
range_actual_length,
)
.await;

let (mut reader, _) = AsyncHttpRangeReader::new(
Client::new(),
url,
CheckSupportMethod::Head,
HeaderMap::default(),
)
.await
.unwrap();

assert_eq!(reader.len(), head_content_length as u64);
reader.prefetch(0..512).await;

let mut buf = vec![0u8; 512];
let result = reader.read(&mut buf).await;
if is_ok {
assert_matches!(
result,
Ok(_),
"{head_content_length} {range_header_length} {range_actual_length}"
);
} else {
assert_matches!(
result,
Err(_),
"{head_content_length} {range_header_length} {range_actual_length}"
);
}
}
}
}