Skip to content
This repository has been archived by the owner on Sep 10, 2024. It is now read-only.

Commit

Permalink
Soft-delete upstream OAuth 2.0 providers on config sync
Browse files Browse the repository at this point in the history
  • Loading branch information
sandhose committed Apr 3, 2024
1 parent 4e3823f commit cd0ec35
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 20 deletions.
81 changes: 66 additions & 15 deletions crates/cli/src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@

//! Utilities to synchronize the configuration file with the database.
use std::collections::HashSet;
use std::collections::{BTreeMap, BTreeSet};

use mas_config::{ClientsConfig, UpstreamOAuth2Config};
use mas_keystore::Encrypter;
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderParams, Clock, RepositoryAccess};
use mas_storage::{
upstream_oauth2::{UpstreamOAuthProviderFilter, UpstreamOAuthProviderParams},
Clock, Pagination, RepositoryAccess,
};
use mas_storage_pg::PgRepository;
use sqlx::{postgres::PgAdvisoryLock, Connection, PgConnection};
use tracing::{error, info, info_span, warn};
Expand Down Expand Up @@ -107,35 +110,83 @@ pub async fn config_sync(
let config_ids = upstream_oauth2_config
.providers
.iter()
.filter(|p| p.enabled)
.map(|p| p.id)
.collect::<HashSet<_>>();
.collect::<BTreeSet<_>>();

// Let's assume we have less than 1000 providers
let page = repo
.upstream_oauth_provider()
.list(
UpstreamOAuthProviderFilter::default(),
Pagination::first(1000),
)
.await?;

// A warning is probably enough
if page.has_next_page {
warn!(
"More than 1000 providers in the database, only the first 1000 will be considered"
);
}

let mut existing_enabled_ids = BTreeSet::new();
let mut existing_disabled = BTreeMap::new();
// Process the existing providers
for provider in page.edges {
if provider.enabled() {
if config_ids.contains(&provider.id) {
existing_enabled_ids.insert(provider.id);
} else {
// Provider is enabled in the database but not in the config
info!(%provider.id, "Disabling provider");

let provider = if dry_run {
provider
} else {
repo.upstream_oauth_provider()
.disable(clock, provider)
.await?
};

existing_disabled.insert(provider.id, provider);
}
} else {
existing_disabled.insert(provider.id, provider);
}
}

let existing = repo.upstream_oauth_provider().all_enabled().await?;
let existing_ids = existing.iter().map(|p| p.id).collect::<HashSet<_>>();
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
if prune {
for provider in to_delete {
info!(%provider.id, "Deleting provider");
for provider_id in existing_disabled.keys().copied() {
info!(provider.id = %provider_id, "Deleting provider");

if dry_run {
continue;
}

repo.upstream_oauth_provider().delete(provider).await?;
repo.upstream_oauth_provider()
.delete_by_id(provider_id)
.await?;
}
} else {
let len = to_delete.count();
let len = existing_disabled.len();
match len {
0 => {},
1 => warn!("A provider in the database is not in the config. Run `mas-cli config sync --prune` to delete it."),
n => warn!("{n} providers in the database are not in the config. Run `mas-cli config sync --prune` to delete them."),
1 => warn!("A provider is soft-deleted in the database. Run `mas-cli config sync --prune` to delete it."),
n => warn!("{n} providers are soft-deleted in the database. Run `mas-cli config sync --prune` to delete them."),
}
}

for provider in upstream_oauth2_config.providers {
if !provider.enabled {
continue;
}

let _span = info_span!("provider", %provider.id).entered();
if existing_ids.contains(&provider.id) {
if existing_enabled_ids.contains(&provider.id) {
info!("Updating provider");
} else if existing_disabled.contains_key(&provider.id) {
info!("Enabling and updating provider");
} else {
info!("Adding provider");
}
Expand Down Expand Up @@ -224,10 +275,10 @@ pub async fn config_sync(
let config_ids = clients_config
.iter()
.map(|c| c.client_id)
.collect::<HashSet<_>>();
.collect::<BTreeSet<_>>();

let existing = repo.oauth2_client().all_static().await?;
let existing_ids = existing.iter().map(|p| p.id).collect::<HashSet<_>>();
let existing_ids = existing.iter().map(|p| p.id).collect::<BTreeSet<_>>();
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
if prune {
for client in to_delete {
Expand Down
15 changes: 15 additions & 0 deletions crates/config/src/sections/upstream_oauth2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,24 @@ impl PkceMethod {
}
}

fn default_true() -> bool {
true
}

#[allow(clippy::trivially_copy_pass_by_ref)]
fn is_default_true(value: &bool) -> bool {
*value
}

#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct Provider {
/// Whether this provider is enabled.
///
/// Defaults to `true`
#[serde(default = "default_true", skip_serializing_if = "is_default_true")]
pub enabled: bool,

/// An internal unique identifier for this provider
#[schemars(
with = "String",
Expand Down
12 changes: 12 additions & 0 deletions crates/data-model/src/upstream_oauth2/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,18 @@ pub struct UpstreamOAuthProvider {
pub additional_authorization_parameters: Vec<(String, String)>,
}

impl PartialOrd for UpstreamOAuthProvider {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.id.cmp(&other.id))
}
}

impl Ord for UpstreamOAuthProvider {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.id.cmp(&other.id)
}
}

impl UpstreamOAuthProvider {
/// Returns `true` if the provider is enabled
#[must_use]
Expand Down
10 changes: 7 additions & 3 deletions crates/storage-pg/src/upstream_oauth2/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,8 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
async fn disable(
&mut self,
clock: &dyn Clock,
upstream_oauth_provider: UpstreamOAuthProvider,
) -> Result<(), Self::Error> {
mut upstream_oauth_provider: UpstreamOAuthProvider,
) -> Result<UpstreamOAuthProvider, Self::Error> {
let disabled_at = clock.now();
let res = sqlx::query!(
r#"
Expand All @@ -531,7 +531,11 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
.execute(&mut *self.conn)
.await?;

DatabaseError::ensure_affected_rows(&res, 1)
DatabaseError::ensure_affected_rows(&res, 1)?;

upstream_oauth_provider.disabled_at = Some(disabled_at);

Ok(upstream_oauth_provider)
}

#[tracing::instrument(
Expand Down
6 changes: 4 additions & 2 deletions crates/storage/src/upstream_oauth2/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync {

/// Disable an upstream OAuth provider
///
/// Returns the disabled provider
///
/// # Parameters
///
/// * `clock`: The clock used to generate timestamps
Expand All @@ -216,7 +218,7 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync {
&mut self,
clock: &dyn Clock,
provider: UpstreamOAuthProvider,
) -> Result<(), Self::Error>;
) -> Result<UpstreamOAuthProvider, Self::Error>;

/// List [`UpstreamOAuthProvider`] with the given filter and pagination
///
Expand Down Expand Up @@ -281,7 +283,7 @@ repository_impl!(UpstreamOAuthProviderRepository:
&mut self,
clock: &dyn Clock,
provider: UpstreamOAuthProvider
) -> Result<(), Self::Error>;
) -> Result<UpstreamOAuthProvider, Self::Error>;

async fn list(
&mut self,
Expand Down
4 changes: 4 additions & 0 deletions docs/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -1562,6 +1562,10 @@
"token_endpoint_auth_method"
],
"properties": {
"enabled": {
"description": "Whether this provider is enabled.\n\nDefaults to `true`",
"type": "boolean"
},
"id": {
"description": "A ULID as per https://github.com/ulid/spec",
"type": "string",
Expand Down

0 comments on commit cd0ec35

Please sign in to comment.