Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor chain fetching #17044

Merged
merged 2 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/17044.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor auth chain fetching to reduce duplication.
162 changes: 66 additions & 96 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Collection,
Dict,
FrozenSet,
Generator,
Iterable,
List,
Optional,
Expand Down Expand Up @@ -279,64 +280,16 @@ def _get_auth_chain_ids_using_cover_index_txn(

# Now we look up all links for the chains we have, adding chains that
# are reachable from any event.
#
# This query is structured to first get all chain IDs reachable, and
# then pull out all links from those chains. This does pull out more
# rows than is strictly necessary, however there isn't a way of
# structuring the recursive part of query to pull out the links without
# also returning large quantities of redundant data (which can make it a
# lot slower).
sql = """
WITH RECURSIVE links(chain_id) AS (
SELECT
DISTINCT origin_chain_id
FROM event_auth_chain_links WHERE %s
UNION
SELECT
target_chain_id
FROM event_auth_chain_links
INNER JOIN links ON (chain_id = origin_chain_id)
)
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
FROM links
INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
"""

# A map from chain ID to max sequence number *reachable* from any event ID.
chains: Dict[int, int] = {}

# Add all linked chains reachable from initial set of chains.
chains_to_fetch = set(event_chains.keys())
while chains_to_fetch:
batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
chains_to_fetch.difference_update(batch2)
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)

links: Dict[int, List[Tuple[int, int, int]]] = {}

for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
links.setdefault(origin_chain_id, []).append(
(origin_sequence_number, target_chain_id, target_sequence_number)
)

for links in self._get_chain_links(txn, set(event_chains.keys())):
for chain_id in links:
if chain_id not in event_chains:
continue

_materialize(chain_id, event_chains[chain_id], links, chains)

chains_to_fetch.difference_update(chains)

# Add the initial set of chains, excluding the sequence corresponding to
# initial event.
for chain_id, seq_no in event_chains.items():
Expand Down Expand Up @@ -380,6 +333,68 @@ def _get_auth_chain_ids_using_cover_index_txn(

return results

@classmethod
def _get_chain_links(
cls, txn: LoggingTransaction, chains_to_fetch: Set[int]
) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]:
"""Fetch all auth chain links from the given set of chains, and all
links from those chains, recursively.
Note: This may return links that are not reachable from the given
chains.
Returns a generator that produces dicts from origin chain ID to 3-tuple
of origin sequence number, target chain ID and target sequence number.
"""

# This query is structured to first get all chain IDs reachable, and
# then pull out all links from those chains. This does pull out more
# rows than is strictly necessary, however there isn't a way of
# structuring the recursive part of query to pull out the links without
# also returning large quantities of redundant data (which can make it a
# lot slower).
sql = """
WITH RECURSIVE links(chain_id) AS (
SELECT
DISTINCT origin_chain_id
FROM event_auth_chain_links WHERE %s
UNION
SELECT
target_chain_id
FROM event_auth_chain_links
INNER JOIN links ON (chain_id = origin_chain_id)
)
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
FROM links
INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
"""

while chains_to_fetch:
batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
chains_to_fetch.difference_update(batch2)
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)

links: Dict[int, List[Tuple[int, int, int]]] = {}

for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
links.setdefault(origin_chain_id, []).append(
(origin_sequence_number, target_chain_id, target_sequence_number)
)

chains_to_fetch.difference_update(links)

yield links

def _get_auth_chain_ids_txn(
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
) -> Set[str]:
Expand Down Expand Up @@ -564,61 +579,16 @@ def fetch_chain_info(events_to_fetch: Collection[str]) -> None:

# Now we look up all links for the chains we have, adding chains that
# are reachable from any event.
#
# This query is structured to first get all chain IDs reachable, and
# then pull out all links from those chains. This does pull out more
# rows than is strictly necessary, however there isn't a way of
# structuring the recursive part of query to pull out the links without
# also returning large quantities of redundant data (which can make it a
# lot slower).
sql = """
WITH RECURSIVE links(chain_id) AS (
SELECT
DISTINCT origin_chain_id
FROM event_auth_chain_links WHERE %s
UNION
SELECT
target_chain_id
FROM event_auth_chain_links
INNER JOIN links ON (chain_id = origin_chain_id)
)
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
FROM links
INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
"""

# (We need to take a copy of `seen_chains` as we want to mutate it in
# the loop)
chains_to_fetch = set(seen_chains)
while chains_to_fetch:
batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)

links: Dict[int, List[Tuple[int, int, int]]] = {}

for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
links.setdefault(origin_chain_id, []).append(
(origin_sequence_number, target_chain_id, target_sequence_number)
)

# (We need to take a copy of `seen_chains` as the function mutates it)
for links in self._get_chain_links(txn, set(seen_chains)):
for chains in set_to_chain:
for chain_id in links:
if chain_id not in chains:
continue

_materialize(chain_id, chains[chain_id], links, chains)

chains_to_fetch.difference_update(chains)
seen_chains.update(chains)

# Now for each chain we figure out the maximum sequence number reachable
Expand Down
Loading