Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Fetch thread summaries for multiple events in a single query (#11752)
Browse files Browse the repository at this point in the history
This should reduce database usage when fetching bundled aggregations
as the number of individual queries (and round trips to the database) are
reduced.
  • Loading branch information
clokep authored Feb 11, 2022
1 parent bb98c59 commit b65acea
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 74 deletions.
1 change: 1 addition & 0 deletions changelog.d/11752.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve performance when fetching bundled aggregations for multiple events.
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1812,7 +1812,7 @@ def _handle_event_relations(
# potentially error-prone) so it is always invalidated.
txn.call_after(
self.store.get_thread_participated.invalidate,
(parent_id, event.room_id, event.sender),
(parent_id, event.sender),
)

def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
Expand Down
222 changes: 149 additions & 73 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Iterable,
List,
Optional,
Set,
Tuple,
Union,
cast,
Expand Down Expand Up @@ -454,106 +455,175 @@ def _get_applicable_edits_txn(txn: LoggingTransaction) -> Dict[str, str]:
}

@cached()
async def get_thread_summary(
self, event_id: str, room_id: str
) -> Tuple[int, Optional[EventBase]]:
def get_thread_summary(self, event_id: str) -> Optional[Tuple[int, EventBase]]:
raise NotImplementedError()

@cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
async def _get_thread_summaries(
self, event_ids: Collection[str]
) -> Dict[str, Optional[Tuple[int, EventBase]]]:
"""Get the number of threaded replies and the latest reply (if any) for the given event.
Args:
event_id: Summarize the thread related to this event ID.
room_id: The room the event belongs to.
event_ids: Summarize the thread related to this event ID.
Returns:
The number of items in the thread and the most recent response, if any.
A map of the thread summary each event. A missing event implies there
are no threaded replies.
Each summary includes the number of items in the thread and the most
recent response.
"""

def _get_thread_summary_txn(
def _get_thread_summaries_txn(
txn: LoggingTransaction,
) -> Tuple[int, Optional[str]]:
# Fetch the latest event ID in the thread.
) -> Tuple[Dict[str, int], Dict[str, str]]:
# Fetch the count of threaded events and the latest event ID.
# TODO Should this only allow m.room.message events.
sql = """
SELECT event_id
FROM event_relations
INNER JOIN events USING (event_id)
WHERE
relates_to_id = ?
AND room_id = ?
AND relation_type = ?
ORDER BY topological_ordering DESC, stream_ordering DESC
LIMIT 1
"""
if isinstance(self.database_engine, PostgresEngine):
# The `DISTINCT ON` clause will pick the *first* row it encounters,
# so ordering by topologica ordering + stream ordering desc will
# ensure we get the latest event in the thread.
sql = """
SELECT DISTINCT ON (parent.event_id) parent.event_id, child.event_id FROM events AS child
INNER JOIN event_relations USING (event_id)
INNER JOIN events AS parent ON
parent.event_id = relates_to_id
AND parent.room_id = child.room_id
WHERE
%s
AND relation_type = ?
ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC
"""
else:
# SQLite uses a simplified query which returns all entries for a
# thread. The first result for each thread is chosen to and subsequent
# results for a thread are ignored.
sql = """
SELECT parent.event_id, child.event_id FROM events AS child
INNER JOIN event_relations USING (event_id)
INNER JOIN events AS parent ON
parent.event_id = relates_to_id
AND parent.room_id = child.room_id
WHERE
%s
AND relation_type = ?
ORDER BY child.topological_ordering DESC, child.stream_ordering DESC
"""

clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", event_ids
)
args.append(RelationTypes.THREAD)

txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
row = txn.fetchone()
if row is None:
return 0, None
txn.execute(sql % (clause,), args)
latest_event_ids = {}
for parent_event_id, child_event_id in txn:
# Only consider the latest threaded reply (by topological ordering).
if parent_event_id not in latest_event_ids:
latest_event_ids[parent_event_id] = child_event_id

latest_event_id = row[0]
# If no threads were found, bail.
if not latest_event_ids:
return {}, latest_event_ids

# Fetch the number of threaded replies.
sql = """
SELECT COUNT(event_id)
FROM event_relations
INNER JOIN events USING (event_id)
SELECT parent.event_id, COUNT(child.event_id) FROM events AS child
INNER JOIN event_relations USING (event_id)
INNER JOIN events AS parent ON
parent.event_id = relates_to_id
AND parent.room_id = child.room_id
WHERE
relates_to_id = ?
AND room_id = ?
%s
AND relation_type = ?
GROUP BY parent.event_id
"""
txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
count = cast(Tuple[int], txn.fetchone())[0]

