Skip to content

Commit 7412bb9

Browse files
authored
fix: reduce query memory usage in DatasetExampleRevisionsDataLoader (#6116)
1 parent d9303c2 commit 7412bb9

File tree

1 file changed

+26
-50
lines changed

1 file changed

+26
-50
lines changed

src/phoenix/server/api/dataloaders/dataset_example_revisions.py

+26-50
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Optional, Union
22

3-
from sqlalchemy import and_, case, func, null, or_, select
3+
from sqlalchemy import Integer, case, func, or_, select, union
44
from sqlalchemy.sql.expression import literal
55
from strawberry.dataloader import DataLoader
66
from typing_extensions import TypeAlias
@@ -20,74 +20,50 @@ class DatasetExampleRevisionsDataLoader(DataLoader[Key, Result]):
2020
def __init__(self, db: DbSessionFactory) -> None:
2121
super().__init__(
2222
load_fn=self._load_fn,
23-
max_batch_size=200, # needed to prevent the size of the query from getting too large
23+
# Setting max_batch_size to prevent the size of the query from getting too large.
24+
# The maximum number of terms is SQLITE_MAX_COMPOUND_SELECT which defaults to 500.
25+
# This is needed because of the compound select query below used in transferring
26+
# the input data to the database. SQLite in fact has better ways to transfer data,
27+
# but unfortunately they're not made available in sqlalchemy yet.
28+
max_batch_size=200,
2429
)
2530
self._db = db
2631

2732
async def _load_fn(self, keys: list[Key]) -> list[Union[Result, NotFound]]:
28-
example_and_version_ids = tuple(
29-
set(
30-
(example_id, version_id)
31-
for example_id, version_id in keys
32-
if version_id is not None
33-
)
34-
)
35-
versionless_example_ids = tuple(
36-
set(example_id for example_id, version_id in keys if version_id is None)
37-
)
38-
resolved_example_and_version_ids = (
39-
(
33+
# sqlalchemy has limited SQLite support for VALUES, so use UNION ALL instead.
34+
# For details, see https://github.com/sqlalchemy/sqlalchemy/issues/7228
35+
keys_subquery = union(
36+
*(
4037
select(
41-
models.DatasetExample.id.label("example_id"),
42-
models.DatasetVersion.id.label("version_id"),
43-
)
44-
.select_from(models.DatasetExample)
45-
.join(
46-
models.DatasetVersion,
47-
onclause=literal(True), # cross join
48-
)
49-
.where(
50-
or_(
51-
*(
52-
and_(
53-
models.DatasetExample.id == example_id,
54-
models.DatasetVersion.id == version_id,
55-
)
56-
for example_id, version_id in example_and_version_ids
57-
)
58-
)
38+
literal(example_id, Integer).label("example_id"),
39+
literal(version_id, Integer).label("version_id"),
5940
)
41+
for example_id, version_id in keys
6042
)
61-
.union(
62-
select(
63-
models.DatasetExample.id.label("example_id"), null().label("version_id")
64-
).where(models.DatasetExample.id.in_(versionless_example_ids))
65-
)
66-
.subquery()
67-
)
43+
).subquery()
6844
revision_ids = (
6945
select(
70-
resolved_example_and_version_ids.c.example_id,
71-
resolved_example_and_version_ids.c.version_id,
46+
keys_subquery.c.example_id,
47+
keys_subquery.c.version_id,
7248
func.max(models.DatasetExampleRevision.id).label("revision_id"),
7349
)
74-
.select_from(resolved_example_and_version_ids)
50+
.select_from(keys_subquery)
7551
.join(
7652
models.DatasetExampleRevision,
77-
onclause=resolved_example_and_version_ids.c.example_id
53+
onclause=keys_subquery.c.example_id
7854
== models.DatasetExampleRevision.dataset_example_id,
7955
)
8056
.where(
8157
or_(
82-
resolved_example_and_version_ids.c.version_id.is_(None),
83-
models.DatasetExampleRevision.dataset_version_id
84-
<= resolved_example_and_version_ids.c.version_id,
58+
# This query gets the latest `revision_id` for each example:
59+
# - If `version_id` is NOT given, it finds the maximum `revision_id`.
60+
# - If `version_id` is given, it finds the highest `revision_id` whose
61+
# `version_id` is less than or equal to the one specified.
62+
keys_subquery.c.version_id.is_(None),
63+
models.DatasetExampleRevision.dataset_version_id <= keys_subquery.c.version_id,
8564
)
8665
)
87-
.group_by(
88-
resolved_example_and_version_ids.c.example_id,
89-
resolved_example_and_version_ids.c.version_id,
90-
)
66+
.group_by(keys_subquery.c.example_id, keys_subquery.c.version_id)
9167
).subquery()
9268
query = (
9369
select(

0 commit comments

Comments
 (0)