Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Bulk-invalidate e2e cached queries after claiming keys (#16613)
Browse files Browse the repository at this point in the history
Co-authored-by: Patrick Cloke <[email protected]>
  • Loading branch information
David Robertson and clokep authored Nov 9, 2023
1 parent f6aa047 commit 91587d4
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 28 deletions.
1 change: 1 addition & 0 deletions changelog.d/16613.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve the performance of claiming encryption keys in multi-worker deployments.
2 changes: 1 addition & 1 deletion synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,7 +1116,7 @@ async def simple_insert_many(
def simple_insert_many_txn(
txn: LoggingTransaction,
table: str,
keys: Collection[str],
keys: Sequence[str],
values: Collection[Iterable[Any]],
) -> None:
"""Executes an INSERT query on the named table.
Expand Down
75 changes: 71 additions & 4 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,30 @@ def _invalidate_cache_and_stream(
txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)

def _invalidate_cache_and_stream_bulk(
self,
txn: LoggingTransaction,
cache_func: CachedFunction,
key_tuples: Collection[Tuple[Any, ...]],
) -> None:
"""A bulk version of _invalidate_cache_and_stream.
Locally invalidate every key-tuple in `key_tuples`, then emit invalidations
for each key-tuple over replication.
This implementation is more efficient than a loop which repeatedly calls the
non-bulk version.
"""
if not key_tuples:
return

for keys in key_tuples:
txn.call_after(cache_func.invalidate, keys)

self._send_invalidation_to_replication_bulk(
txn, cache_func.__name__, key_tuples
)

def _invalidate_all_cache_and_stream(
self, txn: LoggingTransaction, cache_func: CachedFunction
) -> None:
Expand Down Expand Up @@ -564,10 +588,6 @@ def _send_invalidation_to_replication(
if isinstance(self.database_engine, PostgresEngine):
assert self._cache_id_gen is not None

# get_next() returns a context manager which is designed to wrap
# the transaction. However, we want to only get an ID when we want
# to use it, here, so we need to call __enter__ manually, and have
# __exit__ called after the transaction finishes.
stream_id = self._cache_id_gen.get_next_txn(txn)
txn.call_after(self.hs.get_notifier().on_new_replication_data)

Expand All @@ -586,6 +606,53 @@ def _send_invalidation_to_replication(
},
)

def _send_invalidation_to_replication_bulk(
self,
txn: LoggingTransaction,
cache_name: str,
key_tuples: Collection[Tuple[Any, ...]],
) -> None:
"""Announce the invalidation of multiple (but not all) cache entries.
This is more efficient than repeated calls to the non-bulk version. It should
NOT be used to invalidating the entire cache: use
`_send_invalidation_to_replication` with keys=None.
Note that this does *not* invalidate the cache locally.
Args:
txn
cache_name
key_tuples: Key-tuples to invalidate. Assumed to be non-empty.
"""
if isinstance(self.database_engine, PostgresEngine):
assert self._cache_id_gen is not None

stream_ids = self._cache_id_gen.get_next_mult_txn(txn, len(key_tuples))
ts = self._clock.time_msec()
txn.call_after(self.hs.get_notifier().on_new_replication_data)
self.db_pool.simple_insert_many_txn(
txn,
table="cache_invalidation_stream_by_instance",
keys=(
"stream_id",
"instance_name",
"cache_func",
"keys",
"invalidation_ts",
),
values=[
# We convert key_tuples to a list here because psycopg2 serialises
# lists as pq arrrays, but serialises tuples as "composite types".
# (We need an array because the `keys` column has type `[]text`.)
# See:
# https://www.psycopg.org/docs/usage.html#adapt-list
# https://www.psycopg.org/docs/usage.html#adapt-tuple
(stream_id, self._instance_name, cache_name, list(key_tuple), ts)
for stream_id, key_tuple in zip(stream_ids, key_tuples)
],
)

def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
if self._cache_id_gen:
return self._cache_id_gen.get_current_token_for_writer(instance_name)
Expand Down
26 changes: 12 additions & 14 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,13 +1237,11 @@ def _claim_e2e_fallback_keys_bulk_txn(
for user_id, device_id, algorithm, key_id, key_json in claimed_keys:
device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)

if (user_id, device_id) in seen_user_device:
continue
seen_user_device.add((user_id, device_id))
self._invalidate_cache_and_stream(
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
)

self._invalidate_cache_and_stream_bulk(
txn, self.get_e2e_unused_fallback_key_types, seen_user_device
)

return results

Expand Down Expand Up @@ -1376,14 +1374,14 @@ def _claim_e2e_one_time_keys_bulk(
List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list)
)

seen_user_device: Set[Tuple[str, str]] = set()
for user_id, device_id, _, _, _ in otk_rows:
if (user_id, device_id) in seen_user_device:
continue
seen_user_device.add((user_id, device_id))
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
seen_user_device = {
(user_id, device_id) for user_id, device_id, _, _, _ in otk_rows
}
self._invalidate_cache_and_stream_bulk(
txn,
self.count_e2e_one_time_keys,
seen_user_device,
)

return otk_rows

Expand Down
56 changes: 47 additions & 9 deletions synapse/storage/util/id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,8 +650,8 @@ def get_next_txn(self, txn: LoggingTransaction) -> int:

next_id = self._load_next_id_txn(txn)

txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)
txn.call_after(self._mark_ids_as_finished, [next_id])
txn.call_on_exception(self._mark_ids_as_finished, [next_id])
txn.call_after(self._notifier.notify_replication)

# Update the `stream_positions` table with newly updated stream
Expand All @@ -671,14 +671,50 @@ def get_next_txn(self, txn: LoggingTransaction) -> int:

return self._return_factor * next_id

def _mark_id_as_finished(self, next_id: int) -> None:
"""The ID has finished being processed so we should advance the
def get_next_mult_txn(self, txn: LoggingTransaction, n: int) -> List[int]:
"""
Usage:
stream_id = stream_id_gen.get_next_txn(txn)
# ... persist event ...
"""

# If we have a list of instances that are allowed to write to this
# stream, make sure we're in it.
if self._writers and self._instance_name not in self._writers:
raise Exception("Tried to allocate stream ID on non-writer")

next_ids = self._load_next_mult_id_txn(txn, n)

txn.call_after(self._mark_ids_as_finished, next_ids)
txn.call_on_exception(self._mark_ids_as_finished, next_ids)
txn.call_after(self._notifier.notify_replication)

# Update the `stream_positions` table with newly updated stream
# ID (unless self._writers is not set in which case we don't
# bother, as nothing will read it).
#
# We only do this on the success path so that the persisted current
# position points to a persisted row with the correct instance name.
if self._writers:
txn.call_after(
run_as_background_process,
"MultiWriterIdGenerator._update_table",
self._db.runInteraction,
"MultiWriterIdGenerator._update_table",
self._update_stream_positions_table_txn,
)

return [self._return_factor * next_id for next_id in next_ids]

def _mark_ids_as_finished(self, next_ids: List[int]) -> None:
"""These IDs have finished being processed so we should advance the
current position if possible.
"""

with self._lock:
self._unfinished_ids.discard(next_id)
self._finished_ids.add(next_id)
self._unfinished_ids.difference_update(next_ids)
self._finished_ids.update(next_ids)

new_cur: Optional[int] = None

Expand Down Expand Up @@ -727,7 +763,10 @@ def _mark_id_as_finished(self, next_id: int) -> None:
curr, new_cur, self._max_position_of_local_instance
)

self._add_persisted_position(next_id)
# TODO Can we call this for just the last position or somehow batch
# _add_persisted_position.
for next_id in next_ids:
self._add_persisted_position(next_id)

def get_current_token(self) -> int:
return self.get_persisted_upto_position()
Expand Down Expand Up @@ -933,8 +972,7 @@ async def __aexit__(
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> bool:
for i in self.stream_ids:
self.id_gen._mark_id_as_finished(i)
self.id_gen._mark_ids_as_finished(self.stream_ids)

self.notifier.notify_replication()

Expand Down
117 changes: 117 additions & 0 deletions tests/storage/databases/main/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock, call

from synapse.storage.database import LoggingTransaction

from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.unittest import HomeserverTestCase


class CacheInvalidationTestCase(HomeserverTestCase):
def setUp(self) -> None:
super().setUp()
self.store = self.hs.get_datastores().main

def test_bulk_invalidation(self) -> None:
master_invalidate = Mock()

self.store._get_cached_user_device.invalidate = master_invalidate

keys_to_invalidate = [
("a", "b"),
("c", "d"),
("e", "f"),
("g", "h"),
]

def test_txn(txn: LoggingTransaction) -> None:
self.store._invalidate_cache_and_stream_bulk(
txn,
# This is an arbitrarily chosen cached store function. It was chosen
# because it takes more than one argument. We'll use this later to
# check that the invalidation was actioned over replication.
cache_func=self.store._get_cached_user_device,
key_tuples=keys_to_invalidate,
)

self.get_success(
self.store.db_pool.runInteraction(
"test_invalidate_cache_and_stream_bulk", test_txn
)
)

master_invalidate.assert_has_calls(
[call(key_list) for key_list in keys_to_invalidate],
any_order=True,
)


class CacheInvalidationOverReplicationTestCase(BaseMultiWorkerStreamTestCase):
def setUp(self) -> None:
super().setUp()
self.store = self.hs.get_datastores().main

def test_bulk_invalidation_replicates(self) -> None:
"""Like test_bulk_invalidation, but also checks the invalidations replicate."""
master_invalidate = Mock()
worker_invalidate = Mock()

self.store._get_cached_user_device.invalidate = master_invalidate
worker = self.make_worker_hs("synapse.app.generic_worker")
worker_ds = worker.get_datastores().main
worker_ds._get_cached_user_device.invalidate = worker_invalidate

keys_to_invalidate = [
("a", "b"),
("c", "d"),
("e", "f"),
("g", "h"),
]

def test_txn(txn: LoggingTransaction) -> None:
self.store._invalidate_cache_and_stream_bulk(
txn,
# This is an arbitrarily chosen cached store function. It was chosen
# because it takes more than one argument. We'll use this later to
# check that the invalidation was actioned over replication.
cache_func=self.store._get_cached_user_device,
key_tuples=keys_to_invalidate,
)

assert self.store._cache_id_gen is not None
initial_token = self.store._cache_id_gen.get_current_token()
self.get_success(
self.database_pool.runInteraction(
"test_invalidate_cache_and_stream_bulk", test_txn
)
)
second_token = self.store._cache_id_gen.get_current_token()

self.assertGreaterEqual(second_token, initial_token + len(keys_to_invalidate))

self.get_success(
worker.get_replication_data_handler().wait_for_stream_position(
"master", "caches", second_token
)
)

master_invalidate.assert_has_calls(
[call(key_list) for key_list in keys_to_invalidate],
any_order=True,
)
worker_invalidate.assert_has_calls(
[call(key_list) for key_list in keys_to_invalidate],
any_order=True,
)

0 comments on commit 91587d4

Please sign in to comment.