Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Refactor _get_e2e_device_keys_txn to split large queries (#13956)
Browse files Browse the repository at this point in the history
Instead of running a single large query, run a single query for
user-only lookups and additional queries for batches of user device
lookups.

Resolves #13580.

Signed-off-by: Sean Quah <[email protected]>
  • Loading branch information
squahtx authored Oct 3, 2022
1 parent 061739d commit d65862c
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 29 deletions.
1 change: 1 addition & 0 deletions changelog.d/13956.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a long-standing bug where `POST /_matrix/client/v3/keys/query` requests could result in excessively large SQL queries.
60 changes: 60 additions & 0 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2461,6 +2461,66 @@ def make_in_list_sql_clause(
return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)


# These overloads ensure that `columns` and `iterable` values have the same length.
# Suppress "Single overload definition, multiple required" complaint.
@overload # type: ignore[misc]
def make_tuple_in_list_sql_clause(
database_engine: BaseDatabaseEngine,
columns: Tuple[str, str],
iterable: Collection[Tuple[Any, Any]],
) -> Tuple[str, list]:
...


def make_tuple_in_list_sql_clause(
database_engine: BaseDatabaseEngine,
columns: Tuple[str, ...],
iterable: Collection[Tuple[Any, ...]],
) -> Tuple[str, list]:
"""Returns an SQL clause that checks the given tuple of columns is in the iterable.
Args:
database_engine
columns: Names of the columns in the tuple.
iterable: The tuples to check the columns against.
Returns:
A tuple of SQL query and the args
"""
if len(columns) == 0:
# Should be unreachable due to mypy, as long as the overloads are set up right.
if () in iterable:
return "TRUE", []
else:
return "FALSE", []

if len(columns) == 1:
# Use `= ANY(?)` on postgres.
return make_in_list_sql_clause(
database_engine, next(iter(columns)), [values[0] for values in iterable]
)

# There are multiple columns. Avoid using an `= ANY(?)` clause on postgres, as
# indices are not used when there are multiple columns. Instead, use an `IN`
# expression.
#
# `IN ((?, ...), ...)` with tuples is supported by postgres only, whereas
# `IN (VALUES (?, ...), ...)` is supported by both sqlite and postgres.
# Thus, the latter is chosen.

if len(iterable) == 0:
# A 0-length `VALUES` list is not allowed in sqlite or postgres.
# Also note that a 0-length `IN (...)` clause (not using `VALUES`) is not
# allowed in postgres.
return "FALSE", []

tuple_sql = "(%s)" % (",".join("?" for _ in columns),)
return "(%s) IN (VALUES %s)" % (
",".join(column for column in columns),
",".join(tuple_sql for _ in iterable),
), [value for values in iterable for value in values]


KV = TypeVar("KV")


Expand Down
83 changes: 54 additions & 29 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
LoggingDatabaseConnection,
LoggingTransaction,
make_in_list_sql_clause,
make_tuple_in_list_sql_clause,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
Expand Down Expand Up @@ -278,7 +279,7 @@ async def get_e2e_device_keys_and_signatures(
def _get_e2e_device_keys_txn(
self,
txn: LoggingTransaction,
query_list: Collection[Tuple[str, str]],
query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: bool = False,
include_deleted_devices: bool = False,
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
Expand All @@ -288,49 +289,73 @@ def _get_e2e_device_keys_txn(
cross-signing signatures which have been added subsequently (for which, see
get_e2e_device_keys_and_signatures)
"""
query_clauses = []
query_params = []
query_clauses: List[str] = []
query_params_list: List[List[object]] = []

if include_all_devices is False:
include_deleted_devices = False

if include_deleted_devices:
deleted_devices = set(query_list)

# Split the query list into queries for users and queries for particular
# devices.
user_list = []
user_device_list = []
for (user_id, device_id) in query_list:
query_clause = "user_id = ?"
query_params.append(user_id)

if device_id is not None:
query_clause += " AND device_id = ?"
query_params.append(device_id)

query_clauses.append(query_clause)

sql = (
"SELECT user_id, device_id, "
" d.display_name, "
" k.key_json"
" FROM devices d"
" %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
" WHERE %s AND NOT d.hidden"
) % (
"LEFT" if include_all_devices else "INNER",
" OR ".join("(" + q + ")" for q in query_clauses),
)
if device_id is None:
user_list.append(user_id)
else:
user_device_list.append((user_id, device_id))

txn.execute(sql, query_params)
if user_list:
user_id_in_list_clause, user_args = make_in_list_sql_clause(
txn.database_engine, "user_id", user_list
)
query_clauses.append(user_id_in_list_clause)
query_params_list.append(user_args)

if user_device_list:
# Divide the device queries into batches, to avoid excessively large
# queries.
for user_device_batch in batch_iter(user_device_list, 1024):
(
user_device_id_in_list_clause,
user_device_args,
) = make_tuple_in_list_sql_clause(
txn.database_engine, ("user_id", "device_id"), user_device_batch
)
query_clauses.append(user_device_id_in_list_clause)
query_params_list.append(user_device_args)

result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {}
for (user_id, device_id, display_name, key_json) in txn:
if include_deleted_devices:
deleted_devices.remove((user_id, device_id))
result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
display_name, db_to_json(key_json) if key_json else None
for query_clause, query_params in zip(query_clauses, query_params_list):
sql = (
"SELECT user_id, device_id, "
" d.display_name, "
" k.key_json"
" FROM devices d"
" %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
" WHERE %s AND NOT d.hidden"
) % (
"LEFT" if include_all_devices else "INNER",
query_clause,
)

txn.execute(sql, query_params)

for (user_id, device_id, display_name, key_json) in txn:
assert device_id is not None
if include_deleted_devices:
deleted_devices.remove((user_id, device_id))
result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
display_name, db_to_json(key_json) if key_json else None
)

if include_deleted_devices:
for user_id, device_id in deleted_devices:
if device_id is None:
continue
result.setdefault(user_id, {})[device_id] = None

return result
Expand Down

0 comments on commit d65862c

Please sign in to comment.