Skip to content
This repository was archived by the owner on Sep 10, 2024. It is now read-only.

Commit 1ca82a9

Browse files
committed
Use dynamic filters on app sessions by reusing the OAuth/compat sessions filters
1 parent 718b7f0 commit 1ca82a9

File tree

1 file changed

+51
-78
lines changed

1 file changed

+51
-78
lines changed

crates/storage-pg/src/app_session.rs

+51-78
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2023 The Matrix.org Foundation C.I.C.
1+
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -17,20 +17,22 @@
1717
use async_trait::async_trait;
1818
use mas_data_model::{CompatSession, CompatSessionState, Device, Session, SessionState, UserAgent};
1919
use mas_storage::{
20-
app_session::{AppSession, AppSessionFilter, AppSessionRepository},
20+
app_session::{AppSession, AppSessionFilter, AppSessionRepository, AppSessionState},
21+
compat::CompatSessionFilter,
22+
oauth2::OAuth2SessionFilter,
2123
Page, Pagination,
2224
};
2325
use oauth2_types::scope::{Scope, ScopeToken};
2426
use sea_query::{
25-
Alias, ColumnRef, CommonTableExpression, Expr, PgFunc, PostgresQueryBuilder, Query, UnionType,
27+
Alias, ColumnRef, CommonTableExpression, Expr, PostgresQueryBuilder, Query, UnionType,
2628
};
2729
use sea_query_binder::SqlxBinder;
2830
use sqlx::PgConnection;
2931
use ulid::Ulid;
30-
use uuid::Uuid;
3132

3233
use crate::{
3334
errors::DatabaseInconsistencyError,
35+
filter::StatementExt,
3436
iden::{CompatSessions, OAuth2Sessions},
3537
pagination::QueryBuilderExt,
3638
DatabaseError, ExecuteExt,
@@ -202,6 +204,44 @@ impl TryFrom<AppSessionLookup> for AppSession {
202204
}
203205
}
204206

207+
/// Split a [`AppSessionFilter`] into two separate filters: a
208+
/// [`CompatSessionFilter`] and an [`OAuth2SessionFilter`].
209+
fn split_filter(
210+
filter: AppSessionFilter<'_>,
211+
) -> (CompatSessionFilter<'_>, OAuth2SessionFilter<'_>) {
212+
let mut compat_filter = CompatSessionFilter::new();
213+
let mut oauth2_filter = OAuth2SessionFilter::new();
214+
215+
if let Some(user) = filter.user() {
216+
compat_filter = compat_filter.for_user(user);
217+
oauth2_filter = oauth2_filter.for_user(user);
218+
}
219+
220+
match filter.state() {
221+
Some(AppSessionState::Active) => {
222+
compat_filter = compat_filter.active_only();
223+
oauth2_filter = oauth2_filter.active_only();
224+
}
225+
Some(AppSessionState::Finished) => {
226+
compat_filter = compat_filter.finished_only();
227+
oauth2_filter = oauth2_filter.finished_only();
228+
}
229+
None => {}
230+
}
231+
232+
if let Some(device) = filter.device() {
233+
compat_filter = compat_filter.for_device(device);
234+
oauth2_filter = oauth2_filter.for_device(device);
235+
}
236+
237+
if let Some(browser_session) = filter.browser_session() {
238+
compat_filter = compat_filter.for_browser_session(browser_session);
239+
oauth2_filter = oauth2_filter.for_browser_session(browser_session);
240+
}
241+
242+
(compat_filter, oauth2_filter)
243+
}
244+
205245
#[async_trait]
206246
impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
207247
type Error = DatabaseError;
@@ -220,6 +260,8 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
220260
filter: AppSessionFilter<'_>,
221261
pagination: Pagination,
222262
) -> Result<Page<AppSession>, Self::Error> {
263+
let (compat_filter, oauth2_filter) = split_filter(filter);
264+
223265
let mut oauth2_session_select = Query::select()
224266
.expr_as(
225267
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
@@ -269,26 +311,7 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
269311
AppSessionLookupIden::LastActiveIp,
270312
)
271313
.from(OAuth2Sessions::Table)
272-
.and_where_option(filter.user().map(|user| {
273-
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
274-
}))
275-
.and_where_option(filter.state().map(|state| {
276-
if state.is_active() {
277-
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
278-
} else {
279-
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
280-
}
281-
}))
282-
.and_where_option(filter.browser_session().map(|browser_session| {
283-
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
284-
.eq(Uuid::from(browser_session.id))
285-
}))
286-
.and_where_option(filter.device().map(|device| {
287-
Expr::val(device.to_scope_token().to_string()).eq(PgFunc::any(Expr::col((
288-
OAuth2Sessions::Table,
289-
OAuth2Sessions::ScopeList,
290-
))))
291-
}))
314+
.apply_filter(oauth2_filter)
292315
.clone();
293316

294317
let compat_session_select = Query::select()
@@ -340,23 +363,7 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
340363
AppSessionLookupIden::LastActiveIp,
341364
)
342365
.from(CompatSessions::Table)
343-
.and_where_option(filter.user().map(|user| {
344-
Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
345-
}))
346-
.and_where_option(filter.state().map(|state| {
347-
if state.is_active() {
348-
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
349-
} else {
350-
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
351-
}
352-
}))
353-
.and_where_option(filter.browser_session().map(|browser_session| {
354-
Expr::col((CompatSessions::Table, CompatSessions::UserSessionId))
355-
.eq(Uuid::from(browser_session.id))
356-
}))
357-
.and_where_option(filter.device().map(|device| {
358-
Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.to_string())
359-
}))
366+
.apply_filter(compat_filter)
360367
.clone();
361368

