diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 7791b289e259..1d32760a7bcf 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -452,77 +452,27 @@ async def get_rules( self.room_push_rule_cache_metrics.inc_hits() return self.data.rules_by_user - self.room_push_rule_cache_metrics.inc_misses() - - ret_rules_by_user = {} - missing_member_event_ids = {} - if state_group and self.data.state_group == context.prev_group: - # If we have a simple delta then we can reuse most of the previous - # results. - ret_rules_by_user = self.data.rules_by_user - current_state_ids = context.delta_ids - - push_rules_delta_state_cache_metric.inc_hits() - else: - current_state_ids = await context.get_current_state_ids() - push_rules_delta_state_cache_metric.inc_misses() - # Ensure the state IDs exist. - assert current_state_ids is not None - - push_rules_state_size_counter.inc(len(current_state_ids)) - - logger.debug( - "Looking for member changes in %r %r", state_group, current_state_ids + local_users = await self.store.get_local_users_in_room( + self.room_id, on_invalidate=self.invalidate_all_cb ) - # Loop through to see which member events we've seen and have rules - # for and which we need to fetch - for key in current_state_ids: - typ, user_id = key - if typ != EventTypes.Member: - continue - - if user_id in self.data.uninteresting_user_set: - continue - - if not self.is_mine_id(user_id): - self.data.uninteresting_user_set.add(user_id) - continue - - if self.store.get_if_app_services_interested_in_user(user_id): - self.data.uninteresting_user_set.add(user_id) - continue + if event.type == EventTypes.Member and event.membership == Membership.JOIN: + if self.is_mine_id(event.state_key): + local_users = list(local_users) + local_users.append(event.state_key) - event_id = current_state_ids[key] + ret_rules_by_user = await self.store.bulk_get_push_rules( + local_users, on_invalidate=self.invalidate_all_cb + ) - res = self.data.member_map.get(event_id, None) - if res: - if res.membership == Membership.JOIN: - rules = self.data.rules_by_user.get(res.user_id, None) - if rules: - ret_rules_by_user[res.user_id] = rules - continue + logger.info("Users in room: %s", local_users) - # If a user has left a room we remove their push rule. If they - # joined then we re-add it later in _update_rules_with_member_event_ids - ret_rules_by_user.pop(user_id, None) - missing_member_event_ids[user_id] = event_id - - if missing_member_event_ids: - # If we have some member events we haven't seen, look them up - # and fetch push rules for them if appropriate. - logger.debug("Found new member events %r", missing_member_event_ids) - await self._update_rules_with_member_event_ids( - ret_rules_by_user, missing_member_event_ids, state_group, event - ) - else: - # The push rules didn't change but lets update the cache anyway - self.update_cache( - self.data.sequence, - members={}, # There were no membership changes - rules_by_user=ret_rules_by_user, - state_group=state_group, - ) + self.update_cache( + self.data.sequence, + members={}, # There were no membership changes + rules_by_user=ret_rules_by_user, + state_group=state_group, + ) if logger.isEnabledFor(logging.DEBUG): logger.debug( @@ -530,67 +480,6 @@ async def get_rules( ) return ret_rules_by_user - async def _update_rules_with_member_event_ids( - self, - ret_rules_by_user: Dict[str, list], - member_event_ids: Dict[str, str], - state_group: Optional[int], - event: EventBase, - ) -> None: - """Update the partially filled rules_by_user dict by fetching rules for - any newly joined users in the `member_event_ids` list. - - Args: - ret_rules_by_user: Partially filled dict of push rules. Gets - updated with any new rules. - member_event_ids: Dict of user id to event id for membership events - that have happened since the last time we filled rules_by_user - state_group: The state group we are currently computing push rules - for. Used when updating the cache. - event: The event we are currently computing push rules for. - """ - sequence = self.data.sequence - - members = await self.store.get_membership_from_event_ids( - member_event_ids.values() - ) - - # If the event is a join event then it will be in current state events - # map but not in the DB, so we have to explicitly insert it. - if event.type == EventTypes.Member: - for event_id in member_event_ids.values(): - if event_id == event.event_id: - members[event_id] = EventIdMembership( - user_id=event.state_key, membership=event.membership - ) - - if logger.isEnabledFor(logging.DEBUG): - logger.debug("Found members %r: %r", self.room_id, members.values()) - - joined_user_ids = { - entry.user_id - for entry in members.values() - if entry and entry.membership == Membership.JOIN - } - - logger.debug("Joined: %r", joined_user_ids) - - # Previously we only considered users with pushers or read receipts in that - # room. We can't do this anymore because we use push actions to calculate unread - # counts, which don't rely on the user having pushers or sent a read receipt into - # the room. Therefore we just need to filter for local users here. - user_ids = list(filter(self.is_mine_id, joined_user_ids)) - - rules_by_user = await self.store.bulk_get_push_rules( - user_ids, on_invalidate=self.invalidate_all_cb - ) - - ret_rules_by_user.update( - item for item in rules_by_user.items() if item[0] is not None - ) - - self.update_cache(sequence, members, ret_rules_by_user, state_group) - def update_cache( self, sequence: int, diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 1653a6a9b694..a07d48f66c95 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -217,6 +217,7 @@ def _invalidate_caches_for_event( if etype == EventTypes.Member: self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) self.get_invited_rooms_for_local_user.invalidate((state_key,)) + self.get_local_users_in_room.invalidate((room_id,)) if relates_to: self.get_relations_for_event.invalidate((relates_to,)) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 17e35cf63e68..46a004ce7985 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1766,6 +1766,10 @@ def _store_room_members_txn( self.store.get_invited_rooms_for_local_user.invalidate, (event.state_key,), ) + txn.call_after( + self.store.get_local_users_in_room.invalidate, + (event.room_id,), + ) # The `_get_membership_from_event_id` is immutable, except for the # case where we look up an event *before* persisting it. diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 31bc8c56011a..ac6f568e6852 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -444,6 +444,15 @@ def _get_rooms_for_local_user_where_membership_is_txn( return results + @cached() + async def get_local_users_in_room(self, room_id: str) -> List[str]: + return await self.db_pool.simple_select_onecol( + table="local_current_membership", + keyvalues={"room_id": room_id, "membership": Membership.JOIN}, + retcol="user_id", + desc="get_local_users_in_room", + ) + async def get_local_current_membership_for_user_in_room( self, user_id: str, room_id: str ) -> Tuple[Optional[str], Optional[str]]: