Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport 0.14.x: feat(http1): support configurable max_headers #3773

Open
wants to merge 2 commits into
base: 0.14.x
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -39,6 +39,7 @@ want = "0.3"
# Optional

libc = { version = "0.2", optional = true }
smallvec = { version = "1.12", features = ["const_generics", "const_new"], optional = true }
socket2 = { version = ">=0.4.7, <0.6.0", optional = true, features = ["all"] }

[dev-dependencies]
@@ -87,8 +88,8 @@ http1 = []
http2 = ["h2"]

# Client/Server
client = []
server = []
client = ["dep:smallvec"]
server = ["dep:smallvec"]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just ported it from #3523.
I'm not sure if it's required or not.


# `impl Stream` for things
stream = []
23 changes: 23 additions & 0 deletions src/client/conn/http1.rs
Original file line number Diff line number Diff line change
@@ -115,6 +115,7 @@ pub struct Builder {
h1_writev: Option<bool>,
h1_title_case_headers: bool,
h1_preserve_header_case: bool,
h1_max_headers: Option<usize>,
#[cfg(feature = "ffi")]
h1_preserve_header_order: bool,
h1_read_buf_exact_size: Option<usize>,
@@ -302,6 +303,7 @@ impl Builder {
h1_parser_config: Default::default(),
h1_title_case_headers: false,
h1_preserve_header_case: false,
h1_max_headers: None,
#[cfg(feature = "ffi")]
h1_preserve_header_order: false,
h1_max_buf_size: None,
@@ -434,6 +436,24 @@ impl Builder {
self
}

/// Set the maximum number of headers.
///
/// When a response is received, the parser will reserve a buffer to store headers for optimal
/// performance.
///
/// If client receives more headers than the buffer size, the error "message header too large"
/// is returned.
///
/// Note that headers is allocated on the stack by default, which has higher performance. After
/// setting this value, headers will be allocated in heap memory, that is, heap memory
/// allocation will occur for each response, and there will be a performance drop of about 5%.
///
/// Default is 100.
pub fn max_headers(&mut self, val: usize) -> &mut Self {
self.h1_max_headers = Some(val);
self
}

/// Set whether to support preserving original header order.
///
/// Currently, this will record the order in which headers are received, and store this
@@ -514,6 +534,9 @@ impl Builder {
if opts.h1_preserve_header_case {
conn.set_preserve_header_case();
}
if let Some(max_headers) = opts.h1_max_headers {
conn.set_http1_max_headers(max_headers);
}
#[cfg(feature = "ffi")]
if opts.h1_preserve_header_order {
conn.set_preserve_header_order();
7 changes: 7 additions & 0 deletions src/proto/h1/conn.rs
Original file line number Diff line number Diff line change
@@ -53,6 +53,7 @@ where
keep_alive: KA::Busy,
method: None,
h1_parser_config: ParserConfig::default(),
h1_max_headers: None,
#[cfg(all(feature = "server", feature = "runtime"))]
h1_header_read_timeout: None,
#[cfg(all(feature = "server", feature = "runtime"))]
@@ -125,6 +126,10 @@ where
self.state.h09_responses = true;
}

pub(crate) fn set_http1_max_headers(&mut self, val: usize) {
self.state.h1_max_headers = Some(val);
}

#[cfg(all(feature = "server", feature = "runtime"))]
pub(crate) fn set_http1_header_read_timeout(&mut self, val: Duration) {
self.state.h1_header_read_timeout = Some(val);
@@ -198,6 +203,7 @@ where
cached_headers: &mut self.state.cached_headers,
req_method: &mut self.state.method,
h1_parser_config: self.state.h1_parser_config.clone(),
h1_max_headers: self.state.h1_max_headers,
#[cfg(all(feature = "server", feature = "runtime"))]
h1_header_read_timeout: self.state.h1_header_read_timeout,
#[cfg(all(feature = "server", feature = "runtime"))]
@@ -822,6 +828,7 @@ struct State {
/// a body or not.
method: Option<Method>,
h1_parser_config: ParserConfig,
h1_max_headers: Option<usize>,
#[cfg(all(feature = "server", feature = "runtime"))]
h1_header_read_timeout: Option<Duration>,
#[cfg(all(feature = "server", feature = "runtime"))]
2 changes: 2 additions & 0 deletions src/proto/h1/io.rs
Original file line number Diff line number Diff line change
@@ -191,6 +191,7 @@ where
cached_headers: parse_ctx.cached_headers,
req_method: parse_ctx.req_method,
h1_parser_config: parse_ctx.h1_parser_config.clone(),
h1_max_headers: parse_ctx.h1_max_headers,
#[cfg(all(feature = "server", feature = "runtime"))]
h1_header_read_timeout: parse_ctx.h1_header_read_timeout,
#[cfg(all(feature = "server", feature = "runtime"))]
@@ -741,6 +742,7 @@ mod tests {
cached_headers: &mut None,
req_method: &mut None,
h1_parser_config: Default::default(),
h1_max_headers: None,
#[cfg(feature = "runtime")]
h1_header_read_timeout: None,
#[cfg(feature = "runtime")]
1 change: 1 addition & 0 deletions src/proto/h1/mod.rs
Original file line number Diff line number Diff line change
@@ -76,6 +76,7 @@ pub(crate) struct ParseContext<'a> {
cached_headers: &'a mut Option<HeaderMap>,
req_method: &'a mut Option<Method>,
h1_parser_config: ParserConfig,
h1_max_headers: Option<usize>,
#[cfg(all(feature = "server", feature = "runtime"))]
h1_header_read_timeout: Option<Duration>,
#[cfg(all(feature = "server", feature = "runtime"))]
192 changes: 176 additions & 16 deletions src/proto/h1/role.rs
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@ use bytes::BytesMut;
use http::header::ValueIter;
use http::header::{self, Entry, HeaderName, HeaderValue};
use http::{HeaderMap, Method, StatusCode, Version};
use smallvec::{smallvec, smallvec_inline, SmallVec};
#[cfg(all(feature = "server", feature = "runtime"))]
use tokio::time::Instant;
use tracing::{debug, error, trace, trace_span, warn};
@@ -24,7 +25,7 @@ use crate::proto::h1::{
};
use crate::proto::{BodyLength, MessageHead, RequestHead, RequestLine};

const MAX_HEADERS: usize = 100;
const DEFAULT_MAX_HEADERS: usize = 100;
const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific
#[cfg(feature = "server")]
const MAX_URI_LEN: usize = (u16::MAX - 1) as usize;
@@ -169,14 +170,17 @@ impl Http1Transaction for Server {
// but we *never* read any of it until after httparse has assigned
// values into it. By not zeroing out the stack memory, this saves
// a good ~5% on pipeline benchmarks.
let mut headers_indices: [MaybeUninit<HeaderIndices>; MAX_HEADERS] = unsafe {
// SAFETY: We can go safely from MaybeUninit array to array of MaybeUninit
MaybeUninit::uninit().assume_init()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I should preserve unsafe and call assume_init

};
let mut headers_indices: SmallVec<[MaybeUninit<HeaderIndices>; DEFAULT_MAX_HEADERS]> =
match ctx.h1_max_headers {
Some(cap) => smallvec![MaybeUninit::uninit(); cap],
None => smallvec_inline![MaybeUninit::uninit(); DEFAULT_MAX_HEADERS],
};
{
/* SAFETY: it is safe to go from MaybeUninit array to array of MaybeUninit */
let mut headers: [MaybeUninit<httparse::Header<'_>>; MAX_HEADERS] =
unsafe { MaybeUninit::uninit().assume_init() };
let mut headers: SmallVec<[MaybeUninit<httparse::Header<'_>>; DEFAULT_MAX_HEADERS]> =
match ctx.h1_max_headers {
Some(cap) => smallvec![MaybeUninit::uninit(); cap],
None => smallvec_inline![MaybeUninit::uninit(); DEFAULT_MAX_HEADERS],
};
trace!(bytes = buf.len(), "Request.parse");
let mut req = httparse::Request::new(&mut []);
let bytes = buf.as_ref();
@@ -966,15 +970,18 @@ impl Http1Transaction for Client {

// Loop to skip information status code headers (100 Continue, etc).
loop {
// Unsafe: see comment in Server Http1Transaction, above.
let mut headers_indices: [MaybeUninit<HeaderIndices>; MAX_HEADERS] = unsafe {
// SAFETY: We can go safely from MaybeUninit array to array of MaybeUninit
MaybeUninit::uninit().assume_init()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I should preserve unsafe and call assume_init

};
let mut headers_indices: SmallVec<[MaybeUninit<HeaderIndices>; DEFAULT_MAX_HEADERS]> =
match ctx.h1_max_headers {
Some(cap) => smallvec![MaybeUninit::uninit(); cap],
None => smallvec_inline![MaybeUninit::uninit(); DEFAULT_MAX_HEADERS],
};
let (len, status, reason, version, headers_len) = {
// SAFETY: We can go safely from MaybeUninit array to array of MaybeUninit
let mut headers: [MaybeUninit<httparse::Header<'_>>; MAX_HEADERS] =
unsafe { MaybeUninit::uninit().assume_init() };
let mut headers: SmallVec<
[MaybeUninit<httparse::Header<'_>>; DEFAULT_MAX_HEADERS],
> = match ctx.h1_max_headers {
Some(cap) => smallvec![MaybeUninit::uninit(); cap],
None => smallvec_inline![MaybeUninit::uninit(); DEFAULT_MAX_HEADERS],
};
trace!(bytes = buf.len(), "Response.parse");
let mut res = httparse::Response::new(&mut []);
let bytes = buf.as_ref();
@@ -1555,6 +1562,7 @@ mod tests {
cached_headers: &mut None,
req_method: &mut method,
h1_parser_config: Default::default(),
h1_max_headers: None,
#[cfg(feature = "runtime")]
h1_header_read_timeout: None,
#[cfg(feature = "runtime")]
@@ -1590,6 +1598,7 @@ mod tests {
cached_headers: &mut None,
req_method: &mut Some(crate::Method::GET),
h1_parser_config: Default::default(),
h1_max_headers: None,
#[cfg(feature = "runtime")]
h1_header_read_timeout: None,
#[cfg(feature = "runtime")]
@@ -1620,6 +1629,7 @@ mod tests {
cached_headers: &mut None,
req_method: &mut None,
h1_parser_config: Default::default(),
h1_max_headers: None,
#[cfg(feature = "runtime")]
h1_header_read_timeout: None,
#[cfg(feature = "runtime")]
@@ -1648,6 +1658,7 @@ mod tests {
cached_headers: &mut None,
req_method: &mut Some(crate::Method::GET),
h1_parser_config: Default::default(),
h1_max_headers: None,
#[cfg(feature = "runtime")]
h1_header_read_timeout: None,
#[cfg(feature = "runtime")]
@@ -1678,6 +1689,7 @@ mod tests {
cached_headers: &mut None,
req_method: &mut Some(crate::Method::GET),
h1_parser_config: Default::default(),
h1_max_headers: None,
#[cfg(feature = "runtime")]
h1_header_read_timeout: None,
#[cfg(feature = "runtime")]
@@ -1712,6 +1724,7 @@ mod tests {
cached_headers: &mut None,
req_method: &mut Some(crate::Method::GET),
h1_parser_config,
h1_max_headers: None,
#[cfg(feature = "runtime")]
h1_header_read_timeout: None,
#[cfg(feature = "runtime")]
@@ -1743,6 +1756,7 @@ mod tests {
cached_headers: &mut None,
req_method: &mut Some(crate::Method::GET),
h1_parser_config: Default::default(),
h1_max_headers: None,
#[cfg(feature = "runtime")]
h1_header_read_timeout: None,
#[cfg(feature = "runtime")]
@@ -1769,6 +1783,7 @@ mod tests {
cached_headers: &mut None,
req_method: &mut None,
h1_parser_config: Default::default(),
h1_max_headers: None,
#[cfg(feature = "runtime")]
h1_header_read_timeout: None,
#[cfg(feature = "runtime")]
@@ -1816,6 +1831,7 @@ mod tests {
cached_headers: &mut None,
req_method: &mut None,
h1_parser_config: Default::default(),
h1_max_headers: None,
#[cfg(feature = "runtime")]
h1_header_read_timeout: None,
#[cfg(feature = "runtime")]
@@ -1844,6 +1860,7 @@ mod tests {
cached_headers: &mut None,
req_method: &mut None,
h1_parser_config: Default::default(),
h1_max_headers: None,
#[cfg(feature = "runtime")]
h1_header_read_timeout: None,
#[cfg(feature = "runtime")]
@@ -2081,6 +2098,7 @@ mod tests {
cached_headers: &mut None,
req_method: &mut Some(Method::GET),
h1_parser_config: Default::default(),
h1_max_headers: None,
#[cfg(feature = "runtime")]
h1_header_read_timeout: None,
#[cfg(feature = "runtime")]
@@ -2109,6 +2127,7 @@ mod tests {
cached_headers: &mut None,
req_method: &mut Some(m),
h1_parser_config: Default::default(),
h1_max_headers: None,
#[cfg(feature = "runtime")]
h1_header_read_timeout: None,
#[cfg(feature = "runtime")]
@@ -2137,6 +2156,7 @@ mod tests {
cached_headers: &mut None,
req_method: &mut Some(Method::GET),
h1_parser_config: Default::default(),
h1_max_headers: None,
#[cfg(feature = "runtime")]
h1_header_read_timeout: None,
#[cfg(feature = "runtime")]
@@ -2642,6 +2662,7 @@ mod tests {
cached_headers: &mut None,
req_method: &mut Some(Method::GET),
h1_parser_config: Default::default(),
h1_max_headers: None,
#[cfg(feature = "runtime")]
h1_header_read_timeout: None,
#[cfg(feature = "runtime")]
@@ -2664,6 +2685,143 @@ mod tests {
assert_eq!(parsed.head.headers["server"], "hello\tworld");
}

#[test]
fn parse_too_large_headers() {
fn gen_req_with_headers(num: usize) -> String {
let mut req = String::from("GET / HTTP/1.1\r\n");
for i in 0..num {
req.push_str(&format!("key{i}: val{i}\r\n"));
}
req.push_str("\r\n");
req
}
fn gen_resp_with_headers(num: usize) -> String {
let mut req = String::from("HTTP/1.1 200 OK\r\n");
for i in 0..num {
req.push_str(&format!("key{i}: val{i}\r\n"));
}
req.push_str("\r\n");
req
}
fn parse(max_headers: Option<usize>, gen_size: usize, should_success: bool) {
{
// server side
let mut bytes = BytesMut::from(gen_req_with_headers(gen_size).as_str());
let result = Server::parse(
&mut bytes,
ParseContext {
cached_headers: &mut None,
req_method: &mut None,
h1_parser_config: Default::default(),
h1_max_headers: max_headers,
h1_header_read_timeout: None,
h1_header_read_timeout_fut: &mut None,
h1_header_read_timeout_running: &mut false,
timer: Time::Empty,
preserve_header_case: false,
#[cfg(feature = "ffi")]
preserve_header_order: false,
h09_responses: false,
#[cfg(feature = "ffi")]
on_informational: &mut None,
},
);
if should_success {
result.expect("parse ok").expect("parse complete");
} else {
result.expect_err("parse should err");
}
}
{
// client side
let mut bytes = BytesMut::from(gen_resp_with_headers(gen_size).as_str());
let result = Client::parse(
&mut bytes,
ParseContext {
cached_headers: &mut None,
req_method: &mut None,
h1_parser_config: Default::default(),
h1_max_headers: max_headers,
h1_header_read_timeout: None,
h1_header_read_timeout_fut: &mut None,
h1_header_read_timeout_running: &mut false,
timer: Time::Empty,
preserve_header_case: false,
#[cfg(feature = "ffi")]
preserve_header_order: false,
h09_responses: false,
#[cfg(feature = "ffi")]
on_informational: &mut None,
},
);
if should_success {
result.expect("parse ok").expect("parse complete");
} else {
result.expect_err("parse should err");
}
}
}

// check generator
assert_eq!(
gen_req_with_headers(0),
String::from("GET / HTTP/1.1\r\n\r\n")
);
assert_eq!(
gen_req_with_headers(1),
String::from("GET / HTTP/1.1\r\nkey0: val0\r\n\r\n")
);
assert_eq!(
gen_req_with_headers(2),
String::from("GET / HTTP/1.1\r\nkey0: val0\r\nkey1: val1\r\n\r\n")
);
assert_eq!(
gen_req_with_headers(3),
String::from("GET / HTTP/1.1\r\nkey0: val0\r\nkey1: val1\r\nkey2: val2\r\n\r\n")
);

// default max_headers is 100, so
//
// - less than or equal to 100, accepted
//
parse(None, 0, true);
parse(None, 1, true);
parse(None, 50, true);
parse(None, 99, true);
parse(None, 100, true);
//
// - more than 100, rejected
//
parse(None, 101, false);
parse(None, 102, false);
parse(None, 200, false);

// max_headers is 0, parser will reject any headers
//
// - without header, accepted
//
parse(Some(0), 0, true);
//
// - with header(s), rejected
//
parse(Some(0), 1, false);
parse(Some(0), 100, false);

// max_headers is 200
//
// - less than or equal to 200, accepted
//
parse(Some(200), 0, true);
parse(Some(200), 1, true);
parse(Some(200), 100, true);
parse(Some(200), 200, true);
//
// - more than 200, rejected
//
parse(Some(200), 201, false);
parse(Some(200), 210, false);
}

#[test]
fn test_is_complete_fast() {
let s = b"GET / HTTP/1.1\r\na: b\r\n\r\n";
@@ -2756,6 +2914,7 @@ mod tests {
cached_headers: &mut headers,
req_method: &mut None,
h1_parser_config: Default::default(),
h1_max_headers: None,
#[cfg(feature = "runtime")]
h1_header_read_timeout: None,
#[cfg(feature = "runtime")]
@@ -2804,6 +2963,7 @@ mod tests {
cached_headers: &mut headers,
req_method: &mut None,
h1_parser_config: Default::default(),
h1_max_headers: None,
#[cfg(feature = "runtime")]
h1_header_read_timeout: None,
#[cfg(feature = "runtime")]
26 changes: 26 additions & 0 deletions src/server/conn.rs
Original file line number Diff line number Diff line change
@@ -113,6 +113,7 @@ pub struct Http<E = Exec> {
h1_keep_alive: bool,
h1_title_case_headers: bool,
h1_preserve_header_case: bool,
h1_max_headers: Option<usize>,
#[cfg(all(feature = "http1", feature = "runtime"))]
h1_header_read_timeout: Option<Duration>,
h1_writev: Option<bool>,
@@ -260,6 +261,7 @@ impl Http {
h1_title_case_headers: false,
h1_preserve_header_case: false,
#[cfg(all(feature = "http1", feature = "runtime"))]
h1_max_headers: None,
h1_header_read_timeout: None,
h1_writev: None,
#[cfg(feature = "http2")]
@@ -349,6 +351,26 @@ impl<E> Http<E> {
self
}

/// Set the maximum number of headers.
///
/// When a request is received, the parser will reserve a buffer to store headers for optimal
/// performance.
///
/// If server receives more headers than the buffer size, it responds to the client with
/// "431 Request Header Fields Too Large".
///
/// Note that headers is allocated on the stack by default, which has higher performance. After
/// setting this value, headers will be allocated in heap memory, that is, heap memory
/// allocation will occur for each request, and there will be a performance drop of about 5%.
///
/// Default is 100.
#[cfg(feature = "http1")]
#[cfg_attr(docsrs, doc(cfg(feature = "http1")))]
pub fn http1_max_headers(&mut self, val: usize) -> &mut Self {
self.h1_max_headers = Some(val);
self
}

/// Set a timeout for reading client request headers. If a client does not
/// transmit the entire header within this time, the connection is closed.
///
@@ -623,6 +645,7 @@ impl<E> Http<E> {
h1_keep_alive: self.h1_keep_alive,
h1_title_case_headers: self.h1_title_case_headers,
h1_preserve_header_case: self.h1_preserve_header_case,
h1_max_headers: self.h1_max_headers,
#[cfg(all(feature = "http1", feature = "runtime"))]
h1_header_read_timeout: self.h1_header_read_timeout,
h1_writev: self.h1_writev,
@@ -687,6 +710,9 @@ impl<E> Http<E> {
if self.h1_preserve_header_case {
conn.set_preserve_header_case();
}
if let Some(max_headers) = self.h1_max_headers {
conn.set_http1_max_headers(max_headers);
}
#[cfg(all(feature = "http1", feature = "runtime"))]
if let Some(header_read_timeout) = self.h1_header_read_timeout {
conn.set_http1_header_read_timeout(header_read_timeout);
23 changes: 23 additions & 0 deletions src/server/conn/http1.rs
Original file line number Diff line number Diff line change
@@ -42,6 +42,7 @@ pub struct Builder {
h1_keep_alive: bool,
h1_title_case_headers: bool,
h1_preserve_header_case: bool,
h1_max_headers: Option<usize>,
h1_header_read_timeout: Option<Duration>,
h1_writev: Option<bool>,
max_buf_size: Option<usize>,
@@ -208,6 +209,7 @@ impl Builder {
h1_keep_alive: true,
h1_title_case_headers: false,
h1_preserve_header_case: false,
h1_max_headers: None,
h1_header_read_timeout: None,
h1_writev: None,
max_buf_size: None,
@@ -260,6 +262,24 @@ impl Builder {
self
}

/// Set the maximum number of headers.
///
/// When a request is received, the parser will reserve a buffer to store headers for optimal
/// performance.
///
/// If server receives more headers than the buffer size, it responds to the client with
/// "431 Request Header Fields Too Large".
///
/// Note that headers is allocated on the stack by default, which has higher performance. After
/// setting this value, headers will be allocated in heap memory, that is, heap memory
/// allocation will occur for each request, and there will be a performance drop of about 5%.
///
/// Default is 100.
pub fn max_headers(&mut self, val: usize) -> &mut Self {
self.h1_max_headers = Some(val);
self
}

/// Set a timeout for reading client request headers. If a client does not
/// transmit the entire header within this time, the connection is closed.
///
@@ -370,6 +390,9 @@ impl Builder {
if self.h1_preserve_header_case {
conn.set_preserve_header_case();
}
if let Some(max_headers) = self.h1_max_headers {
conn.set_http1_max_headers(max_headers);
}
if let Some(header_read_timeout) = self.h1_header_read_timeout {
conn.set_http1_header_read_timeout(header_read_timeout);
}