Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Update get_users_in_room mis-use to get hosts with dedicated get_current_hosts_in_room #13605

Merged
merged 9 commits into from
Aug 24, 2022
1 change: 1 addition & 0 deletions changelog.d/13605.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor `get_users_in_room(room_id)` mis-use with dedicated `get_current_hosts_in_room(room_id)` function.
8 changes: 6 additions & 2 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.federation_sender = hs.get_federation_sender()
self._storage_controllers = hs.get_storage_controllers()

self.device_list_updater = DeviceListUpdater(hs, self)

Expand Down Expand Up @@ -694,8 +695,11 @@ async def _handle_new_device_update_async(self) -> None:

# Ignore any users that aren't ours
if self.hs.is_mine_id(user_id):
joined_user_ids = await self.store.get_users_in_room(room_id)
hosts = {get_domain_from_id(u) for u in joined_user_ids}
hosts = set(
await self._storage_controllers.state.get_current_hosts_in_room(
room_id
)
)
hosts.discard(self.server_name)

# Check if we've already sent this update to some hosts
Expand Down
12 changes: 7 additions & 5 deletions synapse/handlers/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from synapse.appservice import ApplicationService
from synapse.module_api import NOT_SPAM
from synapse.storage.databases.main.directory import RoomAliasMapping
from synapse.types import JsonDict, Requester, RoomAlias, get_domain_from_id
from synapse.types import JsonDict, Requester, RoomAlias

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -83,8 +83,9 @@ async def _create_association(
# TODO(erikj): Add transactions.
# TODO(erikj): Check if there is a current association.
if not servers:
users = await self.store.get_users_in_room(room_id)
servers = {get_domain_from_id(u) for u in users}
servers = await self._storage_controllers.state.get_current_hosts_in_room(
room_id
)

if not servers:
raise SynapseError(400, "Failed to get server list")
Expand Down Expand Up @@ -287,8 +288,9 @@ async def get_association(self, room_alias: RoomAlias) -> JsonDict:
Codes.NOT_FOUND,
)

users = await self.store.get_users_in_room(room_id)
extra_servers = {get_domain_from_id(u) for u in users}
extra_servers = await self._storage_controllers.state.get_current_hosts_in_room(
room_id
)
servers_set = set(extra_servers) | set(servers)

# If this server is in the list of servers, return it first.
Expand Down
3 changes: 1 addition & 2 deletions synapse/handlers/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -2051,8 +2051,7 @@ async def get_interested_remotes(
)

for room_id, states in room_ids_to_states.items():
user_ids = await store.get_users_in_room(room_id)
hosts = {get_domain_from_id(user_id) for user_id in user_ids}
hosts = await store.get_current_hosts_in_room(room_id)
for host in hosts:
hosts_and_states.setdefault(host, set()).update(states)

Expand Down
7 changes: 4 additions & 3 deletions synapse/handlers/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)
from synapse.replication.tcp.streams import TypingStream
from synapse.streams import EventSource
from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
from synapse.types import JsonDict, Requester, StreamKeyType, UserID
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
Expand Down Expand Up @@ -362,8 +362,9 @@ async def _recv_edu(self, origin: str, content: JsonDict) -> None:
)
return

users = await self.store.get_users_in_room(room_id)
domains = {get_domain_from_id(u) for u in users}
domains = await self._storage_controllers.state.get_current_hosts_in_room(
room_id
)

if self.server_name in domains:
logger.info("Got typing update from %s: %r", user_id, content)
Expand Down
17 changes: 12 additions & 5 deletions tests/federation/test_federation_sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,17 +173,24 @@ def default_config(self):
return c

def prepare(self, reactor, clock, hs):
# stub out `get_rooms_for_user` and `get_users_in_room` so that the
test_room_id = "!room:host1"

# stub out `get_rooms_for_user` and `get_current_hosts_in_room` so that the
# server thinks the user shares a room with `@user2:host2`
def get_rooms_for_user(user_id):
return defer.succeed({"!room:host1"})
return defer.succeed({test_room_id})

hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user

def get_users_in_room(room_id):
return defer.succeed({"@user2:host2"})
async def get_current_hosts_in_room(room_id):
if room_id == test_room_id:
return ["host2"]

# TODO: We should fail the test when we encounter an unxpected room ID.
# We can't just use `self.fail(...)` here because the app code is greedy
# with `Exception` and will catch it before the test can see it.

hs.get_datastores().main.get_users_in_room = get_users_in_room
hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room

# whenever send_transaction is called, record the edu data
self.edus = []
Expand Down