diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index 88b177203e..744e1b2ae5 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -1,10 +1,12 @@ use std::fmt; +#[cfg(feature = "server")] +use std::future::Future; use std::io; use std::marker::{PhantomData, Unpin}; use std::pin::Pin; use std::task::{Context, Poll}; #[cfg(feature = "server")] -use std::time::Duration; +use std::time::{Duration, Instant}; use crate::rt::{Read, Write}; use bytes::{Buf, Bytes}; @@ -209,33 +211,67 @@ where debug_assert!(self.can_read_head()); trace!("Conn::read_head"); - let msg = match ready!(self.io.parse::( + #[cfg(feature = "server")] + if !self.state.h1_header_read_timeout_running { + if let Some(h1_header_read_timeout) = self.state.h1_header_read_timeout { + let deadline = Instant::now() + h1_header_read_timeout; + self.state.h1_header_read_timeout_running = true; + match self.state.h1_header_read_timeout_fut { + Some(ref mut h1_header_read_timeout_fut) => { + trace!("resetting h1 header read timeout timer"); + self.state.timer.reset(h1_header_read_timeout_fut, deadline); + } + None => { + trace!("setting h1 header read timeout timer"); + self.state.h1_header_read_timeout_fut = + Some(self.state.timer.sleep_until(deadline)); + } + } + } + } + + let msg = match self.io.parse::( cx, ParseContext { cached_headers: &mut self.state.cached_headers, req_method: &mut self.state.method, h1_parser_config: self.state.h1_parser_config.clone(), h1_max_headers: self.state.h1_max_headers, - #[cfg(feature = "server")] - h1_header_read_timeout: self.state.h1_header_read_timeout, - #[cfg(feature = "server")] - h1_header_read_timeout_fut: &mut self.state.h1_header_read_timeout_fut, - #[cfg(feature = "server")] - h1_header_read_timeout_running: &mut self.state.h1_header_read_timeout_running, - #[cfg(feature = "server")] - timer: self.state.timer.clone(), preserve_header_case: self.state.preserve_header_case, #[cfg(feature = "ffi")] preserve_header_order: self.state.preserve_header_order, h09_responses: self.state.h09_responses, #[cfg(feature = "ffi")] on_informational: &mut self.state.on_informational, + }, + ) { + Poll::Ready(Ok(msg)) => msg, + Poll::Ready(Err(e)) => return self.on_read_head_error(e), + Poll::Pending => { + #[cfg(feature = "server")] + if self.state.h1_header_read_timeout_running { + if let Some(ref mut h1_header_read_timeout_fut) = + self.state.h1_header_read_timeout_fut + { + if Pin::new(h1_header_read_timeout_fut).poll(cx).is_ready() { + self.state.h1_header_read_timeout_running = false; + + warn!("read header from client timeout"); + return Poll::Ready(Some(Err(crate::Error::new_header_timeout()))); + } + } + } + + return Poll::Pending; } - )) { - Ok(msg) => msg, - Err(e) => return self.on_read_head_error(e), }; + #[cfg(feature = "server")] + { + self.state.h1_header_read_timeout_running = false; + self.state.h1_header_read_timeout_fut = None; + } + // Note: don't deconstruct `msg` into local variables, it appears // the optimizer doesn't remove the extra copies. diff --git a/src/proto/h1/io.rs b/src/proto/h1/io.rs index 34eb477fb9..4ad2fca1f4 100644 --- a/src/proto/h1/io.rs +++ b/src/proto/h1/io.rs @@ -1,7 +1,5 @@ use std::cmp; use std::fmt; -#[cfg(feature = "server")] -use std::future::Future; use std::io::{self, IoSlice}; use std::pin::Pin; use std::task::{Context, Poll}; @@ -183,14 +181,6 @@ where req_method: parse_ctx.req_method, h1_parser_config: parse_ctx.h1_parser_config.clone(), h1_max_headers: parse_ctx.h1_max_headers, - #[cfg(feature = "server")] - h1_header_read_timeout: parse_ctx.h1_header_read_timeout, - #[cfg(feature = "server")] - h1_header_read_timeout_fut: parse_ctx.h1_header_read_timeout_fut, - #[cfg(feature = "server")] - h1_header_read_timeout_running: parse_ctx.h1_header_read_timeout_running, - #[cfg(feature = "server")] - timer: parse_ctx.timer.clone(), preserve_header_case: parse_ctx.preserve_header_case, #[cfg(feature = "ffi")] preserve_header_order: parse_ctx.preserve_header_order, @@ -201,12 +191,6 @@ where )? { Some(msg) => { debug!("parsed {} headers", msg.head.headers.len()); - - #[cfg(feature = "server")] - { - *parse_ctx.h1_header_read_timeout_running = false; - parse_ctx.h1_header_read_timeout_fut.take(); - } return Poll::Ready(Ok(msg)); } None => { @@ -215,20 +199,6 @@ where debug!("max_buf_size ({}) reached, closing", max); return Poll::Ready(Err(crate::Error::new_too_large())); } - - #[cfg(feature = "server")] - if *parse_ctx.h1_header_read_timeout_running { - if let Some(h1_header_read_timeout_fut) = - parse_ctx.h1_header_read_timeout_fut - { - if Pin::new(h1_header_read_timeout_fut).poll(cx).is_ready() { - *parse_ctx.h1_header_read_timeout_running = false; - - warn!("read header from client timeout"); - return Poll::Ready(Err(crate::Error::new_header_timeout())); - } - } - } } } if ready!(self.poll_read_from_io(cx)).map_err(crate::Error::new_io)? == 0 { @@ -660,10 +630,8 @@ enum WriteStrategy { #[cfg(test)] mod tests { - use crate::common::io::Compat; - use crate::common::time::Time; - use super::*; + use crate::common::io::Compat; use std::time::Duration; use tokio_test::io::Builder as Mock; @@ -726,10 +694,6 @@ mod tests { req_method: &mut None, h1_parser_config: Default::default(), h1_max_headers: None, - h1_header_read_timeout: None, - h1_header_read_timeout_fut: &mut None, - h1_header_read_timeout_running: &mut false, - timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, diff --git a/src/proto/h1/mod.rs b/src/proto/h1/mod.rs index 074d0b88a8..fe397d24c4 100644 --- a/src/proto/h1/mod.rs +++ b/src/proto/h1/mod.rs @@ -1,16 +1,9 @@ -#[cfg(feature = "server")] -use std::{pin::Pin, time::Duration}; - use bytes::BytesMut; use http::{HeaderMap, Method}; use httparse::ParserConfig; use crate::body::DecodedLength; -#[cfg(feature = "server")] -use crate::common::time::Time; use crate::proto::{BodyLength, MessageHead}; -#[cfg(feature = "server")] -use crate::rt::Sleep; pub(crate) use self::conn::Conn; pub(crate) use self::decode::Decoder; @@ -79,14 +72,6 @@ pub(crate) struct ParseContext<'a> { req_method: &'a mut Option, h1_parser_config: ParserConfig, h1_max_headers: Option, - #[cfg(feature = "server")] - h1_header_read_timeout: Option, - #[cfg(feature = "server")] - h1_header_read_timeout_fut: &'a mut Option>>, - #[cfg(feature = "server")] - h1_header_read_timeout_running: &'a mut bool, - #[cfg(feature = "server")] - timer: Time, preserve_header_case: bool, #[cfg(feature = "ffi")] preserve_header_order: bool, diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index 847453a08c..26be74d87f 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -2,8 +2,6 @@ use std::mem::MaybeUninit; #[cfg(feature = "client")] use std::fmt::{self, Write as _}; -#[cfg(feature = "server")] -use std::time::Instant; use bytes::Bytes; use bytes::BytesMut; @@ -80,24 +78,6 @@ where let _entered = trace_span!("parse_headers"); - #[cfg(feature = "server")] - if !*ctx.h1_header_read_timeout_running { - if let Some(h1_header_read_timeout) = ctx.h1_header_read_timeout { - let deadline = Instant::now() + h1_header_read_timeout; - *ctx.h1_header_read_timeout_running = true; - match ctx.h1_header_read_timeout_fut { - Some(h1_header_read_timeout_fut) => { - debug!("resetting h1 header read timeout timer"); - ctx.timer.reset(h1_header_read_timeout_fut, deadline); - } - None => { - debug!("setting h1 header read timeout timer"); - *ctx.h1_header_read_timeout_fut = Some(ctx.timer.sleep_until(deadline)); - } - } - } - } - T::parse(bytes, ctx) } @@ -1631,8 +1611,6 @@ fn extend(dst: &mut Vec, data: &[u8]) { mod tests { use bytes::BytesMut; - use crate::common::time::Time; - use super::*; #[test] @@ -1647,10 +1625,6 @@ mod tests { req_method: &mut method, h1_parser_config: Default::default(), h1_max_headers: None, - h1_header_read_timeout: None, - h1_header_read_timeout_fut: &mut None, - h1_header_read_timeout_running: &mut false, - timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -1679,10 +1653,6 @@ mod tests { req_method: &mut Some(crate::Method::GET), h1_parser_config: Default::default(), h1_max_headers: None, - h1_header_read_timeout: None, - h1_header_read_timeout_fut: &mut None, - h1_header_read_timeout_running: &mut false, - timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -1706,10 +1676,6 @@ mod tests { req_method: &mut None, h1_parser_config: Default::default(), h1_max_headers: None, - h1_header_read_timeout: None, - h1_header_read_timeout_fut: &mut None, - h1_header_read_timeout_running: &mut false, - timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -1731,10 +1697,6 @@ mod tests { req_method: &mut Some(crate::Method::GET), h1_parser_config: Default::default(), h1_max_headers: None, - h1_header_read_timeout: None, - h1_header_read_timeout_fut: &mut None, - h1_header_read_timeout_running: &mut false, - timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -1758,10 +1720,6 @@ mod tests { req_method: &mut Some(crate::Method::GET), h1_parser_config: Default::default(), h1_max_headers: None, - h1_header_read_timeout: None, - h1_header_read_timeout_fut: &mut None, - h1_header_read_timeout_running: &mut false, - timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -1789,10 +1747,6 @@ mod tests { req_method: &mut Some(crate::Method::GET), h1_parser_config, h1_max_headers: None, - h1_header_read_timeout: None, - h1_header_read_timeout_fut: &mut None, - h1_header_read_timeout_running: &mut false, - timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -1817,10 +1771,6 @@ mod tests { req_method: &mut Some(crate::Method::GET), h1_parser_config: Default::default(), h1_max_headers: None, - h1_header_read_timeout: None, - h1_header_read_timeout_fut: &mut None, - h1_header_read_timeout_running: &mut false, - timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -1840,10 +1790,6 @@ mod tests { req_method: &mut None, h1_parser_config: Default::default(), h1_max_headers: None, - h1_header_read_timeout: None, - h1_header_read_timeout_fut: &mut None, - h1_header_read_timeout_running: &mut false, - timer: Time::Empty, preserve_header_case: true, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -1884,10 +1830,6 @@ mod tests { req_method: &mut None, h1_parser_config: Default::default(), h1_max_headers: None, - h1_header_read_timeout: None, - h1_header_read_timeout_fut: &mut None, - h1_header_read_timeout_running: &mut false, - timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -1909,10 +1851,6 @@ mod tests { req_method: &mut None, h1_parser_config: Default::default(), h1_max_headers: None, - h1_header_read_timeout: None, - h1_header_read_timeout_fut: &mut None, - h1_header_read_timeout_running: &mut false, - timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -2143,10 +2081,6 @@ mod tests { req_method: &mut Some(Method::GET), h1_parser_config: Default::default(), h1_max_headers: None, - h1_header_read_timeout: None, - h1_header_read_timeout_fut: &mut None, - h1_header_read_timeout_running: &mut false, - timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -2168,10 +2102,6 @@ mod tests { req_method: &mut Some(m), h1_parser_config: Default::default(), h1_max_headers: None, - h1_header_read_timeout: None, - h1_header_read_timeout_fut: &mut None, - h1_header_read_timeout_running: &mut false, - timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -2193,10 +2123,6 @@ mod tests { req_method: &mut Some(Method::GET), h1_parser_config: Default::default(), h1_max_headers: None, - h1_header_read_timeout: None, - h1_header_read_timeout_fut: &mut None, - h1_header_read_timeout_running: &mut false, - timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -2756,10 +2682,6 @@ mod tests { req_method: &mut Some(Method::GET), h1_parser_config: Default::default(), h1_max_headers: None, - h1_header_read_timeout: None, - h1_header_read_timeout_fut: &mut None, - h1_header_read_timeout_running: &mut false, - timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -2803,10 +2725,6 @@ mod tests { req_method: &mut None, h1_parser_config: Default::default(), h1_max_headers: max_headers, - h1_header_read_timeout: None, - h1_header_read_timeout_fut: &mut None, - h1_header_read_timeout_running: &mut false, - timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -2831,10 +2749,6 @@ mod tests { req_method: &mut None, h1_parser_config: Default::default(), h1_max_headers: max_headers, - h1_header_read_timeout: None, - h1_header_read_timeout_fut: &mut None, - h1_header_read_timeout_running: &mut false, - timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -2982,10 +2896,6 @@ mod tests { req_method: &mut None, h1_parser_config: Default::default(), h1_max_headers: None, - h1_header_read_timeout: None, - h1_header_read_timeout_fut: &mut None, - h1_header_read_timeout_running: &mut false, - timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -3031,10 +2941,6 @@ mod tests { req_method: &mut None, h1_parser_config: Default::default(), h1_max_headers: None, - h1_header_read_timeout: None, - h1_header_read_timeout_fut: &mut None, - h1_header_read_timeout_running: &mut false, - timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, diff --git a/tests/server.rs b/tests/server.rs index f3efbf3c6e..82bc80669c 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -1508,6 +1508,27 @@ async fn header_read_timeout_slow_writes() { conn.without_shutdown().await.expect_err("header timeout"); } +#[tokio::test] +async fn header_read_timeout_starts_immediately() { + let (listener, addr) = setup_tcp_listener(); + + thread::spawn(move || { + let mut tcp = connect(&addr); + thread::sleep(Duration::from_secs(3)); + let mut buf = [0u8; 256]; + let n = tcp.read(&mut buf).expect("read 1"); + assert_eq!(n, 0); //eof + }); + + let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); + let conn = http1::Builder::new() + .timer(TokioTimer) + .header_read_timeout(Duration::from_secs(2)) + .serve_connection(socket, unreachable_service()); + conn.await.expect_err("header timeout"); +} + #[tokio::test] async fn header_read_timeout_slow_writes_multiple_requests() { let (listener, addr) = setup_tcp_listener();