Skip to content
1 change: 1 addition & 0 deletions changelog.d/18899.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add an in-memory cache to `_get_e2e_cross_signing_signatures_for_devices` to reduce DB load.
11 changes: 9 additions & 2 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2653,15 +2653,22 @@ def make_in_list_sql_clause(


# These overloads ensure that `columns` and `iterable` values have the same length.
# Suppress "Single overload definition, multiple required" complaint.
@overload # type: ignore[misc]
@overload
def make_tuple_in_list_sql_clause(
database_engine: BaseDatabaseEngine,
columns: Tuple[str, str],
iterable: Collection[Tuple[Any, Any]],
) -> Tuple[str, list]: ...


@overload
def make_tuple_in_list_sql_clause(
database_engine: BaseDatabaseEngine,
columns: Tuple[str, str, str],
iterable: Collection[Tuple[Any, Any, Any]],
) -> Tuple[str, list]: ...


def make_tuple_in_list_sql_clause(
database_engine: BaseDatabaseEngine,
columns: Tuple[str, ...],
Expand Down
34 changes: 34 additions & 0 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@


import itertools
import json
import logging
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Tuple

Expand Down Expand Up @@ -62,6 +63,12 @@
# As above, but for invalidating room caches on room deletion
DELETE_ROOM_CACHE_NAME = "dr_cache_fake"

# This cache takes a list of tuples as its first argument, which requires
# special handling.
GET_E2E_CROSS_SIGNING_SIGNATURES_FOR_DEVICE_CACHE_NAME = (
"_get_e2e_cross_signing_signatures_for_device"
)

# How long between cache invalidation table cleanups, once we have caught up
# with the backlog.
REGULAR_CLEANUP_INTERVAL_MS = Config.parse_duration("1h")
Expand Down Expand Up @@ -270,6 +277,33 @@ def process_replication_rows(
# room membership.
#
# self._membership_stream_cache.all_entities_changed(token) # type: ignore[attr-defined]
elif (
row.cache_func
== GET_E2E_CROSS_SIGNING_SIGNATURES_FOR_DEVICE_CACHE_NAME
):
# "keys" is a list of strings, where each string is a
# JSON-encoded representation of the tuple keys, i.e.
# keys: ['["@userid:domain", "DEVICEID"]','["@userid2:domain", "DEVICEID2"]']
#
# This is a side-effect of not being able to send nested
# information over replication.
for json_str in row.keys:
try:
user_id, device_id = json.loads(json_str)
except (json.JSONDecodeError, TypeError):
logger.error(
"Failed to deserialise cache key as valid JSON: %s",
json_str,
)
continue

# Invalidate each key.
#
# Note: .invalidate takes a tuple of arguments, hence the need
# to nest our tuple in another tuple.
self._get_e2e_cross_signing_signatures_for_device.invalidate( # type: ignore[attr-defined]
((user_id, device_id),)
)
else:
self._attempt_to_invalidate_cache(row.cache_func, row.keys)

Expand Down
202 changes: 147 additions & 55 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#
#
import abc
import json
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -354,15 +355,17 @@ async def get_e2e_device_keys_and_signatures(
)

for batch in batch_iter(signature_query, 50):
cross_sigs_result = await self.db_pool.runInteraction(
"get_e2e_cross_signing_signatures_for_devices",
self._get_e2e_cross_signing_signatures_for_devices_txn,
batch,
cross_sigs_result = (
await self._get_e2e_cross_signing_signatures_for_devices(batch)
)

# add each cross-signing signature to the correct device in the result dict.
for user_id, key_id, device_id, signature in cross_sigs_result:
for (
user_id,
device_id,
), signature_list in cross_sigs_result.items():
target_device_result = result[user_id][device_id]

# We've only looked up cross-signatures for non-deleted devices with key
# data.
assert target_device_result is not None
Expand All @@ -373,7 +376,9 @@ async def get_e2e_device_keys_and_signatures(
signing_user_signatures = target_device_signatures.setdefault(
user_id, {}
)
signing_user_signatures[key_id] = signature

for key_id, signature in signature_list:
signing_user_signatures[key_id] = signature

log_kv(result)
return result
Expand Down Expand Up @@ -479,41 +484,83 @@ def get_e2e_device_keys_txn(

return result

def _get_e2e_cross_signing_signatures_for_devices_txn(
self, txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]]
) -> List[Tuple[str, str, str, str]]:
"""Get cross-signing signatures for a given list of devices
@cached()
def _get_e2e_cross_signing_signatures_for_device(
self,
user_id_and_device_id: Tuple[str, str],
) -> Sequence[Tuple[str, str]]:
"""
The single-item version of `_get_e2e_cross_signing_signatures_for_devices`.
See @cachedList for why a separate method is needed.
"""
raise NotImplementedError()

@cachedList(
cached_method_name="_get_e2e_cross_signing_signatures_for_device",
list_name="device_query",
)
async def _get_e2e_cross_signing_signatures_for_devices(
self, device_query: Iterable[Tuple[str, str]]
) -> Mapping[Tuple[str, str], Sequence[Tuple[str, str]]]:
"""Get cross-signing signatures for a given list of user IDs and devices.

Args:
An iterable containing tuples of (user ID, device ID).

Returns:
A mapping of results. The keys are the original (user_id, device_id)
tuple, while the value is the matching list of tuples of
(key_id, signature). The value will be an empty list if no
signatures exist for the device.

Returns signatures made by the owners of the devices.
Given this method is annotated with `@cachedList`, the return dict's
keys match the tuples within `device_query`, so that cache entries can
be computed from the corresponding values.

Returns: a list of results; each entry in the list is a tuple of
(user_id, key_id, target_device_id, signature).
As results are cached, the return type is immutable.
"""
signature_query_clauses = []
signature_query_params = []

for user_id, device_id in device_query:
signature_query_clauses.append(
"target_user_id = ? AND target_device_id = ? AND user_id = ?"
def _get_e2e_cross_signing_signatures_for_devices_txn(
txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]]
) -> Mapping[Tuple[str, str], Sequence[Tuple[str, str]]]:
where_clause_sql, where_clause_params = make_tuple_in_list_sql_clause(
self.database_engine,
columns=("target_user_id", "target_device_id", "user_id"),
iterable=[
(user_id, device_id, user_id) for user_id, device_id in device_query
],
)
signature_query_params.extend([user_id, device_id, user_id])

signature_sql = """
SELECT user_id, key_id, target_device_id, signature
FROM e2e_cross_signing_signatures WHERE %s
""" % (" OR ".join("(" + q + ")" for q in signature_query_clauses))

txn.execute(signature_sql, signature_query_params)
return cast(
List[
Tuple[
str,
str,
str,
str,
]
],
txn.fetchall(),

signature_sql = f"""
SELECT user_id, key_id, target_device_id, signature
FROM e2e_cross_signing_signatures WHERE {where_clause_sql}
"""

txn.execute(signature_sql, where_clause_params)

devices_and_signatures: Dict[Tuple[str, str], List[Tuple[str, str]]] = {}

# `@cachedList` requires we return one key for every item in `device_query`.
# Pre-populate `devices_and_signatures` with each key so that none are missing.
#
# If any are missing, they will be cached as `None`, which is not
# what callers expected.
for user_id, device_id in device_query:
devices_and_signatures.setdefault((user_id, device_id), [])

# Populate the return dictionary with each found key_id and signature.
for user_id, key_id, target_device_id, signature in txn.fetchall():
signature_tuple = (key_id, signature)
devices_and_signatures[(user_id, target_device_id)].append(
signature_tuple
)

return devices_and_signatures

return await self.db_pool.runInteraction(
"_get_e2e_cross_signing_signatures_for_devices_txn",
_get_e2e_cross_signing_signatures_for_devices_txn,
device_query,
)

async def get_e2e_one_time_keys(
Expand Down Expand Up @@ -1772,26 +1819,71 @@ async def store_e2e_cross_signing_signatures(
user_id: the user who made the signatures
signatures: signatures to add
"""
await self.db_pool.simple_insert_many(
"e2e_cross_signing_signatures",
keys=(
"user_id",
"key_id",
"target_user_id",
"target_device_id",
"signature",
),
values=[
(
user_id,
item.signing_key_id,
item.target_user_id,
item.target_device_id,
item.signature,
)

def _store_e2e_cross_signing_signatures(
txn: LoggingTransaction,
signatures: "Iterable[SignatureListItem]",
) -> None:
self.db_pool.simple_insert_many_txn(
txn,
"e2e_cross_signing_signatures",
keys=(
"user_id",
"key_id",
"target_user_id",
"target_device_id",
"signature",
),
values=[
(
user_id,
item.signing_key_id,
item.target_user_id,
item.target_device_id,
item.signature,
)
for item in signatures
],
)

to_invalidate = [
# Each entry is a tuple of arguments to
# `_get_e2e_cross_signing_signatures_for_device`, which
# itself takes a tuple. Hence the double-tuple.
((user_id, item.target_device_id),)
for item in signatures
],
desc="add_e2e_signing_key",
]

if to_invalidate:
# Invalidate the local cache of this worker.
for cache_key in to_invalidate:
txn.call_after(
self._get_e2e_cross_signing_signatures_for_device.invalidate,
cache_key,
)

# Stream cache invalidate keys over replication.
#
# We can only send a primitive per function argument across
# replication.
#
# Encode the array of strings as a JSON string, and we'll unpack
# it on the other side.
to_send = [
(json.dumps([user_id, item.target_device_id]),)
for item in signatures
Comment on lines +1873 to +1874
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not possible to json.dumps the entries of to_invalidate?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can; it looks like this:

to_send = [
    (json.dumps(item[0]),)
    for item in to_invalidate
]

which seemed less clear as to what we're actually sending than what's there now.

]

self._send_invalidation_to_replication_bulk(
txn,
cache_name=self._get_e2e_cross_signing_signatures_for_device.__name__,
key_tuples=to_send,
)

await self.db_pool.runInteraction(
"add_e2e_signing_key",
_store_e2e_cross_signing_signatures,
signatures,
)


Expand Down
7 changes: 5 additions & 2 deletions synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,9 +579,12 @@ def cachedList(
Used to do batch lookups for an already created cache. One of the arguments
is specified as a list that is iterated through to lookup keys in the
original cache. A new tuple consisting of the (deduplicated) keys that weren't in
the cache gets passed to the original function, which is expected to results
the cache gets passed to the original function, which is expected to result
in a map of key to value for each passed value. The new results are stored in the
original cache. Note that any missing values are cached as None.
original cache.

Note that any values in the input that end up being missing from both the
cache and the returned dictionary will be cached as `None`.

Args:
cached_method_name: The name of the single-item lookup method.
Expand Down
Loading