Skip to content

Commit

Permalink
feat(next): expand macro into axum routes (#488)
Browse files Browse the repository at this point in the history
* feat: app codegen model

* refactor: qualify all namespaces

* feat: low-level wasi export fn

* refactor: restrict to supported axum methods
  • Loading branch information
chesedo authored Nov 24, 2022
1 parent f913b8a commit c2b0f63
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 29 deletions.
1 change: 1 addition & 0 deletions codegen/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod main;
mod next;

use proc_macro::TokenStream;
use proc_macro_error::proc_macro_error;
Expand Down
192 changes: 192 additions & 0 deletions codegen/src/next/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
use proc_macro_error::emit_error;
use quote::{quote, ToTokens};
use syn::{Ident, LitStr};

struct Endpoint {
route: LitStr,
method: Ident,
function: Ident,
}

impl ToTokens for Endpoint {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
let Self {
route,
method,
function,
} = self;

match method.to_string().as_str() {
"get" | "post" | "delete" | "put" | "options" | "head" | "trace" | "patch" => {}
_ => {
emit_error!(
method,
"method is not supported";
hint = "Try one of the following: `get`, `post`, `delete`, `put`, `options`, `head`, `trace` or `patch`"
)
}
};

let route = quote!(.route(#route, axum::routing::#method(#function)));

route.to_tokens(tokens);
}
}

pub(crate) struct App {
endpoints: Vec<Endpoint>,
}

impl ToTokens for App {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
let Self { endpoints } = self;

let app = quote!(
async fn __app<B>(request: http::Request<B>) -> axum::response::Response
where
B: axum::body::HttpBody + Send + 'static,
{
use tower_service::Service;

let mut router = axum::Router::new()
#(#endpoints)*
.into_service();

let response = router.call(request).await.unwrap();

response
}
);

app.to_tokens(tokens);
}
}

pub(crate) fn wasi_bindings(app: App) -> proc_macro2::TokenStream {
quote!(
#app

#[no_mangle]
#[allow(non_snake_case)]
pub extern "C" fn __SHUTTLE_Axum_call(
fd_3: std::os::wasi::prelude::RawFd,
fd_4: std::os::wasi::prelude::RawFd,
) {
use axum::body::HttpBody;
use std::io::{Read, Write};
use std::os::wasi::io::FromRawFd;

println!("inner handler awoken; interacting with fd={fd_3},{fd_4}");

// file descriptor 3 for reading and writing http parts
let mut parts_fd = unsafe { std::fs::File::from_raw_fd(fd_3) };

let reader = std::io::BufReader::new(&mut parts_fd);

// deserialize request parts from rust messagepack
let wrapper: shuttle_common::wasm::RequestWrapper = rmp_serde::from_read(reader).unwrap();

// file descriptor 4 for reading and writing http body
let mut body_fd = unsafe { std::fs::File::from_raw_fd(fd_4) };

// read body from host
let mut body_buf = Vec::new();
let mut c_buf: [u8; 1] = [0; 1];
loop {
body_fd.read(&mut c_buf).unwrap();
if c_buf[0] == 0 {
break;
} else {
body_buf.push(c_buf[0]);
}
}

let request: http::Request<axum::body::Body> = wrapper
.into_request_builder()
.body(body_buf.into())
.unwrap();

println!("inner router received request: {:?}", &request);
let res = futures_executor::block_on(__app(request));

let (parts, mut body) = res.into_parts();

// wrap and serialize response parts as rmp
let response_parts = shuttle_common::wasm::ResponseWrapper::from(parts).into_rmp();

// write response parts
parts_fd.write_all(&response_parts).unwrap();

// write body if there is one
if let Some(body) = futures_executor::block_on(body.data()) {
body_fd.write_all(body.unwrap().as_ref()).unwrap();
}
// signal to the reader that end of file has been reached
body_fd.write(&[0]).unwrap();
}
)
}

#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use quote::quote;
use syn::parse_quote;

use crate::next::App;

use super::Endpoint;

#[test]
fn endpoint_to_token() {
let endpoint = Endpoint {
route: parse_quote!("/hello"),
method: parse_quote!(get),
function: parse_quote!(hello),
};

let actual = quote!(#endpoint);
let expected = quote!(.route("/hello", axum::routing::get(hello)));

assert_eq!(actual.to_string(), expected.to_string());
}

#[test]
fn app_to_token() {
let app = App {
endpoints: vec![
Endpoint {
route: parse_quote!("/hello"),
method: parse_quote!(get),
function: parse_quote!(hello),
},
Endpoint {
route: parse_quote!("/goodbye"),
method: parse_quote!(post),
function: parse_quote!(goodbye),
},
],
};

let actual = quote!(#app);
let expected = quote!(
async fn __app<B>(request: http::Request<B>) -> axum::response::Response
where
B: axum::body::HttpBody + Send + 'static,
{
use tower_service::Service;

let mut router = axum::Router::new()
.route("/hello", axum::routing::get(hello))
.route("/goodbye", axum::routing::post(goodbye))
.into_service();

let response = router.call(request).await.unwrap();

response
}
);

assert_eq!(actual.to_string(), expected.to_string());
}
}
54 changes: 25 additions & 29 deletions tmp/axum-wasm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,19 @@
use axum::body::{Body, HttpBody};
use axum::{response::Response, routing::get, Router};
use futures_executor::block_on;
use http::Request;
use shuttle_common::wasm::{RequestWrapper, ResponseWrapper};
use std::fs::File;
use std::io::BufReader;
use std::io::{Read, Write};
use std::os::wasi::prelude::*;
use tower_service::Service;

extern crate rmp_serde as rmps;

pub fn handle_request<B>(req: Request<B>) -> Response
pub fn handle_request<B>(req: http::Request<B>) -> axum::response::Response
where
B: HttpBody + Send + 'static,
B: axum::body::HttpBody + Send + 'static,
{
block_on(app(req))
futures_executor::block_on(app(req))
}

async fn app<B>(request: Request<B>) -> Response
async fn app<B>(request: http::Request<B>) -> axum::response::Response
where
B: HttpBody + Send + 'static,
B: axum::body::HttpBody + Send + 'static,
{
let mut router = Router::new()
.route("/hello", get(hello))
.route("/goodbye", get(goodbye))
use tower_service::Service;

let mut router = axum::Router::new()
.route("/hello", axum::routing::get(hello))
.route("/goodbye", axum::routing::get(goodbye))
.into_service();

let response = router.call(request).await.unwrap();
Expand All @@ -42,19 +31,26 @@ async fn goodbye() -> &'static str {

#[no_mangle]
#[allow(non_snake_case)]
pub extern "C" fn __SHUTTLE_Axum_call(fd_3: RawFd, fd_4: RawFd) {
pub extern "C" fn __SHUTTLE_Axum_call(
fd_3: std::os::wasi::prelude::RawFd,
fd_4: std::os::wasi::prelude::RawFd,
) {
use axum::body::HttpBody;
use std::io::{Read, Write};
use std::os::wasi::io::FromRawFd;

println!("inner handler awoken; interacting with fd={fd_3},{fd_4}");

// file descriptor 3 for reading and writing http parts
let mut parts_fd = unsafe { File::from_raw_fd(fd_3) };
let mut parts_fd = unsafe { std::fs::File::from_raw_fd(fd_3) };

let reader = BufReader::new(&mut parts_fd);
let reader = std::io::BufReader::new(&mut parts_fd);

// deserialize request parts from rust messagepack
let wrapper: RequestWrapper = rmps::from_read(reader).unwrap();
let wrapper: shuttle_common::wasm::RequestWrapper = rmp_serde::from_read(reader).unwrap();

// file descriptor 4 for reading and writing http body
let mut body_fd = unsafe { File::from_raw_fd(fd_4) };
let mut body_fd = unsafe { std::fs::File::from_raw_fd(fd_4) };

// read body from host
let mut body_buf = Vec::new();
Expand All @@ -68,7 +64,7 @@ pub extern "C" fn __SHUTTLE_Axum_call(fd_3: RawFd, fd_4: RawFd) {
}
}

let request: Request<Body> = wrapper
let request: http::Request<axum::body::Body> = wrapper
.into_request_builder()
.body(body_buf.into())
.unwrap();
Expand All @@ -79,13 +75,13 @@ pub extern "C" fn __SHUTTLE_Axum_call(fd_3: RawFd, fd_4: RawFd) {
let (parts, mut body) = res.into_parts();

// wrap and serialize response parts as rmp
let response_parts = ResponseWrapper::from(parts).into_rmp();
let response_parts = shuttle_common::wasm::ResponseWrapper::from(parts).into_rmp();

// write response parts
parts_fd.write_all(&response_parts).unwrap();

// write body if there is one
if let Some(body) = block_on(body.data()) {
if let Some(body) = futures_executor::block_on(body.data()) {
body_fd.write_all(body.unwrap().as_ref()).unwrap();
}
// signal to the reader that end of file has been reached
Expand Down

0 comments on commit c2b0f63

Please sign in to comment.