Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Refactor _resolve_state_at_missing_prevs to return an EventContext #13404

1 change: 1 addition & 0 deletions changelog.d/13404.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor `_resolve_state_at_missing_prevs` to compute an `EventContext` instead.
126 changes: 44 additions & 82 deletions synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
Expand Down Expand Up @@ -278,19 +277,17 @@ async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None:
)

try:
await self._process_received_pdu(
origin, pdu, state_ids=None, partial_state=None
)
context = await self._state_handler.compute_event_context(pdu)
await self._process_received_pdu(origin, pdu, context)
except PartialStateConflictError:
# The room was un-partial stated while we were processing the PDU.
# Try once more, with full state this time.
logger.info(
"Room %s was un-partial stated while processing the PDU, trying again.",
room_id,
)
await self._process_received_pdu(
origin, pdu, state_ids=None, partial_state=None
)
context = await self._state_handler.compute_event_context(pdu)
await self._process_received_pdu(origin, pdu, context)

async def on_send_membership_event(
self, origin: str, event: EventBase
Expand Down Expand Up @@ -320,6 +317,7 @@ async def on_send_membership_event(
The event and context of the event after inserting it into the room graph.

Raises:
RuntimeError if any prev_events are missing
SynapseError if the event is not accepted into the room
PartialStateConflictError if the room was un-partial stated in between
computing the state at the event and persisting it. The caller should
Expand Down Expand Up @@ -380,7 +378,7 @@ async def on_send_membership_event(
# need to.
await self._event_creation_handler.cache_joined_hosts_for_event(event, context)

await self._check_for_soft_fail(event, None, origin=origin)
await self._check_for_soft_fail(event, context=context, origin=origin)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have all the prev events here, otherwise we would have raised a RuntimeError when computing the event context. Thus the soft fail check will behave the same as before.

await self._run_push_actions_and_persist_event(event, context)
return event, context

Expand Down Expand Up @@ -538,36 +536,10 @@ async def update_state_for_partial_state_event(
#
# This is the same operation as we do when we receive a regular event
# over federation.
state_ids, partial_state = await self._resolve_state_at_missing_prevs(
context = await self._compute_event_context_with_maybe_missing_prevs(
destination, event
)

# There are three possible cases for (state_ids, partial_state):
# * `state_ids` and `partial_state` are both `None` if we had all the
# prev_events. The prev_events may or may not have partial state and
# we won't know until we compute the event context.
# * `state_ids` is not `None` and `partial_state` is `False` if we were
# missing some prev_events (but we have full state for any we did
# have). We calculated the full state after the prev_events.
# * `state_ids` is not `None` and `partial_state` is `True` if we were
# missing some, but not all, prev_events. At least one of the
# prev_events we did have had partial state, so we calculated a partial
# state after the prev_events.

context = None
if state_ids is not None and partial_state:
# the state after the prev events is still partial. We can't de-partial
# state the event, so don't bother building the event context.
pass
else:
# build a new state group for it if need be
context = await self._state_handler.compute_event_context(
event,
state_ids_before_event=state_ids,
partial_state=partial_state,
)

if context is None or context.partial_state:
Comment on lines -544 to -570
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We lose the optimization here in exchange for simpler code.

if context.partial_state:
# this can happen if some or all of the event's prev_events still have
# partial state. We were careful to only pick events from the db without
# partial-state prev events, so that implies that a prev event has
Expand Down Expand Up @@ -833,26 +805,25 @@ async def _process_pulled_event(

try:
try:
state_ids, partial_state = await self._resolve_state_at_missing_prevs(
context = await self._compute_event_context_with_maybe_missing_prevs(
origin, event
)
await self._process_received_pdu(
origin,
event,
state_ids=state_ids,
partial_state=partial_state,
context,
backfilled=backfilled,
)
except PartialStateConflictError:
# The room was un-partial stated while we were processing the event.
# Try once more, with full state this time.
state_ids, partial_state = await self._resolve_state_at_missing_prevs(
context = await self._compute_event_context_with_maybe_missing_prevs(
origin, event
)

# We ought to have full state now, barring some unlikely race where we left and
# rejoned the room in the background.
if state_ids is not None and partial_state:
if context.partial_state:
raise AssertionError(
f"Event {event.event_id} still has a partial resolved state "
f"after room {event.room_id} was un-partial stated"
Expand All @@ -861,8 +832,7 @@ async def _process_pulled_event(
await self._process_received_pdu(
origin,
event,
state_ids=state_ids,
partial_state=partial_state,
context,
backfilled=backfilled,
)
except FederationError as e:
Expand All @@ -871,15 +841,18 @@ async def _process_pulled_event(
else:
raise

async def _resolve_state_at_missing_prevs(
async def _compute_event_context_with_maybe_missing_prevs(
self, dest: str, event: EventBase
) -> Tuple[Optional[StateMap[str]], Optional[bool]]:
"""Calculate the state at an event with missing prev_events.
) -> EventContext:
"""Build an EventContext structure for a non-outlier event whose prev_events may
be missing.

This is used when we have pulled a batch of events from a remote server, and
still don't have all the prev_events.
This is used when we have pulled a batch of events from a remote server, and may
not have all the prev_events.

If we already have all the prev_events for `event`, this method does nothing.
To build an EventContext, we need to calculate the state before the event. If we
already have all the prev_events for `event`, we can simply use the state after
the prev_events to calculate the state before `event`.

Otherwise, the missing prevs become new backwards extremities, and we fall back
to asking the remote server for the state after each missing `prev_event`,
Expand All @@ -900,10 +873,7 @@ async def _resolve_state_at_missing_prevs(
event: an event to check for missing prevs.

Returns:
if we already had all the prev events, `None, None`. Otherwise, returns a
tuple containing:
* the event ids of the state at `event`.
* a boolean indicating whether the state may be partial.
The event context.

Raises:
FederationError if we fail to get the state from the remote server after any
Expand All @@ -917,7 +887,7 @@ async def _resolve_state_at_missing_prevs(
missing_prevs = prevs - seen

if not missing_prevs:
return None, None
return await self._state_handler.compute_event_context(event)

logger.info(
"Event %s is missing prev_events %s: calculating state for a "
Expand Down Expand Up @@ -983,7 +953,9 @@ async def _resolve_state_at_missing_prevs(
"We can't get valid state history.",
affected=event_id,
)
return state_map, partial_state
return await self._state_handler.compute_event_context(
event, state_ids_before_event=state_map, partial_state=partial_state
)

async def _get_state_ids_after_missing_prev_event(
self,
Expand Down Expand Up @@ -1152,8 +1124,7 @@ async def _process_received_pdu(
self,
origin: str,
event: EventBase,
state_ids: Optional[StateMap[str]],
partial_state: Optional[bool],
context: EventContext,
backfilled: bool = False,
) -> None:
"""Called when we have a new non-outlier event.
Expand All @@ -1175,32 +1146,18 @@ async def _process_received_pdu(

event: event to be persisted

state_ids: Normally None, but if we are handling a gap in the graph
(ie, we are missing one or more prev_events), the resolved state at the
event

partial_state:
`True` if `state_ids` is partial and omits non-critical membership
events.
`False` if `state_ids` is the full state.
`None` if `state_ids` is not provided. In this case, the flag will be
calculated based on `event`'s prev events.
context: The `EventContext` to persist the event with.

backfilled: True if this is part of a historical batch of events (inhibits
notification to clients, and validation of device keys.)

PartialStateConflictError: if the room was un-partial stated in between
computing the state at the event and persisting it. The caller should retry
exactly once in this case.
computing the state at the event and persisting it. The caller should
recompute `context` and retry exactly once when this happens.
"""
logger.debug("Processing event: %s", event)
assert not event.internal_metadata.outlier

context = await self._state_handler.compute_event_context(
event,
state_ids_before_event=state_ids,
partial_state=partial_state,
)
try:
await self._check_event_auth(origin, event, context)
except AuthError as e:
Expand All @@ -1212,7 +1169,7 @@ async def _process_received_pdu(
# For new (non-backfilled and non-outlier) events we check if the event
# passes auth based on the current state. If it doesn't then we
# "soft-fail" the event.
await self._check_for_soft_fail(event, state_ids, origin=origin)
await self._check_for_soft_fail(event, context=context, origin=origin)

await self._run_push_actions_and_persist_event(event, context, backfilled)

Expand Down Expand Up @@ -1773,7 +1730,7 @@ async def _maybe_kick_guest_users(self, event: EventBase) -> None:
async def _check_for_soft_fail(
self,
event: EventBase,
state_ids: Optional[StateMap[str]],
context: EventContext,
origin: str,
) -> None:
Comment on lines 1730 to 1735
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It only makes sense to call _check_for_soft_fail prior to persisting an event and EventContexts are intended to hold information relevant to persisting an event. So we aren't losing much by accepting an EventContext instead of a StateMap.

"""Checks if we should soft fail the event; if so, marks the event as
Expand All @@ -1784,7 +1741,7 @@ async def _check_for_soft_fail(

Args:
event
state_ids: The state at the event if we don't have all the event's prev events
context: The `EventContext` which we are about to persist the event with.
origin: The host the event originates from.
"""
if await self._store.is_partial_state_room(event.room_id):
Expand All @@ -1810,11 +1767,15 @@ async def _check_for_soft_fail(
auth_types = auth_types_for_event(room_version_obj, event)

# Calculate the "current state".
if state_ids is not None:
# If we're explicitly given the state then we won't have all the
# prev events, and so we have a gap in the graph. In this case
# we want to be a little careful as we might have been down for
# a while and have an incorrect view of the current state,
seen_event_ids = await self._store.have_events_in_timeline(prev_event_ids)
has_missing_prevs = bool(prev_event_ids - seen_event_ids)
if has_missing_prevs:
# We don't have all the prev_events of this event, which means we have a
# gap in the graph, and the new event is going to become a new backwards
# extremity.
#
# In this case we want to be a little careful as we might have been
# down for a while and have an incorrect view of the current state,
# however we still want to do checks as gaps are easy to
# maliciously manufacture.
#
Expand All @@ -1827,6 +1788,7 @@ async def _check_for_soft_fail(
event.room_id, extrem_ids
)
state_sets: List[StateMap[str]] = list(state_sets_d.values())
state_ids = await context.get_prev_state_ids()
state_sets.append(state_ids)
current_state_ids = (
await self._state_resolution_handler.resolve_events_with_store(
Expand Down
8 changes: 8 additions & 0 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,10 @@ async def compute_event_context(
flag will be calculated based on `event`'s prev events.
Returns:
The event context.

Raises:
RuntimeError if `state_ids_before_event` is not provided and one or more
prev events are missing or outliers.
"""

assert not event.internal_metadata.is_outlier()
Expand Down Expand Up @@ -432,6 +436,10 @@ async def resolve_state_groups_for_events(

Returns:
The resolved state

Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie. they are outliers or unknown)
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)

Expand Down
4 changes: 4 additions & 0 deletions synapse/storage/controllers/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,10 @@ async def get_state_group_for_events(
event_ids: events to get state groups for
await_full_state: if true, will block if we do not yet have complete
state at these events.

Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie. they are outliers or unknown)
"""
if await_full_state:
await self._partial_state_events_tracker.await_full_state(event_ids)
Expand Down
15 changes: 11 additions & 4 deletions tests/handlers/test_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,16 +280,23 @@ def test_backfill_with_many_backward_extremities(self) -> None:

# we poke this directly into _process_received_pdu, to avoid the
# federation handler wanting to backfill the fake event.
self.get_success(
federation_event_handler._process_received_pdu(
self.OTHER_SERVER_NAME,
state_handler = self.hs.get_state_handler()
context = self.get_success(
state_handler.compute_event_context(
event,
state_ids={
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in current_state
},
partial_state=False,
)
)
self.get_success(
federation_event_handler._process_received_pdu(
self.OTHER_SERVER_NAME,
event,
context,
)
)

# we should now have 8 backwards extremities.
backwards_extremities = self.get_success(
Expand Down