362369
let common_table_expression = CommonTableExpression::new()
@@ -397,51 +404,17 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
397404
err,
398405
)]
399406
async fn count(&mut self, filter: AppSessionFilter<'_>) -> Result<usize, Self::Error> {
407+
let (compat_filter, oauth2_filter) = split_filter(filter);
400408
let mut oauth2_session_select = Query::select()
401409
.expr(Expr::cust("1"))
402410
.from(OAuth2Sessions::Table)
403-
.and_where_option(filter.user().map(|user| {
404-
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
405-
}))
406-
.and_where_option(filter.state().map(|state| {
407-
if state.is_active() {
408-
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
409-
} else {
410-
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
411-
}
412-
}))
413-
.and_where_option(filter.browser_session().map(|browser_session| {
414-
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
415-
.eq(Uuid::from(browser_session.id))
416-
}))
417-
.and_where_option(filter.device().map(|device| {
418-
Expr::val(device.to_scope_token().to_string()).eq(PgFunc::any(Expr::col((
419-
OAuth2Sessions::Table,
420-
OAuth2Sessions::ScopeList,
421-
))))
422-
}))
411+
.apply_filter(oauth2_filter)
423412
.clone();
424413

425414
let compat_session_select = Query::select()
426415
.expr(Expr::cust("1"))
427416
.from(CompatSessions::Table)
428-
.and_where_option(filter.user().map(|user| {
429-
Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
430-
}))
431-
.and_where_option(filter.state().map(|state| {
432-
if state.is_active() {
433-
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
434-
} else {
435-
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
436-
}
437-
}))
438-
.and_where_option(filter.browser_session().map(|browser_session| {
439-
Expr::col((CompatSessions::Table, CompatSessions::UserSessionId))
440-
.eq(Uuid::from(browser_session.id))
441-
}))
442-
.and_where_option(filter.device().map(|device| {
443-
Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.to_string())
444-
}))
417+
.apply_filter(compat_filter)
445418
.clone();
446419

447420
let common_table_expression = CommonTableExpression::new()

0 commit comments

Comments
 (0)