Skip to content
1 change: 1 addition & 0 deletions changelog.d/18899.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add an in-memory cache to `_get_e2e_cross_signing_signatures_for_devices` to reduce DB load.
30 changes: 30 additions & 0 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@
# As above, but for invalidating room caches on room deletion
DELETE_ROOM_CACHE_NAME = "dr_cache_fake"

# This cache takes a list of tuples as its first argument, which requires
# special handling.
GET_E2E_CROSS_SIGNING_SIGNATURES_FOR_DEVICE_CACHE_NAME = (
"_get_e2e_cross_signing_signatures_for_device"
)

# How long between cache invalidation table cleanups, once we have caught up
# with the backlog.
REGULAR_CLEANUP_INTERVAL_MS = Config.parse_duration("1h")
Expand Down Expand Up @@ -270,6 +276,30 @@ def process_replication_rows(
# room membership.
#
# self._membership_stream_cache.all_entities_changed(token) # type: ignore[attr-defined]
elif (
row.cache_func
== GET_E2E_CROSS_SIGNING_SIGNATURES_FOR_DEVICE_CACHE_NAME
):
# "keys" is a list of strings, where each string is a
# stringified representation of the tuple keys, i.e.
# keys: ['(@userid:domain,DEVICEID)','(@userid2:domain,DEVICEID2)']
#
# This is a side-effect of not being able to send nested information over replication.
for tuple_key in row.keys:
user_id, device_id = (
# Remove the leading and following parantheses.
tuple_key[1:-1]
# Split by comma
.split(",")
)

# Invalidate each key.
#
# Note: .invalidate takes a tuple of arguments, hence the need
# to nest our tuple in another tuple.
self._get_e2e_cross_signing_signatures_for_device.invalidate( # type: ignore[attr-defined]
((user_id, device_id),)
)
else:
self._attempt_to_invalidate_cache(row.cache_func, row.keys)

Expand Down
192 changes: 136 additions & 56 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,15 +354,17 @@ async def get_e2e_device_keys_and_signatures(
)

for batch in batch_iter(signature_query, 50):
cross_sigs_result = await self.db_pool.runInteraction(
"get_e2e_cross_signing_signatures_for_devices",
self._get_e2e_cross_signing_signatures_for_devices_txn,
batch,
cross_sigs_result = (
await self._get_e2e_cross_signing_signatures_for_devices(batch)
)

# add each cross-signing signature to the correct device in the result dict.
for user_id, key_id, device_id, signature in cross_sigs_result:
for (
user_id,
device_id,
), signature_list in cross_sigs_result.items():
target_device_result = result[user_id][device_id]

# We've only looked up cross-signatures for non-deleted devices with key
# data.
assert target_device_result is not None
Expand All @@ -373,7 +375,15 @@ async def get_e2e_device_keys_and_signatures(
signing_user_signatures = target_device_signatures.setdefault(
user_id, {}
)
signing_user_signatures[key_id] = signature

if signature_list is None:
# There are no signatures for this user_id/device_id combination.
# We do this here to ensure that the "signatures" key gets created above,
# even if it is empty.
continue

for key_id, signature in signature_list:
signing_user_signatures[key_id] = signature

log_kv(result)
return result
Expand Down Expand Up @@ -479,41 +489,85 @@ def get_e2e_device_keys_txn(

return result

def _get_e2e_cross_signing_signatures_for_devices_txn(
self, txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]]
) -> List[Tuple[str, str, str, str]]:
"""Get cross-signing signatures for a given list of devices
@cached()
def _get_e2e_cross_signing_signatures_for_device(
self,
user_id_and_device_id: Tuple[str, str],
) -> Sequence[Tuple[str, str]]:
"""
The single-item version of `_get_e2e_cross_signing_signatures_for_devices`.
See @cachedList for why a separate method is needed.
"""
raise NotImplementedError()

@cachedList(
cached_method_name="_get_e2e_cross_signing_signatures_for_device",
list_name="device_query",
)
async def _get_e2e_cross_signing_signatures_for_devices(
self, device_query: Iterable[Tuple[str, str]]
) -> Mapping[Tuple[str, str], Optional[Sequence[Tuple[str, str]]]]:
"""Get cross-signing signatures for a given list of user IDs and devices.

Args:
An iterable containing tuples of (user ID, device ID).

Returns:
A mapping of results. The keys are the original (user_id, device_id)
tuple, while the value is the matching list of tuples of
(key_id, signature). The value will be `None` instead if no
signatures exist for the device (this is a behaviour of
`@cachedList`).

Returns signatures made by the owners of the devices.
Given this method is annotated with `@cachedList`, the return dict's
keys match the tuples within `device_query`, so that cache entries can
be computed from the corresponding values.

Returns: a list of results; each entry in the list is a tuple of
(user_id, key_id, target_device_id, signature).
As results are cached, the return type is immutable.
"""
signature_query_clauses = []
signature_query_params = []

for user_id, device_id in device_query:
signature_query_clauses.append(
"target_user_id = ? AND target_device_id = ? AND user_id = ?"
)
signature_query_params.extend([user_id, device_id, user_id])

signature_sql = """
SELECT user_id, key_id, target_device_id, signature
FROM e2e_cross_signing_signatures WHERE %s
""" % (" OR ".join("(" + q + ")" for q in signature_query_clauses))

txn.execute(signature_sql, signature_query_params)
return cast(
List[
Tuple[
str,
str,
str,
str,
]
],
txn.fetchall(),
def _get_e2e_cross_signing_signatures_for_devices_txn(
txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]]
) -> Mapping[Tuple[str, str], Sequence[Tuple[str, str]]]:
signature_query_clauses = []
signature_query_params = []

for user_id, device_id in device_query:
signature_query_clauses.append(
"target_user_id = ? AND target_device_id = ? AND user_id = ?"
)
signature_query_params.extend([user_id, device_id, user_id])

signature_sql = """
SELECT user_id, key_id, target_device_id, signature
FROM e2e_cross_signing_signatures WHERE %s
""" % (" OR ".join("(" + q + ")" for q in signature_query_clauses))

txn.execute(signature_sql, signature_query_params)

devices_and_signatures: Dict[Tuple[str, str], List[Tuple[str, str]]] = {}

# `@cachedList` requires we return one key for every item in `device_query`.
# Pre-populate `devices_and_signatures` with each key so that none are missing.
#
# If any are missing, they will be cached as `None`, which is not
# what callers expected.
for user_id, device_id in device_query:
devices_and_signatures.setdefault((user_id, device_id), [])

# Populate the return dictionary with each found key_id and signature.
for user_id, key_id, target_device_id, signature in txn.fetchall():
signature_tuple = (key_id, signature)
devices_and_signatures[(user_id, target_device_id)].append(
signature_tuple
)

return devices_and_signatures

return await self.db_pool.runInteraction(
"_get_e2e_cross_signing_signatures_for_devices_txn",
_get_e2e_cross_signing_signatures_for_devices_txn,
device_query,
)

async def get_e2e_one_time_keys(
Expand Down Expand Up @@ -1772,26 +1826,52 @@ async def store_e2e_cross_signing_signatures(
user_id: the user who made the signatures
signatures: signatures to add
"""
await self.db_pool.simple_insert_many(
"e2e_cross_signing_signatures",
keys=(
"user_id",
"key_id",
"target_user_id",
"target_device_id",
"signature",
),
values=[
(
user_id,
item.signing_key_id,
item.target_user_id,
item.target_device_id,
item.signature,
)

def _store_e2e_cross_signing_signatures(
txn: LoggingTransaction,
signatures: "Iterable[SignatureListItem]",
) -> None:
self.db_pool.simple_insert_many_txn(
txn,
"e2e_cross_signing_signatures",
keys=(
"user_id",
"key_id",
"target_user_id",
"target_device_id",
"signature",
),
values=[
(
user_id,
item.signing_key_id,
item.target_user_id,
item.target_device_id,
item.signature,
)
for item in signatures
],
)

to_invalidate = [
# Each entry is a tuple of arguments to
# `_get_e2e_cross_signing_signatures_for_device`, which
# itself takes a tuple. Hence the double-tuple.
((user_id, item.target_device_id),)
for item in signatures
],
desc="add_e2e_signing_key",
]

if to_invalidate:
self._invalidate_cache_and_stream_bulk(
txn,
self._get_e2e_cross_signing_signatures_for_device,
to_invalidate,
)

await self.db_pool.runInteraction(
"add_e2e_signing_key",
_store_e2e_cross_signing_signatures,
signatures,
)


Expand Down
7 changes: 5 additions & 2 deletions synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,9 +579,12 @@ def cachedList(
Used to do batch lookups for an already created cache. One of the arguments
is specified as a list that is iterated through to lookup keys in the
original cache. A new tuple consisting of the (deduplicated) keys that weren't in
the cache gets passed to the original function, which is expected to results
the cache gets passed to the original function, which is expected to result
in a map of key to value for each passed value. The new results are stored in the
original cache. Note that any missing values are cached as None.
original cache.

Note that any values in the input that end up being missing from both the
cache and the returned dictionary will be cached as `None`.

Args:
cached_method_name: The name of the single-item lookup method.
Expand Down
Loading