From 0de844d0eca1a6d333012a76a09d024cad0d981f Mon Sep 17 00:00:00 2001 From: Damien Broka Date: Thu, 3 Nov 2022 11:41:12 +0000 Subject: [PATCH 1/2] feat(gateway): add account_tier column --- .../0001_add_top_level_account_perms.sql | 1 + gateway/src/api/latest.rs | 2 +- gateway/src/auth.rs | 117 ++++++++++++++++-- gateway/src/lib.rs | 4 +- gateway/src/service.rs | 83 ++++++------- 5 files changed, 149 insertions(+), 58 deletions(-) create mode 100644 gateway/migrations/0001_add_top_level_account_perms.sql diff --git a/gateway/migrations/0001_add_top_level_account_perms.sql b/gateway/migrations/0001_add_top_level_account_perms.sql new file mode 100644 index 000000000..15731aed2 --- /dev/null +++ b/gateway/migrations/0001_add_top_level_account_perms.sql @@ -0,0 +1 @@ +ALTER TABLE accounts ADD account_tier TEXT DEFAULT "basic" NOT NULL; diff --git a/gateway/src/api/latest.rs b/gateway/src/api/latest.rs index 8ebb68730..497c330df 100644 --- a/gateway/src/api/latest.rs +++ b/gateway/src/api/latest.rs @@ -49,7 +49,7 @@ async fn get_user( Path(account_name): Path, _: Admin, ) -> Result, Error> { - let user = service.user_from_account_name(account_name).await?; + let user = User::retrieve_from_account_name(&service, account_name).await?; Ok(AxumJson(user.into())) } diff --git a/gateway/src/auth.rs b/gateway/src/auth.rs index c5ea0b6ae..29ce85528 100644 --- a/gateway/src/auth.rs +++ b/gateway/src/auth.rs @@ -1,4 +1,4 @@ -use std::fmt::Formatter; +use std::fmt::{Debug, Formatter}; use std::str::FromStr; use std::sync::Arc; @@ -58,8 +58,6 @@ impl Key { } } -const FALSE: fn() -> bool = || false; - /// A wrapper for a guard that verifies an API key is associated with a /// valid user. /// @@ -71,11 +69,113 @@ pub struct User { pub name: AccountName, pub key: Key, pub projects: Vec, - #[serde(skip_serializing_if = "std::ops::Not::not")] - #[serde(default = "FALSE")] + pub permissions: Permissions, +} + +impl User { + pub fn is_super_user(&self) -> bool { + self.permissions.is_super_user() + } + + pub fn new_with_defaults(name: AccountName, key: Key) -> Self { + Self { + name, + key, + projects: Vec::new(), + permissions: Permissions::default(), + } + } + + pub async fn retrieve_from_account_name( + svc: &GatewayService, + name: AccountName, + ) -> Result { + let key = svc.key_from_account_name(&name).await?; + let permissions = svc.get_permissions(&name).await?; + let projects = svc.iter_user_projects(&name).await?.collect(); + Ok(User { + name, + key, + projects, + permissions, + }) + } + + pub async fn retrieve_from_key(svc: &GatewayService, key: Key) -> Result { + let name = svc.account_name_from_key(&key).await?; + let permissions = svc.get_permissions(&name).await?; + let projects = svc.iter_user_projects(&name).await?.collect(); + Ok(User { + name, + key, + projects, + permissions, + }) + } +} + +#[derive(Clone, Copy, Deserialize, PartialEq, Eq, Serialize, Debug, sqlx::Type)] +#[sqlx(rename_all = "lowercase")] +pub enum AccountTier { + Basic, + Pro, + Team, +} + +#[derive(Default)] +pub struct PermissionsBuilder { + tier: Option, + super_user: Option, +} + +impl PermissionsBuilder { + pub fn super_user(mut self, is_super_user: bool) -> Self { + self.super_user = Some(is_super_user); + self + } + + pub fn tier(mut self, tier: AccountTier) -> Self { + self.tier = Some(tier); + self + } + + pub fn build(self) -> Permissions { + Permissions { + tier: self.tier.unwrap_or(AccountTier::Basic), + super_user: self.super_user.unwrap_or_default(), + } + } +} + +#[derive(Clone, Deserialize, PartialEq, Eq, Serialize, Debug)] +pub struct Permissions { + pub tier: AccountTier, pub super_user: bool, } +impl Default for Permissions { + fn default() -> Self { + Self { + tier: AccountTier::Basic, + super_user: false, + } + } +} + +impl Permissions { + pub fn builder() -> PermissionsBuilder { + PermissionsBuilder::default() + } + + pub fn tier(&self) -> &AccountTier { + &self.tier + } + + pub fn is_super_user(&self) -> bool { + self.super_user + } +} + #[async_trait] impl FromRequest for User where @@ -88,8 +188,7 @@ where let Extension(service) = Extension::>::from_request(req) .await .unwrap(); - let user = service - .user_from_key(key) + let user = User::retrieve_from_key(&service, key) .await // Absord any error into `Unauthorized` .map_err(|e| Error::source(ErrorKind::Unauthorized, e))?; @@ -144,7 +243,7 @@ where // Record current project for tracing purposes Span::current().record("account.project", &scope.to_string()); - if user.super_user || user.projects.contains(&scope) { + if user.is_super_user() || user.projects.contains(&scope) { Ok(Self { user, scope }) } else { Err(Error::from(ErrorKind::ProjectNotFound)) @@ -165,7 +264,7 @@ where async fn from_request(req: &mut RequestParts) -> Result { let user = User::from_request(req).await?; - if user.super_user { + if user.is_super_user() { Ok(Self { user }) } else { Err(Error::from(ErrorKind::Forbidden)) diff --git a/gateway/src/lib.rs b/gateway/src/lib.rs index d4ce14b96..687bec738 100644 --- a/gateway/src/lib.rs +++ b/gateway/src/lib.rs @@ -286,7 +286,7 @@ pub mod tests { use hyper::http::Uri; use hyper::{Body, Client as HyperClient, Request, Response, StatusCode}; use rand::distributions::{Alphanumeric, DistString, Distribution, Uniform}; - use shuttle_common::models::{project, service}; + use shuttle_common::models::{project, service, user}; use sqlx::SqlitePool; use tokio::sync::mpsc::channel; @@ -648,7 +648,7 @@ pub mod tests { let User { key, name, .. } = service.create_user("neo".parse().unwrap()).await.unwrap(); service.set_super_user(&name, true).await.unwrap(); - let User { key, .. } = api_client + let user::Response { key, .. } = api_client .request( Request::post("/users/trinity") .with_header(&Authorization::bearer(key.as_str()).unwrap()) diff --git a/gateway/src/service.rs b/gateway/src/service.rs index 2626b2adb..b650e9f5a 100644 --- a/gateway/src/service.rs +++ b/gateway/src/service.rs @@ -23,7 +23,7 @@ use tracing::{debug, Span}; use tracing_opentelemetry::OpenTelemetrySpanExt; use crate::args::ContextArgs; -use crate::auth::{Key, User}; +use crate::auth::{Key, User, Permissions}; use crate::project::Project; use crate::task::TaskBuilder; use crate::{AccountName, DockerContext, Error, ErrorKind, ProjectName}; @@ -304,30 +304,6 @@ impl GatewayService { Ok(control_key) } - pub async fn user_from_account_name(&self, name: AccountName) -> Result { - let key = self.key_from_account_name(&name).await?; - let super_user = self.is_super_user(&name).await?; - let projects = self.iter_user_projects(&name).await?.collect(); - Ok(User { - name, - key, - projects, - super_user, - }) - } - - pub async fn user_from_key(&self, key: Key) -> Result { - let name = self.account_name_from_key(&key).await?; - let super_user = self.is_super_user(&name).await?; - let projects = self.iter_user_projects(&name).await?.collect(); - Ok(User { - name, - key, - projects, - super_user, - }) - } - pub async fn create_user(&self, name: AccountName) -> Result { let key = Key::new_random(); query("INSERT INTO accounts (account_name, key) VALUES (?1, ?2)") @@ -347,38 +323,52 @@ impl GatewayService { // Otherwise this is internal err.into() })?; - Ok(User { - name, - key, - projects: Vec::default(), - super_user: false, - }) + Ok(User::new_with_defaults(name, key)) } - pub async fn is_super_user(&self, account_name: &AccountName) -> Result { - let is_super_user = query("SELECT super_user FROM accounts WHERE account_name = ?1") + pub async fn get_permissions(&self, account_name: &AccountName) -> Result { + let permissions = query("SELECT super_user, account_tier FROM accounts WHERE account_name = ?1") .bind(account_name) .fetch_optional(&self.db) .await? - .map(|row| row.try_get("super_user").unwrap()) - .unwrap_or(false); // defaults to `false` (i.e. not super user) - Ok(is_super_user) + .map(|row| { + Permissions::builder() + .super_user(row.try_get("super_user").unwrap()) + .tier(row.try_get("account_tier").unwrap()) + .build() + }) + .unwrap_or_default(); // defaults to `false` (i.e. not super user) + Ok(permissions) } pub async fn set_super_user( &self, account_name: &AccountName, - value: bool, + super_user: bool ) -> Result<(), Error> { query("UPDATE accounts SET super_user = ?1 WHERE account_name = ?2") - .bind(value) + .bind(super_user) + .bind(account_name) + .execute(&self.db) + .await?; + Ok(()) + } + + pub async fn set_permissions( + &self, + account_name: &AccountName, + permissions: &Permissions + ) -> Result<(), Error> { + query("UPDATE accounts SET super_user = ?1, account_tier = ?2 WHERE account_name = ?3") + .bind(&permissions.super_user) + .bind(&permissions.tier) .bind(account_name) .execute(&self.db) .await?; Ok(()) } - async fn iter_user_projects( + pub async fn iter_user_projects( &self, AccountName(account_name): &AccountName, ) -> Result, Error> { @@ -489,6 +479,7 @@ pub mod tests { use std::str::FromStr; use super::*; + use crate::auth::AccountTier; use crate::task::{self, TaskResult}; use crate::tests::{assert_err_kind, World}; use crate::{Error, ErrorKind}; @@ -501,34 +492,34 @@ pub mod tests { let account_name: AccountName = "test_user_123".parse()?; assert_err_kind!( - svc.user_from_account_name(account_name.clone()).await, + User::retrieve_from_account_name(&svc, account_name.clone()).await, ErrorKind::UserNotFound ); assert_err_kind!( - svc.user_from_key(Key::from_str("123").unwrap()).await, + User::retrieve_from_key(&svc, Key::from_str("123").unwrap()).await, ErrorKind::UserNotFound ); let user = svc.create_user(account_name.clone()).await?; assert_eq!( - svc.user_from_account_name(account_name.clone()).await?, + User::retrieve_from_account_name(&svc, account_name.clone()).await?, user ); - assert!(!svc.is_super_user(&account_name).await?); - let User { name, key, projects, - super_user, + permissions, } = user; assert!(projects.is_empty()); - assert!(!super_user); + assert!(!permissions.is_super_user()); + + assert_eq!(*permissions.tier(), AccountTier::Basic); assert_eq!(name, account_name); From 868711c7249b587ea2757c5ef0409cbc5d23d019 Mon Sep 17 00:00:00 2001 From: Damien Broka Date: Fri, 4 Nov 2022 13:08:58 +0000 Subject: [PATCH 2/2] fmt --- gateway/src/service.rs | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/gateway/src/service.rs b/gateway/src/service.rs index b650e9f5a..eec11196f 100644 --- a/gateway/src/service.rs +++ b/gateway/src/service.rs @@ -23,7 +23,7 @@ use tracing::{debug, Span}; use tracing_opentelemetry::OpenTelemetrySpanExt; use crate::args::ContextArgs; -use crate::auth::{Key, User, Permissions}; +use crate::auth::{Key, Permissions, User}; use crate::project::Project; use crate::task::TaskBuilder; use crate::{AccountName, DockerContext, Error, ErrorKind, ProjectName}; @@ -327,24 +327,25 @@ impl GatewayService { } pub async fn get_permissions(&self, account_name: &AccountName) -> Result { - let permissions = query("SELECT super_user, account_tier FROM accounts WHERE account_name = ?1") - .bind(account_name) - .fetch_optional(&self.db) - .await? - .map(|row| { - Permissions::builder() - .super_user(row.try_get("super_user").unwrap()) - .tier(row.try_get("account_tier").unwrap()) - .build() - }) - .unwrap_or_default(); // defaults to `false` (i.e. not super user) + let permissions = + query("SELECT super_user, account_tier FROM accounts WHERE account_name = ?1") + .bind(account_name) + .fetch_optional(&self.db) + .await? + .map(|row| { + Permissions::builder() + .super_user(row.try_get("super_user").unwrap()) + .tier(row.try_get("account_tier").unwrap()) + .build() + }) + .unwrap_or_default(); // defaults to `false` (i.e. not super user) Ok(permissions) } pub async fn set_super_user( &self, account_name: &AccountName, - super_user: bool + super_user: bool, ) -> Result<(), Error> { query("UPDATE accounts SET super_user = ?1 WHERE account_name = ?2") .bind(super_user) @@ -357,7 +358,7 @@ impl GatewayService { pub async fn set_permissions( &self, account_name: &AccountName, - permissions: &Permissions + permissions: &Permissions, ) -> Result<(), Error> { query("UPDATE accounts SET super_user = ?1, account_tier = ?2 WHERE account_name = ?3") .bind(&permissions.super_user)