Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(gateway): add account_tier column #458

Merged
merged 2 commits into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gateway/migrations/0001_add_top_level_account_perms.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE accounts ADD account_tier TEXT DEFAULT "basic" NOT NULL;
2 changes: 1 addition & 1 deletion gateway/src/api/latest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async fn get_user(
Path(account_name): Path<AccountName>,
_: Admin,
) -> Result<AxumJson<user::Response>, Error> {
let user = service.user_from_account_name(account_name).await?;
let user = User::retrieve_from_account_name(&service, account_name).await?;

Ok(AxumJson(user.into()))
}
Expand Down
117 changes: 108 additions & 9 deletions gateway/src/auth.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::fmt::Formatter;
use std::fmt::{Debug, Formatter};
use std::str::FromStr;
use std::sync::Arc;

Expand Down Expand Up @@ -58,8 +58,6 @@ impl Key {
}
}

const FALSE: fn() -> bool = || false;

/// A wrapper for a guard that verifies an API key is associated with a
/// valid user.
///
Expand All @@ -71,11 +69,113 @@ pub struct User {
pub name: AccountName,
pub key: Key,
pub projects: Vec<ProjectName>,
#[serde(skip_serializing_if = "std::ops::Not::not")]
#[serde(default = "FALSE")]
pub permissions: Permissions,
}

impl User {
pub fn is_super_user(&self) -> bool {
self.permissions.is_super_user()
}

pub fn new_with_defaults(name: AccountName, key: Key) -> Self {
Self {
name,
key,
projects: Vec::new(),
permissions: Permissions::default(),
}
}

pub async fn retrieve_from_account_name(
svc: &GatewayService,
name: AccountName,
) -> Result<User, Error> {
let key = svc.key_from_account_name(&name).await?;
let permissions = svc.get_permissions(&name).await?;
let projects = svc.iter_user_projects(&name).await?.collect();
Ok(User {
name,
key,
projects,
permissions,
})
}

pub async fn retrieve_from_key(svc: &GatewayService, key: Key) -> Result<User, Error> {
let name = svc.account_name_from_key(&key).await?;
let permissions = svc.get_permissions(&name).await?;
let projects = svc.iter_user_projects(&name).await?.collect();
Ok(User {
name,
key,
projects,
permissions,
})
}
}

#[derive(Clone, Copy, Deserialize, PartialEq, Eq, Serialize, Debug, sqlx::Type)]
#[sqlx(rename_all = "lowercase")]
pub enum AccountTier {
Basic,
Pro,
Team,
}

#[derive(Default)]
pub struct PermissionsBuilder {
tier: Option<AccountTier>,
super_user: Option<bool>,
}

impl PermissionsBuilder {
pub fn super_user(mut self, is_super_user: bool) -> Self {
self.super_user = Some(is_super_user);
self
}

pub fn tier(mut self, tier: AccountTier) -> Self {
self.tier = Some(tier);
self
}

pub fn build(self) -> Permissions {
Permissions {
tier: self.tier.unwrap_or(AccountTier::Basic),
super_user: self.super_user.unwrap_or_default(),
}
}
}

#[derive(Clone, Deserialize, PartialEq, Eq, Serialize, Debug)]
pub struct Permissions {
pub tier: AccountTier,
pub super_user: bool,
}

impl Default for Permissions {
fn default() -> Self {
Self {
tier: AccountTier::Basic,
super_user: false,
}
}
}

impl Permissions {
pub fn builder() -> PermissionsBuilder {
PermissionsBuilder::default()
}

pub fn tier(&self) -> &AccountTier {
&self.tier
}

pub fn is_super_user(&self) -> bool {
self.super_user
}
}

#[async_trait]
impl<B> FromRequest<B> for User
where
Expand All @@ -88,8 +188,7 @@ where
let Extension(service) = Extension::<Arc<GatewayService>>::from_request(req)
.await
.unwrap();
let user = service
.user_from_key(key)
let user = User::retrieve_from_key(&service, key)
.await
// Absord any error into `Unauthorized`
.map_err(|e| Error::source(ErrorKind::Unauthorized, e))?;
Expand Down Expand Up @@ -144,7 +243,7 @@ where

