Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Claim local one-time-keys in bulk #16565

Merged
merged 15 commits into from
Oct 30, 2023
1 change: 1 addition & 0 deletions changelog.d/16565.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve the performance of claiming encryption keys.
10 changes: 10 additions & 0 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,16 @@ async def claim_client_keys(destination: str) -> None:
async def upload_keys_for_user(
self, user_id: str, device_id: str, keys: JsonDict
) -> JsonDict:
"""
Args:
user_id: user whose keys are being uploaded.
device_id: device whose keys are being uploaded.
keys: the body of a /keys/upload request.

Returns a dictionary with one field:
"one_time_keys": A mapping from algorithm to number of keys for that
algorithm, including those previously persisted.
"""
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)

Expand Down
256 changes: 141 additions & 115 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
Expand Down Expand Up @@ -1110,7 +1111,7 @@ def get_device_stream_token(self) -> int:
...

async def claim_e2e_one_time_keys(
self, query_list: Iterable[Tuple[str, str, str, int]]
self, query_list: Collection[Tuple[str, str, str, int]]
) -> Tuple[
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
]:
Expand All @@ -1120,131 +1121,63 @@ async def claim_e2e_one_time_keys(
query_list: An iterable of tuples of (user ID, device ID, algorithm).

Returns:
A tuple pf:
A tuple (results, missing) of:
A map of user ID -> a map device ID -> a map of key ID -> JSON.

A copy of the input which has not been fulfilled.
A copy of the input which has not been fulfilled. The returned counts
may be less than the input counts. In this case, the returned counts
are the number of claims that were not fulfilled.
"""

@trace
def _claim_e2e_one_time_key_simple(
txn: LoggingTransaction,
user_id: str,
device_id: str,
algorithm: str,
count: int,
) -> List[Tuple[str, str]]:
"""Claim OTK for device for DBs that don't support RETURNING.

Returns:
A tuple of key name (algorithm + key ID) and key JSON, if an
OTK was found.
"""

sql = """
SELECT key_id, key_json FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ?
LIMIT ?
"""

txn.execute(sql, (user_id, device_id, algorithm, count))
otk_rows = list(txn)
if not otk_rows:
return []

self.db_pool.simple_delete_many_txn(
txn,
table="e2e_one_time_keys_json",
column="key_id",
values=[otk_row[0] for otk_row in otk_rows],
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
},
)
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)

return [
(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
]

@trace
def _claim_e2e_one_time_key_returning(
txn: LoggingTransaction,
user_id: str,
device_id: str,
algorithm: str,
count: int,
) -> List[Tuple[str, str]]:
"""Claim OTK for device for DBs that support RETURNING.

Returns:
A tuple of key name (algorithm + key ID) and key JSON, if an
OTK was found.
"""

# We can use RETURNING to do the fetch and DELETE in once step.
sql = """
DELETE FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ?
AND key_id IN (
SELECT key_id FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ?
LIMIT ?
)
RETURNING key_id, key_json
"""

txn.execute(
sql,
(user_id, device_id, algorithm, user_id, device_id, algorithm, count),
)
otk_rows = list(txn)
if not otk_rows:
return []

self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)

return [
(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
]

results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
missing: List[Tuple[str, str, str, int]] = []
for user_id, device_id, algorithm, count in query_list:
if self.database_engine.supports_returning:
# If we support RETURNING clause we can use a single query that
# allows us to use autocommit mode.
_claim_e2e_one_time_key = _claim_e2e_one_time_key_returning
db_autocommit = True
else:
_claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
db_autocommit = False

claim_rows = await self.db_pool.runInteraction(
if isinstance(self.database_engine, PostgresEngine):
# If we can use execute_values we can use a single batch query
# in autocommit mode.
unfulfilled_claim_counts: Dict[Tuple[str, str, str], int] = {}
for user_id, device_id, algorithm, count in query_list:
unfulfilled_claim_counts[user_id, device_id, algorithm] = count

bulk_claims = await self.db_pool.runInteraction(
"claim_e2e_one_time_keys",
_claim_e2e_one_time_key,
user_id,
device_id,
algorithm,
count,
db_autocommit=db_autocommit,
self._claim_e2e_one_time_keys_bulk,
query_list,
db_autocommit=True,
)
if claim_rows:

for user_id, device_id, algorithm, key_id, key_json in bulk_claims:
device_results = results.setdefault(user_id, {}).setdefault(
device_id, {}
)
for claim_row in claim_rows:
device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
unfulfilled_claim_counts[(user_id, device_id, algorithm)] -= 1

# Did we get enough OTKs?
count -= len(claim_rows)
if count:
missing.append((user_id, device_id, algorithm, count))
missing = [
(user, device, alg, count)
for (user, device, alg), count in unfulfilled_claim_counts.items()
if count > 0
]
else:
for user_id, device_id, algorithm, count in query_list:
claim_rows = await self.db_pool.runInteraction(
"claim_e2e_one_time_keys",
self._claim_e2e_one_time_key_simple,
user_id,
device_id,
algorithm,
count,
db_autocommit=False,
)
if claim_rows:
device_results = results.setdefault(user_id, {}).setdefault(
device_id, {}
)
for claim_row in claim_rows:
device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
# Did we get enough OTKs?
count -= len(claim_rows)
if count:
missing.append((user_id, device_id, algorithm, count))

return results, missing

Expand Down Expand Up @@ -1302,6 +1235,99 @@ async def claim_e2e_fallback_keys(

return results

@trace
def _claim_e2e_one_time_key_simple(
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
self,
txn: LoggingTransaction,
user_id: str,
device_id: str,
algorithm: str,
count: int,
) -> List[Tuple[str, str]]:
"""Claim OTK for device for DBs that don't support RETURNING.

Returns:
A tuple of key name (algorithm + key ID) and key JSON, if an
OTK was found.
"""

sql = """
SELECT key_id, key_json FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ?
LIMIT ?
"""

txn.execute(sql, (user_id, device_id, algorithm, count))
otk_rows = list(txn)
if not otk_rows:
return []

self.db_pool.simple_delete_many_txn(
txn,
table="e2e_one_time_keys_json",
column="key_id",
values=[otk_row[0] for otk_row in otk_rows],
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
},
)
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)

return [(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows]

@trace
def _claim_e2e_one_time_keys_bulk(
self,
txn: LoggingTransaction,
query_list: Iterable[Tuple[str, str, str, int]],
) -> List[Tuple[str, str, str, str, str]]:
"""Bulk claim OTKs, for DBs that support DELETE FROM... RETURNING.

Args:
query_list: Collection of tuples (user_id, device_id, algorithm, count)
as passed to claim_e2e_one_time_keys.

Returns:
A list of tuples (user_id, device_id, algorithm, key_id, key_json)
for each OTK claimed.
"""
sql = """
WITH claims(user_id, device_id, algorithm, claim_count) AS (
VALUES ?
), ranked_keys AS (
SELECT
user_id, device_id, algorithm, key_id, claim_count,
ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r
FROM e2e_one_time_keys_json
JOIN claims USING (user_id, device_id, algorithm)
)
DELETE FROM e2e_one_time_keys_json k
WHERE (user_id, device_id, algorithm, key_id) IN (
SELECT user_id, device_id, algorithm, key_id
FROM ranked_keys
WHERE r <= claim_count
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
)
RETURNING user_id, device_id, algorithm, key_id, key_json;
"""
otk_rows = cast(
List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list)
)

seen_user_device: Set[Tuple[str, str]] = set()
for user_id, device_id, _, _, _ in otk_rows:
if (user_id, device_id) in seen_user_device:
continue
seen_user_device.add((user_id, device_id))
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
Comment on lines +1320 to +1327
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old code never used to have this deduplication. If the same (user, device) showed up twice (either with multiple algorithms or claiming multiple keys) we'd send out more than one invalidation for that (user, device). I don't know how much this matters in practice.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how much this matters in practice.

I suspect it is just inefficient, but doesn't matter too much?


return otk_rows


class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def __init__(
Expand Down
Loading
Loading