Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Handle cancellation in EventsWorkerStore._get_events_from_cache_or_db
Browse files Browse the repository at this point in the history
Multiple calls to `EventsWorkerStore._get_events_from_cache_or_db` can
reuse the same database fetch, which is initiated by the first call.
Ensure that cancelling the first call doesn't cancel the other calls
sharing the same database fetch.

Signed-off-by: Sean Quah <[email protected]>
  • Loading branch information
Sean Quah committed Apr 22, 2022
1 parent a50fb41 commit a1031cf
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 36 deletions.
1 change: 1 addition & 0 deletions changelog.d/0000.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implemented cancellation support in `EventsWorkerStore._get_events_from_cache_or_db`.
83 changes: 49 additions & 34 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
Expand Down Expand Up @@ -640,42 +640,57 @@ async def _get_events_from_cache_or_db(
missing_events_ids.difference_update(already_fetching_ids)

if missing_events_ids:
log_ctx = current_context()
log_ctx.record_event_fetch(len(missing_events_ids))

# Add entries to `self._current_event_fetches` for each event we're
# going to pull from the DB. We use a single deferred that resolves
# to all the events we pulled from the DB (this will result in this
# function returning more events than requested, but that can happen
# already due to `_get_events_from_db`).
fetching_deferred: ObservableDeferred[
Dict[str, EventCacheEntry]
] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
for event_id in missing_events_ids:
self._current_event_fetches[event_id] = fetching_deferred

# Note that _get_events_from_db is also responsible for turning db rows
# into FrozenEvents (via _get_event_from_row), which involves seeing if
# the events have been redacted, and if so pulling the redaction event out
# of the database to check it.
#
try:
missing_events = await self._get_events_from_db(
missing_events_ids,
)

event_entry_map.update(missing_events)
except Exception as e:
with PreserveLoggingContext():
fetching_deferred.errback(e)
raise e
finally:
# Ensure that we mark these events as no longer being fetched.
async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]:
"""Fetches the events in `missing_event_ids` from the database.
Also creates entries in `self._current_event_fetches` to allow
concurrent `_get_events_from_cache_or_db` calls to reuse the same fetch.
"""
log_ctx = current_context()
log_ctx.record_event_fetch(len(missing_events_ids))

# Add entries to `self._current_event_fetches` for each event we're
# going to pull from the DB. We use a single deferred that resolves
# to all the events we pulled from the DB (this will result in this
# function returning more events than requested, but that can happen
# already due to `_get_events_from_db`).
fetching_deferred: ObservableDeferred[
Dict[str, EventCacheEntry]
] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
for event_id in missing_events_ids:
self._current_event_fetches.pop(event_id, None)
self._current_event_fetches[event_id] = fetching_deferred

with PreserveLoggingContext():
fetching_deferred.callback(missing_events)
# Note that _get_events_from_db is also responsible for turning db rows
# into FrozenEvents (via _get_event_from_row), which involves seeing if
# the events have been redacted, and if so pulling the redaction event
# out of the database to check it.
#
try:
missing_events = await self._get_events_from_db(
missing_events_ids,
)
except Exception as e:
with PreserveLoggingContext():
fetching_deferred.errback(e)
raise e
finally:
# Ensure that we mark these events as no longer being fetched.
for event_id in missing_events_ids:
self._current_event_fetches.pop(event_id, None)

with PreserveLoggingContext():
fetching_deferred.callback(missing_events)

return missing_events

# We must allow the database fetch to complete in the presence of
# cancellations, since multiple `_get_events_from_cache_or_db` calls can
# reuse the same fetch.
missing_events: Dict[str, EventCacheEntry] = await delay_cancellation(
get_missing_events_from_db()
)
event_entry_map.update(missing_events)

if already_fetching_deferreds:
# Wait for the other event requests to finish and add their results
Expand Down
121 changes: 119 additions & 2 deletions tests/storage/databases/main/test_events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# limitations under the License.
import json
from contextlib import contextmanager
from typing import Generator
from typing import Generator, Tuple
from unittest import mock

from twisted.enterprise.adbapi import ConnectionPool
from twisted.internet.defer import ensureDeferred
from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
from twisted.test.proto_helpers import MemoryReactor

