Skip to content

Commit

Permalink
feat: draft of extracting service integrations
Browse files Browse the repository at this point in the history
  • Loading branch information
oddgrd committed Mar 10, 2023
1 parent 4e1690d commit 9508d42
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 43 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ exclude = [
"resources/persist",
"resources/secrets",
"resources/shared-db",
"resources/static-folder"
"resources/static-folder",
"integrations"
]

[workspace.package]
Expand Down
134 changes: 93 additions & 41 deletions codegen/src/shuttle_main/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,18 @@ use quote::{quote, ToTokens};
use syn::{
parenthesized, parse::Parse, parse2, parse_macro_input, parse_quote, punctuated::Punctuated,
spanned::Spanned, token::Paren, Attribute, Expr, FnArg, Ident, ItemFn, Pat, PatIdent, Path,
ReturnType, Signature, Stmt, Token, Type, TypePath,
PathSegment, ReturnType, Signature, Stmt, Token, Type, TypePath,
};

pub(crate) fn r#impl(_attr: TokenStream, item: TokenStream) -> TokenStream {
let mut fn_decl = parse_macro_input!(item as ItemFn);

let loader = Loader::from_item_fn(&mut fn_decl);

let main_fn = MainFn::from_item_fn(&mut fn_decl);

let expanded = quote! {
#[tokio::main]
async fn main() {
shuttle_runtime::start(loader).await;
}
#main_fn

#loader

Expand All @@ -30,6 +29,11 @@ struct Loader {
fn_ident: Ident,
fn_inputs: Vec<Input>,
fn_return: TypePath,
import_path: PathSegment,
}

struct MainFn {
import_path: PathSegment,
}

#[derive(Debug, PartialEq)]
Expand Down Expand Up @@ -128,17 +132,38 @@ impl Loader {
.collect();

if let Some(type_path) = check_return_type(item_fn.sig.clone()) {
// We need the first segment of the path so we can import the codegen dependencies from it.
let Some(import_path) = type_path.path.segments.first().cloned() else {
return None;
};

Some(Self {
fn_ident: item_fn.sig.ident.clone(),
fn_inputs: inputs,
fn_return: type_path,
import_path,
})
} else {
None
}
}
}

impl MainFn {
pub(crate) fn from_item_fn(item_fn: &mut ItemFn) -> Option<Self> {
if let Some(type_path) = check_return_type(item_fn.sig.clone()) {
// We need the first segment of the path so we can import the codegen dependencies from it.
let Some(import_path) = type_path.path.segments.first().cloned() else {
return None;
};

Some(Self { import_path })
} else {
None
}
}
}

