diff --git a/Cargo.lock b/Cargo.lock index bc7113212..bcafc3d30 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1530,6 +1530,7 @@ dependencies = [ "serde", "serde_json", "shuttle-common", + "shuttle-proto", "shuttle-secrets", "shuttle-service", "sqlx", @@ -1542,6 +1543,7 @@ dependencies = [ "tokiotest-httpserver", "toml", "toml_edit 0.15.0", + "tonic", "tracing", "tracing-subscriber", "url", @@ -6232,6 +6234,7 @@ dependencies = [ "opentelemetry-datadog", "opentelemetry-http", "pipe", + "portpicker", "rand 0.8.5", "serde", "serde_json", @@ -6307,11 +6310,16 @@ dependencies = [ name = "shuttle-proto" version = "0.8.0" dependencies = [ + "anyhow", + "chrono", "prost", "prost-types", "shuttle-common", + "tokio", "tonic", "tonic-build", + "tracing", + "uuid 1.2.2", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 5b2ed622d..483679463 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,5 +44,6 @@ uuid = "1.2.2" thiserror = "1.0.37" serde = "1.0.148" serde_json = "1.0.89" +tonic = "0.8.3" tracing = "0.1.37" tracing-subscriber = "0.3.16" diff --git a/cargo-shuttle/Cargo.toml b/cargo-shuttle/Cargo.toml index 3b2a6916f..d2a37e792 100644 --- a/cargo-shuttle/Cargo.toml +++ b/cargo-shuttle/Cargo.toml @@ -43,6 +43,7 @@ tokio = { version = "1.22.0", features = ["macros"] } tokio-tungstenite = { version = "0.17.2", features = ["native-tls"] } toml = "0.5.9" toml_edit = "0.15.0" +tonic = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } url = "2.3.1" @@ -53,6 +54,9 @@ webbrowser = "0.8.2" workspace = true features= ["models"] +[dependencies.shuttle-proto] +workspace = true + [dependencies.shuttle-secrets] version = "0.8.0" path = "../resources/secrets" diff --git a/cargo-shuttle/build.rs b/cargo-shuttle/build.rs new file mode 100644 index 000000000..4503a843d --- /dev/null +++ b/cargo-shuttle/build.rs @@ -0,0 +1,17 @@ +use std::{env, process::Command}; + +fn main() { + println!("cargo:rerun-if-changed=../runtime"); + + // Build binary for runtime so that it can be embedded in the binary for the cli + let out_dir = env::var_os("OUT_DIR").unwrap(); + Command::new("cargo") + .arg("build") + .arg("--package") + .arg("shuttle-runtime") + .arg("--target-dir") + .arg(out_dir) + .arg("--release") + .output() + .expect("failed to build the shuttle runtime"); +} diff --git a/cargo-shuttle/src/lib.rs b/cargo-shuttle/src/lib.rs index 13236628a..3709bffff 100644 --- a/cargo-shuttle/src/lib.rs +++ b/cargo-shuttle/src/lib.rs @@ -1,11 +1,12 @@ mod args; mod client; pub mod config; -mod factory; mod init; +mod provisioner_server; use shuttle_common::project::ProjectName; -use std::collections::BTreeMap; +use shuttle_proto::runtime::{self, LoadRequest, StartRequest, SubscribeLogsRequest}; +use std::collections::HashMap; use std::ffi::OsString; use std::fs::{read_to_string, File}; use std::io::stdout; @@ -21,26 +22,26 @@ use clap_complete::{generate, Shell}; use config::RequestContext; use crossterm::style::Stylize; use dialoguer::{theme::ColorfulTheme, Confirm, FuzzySelect, Input, Password}; -use factory::LocalFactory; use flate2::write::GzEncoder; use flate2::Compression; -use futures::StreamExt; +use futures::{StreamExt, TryFutureExt}; use git2::{Repository, StatusOptions}; use ignore::overrides::OverrideBuilder; use ignore::WalkBuilder; +use provisioner_server::LocalProvisioner; use shuttle_common::models::{project, secret}; -use shuttle_service::loader::{build_crate, Loader, Runtime}; -use shuttle_service::Logger; +use shuttle_service::loader::{build_crate, Runtime}; use std::fmt::Write; use strum::IntoEnumIterator; use tar::Builder; -use tokio::sync::mpsc; use tracing::trace; use uuid::Uuid; use crate::args::{DeploymentCommand, ProjectCommand}; use crate::client::Client; +const BINARY_BYTES: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/release/shuttle-runtime")); + pub struct Shuttle { ctx: RequestContext, } @@ -392,60 +393,116 @@ impl Shuttle { "Building".bold().green(), working_directory.display() ); - let so_path = match build_crate(id, working_directory, false, false, tx).await? { - Runtime::Legacy(path) => path, - Runtime::Next(_) => todo!(), - }; + let runtime = build_crate(id, working_directory, false, tx).await?; trace!("loading secrets"); let secrets_path = working_directory.join("Secrets.toml"); - let secrets: BTreeMap = - if let Ok(secrets_str) = read_to_string(secrets_path) { - let secrets: BTreeMap = - secrets_str.parse::()?.try_into()?; + let secrets: HashMap = if let Ok(secrets_str) = read_to_string(secrets_path) + { + let secrets: HashMap = + secrets_str.parse::()?.try_into()?; - trace!(keys = ?secrets.keys(), "available secrets"); + trace!(keys = ?secrets.keys(), "available secrets"); - secrets - } else { - trace!("no Secrets.toml was found"); - Default::default() - }; + secrets + } else { + trace!("no Secrets.toml was found"); + Default::default() + }; + + let service_name = self.ctx.project_name().to_string(); - let loader = Loader::from_so_file(so_path)?; + let (is_wasm, so_path) = match runtime { + Runtime::Next(path) => (true, path), + Runtime::Legacy(path) => (false, path), + }; - let mut factory = LocalFactory::new( - self.ctx.project_name().clone(), + let provisioner = LocalProvisioner::new()?; + let provisioner_server = provisioner.start(SocketAddr::new( + Ipv4Addr::LOCALHOST.into(), + run_args.port + 1, + )); + let (mut runtime, mut runtime_client) = runtime::start( + BINARY_BYTES, + is_wasm, + runtime::StorageManagerType::WorkingDir(working_directory.to_path_buf()), + &format!("http://localhost:{}", run_args.port + 1), + ) + .await + .map_err(|err| { + provisioner_server.abort(); + + err + })?; + + let load_request = tonic::Request::new(LoadRequest { + path: so_path + .into_os_string() + .into_string() + .expect("to convert path to string"), + service_name: service_name.clone(), secrets, - working_directory.to_path_buf(), - )?; - let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), run_args.port); + }); + trace!("loading service"); + let _ = runtime_client + .load(load_request) + .or_else(|err| async { + provisioner_server.abort(); + runtime.kill().await?; + + Err(err) + }) + .await?; - trace!("loading project"); - println!( - "\n{:>12} {} on http://{}", - "Starting".bold().green(), - self.ctx.project_name(), - addr - ); - let (tx, mut rx) = mpsc::unbounded_channel(); + let mut stream = runtime_client + .subscribe_logs(tonic::Request::new(SubscribeLogsRequest {})) + .or_else(|err| async { + provisioner_server.abort(); + runtime.kill().await?; + + Err(err) + }) + .await? + .into_inner(); tokio::spawn(async move { - while let Some(log) = rx.recv().await { + while let Some(log) = stream.message().await.expect("to get log from stream") { + let log: shuttle_common::LogItem = log.into(); println!("{log}"); } }); - let logger = Logger::new(tx, id); - let (handle, so) = loader.load(&mut factory, addr, logger).await?; + let start_request = StartRequest { + deployment_id: id.as_bytes().to_vec(), + service_name, + port: run_args.port as u32, + }; - handle.await??; + trace!(?start_request, "starting service"); + let response = runtime_client + .start(tonic::Request::new(start_request)) + .or_else(|err| async { + provisioner_server.abort(); + runtime.kill().await?; - tokio::task::spawn_blocking(move || { - trace!("closing so file"); - so.close().unwrap(); - }); + Err(err) + }) + .await? + .into_inner(); + + trace!(response = ?response, "client response: "); + + let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), run_args.port); + + println!( + "\n{:>12} {} on http://{}", + "Starting".bold().green(), + self.ctx.project_name(), + addr + ); + + runtime.wait().await?; Ok(()) } diff --git a/cargo-shuttle/src/factory.rs b/cargo-shuttle/src/provisioner_server.rs similarity index 86% rename from cargo-shuttle/src/factory.rs rename to cargo-shuttle/src/provisioner_server.rs index fa9b5ff52..e3640255f 100644 --- a/cargo-shuttle/src/factory.rs +++ b/cargo-shuttle/src/provisioner_server.rs @@ -9,55 +9,53 @@ use bollard::{ }; use crossterm::{ cursor::{MoveDown, MoveUp}, - style::Stylize, terminal::{Clear, ClearType}, QueueableCommand, }; use futures::StreamExt; use portpicker::pick_unused_port; -use shuttle_common::{ - database::{AwsRdsEngine, SharedEngine}, - DatabaseReadyInfo, +use shuttle_common::database::{AwsRdsEngine, SharedEngine}; +use shuttle_proto::provisioner::{ + provisioner_server::{Provisioner, ProvisionerServer}, + DatabaseRequest, DatabaseResponse, }; -use shuttle_service::{database::Type, error::CustomError, Factory, ServiceName}; -use std::{ - collections::{BTreeMap, HashMap}, - io::stdout, - path::PathBuf, - time::Duration, +use shuttle_service::database::Type; +use std::{collections::HashMap, io::stdout, net::SocketAddr, time::Duration}; +use tokio::{task::JoinHandle, time::sleep}; +use tonic::{ + transport::{self, Server}, + Request, Response, Status, }; -use tokio::time::sleep; use tracing::{error, trace}; -pub struct LocalFactory { +/// A provisioner for local runs +/// It uses Docker to create Databases +pub struct LocalProvisioner { docker: Docker, - service_name: ServiceName, - secrets: BTreeMap, - working_directory: PathBuf, } -impl LocalFactory { - pub fn new( - service_name: ServiceName, - secrets: BTreeMap, - working_directory: PathBuf, - ) -> Result { +impl LocalProvisioner { + pub fn new() -> Result { Ok(Self { docker: Docker::connect_with_local_defaults()?, - service_name, - secrets, - working_directory, }) } -} -#[async_trait] -impl Factory for LocalFactory { + pub fn start(self, address: SocketAddr) -> JoinHandle> { + tokio::spawn(async move { + Server::builder() + .add_service(ProvisionerServer::new(self)) + .serve(address) + .await + }) + } + async fn get_db_connection_string( - &mut self, + &self, + service_name: &str, db_type: Type, - ) -> Result { - trace!("getting sql string for service '{}'", self.service_name); + ) -> Result { + trace!("getting sql string for service '{}'", service_name); let EngineConfig { r#type, @@ -70,7 +68,7 @@ impl Factory for LocalFactory { env, is_ready_cmd, } = db_type_to_config(db_type); - let container_name = format!("shuttle_{}_{}", self.service_name, r#type); + let container_name = format!("shuttle_{}_{}", service_name, r#type); let container = match self.docker.inspect_container(&container_name, None).await { Ok(container) => { @@ -118,7 +116,7 @@ impl Factory for LocalFactory { } Err(error) => { error!("got unexpected error while inspecting docker container: {error}"); - return Err(shuttle_service::Error::Custom(CustomError::new(error))); + return Err(Status::internal(error.to_string())); } }; @@ -153,52 +151,24 @@ impl Factory for LocalFactory { self.wait_for_ready(&container_name, is_ready_cmd).await?; - let db_info = DatabaseReadyInfo::new( + let res = DatabaseResponse { engine, username, password, database_name, port, - "localhost".to_string(), - "localhost".to_string(), - ); - - let conn_str = db_info.connection_string_private(); - - println!( - "{:>12} can be reached at {}\n", - "DB ready".bold().cyan(), - conn_str - ); - - Ok(conn_str) - } - - async fn get_secrets( - &mut self, - ) -> Result, shuttle_service::Error> { - Ok(self.secrets.clone()) - } - - fn get_service_name(&self) -> ServiceName { - self.service_name.clone() - } + address_private: "localhost".to_string(), + address_public: "localhost".to_string(), + }; - fn get_build_path(&self) -> Result { - Ok(self.working_directory.clone()) + Ok(res) } - fn get_storage_path(&self) -> Result { - Ok(self.working_directory.clone()) - } -} - -impl LocalFactory { async fn wait_for_ready( &self, container_name: &str, is_ready_cmd: Vec, - ) -> Result<(), shuttle_service::Error> { + ) -> Result<(), Status> { loop { trace!("waiting for '{container_name}' to be ready for connections"); @@ -276,6 +246,27 @@ impl LocalFactory { } } +#[async_trait] +impl Provisioner for LocalProvisioner { + async fn provision_database( + &self, + request: Request, + ) -> Result, Status> { + let DatabaseRequest { + project_name, + db_type, + } = request.into_inner(); + + let db_type: Option = db_type.unwrap().into(); + + let res = self + .get_db_connection_string(&project_name, db_type.unwrap()) + .await?; + + Ok(Response::new(res)) + } +} + fn print_layers(layers: &Vec) { for info in layers { stdout() diff --git a/codegen/Cargo.toml b/codegen/Cargo.toml index 00b26b436..407b53b5b 100644 --- a/codegen/Cargo.toml +++ b/codegen/Cargo.toml @@ -18,3 +18,7 @@ syn = { version = "1.0.104", features = ["full", "extra-traits"] } [dev-dependencies] pretty_assertions = "1.3.0" trybuild = "1.0.72" + +[features] +frameworks = [] +next = [] diff --git a/codegen/src/lib.rs b/codegen/src/lib.rs index 2adc85247..eddf72e07 100644 --- a/codegen/src/lib.rs +++ b/codegen/src/lib.rs @@ -1,27 +1,33 @@ +#[cfg(feature = "next")] mod next; +#[cfg(feature = "frameworks")] mod shuttle_main; -use next::App; use proc_macro::TokenStream; use proc_macro_error::proc_macro_error; -use syn::{parse_macro_input, File}; +#[cfg(feature = "frameworks")] #[proc_macro_error] #[proc_macro_attribute] pub fn main(attr: TokenStream, item: TokenStream) -> TokenStream { shuttle_main::r#impl(attr, item) } +#[cfg(feature = "next")] #[proc_macro_error] #[proc_macro] pub fn app(item: TokenStream) -> TokenStream { + use next::App; + use syn::{parse_macro_input, File}; + let mut file = parse_macro_input!(item as File); let app = App::from_file(&mut file); + let bindings = next::wasi_bindings(app); quote::quote!( #file - #app + #bindings ) .into() } diff --git a/codegen/src/next/mod.rs b/codegen/src/next/mod.rs index 520b9ee95..4de2e98ef 100644 --- a/codegen/src/next/mod.rs +++ b/codegen/src/next/mod.rs @@ -255,7 +255,6 @@ impl ToTokens for App { } } -#[allow(dead_code)] pub(crate) fn wasi_bindings(app: App) -> proc_macro2::TokenStream { quote!( #app @@ -270,7 +269,8 @@ pub(crate) fn wasi_bindings(app: App) -> proc_macro2::TokenStream { use axum::body::HttpBody; use std::io::{Read, Write}; use std::os::wasi::io::FromRawFd; - println!("inner handler awoken; interacting with fd={fd_3},{fd_4},{fd_5}"); + + println!("inner handler awoken; interacting with fd={},{},{}", fd_3, fd_4, fd_5); // file descriptor 3 for reading and writing http parts let mut parts_fd = unsafe { std::fs::File::from_raw_fd(fd_3) }; @@ -295,7 +295,7 @@ pub(crate) fn wasi_bindings(app: App) -> proc_macro2::TokenStream { .unwrap(); println!("inner router received request: {:?}", &request); - let res = handle_request(request); + let res = futures_executor::block_on(__app(request)); let (parts, mut body) = res.into_parts(); diff --git a/common/src/lib.rs b/common/src/lib.rs index 6443b90a3..a7a3a66fc 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -26,7 +26,6 @@ pub type ApiKey = String; pub type ApiUrl = String; pub type Host = String; pub type DeploymentId = Uuid; -pub type Port = u16; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DatabaseReadyInfo { diff --git a/common/src/storage_manager.rs b/common/src/storage_manager.rs index 5a5fa1300..f2270243a 100644 --- a/common/src/storage_manager.rs +++ b/common/src/storage_manager.rs @@ -2,13 +2,25 @@ use std::{fs, io, path::PathBuf}; use uuid::Uuid; -/// Manager to take care of directories for storing project, services and deployment files +pub trait StorageManager: Clone + Sync + Send { + /// Path for a specific service build files + fn service_build_path>(&self, service_name: S) -> Result; + + /// Path to folder for storing deployment files + fn deployment_storage_path>( + &self, + service_name: S, + deployment_id: &Uuid, + ) -> Result; +} + +/// Manager to take care of directories for storing project, services and deployment files for deployer #[derive(Clone)] -pub struct StorageManager { +pub struct ArtifactsStorageManager { artifacts_path: PathBuf, } -impl StorageManager { +impl ArtifactsStorageManager { pub fn new(artifacts_path: PathBuf) -> Self { Self { artifacts_path } } @@ -21,14 +33,6 @@ impl StorageManager { Ok(builds_path) } - /// Path for a specific service - pub fn service_build_path>(&self, service_name: S) -> Result { - let builds_path = self.builds_path()?.join(service_name.as_ref()); - fs::create_dir_all(&builds_path)?; - - Ok(builds_path) - } - /// The directory in which compiled '.so' files are stored. pub fn libraries_path(&self) -> Result { let libs_path = self.artifacts_path.join("shuttle-libs"); @@ -51,9 +55,17 @@ impl StorageManager { Ok(storage_path) } +} - /// Path to folder for storing deployment files - pub fn deployment_storage_path>( +impl StorageManager for ArtifactsStorageManager { + fn service_build_path>(&self, service_name: S) -> Result { + let builds_path = self.builds_path()?.join(service_name.as_ref()); + fs::create_dir_all(&builds_path)?; + + Ok(builds_path) + } + + fn deployment_storage_path>( &self, service_name: S, deployment_id: &Uuid, @@ -67,3 +79,29 @@ impl StorageManager { Ok(storage_path) } } + +/// Manager to take care of directories for storing project, services and deployment files for the local runner +#[derive(Clone)] +pub struct WorkingDirStorageManager { + working_dir: PathBuf, +} + +impl WorkingDirStorageManager { + pub fn new(working_dir: PathBuf) -> Self { + Self { working_dir } + } +} + +impl StorageManager for WorkingDirStorageManager { + fn service_build_path>(&self, _service_name: S) -> Result { + Ok(self.working_dir.clone()) + } + + fn deployment_storage_path>( + &self, + _service_name: S, + _deployment_id: &Uuid, + ) -> Result { + Ok(self.working_dir.clone()) + } +} diff --git a/deployer/Cargo.toml b/deployer/Cargo.toml index db2804a32..fa10591f3 100644 --- a/deployer/Cargo.toml +++ b/deployer/Cargo.toml @@ -27,6 +27,7 @@ opentelemetry = { version = "0.18.0", features = ["rt-tokio"] } opentelemetry-datadog = { version = "0.6.0", features = ["reqwest-client"] } opentelemetry-http = "0.7.0" pipe = "0.4.0" +portpicker = "0.1.1" serde = { workspace = true } serde_json = { workspace = true } sqlx = { version = "0.6.2", features = ["runtime-tokio-native-tls", "sqlite", "chrono", "json", "migrate", "uuid"] } @@ -35,7 +36,7 @@ tar = "0.4.38" thiserror = { workspace = true } tokio = { version = "1.22.0", features = ["fs", "process"] } toml = "0.5.9" -tonic = "0.8.3" +tonic = { workspace = true } tower = { version = "0.4.13", features = ["make"] } tower-http = { version = "0.3.4", features = ["auth", "trace"] } tracing = { workspace = true } diff --git a/deployer/build.rs b/deployer/build.rs new file mode 100644 index 000000000..ad2a67032 --- /dev/null +++ b/deployer/build.rs @@ -0,0 +1,17 @@ +use std::{env, process::Command}; + +fn main() { + println!("cargo:rerun-if-changed=../runtime"); + + // Build binary for runtime so that it can be embedded in the binary for deployer + let out_dir = env::var_os("OUT_DIR").unwrap(); + Command::new("cargo") + .arg("build") + .arg("--package") + .arg("shuttle-runtime") + .arg("--target-dir") + .arg(out_dir) + .arg("--release") + .output() + .expect("failed to build the shuttle runtime"); +} diff --git a/deployer/src/args.rs b/deployer/src/args.rs index 87b467bbc..75b417384 100644 --- a/deployer/src/args.rs +++ b/deployer/src/args.rs @@ -3,7 +3,8 @@ use std::{net::SocketAddr, path::PathBuf}; use clap::Parser; use fqdn::FQDN; use hyper::Uri; -use shuttle_common::{project::ProjectName, Port}; +use shuttle_common::project::ProjectName; +use tonic::transport::Endpoint; /// Program to handle the deploys for a single project /// Handling includes, building, testing, and running each service @@ -15,12 +16,8 @@ pub struct Args { pub state: String, /// Address to connect to the provisioning service - #[clap(long)] - pub provisioner_address: String, - - /// Port provisioner is running on - #[clap(long, default_value = "5000")] - pub provisioner_port: Port, + #[clap(long, default_value = "http://provisioner:5000")] + pub provisioner_address: Endpoint, /// FQDN where the proxy can be reached at #[clap(long)] diff --git a/deployer/src/deployment/deploy_layer.rs b/deployer/src/deployment/deploy_layer.rs index 6cbf3edbb..ca7df03e7 100644 --- a/deployer/src/deployment/deploy_layer.rs +++ b/deployer/src/deployment/deploy_layer.rs @@ -411,7 +411,7 @@ mod tests { deploy_layer::LogType, gateway_client::BuildQueueClient, ActiveDeploymentsGetter, Built, DeploymentManager, Queued, }, - persistence::{SecretRecorder, State}, + persistence::{Secret, SecretGetter, SecretRecorder, State}, }; use super::{DeployLayer, Log, LogRecorder}; @@ -536,6 +536,18 @@ mod tests { } } + #[derive(Clone)] + struct StubSecretGetter; + + #[async_trait::async_trait] + impl SecretGetter for StubSecretGetter { + type Err = std::io::Error; + + async fn get_secrets(&self, _service_id: &Uuid) -> Result, Self::Err> { + Ok(Default::default()) + } + } + #[tokio::test(flavor = "multi_thread")] async fn deployment_to_be_queued() { let deployment_manager = get_deployment_manager().await; @@ -944,6 +956,7 @@ mod tests { .secret_recorder(RECORDER.clone()) .active_deployment_getter(StubActiveDeploymentGetter) .artifacts_path(PathBuf::from("/tmp")) + .secret_getter(StubSecretGetter) .runtime(get_runtime_client().await) .queue_client(StubBuildQueueClient) .build() diff --git a/deployer/src/deployment/mod.rs b/deployer/src/deployment/mod.rs index c069e1e7f..0048573ae 100644 --- a/deployer/src/deployment/mod.rs +++ b/deployer/src/deployment/mod.rs @@ -7,13 +7,13 @@ use std::path::PathBuf; pub use queue::Queued; pub use run::{ActiveDeploymentsGetter, Built}; -use shuttle_common::storage_manager::StorageManager; +use shuttle_common::storage_manager::ArtifactsStorageManager; use shuttle_proto::runtime::runtime_client::RuntimeClient; use tonic::transport::Channel; use tracing::{instrument, Span}; use tracing_opentelemetry::OpenTelemetrySpanExt; -use crate::persistence::{SecretRecorder, State}; +use crate::persistence::{SecretGetter, SecretRecorder, State}; use tokio::sync::{broadcast, mpsc}; use uuid::Uuid; @@ -23,20 +23,22 @@ const QUEUE_BUFFER_SIZE: usize = 100; const RUN_BUFFER_SIZE: usize = 100; const KILL_BUFFER_SIZE: usize = 10; -pub struct DeploymentManagerBuilder { +pub struct DeploymentManagerBuilder { build_log_recorder: Option, secret_recorder: Option, active_deployment_getter: Option, artifacts_path: Option, runtime_client: Option>, + secret_getter: Option, queue_client: Option, } -impl DeploymentManagerBuilder +impl DeploymentManagerBuilder where LR: LogRecorder, SR: SecretRecorder, ADG: ActiveDeploymentsGetter, + SG: SecretGetter, QC: BuildQueueClient, { pub fn build_log_recorder(mut self, build_log_recorder: LR) -> Self { @@ -69,11 +71,18 @@ where self } + pub fn secret_getter(mut self, secret_getter: SG) -> Self { + self.secret_getter = Some(secret_getter); + + self + } + pub fn runtime(mut self, runtime_client: RuntimeClient) -> Self { self.runtime_client = Some(runtime_client); self } + /// Creates two Tokio tasks, one for building queued services, the other for /// executing/deploying built services. Two multi-producer, single consumer /// channels are also created which are for moving on-going service @@ -89,11 +98,12 @@ where let artifacts_path = self.artifacts_path.expect("artifacts path to be set"); let queue_client = self.queue_client.expect("a queue client to be set"); let runtime_client = self.runtime_client.expect("a runtime client to be set"); + let secret_getter = self.secret_getter.expect("a secret getter to be set"); let (queue_send, queue_recv) = mpsc::channel(QUEUE_BUFFER_SIZE); let (run_send, run_recv) = mpsc::channel(RUN_BUFFER_SIZE); let (kill_send, _) = broadcast::channel(KILL_BUFFER_SIZE); - let storage_manager = StorageManager::new(artifacts_path); + let storage_manager = ArtifactsStorageManager::new(artifacts_path); let run_send_clone = run_send.clone(); @@ -110,6 +120,7 @@ where runtime_client, kill_send.clone(), active_deployment_getter, + secret_getter, storage_manager.clone(), )); @@ -127,7 +138,7 @@ pub struct DeploymentManager { queue_send: QueueSender, run_send: RunSender, kill_send: KillSender, - storage_manager: StorageManager, + storage_manager: ArtifactsStorageManager, } /// ```no-test @@ -147,13 +158,14 @@ pub struct DeploymentManager { impl DeploymentManager { /// Create a new deployment manager. Manages one or more 'pipelines' for /// processing service building, loading, and deployment. - pub fn builder() -> DeploymentManagerBuilder { + pub fn builder() -> DeploymentManagerBuilder { DeploymentManagerBuilder { build_log_recorder: None, secret_recorder: None, active_deployment_getter: None, artifacts_path: None, runtime_client: None, + secret_getter: None, queue_client: None, } } @@ -179,7 +191,7 @@ impl DeploymentManager { } } - pub fn storage_manager(&self) -> StorageManager { + pub fn storage_manager(&self) -> ArtifactsStorageManager { self.storage_manager.clone() } } diff --git a/deployer/src/deployment/queue.rs b/deployer/src/deployment/queue.rs index c5ae79822..c3bcd1e52 100644 --- a/deployer/src/deployment/queue.rs +++ b/deployer/src/deployment/queue.rs @@ -3,7 +3,7 @@ use super::gateway_client::BuildQueueClient; use super::{Built, QueueReceiver, RunSender, State}; use crate::error::{Error, Result, TestError}; use crate::persistence::{LogLevel, SecretRecorder}; -use shuttle_common::storage_manager::StorageManager; +use shuttle_common::storage_manager::{ArtifactsStorageManager, StorageManager}; use cargo::util::interning::InternedString; use cargo_metadata::Message; @@ -36,7 +36,7 @@ pub async fn task( run_send: RunSender, log_recorder: impl LogRecorder, secret_recorder: impl SecretRecorder, - storage_manager: StorageManager, + storage_manager: ArtifactsStorageManager, queue_client: impl BuildQueueClient, ) { info!("Queue task started"); @@ -147,7 +147,7 @@ impl Queued { #[instrument(skip(self, storage_manager, log_recorder, secret_recorder), fields(id = %self.id, state = %State::Building))] async fn handle( self, - storage_manager: StorageManager, + storage_manager: ArtifactsStorageManager, log_recorder: impl LogRecorder, secret_recorder: impl SecretRecorder, ) -> Result { @@ -200,7 +200,7 @@ impl Queued { }); let project_path = project_path.canonicalize()?; - let so_path = build_deployment(self.id, &project_path, false, tx.clone()).await?; + let so_path = build_deployment(self.id, &project_path, tx.clone()).await?; if self.will_run_tests { info!( @@ -309,10 +309,9 @@ async fn extract_tar_gz_data(data: impl Read, dest: impl AsRef) -> Result< async fn build_deployment( deployment_id: Uuid, project_path: &Path, - wasm: bool, tx: crossbeam_channel::Sender, ) -> Result { - let runtime_path = build_crate(deployment_id, project_path, true, wasm, tx) + let runtime_path = build_crate(deployment_id, project_path, true, tx) .await .map_err(|e| Error::Build(e.into()))?; @@ -384,7 +383,7 @@ async fn run_pre_deploy_tests( /// Store 'so' file in the libs folder #[instrument(skip(storage_manager, so_path, id))] async fn store_lib( - storage_manager: &StorageManager, + storage_manager: &ArtifactsStorageManager, so_path: impl AsRef, id: &Uuid, ) -> Result<()> { @@ -399,7 +398,7 @@ async fn store_lib( mod tests { use std::{collections::BTreeMap, fs::File, io::Write, path::Path}; - use shuttle_common::storage_manager::StorageManager; + use shuttle_common::storage_manager::ArtifactsStorageManager; use tempdir::TempDir; use tokio::fs; use uuid::Uuid; @@ -529,7 +528,7 @@ ff0e55bda1ff01000000000000000000e0079c01ff12a55500280000", async fn store_lib() { let libs_dir = TempDir::new("lib-store").unwrap(); let libs_p = libs_dir.path(); - let storage_manager = StorageManager::new(libs_p.to_path_buf()); + let storage_manager = ArtifactsStorageManager::new(libs_p.to_path_buf()); let build_p = storage_manager.builds_path().unwrap(); diff --git a/deployer/src/deployment/run.rs b/deployer/src/deployment/run.rs index 4837e882f..7b14a4e7b 100644 --- a/deployer/src/deployment/run.rs +++ b/deployer/src/deployment/run.rs @@ -7,8 +7,9 @@ use std::{ use async_trait::async_trait; use opentelemetry::global; +use portpicker::pick_unused_port; use shuttle_common::project::ProjectName as ServiceName; -use shuttle_common::storage_manager::StorageManager; +use shuttle_common::storage_manager::ArtifactsStorageManager; use shuttle_proto::runtime::{runtime_client::RuntimeClient, LoadRequest, StartRequest}; use tokio::task::JoinError; @@ -18,17 +19,20 @@ use tracing_opentelemetry::OpenTelemetrySpanExt; use uuid::Uuid; use super::{KillReceiver, KillSender, RunReceiver, State}; -use crate::error::{Error, Result}; +use crate::{ + error::{Error, Result}, + persistence::SecretGetter, +}; -/// Run a task which takes runnable deploys from a channel and starts them up with a factory provided by the -/// abstract factory and a runtime logger provided by the logger factory +/// Run a task which takes runnable deploys from a channel and starts them up on our runtime /// A deploy is killed when it receives a signal from the kill channel pub async fn task( mut recv: RunReceiver, runtime_client: RuntimeClient, kill_send: KillSender, active_deployment_getter: impl ActiveDeploymentsGetter, - storage_manager: StorageManager, + secret_getter: impl SecretGetter, + storage_manager: ArtifactsStorageManager, ) { info!("Run task started"); @@ -39,12 +43,9 @@ pub async fn task( let kill_send = kill_send.clone(); let kill_recv = kill_send.subscribe(); + let secret_getter = secret_getter.clone(); let storage_manager = storage_manager.clone(); - // todo: this is the port the legacy runtime is hardcoded to start services on - let port = 7001; - - let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), port); let _service_name = match ServiceName::from_str(&built.service_name) { Ok(name) => name, Err(err) => { @@ -82,8 +83,8 @@ pub async fn task( async move { if let Err(err) = built .handle( - addr, storage_manager, + secret_getter, runtime_client, kill_recv, old_deployments_killer, @@ -171,12 +172,12 @@ pub struct Built { } impl Built { - #[instrument(skip(self, storage_manager, runtime_client, kill_recv, kill_old_deployments, cleanup), fields(id = %self.id, state = %State::Loading))] + #[instrument(skip(self, storage_manager, secret_getter, runtime_client, kill_recv, kill_old_deployments, cleanup), fields(id = %self.id, state = %State::Loading))] #[allow(clippy::too_many_arguments)] async fn handle( self, - address: SocketAddr, - storage_manager: StorageManager, + storage_manager: ArtifactsStorageManager, + secret_getter: impl SecretGetter, runtime_client: RuntimeClient, kill_recv: KillReceiver, kill_old_deployments: impl futures::Future>, @@ -186,14 +187,32 @@ impl Built { ) -> Result<()> { let so_path = storage_manager.deployment_library_path(&self.id)?; + let port = match pick_unused_port() { + Some(port) => port, + None => { + return Err(Error::PrepareRun( + "could not find a free port to deploy service on".to_string(), + )) + } + }; + + let address = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), port); + kill_old_deployments.await?; info!("got handle for deployment"); // Execute loaded service + load( + self.service_name.clone(), + self.service_id, + so_path, + secret_getter, + runtime_client.clone(), + ) + .await; tokio::spawn(run( self.id, self.service_name, - so_path, runtime_client, address, kill_recv, @@ -204,26 +223,30 @@ impl Built { } } -#[instrument(skip(runtime_client, _kill_recv, _cleanup), fields(address = %_address, state = %State::Running))] -async fn run( - id: Uuid, +async fn load( service_name: String, + service_id: Uuid, so_path: PathBuf, + secret_getter: impl SecretGetter, mut runtime_client: RuntimeClient, - _address: SocketAddr, - _kill_recv: KillReceiver, - _cleanup: impl FnOnce(std::result::Result, JoinError>) - + Send - + 'static, ) { info!( "loading project from: {}", so_path.clone().into_os_string().into_string().unwrap() ); + let secrets = secret_getter + .get_secrets(&service_id) + .await + .unwrap() + .into_iter() + .map(|secret| (secret.key, secret.value)); + let secrets = HashMap::from_iter(secrets); + let load_request = tonic::Request::new(LoadRequest { path: so_path.into_os_string().into_string().unwrap(), service_name: service_name.clone(), + secrets, }); info!("loading service"); let response = runtime_client.load(load_request).await; @@ -231,10 +254,23 @@ async fn run( if let Err(e) = response { info!("failed to load service: {}", e); } +} +#[instrument(skip(runtime_client, _kill_recv, _cleanup), fields(state = %State::Running))] +async fn run( + id: Uuid, + service_name: String, + mut runtime_client: RuntimeClient, + address: SocketAddr, + _kill_recv: KillReceiver, + _cleanup: impl FnOnce(std::result::Result, JoinError>) + + Send + + 'static, +) { let start_request = tonic::Request::new(StartRequest { deployment_id: id.as_bytes().to_vec(), service_name, + port: address.port() as u32, }); info!("starting service"); @@ -245,14 +281,10 @@ async fn run( #[cfg(test)] mod tests { - use std::{ - net::{Ipv4Addr, SocketAddr}, - path::PathBuf, - process::Command, - time::Duration, - }; + use std::{path::PathBuf, process::Command, time::Duration}; - use shuttle_common::storage_manager::StorageManager; + use async_trait::async_trait; + use shuttle_common::storage_manager::ArtifactsStorageManager; use shuttle_proto::runtime::runtime_client::RuntimeClient; use tempdir::TempDir; use tokio::{ @@ -263,17 +295,20 @@ mod tests { use tonic::transport::Channel; use uuid::Uuid; - use crate::error::Error; + use crate::{ + error::Error, + persistence::{Secret, SecretGetter}, + }; use super::Built; const RESOURCES_PATH: &str = "tests/resources"; - fn get_storage_manager() -> StorageManager { + fn get_storage_manager() -> ArtifactsStorageManager { let tmp_dir = TempDir::new("shuttle_run_test").unwrap(); let path = tmp_dir.into_path(); - StorageManager::new(path) + ArtifactsStorageManager::new(path) } async fn kill_old_deployments() -> crate::error::Result<()> { @@ -286,6 +321,22 @@ mod tests { .unwrap() } + #[derive(Clone)] + struct StubSecretGetter; + + #[async_trait] + impl SecretGetter for StubSecretGetter { + type Err = std::io::Error; + + async fn get_secrets(&self, _service_id: &Uuid) -> Result, Self::Err> { + Ok(Default::default()) + } + } + + fn get_secret_getter() -> StubSecretGetter { + StubSecretGetter + } + // This test uses the kill signal to make sure a service does stop when asked to #[tokio::test] async fn can_be_killed() { @@ -305,12 +356,12 @@ mod tests { ); cleanup_send.send(()).unwrap(); }; - let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8001); + let secret_getter = get_secret_getter(); built .handle( - addr, storage_manager, + secret_getter, get_runtime_client().await, kill_recv, kill_old_deployments(), @@ -350,12 +401,12 @@ mod tests { ); cleanup_send.send(()).unwrap(); }; - let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8001); + let secret_getter = get_secret_getter(); built .handle( - addr, storage_manager, + secret_getter, get_runtime_client().await, kill_recv, kill_old_deployments(), @@ -389,12 +440,12 @@ mod tests { ); cleanup_send.send(()).unwrap(); }; - let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8001); + let secret_getter = get_secret_getter(); built .handle( - addr, storage_manager, + secret_getter, get_runtime_client().await, kill_recv, kill_old_deployments(), @@ -416,12 +467,12 @@ mod tests { let (_kill_send, kill_recv) = broadcast::channel(1); let handle_cleanup = |_result| panic!("the service shouldn't even start"); - let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8001); + let secret_getter = get_secret_getter(); let result = built .handle( - addr, storage_manager, + secret_getter, get_runtime_client().await, kill_recv, kill_old_deployments(), @@ -447,13 +498,13 @@ mod tests { let (_kill_send, kill_recv) = broadcast::channel(1); let handle_cleanup = |_result| panic!("no service means no cleanup"); - let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8001); + let secret_getter = get_secret_getter(); let storage_manager = get_storage_manager(); let result = built .handle( - addr, storage_manager, + secret_getter, get_runtime_client().await, kill_recv, kill_old_deployments(), @@ -471,7 +522,7 @@ mod tests { ); } - fn make_so_and_built(crate_name: &str) -> (Built, StorageManager) { + fn make_so_and_built(crate_name: &str) -> (Built, ArtifactsStorageManager) { let crate_dir: PathBuf = [RESOURCES_PATH, crate_name].iter().collect(); Command::new("cargo") diff --git a/deployer/src/error.rs b/deployer/src/error.rs index e01766f1f..1adba5ae2 100644 --- a/deployer/src/error.rs +++ b/deployer/src/error.rs @@ -16,6 +16,8 @@ pub enum Error { Build(#[source] Box), #[error("Load error: {0}")] Load(#[from] LoaderError), + #[error("Prepare to run error: {0}")] + PrepareRun(String), #[error("Run error: {0}")] Run(#[from] shuttle_service::Error), #[error("Pre-deployment test failure: {0}")] diff --git a/deployer/src/handlers/mod.rs b/deployer/src/handlers/mod.rs index 3f2a83298..848dbff7f 100644 --- a/deployer/src/handlers/mod.rs +++ b/deployer/src/handlers/mod.rs @@ -16,6 +16,7 @@ use opentelemetry_http::HeaderExtractor; use shuttle_common::backends::metrics::Metrics; use shuttle_common::models::secret; use shuttle_common::project::ProjectName; +use shuttle_common::storage_manager::StorageManager; use shuttle_common::LogItem; use shuttle_service::loader::clean_crate; use tower_http::auth::RequireAuthorizationLayer; diff --git a/deployer/src/lib.rs b/deployer/src/lib.rs index 7ba32683d..901b818e2 100644 --- a/deployer/src/lib.rs +++ b/deployer/src/lib.rs @@ -30,6 +30,7 @@ pub async fn start(persistence: Persistence, runtime_client: RuntimeClient secrets = 10; } message LoadResponse { @@ -34,16 +37,13 @@ message StartRequest { bytes deployment_id = 1; // Name of service to start string service_name = 2; + // Port to start the service on + uint32 port = 3; } message StartResponse { // Was the start successful bool success = 1; - - // todo: find a way to add optional flag here - // Optional port the service was started on - // This is likely to be None for bots - uint32 port = 2; } message StopRequest { diff --git a/proto/src/lib.rs b/proto/src/lib.rs index 6d820a561..73726c2e8 100644 --- a/proto/src/lib.rs +++ b/proto/src/lib.rs @@ -53,6 +53,36 @@ pub mod provisioner { } } + impl From for Option { + fn from(db_type: database_request::DbType) -> Self { + match db_type { + database_request::DbType::Shared(Shared { + engine: Some(engine), + }) => match engine { + shared::Engine::Postgres(_) => { + Some(database::Type::Shared(SharedEngine::Postgres)) + } + shared::Engine::Mongodb(_) => { + Some(database::Type::Shared(SharedEngine::MongoDb)) + } + }, + database_request::DbType::AwsRds(AwsRds { + engine: Some(engine), + }) => match engine { + aws_rds::Engine::Postgres(_) => { + Some(database::Type::AwsRds(AwsRdsEngine::Postgres)) + } + aws_rds::Engine::Mysql(_) => Some(database::Type::AwsRds(AwsRdsEngine::MySql)), + aws_rds::Engine::Mariadb(_) => { + Some(database::Type::AwsRds(AwsRdsEngine::MariaDB)) + } + }, + database_request::DbType::Shared(Shared { engine: None }) + | database_request::DbType::AwsRds(AwsRds { engine: None }) => None, + } + } + } + impl Display for aws_rds::Engine { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -65,9 +95,26 @@ pub mod provisioner { } pub mod runtime { - use std::time::SystemTime; + use std::{ + env::temp_dir, + fs::OpenOptions, + io::Write, + path::PathBuf, + time::{Duration, SystemTime}, + }; + use anyhow::Context; + use chrono::DateTime; use prost_types::Timestamp; + use tokio::process; + use tonic::transport::{Channel, Endpoint}; + use tracing::info; + use uuid::Uuid; + + pub enum StorageManagerType { + Artifacts(PathBuf), + WorkingDir(PathBuf), + } tonic::include_proto!("runtime"); @@ -113,4 +160,115 @@ pub mod runtime { } } } + + impl From for shuttle_common::LogItem { + fn from(log: LogItem) -> Self { + Self { + id: Uuid::from_slice(&log.id).unwrap(), + timestamp: DateTime::from(SystemTime::try_from(log.timestamp.unwrap()).unwrap()), + state: LogState::from_i32(log.state).unwrap().into(), + level: LogLevel::from_i32(log.level).unwrap().into(), + file: log.file, + line: log.line, + target: log.target, + fields: log.fields, + } + } + } + + impl From for shuttle_common::deployment::State { + fn from(state: LogState) -> Self { + match state { + LogState::Queued => Self::Queued, + LogState::Building => Self::Building, + LogState::Built => Self::Built, + LogState::Loading => Self::Loading, + LogState::Running => Self::Running, + LogState::Completed => Self::Completed, + LogState::Stopped => Self::Stopped, + LogState::Crashed => Self::Crashed, + LogState::Unknown => Self::Unknown, + } + } + } + + impl From for shuttle_common::log::Level { + fn from(level: LogLevel) -> Self { + match level { + LogLevel::Trace => Self::Trace, + LogLevel::Debug => Self::Debug, + LogLevel::Info => Self::Info, + LogLevel::Warn => Self::Warn, + LogLevel::Error => Self::Error, + } + } + } + + pub async fn start( + binary_bytes: &[u8], + wasm: bool, + storage_manager_type: StorageManagerType, + provisioner_address: &str, + ) -> anyhow::Result<(process::Child, runtime_client::RuntimeClient)> { + let runtime_flag = if wasm { "--axum" } else { "--legacy" }; + + let (storage_manager_type, storage_manager_path) = match storage_manager_type { + StorageManagerType::Artifacts(path) => ("artifacts", path), + StorageManagerType::WorkingDir(path) => ("working-dir", path), + }; + + let runtime_executable = get_runtime_executable(binary_bytes); + + let runtime = process::Command::new(runtime_executable) + .args([ + runtime_flag, + "--provisioner-address", + provisioner_address, + "--storage-manager-type", + storage_manager_type, + "--storage-manager-path", + &storage_manager_path.display().to_string(), + ]) + .spawn() + .context("spawning runtime process")?; + + // Sleep because the timeout below does not seem to work + // TODO: investigate why + tokio::time::sleep(Duration::from_secs(2)).await; + + info!("connecting runtime client"); + let conn = Endpoint::new("http://127.0.0.1:6001") + .context("creating runtime client endpoint")? + .connect_timeout(Duration::from_secs(5)); + + let runtime_client = runtime_client::RuntimeClient::connect(conn) + .await + .context("connecting runtime client")?; + + Ok((runtime, runtime_client)) + } + + fn get_runtime_executable(binary_bytes: &[u8]) -> PathBuf { + let tmp_dir = temp_dir(); + + let path = tmp_dir.join("shuttle-runtime"); + let mut open_options = OpenOptions::new(); + open_options.write(true).create(true).truncate(true); + + #[cfg(target_family = "unix")] + { + use std::os::unix::prelude::OpenOptionsExt; + + open_options.mode(0o755); + } + + let mut file = open_options + .open(&path) + .expect("to create runtime executable file"); + + file.write_all(binary_bytes) + .expect("to write out binary file"); + + path + } } diff --git a/provisioner/Cargo.toml b/provisioner/Cargo.toml index 91b1bff91..1ce62ca7b 100644 --- a/provisioner/Cargo.toml +++ b/provisioner/Cargo.toml @@ -18,7 +18,7 @@ rand = "0.8.5" sqlx = { version = "0.6.2", features = ["postgres", "runtime-tokio-native-tls"] } thiserror = { workspace = true } tokio = { version = "1.22.0", features = ["macros", "rt-multi-thread"] } -tonic = "0.8.3" +tonic = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } diff --git a/runtime/Cargo.toml b/runtime/Cargo.toml index 89fbf12b4..efc9f2713 100644 --- a/runtime/Cargo.toml +++ b/runtime/Cargo.toml @@ -16,7 +16,7 @@ rmp-serde = { version = "1.1.1" } thiserror = { workspace = true } tokio = { version = "=1.22.0", features = ["full"] } tokio-stream = "0.1.11" -tonic = "0.8.2" +tonic = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } uuid = { workspace = true, features = ["v4"] } diff --git a/runtime/Makefile b/runtime/Makefile index a706e8146..6472a425e 100644 --- a/runtime/Makefile +++ b/runtime/Makefile @@ -3,8 +3,8 @@ all: axum axum: - cd ../tmp/axum-wasm; cargo build --target wasm32-wasi - cp ../tmp/axum-wasm/target/wasm32-wasi/debug/shuttle_axum.wasm axum.wasm + cd ../tmp/axum-wasm-expanded; cargo build --target wasm32-wasi + cp ../tmp/axum-wasm-expanded/target/wasm32-wasi/debug/shuttle_axum_expanded.wasm axum.wasm test: axum cargo test --all-features -- --nocapture diff --git a/runtime/README.md b/runtime/README.md index e0ea4da9f..13205bc4a 100644 --- a/runtime/README.md +++ b/runtime/README.md @@ -1,5 +1,15 @@ # How to run +## The easy way +Both the legacy and next examples can be run using the local client: + +``` bash +cd path/to/example +cargo run --manifest ../../../Cargo.toml --bin cargo-shuttle -- run +``` + +When a more fine controlled testing is needed, use the instructions below. + ## axum-wasm Compile the wasm axum router: diff --git a/runtime/src/args.rs b/runtime/src/args.rs index 2e123f4e8..016121d4b 100644 --- a/runtime/src/args.rs +++ b/runtime/src/args.rs @@ -1,17 +1,36 @@ -use clap::Parser; +use std::path::PathBuf; + +use clap::{Parser, ValueEnum}; use tonic::transport::Endpoint; #[derive(Parser, Debug)] pub struct Args { /// Address to reach provisioner at - #[clap(long, default_value = "localhost:5000")] + #[arg(long, default_value = "http://localhost:5000")] pub provisioner_address: Endpoint, /// Is this runtime for a legacy service - #[clap(long, conflicts_with("axum"))] + #[arg(long, conflicts_with("axum"))] pub legacy: bool, /// Is this runtime for an axum-wasm service - #[clap(long, conflicts_with("legacy"))] + #[arg(long, conflicts_with("legacy"))] pub axum: bool, + + /// Type of storage manager to start + #[arg(long, value_enum)] + pub storage_manager_type: StorageManagerType, + + /// Path to use for storage manager + #[arg(long)] + pub storage_manager_path: PathBuf, +} + +#[derive(Clone, Debug, ValueEnum)] +pub enum StorageManagerType { + /// Use a deloyer artifacts directory + Artifacts, + + /// Use a local working directory + WorkingDir, } diff --git a/runtime/src/axum/mod.rs b/runtime/src/axum/mod.rs index e3dada547..68fbc07d7 100644 --- a/runtime/src/axum/mod.rs +++ b/runtime/src/axum/mod.rs @@ -5,7 +5,6 @@ use std::ops::DerefMut; use std::os::unix::prelude::RawFd; use std::path::{Path, PathBuf}; use std::str::FromStr; -use std::sync::Mutex; use async_trait::async_trait; use cap_std::os::unix::net::UnixStream; @@ -20,7 +19,7 @@ use shuttle_proto::runtime::{ SubscribeLogsRequest, }; use shuttle_service::ServiceName; -use tokio::sync::oneshot; +use tokio::sync::{mpsc, oneshot}; use tokio_stream::wrappers::ReceiverStream; use tonic::Status; use tracing::{error, trace}; @@ -33,7 +32,6 @@ extern crate rmp_serde as rmps; pub struct AxumWasm { router: std::sync::Mutex>, - port: Mutex>, kill_tx: std::sync::Mutex>>, } @@ -41,7 +39,6 @@ impl AxumWasm { pub fn new() -> Self { Self { router: std::sync::Mutex::new(None), - port: std::sync::Mutex::new(None), kill_tx: std::sync::Mutex::new(None), } } @@ -73,10 +70,10 @@ impl Runtime for AxumWasm { async fn start( &self, - _request: tonic::Request, + request: tonic::Request, ) -> Result, Status> { - let port = 7002; - let address = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), port); + let StartRequest { port, .. } = request.into_inner(); + let address = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), port as u16); let router = self.router.lock().unwrap().take().unwrap(); @@ -87,12 +84,7 @@ impl Runtime for AxumWasm { // TODO: split `into_server` up into build and run functions tokio::spawn(router.into_server(address, kill_rx)); - *self.port.lock().unwrap() = Some(port); - - let message = StartResponse { - success: true, - port: port as u32, - }; + let message = StartResponse { success: true }; Ok(tonic::Response::new(message)) } @@ -103,7 +95,9 @@ impl Runtime for AxumWasm { &self, _request: tonic::Request, ) -> Result, Status> { - todo!() + let (_tx, rx) = mpsc::channel(1); + + Ok(tonic::Response::new(ReceiverStream::new(rx))) } async fn stop( diff --git a/runtime/src/legacy/mod.rs b/runtime/src/legacy/mod.rs index 18e0f8147..c3d98c44f 100644 --- a/runtime/src/legacy/mod.rs +++ b/runtime/src/legacy/mod.rs @@ -1,4 +1,6 @@ use std::{ + collections::BTreeMap, + iter::FromIterator, net::{Ipv4Addr, SocketAddr}, ops::DerefMut, path::PathBuf, @@ -31,40 +33,53 @@ use crate::provisioner_factory::{AbstractFactory, AbstractProvisionerFactory}; mod error; -pub struct Legacy { +pub struct Legacy +where + S: StorageManager, +{ // Mutexes are for interior mutability so_path: Mutex>, - port: Mutex>, logs_rx: Mutex>>, logs_tx: Mutex>, provisioner_address: Endpoint, kill_tx: Mutex>>, + secrets: Mutex>>, + storage_manager: S, } -impl Legacy { - pub fn new(provisioner_address: Endpoint) -> Self { +impl Legacy +where + S: StorageManager, +{ + pub fn new(provisioner_address: Endpoint, storage_manager: S) -> Self { let (tx, rx) = mpsc::unbounded_channel(); Self { so_path: Mutex::new(None), - port: Mutex::new(None), logs_rx: Mutex::new(Some(rx)), logs_tx: Mutex::new(tx), kill_tx: Mutex::new(None), provisioner_address, + secrets: Mutex::new(None), + storage_manager, } } } #[async_trait] -impl Runtime for Legacy { +impl Runtime for Legacy +where + S: StorageManager + 'static, +{ async fn load(&self, request: Request) -> Result, Status> { - let so_path = request.into_inner().path; - trace!(so_path, "loading"); + let LoadRequest { path, secrets, .. } = request.into_inner(); + trace!(path, "loading"); - let so_path = PathBuf::from(so_path); + let so_path = PathBuf::from(path); *self.so_path.lock().unwrap() = Some(so_path); + *self.secrets.lock().unwrap() = Some(BTreeMap::from_iter(secrets.into_iter())); + let message = LoadResponse { success: true }; Ok(Response::new(message)) } @@ -73,9 +88,6 @@ impl Runtime for Legacy { &self, request: Request, ) -> Result, Status> { - let service_port = 7001; - let service_address = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), service_port); - let provisioner_client = ProvisionerClient::connect(self.provisioner_address.clone()) .await .expect("failed to connect to provisioner"); @@ -91,21 +103,37 @@ impl Runtime for Legacy { }) .map_err(|err| Status::from_error(Box::new(err)))? .clone(); - - let storage_manager = StorageManager::new(so_path.clone()); + let secrets = self + .secrets + .lock() + .unwrap() + .as_ref() + .ok_or_else(|| -> error::Error { + error::Error::Start(anyhow!( + "trying to get secrets from a service that was not loaded" + )) + }) + .map_err(|err| Status::from_error(Box::new(err)))? + .clone(); let StartRequest { deployment_id, service_name, + port, } = request.into_inner(); + let service_address = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), port as u16); let service_name = ServiceName::from_str(service_name.as_str()) .map_err(|err| Status::from_error(Box::new(err)))?; - let deployment_id = Uuid::from_str(std::str::from_utf8(&deployment_id).unwrap()).unwrap(); + let deployment_id = Uuid::from_slice(&deployment_id).unwrap(); - let mut factory = - abstract_factory.get_factory(service_name, deployment_id, storage_manager); + let mut factory = abstract_factory.get_factory( + service_name, + deployment_id, + secrets, + self.storage_manager.clone(), + ); let logs_tx = self.logs_tx.lock().unwrap().clone(); @@ -114,7 +142,7 @@ impl Runtime for Legacy { trace!(%service_address, "starting"); let service = load_service(service_address, so_path, &mut factory, logger) .await - .unwrap(); + .map_err(|error| Status::internal(error.to_string()))?; let (kill_tx, kill_rx) = tokio::sync::oneshot::channel(); @@ -123,12 +151,7 @@ impl Runtime for Legacy { // start service as a background task with a kill receiver tokio::spawn(run_until_stopped(service, service_address, kill_rx)); - *self.port.lock().unwrap() = Some(service_port); - - let message = StartResponse { - success: true, - port: service_port as u32, - }; + let message = StartResponse { success: true }; Ok(Response::new(message)) } @@ -184,7 +207,7 @@ impl Runtime for Legacy { } /// Run the service until a stop signal is received -#[instrument(skip(service))] +#[instrument(skip(service, kill_rx))] async fn run_until_stopped( service: LoadedService, addr: SocketAddr, diff --git a/runtime/src/lib.rs b/runtime/src/lib.rs index 81d7acaef..fc17f47c4 100644 --- a/runtime/src/lib.rs +++ b/runtime/src/lib.rs @@ -3,6 +3,6 @@ mod axum; mod legacy; pub mod provisioner_factory; -pub use args::Args; +pub use args::{Args, StorageManagerType}; pub use axum::AxumWasm; pub use legacy::Legacy; diff --git a/runtime/src/main.rs b/runtime/src/main.rs index 3c4f64cb5..b7d9f3c3e 100644 --- a/runtime/src/main.rs +++ b/runtime/src/main.rs @@ -4,8 +4,9 @@ use std::{ }; use clap::Parser; +use shuttle_common::storage_manager::{ArtifactsStorageManager, WorkingDirStorageManager}; use shuttle_proto::runtime::runtime_server::RuntimeServer; -use shuttle_runtime::{Args, AxumWasm, Legacy}; +use shuttle_runtime::{Args, AxumWasm, Legacy, StorageManagerType}; use tonic::transport::Server; use tracing::trace; use tracing_subscriber::{fmt, prelude::*, EnvFilter}; @@ -33,9 +34,26 @@ async fn main() { Server::builder().http2_keepalive_interval(Some(Duration::from_secs(60))); let router = if args.legacy { - let legacy = Legacy::new(provisioner_address); - let svc = RuntimeServer::new(legacy); - server_builder.add_service(svc) + match args.storage_manager_type { + StorageManagerType::Artifacts => { + let legacy = Legacy::new( + provisioner_address, + ArtifactsStorageManager::new(args.storage_manager_path), + ); + + let svc = RuntimeServer::new(legacy); + server_builder.add_service(svc) + } + StorageManagerType::WorkingDir => { + let legacy = Legacy::new( + provisioner_address, + WorkingDirStorageManager::new(args.storage_manager_path), + ); + + let svc = RuntimeServer::new(legacy); + server_builder.add_service(svc) + } + } } else if args.axum { let axum = AxumWasm::default(); let svc = RuntimeServer::new(axum); diff --git a/runtime/src/provisioner_factory.rs b/runtime/src/provisioner_factory.rs index 28bd9d7c9..58cd34aa2 100644 --- a/runtime/src/provisioner_factory.rs +++ b/runtime/src/provisioner_factory.rs @@ -11,7 +11,7 @@ use tracing::{debug, info, trace}; use uuid::Uuid; /// Trait to make it easy to get a factory (service locator) for each service being started -pub trait AbstractFactory: Send + 'static { +pub trait AbstractFactory: Send + 'static { type Output: Factory; /// Get a factory for a specific service @@ -19,7 +19,8 @@ pub trait AbstractFactory: Send + 'static { &self, service_name: ServiceName, deployment_id: Uuid, - storage_manager: StorageManager, + secrets: BTreeMap, + storage_manager: S, ) -> Self::Output; } @@ -29,19 +30,24 @@ pub struct AbstractProvisionerFactory { provisioner_client: ProvisionerClient, } -impl AbstractFactory for AbstractProvisionerFactory { - type Output = ProvisionerFactory; +impl AbstractFactory for AbstractProvisionerFactory +where + S: StorageManager, +{ + type Output = ProvisionerFactory; fn get_factory( &self, service_name: ServiceName, deployment_id: Uuid, - storage_manager: StorageManager, + secrets: BTreeMap, + storage_manager: S, ) -> Self::Output { ProvisionerFactory::new( self.provisioner_client.clone(), service_name, deployment_id, + secrets, storage_manager, ) } @@ -54,21 +60,28 @@ impl AbstractProvisionerFactory { } /// A factory (service locator) which goes through the provisioner crate -pub struct ProvisionerFactory { +pub struct ProvisionerFactory +where + S: StorageManager, +{ service_name: ServiceName, deployment_id: Uuid, - storage_manager: StorageManager, + storage_manager: S, provisioner_client: ProvisionerClient, info: Option, - secrets: Option>, + secrets: BTreeMap, } -impl ProvisionerFactory { +impl ProvisionerFactory +where + S: StorageManager, +{ pub(crate) fn new( provisioner_client: ProvisionerClient, service_name: ServiceName, deployment_id: Uuid, - storage_manager: StorageManager, + secrets: BTreeMap, + storage_manager: S, ) -> Self { Self { provisioner_client, @@ -76,18 +89,21 @@ impl ProvisionerFactory { deployment_id, storage_manager, info: None, - secrets: None, + secrets, } } } #[async_trait] -impl Factory for ProvisionerFactory { +impl Factory for ProvisionerFactory +where + S: StorageManager + Sync + Send, +{ async fn get_db_connection_string( &mut self, db_type: database::Type, ) -> Result { - info!("Provisioning a {db_type} on the shuttle servers. This can take a while..."); + info!("Provisioning a {db_type}. This can take a while..."); if let Some(ref info) = self.info { debug!("A database has already been provisioned for this deployment, so reusing it"); @@ -119,12 +135,7 @@ impl Factory for ProvisionerFactory { } async fn get_secrets(&mut self) -> Result, shuttle_service::Error> { - if let Some(ref secrets) = self.secrets { - debug!("Returning previously fetched secrets"); - Ok(secrets.clone()) - } else { - todo!() - } + Ok(self.secrets.clone()) } fn get_service_name(&self) -> ServiceName { diff --git a/service/Cargo.toml b/service/Cargo.toml index 08512f6e7..523a16281 100644 --- a/service/Cargo.toml +++ b/service/Cargo.toml @@ -62,7 +62,9 @@ tokio = { version = "1.22.0", features = ["macros"] } uuid = { workspace = true, features = ["v4"] } [features] -codegen = ["shuttle-codegen"] +default = ["codegen"] + +codegen = ["shuttle-codegen/frameworks"] loader = ["cargo", "libloading"] web-actix-web = ["actix-web", "num_cpus"] diff --git a/service/src/error.rs b/service/src/error.rs index de1898af1..ba261ef46 100644 --- a/service/src/error.rs +++ b/service/src/error.rs @@ -8,8 +8,6 @@ pub enum Error { Io(#[from] std::io::Error), #[error("Database error: {0}")] Database(String), - #[error("Secret error: {0}")] - Secret(String), #[error("Panic occurred in shuttle_service::main`: {0}")] BuildPanic(String), #[error("Panic occurred in `Service::bind`: {0}")] @@ -19,12 +17,3 @@ pub enum Error { } pub type CustomError = anyhow::Error; - -// This is implemented manually as defining `Error::Database(#[from] mongodb::error::Error)` resulted in a -// segfault even with a feature guard. -#[cfg(feature = "mongodb-integration")] -impl From for Error { - fn from(e: mongodb::error::Error) -> Self { - Error::Database(e.to_string()) - } -} diff --git a/service/src/loader.rs b/service/src/loader.rs index b57bffc70..fd4f93cf7 100644 --- a/service/src/loader.rs +++ b/service/src/loader.rs @@ -111,7 +111,6 @@ pub async fn build_crate( deployment_id: Uuid, project_path: &Path, release_mode: bool, - wasm: bool, tx: Sender, ) -> anyhow::Result { let (read, write) = pipe::pipe(); @@ -145,14 +144,18 @@ pub async fn build_crate( let summary = current.manifest_mut().summary_mut(); make_name_unique(summary, deployment_id); - check_version(summary)?; + + let is_next = is_next(summary); + if !is_next { + check_version(summary)?; + } check_no_panic(&ws)?; - let opts = get_compile_options(&config, release_mode, wasm)?; + let opts = get_compile_options(&config, release_mode, is_next)?; let compilation = compile(&ws, &opts); let path = compilation?.cdylibs[0].path.clone(); - Ok(if wasm { + Ok(if is_next { Runtime::Next(path) } else { Runtime::Legacy(path) @@ -246,7 +249,7 @@ fn get_compile_options( }; opts.build_config.requested_kinds = vec![if wasm { - CompileKind::Target(CompileTarget::new("wasm32-unknown-unknown")?) + CompileKind::Target(CompileTarget::new("wasm32-wasi")?) } else { CompileKind::Host }]; @@ -292,6 +295,20 @@ fn make_name_unique(summary: &mut Summary, deployment_id: Uuid) { ); } +fn is_next(summary: &Summary) -> bool { + let features = if let Some(shuttle) = summary + .dependencies() + .iter() + .find(|dependency| dependency.package_name() == "shuttle-codegen") + { + shuttle.features() + } else { + &[] + }; + + features.contains(&InternedString::new("next")) +} + /// Check that the crate being build is compatible with this version of loader fn check_version(summary: &Summary) -> anyhow::Result<()> { let valid_version = VERSION.to_semver().unwrap(); diff --git a/service/tests/integration/build_crate.rs b/service/tests/integration/build_crate.rs index 432ad7751..6c7e289eb 100644 --- a/service/tests/integration/build_crate.rs +++ b/service/tests/integration/build_crate.rs @@ -6,15 +6,9 @@ use shuttle_service::loader::{build_crate, Runtime}; async fn not_shuttle() { let (tx, _) = crossbeam_channel::unbounded(); let project_path = format!("{}/tests/resources/not-shuttle", env!("CARGO_MANIFEST_DIR")); - let so_path = match build_crate( - Default::default(), - Path::new(&project_path), - false, - false, - tx, - ) - .await - .unwrap() + let so_path = match build_crate(Default::default(), Path::new(&project_path), false, tx) + .await + .unwrap() { Runtime::Legacy(path) => path, _ => unreachable!(), @@ -37,15 +31,9 @@ async fn not_shuttle() { async fn not_lib() { let (tx, _) = crossbeam_channel::unbounded(); let project_path = format!("{}/tests/resources/not-lib", env!("CARGO_MANIFEST_DIR")); - build_crate( - Default::default(), - Path::new(&project_path), - false, - false, - tx, - ) - .await - .unwrap(); + build_crate(Default::default(), Path::new(&project_path), false, tx) + .await + .unwrap(); } #[tokio::test(flavor = "multi_thread")] @@ -53,14 +41,7 @@ async fn not_cdylib() { let (tx, _) = crossbeam_channel::unbounded(); let project_path = format!("{}/tests/resources/not-cdylib", env!("CARGO_MANIFEST_DIR")); assert!(matches!( - build_crate( - Default::default(), - Path::new(&project_path), - false, - false, - tx - ) - .await, + build_crate(Default::default(), Path::new(&project_path), false, tx).await, Ok(Runtime::Legacy(_)) )); assert!(PathBuf::from(project_path) @@ -73,14 +54,7 @@ async fn is_cdylib() { let (tx, _) = crossbeam_channel::unbounded(); let project_path = format!("{}/tests/resources/is-cdylib", env!("CARGO_MANIFEST_DIR")); assert!(matches!( - build_crate( - Default::default(), - Path::new(&project_path), - false, - false, - tx - ) - .await, + build_crate(Default::default(), Path::new(&project_path), false, tx).await, Ok(Runtime::Legacy(_)) )); assert!(PathBuf::from(project_path) @@ -96,13 +70,7 @@ async fn not_found() { "{}/tests/resources/non-existing", env!("CARGO_MANIFEST_DIR") ); - build_crate( - Default::default(), - Path::new(&project_path), - false, - false, - tx, - ) - .await - .unwrap(); + build_crate(Default::default(), Path::new(&project_path), false, tx) + .await + .unwrap(); } diff --git a/tmp/axum-wasm-expanded/Cargo.toml b/tmp/axum-wasm-expanded/Cargo.toml new file mode 100644 index 000000000..90948b065 --- /dev/null +++ b/tmp/axum-wasm-expanded/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "shuttle-axum-expanded" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = [ "cdylib" ] + +[dependencies] +# most axum features can be enabled, but "tokio" and "ws" depend on socket2 +# via "hyper/tcp" which is not compatible with wasi +axum = { version = "0.6.0", default-features = false } +futures-executor = "0.3.21" +http = "0.2.7" +tower-service = "0.3.1" +rmp-serde = { version = "1.1.1" } + +[dependencies.shuttle-common] +path = "../../common" +features = ["axum-wasm"] +version = "0.8.0" diff --git a/tmp/axum-wasm-expanded/src/lib.rs b/tmp/axum-wasm-expanded/src/lib.rs new file mode 100644 index 000000000..62ca37c28 --- /dev/null +++ b/tmp/axum-wasm-expanded/src/lib.rs @@ -0,0 +1,96 @@ +use axum::{ + body::BoxBody, + extract::BodyStream, + response::{IntoResponse, Response}, +}; +use futures::TryStreamExt; + +pub fn handle_request(req: http::Request) -> axum::response::Response { + futures_executor::block_on(app(req)) +} + +async fn app(request: http::Request) -> axum::response::Response { + use tower_service::Service; + + let mut router = axum::Router::new() + .route("/hello", axum::routing::get(hello)) + .route("/goodbye", axum::routing::get(goodbye)) + .route("/uppercase", axum::routing::post(uppercase)); + + let response = router.call(request).await.unwrap(); + + response +} + +async fn hello() -> &'static str { + "Hello, World!" +} + +async fn goodbye() -> &'static str { + "Goodbye, World!" +} + +// Map the bytes of the body stream to uppercase and return the stream directly. +async fn uppercase(body: BodyStream) -> impl IntoResponse { + let chunk_stream = body.map_ok(|chunk| { + chunk + .iter() + .map(|byte| byte.to_ascii_uppercase()) + .collect::>() + }); + Response::new(axum::body::StreamBody::new(chunk_stream)) +} + +#[no_mangle] +#[allow(non_snake_case)] +pub extern "C" fn __SHUTTLE_Axum_call( + fd_3: std::os::wasi::prelude::RawFd, + fd_4: std::os::wasi::prelude::RawFd, + fd_5: std::os::wasi::prelude::RawFd, +) { + use axum::body::HttpBody; + use std::io::{Read, Write}; + use std::os::wasi::io::FromRawFd; + + println!("inner handler awoken; interacting with fd={fd_3},{fd_4},{fd_5}"); + + // file descriptor 3 for reading and writing http parts + let mut parts_fd = unsafe { std::fs::File::from_raw_fd(fd_3) }; + + let reader = std::io::BufReader::new(&mut parts_fd); + + // deserialize request parts from rust messagepack + let wrapper: shuttle_common::wasm::RequestWrapper = rmp_serde::from_read(reader).unwrap(); + + // file descriptor 4 for reading http body into wasm + let mut body_read_stream = unsafe { std::fs::File::from_raw_fd(fd_4) }; + + let mut reader = std::io::BufReader::new(&mut body_read_stream); + let mut body_buf = Vec::new(); + reader.read_to_end(&mut body_buf).unwrap(); + + let body = axum::body::Body::from(body_buf); + + let request = wrapper + .into_request_builder() + .body(axum::body::boxed(body)) + .unwrap(); + + println!("inner router received request: {:?}", &request); + let res = handle_request(request); + + let (parts, mut body) = res.into_parts(); + + // wrap and serialize response parts as rmp + let response_parts = shuttle_common::wasm::ResponseWrapper::from(parts).into_rmp(); + + // write response parts + parts_fd.write_all(&response_parts).unwrap(); + + // file descriptor 5 for writing http body to host + let mut body_write_stream = unsafe { std::fs::File::from_raw_fd(fd_5) }; + + // write body if there is one + if let Some(body) = futures_executor::block_on(body.data()) { + body_write_stream.write_all(body.unwrap().as_ref()).unwrap(); +} diff --git a/tmp/axum-wasm/Cargo.toml b/tmp/axum-wasm/Cargo.toml index 740001a64..02a0b8546 100644 --- a/tmp/axum-wasm/Cargo.toml +++ b/tmp/axum-wasm/Cargo.toml @@ -16,6 +16,11 @@ tower-service = "0.3.1" rmp-serde = { version = "1.1.1" } futures = "0.3.25" +[dependencies.shuttle-codegen] +path = "../../codegen" +features = ["next"] +version = "0.8.0" + [dependencies.shuttle-common] path = "../../common" features = ["axum-wasm"] diff --git a/tmp/axum-wasm/src/lib.rs b/tmp/axum-wasm/src/lib.rs index e0801ac9a..5be67e21c 100644 --- a/tmp/axum-wasm/src/lib.rs +++ b/tmp/axum-wasm/src/lib.rs @@ -1,97 +1,11 @@ -use axum::{ - body::BoxBody, - extract::BodyStream, - response::{IntoResponse, Response}, -}; -use futures::TryStreamExt; - -pub fn handle_request(req: http::Request) -> axum::response::Response { - futures_executor::block_on(app(req)) -} - -async fn app(request: http::Request) -> axum::response::Response { - use tower_service::Service; - - let mut router = axum::Router::new() - .route("/hello", axum::routing::get(hello)) - .route("/goodbye", axum::routing::get(goodbye)) - .route("/uppercase", axum::routing::post(uppercase)); - - let response = router.call(request).await.unwrap(); - - response -} - -async fn hello() -> &'static str { - "Hello, World!" -} - -async fn goodbye() -> &'static str { - "Goodbye, World!" -} - -// Map the bytes of the body stream to uppercase and return the stream directly. -async fn uppercase(body: BodyStream) -> impl IntoResponse { - let chunk_stream = body.map_ok(|chunk| { - chunk - .iter() - .map(|byte| byte.to_ascii_uppercase()) - .collect::>() - }); - Response::new(axum::body::StreamBody::new(chunk_stream)) -} - -#[no_mangle] -#[allow(non_snake_case)] -pub extern "C" fn __SHUTTLE_Axum_call( - fd_3: std::os::wasi::prelude::RawFd, - fd_4: std::os::wasi::prelude::RawFd, - fd_5: std::os::wasi::prelude::RawFd, -) { - use axum::body::HttpBody; - use std::io::{Read, Write}; - use std::os::wasi::io::FromRawFd; - - println!("inner handler awoken; interacting with fd={fd_3},{fd_4},{fd_5}"); - - // file descriptor 3 for reading and writing http parts - let mut parts_fd = unsafe { std::fs::File::from_raw_fd(fd_3) }; - - let reader = std::io::BufReader::new(&mut parts_fd); - - // deserialize request parts from rust messagepack - let wrapper: shuttle_common::wasm::RequestWrapper = rmp_serde::from_read(reader).unwrap(); - - // file descriptor 4 for reading http body into wasm - let mut body_read_stream = unsafe { std::fs::File::from_raw_fd(fd_4) }; - - let mut reader = std::io::BufReader::new(&mut body_read_stream); - let mut body_buf = Vec::new(); - reader.read_to_end(&mut body_buf).unwrap(); - - let body = axum::body::Body::from(body_buf); - - let request = wrapper - .into_request_builder() - .body(axum::body::boxed(body)) - .unwrap(); - - println!("inner router received request: {:?}", &request); - let res = handle_request(request); - - let (parts, mut body) = res.into_parts(); - - // wrap and serialize response parts as rmp - let response_parts = shuttle_common::wasm::ResponseWrapper::from(parts).into_rmp(); - - // write response parts - parts_fd.write_all(&response_parts).unwrap(); - - // file descriptor 5 for writing http body to host - let mut body_write_stream = unsafe { std::fs::File::from_raw_fd(fd_5) }; +shuttle_codegen::app! { + #[shuttle_codegen::endpoint(method = get, route = "/hello")] + async fn hello() -> &'static str { + "Hello, World!" + } - // write body if there is one - if let Some(body) = futures_executor::block_on(body.data()) { - body_write_stream.write_all(body.unwrap().as_ref()).unwrap(); + #[shuttle_codegen::endpoint(method = get, route = "/goodbye")] + async fn goodbye() -> &'static str { + "Goodbye, World!" } }