Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ itertools = "0.13.0"
mime = "0.3.17"
pretty_assertions = "1.4.0"
regex = {version = "1.10.5", default-features = false}
# TODO use a stable version
# reqwest = {version = "0.12.23", default-features = false}
reqwest = {git = "https://github.com/LucasPickering/reqwest", branch = "2374-custom-boundary", default-features = false}
reqwest = {version = "0.12.23", default-features = false}
rstest = {version = "0.24.0", default-features = false}
saphyr = "0.0.6"
schemars = "1.0.2"
Expand Down
13 changes: 6 additions & 7 deletions crates/core/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -746,13 +746,12 @@ impl RenderedBody {
RenderedBody::FormMultipart(fields) => {
let mut form = Form::new();

// Use a static boundary in tests for assertions. Test-only
// code can be dangerous, but in non-test we're just using the
// default library behavior. There's also plenty of tests in
// other crates that hit this code path, and cfg(test) won't
// be enabled for those.
if cfg!(test) {
form.set_boundary("BOUNDARY");
#[cfg(test)]
{
// Hack alert!! Reqwest uses a random boundary between parts
// in a multipart request. We share this with the tests via
// TLS. See the TLS declaration for more info.
tests::MULTIPART_BOUNDARY.set(form.boundary().to_owned());
}

for (field, stream) in fields {
Expand Down
76 changes: 55 additions & 21 deletions crates/core/src/http/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,25 @@ use reqwest::{Body, StatusCode, header};
use rstest::rstest;
use serde_json::json;
use slumber_util::{Factory, assert_err, test_data_dir};
use std::{path, ptr};
use std::{cell::RefCell, path, ptr};
use wiremock::{Mock, MockServer, ResponseTemplate, matchers};

thread_local! {
/// Out-of-band communication that the render code uses to share the
/// boundary used for whatever multipart form was rendered most recently.
/// This is really hacky but it's the best solution I can think of.
///
/// Some alternatives:
/// - Control the randomness to make the boundary predictable. Reqwest
/// doesn't provide any way to do this.
/// - Use a regex in expectations instead of a static string. That blows up
/// the dependency tree and also gives much worse assertion messages.
/// - Set the boundary to a static value. Right now that's not possible but
/// if https://github.com/seanmonstar/reqwest/pull/2814 ever gets merged,
/// we can use form.set_boundary()
pub static MULTIPART_BOUNDARY: RefCell<String> = RefCell::default();
}

/// Create a template context. Take a set of extra recipes to add to the created
/// collection
fn template_context(recipe: Recipe, host: Option<&str>) -> TemplateContext {
Expand Down Expand Up @@ -390,25 +406,25 @@ async fn test_body(
"user_id".into() => "{{ user_id }}".into(),
}),
// Normally the boundary is random, but we make it static for testing
Some("multipart/form-data; boundary=BOUNDARY"),
"--BOUNDARY\r
Some("multipart/form-data; boundary={BOUNDARY}"),
"--{BOUNDARY}\r
Content-Disposition: form-data; name=\"user_id\"\r
\r
1\r
--BOUNDARY--\r
--{BOUNDARY}--\r
",
)]
#[case::form_multipart_file(
RecipeBody::FormMultipart(indexmap! {
"file".into() => "{{ file('data.json') }}".into(),
}),
Some("multipart/form-data; boundary=BOUNDARY"),
"--BOUNDARY\r
Some("multipart/form-data; boundary={BOUNDARY}"),
"--{BOUNDARY}\r
Content-Disposition: form-data; name=\"file\"; filename=\"data.json\"\r
Content-Type: application/json\r
\r
{ \"a\": 1, \"b\": 2 }\r
--BOUNDARY--\r
--{BOUNDARY}--\r
",
)]
#[case::form_multipart_file_multichunk(
Expand All @@ -417,24 +433,24 @@ Content-Type: application/json\r
// because it's not *just* the file
"file".into() => "data: {{ file('data.json') }}".into(),
}),
Some("multipart/form-data; boundary=BOUNDARY"),
"--BOUNDARY\r
Some("multipart/form-data; boundary={BOUNDARY}"),
"--{BOUNDARY}\r
Content-Disposition: form-data; name=\"file\"\r
\r
data: { \"a\": 1, \"b\": 2 }\r
--BOUNDARY--\r
--{BOUNDARY}--\r
",
)]
#[case::form_multipart_command(
RecipeBody::FormMultipart(indexmap! {
"command".into() => "{{ command(['cat', 'data.json']) }}".into(),
}),
Some("multipart/form-data; boundary=BOUNDARY"),
"--BOUNDARY\r
Some("multipart/form-data; boundary={BOUNDARY}"),
"--{BOUNDARY}\r
Content-Disposition: form-data; name=\"command\"\r
\r
{ \"a\": 1, \"b\": 2 }\r
--BOUNDARY--\r
--{BOUNDARY}--\r
",
)]
#[tokio::test]
Expand Down Expand Up @@ -476,6 +492,16 @@ async fn test_body_stream(
let seed = seed(&context, BuildOptions::default());
let ticket = http_engine.build(seed, &context).await.unwrap();

// The rendering code should set the correct boundary in TLS
let (expected_content_type, expected_body) = MULTIPART_BOUNDARY
.with_borrow(|boundary| {
(
expected_content_type
.map(|s| s.replace("{BOUNDARY}", boundary)),
expected_body.replace("{BOUNDARY}", boundary),
)
});

let exchange = ticket.send().await.unwrap();

// Note: this doesn't actually enforce that the body was streamed
Expand All @@ -490,9 +516,13 @@ async fn test_body_stream(
content_type.to_str().expect("Invalid Content-Type header")
},
);
assert_eq!(actual_content_type, expected_content_type);
assert_eq!(
actual_content_type,
expected_content_type.as_deref(),
"Incorrect Content-Type header"
);
let body = exchange.response.body.text().expect("Invalid UTF-8 body");
assert_eq!(body, expected_body);
assert_eq!(body, expected_body, "Incorrect body");
}

