Skip to content

Commit

Permalink
Add tests, remove unnecessary clone
Browse files Browse the repository at this point in the history
Signed-off-by: Daniele Ahmed <[email protected]>
  • Loading branch information
82marbag authored and Daniele Ahmed committed Mar 22, 2023
1 parent abf8226 commit a322ef6
Showing 1 changed file with 64 additions and 13 deletions.
77 changes: 64 additions & 13 deletions rust-runtime/aws-smithy-http-server/src/request/request_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,24 @@ where

fn call(&mut self, mut req: http::Request<Body>) -> Self::Future {
let request_id = ServerRequestId::new();
req.extensions_mut().insert(request_id.clone());
let header_key = self.header_key.clone();
ServerRequestIdResponseFuture {
request_id,
header_key: Some(header_key),
fut: self.inner.call(req),
match &self.header_key {
Some(header_key) => {
req.extensions_mut().insert(request_id.clone());
ServerRequestIdResponseFuture {
response_package: Some(ResponsePackage {
request_id,
header_key: Some(header_key.clone()),
}),
fut: self.inner.call(req),
}
}
None => {
req.extensions_mut().insert(request_id);
ServerRequestIdResponseFuture {
response_package: None,
fut: self.inner.call(req),
}
}
}
}
}
Expand All @@ -172,10 +184,14 @@ impl<Protocol> IntoResponse<Protocol> for MissingServerRequestId {
}
}

struct ResponsePackage {
request_id: ServerRequestId,
header_key: Option<HeaderName>,
}

pin_project_lite::pin_project! {
pub struct ServerRequestIdResponseFuture<Fut> {
request_id: ServerRequestId,
header_key: Option<Option<HeaderName>>,
response_package: Option<ResponsePackage>,
#[pin]
fut: Fut,
}
Expand All @@ -190,12 +206,12 @@ where
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let fut = this.fut;
let request_id = this.request_id;
let response_package = this.response_package;
fut.try_poll(cx)
.map_ok(|mut res| {
if let Some(header_key) = this.header_key.take().expect("Futures should not be polled after completion") {
if let Ok(value) = HeaderValue::from_str(&request_id.id.to_string()) {
res.headers_mut().insert(header_key, value);
if let Some(response_package) = response_package {
if let Ok(value) = HeaderValue::from_str(&response_package.request_id.id.to_string()) {
res.headers_mut().insert(response_package.header_key.take().expect("Futures should not be polled after completion"), value);
}
}
res
Expand All @@ -205,11 +221,46 @@ where

#[cfg(test)]
mod tests {
use super::ServerRequestId;
use super::*;
use crate::body::{Body, BoxBody};
use crate::request::Request;
use http::HeaderValue;
use tower::{service_fn, Service, ServiceBuilder, ServiceExt};
use std::convert::Infallible;

#[test]
fn test_request_id_parsed_by_header_value() {
assert!(HeaderValue::from_str(&ServerRequestId::new().id.to_string()).is_ok());
}

#[tokio::test]
async fn test_request_id_in_response_header() {
let svc = ServiceBuilder::new()
.layer(&ServerRequestIdProviderLayer::new_with_response_header(HeaderName::from_static("x-request-id")))
.service(service_fn(|req: Request<Body>| async move {
Ok::<_, Infallible>(Response::new(BoxBody::default()))
}));

let mut req = Request::new(Body::empty());

let res = svc.oneshot(req).await.unwrap();
let request_id = res.headers().get("x-request-id").unwrap().to_str().unwrap();

assert!(HeaderValue::from_str(request_id).is_ok());
}

#[tokio::test]
async fn test_request_id_not_in_response_header() {
let svc = ServiceBuilder::new()
.layer(&ServerRequestIdProviderLayer::new())
.service(service_fn(|req: Request<Body>| async move {
Ok::<_, Infallible>(Response::new(BoxBody::default()))
}));

let mut req = Request::new(Body::empty());

let res = svc.oneshot(req).await.unwrap();

assert!(res.headers().is_empty());
}
}

0 comments on commit a322ef6

Please sign in to comment.