diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index 358cdfb660..148d43ce8c 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -5,10 +5,12 @@ use std::marker::PhantomData; use bytes::{Buf, Bytes}; use futures::{Async, Poll}; use http::{HeaderMap, Method, Version}; +use http::header::{HeaderValue, CONNECTION}; use tokio_io::{AsyncRead, AsyncWrite}; use ::Chunk; use proto::{BodyLength, DecodedLength, MessageHead}; +use headers::connection_keep_alive; use super::io::{Buffered}; use super::{EncodedBuf, Encode, Encoder, /*Decode,*/ Decoder, Http1Transaction, ParseContext}; @@ -438,12 +440,38 @@ where I: AsyncRead + AsyncWrite, } } + // Fix keep-alives when Connection: keep-alive header is not present + fn fix_keep_alive(&mut self, head: &mut MessageHead) { + let outgoing_is_keep_alive = head + .headers + .get(CONNECTION) + .and_then(|value| Some(connection_keep_alive(value))) + .unwrap_or(false); + + if !outgoing_is_keep_alive { + match head.version { + // If response is version 1.0 and keep-alive is not present in the response, + // disable keep-alive so the server closes the connection + Version::HTTP_10 => self.state.disable_keep_alive(), + // If response is version 1.1 and keep-alive is wanted, add + // Connection: keep-alive header when not present + Version::HTTP_11 => if self.state.wants_keep_alive() { + head.headers + .insert(CONNECTION, HeaderValue::from_static("keep-alive")); + }, + _ => (), + } + } + } + // If we know the remote speaks an older version, we try to fix up any messages // to work with our older peer. fn enforce_version(&mut self, head: &mut MessageHead) { match self.state.version { Version::HTTP_10 => { + // Fixes response or connection when keep-alive header is not present + self.fix_keep_alive(head); // If the remote only knows HTTP/1.0, we should force ourselves // to do only speak HTTP/1.0 as well. head.version = Version::HTTP_10; diff --git a/tests/server.rs b/tests/server.rs index fd31a5ac07..f5c13fab2a 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -30,7 +30,7 @@ use tokio::reactor::Handle; use tokio_io::{AsyncRead, AsyncWrite}; -use hyper::{Body, Request, Response, StatusCode}; +use hyper::{Body, Request, Response, StatusCode, Version}; use hyper::client::Client; use hyper::server::conn::Http; use hyper::server::Server; @@ -637,6 +637,7 @@ fn keep_alive() { fn http_10_keep_alive() { let foo_bar = b"foo bar baz"; let server = serve(); + // Response version 1.1 with no keep-alive header will downgrade to 1.0 when served server.reply() .header("content-length", foo_bar.len().to_string()) .body(foo_bar); @@ -658,6 +659,10 @@ fn http_10_keep_alive() { } } + // Connection: keep-alive header should be added when downgrading to a 1.0 response + let response = String::from_utf8(buf.to_vec()).unwrap(); + response.contains("Connection: keep-alive\r\n"); + // try again! let quux = b"zar quux"; @@ -682,6 +687,69 @@ fn http_10_keep_alive() { } } +#[test] +fn http_10_close_on_no_ka() { + let foo_bar = b"foo bar baz"; + let server = serve(); + + // A server response with version 1.0 but no keep-alive header + server + .reply() + .version(Version::HTTP_10) + .header("content-length", foo_bar.len().to_string()) + .body(foo_bar); + let mut req = connect(server.addr()); + + // The client request with version 1.0 that may have the keep-alive header + req.write_all( + b"\ + GET / HTTP/1.0\r\n\ + Host: example.domain\r\n\ + Connection: keep-alive\r\n\ + \r\n\ + ", + ).expect("writing 1"); + + let mut buf = [0; 1024 * 8]; + loop { + let n = req.read(&mut buf[..]).expect("reading 1"); + if n < buf.len() { + if &buf[n - foo_bar.len()..n] == foo_bar { + break; + } else { + } + } + } + + // try again! + + let quux = b"zar quux"; + server + .reply() + .header("content-length", quux.len().to_string()) + .body(quux); + + // the write can possibly succeed, since it fills the kernel buffer on the first write + let _ = req.write_all( + b"\ + GET /quux HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: close\r\n\ + \r\n\ + ", + ); + + let mut buf = [0; 1024 * 8]; + match req.read(&mut buf[..]) { + // Ok(0) means EOF, so a proper shutdown + // Err(_) could mean ConnReset or something, also fine + Ok(0) | Err(_) => {} + Ok(n) => { + panic!("read {} bytes on a disabled keep-alive socket", n); + } + } +} + #[test] fn disable_keep_alive() { let foo_bar = b"foo bar baz";