Skip to content

Commit

Permalink
feat(auth): added billing backend support (#1289)
Browse files Browse the repository at this point in the history
* misc(auth): add async-stripe dep

* feat(auth): added update tier handler

The update tier accounts for changing the tier to Pro, requiring a
payload with a completed checkout session in case of upgrading to Pro.

* feat(auth): check subscription validity when getting user

* feat(common): user Response with subscription id if any

* test(auth): added integration tests for pro upgrade/downgrade

* auth: added DEVELOPING notes and STRIPE_SECRET_KEY in Makefile

* fix(auth): CI for auth fails because of memory preassure

* auth: address O review comment on builder

* circleci: added STRIPE_SECRET_KEY deploy images param

* auth: change how we decode

* auth: remove debug event

* DEVELOPING: adjust testing Pro tier notes

* auth: fix field name

* auth: async-stripe relies on rustls and installed certs

* auth: added executable rights to prepare.sh

* auth: update first the mirrors

* auth: add non-interactive flag

* auth: fix prepare_args passing

* auth: bump 0.28.1 && add explicit package version

* common: bump to 0.28.1
iulianbarbu authored Oct 5, 2023

Unverified

This user has not yet uploaded their public signing key.
1 parent bdbf92f commit b37b03f
Showing 26 changed files with 1,279 additions and 76 deletions.
8 changes: 7 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
@@ -393,6 +393,9 @@ jobs:
logger-postgres-uri:
description: "URI used to connect to the logger RDS postgres database."
type: string
stripe-secret-key:
description: "Stripe secret key used to connect a client to Stripe backend"
type: string
production:
description: "Push and deploy to production"
type: boolean
@@ -422,6 +425,7 @@ jobs:
HONEYCOMB_API_KEY=${<< parameters.honeycomb-api-key >>} \
DEPLOYS_API_KEY=${<< parameters.deploys-api-key >>} \
LOGGER_POSTGRES_URI=${<< parameters.logger-postgres-uri >>} \
STRIPE_SECRET_KEY=${<< parameters.stripe-secret-key >>} \
make deploy
- when:
condition: << parameters.production >>
@@ -705,7 +709,6 @@ workflows:
resource_class:
- medium
crate:
- shuttle-auth
- shuttle-proto
- shuttle-resource-recorder
- test-workspace-member-with-integration:
@@ -716,6 +719,7 @@ workflows:
resource_class:
- large
crate:
- shuttle-auth
- shuttle-runtime
- shuttle-service
- test-workspace-member-with-integration:
@@ -774,6 +778,7 @@ workflows:
honeycomb-api-key: DEV_HONEYCOMB_API_KEY
deploys-api-key: DEV_DEPLOYS_API_KEY
logger-postgres-uri: DEV_LOGGER_POSTGRES_URI
stripe-secret-key: DEV_STRIPE_SECRET_KEY
requires:
- build-and-push-unstable
release:
@@ -848,6 +853,7 @@ workflows:
honeycomb-api-key: PROD_HONEYCOMB_API_KEY
deploys-api-key: PROD_DEPLOYS_API_KEY
logger-postgres-uri: PROD_LOGGER_POSTGRES_URI
stripe-secret-key: PROD_STRIPE_SECRET_KEY
ssh-fingerprint: 6a:c5:33:fe:5b:c9:06:df:99:64:ca:17:0d:32:18:2e
ssh-config-script: production-ssh-config.sh
ssh-host: shuttle.prod.internal
276 changes: 237 additions & 39 deletions Cargo.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions Containerfile
Original file line number Diff line number Diff line change
@@ -59,6 +59,9 @@ RUN cargo build \
#### AUTH
FROM docker.io/library/debian:bookworm-20230904-slim AS shuttle-auth
ARG CARGO_PROFILE
ARG prepare_args
COPY auth/prepare.sh /prepare.sh
RUN /prepare.sh "${prepare_args}"
COPY --from=chef-builder /build/target/${CARGO_PROFILE}/shuttle-auth /usr/local/bin
ENTRYPOINT ["/usr/local/bin/shuttle-auth"]
FROM shuttle-auth AS shuttle-auth-dev
7 changes: 7 additions & 0 deletions DEVELOPING.md
Original file line number Diff line number Diff line change
@@ -287,6 +287,13 @@ Finally, configure Docker Compose. You can either

If you are using `nftables`, even with `iptables-nft`, it may be necessary to install and configure the [nftables CNI plugins](https://github.com/greenpau/cni-plugins)

## Testing the Pro tier

We use Stripe to start Pro subscriptions and verify them with a Stripe client that needs a secret key. The `STRIPE_SECRET_KEY` environment variable
should be set to test upgrading a user to Pro tier, or to use a Pro tier feature with cargo-shuttle CLI. On a local environment, that requires
setting up a Stripe account and generating a test API key. Auth can still be initialised and used without a Stripe secret key, but it will fail
when retrieving a user, and when we'll verify the subscription validity.
## Running Tests
Shuttle has reasonable test coverage - and we are working on improving this
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -45,6 +45,7 @@ DOCKER_SOCK?=/var/run/docker.sock
POSTGRES_PASSWORD?=postgres
MONGO_INITDB_ROOT_USERNAME?=mongodb
MONGO_INITDB_ROOT_PASSWORD?=password
STRIPE_SECRET_KEY?=""


ifeq ($(PROD),true)
@@ -135,6 +136,7 @@ DOCKER_COMPOSE_ENV=\
CONTAINER_REGISTRY=$(CONTAINER_REGISTRY)\
MONGO_INITDB_ROOT_USERNAME=$(MONGO_INITDB_ROOT_USERNAME)\
MONGO_INITDB_ROOT_PASSWORD=$(MONGO_INITDB_ROOT_PASSWORD)\
STRIPE_SECRET_KEY=$(STRIPE_SECRET_KEY)\
DD_ENV=$(DD_ENV)\
USE_TLS=$(USE_TLS)\
COMPOSE_PROFILES=$(COMPOSE_PROFILES)\
4 changes: 3 additions & 1 deletion auth/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "shuttle-auth"
version.workspace = true
version = "0.28.1"
edition.workspace = true
license.workspace = true
repository.workspace = true
@@ -23,6 +23,7 @@ sqlx = { workspace = true, features = [
"runtime-tokio-rustls",
"migrate",
] }
async-stripe = { version = "0.25.1", default-features = false, features = ["checkout", "runtime-tokio-hyper-rustls"] }
strum = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["full"] }
@@ -39,3 +40,4 @@ axum-extra = { version = "0.7.1", features = ["cookie"] }
hyper = { workspace = true }
serde_json = { workspace = true }
tower = { workspace = true, features = ["util"] }
portpicker = { workspace = true }
1 change: 1 addition & 0 deletions auth/migrations/0001_add_subscription_column.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE users ADD COLUMN subscription_id TEXT;
9 changes: 9 additions & 0 deletions auth/prepare.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/usr/bin/env sh

##############################################################################################
# This file is run by Containerfile for extra preparation steps for this crate's final image #
##############################################################################################

# We're using rustls for the async-stripe crate and that needs certificates installed.
apt-get update
apt install -y ca-certificates
21 changes: 17 additions & 4 deletions auth/src/api/builder.rs
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@ use crate::{

use super::handlers::{
convert_cookie, convert_key, get_public_key, get_user, health_check, logout, post_user,
put_user_reset_key, refresh_token,
put_user_reset_key, refresh_token, update_user_tier,
};

pub type UserManagerState = Arc<Box<dyn UserManagement>>;
@@ -53,6 +53,7 @@ pub struct ApiBuilder {
router: Router<RouterState>,
pool: Option<SqlitePool>,
session_layer: Option<SessionLayer<MemoryStore>>,
stripe_client: Option<stripe::Client>,
}

impl Default for ApiBuilder {
@@ -71,7 +72,10 @@ impl ApiBuilder {
.route("/auth/refresh", post(refresh_token))
.route("/public-key", get(get_public_key))
.route("/users/:account_name", get(get_user))
.route("/users/:account_name/:account_tier", post(post_user))
.route(
"/users/:account_name/:account_tier",
post(post_user).put(update_user_tier),
)
.route("/users/reset-api-key", put(put_user_reset_key))
.route_layer(from_extractor::<Metrics>())
.layer(
@@ -90,6 +94,7 @@ impl ApiBuilder {
router,
pool: None,
session_layer: None,
stripe_client: None,
}
}

@@ -112,11 +117,19 @@ impl ApiBuilder {
self
}

pub fn with_stripe_client(mut self, stripe_client: stripe::Client) -> Self {
self.stripe_client = Some(stripe_client);
self
}

pub fn into_router(self) -> Router {
let pool = self.pool.expect("an sqlite pool is required");
let session_layer = self.session_layer.expect("a session layer is required");

let user_manager = UserManager { pool };
let stripe_client = self.stripe_client.expect("a stripe client is required");
let user_manager = UserManager {
pool,
stripe_client,
};
let key_manager = EdDsaManager::new();

let state = RouterState {
27 changes: 26 additions & 1 deletion auth/src/api/handlers.rs
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@ use axum_sessions::extractors::{ReadableSession, WritableSession};
use http::StatusCode;
use serde::{Deserialize, Serialize};
use shuttle_common::{claims::Claim, models::user};
use stripe::CheckoutSession;
use tracing::instrument;

use super::{
@@ -39,14 +40,38 @@ pub(crate) async fn post_user(
Ok(Json(user.into()))
}

#[instrument(skip(user_manager))]
pub(crate) async fn update_user_tier(
_: Admin,
State(user_manager): State<UserManagerState>,
Path((account_name, account_tier)): Path<(AccountName, AccountTier)>,
payload: Option<Json<CheckoutSession>>,
) -> Result<(), Error> {
if account_tier == AccountTier::Pro {
match payload {
Some(Json(checkout_session)) => {
user_manager
.upgrade_to_pro(&account_name, checkout_session)
.await?;
}
None => return Err(Error::MissingCheckoutSession),
}
} else {
user_manager
.update_tier(&account_name, account_tier)
.await?;
};

Ok(())
}

pub(crate) async fn put_user_reset_key(
session: ReadableSession,
State(user_manager): State<UserManagerState>,
key: Option<Key>,
) -> Result<(), Error> {
let account_name = match session.get::<String>("account_name") {
Some(account_name) => account_name.into(),

None => match key {
Some(key) => user_manager.get_user_by_key(key.into()).await?.name,
None => return Err(Error::Unauthorized),
4 changes: 4 additions & 0 deletions auth/src/args.rs
Original file line number Diff line number Diff line change
@@ -24,6 +24,10 @@ pub struct StartArgs {
/// Address to bind to
#[arg(long, default_value = "127.0.0.1:8000")]
pub address: SocketAddr,

/// Stripe client secret key
#[arg(long, default_value = "")]
pub stripe_secret_key: String,
}

#[derive(clap::Args, Debug, Clone)]
12 changes: 12 additions & 0 deletions auth/src/error.rs
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@ use axum::Json;

use serde::{ser::SerializeMap, Serialize};
use shuttle_common::models::error::ApiError;
use stripe::StripeError;

#[derive(Debug, thiserror::Error)]
pub enum Error {
@@ -21,6 +22,14 @@ pub enum Error {
Database(#[from] sqlx::Error),
#[error(transparent)]
UnexpectedError(#[from] anyhow::Error),
#[error("Missing checkout session.")]
MissingCheckoutSession,
#[error("Incomplete checkout session.")]
IncompleteCheckoutSession,
#[error("Interacting with stripe resulted in error: {0}.")]
StripeError(#[from] StripeError),
#[error("Missing subscription ID from the checkout session.")]
MissingSubscriptionId,
}

impl Serialize for Error {
@@ -42,6 +51,9 @@ impl IntoResponse for Error {
Error::Forbidden => StatusCode::FORBIDDEN,
Error::Unauthorized | Error::KeyMissing => StatusCode::UNAUTHORIZED,
Error::Database(_) | Error::UserNotFound => StatusCode::NOT_FOUND,
Error::MissingCheckoutSession
| Error::MissingSubscriptionId
| Error::IncompleteCheckoutSession => StatusCode::BAD_REQUEST,
_ => StatusCode::INTERNAL_SERVER_ERROR,
};

1 change: 1 addition & 0 deletions auth/src/lib.rs
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@ pub async fn start(pool: SqlitePool, args: StartArgs) -> io::Result<()> {
let router = api::ApiBuilder::new()
.with_sqlite_pool(pool)
.with_sessions()
.with_stripe_client(stripe::Client::new(args.stripe_secret_key))
.into_router();

info!(address=%args.address, "Binding to and listening at address");
214 changes: 190 additions & 24 deletions auth/src/user.rs
Original file line number Diff line number Diff line change
@@ -12,14 +12,23 @@ use shuttle_common::{
claims::{Scope, ScopeBuilder},
ApiKey,
};
use sqlx::{query, Row, SqlitePool};
use tracing::{debug, trace, Span};
use sqlx::{query, sqlite::SqliteRow, FromRow, Row, SqlitePool};
use tracing::{debug, error, trace, Span};

use crate::{api::UserManagerState, error::Error};
use stripe::{
CheckoutSession, CheckoutSessionStatus, Expandable, SubscriptionId, SubscriptionStatus,
};

#[async_trait]
pub trait UserManagement: Send + Sync {
async fn create_user(&self, name: AccountName, tier: AccountTier) -> Result<User, Error>;
async fn upgrade_to_pro(
&self,
name: &AccountName,
checkout_session_metadata: CheckoutSession,
) -> Result<(), Error>;
async fn update_tier(&self, name: &AccountName, tier: AccountTier) -> Result<(), Error>;
async fn get_user(&self, name: AccountName) -> Result<User, Error>;
async fn get_user_by_key(&self, key: ApiKey) -> Result<User, Error>;
async fn reset_key(&self, name: AccountName) -> Result<(), Error>;
@@ -28,6 +37,7 @@ pub trait UserManagement: Send + Sync {
#[derive(Clone)]
pub struct UserManager {
pub pool: SqlitePool,
pub stripe_client: stripe::Client,
}

#[async_trait]
@@ -42,33 +52,104 @@ impl UserManagement for UserManager {
.execute(&self.pool)
.await?;

Ok(User::new(name, key, tier))
Ok(User::new(name, key, tier, None))
}

async fn get_user(&self, name: AccountName) -> Result<User, Error> {
query("SELECT account_name, key, account_tier FROM users WHERE account_name = ?1")
.bind(&name)
.fetch_optional(&self.pool)
// Update user tier to pro and update the subscription id.
async fn upgrade_to_pro(
&self,
name: &AccountName,
checkout_session_metadata: CheckoutSession,
) -> Result<(), Error> {
// Update the user tier and store the subscription id. We expect the checkout session to be
// completed when it is sent. In case of incomplete checkout sessions, auth backend will not
// fulfill the request.
if checkout_session_metadata
.status
.filter(|inner| inner == &CheckoutSessionStatus::Complete)
.is_some()
{
// Extract the checkout session status if any, otherwise return with error.
let subscription_id = checkout_session_metadata
.subscription
.map(|s| match s {
Expandable::Id(id) => id.to_string(),
Expandable::Object(obj) => obj.id.to_string(),
})
.ok_or(Error::MissingSubscriptionId)?;

// Update the user account tier and subscription_id.
let rows_affected = query(
"UPDATE users SET account_tier = ?1, subscription_id = ?2 WHERE account_name = ?3",
)
.bind(AccountTier::Pro)
.bind(subscription_id)
.bind(name)
.execute(&self.pool)
.await?
.map(|row| User {
name,
key: row.try_get("key").unwrap(),
account_tier: row.try_get("account_tier").unwrap(),
})
.ok_or(Error::UserNotFound)
.rows_affected();

// In case no rows were updated, this means the account doesn't exist.
if rows_affected > 0 {
Ok(())
} else {
Err(Error::UserNotFound)
}
} else {
Err(Error::IncompleteCheckoutSession)
}
}

async fn get_user_by_key(&self, key: ApiKey) -> Result<User, Error> {
query("SELECT account_name, key, account_tier FROM users WHERE key = ?1")
.bind(&key)
.fetch_optional(&self.pool)
// Update tier leaving the subscription_id untouched.
async fn update_tier(&self, name: &AccountName, tier: AccountTier) -> Result<(), Error> {
let rows_affected = query("UPDATE users SET account_tier = ?1 WHERE account_name = ?2")
.bind(tier)
.bind(name)
.execute(&self.pool)
.await?
.map(|row| User {
name: row.try_get("account_name").unwrap(),
key,
account_tier: row.try_get("account_tier").unwrap(),
})
.ok_or(Error::UserNotFound)
.rows_affected();

if rows_affected > 0 {
Ok(())
} else {
Err(Error::UserNotFound)
}
}

async fn get_user(&self, name: AccountName) -> Result<User, Error> {
let mut user: User =
sqlx::query_as("SELECT account_name, key, account_tier, subscription_id FROM users WHERE account_name = ?")
.bind(&name)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::UserNotFound)?;

// Sync the user tier based on the subscription validity, if any.
if let Err(err) = user.sync_tier(self).await {
error!("failed syncing account");
return Err(err);
} else {
debug!("synced account");
}

Ok(user)
}

async fn get_user_by_key(&self, key: ApiKey) -> Result<User, Error> {
let mut user: User = sqlx::query_as(
"SELECT account_name, key, account_tier, subscription_id FROM users WHERE key = ?",
)
.bind(&key)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::UserNotFound)?;

// Sync the user tier based on the subscription validity, if any.
if user.sync_tier(self).await? {
debug!("synced account");
}

Ok(user)
}

async fn reset_key(&self, name: AccountName) -> Result<(), Error> {
@@ -94,19 +175,79 @@ pub struct User {
pub name: AccountName,
pub key: ApiKey,
pub account_tier: AccountTier,
pub subscription_id: Option<SubscriptionId>,
}

impl User {
pub fn is_admin(&self) -> bool {
self.account_tier == AccountTier::Admin
}

pub fn new(name: AccountName, key: ApiKey, account_tier: AccountTier) -> Self {
pub fn new(
name: AccountName,
key: ApiKey,
account_tier: AccountTier,
subscription_id: Option<SubscriptionId>,
) -> Self {
Self {
name,
key,
account_tier,
subscription_id,
}
}

/// In case of an existing subscription, check if valid.
async fn subscription_is_valid(&self, client: &stripe::Client) -> Result<bool, Error> {
if let Some(subscription_id) = self.subscription_id.as_ref() {
let subscription = stripe::Subscription::retrieve(client, subscription_id, &[]).await?;
debug!("subscription: {:#?}", subscription);
return Ok(subscription.status == SubscriptionStatus::Active
|| subscription.status == SubscriptionStatus::Trialing);
}

Ok(false)
}

// Synchronize the tiers with the subscription validity.
async fn sync_tier(&mut self, user_manager: &UserManager) -> Result<bool, Error> {
let subscription_is_valid = self
.subscription_is_valid(&user_manager.stripe_client)
.await?;

if self.account_tier == AccountTier::Pro && !subscription_is_valid {
self.account_tier = AccountTier::PendingPaymentPro;
user_manager
.update_tier(&self.name, self.account_tier)
.await?;
return Ok(true);
}

if self.account_tier == AccountTier::PendingPaymentPro && subscription_is_valid {
self.account_tier = AccountTier::Pro;
user_manager
.update_tier(&self.name, self.account_tier)
.await?;
return Ok(true);
}

Ok(false)
}
}

impl FromRow<'_, SqliteRow> for User {
fn from_row(row: &SqliteRow) -> Result<Self, sqlx::Error> {
let x: &str = row.try_get("subscription_id").unwrap();
println!("{:?}", x);
Ok(User {
name: row.try_get("account_name").unwrap(),
key: row.try_get("key").unwrap(),
account_tier: row.try_get("account_tier").unwrap(),
subscription_id: row
.try_get("subscription_id")
.ok()
.and_then(|inner| SubscriptionId::from_str(inner).ok()),
})
}
}

@@ -142,6 +283,7 @@ impl From<User> for shuttle_common::models::user::Response {
name: user.name.to_string(),
key: user.key.as_ref().to_string(),
account_tier: user.account_tier.to_string(),
subscription_id: user.subscription_id.map(|inner| inner.to_string()),
}
}
}
@@ -188,6 +330,8 @@ where
pub enum AccountTier {
#[default]
Basic,
// A basic user that is pending a payment on the backend.
PendingPaymentPro,
Pro,
Team,
Admin,
@@ -299,6 +443,28 @@ mod tests {
);
}

#[test]
fn pending_payment_pro() {
let scopes: Vec<Scope> = AccountTier::PendingPaymentPro.into();

assert_eq!(
scopes,
vec![
Scope::Deployment,
Scope::DeploymentPush,
Scope::Logs,
Scope::Service,
Scope::ServiceCreate,
Scope::Project,
Scope::ProjectCreate,
Scope::Resources,
Scope::ResourcesWrite,
Scope::Secret,
Scope::SecretWrite,
]
);
}

#[test]
fn pro() {
let scopes: Vec<Scope> = AccountTier::Pro.into();
100 changes: 96 additions & 4 deletions auth/tests/api/helpers.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
use axum::{body::Body, response::Response, Router};
use hyper::http::{header::AUTHORIZATION, Request};
use std::{net::SocketAddr, str::FromStr};

use axum::{body::Body, extract::Path, response::Response, routing::get, Router};
use http::header::CONTENT_TYPE;
use hyper::{
http::{header::AUTHORIZATION, Request},
Server,
};
use serde_json::Value;
use shuttle_auth::{sqlite_init, ApiBuilder};
use sqlx::query;
use tower::ServiceExt;

use crate::stripe::MOCKED_SUBSCRIPTIONS;

pub(crate) const ADMIN_KEY: &str = "ndh9z58jttoes3qv";

pub(crate) struct TestApp {
pub router: Router,
pub mocked_stripe_server: MockedStripeServer,
}

/// Initialize a router with an in-memory sqlite database for each test.
pub(crate) async fn app() -> TestApp {
let sqlite_pool = sqlite_init("sqlite::memory:").await;

let mocked_stripe_server = MockedStripeServer::default();
// Insert an admin user for the tests.
query("INSERT INTO users (account_name, key, account_tier) VALUES (?1, ?2, ?3)")
.bind("admin")
@@ -26,9 +36,16 @@ pub(crate) async fn app() -> TestApp {
let router = ApiBuilder::new()
.with_sqlite_pool(sqlite_pool)
.with_sessions()
.with_stripe_client(stripe::Client::from_url(
mocked_stripe_server.uri.to_string().as_str(),
"",
))
.into_router();

TestApp { router }
TestApp {
router,
mocked_stripe_server,
}
}

impl TestApp {
@@ -51,6 +68,23 @@ impl TestApp {
self.send_request(request).await
}

pub async fn put_user(
&self,
name: &str,
tier: &str,
checkout_session: &'static str,
) -> Response {
let request = Request::builder()
.uri(format!("/users/{name}/{tier}"))
.method("PUT")
.header(AUTHORIZATION, format!("Bearer {ADMIN_KEY}"))
.header(CONTENT_TYPE, "application/json")
.body(Body::from(checkout_session))
.unwrap();

self.send_request(request).await
}

pub async fn get_user(&self, name: &str) -> Response {
let request = Request::builder()
.uri(format!("/users/{name}"))
@@ -61,3 +95,61 @@ impl TestApp {
self.send_request(request).await
}
}

#[derive(Clone)]
pub(crate) struct MockedStripeServer {
uri: http::Uri,
router: Router,
}

impl MockedStripeServer {
async fn subscription_retrieve_handler(
Path(subscription_id): Path<String>,
) -> axum::response::Response<String> {
let sessions = MOCKED_SUBSCRIPTIONS
.iter()
.filter(|sub| sub.contains(format!("\"id\": \"{}\"", subscription_id).as_str()))
.map(|sub| serde_json::from_str(sub).unwrap())
.collect::<Vec<Value>>();
if sessions.len() == 1 {
return Response::new(sessions[0].to_string());
}

Response::builder()
.status(http::StatusCode::NOT_FOUND)
.body("subscription id not found".to_string())
.unwrap()
}

pub(crate) async fn serve(self) {
let address = &SocketAddr::from_str(
format!("{}:{}", self.uri.host().unwrap(), self.uri.port().unwrap()).as_str(),
)
.unwrap();
println!("serving on: {}", address);
Server::bind(address)
.serve(self.router.into_make_service())
.await
.unwrap_or_else(|_| panic!("Failed to bind to address: {}", self.uri));
}
}

impl Default for MockedStripeServer {
fn default() -> MockedStripeServer {
let router = Router::new().route(
"/v1/subscriptions/:subscription_id",
get(MockedStripeServer::subscription_retrieve_handler),
);
MockedStripeServer {
uri: http::Uri::from_str(
format!(
"http://127.0.0.1:{}",
portpicker::pick_unused_port().unwrap()
)
.as_str(),
)
.unwrap(),
router,
}
}
}
1 change: 1 addition & 0 deletions auth/tests/api/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod auth;
mod helpers;
mod session;
mod stripe;
mod users;
147 changes: 147 additions & 0 deletions auth/tests/api/stripe/active_subscription.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
pub(crate) const MOCKED_ACTIVE_SUBSCRIPTION: &str = r#"{
"id": "sub_1Nw8xOD8t1tt0S3DtwAuOVp6",
"object": "subscription",
"application": null,
"application_fee_percent": null,
"automatic_tax": {
"enabled": false
},
"billing_cycle_anchor": 1696102566,
"billing_thresholds": null,
"cancel_at": null,
"cancel_at_period_end": false,
"canceled_at": null,
"cancellation_details": {
"comment": null,
"feedback": null,
"reason": null
},
"collection_method": "charge_automatically",
"created": 1696102566,
"currency": "ron",
"current_period_end": 1698694566,
"current_period_start": 1696102566,
"customer": "cus_OjcBtb9CGkRN0Q",
"days_until_due": null,
"default_payment_method": "pm_1Nw8xND8t1tt0S3DdoPw8WzZ",
"default_source": null,
"default_tax_rates": [],
"description": null,
"discount": null,
"ended_at": null,
"items": {
"object": "list",
"data": [
{
"id": "si_OjcB0PrsQ861FB",
"object": "subscription_item",
"billing_thresholds": null,
"created": 1696102567,
"metadata": {},
"plan": {
"id": "price_1NvdmxD8t1tt0S3DBi2jTI92",
"object": "plan",
"active": true,
"aggregate_usage": null,
"amount": 10000,
"amount_decimal": "10000",
"billing_scheme": "per_unit",
"created": 1695982755,
"currency": "ron",
"interval": "month",
"interval_count": 1,
"livemode": false,
"metadata": {},
"nickname": null,
"product": "prod_Oj5yfmphYbZ8RE",
"tiers_mode": null,
"transform_usage": null,
"trial_period_days": null,
"usage_type": "licensed"
},
"price": {
"id": "price_1NvdmxD8t1tt0S3DBi2jTI92",
"object": "price",
"active": true,
"billing_scheme": "per_unit",
"created": 1695982755,
"currency": "ron",
"custom_unit_amount": null,
"livemode": false,
"lookup_key": null,
"metadata": {},
"nickname": null,
"product": "prod_Oj5yfmphYbZ8RE",
"recurring": {
"aggregate_usage": null,
"interval": "month",
"interval_count": 1,
"trial_period_days": null,
"usage_type": "licensed"
},
"tax_behavior": "unspecified",
"tiers_mode": null,
"transform_quantity": null,
"type": "recurring",
"unit_amount": 10000,
"unit_amount_decimal": "10000"
},
"quantity": 1,
"subscription": "sub_1Nw8xOD8t1tt0S3DtwAuOVp6",
"tax_rates": []
}
],
"has_more": false,
"total_count": 1,
"url": "/v1/subscription_items?subscription=sub_1Nw8xOD8t1tt0S3DtwAuOVp6"
},
"latest_invoice": "in_1Nw8xOD8t1tt0S3DU4YDQ8ok",
"livemode": false,
"metadata": {},
"next_pending_invoice_item_invoice": null,
"on_behalf_of": null,
"pause_collection": null,
"payment_settings": {
"payment_method_options": null,
"payment_method_types": null,
"save_default_payment_method": "off"
},
"pending_invoice_item_interval": null,
"pending_setup_intent": null,
"pending_update": null,
"plan": {
"id": "price_1NvdmxD8t1tt0S3DBi2jTI92",
"object": "plan",
"active": true,
"aggregate_usage": null,
"amount": 10000,
"amount_decimal": "10000",
"billing_scheme": "per_unit",
"created": 1695982755,
"currency": "ron",
"interval": "month",
"interval_count": 1,
"livemode": false,
"metadata": {},
"nickname": null,
"product": "prod_Oj5yfmphYbZ8RE",
"tiers_mode": null,
"transform_usage": null,
"trial_period_days": null,
"usage_type": "licensed"
},
"quantity": 1,
"schedule": null,
"start_date": 1696102566,
"status": "active",
"test_clock": null,
"transfer_data": null,
"trial_end": null,
"trial_settings": {
"end_behavior": {
"missing_payment_method": "create_invoice"
}
},
"trial_start": null
}
"#;
80 changes: 80 additions & 0 deletions auth/tests/api/stripe/completed_checkout_session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
pub(crate) const MOCKED_COMPLETED_CHECKOUT_SESSION: &str = r#"{
"id": "cs_test_a1nmf3TXSDqYScpNLEroolP1ugCtk8Rx7kivUjYHLUdmjyJoociglcbN8q",
"object": "checkout.session",
"after_expiration": null,
"allow_promotion_codes": null,
"amount_subtotal": 10000,
"amount_total": 10000,
"automatic_tax": {
"enabled": false,
"status": null
},
"billing_address_collection": null,
"cancel_url": "https://example.com/cancel",
"client_reference_id": null,
"consent": null,
"consent_collection": null,
"created": 1696102521,
"currency": "ron",
"currency_conversion": null,
"custom_fields": [],
"custom_text": {
"shipping_address": null,
"submit": null,
"terms_of_service_acceptance": null
},
"customer": "cus_OjcBtb9CGkRN0Q",
"customer_creation": "always",
"customer_details": {
"address": {
"city": null,
"country": "RO",
"line1": null,
"line2": null,
"postal_code": null,
"state": null
},
"email": "iulian@shuttle.rs",
"name": "Iulian Barbu",
"phone": null,
"tax_exempt": "none",
"tax_ids": []
},
"customer_email": null,
"expires_at": 1696188921,
"invoice": "in_1Nw8xOD8t1tt0S3DU4YDQ8ok",
"invoice_creation": null,
"livemode": false,
"locale": null,
"metadata": {},
"mode": "subscription",
"payment_intent": null,
"payment_link": null,
"payment_method_collection": "always",
"payment_method_configuration_details": null,
"payment_method_options": null,
"payment_method_types": [
"card"
],
"payment_status": "paid",
"phone_number_collection": {
"enabled": false
},
"recovered_from": null,
"setup_intent": null,
"shipping_address_collection": null,
"shipping_cost": null,
"shipping_details": null,
"shipping_options": [],
"status": "complete",
"submit_type": null,
"subscription": "sub_1Nw8xOD8t1tt0S3DtwAuOVp6",
"success_url": "https://example.com/success?session_id={CHECKOUT_SESSION_ID}",
"total_details": {
"amount_discount": 0,
"amount_shipping": 0,
"amount_tax": 0
},
"url": null
}
"#;
69 changes: 69 additions & 0 deletions auth/tests/api/stripe/incomplete_checkout_session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
pub(crate) const MOCKED_INCOMPLETE_CHECKOUT_SESSION: &str = r#"{
"id": "cs_test_a11rHy7qRTwFZuj4lBHso3Frq7CMZheZYcYqNXEFBV4oddxXFLx7bT911p",
"object": "checkout.session",
"after_expiration": null,
"allow_promotion_codes": false,
"amount_subtotal": 10000,
"amount_total": 10000,
"automatic_tax": {
"enabled": false,
"status": null
},
"billing_address_collection": "auto",
"cancel_url": "https://stripe.com",
"client_reference_id": null,
"consent": null,
"consent_collection": {
"promotions": "none",
"terms_of_service": "none"
},
"created": 1696098429,
"currency": "ron",
"currency_conversion": null,
"custom_fields": [],
"custom_text": {
"shipping_address": null,
"submit": null,
"terms_of_service_acceptance": null
},
"customer": null,
"customer_creation": "if_required",
"customer_details": null,
"customer_email": null,
"expires_at": 1696184829,
"invoice": null,
"invoice_creation": null,
"livemode": false,
"locale": "auto",
"metadata": {},
"mode": "subscription",
"payment_intent": null,
"payment_link": "plink_1Nw7sYD8t1tt0S3DHQRms10g",
"payment_method_collection": "always",
"payment_method_configuration_details": null,
"payment_method_options": null,
"payment_method_types": [
"card"
],
"payment_status": "unpaid",
"phone_number_collection": {
"enabled": false
},
"recovered_from": null,
"setup_intent": null,
"shipping_address_collection": null,
"shipping_cost": null,
"shipping_details": null,
"shipping_options": [],
"status": "open",
"submit_type": "auto",
"subscription": null,
"success_url": "https://stripe.com",
"total_details": {
"amount_discount": 0,
"amount_shipping": 0,
"amount_tax": 0
},
"url": "https://checkout.stripe.com/c/pay/cs_test_a11rHy7qRTwFZuj4lBHso3Frq7CMZheZYcYqNXEFBV4oddxXFLx7bT911p#fidkdWxOYHwnPyd1blpxYHZxWjA0S3NhbkhBPXE0cXE1VjZBMm4xSjBpMm9LVEFhczBBVjF8XVx1aTdAVlxiUGlyN0J1d2xjXTU2cXNoNExzbzYwS1VufDZOS0IwV1ZUQ290RjxycXxTVEpjNTVIZnZXdVdkUycpJ2N3amhWYHdzYHcnP3F3cGApJ2lkfGpwcVF8dWAnPyd2bGtiaWBabHFgaCcpJ2BrZGdpYFVpZGZgbWppYWB3dic%2FcXdwYHgl"
}
"#;
22 changes: 22 additions & 0 deletions auth/tests/api/stripe/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use self::{
active_subscription::MOCKED_ACTIVE_SUBSCRIPTION,
completed_checkout_session::MOCKED_COMPLETED_CHECKOUT_SESSION,
incomplete_checkout_session::MOCKED_INCOMPLETE_CHECKOUT_SESSION,
overdue_payment_checkout_session::MOCKED_OVERDUE_PAYMENT_CHECKOUT_SESSION,
past_due_subscription::MOCKED_PAST_DUE_SUBSCRIPTION,
};

mod active_subscription;
mod completed_checkout_session;
mod incomplete_checkout_session;
mod overdue_payment_checkout_session;
mod past_due_subscription;

pub(crate) const MOCKED_SUBSCRIPTIONS: &[&str] =
&[MOCKED_ACTIVE_SUBSCRIPTION, MOCKED_PAST_DUE_SUBSCRIPTION];

pub(crate) const MOCKED_CHECKOUT_SESSIONS: &[&str] = &[
MOCKED_COMPLETED_CHECKOUT_SESSION,
MOCKED_INCOMPLETE_CHECKOUT_SESSION,
MOCKED_OVERDUE_PAYMENT_CHECKOUT_SESSION,
];
72 changes: 72 additions & 0 deletions auth/tests/api/stripe/overdue_payment_checkout_session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// This is a synthetic checkout session. It is used to simplify the code path for downgrading to `PendingPaymentPro` tier
// when user payment is overdue.

pub(crate) const MOCKED_OVERDUE_PAYMENT_CHECKOUT_SESSION: &str = r#"{
"id": "cs_test_a11rHy7qRTwFZuj4lBHso3Frq7CMZheZYcYqNXEFBV4oddxXFLx7bT911p",
"object": "checkout.session",
"after_expiration": null,
"allow_promotion_codes": false,
"amount_subtotal": 10000,
"amount_total": 10000,
"automatic_tax": {
"enabled": false,
"status": null
},
"billing_address_collection": "auto",
"cancel_url": "https://stripe.com",
"client_reference_id": null,
"consent": null,
"consent_collection": {
"promotions": "none",
"terms_of_service": "none"
},
"created": 1696098429,
"currency": "ron",
"currency_conversion": null,
"custom_fields": [],
"custom_text": {
"shipping_address": null,
"submit": null,
"terms_of_service_acceptance": null
},
"customer": null,
"customer_creation": "if_required",
"customer_details": null,
"customer_email": null,
"expires_at": 1696184829,
"invoice": null,
"invoice_creation": null,
"livemode": false,
"locale": "auto",
"metadata": {},
"mode": "subscription",
"payment_intent": null,
"payment_link": "plink_1Nw7sYD8t1tt0S3DHQRms10g",
"payment_method_collection": "always",
"payment_method_configuration_details": null,
"payment_method_options": null,
"payment_method_types": [
"card"
],
"payment_status": "unpaid",
"phone_number_collection": {
"enabled": false
},
"recovered_from": null,
"setup_intent": null,
"shipping_address_collection": null,
"shipping_cost": null,
"shipping_details": null,
"shipping_options": [],
"status": "complete",
"submit_type": "auto",
"subscription": "sub_1NwObED8t1tt0S3Dq0IYOEsa",
"success_url": "https://stripe.com",
"total_details": {
"amount_discount": 0,
"amount_shipping": 0,
"amount_tax": 0
},
"url": "https://checkout.stripe.com/c/pay/cs_test_a11rHy7qRTwFZuj4lBHso3Frq7CMZheZYcYqNXEFBV4oddxXFLx7bT911p#fidkdWxOYHwnPyd1blpxYHZxWjA0S3NhbkhBPXE0cXE1VjZBMm4xSjBpMm9LVEFhczBBVjF8XVx1aTdAVlxiUGlyN0J1d2xjXTU2cXNoNExzbzYwS1VufDZOS0IwV1ZUQ290RjxycXxTVEpjNTVIZnZXdVdkUycpJ2N3amhWYHdzYHcnP3F3cGApJ2lkfGpwcVF8dWAnPyd2bGtiaWBabHFgaCcpJ2BrZGdpYFVpZGZgbWppYWB3dic%2FcXdwYHgl"
}
"#;
147 changes: 147 additions & 0 deletions auth/tests/api/stripe/past_due_subscription.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
pub(crate) const MOCKED_PAST_DUE_SUBSCRIPTION: &str = r#"{
"id": "sub_1NwObED8t1tt0S3Dq0IYOEsa",
"object": "subscription",
"application": null,
"application_fee_percent": null,
"automatic_tax": {
"enabled": false
},
"billing_cycle_anchor": 1698930360,
"billing_thresholds": null,
"cancel_at": null,
"cancel_at_period_end": false,
"canceled_at": null,
"cancellation_details": {
"comment": null,
"feedback": null,
"reason": null
},
"collection_method": "send_invoice",
"created": 1698930360,
"currency": "ron",
"current_period_end": 1709384760,
"current_period_start": 1706879160,
"customer": "cus_OjsLL84gFdbFPP",
"days_until_due": 30,
"default_payment_method": null,
"default_source": null,
"default_tax_rates": [],
"description": null,
"discount": null,
"ended_at": null,
"items": {
"object": "list",
"data": [
{
"id": "si_OjsLE9Q9sTZMtM",
"object": "subscription_item",
"billing_thresholds": null,
"created": 1698930360,
"metadata": {},
"plan": {
"id": "price_1NvdmxD8t1tt0S3DBi2jTI92",
"object": "plan",
"active": true,
"aggregate_usage": null,
"amount": 10000,
"amount_decimal": "10000",
"billing_scheme": "per_unit",
"created": 1695982755,
"currency": "ron",
"interval": "month",
"interval_count": 1,
"livemode": false,
"metadata": {},
"nickname": null,
"product": "prod_Oj5yfmphYbZ8RE",
"tiers_mode": null,
"transform_usage": null,
"trial_period_days": null,
"usage_type": "licensed"
},
"price": {
"id": "price_1NvdmxD8t1tt0S3DBi2jTI92",
"object": "price",
"active": true,
"billing_scheme": "per_unit",
"created": 1695982755,
"currency": "ron",
"custom_unit_amount": null,
"livemode": false,
"lookup_key": null,
"metadata": {},
"nickname": null,
"product": "prod_Oj5yfmphYbZ8RE",
"recurring": {
"aggregate_usage": null,
"interval": "month",
"interval_count": 1,
"trial_period_days": null,
"usage_type": "licensed"
},
"tax_behavior": "unspecified",
"tiers_mode": null,
"transform_quantity": null,
"type": "recurring",
"unit_amount": 10000,
"unit_amount_decimal": "10000"
},
"quantity": 1,
"subscription": "sub_1NwObED8t1tt0S3Dq0IYOEsa",
"tax_rates": []
}
],
"has_more": false,
"total_count": 1,
"url": "/v1/subscription_items?subscription=sub_1NwObED8t1tt0S3Dq0IYOEsa"
},
"latest_invoice": "in_1NwOgYD8t1tt0S3DWcXcslkk",
"livemode": false,
"metadata": {},
"next_pending_invoice_item_invoice": null,
"on_behalf_of": null,
"pause_collection": null,
"payment_settings": {
"payment_method_options": null,
"payment_method_types": null,
"save_default_payment_method": "off"
},
"pending_invoice_item_interval": null,
"pending_setup_intent": null,
"pending_update": null,
"plan": {
"id": "price_1NvdmxD8t1tt0S3DBi2jTI92",
"object": "plan",
"active": true,
"aggregate_usage": null,
"amount": 10000,
"amount_decimal": "10000",
"billing_scheme": "per_unit",
"created": 1695982755,
"currency": "ron",
"interval": "month",
"interval_count": 1,
"livemode": false,
"metadata": {},
"nickname": null,
"product": "prod_Oj5yfmphYbZ8RE",
"tiers_mode": null,
"transform_usage": null,
"trial_period_days": null,
"usage_type": "licensed"
},
"quantity": 1,
"schedule": null,
"start_date": 1698930360,
"status": "past_due",
"test_clock": "clock_1NwOQ2D8t1tt0S3DShKPQWLB",
"transfer_data": null,
"trial_end": null,
"trial_settings": {
"end_behavior": {
"missing_payment_method": "create_invoice"
}
},
"trial_start": null
}
"#;
124 changes: 123 additions & 1 deletion auth/tests/api/users.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use crate::helpers::{self, app};
use std::time::Duration;

use crate::{
helpers::{self, app},
stripe::{MOCKED_CHECKOUT_SESSIONS, MOCKED_SUBSCRIPTIONS},
};
use axum::body::Body;
use hyper::http::{header::AUTHORIZATION, Request, StatusCode};
use serde_json::{self, Value};
@@ -104,6 +109,123 @@ async fn get_user() {
assert_eq!(user, persisted_user);
}

#[tokio::test]
async fn successful_upgrade_to_pro() {
let app = app().await;

// Wait for the mocked Stripe server to start.
tokio::task::spawn(app.mocked_stripe_server.clone().serve());
tokio::time::sleep(Duration::from_secs(1)).await;

// POST user first so one exists in the database.
let response = app.post_user("test-user", "basic").await;

assert_eq!(response.status(), StatusCode::OK);

let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
let expected_user: Value = serde_json::from_slice(&body).unwrap();

let response = app
.put_user("test-user", "pro", MOCKED_CHECKOUT_SESSIONS[0])
.await;
assert_eq!(response.status(), StatusCode::OK);

let response = app.get_user("test-user").await;
assert_eq!(response.status(), StatusCode::OK);
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
let actual_user: Value = serde_json::from_slice(&body).unwrap();

assert_eq!(
expected_user.as_object().unwrap().get("name").unwrap(),
actual_user.as_object().unwrap().get("name").unwrap()
);

assert_eq!(
expected_user.as_object().unwrap().get("key").unwrap(),
actual_user.as_object().unwrap().get("key").unwrap()
);

assert_eq!(
actual_user
.as_object()
.unwrap()
.get("account_tier")
.unwrap(),
"pro"
);

let mocked_subscription_obj: Value = serde_json::from_str(MOCKED_SUBSCRIPTIONS[0]).unwrap();
assert_eq!(
actual_user
.as_object()
.unwrap()
.get("subscription_id")
.unwrap(),
mocked_subscription_obj
.as_object()
.unwrap()
.get("id")
.unwrap()
);
}

#[tokio::test]
async fn unsuccessful_upgrade_to_pro() {
let app = app().await;

// Wait for the mocked Stripe server to start.
tokio::task::spawn(app.mocked_stripe_server.clone().serve());
tokio::time::sleep(Duration::from_secs(1)).await;

// POST user first so one exists in the database.
let response = app.post_user("test-user", "basic").await;
assert_eq!(response.status(), StatusCode::OK);

// Test upgrading to pro without a checkout session object.
let response = app.put_user("test-user", "pro", "").await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST);

// Test upgrading to pro with an incomplete checkout session object.
let response = app
.put_user("test-user", "pro", MOCKED_CHECKOUT_SESSIONS[1])
.await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}

#[tokio::test]
async fn downgrade_in_case_subscription_due_payment() {
let app = app().await;

// Wait for the mocked Stripe server to start.
tokio::task::spawn(app.mocked_stripe_server.clone().serve());
tokio::time::sleep(Duration::from_secs(1)).await;

// POST user first so one exists in the database.
let response = app.post_user("test-user", "basic").await;
assert_eq!(response.status(), StatusCode::OK);

// Test upgrading to pro with a checkout session that points to a due session.
let response = app
.put_user("test-user", "pro", MOCKED_CHECKOUT_SESSIONS[2])
.await;
assert_eq!(response.status(), StatusCode::OK);

// This get_user request should check the subscription status and return an accurate tier.
let response = app.get_user("test-user").await;
assert_eq!(response.status(), StatusCode::OK);
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
let actual_user: Value = serde_json::from_slice(&body).unwrap();

assert_eq!(
actual_user
.as_object()
.unwrap()
.get("account_tier")
.unwrap(),
"pendingpaymentpro"
);
}

#[tokio::test]
async fn test_reset_key() {
let app = app().await;
2 changes: 1 addition & 1 deletion common/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "shuttle-common"
version = "0.28.0"
version = "0.28.1"
edition.workspace = true
license.workspace = true
repository.workspace = true
1 change: 1 addition & 0 deletions common/src/models/user.rs
Original file line number Diff line number Diff line change
@@ -5,4 +5,5 @@ pub struct Response {
pub name: String,
pub key: String,
pub account_tier: String,
pub subscription_id: Option<String>,
}
1 change: 1 addition & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -52,6 +52,7 @@ services:
- "--state=/var/lib/shuttle-auth"
- "start"
- "--address=0.0.0.0:8000"
- "--stripe-secret-key=${STRIPE_SECRET_KEY}"
builder:
image: "${CONTAINER_REGISTRY}/builder:${BUILDER_TAG}"
depends_on:

0 comments on commit b37b03f

Please sign in to comment.