// Record current project for tracing purposes
Span::current().record("account.project", &scope.to_string());
if user.super_user || user.projects.contains(&scope) {
if user.is_super_user() || user.projects.contains(&scope) {
Ok(Self { user, scope })
} else {
Err(Error::from(ErrorKind::ProjectNotFound))
Expand All @@ -165,7 +264,7 @@ where

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let user = User::from_request(req).await?;
if user.super_user {
if user.is_super_user() {
Ok(Self { user })
} else {
Err(Error::from(ErrorKind::Forbidden))
Expand Down
4 changes: 2 additions & 2 deletions gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ pub mod tests {
use hyper::http::Uri;
use hyper::{Body, Client as HyperClient, Request, Response, StatusCode};
use rand::distributions::{Alphanumeric, DistString, Distribution, Uniform};
use shuttle_common::models::{project, service};
use shuttle_common::models::{project, service, user};
use sqlx::SqlitePool;
use tokio::sync::mpsc::channel;

Expand Down Expand Up @@ -648,7 +648,7 @@ pub mod tests {
let User { key, name, .. } = service.create_user("neo".parse().unwrap()).await.unwrap();
service.set_super_user(&name, true).await.unwrap();

let User { key, .. } = api_client
let user::Response { key, .. } = api_client
.request(
Request::post("/users/trinity")
.with_header(&Authorization::bearer(key.as_str()).unwrap())
Expand Down
92 changes: 42 additions & 50 deletions gateway/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use tracing::{debug, Span};
use tracing_opentelemetry::OpenTelemetrySpanExt;

use crate::args::ContextArgs;
use crate::auth::{Key, User};
use crate::auth::{Key, Permissions, User};
use crate::project::Project;
use crate::task::TaskBuilder;
use crate::{AccountName, DockerContext, Error, ErrorKind, ProjectName};
Expand Down Expand Up @@ -304,30 +304,6 @@ impl GatewayService {
Ok(control_key)
}

pub async fn user_from_account_name(&self, name: AccountName) -> Result<User, Error> {
let key = self.key_from_account_name(&name).await?;
let super_user = self.is_super_user(&name).await?;
let projects = self.iter_user_projects(&name).await?.collect();
Ok(User {
name,
key,
projects,
super_user,
})
}

pub async fn user_from_key(&self, key: Key) -> Result<User, Error> {
let name = self.account_name_from_key(&key).await?;
let super_user = self.is_super_user(&name).await?;
let projects = self.iter_user_projects(&name).await?.collect();
Ok(User {
name,
key,
projects,
super_user,
})
}

pub async fn create_user(&self, name: AccountName) -> Result<User, Error> {
let key = Key::new_random();
query("INSERT INTO accounts (account_name, key) VALUES (?1, ?2)")
Expand All @@ -347,38 +323,53 @@ impl GatewayService {
// Otherwise this is internal
err.into()
})?;
Ok(User {
name,
key,
projects: Vec::default(),
super_user: false,
})
Ok(User::new_with_defaults(name, key))
}

pub async fn get_permissions(&self, account_name: &AccountName) -> Result<Permissions, Error> {
let permissions =
query("SELECT super_user, account_tier FROM accounts WHERE account_name = ?1")
.bind(account_name)
.fetch_optional(&self.db)
.await?
.map(|row| {
Permissions::builder()
.super_user(row.try_get("super_user").unwrap())
.tier(row.try_get("account_tier").unwrap())
.build()
})
.unwrap_or_default(); // defaults to `false` (i.e. not super user)
Ok(permissions)
}

pub async fn is_super_user(&self, account_name: &AccountName) -> Result<bool, Error> {
let is_super_user = query("SELECT super_user FROM accounts WHERE account_name = ?1")
pub async fn set_super_user(
&self,
account_name: &AccountName,
super_user: bool,
) -> Result<(), Error> {
query("UPDATE accounts SET super_user = ?1 WHERE account_name = ?2")
.bind(super_user)
.bind(account_name)
.fetch_optional(&self.db)
.await?
.map(|row| row.try_get("super_user").unwrap())
.unwrap_or(false); // defaults to `false` (i.e. not super user)
Ok(is_super_user)
.execute(&self.db)
.await?;
Ok(())
}

pub async fn set_super_user(
pub async fn set_permissions(
&self,
account_name: &AccountName,
value: bool,
permissions: &Permissions,
) -> Result<(), Error> {
query("UPDATE accounts SET super_user = ?1 WHERE account_name = ?2")
.bind(value)
query("UPDATE accounts SET super_user = ?1, account_tier = ?2 WHERE account_name = ?3")
.bind(&permissions.super_user)
.bind(&permissions.tier)
.bind(account_name)
.execute(&self.db)
.await?;
Ok(())
}

async fn iter_user_projects(
pub async fn iter_user_projects(
&self,
AccountName(account_name): &AccountName,
) -> Result<impl Iterator<Item = ProjectName>, Error> {
Expand Down Expand Up @@ -489,6 +480,7 @@ pub mod tests {
use std::str::FromStr;

use super::*;
use crate::auth::AccountTier;
use crate::task::{self, TaskResult};
use crate::tests::{assert_err_kind, World};
use crate::{Error, ErrorKind};
Expand All @@ -501,34 +493,34 @@ pub mod tests {
let account_name: AccountName = "test_user_123".parse()?;

assert_err_kind!(
svc.user_from_account_name(account_name.clone()).await,
User::retrieve_from_account_name(&svc, account_name.clone()).await,
ErrorKind::UserNotFound
);

assert_err_kind!(
svc.user_from_key(Key::from_str("123").unwrap()).await,
User::retrieve_from_key(&svc, Key::from_str("123").unwrap()).await,
ErrorKind::UserNotFound
);

let user = svc.create_user(account_name.clone()).await?;

assert_eq!(
svc.user_from_account_name(account_name.clone()).await?,
User::retrieve_from_account_name(&svc, account_name.clone()).await?,
user
);

assert!(!svc.is_super_user(&account_name).await?);

let User {
name,
key,
projects,
super_user,
permissions,
} = user;

assert!(projects.is_empty());

assert!(!super_user);
assert!(!permissions.is_super_user());

assert_eq!(*permissions.tier(), AccountTier::Basic);

assert_eq!(name, account_name);

Expand Down