Skip to content

Commit

Permalink
take() ResponsePackage instead
Browse files Browse the repository at this point in the history
Signed-off-by: Daniele Ahmed <[email protected]>
  • Loading branch information
82marbag authored and Daniele Ahmed committed Mar 22, 2023
1 parent 38d0302 commit a244b9d
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions rust-runtime/aws-smithy-http-server/src/request/request_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ impl ServerRequestId {
pub fn new() -> Self {
Self { id: Uuid::new_v4() }
}

pub(crate) fn to_header(&self) -> HeaderValue {
HeaderValue::from_str(&self.id.to_string()).expect("This string contains only valid ASCII")
}
}

impl Display for ServerRequestId {
Expand Down Expand Up @@ -162,7 +166,7 @@ where
ServerRequestIdResponseFuture {
response_package: Some(ResponsePackage {
request_id,
header_key: Some(header_key.clone()),
header_key: header_key.clone(),
}),
fut: self.inner.call(req),
}
Expand All @@ -186,7 +190,7 @@ impl<Protocol> IntoResponse<Protocol> for MissingServerRequestId {

struct ResponsePackage {
request_id: ServerRequestId,
header_key: Option<HeaderName>,
header_key: HeaderName,
}

pin_project_lite::pin_project! {
Expand All @@ -209,10 +213,8 @@ where
let response_package = this.response_package;
fut.try_poll(cx)
.map_ok(|mut res| {
if let Some(response_package) = response_package {
if let Ok(value) = HeaderValue::from_str(&response_package.request_id.id.to_string()) {
res.headers_mut().insert(response_package.header_key.take().expect("Futures should not be polled after completion"), value);
}
if let Some(response_package) = response_package.take() {
res.headers_mut().insert(response_package.header_key, response_package.request_id.to_header());
}
res
})
Expand All @@ -225,23 +227,23 @@ mod tests {
use crate::body::{Body, BoxBody};
use crate::request::Request;
use http::HeaderValue;
use tower::{service_fn, Service, ServiceBuilder, ServiceExt};
use tower::{service_fn, ServiceBuilder, ServiceExt};
use std::convert::Infallible;

#[test]
fn test_request_id_parsed_by_header_value() {
assert!(HeaderValue::from_str(&ServerRequestId::new().id.to_string()).is_ok());
fn test_request_id_parsed_by_header_value_infallible() {
ServerRequestId::new().to_header();
}

#[tokio::test]
async fn test_request_id_in_response_header() {
let svc = ServiceBuilder::new()
.layer(&ServerRequestIdProviderLayer::new_with_response_header(HeaderName::from_static("x-request-id")))
.service(service_fn(|req: Request<Body>| async move {
.service(service_fn(|_req: Request<Body>| async move {
Ok::<_, Infallible>(Response::new(BoxBody::default()))
}));

let mut req = Request::new(Body::empty());
let req = Request::new(Body::empty());

let res = svc.oneshot(req).await.unwrap();
let request_id = res.headers().get("x-request-id").unwrap().to_str().unwrap();
Expand All @@ -253,11 +255,11 @@ mod tests {
async fn test_request_id_not_in_response_header() {
let svc = ServiceBuilder::new()
.layer(&ServerRequestIdProviderLayer::new())
.service(service_fn(|req: Request<Body>| async move {
.service(service_fn(|_req: Request<Body>| async move {
Ok::<_, Infallible>(Response::new(BoxBody::default()))
}));

let mut req = Request::new(Body::empty());
let req = Request::new(Body::empty());

let res = svc.oneshot(req).await.unwrap();

Expand Down

0 comments on commit a244b9d

Please sign in to comment.