Skip to content

Commit

Permalink
Add request ID to response headers (#2438)
Browse files Browse the repository at this point in the history
* Add request ID to response headers

Signed-off-by: Daniele Ahmed <[email protected]>

* Add parsing test

Signed-off-by: Daniele Ahmed <[email protected]>

* Style

Signed-off-by: Daniele Ahmed <[email protected]>

* CHANGELOG

Signed-off-by: Daniele Ahmed <[email protected]>

* Fix import

Signed-off-by: Daniele Ahmed <[email protected]>

* Panic if ServerRequestIdProviderLayer is not present

Signed-off-by: Daniele Ahmed <[email protected]>

* Own value

Signed-off-by: Daniele Ahmed <[email protected]>

* Correct docs

Signed-off-by: Daniele Ahmed <[email protected]>

* Add order of layer to expect() message

Signed-off-by: Daniele Ahmed <[email protected]>

* Remove Box

Signed-off-by: Daniele Ahmed <[email protected]>

* Require order of request ID layers

Signed-off-by: Daniele Ahmed <[email protected]>

* Revert "Require order of request ID layers"

This reverts commit 147eef2.

* One layer to generate and inject the header

Signed-off-by: Daniele Ahmed <[email protected]>

* HeaderName for header name

Signed-off-by: Daniele Ahmed <[email protected]>

* CHANGELOG

Signed-off-by: Daniele Ahmed <[email protected]>

* Remove additional layer

Signed-off-by: Daniele Ahmed <[email protected]>

* Remove to_owned

Signed-off-by: Daniele Ahmed <[email protected]>

* Add tests, remove unnecessary clone

Signed-off-by: Daniele Ahmed <[email protected]>

* take() ResponsePackage instead

Signed-off-by: Daniele Ahmed <[email protected]>

* Update docs

Signed-off-by: Daniele Ahmed <[email protected]>

* Update docs

Signed-off-by: Daniele Ahmed <[email protected]>

* cargo fmt

Signed-off-by: Daniele Ahmed <[email protected]>

* Update CHANGELOG

Signed-off-by: Daniele Ahmed <[email protected]>

---------

Signed-off-by: Daniele Ahmed <[email protected]>
Co-authored-by: Daniele Ahmed <[email protected]>
  • Loading branch information
82marbag and Daniele Ahmed authored Mar 23, 2023
1 parent abbf78f commit d89a90d
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 9 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -369,3 +369,15 @@ message = "Increase Tokio version to 1.23.1 for all crates. This is to address [
references = ["smithy-rs#2474"]
meta = { "breaking" = false, "tada" = false, "bug" = false }
author = "rcoh"

[[smithy-rs]]
message = """Servers can send the `ServerRequestId` in the response headers.
Servers need to create their service using the new layer builder `ServerRequestIdProviderLayer::new_with_response_header`:
```
let app = app
.layer(&ServerRequestIdProviderLayer::new_with_response_header(HeaderName::from_static("x-request-id")));
```
"""
references = ["smithy-rs#2438"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "server"}
author = "82marbag"
142 changes: 133 additions & 9 deletions rust-runtime/aws-smithy-http-server/src/request/request_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
//! A [`ServerRequestId`] is an opaque random identifier generated by the server every time it receives a request.
//! It uniquely identifies the request within that service instance. It can be used to collate all logs, events and
//! data related to a single operation.
//! Use [`ServerRequestIdProviderLayer::new`] to use [`ServerRequestId`] in your handler.
//!
//! The [`ServerRequestId`] can be returned to the caller, who can in turn share the [`ServerRequestId`] to help the service owner in troubleshooting issues related to their usage of the service.
//! Use [`ServerRequestIdProviderLayer::new_with_response_header`] to use [`ServerRequestId`] in your handler and add it to the response headers.
//!
//! The [`ServerRequestId`] is not meant to be propagated to downstream dependencies of the service. You should rely on a distributed tracing implementation for correlation purposes (e.g. OpenTelemetry).
//!
Expand All @@ -34,20 +36,24 @@
//! .operation(handler)
//! .build().unwrap();
//!
//! let app = app.layer(&ServerRequestIdProviderLayer::new()); /* Generate a server request ID */
//! let app = app
//! .layer(&ServerRequestIdProviderLayer::new_with_response_header(HeaderName::from_static("x-request-id"))); /* Generate a server request ID and add it to the response header */
//!
//! let bind: std::net::SocketAddr = format!("{}:{}", args.address, args.port)
//! .parse()
//! .expect("unable to parse the server bind address and port");
//! let server = hyper::Server::bind(&bind).serve(app.into_make_service());
//! ```
use std::future::Future;
use std::{
fmt::Display,
task::{Context, Poll},
};

use futures_util::TryFuture;
use http::request::Parts;
use http::{header::HeaderName, HeaderValue, Response};
use thiserror::Error;
use tower::{Layer, Service};
use uuid::Uuid;
Expand All @@ -74,6 +80,10 @@ impl ServerRequestId {
pub fn new() -> Self {
Self { id: Uuid::new_v4() }
}

pub(crate) fn to_header(&self) -> HeaderValue {
HeaderValue::from_str(&self.id.to_string()).expect("This string contains only valid ASCII")
}
}

impl Display for ServerRequestId {
Expand All @@ -99,17 +109,28 @@ impl Default for ServerRequestId {
#[derive(Clone)]
pub struct ServerRequestIdProvider<S> {
inner: S,
header_key: Option<HeaderName>,
}

/// A layer that provides services with a unique request ID instance
#[derive(Debug)]
#[non_exhaustive]
pub struct ServerRequestIdProviderLayer;
pub struct ServerRequestIdProviderLayer {
header_key: Option<HeaderName>,
}

impl ServerRequestIdProviderLayer {
/// Generate a new unique request ID
/// Generate a new unique request ID and do not add it as a response header
/// Use [`ServerRequestIdProviderLayer::new_with_response_header`] to also add it as a response header
pub fn new() -> Self {
Self {}
Self { header_key: None }
}

/// Generate a new unique request ID and add it as a response header
pub fn new_with_response_header(header_key: HeaderName) -> Self {
Self {
header_key: Some(header_key),
}
}
}

Expand All @@ -123,25 +144,47 @@ impl<S> Layer<S> for ServerRequestIdProviderLayer {
type Service = ServerRequestIdProvider<S>;

fn layer(&self, inner: S) -> Self::Service {
ServerRequestIdProvider { inner }
ServerRequestIdProvider {
inner,
header_key: self.header_key.clone(),
}
}
}

impl<Body, S> Service<http::Request<Body>> for ServerRequestIdProvider<S>
where
S: Service<http::Request<Body>>,
S: Service<http::Request<Body>, Response = Response<crate::body::BoxBody>>,
S::Future: std::marker::Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
type Future = ServerRequestIdResponseFuture<S::Future>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, mut req: http::Request<Body>) -> Self::Future {
req.extensions_mut().insert(ServerRequestId::new());
self.inner.call(req)
let request_id = ServerRequestId::new();
match &self.header_key {
Some(header_key) => {
req.extensions_mut().insert(request_id.clone());
ServerRequestIdResponseFuture {
response_package: Some(ResponsePackage {
request_id,
header_key: header_key.clone(),
}),
fut: self.inner.call(req),
}
}
None => {
req.extensions_mut().insert(request_id);
ServerRequestIdResponseFuture {
response_package: None,
fut: self.inner.call(req),
}
}
}
}
}

Expand All @@ -150,3 +193,84 @@ impl<Protocol> IntoResponse<Protocol> for MissingServerRequestId {
internal_server_error()
}
}