from synapse.api.room_versions import EventFormatVersions, RoomVersions
Expand Down Expand Up @@ -281,3 +282,119 @@ def test_recovery(self) -> None:

# This next event fetch should succeed
self.get_success(self.store.get_event(self.event_ids[0]))


class GetEventCancellationTestCase(unittest.HomeserverTestCase):
"""Test cancellation of `get_event` calls."""

servlets = [
admin.register_servlets,
room.register_servlets,
login.register_servlets,
]

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
self.store: EventsWorkerStore = hs.get_datastores().main

self.user = self.register_user("user", "pass")
self.token = self.login(self.user, "pass")

self.room = self.helper.create_room_as(self.user, tok=self.token)

res = self.helper.send(self.room, tok=self.token)
self.event_id = res["event_id"]

# Reset the event cache so the tests start with it empty
self.store._get_event_cache.clear()

@contextmanager
def blocking_get_event_calls(
self,
) -> Generator[
Tuple["Deferred[None]", "Deferred[None]", "Deferred[None]"], None, None
]:
"""Starts two concurrent `get_event` calls for the same event.
Both `get_event` calls will use the same database fetch, which will be blocked
at the time this function returns.
Returns:
A tuple containing:
* A `Deferred` that unblocks the database fetch.
* A cancellable `Deferred` for the first `get_event` call.
* A cancellable `Deferred` for the second `get_event` call.
"""
# Patch `DatabasePool.runWithConnection` to block.
unblock: "Deferred[None]" = Deferred()
original_runWithConnection = self.store.db_pool.runWithConnection

async def runWithConnection(*args, **kwargs):
await unblock
return await original_runWithConnection(*args, **kwargs)

with mock.patch.object(
self.store.db_pool,
"runWithConnection",
new=runWithConnection,
):
ctx1 = LoggingContext("get_event1")
ctx2 = LoggingContext("get_event2")

async def get_event(ctx: LoggingContext) -> None:
with ctx:
await self.store.get_event(self.event_id)

get_event1 = ensureDeferred(get_event(ctx1))
get_event2 = ensureDeferred(get_event(ctx2))

# Both `get_event` calls ought to be blocked.
self.assertNoResult(get_event1)
self.assertNoResult(get_event2)

yield unblock, get_event1, get_event2

# Confirm that the two `get_event` calls shared the same database fetch.
self.assertEqual(ctx1.get_resource_usage().evt_db_fetch_count, 1)
self.assertEqual(ctx2.get_resource_usage().evt_db_fetch_count, 0)

def test_first_get_event_cancelled(self):
"""Test cancellation of the first `get_event` call sharing a database fetch.
The first `get_event` call is the one which initiates the fetch. We expect the
fetch to complete despite the cancellation. Furthermore, the first `get_event`
call must not abort before the fetch is complete, otherwise the fetch will be
using a finished logging context.
"""
with self.blocking_get_event_calls() as (unblock, get_event1, get_event2):
# Cancel the first `get_event` call.
get_event1.cancel()
# The first `get_event` call must not abort immediately, otherwise its
# logging context will be finished while it is still in use by the database
# fetch.
self.assertNoResult(get_event1)
# The second `get_event` call must not be cancelled.
self.assertNoResult(get_event2)

# Unblock the database fetch.
unblock.callback(None)
# A `CancelledError` should be raised out of the first `get_event` call.
exc = self.get_failure(get_event1, CancelledError).value
self.assertIsInstance(exc, CancelledError)
# The second `get_event` call should complete successfully.
self.get_success(get_event2)

def test_second_get_event_cancelled(self):
"""Test cancellation of the second `get_event` call sharing a database fetch."""
with self.blocking_get_event_calls() as (unblock, get_event1, get_event2):
# Cancel the second `get_event` call.
get_event2.cancel()
# The first `get_event` call must not be cancelled.
self.assertNoResult(get_event1)
# The second `get_event` call gets cancelled immediately.
exc = self.get_failure(get_event2, CancelledError).value
self.assertIsInstance(exc, CancelledError)

# Unblock the database fetch.
unblock.callback(None)
# The first `get_event` call should complete successfully.
self.get_success(get_event1)

0 comments on commit a1031cf

Please sign in to comment.