Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Type annotations for test_v2 (#12985)
Browse files Browse the repository at this point in the history
  • Loading branch information
David Robertson authored Jun 9, 2022
1 parent 04ca3a5 commit 97053c9
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 58 deletions.
1 change: 1 addition & 0 deletions changelog.d/12985.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type annotations to `tests.state.test_v2`.
4 changes: 3 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ exclude = (?x)
|tests/rest/media/v1/test_media_storage.py
|tests/server.py
|tests/server_notices/test_resource_limits_server_notices.py
|tests/state/test_v2.py
|tests/test_metrics.py
|tests/test_server.py
|tests/test_state.py
Expand Down Expand Up @@ -115,6 +114,9 @@ disallow_untyped_defs = False
[mypy-tests.handlers.test_user_directory]
disallow_untyped_defs = True

[mypy-tests.state.test_profile]
disallow_untyped_defs = True

[mypy-tests.storage.test_profile]
disallow_untyped_defs = True

Expand Down
57 changes: 42 additions & 15 deletions synapse/state/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,46 +17,73 @@
import logging
from typing import (
Any,
Awaitable,
Callable,
Collection,
Dict,
Generator,
Iterable,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
overload,
)

from typing_extensions import Literal
from typing_extensions import Literal, Protocol

import synapse.state
from synapse import event_auth
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap
from synapse.util import Clock

logger = logging.getLogger(__name__)


class Clock(Protocol):
# This is usually synapse.util.Clock, but it's replaced with a FakeClock in tests.
# We only ever sleep(0) though, so that other async functions can make forward
# progress without waiting for stateres to complete.
def sleep(self, duration_ms: float) -> Awaitable[None]:
...


class StateResolutionStore(Protocol):
# This is usually synapse.state.StateResolutionStore, but it's replaced with a
# TestStateResolutionStore in tests.
def get_events(
self, event_ids: Collection[str], allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]:
...

def get_auth_chain_difference(
self, room_id: str, state_sets: List[Set[str]]
) -> Awaitable[Set[str]]:
...


# We want to await to the reactor occasionally during state res when dealing
# with large data sets, so that we don't exhaust the reactor. This is done by
# awaiting to reactor during loops every N iterations.
_AWAIT_AFTER_ITERATIONS = 100


__all__ = [
"resolve_events_with_store",
]


async def resolve_events_with_store(
clock: Clock,
room_id: str,
room_version: RoomVersion,
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "synapse.state.StateResolutionStore",
state_res_store: StateResolutionStore,
) -> StateMap[str]:
"""Resolves the state using the v2 state resolution algorithm
Expand Down Expand Up @@ -194,7 +221,7 @@ async def _get_power_level_for_sender(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
state_res_store: StateResolutionStore,
) -> int:
"""Return the power level of the sender of the given event according to
their auth events.
Expand Down Expand Up @@ -243,9 +270,9 @@ async def _get_power_level_for_sender(

async def _get_auth_chain_difference(
room_id: str,
state_sets: Sequence[StateMap[str]],
state_sets: Sequence[Mapping[Any, str]],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
state_res_store: StateResolutionStore,
) -> Set[str]:
"""Compare the auth chains of each state set and return the set of events
that only appear in some but not all of the auth chains.
Expand Down Expand Up @@ -406,7 +433,7 @@ async def _add_event_and_auth_chain_to_graph(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
state_res_store: StateResolutionStore,
auth_diff: Set[str],
) -> None:
"""Helper function for _reverse_topological_power_sort that add the event
Expand Down Expand Up @@ -440,7 +467,7 @@ async def _reverse_topological_power_sort(
room_id: str,
event_ids: Iterable[str],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
state_res_store: StateResolutionStore,
auth_diff: Set[str],
) -> List[str]:
"""Returns a list of the event_ids sorted by reverse topological ordering,
Expand Down Expand Up @@ -501,7 +528,7 @@ async def _iterative_auth_checks(
event_ids: List[str],
base_state: StateMap[str],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
state_res_store: StateResolutionStore,
) -> MutableStateMap[str]:
"""Sequentially apply auth checks to each event in given list, updating the
state as it goes along.
Expand Down Expand Up @@ -570,7 +597,7 @@ async def _mainline_sort(
event_ids: List[str],
resolved_power_event_id: Optional[str],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
state_res_store: StateResolutionStore,
) -> List[str]:
"""Returns a sorted list of event_ids sorted by mainline ordering based on
the given event resolved_power_event_id
Expand Down Expand Up @@ -639,7 +666,7 @@ async def _get_mainline_depth_for_event(
event: EventBase,
mainline_map: Dict[str, int],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
state_res_store: StateResolutionStore,
) -> int:
"""Get the mainline depths for the given event based on the mainline map
Expand Down Expand Up @@ -683,7 +710,7 @@ async def _get_event(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
state_res_store: StateResolutionStore,
allow_none: Literal[False] = False,
) -> EventBase:
...
Expand All @@ -694,7 +721,7 @@ async def _get_event(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
state_res_store: StateResolutionStore,
allow_none: Literal[True],
) -> Optional[EventBase]:
...
Expand All @@ -704,7 +731,7 @@ async def _get_event(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
state_res_store: StateResolutionStore,
allow_none: bool = False,
) -> Optional[EventBase]:
"""Helper function to look up event in event_map, falling back to looking
Expand Down
Loading

0 comments on commit 97053c9

Please sign in to comment.