Skip to content

Commit b1eee6d

Browse files
authored
feat: add account_tier column (#458)
1 parent 83cbccd commit b1eee6d

File tree

5 files changed

+154
-62
lines changed

5 files changed

+154
-62
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ALTER TABLE accounts ADD account_tier TEXT DEFAULT "basic" NOT NULL;

gateway/src/api/latest.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ async fn get_user(
4949
Path(account_name): Path<AccountName>,
5050
_: Admin,
5151
) -> Result<AxumJson<user::Response>, Error> {
52-
let user = service.user_from_account_name(account_name).await?;
52+
let user = User::retrieve_from_account_name(&service, account_name).await?;
5353

5454
Ok(AxumJson(user.into()))
5555
}

gateway/src/auth.rs

+108-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::fmt::Formatter;
1+
use std::fmt::{Debug, Formatter};
22
use std::str::FromStr;
33
use std::sync::Arc;
44

@@ -58,8 +58,6 @@ impl Key {
5858
}
5959
}
6060

61-
const FALSE: fn() -> bool = || false;
62-
6361
/// A wrapper for a guard that verifies an API key is associated with a
6462
/// valid user.
6563
///
@@ -71,11 +69,113 @@ pub struct User {
7169
pub name: AccountName,
7270
pub key: Key,
7371
pub projects: Vec<ProjectName>,
74-
#[serde(skip_serializing_if = "std::ops::Not::not")]
75-
#[serde(default = "FALSE")]
72+
pub permissions: Permissions,
73+
}
74+
75+
impl User {
76+
pub fn is_super_user(&self) -> bool {
77+
self.permissions.is_super_user()
78+
}
79+
80+
pub fn new_with_defaults(name: AccountName, key: Key) -> Self {
81+
Self {
82+
name,
83+
key,
84+
projects: Vec::new(),
85+
permissions: Permissions::default(),
86+
}
87+
}
88+
89+
pub async fn retrieve_from_account_name(
90+
svc: &GatewayService,
91+
name: AccountName,
92+
) -> Result<User, Error> {
93+
let key = svc.key_from_account_name(&name).await?;
94+
let permissions = svc.get_permissions(&name).await?;
95+
let projects = svc.iter_user_projects(&name).await?.collect();
96+
Ok(User {
97+
name,
98+
key,
99+
projects,
100+
permissions,
101+
})
102+
}
103+
104+
pub async fn retrieve_from_key(svc: &GatewayService, key: Key) -> Result<User, Error> {
105+
let name = svc.account_name_from_key(&key).await?;
106+
let permissions = svc.get_permissions(&name).await?;
107+
let projects = svc.iter_user_projects(&name).await?.collect();
108+
Ok(User {
109+
name,
110+
key,
111+
projects,
112+
permissions,
113+
})
114+
}
115+
}
116+
117+
#[derive(Clone, Copy, Deserialize, PartialEq, Eq, Serialize, Debug, sqlx::Type)]
118+
#[sqlx(rename_all = "lowercase")]
119+
pub enum AccountTier {
120+
Basic,
121+
Pro,
122+
Team,
123+
}
124+
125+
#[derive(Default)]
126+
pub struct PermissionsBuilder {
127+
tier: Option<AccountTier>,
128+
super_user: Option<bool>,
129+
}
130+
131+
impl PermissionsBuilder {
132+
pub fn super_user(mut self, is_super_user: bool) -> Self {
133+
self.super_user = Some(is_super_user);
134+
self
135+
}
136+
137+
pub fn tier(mut self, tier: AccountTier) -> Self {
138+
self.tier = Some(tier);
139+
self
140+
}
141+
142+
pub fn build(self) -> Permissions {
143+
Permissions {
144+
tier: self.tier.unwrap_or(AccountTier::Basic),
145+
super_user: self.super_user.unwrap_or_default(),
146+
}
147+
}
148+
}
149+
150+
#[derive(Clone, Deserialize, PartialEq, Eq, Serialize, Debug)]
151+
pub struct Permissions {
152+
pub tier: AccountTier,
76153
pub super_user: bool,
77154
}
78155

