From 5243aa0519a3f1254465e82325ffcad032243fb3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 2 Jun 2021 13:27:11 +0100 Subject: [PATCH 01/10] Make checking sigs linear --- synapse/crypto/keyring.py | 46 ++--- synapse/federation/federation_base.py | 228 ++++++++---------------- synapse/federation/federation_client.py | 104 ++++++----- 3 files changed, 149 insertions(+), 229 deletions(-) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index c840ffca7141..e5a4685ed49f 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -233,41 +233,19 @@ def verify_json_objects_for_server( for server_name, json_object, validity_time in server_and_json ] - def verify_events_for_server( - self, server_and_events: Iterable[Tuple[str, EventBase, int]] - ) -> List[defer.Deferred]: - """Bulk verification of signatures on events. - - Args: - server_and_events: - Iterable of `(server_name, event, validity_time)` tuples. - - `server_name` is which server we are verifying the signature for - on the event. - - `event` is the event that we'll verify the signatures of for - the given `server_name`. - - `validity_time` is a timestamp at which the signing key must be - valid. - - Returns: - List: for each input triplet, a deferred indicating success - or failure to verify each event's signature for the given - server_name. The deferreds run their callbacks in the sentinel - logcontext. - """ - return [ - run_in_background( - self.process_request, - VerifyJsonRequest.from_event( - server_name, - event, - validity_time, - ), + async def verify_event_for_server( + self, + server_name: str, + event: EventBase, + validity_time: int, + ) -> None: + await self.process_request( + VerifyJsonRequest.from_event( + server_name, + event, + validity_time, ) - for server_name, event, validity_time in server_and_events - ] + ) async def process_request(self, verify_request: VerifyJsonRequest) -> None: """Processes the `VerifyJsonRequest`. Raises if the object is not signed diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 3fe496dcd330..8069d904a2bc 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -14,11 +14,6 @@ # limitations under the License. import logging from collections import namedtuple -from typing import Iterable, List - -from twisted.internet import defer -from twisted.internet.defer import Deferred, DeferredList -from twisted.python.failure import Failure from synapse.api.constants import MAX_DEPTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError @@ -28,11 +23,6 @@ from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import prune_event, validate_canonicaljson from synapse.http.servlet import assert_params_in_dict -from synapse.logging.context import ( - PreserveLoggingContext, - current_context, - make_deferred_yieldable, -) from synapse.types import JsonDict, get_domain_from_id logger = logging.getLogger(__name__) @@ -48,22 +38,14 @@ def __init__(self, hs): self.store = hs.get_datastore() self._clock = hs.get_clock() - def _check_sigs_and_hash( + async def _check_sigs_and_hash( self, room_version: RoomVersion, pdu: EventBase - ) -> Deferred: - return make_deferred_yieldable( - self._check_sigs_and_hashes(room_version, [pdu])[0] - ) - - def _check_sigs_and_hashes( - self, room_version: RoomVersion, pdus: List[EventBase] - ) -> List[Deferred]: - """Checks that each of the received events is correctly signed by the - sending server. + ) -> EventBase: + """Checks that event is correctly signed by the sending server. Args: - room_version: The room version of the PDUs - pdus: the events to be checked + room_version: The room version of the PDU + pdu: the event to be checked Returns: For each input event, a deferred which: @@ -73,77 +55,60 @@ def _check_sigs_and_hashes( * throws a SynapseError if the signature check failed. The deferreds run their callbacks in the sentinel """ - deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus) - - ctx = current_context() - - @defer.inlineCallbacks - def callback(_, pdu: EventBase): - with PreserveLoggingContext(ctx): - if not check_event_content_hash(pdu): - # let's try to distinguish between failures because the event was - # redacted (which are somewhat expected) vs actual ball-tampering - # incidents. - # - # This is just a heuristic, so we just assume that if the keys are - # about the same between the redacted and received events, then the - # received event was probably a redacted copy (but we then use our - # *actual* redacted copy to be on the safe side.) - redacted_event = prune_event(pdu) - if set(redacted_event.keys()) == set(pdu.keys()) and set( - redacted_event.content.keys() - ) == set(pdu.content.keys()): - logger.info( - "Event %s seems to have been redacted; using our redacted " - "copy", - pdu.event_id, - ) - else: - logger.warning( - "Event %s content has been tampered, redacting", - pdu.event_id, - ) - return redacted_event - - result = yield defer.ensureDeferred( - self.spam_checker.check_event_for_spam(pdu) + try: + await _check_sigs_on_pdu(self.keyring, room_version, pdu) + except Exception as e: + logger.warning( + "Signature check failed for %s: %s", + pdu.event_id, + e, + ) + raise + + if not check_event_content_hash(pdu): + # let's try to distinguish between failures because the event was + # redacted (which are somewhat expected) vs actual ball-tampering + # incidents. + # + # This is just a heuristic, so we just assume that if the keys are + # about the same between the redacted and received events, then the + # received event was probably a redacted copy (but we then use our + # *actual* redacted copy to be on the safe side.) + redacted_event = prune_event(pdu) + if set(redacted_event.keys()) == set(pdu.keys()) and set( + redacted_event.content.keys() + ) == set(pdu.content.keys()): + logger.info( + "Event %s seems to have been redacted; using our redacted copy", + pdu.event_id, ) - - if result: - logger.warning( - "Event contains spam, redacting %s: %s", - pdu.event_id, - pdu.get_pdu_json(), - ) - return prune_event(pdu) - - return pdu - - def errback(failure: Failure, pdu: EventBase): - failure.trap(SynapseError) - with PreserveLoggingContext(ctx): + else: logger.warning( - "Signature check failed for %s: %s", + "Event %s content has been tampered, redacting", pdu.event_id, - failure.getErrorMessage(), ) - return failure + return redacted_event + + result = await self.spam_checker.check_event_for_spam(pdu) - for deferred, pdu in zip(deferreds, pdus): - deferred.addCallbacks( - callback, errback, callbackArgs=[pdu], errbackArgs=[pdu] + if result: + logger.warning( + "Event contains spam, redacting %s: %s", + pdu.event_id, + pdu.get_pdu_json(), ) + return prune_event(pdu) - return deferreds + return pdu class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])): pass -def _check_sigs_on_pdus( - keyring: Keyring, room_version: RoomVersion, pdus: Iterable[EventBase] -) -> List[Deferred]: +async def _check_sigs_on_pdu( + keyring: Keyring, room_version: RoomVersion, pdu: EventBase +): """Check that the given events are correctly signed Args: @@ -177,90 +142,47 @@ def _check_sigs_on_pdus( # let's start by getting the domain for each pdu, and flattening the event back # to JSON. - pdus_to_check = [ - PduToCheckSig( - pdu=p, - sender_domain=get_domain_from_id(p.sender), - deferreds=[], - ) - for p in pdus - ] - # First we check that the sender event is signed by the sender's domain # (except if its a 3pid invite, in which case it may be sent by any server) - pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)] - - more_deferreds = keyring.verify_events_for_server( - [ - ( - p.sender_domain, - p.pdu, - p.pdu.origin_server_ts if room_version.enforce_key_validity else 0, + if not _is_invite_via_3pid(pdu): + try: + await keyring.verify_event_for_server( + get_domain_from_id(pdu.sender), + pdu, + pdu.origin_server_ts if room_version.enforce_key_validity else 0, ) - for p in pdus_to_check_sender - ] - ) - - def sender_err(e, pdu_to_check): - errmsg = "event id %s: unable to verify signature for sender %s: %s" % ( - pdu_to_check.pdu.event_id, - pdu_to_check.sender_domain, - e.getErrorMessage(), - ) - raise SynapseError(403, errmsg, Codes.FORBIDDEN) - - for p, d in zip(pdus_to_check_sender, more_deferreds): - d.addErrback(sender_err, p) - p.deferreds.append(d) + except Exception as e: + errmsg = "event id %s: unable to verify signature for sender %s: %s" % ( + pdu.event_id, + get_domain_from_id(pdu.sender), + e, + ) + raise SynapseError(403, errmsg, Codes.FORBIDDEN) # now let's look for events where the sender's domain is different to the # event id's domain (normally only the case for joins/leaves), and add additional # checks. Only do this if the room version has a concept of event ID domain # (ie, the room version uses old-style non-hash event IDs). - if room_version.event_format == EventFormatVersions.V1: - pdus_to_check_event_id = [ - p - for p in pdus_to_check - if p.sender_domain != get_domain_from_id(p.pdu.event_id) - ] - - more_deferreds = keyring.verify_events_for_server( - [ - ( - get_domain_from_id(p.pdu.event_id), - p.pdu, - p.pdu.origin_server_ts if room_version.enforce_key_validity else 0, - ) - for p in pdus_to_check_event_id - ] - ) - - def event_err(e, pdu_to_check): + if room_version.event_format == EventFormatVersions.V1 and get_domain_from_id( + pdu.event_id + ): + try: + await keyring.verify_event_for_server( + get_domain_from_id(pdu.event_id), + pdu, + pdu.origin_server_ts if room_version.enforce_key_validity else 0, + ) + except Exception as e: errmsg = ( - "event id %s: unable to verify signature for event id domain: %s" - % (pdu_to_check.pdu.event_id, e.getErrorMessage()) + "event id %s: unable to verify signature for event id domain %s: %s" + % ( + pdu.event_id, + get_domain_from_id(pdu.event_id), + e, + ) ) raise SynapseError(403, errmsg, Codes.FORBIDDEN) - for p, d in zip(pdus_to_check_event_id, more_deferreds): - d.addErrback(event_err, p) - p.deferreds.append(d) - - # replace lists of deferreds with single Deferreds - return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check] - - -def _flatten_deferred_list(deferreds: List[Deferred]) -> Deferred: - """Given a list of deferreds, either return the single deferred, - combine into a DeferredList, or return an already resolved deferred. - """ - if len(deferreds) > 1: - return DeferredList(deferreds, fireOnOneErrback=True, consumeErrors=True) - elif len(deferreds) == 1: - return deferreds[0] - else: - return defer.succeed(None) - def _is_invite_via_3pid(event: EventBase) -> bool: return ( diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index e0e9f5d0beeb..5055e41003e2 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -35,9 +35,6 @@ import attr from prometheus_client import Counter -from twisted.internet import defer -from twisted.internet.defer import Deferred - from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( CodeMessageException, @@ -56,10 +53,9 @@ from synapse.events import EventBase, builder from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.transport.client import SendJoinResponse -from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.logging.utils import log_function from synapse.types import JsonDict, get_domain_from_id -from synapse.util import unwrapFirstError +from synapse.util.async_helpers import yieldable_gather_results from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination @@ -386,51 +382,75 @@ async def _check_sigs_and_hash_and_fetch( Returns: A list of PDUs that have valid signatures and hashes. """ - deferreds = self._check_sigs_and_hashes(room_version, pdus) + valid_pdus = await yieldable_gather_results( + self._check_sigs_and_hash_and_fetch_one, + pdus, + origin=origin, + room_version=room_version, + outlier=outlier, + ) - async def handle_check_result(pdu: EventBase, deferred: Deferred): - try: - res = await make_deferred_yieldable(deferred) - except SynapseError: - res = None + if include_none: + return valid_pdus + else: + return [p for p in valid_pdus if p] - if not res: - # Check local db. - res = await self.store.get_event( - pdu.event_id, allow_rejected=True, allow_none=True - ) + async def _check_sigs_and_hash_and_fetch_one( + self, + pdu: EventBase, + origin: str, + room_version: RoomVersion, + outlier: bool = False, + ) -> Optional[EventBase]: + """Takes a PDU and checks its signatures and hashes. If the PDU fails + its signature check then we check if we have it in the database and if + not then request if from the originating server of that PDU. - pdu_origin = get_domain_from_id(pdu.sender) - if not res and pdu_origin != origin: - try: - res = await self.get_pdu( - destinations=[pdu_origin], - event_id=pdu.event_id, - room_version=room_version, - outlier=outlier, - timeout=10000, - ) - except SynapseError: - pass + If then PDU fails its content hash check then it is redacted. - if not res: - logger.warning( - "Failed to find copy of %s with valid signature", pdu.event_id - ) + Args: + origin + pdu + room_version + outlier: Whether the events are outliers or not + include_none: Whether to include None in the returned list + for events that have failed their checks - return res + Returns: + The PDU (possibly redacted) if it has valid signatures and hashes. + """ - handle = preserve_fn(handle_check_result) - deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)] + res = None + try: + res = await self._check_sigs_and_hash(room_version, pdu) + except SynapseError: + pass + + if not res: + # Check local db. + res = await self.store.get_event( + pdu.event_id, allow_rejected=True, allow_none=True + ) - valid_pdus = await make_deferred_yieldable( - defer.gatherResults(deferreds2, consumeErrors=True) - ).addErrback(unwrapFirstError) + pdu_origin = get_domain_from_id(pdu.sender) + if not res and pdu_origin != origin: + try: + res = await self.get_pdu( + destinations=[pdu_origin], + event_id=pdu.event_id, + room_version=room_version, + outlier=outlier, + timeout=10000, + ) + except SynapseError: + pass - if include_none: - return valid_pdus - else: - return [p for p in valid_pdus if p] + if not res: + logger.warning( + "Failed to find copy of %s with valid signature", pdu.event_id + ) + + return res async def get_event_auth( self, destination: str, room_id: str, event_id: str From ae6fd33d1c5248991400070cac4453332dd09fea Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 2 Jun 2021 14:12:53 +0100 Subject: [PATCH 02/10] Use concurrently_execute --- synapse/federation/federation_client.py | 34 +++++++++++++++++-------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 5055e41003e2..7ac347677bf1 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -21,6 +21,7 @@ Any, Awaitable, Callable, + Collection, Dict, Iterable, List, @@ -55,7 +56,7 @@ from synapse.federation.transport.client import SendJoinResponse from synapse.logging.utils import log_function from synapse.types import JsonDict, get_domain_from_id -from synapse.util.async_helpers import yieldable_gather_results +from synapse.util.async_helpers import concurrently_execute, yieldable_gather_results from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination @@ -356,7 +357,7 @@ async def get_room_state_ids( async def _check_sigs_and_hash_and_fetch( self, origin: str, - pdus: List[EventBase], + pdus: Collection[EventBase], room_version: RoomVersion, outlier: bool = False, include_none: bool = False, @@ -691,8 +692,6 @@ async def send_request(destination) -> Dict[str, Any]: state = response.state auth_chain = response.auth_events - pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)} - create_event = None for e in state: if (e.type, e.state_key) == (EventTypes.Create, ""): @@ -716,14 +715,29 @@ async def send_request(destination) -> Dict[str, Any]: % (create_room_version,) ) - valid_pdus = await self._check_sigs_and_hash_and_fetch( - destination, - list(pdus.values()), - outlier=True, - room_version=room_version, + logger.info( + "Processing from send_join %d events", len(state) + len(auth_chain) ) - valid_pdus_map = {p.event_id: p for p in valid_pdus} + # We now go and check the signatures and hashes for the event. Note + # that we limit how many events we process at a time to keep the + # memory overhead from exploding. + valid_pdus_map: Dict[str, EventBase] = {} + + async def _execute(pdu: EventBase) -> None: + valid_pdu = await self._check_sigs_and_hash_and_fetch_one( + pdu=pdu, + origin=destination, + outlier=True, + room_version=room_version, + ) + + if valid_pdu: + valid_pdus_map[valid_pdu.event_id] = valid_pdu + + await concurrently_execute( + _execute, itertools.chain(state, auth_chain), 10000 + ) # NB: We *need* to copy to ensure that we don't have multiple # references being passed on, as that causes... issues. From e290fbac82d6f4245013ea680a64647c7e4e0258 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 3 Jun 2021 09:27:40 +0100 Subject: [PATCH 03/10] Newsfile --- changelog.d/10117.feature | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/10117.feature diff --git a/changelog.d/10117.feature b/changelog.d/10117.feature new file mode 100644 index 000000000000..e137e142c638 --- /dev/null +++ b/changelog.d/10117.feature @@ -0,0 +1 @@ +Significantly reduce memory usage of joining large remote rooms. From e6d70e8f830f2ed566adffa5119b1cc583def9bd Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 7 Jun 2021 09:51:53 +0100 Subject: [PATCH 04/10] Fix up docstrings --- synapse/federation/federation_base.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 8069d904a2bc..08cc816170d4 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -48,13 +48,10 @@ async def _check_sigs_and_hash( pdu: the event to be checked Returns: - For each input event, a deferred which: - * returns the original event if the checks pass - * returns a redacted version of the event (if the signature + * the original event if the checks pass + * a redacted version of the event (if the signature matched but the hash did not) - * throws a SynapseError if the signature check failed. - The deferreds run their callbacks in the sentinel - """ + * throws a SynapseError if the signature check failed.""" try: await _check_sigs_on_pdu(self.keyring, room_version, pdu) except Exception as e: @@ -108,17 +105,15 @@ class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferr async def _check_sigs_on_pdu( keyring: Keyring, room_version: RoomVersion, pdu: EventBase -): +) -> None: """Check that the given events are correctly signed + Raise a SynapseError if the event wasn't correctly signed. + Args: keyring: keyring object to do the checks room_version: the room version of the PDUs pdus: the events to be checked - - Returns: - A Deferred for each event in pdus, which will either succeed if - the signatures are valid, or fail (with a SynapseError) if not. """ # we want to check that the event is signed by: From 1ccc8aca5bf048a4e2fb55e25cc8f236ae03a391 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 7 Jun 2021 09:52:28 +0100 Subject: [PATCH 05/10] Only catch SynapseErrors --- synapse/federation/federation_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 08cc816170d4..03e98c482e79 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -54,7 +54,7 @@ async def _check_sigs_and_hash( * throws a SynapseError if the signature check failed.""" try: await _check_sigs_on_pdu(self.keyring, room_version, pdu) - except Exception as e: + except SynapseError as e: logger.warning( "Signature check failed for %s: %s", pdu.event_id, From 2a266ac1a1799f05112efc6de5407f2292dd8393 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 7 Jun 2021 09:53:49 +0100 Subject: [PATCH 06/10] Fix missing comparison --- synapse/federation/federation_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 03e98c482e79..c066617b9233 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -160,7 +160,7 @@ async def _check_sigs_on_pdu( # (ie, the room version uses old-style non-hash event IDs). if room_version.event_format == EventFormatVersions.V1 and get_domain_from_id( pdu.event_id - ): + ) != get_domain_from_id(pdu.sender): try: await keyring.verify_event_for_server( get_domain_from_id(pdu.event_id), From 3f7d8fcabd3847e6f2a871d4a4d7bb0c21ab99df Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 7 Jun 2021 09:59:00 +0100 Subject: [PATCH 07/10] Add concurrency limit --- synapse/federation/federation_client.py | 32 +++++++++++++++---------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 7ac347677bf1..bbd5bc759dc9 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -56,7 +56,7 @@ from synapse.federation.transport.client import SendJoinResponse from synapse.logging.utils import log_function from synapse.types import JsonDict, get_domain_from_id -from synapse.util.async_helpers import concurrently_execute, yieldable_gather_results +from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination @@ -383,18 +383,26 @@ async def _check_sigs_and_hash_and_fetch( Returns: A list of PDUs that have valid signatures and hashes. """ - valid_pdus = await yieldable_gather_results( - self._check_sigs_and_hash_and_fetch_one, - pdus, - origin=origin, - room_version=room_version, - outlier=outlier, - ) - if include_none: - return valid_pdus - else: - return [p for p in valid_pdus if p] + # We limit how many PDUs we check at once, as if we try to do hundreds + # of thousands of PDUs at once we see large memory spikes. + + valid_pdus = [] + + async def _execute(pdu: EventBase) -> None: + valid_pdu = await self._check_sigs_and_hash_and_fetch_one( + pdu=pdu, + origin=origin, + outlier=outlier, + room_version=room_version, + ) + + if valid_pdu or include_none: + valid_pdus.append(valid_pdu) + + await concurrently_execute(_execute, pdus, 10000) + + return valid_pdus async def _check_sigs_and_hash_and_fetch_one( self, From 76971befd10588a04a3ad751f634cbed5bd2d7c8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 7 Jun 2021 09:59:13 +0100 Subject: [PATCH 08/10] Remove unused parameter --- synapse/federation/federation_client.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index bbd5bc759dc9..1076ebc0367e 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -360,7 +360,6 @@ async def _check_sigs_and_hash_and_fetch( pdus: Collection[EventBase], room_version: RoomVersion, outlier: bool = False, - include_none: bool = False, ) -> List[EventBase]: """Takes a list of PDUs and checks the signatures and hashes of each one. If a PDU fails its signature check then we check if we have it in @@ -377,8 +376,6 @@ async def _check_sigs_and_hash_and_fetch( pdu room_version outlier: Whether the events are outliers or not - include_none: Whether to include None in the returned list - for events that have failed their checks Returns: A list of PDUs that have valid signatures and hashes. @@ -397,7 +394,7 @@ async def _execute(pdu: EventBase) -> None: room_version=room_version, ) - if valid_pdu or include_none: + if valid_pdu: valid_pdus.append(valid_pdu) await concurrently_execute(_execute, pdus, 10000) From b3efd53e256586896e06c010e7d0a84eed0f4ffd Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 7 Jun 2021 10:13:44 +0100 Subject: [PATCH 09/10] Make `concurrently_execute` handle large limits better --- synapse/util/async_helpers.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 5c55bb0125d9..c31e8ec3763b 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -160,8 +160,11 @@ def __repr__(self) -> str: ) +T = TypeVar("T") + + def concurrently_execute( - func: Callable, args: Iterable[Any], limit: int + func: Callable[[T], Any], args: Iterable[T], limit: int ) -> defer.Deferred: """Executes the function with each argument concurrently while limiting the number of concurrent executions. @@ -173,20 +176,26 @@ def concurrently_execute( limit: Maximum number of conccurent executions. Returns: - Deferred[list]: Resolved when all function invocations have finished. + Deferred: Resolved when all function invocations have finished. """ it = iter(args) - async def _concurrently_execute_inner(): + async def _concurrently_execute_inner(value: T) -> None: try: while True: - await maybe_awaitable(func(next(it))) + await maybe_awaitable(func(value)) + value = next(it) except StopIteration: pass + # We use `zip` to handle the case where the number of args is less than the + # limit, avoiding needlessly spawning unnecessary background tasks. return make_deferred_yieldable( defer.gatherResults( - [run_in_background(_concurrently_execute_inner) for _ in range(limit)], + [ + run_in_background(_concurrently_execute_inner, value) + for value, _ in zip(it, range(limit)) + ], consumeErrors=True, ) ).addErrback(unwrapFirstError) From 18be5569c26115e56a4b017ac6fa44d93efe1efa Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 7 Jun 2021 18:05:14 +0100 Subject: [PATCH 10/10] Fix bug where we dropped values --- synapse/util/async_helpers.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index c31e8ec3763b..061102c3c894 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -15,6 +15,7 @@ import collections import inspect +import itertools import logging from contextlib import contextmanager from typing import ( @@ -188,13 +189,14 @@ async def _concurrently_execute_inner(value: T) -> None: except StopIteration: pass - # We use `zip` to handle the case where the number of args is less than the - # limit, avoiding needlessly spawning unnecessary background tasks. + # We use `itertools.islice` to handle the case where the number of args is + # less than the limit, avoiding needlessly spawning unnecessary background + # tasks. return make_deferred_yieldable( defer.gatherResults( [ run_in_background(_concurrently_execute_inner, value) - for value, _ in zip(it, range(limit)) + for value in itertools.islice(it, limit) ], consumeErrors=True, )