Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up fetching latest stream positions via cache #17606

Merged
merged 3 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/17606.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Speed up incremental syncs in sliding sync by adding some more caching.
6 changes: 6 additions & 0 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,8 @@ def _invalidate_caches_for_event(
"get_unread_event_push_actions_by_room_for_user", (room_id,)
)

self._attempt_to_invalidate_cache("_get_max_event_pos", (room_id,))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we should be invalidating in other spots. For example, just tracing where we bust the cache for get_unread_event_push_actions_by_room_for_user

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh, I think most of those end up calling each other? We should clean those up, but for now I've just added more invalidations.


# The `_get_membership_from_event_id` is immutable, except for the
# case where we look up an event *before* persisting it.
self._attempt_to_invalidate_cache("_get_membership_from_event_id", (event_id,))
Expand Down Expand Up @@ -404,6 +406,8 @@ def _invalidate_caches_for_room_events(self, room_id: str) -> None:
)
self._attempt_to_invalidate_cache("get_relations_for_event", (room_id,))

self._attempt_to_invalidate_cache("_get_max_event_pos", (room_id,))

self._attempt_to_invalidate_cache("_get_membership_from_event_id", None)
self._attempt_to_invalidate_cache("get_applicable_edit", None)
self._attempt_to_invalidate_cache("get_thread_id", None)
Expand Down Expand Up @@ -476,6 +480,8 @@ def _invalidate_caches_for_room(self, room_id: str) -> None:
self._attempt_to_invalidate_cache("get_room_type", (room_id,))
self._attempt_to_invalidate_cache("get_room_encryption", (room_id,))

self._attempt_to_invalidate_cache("_get_max_event_pos", (room_id,))

# And delete state caches.

