Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add request ID to response headers #2438

Merged
merged 24 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -369,3 +369,15 @@ message = "Increase Tokio version to 1.23.1 for all crates. This is to address [
references = ["smithy-rs#2474"]
meta = { "breaking" = false, "tada" = false, "bug" = false }
author = "rcoh"

[[smithy-rs]]
message = """Servers can send the `ServerRequestId` in the response headers.
Servers need to create their service using the new layer builder `ServerRequestIdProviderLayer::new_with_response_header`:
```
let app = app
.layer(&ServerRequestIdProviderLayer::new_with_response_header(HeaderName::from_static("x-request-id")));
```
"""
references = ["smithy-rs#2438"]
meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "server"}
Copy link
Contributor

@hlbarber hlbarber Mar 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, this is a breaking change:

- type Future = S::Future;
+ type Future = ServerRequestIdResponseFuture<S::Future>;

author = "82marbag"
134 changes: 126 additions & 8 deletions rust-runtime/aws-smithy-http-server/src/request/request_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
//!
//! The [`ServerRequestId`] is not meant to be propagated to downstream dependencies of the service. You should rely on a distributed tracing implementation for correlation purposes (e.g. OpenTelemetry).
//!
//! To optionally add the [`ServerRequestId`] to the response headers, use [`ServerRequestIdProviderLayer::new_with_response_header`].
//! Use [`ServerRequestIdProviderLayer::new`] to use [`ServerRequestId`] in your handler.
//! Use [`ServerRequestIdProviderLayer::new_with_response_header`] to use [`ServerRequestId`] in your handler and add it to the response headers.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We speak of the fact that the server request id can be returned to the caller a couple of paragraphs above, I think it makes sense to mention the relevant constructor there.

//!
//! ## Examples
//!
//! Your handler can now optionally take as input a [`ServerRequestId`].
Expand All @@ -34,7 +38,8 @@
//! .operation(handler)
//! .build().unwrap();
//!
//! let app = app.layer(&ServerRequestIdProviderLayer::new()); /* Generate a server request ID */
//! let app = app
//! .layer(&ServerRequestIdProviderLayer::new_with_response_header(HeaderName::from_static("x-request-id"))); /* Generate a server request ID and add it to the response header */
//!
//! let bind: std::net::SocketAddr = format!("{}:{}", args.address, args.port)
//! .parse()
Expand All @@ -46,8 +51,11 @@ use std::{
fmt::Display,
task::{Context, Poll},
};
use std::future::Future;

use futures_util::TryFuture;
use http::request::Parts;
use http::{header::HeaderName, HeaderValue, Response};
use thiserror::Error;
use tower::{Layer, Service};
use uuid::Uuid;
Expand All @@ -74,6 +82,10 @@ impl ServerRequestId {
pub fn new() -> Self {
Self { id: Uuid::new_v4() }
}

pub(crate) fn to_header(&self) -> HeaderValue {
HeaderValue::from_str(&self.id.to_string()).expect("This string contains only valid ASCII")
}
}

impl Display for ServerRequestId {
Expand All @@ -99,17 +111,25 @@ impl Default for ServerRequestId {
#[derive(Clone)]
pub struct ServerRequestIdProvider<S> {
inner: S,
header_key: Option<HeaderName>,
}

/// A layer that provides services with a unique request ID instance
#[derive(Debug)]
#[non_exhaustive]
pub struct ServerRequestIdProviderLayer;
pub struct ServerRequestIdProviderLayer {
header_key: Option<HeaderName>,
}

impl ServerRequestIdProviderLayer {
/// Generate a new unique request ID
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd update this to mention this other constructor and explain that it won't add it to the response headers (and vice versa in the docs for the other method).

pub fn new() -> Self {
Self {}
Self { header_key: None, }
}

/// Generate a new unique request ID and add it as a response header
pub fn new_with_response_header(header_key: HeaderName) -> Self {
Self { header_key: Some(header_key), }
}
}

Expand All @@ -123,25 +143,44 @@ impl<S> Layer<S> for ServerRequestIdProviderLayer {
type Service = ServerRequestIdProvider<S>;

fn layer(&self, inner: S) -> Self::Service {
ServerRequestIdProvider { inner }
ServerRequestIdProvider { inner, header_key: self.header_key.clone() }
}
}

impl<Body, S> Service<http::Request<Body>> for ServerRequestIdProvider<S>
where
S: Service<http::Request<Body>>,
S: Service<http::Request<Body>, Response = Response<crate::body::BoxBody>>,
S::Future: std::marker::Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
type Future = ServerRequestIdResponseFuture<S::Future>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, mut req: http::Request<Body>) -> Self::Future {
req.extensions_mut().insert(ServerRequestId::new());
self.inner.call(req)
let request_id = ServerRequestId::new();
match &self.header_key {
Some(header_key) => {
req.extensions_mut().insert(request_id.clone());
ServerRequestIdResponseFuture {
response_package: Some(ResponsePackage {
request_id,
header_key: header_key.clone(),
}),
fut: self.inner.call(req),
}
}
None => {
req.extensions_mut().insert(request_id);
ServerRequestIdResponseFuture {
response_package: None,
fut: self.inner.call(req),
}
}
}
}
}

Expand All @@ -150,3 +189,82 @@ impl<Protocol> IntoResponse<Protocol> for MissingServerRequestId {
internal_server_error()
}
}

struct ResponsePackage {
request_id: ServerRequestId,
header_key: HeaderName,
}

pin_project_lite::pin_project! {
pub struct ServerRequestIdResponseFuture<Fut> {
response_package: Option<ResponsePackage>,
#[pin]
fut: Fut,
}
}

impl<Fut> Future for ServerRequestIdResponseFuture<Fut>
where
Fut: TryFuture<Ok = Response<crate::body::BoxBody>>,
{
type Output = Result<Fut::Ok, Fut::Error>;

fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let fut = this.fut;
let response_package = this.response_package;
fut.try_poll(cx)
.map_ok(|mut res| {
if let Some(response_package) = response_package.take() {
res.headers_mut().insert(response_package.header_key, response_package.request_id.to_header());
}
res
})
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::body::{Body, BoxBody};
use crate::request::Request;
use http::HeaderValue;
use tower::{service_fn, ServiceBuilder, ServiceExt};
use std::convert::Infallible;

#[test]
fn test_request_id_parsed_by_header_value_infallible() {
ServerRequestId::new().to_header();
}

#[tokio::test]
async fn test_request_id_in_response_header() {
let svc = ServiceBuilder::new()
.layer(&ServerRequestIdProviderLayer::new_with_response_header(HeaderName::from_static("x-request-id")))
.service(service_fn(|_req: Request<Body>| async move {
Ok::<_, Infallible>(Response::new(BoxBody::default()))
}));

let req = Request::new(Body::empty());

let res = svc.oneshot(req).await.unwrap();
let request_id = res.headers().get("x-request-id").unwrap().to_str().unwrap();

assert!(HeaderValue::from_str(request_id).is_ok());
}

#[tokio::test]
async fn test_request_id_not_in_response_header() {
let svc = ServiceBuilder::new()
.layer(&ServerRequestIdProviderLayer::new())
.service(service_fn(|_req: Request<Body>| async move {
Ok::<_, Infallible>(Response::new(BoxBody::default()))
}));

let req = Request::new(Body::empty());

let res = svc.oneshot(req).await.unwrap();

assert!(res.headers().is_empty());
}
}