diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 6460eb99521c..9915b9288b01 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -54,6 +54,7 @@ ReplicationPresenceSetState, ) from synapse.replication.tcp.commands import ClearUserSyncsCommand +from synapse.replication.tcp.streams import PresenceStream from synapse.state import StateHandler from synapse.storage.databases.main import DataStore from synapse.types import Collection, JsonDict, UserID, get_domain_from_id @@ -253,8 +254,10 @@ async def update_external_syncs_clear(self, process_id): """ pass - async def process_replication_rows(self, token, rows): - """Process presence stream rows received over replication.""" + async def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: list + ): + """Process streams received over replication.""" pass async def maybe_send_presence_to_interested_destinations( @@ -421,7 +424,14 @@ async def notify_from_replication(self, states, stream_id): # If this is a federation sender, notify about presence updates. await self.maybe_send_presence_to_interested_destinations(states) - async def process_replication_rows(self, token, rows): + async def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: list + ): + await super().process_replication_rows(stream_name, instance_name, token, rows) + + if stream_name != PresenceStream.NAME: + return + states = [ UserPresenceState( row.user_id, diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index ce5d651cb8c5..778fde1e9765 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -191,8 +191,6 @@ async def on_rdata( self.stop_pusher(row.user_id, row.app_id, row.pushkey) else: await self.start_pusher(row.user_id, row.app_id, row.pushkey) - elif stream_name == PresenceStream.NAME: - await self._presence_handler.process_replication_rows(token, rows) elif stream_name == EventsStream.NAME: # We shouldn't get multiple rows per token for events stream, so # we don't need to optimise this for multiple rows. @@ -221,6 +219,10 @@ async def on_rdata( membership=row.data.membership, ) + await self._presence_handler.process_replication_rows( + stream_name, instance_name, token, rows + ) + # Notify any waiting deferreds. The list is ordered by position so we # just iterate through the list until we reach a position that is # greater than the received row position.