diff --git a/Cargo.toml b/Cargo.toml index fb3a9ccda..ebcd4bb4f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,8 @@ exclude = [ "resources/persist", "resources/secrets", "resources/shared-db", - "resources/static-folder" + "resources/static-folder", + "integrations" ] [workspace.package] diff --git a/codegen/src/shuttle_main/mod.rs b/codegen/src/shuttle_main/mod.rs index 3b73f82a5..22698db62 100644 --- a/codegen/src/shuttle_main/mod.rs +++ b/codegen/src/shuttle_main/mod.rs @@ -4,7 +4,7 @@ use quote::{quote, ToTokens}; use syn::{ parenthesized, parse::Parse, parse2, parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned, token::Paren, Attribute, Expr, FnArg, Ident, ItemFn, Pat, PatIdent, Path, - ReturnType, Signature, Stmt, Token, Type, TypePath, + PathSegment, ReturnType, Signature, Stmt, Token, Type, TypePath, }; pub(crate) fn r#impl(_attr: TokenStream, item: TokenStream) -> TokenStream { @@ -12,11 +12,10 @@ pub(crate) fn r#impl(_attr: TokenStream, item: TokenStream) -> TokenStream { let loader = Loader::from_item_fn(&mut fn_decl); + let main_fn = MainFn::from_item_fn(&mut fn_decl); + let expanded = quote! { - #[tokio::main] - async fn main() { - shuttle_runtime::start(loader).await; - } + #main_fn #loader @@ -30,6 +29,11 @@ struct Loader { fn_ident: Ident, fn_inputs: Vec, fn_return: TypePath, + import_path: PathSegment, +} + +struct MainFn { + import_path: PathSegment, } #[derive(Debug, PartialEq)] @@ -128,10 +132,16 @@ impl Loader { .collect(); if let Some(type_path) = check_return_type(item_fn.sig.clone()) { + // We need the first segment of the path so we can import the codegen dependencies from it. + let Some(import_path) = type_path.path.segments.first().cloned() else { + return None; + }; + Some(Self { fn_ident: item_fn.sig.ident.clone(), fn_inputs: inputs, fn_return: type_path, + import_path, }) } else { None @@ -139,6 +149,21 @@ impl Loader { } } +impl MainFn { + pub(crate) fn from_item_fn(item_fn: &mut ItemFn) -> Option { + if let Some(type_path) = check_return_type(item_fn.sig.clone()) { + // We need the first segment of the path so we can import the codegen dependencies from it. + let Some(import_path) = type_path.path.segments.first().cloned() else { + return None; + }; + + Some(Self { import_path }) + } else { + None + } + } +} + fn check_return_type(signature: Signature) -> Option { match signature.output { ReturnType::Default => { @@ -193,6 +218,8 @@ impl ToTokens for Loader { let return_type = &self.fn_return; + let import_path = &self.import_path; + let mut fn_inputs: Vec<_> = Vec::with_capacity(self.fn_inputs.len()); let mut fn_inputs_builder: Vec<_> = Vec::with_capacity(self.fn_inputs.len()); let mut fn_inputs_builder_options: Vec<_> = Vec::with_capacity(self.fn_inputs.len()); @@ -213,25 +240,25 @@ impl ToTokens for Loader { None } else { Some(parse_quote!( - use shuttle_service::ResourceBuilder; + use #import_path::shuttle_runtime::ResourceBuilder; )) }; let loader = quote! { - async fn loader( - mut #factory_ident: shuttle_runtime::ProvisionerFactory, - logger: shuttle_runtime::Logger, + async fn loader( + mut #factory_ident: #import_path::shuttle_runtime::ProvisionerFactory, + logger: #import_path::shuttle_runtime::Logger, ) -> #return_type { - use shuttle_service::Context; - use shuttle_service::tracing_subscriber::prelude::*; + use #import_path::shuttle_runtime::Context; + use #import_path::shuttle_runtime::tracing_subscriber::prelude::*; #extra_imports let filter_layer = - shuttle_service::tracing_subscriber::EnvFilter::try_from_default_env() - .or_else(|_| shuttle_service::tracing_subscriber::EnvFilter::try_new("INFO")) + #import_path::shuttle_runtime::tracing_subscriber::EnvFilter::try_from_default_env() + .or_else(|_| #import_path::shuttle_runtime::tracing_subscriber::EnvFilter::try_new("INFO")) .unwrap(); - shuttle_service::tracing_subscriber::registry() + #import_path::shuttle_runtime::tracing_subscriber::registry() .with(filter_layer) .with(logger) .init(); @@ -246,11 +273,27 @@ impl ToTokens for Loader { } } +impl ToTokens for MainFn { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let import_path = &self.import_path; + + let main_fn = quote! { + #[tokio::main] + async fn main() { + #import_path::shuttle_runtime::start(loader).await; + } + + }; + + main_fn.to_tokens(tokens); + } +} + #[cfg(test)] mod tests { use pretty_assertions::assert_eq; use quote::quote; - use syn::{parse_quote, Ident}; + use syn::{parse_quote, Ident, PathSegment}; use super::{Builder, BuilderOptions, Input, Loader}; @@ -269,27 +312,30 @@ mod tests { #[test] fn output_with_return() { + let import_path: PathSegment = parse_quote!(shuttle_simple); + let input = Loader { fn_ident: parse_quote!(simple), fn_inputs: Vec::new(), fn_return: parse_quote!(ShuttleSimple), + import_path, }; let actual = quote!(#input); let expected = quote! { - async fn loader( - mut _factory: shuttle_runtime::ProvisionerFactory, - logger: shuttle_runtime::Logger, + async fn loader( + mut _factory: shuttle_simple::shuttle_runtime::ProvisionerFactory, + logger: shuttle_simple::shuttle_runtime::Logger, ) -> ShuttleSimple { - use shuttle_service::Context; - use shuttle_service::tracing_subscriber::prelude::*; + use shuttle_simple::shuttle_runtime::Context; + use shuttle_simple::shuttle_runtime::tracing_subscriber::prelude::*; let filter_layer = - shuttle_service::tracing_subscriber::EnvFilter::try_from_default_env() - .or_else(|_| shuttle_service::tracing_subscriber::EnvFilter::try_new("INFO")) + shuttle_simple::shuttle_runtime::tracing_subscriber::EnvFilter::try_from_default_env() + .or_else(|_| shuttle_simple::shuttle_runtime::tracing_subscriber::EnvFilter::try_new("INFO")) .unwrap(); - shuttle_service::tracing_subscriber::registry() + shuttle_simple::shuttle_runtime::tracing_subscriber::registry() .with(filter_layer) .with(logger) .init(); @@ -334,6 +380,8 @@ mod tests { #[test] fn output_with_inputs() { + let import_path: PathSegment = parse_quote!(shuttle_complex); + let input = Loader { fn_ident: parse_quote!(complex), fn_inputs: vec![ @@ -353,24 +401,25 @@ mod tests { }, ], fn_return: parse_quote!(ShuttleComplex), + import_path, }; let actual = quote!(#input); let expected = quote! { - async fn loader( - mut factory: shuttle_runtime::ProvisionerFactory, - logger: shuttle_runtime::Logger, + async fn loader( + mut factory: shuttle_complex::shuttle_runtime::ProvisionerFactory, + logger: shuttle_complex::shuttle_runtime::Logger, ) -> ShuttleComplex { - use shuttle_service::Context; - use shuttle_service::tracing_subscriber::prelude::*; - use shuttle_service::ResourceBuilder; + use shuttle_complex::shuttle_runtime::Context; + use shuttle_complex::shuttle_runtime::tracing_subscriber::prelude::*; + use shuttle_complex::shuttle_runtime::ResourceBuilder; let filter_layer = - shuttle_service::tracing_subscriber::EnvFilter::try_from_default_env() - .or_else(|_| shuttle_service::tracing_subscriber::EnvFilter::try_new("INFO")) + shuttle_complex::shuttle_runtime::tracing_subscriber::EnvFilter::try_from_default_env() + .or_else(|_| shuttle_complex::shuttle_runtime::tracing_subscriber::EnvFilter::try_new("INFO")) .unwrap(); - shuttle_service::tracing_subscriber::registry() + shuttle_complex::shuttle_runtime::tracing_subscriber::registry() .with(filter_layer) .with(logger) .init(); @@ -460,6 +509,8 @@ mod tests { #[test] fn output_with_input_options() { + let import_path: PathSegment = parse_quote!(shuttle_complex); + let mut input = Loader { fn_ident: parse_quote!(complex), fn_inputs: vec![Input { @@ -470,6 +521,7 @@ mod tests { }, }], fn_return: parse_quote!(ShuttleComplex), + import_path, }; input.fn_inputs[0] @@ -485,20 +537,20 @@ mod tests { let actual = quote!(#input); let expected = quote! { - async fn loader( - mut factory: shuttle_runtime::ProvisionerFactory, - logger: shuttle_runtime::Logger, + async fn loader( + mut factory: shuttle_complex::shuttle_runtime::ProvisionerFactory, + logger: shuttle_complex::shuttle_runtime::Logger, ) -> ShuttleComplex { - use shuttle_service::Context; - use shuttle_service::tracing_subscriber::prelude::*; - use shuttle_service::ResourceBuilder; + use shuttle_complex::shuttle_runtime::Context; + use shuttle_complex::shuttle_runtime::tracing_subscriber::prelude::*; + use shuttle_complex::shuttle_runtime::ResourceBuilder; let filter_layer = - shuttle_service::tracing_subscriber::EnvFilter::try_from_default_env() - .or_else(|_| shuttle_service::tracing_subscriber::EnvFilter::try_new("INFO")) + shuttle_complex::shuttle_runtime::tracing_subscriber::EnvFilter::try_from_default_env() + .or_else(|_| shuttle_complex::shuttle_runtime::tracing_subscriber::EnvFilter::try_new("INFO")) .unwrap(); - shuttle_service::tracing_subscriber::registry() + shuttle_complex::shuttle_runtime::tracing_subscriber::registry() .with(filter_layer) .with(logger) .init(); diff --git a/integrations/README.md b/integrations/README.md new file mode 100644 index 000000000..39272e508 --- /dev/null +++ b/integrations/README.md @@ -0,0 +1,5 @@ +## Service Integrations +The list of supported frameworks for shuttle is always growing. If you feel we are missing a framework you would like, then feel to create a feature request for your desired framework. + +## Writing your own service integration +Creating your own service integration is quite simple. You only need to implement the [`Service`](https://docs.rs/shuttle-service/latest/shuttle_service/trait.Service.html) trait for your framework. diff --git a/integrations/shuttle-axum/Cargo.toml b/integrations/shuttle-axum/Cargo.toml new file mode 100644 index 000000000..65213b78d --- /dev/null +++ b/integrations/shuttle-axum/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "shuttle-axum" +version = "0.1.0" +edition = "2021" + +[workspace] +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +async-trait = "0.1.56" +axum = { version = "0.6.0" } +shuttle-runtime = { path = "../../runtime", version = "0.1.0" } diff --git a/integrations/shuttle-axum/src/lib.rs b/integrations/shuttle-axum/src/lib.rs new file mode 100644 index 000000000..8d075586e --- /dev/null +++ b/integrations/shuttle-axum/src/lib.rs @@ -0,0 +1,39 @@ +//! Shuttle service integration for the Axum web framework. +//! ## Example +//! ```rust,no_run +//! use shuttle_axum::AxumService; +//! +//! async fn hello_world() -> &'static str { +//! "Hello, world!" +//! } +//! +//! #[shuttle_axum::main] +//! async fn axum() -> shuttle_service::ShuttleAxum { +//! let router = Router::new().route("/hello", get(hello_world)); +//! +//! Ok(AxumService(router)) +//! } +//! ``` + +/// A wrapper type for `axum::Router` so we can implement `shuttle_runtime::Service` for it. +pub struct AxumService(pub axum::Router); + +#[shuttle_runtime::async_trait] +impl shuttle_runtime::Service for AxumService { + /// Takes the router that is returned by the user in their `shuttle_runtime::main` function + /// and binds to an address passed in by shuttle. + async fn bind(mut self, addr: std::net::SocketAddr) -> Result<(), shuttle_runtime::Error> { + axum::Server::bind(&addr) + .serve(self.0.into_make_service()) + .await + .map_err(shuttle_runtime::CustomError::new)?; + + Ok(()) + } +} + +/// The return type that should be returned from the `shuttle_runtime::main` function. +pub type ShuttleAxum = Result; + +pub use shuttle_runtime; +pub use shuttle_runtime::*; diff --git a/runtime/src/lib.rs b/runtime/src/lib.rs index 233616493..44cf6a994 100644 --- a/runtime/src/lib.rs +++ b/runtime/src/lib.rs @@ -4,10 +4,16 @@ mod logger; mod next; mod provisioner_factory; +pub use async_trait::async_trait; pub use legacy::{start, Legacy}; pub use logger::Logger; #[cfg(feature = "next")] pub use next::{AxumWasm, NextArgs}; pub use provisioner_factory::ProvisionerFactory; pub use shuttle_common::storage_manager::StorageManager; -pub use shuttle_service::{main, Error, Service}; +pub use shuttle_service::{main, Error, ResourceBuilder, Service}; + +pub type CustomError = anyhow::Error; +pub use anyhow::Context; +pub use tracing; +pub use tracing_subscriber;