Skip to content

Commit

Permalink
feat: serialize the full HTTP req/res to rmp
Browse files Browse the repository at this point in the history
  • Loading branch information
oddgrd committed Nov 15, 2022
1 parent 4b9e47c commit f9d73c2
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 66 deletions.
3 changes: 2 additions & 1 deletion runtime/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ cap-std = "0.26.0"
clap ={ version = "4.0.18", features = ["derive"] }
serenity = { version = "0.11.5", default-features = false, features = ["client", "gateway", "rustls_backend", "model"] }
thiserror = "1.0.37"
hyper = "1.0.0-rc.1"
hyper = "0.14.23"
tokio = { version = "=1.20.1", features = ["full"] }
tokio-stream = "0.1.11"
tonic = "0.8.2"
Expand All @@ -22,6 +22,7 @@ uuid = { version = "1.1.2", features = ["v4"] }
wasi-common = "2.0.0"
wasmtime = "2.0.0"
wasmtime-wasi = "2.0.0"
http-body = "0.4.5"

[dependencies.shuttle-common]
version = "0.7.0"
Expand Down
35 changes: 16 additions & 19 deletions runtime/src/axum/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ use std::sync::{Arc, Mutex};

use async_trait::async_trait;
use cap_std::os::unix::net::UnixStream;
use http_body::Full;
use hyper::body::Bytes;
use hyper::Response;
use shuttle_axum_utils::{RequestWrapper, ResponseWrapper};
use shuttle_axum_utils::{wrap_request, RequestWrapper, ResponseWrapper};
use shuttle_proto::runtime::runtime_server::Runtime;
use shuttle_proto::runtime::{LoadRequest, LoadResponse, StartRequest, StartResponse};
use tonic::Status;
Expand Down Expand Up @@ -125,9 +127,8 @@ struct RouterInner {
}

