diff --git a/src/client/conn.rs b/src/client/conn.rs index d377a24da8..e35d435c6e 100644 --- a/src/client/conn.rs +++ b/src/client/conn.rs @@ -531,7 +531,7 @@ impl Future for HandshakeInner where T: AsyncRead + AsyncWrite + Send + 'static, B: Payload, - R: proto::Http1Transaction< + R: proto::h1::Http1Transaction< Incoming=StatusCode, Outgoing=proto::RequestLine, >, diff --git a/src/client/mod.rs b/src/client/mod.rs index 8f8ee03bf5..1b4aebf063 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -193,7 +193,6 @@ where C: Connect + Sync + 'static, Version::HTTP_11 => (), other => { error!("Request has unsupported version \"{:?}\"", other); - //TODO: replace this with a proper variant return ResponseFuture::new(Box::new(future::err(::Error::new_user_unsupported_version()))); } } diff --git a/src/error.rs b/src/error.rs index 23da310238..f1326e2d99 100644 --- a/src/error.rs +++ b/src/error.rs @@ -26,8 +26,6 @@ pub(crate) enum Kind { Parse(Parse), /// A message reached EOF, but is not complete. Incomplete, - /// A protocol upgrade was encountered, but not yet supported in hyper. - Upgrade, /// A client connection received a response when not waiting for one. MismatchedResponse, /// A pending item was dropped before ever being processed. @@ -74,6 +72,9 @@ pub(crate) enum Parse { Header, TooLarge, Status, + + /// A protocol upgrade was encountered, but not yet supported in hyper. + UpgradeNotSupported, } /* @@ -141,10 +142,6 @@ impl Error { Error::new(Kind::Canceled, cause.map(Into::into)) } - pub(crate) fn new_upgrade() -> Error { - Error::new(Kind::Upgrade, None) - } - pub(crate) fn new_incomplete() -> Error { Error::new(Kind::Incomplete, None) } @@ -161,10 +158,6 @@ impl Error { Error::new(Kind::Parse(Parse::Status), None) } - pub(crate) fn new_version() -> Error { - Error::new(Kind::Parse(Parse::Version), None) - } - pub(crate) fn new_version_h2() -> Error { Error::new(Kind::Parse(Parse::VersionH2), None) } @@ -260,8 +253,8 @@ impl StdError for Error { Kind::Parse(Parse::Header) => "invalid Header provided", Kind::Parse(Parse::TooLarge) => "message head is too large", Kind::Parse(Parse::Status) => "invalid Status provided", + Kind::Parse(Parse::UpgradeNotSupported) => "unsupported protocol upgrade", Kind::Incomplete => "message is incomplete", - Kind::Upgrade => "unsupported protocol upgrade", Kind::MismatchedResponse => "response received without matching request", Kind::Closed => "connection closed", Kind::Connect => "an error occurred trying to connect", @@ -325,8 +318,8 @@ impl From for Parse { } } -impl From for Parse { - fn from(_: http::uri::InvalidUriBytes) -> Parse { +impl From for Parse { + fn from(_: http::uri::InvalidUri) -> Parse { Parse::Uri } } diff --git a/src/headers.rs b/src/headers.rs index d8a0ed3117..70b963f69b 100644 --- a/src/headers.rs +++ b/src/headers.rs @@ -2,45 +2,43 @@ use std::fmt::Write; use bytes::BytesMut; use http::HeaderMap; -use http::header::{CONNECTION, CONTENT_LENGTH, EXPECT, TRANSFER_ENCODING}; +use http::header::{CONTENT_LENGTH, TRANSFER_ENCODING}; use http::header::{HeaderValue, OccupiedEntry, ValueIter}; /// Maximum number of bytes needed to serialize a u64 into ASCII decimal. const MAX_DECIMAL_U64_BYTES: usize = 20; -pub fn connection_keep_alive(headers: &HeaderMap) -> bool { - for line in headers.get_all(CONNECTION) { - if let Ok(s) = line.to_str() { - for val in s.split(',') { - if eq_ascii(val.trim(), "keep-alive") { - return true; - } - } - } - } +pub fn connection_keep_alive(value: &HeaderValue) -> bool { + connection_has(value, "keep-alive") +} - false +pub fn connection_close(value: &HeaderValue) -> bool { + connection_has(value, "close") } -pub fn connection_close(headers: &HeaderMap) -> bool { - for line in headers.get_all(CONNECTION) { - if let Ok(s) = line.to_str() { - for val in s.split(',') { - if eq_ascii(val.trim(), "close") { - return true; - } +fn connection_has(value: &HeaderValue, needle: &str) -> bool { + if let Ok(s) = value.to_str() { + for val in s.split(',') { + if eq_ascii(val.trim(), needle) { + return true; } } } - false } -pub fn content_length_parse(headers: &HeaderMap) -> Option { - content_length_parse_all(headers.get_all(CONTENT_LENGTH).into_iter()) +pub fn content_length_parse(value: &HeaderValue) -> Option { + value + .to_str() + .ok() + .and_then(|s| s.parse().ok()) +} + +pub fn content_length_parse_all(headers: &HeaderMap) -> Option { + content_length_parse_all_values(headers.get_all(CONTENT_LENGTH).into_iter()) } -pub fn content_length_parse_all(values: ValueIter) -> Option { +pub fn content_length_parse_all_values(values: ValueIter) -> Option { // If multiple Content-Length headers were sent, everything can still // be alright if they all contain the same value, and all parse // correctly. If not, then it's an error. @@ -70,10 +68,6 @@ pub fn content_length_parse_all(values: ValueIter) -> Option { } } -pub fn content_length_zero(headers: &mut HeaderMap) { - headers.insert(CONTENT_LENGTH, HeaderValue::from_static("0")); -} - pub fn content_length_value(len: u64) -> HeaderValue { let mut len_buf = BytesMut::with_capacity(MAX_DECIMAL_U64_BYTES); write!(len_buf, "{}", len) @@ -84,10 +78,6 @@ pub fn content_length_value(len: u64) -> HeaderValue { } } -pub fn expect_continue(headers: &HeaderMap) -> bool { - Some(&b"100-continue"[..]) == headers.get(EXPECT).map(|v| v.as_bytes()) -} - pub fn transfer_encoding_is_chunked(headers: &HeaderMap) -> bool { is_chunked(headers.get_all(TRANSFER_ENCODING).into_iter()) } @@ -95,10 +85,17 @@ pub fn transfer_encoding_is_chunked(headers: &HeaderMap) -> bool { pub fn is_chunked(mut encodings: ValueIter) -> bool { // chunked must always be the last encoding, according to spec if let Some(line) = encodings.next_back() { - if let Ok(s) = line.to_str() { - if let Some(encoding) = s.rsplit(',').next() { - return eq_ascii(encoding.trim(), "chunked"); - } + return is_chunked_(line); + } + + false +} + +pub fn is_chunked_(value: &HeaderValue) -> bool { + // chunked must always be the last encoding, according to spec + if let Ok(s) = value.to_str() { + if let Some(encoding) = s.rsplit(',').next() { + return eq_ascii(encoding.trim(), "chunked"); } } diff --git a/src/lib.rs b/src/lib.rs index 6ab96494ba..a603fe2573 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,6 @@ #![doc(html_root_url = "https://docs.rs/hyper/0.11.22")] #![deny(missing_docs)] -#![deny(warnings)] +//#![deny(warnings)] #![deny(missing_debug_implementations)] #![cfg_attr(all(test, feature = "nightly"), feature(test))] diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index 055cccc02f..f0309a9a83 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -4,13 +4,13 @@ use std::marker::PhantomData; use bytes::{Buf, Bytes}; use futures::{Async, Poll}; -use http::{Method, Version}; +use http::{HeaderMap, Method, Version}; use tokio_io::{AsyncRead, AsyncWrite}; use ::Chunk; -use proto::{BodyLength, Decode, Http1Transaction, MessageHead}; +use proto::{BodyLength, MessageHead}; use super::io::{Buffered}; -use super::{EncodedBuf, Encoder, Decoder}; +use super::{EncodedBuf, Encode, Encoder, Decode, Decoder, Http1Transaction, ParseContext}; const H2_PREFACE: &'static [u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; @@ -36,6 +36,7 @@ where I: AsyncRead + AsyncWrite, Conn { io: Buffered::new(io), state: State { + cached_headers: None, error: None, keep_alive: KA::Busy, method: None, @@ -118,8 +119,11 @@ where I: AsyncRead + AsyncWrite, trace!("Conn::read_head"); loop { - let (version, head) = match self.io.parse::() { - Ok(Async::Ready(head)) => (head.version, head), + let msg = match self.io.parse::(ParseContext { + cached_headers: &mut self.state.cached_headers, + req_method: &mut self.state.method, + }) { + Ok(Async::Ready(msg)) => msg, Ok(Async::NotReady) => return Ok(Async::NotReady), Err(e) => { // If we are currently waiting on a message, then an empty @@ -141,48 +145,32 @@ where I: AsyncRead + AsyncWrite, } }; - match version { - Version::HTTP_10 | - Version::HTTP_11 => {}, - _ => { - error!("unimplemented HTTP Version = {:?}", version); - self.state.close_read(); - //TODO: replace this with a more descriptive error - return Err(::Error::new_version()); - } - }; - self.state.version = version; - - let decoder = match T::decoder(&head, &mut self.state.method) { - Ok(Decode::Normal(d)) => { + self.state.version = msg.head.version; + let head = msg.head; + let decoder = match msg.decode { + Decode::Normal(d) => { d }, - Ok(Decode::Final(d)) => { + Decode::Final(d) => { trace!("final decoder, HTTP ending"); debug_assert!(d.is_eof()); self.state.close_read(); d }, - Ok(Decode::Ignore) => { + Decode::Ignore => { // likely a 1xx message that we can ignore continue; } - Err(e) => { - debug!("decoder error = {:?}", e); - self.state.close_read(); - return self.on_parse_error(e) - .map(|()| Async::NotReady); - } }; debug!("incoming body is {}", decoder); self.state.busy(); - if head.expecting_continue() { - let msg = b"HTTP/1.1 100 Continue\r\n\r\n"; - self.io.write_buf_mut().extend_from_slice(msg); + if msg.expect_continue { + let cont = b"HTTP/1.1 100 Continue\r\n\r\n"; + self.io.write_buf_mut().extend_from_slice(cont); } - let wants_keep_alive = head.should_keep_alive(); + let wants_keep_alive = msg.keep_alive; self.state.keep_alive &= wants_keep_alive; let (body, reading) = if decoder.is_eof() { (false, Reading::KeepAlive) @@ -410,8 +398,17 @@ where I: AsyncRead + AsyncWrite, self.enforce_version(&mut head); let buf = self.io.write_buf_mut(); - self.state.writing = match T::encode(head, body, &mut self.state.method, self.state.title_case_headers, buf) { + self.state.writing = match T::encode(Encode { + head: &mut head, + body, + keep_alive: self.state.wants_keep_alive(), + req_method: &mut self.state.method, + title_case_headers: self.state.title_case_headers, + }, buf) { Ok(encoder) => { + debug_assert!(self.state.cached_headers.is_none()); + debug_assert!(head.headers.is_empty()); + self.state.cached_headers = Some(head.headers); if !encoder.is_eof() { Writing::Body(encoder) } else if encoder.is_last() { @@ -430,24 +427,12 @@ where I: AsyncRead + AsyncWrite, // If we know the remote speaks an older version, we try to fix up any messages // to work with our older peer. fn enforce_version(&mut self, head: &mut MessageHead) { - //use header::Connection; - - let wants_keep_alive = if self.state.wants_keep_alive() { - let ka = head.should_keep_alive(); - self.state.keep_alive &= ka; - ka - } else { - false - }; match self.state.version { Version::HTTP_10 => { // If the remote only knows HTTP/1.0, we should force ourselves // to do only speak HTTP/1.0 as well. head.version = Version::HTTP_10; - if wants_keep_alive { - //TODO: head.headers.set(Connection::keep_alive()); - } }, _ => { // If the remote speaks HTTP/1.1, then it *should* be fine with @@ -617,13 +602,27 @@ impl fmt::Debug for Conn { } struct State { + /// Re-usable HeaderMap to reduce allocating new ones. + cached_headers: Option, + /// If an error occurs when there wasn't a direct way to return it + /// back to the user, this is set. error: Option<::Error>, + /// Current keep-alive status. keep_alive: KA, + /// If mid-message, the HTTP Method that started it. + /// + /// This is used to know things such as if the message can include + /// a body or not. method: Option, title_case_headers: bool, + /// Set to true when the Dispatcher should poll read operations + /// again. See the `maybe_notify` method for more. notify_read: bool, + /// State of allowed reads reading: Reading, + /// State of allowed writes writing: Writing, + /// Either HTTP/1.0 or 1.1 connection version: Version, } diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs index 1965f78748..2de161ba22 100644 --- a/src/proto/h1/dispatch.rs +++ b/src/proto/h1/dispatch.rs @@ -4,7 +4,8 @@ use http::{Request, Response, StatusCode}; use tokio_io::{AsyncRead, AsyncWrite}; use body::{Body, Payload}; -use proto::{BodyLength, Conn, Http1Transaction, MessageHead, RequestHead, RequestLine, ResponseHead}; +use proto::{BodyLength, Conn, MessageHead, RequestHead, RequestLine, ResponseHead}; +use super::Http1Transaction; use service::Service; pub(crate) struct Dispatcher { diff --git a/src/proto/h1/encode.rs b/src/proto/h1/encode.rs index 41d0135d3d..cfda1c30c2 100644 --- a/src/proto/h1/encode.rs +++ b/src/proto/h1/encode.rs @@ -7,7 +7,7 @@ use iovec::IoVec; use common::StaticBuf; /// Encoders to handle different Transfer-Encodings. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct Encoder { kind: Kind, is_last: bool, @@ -70,8 +70,9 @@ impl Encoder { } } - pub fn set_last(&mut self) { - self.is_last = true; + pub fn set_last(mut self, is_last: bool) -> Self { + self.is_last = is_last; + self } pub fn is_last(&self) -> bool { diff --git a/src/proto/h1/io.rs b/src/proto/h1/io.rs index 27ae2cef38..0e65545a99 100644 --- a/src/proto/h1/io.rs +++ b/src/proto/h1/io.rs @@ -8,7 +8,7 @@ use futures::{Async, Poll}; use iovec::IoVec; use tokio_io::{AsyncRead, AsyncWrite}; -use proto::{Http1Transaction, MessageHead}; +use super::{Http1Transaction, ParseContext, ParsedMessage}; /// The initial buffer size allocated before trying to read from IO. pub(crate) const INIT_BUFFER_SIZE: usize = 8192; @@ -126,12 +126,16 @@ where } } - pub(super) fn parse(&mut self) -> Poll, ::Error> { + pub(super) fn parse(&mut self, ctx: ParseContext) + -> Poll, ::Error> + where + S: Http1Transaction, + { loop { - match try!(S::parse(&mut self.read_buf)) { - Some((head, len)) => { - debug!("parsed {} headers ({} bytes)", head.headers.len(), len); - return Ok(Async::Ready(head)) + match try!(S::parse(&mut self.read_buf, ParseContext { cached_headers: ctx.cached_headers, req_method: ctx.req_method, })) { + Some(msg) => { + debug!("parsed {} headers", msg.head.headers.len()); + return Ok(Async::Ready(msg)) }, None => { if self.read_buf.capacity() >= self.max_buf_size { @@ -617,7 +621,11 @@ mod tests { let mock = AsyncIo::new_buf(raw, raw.len()); let mut buffered = Buffered::<_, Cursor>>::new(mock); - assert_eq!(buffered.parse::<::proto::ClientTransaction>().unwrap(), Async::NotReady); + let ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut None, + }; + assert!(buffered.parse::<::proto::ClientTransaction>(ctx).unwrap().is_not_ready()); assert!(buffered.io.blocked()); } diff --git a/src/proto/h1/mod.rs b/src/proto/h1/mod.rs index e2cc9087a4..d3576fa4d1 100644 --- a/src/proto/h1/mod.rs +++ b/src/proto/h1/mod.rs @@ -1,3 +1,8 @@ +use bytes::BytesMut; +use http::{HeaderMap, Method}; + +use proto::{MessageHead, BodyLength}; + pub(crate) use self::conn::Conn; pub(crate) use self::dispatch::Dispatcher; pub use self::decode::Decoder; @@ -11,5 +16,58 @@ mod decode; pub(crate) mod dispatch; mod encode; mod io; -pub mod role; +mod role; + + +pub(crate) type ServerTransaction = self::role::Server; +//pub type ServerTransaction = self::role::Server; +//pub type ServerUpgradeTransaction = self::role::Server; + +pub(crate) type ClientTransaction = self::role::Client; +pub(crate) type ClientUpgradeTransaction = self::role::Client; + +pub(crate) trait Http1Transaction { + type Incoming; + type Outgoing: Default; + fn parse(bytes: &mut BytesMut, ctx: ParseContext) -> ParseResult; + fn encode(enc: Encode, dst: &mut Vec) -> ::Result; + + fn on_error(err: &::Error) -> Option>; + + fn should_error_on_parse_eof() -> bool; + fn should_read_first() -> bool; +} + +pub(crate) type ParseResult = Result>, ::error::Parse>; + +#[derive(Debug)] +pub(crate) struct ParsedMessage { + head: MessageHead, + decode: Decode, + expect_continue: bool, + keep_alive: bool, +} + +pub(crate) struct ParseContext<'a> { + cached_headers: &'a mut Option, + req_method: &'a mut Option, +} + +/// Passed to Http1Transaction::encode +pub(crate) struct Encode<'a, T: 'a> { + head: &'a mut MessageHead, + body: Option, + keep_alive: bool, + req_method: &'a mut Option, + title_case_headers: bool, +} +#[derive(Debug, PartialEq)] +pub enum Decode { + /// Decode normally. + Normal(Decoder), + /// After this decoder is done, HTTP is done. + Final(Decoder), + /// A header block that should be ignored, like unknown 1xx responses. + Ignore, +} diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index 1fae294ede..4bd40f690a 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -2,13 +2,14 @@ use std::fmt::{self, Write}; use std::mem; use bytes::{BytesMut, Bytes}; -use http::header::{CONTENT_LENGTH, DATE, Entry, HeaderName, HeaderValue, TRANSFER_ENCODING}; -use http::{HeaderMap, Method, StatusCode, Uri, Version}; +use http::header::{self, Entry, HeaderName, HeaderValue}; +use http::{HeaderMap, Method, StatusCode, Version}; use httparse; +use error::Parse; use headers; -use proto::{BodyLength, Decode, MessageHead, Http1Transaction, ParseResult, RequestLine, RequestHead}; -use proto::h1::{Encoder, Decoder, date}; +use proto::{BodyLength, MessageHead, RequestLine, RequestHead}; +use proto::h1::{Decode, Decoder, Encode, Encoder, Http1Transaction, ParseResult, ParseContext, ParsedMessage, date}; const MAX_HEADERS: usize = 100; const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific @@ -18,9 +19,9 @@ const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific // There is 1 modifier, OnUpgrade, which can wrap Client and Server, // to signal that HTTP upgrades are not supported. -pub struct Client(T); +pub(crate) struct Client(T); -pub struct Server(T); +pub(crate) struct Server(T); impl Http1Transaction for Server where @@ -29,7 +30,7 @@ where type Incoming = RequestLine; type Outgoing = StatusCode; - fn parse(buf: &mut BytesMut) -> ParseResult { + fn parse(buf: &mut BytesMut, ctx: ParseContext) -> ParseResult { if buf.len() == 0 { return Ok(None); } @@ -38,7 +39,7 @@ where // values into it. By not zeroing out the stack memory, this saves // a good ~5% on pipeline benchmarks. let mut headers_indices: [HeaderIndices; MAX_HEADERS] = unsafe { mem::uninitialized() }; - let (len, method, path, version, headers_len) = { + let (len, subject, version, headers_len) = { let mut headers: [httparse::Header; MAX_HEADERS] = unsafe { mem::uninitialized() }; trace!("Request.parse([Header; {}], [u8; {}])", headers.len(), buf.len()); let mut req = httparse::Request::new(&mut headers); @@ -46,11 +47,8 @@ where httparse::Status::Complete(len) => { trace!("Request.parse Complete({})", len); let method = Method::from_bytes(req.method.unwrap().as_bytes())?; - let path = req.path.unwrap(); - let bytes_ptr = buf.as_ref().as_ptr() as usize; - let path_start = path.as_ptr() as usize - bytes_ptr; - let path_end = path_start + path.len(); - let path = (path_start, path_end); + let path = req.path.unwrap().parse()?; + let subject = RequestLine(method, path); let version = if req.version.unwrap() == 1 { Version::HTTP_11 } else { @@ -59,31 +57,13 @@ where record_header_indices(buf.as_ref(), &req.headers, &mut headers_indices); let headers_len = req.headers.len(); - (len, method, path, version, headers_len) + (len, subject, version, headers_len) } httparse::Status::Partial => return Ok(None), } }; let slice = buf.split_to(len).freeze(); - let path = slice.slice(path.0, path.1); - let path = Uri::from_shared(path)?; - let subject = RequestLine( - method, - path, - ); - - let headers = fill_headers(slice, &headers_indices[..headers_len]); - - Ok(Some((MessageHead { - version: version, - subject: subject, - headers: headers, - }, len))) - } - - fn decoder(head: &MessageHead, method: &mut Option) -> ::Result { - *method = Some(head.subject.0.clone()); // According to https://tools.ietf.org/html/rfc7230#section-3.3.3 // 1. (irrelevant to Request) @@ -94,6 +74,117 @@ where // 6. Length 0. // 7. (irrelevant to Request) + + let mut decoder = None; + let mut expect_continue = false; + let mut keep_alive = version == Version::HTTP_11; + let mut con_len = None; + let mut is_te = false; + let mut is_te_chunked = false; + + let mut headers = ctx.cached_headers + .take() + .unwrap_or_else(HeaderMap::new); + + headers.reserve(headers_len); + + for header in &headers_indices[..headers_len] { + let name = HeaderName::from_bytes(&slice[header.name.0..header.name.1]) + .expect("header name already validated"); + let val = slice.slice(header.value.0, header.value.1); + // Unsafe: httparse already validated header value + let value = unsafe { + HeaderValue::from_shared_unchecked(val) + }; + + match name { + header::TRANSFER_ENCODING => { + // https://tools.ietf.org/html/rfc7230#section-3.3.3 + // If Transfer-Encoding header is present, and 'chunked' is + // not the final encoding, and this is a Request, then it is + // mal-formed. A server should respond with 400 Bad Request. + if version == Version::HTTP_10 { + debug!("HTTP/1.0 cannot have Transfer-Encoding header"); + return Err(Parse::Header); + } + is_te = true; + if headers::is_chunked_(&value) { + is_te_chunked = true; + decoder = Some(Decoder::chunked()); + //debug!("request with transfer-encoding header, but not chunked, bad request"); + //return Err(Parse::Header); + } + }, + header::CONTENT_LENGTH => { + if is_te { + continue; + } + let len = value.to_str() + .map_err(|_| Parse::Header) + .and_then(|s| s.parse().map_err(|_| Parse::Header))?; + if let Some(prev) = con_len { + if prev != len { + debug!( + "multiple Content-Length headers with different values: [{}, {}]", + prev, + len, + ); + return Err(Parse::Header); + } + // we don't need to append this secondary length + continue; + } + con_len = Some(len); + decoder = Some(Decoder::length(len)); + }, + header::CONNECTION => { + // keep_alive was previously set to default for Version + if keep_alive { + // HTTP/1.1 + keep_alive = !headers::connection_close(&value); + + } else { + // HTTP/1.0 + keep_alive = headers::connection_keep_alive(&value); + } + }, + header::EXPECT => { + expect_continue = value.as_bytes() == b"100-continue"; + }, + + _ => (), + } + + headers.append(name, value); + } + + let decoder = if let Some(decoder) = decoder { + decoder + } else { + if is_te && !is_te_chunked { + debug!("request with transfer-encoding header, but not chunked, bad request"); + return Err(Parse::Header); + } + Decoder::length(0) + }; + + *ctx.req_method = Some(subject.0.clone()); + + Ok(Some(ParsedMessage { + head: MessageHead { + version, + subject, + headers, + }, + decode: Decode::Normal(decoder), + expect_continue, + keep_alive, + })) + } + + /* + fn decoder(head: &MessageHead, method: &mut Option) -> ::Result { + *method = Some(head.subject.0.clone()); if head.headers.contains_key(TRANSFER_ENCODING) { // https://tools.ietf.org/html/rfc7230#section-3.3.3 // If Transfer-Encoding header is present, and 'chunked' is @@ -116,70 +207,233 @@ where } else { Ok(Decode::Normal(Decoder::length(0))) } - } + }*/ - fn encode( - mut head: MessageHead, - body: Option, - method: &mut Option, - _title_case_headers: bool, - dst: &mut Vec, - ) -> ::Result { - trace!("Server::encode body={:?}, method={:?}", body, method); + fn encode(mut msg: Encode, dst: &mut Vec) -> ::Result { + trace!("Server::encode body={:?}, method={:?}", msg.body, msg.req_method); + debug_assert!(!msg.title_case_headers, "no server config for title case headers"); // hyper currently doesn't support returning 1xx status codes as a Response // This is because Service only allows returning a single Response, and // so if you try to reply with a e.g. 100 Continue, you have no way of // replying with the latter status code response. - let ret = if StatusCode::SWITCHING_PROTOCOLS == head.subject { - T::on_encode_upgrade(&mut head) - .map(|_| { - let mut enc = Server::set_length(&mut head, body, method.as_ref()); - enc.set_last(); - enc - }) - } else if head.subject.is_informational() { + let (ret, mut is_last) = if StatusCode::SWITCHING_PROTOCOLS == msg.head.subject { + (T::on_encode_upgrade(&mut msg), true) + } else if msg.head.subject.is_informational() { error!("response with 1xx status code not supported"); - head = MessageHead::default(); - head.subject = StatusCode::INTERNAL_SERVER_ERROR; - headers::content_length_zero(&mut head.headers); + *msg.head = MessageHead::default(); + msg.head.subject = StatusCode::INTERNAL_SERVER_ERROR; + msg.body = None; //TODO: change this to a more descriptive error than just a parse error - Err(::Error::new_status()) + (Err(::Error::new_status()), true) } else { - Ok(Server::set_length(&mut head, body, method.as_ref())) + (Ok(()), !msg.keep_alive) }; + // In some error cases, we don't know about the invalid message until already + // pushing some bytes onto the `dst`. In those cases, we don't want to send + // the half-pushed message, so rewind to before. + let orig_len = dst.len(); + let rewind = |dst: &mut Vec| { + dst.truncate(orig_len); + }; - let init_cap = 30 + head.headers.len() * AVERAGE_HEADER_SIZE; + let init_cap = 30 + msg.head.headers.len() * AVERAGE_HEADER_SIZE; dst.reserve(init_cap); - if head.version == Version::HTTP_11 && head.subject == StatusCode::OK { + if msg.head.version == Version::HTTP_11 && msg.head.subject == StatusCode::OK { extend(dst, b"HTTP/1.1 200 OK\r\n"); } else { - match head.version { + match msg.head.version { Version::HTTP_10 => extend(dst, b"HTTP/1.0 "), Version::HTTP_11 => extend(dst, b"HTTP/1.1 "), _ => unreachable!(), } - extend(dst, head.subject.as_str().as_bytes()); + extend(dst, msg.head.subject.as_str().as_bytes()); extend(dst, b" "); // a reason MUST be written, as many parsers will expect it. - extend(dst, head.subject.canonical_reason().unwrap_or("").as_bytes()); + extend(dst, msg.head.subject.canonical_reason().unwrap_or("").as_bytes()); extend(dst, b"\r\n"); } - write_headers(&head.headers, dst); - // using http::h1::date is quite a lot faster than generating a unique Date header each time - // like req/s goes up about 10% - if !head.headers.contains_key(DATE) { + + let mut encoder = Encoder::length(0); + let mut wrote_len = false; + let mut wrote_date = false; + 'headers: for (name, mut values) in msg.head.headers.drain() { + match name { + header::CONTENT_LENGTH => { + if wrote_len { + warn!("transfer-encoding and content-length both found, canceling"); + rewind(dst); + return Err(::Error::new_header()); + } + match msg.body { + Some(BodyLength::Known(len)) => { + // The Payload claims to know a length, and + // the headers are already set. For performance + // reasons, we are just going to trust that + // the values match. + // + // In debug builds, we'll assert they are the + // same to help developers find bugs. + encoder = Encoder::length(len); + }, + Some(BodyLength::Unknown) => { + // The Payload impl didn't know how long the + // body is, but a length header was included. + // We have to parse the value to return our + // Encoder... + let mut folded = None::<(u64, HeaderValue)>; + for value in values { + if let Some(len) = headers::content_length_parse(&value) { + if let Some(fold) = folded { + if fold.0 != len { + warn!("multiple Content-Length values found: [{}, {}]", fold.0, len); + rewind(dst); + return Err(::Error::new_header()); + } + folded = Some(fold); + } else { + folded = Some((len, value)); + } + } else { + warn!("illegal Content-Length value: {:?}", value); + rewind(dst); + return Err(::Error::new_header()); + } + } + if let Some((len, value)) = folded { + encoder = Encoder::length(len); + extend(dst, b"content-length: "); + extend(dst, value.as_bytes()); + extend(dst, b"\r\n"); + wrote_len = true; + continue 'headers; + } else { + // No values in content-length... ignore? + continue 'headers; + } + }, + None => { + // We have no body to actually send, + // but the headers claim a content-length. + // There's only 2 ways this makes sense: + // + // - The header says the length is `0`. + // - This is a response to a `HEAD` request. + if msg.req_method == &Some(Method::HEAD) { + debug_assert_eq!(encoder, Encoder::length(0)); + } else { + for value in values { + if value.as_bytes() != b"0" { + warn!("content-length value found, but empty body provided: {:?}", value); + } + } + continue 'headers; + } + } + } + wrote_len = true; + }, + header::TRANSFER_ENCODING => { + if wrote_len { + warn!("transfer-encoding and content-length both found, canceling"); + rewind(dst); + return Err(::Error::new_header()); + } + // check that we actually can send a chunked body... + if msg.head.version == Version::HTTP_10 || !Server::can_chunked(msg.req_method, msg.head.subject) { + continue; + } + wrote_len = true; + encoder = Encoder::chunked(); + + extend(dst, b"transfer-encoding: "); + + let mut saw_chunked; + if let Some(te) = values.next() { + extend(dst, te.as_bytes()); + saw_chunked = headers::is_chunked_(&te); + for value in values { + extend(dst, b", "); + extend(dst, value.as_bytes()); + saw_chunked = headers::is_chunked_(&value); + } + if !saw_chunked { + extend(dst, b", chunked\r\n"); + } else { + extend(dst, b"\r\n"); + } + } else { + // zero lines? add a chunked line then + extend(dst, b"chunked\r\n"); + } + continue 'headers; + }, + header::CONNECTION => { + if !is_last { + for value in values { + extend(dst, name.as_str().as_bytes()); + extend(dst, b": "); + extend(dst, value.as_bytes()); + extend(dst, b"\r\n"); + + if headers::connection_close(&value) { + is_last = true; + } + } + continue 'headers; + } + }, + header::DATE => { + wrote_date = true; + }, + _ => (), + } + //TODO: this should perhaps instead combine them into + //single lines, as RFC7230 suggests is preferable. + for value in values { + extend(dst, name.as_str().as_bytes()); + extend(dst, b": "); + extend(dst, value.as_bytes()); + extend(dst, b"\r\n"); + } + } + + if !wrote_len { + encoder = match msg.body { + Some(BodyLength::Unknown) => { + if msg.head.version == Version::HTTP_10 || !Server::can_chunked(msg.req_method, msg.head.subject) { + Encoder::close_delimited() + } else { + extend(dst, b"transfer-encoding: chunked\r\n"); + Encoder::chunked() + } + }, + None | + Some(BodyLength::Known(0)) => { + extend(dst, b"content-length: 0\r\n"); + Encoder::length(0) + }, + Some(BodyLength::Known(len)) => { + let _ = write!(FastWrite(dst), "content-length: {}\r\n", len); + Encoder::length(len) + }, + }; + } + + // cached date is much faster than formatting every request + if !wrote_date { dst.reserve(date::DATE_VALUE_LENGTH + 8); extend(dst, b"date: "); date::extend(dst); + extend(dst, b"\r\n\r\n"); + } else { extend(dst, b"\r\n"); } - extend(dst, b"\r\n"); - ret + ret.map(|()| encoder.set_last(is_last)) } fn on_error(err: &::Error) -> Option> { @@ -213,6 +467,7 @@ where } impl Server<()> { + /* fn set_length(head: &mut MessageHead, body: Option, method: Option<&Method>) -> Encoder { // these are here thanks to borrowck // `if method == Some(&Method::Get)` says the RHS doesn't live long enough @@ -239,13 +494,31 @@ impl Server<()> { if let (Some(body), true) = (body, can_have_body) { set_length(&mut head.headers, body, head.version == Version::HTTP_11) } else { - head.headers.remove(TRANSFER_ENCODING); + head.headers.remove(header::TRANSFER_ENCODING); if can_have_body { headers::content_length_zero(&mut head.headers); } Encoder::length(0) } } + */ + + fn can_chunked(method: &Option, status: StatusCode) -> bool { + if method == &Some(Method::HEAD) { + false + } else if method == &Some(Method::CONNECT) && status.is_success() { + false + } else { + match status { + // TODO: support for 1xx codes needs improvement everywhere + // would be 100...199 => false + StatusCode::SWITCHING_PROTOCOLS | + StatusCode::NO_CONTENT | + StatusCode::NOT_MODIFIED => false, + _ => true, + } + } + } } impl Http1Transaction for Client @@ -255,7 +528,7 @@ where type Incoming = StatusCode; type Outgoing = RequestLine; - fn parse(buf: &mut BytesMut) -> ParseResult { + fn parse(buf: &mut BytesMut, ctx: ParseContext) -> ParseResult { if buf.len() == 0 { return Ok(None); } @@ -285,16 +558,80 @@ where let slice = buf.split_to(len).freeze(); - let headers = fill_headers(slice, &headers_indices[..headers_len]); + let mut headers = ctx.cached_headers + .take() + .unwrap_or_else(HeaderMap::new); + + headers.reserve(headers_len); + fill_headers(&mut headers, slice, &headers_indices[..headers_len]); + + let keep_alive = version == Version::HTTP_11; - Ok(Some((MessageHead { - version: version, + let head = MessageHead { + version, subject: status, - headers: headers, - }, len))) + headers, + }; + let decode = Client::::decoder(&head, ctx.req_method)?; + + Ok(Some(ParsedMessage { + head, + decode, + expect_continue: false, + keep_alive, + })) + } + + fn encode(msg: Encode, dst: &mut Vec) -> ::Result { + trace!("Client::encode body={:?}, method={:?}", msg.body, msg.req_method); + + *msg.req_method = Some(msg.head.subject.0.clone()); + + let body = Client::set_length(msg.head, msg.body); + + let init_cap = 30 + msg.head.headers.len() * AVERAGE_HEADER_SIZE; + dst.reserve(init_cap); + + + extend(dst, msg.head.subject.0.as_str().as_bytes()); + extend(dst, b" "); + //TODO: add API to http::Uri to encode without std::fmt + let _ = write!(FastWrite(dst), "{} ", msg.head.subject.1); + + match msg.head.version { + Version::HTTP_10 => extend(dst, b"HTTP/1.0"), + Version::HTTP_11 => extend(dst, b"HTTP/1.1"), + _ => unreachable!(), + } + extend(dst, b"\r\n"); + + if msg.title_case_headers { + write_headers_title_case(&msg.head.headers, dst); + } else { + write_headers(&msg.head.headers, dst); + } + extend(dst, b"\r\n"); + msg.head.headers.clear(); //TODO: remove when switching to drain() + + Ok(body) + } + + fn on_error(_err: &::Error) -> Option> { + // we can't tell the server about any errors it creates + None } - fn decoder(inc: &MessageHead, method: &mut Option) -> ::Result { + fn should_error_on_parse_eof() -> bool { + true + } + + fn should_read_first() -> bool { + false + } +} + +impl Client { + fn decoder(inc: &MessageHead, method: &mut Option) -> Result { // According to https://tools.ietf.org/html/rfc7230#section-3.3.3 // 1. HEAD responses, and Status 1xx, 204, and 304 cannot have a body. // 2. Status 2xx to a CONNECT cannot have a body. @@ -332,82 +669,30 @@ where } } - if inc.headers.contains_key(TRANSFER_ENCODING) { + if inc.headers.contains_key(header::TRANSFER_ENCODING) { // https://tools.ietf.org/html/rfc7230#section-3.3.3 // If Transfer-Encoding header is present, and 'chunked' is // not the final encoding, and this is a Request, then it is // mal-formed. A server should respond with 400 Bad Request. if inc.version == Version::HTTP_10 { debug!("HTTP/1.0 cannot have Transfer-Encoding header"); - Err(::Error::new_header()) + Err(Parse::Header) } else if headers::transfer_encoding_is_chunked(&inc.headers) { Ok(Decode::Normal(Decoder::chunked())) } else { trace!("not chunked, read till eof"); Ok(Decode::Normal(Decoder::eof())) } - } else if let Some(len) = headers::content_length_parse(&inc.headers) { + } else if let Some(len) = headers::content_length_parse_all(&inc.headers) { Ok(Decode::Normal(Decoder::length(len))) - } else if inc.headers.contains_key(CONTENT_LENGTH) { + } else if inc.headers.contains_key(header::CONTENT_LENGTH) { debug!("illegal Content-Length header"); - Err(::Error::new_header()) + Err(Parse::Header) } else { trace!("neither Transfer-Encoding nor Content-Length"); Ok(Decode::Normal(Decoder::eof())) } } - - fn encode( - mut head: MessageHead, - body: Option, - method: &mut Option, - title_case_headers: bool, - dst: &mut Vec, - ) -> ::Result { - trace!("Client::encode body={:?}, method={:?}", body, method); - - *method = Some(head.subject.0.clone()); - - let body = Client::set_length(&mut head, body); - - let init_cap = 30 + head.headers.len() * AVERAGE_HEADER_SIZE; - dst.reserve(init_cap); - - - extend(dst, head.subject.0.as_str().as_bytes()); - extend(dst, b" "); - //TODO: add API to http::Uri to encode without std::fmt - let _ = write!(FastWrite(dst), "{} ", head.subject.1); - - match head.version { - Version::HTTP_10 => extend(dst, b"HTTP/1.0"), - Version::HTTP_11 => extend(dst, b"HTTP/1.1"), - _ => unreachable!(), - } - extend(dst, b"\r\n"); - - if title_case_headers { - write_headers_title_case(&head.headers, dst); - } else { - write_headers(&head.headers, dst); - } - extend(dst, b"\r\n"); - - Ok(body) - } - - fn on_error(_err: &::Error) -> Option> { - // we can't tell the server about any errors it creates - None - } - - fn should_error_on_parse_eof() -> bool { - true - } - - fn should_read_first() -> bool { - false - } } impl Client<()> { @@ -419,7 +704,7 @@ impl Client<()> { && (head.subject.0 != Method::CONNECT); set_length(&mut head.headers, body, can_chunked) } else { - head.headers.remove(TRANSFER_ENCODING); + head.headers.remove(header::TRANSFER_ENCODING); Encoder::length(0) } } @@ -433,13 +718,13 @@ fn set_length(headers: &mut HeaderMap, body: BodyLength, can_chunked: bool) -> E // Content-Length header while holding an `Entry` for the Transfer-Encoding // header, so unfortunately, we must do the check here, first. - let existing_con_len = headers::content_length_parse(headers); + let existing_con_len = headers::content_length_parse_all(headers); let mut should_remove_con_len = false; if can_chunked { // If the user set a transfer-encoding, respect that. Let's just // make sure `chunked` is the final encoding. - let encoder = match headers.entry(TRANSFER_ENCODING) + let encoder = match headers.entry(header::TRANSFER_ENCODING) .expect("TRANSFER_ENCODING is valid HeaderName") { Entry::Occupied(te) => { should_remove_con_len = true; @@ -485,7 +770,7 @@ fn set_length(headers: &mut HeaderMap, body: BodyLength, can_chunked: bool) -> E // content-length header. if let Some(encoder) = encoder { if should_remove_con_len && existing_con_len.is_some() { - headers.remove(CONTENT_LENGTH); + headers.remove(header::CONTENT_LENGTH); } return encoder; } @@ -504,7 +789,7 @@ fn set_length(headers: &mut HeaderMap, body: BodyLength, can_chunked: bool) -> E // Chunked isn't legal, so if it is set, we need to remove it. // Also, if it *is* set, then we shouldn't replace with a length, // since the user tried to imply there isn't a length. - let encoder = if headers.remove(TRANSFER_ENCODING).is_some() { + let encoder = if headers.remove(header::TRANSFER_ENCODING).is_some() { trace!("removing illegal transfer-encoding header"); should_remove_con_len = true; Encoder::close_delimited() @@ -517,7 +802,7 @@ fn set_length(headers: &mut HeaderMap, body: BodyLength, can_chunked: bool) -> E }; if should_remove_con_len && existing_con_len.is_some() { - headers.remove(CONTENT_LENGTH); + headers.remove(header::CONTENT_LENGTH); } encoder @@ -533,12 +818,12 @@ fn set_content_length(headers: &mut HeaderMap, len: u64) -> Encoder { // so perhaps only do that while the user is developing/testing. if cfg!(debug_assertions) { - match headers.entry(CONTENT_LENGTH) + match headers.entry(header::CONTENT_LENGTH) .expect("CONTENT_LENGTH is valid HeaderName") { Entry::Occupied(mut cl) => { // Internal sanity check, we should have already determined // that the header was illegal before calling this function. - debug_assert!(headers::content_length_parse_all(cl.iter()).is_none()); + debug_assert!(headers::content_length_parse_all_values(cl.iter()).is_none()); // Uh oh, the user set `Content-Length` headers, but set bad ones. // This would be an illegal message anyways, so let's try to repair // with our known good length. @@ -553,26 +838,26 @@ fn set_content_length(headers: &mut HeaderMap, len: u64) -> Encoder { } } } else { - headers.insert(CONTENT_LENGTH, headers::content_length_value(len)); + headers.insert(header::CONTENT_LENGTH, headers::content_length_value(len)); Encoder::length(len) } } -pub trait OnUpgrade { - fn on_encode_upgrade(head: &mut MessageHead) -> ::Result<()>; - fn on_decode_upgrade() -> ::Result; +pub(crate) trait OnUpgrade { + fn on_encode_upgrade(msg: &mut Encode) -> ::Result<()>; + fn on_decode_upgrade() -> Result; } -pub enum YesUpgrades {} +pub(crate) enum YesUpgrades {} -pub enum NoUpgrades {} +pub(crate) enum NoUpgrades {} impl OnUpgrade for YesUpgrades { - fn on_encode_upgrade(_head: &mut MessageHead) -> ::Result<()> { + fn on_encode_upgrade(_: &mut Encode) -> ::Result<()> { Ok(()) } - fn on_decode_upgrade() -> ::Result { + fn on_decode_upgrade() -> Result { debug!("101 response received, upgrading"); // 101 upgrades always have no body Ok(Decoder::length(0)) @@ -580,18 +865,18 @@ impl OnUpgrade for YesUpgrades { } impl OnUpgrade for NoUpgrades { - fn on_encode_upgrade(head: &mut MessageHead) -> ::Result<()> { + fn on_encode_upgrade(msg: &mut Encode) -> ::Result<()> { error!("response with 101 status code not supported"); - *head = MessageHead::default(); - head.subject = ::StatusCode::INTERNAL_SERVER_ERROR; - headers::content_length_zero(&mut head.headers); + *msg.head = MessageHead::default(); + msg.head.subject = ::StatusCode::INTERNAL_SERVER_ERROR; + msg.body = None; //TODO: replace with more descriptive error - return Err(::Error::new_status()); + Err(::Error::new_status()) } - fn on_decode_upgrade() -> ::Result { + fn on_decode_upgrade() -> Result { debug!("received 101 upgrade response, not supported"); - return Err(::Error::new_upgrade()); + Err(Parse::UpgradeNotSupported) } } @@ -613,8 +898,7 @@ fn record_header_indices(bytes: &[u8], headers: &[httparse::Header], indices: &m } } -fn fill_headers(slice: Bytes, indices: &[HeaderIndices]) -> HeaderMap { - let mut headers = HeaderMap::with_capacity(indices.len()); +fn fill_headers(headers: &mut HeaderMap, slice: Bytes, indices: &[HeaderIndices]) { for header in indices { let name = HeaderName::from_bytes(&slice[header.name.0..header.name.1]) .expect("header name already validated"); @@ -625,7 +909,6 @@ fn fill_headers(slice: Bytes, indices: &[HeaderIndices]) -> HeaderMap { }; headers.append(name, value); } - headers } // Write header names as title case. The header name is assumed to be ASCII, @@ -699,49 +982,29 @@ fn extend(dst: &mut Vec, data: &[u8]) { mod tests { use bytes::BytesMut; - use proto::{Decode, MessageHead}; - use super::{Decoder, Server as S, Client as C, NoUpgrades, Http1Transaction}; + use super::*; + use super::{Server as S, Client as C}; type Server = S; type Client = C; - impl Decode { - fn final_(self) -> Decoder { - match self { - Decode::Final(d) => d, - other => panic!("expected Final, found {:?}", other), - } - } - - fn normal(self) -> Decoder { - match self { - Decode::Normal(d) => d, - other => panic!("expected Normal, found {:?}", other), - } - } - - fn ignore(self) { - match self { - Decode::Ignore => {}, - other => panic!("expected Ignore, found {:?}", other), - } - } - } - - #[test] fn test_parse_request() { extern crate pretty_env_logger; let _ = pretty_env_logger::try_init(); let mut raw = BytesMut::from(b"GET /echo HTTP/1.1\r\nHost: hyper.rs\r\n\r\n".to_vec()); - let expected_len = raw.len(); - let (req, len) = Server::parse(&mut raw).unwrap().unwrap(); - assert_eq!(len, expected_len); - assert_eq!(req.subject.0, ::Method::GET); - assert_eq!(req.subject.1, "/echo"); - assert_eq!(req.version, ::Version::HTTP_11); - assert_eq!(req.headers.len(), 1); - assert_eq!(req.headers["Host"], "hyper.rs"); + let mut method = None; + let msg = Server::parse(&mut raw, ParseContext { + cached_headers: &mut None, + req_method: &mut method, + }).unwrap().unwrap(); + assert_eq!(raw.len(), 0); + assert_eq!(msg.head.subject.0, ::Method::GET); + assert_eq!(msg.head.subject.1, "/echo"); + assert_eq!(msg.head.version, ::Version::HTTP_11); + assert_eq!(msg.head.headers.len(), 1); + assert_eq!(msg.head.headers["Host"], "hyper.rs"); + assert_eq!(method, Some(::Method::GET)); } @@ -750,142 +1013,294 @@ mod tests { extern crate pretty_env_logger; let _ = pretty_env_logger::try_init(); let mut raw = BytesMut::from(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n".to_vec()); - let expected_len = raw.len(); - let (req, len) = Client::parse(&mut raw).unwrap().unwrap(); - assert_eq!(len, expected_len); - assert_eq!(req.subject, ::StatusCode::OK); - assert_eq!(req.version, ::Version::HTTP_11); - assert_eq!(req.headers.len(), 1); - assert_eq!(req.headers["Content-Length"], "0"); + let ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut Some(::Method::GET), + }; + let msg = Client::parse(&mut raw, ctx).unwrap().unwrap(); + assert_eq!(raw.len(), 0); + assert_eq!(msg.head.subject, ::StatusCode::OK); + assert_eq!(msg.head.version, ::Version::HTTP_11); + assert_eq!(msg.head.headers.len(), 1); + assert_eq!(msg.head.headers["Content-Length"], "0"); } #[test] fn test_parse_request_errors() { let mut raw = BytesMut::from(b"GET htt:p// HTTP/1.1\r\nHost: hyper.rs\r\n\r\n".to_vec()); - Server::parse(&mut raw).unwrap_err(); + let ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut None, + }; + Server::parse(&mut raw, ctx).unwrap_err(); } + #[test] fn test_decoder_request() { use super::Decoder; - let method = &mut None; - let mut head = MessageHead::<::proto::RequestLine>::default(); + fn parse(s: &str) -> ParsedMessage { + let mut bytes = BytesMut::from(s); + Server::parse(&mut bytes, ParseContext { + cached_headers: &mut None, + req_method: &mut None, + }) + .expect("parse ok") + .expect("parse complete") + } - head.subject.0 = ::Method::GET; - assert_eq!(Decoder::length(0), Server::decoder(&head, method).unwrap().normal()); - assert_eq!(*method, Some(::Method::GET)); + fn parse_err(s: &str, comment: &str) -> ::error::Parse { + let mut bytes = BytesMut::from(s); + Server::parse(&mut bytes, ParseContext { + cached_headers: &mut None, + req_method: &mut None, + }) + .expect_err(comment) + } - head.subject.0 = ::Method::POST; - assert_eq!(Decoder::length(0), Server::decoder(&head, method).unwrap().normal()); - assert_eq!(*method, Some(::Method::POST)); + // no length or transfer-encoding means 0-length body + assert_eq!(parse("\ + GET / HTTP/1.1\r\n\ + \r\n\ + ").decode, Decode::Normal(Decoder::length(0))); + + assert_eq!(parse("\ + POST / HTTP/1.1\r\n\ + \r\n\ + ").decode, Decode::Normal(Decoder::length(0))); + + // transfer-encoding: chunked + assert_eq!(parse("\ + POST / HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + ").decode, Decode::Normal(Decoder::chunked())); + + assert_eq!(parse("\ + POST / HTTP/1.1\r\n\ + transfer-encoding: gzip, chunked\r\n\ + \r\n\ + ").decode, Decode::Normal(Decoder::chunked())); + + assert_eq!(parse("\ + POST / HTTP/1.1\r\n\ + transfer-encoding: gzip\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + ").decode, Decode::Normal(Decoder::chunked())); + + // content-length + assert_eq!(parse("\ + POST / HTTP/1.1\r\n\ + content-length: 10\r\n\ + \r\n\ + ").decode, Decode::Normal(Decoder::length(10))); - head.headers.insert("transfer-encoding", ::http::header::HeaderValue::from_static("chunked")); - assert_eq!(Decoder::chunked(), Server::decoder(&head, method).unwrap().normal()); // transfer-encoding and content-length = chunked - head.headers.insert("content-length", ::http::header::HeaderValue::from_static("10")); - assert_eq!(Decoder::chunked(), Server::decoder(&head, method).unwrap().normal()); - - head.headers.remove("transfer-encoding"); - assert_eq!(Decoder::length(10), Server::decoder(&head, method).unwrap().normal()); - - head.headers.insert("content-length", ::http::header::HeaderValue::from_static("5")); - head.headers.append("content-length", ::http::header::HeaderValue::from_static("5")); - assert_eq!(Decoder::length(5), Server::decoder(&head, method).unwrap().normal()); - - head.headers.insert("content-length", ::http::header::HeaderValue::from_static("5")); - head.headers.append("content-length", ::http::header::HeaderValue::from_static("6")); - Server::decoder(&head, method).unwrap_err(); - - head.headers.remove("content-length"); - - head.headers.insert("transfer-encoding", ::http::header::HeaderValue::from_static("gzip")); - Server::decoder(&head, method).unwrap_err(); + assert_eq!(parse("\ + POST / HTTP/1.1\r\n\ + content-length: 10\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + ").decode, Decode::Normal(Decoder::chunked())); + + assert_eq!(parse("\ + POST / HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\ + content-length: 10\r\n\ + \r\n\ + ").decode, Decode::Normal(Decoder::chunked())); + + assert_eq!(parse("\ + POST / HTTP/1.1\r\n\ + transfer-encoding: gzip\r\n\ + content-length: 10\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + ").decode, Decode::Normal(Decoder::chunked())); + + + // multiple content-lengths of same value are fine + assert_eq!(parse("\ + POST / HTTP/1.1\r\n\ + content-length: 10\r\n\ + content-length: 10\r\n\ + \r\n\ + ").decode, Decode::Normal(Decoder::length(10))); + + + // multiple content-lengths with different values is an error + parse_err("\ + POST / HTTP/1.1\r\n\ + content-length: 10\r\n\ + content-length: 11\r\n\ + \r\n\ + ", "multiple content-lengths"); + + // transfer-encoding that isn't chunked is an error + parse_err("\ + POST / HTTP/1.1\r\n\ + transfer-encoding: gzip\r\n\ + \r\n\ + ", "transfer-encoding but not chunked"); + + parse_err("\ + POST / HTTP/1.1\r\n\ + transfer-encoding: chunked, gzip\r\n\ + \r\n\ + ", "transfer-encoding doesn't end in chunked"); // http/1.0 - head.version = ::Version::HTTP_10; - head.headers.clear(); - // 1.0 requests can only have bodies if content-length is set - assert_eq!(Decoder::length(0), Server::decoder(&head, method).unwrap().normal()); + assert_eq!(parse("\ + POST / HTTP/1.0\r\n\ + content-length: 10\r\n\ + \r\n\ + ").decode, Decode::Normal(Decoder::length(10))); - head.headers.insert("transfer-encoding", ::http::header::HeaderValue::from_static("chunked")); - Server::decoder(&head, method).unwrap_err(); - head.headers.remove("transfer-encoding"); - head.headers.insert("content-length", ::http::header::HeaderValue::from_static("15")); - assert_eq!(Decoder::length(15), Server::decoder(&head, method).unwrap().normal()); + // 1.0 doesn't understand chunked, so its an error + parse_err("\ + POST / HTTP/1.0\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + ", "1.0 chunked"); } #[test] fn test_decoder_response() { - use super::Decoder; - let method = &mut Some(::Method::GET); - let mut head = MessageHead::<::StatusCode>::default(); - - head.subject = ::StatusCode::from_u16(204).unwrap(); - assert_eq!(Decoder::length(0), Client::decoder(&head, method).unwrap().normal()); - head.subject = ::StatusCode::from_u16(304).unwrap(); - assert_eq!(Decoder::length(0), Client::decoder(&head, method).unwrap().normal()); - - head.subject = ::StatusCode::OK; - assert_eq!(Decoder::eof(), Client::decoder(&head, method).unwrap().normal()); - - *method = Some(::Method::HEAD); - assert_eq!(Decoder::length(0), Client::decoder(&head, method).unwrap().normal()); - - *method = Some(::Method::CONNECT); - assert_eq!(Decoder::length(0), Client::decoder(&head, method).unwrap().final_()); + fn parse(s: &str) -> ParsedMessage { + parse_with_method(s, Method::GET) + } + fn parse_with_method(s: &str, m: Method) -> ParsedMessage { + let mut bytes = BytesMut::from(s); + Client::parse(&mut bytes, ParseContext { + cached_headers: &mut None, + req_method: &mut Some(m), + }) + .expect("parse ok") + .expect("parse complete") + } - // CONNECT receiving non 200 can have a body - head.subject = ::StatusCode::NOT_FOUND; - head.headers.insert("content-length", ::http::header::HeaderValue::from_static("10")); - assert_eq!(Decoder::length(10), Client::decoder(&head, method).unwrap().normal()); - head.headers.remove("content-length"); + fn parse_err(s: &str) -> ::error::Parse { + let mut bytes = BytesMut::from(s); + Client::parse(&mut bytes, ParseContext { + cached_headers: &mut None, + req_method: &mut Some(Method::GET), + }) + .expect_err("parse should err") + } - *method = Some(::Method::GET); - head.headers.insert("transfer-encoding", ::http::header::HeaderValue::from_static("chunked")); - assert_eq!(Decoder::chunked(), Client::decoder(&head, method).unwrap().normal()); + // no content-length or transfer-encoding means close-delimited + assert_eq!(parse("\ + HTTP/1.1 200 OK\r\n\ + \r\n\ + ").decode, Decode::Normal(Decoder::eof())); + + // 204 and 304 never have a body + assert_eq!(parse("\ + HTTP/1.1 204 No Content\r\n\ + \r\n\ + ").decode, Decode::Normal(Decoder::length(0))); + + assert_eq!(parse("\ + HTTP/1.1 304 Not Modified\r\n\ + \r\n\ + ").decode, Decode::Normal(Decoder::length(0))); + + // content-length + assert_eq!(parse("\ + HTTP/1.1 200 OK\r\n\ + content-length: 8\r\n\ + \r\n\ + ").decode, Decode::Normal(Decoder::length(8))); + + assert_eq!(parse("\ + HTTP/1.1 200 OK\r\n\ + content-length: 8\r\n\ + content-length: 8\r\n\ + \r\n\ + ").decode, Decode::Normal(Decoder::length(8))); + + parse_err("\ + HTTP/1.1 200 OK\r\n\ + content-length: 8\r\n\ + content-length: 9\r\n\ + \r\n\ + "); + + + // transfer-encoding + assert_eq!(parse("\ + HTTP/1.1 200 OK\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + ").decode, Decode::Normal(Decoder::chunked())); // transfer-encoding and content-length = chunked - head.headers.insert("content-length", ::http::header::HeaderValue::from_static("10")); - assert_eq!(Decoder::chunked(), Client::decoder(&head, method).unwrap().normal()); + assert_eq!(parse("\ + HTTP/1.1 200 OK\r\n\ + content-length: 10\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + ").decode, Decode::Normal(Decoder::chunked())); + + + // HEAD can have content-length, but not body + assert_eq!(parse_with_method("\ + HTTP/1.1 200 OK\r\n\ + content-length: 8\r\n\ + \r\n\ + ", Method::HEAD).decode, Decode::Normal(Decoder::length(0))); + + // CONNECT with 200 never has body + assert_eq!(parse_with_method("\ + HTTP/1.1 200 OK\r\n\ + \r\n\ + ", Method::CONNECT).decode, Decode::Final(Decoder::length(0))); - head.headers.remove("transfer-encoding"); - assert_eq!(Decoder::length(10), Client::decoder(&head, method).unwrap().normal()); - - head.headers.insert("content-length", ::http::header::HeaderValue::from_static("5")); - head.headers.append("content-length", ::http::header::HeaderValue::from_static("5")); - assert_eq!(Decoder::length(5), Client::decoder(&head, method).unwrap().normal()); + // CONNECT receiving non 200 can have a body + assert_eq!(parse_with_method("\ + HTTP/1.1 400 Bad Request\r\n\ + \r\n\ + ", Method::CONNECT).decode, Decode::Normal(Decoder::eof())); - head.headers.insert("content-length", ::http::header::HeaderValue::from_static("5")); - head.headers.append("content-length", ::http::header::HeaderValue::from_static("6")); - Client::decoder(&head, method).unwrap_err(); - head.headers.clear(); // 1xx status codes - head.subject = ::StatusCode::CONTINUE; - Client::decoder(&head, method).unwrap().ignore(); + assert_eq!(parse("\ + HTTP/1.1 100 Continue\r\n\ + \r\n\ + ").decode, Decode::Ignore); - head.subject = ::StatusCode::from_u16(103).unwrap(); - Client::decoder(&head, method).unwrap().ignore(); + assert_eq!(parse("\ + HTTP/1.1 103 Early Hints\r\n\ + \r\n\ + ").decode, Decode::Ignore); // 101 upgrade not supported yet - head.subject = ::StatusCode::SWITCHING_PROTOCOLS; - Client::decoder(&head, method).unwrap_err(); - head.subject = ::StatusCode::OK; + parse_err("\ + HTTP/1.1 101 Switching Protocols\r\n\ + \r\n\ + "); - // http/1.0 - head.version = ::Version::HTTP_10; - - assert_eq!(Decoder::eof(), Client::decoder(&head, method).unwrap().normal()); - head.headers.insert("transfer-encoding", ::http::header::HeaderValue::from_static("chunked")); - Client::decoder(&head, method).unwrap_err(); + // http/1.0 + assert_eq!(parse("\ + HTTP/1.0 200 OK\r\n\ + \r\n\ + ").decode, Decode::Normal(Decoder::eof())); + + // 1.0 doesn't understand chunked + parse_err("\ + HTTP/1.0 200 OK\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + "); } #[test] @@ -898,7 +1313,13 @@ mod tests { head.headers.insert("content-type", HeaderValue::from_static("application/json")); let mut vec = Vec::new(); - Client::encode(head, Some(BodyLength::Known(10)), &mut None, true, &mut vec).unwrap(); + Client::encode(Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut None, + title_case_headers: true, + }, &mut vec).unwrap(); assert_eq!(vec, b"GET / HTTP/1.1\r\nContent-Length: 10\r\nContent-Type: application/json\r\n\r\n".to_vec()); } @@ -929,10 +1350,15 @@ mod tests { \r\n\r\n".to_vec() ); let len = raw.len(); + let mut headers = Some(HeaderMap::new()); b.bytes = len as u64; b.iter(|| { - Server::parse(&mut raw).unwrap(); + let msg = Server::parse(&mut raw, ParseContext { + cached_headers: &mut headers, + req_method: &mut None, + }).unwrap().unwrap(); + headers = Some(msg.head.headers); restart(&mut raw, len); }); @@ -952,10 +1378,15 @@ mod tests { b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n".to_vec() ); let len = raw.len(); + let mut headers = Some(HeaderMap::new()); b.bytes = len as u64; b.iter(|| { - Server::parse(&mut raw).unwrap(); + let msg = Server::parse(&mut raw, ParseContext { + cached_headers: &mut headers, + req_method: &mut None, + }).unwrap().unwrap(); + headers = Some(msg.head.headers); restart(&mut raw, len); }); @@ -978,12 +1409,20 @@ mod tests { b.bytes = len as u64; let mut head = MessageHead::default(); - head.headers.insert("content-length", HeaderValue::from_static("10")); - head.headers.insert("content-type", HeaderValue::from_static("application/json")); + let mut headers = HeaderMap::new(); + headers.insert("content-length", HeaderValue::from_static("10")); + headers.insert("content-type", HeaderValue::from_static("application/json")); b.iter(|| { let mut vec = Vec::new(); - Server::encode(head.clone(), Some(BodyLength::Known(10)), &mut None, false, &mut vec).unwrap(); + head.headers = headers.clone(); + Server::encode(Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut Some(Method::GET), + title_case_headers: false, + }, &mut vec).unwrap(); assert_eq!(vec.len(), len); ::test::black_box(vec); }) @@ -1001,7 +1440,13 @@ mod tests { b.iter(|| { let mut vec = Vec::new(); - Server::encode(head.clone(), Some(BodyLength::Known(10)), &mut None, false, &mut vec).unwrap(); + Server::encode(Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut Some(Method::GET), + title_case_headers: false, + }, &mut vec).unwrap(); assert_eq!(vec.len(), len); ::test::black_box(vec); }) diff --git a/src/proto/mod.rs b/src/proto/mod.rs index 2987754fe9..131fb23209 100644 --- a/src/proto/mod.rs +++ b/src/proto/mod.rs @@ -1,10 +1,7 @@ //! Pieces pertaining to the HTTP message protocol. -use bytes::BytesMut; use http::{HeaderMap, Method, StatusCode, Uri, Version}; -use headers; - -pub(crate) use self::h1::{dispatch, Conn}; +pub(crate) use self::h1::{dispatch, Conn, ClientTransaction, ClientUpgradeTransaction, ServerTransaction}; pub(crate) mod h1; pub(crate) mod h2; @@ -30,6 +27,7 @@ pub struct RequestLine(pub Method, pub Uri); /// An incoming response message. pub type ResponseHead = MessageHead; +/* impl MessageHead { pub fn should_keep_alive(&self) -> bool { should_keep_alive(self.version, &self.headers) @@ -55,33 +53,7 @@ pub fn should_keep_alive(version: Version, headers: &HeaderMap) -> bool { pub fn expecting_continue(version: Version, headers: &HeaderMap) -> bool { version == Version::HTTP_11 && headers::expect_continue(headers) } - -pub(crate) type ServerTransaction = h1::role::Server; -//pub type ServerTransaction = h1::role::Server; -//pub type ServerUpgradeTransaction = h1::role::Server; - -pub(crate) type ClientTransaction = h1::role::Client; -pub(crate) type ClientUpgradeTransaction = h1::role::Client; - -pub(crate) trait Http1Transaction { - type Incoming; - type Outgoing: Default; - fn parse(bytes: &mut BytesMut) -> ParseResult; - fn decoder(head: &MessageHead, method: &mut Option) -> ::Result; - fn encode( - head: MessageHead, - body: Option, - method: &mut Option, - title_case_headers: bool, - dst: &mut Vec, - ) -> ::Result; - fn on_error(err: &::Error) -> Option>; - - fn should_error_on_parse_eof() -> bool; - fn should_read_first() -> bool; -} - -pub(crate) type ParseResult = Result, usize)>, ::error::Parse>; +*/ #[derive(Debug)] pub enum BodyLength { @@ -91,17 +63,7 @@ pub enum BodyLength { Unknown, } - -#[derive(Debug)] -pub enum Decode { - /// Decode normally. - Normal(h1::Decoder), - /// After this decoder is done, HTTP is done. - Final(h1::Decoder), - /// A header block that should be ignored, like unknown 1xx responses. - Ignore, -} - +/* #[test] fn test_should_keep_alive() { let mut headers = HeaderMap::new(); @@ -129,3 +91,4 @@ fn test_expecting_continue() { assert!(!expecting_continue(Version::HTTP_10, &headers)); assert!(expecting_continue(Version::HTTP_11, &headers)); } +*/ diff --git a/tests/server.rs b/tests/server.rs index 4b86795854..878f23dbfd 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -181,6 +181,13 @@ mod response_body_lengths { has_header(&body, "transfer-encoding:"), "expects_chunked" ); + + assert_eq!( + case.expects_chunked, + has_header(&body, "chunked\r\n"), + "expects_chunked" + ); + assert_eq!( case.expects_con_len, has_header(&body, "content-length:"), @@ -200,7 +207,7 @@ mod response_body_lengths { } #[test] - fn get_fixed_response_known() { + fn fixed_response_known() { run_test(TestCase { version: 1, headers: &[("content-length", "11")], @@ -211,7 +218,7 @@ mod response_body_lengths { } #[test] - fn get_fixed_response_unknown() { + fn fixed_response_unknown() { run_test(TestCase { version: 1, headers: &[("content-length", "11")], @@ -222,7 +229,18 @@ mod response_body_lengths { } #[test] - fn get_chunked_response_known() { + fn fixed_response_known_empty() { + run_test(TestCase { + version: 1, + headers: &[("content-length", "0")], + body: Bd::Known(""), + expects_chunked: false, + expects_con_len: true, + }); + } + + #[test] + fn chunked_response_known() { run_test(TestCase { version: 1, headers: &[("transfer-encoding", "chunked")], @@ -234,7 +252,7 @@ mod response_body_lengths { } #[test] - fn get_chunked_response_unknown() { + fn chunked_response_unknown() { run_test(TestCase { version: 1, headers: &[("transfer-encoding", "chunked")], @@ -245,7 +263,22 @@ mod response_body_lengths { } #[test] - fn get_chunked_response_trumps_length() { + fn te_response_adds_chunked() { + run_test(TestCase { + version: 1, + headers: &[("transfer-encoding", "gzip")], + body: Bd::Unknown("foo bar baz"), + expects_chunked: true, + expects_con_len: false, + }); + } + + #[test] + #[ignore] + // This used to be the case, but providing this functionality got in the + // way of performance. It can probably be brought back later, and doing + // so should be backwards-compatible... + fn chunked_response_trumps_length() { run_test(TestCase { version: 1, headers: &[ @@ -260,7 +293,7 @@ mod response_body_lengths { } #[test] - fn get_auto_response_with_entity_unknown_length() { + fn auto_response_with_unknown_length() { run_test(TestCase { version: 1, // no headers means trying to guess from Payload @@ -272,7 +305,7 @@ mod response_body_lengths { } #[test] - fn get_auto_response_with_entity_known_length() { + fn auto_response_with_known_length() { run_test(TestCase { version: 1, // no headers means trying to guess from Payload @@ -283,9 +316,20 @@ mod response_body_lengths { }); } + #[test] + fn auto_response_known_empty() { + run_test(TestCase { + version: 1, + // no headers means trying to guess from Payload + headers: &[], + body: Bd::Known(""), + expects_chunked: false, + expects_con_len: true, + }); + } #[test] - fn http_10_get_auto_response_with_entity_unknown_length() { + fn http10_auto_response_with_unknown_length() { run_test(TestCase { version: 0, // no headers means trying to guess from Payload @@ -298,7 +342,7 @@ mod response_body_lengths { #[test] - fn http_10_get_chunked_response() { + fn http10_chunked_response() { run_test(TestCase { version: 0, // http/1.0 should strip this header @@ -620,6 +664,62 @@ fn disable_keep_alive() { } } +#[test] +fn header_connection_close() { + let foo_bar = b"foo bar baz"; + let server = serve(); + server.reply() + .header("content-length", foo_bar.len().to_string()) + .header("connection", "close") + .body(foo_bar); + let mut req = connect(server.addr()); + req.write_all(b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: keep-alive\r\n\ + \r\n\ + ").expect("writing 1"); + + let mut buf = [0; 1024 * 8]; + loop { + let n = req.read(&mut buf[..]).expect("reading 1"); + if n < buf.len() { + if &buf[n - foo_bar.len()..n] == foo_bar { + break; + } else { + } + } + } + + // try again! + // but since the server responded with connection: close, the internal + // state should have noticed and shutdown + + let quux = b"zar quux"; + server.reply() + .header("content-length", quux.len().to_string()) + .body(quux); + + // the write can possibly succeed, since it fills the kernel buffer on the first write + let _ = req.write_all(b"\ + GET /quux HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: close\r\n\ + \r\n\ + "); + + let mut buf = [0; 1024 * 8]; + match req.read(&mut buf[..]) { + // Ok(0) means EOF, so a proper shutdown + // Err(_) could mean ConnReset or something, also fine + Ok(0) | + Err(_) => {} + Ok(n) => { + panic!("read {} bytes on a disabled keep-alive socket", n); + } + } +} + #[test] fn expect_continue() { let server = serve();