Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add a presence federation replication stream
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston committed Apr 15, 2021
1 parent 8f56607 commit 5c63b65
Show file tree
Hide file tree
Showing 5 changed files with 428 additions and 17 deletions.
248 changes: 234 additions & 14 deletions synapse/handlers/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import abc
import contextlib
import logging
from bisect import bisect
from contextlib import contextmanager
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -53,7 +54,9 @@
ReplicationBumpPresenceActiveTime,
ReplicationPresenceSetState,
)
from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.replication.tcp.commands import ClearUserSyncsCommand
from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore
from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
Expand Down Expand Up @@ -124,6 +127,8 @@ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.store = hs.get_datastore()

self.federation_queue = PresenceFederationQueue(hs, self)

self._busy_presence_enabled = hs.config.experimental.msc3026_enabled

active_presence = self.store.take_presence_startup_info()
Expand Down Expand Up @@ -245,9 +250,17 @@ async def update_external_syncs_clear(self, process_id):
"""
pass

async def process_replication_rows(self, token, rows):
async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: list
):
"""Process presence stream rows received over replication."""
pass
await self.federation_queue.process_replication_rows(
stream_name, instance_name, token, rows
)

def get_federation_queue(self) -> "PresenceFederationQueue":
"""Get the presence federation queue, if any."""
return self.federation_queue


class _NullContextManager(ContextManager[None]):
Expand All @@ -265,6 +278,7 @@ def __init__(self, hs):

self.presence_router = hs.get_presence_router()
self._presence_enabled = hs.config.use_presence
self.state = hs.get_state_handler()

# The number of ongoing syncs on this process, by user id.
# Empty if _presence_enabled is false.
Expand All @@ -273,6 +287,10 @@ def __init__(self, hs):
self.notifier = hs.get_notifier()
self.instance_id = hs.get_instance_id()

self._federation = None
if hs.should_send_federation():
self._federation = hs.get_federation_sender()

# user_id -> last_sync_ms. Lists the users that have stopped syncing
# but we haven't notified the master of that yet
self.users_going_offline = {}
Expand Down Expand Up @@ -388,7 +406,14 @@ async def notify_from_replication(self, states, stream_id):
users=users_to_states.keys(),
)

async def process_replication_rows(self, token, rows):
async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: list
):
await super().process_replication_rows(stream_name, instance_name, token, rows)

if stream_name != PresenceStream.NAME:
return

states = [
UserPresenceState(
row.user_id,
Expand All @@ -408,6 +433,20 @@ async def process_replication_rows(self, token, rows):
stream_id = token
await self.notify_from_replication(states, stream_id)

# Handle poking the local federation sender, if there is one.
if not self._federation:
return

hosts_and_states = await get_interested_remotes(
self.store,
self.presence_router,
states,
self.state,
)

for destinations, states in hosts_and_states:
self._federation.send_presence_to_destinations(states, destinations)

def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
return [
user_id
Expand Down Expand Up @@ -463,11 +502,14 @@ def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self.wheel_timer = WheelTimer()
self.notifier = hs.get_notifier()
self.federation = hs.get_federation_sender()
self.state = hs.get_state_handler()
self.presence_router = hs.get_presence_router()
self._presence_enabled = hs.config.use_presence

self.federation_sender = None
if hs.should_send_federation():
self.federation_sender = hs.get_federation_sender()

federation_registry = hs.get_federation_registry()

federation_registry.register_edu_handler("m.presence", self.incoming_presence)
Expand Down Expand Up @@ -680,7 +722,17 @@ async def _update_states(self, new_states: Iterable[UserPresenceState]) -> None:
if to_federation_ping:
federation_presence_out_counter.inc(len(to_federation_ping))

await self._push_to_remotes(to_federation_ping.values())
hosts_and_states = await get_interested_remotes(
self.store,
self.presence_router,
list(to_federation_ping.values()),
self.state,
)

for destinations, states in hosts_and_states:
self.federation_queue.send_presence_to_destinations(
states, destinations
)

async def _handle_timeouts(self):
"""Checks the presence of users that have timed out and updates as
Expand Down Expand Up @@ -920,14 +972,12 @@ async def _persist_and_notify(self, states):
users=[UserID.from_string(u) for u in users_to_states],
)

