Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Refactor _get_e2e_device_keys_txn to split large queries #13956

Merged
merged 4 commits into from
Oct 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/13956.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a long-standing bug where `POST /_matrix/client/v3/keys/query` requests could result in excessively large SQL queries.
60 changes: 60 additions & 0 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2431,6 +2431,66 @@ def make_in_list_sql_clause(
return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)


# These overloads ensure that `columns` and `iterable` values have the same length.
# Suppress "Single overload definition, multiple required" complaint.
@overload # type: ignore[misc]
def make_tuple_in_list_sql_clause(
database_engine: BaseDatabaseEngine,
columns: Tuple[str, str],
iterable: Collection[Tuple[Any, Any]],
) -> Tuple[str, list]:
...


def make_tuple_in_list_sql_clause(
database_engine: BaseDatabaseEngine,
columns: Tuple[str, ...],
iterable: Collection[Tuple[Any, ...]],
) -> Tuple[str, list]:
"""Returns an SQL clause that checks the given tuple of columns is in the iterable.

Args:
database_engine
columns: Names of the columns in the tuple.
iterable: The tuples to check the columns against.

Returns:
A tuple of SQL query and the args
"""
if len(columns) == 0:
# Should be unreachable due to mypy, as long as the overloads are set up right.
if () in iterable:
return "TRUE", []
else:
return "FALSE", []

if len(columns) == 1:
# Use `= ANY(?)` on postgres.
return make_in_list_sql_clause(
database_engine, next(iter(columns)), [values[0] for values in iterable]
)

# There are multiple columns. Avoid using an `= ANY(?)` clause on postgres, as
# indices are not used when there are multiple columns. Instead, use an `IN`
# expression.
#
# `IN ((?, ...), ...)` with tuples is supported by postgres only, whereas
# `IN (VALUES (?, ...), ...)` is supported by both sqlite and postgres.
# Thus, the latter is chosen.
Comment on lines +2477 to +2479
Copy link
Contributor Author

@squahtx squahtx Sep 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The query plan for IN (VALUES (?, ...), ...) looks like

                                                                               QUERY PLAN
------------------------------------------------------------------------------------------------------------------------------------------------------------------------
 Nested Loop  (cost=5.67..2480429.48 rows=56325687 width=504) (actual time=1.669..1.671 rows=0 loops=1)
   ->  Nested Loop  (cost=5.11..923.05 rows=2657681 width=535) (actual time=1.668..1.669 rows=0 loops=1)
         ->  HashAggregate  (cost=4.55..6.55 rows=200 width=64) (actual time=0.321..0.418 rows=260 loops=1)
               Group Key: "*VALUES*".column1, "*VALUES*".column2
               ->  Values Scan on "*VALUES*"  (cost=0.00..3.25 rows=260 width=64) (actual time=0.004..0.158 rows=260 loops=1)
         ->  Index Scan using e2e_device_keys_json_uniqueness on e2e_device_keys_json k  (cost=0.56..4.58 rows=1 width=471) (actual time=0.004..0.004 rows=0 loops=260)
               Index Cond: ((user_id = "*VALUES*".column1) AND (device_id = "*VALUES*".column2))
   ->  Index Scan using device_uniqueness on devices d  (cost=0.56..0.93 rows=1 width=67) (never executed)
         Index Cond: ((user_id = k.user_id) AND (device_id = k.device_id))
         Filter: (NOT hidden)
 Planning time: 3.435 ms
 Execution time: 1.748 ms
(12 rows)

