From 95bc9049c1af146884cb55b9a352e27bcc8e5d40 Mon Sep 17 00:00:00 2001 From: oddgrd <29732646+oddgrd@users.noreply.github.com> Date: Sun, 13 Nov 2022 21:26:28 +0100 Subject: [PATCH] feat: serialize the full HTTP req/res to rmp --- runtime/Cargo.toml | 3 +- runtime/src/axum/mod.rs | 35 ++++++------- tmp/axum-wasm/src/lib.rs | 30 ++++++----- tmp/utils/Cargo.toml | 13 +++-- tmp/utils/src/lib.rs | 107 ++++++++++++++++++++++++++++----------- 5 files changed, 122 insertions(+), 66 deletions(-) diff --git a/runtime/Cargo.toml b/runtime/Cargo.toml index 5478309bea..4e3e418933 100644 --- a/runtime/Cargo.toml +++ b/runtime/Cargo.toml @@ -12,7 +12,7 @@ cap-std = "0.26.0" clap ={ version = "4.0.18", features = ["derive"] } serenity = { version = "0.11.5", default-features = false, features = ["client", "gateway", "rustls_backend", "model"] } thiserror = "1.0.37" -hyper = "1.0.0-rc.1" +hyper = "0.14.23" tokio = { version = "=1.20.1", features = ["full"] } tonic = "0.8.2" tracing = "0.1.37" @@ -20,6 +20,7 @@ tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } wasi-common = "2.0.0" wasmtime = "2.0.0" wasmtime-wasi = "2.0.0" +http-body = "0.4.5" [dependencies.shuttle-common] version = "0.7.0" diff --git a/runtime/src/axum/mod.rs b/runtime/src/axum/mod.rs index 9f8a559a66..cd9b4a0dd3 100644 --- a/runtime/src/axum/mod.rs +++ b/runtime/src/axum/mod.rs @@ -6,8 +6,10 @@ use std::sync::{Arc, Mutex}; use async_trait::async_trait; use cap_std::os::unix::net::UnixStream; +use http_body::Full; +use hyper::body::Bytes; use hyper::Response; -use shuttle_axum_utils::{RequestWrapper, ResponseWrapper}; +use shuttle_axum_utils::{wrap_request, RequestWrapper, ResponseWrapper}; use shuttle_proto::runtime::runtime_server::Runtime; use shuttle_proto::runtime::{LoadRequest, LoadResponse, StartRequest, StartResponse}; use tonic::Status; @@ -125,9 +127,8 @@ struct RouterInner { } impl RouterInner { - /// Send a HTTP request to given endpoint on the axum-wasm router and return the response - /// todo: also send and receive the body - pub async fn send_request(&mut self, req: hyper::Request) -> Response { + /// Send a HTTP request with body to given endpoint on the axum-wasm router and return the response + pub async fn send_request(&mut self, req: hyper::Request>) -> Response> { let (mut host, client) = UnixStream::pair().unwrap(); let client = WasiUnixStream::from_cap_std(client); @@ -136,7 +137,7 @@ impl RouterInner { .insert_file(3, Box::new(client), FileCaps::all()); // serialise request to rmp - let request_rmp = RequestWrapper::from(req).into_rmp(); + let request_rmp = wrap_request(req).await.into_rmp(); host.write_all(&request_rmp).unwrap(); host.write(&[0]).unwrap(); @@ -158,14 +159,8 @@ impl RouterInner { // deserialize response from rmp let res = ResponseWrapper::from_rmp(res_buf); - // todo: clean up conversion of wrapper to request - let mut response = Response::builder().status(res.status).version(res.version); - response - .headers_mut() - .unwrap() - .extend(res.headers.into_iter()); - - response.body("Some body".to_string()).unwrap() + // consume the wrapper and return response + res.into_response() } } @@ -196,36 +191,38 @@ pub mod tests { let mut inner = axum.inner.lock().unwrap(); // GET /hello - let request: Request = Request::builder() + let request: Request> = Request::builder() .method(Method::GET) .version(Version::HTTP_11) .header("test", HeaderValue::from_static("hello")) .uri(format!("https://axum-wasm.example/hello")) - .body("Some body".to_string()) + .body(Full::new(Bytes::from_static(b"some body"))) .unwrap(); let res = inner.send_request(request).await; assert_eq!(res.status(), StatusCode::OK); + assert_eq!(std::str::from_utf8(&res.body()).unwrap(), "Hello, World!"); // GET /goodbye - let request: Request = Request::builder() + let request: Request> = Request::builder() .method(Method::GET) .version(Version::HTTP_11) .header("test", HeaderValue::from_static("goodbye")) .uri(format!("https://axum-wasm.example/goodbye")) - .body("Some body".to_string()) + .body(Full::new(Bytes::from_static(b"some body"))) .unwrap(); let res = inner.send_request(request).await; assert_eq!(res.status(), StatusCode::OK); + assert_eq!(std::str::from_utf8(&res.body()).unwrap(), "Goodbye, World!"); // GET /invalid - let request: Request = Request::builder() + let request: Request> = Request::builder() .method(Method::GET) .version(Version::HTTP_11) .header("test", HeaderValue::from_static("invalid")) .uri(format!("https://axum-wasm.example/invalid")) - .body("Some body".to_string()) + .body(Full::new(Bytes::from_static(b"some body"))) .unwrap(); let res = inner.send_request(request).await; diff --git a/tmp/axum-wasm/src/lib.rs b/tmp/axum-wasm/src/lib.rs index f8b0a2eebb..230032a927 100644 --- a/tmp/axum-wasm/src/lib.rs +++ b/tmp/axum-wasm/src/lib.rs @@ -1,17 +1,24 @@ +use axum::body::HttpBody; use axum::{response::Response, routing::get, Router}; use futures_executor::block_on; use http::Request; -use shuttle_axum_utils::{RequestWrapper, ResponseWrapper}; +use shuttle_axum_utils::{wrap_response, RequestWrapper}; use std::fs::File; use std::io::{Read, Write}; use std::os::wasi::prelude::*; use tower_service::Service; -pub fn handle_request(req: Request) -> Response { +pub fn handle_request(req: Request) -> Response +where + B: HttpBody + Send + 'static, +{ block_on(app(req)) } -async fn app(request: Request) -> Response { +async fn app(request: Request) -> Response +where + B: HttpBody + Send + 'static, +{ let mut router = Router::new() .route("/hello", get(hello)) .route("/goodbye", get(goodbye)) @@ -48,23 +55,18 @@ pub extern "C" fn __SHUTTLE_Axum_call(fd: RawFd) { } } + // deserialize request from rust messagepack let req = RequestWrapper::from_rmp(req_buf); - // todo: clean up conversion of wrapper to request - let mut request: Request = Request::builder() - .method(req.method) - .version(req.version) - .uri(req.uri) - .body("Some body".to_string()) - .unwrap(); - - request.headers_mut().extend(req.headers.into_iter()); + // consume wrapper and return Request + let request = req.into_request(); println!("inner router received request: {:?}", &request); let res = handle_request(request); println!("inner router sending response: {:?}", &res); - let response = ResponseWrapper::from(res); + // wrap inner response and serialize it as rust messagepack + let response = block_on(wrap_response(res)).into_rmp(); - f.write_all(&response.into_rmp()).unwrap(); + f.write_all(&response).unwrap(); } diff --git a/tmp/utils/Cargo.toml b/tmp/utils/Cargo.toml index 691214143b..52d422674b 100644 --- a/tmp/utils/Cargo.toml +++ b/tmp/utils/Cargo.toml @@ -2,11 +2,18 @@ name = "shuttle-axum-utils" version = "0.1.0" edition = "2021" - +description = "Utilities for serializing requests to and from rust messagepack" [lib] [dependencies] http = "0.2.7" -serde = { version = "1.0.137", features = [ "derive" ] } -http-serde = { version ="1.1.2" } +http-body = "0.4.5" +http-serde = { version = "1.1.2" } +# hyper dep because I was struggling to turn http body to bytes with hyper::body::to_bytes +hyper = "0.14.23" rmp-serde = { version = "1.1.1" } +serde = { version = "1.0.137", features = [ "derive" ] } + +[dev-dependencies] +# unit tests have to call an async function to wrap req/res +futures-executor = "0.3.21" diff --git a/tmp/utils/src/lib.rs b/tmp/utils/src/lib.rs index 477f1d86db..82ccf43d2f 100644 --- a/tmp/utils/src/lib.rs +++ b/tmp/utils/src/lib.rs @@ -1,10 +1,12 @@ use http::{HeaderMap, Method, Request, Response, StatusCode, Uri, Version}; +use http_body::{Body, Full}; +use hyper::body::Bytes; use rmps::{Deserializer, Serializer}; use serde::{Deserialize, Serialize}; extern crate rmp_serde as rmps; -// todo: add extensions +// todo: add http extensions field #[derive(Serialize, Deserialize, Debug)] pub struct RequestWrapper { #[serde(with = "http_serde::method")] @@ -18,18 +20,27 @@ pub struct RequestWrapper { #[serde(with = "http_serde::header_map")] pub headers: HeaderMap, -} -impl From> for RequestWrapper { - fn from(req: Request) -> Self { - let (parts, _) = req.into_parts(); + // I used Vec since it can derive serialize/deserialize + pub body: Vec, +} - Self { - method: parts.method, - uri: parts.uri, - version: parts.version, - headers: parts.headers, - } +/// Wrap HTTP Request in a struct that can be serialized to and from Rust MessagePack +pub async fn wrap_request(req: Request) -> RequestWrapper +where + B: Body, + B::Error: std::fmt::Debug, +{ + let (parts, body) = req.into_parts(); + + let body = hyper::body::to_bytes(body).await.unwrap(); + + RequestWrapper { + method: parts.method, + uri: parts.uri, + version: parts.version, + headers: parts.headers, + body: body.into(), } } @@ -48,9 +59,23 @@ impl RequestWrapper { Deserialize::deserialize(&mut de).unwrap() } + + /// Consume wrapper and return Request + pub fn into_request(self) -> Request> { + let mut request: Request> = Request::builder() + .method(self.method) + .version(self.version) + .uri(self.uri) + .body(Full::new(self.body.into())) + .unwrap(); + + request.headers_mut().extend(self.headers.into_iter()); + + request + } } -// todo: add extensions +// todo: add http extensions field #[derive(Serialize, Deserialize, Debug)] pub struct ResponseWrapper { #[serde(with = "http_serde::status_code")] @@ -61,17 +86,26 @@ pub struct ResponseWrapper { #[serde(with = "http_serde::header_map")] pub headers: HeaderMap, -} -impl From> for ResponseWrapper { - fn from(res: Response) -> Self { - let (parts, _) = res.into_parts(); + // I used Vec since it can derive serialize/deserialize + pub body: Vec, +} - Self { - status: parts.status, - version: parts.version, - headers: parts.headers, - } +/// Wrap HTTP Response in a struct that can be serialized to and from Rust MessagePack +pub async fn wrap_response(res: Response) -> ResponseWrapper +where + B: Body, + B::Error: std::fmt::Debug, +{ + let (parts, body) = res.into_parts(); + + let body = hyper::body::to_bytes(body).await.unwrap(); + + ResponseWrapper { + status: parts.status, + version: parts.version, + headers: parts.headers, + body: body.into(), } } @@ -90,25 +124,38 @@ impl ResponseWrapper { Deserialize::deserialize(&mut de).unwrap() } + + /// Consume wrapper and return Response + pub fn into_response(self) -> Response> { + let mut response = Response::builder() + .status(self.status) + .version(self.version); + response + .headers_mut() + .unwrap() + .extend(self.headers.into_iter()); + + response.body(self.body).unwrap() + } } #[cfg(test)] mod test { - use http::HeaderValue; - use super::*; + use futures_executor::block_on; + use http::HeaderValue; #[test] fn request_roundtrip() { - let request: Request = Request::builder() + let request: Request> = Request::builder() .method(Method::PUT) .version(Version::HTTP_11) .header("test", HeaderValue::from_static("request")) .uri(format!("https://axum-wasm.example/hello")) - .body("Some body".to_string()) + .body(Full::new(Bytes::from_static(b"request body"))) .unwrap(); - let rmp = RequestWrapper::from(request).into_rmp(); + let rmp = block_on(wrap_request(request)).into_rmp(); let back = RequestWrapper::from_rmp(rmp); @@ -122,18 +169,19 @@ mod test { back.uri.to_string(), "https://axum-wasm.example/hello".to_string() ); + assert_eq!(std::str::from_utf8(&back.body).unwrap(), "request body"); } #[test] fn response_roundtrip() { - let response: Response = Response::builder() + let response: Response> = Response::builder() .version(Version::HTTP_11) .header("test", HeaderValue::from_static("response")) .status(StatusCode::NOT_MODIFIED) - .body("Some body".to_string()) + .body(Full::new(Bytes::from_static(b"response body"))) .unwrap(); - let rmp = ResponseWrapper::from(response).into_rmp(); + let rmp = block_on(wrap_response(response)).into_rmp(); let back = ResponseWrapper::from_rmp(rmp); @@ -143,5 +191,6 @@ mod test { ); assert_eq!(back.status, StatusCode::NOT_MODIFIED); assert_eq!(back.version, Version::HTTP_11); + assert_eq!(std::str::from_utf8(&back.body).unwrap(), "response body"); } }