await self._push_to_remotes(states)

async def _push_to_remotes(self, states):
"""Sends state updates to remote servers.
# We only need to tell the local federation sender, if any, that new
# presence has happened. Other federation senders will get notified via
# the presence replication stream.
if not self.federation_sender:
return

Args:
states (list(UserPresenceState))
"""
hosts_and_states = await get_interested_remotes(
self.store,
self.presence_router,
Expand All @@ -936,7 +986,7 @@ async def _push_to_remotes(self, states):
)

for destinations, states in hosts_and_states:
self.federation.send_presence_to_destinations(states, destinations)
self.federation_sender.send_presence_to_destinations(states, destinations)

async def incoming_presence(self, origin, content):
"""Called when we receive a `m.presence` EDU from a remote server."""
Expand Down Expand Up @@ -1174,7 +1224,7 @@ async def _handle_state_delta(self, deltas):

# Send out user presence updates for each destination
for destination, user_state_set in presence_destinations.items():
self.federation.send_presence_to_destinations(
self.federation_queue.send_presence_to_destinations(
destinations=[destination], states=user_state_set
)

Expand Down Expand Up @@ -1819,3 +1869,173 @@ async def get_interested_remotes(
hosts_and_states.append(([host], states))

return hosts_and_states


class PresenceFederationQueue:
"""Handles sending ad hoc presence updates over federation, which are *not*
due to state updates (that get handled via the presence stream), e.g.
federation pings and sending existing present states to newly joined hosts.
Only the last N minutes will be queued, so if a federation sender instance
is down for longer then some updates will be dropped. This is OK as presence
is ephemeral, and so it will self correct eventually.
"""

# How long to keep entries in the queue for. Workers that are down for
# longer than this duration will miss out on older updates.
_KEEP_ITEMS_IN_QUEUE_FOR_MS = 5 * 60 * 1000

# How often to check if we can expire entries from the queue.
_CLEAR_ITEMS_EVERY_MS = 60 * 1000

def __init__(self, hs: "HomeServer", presence_handler: BasePresenceHandler):
self._clock = hs.get_clock()
self._notifier = hs.get_notifier()
self._instance_name = hs.get_instance_name()
self._presence_handler = presence_handler
self._repl_client = ReplicationGetStreamUpdates.make_client(hs)

# Should we keep a queue of recent presence updates? We only bother if
# another process may be handling federation sending.
self._queue_presence_updates = True

# The federation sender if this instance is a federation sender.
self._federation = None

if hs.should_send_federation():
self._federation = hs.get_federation_sender()

# We don't bother queuing up presence states if only this instance
# is sending federation.
if hs.config.worker.federation_shard_config.instances == [
self._instance_name
]:
self._queue_presence_updates = False

# The queue of recently queued updates as tuples of: `(timestamp,
# stream_id, destinations, user_ids)`. We don't store the full states
# for efficiency, and remote workers will already have the full states
# cached.
self._queue = [] # type: List[Tuple[int, int, Collection[str], Set[str]]]

self._next_id = 1

# Map from instance name to current token
self._current_tokens = {} # type: Dict[str, int]

if self._queue_presence_updates:
self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS)

def _clear_queue(self):
"""Clear out older entries from the queue."""
clear_before = self._clock.time_msec() - self._KEEP_ITEMS_IN_QUEUE_FOR_MS

# The queue is sorted by timestamp, so we can bisect to find the right
# place to purge before. Note that we are searching using a 1-tuple with
# the time, which does The Right Thing since the queue is a tuple where
# the first item is a timestamp.
index = bisect(self._queue, (clear_before,))
self._queue = self._queue[index:]

def send_presence_to_destinations(
self, states: Collection[UserPresenceState], destinations: Collection[str]
) -> None:
"""Send the presence states to the given destinations.
Will forward to the local federation sender (if there is one) and queue
to send over replication (if there are other federation sender instances.).
"""

if self._federation:
self._federation.send_presence_to_destinations(states, destinations)

if not self._queue_presence_updates:
return

now = self._clock.time_msec()

stream_id = self._next_id
self._next_id += 1

self._queue.append((now, stream_id, destinations, {s.user_id for s in states}))

self._notifier.notify_replication()

def get_current_token(self, instance_name: str) -> int:
if instance_name == self._instance_name:
return self._next_id - 1
else:
return self._current_tokens.get(instance_name, 0)

async def get_replication_rows(
self,
instance_name: str,
from_token: int,
upto_token: int,
target_row_count: int,
) -> Tuple[List[Tuple[int, Tuple[str, str]]], int, bool]:
"""Get all the updates between the two tokens.
We return rows in the form of `(destination, user_id)` to keep the size
of each row bounded (rather than returning the sets in a row).
"""
if instance_name != self._instance_name:
# If not local we query over replication.
result = await self._repl_client(
instance_name=instance_name,
stream_name=PresenceFederationStream.NAME,
from_token=from_token,
upto_token=upto_token,
)
return result["updates"], result["upto_token"], result["limited"]

# We can find the correct position in the queue by noting that there is
# exactly one entry per stream ID, and that the last entry has an ID of
# `self._next_id - 1`, so we can count backwards from the end.
#
# Since the start of the queue is periodically truncated we need to
# handle the case where `from_token` stream ID has already been dropped.
start_idx = max(from_token - self._next_id, -len(self._queue))

to_send = [] # type: List[Tuple[int, Tuple[str, str]]]
limited = False
new_id = upto_token
for _, stream_id, destinations, user_ids in self._queue[start_idx:]:
if stream_id > upto_token:
break

new_id = stream_id

to_send.extend(
(stream_id, (destination, user_id))
for destination in destinations
for user_id in user_ids
)

if len(to_send) > target_row_count:
limited = True
break

return to_send, new_id, limited

async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: list
):
if stream_name != PresenceFederationStream.NAME:
return

