Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Cancel the processing of key query requests when they time out. (#13680)
Browse files Browse the repository at this point in the history
  • Loading branch information
reivilibre authored Sep 7, 2022
1 parent c2fe48a commit d3d9ca1
Show file tree
Hide file tree
Showing 18 changed files with 110 additions and 20 deletions.
1 change: 1 addition & 0 deletions changelog.d/13680.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Cancel the processing of key query requests when they time out.
5 changes: 5 additions & 0 deletions synapse/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
trace,
)
from synapse.types import Requester, create_requester
from synapse.util.cancellation import cancellable

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -118,6 +119,7 @@ async def check_user_in_room(
errcode=Codes.NOT_JOINED,
)

@cancellable
async def get_user_by_req(
self,
request: SynapseRequest,
Expand Down Expand Up @@ -166,6 +168,7 @@ async def get_user_by_req(
parent_span.set_tag("appservice_id", requester.app_service.id)
return requester

@cancellable
async def _wrapped_get_user_by_req(
self,
request: SynapseRequest,
Expand Down Expand Up @@ -281,6 +284,7 @@ async def validate_appservice_can_control_user_id(
403, "Application service has not registered this user (%s)" % user_id
)

@cancellable
async def _get_appservice_user(self, request: Request) -> Optional[Requester]:
"""
Given a request, reads the request parameters to determine:
Expand Down Expand Up @@ -523,6 +527,7 @@ def has_access_token(request: Request) -> bool:
return bool(query_params) or bool(auth_headers)

@staticmethod
@cancellable
def get_access_token_from_request(request: Request) -> str:
"""Extracts the access_token from the request.
Expand Down
3 changes: 3 additions & 0 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.cancellation import cancellable
from synapse.util.metrics import measure_func
from synapse.util.retryutils import NotRetryingDestination

Expand Down Expand Up @@ -124,6 +125,7 @@ async def get_device(self, user_id: str, device_id: str) -> JsonDict:

return device

@cancellable
async def get_device_changes_in_shared_rooms(
self, user_id: str, room_ids: Collection[str], from_token: StreamToken
) -> Collection[str]:
Expand Down Expand Up @@ -163,6 +165,7 @@ async def get_device_changes_in_shared_rooms(

@trace
@measure_func("device.get_user_ids_changed")
@cancellable
async def get_user_ids_changed(
self, user_id: str, from_token: StreamToken
) -> JsonDict:
Expand Down
40 changes: 24 additions & 16 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
get_verify_key_from_cross_signing_key,
)
from synapse.util import json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer
from synapse.util.async_helpers import Linearizer, delay_cancellation
from synapse.util.cancellation import cancellable
from synapse.util.retryutils import NotRetryingDestination

if TYPE_CHECKING:
Expand Down Expand Up @@ -91,6 +92,7 @@ def __init__(self, hs: "HomeServer"):
)

@trace
@cancellable
async def query_devices(
self,
query_body: JsonDict,
Expand Down Expand Up @@ -208,22 +210,26 @@ async def query_devices(
r[user_id] = remote_queries[user_id]

# Now fetch any devices that we don't have in our cache
# TODO It might make sense to propagate cancellations into the
# deferreds which are querying remote homeservers.
await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self._query_devices_for_destination,
results,
cross_signing_keys,
failures,
destination,
queries,
timeout,
)
for destination, queries in remote_queries_not_in_cache.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
delay_cancellation(
defer.gatherResults(
[
run_in_background(
self._query_devices_for_destination,
results,
cross_signing_keys,
failures,
destination,
queries,
timeout,
)
for destination, queries in remote_queries_not_in_cache.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
)

ret = {"device_keys": results, "failures": failures}
Expand Down Expand Up @@ -347,6 +353,7 @@ async def _query_devices_for_destination(

return

@cancellable
async def get_cross_signing_keys_from_cache(
self, query: Iterable[str], from_user_id: Optional[str]
) -> Dict[str, Dict[str, dict]]:
Expand Down Expand Up @@ -393,6 +400,7 @@ async def get_cross_signing_keys_from_cache(
}

@trace
@cancellable
async def query_local_devices(
self, query: Mapping[str, Optional[List[str]]]
) -> Dict[str, Dict[str, dict]]:
Expand Down
6 changes: 4 additions & 2 deletions synapse/rest/client/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
)
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import log_kv, set_tag
from synapse.rest.client._base import client_patterns, interactive_auth_handler
from synapse.types import JsonDict, StreamToken

from ._base import client_patterns, interactive_auth_handler
from synapse.util.cancellation import cancellable

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -156,6 +156,7 @@ def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()

@cancellable
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
Expand Down Expand Up @@ -199,6 +200,7 @@ def __init__(self, hs: "HomeServer"):
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastores().main

@cancellable
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)

Expand Down
4 changes: 4 additions & 0 deletions synapse/storage/controllers/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
PartialStateEventsTracker,
)
from synapse.types import MutableStateMap, StateMap
from synapse.util.cancellation import cancellable

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -229,6 +230,7 @@ async def get_state_for_events(

@trace
@tag_args
@cancellable
async def get_state_ids_for_events(
self,
event_ids: Collection[str],
Expand Down Expand Up @@ -350,6 +352,7 @@ def get_state_for_groups(

@trace
@tag_args
@cancellable
async def get_state_group_for_events(
self,
event_ids: Collection[str],
Expand Down Expand Up @@ -398,6 +401,7 @@ async def store_state_group(
event_id, room_id, prev_group, delta_ids, current_state_ids
)

@cancellable
async def get_current_state_ids(
self,
room_id: str,
Expand Down
4 changes: 4 additions & 0 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr

Expand Down Expand Up @@ -668,6 +669,7 @@ def get_device_stream_token(self) -> int:
...

@trace
@cancellable
async def get_user_devices_from_cache(
self, query_list: List[Tuple[str, Optional[str]]]
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
Expand Down Expand Up @@ -743,6 +745,7 @@ def get_cached_device_list_changes(

return self._device_list_stream_cache.get_all_entities_changed(from_key)

@cancellable
async def get_users_whose_devices_changed(
self,
from_key: int,
Expand Down Expand Up @@ -1221,6 +1224,7 @@ async def _get_min_device_lists_changes_in_room(self) -> int:
desc="get_min_device_lists_changes_in_room",
)

@cancellable
async def get_device_list_changes_in_rooms(
self, room_ids: Collection[str], from_id: int
) -> Optional[Set[str]]:
Expand Down
5 changes: 4 additions & 1 deletion synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter

if TYPE_CHECKING:
Expand Down Expand Up @@ -135,6 +136,7 @@ async def get_e2e_device_keys_for_federation_query(
return now_stream_id, []

@trace
@cancellable
async def get_e2e_device_keys_for_cs_api(
self, query_list: List[Tuple[str, Optional[str]]]
) -> Dict[str, Dict[str, JsonDict]]:
Expand Down Expand Up @@ -197,6 +199,7 @@ async def get_e2e_device_keys_and_signatures(
...

@trace
@cancellable
async def get_e2e_device_keys_and_signatures(
self,
query_list: Collection[Tuple[str, Optional[str]]],
Expand Down Expand Up @@ -887,6 +890,7 @@ def _get_e2e_cross_signing_signatures_txn(

return keys

@cancellable
async def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: Optional[str] = None
) -> Dict[str, Optional[Dict[str, JsonDict]]]:
Expand All @@ -902,7 +906,6 @@ async def get_e2e_cross_signing_keys_bulk(
keys were not found, either their user ID will not be in the dict,
or their user ID will map to None.
"""

result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)

if from_user_id:
Expand Down
2 changes: 2 additions & 0 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter

if TYPE_CHECKING:
Expand Down Expand Up @@ -976,6 +977,7 @@ def _get_min_depth_interaction(

return int(min_depth) if min_depth is not None else None

@cancellable
async def get_forward_extremities_for_room_at_stream_ordering(
self, room_id: str, stream_ordering: int
) -> List[str]:
Expand Down
4 changes: 4 additions & 0 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import AsyncLruCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure

Expand Down Expand Up @@ -339,6 +340,7 @@ async def get_event(
) -> Optional[EventBase]:
...

@cancellable
async def get_event(
self,
event_id: str,
Expand Down Expand Up @@ -433,6 +435,7 @@ async def get_events(

@trace
@tag_args
@cancellable
async def get_events_as_list(
self,
event_ids: Collection[str],
Expand Down Expand Up @@ -584,6 +587,7 @@ async def get_events_as_list(

return events

@cancellable
async def _get_events_from_cache_or_db(
self, event_ids: Iterable[str], allow_rejected: bool = False
) -> Dict[str, EventCacheEntry]:
Expand Down
2 changes: 2 additions & 0 deletions synapse/storage/databases/main/roommember.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure

Expand Down Expand Up @@ -770,6 +771,7 @@ def _get_users_server_still_shares_room_with_txn(
_get_users_server_still_shares_room_with_txn,
)

@cancellable
async def get_rooms_for_user(
self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None
) -> FrozenSet[str]:
Expand Down
2 changes: 2 additions & 0 deletions synapse/storage/databases/main/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from synapse.types import JsonDict, JsonMapping, StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter

if TYPE_CHECKING:
Expand Down Expand Up @@ -281,6 +282,7 @@ def _get_current_state_ids_txn(txn: LoggingTransaction) -> StateMap[str]:
)

# FIXME: how should this be cached?
@cancellable
async def get_partial_filtered_current_state_ids(
self, room_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
Expand Down
2 changes: 2 additions & 0 deletions synapse/storage/databases/main/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from synapse.types import PersistedEventPosition, RoomStreamToken
from synapse.util.caches.descriptors import cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.cancellation import cancellable

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -597,6 +598,7 @@ def f(txn: LoggingTransaction) -> List[_EventDictReturn]:

return ret, key

@cancellable
async def get_membership_changes_for_user(
self,
user_id: str,
Expand Down
3 changes: 3 additions & 0 deletions synapse/storage/databases/state/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from synapse.types import MutableStateMap, StateKey, StateMap
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.cancellation import cancellable

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -156,6 +157,7 @@ def _get_state_group_delta_txn(txn: LoggingTransaction) -> _GetStateGroupDelta:
"get_state_group_delta", _get_state_group_delta_txn
)

@cancellable
async def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter
) -> Dict[int, StateMap[str]]:
Expand Down Expand Up @@ -235,6 +237,7 @@ def _get_state_for_group_using_cache(

return state_filter.filter_state(state_dict_ids), not missing_types

@cancellable
async def _get_state_for_groups(
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Dict[int, MutableStateMap[str]]:
Expand Down
Loading

0 comments on commit d3d9ca1

Please sign in to comment.