/// Test overriding authentication in BuildOptions
Expand Down Expand Up @@ -739,11 +769,11 @@ async fn test_override_body_form(http_engine: HttpEngine) {
RecipeBody::FormMultipart(indexmap!{
"file".into() => "{{ stream_prompt }}".into(),
}),
"--BOUNDARY\r
"--{BOUNDARY}\r
Content-Disposition: form-data; name=\"file\"\r
\r
first\r
--BOUNDARY--\r
--{BOUNDARY}--\r
",
)]
#[case::multipart_body_multiple(
Expand All @@ -756,17 +786,17 @@ first\r
"f1".into() => "{{ stream_prompt }}".into(),
"f2".into() => "{{ stream_prompt }}".into(),
}),
"--BOUNDARY\r
"--{BOUNDARY}\r
Content-Disposition: form-data; name=\"f1\"; filename=\"first.txt\"\r
Content-Type: text/plain\r
\r
first\r
--BOUNDARY\r
--{BOUNDARY}\r
Content-Disposition: form-data; name=\"f2\"; filename=\"second.txt\"\r
Content-Type: text/plain\r
\r
second\r
--BOUNDARY--\r
--{BOUNDARY}--\r
",
)]
#[tokio::test]
Expand Down Expand Up @@ -801,6 +831,10 @@ async fn test_profile_duplicate(
let seed = seed(&context, BuildOptions::default());
let ticket = http_engine.build(seed, &context).await.unwrap();

// The rendering code should set the correct boundary in TLS
let expected_body = MULTIPART_BOUNDARY
.with_borrow(|boundary| expected_body.replace("{BOUNDARY}", boundary));

// Make sure the URL rendered correctly before sending
let expected_url: Url = format!("{host}/first/first").parse().unwrap();
let exchange = ticket.send().await.unwrap();
Expand All @@ -810,7 +844,7 @@ async fn test_profile_duplicate(
assert_eq!(
// The response body is the same as the request body
std::str::from_utf8(exchange.response.body.bytes()).ok(),
Some(expected_body)
Some(expected_body.as_str())
);
}

Expand Down
Loading