Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Recheck if remote device is cached before requesting it #16252

Merged
merged 4 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16252.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug when using workers where Synapse could end up re-requesting the same remote device repeatedly.
22 changes: 16 additions & 6 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)
from synapse.types import (
JsonDict,
JsonMapping,
StrCollection,
StreamKeyType,
StreamToken,
Expand Down Expand Up @@ -982,7 +983,7 @@ def __init__(self, hs: "HomeServer"):

async def multi_user_device_resync(
self, user_ids: List[str], mark_failed_as_stale: bool = True
) -> Dict[str, Optional[JsonDict]]:
) -> Dict[str, Optional[JsonMapping]]:
"""
Like `user_device_resync` but operates on multiple users **from the same origin**
at once.
Expand Down Expand Up @@ -1011,6 +1012,7 @@ def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
self._notifier = hs.get_notifier()

self._remote_edu_linearizer = Linearizer(name="remote_device_list")
self._resync_linearizer = Linearizer(name="remote_device_resync")

# user_id -> list of updates waiting to be handled.
self._pending_updates: Dict[
Expand Down Expand Up @@ -1253,7 +1255,7 @@ async def _maybe_retry_device_resync(self) -> None:

async def multi_user_device_resync(
self, user_ids: List[str], mark_failed_as_stale: bool = True
) -> Dict[str, Optional[JsonDict]]:
) -> Dict[str, Optional[JsonMapping]]:
"""
Like `user_device_resync` but operates on multiple users **from the same origin**
at once.
Expand All @@ -1273,9 +1275,11 @@ async def multi_user_device_resync(
failed = set()
# TODO(Perf): Actually batch these up
for user_id in user_ids:
user_result, user_failed = await self._user_device_resync_returning_failed(
user_id
)
async with self._resync_linearizer.queue(user_id):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is queueing the right thing here? Should we instead share the results via an ObservableDeferred? (If you have two requests for the same user in a row there's no reason to service the second if the first is ongoing?)

Edit: It looks like this is essentially what we're doing by checking the cache inside of _user_device_resync_returning_failed?


Will this make it harder to batch these? (Maybe not the immediate concern, but don't want to design ourselves into a corner.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If these weren't device lists I would be more gung ho about batching these up. However, we need to make sure that we don't break any semantics. Specifically I'm worried about the situations where we get two requests for a remote users device lists, one that happens before we get poked about a new device and one after, if we coalesce those two remote calls we need to make sure that we don't return the old set of devices (retrieved by the first request) to the second request. IYSWIM....

I don't think this backs us in to a corner, we can always change this down the line

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're not talking about the same batching -- I was talking about batching multiple users (based on the comment a few lines above this), but maybe that's tricky.

(
user_result,
user_failed,
) = await self._user_device_resync_returning_failed(user_id)
result[user_id] = user_result
if user_failed:
failed.add(user_id)
Expand All @@ -1287,7 +1291,7 @@ async def multi_user_device_resync(

async def _user_device_resync_returning_failed(
self, user_id: str
) -> Tuple[Optional[JsonDict], bool]:
) -> Tuple[Optional[JsonMapping], bool]:
"""Fetches all devices for a user and updates the device cache with them.

Args:
Expand All @@ -1300,6 +1304,12 @@ async def _user_device_resync_returning_failed(
e.g. due to a connection problem.
- True iff the resync failed and the device list should be marked as stale.
"""
# Check that we haven't gone and fetched the devices since we last
# checked if we needed to resync these device lists.
if await self.store.get_users_whose_devices_are_cached([user_id]):
cached = await self.store.get_cached_devices_for_user(user_id)
return cached, False
Comment on lines +1354 to +1358
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is safe because if we linearized on user ID, then request the same user again but have not received a request to invalidate that data (via a resync request) then it must still be valid?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, we already do do this check before entering this block, so this is safe. The flow is:

  1. Client reader receives a requests for keys.
  2. Client reader checks if keys are cached, if so returns them.
  3. If not, client reader queries master to fetch the keys.
  4. New Master linearizes requests for a users keys
  5. New Master checks if we now have the keys in the cache, if so returns.
  6. Fetches keys from the remote server, and caches them if we share a room with the remote user.

So this is step 5, which is safe because we're basically re-doing a check we have previously done


logger.debug("Attempting to resync the device list for %s", user_id)
log_kv({"message": "Doing resync to update device list."})
# Fetch all devices for the user.
Expand Down
4 changes: 2 additions & 2 deletions synapse/replication/http/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from synapse.http.server import HttpServer
from synapse.logging.opentracing import active_span
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
from synapse.types import JsonDict, JsonMapping

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -82,7 +82,7 @@ async def _serialize_payload(user_ids: List[str]) -> JsonDict: # type: ignore[o

async def _handle_request( # type: ignore[override]
self, request: Request, content: JsonDict
) -> Tuple[int, Dict[str, Optional[JsonDict]]]:
) -> Tuple[int, Dict[str, Optional[JsonMapping]]]:
user_ids: List[str] = content["user_ids"]

logger.info("Resync for %r", user_ids)
Expand Down
26 changes: 17 additions & 9 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,18 +760,10 @@ async def get_user_devices_from_cache(
mapping of user_id -> device_id -> device_info.
"""
unique_user_ids = user_ids | {user_id for user_id, _ in user_and_device_ids}
user_map = await self.get_device_list_last_stream_id_for_remotes(
list(unique_user_ids)
)

# We go and check if any of the users need to have their device lists
# resynced. If they do then we remove them from the cached list.
users_needing_resync = await self.get_user_ids_requiring_device_list_resync(
user_ids_in_cache = await self.get_users_whose_devices_are_cached(
unique_user_ids
)
user_ids_in_cache = {
user_id for user_id, stream_id in user_map.items() if stream_id
} - users_needing_resync
user_ids_not_in_cache = unique_user_ids - user_ids_in_cache

# First fetch all the users which all devices are to be returned.
Expand All @@ -793,6 +785,22 @@ async def get_user_devices_from_cache(

return user_ids_not_in_cache, results

async def get_users_whose_devices_are_cached(
self, user_ids: StrCollection
) -> Set[str]:
"""Checks which of the given users we have cached the devices for."""
user_map = await self.get_device_list_last_stream_id_for_remotes(user_ids)

# We go and check if any of the users need to have their device lists
# resynced. If they do then we remove them from the cached list.
users_needing_resync = await self.get_user_ids_requiring_device_list_resync(
user_ids
)
user_ids_in_cache = {
user_id for user_id, stream_id in user_map.items() if stream_id
} - users_needing_resync
return user_ids_in_cache

@cached(num_args=2, tree=True)
async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict:
content = await self.db_pool.simple_select_one_onecol(
Expand Down
Loading