struct ResponsePackage {
request_id: ServerRequestId,
header_key: HeaderName,
}

pin_project_lite::pin_project! {
pub struct ServerRequestIdResponseFuture<Fut> {
response_package: Option<ResponsePackage>,
#[pin]
fut: Fut,
}
}

impl<Fut> Future for ServerRequestIdResponseFuture<Fut>
where
Fut: TryFuture<Ok = Response<crate::body::BoxBody>>,
{
type Output = Result<Fut::Ok, Fut::Error>;

fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let fut = this.fut;
let response_package = this.response_package;
fut.try_poll(cx).map_ok(|mut res| {
if let Some(response_package) = response_package.take() {
res.headers_mut()
.insert(response_package.header_key, response_package.request_id.to_header());
}
res
})
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::body::{Body, BoxBody};
use crate::request::Request;
use http::HeaderValue;
use std::convert::Infallible;
use tower::{service_fn, ServiceBuilder, ServiceExt};

#[test]
fn test_request_id_parsed_by_header_value_infallible() {
ServerRequestId::new().to_header();
}

#[tokio::test]
async fn test_request_id_in_response_header() {
let svc = ServiceBuilder::new()
.layer(&ServerRequestIdProviderLayer::new_with_response_header(
HeaderName::from_static("x-request-id"),
))
.service(service_fn(|_req: Request<Body>| async move {
Ok::<_, Infallible>(Response::new(BoxBody::default()))
}));

let req = Request::new(Body::empty());

let res = svc.oneshot(req).await.unwrap();
let request_id = res.headers().get("x-request-id").unwrap().to_str().unwrap();

assert!(HeaderValue::from_str(request_id).is_ok());
}

#[tokio::test]
async fn test_request_id_not_in_response_header() {
let svc = ServiceBuilder::new()
.layer(&ServerRequestIdProviderLayer::new())
.service(service_fn(|_req: Request<Body>| async move {
Ok::<_, Infallible>(Response::new(BoxBody::default()))
}));

let req = Request::new(Body::empty());

let res = svc.oneshot(req).await.unwrap();

assert!(res.headers().is_empty());
}
}

0 comments on commit d89a90d

Please sign in to comment.