Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Batch fetch annotations.
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Nov 21, 2022
1 parent dfb843e commit 41ef837
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 89 deletions.
1 change: 1 addition & 0 deletions changelog.d/14491.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Reduce database load of [Client-Server endpoints](https://spec.matrix.org/v1.4/client-server-api/#aggregations) which return bundled aggregations.
92 changes: 59 additions & 33 deletions synapse/handlers/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,16 @@
# limitations under the License.
import enum
import logging
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
from typing import (
TYPE_CHECKING,
Collection,
Dict,
FrozenSet,
Iterable,
List,
Optional,
Tuple,
)

import attr

Expand Down Expand Up @@ -259,13 +268,9 @@ async def redact_events_related_to(
e.msg,
)

async def get_annotations_for_event(
self,
event_id: str,
room_id: str,
limit: int = 5,
ignored_users: FrozenSet[str] = frozenset(),
) -> List[JsonDict]:
async def get_annotations_for_events(
self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
) -> Dict[str, List[JsonDict]]:
"""Get a list of annotations on the event, grouped by event type and
aggregation key, sorted by count.
Expand All @@ -274,33 +279,53 @@ async def get_annotations_for_event(
Args:
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
limit: Only fetch the `limit` groups.
ignored_users: The users ignored by the requesting user.
Returns:
List of groups of annotations that match. Each row is a dict with
`type`, `key` and `count` fields.
A map of event IDs to a list of groups of annotations that match.
Each entry is a dict with `type`, `key` and `count` fields.
"""
# Get the base results for all users.
full_results = await self._main_store.get_aggregation_groups_for_event(
event_id, room_id, limit
full_results = await self._main_store.get_aggregation_groups_for_events(
event_ids
)

# Avoid additional logic if there are no ignored users.
if not ignored_users:
return {
event_id: results
for event_id, results in full_results.items()
if results
}

# Then subtract off the results for any ignored users.
ignored_results = await self._main_store.get_aggregation_groups_for_users(
event_id, room_id, limit, ignored_users
[event_id for event_id, results in full_results.items() if results],
ignored_users,
)

filtered_results = []
for result in full_results:
key = (result["type"], result["key"])
if key in ignored_results:
result = result.copy()
result["count"] -= ignored_results[key]
if result["count"] <= 0:
continue
filtered_results.append(result)
filtered_results = {}
for event_id, results in full_results.items():
# If no annotations, skip.
if not results:
continue

# If there are not ignored results for this event, copy verbatim.
if event_id not in ignored_results:
filtered_results[event_id] = results
continue

# Otherwise, subtract out the ignored results.
filtered_results[event_id] = []
event_ignored_results = ignored_results[event_id]
for result in results:
key = (result["type"], result["key"])
if key in event_ignored_results:
result = result.copy()
result["count"] -= event_ignored_results[key]
if result["count"] <= 0:
continue
filtered_results[event_id].append(result)

return filtered_results

Expand Down Expand Up @@ -499,17 +524,18 @@ async def get_bundled_aggregations(
# (as that is what makes it part of the thread).
relations_by_id[latest_thread_event.event_id] = RelationTypes.THREAD

# Fetch other relations per event.
for event in events_by_id.values():
# Fetch any annotations (ie, reactions) to bundle with this event.
annotations = await self.get_annotations_for_event(
event.event_id, event.room_id, ignored_users=ignored_users
)
# Fetch any annotations (ie, reactions) to bundle with this event.
annotations_by_event_id = await self.get_annotations_for_events(
events_by_id.keys(), ignored_users=ignored_users
)
for event_id, annotations in annotations_by_event_id.items():
if annotations:
results.setdefault(
event.event_id, BundledAggregations()
).annotations = {"chunk": annotations}
results.setdefault(event_id, BundledAggregations()).annotations = {
"chunk": annotations
}

# Fetch other relations per event.
for event in events_by_id.values():
# Fetch any references to bundle with this event.
references, next_token = await self.get_relations_for_event(
event.event_id,
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def _do_execute(
):
return func(sql, *args, **kwargs)
except Exception as e:
sql_logger.debug("[SQL FAIL] {%s} %s", self.name, e)
logger.error("[SQL FAIL] {%s} %s", self.name, e)
raise
finally:
secs = time.time() - start
Expand Down
129 changes: 77 additions & 52 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
FrozenSet,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
Expand Down Expand Up @@ -394,106 +395,130 @@ async def event_is_target_of_relation(self, parent_id: str) -> bool:
)
return result is not None

@cached(tree=True)
async def get_aggregation_groups_for_event(
self, event_id: str, room_id: str, limit: int = 5
) -> List[JsonDict]:
"""Get a list of annotations on the event, grouped by event type and
@cached()
async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]:
raise NotImplementedError()

@cachedList(
cached_method_name="get_aggregation_groups_for_event", list_name="event_ids"
)
async def get_aggregation_groups_for_events(
self, event_ids: Collection[str]
) -> Mapping[str, Optional[List[JsonDict]]]:
"""Get a list of annotations on the given events, grouped by event type and
aggregation key, sorted by count.
This is used e.g. to get the what and how many reactions have happend
on an event.
Args:
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
limit: Only fetch the `limit` groups.
event_ids: Fetch events that relate to these event IDs.
Returns:
List of groups of annotations that match. Each row is a dict with
`type`, `key` and `count` fields.
A map of event IDs to a list of groups of annotations that match.
Each entry is a dict with `type`, `key` and `count` fields.
"""

args = [
event_id,
room_id,
RelationTypes.ANNOTATION,
limit,
]
clause, args = make_in_list_sql_clause(
self.database_engine, "relates_to_id", event_ids
)
args.append(RelationTypes.ANNOTATION)

sql = """
SELECT type, aggregation_key, COUNT(DISTINCT sender)
FROM event_relations
INNER JOIN events USING (event_id)
WHERE relates_to_id = ? AND room_id = ? AND relation_type = ?
GROUP BY relation_type, type, aggregation_key
ORDER BY COUNT(*) DESC
LIMIT ?
sql = f"""
SELECT
relates_to_id,
annotation.type,
aggregation_key,
COUNT(DISTINCT annotation.sender)
FROM events AS annotation
INNER JOIN event_relations USING (event_id)
INNER JOIN events AS parent ON
parent.event_id = relates_to_id
AND parent.room_id = annotation.room_id
WHERE
{clause}
AND relation_type = ?
GROUP BY parent.event_id, annotation.type, aggregation_key
ORDER BY parent.event_id, COUNT(*) DESC
"""

def _get_aggregation_groups_for_event_txn(
txn: LoggingTransaction,
) -> List[JsonDict]:
) -> Mapping[str, List[JsonDict]]:
txn.execute(sql, args)

return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn]
result: Dict[str, List[JsonDict]] = {}
for event_id, type, key, count in cast(
List[Tuple[str, str, str, int]], txn
):
result.setdefault(event_id, []).append(
{"type": type, "key": key, "count": count}
)

return result

return await self.db_pool.runInteraction(
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
)

async def get_aggregation_groups_for_users(
self,
event_id: str,
room_id: str,
limit: int,
users: FrozenSet[str] = frozenset(),
) -> Dict[Tuple[str, str], int]:
self, event_ids: Collection[str], users: FrozenSet[str] = frozenset()
) -> Dict[str, Dict[Tuple[str, str], int]]:
"""Fetch the partial aggregations for an event for specific users.
This is used, in conjunction with get_aggregation_groups_for_event, to
remove information from the results for ignored users.
Args:
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
limit: Only fetch the `limit` groups.
event_ids: Fetch events that relate to these event IDs.
users: The users to fetch information for.
Returns:
A map of (event type, aggregation key) to a count of users.
A map of event ID to a map of (event type, aggregation key) to a
count of users.
"""

