diff --git a/changelog.d/10018.misc b/changelog.d/10018.misc new file mode 100644 index 000000000000..eaf9f64867a6 --- /dev/null +++ b/changelog.d/10018.misc @@ -0,0 +1 @@ +Reduce memory usage when verifying signatures on large numbers of events at once. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 5f18ef77489d..6fc07129788f 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -17,7 +17,7 @@ import logging import urllib from collections import defaultdict -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Set, Tuple import attr from signedjson.key import ( @@ -42,6 +42,8 @@ SynapseError, ) from synapse.config.key import TrustedKeyServer +from synapse.events import EventBase +from synapse.events.utils import prune_event_dict from synapse.logging.context import ( PreserveLoggingContext, make_deferred_yieldable, @@ -69,7 +71,11 @@ class VerifyJsonRequest: Attributes: server_name: The name of the server to verify against. - json_object: The JSON object to verify. + get_json_object: A callback to fetch the JSON object to verify. + A callback is used to allow deferring the creation of the JSON + object to verify until needed, e.g. for events we can defer + creating the redacted copy. This reduces the memory usage when + there are large numbers of in flight requests. minimum_valid_until_ts: time at which we require the signing key to be valid. (0 implies we don't care) @@ -88,14 +94,50 @@ class VerifyJsonRequest: """ server_name = attr.ib(type=str) - json_object = attr.ib(type=JsonDict) + get_json_object = attr.ib(type=Callable[[], JsonDict]) minimum_valid_until_ts = attr.ib(type=int) request_name = attr.ib(type=str) - key_ids = attr.ib(init=False, type=List[str]) + key_ids = attr.ib(type=List[str]) key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred) - def __attrs_post_init__(self): - self.key_ids = signature_ids(self.json_object, self.server_name) + @staticmethod + def from_json_object( + server_name: str, + json_object: JsonDict, + minimum_valid_until_ms: int, + request_name: str, + ): + """Create a VerifyJsonRequest to verify all signatures on a signed JSON + object for the given server. + """ + key_ids = signature_ids(json_object, server_name) + return VerifyJsonRequest( + server_name, + lambda: json_object, + minimum_valid_until_ms, + request_name=request_name, + key_ids=key_ids, + ) + + @staticmethod + def from_event( + server_name: str, + event: EventBase, + minimum_valid_until_ms: int, + ): + """Create a VerifyJsonRequest to verify all signatures on an event + object for the given server. + """ + key_ids = list(event.signatures.get(server_name, [])) + return VerifyJsonRequest( + server_name, + # We defer creating the redacted json object, as it uses a lot more + # memory than the Event object itself. + lambda: prune_event_dict(event.room_version, event.get_pdu_json()), + minimum_valid_until_ms, + request_name=event.event_id, + key_ids=key_ids, + ) class KeyLookupError(ValueError): @@ -147,8 +189,13 @@ def verify_json_for_server( Deferred[None]: completes if the the object was correctly signed, otherwise errbacks with an error """ - req = VerifyJsonRequest(server_name, json_object, validity_time, request_name) - requests = (req,) + request = VerifyJsonRequest.from_json_object( + server_name, + json_object, + validity_time, + request_name, + ) + requests = (request,) return make_deferred_yieldable(self._verify_objects(requests)[0]) def verify_json_objects_for_server( @@ -175,10 +222,41 @@ def verify_json_objects_for_server( logcontext. """ return self._verify_objects( - VerifyJsonRequest(server_name, json_object, validity_time, request_name) + VerifyJsonRequest.from_json_object( + server_name, json_object, validity_time, request_name + ) for server_name, json_object, validity_time, request_name in server_and_json ) + def verify_events_for_server( + self, server_and_events: Iterable[Tuple[str, EventBase, int]] + ) -> List[defer.Deferred]: + """Bulk verification of signatures on events. + + Args: + server_and_events: + Iterable of `(server_name, event, validity_time)` tuples. + + `server_name` is which server we are verifying the signature for + on the event. + + `event` is the event that we'll verify the signatures of for + the given `server_name`. + + `validity_time` is a timestamp at which the signing key must be + valid. + + Returns: + List: for each input triplet, a deferred indicating success + or failure to verify each event's signature for the given + server_name. The deferreds run their callbacks in the sentinel + logcontext. + """ + return self._verify_objects( + VerifyJsonRequest.from_event(server_name, event, validity_time) + for server_name, event, validity_time in server_and_events + ) + def _verify_objects( self, verify_requests: Iterable[VerifyJsonRequest] ) -> List[defer.Deferred]: @@ -892,7 +970,7 @@ async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None: with PreserveLoggingContext(): _, key_id, verify_key = await verify_request.key_ready - json_object = verify_request.json_object + json_object = verify_request.get_json_object() try: verify_signed_json(json_object, server_name, verify_key) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 949dcd46141f..3fe496dcd330 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -137,11 +137,7 @@ def errback(failure: Failure, pdu: EventBase): return deferreds -class PduToCheckSig( - namedtuple( - "PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"] - ) -): +class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])): pass @@ -184,7 +180,6 @@ def _check_sigs_on_pdus( pdus_to_check = [ PduToCheckSig( pdu=p, - redacted_pdu_json=prune_event(p).get_pdu_json(), sender_domain=get_domain_from_id(p.sender), deferreds=[], ) @@ -195,13 +190,12 @@ def _check_sigs_on_pdus( # (except if its a 3pid invite, in which case it may be sent by any server) pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)] - more_deferreds = keyring.verify_json_objects_for_server( + more_deferreds = keyring.verify_events_for_server( [ ( p.sender_domain, - p.redacted_pdu_json, + p.pdu, p.pdu.origin_server_ts if room_version.enforce_key_validity else 0, - p.pdu.event_id, ) for p in pdus_to_check_sender ] @@ -230,13 +224,12 @@ def sender_err(e, pdu_to_check): if p.sender_domain != get_domain_from_id(p.pdu.event_id) ] - more_deferreds = keyring.verify_json_objects_for_server( + more_deferreds = keyring.verify_events_for_server( [ ( get_domain_from_id(p.pdu.event_id), - p.redacted_pdu_json, + p.pdu, p.pdu.origin_server_ts if room_version.enforce_key_validity else 0, - p.pdu.event_id, ) for p in pdus_to_check_event_id ]