Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: return streaming body from wasm router #558

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions runtime/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ uuid = { workspace = true, features = ["v4"] }
wasi-common = "4.0.0"
wasmtime = "4.0.0"
wasmtime-wasi = "4.0.0"
futures = "0.3.25"

[dependencies.shuttle-common]
workspace = true
Expand Down
41 changes: 18 additions & 23 deletions runtime/src/axum/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::sync::Mutex;

use async_trait::async_trait;
use cap_std::os::unix::net::UnixStream;
use futures::TryStreamExt;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request, Response};
use shuttle_common::wasm::{RequestWrapper, ResponseWrapper};
Expand Down Expand Up @@ -197,18 +198,24 @@ impl RouterInner {
.unwrap();

let (mut parts_stream, parts_client) = UnixStream::pair().unwrap();
let (mut body_stream, body_client) = UnixStream::pair().unwrap();
let (body_read_stream, body_read_client) = UnixStream::pair().unwrap();
let (mut body_write_stream, body_write_client) = UnixStream::pair().unwrap();

let parts_client = WasiUnixStream::from_cap_std(parts_client);
let body_client = WasiUnixStream::from_cap_std(body_client);
let body_read_client = WasiUnixStream::from_cap_std(body_read_client);
let body_write_client = WasiUnixStream::from_cap_std(body_write_client);

store
.data_mut()
.insert_file(3, Box::new(parts_client), FileCaps::all());

store
.data_mut()
.insert_file(4, Box::new(body_client), FileCaps::all());
.insert_file(4, Box::new(body_write_client), FileCaps::all());

store
.data_mut()
.insert_file(5, Box::new(body_read_client), FileCaps::all());

let (parts, body) = req.into_parts();

Expand All @@ -219,21 +226,19 @@ impl RouterInner {
parts_stream.write_all(&request_rmp).unwrap();

// write body
body_stream
body_write_stream
.write_all(hyper::body::to_bytes(body).await.unwrap().as_ref())
oddgrd marked this conversation as resolved.
Show resolved Hide resolved
.unwrap();
// signal to the receiver that end of file has been reached
body_stream.write_all(&[0]).unwrap();

println!("calling inner Router");
// println!("calling inner Router");
self.linker
.get(&mut store, "axum", "__SHUTTLE_Axum_call")
.unwrap()
.into_func()
.unwrap()
.typed::<(RawFd, RawFd), ()>(&store)
.typed::<(RawFd, RawFd, RawFd), ()>(&store)
.unwrap()
.call(&mut store, (3, 4))
.call(&mut store, (3, 4, 5))
.unwrap();

// read response parts from host
Expand All @@ -243,21 +248,11 @@ impl RouterInner {
let wrapper: ResponseWrapper = rmps::from_read(reader).unwrap();

// read response body from wasm router
let mut body_buf = Vec::new();
let mut c_buf: [u8; 1] = [0; 1];
loop {
body_stream.read_exact(&mut c_buf).unwrap();
if c_buf[0] == 0 {
break;
} else {
body_buf.push(c_buf[0]);
}
}
let reader = BufReader::new(body_read_stream);
let stream = futures::stream::iter(reader.bytes()).try_chunks(2);
let body = hyper::Body::wrap_stream(stream);

let response: Response<Body> = wrapper
.into_response_builder()
.body(body_buf.into())
.unwrap();
let response: Response<Body> = wrapper.into_response_builder().body(body).unwrap();

Ok(response)
}
Expand Down
1 change: 1 addition & 0 deletions tmp/axum-wasm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ futures-executor = "0.3.21"
http = "0.2.7"
tower-service = "0.3.1"
rmp-serde = { version = "1.1.1" }
futures = "0.3.25"

[dependencies.shuttle-common]
path = "../../common"
Expand Down
44 changes: 18 additions & 26 deletions tmp/axum-wasm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@ async fn goodbye() -> &'static str {
pub extern "C" fn __SHUTTLE_Axum_call(
fd_3: std::os::wasi::prelude::RawFd,
fd_4: std::os::wasi::prelude::RawFd,
fd_5: std::os::wasi::prelude::RawFd,
) {
use axum::body::HttpBody;
use std::io::{Read, Write};
use axum::body::{Body, HttpBody};
use futures::stream::TryStreamExt;
use std::io::{BufReader, Read, Write};
use std::os::wasi::io::FromRawFd;

println!("inner handler awoken; interacting with fd={fd_3},{fd_4}");
// println!("inner handler awoken; interacting with fd={fd_3},{fd_4}");

// file descriptor 3 for reading and writing http parts
let mut parts_fd = unsafe { std::fs::File::from_raw_fd(fd_3) };
Expand All @@ -48,27 +49,17 @@ pub extern "C" fn __SHUTTLE_Axum_call(
// deserialize request parts from rust messagepack
let wrapper: shuttle_common::wasm::RequestWrapper = rmp_serde::from_read(reader).unwrap();

// file descriptor 4 for reading and writing http body
let mut body_fd = unsafe { std::fs::File::from_raw_fd(fd_4) };

// read body from host
let mut body_buf = Vec::new();
let mut c_buf: [u8; 1] = [0; 1];
loop {
body_fd.read(&mut c_buf).unwrap();
if c_buf[0] == 0 {
break;
} else {
body_buf.push(c_buf[0]);
}
}
// file descriptor 4 for reading http body into wasm
let body_read_stream = unsafe { std::fs::File::from_raw_fd(fd_4) };

let request: http::Request<axum::body::Body> = wrapper
.into_request_builder()
.body(body_buf.into())
.unwrap();
let reader = BufReader::new(body_read_stream);
let stream = futures::stream::iter(reader.bytes()).try_chunks(2);
let body = Body::wrap_stream(stream);

println!("inner router received request: {:?}", &request);
let request: http::Request<axum::body::Body> =
wrapper.into_request_builder().body(body).unwrap();

// println!("inner router received request: {:?}", &request);
let res = handle_request(request);

let (parts, mut body) = res.into_parts();
Expand All @@ -79,10 +70,11 @@ pub extern "C" fn __SHUTTLE_Axum_call(
// write response parts
parts_fd.write_all(&response_parts).unwrap();

// file descriptor 5 for writing http body to host
let mut body_write_stream = unsafe { std::fs::File::from_raw_fd(fd_5) };

// write body if there is one
if let Some(body) = futures_executor::block_on(body.data()) {
body_fd.write_all(body.unwrap().as_ref()).unwrap();
body_write_stream.write_all(body.unwrap().as_ref()).unwrap();
}
// signal to the reader that end of file has been reached
body_fd.write(&[0]).unwrap();
}