From 37ec724fd6405dd97c5873dddc956df1711b29ab Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Mon, 8 Oct 2018 17:52:08 -0700 Subject: [PATCH] fix(http2): add Date header if not present for HTTP2 server responses --- src/proto/h1/date.rs | 10 +++ src/proto/h1/mod.rs | 2 +- src/proto/h2/server.rs | 9 +++ tests/integration.rs | 24 +++++++ tests/support/mod.rs | 140 ++++++++++++++++++++++++++++++----------- 5 files changed, 149 insertions(+), 36 deletions(-) diff --git a/src/proto/h1/date.rs b/src/proto/h1/date.rs index 48bcdfcd6b..abaa3f9c60 100644 --- a/src/proto/h1/date.rs +++ b/src/proto/h1/date.rs @@ -2,6 +2,7 @@ use std::cell::RefCell; use std::fmt::{self, Write}; use std::str; +use http::header::HeaderValue; use time::{self, Duration}; // "Sun, 06 Nov 1994 08:49:37 GMT".len() @@ -19,6 +20,15 @@ pub fn update() { }) } +pub(crate) fn update_and_header_value() -> HeaderValue { + CACHED.with(|cache| { + let mut cache = cache.borrow_mut(); + cache.check(); + HeaderValue::from_bytes(cache.buffer()) + .expect("Date format should be valid HeaderValue") + }) +} + struct CachedDate { bytes: [u8; DATE_VALUE_LENGTH], pos: usize, diff --git a/src/proto/h1/mod.rs b/src/proto/h1/mod.rs index 15faa2135f..5facd13534 100644 --- a/src/proto/h1/mod.rs +++ b/src/proto/h1/mod.rs @@ -11,7 +11,7 @@ pub use self::io::Cursor; //TODO: move out of h1::io pub use self::io::MINIMUM_MAX_BUFFER_SIZE; mod conn; -mod date; +pub(super) mod date; mod decode; pub(crate) mod dispatch; mod encode; diff --git a/src/proto/h2/server.rs b/src/proto/h2/server.rs index 1ded63b33e..0dc8e7a866 100644 --- a/src/proto/h2/server.rs +++ b/src/proto/h2/server.rs @@ -193,6 +193,15 @@ where let (head, body) = res.into_parts(); let mut res = ::http::Response::from_parts(head, ()); super::strip_connection_headers(res.headers_mut(), false); + + // set Date header if it isn't already set... + res + .headers_mut() + .entry(::http::header::DATE) + .expect("DATE is a valid HeaderName") + .or_insert_with(::proto::h1::date::update_and_header_value); + + // automatically set Content-Length from body... if let Some(len) = body.content_length() { headers::set_content_length_if_missing(res.headers_mut(), len); } diff --git a/tests/integration.rs b/tests/integration.rs index 9e7450b685..68cc4c403a 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -11,6 +11,9 @@ t! { ; response: status: 200, + headers: { + "date" => SOME, + }, ; server: request: @@ -37,6 +40,27 @@ t! { ; } +t! { + date_isnt_overwritten, + client: + request: + ; + response: + status: 200, + headers: { + "date" => "let me through", + }, + ; + server: + request: + ; + response: + headers: { + "date" => "let me through", + }, + ; +} + t! { get_body, client: diff --git a/tests/support/mod.rs b/tests/support/mod.rs index 4052465564..ab8cc660e3 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -95,21 +95,21 @@ macro_rules! t { fn $name() { let c = vec![$(( __CReq { - $($c_req_prop: __internal_req_res_prop!($c_req_prop: $c_req_val),)* + $($c_req_prop: __internal_map_prop!($c_req_prop: $c_req_val),)* ..Default::default() }, __CRes { - $($c_res_prop: __internal_req_res_prop!($c_res_prop: $c_res_val),)* + $($c_res_prop: __internal_eq_prop!($c_res_prop: $c_res_val),)* ..Default::default() } ),)*]; let s = vec![$(( __SReq { - $($s_req_prop: __internal_req_res_prop!($s_req_prop: $s_req_val),)* + $($s_req_prop: __internal_eq_prop!($s_req_prop: $s_req_val),)* ..Default::default() }, __SRes { - $($s_res_prop: __internal_req_res_prop!($s_res_prop: $s_res_val),)* + $($s_res_prop: __internal_map_prop!($s_res_prop: $s_res_val),)* ..Default::default() } ),)*]; @@ -157,27 +157,47 @@ macro_rules! t { ); } -macro_rules! __internal_req_res_prop { - (method: $prop_val:expr) => ( - $prop_val - ); - (status: $prop_val:expr) => ( - StatusCode::from_u16($prop_val).expect("status code") - ); +macro_rules! __internal_map_prop { (headers: $map:tt) => ({ #[allow(unused_mut)] { let mut headers = HeaderMap::new(); - __internal_headers!(headers, $map); + __internal_headers_map!(headers, $map); + headers + } + }); + ($name:tt: $val:tt) => ({ + __internal_req_res_prop!($name: $val) + }); +} + +macro_rules! __internal_eq_prop { + (headers: $map:tt) => ({ + #[allow(unused_mut)] + { + let mut headers = Vec::new(); + __internal_headers_eq!(headers, $map); headers } }); + ($name:tt: $val:tt) => ({ + __internal_req_res_prop!($name: $val) + }); +} + +macro_rules! __internal_req_res_prop { + (method: $prop_val:expr) => ( + $prop_val + ); + (status: $prop_val:expr) => ( + StatusCode::from_u16($prop_val).expect("status code") + ); ($prop_name:ident: $prop_val:expr) => ( From::from($prop_val) ) } -macro_rules! __internal_headers { +macro_rules! __internal_headers_map { ($headers:ident, { $($name:expr => $val:expr,)* }) => { $( $headers.insert($name, $val.to_string().parse().expect("header value")); @@ -185,7 +205,39 @@ macro_rules! __internal_headers { } } -#[derive(Clone, Debug, Default)] +macro_rules! __internal_headers_eq { + (@pat $name: expr, $pat:pat) => { + ::std::sync::Arc::new(move |__hdrs: &::hyper::HeaderMap| { + match __hdrs.get($name) { + $pat => (), + other => panic!("headers[{}] was not {}: {:?}", stringify!($name), stringify!($pat), other), + } + }) as ::std::sync::Arc + }; + (@val $name: expr, NONE) => { + __internal_headers_eq!(@pat $name, None); + }; + (@val $name: expr, SOME) => { + __internal_headers_eq!(@pat $name, Some(_)); + }; + (@val $name: expr, $val:expr) => ({ + let __val = Option::from($val); + ::std::sync::Arc::new(move |__hdrs: &::hyper::HeaderMap| { + if let Some(ref val) = __val { + assert_eq!(__hdrs.get($name).expect(stringify!($name)), val.to_string().as_str(), stringify!($name)); + } else { + assert_eq!(__hdrs.get($name), None, stringify!($name)); + } + }) as ::std::sync::Arc + }); + ($headers:ident, { $($name:expr => $val:tt,)* }) => { + $( + $headers.push(__internal_headers_eq!(@val $name, $val)); + )* + } +} + +#[derive(Clone, Debug)] pub struct __CReq { pub method: &'static str, pub uri: &'static str, @@ -193,21 +245,43 @@ pub struct __CReq { pub body: Vec, } -#[derive(Clone, Debug, Default)] +impl Default for __CReq { + fn default() -> __CReq { + __CReq { + method: "GET", + uri: "/", + headers: HeaderMap::new(), + body: Vec::new(), + } + } +} + +#[derive(Clone, Default)] pub struct __CRes { pub status: hyper::StatusCode, pub body: Vec, - pub headers: HeaderMap, + pub headers: __HeadersEq, } -#[derive(Clone, Debug, Default)] +#[derive(Clone)] pub struct __SReq { pub method: &'static str, pub uri: &'static str, - pub headers: HeaderMap, + pub headers: __HeadersEq, pub body: Vec, } +impl Default for __SReq { + fn default() -> __SReq { + __SReq { + method: "GET", + uri: "/", + headers: Vec::new(), + body: Vec::new(), + } + } +} + #[derive(Clone, Debug, Default)] pub struct __SRes { pub status: hyper::StatusCode, @@ -215,6 +289,8 @@ pub struct __SRes { pub headers: HeaderMap, } +pub type __HeadersEq = Vec>; + pub struct __TestConfig { pub client_version: usize, pub client_msgs: Vec<(__CReq, __CRes)>, @@ -257,20 +333,17 @@ pub fn __run_test(cfg: __TestConfig) { .unwrap() .remove(0); - assert_eq!(req.uri().path(), sreq.uri); - assert_eq!(req.method(), &sreq.method); - assert_eq!(req.version(), version); - for (name, value) in &sreq.headers { - assert_eq!( - req.headers()[name], - value - ); + assert_eq!(req.uri().path(), sreq.uri, "client path"); + assert_eq!(req.method(), &sreq.method, "client method"); + assert_eq!(req.version(), version, "client version"); + for func in &sreq.headers { + func(&req.headers()); } let sbody = sreq.body; req.into_body() .concat2() .map(move |body| { - assert_eq!(body.as_ref(), sbody.as_slice()); + assert_eq!(body.as_ref(), sbody.as_slice(), "client body"); let mut res = Response::builder() .status(sres.status) @@ -339,18 +412,15 @@ pub fn __run_test(cfg: __TestConfig) { client.request(req) .and_then(move |res| { - assert_eq!(res.status(), cstatus); - assert_eq!(res.version(), version); - for (name, value) in &cheaders { - assert_eq!( - res.headers()[name], - value - ); + assert_eq!(res.status(), cstatus, "server status"); + assert_eq!(res.version(), version, "server version"); + for func in &cheaders { + func(&res.headers()); } res.into_body().concat2() }) .map(move |body| { - assert_eq!(body.as_ref(), cbody.as_slice()); + assert_eq!(body.as_ref(), cbody.as_slice(), "server body"); }) .map_err(|e| panic!("client error: {}", e)) });