Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Update all stream IDs after processing replication rows #14723

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/14723.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Ensure stream IDs are always updated after caches get invalidated with workers. Contributed by Nick @ Beeper (@fizzadar).
3 changes: 3 additions & 0 deletions synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ async def on_rdata(
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
"""
self.store.process_replication_rows(stream_name, instance_name, token, rows)
# NOTE: this must be called after process_replication_rows to ensure any
# cache invalidations are first handled before any stream ID advances.
self.store.process_replication_position(stream_name, instance_name, token)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a quick comment here to say that process_replication_position must be called after process_replication_rows please? Just so that anyone looking at this function knows that this is a bit sensitive.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


if self.send_handler:
await self.send_handler.process_replication_rows(stream_name, token, rows)
Expand Down
17 changes: 16 additions & 1 deletion synapse/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,22 @@ def process_replication_rows( # noqa: B027 (no-op by design)
token: int,
rows: Iterable[Any],
) -> None:
pass
"""
Used by storage classes to invalidate caches based on incoming replication data. These
must not update any ID generators, use `process_replication_position`.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you mention that we should not be advancing stream ID generators here please? Just to spell out which function should be used for what.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


def process_replication_position( # noqa: B027 (no-op by design)
self,
stream_name: str,
instance_name: str,
token: int,
) -> None:
"""
Used by storage classes to advance ID generators based on incoming replication data. This
is called after process_replication_rows such that caches are invalidated before any token
positions advance.
"""

def _invalidate_state_caches(
self, room_id: str, members_changed: Collection[str]
Expand Down
14 changes: 10 additions & 4 deletions synapse/storage/databases/main/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,7 @@ def process_replication_rows(
token: int,
rows: Iterable[Any],
) -> None:
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
elif stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
if stream_name == AccountDataStream.NAME:
for row in rows:
if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate(
Expand All @@ -429,6 +426,15 @@ def process_replication_rows(

super().process_replication_rows(stream_name, instance_name, token, rows)

def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
elif stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)

async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
) -> int:
Expand Down
11 changes: 8 additions & 3 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,6 @@ def process_replication_rows(
backfilled=True,
)
elif stream_name == CachesStream.NAME:
if self._cache_id_gen:
self._cache_id_gen.advance(instance_name, token)

for row in rows:
if row.cache_func == CURRENT_STATE_CACHE_NAME:
if row.keys is None:
Expand All @@ -182,6 +179,14 @@ def process_replication_rows(

super().process_replication_rows(stream_name, instance_name, token, rows)

def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == CachesStream.NAME:
if self._cache_id_gen:
self._cache_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)

def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
data = row.data

Expand Down
7 changes: 7 additions & 0 deletions synapse/storage/databases/main/deviceinbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,13 @@ def process_replication_rows(
)
return super().process_replication_rows(stream_name, instance_name, token, rows)

def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == ToDeviceStream.NAME:
self._device_inbox_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)

def get_to_device_stream_token(self) -> int:
return self._device_inbox_id_gen.get_current_token()

Expand Down
11 changes: 9 additions & 2 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,21 @@ def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
self._invalidate_caches_for_devices(token, rows)
elif stream_name == UserSignatureStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)

def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
elif stream_name == UserSignatureStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)

def _invalidate_caches_for_devices(
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
) -> None:
Expand Down
15 changes: 10 additions & 5 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,11 +388,7 @@ def process_replication_rows(
token: int,
rows: Iterable[Any],
) -> None:
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(instance_name, token)
elif stream_name == BackfillStream.NAME:
self._backfill_id_gen.advance(instance_name, -token)
elif stream_name == UnPartialStatedEventStream.NAME:
if stream_name == UnPartialStatedEventStream.NAME:
for row in rows:
assert isinstance(row, UnPartialStatedEventStreamRow)

Expand All @@ -405,6 +401,15 @@ def process_replication_rows(

super().process_replication_rows(stream_name, instance_name, token, rows)

def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(instance_name, token)
elif stream_name == BackfillStream.NAME:
self._backfill_id_gen.advance(instance_name, -token)
super().process_replication_position(stream_name, instance_name, token)

async def have_censored_event(self, event_id: str) -> bool:
"""Check if an event has been censored, i.e. if the content of the event has been erased
from the database due to a redaction.
Expand Down
8 changes: 7 additions & 1 deletion synapse/storage/databases/main/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,14 @@ def process_replication_rows(
rows: Iterable[Any],
) -> None:
if stream_name == PresenceStream.NAME:
self._presence_id_gen.advance(instance_name, token)
for row in rows:
self.presence_stream_cache.entity_has_changed(row.user_id, token)
self._get_presence_for_user.invalidate((row.user_id,))
return super().process_replication_rows(stream_name, instance_name, token, rows)

def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == PresenceStream.NAME:
self._presence_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)
7 changes: 7 additions & 0 deletions synapse/storage/databases/main/push_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,13 @@ def process_replication_rows(
self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)

def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)

@cached(max_entries=5000)
async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
rows = await self.db_pool.simple_select_list(
Expand Down
6 changes: 3 additions & 3 deletions synapse/storage/databases/main/pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,12 @@ def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()

def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == PushersStream.NAME:
self._pushers_id_gen.advance(instance_name, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
super().process_replication_position(stream_name, instance_name, token)

async def get_pushers_by_app_id_and_pushkey(
self, app_id: str, pushkey: str
Expand Down
7 changes: 7 additions & 0 deletions synapse/storage/databases/main/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,13 @@ def process_replication_rows(

return super().process_replication_rows(stream_name, instance_name, token, rows)

def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)

def _insert_linearized_receipt_txn(
self,
txn: LoggingTransaction,
Expand Down
8 changes: 7 additions & 1 deletion synapse/storage/databases/main/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,13 +300,19 @@ def process_replication_rows(
rows: Iterable[Any],
) -> None:
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token)

super().process_replication_rows(stream_name, instance_name, token, rows)

def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)


class TagsStore(TagsWorkerStore):
pass