Skip to content
This repository has been archived by the owner on Sep 10, 2024. It is now read-only.

Commit

Permalink
Use dynamic filters on app sessions by reusing the OAuth/compat sessi…
Browse files Browse the repository at this point in the history
…ons filters
  • Loading branch information
sandhose committed Jul 16, 2024
1 parent 79ff9df commit 26ef5c3
Showing 1 changed file with 51 additions and 78 deletions.
129 changes: 51 additions & 78 deletions crates/storage-pg/src/app_session.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -17,20 +17,22 @@
use async_trait::async_trait;
use mas_data_model::{CompatSession, CompatSessionState, Device, Session, SessionState, UserAgent};
use mas_storage::{
app_session::{AppSession, AppSessionFilter, AppSessionRepository},
app_session::{AppSession, AppSessionFilter, AppSessionRepository, AppSessionState},
compat::CompatSessionFilter,
oauth2::OAuth2SessionFilter,
Page, Pagination,
};
use oauth2_types::scope::{Scope, ScopeToken};
use sea_query::{
Alias, ColumnRef, CommonTableExpression, Expr, PgFunc, PostgresQueryBuilder, Query, UnionType,
Alias, ColumnRef, CommonTableExpression, Expr, PostgresQueryBuilder, Query, UnionType,
};
use sea_query_binder::SqlxBinder;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;

use crate::{
errors::DatabaseInconsistencyError,
filter::StatementExt,
iden::{CompatSessions, OAuth2Sessions},
pagination::QueryBuilderExt,
DatabaseError, ExecuteExt,
Expand Down Expand Up @@ -202,6 +204,44 @@ impl TryFrom<AppSessionLookup> for AppSession {
}
}

/// Split a [`AppSessionFilter`] into two separate filters: a
/// [`CompatSessionFilter`] and an [`OAuth2SessionFilter`].
fn split_filter(
filter: AppSessionFilter<'_>,
) -> (CompatSessionFilter<'_>, OAuth2SessionFilter<'_>) {
let mut compat_filter = CompatSessionFilter::new();
let mut oauth2_filter = OAuth2SessionFilter::new();

if let Some(user) = filter.user() {
compat_filter = compat_filter.for_user(user);
oauth2_filter = oauth2_filter.for_user(user);
}

match filter.state() {
Some(AppSessionState::Active) => {
compat_filter = compat_filter.active_only();
oauth2_filter = oauth2_filter.active_only();
}
Some(AppSessionState::Finished) => {
compat_filter = compat_filter.finished_only();
oauth2_filter = oauth2_filter.finished_only();
}
None => {}
}

if let Some(device) = filter.device() {
compat_filter = compat_filter.for_device(device);
oauth2_filter = oauth2_filter.for_device(device);
}

if let Some(browser_session) = filter.browser_session() {
compat_filter = compat_filter.for_browser_session(browser_session);
oauth2_filter = oauth2_filter.for_browser_session(browser_session);
}

(compat_filter, oauth2_filter)
}

#[async_trait]
impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
type Error = DatabaseError;
Expand All @@ -220,6 +260,8 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
filter: AppSessionFilter<'_>,
pagination: Pagination,
) -> Result<Page<AppSession>, Self::Error> {
let (compat_filter, oauth2_filter) = split_filter(filter);

let mut oauth2_session_select = Query::select()
.expr_as(
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
Expand Down Expand Up @@ -269,26 +311,7 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
AppSessionLookupIden::LastActiveIp,
)
.from(OAuth2Sessions::Table)
.and_where_option(filter.user().map(|user| {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
}))
.and_where_option(filter.state().map(|state| {
if state.is_active() {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
} else {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
}
}))
.and_where_option(filter.browser_session().map(|browser_session| {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
.eq(Uuid::from(browser_session.id))
}))
.and_where_option(filter.device().map(|device| {
Expr::val(device.to_scope_token().to_string()).eq(PgFunc::any(Expr::col((
OAuth2Sessions::Table,
OAuth2Sessions::ScopeList,
))))
}))
.apply_filter(oauth2_filter)
.clone();

let compat_session_select = Query::select()
Expand Down Expand Up @@ -340,23 +363,7 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
AppSessionLookupIden::LastActiveIp,
)
.from(CompatSessions::Table)
.and_where_option(filter.user().map(|user| {
Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
}))
.and_where_option(filter.state().map(|state| {
if state.is_active() {
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
} else {
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
}
}))
.and_where_option(filter.browser_session().map(|browser_session| {
Expr::col((CompatSessions::Table, CompatSessions::UserSessionId))
.eq(Uuid::from(browser_session.id))
}))
.and_where_option(filter.device().map(|device| {
Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.to_string())
}))
.apply_filter(compat_filter)
.clone();

let common_table_expression = CommonTableExpression::new()
Expand Down Expand Up @@ -397,51 +404,17 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
err,
)]
async fn count(&mut self, filter: AppSessionFilter<'_>) -> Result<usize, Self::Error> {
let (compat_filter, oauth2_filter) = split_filter(filter);
let mut oauth2_session_select = Query::select()
.expr(Expr::cust("1"))
.from(OAuth2Sessions::Table)
.and_where_option(filter.user().map(|user| {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
}))
.and_where_option(filter.state().map(|state| {
if state.is_active() {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
} else {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
}
}))
.and_where_option(filter.browser_session().map(|browser_session| {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
.eq(Uuid::from(browser_session.id))
}))
.and_where_option(filter.device().map(|device| {
Expr::val(device.to_scope_token().to_string()).eq(PgFunc::any(Expr::col((
OAuth2Sessions::Table,
OAuth2Sessions::ScopeList,
))))
}))
.apply_filter(oauth2_filter)
.clone();

let compat_session_select = Query::select()
.expr(Expr::cust("1"))
.from(CompatSessions::Table)
.and_where_option(filter.user().map(|user| {
Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
}))
.and_where_option(filter.state().map(|state| {
if state.is_active() {
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
} else {
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
}
}))
.and_where_option(filter.browser_session().map(|browser_session| {
Expr::col((CompatSessions::Table, CompatSessions::UserSessionId))
.eq(Uuid::from(browser_session.id))
}))
.and_where_option(filter.device().map(|device| {
Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.to_string())
}))
.apply_filter(compat_filter)
.clone();

let common_table_expression = CommonTableExpression::new()
Expand Down

0 comments on commit 26ef5c3

Please sign in to comment.