-
Notifications
You must be signed in to change notification settings - Fork 4
fix: Avoid panics when content is longer than content length header #8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -110,6 +110,8 @@ struct Inner { | |
| streamer_state_rx: WatchStream<StreamerState>, | ||
|
|
||
| /// A channel sender to send range requests to the background task | ||
| /// | ||
| /// Contract: All ranges sent must be inside the range of the memory map | ||
| request_tx: tokio::sync::mpsc::Sender<Range<u64>>, | ||
|
|
||
| /// An optional object to reserve a slot in the `request_tx` sender. When in the process of | ||
|
|
@@ -196,7 +198,7 @@ impl AsyncHttpRangeReader { | |
| } | ||
|
|
||
| /// Initialize the reader from [`AsyncHttpRangeReader::initial_tail_request`] (or a user | ||
| /// provided response that also has a range of bytes from the end as body) | ||
| /// provided range response) | ||
| pub async fn from_tail_response( | ||
|
konstin marked this conversation as resolved.
|
||
| client: impl Into<reqwest_middleware::ClientWithMiddleware>, | ||
| tail_request_response: Response, | ||
|
|
@@ -212,10 +214,11 @@ impl AsyncHttpRangeReader { | |
| .ok_or(AsyncHttpRangeReaderError::ContentRangeMissing)? | ||
| .to_str() | ||
| .map_err(|_err| AsyncHttpRangeReaderError::ContentRangeMissing)?; | ||
| // The parser ensures finish < complete_length | ||
| let content_range = ContentRange::parse(content_range_header).ok_or_else(|| { | ||
| AsyncHttpRangeReaderError::ContentRangeParser(content_range_header.to_string()) | ||
| })?; | ||
| let (start, finish, complete_length) = match content_range { | ||
| let (start, end_inclusive, complete_length) = match content_range { | ||
| ContentRange::Bytes(ContentRangeBytes { | ||
| first_byte, | ||
| last_byte, | ||
|
|
@@ -236,8 +239,7 @@ impl AsyncHttpRangeReader { | |
| let memory_map_slice = | ||
| unsafe { std::slice::from_raw_parts(memory_map.as_ptr(), memory_map.len()) }; | ||
|
|
||
| let requested_range = | ||
| SparseRange::from_range(complete_length - (finish - start)..complete_length); | ||
| let requested_range = SparseRange::from_range(start..end_inclusive + 1); | ||
|
|
||
| // adding more than 2 entries to the channel would block the sender. I assumed two would | ||
| // suffice because I would want to 1) prefetch a certain range and 2) read stuff via the | ||
|
|
@@ -249,7 +251,7 @@ impl AsyncHttpRangeReader { | |
| client, | ||
| url, | ||
| extra_headers, | ||
| Some((tail_request_response, start)), | ||
| Some((tail_request_response, start, end_inclusive + 1)), | ||
| memory_map, | ||
| state_tx, | ||
| request_rx, | ||
|
|
@@ -259,7 +261,7 @@ impl AsyncHttpRangeReader { | |
| let mut streamer_state = StreamerState::default(); | ||
| streamer_state | ||
| .requested_ranges | ||
| .push(complete_length - (finish - start)..complete_length); | ||
| .push(start..end_inclusive + 1); | ||
|
|
||
| let reader = Self { | ||
| len: memory_map_slice.len() as u64, | ||
|
|
@@ -416,23 +418,22 @@ async fn run_streamer( | |
| client: reqwest_middleware::ClientWithMiddleware, | ||
| url: Url, | ||
| extra_headers: HeaderMap, | ||
| initial_tail_response: Option<(Response, u64)>, | ||
| initial_tail_response: Option<(Response, u64, u64)>, | ||
| mut memory_map: MmapMut, | ||
| mut state_tx: Sender<StreamerState>, | ||
| mut request_rx: tokio::sync::mpsc::Receiver<Range<u64>>, | ||
| ) { | ||
| let mut state = StreamerState::default(); | ||
|
|
||
| if let Some((response, response_start)) = initial_tail_response { | ||
| if let Some((response, start, end_exclusive)) = initial_tail_response { | ||
| // Add the initial range to the state | ||
| state | ||
| .requested_ranges | ||
| .push(response_start..memory_map.len() as u64); | ||
| state.requested_ranges.push(start..memory_map.len() as u64); | ||
|
konstin marked this conversation as resolved.
Outdated
|
||
|
|
||
| // Stream the initial data in memory | ||
| if !stream_response( | ||
| response, | ||
| response_start, | ||
| start, | ||
| end_exclusive, | ||
| &mut memory_map, | ||
| &mut state_tx, | ||
| &mut state, | ||
|
|
@@ -497,6 +498,7 @@ async fn run_streamer( | |
| if !stream_response( | ||
| response, | ||
| *range.start(), | ||
| *range.end() + 1, | ||
| &mut memory_map, | ||
| &mut state_tx, | ||
| &mut state, | ||
|
|
@@ -512,13 +514,25 @@ async fn run_streamer( | |
| /// Streams the data from the specified response to the memory map updating progress in between. | ||
| /// Returns `true` if everything went fine, `false` if anything went wrong. The error state, if any, | ||
| /// is stored in `state_tx` so the "frontend" will consume it. | ||
| /// | ||
| /// The response must return bytes for the range of precisely `start..end_exclusive`. | ||
| async fn stream_response( | ||
| tail_request_response: Response, | ||
| mut offset: u64, | ||
| start: u64, | ||
| end_exclusive: u64, | ||
| memory_map: &mut MmapMut, | ||
| state_tx: &mut Sender<StreamerState>, | ||
| state: &mut StreamerState, | ||
| ) -> bool { | ||
| // Enforce request channel contract | ||
| assert!( | ||
| (end_exclusive as usize) <= memory_map.len(), | ||
| "end is outside of memory map {} > {}", | ||
| end_exclusive, | ||
| memory_map.len() | ||
| ); | ||
|
|
||
| let mut offset = start; | ||
| let mut byte_stream = tail_request_response.bytes_stream(); | ||
| while let Some(bytes) = byte_stream.next().await { | ||
| let bytes = match bytes { | ||
|
|
@@ -534,7 +548,17 @@ async fn stream_response( | |
| let byte_range = offset..offset + bytes.len() as u64; | ||
|
|
||
| // Update the offset | ||
| offset = byte_range.end; | ||
| offset += bytes.len() as u64; | ||
|
Comment on lines
-537
to
+610
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I liked the previous version of this line better :)
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh 😅 |
||
|
|
||
| // Prevent the server from sending more bytes than advertised in a response | ||
| if offset > end_exclusive { | ||
| state.error = Some(AsyncHttpRangeReaderError::ContentLengthMismatch { | ||
| expected: end_exclusive - start, | ||
| actual: offset - start, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This isn't quite true. The actual content length if the response is chunked and there's no content length header is unknowable. And regardless, if there is a content length, that's what the actual content length would be. |
||
| }); | ||
| let _ = state_tx.send(state.clone()); | ||
| return false; | ||
| } | ||
|
|
||
| // Copy the data from the stream to memory | ||
| memory_map[byte_range.start as usize..byte_range.end as usize] | ||
|
|
@@ -551,6 +575,16 @@ async fn stream_response( | |
| } | ||
| } | ||
|
|
||
| // Prevent the server from sending less bytes than advertised in a response | ||
| if offset != end_exclusive { | ||
| state.error = Some(AsyncHttpRangeReaderError::ContentLengthMismatch { | ||
| expected: end_exclusive - start, | ||
| actual: offset - start, | ||
| }); | ||
| let _ = state_tx.send(state.clone()); | ||
| return false; | ||
| } | ||
|
|
||
| true | ||
| } | ||
|
|
||
|
|
@@ -658,7 +692,12 @@ mod test { | |
| use crate::static_directory_server::StaticDirectoryServer; | ||
| use assert_matches::assert_matches; | ||
| use async_zip::tokio::read::seek::ZipFileReader; | ||
| use axum::body::Body; | ||
| use axum::extract::Request; | ||
| use axum::response::IntoResponse; | ||
| use futures::AsyncReadExt; | ||
| use reqwest::header; | ||
| use reqwest::Method; | ||
| use reqwest::{Client, StatusCode}; | ||
| use rstest::*; | ||
| use std::path::Path; | ||
|
|
@@ -854,4 +893,121 @@ mod test { | |
| err, AsyncHttpRangeReaderError::HttpError(err) if err.status() == Some(StatusCode::NOT_FOUND) | ||
| ); | ||
| } | ||
|
|
||
| /// Spawn a server where the HEAD response reports `head_size` bytes, and range requests always | ||
| /// claim to be `pretend_size` bytes, while actually serving `actual_size`. | ||
| async fn spawn_mismatch_server( | ||
| head_content_length: usize, | ||
| pretend_size: usize, | ||
| actual_size: usize, | ||
| ) -> Url { | ||
| let app = | ||
| axum::Router::new().fallback(async move |request: Request| match *request.method() { | ||
| Method::HEAD => { | ||
| let headers = [ | ||
| (header::CONTENT_LENGTH, head_content_length.to_string()), | ||
| (header::ACCEPT_RANGES, "bytes".to_string()), | ||
| ]; | ||
| (StatusCode::OK, headers).into_response() | ||
| } | ||
| Method::GET => { | ||
| let range_header = request | ||
| .headers() | ||
| .get(header::RANGE) | ||
| .unwrap() | ||
| .to_str() | ||
| .unwrap() | ||
| .to_string(); | ||
|
|
||
| let range_spec = range_header.strip_prefix("bytes=").unwrap(); | ||
| let (start_str, _end_str) = range_spec.split_once('-').unwrap(); | ||
| let start = start_str.parse::<usize>().unwrap(); | ||
| // The end is inclusive | ||
| let end = start + pretend_size - 1; | ||
|
|
||
| axum::response::Response::builder() | ||
| .status(StatusCode::PARTIAL_CONTENT) | ||
| // Note that the client ignores this value currently, it only checks the | ||
| // actual size | ||
| .header( | ||
| header::CONTENT_RANGE, | ||
| format!("bytes {start}-{end}/{head_content_length}"), | ||
| ) | ||
| .body(Body::from(vec![1u8; actual_size])) | ||
| .unwrap() | ||
| .into_response() | ||
| } | ||
| _ => StatusCode::METHOD_NOT_ALLOWED.into_response(), | ||
| }); | ||
|
|
||
| let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); | ||
| let local_addr = listener.local_addr().unwrap(); | ||
| tokio::spawn(async move { | ||
| axum::serve(listener, app.into_make_service()) | ||
| .await | ||
| .unwrap(); | ||
| }); | ||
|
|
||
| Url::parse(&format!("http://localhost:{}/file", local_addr.port())).unwrap() | ||
| } | ||
|
|
||
| /// HEAD says 512 bytes, but range responses return 1024 bytes — overflows | ||
| /// the memory map. | ||
| #[tokio::test] | ||
| async fn test_content_length_response_beyond_content_length() { | ||
| let cases = [ | ||
| // Baseline | ||
| (512, 512, 512, true), | ||
| // The requested and declared length is 512, while the actual content is 1024 | ||
| (512, 512, 1024, false), | ||
| // The declared total length is 512, but it says and sends a range of 1024 | ||
| (512, 1024, 1024, false), | ||
| // We ignore the response range end header is lying, we're getting the 512 we ordered | ||
| (512, 1024, 512, true), | ||
| // Baseline | ||
| (1024, 512, 512, true), | ||
| // We requested 512, but we're getting 1024 | ||
| (1024, 512, 1024, false), | ||
| // We requested 512, but we're getting 1024 | ||
| (1024, 1024, 1024, false), | ||
| // We ignore the response range end header is lying, we're getting the 512 we ordered | ||
| (1024, 1024, 512, true), | ||
| ]; | ||
| for (head_content_length, range_header_length, range_actual_length, is_ok) in cases { | ||
| let url = spawn_mismatch_server( | ||
| head_content_length, | ||
| range_header_length, | ||
| range_actual_length, | ||
| ) | ||
| .await; | ||
|
|
||
| let (mut reader, _) = AsyncHttpRangeReader::new( | ||
| Client::new(), | ||
| url, | ||
| CheckSupportMethod::Head, | ||
| HeaderMap::default(), | ||
| ) | ||
| .await | ||
| .unwrap(); | ||
|
|
||
| assert_eq!(reader.len(), head_content_length as u64); | ||
| reader.prefetch(0..512).await; | ||
|
|
||
| let mut buf = vec![0u8; 512]; | ||
| let result = reader.read(&mut buf).await; | ||
| if is_ok { | ||
| assert_matches!( | ||
| result, | ||
| Ok(_), | ||
| "{head_content_length} {range_header_length} {range_actual_length}" | ||
| ); | ||
| } else { | ||
| assert_matches!( | ||
| result, | ||
| Err(_), | ||
| "{head_content_length} {range_header_length} {range_actual_length}" | ||
| ); | ||
| } | ||
| } | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.