Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Cache user IDs instead of profile objects #13573

Merged
merged 7 commits into from
Aug 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/13573.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Cache user IDs instead of profiles to reduce cache memory usage. Contributed by Nick @ Beeper (@fizzadar).
4 changes: 2 additions & 2 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -2421,10 +2421,10 @@ async def get_rooms_for_user_at(
joined_room.room_id, joined_room.event_pos.stream
)
)
users_in_room = await self.state.get_current_users_in_room(
user_ids_in_room = await self.state.get_current_user_ids_in_room(
joined_room.room_id, extrems
)
if user_id in users_in_room:
if user_id in user_ids_in_room:
joined_room_ids.add(joined_room.room_id)

return frozenset(joined_room_ids)
Expand Down
13 changes: 6 additions & 7 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo
from synapse.storage.state import StateFilter
from synapse.types import StateMap
from synapse.util.async_helpers import Linearizer
Expand Down Expand Up @@ -210,11 +209,11 @@ async def compute_state_after_events(
ret = await self.resolve_state_groups_for_events(room_id, event_ids)
return await ret.get_state(self._state_storage_controller, state_filter)

async def get_current_users_in_room(
async def get_current_user_ids_in_room(
self, room_id: str, latest_event_ids: List[str]
) -> Dict[str, ProfileInfo]:
) -> Set[str]:
"""
Get the users who are currently in a room.
Get the users IDs who are currently in a room.

Note: This is much slower than using the equivalent method
`DataStore.get_users_in_room` or `DataStore.get_users_in_room_with_profiles`,
Expand All @@ -225,15 +224,15 @@ async def get_current_users_in_room(
room_id: The ID of the room.
latest_event_ids: Precomputed list of latest event IDs. Will be computed if None.
Returns:
Dictionary of user IDs to their profileinfo.
Set of user IDs in the room.
"""

assert latest_event_ids is not None

logger.debug("calling resolve_state_groups from get_current_users_in_room")
logger.debug("calling resolve_state_groups from get_current_user_ids_in_room")
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = await entry.get_state(self._state_storage_controller, StateFilter.all())
return await self.store.get_joined_users_from_state(room_id, state, entry)
return await self.store.get_joined_user_ids_from_state(room_id, state, entry)

async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str]
Expand Down
67 changes: 29 additions & 38 deletions synapse/storage/databases/main/roommember.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,9 +835,9 @@ async def get_mutual_rooms_between_users(

return shared_room_ids or frozenset()

async def get_joined_users_from_state(
async def get_joined_user_ids_from_state(
self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
) -> Dict[str, ProfileInfo]:
) -> Set[str]:
state_group: Union[object, int] = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
Expand All @@ -848,25 +848,25 @@ async def get_joined_users_from_state(

assert state_group is not None
with Measure(self._clock, "get_joined_users_from_state"):
return await self._get_joined_users_from_context(
return await self._get_joined_user_ids_from_context(
room_id, state_group, state, context=state_entry
)

@cached(num_args=2, iterable=True, max_entries=100000)
async def _get_joined_users_from_context(
async def _get_joined_user_ids_from_context(
self,
room_id: str,
state_group: Union[object, int],
current_state_ids: StateMap[str],
event: Optional[EventBase] = None,
context: Optional["_StateCacheEntry"] = None,
) -> Dict[str, ProfileInfo]:
) -> Set[str]:
# We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states
# with a state_group of None are likely to be different.
assert state_group is not None

users_in_room = {}
users_in_room = set()
member_event_ids = [
e_id
for key, e_id in current_state_ids.items()
Expand All @@ -879,19 +879,19 @@ async def _get_joined_users_from_context(
# If we do then we can reuse that result and simply update it with
# any membership changes in `delta_ids`
if context.prev_group and context.delta_ids:
prev_res = self._get_joined_users_from_context.cache.get_immediate(
prev_res = self._get_joined_user_ids_from_context.cache.get_immediate(
(room_id, context.prev_group), None
)
if prev_res and isinstance(prev_res, dict):
users_in_room = dict(prev_res)
if prev_res and isinstance(prev_res, set):
users_in_room = prev_res
member_event_ids = [
e_id
for key, e_id in context.delta_ids.items()
if key[0] == EventTypes.Member
]
for etype, state_key in context.delta_ids:
if etype == EventTypes.Member:
users_in_room.pop(state_key, None)
users_in_room.discard(state_key)

# We check if we have any of the member event ids in the event cache
# before we ask the DB
Expand All @@ -908,42 +908,41 @@ async def _get_joined_users_from_context(
ev_entry = event_map.get(event_id)
if ev_entry and not ev_entry.event.rejected_reason:
if ev_entry.event.membership == Membership.JOIN:
users_in_room[ev_entry.event.state_key] = ProfileInfo(
display_name=ev_entry.event.content.get("displayname", None),
avatar_url=ev_entry.event.content.get("avatar_url", None),
)
users_in_room.add(ev_entry.event.state_key)
else:
missing_member_event_ids.append(event_id)

if missing_member_event_ids:
event_to_memberships = await self._get_joined_profiles_from_event_ids(
event_to_memberships = await self._get_user_ids_from_membership_event_ids(
missing_member_event_ids
)
users_in_room.update(row for row in event_to_memberships.values() if row)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, the removal of the if row is causing exception on my server, because there is an event_id being passed in that isn't in room_memberships. I'm currently trying to track it down.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, no, this is actually because we're filtering out non-joins, which @cachedList will return as None

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed by #13600

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, thank you for fixing that so quick 🙏!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was excited to test the change on my server, so deployed it and saw everything light up in red :D

I am a bit scared that none of the tests spotted this....

users_in_room.update(event_to_memberships.values())

if event is not None and event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
if event.event_id in member_event_ids:
users_in_room[event.state_key] = ProfileInfo(
display_name=event.content.get("displayname", None),
avatar_url=event.content.get("avatar_url", None),
)
users_in_room.add(event.state_key)

return users_in_room

@cached(max_entries=10000)
def _get_joined_profile_from_event_id(
@cached(
max_entries=10000,
# This name matches the old function that has been replaced - the cache name
# is kept here to maintain backwards compatibility.
name="_get_joined_profile_from_event_id",
)
def _get_user_id_from_membership_event_id(
self, event_id: str
) -> Optional[Tuple[str, ProfileInfo]]:
raise NotImplementedError()

@cachedList(
cached_method_name="_get_joined_profile_from_event_id",
cached_method_name="_get_user_id_from_membership_event_id",
list_name="event_ids",
)
async def _get_joined_profiles_from_event_ids(
async def _get_user_ids_from_membership_event_ids(
self, event_ids: Iterable[str]
) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]:
) -> Dict[str, str]:
"""For given set of member event_ids check if they point to a join
event and if so return the associated user and profile info.

Expand All @@ -958,21 +957,13 @@ async def _get_joined_profiles_from_event_ids(
table="room_memberships",
column="event_id",
iterable=event_ids,
retcols=("user_id", "display_name", "avatar_url", "event_id"),
retcols=("user_id", "event_id"),
keyvalues={"membership": Membership.JOIN},
batch_size=1000,
desc="_get_joined_profiles_from_event_ids",
desc="_get_user_ids_from_membership_event_ids",
)

return {
row["event_id"]: (
row["user_id"],
ProfileInfo(
avatar_url=row["avatar_url"], display_name=row["display_name"]
),
)
for row in rows
}
return {row["event_id"]: row["user_id"] for row in rows}

@cached(max_entries=10000)
async def is_host_joined(self, room_id: str, host: str) -> bool:
Expand Down Expand Up @@ -1131,12 +1122,12 @@ async def _get_joined_hosts(
else:
# The cache doesn't match the state group or prev state group,
# so we calculate the result from first principles.
joined_users = await self.get_joined_users_from_state(
joined_user_ids = await self.get_joined_user_ids_from_state(
room_id, state, state_entry
)

cache.hosts_to_joined_users = {}
for user_id in joined_users:
for user_id in joined_user_ids:
host = intern_string(get_domain_from_id(user_id))
cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)

Expand Down
26 changes: 19 additions & 7 deletions synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ def __init__(
num_args: Optional[int],
uncached_args: Optional[Collection[str]] = None,
cache_context: bool = False,
name: Optional[str] = None,
):
self.orig = orig
self.name = name or orig.__name__

arg_spec = inspect.getfullargspec(orig)
all_args = arg_spec.args
Expand Down Expand Up @@ -211,7 +213,7 @@ def __init__(

def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: LruCache[CacheKey, Any] = LruCache(
cache_name=self.orig.__name__,
cache_name=self.name,
max_size=self.max_entries,
)

Expand Down Expand Up @@ -241,7 +243,7 @@ def _wrapped(*args: Any, **kwargs: Any) -> Any:

wrapped = cast(_CachedFunction, _wrapped)
wrapped.cache = cache
obj.__dict__[self.orig.__name__] = wrapped
obj.__dict__[self.name] = wrapped

return wrapped

Expand Down Expand Up @@ -301,12 +303,14 @@ def __init__(
cache_context: bool = False,
iterable: bool = False,
prune_unread_entries: bool = True,
name: Optional[str] = None,
):
super().__init__(
orig,
num_args=num_args,
uncached_args=uncached_args,
cache_context=cache_context,
name=name,
)

if tree and self.num_args < 2:
Expand All @@ -321,7 +325,7 @@ def __init__(

def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: DeferredCache[CacheKey, Any] = DeferredCache(
name=self.orig.__name__,
name=self.name,
max_entries=self.max_entries,
tree=self.tree,
iterable=self.iterable,
Expand Down Expand Up @@ -372,7 +376,7 @@ def _wrapped(*args: Any, **kwargs: Any) -> Any:
wrapped.cache = cache
wrapped.num_args = self.num_args

obj.__dict__[self.orig.__name__] = wrapped
obj.__dict__[self.name] = wrapped

return wrapped

Expand All @@ -393,6 +397,7 @@ def __init__(
cached_method_name: str,
list_name: str,
num_args: Optional[int] = None,
name: Optional[str] = None,
):
"""
Args:
Expand All @@ -403,7 +408,7 @@ def __init__(
but including list_name) to use as cache keys. Defaults to all
named args of the function.
"""
super().__init__(orig, num_args=num_args, uncached_args=None)
super().__init__(orig, num_args=num_args, uncached_args=None, name=name)

self.list_name = list_name

Expand Down Expand Up @@ -525,7 +530,7 @@ def errback_all(f: Failure) -> None:
else:
return defer.succeed(results)

obj.__dict__[self.orig.__name__] = wrapped
obj.__dict__[self.name] = wrapped

return wrapped

Expand Down Expand Up @@ -577,6 +582,7 @@ def cached(
cache_context: bool = False,
iterable: bool = False,
prune_unread_entries: bool = True,
name: Optional[str] = None,
) -> Callable[[F], _CachedFunction[F]]:
func = lambda orig: DeferredCacheDescriptor(
orig,
Expand All @@ -587,13 +593,18 @@ def cached(
cache_context=cache_context,
iterable=iterable,
prune_unread_entries=prune_unread_entries,
name=name,
)

return cast(Callable[[F], _CachedFunction[F]], func)


def cachedList(
*, cached_method_name: str, list_name: str, num_args: Optional[int] = None
*,
cached_method_name: str,
list_name: str,
num_args: Optional[int] = None,
name: Optional[str] = None,
) -> Callable[[F], _CachedFunction[F]]:
"""Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.

Expand Down Expand Up @@ -628,6 +639,7 @@ def batch_do_something(self, first_arg, second_args):
cached_method_name=cached_method_name,
list_name=list_name,
num_args=num_args,
name=name,
)

return cast(Callable[[F], _CachedFunction[F]], func)
Expand Down