# We keep track of the current tokens
self._current_tokens[instance_name] = token

# If we're a federation sender we pull out the presence states to send
# and forward them on.
if not self._federation:
return

hosts_to_users = {} # type: Dict[str, Set[str]]
for row in rows:
hosts_to_users.setdefault(row.destination, set()).add(row.user_id)

for host, user_ids in hosts_to_users.items():
states = await self._presence_handler.current_state_for_users(user_ids)
self._federation.send_presence_to_destinations(states.values(), [host])
8 changes: 5 additions & 3 deletions synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
AccountDataStream,
DeviceListsStream,
GroupServerStream,
PresenceStream,
PushersStream,
PushRulesStream,
ReceiptsStream,
Expand Down Expand Up @@ -191,8 +190,6 @@ async def on_rdata(
self.stop_pusher(row.user_id, row.app_id, row.pushkey)
else:
await self.start_pusher(row.user_id, row.app_id, row.pushkey)
elif stream_name == PresenceStream.NAME:
await self._presence_handler.process_replication_rows(token, rows)
elif stream_name == EventsStream.NAME:
# We shouldn't get multiple rows per token for events stream, so
# we don't need to optimise this for multiple rows.
Expand Down Expand Up @@ -221,6 +218,10 @@ async def on_rdata(
membership=row.data.membership,
)

await self._presence_handler.process_replication_rows(
stream_name, instance_name, token, rows
)

# Notify any waiting deferreds. The list is ordered by position so we
# just iterate through the list until we reach a position that is
# greater than the received row position.
Expand Down Expand Up @@ -338,6 +339,7 @@ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self._is_mine_id = hs.is_mine_id
self._hs = hs
self._presence_handler = hs.get_presence_handler()

# We need to make a temporary value to ensure that mypy picks up the
# right type. We know we should have a federation sender instance since
Expand Down
Loading

0 comments on commit 5c63b65

Please sign in to comment.