diff --git a/Cargo.toml b/Cargo.toml
index fb3a9ccda..ebcd4bb4f 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -20,7 +20,8 @@ exclude = [
"resources/persist",
"resources/secrets",
"resources/shared-db",
- "resources/static-folder"
+ "resources/static-folder",
+ "integrations"
]
[workspace.package]
diff --git a/codegen/src/shuttle_main/mod.rs b/codegen/src/shuttle_main/mod.rs
index 3b73f82a5..22698db62 100644
--- a/codegen/src/shuttle_main/mod.rs
+++ b/codegen/src/shuttle_main/mod.rs
@@ -4,7 +4,7 @@ use quote::{quote, ToTokens};
use syn::{
parenthesized, parse::Parse, parse2, parse_macro_input, parse_quote, punctuated::Punctuated,
spanned::Spanned, token::Paren, Attribute, Expr, FnArg, Ident, ItemFn, Pat, PatIdent, Path,
- ReturnType, Signature, Stmt, Token, Type, TypePath,
+ PathSegment, ReturnType, Signature, Stmt, Token, Type, TypePath,
};
pub(crate) fn r#impl(_attr: TokenStream, item: TokenStream) -> TokenStream {
@@ -12,11 +12,10 @@ pub(crate) fn r#impl(_attr: TokenStream, item: TokenStream) -> TokenStream {
let loader = Loader::from_item_fn(&mut fn_decl);
+ let main_fn = MainFn::from_item_fn(&mut fn_decl);
+
let expanded = quote! {
- #[tokio::main]
- async fn main() {
- shuttle_runtime::start(loader).await;
- }
+ #main_fn
#loader
@@ -30,6 +29,11 @@ struct Loader {
fn_ident: Ident,
fn_inputs: Vec,
fn_return: TypePath,
+ import_path: PathSegment,
+}
+
+struct MainFn {
+ import_path: PathSegment,
}
#[derive(Debug, PartialEq)]
@@ -128,10 +132,16 @@ impl Loader {
.collect();
if let Some(type_path) = check_return_type(item_fn.sig.clone()) {
+ // We need the first segment of the path so we can import the codegen dependencies from it.
+ let Some(import_path) = type_path.path.segments.first().cloned() else {
+ return None;
+ };
+
Some(Self {
fn_ident: item_fn.sig.ident.clone(),
fn_inputs: inputs,
fn_return: type_path,
+ import_path,
})
} else {
None
@@ -139,6 +149,21 @@ impl Loader {
}
}
+impl MainFn {
+ pub(crate) fn from_item_fn(item_fn: &mut ItemFn) -> Option {
+ if let Some(type_path) = check_return_type(item_fn.sig.clone()) {
+ // We need the first segment of the path so we can import the codegen dependencies from it.
+ let Some(import_path) = type_path.path.segments.first().cloned() else {
+ return None;
+ };
+
+ Some(Self { import_path })
+ } else {
+ None
+ }
+ }
+}
+
fn check_return_type(signature: Signature) -> Option {
match signature.output {
ReturnType::Default => {
@@ -193,6 +218,8 @@ impl ToTokens for Loader {
let return_type = &self.fn_return;
+ let import_path = &self.import_path;
+
let mut fn_inputs: Vec<_> = Vec::with_capacity(self.fn_inputs.len());
let mut fn_inputs_builder: Vec<_> = Vec::with_capacity(self.fn_inputs.len());
let mut fn_inputs_builder_options: Vec<_> = Vec::with_capacity(self.fn_inputs.len());
@@ -213,25 +240,25 @@ impl ToTokens for Loader {
None
} else {
Some(parse_quote!(
- use shuttle_service::ResourceBuilder;
+ use #import_path::shuttle_runtime::ResourceBuilder;
))
};
let loader = quote! {
- async fn loader(
- mut #factory_ident: shuttle_runtime::ProvisionerFactory,
- logger: shuttle_runtime::Logger,
+ async fn loader(
+ mut #factory_ident: #import_path::shuttle_runtime::ProvisionerFactory,
+ logger: #import_path::shuttle_runtime::Logger,
) -> #return_type {
- use shuttle_service::Context;
- use shuttle_service::tracing_subscriber::prelude::*;
+ use #import_path::shuttle_runtime::Context;
+ use #import_path::shuttle_runtime::tracing_subscriber::prelude::*;
#extra_imports
let filter_layer =
- shuttle_service::tracing_subscriber::EnvFilter::try_from_default_env()
- .or_else(|_| shuttle_service::tracing_subscriber::EnvFilter::try_new("INFO"))
+ #import_path::shuttle_runtime::tracing_subscriber::EnvFilter::try_from_default_env()
+ .or_else(|_| #import_path::shuttle_runtime::tracing_subscriber::EnvFilter::try_new("INFO"))
.unwrap();
- shuttle_service::tracing_subscriber::registry()
+ #import_path::shuttle_runtime::tracing_subscriber::registry()
.with(filter_layer)
.with(logger)
.init();
@@ -246,11 +273,27 @@ impl ToTokens for Loader {
}
}
+impl ToTokens for MainFn {
+ fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
+ let import_path = &self.import_path;
+
+ let main_fn = quote! {
+ #[tokio::main]
+ async fn main() {
+ #import_path::shuttle_runtime::start(loader).await;
+ }
+
+ };
+
+ main_fn.to_tokens(tokens);
+ }
+}
+
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use quote::quote;
- use syn::{parse_quote, Ident};
+ use syn::{parse_quote, Ident, PathSegment};
use super::{Builder, BuilderOptions, Input, Loader};
@@ -269,27 +312,30 @@ mod tests {
#[test]
fn output_with_return() {
+ let import_path: PathSegment = parse_quote!(shuttle_simple);
+
let input = Loader {
fn_ident: parse_quote!(simple),
fn_inputs: Vec::new(),
fn_return: parse_quote!(ShuttleSimple),
+ import_path,
};
let actual = quote!(#input);
let expected = quote! {
- async fn loader(
- mut _factory: shuttle_runtime::ProvisionerFactory,
- logger: shuttle_runtime::Logger,
+ async fn loader(
+ mut _factory: shuttle_simple::shuttle_runtime::ProvisionerFactory,
+ logger: shuttle_simple::shuttle_runtime::Logger,
) -> ShuttleSimple {
- use shuttle_service::Context;
- use shuttle_service::tracing_subscriber::prelude::*;
+ use shuttle_simple::shuttle_runtime::Context;
+ use shuttle_simple::shuttle_runtime::tracing_subscriber::prelude::*;
let filter_layer =
- shuttle_service::tracing_subscriber::EnvFilter::try_from_default_env()
- .or_else(|_| shuttle_service::tracing_subscriber::EnvFilter::try_new("INFO"))
+ shuttle_simple::shuttle_runtime::tracing_subscriber::EnvFilter::try_from_default_env()
+ .or_else(|_| shuttle_simple::shuttle_runtime::tracing_subscriber::EnvFilter::try_new("INFO"))
.unwrap();
- shuttle_service::tracing_subscriber::registry()
+ shuttle_simple::shuttle_runtime::tracing_subscriber::registry()
.with(filter_layer)
.with(logger)
.init();
@@ -334,6 +380,8 @@ mod tests {
#[test]
fn output_with_inputs() {
+ let import_path: PathSegment = parse_quote!(shuttle_complex);
+
let input = Loader {
fn_ident: parse_quote!(complex),
fn_inputs: vec![
@@ -353,24 +401,25 @@ mod tests {
},
],
fn_return: parse_quote!(ShuttleComplex),
+ import_path,
};
let actual = quote!(#input);
let expected = quote! {
- async fn loader(
- mut factory: shuttle_runtime::ProvisionerFactory,
- logger: shuttle_runtime::Logger,
+ async fn loader(
+ mut factory: shuttle_complex::shuttle_runtime::ProvisionerFactory,
+ logger: shuttle_complex::shuttle_runtime::Logger,
) -> ShuttleComplex {
- use shuttle_service::Context;
- use shuttle_service::tracing_subscriber::prelude::*;
- use shuttle_service::ResourceBuilder;
+ use shuttle_complex::shuttle_runtime::Context;
+ use shuttle_complex::shuttle_runtime::tracing_subscriber::prelude::*;
+ use shuttle_complex::shuttle_runtime::ResourceBuilder;
let filter_layer =
- shuttle_service::tracing_subscriber::EnvFilter::try_from_default_env()
- .or_else(|_| shuttle_service::tracing_subscriber::EnvFilter::try_new("INFO"))
+ shuttle_complex::shuttle_runtime::tracing_subscriber::EnvFilter::try_from_default_env()
+ .or_else(|_| shuttle_complex::shuttle_runtime::tracing_subscriber::EnvFilter::try_new("INFO"))
.unwrap();
- shuttle_service::tracing_subscriber::registry()
+ shuttle_complex::shuttle_runtime::tracing_subscriber::registry()
.with(filter_layer)
.with(logger)
.init();
@@ -460,6 +509,8 @@ mod tests {
#[test]
fn output_with_input_options() {
+ let import_path: PathSegment = parse_quote!(shuttle_complex);
+
let mut input = Loader {
fn_ident: parse_quote!(complex),
fn_inputs: vec![Input {
@@ -470,6 +521,7 @@ mod tests {
},
}],
fn_return: parse_quote!(ShuttleComplex),
+ import_path,
};
input.fn_inputs[0]
@@ -485,20 +537,20 @@ mod tests {
let actual = quote!(#input);
let expected = quote! {
- async fn loader(
- mut factory: shuttle_runtime::ProvisionerFactory,
- logger: shuttle_runtime::Logger,
+ async fn loader(
+ mut factory: shuttle_complex::shuttle_runtime::ProvisionerFactory,
+ logger: shuttle_complex::shuttle_runtime::Logger,
) -> ShuttleComplex {
- use shuttle_service::Context;
- use shuttle_service::tracing_subscriber::prelude::*;
- use shuttle_service::ResourceBuilder;
+ use shuttle_complex::shuttle_runtime::Context;
+ use shuttle_complex::shuttle_runtime::tracing_subscriber::prelude::*;
+ use shuttle_complex::shuttle_runtime::ResourceBuilder;
let filter_layer =
- shuttle_service::tracing_subscriber::EnvFilter::try_from_default_env()
- .or_else(|_| shuttle_service::tracing_subscriber::EnvFilter::try_new("INFO"))
+ shuttle_complex::shuttle_runtime::tracing_subscriber::EnvFilter::try_from_default_env()
+ .or_else(|_| shuttle_complex::shuttle_runtime::tracing_subscriber::EnvFilter::try_new("INFO"))
.unwrap();
- shuttle_service::tracing_subscriber::registry()
+ shuttle_complex::shuttle_runtime::tracing_subscriber::registry()
.with(filter_layer)
.with(logger)
.init();
diff --git a/integrations/README.md b/integrations/README.md
new file mode 100644
index 000000000..39272e508
--- /dev/null
+++ b/integrations/README.md
@@ -0,0 +1,5 @@
+## Service Integrations
+The list of supported frameworks for shuttle is always growing. If you feel we are missing a framework you would like, then feel to create a feature request for your desired framework.
+
+## Writing your own service integration
+Creating your own service integration is quite simple. You only need to implement the [`Service`](https://docs.rs/shuttle-service/latest/shuttle_service/trait.Service.html) trait for your framework.
diff --git a/integrations/shuttle-axum/Cargo.toml b/integrations/shuttle-axum/Cargo.toml
new file mode 100644
index 000000000..65213b78d
--- /dev/null
+++ b/integrations/shuttle-axum/Cargo.toml
@@ -0,0 +1,12 @@
+[package]
+name = "shuttle-axum"
+version = "0.1.0"
+edition = "2021"
+
+[workspace]
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+async-trait = "0.1.56"
+axum = { version = "0.6.0" }
+shuttle-runtime = { path = "../../runtime", version = "0.1.0" }
diff --git a/integrations/shuttle-axum/src/lib.rs b/integrations/shuttle-axum/src/lib.rs
new file mode 100644
index 000000000..8d075586e
--- /dev/null
+++ b/integrations/shuttle-axum/src/lib.rs
@@ -0,0 +1,39 @@
+//! Shuttle service integration for the Axum web framework.
+//! ## Example
+//! ```rust,no_run
+//! use shuttle_axum::AxumService;
+//!
+//! async fn hello_world() -> &'static str {
+//! "Hello, world!"
+//! }
+//!
+//! #[shuttle_axum::main]
+//! async fn axum() -> shuttle_service::ShuttleAxum {
+//! let router = Router::new().route("/hello", get(hello_world));
+//!
+//! Ok(AxumService(router))
+//! }
+//! ```
+
+/// A wrapper type for `axum::Router` so we can implement `shuttle_runtime::Service` for it.
+pub struct AxumService(pub axum::Router);
+
+#[shuttle_runtime::async_trait]
+impl shuttle_runtime::Service for AxumService {
+ /// Takes the router that is returned by the user in their `shuttle_runtime::main` function
+ /// and binds to an address passed in by shuttle.
+ async fn bind(mut self, addr: std::net::SocketAddr) -> Result<(), shuttle_runtime::Error> {
+ axum::Server::bind(&addr)
+ .serve(self.0.into_make_service())
+ .await
+ .map_err(shuttle_runtime::CustomError::new)?;
+
+ Ok(())
+ }
+}
+
+/// The return type that should be returned from the `shuttle_runtime::main` function.
+pub type ShuttleAxum = Result;
+
+pub use shuttle_runtime;
+pub use shuttle_runtime::*;
diff --git a/runtime/src/lib.rs b/runtime/src/lib.rs
index 233616493..44cf6a994 100644
--- a/runtime/src/lib.rs
+++ b/runtime/src/lib.rs
@@ -4,10 +4,16 @@ mod logger;
mod next;
mod provisioner_factory;
+pub use async_trait::async_trait;
pub use legacy::{start, Legacy};
pub use logger::Logger;
#[cfg(feature = "next")]
pub use next::{AxumWasm, NextArgs};
pub use provisioner_factory::ProvisionerFactory;
pub use shuttle_common::storage_manager::StorageManager;
-pub use shuttle_service::{main, Error, Service};
+pub use shuttle_service::{main, Error, ResourceBuilder, Service};
+
+pub type CustomError = anyhow::Error;
+pub use anyhow::Context;
+pub use tracing;
+pub use tracing_subscriber;