Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add type hints to expiring cache.
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Apr 2, 2021
1 parent 4609e58 commit e0c4206
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 54 deletions.
2 changes: 1 addition & 1 deletion synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(self, hs: "HomeServer"):
max_len=1000,
expiry_ms=120 * 1000,
reset_expiry_on_get=False,
)
) # type: ExpiringCache[str, EventBase]

def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache"""
Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
max_len=10000,
expiry_ms=30 * 60 * 1000,
iterable=True,
)
) # type: ExpiringCache[str, Set[str]]

# Attempt to resync out of sync device lists every 30s.
self._resync_retry_in_progress = False
Expand Down Expand Up @@ -760,7 +760,7 @@ async def _need_to_do_resync(
"""Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta.
"""
seen_updates = self._seen_updates.get(user_id, set())
seen_updates = self._seen_updates.get(user_id, set()) # type: Set[str]

extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)

Expand Down
12 changes: 0 additions & 12 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
)
from synapse.util import json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination

if TYPE_CHECKING:
Expand Down Expand Up @@ -1292,17 +1291,6 @@ def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
# user_id -> list of updates waiting to be handled.
self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]

# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
# resyncs.
self._seen_updates = ExpiringCache(
cache_name="signing_key_update_edu",
clock=self.clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
iterable=True,
)

async def incoming_signing_key_update(
self, origin: str, edu_content: JsonDict
) -> None:
Expand Down
10 changes: 6 additions & 4 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,13 @@ def __init__(self, hs: "HomeServer"):
self.storage = hs.get_storage()
self.state_store = self.storage.state

# ExpiringCache((User, Device)) -> LruCache(state_key => event_id)
# ExpiringCache((User, Device)) -> LruCache(user_id => event_id)
self.lazy_loaded_members_cache = ExpiringCache(
"lazy_loaded_members_cache",
self.clock,
max_len=0,
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
)
) # type: ExpiringCache[Tuple[str, Optional[str]], LruCache[str, str]]

async def wait_for_sync_for_user(
self,
Expand Down Expand Up @@ -733,8 +733,10 @@ async def compute_summary(

def get_lazy_loaded_members_cache(
self, cache_key: Tuple[str, Optional[str]]
) -> LruCache:
cache = self.lazy_loaded_members_cache.get(cache_key)
) -> LruCache[str, str]:
cache = self.lazy_loaded_members_cache.get(
cache_key
) # type: Optional[LruCache[str, str]]
if cache is None:
logger.debug("creating LruCache for %r", cache_key)
cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
Expand Down
2 changes: 1 addition & 1 deletion synapse/rest/media/v1/preview_url_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def __init__(
clock=self.clock,
# don't spider URLs more often than once an hour
expiry_ms=ONE_HOUR,
)
) # type: ExpiringCache[str, ObservableDeferred]

if self._worker_run_media_background_jobs:
self._cleaner_loop = self.clock.looping_call(
Expand Down
5 changes: 3 additions & 2 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Callable,
DefaultDict,
Dict,
FrozenSet,
Iterable,
List,
Optional,
Expand Down Expand Up @@ -515,7 +516,7 @@ def __init__(self, hs):
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True,
reset_expiry_on_get=True,
)
) # type: ExpiringCache[FrozenSet[int], _StateCacheEntry]

#
# stuff for tracking time spent on state-res by room
Expand All @@ -536,7 +537,7 @@ async def resolve_state_groups(
state_groups_ids: Dict[int, StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
):
) -> _StateCacheEntry:
"""Resolves conflicts between a set of state groups
Always generates a new state group (unless we hit the cache), so should
Expand Down
83 changes: 51 additions & 32 deletions synapse/util/caches/expiringcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,50 @@

import logging
from collections import OrderedDict
from typing import Any, Generic, Optional, TypeVar, Union, overload

import attr
from typing_extensions import Literal

from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import Clock
from synapse.util.caches import register_cache

logger = logging.getLogger(__name__)


SENTINEL = object()
SENTINEL = object() # type: Any


T = TypeVar("T")
KT = TypeVar("KT")
VT = TypeVar("VT")

class ExpiringCache:

