Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gateway disconnect notifications settings #936

Open
wants to merge 17 commits into
base: dev
Choose a base branch
from
Open

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ tokio = { version = "1", features = [
"time",
] }
tokio-stream = "0.1"
tokio-util = "0.7"
tonic = { version = "0.11", features = ["gzip", "tls", "tls-roots"] }
tonic-health = "0.11"
totp-lite = { version = "2.0" }
Expand Down
1 change: 1 addition & 0 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
packages = with pkgs; [
sqlx-cli
just
vtsls
];

# Specify the rust-src path (many editors rely on this)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ALTER TABLE settings DROP COLUMN gateway_disconnect_notifications_enabled;
ALTER TABLE settings DROP COLUMN gateway_disconnect_notifications_inactivity_threshold;
ALTER TABLE settings DROP COLUMN gateway_disconnect_notifications_reconnect_notification_enabled;
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ALTER TABLE settings ADD gateway_disconnect_notifications_enabled BOOLEAN NOT NULL DEFAULT FALSE;
ALTER TABLE settings ADD gateway_disconnect_notifications_inactivity_threshold INT4 NOT NULL DEFAULT 5;
ALTER TABLE settings ADD gateway_disconnect_notifications_reconnect_notification_enabled BOOLEAN NOT NULL DEFAULT FALSE;
2 changes: 1 addition & 1 deletion src/appstate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ impl AppState {
debug!("WebHook triggered");
debug!("Retrieving webhooks");
if let Ok(webhooks) = WebHook::all_enabled(&pool, &msg).await {
info!("Found webhooks: {webhooks:#?}");
info!("Found webhooks: {webhooks:?}");
let (payload, event) = match msg {
AppEvent::UserCreated(user) => (json!(user), "user_created"),
AppEvent::UserModified(user) => (json!(user), "user_modified"),
Expand Down
23 changes: 14 additions & 9 deletions src/bin/defguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use std::{
use defguard::{
auth::failed_login::FailedLoginMap,
config::{Command, DefGuardConfig},
db::{init_db, AppEvent, GatewayEvent, Settings, User},
db::{
init_db, models::settings::initialize_current_settings, AppEvent, GatewayEvent, Settings,
User,
},
enterprise::{
license::{run_periodic_license_check, set_cached_license, License},
limits::update_counts,
Expand Down Expand Up @@ -88,6 +91,8 @@ async fn main() -> Result<(), anyhow::Error> {

// initialize default settings
Settings::init_defaults(&pool).await?;
// initialize global settings struct
initialize_current_settings(&pool).await?;

// read grpc TLS cert and key
let grpc_cert = config
Expand Down Expand Up @@ -118,14 +123,14 @@ async fn main() -> Result<(), anyhow::Error> {

// run services
tokio::select! {
res = run_grpc_bidi_stream(pool.clone(), wireguard_tx.clone(), mail_tx.clone()), if config.proxy_url.is_some() => error!("Proxy gRPC stream returned early: {res:#?}"),
res = run_grpc_server(Arc::clone(&worker_state), pool.clone(), Arc::clone(&gateway_state), wireguard_tx.clone(), mail_tx.clone(), grpc_cert, grpc_key, failed_logins.clone()) => error!("gRPC server returned early: {res:#?}"),
res = run_web_server(worker_state, gateway_state, webhook_tx, webhook_rx, wireguard_tx.clone(), mail_tx, pool.clone(), failed_logins) => error!("Web server returned early: {res:#?}"),
res = run_mail_handler(mail_rx, pool.clone()) => error!("Mail handler returned early: {res:#?}"),
res = run_periodic_peer_disconnect(pool.clone(), wireguard_tx) => error!("Periodic peer disconnect task returned early: {res:#?}"),
res = run_periodic_stats_purge(pool.clone(), config.stats_purge_frequency.into(), config.stats_purge_threshold.into()), if !config.disable_stats_purge => error!("Periodic stats purge task returned early: {res:#?}"),
res = run_periodic_license_check(&pool) => error!("Periodic license check task returned early: {res:#?}"),
res = run_utility_thread(&pool) => error!("Utility thread returned early: {res:#?}"),
res = run_grpc_bidi_stream(pool.clone(), wireguard_tx.clone(), mail_tx.clone()), if config.proxy_url.is_some() => error!("Proxy gRPC stream returned early: {res:?}"),
res = run_grpc_server(Arc::clone(&worker_state), pool.clone(), Arc::clone(&gateway_state), wireguard_tx.clone(), mail_tx.clone(), grpc_cert, grpc_key, failed_logins.clone()) => error!("gRPC server returned early: {res:?}"),
res = run_web_server(worker_state, gateway_state, webhook_tx, webhook_rx, wireguard_tx.clone(), mail_tx, pool.clone(), failed_logins) => error!("Web server returned early: {res:?}"),
res = run_mail_handler(mail_rx, pool.clone()) => error!("Mail handler returned early: {res:?}"),
res = run_periodic_peer_disconnect(pool.clone(), wireguard_tx) => error!("Periodic peer disconnect task returned early: {res:?}"),
res = run_periodic_stats_purge(pool.clone(), config.stats_purge_frequency.into(), config.stats_purge_threshold.into()), if !config.disable_stats_purge => error!("Periodic stats purge task returned early: {res:?}"),
res = run_periodic_license_check(&pool) => error!("Periodic license check task returned early: {res:?}"),
res = run_utility_thread(&pool) => error!("Utility thread returned early: {res:?}"),
}
Ok(())
}
8 changes: 0 additions & 8 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,6 @@ pub struct DefGuardConfig {
#[arg(long, env = "DEFGUARD_PROXY_GRPC_CA")]
pub proxy_grpc_ca: Option<String>,

#[arg(
long,
env = "DEFGUARD_GATEWAY_DISCONNECTION_NOTIFICATION_TIMEOUT",
default_value = "10m"
)]
#[serde(skip_serializing)]
pub gateway_disconnection_notification_timeout: Duration,

#[command(subcommand)]
#[serde(skip_serializing)]
pub cmd: Option<Command>,
Expand Down
4 changes: 2 additions & 2 deletions src/db/models/enrollment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ impl Token {
&self,
transaction: &mut PgConnection,
) -> Result<String, TokenError> {
let settings = Settings::get_settings(&mut *transaction).await?;
let settings = Settings::get_current_settings();

// load configured content as template
let mut tera = Tera::default();
Expand All @@ -370,7 +370,7 @@ impl Token {
ip_address: &str,
device_info: Option<&str>,
) -> Result<String, TokenError> {
let settings = Settings::get_settings(&mut *transaction).await?;
let settings = Settings::get_current_settings();

// load configured content as template
let mut tera = Tera::default();
Expand Down
110 changes: 85 additions & 25 deletions src/db/models/settings.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,71 @@
use std::collections::HashMap;
use std::{
collections::HashMap,
sync::{RwLock, RwLockReadGuard},
};

use sqlx::{query, query_as, PgExecutor, PgPool, Type};
use struct_patch::Patch;
use thiserror::Error;

use crate::secret::SecretString;

#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, Type, Debug)]
// wrap in `Option` since a static cannot be initialized with a non-const function
static SETTINGS: RwLock<Option<Settings>> = RwLock::new(None);

pub(crate) fn set_settings(new_settings: Settings) {
*SETTINGS
.write()
.expect("Failed to acquire lock on current settings.") = Some(new_settings);
}

pub(crate) fn get_settings() -> RwLockReadGuard<'static, Option<Settings>> {
SETTINGS
.read()
.expect("Failed to acquire lock on current settings.")
}

/// Initializes global `SETTINGS` struct at program startup
pub async fn initialize_current_settings(pool: &PgPool) -> Result<(), sqlx::Error> {
debug!("Initializing global settings strut");
match Settings::get(pool).await? {
Some(settings) => {
set_settings(settings);
}
None => {
debug!("Settings not found in DB. Using default values to initialize global settings struct");
set_settings(Settings::default());
}
}
Ok(())
}

/// Helper function which stores updated `Settings` in the DB and also updates the global `SETTINGS` struct
pub async fn update_current_settings(
pool: &PgPool,
new_settings: Settings,
) -> Result<(), sqlx::Error> {
debug!("Updating current settings to: {new_settings:?}");
new_settings.save(pool).await?;
set_settings(new_settings);
Ok(())
}

#[derive(Error, Debug)]
pub enum SettingsValidationError {
#[error("Cannot enable gateway disconnect notifications. SMTP is not configured")]
CannotEnableGatewayNotifications,
}

#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, Type, Debug, Default)]
#[sqlx(type_name = "smtp_encryption", rename_all = "lowercase")]
pub enum SmtpEncryption {
#[default]
None,
StartTls,
ImplicitTls,
}

#[derive(Clone, Debug, Deserialize, PartialEq, Patch, Serialize)]
#[derive(Clone, Debug, Deserialize, PartialEq, Patch, Serialize, Default)]
#[patch(attribute(derive(Deserialize, Serialize)))]
pub struct Settings {
// Modules
Expand Down Expand Up @@ -58,6 +110,10 @@ pub struct Settings {
// Whether to create a new account when users try to log in with external OpenID
pub openid_create_account: bool,
pub license: Option<String>,
// Gateway disconnect notifications
pub gateway_disconnect_notifications_enabled: bool,
pub gateway_disconnect_notifications_inactivity_threshold: i32,
pub gateway_disconnect_notifications_reconnect_notification_enabled: bool,
}

impl Settings {
Expand All @@ -78,13 +134,26 @@ impl Settings {
ldap_group_search_base, ldap_user_search_base, ldap_user_obj_class, \
ldap_group_obj_class, ldap_username_attr, ldap_groupname_attr, \
ldap_group_member_attr, ldap_member_attr, openid_create_account, \
license \
license, \
gateway_disconnect_notifications_enabled, gateway_disconnect_notifications_inactivity_threshold, gateway_disconnect_notifications_reconnect_notification_enabled \
FROM \"settings\" WHERE id = 1",
)
.fetch_optional(executor)
.await
}

/// Checks if given settings are correct
pub fn validate(&self) -> Result<(), SettingsValidationError> {
debug!("Validating settings: {self:?}");
// check if gateway disconnect notifications can be enabled, since it requires SMTP to be configured
if self.gateway_disconnect_notifications_enabled && !self.smtp_configured() {
warn!("Cannot enable gateway disconnect notifications. SMTP is not configured.");
return Err(SettingsValidationError::CannotEnableGatewayNotifications);
};

Ok(())
}

pub async fn save<'e, E>(&self, executor: E) -> Result<(), sqlx::Error>
where
E: PgExecutor<'e>,
Expand Down Expand Up @@ -123,7 +192,10 @@ impl Settings {
ldap_group_member_attr = $30, \
ldap_member_attr = $31, \
openid_create_account = $32, \
license = $33 \
license = $33, \
gateway_disconnect_notifications_enabled = $34, \
gateway_disconnect_notifications_inactivity_threshold = $35, \
gateway_disconnect_notifications_reconnect_notification_enabled = $36 \
WHERE id = 1",
self.openid_enabled,
self.wireguard_enabled,
Expand Down Expand Up @@ -157,35 +229,23 @@ impl Settings {
self.ldap_group_member_attr,
self.ldap_member_attr,
self.openid_create_account,
self.license
)
.execute(executor)
.await?;

Ok(())
}

pub(crate) async fn save_license<'e, E>(&self, executor: E) -> Result<(), sqlx::Error>
where
E: PgExecutor<'e>,
{
query!(
"UPDATE \"settings\" SET license = $1 WHERE id = 1",
self.license,
self.gateway_disconnect_notifications_enabled,
self.gateway_disconnect_notifications_inactivity_threshold,
self.gateway_disconnect_notifications_reconnect_notification_enabled
)
.execute(executor)
.await?;

Ok(())
}

pub async fn get_settings<'e, E>(executor: E) -> Result<Self, sqlx::Error>
where
E: PgExecutor<'e>,
{
let settings = Settings::get(executor).await?;
pub fn get_current_settings() -> Self {
// fetch global settings
let maybe_settings = get_settings().clone();

Ok(settings.expect("Settings not found"))
// panic if settings have not been initialized, since it should happen at startup
maybe_settings.expect("Global settings have not been initialized")
}

// Set default values for settings if not set yet.
Expand Down
2 changes: 1 addition & 1 deletion src/enterprise/handlers/openid_login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ pub(crate) async fn user_from_claims(
let sub = token_claims.subject().to_string();

// Handle logging in or creating user.
let settings = Settings::get_settings(pool).await?;
let settings = Settings::get_current_settings();
let user = match User::find_by_sub(pool, &sub)
.await
.map_err(|err| WebError::Authorization(err.to_string()))?
Expand Down
Loading
Loading