156+
impl Default for Permissions {
157+
fn default() -> Self {
158+
Self {
159+
tier: AccountTier::Basic,
160+
super_user: false,
161+
}
162+
}
163+
}
164+
165+
impl Permissions {
166+
pub fn builder() -> PermissionsBuilder {
167+
PermissionsBuilder::default()
168+
}
169+
170+
pub fn tier(&self) -> &AccountTier {
171+
&self.tier
172+
}
173+
174+
pub fn is_super_user(&self) -> bool {
175+
self.super_user
176+
}
177+
}
178+
79179
#[async_trait]
80180
impl<B> FromRequest<B> for User
81181
where
@@ -88,8 +188,7 @@ where
88188
let Extension(service) = Extension::<Arc<GatewayService>>::from_request(req)
89189
.await
90190
.unwrap();
91-
let user = service
92-
.user_from_key(key)
191+
let user = User::retrieve_from_key(&service, key)
93192
.await
94193
// Absord any error into `Unauthorized`
95194
.map_err(|e| Error::source(ErrorKind::Unauthorized, e))?;
@@ -144,7 +243,7 @@ where
144243

145244
// Record current project for tracing purposes
146245
Span::current().record("account.project", &scope.to_string());
147-
if user.super_user || user.projects.contains(&scope) {
246+
if user.is_super_user() || user.projects.contains(&scope) {
148247
Ok(Self { user, scope })
149248
} else {
150249
Err(Error::from(ErrorKind::ProjectNotFound))
@@ -165,7 +264,7 @@ where
165264

166265
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
167266
let user = User::from_request(req).await?;
168-
if user.super_user {
267+
if user.is_super_user() {
169268
Ok(Self { user })
170269
} else {
171270
Err(Error::from(ErrorKind::Forbidden))

gateway/src/lib.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ pub mod tests {
286286
use hyper::http::Uri;
287287
use hyper::{Body, Client as HyperClient, Request, Response, StatusCode};
288288
use rand::distributions::{Alphanumeric, DistString, Distribution, Uniform};
289-
use shuttle_common::models::{project, service};
289+
use shuttle_common::models::{project, service, user};
290290
use sqlx::SqlitePool;
291291
use tokio::sync::mpsc::channel;
292292

@@ -648,7 +648,7 @@ pub mod tests {
648648
let User { key, name, .. } = service.create_user("neo".parse().unwrap()).await.unwrap();
649649
service.set_super_user(&name, true).await.unwrap();
650650

651-
let User { key, .. } = api_client
651+
let user::Response { key, .. } = api_client
652652
.request(
653653
Request::post("/users/trinity")
654654
.with_header(&Authorization::bearer(key.as_str()).unwrap())

gateway/src/service.rs

+42-50
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use tracing::{debug, Span};
2323
use tracing_opentelemetry::OpenTelemetrySpanExt;
2424

2525
use crate::args::ContextArgs;
26-
use crate::auth::{Key, User};
26+
use crate::auth::{Key, Permissions, User};
2727
use crate::project::Project;
2828
use crate::task::TaskBuilder;
2929
use crate::{AccountName, DockerContext, Error, ErrorKind, ProjectName};
@@ -304,30 +304,6 @@ impl GatewayService {
304304
Ok(control_key)
305305
}
306306

307-
pub async fn user_from_account_name(&self, name: AccountName) -> Result<User, Error> {
308-
let key = self.key_from_account_name(&name).await?;
309-
let super_user = self.is_super_user(&name).await?;
310-
let projects = self.iter_user_projects(&name).await?.collect();
311-
Ok(User {
312-
name,
313-
key,
314-
projects,
315-
super_user,
316-
})
317-
}
318-
319-
pub async fn user_from_key(&self, key: Key) -> Result<User, Error> {
320-
let name = self.account_name_from_key(&key).await?;
321-
let super_user = self.is_super_user(&name).await?;
322-
let projects = self.iter_user_projects(&name).await?.collect();
323-
Ok(User {
324-
name,
325-
key,
326-
projects,
327-
super_user,
328-
})
329-
}
330-
331307
pub async fn create_user(&self, name: AccountName) -> Result<User, Error> {
332308
let key = Key::new_random();
333309
query("INSERT INTO accounts (account_name, key) VALUES (?1, ?2)")
@@ -347,38 +323,53 @@ impl GatewayService {
347323
// Otherwise this is internal
348324
err.into()
349325
})?;
350-
Ok(User {
351-
name,
352-
key,
353-
projects: Vec::default(),
354-
super_user: false,
355-
})
326+
Ok(User::new_with_defaults(name, key))
327+
}
328+
329+
pub async fn get_permissions(&self, account_name: &AccountName) -> Result<Permissions, Error> {
330+
let permissions =
331+
query("SELECT super_user, account_tier FROM accounts WHERE account_name = ?1")
332+
.bind(account_name)
333+
.fetch_optional(&self.db)
334+
.await?
335+
.map(|row| {
336+
Permissions::builder()
337+
.super_user(row.try_get("super_user").unwrap())
338+
.tier(row.try_get("account_tier").unwrap())
339+
.build()
340+
})
341+
.unwrap_or_default(); // defaults to `false` (i.e. not super user)
342+
Ok(permissions)
356343
}
357344

358-
pub async fn is_super_user(&self, account_name: &AccountName) -> Result<bool, Error> {
359-
let is_super_user = query("SELECT super_user FROM accounts WHERE account_name = ?1")
345+
pub async fn set_super_user(
346+
&self,
347+
account_name: &AccountName,
348+
super_user: bool,
349+
) -> Result<(), Error> {
350+
query("UPDATE accounts SET super_user = ?1 WHERE account_name = ?2")
351+
.bind(super_user)
360352
.bind(account_name)
361-
.fetch_optional(&self.db)
362-
.await?
363-
.map(|row| row.try_get("super_user").unwrap())
364-
.unwrap_or(false); // defaults to `false` (i.e. not super user)
365-
Ok(is_super_user)
353+
.execute(&self.db)
354+
.await?;
355+
Ok(())
366356
}
367357

368-
pub async fn set_super_user(
358+
pub async fn set_permissions(
369359
&self,
370360
account_name: &AccountName,
371-
value: bool,
361+
permissions: &Permissions,
372362
) -> Result<(), Error> {
373-
query("UPDATE accounts SET super_user = ?1 WHERE account_name = ?2")
374-
.bind(value)
363+
query("UPDATE accounts SET super_user = ?1, account_tier = ?2 WHERE account_name = ?3")
364+
.bind(&permissions.super_user)
365+
.bind(&permissions.tier)
375366
.bind(account_name)
376367
.execute(&self.db)
377368
.await?;
378369
Ok(())
379370
}
380371

381-
async fn iter_user_projects(
372+
pub async fn iter_user_projects(
382373
&self,
383374
AccountName(account_name): &AccountName,
384375
) -> Result<impl Iterator<Item = ProjectName>, Error> {
@@ -489,6 +480,7 @@ pub mod tests {
489480
use std::str::FromStr;
490481

491482
use super::*;
483+
use crate::auth::AccountTier;
492484
use crate::task::{self, TaskResult};
493485
use crate::tests::{assert_err_kind, World};
494486
use crate::{Error, ErrorKind};
@@ -501,34 +493,34 @@ pub mod tests {
501493
let account_name: AccountName = "test_user_123".parse()?;
502494

503495
assert_err_kind!(
504-
svc.user_from_account_name(account_name.clone()).await,
496+
User::retrieve_from_account_name(&svc, account_name.clone()).await,
505497
ErrorKind::UserNotFound
506498
);
507499

508500
assert_err_kind!(
509-
svc.user_from_key(Key::from_str("123").unwrap()).await,
501+
User::retrieve_from_key(&svc, Key::from_str("123").unwrap()).await,
510502
ErrorKind::UserNotFound
511503
);
512504

513505
let user = svc.create_user(account_name.clone()).await?;
514506

515507
assert_eq!(
516-
svc.user_from_account_name(account_name.clone()).await?,
508+
User::retrieve_from_account_name(&svc, account_name.clone()).await?,
517509
user
518510
);
519511

520-
assert!(!svc.is_super_user(&account_name).await?);
521-
522512
let User {
523513
name,
524514
key,
525515
projects,
526-
super_user,
516+
permissions,
527517
} = user;
528518

529519
assert!(projects.is_empty());
530520

531-
assert!(!super_user);
521+
assert!(!permissions.is_super_user());
522+
523+
assert_eq!(*permissions.tier(), AccountTier::Basic);
532524

533525
assert_eq!(name, account_name);
534526

0 commit comments

Comments
 (0)