When VALUES is omitted, the query plan becomes entertainingly large:

 Nested Loop  (cost=670.12..680.76 rows=5 width=504) (actual time=0.537..0.566 rows=0 loops=1)
   ->  Bitmap Heap Scan on devices d  (cost=669.56..676.18 rows=1 width=67) (actual time=0.536..0.565 rows=0 loops=1)
         Recheck Cond: (((user_id = 'a01'::text) AND (device_id = 'b'::text)) OR ((user_id = 'c01'::text) AND ...
         Filter: (NOT hidden)
         ->  BitmapOr  (cost=669.56..669.56 rows=2 width=0) (actual time=0.534..0.563 rows=0 loops=1)
               ->  Bitmap Index Scan on device_uniqueness  (cost=0.00..2.57 rows=1 width=0) (actual time=0.008..0.008 rows=0 loops=1)
                     Index Cond: ((user_id = 'a01'::text) AND (device_id = 'b'::text))
               ->  Bitmap Index Scan on device_uniqueness  (cost=0.00..2.57 rows=1 width=0) (actual time=0.002..0.002 rows=0 loops=1)
                     Index Cond: ((user_id = 'c01'::text) AND (device_id = 'd'::text))
               ->  Bitmap Index Scan on device_uniqueness  (cost=0.00..2.57 rows=1 width=0) (actual time=0.002..0.002 rows=0 loops=1)
                     Index Cond: ((user_id = 'e01'::text) AND (device_id = 'f'::text))
               ->  ...
   ->  Index Scan using e2e_device_keys_json_uniqueness on e2e_device_keys_json k  (cost=0.56..4.58 rows=1 width=471) (never executed)
         Index Cond: ((user_id = d.user_id) AND (device_id = d.device_id))
 Planning time: 12.403 ms
 Execution time: 3.989 ms
(529 rows)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AAAAAAAAARGH WHYYYYYYYYYYYYYYYYYYYYYYYY

Is this also true on pg14? :(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't got access to a postgres 14 db with realistic data to test with.
It only happens with tuples by the way. A regular IN (...) with single values gets turned into =ANY.


if len(iterable) == 0:
# A 0-length `VALUES` list is not allowed in sqlite or postgres.
# Also note that a 0-length `IN (...)` clause (not using `VALUES`) is not
# allowed in postgres.
return "FALSE", []

tuple_sql = "(%s)" % (",".join("?" for _ in columns),)
return "(%s) IN (VALUES %s)" % (
",".join(column for column in columns),
",".join(tuple_sql for _ in iterable),
), [value for values in iterable for value in values]


KV = TypeVar("KV")


Expand Down
83 changes: 54 additions & 29 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
LoggingDatabaseConnection,
LoggingTransaction,
make_in_list_sql_clause,
make_tuple_in_list_sql_clause,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
Expand Down Expand Up @@ -278,7 +279,7 @@ async def get_e2e_device_keys_and_signatures(
def _get_e2e_device_keys_txn(
self,
txn: LoggingTransaction,
query_list: Collection[Tuple[str, str]],
query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: bool = False,
include_deleted_devices: bool = False,
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
Expand All @@ -288,49 +289,73 @@ def _get_e2e_device_keys_txn(
cross-signing signatures which have been added subsequently (for which, see
get_e2e_device_keys_and_signatures)
"""
query_clauses = []
query_params = []
query_clauses: List[str] = []
query_params_list: List[List[object]] = []

if include_all_devices is False:
include_deleted_devices = False

if include_deleted_devices:
deleted_devices = set(query_list)

# Split the query list into queries for users and queries for particular
# devices.
user_list = []
user_device_list = []
for (user_id, device_id) in query_list:
query_clause = "user_id = ?"
query_params.append(user_id)

if device_id is not None:
query_clause += " AND device_id = ?"
query_params.append(device_id)

query_clauses.append(query_clause)

sql = (
"SELECT user_id, device_id, "
" d.display_name, "
" k.key_json"
" FROM devices d"
" %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
" WHERE %s AND NOT d.hidden"
) % (
"LEFT" if include_all_devices else "INNER",
" OR ".join("(" + q + ")" for q in query_clauses),
)
if device_id is None:
user_list.append(user_id)
else:
user_device_list.append((user_id, device_id))

txn.execute(sql, query_params)
if user_list:
user_id_in_list_clause, user_args = make_in_list_sql_clause(
txn.database_engine, "user_id", user_list
)
query_clauses.append(user_id_in_list_clause)
query_params_list.append(user_args)

if user_device_list:
# Divide the device queries into batches, to avoid excessively large
# queries.
for user_device_batch in batch_iter(user_device_list, 1024):
(
user_device_id_in_list_clause,
user_device_args,
) = make_tuple_in_list_sql_clause(
txn.database_engine, ("user_id", "device_id"), user_device_batch
)
query_clauses.append(user_device_id_in_list_clause)
query_params_list.append(user_device_args)

result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {}
for (user_id, device_id, display_name, key_json) in txn:
if include_deleted_devices:
deleted_devices.remove((user_id, device_id))
result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
display_name, db_to_json(key_json) if key_json else None
for query_clause, query_params in zip(query_clauses, query_params_list):
sql = (
"SELECT user_id, device_id, "
" d.display_name, "
" k.key_json"
" FROM devices d"
" %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
" WHERE %s AND NOT d.hidden"
) % (
"LEFT" if include_all_devices else "INNER",
query_clause,
)

txn.execute(sql, query_params)

for (user_id, device_id, display_name, key_json) in txn:
assert device_id is not None
if include_deleted_devices:
deleted_devices.remove((user_id, device_id))
result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
display_name, db_to_json(key_json) if key_json else None
)

if include_deleted_devices:
for user_id, device_id in deleted_devices:
if device_id is None:
continue
result.setdefault(user_id, {})[device_id] = None

return result
Expand Down