Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add missing type hints for tests.events #14904

Merged
merged 5 commits into from
Jan 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/14904.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints.
5 changes: 3 additions & 2 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ exclude = (?x)
|tests/api/test_auth.py
|tests/app/test_openid_listener.py
|tests/appservice/test_scheduler.py
|tests/events/test_presence_router.py
|tests/events/test_utils.py
|tests/federation/test_federation_catch_up.py
|tests/federation/test_federation_sender.py
|tests/federation/transport/test_knocking.py
Expand Down Expand Up @@ -87,6 +85,9 @@ disallow_untyped_defs = True
[mypy-tests.crypto.*]
disallow_untyped_defs = True

[mypy-tests.events.*]
disallow_untyped_defs = True

[mypy-tests.federation.transport.test_client]
disallow_untyped_defs = True

Expand Down
3 changes: 2 additions & 1 deletion synapse/events/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,10 +605,11 @@ def serialize_events(


_PowerLevel = Union[str, int]
PowerLevelsContent = Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]]


def copy_and_fixup_power_levels_contents(
old_power_levels: Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]]
old_power_levels: PowerLevelsContent,
) -> Dict[str, Union[int, Dict[str, int]]]:
"""Copy the content of a power_levels event, unfreezing frozendicts along the way.

Expand Down
58 changes: 35 additions & 23 deletions tests/events/test_presence_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,22 @@

import attr

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import EduTypes
from synapse.events.presence_router import PresenceRouter, load_legacy_presence_router
from synapse.federation.units import Transaction
from synapse.handlers.presence import UserPresenceState
from synapse.module_api import ModuleApi
from synapse.rest import admin
from synapse.rest.client import login, presence, room
from synapse.server import HomeServer
from synapse.types import JsonDict, StreamToken, create_requester
from synapse.util import Clock

from tests.handlers.test_sync import generate_sync_config
from tests.test_utils import simple_async_mock
from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config
from tests.unittest import FederatingHomeserverTestCase, override_config


@attr.s
Expand All @@ -49,9 +53,7 @@ async def get_users_for_states(
}
return users_to_state

async def get_interested_users(
self, user_id: str
) -> Union[Set[str], PresenceRouter.ALL_USERS]:
async def get_interested_users(self, user_id: str) -> Union[Set[str], str]:
if user_id in self._config.users_who_should_receive_all_presence:
return PresenceRouter.ALL_USERS

Expand All @@ -71,9 +73,14 @@ def parse_config(config_dict: dict) -> PresenceRouterTestConfig:
# Initialise a typed config object
config = PresenceRouterTestConfig()

config.users_who_should_receive_all_presence = config_dict.get(
users_who_should_receive_all_presence = config_dict.get(
"users_who_should_receive_all_presence"
)
assert isinstance(users_who_should_receive_all_presence, list)

config.users_who_should_receive_all_presence = (
users_who_should_receive_all_presence
)

return config

Expand All @@ -96,9 +103,7 @@ async def get_users_for_states(
}
return users_to_state

async def get_interested_users(
self, user_id: str
) -> Union[Set[str], PresenceRouter.ALL_USERS]:
async def get_interested_users(self, user_id: str) -> Union[Set[str], str]:
if user_id in self._config.users_who_should_receive_all_presence:
return PresenceRouter.ALL_USERS

Expand All @@ -118,9 +123,14 @@ def parse_config(config_dict: dict) -> PresenceRouterTestConfig:
# Initialise a typed config object
config = PresenceRouterTestConfig()

config.users_who_should_receive_all_presence = config_dict.get(
users_who_should_receive_all_presence = config_dict.get(
"users_who_should_receive_all_presence"
)
assert isinstance(users_who_should_receive_all_presence, list)

config.users_who_should_receive_all_presence = (
users_who_should_receive_all_presence
)

return config

Expand All @@ -140,7 +150,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
presence.register_servlets,
]

def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# Mock out the calls over federation.
fed_transport_client = Mock(spec=["send_transaction"])
fed_transport_client.send_transaction = simple_async_mock({})
Expand All @@ -153,7 +163,9 @@ def make_homeserver(self, reactor, clock):

return hs

def prepare(self, reactor, clock, homeserver):
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.sync_handler = self.hs.get_sync_handler()
self.module_api = homeserver.get_module_api()

Expand All @@ -176,7 +188,7 @@ def default_config(self) -> JsonDict:
},
}
)
def test_receiving_all_presence_legacy(self):
def test_receiving_all_presence_legacy(self) -> None:
self.receiving_all_presence_test_body()

@override_config(
Expand All @@ -193,10 +205,10 @@ def test_receiving_all_presence_legacy(self):
],
}
)
def test_receiving_all_presence(self):
def test_receiving_all_presence(self) -> None:
self.receiving_all_presence_test_body()

def receiving_all_presence_test_body(self):
def receiving_all_presence_test_body(self) -> None:
"""Test that a user that does not share a room with another other can receive
presence for them, due to presence routing.
"""
Expand Down Expand Up @@ -302,7 +314,7 @@ def receiving_all_presence_test_body(self):
},
}
)
def test_send_local_online_presence_to_with_module_legacy(self):
def test_send_local_online_presence_to_with_module_legacy(self) -> None:
self.send_local_online_presence_to_with_module_test_body()

@override_config(
Expand All @@ -321,10 +333,10 @@ def test_send_local_online_presence_to_with_module_legacy(self):
],
}
)
def test_send_local_online_presence_to_with_module(self):
def test_send_local_online_presence_to_with_module(self) -> None:
self.send_local_online_presence_to_with_module_test_body()

def send_local_online_presence_to_with_module_test_body(self):
def send_local_online_presence_to_with_module_test_body(self) -> None:
"""Tests that send_local_presence_to_users sends local online presence to a set
of specified local and remote users, with a custom PresenceRouter module enabled.
"""
Expand Down Expand Up @@ -447,18 +459,18 @@ def send_local_online_presence_to_with_module_test_body(self):
continue

# EDUs can contain multiple presence updates
for presence_update in edu["content"]["push"]:
for presence_edu in edu["content"]["push"]:
# Check for presence updates that contain the user IDs we're after
found_users.add(presence_update["user_id"])
found_users.add(presence_edu["user_id"])

# Ensure that no offline states are being sent out
self.assertNotEqual(presence_update["presence"], "offline")
self.assertNotEqual(presence_edu["presence"], "offline")

self.assertEqual(found_users, expected_users)


def send_presence_update(
testcase: TestCase,
testcase: FederatingHomeserverTestCase,
user_id: str,
access_token: str,
presence_state: str,
Expand All @@ -479,7 +491,7 @@ def send_presence_update(


def sync_presence(
testcase: TestCase,
testcase: FederatingHomeserverTestCase,
user_id: str,
since_token: Optional[StreamToken] = None,
) -> Tuple[List[UserPresenceState], StreamToken]:
Expand All @@ -500,7 +512,7 @@ def sync_presence(
requester = create_requester(user_id)
sync_config = generate_sync_config(requester.user.to_string())
sync_result = testcase.get_success(
testcase.sync_handler.wait_for_sync_for_user(
testcase.hs.get_sync_handler().wait_for_sync_for_user(
requester, sync_config, since_token
)
)
Expand Down
17 changes: 12 additions & 5 deletions tests/events/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from twisted.test.proto_helpers import MemoryReactor

from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.util import Clock

from tests import unittest
from tests.test_utils.event_injection import create_event
Expand All @@ -27,15 +32,15 @@ class TestEventContext(unittest.HomeserverTestCase):
room.register_servlets,
]

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()

self.user_id = self.register_user("u1", "pass")
self.user_tok = self.login("u1", "pass")
self.room_id = self.helper.create_room_as(tok=self.user_tok)

def test_serialize_deserialize_msg(self):
def test_serialize_deserialize_msg(self) -> None:
"""Test that an EventContext for a message event is the same after
serialize/deserialize.
"""
Expand All @@ -51,7 +56,7 @@ def test_serialize_deserialize_msg(self):

self._check_serialize_deserialize(event, context)

def test_serialize_deserialize_state_no_prev(self):
def test_serialize_deserialize_state_no_prev(self) -> None:
"""Test that an EventContext for a state event (with not previous entry)
is the same after serialize/deserialize.
"""
Expand All @@ -67,7 +72,7 @@ def test_serialize_deserialize_state_no_prev(self):

self._check_serialize_deserialize(event, context)

def test_serialize_deserialize_state_prev(self):
def test_serialize_deserialize_state_prev(self) -> None:
"""Test that an EventContext for a state event (which replaces a
previous entry) is the same after serialize/deserialize.
"""
Expand All @@ -84,7 +89,9 @@ def test_serialize_deserialize_state_prev(self):

self._check_serialize_deserialize(event, context)

def _check_serialize_deserialize(self, event, context):
def _check_serialize_deserialize(
self, event: EventBase, context: EventContext
) -> None:
serialized = self.get_success(context.serialize(event, self.store))

d_context = EventContext.deserialize(self._storage_controllers, serialized)
Expand Down
Loading