diff --git a/changelog.d/13404.misc b/changelog.d/13404.misc new file mode 100644 index 000000000000..655be4061beb --- /dev/null +++ b/changelog.d/13404.misc @@ -0,0 +1 @@ +Refactor `_resolve_state_at_missing_prevs` to compute an `EventContext` instead. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 2ba2b1527eac..3cbaaeb9dce3 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -23,7 +23,6 @@ Dict, Iterable, List, - Optional, Sequence, Set, Tuple, @@ -278,9 +277,8 @@ async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None: ) try: - await self._process_received_pdu( - origin, pdu, state_ids=None, partial_state=None - ) + context = await self._state_handler.compute_event_context(pdu) + await self._process_received_pdu(origin, pdu, context) except PartialStateConflictError: # The room was un-partial stated while we were processing the PDU. # Try once more, with full state this time. @@ -288,9 +286,8 @@ async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None: "Room %s was un-partial stated while processing the PDU, trying again.", room_id, ) - await self._process_received_pdu( - origin, pdu, state_ids=None, partial_state=None - ) + context = await self._state_handler.compute_event_context(pdu) + await self._process_received_pdu(origin, pdu, context) async def on_send_membership_event( self, origin: str, event: EventBase @@ -320,6 +317,7 @@ async def on_send_membership_event( The event and context of the event after inserting it into the room graph. Raises: + RuntimeError if any prev_events are missing SynapseError if the event is not accepted into the room PartialStateConflictError if the room was un-partial stated in between computing the state at the event and persisting it. The caller should @@ -380,7 +378,7 @@ async def on_send_membership_event( # need to. await self._event_creation_handler.cache_joined_hosts_for_event(event, context) - await self._check_for_soft_fail(event, None, origin=origin) + await self._check_for_soft_fail(event, context=context, origin=origin) await self._run_push_actions_and_persist_event(event, context) return event, context @@ -538,36 +536,10 @@ async def update_state_for_partial_state_event( # # This is the same operation as we do when we receive a regular event # over federation. - state_ids, partial_state = await self._resolve_state_at_missing_prevs( + context = await self._compute_event_context_with_maybe_missing_prevs( destination, event ) - - # There are three possible cases for (state_ids, partial_state): - # * `state_ids` and `partial_state` are both `None` if we had all the - # prev_events. The prev_events may or may not have partial state and - # we won't know until we compute the event context. - # * `state_ids` is not `None` and `partial_state` is `False` if we were - # missing some prev_events (but we have full state for any we did - # have). We calculated the full state after the prev_events. - # * `state_ids` is not `None` and `partial_state` is `True` if we were - # missing some, but not all, prev_events. At least one of the - # prev_events we did have had partial state, so we calculated a partial - # state after the prev_events. - - context = None - if state_ids is not None and partial_state: - # the state after the prev events is still partial. We can't de-partial - # state the event, so don't bother building the event context. - pass - else: - # build a new state group for it if need be - context = await self._state_handler.compute_event_context( - event, - state_ids_before_event=state_ids, - partial_state=partial_state, - ) - - if context is None or context.partial_state: + if context.partial_state: # this can happen if some or all of the event's prev_events still have # partial state. We were careful to only pick events from the db without # partial-state prev events, so that implies that a prev event has @@ -833,26 +805,25 @@ async def _process_pulled_event( try: try: - state_ids, partial_state = await self._resolve_state_at_missing_prevs( + context = await self._compute_event_context_with_maybe_missing_prevs( origin, event ) await self._process_received_pdu( origin, event, - state_ids=state_ids, - partial_state=partial_state, + context, backfilled=backfilled, ) except PartialStateConflictError: # The room was un-partial stated while we were processing the event. # Try once more, with full state this time. - state_ids, partial_state = await self._resolve_state_at_missing_prevs( + context = await self._compute_event_context_with_maybe_missing_prevs( origin, event ) # We ought to have full state now, barring some unlikely race where we left and # rejoned the room in the background. - if state_ids is not None and partial_state: + if context.partial_state: raise AssertionError( f"Event {event.event_id} still has a partial resolved state " f"after room {event.room_id} was un-partial stated" @@ -861,8 +832,7 @@ async def _process_pulled_event( await self._process_received_pdu( origin, event, - state_ids=state_ids, - partial_state=partial_state, + context, backfilled=backfilled, ) except FederationError as e: @@ -871,15 +841,18 @@ async def _process_pulled_event( else: raise - async def _resolve_state_at_missing_prevs( + async def _compute_event_context_with_maybe_missing_prevs( self, dest: str, event: EventBase - ) -> Tuple[Optional[StateMap[str]], Optional[bool]]: - """Calculate the state at an event with missing prev_events. + ) -> EventContext: + """Build an EventContext structure for a non-outlier event whose prev_events may + be missing. - This is used when we have pulled a batch of events from a remote server, and - still don't have all the prev_events. + This is used when we have pulled a batch of events from a remote server, and may + not have all the prev_events. - If we already have all the prev_events for `event`, this method does nothing. + To build an EventContext, we need to calculate the state before the event. If we + already have all the prev_events for `event`, we can simply use the state after + the prev_events to calculate the state before `event`. Otherwise, the missing prevs become new backwards extremities, and we fall back to asking the remote server for the state after each missing `prev_event`, @@ -900,10 +873,7 @@ async def _resolve_state_at_missing_prevs( event: an event to check for missing prevs. Returns: - if we already had all the prev events, `None, None`. Otherwise, returns a - tuple containing: - * the event ids of the state at `event`. - * a boolean indicating whether the state may be partial. + The event context. Raises: FederationError if we fail to get the state from the remote server after any @@ -917,7 +887,7 @@ async def _resolve_state_at_missing_prevs( missing_prevs = prevs - seen if not missing_prevs: - return None, None + return await self._state_handler.compute_event_context(event) logger.info( "Event %s is missing prev_events %s: calculating state for a " @@ -983,7 +953,9 @@ async def _resolve_state_at_missing_prevs( "We can't get valid state history.", affected=event_id, ) - return state_map, partial_state + return await self._state_handler.compute_event_context( + event, state_ids_before_event=state_map, partial_state=partial_state + ) async def _get_state_ids_after_missing_prev_event( self, @@ -1152,8 +1124,7 @@ async def _process_received_pdu( self, origin: str, event: EventBase, - state_ids: Optional[StateMap[str]], - partial_state: Optional[bool], + context: EventContext, backfilled: bool = False, ) -> None: """Called when we have a new non-outlier event. @@ -1175,32 +1146,18 @@ async def _process_received_pdu( event: event to be persisted - state_ids: Normally None, but if we are handling a gap in the graph - (ie, we are missing one or more prev_events), the resolved state at the - event - - partial_state: - `True` if `state_ids` is partial and omits non-critical membership - events. - `False` if `state_ids` is the full state. - `None` if `state_ids` is not provided. In this case, the flag will be - calculated based on `event`'s prev events. + context: The `EventContext` to persist the event with. backfilled: True if this is part of a historical batch of events (inhibits notification to clients, and validation of device keys.) PartialStateConflictError: if the room was un-partial stated in between - computing the state at the event and persisting it. The caller should retry - exactly once in this case. + computing the state at the event and persisting it. The caller should + recompute `context` and retry exactly once when this happens. """ logger.debug("Processing event: %s", event) assert not event.internal_metadata.outlier - context = await self._state_handler.compute_event_context( - event, - state_ids_before_event=state_ids, - partial_state=partial_state, - ) try: await self._check_event_auth(origin, event, context) except AuthError as e: @@ -1212,7 +1169,7 @@ async def _process_received_pdu( # For new (non-backfilled and non-outlier) events we check if the event # passes auth based on the current state. If it doesn't then we # "soft-fail" the event. - await self._check_for_soft_fail(event, state_ids, origin=origin) + await self._check_for_soft_fail(event, context=context, origin=origin) await self._run_push_actions_and_persist_event(event, context, backfilled) @@ -1773,7 +1730,7 @@ async def _maybe_kick_guest_users(self, event: EventBase) -> None: async def _check_for_soft_fail( self, event: EventBase, - state_ids: Optional[StateMap[str]], + context: EventContext, origin: str, ) -> None: """Checks if we should soft fail the event; if so, marks the event as @@ -1784,7 +1741,7 @@ async def _check_for_soft_fail( Args: event - state_ids: The state at the event if we don't have all the event's prev events + context: The `EventContext` which we are about to persist the event with. origin: The host the event originates from. """ if await self._store.is_partial_state_room(event.room_id): @@ -1810,11 +1767,15 @@ async def _check_for_soft_fail( auth_types = auth_types_for_event(room_version_obj, event) # Calculate the "current state". - if state_ids is not None: - # If we're explicitly given the state then we won't have all the - # prev events, and so we have a gap in the graph. In this case - # we want to be a little careful as we might have been down for - # a while and have an incorrect view of the current state, + seen_event_ids = await self._store.have_events_in_timeline(prev_event_ids) + has_missing_prevs = bool(prev_event_ids - seen_event_ids) + if has_missing_prevs: + # We don't have all the prev_events of this event, which means we have a + # gap in the graph, and the new event is going to become a new backwards + # extremity. + # + # In this case we want to be a little careful as we might have been + # down for a while and have an incorrect view of the current state, # however we still want to do checks as gaps are easy to # maliciously manufacture. # @@ -1827,6 +1788,7 @@ async def _check_for_soft_fail( event.room_id, extrem_ids ) state_sets: List[StateMap[str]] = list(state_sets_d.values()) + state_ids = await context.get_prev_state_ids() state_sets.append(state_ids) current_state_ids = ( await self._state_resolution_handler.resolve_events_with_store( diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 69834de0de7c..c355e4f98ab3 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -278,6 +278,10 @@ async def compute_event_context( flag will be calculated based on `event`'s prev events. Returns: The event context. + + Raises: + RuntimeError if `state_ids_before_event` is not provided and one or more + prev events are missing or outliers. """ assert not event.internal_metadata.is_outlier() @@ -432,6 +436,10 @@ async def resolve_state_groups_for_events( Returns: The resolved state + + Raises: + RuntimeError if we don't have a state group for one or more of the events + (ie. they are outliers or unknown) """ logger.debug("resolve_state_groups event_ids %s", event_ids) diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 20805c94fa19..1e35046e07cd 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -338,6 +338,10 @@ async def get_state_group_for_events( event_ids: events to get state groups for await_full_state: if true, will block if we do not yet have complete state at these events. + + Raises: + RuntimeError if we don't have a state group for one or more of the events + (ie. they are outliers or unknown) """ if await_full_state: await self._partial_state_events_tracker.await_full_state(event_ids) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index fb06e5e81266..aea96a0986a6 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -280,16 +280,23 @@ def test_backfill_with_many_backward_extremities(self) -> None: # we poke this directly into _process_received_pdu, to avoid the # federation handler wanting to backfill the fake event. - self.get_success( - federation_event_handler._process_received_pdu( - self.OTHER_SERVER_NAME, + state_handler = self.hs.get_state_handler() + context = self.get_success( + state_handler.compute_event_context( event, - state_ids={ + state_ids_before_event={ (e.type, e.state_key): e.event_id for e in current_state }, partial_state=False, ) ) + self.get_success( + federation_event_handler._process_received_pdu( + self.OTHER_SERVER_NAME, + event, + context, + ) + ) # we should now have 8 backwards extremities. backwards_extremities = self.get_success(