impl RouterInner {
/// Send a HTTP request to given endpoint on the axum-wasm router and return the response
/// todo: also send and receive the body
pub async fn send_request(&mut self, req: hyper::Request<String>) -> Response<String> {
/// Send a HTTP request with body to given endpoint on the axum-wasm router and return the response
pub async fn send_request(&mut self, req: hyper::Request<Full<Bytes>>) -> Response<Vec<u8>> {
let (mut host, client) = UnixStream::pair().unwrap();
let client = WasiUnixStream::from_cap_std(client);

Expand All @@ -136,7 +137,7 @@ impl RouterInner {
.insert_file(3, Box::new(client), FileCaps::all());

// serialise request to rmp
let request_rmp = RequestWrapper::from(req).into_rmp();
let request_rmp = wrap_request(req).await.into_rmp();

host.write_all(&request_rmp).unwrap();
host.write(&[0]).unwrap();
Expand All @@ -158,14 +159,8 @@ impl RouterInner {
// deserialize response from rmp
let res = ResponseWrapper::from_rmp(res_buf);

// todo: clean up conversion of wrapper to request
let mut response = Response::builder().status(res.status).version(res.version);
response
.headers_mut()
.unwrap()
.extend(res.headers.into_iter());

response.body("Some body".to_string()).unwrap()
// consume the wrapper and return response
res.into_response()
}
}

Expand Down Expand Up @@ -196,36 +191,38 @@ pub mod tests {
let mut inner = axum.inner.lock().unwrap();

// GET /hello
let request: Request<String> = Request::builder()
let request: Request<Full<Bytes>> = Request::builder()
.method(Method::GET)
.version(Version::HTTP_11)
.header("test", HeaderValue::from_static("hello"))
.uri(format!("https://axum-wasm.example/hello"))
.body("Some body".to_string())
.body(Full::new(Bytes::from_static(b"some body")))
.unwrap();

let res = inner.send_request(request).await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(std::str::from_utf8(&res.body()).unwrap(), "Hello, World!");

// GET /goodbye
let request: Request<String> = Request::builder()
let request: Request<Full<Bytes>> = Request::builder()
.method(Method::GET)
.version(Version::HTTP_11)
.header("test", HeaderValue::from_static("goodbye"))
.uri(format!("https://axum-wasm.example/goodbye"))
.body("Some body".to_string())
.body(Full::new(Bytes::from_static(b"some body")))
.unwrap();

let res = inner.send_request(request).await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(std::str::from_utf8(&res.body()).unwrap(), "Goodbye, World!");

// GET /invalid
let request: Request<String> = Request::builder()
let request: Request<Full<Bytes>> = Request::builder()
.method(Method::GET)
.version(Version::HTTP_11)
.header("test", HeaderValue::from_static("invalid"))
.uri(format!("https://axum-wasm.example/invalid"))
.body("Some body".to_string())
.body(Full::new(Bytes::from_static(b"some body")))
.unwrap();

let res = inner.send_request(request).await;
Expand Down
30 changes: 16 additions & 14 deletions tmp/axum-wasm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
use axum::body::HttpBody;
use axum::{response::Response, routing::get, Router};
use futures_executor::block_on;
use http::Request;
use shuttle_axum_utils::{RequestWrapper, ResponseWrapper};
use shuttle_axum_utils::{wrap_response, RequestWrapper};
use std::fs::File;
use std::io::{Read, Write};
use std::os::wasi::prelude::*;
use tower_service::Service;

pub fn handle_request(req: Request<String>) -> Response {
pub fn handle_request<B>(req: Request<B>) -> Response
where
B: HttpBody + Send + 'static,
{
block_on(app(req))
}

async fn app(request: Request<String>) -> Response {
async fn app<B>(request: Request<B>) -> Response
where
B: HttpBody + Send + 'static,
{
let mut router = Router::new()
.route("/hello", get(hello))
.route("/goodbye", get(goodbye))
Expand Down Expand Up @@ -48,23 +55,18 @@ pub extern "C" fn __SHUTTLE_Axum_call(fd: RawFd) {
}
}

// deserialize request from rust messagepack
let req = RequestWrapper::from_rmp(req_buf);

// todo: clean up conversion of wrapper to request
let mut request: Request<String> = Request::builder()
.method(req.method)
.version(req.version)
.uri(req.uri)
.body("Some body".to_string())
.unwrap();

request.headers_mut().extend(req.headers.into_iter());
// consume wrapper and return Request
let request = req.into_request();

println!("inner router received request: {:?}", &request);
let res = handle_request(request);

println!("inner router sending response: {:?}", &res);
let response = ResponseWrapper::from(res);
// wrap inner response and serialize it as rust messagepack
let response = block_on(wrap_response(res)).into_rmp();

f.write_all(&response.into_rmp()).unwrap();
f.write_all(&response).unwrap();
}
13 changes: 10 additions & 3 deletions tmp/utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,18 @@
name = "shuttle-axum-utils"
version = "0.1.0"
edition = "2021"

description = "Utilities for serializing requests to and from rust messagepack"
[lib]

[dependencies]
http = "0.2.7"
serde = { version = "1.0.137", features = [ "derive" ] }
http-serde = { version ="1.1.2" }
http-body = "0.4.5"
http-serde = { version = "1.1.2" }
# hyper dep because I was struggling to turn http body to bytes with hyper::body::to_bytes
hyper = "0.14.23"
rmp-serde = { version = "1.1.1" }
serde = { version = "1.0.137", features = [ "derive" ] }

[dev-dependencies]
# unit tests have to call an async function to wrap req/res
futures-executor = "0.3.21"
107 changes: 78 additions & 29 deletions tmp/utils/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use http::{HeaderMap, Method, Request, Response, StatusCode, Uri, Version};
use http_body::{Body, Full};
use hyper::body::Bytes;
use rmps::{Deserializer, Serializer};
use serde::{Deserialize, Serialize};

extern crate rmp_serde as rmps;

// todo: add extensions
// todo: add http extensions field
#[derive(Serialize, Deserialize, Debug)]
pub struct RequestWrapper {
#[serde(with = "http_serde::method")]
Expand All @@ -18,18 +20,27 @@ pub struct RequestWrapper {

#[serde(with = "http_serde::header_map")]
pub headers: HeaderMap,
}

impl<B> From<Request<B>> for RequestWrapper {
fn from(req: Request<B>) -> Self {
let (parts, _) = req.into_parts();
// I used Vec<u8> since it can derive serialize/deserialize
pub body: Vec<u8>,
}

Self {
method: parts.method,
uri: parts.uri,
version: parts.version,
headers: parts.headers,
}
/// Wrap HTTP Request in a struct that can be serialized to and from Rust MessagePack
pub async fn wrap_request<B>(req: Request<B>) -> RequestWrapper
where
B: Body<Data = Bytes>,
B::Error: std::fmt::Debug,
{
let (parts, body) = req.into_parts();

let body = hyper::body::to_bytes(body).await.unwrap();

RequestWrapper {
method: parts.method,
uri: parts.uri,
version: parts.version,
headers: parts.headers,
body: body.into(),
}
}

Expand All @@ -48,9 +59,23 @@ impl RequestWrapper {

Deserialize::deserialize(&mut de).unwrap()
}

/// Consume wrapper and return Request
pub fn into_request(self) -> Request<Full<Bytes>> {
let mut request: Request<Full<Bytes>> = Request::builder()
.method(self.method)
.version(self.version)
.uri(self.uri)
.body(Full::new(self.body.into()))
.unwrap();

request.headers_mut().extend(self.headers.into_iter());

request
}
}

// todo: add extensions
// todo: add http extensions field
#[derive(Serialize, Deserialize, Debug)]
pub struct ResponseWrapper {
#[serde(with = "http_serde::status_code")]
Expand All @@ -61,17 +86,26 @@ pub struct ResponseWrapper {

#[serde(with = "http_serde::header_map")]
pub headers: HeaderMap,
}

impl<B> From<Response<B>> for ResponseWrapper {
fn from(res: Response<B>) -> Self {
let (parts, _) = res.into_parts();
// I used Vec<u8> since it can derive serialize/deserialize
pub body: Vec<u8>,
}

Self {
status: parts.status,
version: parts.version,
headers: parts.headers,
}
/// Wrap HTTP Response in a struct that can be serialized to and from Rust MessagePack
pub async fn wrap_response<B>(res: Response<B>) -> ResponseWrapper
where
B: Body<Data = Bytes>,
B::Error: std::fmt::Debug,
{
let (parts, body) = res.into_parts();

let body = hyper::body::to_bytes(body).await.unwrap();

ResponseWrapper {
status: parts.status,
version: parts.version,
headers: parts.headers,
body: body.into(),
}
}

Expand All @@ -90,25 +124,38 @@ impl ResponseWrapper {

Deserialize::deserialize(&mut de).unwrap()
}

/// Consume wrapper and return Response
pub fn into_response(self) -> Response<Vec<u8>> {
let mut response = Response::builder()
.status(self.status)
.version(self.version);
response
.headers_mut()
.unwrap()
.extend(self.headers.into_iter());

response.body(self.body).unwrap()
}
}

#[cfg(test)]
mod test {
use http::HeaderValue;

use super::*;
use futures_executor::block_on;
use http::HeaderValue;

#[test]
fn request_roundtrip() {
let request: Request<String> = Request::builder()
let request: Request<Full<Bytes>> = Request::builder()
.method(Method::PUT)
.version(Version::HTTP_11)
.header("test", HeaderValue::from_static("request"))
.uri(format!("https://axum-wasm.example/hello"))
.body("Some body".to_string())
.body(Full::new(Bytes::from_static(b"request body")))
.unwrap();

let rmp = RequestWrapper::from(request).into_rmp();
let rmp = block_on(wrap_request(request)).into_rmp();

let back = RequestWrapper::from_rmp(rmp);

Expand All @@ -122,18 +169,19 @@ mod test {
back.uri.to_string(),
"https://axum-wasm.example/hello".to_string()
);
assert_eq!(std::str::from_utf8(&back.body).unwrap(), "request body");
}

#[test]
fn response_roundtrip() {
let response: Response<String> = Response::builder()
let response: Response<Full<Bytes>> = Response::builder()
.version(Version::HTTP_11)
.header("test", HeaderValue::from_static("response"))
.status(StatusCode::NOT_MODIFIED)
.body("Some body".to_string())
.body(Full::new(Bytes::from_static(b"response body")))
.unwrap();

let rmp = ResponseWrapper::from(response).into_rmp();
let rmp = block_on(wrap_response(response)).into_rmp();

let back = ResponseWrapper::from_rmp(rmp);

Expand All @@ -143,5 +191,6 @@ mod test {
);
assert_eq!(back.status, StatusCode::NOT_MODIFIED);
assert_eq!(back.version, Version::HTTP_11);
assert_eq!(std::str::from_utf8(&back.body).unwrap(), "response body");
}
}

0 comments on commit f9d73c2

Please sign in to comment.