diff --git a/rust-runtime/aws-smithy-http-server/src/request/request_id.rs b/rust-runtime/aws-smithy-http-server/src/request/request_id.rs index 1a4bb44477..49f5c6013d 100644 --- a/rust-runtime/aws-smithy-http-server/src/request/request_id.rs +++ b/rust-runtime/aws-smithy-http-server/src/request/request_id.rs @@ -80,6 +80,10 @@ impl ServerRequestId { pub fn new() -> Self { Self { id: Uuid::new_v4() } } + + pub(crate) fn to_header(&self) -> HeaderValue { + HeaderValue::from_str(&self.id.to_string()).expect("This string contains only valid ASCII") + } } impl Display for ServerRequestId { @@ -162,7 +166,7 @@ where ServerRequestIdResponseFuture { response_package: Some(ResponsePackage { request_id, - header_key: Some(header_key.clone()), + header_key: header_key.clone(), }), fut: self.inner.call(req), } @@ -186,7 +190,7 @@ impl IntoResponse for MissingServerRequestId { struct ResponsePackage { request_id: ServerRequestId, - header_key: Option, + header_key: HeaderName, } pin_project_lite::pin_project! { @@ -209,10 +213,8 @@ where let response_package = this.response_package; fut.try_poll(cx) .map_ok(|mut res| { - if let Some(response_package) = response_package { - if let Ok(value) = HeaderValue::from_str(&response_package.request_id.id.to_string()) { - res.headers_mut().insert(response_package.header_key.take().expect("Futures should not be polled after completion"), value); - } + if let Some(response_package) = response_package.take() { + res.headers_mut().insert(response_package.header_key, response_package.request_id.to_header()); } res }) @@ -225,23 +227,23 @@ mod tests { use crate::body::{Body, BoxBody}; use crate::request::Request; use http::HeaderValue; - use tower::{service_fn, Service, ServiceBuilder, ServiceExt}; + use tower::{service_fn, ServiceBuilder, ServiceExt}; use std::convert::Infallible; #[test] - fn test_request_id_parsed_by_header_value() { - assert!(HeaderValue::from_str(&ServerRequestId::new().id.to_string()).is_ok()); + fn test_request_id_parsed_by_header_value_infallible() { + ServerRequestId::new().to_header(); } #[tokio::test] async fn test_request_id_in_response_header() { let svc = ServiceBuilder::new() .layer(&ServerRequestIdProviderLayer::new_with_response_header(HeaderName::from_static("x-request-id"))) - .service(service_fn(|req: Request| async move { + .service(service_fn(|_req: Request| async move { Ok::<_, Infallible>(Response::new(BoxBody::default())) })); - let mut req = Request::new(Body::empty()); + let req = Request::new(Body::empty()); let res = svc.oneshot(req).await.unwrap(); let request_id = res.headers().get("x-request-id").unwrap().to_str().unwrap(); @@ -253,11 +255,11 @@ mod tests { async fn test_request_id_not_in_response_header() { let svc = ServiceBuilder::new() .layer(&ServerRequestIdProviderLayer::new()) - .service(service_fn(|req: Request| async move { + .service(service_fn(|_req: Request| async move { Ok::<_, Infallible>(Response::new(BoxBody::default())) })); - let mut req = Request::new(Body::empty()); + let req = Request::new(Body::empty()); let res = svc.oneshot(req).await.unwrap();