1
1
from typing import Optional , Union
2
2
3
- from sqlalchemy import and_ , case , func , null , or_ , select
3
+ from sqlalchemy import Integer , case , func , or_ , select , union
4
4
from sqlalchemy .sql .expression import literal
5
5
from strawberry .dataloader import DataLoader
6
6
from typing_extensions import TypeAlias
@@ -20,74 +20,50 @@ class DatasetExampleRevisionsDataLoader(DataLoader[Key, Result]):
20
20
def __init__ (self , db : DbSessionFactory ) -> None :
21
21
super ().__init__ (
22
22
load_fn = self ._load_fn ,
23
- max_batch_size = 200 , # needed to prevent the size of the query from getting too large
23
+ # Setting max_batch_size to prevent the size of the query from getting too large.
24
+ # The maximum number of terms is SQLITE_MAX_COMPOUND_SELECT which defaults to 500.
25
+ # This is needed because of the compound select query below used in transferring
26
+ # the input data to the database. SQLite in fact has better ways to transfer data,
27
+ # but unfortunately they're not made available in sqlalchemy yet.
28
+ max_batch_size = 200 ,
24
29
)
25
30
self ._db = db
26
31
27
32
async def _load_fn (self , keys : list [Key ]) -> list [Union [Result , NotFound ]]:
28
- example_and_version_ids = tuple (
29
- set (
30
- (example_id , version_id )
31
- for example_id , version_id in keys
32
- if version_id is not None
33
- )
34
- )
35
- versionless_example_ids = tuple (
36
- set (example_id for example_id , version_id in keys if version_id is None )
37
- )
38
- resolved_example_and_version_ids = (
39
- (
33
+ # sqlalchemy has limited SQLite support for VALUES, so use UNION ALL instead.
34
+ # For details, see https://github.com/sqlalchemy/sqlalchemy/issues/7228
35
+ keys_subquery = union (
36
+ * (
40
37
select (
41
- models .DatasetExample .id .label ("example_id" ),
42
- models .DatasetVersion .id .label ("version_id" ),
43
- )
44
- .select_from (models .DatasetExample )
45
- .join (
46
- models .DatasetVersion ,
47
- onclause = literal (True ), # cross join
48
- )
49
- .where (
50
- or_ (
51
- * (
52
- and_ (
53
- models .DatasetExample .id == example_id ,
54
- models .DatasetVersion .id == version_id ,
55
- )
56
- for example_id , version_id in example_and_version_ids
57
- )
58
- )
38
+ literal (example_id , Integer ).label ("example_id" ),
39
+ literal (version_id , Integer ).label ("version_id" ),
59
40
)
41
+ for example_id , version_id in keys
60
42
)
61
- .union (
62
- select (
63
- models .DatasetExample .id .label ("example_id" ), null ().label ("version_id" )
64
- ).where (models .DatasetExample .id .in_ (versionless_example_ids ))
65
- )
66
- .subquery ()
67
- )
43
+ ).subquery ()
68
44
revision_ids = (
69
45
select (
70
- resolved_example_and_version_ids .c .example_id ,
71
- resolved_example_and_version_ids .c .version_id ,
46
+ keys_subquery .c .example_id ,
47
+ keys_subquery .c .version_id ,
72
48
func .max (models .DatasetExampleRevision .id ).label ("revision_id" ),
73
49
)
74
- .select_from (resolved_example_and_version_ids )
50
+ .select_from (keys_subquery )
75
51
.join (
76
52
models .DatasetExampleRevision ,
77
- onclause = resolved_example_and_version_ids .c .example_id
53
+ onclause = keys_subquery .c .example_id
78
54
== models .DatasetExampleRevision .dataset_example_id ,
79
55
)
80
56
.where (
81
57
or_ (
82
- resolved_example_and_version_ids .c .version_id .is_ (None ),
83
- models .DatasetExampleRevision .dataset_version_id
84
- <= resolved_example_and_version_ids .c .version_id ,
58
+ # This query gets the latest `revision_id` for each example:
59
+ # - If `version_id` is NOT given, it finds the maximum `revision_id`.
60
+ # - If `version_id` is given, it finds the highest `revision_id` whose
61
+ # `version_id` is less than or equal to the one specified.
62
+ keys_subquery .c .version_id .is_ (None ),
63
+ models .DatasetExampleRevision .dataset_version_id <= keys_subquery .c .version_id ,
85
64
)
86
65
)
87
- .group_by (
88
- resolved_example_and_version_ids .c .example_id ,
89
- resolved_example_and_version_ids .c .version_id ,
90
- )
66
+ .group_by (keys_subquery .c .example_id , keys_subquery .c .version_id )
91
67
).subquery ()
92
68
query = (
93
69
select (
0 commit comments