Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Handle local device list updates during partial join (#13934)
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston authored Sep 28, 2022
1 parent df8b91e commit 5f659d4
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 15 deletions.
1 change: 1 addition & 0 deletions changelog.d/13934.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Correctly handle sending local device list updates to remote servers during a partial join.
84 changes: 82 additions & 2 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,10 +762,90 @@ async def handle_room_un_partial_stated(self, room_id: str) -> None:
gone from partial to full state.
"""

# We defer to the device list updater implementation as we're on the
# right worker.
# We defer to the device list updater to handle pending remote device
# list updates.
await self.device_list_updater.handle_room_un_partial_stated(room_id)

# Replay local updates.
(
join_event_id,
device_lists_stream_id,
) = await self.store.get_join_event_id_and_device_lists_stream_id_for_partial_state(
room_id
)

# Get the local device list changes that have happened in the room since
# we started joining. If there are no updates there's nothing left to do.
changes = await self.store.get_device_list_changes_in_room(
room_id, device_lists_stream_id
)
local_changes = {(u, d) for u, d in changes if self.hs.is_mine_id(u)}
if not local_changes:
return

# Note: We have persisted the full state at this point, we just haven't
# cleared the `partial_room` flag.
join_state_ids = await self._state_storage.get_state_ids_for_event(
join_event_id, await_full_state=False
)
current_state_ids = await self.store.get_partial_current_state_ids(room_id)

# Now we need to work out all servers that might have been in the room
# at any point during our join.

# First we look for any membership states that have changed between the
# initial join and now...
all_keys = set(join_state_ids)
all_keys.update(current_state_ids)

potentially_changed_hosts = set()
for etype, state_key in all_keys:
if etype != EventTypes.Member:
continue

prev = join_state_ids.get((etype, state_key))
current = current_state_ids.get((etype, state_key))

if prev != current:
potentially_changed_hosts.add(get_domain_from_id(state_key))

# ... then we add all the hosts that are currently joined to the room...
current_hosts_in_room = await self.store.get_current_hosts_in_room(room_id)
potentially_changed_hosts.update(current_hosts_in_room)

# ... and finally we remove any hosts that we were told about, as we
# will have sent device list updates to those hosts when they happened.
known_hosts_at_join = await self.store.get_partial_state_servers_at_join(
room_id
)
potentially_changed_hosts.difference_update(known_hosts_at_join)

potentially_changed_hosts.discard(self.server_name)

if not potentially_changed_hosts:
# Nothing to do.
return

logger.info(
"Found %d changed hosts to send device list updates to",
len(potentially_changed_hosts),
)

for user_id, device_id in local_changes:
await self.store.add_device_list_outbound_pokes(
user_id=user_id,
device_id=device_id,
room_id=room_id,
stream_id=None,
hosts=potentially_changed_hosts,
context=None,
)

# Notify things that device lists need to be sent out.
self.notifier.notify_replication()
for host in potentially_changed_hosts:
self.federation_sender.send_device_messages(host, immediate=False)


def _update_device_from_client_ips(
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
Expand Down
55 changes: 42 additions & 13 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,33 @@ def _get_device_list_changes_in_rooms_txn(

return changes

async def get_device_list_changes_in_room(
self, room_id: str, min_stream_id: int
) -> Collection[Tuple[str, str]]:
"""Get all device list changes that happened in the room since the given
stream ID.
Returns:
Collection of user ID/device ID tuples of all devices that have
changed
"""

sql = """
SELECT DISTINCT user_id, device_id FROM device_lists_changes_in_room
WHERE room_id = ? AND stream_id > ?
"""

def get_device_list_changes_in_room_txn(
txn: LoggingTransaction,
) -> Collection[Tuple[str, str]]:
txn.execute(sql, (room_id, min_stream_id))
return cast(Collection[Tuple[str, str]], txn.fetchall())

return await self.db_pool.runInteraction(
"get_device_list_changes_in_room",
get_device_list_changes_in_room_txn,
)


class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(
Expand Down Expand Up @@ -1946,14 +1973,15 @@ async def add_device_list_outbound_pokes(
user_id: str,
device_id: str,
room_id: str,
stream_id: int,
stream_id: Optional[int],
hosts: Collection[str],
context: Optional[Dict[str, str]],
) -> None:
"""Queue the device update to be sent to the given set of hosts,
calculated from the room ID.
Marks the associated row in `device_lists_changes_in_room` as handled.
Marks the associated row in `device_lists_changes_in_room` as handled,
if `stream_id` is provided.
"""

def add_device_list_outbound_pokes_txn(
Expand All @@ -1969,17 +1997,18 @@ def add_device_list_outbound_pokes_txn(
context=context,
)

self.db_pool.simple_update_txn(
txn,
table="device_lists_changes_in_room",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"stream_id": stream_id,
"room_id": room_id,
},
updatevalues={"converted_to_destinations": True},
)
if stream_id:
self.db_pool.simple_update_txn(
txn,
table="device_lists_changes_in_room",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"stream_id": stream_id,
"room_id": room_id,
},
updatevalues={"converted_to_destinations": True},
)

if not hosts:
# If there are no hosts then we don't try and generate stream IDs.
Expand Down
16 changes: 16 additions & 0 deletions synapse/storage/databases/main/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,22 @@ async def is_partial_state_room(self, room_id: str) -> bool:

return entry is not None

async def get_join_event_id_and_device_lists_stream_id_for_partial_state(
self, room_id: str
) -> Tuple[str, int]:
"""Get the event ID of the initial join that started the partial
join, and the device list stream ID at the point we started the partial
join.
"""

result = await self.db_pool.simple_select_one(
table="partial_state_rooms",
keyvalues={"room_id": room_id},
retcols=("join_event_id", "device_lists_stream_id"),
desc="get_join_event_id_for_partial_state",
)
return result["join_event_id"], result["device_lists_stream_id"]


class _BackgroundUpdates:
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
Expand Down

0 comments on commit 5f659d4

Please sign in to comment.