Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Use the chain cover index in get_auth_chain_ids #9576

Merged
merged 8 commits into from
Mar 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9576.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve efficiency of calculating the auth chain in large rooms.
6 changes: 4 additions & 2 deletions synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ async def on_state_ids_request(

async def _on_state_ids_request_compute(self, room_id, event_id):
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
auth_chain_ids = await self.store.get_auth_chain_ids(state_ids)
auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}

async def _on_context_state_request_compute(
Expand All @@ -460,7 +460,9 @@ async def _on_context_state_request_compute(
else:
pdus = (await self.state.get_current_state(room_id)).values()

auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus])
auth_chain = await self.store.get_auth_chain(
room_id, [pdu.event_id for pdu in pdus]
)

return {
"pdus": [pdu.get_pdu_json() for pdu in pdus],
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,7 +1317,7 @@ async def send_invite(self, target_host, event):
async def on_event_auth(self, event_id: str) -> List[EventBase]:
event = await self.store.get_event(event_id)
auth = await self.store.get_auth_chain(
list(event.auth_event_ids()), include_given=True
event.room_id, list(event.auth_event_ids()), include_given=True
)
return list(auth)

Expand Down Expand Up @@ -1580,7 +1580,7 @@ async def on_send_join_request(self, origin, pdu):
prev_state_ids = await context.get_prev_state_ids()

state_ids = list(prev_state_ids.values())
auth_chain = await self.store.get_auth_chain(state_ids)
auth_chain = await self.store.get_auth_chain(event.room_id, state_ids)

state = await self.store.get_events(list(prev_state_ids.values()))

Expand Down Expand Up @@ -2219,7 +2219,7 @@ async def on_query_auth(

# Now get the current auth_chain for the event.
local_auth_chain = await self.store.get_auth_chain(
list(event.auth_event_ids()), include_given=True
room_id, list(event.auth_event_ids()), include_given=True
)

# TODO: Check if we would now reject event_id. If so we need to tell
Expand Down
148 changes: 145 additions & 3 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,46 +54,188 @@ def __init__(self, database: DatabasePool, db_conn, hs):
) # type: LruCache[str, List[Tuple[str, int]]]

async def get_auth_chain(
self, event_ids: Collection[str], include_given: bool = False
self, room_id: str, event_ids: Collection[str], include_given: bool = False
) -> List[EventBase]:
"""Get auth events for given event_ids. The events *must* be state events.

Args:
room_id: The room the event is in.
event_ids: state events
include_given: include the given events in result

Returns:
list of events
"""
event_ids = await self.get_auth_chain_ids(
event_ids, include_given=include_given
room_id, event_ids, include_given=include_given
)
return await self.get_events_as_list(event_ids)

async def get_auth_chain_ids(
self,
room_id: str,
event_ids: Collection[str],
include_given: bool = False,
) -> List[str]:
"""Get auth events for given event_ids. The events *must* be state events.

Args:
room_id: The room the event is in.
event_ids: state events
include_given: include the given events in result

Returns:
An awaitable which resolve to a list of event_ids
list of event_ids
"""

# Check if we have indexed the room so we can use the chain cover
# algorithm.
room = await self.get_room(room_id)
if room["has_auth_chain_index"]:
try:
return await self.db_pool.runInteraction(
"get_auth_chain_ids_chains",
self._get_auth_chain_ids_using_cover_index_txn,
room_id,
event_ids,
include_given,
)
except _NoChainCoverIndex:
# For whatever reason we don't actually have a chain cover index
# for the events in question, so we fall back to the old method.
pass

return await self.db_pool.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
event_ids,
include_given,
)

def _get_auth_chain_ids_using_cover_index_txn(
self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
) -> List[str]:
"""Calculates the auth chain IDs using the chain index."""

# First we look up the chain ID/sequence numbers for the given events.

initial_events = set(event_ids)

# All the events that we've found that are reachable from the events.
seen_events = set() # type: Set[str]

# A map from chain ID to max sequence number of the given events.
event_chains = {} # type: Dict[int, int]

sql = """
SELECT event_id, chain_id, sequence_number
FROM event_auth_chains
WHERE %s
"""
for batch in batch_iter(initial_events, 1000):
clause, args = make_in_list_sql_clause(
txn.database_engine, "event_id", batch
)
txn.execute(sql % (clause,), args)

for event_id, chain_id, sequence_number in txn:
seen_events.add(event_id)
event_chains[chain_id] = max(
sequence_number, event_chains.get(chain_id, 0)
)

# Check that we actually have a chain ID for all the events.
events_missing_chain_info = initial_events.difference(seen_events)
if events_missing_chain_info:
# This can happen due to e.g. downgrade/upgrade of the server. We
# raise an exception and fall back to the previous algorithm.
logger.info(
"Unexpectedly found that events don't have chain IDs in room %s: %s",
room_id,
events_missing_chain_info,
)
raise _NoChainCoverIndex(room_id)

