Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cargo-progenitor/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ pub fn dependencies(builder: Generator, include_client: bool) -> Vec<String> {
deps.push(format!("base64 = \"{}\"", DEPENDENCIES.base64));
deps.push(format!("rand = \"{}\"", DEPENDENCIES.rand));
}
if type_space.uses_serde_json() || needs_serde_json {
if type_space.uses_serde_json() || needs_serde_json || builder.uses_serde_json() {
deps.push(format!("serde_json = \"{}\"", DEPENDENCIES.serde_json));
}
deps.sort_unstable();
Expand Down
161 changes: 161 additions & 0 deletions progenitor-client/src/progenitor_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -527,8 +527,62 @@ pub fn encode_path(pc: &str) -> String {
}

#[doc(hidden)]
/// A part in a multipart/related message (RFC 2387).
///
/// Each part has a Content-Type, Content-ID, and binary content.
///
/// Uses Cow to avoid cloning binary payloads while allowing owned data
/// for serialized JSON content.
pub struct MultipartPart<'a> {
/// The MIME type of this part (e.g., "application/json", "image/png").
pub content_type: &'a str,
/// The Content-ID value (appears as `<content_id>` in headers).
pub content_id: &'a str,
/// The binary content of this part.
/// Uses Cow to allow both borrowed (binary fields) and owned (serialized JSON) data.
pub bytes: std::borrow::Cow<'a, [u8]>,
}

/// Trait for types that can be serialized as RFC 2387 multipart/related bodies.
///
/// This trait is automatically implemented for multipart/related request body types
/// generated from OpenAPI specifications.
///
/// # Panics
///
/// Implementations may panic if:
/// - JSON serialization of structured fields fails (indicates non-serializable type)
/// - Required fields are not populated (indicates programmer error)
///
/// # Notes
///
/// - The first part returned is the "root" part (per RFC 2387 Section 3.2)
/// - Optional fields should be excluded from the result vec when not provided
/// - Part order should match the OpenAPI schema property order
#[doc(hidden)]
pub trait MultipartRelatedBody {
/// Convert this body into a vector of multipart parts.
///
/// The returned vector should:
/// - Be non-empty (at least one part)
/// - Have the root part first
/// - Exclude optional parts that are None/empty
/// - Preserve schema property order
fn as_multipart_parts(&self) -> Vec<MultipartPart<'_>>;
}

// Blanket impl for references - allows .multipart_related(&body) to work
impl<T: MultipartRelatedBody + ?Sized> MultipartRelatedBody for &T {
fn as_multipart_parts(&self) -> Vec<MultipartPart<'_>> {
(*self).as_multipart_parts()
}
}

#[doc(hidden)]
#[allow(clippy::result_large_err)]
pub trait RequestBuilderExt<E> {
fn form_urlencoded<T: Serialize + ?Sized>(self, body: &T) -> Result<RequestBuilder, Error<E>>;
fn multipart_related<T: MultipartRelatedBody + ?Sized>(self, body: &T) -> Result<RequestBuilder, Error<E>>;
}

impl<E> RequestBuilderExt<E> for RequestBuilder {
Expand All @@ -543,6 +597,113 @@ impl<E> RequestBuilderExt<E> for RequestBuilder {
.map_err(|_| Error::InvalidRequest("failed to serialize body".to_string()))?,
))
}

fn multipart_related<T: MultipartRelatedBody + ?Sized>(self, body: &T) -> Result<Self, Error<E>> {
// Generate a unique boundary
let boundary = generate_multipart_boundary();

// Get the parts from the body
let parts = body.as_multipart_parts();

// RFC 2387 requires at least one part
if parts.is_empty() {
return Err(Error::InvalidRequest(
"multipart/related requires at least one part".to_string()
));
}

// RFC 2387: The 'type' parameter must be the content-type of the root part (first part)
let root_content_type = &parts[0].content_type;

// Quote the type parameter per RFC 2045 if it contains special characters
let quoted_type = quote_mime_parameter(root_content_type);

// Preallocate body buffer based on estimated size
let estimated_size: usize = parts.iter()
.map(|p| p.bytes.len() + 200) // 200 bytes for headers per part
.sum::<usize>() + 500; // Extra for boundaries and final marker
let mut body_bytes = Vec::with_capacity(estimated_size);

// Write each part
for part in &parts {
// Validate Content-ID to prevent header injection
validate_content_id(part.content_id)?;

// Write boundary
body_bytes.extend_from_slice(format!("--{}\r\n", boundary).as_bytes());

// Write Content-Type header
body_bytes.extend_from_slice(format!("Content-Type: {}\r\n", part.content_type).as_bytes());

// Write Content-ID header
body_bytes.extend_from_slice(format!("Content-ID: <{}>\r\n\r\n", part.content_id).as_bytes());

// Write content bytes
body_bytes.extend_from_slice(&part.bytes);

// Write CRLF after content
body_bytes.extend_from_slice(b"\r\n");
}

// Write closing boundary
body_bytes.extend_from_slice(format!("--{}--\r\n", boundary).as_bytes());

// Set Content-Type header with boundary and type from root part
let content_type = format!("multipart/related; boundary={}; type={}", boundary, quoted_type);

Ok(self
.header(
reqwest::header::CONTENT_TYPE,
reqwest::header::HeaderValue::from_str(&content_type)
.map_err(|e| Error::InvalidRequest(format!("invalid content-type header: {}", e)))?,
)
.body(body_bytes))
}
}

