Skip to content

Commit

Permalink
Refactor chain fetching (#17044)
Browse files Browse the repository at this point in the history
Since these queries are duplicated in two places.
  • Loading branch information
erikjohnston committed Apr 2, 2024
1 parent fd48fc4 commit ec174d0
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 96 deletions.
1 change: 1 addition & 0 deletions changelog.d/17044.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor auth chain fetching to reduce duplication.
162 changes: 66 additions & 96 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Collection,
Dict,
FrozenSet,
Generator,
Iterable,
List,
Optional,
Expand Down Expand Up @@ -279,64 +280,16 @@ def _get_auth_chain_ids_using_cover_index_txn(

# Now we look up all links for the chains we have, adding chains that
# are reachable from any event.
#
# This query is structured to first get all chain IDs reachable, and
# then pull out all links from those chains. This does pull out more
# rows than is strictly necessary, however there isn't a way of
# structuring the recursive part of query to pull out the links without
# also returning large quantities of redundant data (which can make it a
# lot slower).
sql = """
WITH RECURSIVE links(chain_id) AS (
SELECT
DISTINCT origin_chain_id
FROM event_auth_chain_links WHERE %s
UNION
SELECT
target_chain_id
FROM event_auth_chain_links
INNER JOIN links ON (chain_id = origin_chain_id)
)
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
FROM links
INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
"""

# A map from chain ID to max sequence number *reachable* from any event ID.
chains: Dict[int, int] = {}

# Add all linked chains reachable from initial set of chains.
chains_to_fetch = set(event_chains.keys())
while chains_to_fetch:
batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
chains_to_fetch.difference_update(batch2)
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)

links: Dict[int, List[Tuple[int, int, int]]] = {}

for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
links.setdefault(origin_chain_id, []).append(
(origin_sequence_number, target_chain_id, target_sequence_number)
)

for links in self._get_chain_links(txn, set(event_chains.keys())):
for chain_id in links:
if chain_id not in event_chains:
continue

_materialize(chain_id, event_chains[chain_id], links, chains)

chains_to_fetch.difference_update(chains)

# Add the initial set of chains, excluding the sequence corresponding to
# initial event.
for chain_id, seq_no in event_chains.items():
Expand Down Expand Up @@ -380,6 +333,68 @@ def _get_auth_chain_ids_using_cover_index_txn(

return results

@classmethod
def _get_chain_links(
cls, txn: LoggingTransaction, chains_to_fetch: Set[int]
) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]:
"""Fetch all auth chain links from the given set of chains, and all
links from those chains, recursively.
Note: This may return links that are not reachable from the given
chains.
Returns a generator that produces dicts from origin chain ID to 3-tuple
of origin sequence number, target chain ID and target sequence number.
"""

# This query is structured to first get all chain IDs reachable, and
# then pull out all links from those chains. This does pull out more
# rows than is strictly necessary, however there isn't a way of
# structuring the recursive part of query to pull out the links without
# also returning large quantities of redundant data (which can make it a
# lot slower).
sql = """
WITH RECURSIVE links(chain_id) AS (
SELECT
DISTINCT origin_chain_id
FROM event_auth_chain_links WHERE %s
UNION
SELECT
target_chain_id
FROM event_auth_chain_links
INNER JOIN links ON (chain_id = origin_chain_id)
)
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
FROM links
INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
"""

while chains_to_fetch:
batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
chains_to_fetch.difference_update(batch2)
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)

links: Dict[int, List[Tuple[int, int, int]]] = {}

for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
links.setdefault(origin_chain_id, []).append(
(origin_sequence_number, target_chain_id, target_sequence_number)
)

chains_to_fetch.difference_update(links)

yield links

def _get_auth_chain_ids_txn(
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
) -> Set[str]:
Expand Down Expand Up @@ -564,61 +579,16 @@ def fetch_chain_info(events_to_fetch: Collection[str]) -> None:

# Now we look up all links for the chains we have, adding chains that
# are reachable from any event.
#
# This query is structured to first get all chain IDs reachable, and
# then pull out all links from those chains. This does pull out more
# rows than is strictly necessary, however there isn't a way of
# structuring the recursive part of query to pull out the links without
# also returning large quantities of redundant data (which can make it a
# lot slower).
sql = """
WITH RECURSIVE links(chain_id) AS (
SELECT
DISTINCT origin_chain_id
FROM event_auth_chain_links WHERE %s
UNION
SELECT
target_chain_id
FROM event_auth_chain_links
INNER JOIN links ON (chain_id = origin_chain_id)
)
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
FROM links
INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
"""

# (We need to take a copy of `seen_chains` as we want to mutate it in
# the loop)
chains_to_fetch = set(seen_chains)
while chains_to_fetch:
batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)

links: Dict[int, List[Tuple[int, int, int]]] = {}

for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
links.setdefault(origin_chain_id, []).append(
(origin_sequence_number, target_chain_id, target_sequence_number)
)

# (We need to take a copy of `seen_chains` as the function mutates it)
for links in self._get_chain_links(txn, set(seen_chains)):
for chains in set_to_chain:
for chain_id in links:
if chain_id not in chains:
continue

_materialize(chain_id, chains[chain_id], links, chains)

chains_to_fetch.difference_update(chains)
seen_chains.update(chains)

# Now for each chain we figure out the maximum sequence number reachable
Expand Down

0 comments on commit ec174d0

Please sign in to comment.