Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Await un-partial-stating after a partial-state join #12399

Merged
merged 7 commits into from
Apr 21, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12399.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Preparation for faster-room-join work: Implement a tracking mechanism to allow functions to wait for full room state to arrive.
1 change: 1 addition & 0 deletions synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@ async def update_state_for_partial_state_event(
)
return
squahtx marked this conversation as resolved.
Show resolved Hide resolved
await self._store.update_state_for_partial_state_event(event, context)
self._state_store.notify_event_un_partial_stated(event.event_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not particularly fussy, but 'un-partial-stated' made me grin; I would have called this completed? 😀

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe? it would have to be notify_event_state_completed or something, to distinguish from other types of completion. Having a single concept which we refer to throughout the codebase as "partial state" seems clearer to me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might have a point there, even if 'un-partial' is a little non-obvious, once you notice what it means it makes sense.


async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Collection[str]
Expand Down
10 changes: 9 additions & 1 deletion synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1956,7 +1956,15 @@ def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]:
async def get_partial_state_events(
self, event_ids: Collection[str]
) -> Dict[str, bool]:
"""Checks which of the given events have partial state"""
"""Checks which of the given events have partial state

Args:
event_ids: the events we want to check for partial state.

Returns:
a dict mapping from event id to partial-stateness. We return True for
any of the events which are unknown (or are outliers).
"""
result = await self.db_pool.simple_select_many_batch(
table="partial_state_events",
column="event_id",
Expand Down
1 change: 1 addition & 0 deletions synapse/storage/databases/main/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ def _update_state_for_partial_state_event_txn(
)

# TODO(faster_joins): need to do something about workers here
txn.call_after(self.is_partial_state_event.invalidate, (event.event_id,))
txn.call_after(
self._get_state_group_for_event.prefill,
(event.event_id,),
Expand Down
5 changes: 5 additions & 0 deletions synapse/storage/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
from synapse.types import MutableStateMap, StateKey, StateMap

if TYPE_CHECKING:
Expand Down Expand Up @@ -542,6 +543,10 @@ class StateGroupStorage:

def __init__(self, hs: "HomeServer", stores: "Databases"):
self.stores = stores
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)

def notify_event_un_partial_stated(self, event_id: str) -> None:
self._partial_state_events_tracker.notify_un_partial_stated(event_id)

async def get_state_group_delta(
self, state_group: int
Expand Down
116 changes: 116 additions & 0 deletions synapse/storage/util/partial_state_events_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from collections import defaultdict
from typing import Collection, Dict, Set

from twisted.internet import defer
from twisted.internet.defer import Deferred

from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.util import unwrapFirstError

logger = logging.getLogger(__name__)


class PartialStateEventsTracker:
"""Keeps track of which events have partial state, after a partial-state join"""
squahtx marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, store: EventsWorkerStore):
self._store = store
self._observers: Dict[str, Set[Deferred]] = defaultdict(set)
richvdh marked this conversation as resolved.
Show resolved Hide resolved

def notify_un_partial_stated(self, event_id: str) -> None:
"""Notify that we now have full state for a given event

Called by the state-resynchronization loop whenever we resynchronize the state
for a particular event. Unblocks any callers to await_full_state() for that
event.

Args:
event_id: the event that now has full state.
"""
observers = self._observers.pop(event_id, None)
if not observers:
return
logger.info(
"Notifying %i things waiting for un-partial-stating of event %s",
len(observers),
event_id,
)
with PreserveLoggingContext():
for o in observers:
o.callback(None)

async def await_full_state(self, event_ids: Collection[str]) -> None:
"""Wait for all the given events to have full state.

Args:
event_ids: the list of event ids that we want full state for
"""
# first try the happy path: if there are no partial-state events, we can return
# quickly
partial_state_event_ids = [
ev
for ev, p in (await self._store.get_partial_state_events(event_ids)).items()
if p
]

if not partial_state_event_ids:
return

logger.info(
"Awaiting un-partial-stating of events %s",
partial_state_event_ids,
stack_info=True,
)

# create an observer for each lazy-joined event
observers: Dict[str, Deferred[None]] = {
event_id: Deferred() for event_id in partial_state_event_ids
}
for event_id, observer in observers.items():
self._observers[event_id].add(observer)

try:
# some of them may have been un-lazy-joined between us checking the db and
# registering the observer, in which case we'd wait forever for the
# notification. Call back the observers now.
for event_id, partial in (
await self._store.get_partial_state_events(observers.keys())
).items():
if not partial:
observers[event_id].callback(None)
richvdh marked this conversation as resolved.
Show resolved Hide resolved

await make_deferred_yieldable(
defer.gatherResults(
observers.values(),
consumeErrors=True,
)
).addErrback(unwrapFirstError)
logger.info("Events %s all un-partial-stated", observers.keys())
finally:
# remove any observers we created. This should happen when the notification
# is received, but that might not happen for two reasons:
# (a) we're bailing out early on an exception (including us being
# cancelled during the await)
# (b) the event got de-lazy-joined before we set up the observer.
for event_id, observer in observers.items():
observer_set = self._observers.get(event_id)
if observer_set:
observer_set.discard(observer)
if not observer_set:
del self._observers[event_id]
13 changes: 13 additions & 0 deletions tests/storage/util/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
94 changes: 94 additions & 0 deletions tests/storage/util/test_partial_state_events_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict
from unittest import mock

from twisted.internet.defer import CancelledError, ensureDeferred

from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker

from tests.unittest import TestCase


class PartialStateEventsTrackerTestCase(TestCase):
def setUp(self) -> None:
# the results to be returned by the mocked get_partial_state_events
self._events_dict: Dict[str, bool] = {}

async def get_partial_state_events(events):
return {e: self._events_dict[e] for e in events}

self.mock_store = mock.Mock(spec_set=["get_partial_state_events"])
self.mock_store.get_partial_state_events.side_effect = get_partial_state_events

self.tracker = PartialStateEventsTracker(self.mock_store)

def test_does_not_block_for_full_state_events(self):
self._events_dict = {"event1": False, "event2": False}

self.successResultOf(
ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
)

self.mock_store.get_partial_state_events.assert_called_once_with(
["event1", "event2"]
)

def test_blocks_for_partial_state_events(self):
self._events_dict = {"event1": True, "event2": False}

d = ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))

# there should be no result yet
self.assertNoResult(d)

# notifying that the event has been de-partial-stated should unblock
self.tracker.notify_un_partial_stated("event1")
self.successResultOf(d)

def test_un_partial_state_race(self):
# if the event is un-partial-stated between the initial check and the
# registration of the listener, it should not block.
self._events_dict = {"event1": True, "event2": False}

async def get_partial_state_events(events):
res = {e: self._events_dict[e] for e in events}
# change the result for next time
self._events_dict = {"event1": False, "event2": False}
return res

self.mock_store.get_partial_state_events.side_effect = get_partial_state_events

self.successResultOf(
ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
)

def test_cancellation(self):
self._events_dict = {"event1": True, "event2": False}

d1 = ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
self.assertNoResult(d1)

d2 = ensureDeferred(self.tracker.await_full_state(["event1"]))
self.assertNoResult(d2)

d1.cancel()
self.assertFailure(d1, CancelledError)

# d2 should still be waiting!
self.assertNoResult(d2)

self.tracker.notify_un_partial_stated("event1")
self.successResultOf(d2)