fn check_return_type(signature: Signature) -> Option<TypePath> {
match signature.output {
ReturnType::Default => {
Expand Down Expand Up @@ -193,6 +218,8 @@ impl ToTokens for Loader {

let return_type = &self.fn_return;

let import_path = &self.import_path;

let mut fn_inputs: Vec<_> = Vec::with_capacity(self.fn_inputs.len());
let mut fn_inputs_builder: Vec<_> = Vec::with_capacity(self.fn_inputs.len());
let mut fn_inputs_builder_options: Vec<_> = Vec::with_capacity(self.fn_inputs.len());
Expand All @@ -213,25 +240,25 @@ impl ToTokens for Loader {
None
} else {
Some(parse_quote!(
use shuttle_service::ResourceBuilder;
use #import_path::shuttle_runtime::ResourceBuilder;
))
};

let loader = quote! {
async fn loader<S: shuttle_runtime::StorageManager>(
mut #factory_ident: shuttle_runtime::ProvisionerFactory<S>,
logger: shuttle_runtime::Logger,
async fn loader<S: #import_path::shuttle_runtime::StorageManager>(
mut #factory_ident: #import_path::shuttle_runtime::ProvisionerFactory<S>,
logger: #import_path::shuttle_runtime::Logger,
) -> #return_type {
use shuttle_service::Context;
use shuttle_service::tracing_subscriber::prelude::*;
use #import_path::shuttle_runtime::Context;
use #import_path::shuttle_runtime::tracing_subscriber::prelude::*;
#extra_imports

let filter_layer =
shuttle_service::tracing_subscriber::EnvFilter::try_from_default_env()
.or_else(|_| shuttle_service::tracing_subscriber::EnvFilter::try_new("INFO"))
#import_path::shuttle_runtime::tracing_subscriber::EnvFilter::try_from_default_env()
.or_else(|_| #import_path::shuttle_runtime::tracing_subscriber::EnvFilter::try_new("INFO"))
.unwrap();

shuttle_service::tracing_subscriber::registry()
#import_path::shuttle_runtime::tracing_subscriber::registry()
.with(filter_layer)
.with(logger)
.init();
Expand All @@ -246,11 +273,27 @@ impl ToTokens for Loader {
}
}

impl ToTokens for MainFn {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
let import_path = &self.import_path;

let main_fn = quote! {
#[tokio::main]
async fn main() {
#import_path::shuttle_runtime::start(loader).await;
}

};

main_fn.to_tokens(tokens);
}
}

#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use quote::quote;
use syn::{parse_quote, Ident};
use syn::{parse_quote, Ident, PathSegment};

use super::{Builder, BuilderOptions, Input, Loader};

Expand All @@ -269,27 +312,30 @@ mod tests {

#[test]
fn output_with_return() {
let import_path: PathSegment = parse_quote!(shuttle_simple);

let input = Loader {
fn_ident: parse_quote!(simple),
fn_inputs: Vec::new(),
fn_return: parse_quote!(ShuttleSimple),
import_path,
};

let actual = quote!(#input);
let expected = quote! {
async fn loader<S: shuttle_runtime::StorageManager>(
mut _factory: shuttle_runtime::ProvisionerFactory<S>,
logger: shuttle_runtime::Logger,
async fn loader<S: shuttle_simple::shuttle_runtime::StorageManager>(
mut _factory: shuttle_simple::shuttle_runtime::ProvisionerFactory<S>,
logger: shuttle_simple::shuttle_runtime::Logger,
) -> ShuttleSimple {
use shuttle_service::Context;
use shuttle_service::tracing_subscriber::prelude::*;
use shuttle_simple::shuttle_runtime::Context;
use shuttle_simple::shuttle_runtime::tracing_subscriber::prelude::*;

let filter_layer =
shuttle_service::tracing_subscriber::EnvFilter::try_from_default_env()
.or_else(|_| shuttle_service::tracing_subscriber::EnvFilter::try_new("INFO"))
shuttle_simple::shuttle_runtime::tracing_subscriber::EnvFilter::try_from_default_env()
.or_else(|_| shuttle_simple::shuttle_runtime::tracing_subscriber::EnvFilter::try_new("INFO"))
.unwrap();

shuttle_service::tracing_subscriber::registry()
shuttle_simple::shuttle_runtime::tracing_subscriber::registry()
.with(filter_layer)
.with(logger)
.init();
Expand Down Expand Up @@ -334,6 +380,8 @@ mod tests {

#[test]
fn output_with_inputs() {
let import_path: PathSegment = parse_quote!(shuttle_complex);

let input = Loader {
fn_ident: parse_quote!(complex),
fn_inputs: vec![
Expand All @@ -353,24 +401,25 @@ mod tests {
},
],
fn_return: parse_quote!(ShuttleComplex),
import_path,
};

let actual = quote!(#input);
let expected = quote! {
async fn loader<S: shuttle_runtime::StorageManager>(
mut factory: shuttle_runtime::ProvisionerFactory<S>,
logger: shuttle_runtime::Logger,
async fn loader<S: shuttle_complex::shuttle_runtime::StorageManager>(
mut factory: shuttle_complex::shuttle_runtime::ProvisionerFactory<S>,
logger: shuttle_complex::shuttle_runtime::Logger,
) -> ShuttleComplex {
use shuttle_service::Context;
use shuttle_service::tracing_subscriber::prelude::*;
use shuttle_service::ResourceBuilder;
use shuttle_complex::shuttle_runtime::Context;
use shuttle_complex::shuttle_runtime::tracing_subscriber::prelude::*;
use shuttle_complex::shuttle_runtime::ResourceBuilder;

let filter_layer =
shuttle_service::tracing_subscriber::EnvFilter::try_from_default_env()
.or_else(|_| shuttle_service::tracing_subscriber::EnvFilter::try_new("INFO"))
shuttle_complex::shuttle_runtime::tracing_subscriber::EnvFilter::try_from_default_env()
.or_else(|_| shuttle_complex::shuttle_runtime::tracing_subscriber::EnvFilter::try_new("INFO"))
.unwrap();

shuttle_service::tracing_subscriber::registry()
shuttle_complex::shuttle_runtime::tracing_subscriber::registry()
.with(filter_layer)
.with(logger)
.init();
Expand Down Expand Up @@ -460,6 +509,8 @@ mod tests {

#[test]
fn output_with_input_options() {
let import_path: PathSegment = parse_quote!(shuttle_complex);

let mut input = Loader {
fn_ident: parse_quote!(complex),
fn_inputs: vec![Input {
Expand All @@ -470,6 +521,7 @@ mod tests {
},
}],
fn_return: parse_quote!(ShuttleComplex),
import_path,
};

input.fn_inputs[0]
Expand All @@ -485,20 +537,20 @@ mod tests {

let actual = quote!(#input);
let expected = quote! {
async fn loader<S: shuttle_runtime::StorageManager>(
mut factory: shuttle_runtime::ProvisionerFactory<S>,
logger: shuttle_runtime::Logger,
async fn loader<S: shuttle_complex::shuttle_runtime::StorageManager>(
mut factory: shuttle_complex::shuttle_runtime::ProvisionerFactory<S>,
logger: shuttle_complex::shuttle_runtime::Logger,
) -> ShuttleComplex {
use shuttle_service::Context;
use shuttle_service::tracing_subscriber::prelude::*;
use shuttle_service::ResourceBuilder;
use shuttle_complex::shuttle_runtime::Context;
use shuttle_complex::shuttle_runtime::tracing_subscriber::prelude::*;
use shuttle_complex::shuttle_runtime::ResourceBuilder;

let filter_layer =
shuttle_service::tracing_subscriber::EnvFilter::try_from_default_env()
.or_else(|_| shuttle_service::tracing_subscriber::EnvFilter::try_new("INFO"))
shuttle_complex::shuttle_runtime::tracing_subscriber::EnvFilter::try_from_default_env()
.or_else(|_| shuttle_complex::shuttle_runtime::tracing_subscriber::EnvFilter::try_new("INFO"))
.unwrap();

shuttle_service::tracing_subscriber::registry()
shuttle_complex::shuttle_runtime::tracing_subscriber::registry()
.with(filter_layer)
.with(logger)
.init();
Expand Down
5 changes: 5 additions & 0 deletions integrations/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
## Service Integrations
The list of supported frameworks for shuttle is always growing. If you feel we are missing a framework you would like, then feel to create a feature request for your desired framework.

## Writing your own service integration
Creating your own service integration is quite simple. You only need to implement the [`Service`](https://docs.rs/shuttle-service/latest/shuttle_service/trait.Service.html) trait for your framework.
12 changes: 12 additions & 0 deletions integrations/shuttle-axum/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "shuttle-axum"
version = "0.1.0"
edition = "2021"

[workspace]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
async-trait = "0.1.56"
axum = { version = "0.6.0" }
shuttle-runtime = { path = "../../runtime", version = "0.1.0" }
39 changes: 39 additions & 0 deletions integrations/shuttle-axum/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//! Shuttle service integration for the Axum web framework.
//! ## Example
//! ```rust,no_run
//! use shuttle_axum::AxumService;
//!
//! async fn hello_world() -> &'static str {
//! "Hello, world!"
//! }
//!
//! #[shuttle_axum::main]
//! async fn axum() -> shuttle_service::ShuttleAxum {
//! let router = Router::new().route("/hello", get(hello_world));
//!
//! Ok(AxumService(router))
//! }
//! ```
/// A wrapper type for `axum::Router` so we can implement `shuttle_runtime::Service` for it.
pub struct AxumService(pub axum::Router);

#[shuttle_runtime::async_trait]
impl shuttle_runtime::Service for AxumService {
/// Takes the router that is returned by the user in their `shuttle_runtime::main` function
/// and binds to an address passed in by shuttle.
async fn bind(mut self, addr: std::net::SocketAddr) -> Result<(), shuttle_runtime::Error> {
axum::Server::bind(&addr)
.serve(self.0.into_make_service())
.await
.map_err(shuttle_runtime::CustomError::new)?;

Ok(())
}
}

/// The return type that should be returned from the `shuttle_runtime::main` function.
pub type ShuttleAxum = Result<AxumService, shuttle_runtime::Error>;

pub use shuttle_runtime;
pub use shuttle_runtime::*;
8 changes: 7 additions & 1 deletion runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@ mod logger;
mod next;
mod provisioner_factory;

pub use async_trait::async_trait;
pub use legacy::{start, Legacy};
pub use logger::Logger;
#[cfg(feature = "next")]
pub use next::{AxumWasm, NextArgs};
pub use provisioner_factory::ProvisionerFactory;
pub use shuttle_common::storage_manager::StorageManager;
pub use shuttle_service::{main, Error, Service};
pub use shuttle_service::{main, Error, ResourceBuilder, Service};

pub type CustomError = anyhow::Error;
pub use anyhow::Context;
pub use tracing;
pub use tracing_subscriber;

0 comments on commit 9508d42

Please sign in to comment.