fn generate_multipart_boundary() -> String {
use std::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);

// Generate a unique boundary using timestamp, counter, and process ID for better uniqueness
let count = COUNTER.fetch_add(1, Ordering::SeqCst);
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or_else(|_| {
// If system time is before epoch, use a large random-like number
// based on monotonic time
std::time::Instant::now().elapsed().as_nanos()
});
let pid = std::process::id();

// Use nanos for more precision and include PID for multi-process uniqueness
format!("progenitor_boundary_{:x}_{:x}_{:x}", timestamp, pid, count)
}

/// Quote a MIME type parameter value per RFC 2045 if it contains special characters
fn quote_mime_parameter(value: &str) -> String {
// Check if value needs quoting (contains tspecials per RFC 2045)
if value.contains(|c: char| {
matches!(c, '(' | ')' | '<' | '>' | '@' | ',' | ';' | ':' | '\\' | '"' | '/' | '[' | ']' | '?' | '=' | ' ' | '\t')
}) {
// Quote and escape the value
format!("\"{}\"", value.replace('\\', "\\\\").replace('"', "\\\""))
} else {
value.to_string()
}
}

/// Validate Content-ID to prevent header injection attacks
#[allow(clippy::result_large_err)]
fn validate_content_id<E>(content_id: &str) -> Result<(), Error<E>> {
// Content-ID must not contain control characters, angle brackets, or CRLF
if content_id.contains(|c: char| c == '<' || c == '>' || c.is_control()) {
return Err(Error::InvalidRequest(
format!("Invalid Content-ID contains forbidden characters: {}", content_id)
));
}
Ok(())
}

#[doc(hidden)]
Expand Down
5 changes: 3 additions & 2 deletions progenitor-impl/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ impl Generator {
}

let first_page_required = first_page_required_set
.map_or(false, |required| required.contains(&param.api_name));
.is_some_and(|required| required.contains(&param.api_name));

let volitionality = if innately_required || first_page_required {
Volitionality::Required
Expand Down Expand Up @@ -440,7 +440,8 @@ impl Generator {
// are currently...
OperationParameterType::RawBody => None,

OperationParameterType::Type(body_type_id) => Some(body_type_id),
OperationParameterType::Type(body_type_id)
| OperationParameterType::MultipartRelated(body_type_id) => Some(body_type_id),
});

if let Some(body_type_id) = maybe_body_type_id {
Expand Down
17 changes: 15 additions & 2 deletions progenitor-impl/src/httpmock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,22 @@ impl Generator {
kind,
api_name,
description: _,
..
}| {
let arg_type_name = match typ {
OperationParameterType::Type(arg_type_id) => self
.type_space
.get_type(arg_type_id)
.unwrap()
.parameter_ident(),
OperationParameterType::MultipartRelated(arg_type_id) => self
.type_space
.get_type(arg_type_id)
.unwrap()
.parameter_ident(),
OperationParameterType::RawBody => match kind {
OperationParameterKind::Body(BodyContentType::OctetStream) => quote! {
OperationParameterKind::Body(BodyContentType::OctetStream)
| OperationParameterKind::Body(BodyContentType::MultipartRelated) => quote! {
::serde_json::Value
},
OperationParameterKind::Body(BodyContentType::Text(_)) => quote! {
Expand Down Expand Up @@ -250,8 +257,14 @@ impl Generator {

},
),
OperationParameterType::MultipartRelated(_) => (
true,
quote! {
Self(self.0.json_body_obj(value))
},
),
OperationParameterType::RawBody => match body_content_type {
BodyContentType::OctetStream => (
BodyContentType::OctetStream | BodyContentType::MultipartRelated => (
true,
quote! {
Self(self.0.json_body(value))
Expand Down
Loading