# Now we look up all links for the chains we have, adding chains that
# are reachable from any event.
sql = """
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
FROM event_auth_chain_links
WHERE %s
"""

# A map from chain ID to max sequence number *reachable* from any event ID.
chains = {} # type: Dict[int, int]

# Add all linked chains reachable from initial set of chains.
for batch in batch_iter(event_chains, 1000):
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch
)
txn.execute(sql % (clause,), args)

for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
# chains are only reachable if the origin sequence number of
# the link is less than the max sequence number in the
# origin chain.
if origin_sequence_number <= event_chains.get(origin_chain_id, 0):
clokep marked this conversation as resolved.
Show resolved Hide resolved
chains[target_chain_id] = max(
target_sequence_number,
chains.get(target_chain_id, 0),
)

# Add the initial set of chains, excluding the sequence corresponding to
# initial event.
for chain_id, seq_no in event_chains.items():
chains[chain_id] = max(seq_no - 1, chains.get(chain_id, 0))

# Now for each chain we figure out the maximum sequence number reachable
# from *any* event ID. Events with a sequence less than that are in the
# auth chain.
if include_given:
results = initial_events
else:
results = set()

if isinstance(self.database_engine, PostgresEngine):
# We can use `execute_values` to efficiently fetch the gaps when
# using postgres.
sql = """
SELECT event_id
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
WHERE
c.chain_id = l.chain_id
AND sequence_number <= max_seq
"""

rows = txn.execute_values(sql, chains.items())
results.update(r for r, in rows)
else:
# For SQLite we just fall back to doing a noddy for loop.
sql = """
SELECT event_id FROM event_auth_chains
WHERE chain_id = ? AND sequence_number <= ?
"""
for chain_id, max_no in chains.items():
txn.execute(sql, (chain_id, max_no))
results.update(r for r, in txn)

return list(results)

def _get_auth_chain_ids_txn(
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
) -> List[str]:
"""Calculates the auth chain IDs.

This is used when we don't have a cover index for the room.
"""
if include_given:
results = set(event_ids)
else:
Expand Down
76 changes: 73 additions & 3 deletions tests/storage/test_event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ def insert_event(txn, i, room_id):
r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
self.assertTrue(r == [room2] or r == [room3])

@parameterized.expand([(True,), (False,)])
def test_auth_difference(self, use_chain_cover_index: bool):
def _setup_auth_chain(self, use_chain_cover_index: bool) -> str:
room_id = "@ROOM:local"

# The silly auth graph we use to test the auth difference algorithm,
Expand Down Expand Up @@ -165,7 +164,7 @@ def test_auth_difference(self, use_chain_cover_index: bool):
"j": 1,
}

# Mark the room as not having a cover index
# Mark the room as maybe having a cover index.

def store_room(txn):
self.store.db_pool.simple_insert_txn(
Expand Down Expand Up @@ -222,6 +221,77 @@ def insert_event(txn):
)
)

return room_id

@parameterized.expand([(True,), (False,)])
def test_auth_chain_ids(self, use_chain_cover_index: bool):
room_id = self._setup_auth_chain(use_chain_cover_index)

# a and b have the same auth chain.
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["a"]))
self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["b"]))
self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
auth_chain_ids = self.get_success(
self.store.get_auth_chain_ids(room_id, ["a", "b"])
)
self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])

auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["c"]))
self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])

# d and e have the same auth chain.
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["d"]))
self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["e"]))
self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])

auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["f"]))
self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])

auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["g"]))
self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"])

auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"]))
self.assertEqual(auth_chain_ids, ["k"])

auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"]))
self.assertEqual(auth_chain_ids, ["j"])

# j and k have no parents.
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"]))
self.assertEqual(auth_chain_ids, [])
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"]))
self.assertEqual(auth_chain_ids, [])

# More complex input sequences.
auth_chain_ids = self.get_success(
self.store.get_auth_chain_ids(room_id, ["b", "c", "d"])
)
self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])

auth_chain_ids = self.get_success(
self.store.get_auth_chain_ids(room_id, ["h", "i"])
)
self.assertCountEqual(auth_chain_ids, ["k", "j"])

# e gets returned even though include_given is false, but it is in the
# auth chain of b.
auth_chain_ids = self.get_success(
self.store.get_auth_chain_ids(room_id, ["b", "e"])
)
self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])

# Test include_given.
auth_chain_ids = self.get_success(
self.store.get_auth_chain_ids(room_id, ["i"], include_given=True)
)
self.assertCountEqual(auth_chain_ids, ["i", "j"])

@parameterized.expand([(True,), (False,)])
def test_auth_difference(self, use_chain_cover_index: bool):
room_id = self._setup_auth_chain(use_chain_cover_index)

# Now actually test that various combinations give the right result:

difference = self.get_success(
Expand Down