diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index 7b85f5978d..1f568abe06 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -369,3 +369,15 @@ message = "Increase Tokio version to 1.23.1 for all crates. This is to address [ references = ["smithy-rs#2474"] meta = { "breaking" = false, "tada" = false, "bug" = false } author = "rcoh" + +[[smithy-rs]] +message = """Servers can send the `ServerRequestId` in the response headers. +Servers need to create their service using the new layer builder `ServerRequestIdProviderLayer::new_with_response_header`: +``` +let app = app + .layer(&ServerRequestIdProviderLayer::new_with_response_header(HeaderName::from_static("x-request-id"))); +``` +""" +references = ["smithy-rs#2438"] +meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "server"} +author = "82marbag" diff --git a/rust-runtime/aws-smithy-http-server/src/request/request_id.rs b/rust-runtime/aws-smithy-http-server/src/request/request_id.rs index 7894d9806e..7d6ee70ab7 100644 --- a/rust-runtime/aws-smithy-http-server/src/request/request_id.rs +++ b/rust-runtime/aws-smithy-http-server/src/request/request_id.rs @@ -12,8 +12,10 @@ //! A [`ServerRequestId`] is an opaque random identifier generated by the server every time it receives a request. //! It uniquely identifies the request within that service instance. It can be used to collate all logs, events and //! data related to a single operation. +//! Use [`ServerRequestIdProviderLayer::new`] to use [`ServerRequestId`] in your handler. //! //! The [`ServerRequestId`] can be returned to the caller, who can in turn share the [`ServerRequestId`] to help the service owner in troubleshooting issues related to their usage of the service. +//! Use [`ServerRequestIdProviderLayer::new_with_response_header`] to use [`ServerRequestId`] in your handler and add it to the response headers. //! //! The [`ServerRequestId`] is not meant to be propagated to downstream dependencies of the service. You should rely on a distributed tracing implementation for correlation purposes (e.g. OpenTelemetry). //! @@ -34,7 +36,8 @@ //! .operation(handler) //! .build().unwrap(); //! -//! let app = app.layer(&ServerRequestIdProviderLayer::new()); /* Generate a server request ID */ +//! let app = app +//! .layer(&ServerRequestIdProviderLayer::new_with_response_header(HeaderName::from_static("x-request-id"))); /* Generate a server request ID and add it to the response header */ //! //! let bind: std::net::SocketAddr = format!("{}:{}", args.address, args.port) //! .parse() @@ -42,12 +45,15 @@ //! let server = hyper::Server::bind(&bind).serve(app.into_make_service()); //! ``` +use std::future::Future; use std::{ fmt::Display, task::{Context, Poll}, }; +use futures_util::TryFuture; use http::request::Parts; +use http::{header::HeaderName, HeaderValue, Response}; use thiserror::Error; use tower::{Layer, Service}; use uuid::Uuid; @@ -74,6 +80,10 @@ impl ServerRequestId { pub fn new() -> Self { Self { id: Uuid::new_v4() } } + + pub(crate) fn to_header(&self) -> HeaderValue { + HeaderValue::from_str(&self.id.to_string()).expect("This string contains only valid ASCII") + } } impl Display for ServerRequestId { @@ -99,17 +109,28 @@ impl Default for ServerRequestId { #[derive(Clone)] pub struct ServerRequestIdProvider { inner: S, + header_key: Option, } /// A layer that provides services with a unique request ID instance #[derive(Debug)] #[non_exhaustive] -pub struct ServerRequestIdProviderLayer; +pub struct ServerRequestIdProviderLayer { + header_key: Option, +} impl ServerRequestIdProviderLayer { - /// Generate a new unique request ID + /// Generate a new unique request ID and do not add it as a response header + /// Use [`ServerRequestIdProviderLayer::new_with_response_header`] to also add it as a response header pub fn new() -> Self { - Self {} + Self { header_key: None } + } + + /// Generate a new unique request ID and add it as a response header + pub fn new_with_response_header(header_key: HeaderName) -> Self { + Self { + header_key: Some(header_key), + } } } @@ -123,25 +144,47 @@ impl Layer for ServerRequestIdProviderLayer { type Service = ServerRequestIdProvider; fn layer(&self, inner: S) -> Self::Service { - ServerRequestIdProvider { inner } + ServerRequestIdProvider { + inner, + header_key: self.header_key.clone(), + } } } impl Service> for ServerRequestIdProvider where - S: Service>, + S: Service, Response = Response>, + S::Future: std::marker::Send + 'static, { type Response = S::Response; type Error = S::Error; - type Future = S::Future; + type Future = ServerRequestIdResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, mut req: http::Request) -> Self::Future { - req.extensions_mut().insert(ServerRequestId::new()); - self.inner.call(req) + let request_id = ServerRequestId::new(); + match &self.header_key { + Some(header_key) => { + req.extensions_mut().insert(request_id.clone()); + ServerRequestIdResponseFuture { + response_package: Some(ResponsePackage { + request_id, + header_key: header_key.clone(), + }), + fut: self.inner.call(req), + } + } + None => { + req.extensions_mut().insert(request_id); + ServerRequestIdResponseFuture { + response_package: None, + fut: self.inner.call(req), + } + } + } } } @@ -150,3 +193,84 @@ impl IntoResponse for MissingServerRequestId { internal_server_error() } } + +struct ResponsePackage { + request_id: ServerRequestId, + header_key: HeaderName, +} + +pin_project_lite::pin_project! { + pub struct ServerRequestIdResponseFuture { + response_package: Option, + #[pin] + fut: Fut, + } +} + +impl Future for ServerRequestIdResponseFuture +where + Fut: TryFuture>, +{ + type Output = Result; + + fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let fut = this.fut; + let response_package = this.response_package; + fut.try_poll(cx).map_ok(|mut res| { + if let Some(response_package) = response_package.take() { + res.headers_mut() + .insert(response_package.header_key, response_package.request_id.to_header()); + } + res + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::body::{Body, BoxBody}; + use crate::request::Request; + use http::HeaderValue; + use std::convert::Infallible; + use tower::{service_fn, ServiceBuilder, ServiceExt}; + + #[test] + fn test_request_id_parsed_by_header_value_infallible() { + ServerRequestId::new().to_header(); + } + + #[tokio::test] + async fn test_request_id_in_response_header() { + let svc = ServiceBuilder::new() + .layer(&ServerRequestIdProviderLayer::new_with_response_header( + HeaderName::from_static("x-request-id"), + )) + .service(service_fn(|_req: Request| async move { + Ok::<_, Infallible>(Response::new(BoxBody::default())) + })); + + let req = Request::new(Body::empty()); + + let res = svc.oneshot(req).await.unwrap(); + let request_id = res.headers().get("x-request-id").unwrap().to_str().unwrap(); + + assert!(HeaderValue::from_str(request_id).is_ok()); + } + + #[tokio::test] + async fn test_request_id_not_in_response_header() { + let svc = ServiceBuilder::new() + .layer(&ServerRequestIdProviderLayer::new()) + .service(service_fn(|_req: Request| async move { + Ok::<_, Infallible>(Response::new(BoxBody::default())) + })); + + let req = Request::new(Body::empty()); + + let res = svc.oneshot(req).await.unwrap(); + + assert!(res.headers().is_empty()); + } +}