From 26ef5c38af7225b2d8561841379eb66d3c5fb259 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 16 Jul 2024 17:54:53 +0200 Subject: [PATCH] Use dynamic filters on app sessions by reusing the OAuth/compat sessions filters --- crates/storage-pg/src/app_session.rs | 129 +++++++++++---------------- 1 file changed, 51 insertions(+), 78 deletions(-) diff --git a/crates/storage-pg/src/app_session.rs b/crates/storage-pg/src/app_session.rs index 51a20a6ea..c0cbb8968 100644 --- a/crates/storage-pg/src/app_session.rs +++ b/crates/storage-pg/src/app_session.rs @@ -1,4 +1,4 @@ -// Copyright 2023 The Matrix.org Foundation C.I.C. +// Copyright 2023, 2024 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,20 +17,22 @@ use async_trait::async_trait; use mas_data_model::{CompatSession, CompatSessionState, Device, Session, SessionState, UserAgent}; use mas_storage::{ - app_session::{AppSession, AppSessionFilter, AppSessionRepository}, + app_session::{AppSession, AppSessionFilter, AppSessionRepository, AppSessionState}, + compat::CompatSessionFilter, + oauth2::OAuth2SessionFilter, Page, Pagination, }; use oauth2_types::scope::{Scope, ScopeToken}; use sea_query::{ - Alias, ColumnRef, CommonTableExpression, Expr, PgFunc, PostgresQueryBuilder, Query, UnionType, + Alias, ColumnRef, CommonTableExpression, Expr, PostgresQueryBuilder, Query, UnionType, }; use sea_query_binder::SqlxBinder; use sqlx::PgConnection; use ulid::Ulid; -use uuid::Uuid; use crate::{ errors::DatabaseInconsistencyError, + filter::StatementExt, iden::{CompatSessions, OAuth2Sessions}, pagination::QueryBuilderExt, DatabaseError, ExecuteExt, @@ -202,6 +204,44 @@ impl TryFrom for AppSession { } } +/// Split a [`AppSessionFilter`] into two separate filters: a +/// [`CompatSessionFilter`] and an [`OAuth2SessionFilter`]. +fn split_filter( + filter: AppSessionFilter<'_>, +) -> (CompatSessionFilter<'_>, OAuth2SessionFilter<'_>) { + let mut compat_filter = CompatSessionFilter::new(); + let mut oauth2_filter = OAuth2SessionFilter::new(); + + if let Some(user) = filter.user() { + compat_filter = compat_filter.for_user(user); + oauth2_filter = oauth2_filter.for_user(user); + } + + match filter.state() { + Some(AppSessionState::Active) => { + compat_filter = compat_filter.active_only(); + oauth2_filter = oauth2_filter.active_only(); + } + Some(AppSessionState::Finished) => { + compat_filter = compat_filter.finished_only(); + oauth2_filter = oauth2_filter.finished_only(); + } + None => {} + } + + if let Some(device) = filter.device() { + compat_filter = compat_filter.for_device(device); + oauth2_filter = oauth2_filter.for_device(device); + } + + if let Some(browser_session) = filter.browser_session() { + compat_filter = compat_filter.for_browser_session(browser_session); + oauth2_filter = oauth2_filter.for_browser_session(browser_session); + } + + (compat_filter, oauth2_filter) +} + #[async_trait] impl<'c> AppSessionRepository for PgAppSessionRepository<'c> { type Error = DatabaseError; @@ -220,6 +260,8 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> { filter: AppSessionFilter<'_>, pagination: Pagination, ) -> Result, Self::Error> { + let (compat_filter, oauth2_filter) = split_filter(filter); + let mut oauth2_session_select = Query::select() .expr_as( Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)), @@ -269,26 +311,7 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> { AppSessionLookupIden::LastActiveIp, ) .from(OAuth2Sessions::Table) - .and_where_option(filter.user().map(|user| { - Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id)) - })) - .and_where_option(filter.state().map(|state| { - if state.is_active() { - Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null() - } else { - Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null() - } - })) - .and_where_option(filter.browser_session().map(|browser_session| { - Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)) - .eq(Uuid::from(browser_session.id)) - })) - .and_where_option(filter.device().map(|device| { - Expr::val(device.to_scope_token().to_string()).eq(PgFunc::any(Expr::col(( - OAuth2Sessions::Table, - OAuth2Sessions::ScopeList, - )))) - })) + .apply_filter(oauth2_filter) .clone(); let compat_session_select = Query::select() @@ -340,23 +363,7 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> { AppSessionLookupIden::LastActiveIp, ) .from(CompatSessions::Table) - .and_where_option(filter.user().map(|user| { - Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id)) - })) - .and_where_option(filter.state().map(|state| { - if state.is_active() { - Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null() - } else { - Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null() - } - })) - .and_where_option(filter.browser_session().map(|browser_session| { - Expr::col((CompatSessions::Table, CompatSessions::UserSessionId)) - .eq(Uuid::from(browser_session.id)) - })) - .and_where_option(filter.device().map(|device| { - Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.to_string()) - })) + .apply_filter(compat_filter) .clone(); let common_table_expression = CommonTableExpression::new() @@ -397,51 +404,17 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> { err, )] async fn count(&mut self, filter: AppSessionFilter<'_>) -> Result { + let (compat_filter, oauth2_filter) = split_filter(filter); let mut oauth2_session_select = Query::select() .expr(Expr::cust("1")) .from(OAuth2Sessions::Table) - .and_where_option(filter.user().map(|user| { - Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id)) - })) - .and_where_option(filter.state().map(|state| { - if state.is_active() { - Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null() - } else { - Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null() - } - })) - .and_where_option(filter.browser_session().map(|browser_session| { - Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)) - .eq(Uuid::from(browser_session.id)) - })) - .and_where_option(filter.device().map(|device| { - Expr::val(device.to_scope_token().to_string()).eq(PgFunc::any(Expr::col(( - OAuth2Sessions::Table, - OAuth2Sessions::ScopeList, - )))) - })) + .apply_filter(oauth2_filter) .clone(); let compat_session_select = Query::select() .expr(Expr::cust("1")) .from(CompatSessions::Table) - .and_where_option(filter.user().map(|user| { - Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id)) - })) - .and_where_option(filter.state().map(|state| { - if state.is_active() { - Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null() - } else { - Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null() - } - })) - .and_where_option(filter.browser_session().map(|browser_session| { - Expr::col((CompatSessions::Table, CompatSessions::UserSessionId)) - .eq(Uuid::from(browser_session.id)) - })) - .and_where_option(filter.device().map(|device| { - Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.to_string()) - })) + .apply_filter(compat_filter) .clone(); let common_table_expression = CommonTableExpression::new()