return count, latest_event_id
# Regenerate the arguments since only threads found above could
# possibly have any replies.
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", latest_event_ids.keys()
)
args.append(RelationTypes.THREAD)

txn.execute(sql % (clause,), args)
counts = dict(cast(List[Tuple[str, int]], txn.fetchall()))

count, latest_event_id = await self.db_pool.runInteraction(
"get_thread_summary", _get_thread_summary_txn
return counts, latest_event_ids

counts, latest_event_ids = await self.db_pool.runInteraction(
"get_thread_summaries", _get_thread_summaries_txn
)

latest_event = None
if latest_event_id:
latest_event = await self.get_event(latest_event_id, allow_none=True) # type: ignore[attr-defined]
latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]

# Map to the event IDs to the thread summary.
#
# There might not be a summary due to there not being a thread or
# due to the latest event not being known, either case is treated the same.
summaries = {}
for parent_event_id, latest_event_id in latest_event_ids.items():
latest_event = latest_events.get(latest_event_id)

summary = None
if latest_event:
summary = (counts[parent_event_id], latest_event)
summaries[parent_event_id] = summary

return count, latest_event
return summaries

@cached()
async def get_thread_participated(
self, event_id: str, room_id: str, user_id: str
) -> bool:
"""Get whether the requesting user participated in a thread.
def get_thread_participated(self, event_id: str, user_id: str) -> bool:
raise NotImplementedError()

This is separate from get_thread_summary since that can be cached across
all users while this value is specific to the requeser.
@cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
async def _get_threads_participated(
self, event_ids: Collection[str], user_id: str
) -> Dict[str, bool]:
"""Get whether the requesting user participated in the given threads.
This is separate from get_thread_summaries since that can be cached across
all users while this value is specific to the requester.
Args:
event_id: The thread related to this event ID.
room_id: The room the event belongs to.
event_ids: The thread related to these event IDs.
user_id: The user requesting the summary.
Returns:
True if the requesting user participated in the thread, otherwise false.
A map of event ID to a boolean which represents if the requesting
user participated in that event's thread, otherwise false.
"""

def _get_thread_summary_txn(txn: LoggingTransaction) -> bool:
def _get_thread_summary_txn(txn: LoggingTransaction) -> Set[str]:
# Fetch whether the requester has participated or not.
sql = """
SELECT 1
FROM event_relations
INNER JOIN events USING (event_id)
SELECT DISTINCT relates_to_id
FROM events AS child
INNER JOIN event_relations USING (event_id)
INNER JOIN events AS parent ON
parent.event_id = relates_to_id
AND parent.room_id = child.room_id
WHERE
relates_to_id = ?
AND room_id = ?
%s
AND relation_type = ?
AND sender = ?
AND child.sender = ?
"""

txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id))
return bool(txn.fetchone())
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", event_ids
)
args.extend((RelationTypes.THREAD, user_id))

return await self.db_pool.runInteraction(
txn.execute(sql % (clause,), args)
return {row[0] for row in txn.fetchall()}

participated_threads = await self.db_pool.runInteraction(
"get_thread_summary", _get_thread_summary_txn
)

return {event_id: event_id in participated_threads for event_id in event_ids}

async def events_have_relations(
self,
parent_ids: List[str],
Expand Down Expand Up @@ -700,21 +770,6 @@ async def _get_bundled_aggregation_for_event(
if references.chunk:
aggregations.references = await references.to_dict(cast("DataStore", self))

# If this event is the start of a thread, include a summary of the replies.
if self._msc3440_enabled:
thread_count, latest_thread_event = await self.get_thread_summary(
event_id, room_id
)
participated = await self.get_thread_participated(
event_id, room_id, user_id
)
if latest_thread_event:
aggregations.thread = _ThreadAggregation(
latest_event=latest_thread_event,
count=thread_count,
current_user_participated=participated,
)

# Store the bundled aggregations in the event metadata for later use.
return aggregations

Expand Down Expand Up @@ -763,6 +818,27 @@ async def get_bundled_aggregations(
for event_id, edit in edits.items():
results.setdefault(event_id, BundledAggregations()).replace = edit

# Fetch thread summaries.
if self._msc3440_enabled:
summaries = await self._get_thread_summaries(seen_event_ids)
# Only fetch participated for a limited selection based on what had
# summaries.
participated = await self._get_threads_participated(
summaries.keys(), user_id
)
for event_id, summary in summaries.items():
if summary:
thread_count, latest_thread_event = summary
results.setdefault(
event_id, BundledAggregations()
).thread = _ThreadAggregation(
latest_event=latest_thread_event,
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
current_user_participated=participated[event_id],
)

return results


Expand Down

0 comments on commit b65acea

Please sign in to comment.