Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Await un-partial-stating after a partial-state join #12399

Merged
merged 7 commits into from
Apr 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12399.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Preparation for faster-room-join work: Implement a tracking mechanism to allow functions to wait for full room state to arrive.
4 changes: 3 additions & 1 deletion docker/complement/conf/homeserver.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,10 @@ experimental_features:
spaces_enabled: true
# Enable history backfilling support
msc2716_enabled: true
# server-side support for partial state in /send_join
# server-side support for partial state in /send_join responses
msc3706_enabled: true
# client-side support for partial state in /send_join responses
faster_joins: true
# Enable jump to date endpoint
msc3030_enabled: true

Expand Down
2 changes: 1 addition & 1 deletion scripts-dev/complement.sh
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,4 @@ docker build -t $COMPLEMENT_BASE_IMAGE -f "docker/complement/$COMPLEMENT_DOCKERF
# Run the tests!
echo "Images built; running complement"
cd "$COMPLEMENT_DIR"
go test -v -tags synapse_blacklist,msc2716,msc3030 -count=1 "$@" ./tests/...
go test -v -tags synapse_blacklist,msc2716,msc3030,faster_joins -count=1 "$@" ./tests/...
1 change: 1 addition & 0 deletions synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@ async def update_state_for_partial_state_event(
)
return
squahtx marked this conversation as resolved.
Show resolved Hide resolved
await self._store.update_state_for_partial_state_event(event, context)
self._state_store.notify_event_un_partial_stated(event.event_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not particularly fussy, but 'un-partial-stated' made me grin; I would have called this completed? 😀

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe? it would have to be notify_event_state_completed or something, to distinguish from other types of completion. Having a single concept which we refer to throughout the codebase as "partial state" seems clearer to me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might have a point there, even if 'un-partial' is a little non-obvious, once you notice what it means it makes sense.


async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Collection[str]
Expand Down
10 changes: 9 additions & 1 deletion synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1956,7 +1956,15 @@ def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]:
async def get_partial_state_events(
self, event_ids: Collection[str]
) -> Dict[str, bool]:
"""Checks which of the given events have partial state"""
"""Checks which of the given events have partial state

Args:
event_ids: the events we want to check for partial state.

Returns:
a dict mapping from event id to partial-stateness. We return True for
any of the events which are unknown (or are outliers).
"""
result = await self.db_pool.simple_select_many_batch(
table="partial_state_events",
column="event_id",
Expand Down
1 change: 1 addition & 0 deletions synapse/storage/databases/main/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ def _update_state_for_partial_state_event_txn(
)

# TODO(faster_joins): need to do something about workers here
txn.call_after(self.is_partial_state_event.invalidate, (event.event_id,))
txn.call_after(
self._get_state_group_for_event.prefill,
(event.event_id,),
Expand Down
28 changes: 25 additions & 3 deletions synapse/storage/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
from synapse.types import MutableStateMap, StateKey, StateMap

if TYPE_CHECKING:
Expand Down Expand Up @@ -542,6 +543,10 @@ class StateGroupStorage:

def __init__(self, hs: "HomeServer", stores: "Databases"):
self.stores = stores
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)

def notify_event_un_partial_stated(self, event_id: str) -> None:
self._partial_state_events_tracker.notify_un_partial_stated(event_id)

async def get_state_group_delta(
self, state_group: int
Expand Down Expand Up @@ -579,7 +584,7 @@ async def get_state_groups_ids(
if not event_ids:
return {}

event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
event_to_groups = await self._get_state_group_for_events(event_ids)
squahtx marked this conversation as resolved.
Show resolved Hide resolved

groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(groups)
Expand Down Expand Up @@ -668,7 +673,7 @@ async def get_state_for_events(
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
event_to_groups = await self._get_state_group_for_events(event_ids)

groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
Expand Down Expand Up @@ -709,7 +714,7 @@ async def get_state_ids_for_events(
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
event_to_groups = await self._get_state_group_for_events(event_ids)

groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
Expand Down Expand Up @@ -785,6 +790,23 @@ def _get_state_for_groups(
groups, state_filter or StateFilter.all()
)

async def _get_state_group_for_events(
self,
event_ids: Collection[str],
await_full_state: bool = True,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the intention is to follow up with further PRs which relax this by setting await_full_state=False where it is known to be safe to do so.

) -> Mapping[str, int]:
"""Returns mapping event_id -> state_group

Args:
event_ids: events to get state groups for
await_full_state: if true, will block if we do not yet have complete
state at this event.
"""
if await_full_state:
await self._partial_state_events_tracker.await_full_state(event_ids)
Comment on lines +805 to +806
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I appreciate the need for this, but this seems like it could block for a long time, potentially exceeding the timeout of a request.
I suppose the plan is to make enough endpoints not need to wait for full state that this won't matter much?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's part of the plan. Another part of the plan is to get cancellation working (#3528) so that we don't stack up requests. A final part of the plan is probably to change the api to do something different, but we're a way off that yet.


return await self.stores.main._get_state_group_for_events(event_ids)

async def store_state_group(
self,
event_id: str,
Expand Down
120 changes: 120 additions & 0 deletions synapse/storage/util/partial_state_events_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from collections import defaultdict
from typing import Collection, Dict, Set

from twisted.internet import defer
from twisted.internet.defer import Deferred

from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.util import unwrapFirstError

logger = logging.getLogger(__name__)


class PartialStateEventsTracker:
"""Keeps track of which events have partial state, after a partial-state join"""
squahtx marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, store: EventsWorkerStore):
self._store = store
# a map from event id to a set of Deferreds which are waiting for that event to be
# un-partial-stated.
self._observers: Dict[str, Set[Deferred[None]]] = defaultdict(set)
richvdh marked this conversation as resolved.
Show resolved Hide resolved

def notify_un_partial_stated(self, event_id: str) -> None:
"""Notify that we now have full state for a given event

Called by the state-resynchronization loop whenever we resynchronize the state
for a particular event. Unblocks any callers to await_full_state() for that
event.

Args:
event_id: the event that now has full state.
"""
observers = self._observers.pop(event_id, None)
if not observers:
return
logger.info(
"Notifying %i things waiting for un-partial-stating of event %s",
len(observers),
event_id,
)
with PreserveLoggingContext():
for o in observers:
o.callback(None)

async def await_full_state(self, event_ids: Collection[str]) -> None:
"""Wait for all the given events to have full state.

Args:
event_ids: the list of event ids that we want full state for
"""
# first try the happy path: if there are no partial-state events, we can return
# quickly
partial_state_event_ids = [
ev
for ev, p in (await self._store.get_partial_state_events(event_ids)).items()
if p
]

if not partial_state_event_ids:
return

logger.info(
"Awaiting un-partial-stating of events %s",
partial_state_event_ids,
stack_info=True,
)

# create an observer for each lazy-joined event
observers: Dict[str, Deferred[None]] = {
event_id: Deferred() for event_id in partial_state_event_ids
}
for event_id, observer in observers.items():
self._observers[event_id].add(observer)

try:
# some of them may have been un-lazy-joined between us checking the db and
# registering the observer, in which case we'd wait forever for the
# notification. Call back the observers now.
for event_id, partial in (
await self._store.get_partial_state_events(observers.keys())
).items():
# there may have been a call to notify_un_partial_stated during the
# db query, so the observers may already have been called.
if not partial and not observers[event_id].called:
observers[event_id].callback(None)

await make_deferred_yieldable(
defer.gatherResults(
observers.values(),
consumeErrors=True,
)
).addErrback(unwrapFirstError)
logger.info("Events %s all un-partial-stated", observers.keys())
finally:
# remove any observers we created. This should happen when the notification
# is received, but that might not happen for two reasons:
# (a) we're bailing out early on an exception (including us being
# cancelled during the await)
# (b) the event got de-lazy-joined before we set up the observer.
for event_id, observer in observers.items():
observer_set = self._observers.get(event_id)
if observer_set:
observer_set.discard(observer)
if not observer_set:
del self._observers[event_id]
13 changes: 13 additions & 0 deletions tests/storage/util/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
117 changes: 117 additions & 0 deletions tests/storage/util/test_partial_state_events_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict
from unittest import mock

from twisted.internet.defer import CancelledError, ensureDeferred

from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker

from tests.unittest import TestCase


class PartialStateEventsTrackerTestCase(TestCase):
def setUp(self) -> None:
# the results to be returned by the mocked get_partial_state_events
self._events_dict: Dict[str, bool] = {}

async def get_partial_state_events(events):
return {e: self._events_dict[e] for e in events}

self.mock_store = mock.Mock(spec_set=["get_partial_state_events"])
self.mock_store.get_partial_state_events.side_effect = get_partial_state_events

self.tracker = PartialStateEventsTracker(self.mock_store)

def test_does_not_block_for_full_state_events(self):
self._events_dict = {"event1": False, "event2": False}

self.successResultOf(
ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
)

self.mock_store.get_partial_state_events.assert_called_once_with(
["event1", "event2"]
)

def test_blocks_for_partial_state_events(self):
self._events_dict = {"event1": True, "event2": False}

d = ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))

# there should be no result yet
self.assertNoResult(d)

# notifying that the event has been de-partial-stated should unblock
self.tracker.notify_un_partial_stated("event1")
self.successResultOf(d)

def test_un_partial_state_race(self):
# if the event is un-partial-stated between the initial check and the
# registration of the listener, it should not block.
self._events_dict = {"event1": True, "event2": False}

async def get_partial_state_events(events):
res = {e: self._events_dict[e] for e in events}
# change the result for next time
self._events_dict = {"event1": False, "event2": False}
return res

self.mock_store.get_partial_state_events.side_effect = get_partial_state_events

self.successResultOf(
ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
)

def test_un_partial_state_during_get_partial_state_events(self):
# we should correctly handle a call to notify_un_partial_stated during the
# second call to get_partial_state_events.

self._events_dict = {"event1": True, "event2": False}

async def get_partial_state_events1(events):
self.mock_store.get_partial_state_events.side_effect = (
get_partial_state_events2
)
return {e: self._events_dict[e] for e in events}

async def get_partial_state_events2(events):
self.tracker.notify_un_partial_stated("event1")
self._events_dict["event1"] = False
return {e: self._events_dict[e] for e in events}

self.mock_store.get_partial_state_events.side_effect = get_partial_state_events1

self.successResultOf(
ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
)

def test_cancellation(self):
self._events_dict = {"event1": True, "event2": False}

d1 = ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
self.assertNoResult(d1)

d2 = ensureDeferred(self.tracker.await_full_state(["event1"]))
self.assertNoResult(d2)

d1.cancel()
self.assertFailure(d1, CancelledError)

# d2 should still be waiting!
self.assertNoResult(d2)

self.tracker.notify_un_partial_stated("event1")
self.successResultOf(d2)