Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Separate creating an event context from persisting it in the federati…
Browse files Browse the repository at this point in the history
…on handler (#9800)

This refactoring allows adding logic that uses the event context
before persisting it.
  • Loading branch information
clokep authored Apr 14, 2021
1 parent e8816c6 commit 936e698
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 67 deletions.
1 change: 1 addition & 0 deletions changelog.d/9800.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Update experimental support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083): restricting room access via group membership.
178 changes: 113 additions & 65 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@

@attr.s(slots=True)
class _NewEventInfo:
"""Holds information about a received event, ready for passing to _handle_new_events
"""Holds information about a received event, ready for passing to _auth_and_persist_events
Attributes:
event: the received event
Expand Down Expand Up @@ -807,7 +807,10 @@ async def _process_received_pdu(
logger.debug("Processing event: %s", event)

try:
await self._handle_new_event(origin, event, state=state)
context = await self.state_handler.compute_event_context(
event, old_state=state
)
await self._auth_and_persist_event(origin, event, context, state=state)
except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)

Expand Down Expand Up @@ -1010,7 +1013,9 @@ async def backfill(
)

if ev_infos:
await self._handle_new_events(dest, room_id, ev_infos, backfilled=True)
await self._auth_and_persist_events(
dest, room_id, ev_infos, backfilled=True
)

# Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth)
Expand All @@ -1023,10 +1028,12 @@ async def backfill(
# non-outliers
assert not event.internal_metadata.is_outlier()

context = await self.state_handler.compute_event_context(event)

# We store these one at a time since each event depends on the
# previous to work out the state.
# TODO: We can probably do something more clever here.
await self._handle_new_event(dest, event, backfilled=True)
await self._auth_and_persist_event(dest, event, context, backfilled=True)

return events

Expand Down Expand Up @@ -1360,7 +1367,7 @@ async def get_event(event_id: str):

event_infos.append(_NewEventInfo(event, None, auth))

await self._handle_new_events(
await self._auth_and_persist_events(
destination,
room_id,
event_infos,
Expand Down Expand Up @@ -1666,10 +1673,11 @@ async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict:
# would introduce the danger of backwards-compatibility problems.
event.internal_metadata.send_on_behalf_of = origin

context = await self._handle_new_event(origin, event)
context = await self.state_handler.compute_event_context(event)
context = await self._auth_and_persist_event(origin, event, context)

logger.debug(
"on_send_join_request: After _handle_new_event: %s, sigs: %s",
"on_send_join_request: After _auth_and_persist_event: %s, sigs: %s",
event.event_id,
event.signatures,
)
Expand Down Expand Up @@ -1878,10 +1886,11 @@ async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None:

event.internal_metadata.outlier = False

await self._handle_new_event(origin, event)
context = await self.state_handler.compute_event_context(event)
await self._auth_and_persist_event(origin, event, context)

logger.debug(
"on_send_leave_request: After _handle_new_event: %s, sigs: %s",
"on_send_leave_request: After _auth_and_persist_event: %s, sigs: %s",
event.event_id,
event.signatures,
)
Expand Down Expand Up @@ -1989,16 +1998,47 @@ async def get_persisted_pdu(
async def get_min_depth_for_context(self, context: str) -> int:
return await self.store.get_min_depth(context)

async def _handle_new_event(
async def _auth_and_persist_event(
self,
origin: str,
event: EventBase,
context: EventContext,
state: Optional[Iterable[EventBase]] = None,
auth_events: Optional[MutableStateMap[EventBase]] = None,
backfilled: bool = False,
) -> EventContext:
context = await self._prep_event(
origin, event, state=state, auth_events=auth_events, backfilled=backfilled
"""
Process an event by performing auth checks and then persisting to the database.
Args:
origin: The host the event originates from.
event: The event itself.
context:
The event context.
NB that this function potentially modifies it.
state:
The state events used to check the event for soft-fail. If this is
not provided the current state events will be used.
auth_events:
Map from (event_type, state_key) to event
Normally, our calculated auth_events based on the state of the room
at the event's position in the DAG, though occasionally (eg if the
event is an outlier), may be the auth events claimed by the remote
server.
backfilled: True if the event was backfilled.
Returns:
The event context.
"""
context = await self._check_event_auth(
origin,
event,
context,
state=state,
auth_events=auth_events,
backfilled=backfilled,
)

try:
Expand All @@ -2022,7 +2062,7 @@ async def _handle_new_event(

return context

async def _handle_new_events(
async def _auth_and_persist_events(
self,
origin: str,
room_id: str,
Expand All @@ -2040,9 +2080,13 @@ async def _handle_new_events(
async def prep(ev_info: _NewEventInfo):
event = ev_info.event
with nested_logging_context(suffix=event.event_id):
res = await self._prep_event(
res = await self.state_handler.compute_event_context(
event, old_state=ev_info.state
)
res = await self._check_event_auth(
origin,
event,
res,
state=ev_info.state,
auth_events=ev_info.auth_events,
backfilled=backfilled,
Expand Down Expand Up @@ -2177,49 +2221,6 @@ async def _persist_auth_tree(
room_id, [(event, new_event_context)]
)

async def _prep_event(
self,
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
auth_events: Optional[MutableStateMap[EventBase]],
backfilled: bool,
) -> EventContext:
context = await self.state_handler.compute_event_context(event, old_state=state)

if not auth_events:
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events_x = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}

# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_event_ids():
if len(event.prev_event_ids()) == 1 and event.depth < 5:
c = await self.store.get_event(
event.prev_event_ids()[0], allow_none=True
)
if c and c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c

context = await self.do_auth(origin, event, context, auth_events=auth_events)

if not context.rejected:
await self._check_for_soft_fail(event, state, backfilled)

if event.type == EventTypes.GuestAccess and not context.rejected:
await self.maybe_kick_guest_users(event)

# If we are going to send this event over federation we precaclculate
# the joined hosts.
if event.internal_metadata.get_send_on_behalf_of():
await self.event_creation_handler.cache_joined_hosts_for_event(event)

return context

async def _check_for_soft_fail(
self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool
) -> None:
Expand Down Expand Up @@ -2330,19 +2331,28 @@ async def on_get_missing_events(

return missing_events

async def do_auth(
async def _check_event_auth(
self,
origin: str,
event: EventBase,
context: EventContext,
auth_events: MutableStateMap[EventBase],
state: Optional[Iterable[EventBase]],
auth_events: Optional[MutableStateMap[EventBase]],
backfilled: bool,
) -> EventContext:
"""
Checks whether an event should be rejected (for failing auth checks).
Args:
origin:
event:
origin: The host the event originates from.
event: The event itself.
context:
The event context.
NB that this function potentially modifies it.
state:
The state events used to check the event for soft-fail. If this is
not provided the current state events will be used.
auth_events:
Map from (event_type, state_key) to event
Expand All @@ -2352,12 +2362,34 @@ async def do_auth(
server.
Also NB that this function adds entries to it.
If this is not provided, it is calculated from the previous state IDs.
backfilled: True if the event was backfilled.
Returns:
updated context object
The updated context object.
"""
room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]

if not auth_events:
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events_x = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}

# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_event_ids():
if len(event.prev_event_ids()) == 1 and event.depth < 5:
c = await self.store.get_event(
event.prev_event_ids()[0], allow_none=True
)
if c and c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c

try:
context = await self._update_auth_events_and_context_for_auth(
origin, event, context, auth_events
Expand All @@ -2379,6 +2411,17 @@ async def do_auth(
logger.warning("Failed auth resolution for %r because %s", event, e)
context.rejected = RejectedReason.AUTH_ERROR

if not context.rejected:
await self._check_for_soft_fail(event, state, backfilled)

if event.type == EventTypes.GuestAccess and not context.rejected:
await self.maybe_kick_guest_users(event)

# If we are going to send this event over federation we precaclculate
# the joined hosts.
if event.internal_metadata.get_send_on_behalf_of():
await self.event_creation_handler.cache_joined_hosts_for_event(event)

return context

async def _update_auth_events_and_context_for_auth(
Expand All @@ -2388,7 +2431,7 @@ async def _update_auth_events_and_context_for_auth(
context: EventContext,
auth_events: MutableStateMap[EventBase],
) -> EventContext:
"""Helper for do_auth. See there for docs.
"""Helper for _check_event_auth. See there for docs.
Checks whether a given event has the expected auth events. If it
doesn't then we talk to the remote server to compare state to see if
Expand Down Expand Up @@ -2468,9 +2511,14 @@ async def _update_auth_events_and_context_for_auth(
e.internal_metadata.outlier = True

logger.debug(
"do_auth %s missing_auth: %s", event.event_id, e.event_id
"_check_event_auth %s missing_auth: %s",
event.event_id,
e.event_id,
)
context = await self.state_handler.compute_event_context(e)
await self._auth_and_persist_event(
origin, e, context, auth_events=auth
)
await self._handle_new_event(origin, e, auth_events=auth)

if e.event_id in event_auth_events:
auth_events[(e.type, e.state_key)] = e
Expand Down
6 changes: 4 additions & 2 deletions tests/test_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,10 @@ def setUp(self):
)

self.handler = self.homeserver.get_federation_handler()
self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
context
self.handler._check_event_auth = (
lambda origin, event, context, state, auth_events, backfilled: succeed(
context
)
)
self.client = self.homeserver.get_federation_client()
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
Expand Down

0 comments on commit 936e698

Please sign in to comment.