self._invalidate_state_caches_all(room_id)
Expand Down
137 changes: 70 additions & 67 deletions synapse/storage/databases/main/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
Dict,
Iterable,
List,
Mapping,
Optional,
Protocol,
Set,
Expand Down Expand Up @@ -80,7 +81,7 @@
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import PersistedEventPosition, RoomStreamToken, StrCollection
from synapse.util.caches.descriptors import cached
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
Expand Down Expand Up @@ -1381,40 +1382,85 @@ async def bulk_get_last_event_pos_in_room_before_stream_ordering(
rooms
"""

# First we just get the latest positions for the room, as the vast
# majority of them will be before the given end token anyway. By doing
# this we can cache most rooms.
uncapped_results = await self._bulk_get_max_event_pos(room_ids)

# Check that the stream position for the rooms are from before the
# minimum position of the token. If not then we need to fetch more
# rows.
results: Dict[str, int] = {}
recheck_rooms: Set[str] = set()
min_token = end_token.stream
max_token = end_token.get_max_stream_pos()
for room_id, stream in uncapped_results.items():
if stream <= min_token:
results[room_id] = stream
else:
recheck_rooms.add(room_id)

if not recheck_rooms:
return results

# There shouldn't be many rooms that we need to recheck, so we do them
# one-by-one.
for room_id in recheck_rooms:
result = await self.get_last_event_pos_in_room_before_stream_ordering(
room_id, end_token
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like we could additionally use the previous bulk lookup from before for the recheck_rooms here now.

But this PR seems good if we've profiled that this new setup is faster than previous because of the nice cache skip for the majority of rooms anyway.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it mostly felt like a bunch of extra code to maintain that (should) basically never be hit.

if result is not None:
results[room_id] = result[1].stream

return results

@cached()
async def _get_max_event_pos(self, room_id: str) -> int:
raise NotImplementedError()

@cachedList(cached_method_name="_get_max_event_pos", list_name="room_ids")
async def _bulk_get_max_event_pos(
self, room_ids: StrCollection
) -> Mapping[str, int]:
"""Fetch the max position of a persisted event in the room."""

# We need to be careful not to return positions ahead of the current
# positions, so we get the current token now and cap our queries to it.
now_token = self.get_room_max_token()
max_pos = now_token.get_max_stream_pos()

results: Dict[str, int] = {}

# First, we check for the rooms in the stream change cache to see if we
# can just use the latest position from it.
missing_room_ids: Set[str] = set()
for room_id in room_ids:
stream_pos = self._events_stream_cache.get_max_pos_of_last_change(room_id)
if stream_pos and stream_pos <= min_token:
if stream_pos is not None:
results[room_id] = stream_pos
else:
missing_room_ids.add(room_id)

if not missing_room_ids:
return results

# Next, we query the stream position from the DB. At first we fetch all
# positions less than the *max* stream pos in the token, then filter
# them down. We do this as a) this is a cheaper query, and b) the vast
# majority of rooms will have a latest token from before the min stream
# pos.

def bulk_get_last_event_pos_txn(
txn: LoggingTransaction, batch_room_ids: StrCollection
def bulk_get_max_event_pos_txn(
txn: LoggingTransaction, batched_room_ids: StrCollection
) -> Dict[str, int]:
# This query fetches the latest stream position in the rooms before
# the given max position.
clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", batch_room_ids
self.database_engine, "room_id", batched_room_ids
)
sql = f"""
SELECT room_id, (
SELECT stream_ordering FROM events AS e
LEFT JOIN rejections USING (event_id)
WHERE e.room_id = r.room_id
AND stream_ordering <= ?
AND e.stream_ordering <= ?
AND NOT outlier
AND rejection_reason IS NULL
ORDER BY stream_ordering DESC
Expand All @@ -1423,72 +1469,29 @@ def bulk_get_last_event_pos_txn(
FROM rooms AS r
WHERE {clause}
"""
txn.execute(sql, [max_token] + args)
txn.execute(sql, [max_pos] + args)
return {row[0]: row[1] for row in txn}

recheck_rooms: Set[str] = set()
for batched in batch_iter(missing_room_ids, 1000):
result = await self.db_pool.runInteraction(
"bulk_get_last_event_pos_in_room_before_stream_ordering",
bulk_get_last_event_pos_txn,
batched,
for batched in batch_iter(room_ids, 1000):
batch_results = await self.db_pool.runInteraction(
"_bulk_get_max_event_pos", bulk_get_max_event_pos_txn, batched
)

# Check that the stream position for the rooms are from before the
# minimum position of the token. If not then we need to fetch more
# rows.
for room_id, stream in result.items():
if stream <= min_token:
results[room_id] = stream
for room_id, stream_ordering in batch_results.items():
if stream_ordering <= now_token.stream:
results.update(batch_results)
else:
recheck_rooms.add(room_id)

if not recheck_rooms:
return results

# For the remaining rooms we need to fetch all rows between the min and
# max stream positions in the end token, and filter out the rows that
# are after the end token.
#
# This query should be fast as the range between the min and max should
# be small.

def bulk_get_last_event_pos_recheck_txn(
txn: LoggingTransaction, batch_room_ids: StrCollection
) -> Dict[str, int]:
clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", batch_room_ids
)
sql = f"""
SELECT room_id, instance_name, stream_ordering
FROM events
WHERE ? < stream_ordering AND stream_ordering <= ?
AND NOT outlier
AND rejection_reason IS NULL
AND {clause}
ORDER BY stream_ordering ASC
"""
txn.execute(sql, [min_token, max_token] + args)

# We take the max stream ordering that is less than the token. Since
# we ordered by stream ordering we just need to iterate through and
# take the last matching stream ordering.
txn_results: Dict[str, int] = {}
for row in txn:
room_id = row[0]
event_pos = PersistedEventPosition(row[1], row[2])
if not event_pos.persisted_after(end_token):
txn_results[room_id] = event_pos.stream

return txn_results

for batched in batch_iter(recheck_rooms, 1000):
recheck_result = await self.db_pool.runInteraction(
"bulk_get_last_event_pos_in_room_before_stream_ordering_recheck",
bulk_get_last_event_pos_recheck_txn,
batched,
# We now need to handle rooms where the above query returned a stream
# position that was potentially too new. This should happen very rarely
# so we just query the rooms one-by-one
for room_id in recheck_rooms:
result = await self.get_last_event_pos_in_room_before_stream_ordering(
room_id, now_token
)
results.update(recheck_result)
if result is not None:
results[room_id] = result[1].stream

return results

Expand Down
Loading