From 9a8413d91081ad5a949276f05337e984c455e251 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Fri, 20 Mar 2020 13:58:52 -0700 Subject: [PATCH] feat(http2): add HTTP2 keep-alive support for client and server This adds HTTP2 keep-alive support to client and server connections based losely on GRPC keep-alive. When enabled, after no data has been received for some configured interval, an HTTP2 PING frame is sent. If the PING is not acknowledged with a configured timeout, the connection is closed. Clients have an additional option to enable keep-alive while the connection is otherwise idle. When disabled, keep-alive PINGs are only used while there are open request/response streams. If enabled, PINGs are sent even when there are no active streams. For now, since these features use `tokio::time::Delay`, the `runtime` cargo feature is required to use them. --- src/body/body.rs | 21 +- src/client/conn.rs | 56 +++++ src/client/mod.rs | 54 +++++ src/error.rs | 37 ++- src/proto/h2/bdp.rs | 186 --------------- src/proto/h2/client.rs | 73 ++++-- src/proto/h2/mod.rs | 2 +- src/proto/h2/ping.rs | 509 +++++++++++++++++++++++++++++++++++++++++ src/proto/h2/server.rs | 78 +++++-- src/server/conn.rs | 60 ++++- src/server/mod.rs | 50 +++- tests/client.rs | 192 ++++++++++++++++ tests/server.rs | 103 ++++++++- 13 files changed, 1166 insertions(+), 255 deletions(-) delete mode 100644 src/proto/h2/bdp.rs create mode 100644 src/proto/h2/ping.rs diff --git a/src/body/body.rs b/src/body/body.rs index 228a996cc6..49c6683119 100644 --- a/src/body/body.rs +++ b/src/body/body.rs @@ -12,7 +12,7 @@ use http::HeaderMap; use http_body::{Body as HttpBody, SizeHint}; use crate::common::{task, watch, Future, Never, Pin, Poll}; -use crate::proto::h2::bdp; +use crate::proto::h2::ping; use crate::proto::DecodedLength; use crate::upgrade::OnUpgrade; @@ -38,7 +38,7 @@ enum Kind { rx: mpsc::Receiver>, }, H2 { - bdp: bdp::Sampler, + ping: ping::Recorder, content_length: DecodedLength, recv: h2::RecvStream, }, @@ -180,10 +180,10 @@ impl Body { pub(crate) fn h2( recv: h2::RecvStream, content_length: DecodedLength, - bdp: bdp::Sampler, + ping: ping::Recorder, ) -> Self { let body = Body::new(Kind::H2 { - bdp, + ping, content_length, recv, }); @@ -265,14 +265,14 @@ impl Body { } } Kind::H2 { - ref bdp, + ref ping, recv: ref mut h2, content_length: ref mut len, } => match ready!(h2.poll_data(cx)) { Some(Ok(bytes)) => { let _ = h2.flow_control().release_capacity(bytes.len()); len.sub_if(bytes.len() as u64); - bdp.sample(bytes.len()); + ping.record_data(bytes.len()); Poll::Ready(Some(Ok(bytes))) } Some(Err(e)) => Poll::Ready(Some(Err(crate::Error::new_body(e)))), @@ -321,9 +321,14 @@ impl HttpBody for Body { ) -> Poll, Self::Error>> { match self.kind { Kind::H2 { - recv: ref mut h2, .. + recv: ref mut h2, + ref ping, + .. } => match ready!(h2.poll_trailers(cx)) { - Ok(t) => Poll::Ready(Ok(t)), + Ok(t) => { + ping.record_non_data(); + Poll::Ready(Ok(t)) + } Err(e) => Poll::Ready(Err(crate::Error::new_h2(e))), }, _ => Poll::Ready(Ok(None)), diff --git a/src/client/conn.rs b/src/client/conn.rs index f7b24b5268..81eaf8287f 100644 --- a/src/client/conn.rs +++ b/src/client/conn.rs @@ -7,9 +7,12 @@ //! //! If don't have need to manage connections yourself, consider using the //! higher-level [Client](super) API. + use std::fmt; use std::mem; use std::sync::Arc; +#[cfg(feature = "runtime")] +use std::time::Duration; use bytes::Bytes; use futures_util::future::{self, Either, FutureExt as _}; @@ -517,6 +520,59 @@ impl Builder { self } + /// Sets an interval for HTTP2 Ping frames should be sent to keep a + /// connection alive. + /// + /// Pass `None` to disable HTTP2 keep-alive. + /// + /// Default is currently disabled. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + pub fn http2_keep_alive_interval( + &mut self, + interval: impl Into>, + ) -> &mut Self { + self.h2_builder.keep_alive_interval = interval.into(); + self + } + + /// Sets a timeout for receiving an acknowledgement of the keep-alive ping. + /// + /// If the ping is not acknowledged within the timeout, the connection will + /// be closed. Does nothing if `http2_keep_alive_interval` is disabled. + /// + /// Default is 20 seconds. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + pub fn http2_keep_alive_timeout(&mut self, timeout: Duration) -> &mut Self { + self.h2_builder.keep_alive_timeout = timeout; + self + } + + /// Sets whether HTTP2 keep-alive should apply while the connection is idle. + /// + /// If disabled, keep-alive pings are only sent while there are open + /// request/responses streams. If enabled, pings are also sent when no + /// streams are active. Does nothing if `http2_keep_alive_interval` is + /// disabled. + /// + /// Default is `false`. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + pub fn http2_keep_alive_while_idle(&mut self, enabled: bool) -> &mut Self { + self.h2_builder.keep_alive_while_idle = enabled; + self + } + /// Constructs a connection with the configured options and IO. pub fn handshake( &self, diff --git a/src/client/mod.rs b/src/client/mod.rs index f1eaa1643b..42fad34573 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -933,6 +933,7 @@ impl Builder { self.pool_config.max_idle_per_host = max_idle; self } + // HTTP/1 options /// Set whether HTTP/1 connections should try to use vectored writes, @@ -1036,6 +1037,59 @@ impl Builder { self } + /// Sets an interval for HTTP2 Ping frames should be sent to keep a + /// connection alive. + /// + /// Pass `None` to disable HTTP2 keep-alive. + /// + /// Default is currently disabled. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + pub fn http2_keep_alive_interval( + &mut self, + interval: impl Into>, + ) -> &mut Self { + self.conn_builder.http2_keep_alive_interval(interval); + self + } + + /// Sets a timeout for receiving an acknowledgement of the keep-alive ping. + /// + /// If the ping is not acknowledged within the timeout, the connection will + /// be closed. Does nothing if `http2_keep_alive_interval` is disabled. + /// + /// Default is 20 seconds. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + pub fn http2_keep_alive_timeout(&mut self, timeout: Duration) -> &mut Self { + self.conn_builder.http2_keep_alive_timeout(timeout); + self + } + + /// Sets whether HTTP2 keep-alive should apply while the connection is idle. + /// + /// If disabled, keep-alive pings are only sent while there are open + /// request/responses streams. If enabled, pings are also sent when no + /// streams are active. Does nothing if `http2_keep_alive_interval` is + /// disabled. + /// + /// Default is `false`. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + pub fn http2_keep_alive_while_idle(&mut self, enabled: bool) -> &mut Self { + self.conn_builder.http2_keep_alive_while_idle(enabled); + self + } + /// Set whether to retry requests that get disrupted before ever starting /// to write. /// diff --git a/src/error.rs b/src/error.rs index 2a9ea923f6..99a994b879 100644 --- a/src/error.rs +++ b/src/error.rs @@ -91,6 +91,10 @@ pub(crate) enum User { ManualUpgrade, } +// Sentinel type to indicate the error was caused by a timeout. +#[derive(Debug)] +pub(crate) struct TimedOut; + impl Error { /// Returns true if this was an HTTP parse error. pub fn is_parse(&self) -> bool { @@ -133,6 +137,11 @@ impl Error { self.inner.kind == Kind::BodyWriteAborted } + /// Returns true if the error was caused by a timeout. + pub fn is_timeout(&self) -> bool { + self.find_source::().is_some() + } + /// Consumes the error, returning its cause. pub fn into_cause(self) -> Option> { self.inner.cause @@ -153,19 +162,25 @@ impl Error { &self.inner.kind } - pub(crate) fn h2_reason(&self) -> h2::Reason { - // Find an h2::Reason somewhere in the cause stack, if it exists, - // otherwise assume an INTERNAL_ERROR. + fn find_source(&self) -> Option<&E> { let mut cause = self.source(); while let Some(err) = cause { - if let Some(h2_err) = err.downcast_ref::() { - return h2_err.reason().unwrap_or(h2::Reason::INTERNAL_ERROR); + if let Some(ref typed) = err.downcast_ref() { + return Some(typed); } cause = err.source(); } // else - h2::Reason::INTERNAL_ERROR + None + } + + pub(crate) fn h2_reason(&self) -> h2::Reason { + // Find an h2::Reason somewhere in the cause stack, if it exists, + // otherwise assume an INTERNAL_ERROR. + self.find_source::() + .and_then(|h2_err| h2_err.reason()) + .unwrap_or(h2::Reason::INTERNAL_ERROR) } pub(crate) fn new_canceled() -> Error { @@ -397,6 +412,16 @@ trait AssertSendSync: Send + Sync + 'static {} #[doc(hidden)] impl AssertSendSync for Error {} +// ===== impl TimedOut ==== + +impl fmt::Display for TimedOut { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("operation timed out") + } +} + +impl StdError for TimedOut {} + #[cfg(test)] mod tests { use super::*; diff --git a/src/proto/h2/bdp.rs b/src/proto/h2/bdp.rs deleted file mode 100644 index 2a2c99bc7b..0000000000 --- a/src/proto/h2/bdp.rs +++ /dev/null @@ -1,186 +0,0 @@ -// What should it do? -// -// # BDP Algorithm -// -// 1. When receiving a DATA frame, if a BDP ping isn't outstanding: -// 1a. Record current time. -// 1b. Send a BDP ping. -// 2. Increment the number of received bytes. -// 3. When the BDP ping ack is received: -// 3a. Record duration from sent time. -// 3b. Merge RTT with a running average. -// 3c. Calculate bdp as bytes/rtt. -// 3d. If bdp is over 2/3 max, set new max to bdp and update windows. -// -// -// # Implementation -// -// - `hyper::Body::h2` variant includes a "bdp channel" -// - When the body's `poll_data` yields bytes, call `bdp.sample(bytes.len())` -// - -use std::sync::{Arc, Mutex, Weak}; -use std::task::{self, Poll}; -use std::time::{Duration, Instant}; - -use h2::{Ping, PingPong}; - -type WindowSize = u32; - -/// Any higher than this likely will be hitting the TCP flow control. -const BDP_LIMIT: usize = 1024 * 1024 * 16; - -pub(super) fn disabled() -> Sampler { - Sampler { - shared: Weak::new(), - } -} - -pub(super) fn channel(ping_pong: PingPong, initial_window: WindowSize) -> (Sampler, Estimator) { - let shared = Arc::new(Mutex::new(Shared { - bytes: 0, - ping_pong, - ping_sent: false, - sent_at: Instant::now(), - })); - - ( - Sampler { - shared: Arc::downgrade(&shared), - }, - Estimator { - bdp: initial_window, - max_bandwidth: 0.0, - shared, - samples: 0, - rtt: 0.0, - }, - ) -} - -#[derive(Clone)] -pub(crate) struct Sampler { - shared: Weak>, -} - -pub(super) struct Estimator { - shared: Arc>, - - /// Current BDP in bytes - bdp: u32, - /// Largest bandwidth we've seen so far. - max_bandwidth: f64, - /// Count of samples made (ping sent and received) - samples: usize, - /// Round trip time in seconds - rtt: f64, -} - -struct Shared { - bytes: usize, - ping_pong: PingPong, - ping_sent: bool, - sent_at: Instant, -} - -impl Sampler { - pub(crate) fn sample(&self, bytes: usize) { - let shared = if let Some(shared) = self.shared.upgrade() { - shared - } else { - return; - }; - - let mut inner = shared.lock().unwrap(); - - if !inner.ping_sent { - if let Ok(()) = inner.ping_pong.send_ping(Ping::opaque()) { - inner.ping_sent = true; - inner.sent_at = Instant::now(); - trace!("sending BDP ping"); - } else { - return; - } - } - - inner.bytes += bytes; - } -} - -impl Estimator { - pub(super) fn poll_estimate(&mut self, cx: &mut task::Context<'_>) -> Poll { - let mut inner = self.shared.lock().unwrap(); - if !inner.ping_sent { - // XXX: this doesn't register a waker...? - return Poll::Pending; - } - - let (bytes, rtt) = match ready!(inner.ping_pong.poll_pong(cx)) { - Ok(_pong) => { - let rtt = inner.sent_at.elapsed(); - let bytes = inner.bytes; - inner.bytes = 0; - inner.ping_sent = false; - self.samples += 1; - trace!("received BDP ack; bytes = {}, rtt = {:?}", bytes, rtt); - (bytes, rtt) - } - Err(e) => { - debug!("bdp pong error: {}", e); - return Poll::Pending; - } - }; - - drop(inner); - - if let Some(bdp) = self.calculate(bytes, rtt) { - Poll::Ready(bdp) - } else { - // XXX: this doesn't register a waker...? - Poll::Pending - } - } - - fn calculate(&mut self, bytes: usize, rtt: Duration) -> Option { - // No need to do any math if we're at the limit. - if self.bdp as usize == BDP_LIMIT { - return None; - } - - // average the rtt - let rtt = seconds(rtt); - if self.samples < 10 { - // Average the first 10 samples - self.rtt += (rtt - self.rtt) / (self.samples as f64); - } else { - self.rtt += (rtt - self.rtt) / 0.9; - } - - // calculate the current bandwidth - let bw = (bytes as f64) / (self.rtt * 1.5); - trace!("current bandwidth = {:.1}B/s", bw); - - if bw < self.max_bandwidth { - // not a faster bandwidth, so don't update - return None; - } else { - self.max_bandwidth = bw; - } - - // if the current `bytes` sample is at least 2/3 the previous - // bdp, increase to double the current sample. - if (bytes as f64) >= (self.bdp as f64) * 0.66 { - self.bdp = (bytes * 2).min(BDP_LIMIT) as WindowSize; - trace!("BDP increased to {}", self.bdp); - Some(self.bdp) - } else { - None - } - } -} - -fn seconds(dur: Duration) -> f64 { - const NANOS_PER_SEC: f64 = 1_000_000_000.0; - let secs = dur.as_secs() as f64; - secs + (dur.subsec_nanos() as f64) / NANOS_PER_SEC -} diff --git a/src/proto/h2/client.rs b/src/proto/h2/client.rs index 0fc00d33ea..bf4cfccea5 100644 --- a/src/proto/h2/client.rs +++ b/src/proto/h2/client.rs @@ -1,10 +1,13 @@ +#[cfg(feature = "runtime")] +use std::time::Duration; + use futures_channel::{mpsc, oneshot}; use futures_util::future::{self, Either, FutureExt as _, TryFutureExt as _}; use futures_util::stream::StreamExt as _; use h2::client::{Builder, SendRequest}; use tokio::io::{AsyncRead, AsyncWrite}; -use super::{bdp, decode_content_length, PipeToSendStream, SendBuf}; +use super::{decode_content_length, ping, PipeToSendStream, SendBuf}; use crate::body::Payload; use crate::common::{task, Exec, Future, Never, Pin, Poll}; use crate::headers; @@ -32,6 +35,12 @@ pub(crate) struct Config { pub(crate) adaptive_window: bool, pub(crate) initial_conn_window_size: u32, pub(crate) initial_stream_window_size: u32, + #[cfg(feature = "runtime")] + pub(crate) keep_alive_interval: Option, + #[cfg(feature = "runtime")] + pub(crate) keep_alive_timeout: Duration, + #[cfg(feature = "runtime")] + pub(crate) keep_alive_while_idle: bool, } impl Default for Config { @@ -40,6 +49,12 @@ impl Default for Config { adaptive_window: false, initial_conn_window_size: DEFAULT_CONN_WINDOW, initial_stream_window_size: DEFAULT_STREAM_WINDOW, + #[cfg(feature = "runtime")] + keep_alive_interval: None, + #[cfg(feature = "runtime")] + keep_alive_timeout: Duration::from_secs(20), + #[cfg(feature = "runtime")] + keep_alive_while_idle: false, } } } @@ -75,16 +90,35 @@ where } }); - let sampler = if config.adaptive_window { - let (sampler, mut estimator) = - bdp::channel(conn.ping_pong().unwrap(), config.initial_stream_window_size); + let ping_config = ping::Config { + bdp_initial_window: if config.adaptive_window { + Some(config.initial_stream_window_size) + } else { + None + }, + #[cfg(feature = "runtime")] + keep_alive_interval: config.keep_alive_interval, + #[cfg(feature = "runtime")] + keep_alive_timeout: config.keep_alive_timeout, + #[cfg(feature = "runtime")] + keep_alive_while_idle: config.keep_alive_while_idle, + }; + + let ping = if ping_config.is_enabled() { + let pp = conn.ping_pong().expect("conn.ping_pong"); + let (recorder, mut ponger) = ping::channel(pp, ping_config); let conn = future::poll_fn(move |cx| { - match estimator.poll_estimate(cx) { - Poll::Ready(wnd) => { + match ponger.poll(cx) { + Poll::Ready(ping::Ponged::SizeUpdate(wnd)) => { conn.set_target_window_size(wnd); conn.set_initial_window_size(wnd)?; } + #[cfg(feature = "runtime")] + Poll::Ready(ping::Ponged::KeepAliveTimedOut) => { + debug!("connection keep-alive timed out"); + return Poll::Ready(Ok(())); + } Poll::Pending => {} } @@ -93,16 +127,16 @@ where let conn = conn.map_err(|e| debug!("connection error: {}", e)); exec.execute(conn_task(conn, conn_drop_rx, cancel_tx)); - sampler + recorder } else { let conn = conn.map_err(|e| debug!("connection error: {}", e)); exec.execute(conn_task(conn, conn_drop_rx, cancel_tx)); - bdp::disabled() + ping::disabled() }; Ok(ClientTask { - bdp: sampler, + ping, conn_drop_ref, conn_eof, executor: exec, @@ -135,7 +169,7 @@ pub(crate) struct ClientTask where B: Payload, { - bdp: bdp::Sampler, + ping: ping::Recorder, conn_drop_ref: ConnDropRef, conn_eof: ConnEof, executor: Exec, @@ -154,6 +188,7 @@ where match ready!(self.h2_tx.poll_ready(cx)) { Ok(()) => (), Err(err) => { + self.ping.ensure_not_timed_out()?; return if err.reason() == Some(::h2::Reason::NO_ERROR) { trace!("connection gracefully shutdown"); Poll::Ready(Ok(Dispatched::Shutdown)) @@ -188,6 +223,7 @@ where } }; + let ping = self.ping.clone(); if !eos { let mut pipe = Box::pin(PipeToSendStream::new(body, body_tx)).map(|res| { if let Err(e) = res { @@ -201,8 +237,13 @@ where Poll::Ready(_) => (), Poll::Pending => { let conn_drop_ref = self.conn_drop_ref.clone(); + // keep the ping recorder's knowledge of an + // "open stream" alive while this body is + // still sending... + let ping = ping.clone(); let pipe = pipe.map(move |x| { drop(conn_drop_ref); + drop(ping); x }); self.executor.execute(pipe); @@ -210,15 +251,21 @@ where } } - let bdp = self.bdp.clone(); let fut = fut.map(move |result| match result { Ok(res) => { + // record that we got the response headers + ping.record_non_data(); + let content_length = decode_content_length(res.headers()); - let res = - res.map(|stream| crate::Body::h2(stream, content_length, bdp)); + let res = res.map(|stream| { + let ping = ping.for_stream(&stream); + crate::Body::h2(stream, content_length, ping) + }); Ok(res) } Err(err) => { + ping.ensure_not_timed_out().map_err(|e| (e, None))?; + debug!("client response error: {}", err); Err((crate::Error::new_h2(err), None)) } diff --git a/src/proto/h2/mod.rs b/src/proto/h2/mod.rs index 80d52349c7..e25f038cad 100644 --- a/src/proto/h2/mod.rs +++ b/src/proto/h2/mod.rs @@ -12,8 +12,8 @@ use crate::body::Payload; use crate::common::{task, Future, Pin, Poll}; use crate::headers::content_length_parse_all; -pub(crate) mod bdp; pub(crate) mod client; +pub(crate) mod ping; pub(crate) mod server; pub(crate) use self::client::ClientTask; diff --git a/src/proto/h2/ping.rs b/src/proto/h2/ping.rs new file mode 100644 index 0000000000..405b7075ee --- /dev/null +++ b/src/proto/h2/ping.rs @@ -0,0 +1,509 @@ +/// HTTP2 Ping usage +/// +/// hyper uses HTTP2 pings for two purposes: +/// +/// 1. Adaptive flow control using BDP +/// 2. Connection keep-alive +/// +/// Both cases are optional. +/// +/// # BDP Algorithm +/// +/// 1. When receiving a DATA frame, if a BDP ping isn't outstanding: +/// 1a. Record current time. +/// 1b. Send a BDP ping. +/// 2. Increment the number of received bytes. +/// 3. When the BDP ping ack is received: +/// 3a. Record duration from sent time. +/// 3b. Merge RTT with a running average. +/// 3c. Calculate bdp as bytes/rtt. +/// 3d. If bdp is over 2/3 max, set new max to bdp and update windows. + +#[cfg(feature = "runtime")] +use std::fmt; +#[cfg(feature = "runtime")] +use std::future::Future; +#[cfg(feature = "runtime")] +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::{self, Poll}; +use std::time::Duration; +#[cfg(not(feature = "runtime"))] +use std::time::Instant; + +use h2::{Ping, PingPong}; +#[cfg(feature = "runtime")] +use tokio::time::{Delay, Instant}; + +type WindowSize = u32; + +pub(super) fn disabled() -> Recorder { + Recorder { shared: None } +} + +pub(super) fn channel(ping_pong: PingPong, config: Config) -> (Recorder, Ponger) { + debug_assert!( + config.is_enabled(), + "ping channel requires bdp or keep-alive config", + ); + + let bdp = config.bdp_initial_window.map(|wnd| Bdp { + bdp: wnd, + max_bandwidth: 0.0, + samples: 0, + rtt: 0.0, + }); + + let bytes = bdp.as_ref().map(|_| 0); + + #[cfg(feature = "runtime")] + let keep_alive = config.keep_alive_interval.map(|interval| KeepAlive { + interval, + timeout: config.keep_alive_timeout, + while_idle: config.keep_alive_while_idle, + timer: tokio::time::delay_for(interval), + state: KeepAliveState::Init, + }); + + #[cfg(feature = "runtime")] + let last_read_at = keep_alive.as_ref().map(|_| Instant::now()); + + let shared = Arc::new(Mutex::new(Shared { + bytes, + #[cfg(feature = "runtime")] + last_read_at, + #[cfg(feature = "runtime")] + is_keep_alive_timed_out: false, + ping_pong, + ping_sent_at: None, + })); + + ( + Recorder { + shared: Some(shared.clone()), + }, + Ponger { + bdp, + #[cfg(feature = "runtime")] + keep_alive, + shared, + }, + ) +} + +#[derive(Clone)] +pub(super) struct Config { + pub(super) bdp_initial_window: Option, + /// If no frames are received in this amount of time, a PING frame is sent. + #[cfg(feature = "runtime")] + pub(super) keep_alive_interval: Option, + /// After sending a keepalive PING, the connection will be closed if + /// a pong is not received in this amount of time. + #[cfg(feature = "runtime")] + pub(super) keep_alive_timeout: Duration, + /// If true, sends pings even when there are no active streams. + #[cfg(feature = "runtime")] + pub(super) keep_alive_while_idle: bool, +} + +#[derive(Clone)] +pub(crate) struct Recorder { + shared: Option>>, +} + +pub(super) struct Ponger { + bdp: Option, + #[cfg(feature = "runtime")] + keep_alive: Option, + shared: Arc>, +} + +struct Shared { + ping_pong: PingPong, + ping_sent_at: Option, + + // bdp + /// If `Some`, bdp is enabled, and this tracks how many bytes have been + /// read during the current sample. + bytes: Option, + + // keep-alive + /// If `Some`, keep-alive is enabled, and the Instant is how long ago + /// the connection read the last frame. + #[cfg(feature = "runtime")] + last_read_at: Option, + + #[cfg(feature = "runtime")] + is_keep_alive_timed_out: bool, +} + +struct Bdp { + /// Current BDP in bytes + bdp: u32, + /// Largest bandwidth we've seen so far. + max_bandwidth: f64, + /// Count of samples made (ping sent and received) + samples: usize, + /// Round trip time in seconds + rtt: f64, +} + +#[cfg(feature = "runtime")] +struct KeepAlive { + /// If no frames are received in this amount of time, a PING frame is sent. + interval: Duration, + /// After sending a keepalive PING, the connection will be closed if + /// a pong is not received in this amount of time. + timeout: Duration, + /// If true, sends pings even when there are no active streams. + while_idle: bool, + + state: KeepAliveState, + timer: Delay, +} + +#[cfg(feature = "runtime")] +enum KeepAliveState { + Init, + Scheduled, + PingSent, +} + +pub(super) enum Ponged { + SizeUpdate(WindowSize), + #[cfg(feature = "runtime")] + KeepAliveTimedOut, +} + +#[cfg(feature = "runtime")] +#[derive(Debug)] +pub(super) struct KeepAliveTimedOut; + +// ===== impl Config ===== + +impl Config { + pub(super) fn is_enabled(&self) -> bool { + #[cfg(feature = "runtime")] + { + self.bdp_initial_window.is_some() || self.keep_alive_interval.is_some() + } + + #[cfg(not(feature = "runtime"))] + { + self.bdp_initial_window.is_some() + } + } +} + +// ===== impl Recorder ===== + +impl Recorder { + pub(crate) fn record_data(&self, len: usize) { + let shared = if let Some(ref shared) = self.shared { + shared + } else { + return; + }; + + let mut locked = shared.lock().unwrap(); + + #[cfg(feature = "runtime")] + locked.update_last_read_at(); + + if let Some(ref mut bytes) = locked.bytes { + *bytes += len; + } else { + // no need to send bdp ping if bdp is disabled + return; + } + + if !locked.is_ping_sent() { + locked.send_ping(); + } + } + + pub(crate) fn record_non_data(&self) { + #[cfg(feature = "runtime")] + { + let shared = if let Some(ref shared) = self.shared { + shared + } else { + return; + }; + + let mut locked = shared.lock().unwrap(); + + locked.update_last_read_at(); + } + } + + /// If the incoming stream is already closed, convert self into + /// a disabled reporter. + pub(super) fn for_stream(self, stream: &h2::RecvStream) -> Self { + if stream.is_end_stream() { + disabled() + } else { + self + } + } + + pub(super) fn ensure_not_timed_out(&self) -> crate::Result<()> { + #[cfg(feature = "runtime")] + { + if let Some(ref shared) = self.shared { + let locked = shared.lock().unwrap(); + if locked.is_keep_alive_timed_out { + return Err(KeepAliveTimedOut.crate_error()); + } + } + } + + // else + Ok(()) + } +} + +// ===== impl Ponger ===== + +impl Ponger { + pub(super) fn poll(&mut self, cx: &mut task::Context<'_>) -> Poll { + let mut locked = self.shared.lock().unwrap(); + #[cfg(feature = "runtime")] + let is_idle = self.is_idle(); + + #[cfg(feature = "runtime")] + { + if let Some(ref mut ka) = self.keep_alive { + ka.schedule(is_idle, &locked); + ka.maybe_ping(cx, &mut locked); + } + } + + if !locked.is_ping_sent() { + // XXX: this doesn't register a waker...? + return Poll::Pending; + } + + let (bytes, rtt) = match locked.ping_pong.poll_pong(cx) { + Poll::Ready(Ok(_pong)) => { + let rtt = locked + .ping_sent_at + .expect("pong received implies ping_sent_at") + .elapsed(); + locked.ping_sent_at = None; + trace!("recv pong"); + + #[cfg(feature = "runtime")] + { + if let Some(ref mut ka) = self.keep_alive { + locked.update_last_read_at(); + ka.schedule(is_idle, &locked); + } + } + + if let Some(ref mut bdp) = self.bdp { + let bytes = locked.bytes.expect("bdp enabled implies bytes"); + locked.bytes = Some(0); // reset + bdp.samples += 1; + trace!("received BDP ack; bytes = {}, rtt = {:?}", bytes, rtt); + (bytes, rtt) + } else { + // no bdp, done! + return Poll::Pending; + } + } + Poll::Ready(Err(e)) => { + debug!("pong error: {}", e); + return Poll::Pending; + } + Poll::Pending => { + #[cfg(feature = "runtime")] + { + if let Some(ref mut ka) = self.keep_alive { + if let Err(KeepAliveTimedOut) = ka.maybe_timeout(cx) { + self.keep_alive = None; + locked.is_keep_alive_timed_out = true; + return Poll::Ready(Ponged::KeepAliveTimedOut); + } + } + } + + return Poll::Pending; + } + }; + + drop(locked); + + if let Some(bdp) = self.bdp.as_mut().and_then(|bdp| bdp.calculate(bytes, rtt)) { + Poll::Ready(Ponged::SizeUpdate(bdp)) + } else { + // XXX: this doesn't register a waker...? + Poll::Pending + } + } + + #[cfg(feature = "runtime")] + fn is_idle(&self) -> bool { + Arc::strong_count(&self.shared) <= 2 + } +} + +// ===== impl Shared ===== + +impl Shared { + fn send_ping(&mut self) { + match self.ping_pong.send_ping(Ping::opaque()) { + Ok(()) => { + self.ping_sent_at = Some(Instant::now()); + trace!("sent ping"); + } + Err(err) => { + debug!("error sending ping: {}", err); + } + } + } + + fn is_ping_sent(&self) -> bool { + self.ping_sent_at.is_some() + } + + #[cfg(feature = "runtime")] + fn update_last_read_at(&mut self) { + if self.last_read_at.is_some() { + self.last_read_at = Some(Instant::now()); + } + } + + #[cfg(feature = "runtime")] + fn last_read_at(&self) -> Instant { + self.last_read_at.expect("keep_alive expects last_read_at") + } +} + +// ===== impl Bdp ===== + +/// Any higher than this likely will be hitting the TCP flow control. +const BDP_LIMIT: usize = 1024 * 1024 * 16; + +impl Bdp { + fn calculate(&mut self, bytes: usize, rtt: Duration) -> Option { + // No need to do any math if we're at the limit. + if self.bdp as usize == BDP_LIMIT { + return None; + } + + // average the rtt + let rtt = seconds(rtt); + if self.samples < 10 { + // Average the first 10 samples + self.rtt += (rtt - self.rtt) / (self.samples as f64); + } else { + self.rtt += (rtt - self.rtt) / 0.9; + } + + // calculate the current bandwidth + let bw = (bytes as f64) / (self.rtt * 1.5); + trace!("current bandwidth = {:.1}B/s", bw); + + if bw < self.max_bandwidth { + // not a faster bandwidth, so don't update + return None; + } else { + self.max_bandwidth = bw; + } + + // if the current `bytes` sample is at least 2/3 the previous + // bdp, increase to double the current sample. + if (bytes as f64) >= (self.bdp as f64) * 0.66 { + self.bdp = (bytes * 2).min(BDP_LIMIT) as WindowSize; + trace!("BDP increased to {}", self.bdp); + Some(self.bdp) + } else { + None + } + } +} + +fn seconds(dur: Duration) -> f64 { + const NANOS_PER_SEC: f64 = 1_000_000_000.0; + let secs = dur.as_secs() as f64; + secs + (dur.subsec_nanos() as f64) / NANOS_PER_SEC +} + +// ===== impl KeepAlive ===== + +#[cfg(feature = "runtime")] +impl KeepAlive { + fn schedule(&mut self, is_idle: bool, shared: &Shared) { + match self.state { + KeepAliveState::Init => { + if !self.while_idle && is_idle { + return; + } + + self.state = KeepAliveState::Scheduled; + let interval = shared.last_read_at() + self.interval; + self.timer.reset(interval); + } + KeepAliveState::Scheduled | KeepAliveState::PingSent => (), + } + } + + fn maybe_ping(&mut self, cx: &mut task::Context<'_>, shared: &mut Shared) { + match self.state { + KeepAliveState::Scheduled => { + if Pin::new(&mut self.timer).poll(cx).is_pending() { + return; + } + // check if we've received a frame while we were scheduled + if shared.last_read_at() + self.interval > self.timer.deadline() { + self.state = KeepAliveState::Init; + cx.waker().wake_by_ref(); // schedule us again + return; + } + trace!("keep-alive interval ({:?}) reached", self.interval); + shared.send_ping(); + self.state = KeepAliveState::PingSent; + let timeout = Instant::now() + self.timeout; + self.timer.reset(timeout); + } + KeepAliveState::Init | KeepAliveState::PingSent => (), + } + } + + fn maybe_timeout(&mut self, cx: &mut task::Context<'_>) -> Result<(), KeepAliveTimedOut> { + match self.state { + KeepAliveState::PingSent => { + if Pin::new(&mut self.timer).poll(cx).is_pending() { + return Ok(()); + } + trace!("keep-alive timeout ({:?}) reached", self.timeout); + Err(KeepAliveTimedOut) + } + KeepAliveState::Init | KeepAliveState::Scheduled => Ok(()), + } + } +} + +// ===== impl KeepAliveTimedOut ===== + +#[cfg(feature = "runtime")] +impl KeepAliveTimedOut { + pub(super) fn crate_error(self) -> crate::Error { + crate::Error::new(crate::error::Kind::Http2).with(self) + } +} + +#[cfg(feature = "runtime")] +impl fmt::Display for KeepAliveTimedOut { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("keep-alive timed out") + } +} + +#[cfg(feature = "runtime")] +impl std::error::Error for KeepAliveTimedOut { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(&crate::error::TimedOut) + } +} diff --git a/src/proto/h2/server.rs b/src/proto/h2/server.rs index b8d1afb925..bf81c1190f 100644 --- a/src/proto/h2/server.rs +++ b/src/proto/h2/server.rs @@ -1,12 +1,14 @@ use std::error::Error as StdError; use std::marker::Unpin; +#[cfg(feature = "runtime")] +use std::time::Duration; use h2::server::{Connection, Handshake, SendResponse}; use h2::Reason; use pin_project::{pin_project, project}; use tokio::io::{AsyncRead, AsyncWrite}; -use super::{bdp, decode_content_length, PipeToSendStream, SendBuf}; +use super::{decode_content_length, ping, PipeToSendStream, SendBuf}; use crate::body::Payload; use crate::common::exec::H2Exec; use crate::common::{task, Future, Pin, Poll}; @@ -31,6 +33,10 @@ pub(crate) struct Config { pub(crate) initial_conn_window_size: u32, pub(crate) initial_stream_window_size: u32, pub(crate) max_concurrent_streams: Option, + #[cfg(feature = "runtime")] + pub(crate) keep_alive_interval: Option, + #[cfg(feature = "runtime")] + pub(crate) keep_alive_timeout: Duration, } impl Default for Config { @@ -40,6 +46,10 @@ impl Default for Config { initial_conn_window_size: DEFAULT_CONN_WINDOW, initial_stream_window_size: DEFAULT_STREAM_WINDOW, max_concurrent_streams: None, + #[cfg(feature = "runtime")] + keep_alive_interval: None, + #[cfg(feature = "runtime")] + keep_alive_timeout: Duration::from_secs(20), } } } @@ -60,10 +70,7 @@ where B: Payload, { Handshaking { - /// If Some, bdp is enabled with the initial size. - /// - /// If None, bdp is disabled. - bdp_initial_size: Option, + ping_config: ping::Config, hs: Handshake>, }, Serving(Serving), @@ -74,7 +81,7 @@ struct Serving where B: Payload, { - bdp: Option<(bdp::Sampler, bdp::Estimator)>, + ping: Option<(ping::Recorder, ping::Ponger)>, conn: Connection>, closing: Option, } @@ -103,10 +110,22 @@ where None }; + let ping_config = ping::Config { + bdp_initial_window: bdp, + #[cfg(feature = "runtime")] + keep_alive_interval: config.keep_alive_interval, + #[cfg(feature = "runtime")] + keep_alive_timeout: config.keep_alive_timeout, + // If keep-alive is enabled for servers, always enabled while + // idle, so it can more aggresively close dead connections. + #[cfg(feature = "runtime")] + keep_alive_while_idle: true, + }; + Server { exec, state: State::Handshaking { - bdp_initial_size: bdp, + ping_config, hs: handshake, }, service, @@ -149,13 +168,17 @@ where let next = match me.state { State::Handshaking { ref mut hs, - ref bdp_initial_size, + ref ping_config, } => { let mut conn = ready!(Pin::new(hs).poll(cx).map_err(crate::Error::new_h2))?; - let bdp = bdp_initial_size - .map(|wnd| bdp::channel(conn.ping_pong().expect("ping_pong"), wnd)); + let ping = if ping_config.is_enabled() { + let pp = conn.ping_pong().expect("conn.ping_pong"); + Some(ping::channel(pp, ping_config.clone())) + } else { + None + }; State::Serving(Serving { - bdp, + ping, conn, closing: None, }) @@ -193,7 +216,7 @@ where { if self.closing.is_none() { loop { - self.poll_bdp(cx); + self.poll_ping(cx); // Check that the service is ready to accept a new request. // @@ -231,14 +254,16 @@ where Some(Ok((req, respond))) => { trace!("incoming request"); let content_length = decode_content_length(req.headers()); - let bdp_sampler = self - .bdp + let ping = self + .ping .as_ref() - .map(|bdp| bdp.0.clone()) - .unwrap_or_else(bdp::disabled); + .map(|ping| ping.0.clone()) + .unwrap_or_else(ping::disabled); - let req = - req.map(|stream| crate::Body::h2(stream, content_length, bdp_sampler)); + // Record the headers received + ping.record_non_data(); + + let req = req.map(|stream| crate::Body::h2(stream, content_length, ping)); let fut = H2Stream::new(service.call(req), respond); exec.execute_h2stream(fut); } @@ -247,6 +272,10 @@ where } None => { // no more incoming streams... + if let Some((ref ping, _)) = self.ping { + ping.ensure_not_timed_out()?; + } + trace!("incoming connection complete"); return Poll::Ready(Ok(())); } @@ -264,13 +293,18 @@ where Poll::Ready(Err(self.closing.take().expect("polled after error"))) } - fn poll_bdp(&mut self, cx: &mut task::Context<'_>) { - if let Some((_, ref mut estimator)) = self.bdp { - match estimator.poll_estimate(cx) { - Poll::Ready(wnd) => { + fn poll_ping(&mut self, cx: &mut task::Context<'_>) { + if let Some((_, ref mut estimator)) = self.ping { + match estimator.poll(cx) { + Poll::Ready(ping::Ponged::SizeUpdate(wnd)) => { self.conn.set_target_window_size(wnd); let _ = self.conn.set_initial_window_size(wnd); } + #[cfg(feature = "runtime")] + Poll::Ready(ping::Ponged::KeepAliveTimedOut) => { + debug!("keep-alive timed out, closing connection"); + self.conn.abrupt_shutdown(h2::Reason::NO_ERROR); + } Poll::Pending => {} } } diff --git a/src/server/conn.rs b/src/server/conn.rs index 889cb6b206..74bb18dbfc 100644 --- a/src/server/conn.rs +++ b/src/server/conn.rs @@ -13,6 +13,8 @@ use std::fmt; use std::mem; #[cfg(feature = "tcp")] use std::net::SocketAddr; +#[cfg(feature = "runtime")] +use std::time::Duration; use bytes::Bytes; use pin_project::{pin_project, project}; @@ -46,10 +48,10 @@ pub use super::tcp::{AddrIncoming, AddrStream}; pub struct Http { exec: E, h1_half_close: bool, + h1_keep_alive: bool, h1_writev: bool, h2_builder: proto::h2::server::Config, mode: ConnectionMode, - keep_alive: bool, max_buf_size: Option, pipeline_flush: bool, } @@ -182,10 +184,10 @@ impl Http { Http { exec: Exec::Default, h1_half_close: false, + h1_keep_alive: true, h1_writev: true, h2_builder: Default::default(), mode: ConnectionMode::Fallback, - keep_alive: true, max_buf_size: None, pipeline_flush: false, } @@ -218,6 +220,21 @@ impl Http { self } + /// Enables or disables HTTP/1 keep-alive. + /// + /// Default is true. + pub fn http1_keep_alive(&mut self, val: bool) -> &mut Self { + self.h1_keep_alive = val; + self + } + + // renamed due different semantics of http2 keep alive + #[doc(hidden)] + #[deprecated(note = "renamed to `http1_keep_alive`")] + pub fn keep_alive(&mut self, val: bool) -> &mut Self { + self.http1_keep_alive(val) + } + /// Set whether HTTP/1 connections should try to use vectored writes, /// or always flatten into a single buffer. /// @@ -303,11 +320,38 @@ impl Http { self } - /// Enables or disables HTTP keep-alive. + /// Sets an interval for HTTP2 Ping frames should be sent to keep a + /// connection alive. /// - /// Default is true. - pub fn keep_alive(&mut self, val: bool) -> &mut Self { - self.keep_alive = val; + /// Pass `None` to disable HTTP2 keep-alive. + /// + /// Default is currently disabled. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + pub fn http2_keep_alive_interval( + &mut self, + interval: impl Into>, + ) -> &mut Self { + self.h2_builder.keep_alive_interval = interval.into(); + self + } + + /// Sets a timeout for receiving an acknowledgement of the keep-alive ping. + /// + /// If the ping is not acknowledged within the timeout, the connection will + /// be closed. Does nothing if `http2_keep_alive_interval` is disabled. + /// + /// Default is 20 seconds. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + pub fn http2_keep_alive_timeout(&mut self, timeout: Duration) -> &mut Self { + self.h2_builder.keep_alive_timeout = timeout; self } @@ -344,10 +388,10 @@ impl Http { Http { exec, h1_half_close: self.h1_half_close, + h1_keep_alive: self.h1_keep_alive, h1_writev: self.h1_writev, h2_builder: self.h2_builder, mode: self.mode, - keep_alive: self.keep_alive, max_buf_size: self.max_buf_size, pipeline_flush: self.pipeline_flush, } @@ -392,7 +436,7 @@ impl Http { let proto = match self.mode { ConnectionMode::H1Only | ConnectionMode::Fallback => { let mut conn = proto::Conn::new(io); - if !self.keep_alive { + if !self.h1_keep_alive { conn.disable_keep_alive(); } if self.h1_half_close { diff --git a/src/server/mod.rs b/src/server/mod.rs index c6a16a211b..ed6068c867 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -240,7 +240,7 @@ impl Builder { /// /// Default is `true`. pub fn http1_keepalive(mut self, val: bool) -> Self { - self.protocol.keep_alive(val); + self.protocol.http1_keep_alive(val); self } @@ -257,11 +257,11 @@ impl Builder { self } - /// Sets whether HTTP/1 is required. + /// Set the maximum buffer size. /// - /// Default is `false`. - pub fn http1_only(mut self, val: bool) -> Self { - self.protocol.http1_only(val); + /// Default is ~ 400kb. + pub fn http1_max_buf_size(mut self, val: usize) -> Self { + self.protocol.max_buf_size(val); self } @@ -290,6 +290,14 @@ impl Builder { self } + /// Sets whether HTTP/1 is required. + /// + /// Default is `false`. + pub fn http1_only(mut self, val: bool) -> Self { + self.protocol.http1_only(val); + self + } + /// Sets whether HTTP/2 is required. /// /// Default is `false`. @@ -343,11 +351,35 @@ impl Builder { self } - /// Set the maximum buffer size. + /// Sets an interval for HTTP2 Ping frames should be sent to keep a + /// connection alive. /// - /// Default is ~ 400kb. - pub fn http1_max_buf_size(mut self, val: usize) -> Self { - self.protocol.max_buf_size(val); + /// Pass `None` to disable HTTP2 keep-alive. + /// + /// Default is currently disabled. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + pub fn http2_keep_alive_interval(mut self, interval: impl Into>) -> Self { + self.protocol.http2_keep_alive_interval(interval); + self + } + + /// Sets a timeout for receiving an acknowledgement of the keep-alive ping. + /// + /// If the ping is not acknowledged within the timeout, the connection will + /// be closed. Does nothing if `http2_keep_alive_interval` is disabled. + /// + /// Default is 20 seconds. + /// + /// # Cargo Feature + /// + /// Requires the `runtime` cargo feature to be enabled. + #[cfg(feature = "runtime")] + pub fn http2_keep_alive_timeout(mut self, timeout: Duration) -> Self { + self.protocol.http2_keep_alive_timeout(timeout); self } diff --git a/tests/client.rs b/tests/client.rs index f7db372273..576423768f 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -2537,6 +2537,198 @@ mod conn { .expect_err("client should be closed"); } + #[tokio::test] + async fn http2_keep_alive_detects_unresponsive_server() { + let _ = pretty_env_logger::try_init(); + + let mut listener = TkTcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))) + .await + .unwrap(); + let addr = listener.local_addr().unwrap(); + + // spawn a server that reads but doesn't write + tokio::spawn(async move { + let mut sock = listener.accept().await.unwrap().0; + let mut buf = [0u8; 1024]; + loop { + let n = sock.read(&mut buf).await.expect("server read"); + if n == 0 { + // server closed, lets go! + break; + } + } + }); + + let io = tcp_connect(&addr).await.expect("tcp connect"); + let (_client, conn) = conn::Builder::new() + .http2_only(true) + .http2_keep_alive_interval(Duration::from_secs(1)) + .http2_keep_alive_timeout(Duration::from_secs(1)) + // enable while idle since we aren't sending requests + .http2_keep_alive_while_idle(true) + .handshake::<_, Body>(io) + .await + .expect("http handshake"); + + conn.await.expect_err("conn should time out"); + } + + #[tokio::test] + async fn http2_keep_alive_not_while_idle() { + // This tests that not setting `http2_keep_alive_while_idle(true)` + // will use the default behavior which will NOT detect the server + // is unresponsive while no streams are active. + + let _ = pretty_env_logger::try_init(); + + let mut listener = TkTcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))) + .await + .unwrap(); + let addr = listener.local_addr().unwrap(); + + // spawn a server that reads but doesn't write + tokio::spawn(async move { + let sock = listener.accept().await.unwrap().0; + drain_til_eof(sock).await.expect("server read"); + }); + + let io = tcp_connect(&addr).await.expect("tcp connect"); + let (mut client, conn) = conn::Builder::new() + .http2_only(true) + .http2_keep_alive_interval(Duration::from_secs(1)) + .http2_keep_alive_timeout(Duration::from_secs(1)) + .handshake::<_, Body>(io) + .await + .expect("http handshake"); + + tokio::spawn(async move { + conn.await.expect("client conn shouldn't error"); + }); + + // sleep longer than keepalive would trigger + tokio::time::delay_for(Duration::from_secs(4)).await; + + future::poll_fn(|ctx| client.poll_ready(ctx)) + .await + .expect("client should be open"); + } + + #[tokio::test] + async fn http2_keep_alive_closes_open_streams() { + let _ = pretty_env_logger::try_init(); + + let mut listener = TkTcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))) + .await + .unwrap(); + let addr = listener.local_addr().unwrap(); + + // spawn a server that reads but doesn't write + tokio::spawn(async move { + let sock = listener.accept().await.unwrap().0; + drain_til_eof(sock).await.expect("server read"); + }); + + let io = tcp_connect(&addr).await.expect("tcp connect"); + let (mut client, conn) = conn::Builder::new() + .http2_only(true) + .http2_keep_alive_interval(Duration::from_secs(1)) + .http2_keep_alive_timeout(Duration::from_secs(1)) + .handshake::<_, Body>(io) + .await + .expect("http handshake"); + + tokio::spawn(async move { + let err = conn.await.expect_err("client conn should timeout"); + assert!(err.is_timeout()); + }); + + let req = http::Request::new(hyper::Body::empty()); + let err = client + .send_request(req) + .await + .expect_err("request should timeout"); + assert!(err.is_timeout()); + + let err = future::poll_fn(|ctx| client.poll_ready(ctx)) + .await + .expect_err("client should be closed"); + assert!( + err.is_closed(), + "poll_ready error should be closed: {:?}", + err + ); + } + + #[tokio::test] + async fn http2_keep_alive_with_responsive_server() { + // Test that a responsive server works just when client keep + // alive is enabled + use hyper::service::service_fn; + + let _ = pretty_env_logger::try_init(); + + let mut listener = TkTcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))) + .await + .unwrap(); + let addr = listener.local_addr().unwrap(); + + // Spawn an HTTP2 server that reads the whole body and responds + tokio::spawn(async move { + let sock = listener.accept().await.unwrap().0; + hyper::server::conn::Http::new() + .http2_only(true) + .serve_connection( + sock, + service_fn(|req| async move { + tokio::spawn(async move { + let _ = hyper::body::aggregate(req.into_body()) + .await + .expect("server req body aggregate"); + }); + Ok::<_, hyper::Error>(http::Response::new(hyper::Body::empty())) + }), + ) + .await + .expect("serve_connection"); + }); + + let io = tcp_connect(&addr).await.expect("tcp connect"); + let (mut client, conn) = conn::Builder::new() + .http2_only(true) + .http2_keep_alive_interval(Duration::from_secs(1)) + .http2_keep_alive_timeout(Duration::from_secs(1)) + .handshake::<_, Body>(io) + .await + .expect("http handshake"); + + tokio::spawn(async move { + conn.await.expect("client conn shouldn't error"); + }); + + // Use a channel to keep request stream open + let (_tx, body) = hyper::Body::channel(); + let req1 = http::Request::new(body); + let _resp = client.send_request(req1).await.expect("send_request"); + + // sleep longer than keepalive would trigger + tokio::time::delay_for(Duration::from_secs(4)).await; + + future::poll_fn(|ctx| client.poll_ready(ctx)) + .await + .expect("client should be open"); + } + + async fn drain_til_eof(mut sock: T) -> io::Result<()> { + let mut buf = [0u8; 1024]; + loop { + let n = sock.read(&mut buf).await?; + if n == 0 { + // socket closed, lets go! + return Ok(()); + } + } + } + struct DebugStream { tcp: TcpStream, shutdown_called: bool, diff --git a/tests/server.rs b/tests/server.rs index 6818b2f903..cdc79ef7e7 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -18,7 +18,7 @@ use futures_util::future::{self, Either, FutureExt, TryFutureExt}; #[cfg(feature = "stream")] use futures_util::stream::StreamExt as _; use http::header::{HeaderName, HeaderValue}; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream as TkTcpStream}; use tokio::runtime::Runtime; @@ -1818,6 +1818,91 @@ fn skips_content_length_and_body_for_304_responses() { assert_eq!(lines.next(), Some("")); assert_eq!(lines.next(), None); } + +#[tokio::test] +async fn http2_keep_alive_detects_unresponsive_client() { + let _ = pretty_env_logger::try_init(); + + let mut listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); + let addr = listener.local_addr().unwrap(); + + // Spawn a "client" conn that only reads until EOF + tokio::spawn(async move { + let mut conn = connect_async(addr).await; + + // write h2 magic preface and settings frame + conn.write_all(b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") + .await + .expect("client preface"); + conn.write_all(&[ + 0, 0, 0, // len + 4, // kind + 0, // flag + 0, 0, 0, // stream id + ]) + .await + .expect("client settings"); + + // read until eof + let mut buf = [0u8; 1024]; + loop { + let n = conn.read(&mut buf).await.expect("client.read"); + if n == 0 { + // eof + break; + } + } + }); + + let (socket, _) = listener.accept().await.expect("accept"); + + let err = Http::new() + .http2_only(true) + .http2_keep_alive_interval(Duration::from_secs(1)) + .http2_keep_alive_timeout(Duration::from_secs(1)) + .serve_connection(socket, unreachable_service()) + .await + .expect_err("serve_connection should error"); + + assert!(err.is_timeout()); +} + +#[tokio::test] +async fn http2_keep_alive_with_responsive_client() { + let _ = pretty_env_logger::try_init(); + + let mut listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let (socket, _) = listener.accept().await.expect("accept"); + + Http::new() + .http2_only(true) + .http2_keep_alive_interval(Duration::from_secs(1)) + .http2_keep_alive_timeout(Duration::from_secs(1)) + .serve_connection(socket, HelloWorld) + .await + .expect("serve_connection"); + }); + + let tcp = connect_async(addr).await; + let (mut client, conn) = hyper::client::conn::Builder::new() + .http2_only(true) + .handshake::<_, Body>(tcp) + .await + .expect("http handshake"); + + tokio::spawn(async move { + conn.await.expect("client conn"); + }); + + tokio::time::delay_for(Duration::from_secs(4)).await; + + let req = http::Request::new(hyper::Body::empty()); + client.send_request(req).await.expect("client.send_request"); +} + // ------------------------------------------------- // the Server that is used to run all the tests with // ------------------------------------------------- @@ -1864,6 +1949,7 @@ impl Serve { } type BoxError = Box; +type BoxFuture = Pin, BoxError>> + Send>>; struct ReplyBuilder<'a> { tx: &'a Mutex>, @@ -1965,7 +2051,7 @@ enum Msg { impl tower_service::Service> for TestService { type Response = Response; type Error = BoxError; - type Future = Pin, BoxError>> + Send>>; + type Future = BoxFuture; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Ok(()).into() @@ -2039,6 +2125,15 @@ impl tower_service::Service> for HelloWorld { } } +fn unreachable_service() -> impl tower_service::Service< + http::Request, + Response = http::Response, + Error = BoxError, + Future = BoxFuture, +> { + service_fn(|_req| Box::pin(async { Err("request shouldn't be received".into()) }) as BoxFuture) +} + fn connect(addr: &SocketAddr) -> TcpStream { let req = TcpStream::connect(addr).unwrap(); req.set_read_timeout(Some(Duration::from_secs(1))).unwrap(); @@ -2046,6 +2141,10 @@ fn connect(addr: &SocketAddr) -> TcpStream { req } +async fn connect_async(addr: SocketAddr) -> TkTcpStream { + TkTcpStream::connect(addr).await.expect("connect_async") +} + fn serve() -> Serve { serve_opts().serve() }