Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add request ID to response headers #2438

Merged
merged 24 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -369,3 +369,15 @@ message = "Increase Tokio version to 1.23.1 for all crates. This is to address [
references = ["smithy-rs#2474"]
meta = { "breaking" = false, "tada" = false, "bug" = false }
author = "rcoh"

[[smithy-rs]]
message = """Servers can send the `ServerRequestId` in the response headers.
Servers need to create their service using the new layer builder `ServerRequestIdProviderLayer::new_with_response_header`:
```
let app = app
.layer(&ServerRequestIdProviderLayer::new_with_response_header(HeaderName::from_static("x-request-id")));
```
"""
references = ["smithy-rs#2438"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "server"}
author = "82marbag"
142 changes: 133 additions & 9 deletions rust-runtime/aws-smithy-http-server/src/request/request_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
//! A [`ServerRequestId`] is an opaque random identifier generated by the server every time it receives a request.
//! It uniquely identifies the request within that service instance. It can be used to collate all logs, events and
//! data related to a single operation.
//! Use [`ServerRequestIdProviderLayer::new`] to use [`ServerRequestId`] in your handler.
//!
//! The [`ServerRequestId`] can be returned to the caller, who can in turn share the [`ServerRequestId`] to help the service owner in troubleshooting issues related to their usage of the service.
//! Use [`ServerRequestIdProviderLayer::new_with_response_header`] to use [`ServerRequestId`] in your handler and add it to the response headers.
//!
//! The [`ServerRequestId`] is not meant to be propagated to downstream dependencies of the service. You should rely on a distributed tracing implementation for correlation purposes (e.g. OpenTelemetry).
//!
Expand All @@ -34,20 +36,24 @@
//! .operation(handler)
//! .build().unwrap();
//!
//! let app = app.layer(&ServerRequestIdProviderLayer::new()); /* Generate a server request ID */
//! let app = app
//! .layer(&ServerRequestIdProviderLayer::new_with_response_header(HeaderName::from_static("x-request-id"))); /* Generate a server request ID and add it to the response header */
//!
//! let bind: std::net::SocketAddr = format!("{}:{}", args.address, args.port)
//! .parse()
//! .expect("unable to parse the server bind address and port");
//! let server = hyper::Server::bind(&bind).serve(app.into_make_service());
//! ```

use std::future::Future;
use std::{
fmt::Display,
task::{Context, Poll},
};

use futures_util::TryFuture;
use http::request::Parts;
use http::{header::HeaderName, HeaderValue, Response};
use thiserror::Error;
use tower::{Layer, Service};
use uuid::Uuid;
Expand All @@ -74,6 +80,10 @@ impl ServerRequestId {
pub fn new() -> Self {
Self { id: Uuid::new_v4() }
}

pub(crate) fn to_header(&self) -> HeaderValue {
HeaderValue::from_str(&self.id.to_string()).expect("This string contains only valid ASCII")
}
}

impl Display for ServerRequestId {
Expand All @@ -99,17 +109,28 @@ impl Default for ServerRequestId {
#[derive(Clone)]
pub struct ServerRequestIdProvider<S> {
inner: S,
header_key: Option<HeaderName>,
}

/// A layer that provides services with a unique request ID instance
#[derive(Debug)]
#[non_exhaustive]
pub struct ServerRequestIdProviderLayer;
pub struct ServerRequestIdProviderLayer {
header_key: Option<HeaderName>,
}

impl ServerRequestIdProviderLayer {
/// Generate a new unique request ID
/// Generate a new unique request ID and do not add it as a response header
/// Use [`ServerRequestIdProviderLayer::new_with_response_header`] to also add it as a response header
pub fn new() -> Self {
Self {}
Self { header_key: None }
}

/// Generate a new unique request ID and add it as a response header
pub fn new_with_response_header(header_key: HeaderName) -> Self {
Self {
header_key: Some(header_key),
}
}
}

Expand All @@ -123,25 +144,47 @@ impl<S> Layer<S> for ServerRequestIdProviderLayer {
type Service = ServerRequestIdProvider<S>;

fn layer(&self, inner: S) -> Self::Service {
ServerRequestIdProvider { inner }
ServerRequestIdProvider {
inner,
header_key: self.header_key.clone(),
}
}
}

impl<Body, S> Service<http::Request<Body>> for ServerRequestIdProvider<S>
where
S: Service<http::Request<Body>>,
S: Service<http::Request<Body>, Response = Response<crate::body::BoxBody>>,
S::Future: std::marker::Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
type Future = ServerRequestIdResponseFuture<S::Future>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, mut req: http::Request<Body>) -> Self::Future {
req.extensions_mut().insert(ServerRequestId::new());
self.inner.call(req)
let request_id = ServerRequestId::new();
match &self.header_key {
Some(header_key) => {
req.extensions_mut().insert(request_id.clone());
ServerRequestIdResponseFuture {
response_package: Some(ResponsePackage {
request_id,
header_key: header_key.clone(),
}),
fut: self.inner.call(req),
}
}
None => {
req.extensions_mut().insert(request_id);
ServerRequestIdResponseFuture {
response_package: None,
fut: self.inner.call(req),
}
}
}
}
}

Expand All @@ -150,3 +193,84 @@ impl<Protocol> IntoResponse<Protocol> for MissingServerRequestId {
internal_server_error()
}
}

struct ResponsePackage {
request_id: ServerRequestId,
header_key: HeaderName,
}

pin_project_lite::pin_project! {
pub struct ServerRequestIdResponseFuture<Fut> {
response_package: Option<ResponsePackage>,
#[pin]
fut: Fut,
}
}

impl<Fut> Future for ServerRequestIdResponseFuture<Fut>
where
Fut: TryFuture<Ok = Response<crate::body::BoxBody>>,
{
type Output = Result<Fut::Ok, Fut::Error>;

fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let fut = this.fut;
let response_package = this.response_package;
fut.try_poll(cx).map_ok(|mut res| {
if let Some(response_package) = response_package.take() {
res.headers_mut()
.insert(response_package.header_key, response_package.request_id.to_header());
}
res
})
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::body::{Body, BoxBody};
use crate::request::Request;
use http::HeaderValue;
use std::convert::Infallible;
use tower::{service_fn, ServiceBuilder, ServiceExt};

#[test]
fn test_request_id_parsed_by_header_value_infallible() {
ServerRequestId::new().to_header();
}

#[tokio::test]
async fn test_request_id_in_response_header() {
let svc = ServiceBuilder::new()
.layer(&ServerRequestIdProviderLayer::new_with_response_header(
HeaderName::from_static("x-request-id"),
))
.service(service_fn(|_req: Request<Body>| async move {
Ok::<_, Infallible>(Response::new(BoxBody::default()))
}));

let req = Request::new(Body::empty());

let res = svc.oneshot(req).await.unwrap();
let request_id = res.headers().get("x-request-id").unwrap().to_str().unwrap();

assert!(HeaderValue::from_str(request_id).is_ok());
}

#[tokio::test]
async fn test_request_id_not_in_response_header() {
let svc = ServiceBuilder::new()
.layer(&ServerRequestIdProviderLayer::new())
.service(service_fn(|_req: Request<Body>| async move {
Ok::<_, Infallible>(Response::new(BoxBody::default()))
}));

let req = Request::new(Body::empty());

let res = svc.oneshot(req).await.unwrap();

assert!(res.headers().is_empty());
}
}