if not users:
return {}

args: List[Union[str, int]] = [
event_id,
room_id,
RelationTypes.ANNOTATION,
]
events_sql, args = make_in_list_sql_clause(
self.database_engine, "relates_to_id", event_ids
)

users_sql, users_args = make_in_list_sql_clause(
self.database_engine, "sender", users
self.database_engine, "annotation.sender", users
)
args.extend(users_args)
args.append(RelationTypes.ANNOTATION)

sql = f"""
SELECT type, aggregation_key, COUNT(DISTINCT sender)
FROM event_relations
INNER JOIN events USING (event_id)
WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql}
GROUP BY relation_type, type, aggregation_key
ORDER BY COUNT(*) DESC
LIMIT ?
SELECT
relates_to_id,
annotation.type,
aggregation_key,
COUNT(DISTINCT annotation.sender)
FROM events AS annotation
INNER JOIN event_relations USING (event_id)
INNER JOIN events AS parent ON
parent.event_id = relates_to_id
AND parent.room_id = annotation.room_id
WHERE {events_sql} AND {users_sql} AND relation_type = ?
GROUP BY parent.event_id, annotation.type, aggregation_key
ORDER BY parent.event_id, COUNT(*) DESC
"""

def _get_aggregation_groups_for_users_txn(
txn: LoggingTransaction,
) -> Dict[Tuple[str, str], int]:
txn.execute(sql, args + [limit])
) -> Dict[str, Dict[Tuple[str, str], int]]:
txn.execute(sql, args)

return {(row[0], row[1]): row[2] for row in txn}
result: Dict[str, Dict[Tuple[str, str], int]] = {}
for event_id, type, key, count in cast(
List[Tuple[str, str, str, int]], txn
):
result.setdefault(event_id, {})[(type, key)] = count

return result

return await self.db_pool.runInteraction(
"get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn
Expand Down
2 changes: 1 addition & 1 deletion synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def cachedList(
is specified as a list that is iterated through to lookup keys in the
original cache. A new tuple consisting of the (deduplicated) keys that weren't in
the cache gets passed to the original function, which is expected to results
in a map of key to value for each passed value. THe new results are stored in the
in a map of key to value for each passed value. The new results are stored in the
original cache. Note that any missing values are cached as None.
Args:
Expand Down
4 changes: 2 additions & 2 deletions tests/rest/client/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,7 @@ def assert_thread(bundled_aggregations: JsonDict) -> None:

# The "user" sent the root event and is making queries for the bundled
# aggregations: they have participated.
self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 9)
self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8)
# The "user2" sent replies in the thread and is making queries for the
# bundled aggregations: they have participated.
#
Expand Down Expand Up @@ -1170,7 +1170,7 @@ def assert_thread(bundled_aggregations: JsonDict) -> None:
bundled_aggregations["latest_event"].get("unsigned"),
)

self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8)

def test_nested_thread(self) -> None:
"""
Expand Down

0 comments on commit 41ef837

Please sign in to comment.