Skip to content
This repository has been archived by the owner on Mar 26, 2024. It is now read-only.

Commit

Permalink
Update all stream IDs after processing replication rows (matrix-org#1…
Browse files Browse the repository at this point in the history
…4723) (#52)

* Update all stream IDs after processing replication rows (matrix-org#14723)

This creates a new store method, `process_replication_position` that
is called after `process_replication_rows`. By moving stream ID advances
here this guarantees any relevant cache invalidations will have been
applied before the stream is advanced.

This avoids race conditions where Python switches between threads mid
way through processing the `process_replication_rows` method where stream
IDs may be advanced before caches are invalidated due to class resolution
ordering.

See this comment/issue for further discussion:
	matrix-org#14158 (comment)
# Conflicts:
#	synapse/storage/databases/main/devices.py
#	synapse/storage/databases/main/events_worker.py

* Fix bad cherry-picking

* Remove leftover stream advance
  • Loading branch information
Fizzadar authored Jan 17, 2023
1 parent 90878d6 commit c71199e
Show file tree
Hide file tree
Showing 15 changed files with 115 additions and 66 deletions.
1 change: 1 addition & 0 deletions changelog.d/14723.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Ensure stream IDs are always updated after caches get invalidated with workers. Contributed by Nick @ Beeper (@fizzadar).
3 changes: 3 additions & 0 deletions synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ async def on_rdata(
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
"""
self.store.process_replication_rows(stream_name, instance_name, token, rows)
# NOTE: this must be called after process_replication_rows to ensure any
# cache invalidations are first handled before any stream ID advances.
self.store.process_replication_position(stream_name, instance_name, token)

if self.send_handler:
await self.send_handler.process_replication_rows(stream_name, token, rows)
Expand Down
17 changes: 16 additions & 1 deletion synapse/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,22 @@ def process_replication_rows( # noqa: B027 (no-op by design)
token: int,
rows: Iterable[Any],
) -> None:
pass
"""
Used by storage classes to invalidate caches based on incoming replication data. These
must not update any ID generators, use `process_replication_position`.
"""

def process_replication_position( # noqa: B027 (no-op by design)
self,
stream_name: str,
instance_name: str,
token: int,
) -> None:
"""
Used by storage classes to advance ID generators based on incoming replication data. This
is called after process_replication_rows such that caches are invalidated before any token
positions advance.
"""

def _invalidate_state_caches(
self, room_id: str, members_changed: Collection[str]
Expand Down
14 changes: 10 additions & 4 deletions synapse/storage/databases/main/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,10 +415,7 @@ def process_replication_rows(
token: int,
rows: Iterable[Any],
) -> None:
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
elif stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
if stream_name == AccountDataStream.NAME:
for row in rows:
if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate(
Expand All @@ -433,6 +430,15 @@ def process_replication_rows(

super().process_replication_rows(stream_name, instance_name, token, rows)

def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
elif stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)

async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
) -> int:
Expand Down
21 changes: 16 additions & 5 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,6 @@ def process_replication_rows(
backfilled=True,
)
elif stream_name == CachesStream.NAME:
if self._cache_id_gen:
self._cache_id_gen.advance(instance_name, token)

for row in rows:
if row.cache_func == CURRENT_STATE_CACHE_NAME:
if row.keys is None:
Expand All @@ -182,6 +179,14 @@ def process_replication_rows(

super().process_replication_rows(stream_name, instance_name, token, rows)

def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == CachesStream.NAME:
if self._cache_id_gen:
self._cache_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)

def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
data = row.data

Expand All @@ -198,8 +203,14 @@ def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
backfilled=False,
)
elif row.type == EventsStreamCurrentStateRow.TypeId:
# TODO: Nothing to do here, handled in events_worker, cleanup?
pass
assert isinstance(data, EventsStreamCurrentStateRow)
self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token)

if data.type == EventTypes.Member:
self.get_rooms_for_user_with_stream_ordering.invalidate(
(data.state_key,)
)
self.get_rooms_for_user.invalidate((data.state_key,))
else:
raise Exception("Unknown events stream row type %s" % (row.type,))

Expand Down
9 changes: 7 additions & 2 deletions synapse/storage/databases/main/deviceinbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,15 @@ def process_replication_rows(
self._device_federation_outbox_stream_cache.entity_has_changed(
row.entity, token
)
# Important that the ID gen advances after stream change caches
self._device_inbox_id_gen.advance(instance_name, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)

def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == ToDeviceStream.NAME:
self._device_inbox_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)

def get_to_device_stream_token(self) -> int:
return self._device_inbox_id_gen.get_current_token()

Expand Down
13 changes: 9 additions & 4 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,20 @@ def process_replication_rows(
) -> None:
if stream_name == DeviceListsStream.NAME:
self._invalidate_caches_for_devices(token, rows)
# Important that the ID gen advances after stream change caches
self._device_list_id_gen.advance(instance_name, token)
elif stream_name == UserSignatureStream.NAME:
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
# Important that the ID gen advances after stream change caches
self._device_list_id_gen.advance(instance_name, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)

def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
elif stream_name == UserSignatureStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)

def _invalidate_caches_for_devices(
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,7 +1187,7 @@ async def get_forward_extremities_for_room_at_stream_ordering(
"""
# We want to make the cache more effective, so we clamp to the last
# change before the given ordering.
last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)
last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) # type: ignore[attr-defined]

# We don't always have a full stream_to_exterm_id table, e.g. after
# the upgrade that introduced it, so we make sure we never ask for a
Expand Down
43 changes: 3 additions & 40 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,22 +249,6 @@ def __init__(
prefilled_cache=curr_state_delta_prefill,
)

event_cache_prefill, min_event_val = self.db_pool.get_cache_dict(
db_conn,
"events",
entity_column="room_id",
stream_column="stream_ordering",
max_value=events_max,
)
self._events_stream_cache = StreamChangeCache(
"EventsRoomStreamChangeCache",
min_event_val,
prefilled_cache=event_cache_prefill,
)
self._membership_stream_cache = StreamChangeCache(
"MembershipStreamChangeCache", events_max
)

if hs.config.worker.run_background_tasks:
# We periodically clean out old transaction ID mappings
self._clock.looping_call(
Expand Down Expand Up @@ -325,35 +309,14 @@ def get_chain_id_txn(txn: Cursor) -> int:
id_column="chain_id",
)

def process_replication_rows(
self,
stream_name: str,
instance_name: str,
token: int,
rows: Iterable[Any],
def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
# Process event stream replication rows, handling both the ID generators from the events
# worker store and the stream change caches in this store as the two are interlinked.
if stream_name == EventsStream.NAME:
for row in rows:
if row.type == EventsStreamEventRow.TypeId:
self._events_stream_cache.entity_has_changed(
row.data.room_id, token
)
if row.data.type == EventTypes.Member:
self._membership_stream_cache.entity_has_changed(
row.data.state_key, token
)
if row.type == EventsStreamCurrentStateRow.TypeId:
self._curr_state_delta_stream_cache.entity_has_changed(
row.data.room_id, token
)
# Important that the ID gen advances after stream change caches
self._stream_id_gen.advance(instance_name, token)
elif stream_name == BackfillStream.NAME:
self._backfill_id_gen.advance(instance_name, -token)

super().process_replication_rows(stream_name, instance_name, token, rows)
super().process_replication_position(stream_name, instance_name, token)

async def have_censored_event(self, event_id: str) -> bool:
"""Check if an event has been censored, i.e. if the content of the event has been erased
Expand Down
8 changes: 7 additions & 1 deletion synapse/storage/databases/main/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,14 @@ def process_replication_rows(
rows: Iterable[Any],
) -> None:
if stream_name == PresenceStream.NAME:
self._presence_id_gen.advance(instance_name, token)
for row in rows:
self.presence_stream_cache.entity_has_changed(row.user_id, token)
self._get_presence_for_user.invalidate((row.user_id,))
return super().process_replication_rows(stream_name, instance_name, token, rows)

def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == PresenceStream.NAME:
self._presence_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)
8 changes: 7 additions & 1 deletion synapse/storage/databases/main/push_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,18 @@ def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(instance_name, token)
for row in rows:
self.get_push_rules_for_user.invalidate((row.user_id,))
self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)

def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)

@cached(max_entries=5000)
async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
rows = await self.db_pool.simple_select_list(
Expand Down
6 changes: 3 additions & 3 deletions synapse/storage/databases/main/pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,12 @@ def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()

def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == PushersStream.NAME:
self._pushers_id_gen.advance(instance_name, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
super().process_replication_position(stream_name, instance_name, token)

async def get_pushers_by_app_id_and_pushkey(
self, app_id: str, pushkey: str
Expand Down
10 changes: 7 additions & 3 deletions synapse/storage/databases/main/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,11 +600,15 @@ def process_replication_rows(
row.room_id, row.receipt_type, row.user_id
)
self._receipts_stream_cache.entity_has_changed(row.room_id, token)
# Important that the ID gen advances after stream change caches
self._receipts_id_gen.advance(instance_name, token)

return super().process_replication_rows(stream_name, instance_name, token, rows)

def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)

def _insert_linearized_receipt_txn(
self,
txn: LoggingTransaction,
Expand Down
18 changes: 18 additions & 0 deletions synapse/storage/databases/main/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import PersistedEventPosition, RoomStreamToken
from synapse.util.caches.descriptors import cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.cancellation import cancellable

if TYPE_CHECKING:
Expand Down Expand Up @@ -396,6 +397,23 @@ def __init__(
# during startup which would cause one to die.
self._need_to_reset_federation_stream_positions = self._send_federation

events_max = self.get_room_max_stream_ordering()
event_cache_prefill, min_event_val = self.db_pool.get_cache_dict(
db_conn,
"events",
entity_column="room_id",
stream_column="stream_ordering",
max_value=events_max,
)
self._events_stream_cache = StreamChangeCache(
"EventsRoomStreamChangeCache",
min_event_val,
prefilled_cache=event_cache_prefill,
)
self._membership_stream_cache = StreamChangeCache(
"MembershipStreamChangeCache", events_max
)

self._stream_order_on_start = self.get_room_max_stream_ordering()
self._min_stream_order_on_start = self.get_room_min_stream_ordering()

Expand Down
8 changes: 7 additions & 1 deletion synapse/storage/databases/main/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,13 +300,19 @@ def process_replication_rows(
rows: Iterable[Any],
) -> None:
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token)

super().process_replication_rows(stream_name, instance_name, token, rows)

def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)


class TagsStore(TagsWorkerStore):
pass

0 comments on commit c71199e

Please sign in to comment.