1
- // Copyright 2023 The Matrix.org Foundation C.I.C.
1
+ // Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
2
2
//
3
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
4
// you may not use this file except in compliance with the License.
17
17
use async_trait:: async_trait;
18
18
use mas_data_model:: { CompatSession , CompatSessionState , Device , Session , SessionState , UserAgent } ;
19
19
use mas_storage:: {
20
- app_session:: { AppSession , AppSessionFilter , AppSessionRepository } ,
20
+ app_session:: { AppSession , AppSessionFilter , AppSessionRepository , AppSessionState } ,
21
+ compat:: CompatSessionFilter ,
22
+ oauth2:: OAuth2SessionFilter ,
21
23
Page , Pagination ,
22
24
} ;
23
25
use oauth2_types:: scope:: { Scope , ScopeToken } ;
24
26
use sea_query:: {
25
- Alias , ColumnRef , CommonTableExpression , Expr , PgFunc , PostgresQueryBuilder , Query , UnionType ,
27
+ Alias , ColumnRef , CommonTableExpression , Expr , PostgresQueryBuilder , Query , UnionType ,
26
28
} ;
27
29
use sea_query_binder:: SqlxBinder ;
28
30
use sqlx:: PgConnection ;
29
31
use ulid:: Ulid ;
30
- use uuid:: Uuid ;
31
32
32
33
use crate :: {
33
34
errors:: DatabaseInconsistencyError ,
35
+ filter:: StatementExt ,
34
36
iden:: { CompatSessions , OAuth2Sessions } ,
35
37
pagination:: QueryBuilderExt ,
36
38
DatabaseError , ExecuteExt ,
@@ -202,6 +204,44 @@ impl TryFrom<AppSessionLookup> for AppSession {
202
204
}
203
205
}
204
206
207
+ /// Split a [`AppSessionFilter`] into two separate filters: a
208
+ /// [`CompatSessionFilter`] and an [`OAuth2SessionFilter`].
209
+ fn split_filter (
210
+ filter : AppSessionFilter < ' _ > ,
211
+ ) -> ( CompatSessionFilter < ' _ > , OAuth2SessionFilter < ' _ > ) {
212
+ let mut compat_filter = CompatSessionFilter :: new ( ) ;
213
+ let mut oauth2_filter = OAuth2SessionFilter :: new ( ) ;
214
+
215
+ if let Some ( user) = filter. user ( ) {
216
+ compat_filter = compat_filter. for_user ( user) ;
217
+ oauth2_filter = oauth2_filter. for_user ( user) ;
218
+ }
219
+
220
+ match filter. state ( ) {
221
+ Some ( AppSessionState :: Active ) => {
222
+ compat_filter = compat_filter. active_only ( ) ;
223
+ oauth2_filter = oauth2_filter. active_only ( ) ;
224
+ }
225
+ Some ( AppSessionState :: Finished ) => {
226
+ compat_filter = compat_filter. finished_only ( ) ;
227
+ oauth2_filter = oauth2_filter. finished_only ( ) ;
228
+ }
229
+ None => { }
230
+ }
231
+
232
+ if let Some ( device) = filter. device ( ) {
233
+ compat_filter = compat_filter. for_device ( device) ;
234
+ oauth2_filter = oauth2_filter. for_device ( device) ;
235
+ }
236
+
237
+ if let Some ( browser_session) = filter. browser_session ( ) {
238
+ compat_filter = compat_filter. for_browser_session ( browser_session) ;
239
+ oauth2_filter = oauth2_filter. for_browser_session ( browser_session) ;
240
+ }
241
+
242
+ ( compat_filter, oauth2_filter)
243
+ }
244
+
205
245
#[ async_trait]
206
246
impl < ' c > AppSessionRepository for PgAppSessionRepository < ' c > {
207
247
type Error = DatabaseError ;
@@ -220,6 +260,8 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
220
260
filter : AppSessionFilter < ' _ > ,
221
261
pagination : Pagination ,
222
262
) -> Result < Page < AppSession > , Self :: Error > {
263
+ let ( compat_filter, oauth2_filter) = split_filter ( filter) ;
264
+
223
265
let mut oauth2_session_select = Query :: select ( )
224
266
. expr_as (
225
267
Expr :: col ( ( OAuth2Sessions :: Table , OAuth2Sessions :: OAuth2SessionId ) ) ,
@@ -269,26 +311,7 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
269
311
AppSessionLookupIden :: LastActiveIp ,
270
312
)
271
313
. from ( OAuth2Sessions :: Table )
272
- . and_where_option ( filter. user ( ) . map ( |user| {
273
- Expr :: col ( ( OAuth2Sessions :: Table , OAuth2Sessions :: UserId ) ) . eq ( Uuid :: from ( user. id ) )
274
- } ) )
275
- . and_where_option ( filter. state ( ) . map ( |state| {
276
- if state. is_active ( ) {
277
- Expr :: col ( ( OAuth2Sessions :: Table , OAuth2Sessions :: FinishedAt ) ) . is_null ( )
278
- } else {
279
- Expr :: col ( ( OAuth2Sessions :: Table , OAuth2Sessions :: FinishedAt ) ) . is_not_null ( )
280
- }
281
- } ) )
282
- . and_where_option ( filter. browser_session ( ) . map ( |browser_session| {
283
- Expr :: col ( ( OAuth2Sessions :: Table , OAuth2Sessions :: UserSessionId ) )
284
- . eq ( Uuid :: from ( browser_session. id ) )
285
- } ) )
286
- . and_where_option ( filter. device ( ) . map ( |device| {
287
- Expr :: val ( device. to_scope_token ( ) . to_string ( ) ) . eq ( PgFunc :: any ( Expr :: col ( (
288
- OAuth2Sessions :: Table ,
289
- OAuth2Sessions :: ScopeList ,
290
- ) ) ) )
291
- } ) )
314
+ . apply_filter ( oauth2_filter)
292
315
. clone ( ) ;
293
316
294
317
let compat_session_select = Query :: select ( )
@@ -340,23 +363,7 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
340
363
AppSessionLookupIden :: LastActiveIp ,
341
364
)
342
365
. from ( CompatSessions :: Table )
343
- . and_where_option ( filter. user ( ) . map ( |user| {
344
- Expr :: col ( ( CompatSessions :: Table , CompatSessions :: UserId ) ) . eq ( Uuid :: from ( user. id ) )
345
- } ) )
346
- . and_where_option ( filter. state ( ) . map ( |state| {
347
- if state. is_active ( ) {
348
- Expr :: col ( ( CompatSessions :: Table , CompatSessions :: FinishedAt ) ) . is_null ( )
349
- } else {
350
- Expr :: col ( ( CompatSessions :: Table , CompatSessions :: FinishedAt ) ) . is_not_null ( )
351
- }
352
- } ) )
353
- . and_where_option ( filter. browser_session ( ) . map ( |browser_session| {
354
- Expr :: col ( ( CompatSessions :: Table , CompatSessions :: UserSessionId ) )
355
- . eq ( Uuid :: from ( browser_session. id ) )
356
- } ) )
357
- . and_where_option ( filter. device ( ) . map ( |device| {
358
- Expr :: col ( ( CompatSessions :: Table , CompatSessions :: DeviceId ) ) . eq ( device. to_string ( ) )
359
- } ) )
366
+ . apply_filter ( compat_filter)
360
367
. clone ( ) ;
361
368
362
369
let common_table_expression = CommonTableExpression :: new ( )
@@ -397,51 +404,17 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
397
404
err,
398
405
) ]
399
406
async fn count ( & mut self , filter : AppSessionFilter < ' _ > ) -> Result < usize , Self :: Error > {
407
+ let ( compat_filter, oauth2_filter) = split_filter ( filter) ;
400
408
let mut oauth2_session_select = Query :: select ( )
401
409
. expr ( Expr :: cust ( "1" ) )
402
410
. from ( OAuth2Sessions :: Table )
403
- . and_where_option ( filter. user ( ) . map ( |user| {
404
- Expr :: col ( ( OAuth2Sessions :: Table , OAuth2Sessions :: UserId ) ) . eq ( Uuid :: from ( user. id ) )
405
- } ) )
406
- . and_where_option ( filter. state ( ) . map ( |state| {
407
- if state. is_active ( ) {
408
- Expr :: col ( ( OAuth2Sessions :: Table , OAuth2Sessions :: FinishedAt ) ) . is_null ( )
409
- } else {
410
- Expr :: col ( ( OAuth2Sessions :: Table , OAuth2Sessions :: FinishedAt ) ) . is_not_null ( )
411
- }
412
- } ) )
413
- . and_where_option ( filter. browser_session ( ) . map ( |browser_session| {
414
- Expr :: col ( ( OAuth2Sessions :: Table , OAuth2Sessions :: UserSessionId ) )
415
- . eq ( Uuid :: from ( browser_session. id ) )
416
- } ) )
417
- . and_where_option ( filter. device ( ) . map ( |device| {
418
- Expr :: val ( device. to_scope_token ( ) . to_string ( ) ) . eq ( PgFunc :: any ( Expr :: col ( (
419
- OAuth2Sessions :: Table ,
420
- OAuth2Sessions :: ScopeList ,
421
- ) ) ) )
422
- } ) )
411
+ . apply_filter ( oauth2_filter)
423
412
. clone ( ) ;
424
413
425
414
let compat_session_select = Query :: select ( )
426
415
. expr ( Expr :: cust ( "1" ) )
427
416
. from ( CompatSessions :: Table )
428
- . and_where_option ( filter. user ( ) . map ( |user| {
429
- Expr :: col ( ( CompatSessions :: Table , CompatSessions :: UserId ) ) . eq ( Uuid :: from ( user. id ) )
430
- } ) )
431
- . and_where_option ( filter. state ( ) . map ( |state| {
432
- if state. is_active ( ) {
433
- Expr :: col ( ( CompatSessions :: Table , CompatSessions :: FinishedAt ) ) . is_null ( )
434
- } else {
435
- Expr :: col ( ( CompatSessions :: Table , CompatSessions :: FinishedAt ) ) . is_not_null ( )
436
- }
437
- } ) )
438
- . and_where_option ( filter. browser_session ( ) . map ( |browser_session| {
439
- Expr :: col ( ( CompatSessions :: Table , CompatSessions :: UserSessionId ) )
440
- . eq ( Uuid :: from ( browser_session. id ) )
441
- } ) )
442
- . and_where_option ( filter. device ( ) . map ( |device| {
443
- Expr :: col ( ( CompatSessions :: Table , CompatSessions :: DeviceId ) ) . eq ( device. to_string ( ) )
444
- } ) )
417
+ . apply_filter ( compat_filter)
445
418
. clone ( ) ;
446
419
447
420
let common_table_expression = CommonTableExpression :: new ( )
0 commit comments