class ExpiringCache(Generic[KT, VT]):
def __init__(
self,
cache_name,
clock,
max_len=0,
expiry_ms=0,
reset_expiry_on_get=False,
iterable=False,
cache_name: str,
clock: Clock,
max_len: int = 0,
expiry_ms: int = 0,
reset_expiry_on_get: bool = False,
iterable: bool = False,
):
"""
Args:
cache_name (str): Name of this cache, used for logging.
clock (Clock)
max_len (int): Max size of dict. If the dict grows larger than this
cache_name: Name of this cache, used for logging.
clock
max_len: Max size of dict. If the dict grows larger than this
then the oldest items get automatically evicted. Default is 0,
which indicates there is no max limit.
expiry_ms (int): How long before an item is evicted from the cache
expiry_ms: How long before an item is evicted from the cache
in milliseconds. Default is 0, indicating items never get
evicted based on time.
reset_expiry_on_get (bool): If true, will reset the expiry time for
reset_expiry_on_get: If true, will reset the expiry time for
an item on access. Defaults to False.
iterable (bool): If true, the size is calculated by summing the
iterable: If true, the size is calculated by summing the
sizes of all entries, rather than the number of entries.
"""
self._cache_name = cache_name
Expand All @@ -62,7 +72,7 @@ def __init__(
self._expiry_ms = expiry_ms
self._reset_expiry_on_get = reset_expiry_on_get

self._cache = OrderedDict()
self._cache = OrderedDict() # type: OrderedDict[KT, _CacheEntry]

self.iterable = iterable

Expand All @@ -79,12 +89,12 @@ def f():

self._clock.looping_call(f, self._expiry_ms / 2)

def __setitem__(self, key, value):
def __setitem__(self, key: KT, value: VT) -> None:
now = self._clock.time_msec()
self._cache[key] = _CacheEntry(now, value)
self.evict()

def evict(self):
def evict(self) -> None:
# Evict if there are now too many items
while self._max_size and len(self) > self._max_size:
_key, value = self._cache.popitem(last=False)
Expand All @@ -93,7 +103,7 @@ def evict(self):
else:
self.metrics.inc_evictions()

def __getitem__(self, key):
def __getitem__(self, key: KT) -> VT:
try:
entry = self._cache[key]
self.metrics.inc_hits()
Expand All @@ -106,7 +116,7 @@ def __getitem__(self, key):

return entry.value

def pop(self, key, default=SENTINEL):
def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]:
"""Removes and returns the value with the given key from the cache.
If the key isn't in the cache then `default` will be returned if
Expand All @@ -115,29 +125,40 @@ def pop(self, key, default=SENTINEL):
Identical functionality to `dict.pop(..)`.
"""

value = self._cache.pop(key, default)
value = self._cache.pop(key, SENTINEL)
# The key was not found.
if value is SENTINEL:
raise KeyError(key)
if default is SENTINEL:
raise KeyError(key)
return default

return value
return value.value

def __contains__(self, key):
def __contains__(self, key: KT) -> bool:
return key in self._cache

def get(self, key, default=None):
@overload
def get(self, key: KT, default: Literal[None] = None) -> Optional[VT]:
...

@overload
def get(self, key: KT, default: T) -> Union[VT, T]:
...

def get(self, key: KT, default: Optional[T] = None) -> Union[VT, Optional[T]]:
try:
return self[key]
except KeyError:
return default

def setdefault(self, key, value):
def setdefault(self, key: KT, value: VT) -> VT:
try:
return self[key]
except KeyError:
self[key] = value
return value

def _prune_cache(self):
def _prune_cache(self) -> None:
if not self._expiry_ms:
# zero expiry time means don't expire. This should never get called
# since we have this check in start too.
Expand Down Expand Up @@ -166,7 +187,7 @@ def _prune_cache(self):
len(self),
)

def __len__(self):
def __len__(self) -> int:
if self.iterable:
return sum(len(entry.value) for entry in self._cache.values())
else:
Expand All @@ -190,9 +211,7 @@ def set_cache_factor(self, factor: float) -> bool:
return False


@attr.s(slots=True)
class _CacheEntry:
__slots__ = ["time", "value"]

def __init__(self, time, value):
self.time = time
self.value = value
time = attr.ib(type=int)
value = attr.ib()

0 comments on commit e0c4206

Please sign in to comment.