Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Speed up @cachedList #13591

Merged
merged 10 commits into from
Aug 23, 2022
74 changes: 74 additions & 0 deletions synapse/util/caches/deferred_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from twisted.internet import defer
from twisted.python.failure import Failure

from synapse.logging.context import PreserveLoggingContext
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
Expand Down Expand Up @@ -244,6 +245,25 @@ def set(
# we return a new Deferred which will be called before any subsequent observers.
return deferred

def set_bulk(
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
self,
keys: Collection[KT],
callback: Optional[Callable[[], None]] = None,
) -> "CacheMultipleEntries[KT, VT]":
"""Bulk set API for use when fetching multiple keys at once from the DB.

Called *before* starting the fetch from the DB, and the caller *must*
call either `complete_bulk(..)` or `error_bulk(..)` on the return value.
"""

entry = CacheMultipleEntries[KT, VT]()
entry.add_global_callback(callback)

for key in keys:
self._pending_deferred_cache[key] = entry

return entry

def _set_completed_callback(
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
self, value: VT, entry: "CacheEntry[KT, VT]", key: KT
) -> VT:
Expand Down Expand Up @@ -366,3 +386,57 @@ def add_callback(self, key: KT, callback: Optional[Callable[[], None]]) -> None:

def get_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
return self._callbacks


class CacheMultipleEntries(CacheEntry[KT, VT]):
"""Cache entry that is used for bulk lookups and insertions."""

__slots__ = ["_deferred", "_callbacks", "_global_callbacks"]

def __init__(self) -> None:
self._deferred: Optional[ObservableDeferred[Dict[KT, VT]]] = None
self._callbacks: Dict[KT, Set[Callable[[], None]]] = {}
self._global_callbacks: Set[Callable[[], None]] = set()

def deferred(self, key: KT) -> "defer.Deferred[VT]":
if not self._deferred:
self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
return self._deferred.observe().addCallback(lambda res: res.get(key))

def add_callback(self, key: KT, callback: Optional[Callable[[], None]]) -> None:
if callback is None:
return

self._callbacks.setdefault(key, set()).add(callback)

def get_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
return self._callbacks.get(key, set()) | self._global_callbacks

def add_global_callback(self, callback: Optional[Callable[[], None]]) -> None:
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
"""Add a callback for when any keys get invalidated."""
if callback is None:
return

self._global_callbacks.add(callback)

def complete_bulk(
self,
cache: DeferredCache[KT, VT],
result: Dict[KT, VT],
) -> None:
"""Called when there is a result"""
for key, value in result.items():
cache._set_completed_callback(value, self, key)

if self._deferred:
self._deferred.callback(result)

def error_bulk(
self, cache: DeferredCache[KT, VT], keys: Collection[KT], failure: Failure
) -> None:
"""Called when bulk lookup failed."""
for key in keys:
cache._error_callback(failure, self, key)

if self._deferred:
self._deferred.errback(failure)
42 changes: 13 additions & 29 deletions synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,46 +471,30 @@ def arg_to_cache_key(arg: Hashable) -> Hashable:
missing.add(arg)

if missing:
# we need a deferred for each entry in the list,
# which we put in the cache. Each deferred resolves with the
# relevant result for that key.
deferreds_map = {}
for arg in missing:
deferred: "defer.Deferred[Any]" = defer.Deferred()
deferreds_map[arg] = deferred
key = arg_to_cache_key(arg)
cached_defers.append(
cache.set(key, deferred, callback=invalidate_callback)
)
cache_keys = [arg_to_cache_key(key) for key in missing]
cache_entry = cache.set_bulk(cache_keys, callback=invalidate_callback)

def complete_all(res: Dict[Hashable, Any]) -> None:
# the wrapped function has completed. It returns a dict.
# We can now update our own result map, and then resolve the
# observable deferreds in the cache.
for e, d1 in deferreds_map.items():
val = res.get(e, None)
# make sure we update the results map before running the
# deferreds, because as soon as we run the last deferred, the
# gatherResults() below will complete and return the result
# dict to our caller.
results[e] = val
d1.callback(val)
missing_results = {}
for key in missing:
val = res.get(key, None)

results[key] = val
missing_results[arg_to_cache_key(key)] = val

cache_entry.complete_bulk(cache, missing_results)

def errback_all(f: Failure) -> None:
# the wrapped function has failed. Propagate the failure into
# the cache, which will invalidate the entry, and cause the
# relevant cached_deferreds to fail, which will propagate the
# failure to our caller.
for d1 in deferreds_map.values():
d1.errback(f)
cache_entry.error_bulk(cache, cache_keys, f)

args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing

# dispatch the call, and attach the two handlers
defer.maybeDeferred(
missing_d = defer.maybeDeferred(
preserve_fn(self.orig), **args_to_call
).addCallbacks(complete_all, errback_all)
cached_defers.append(missing_d)

if cached_defers:
d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
Expand Down