From e4903fbb96fa8aaa294d6092faaa9352ab8c37a2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 8 Feb 2023 13:16:36 -0500 Subject: [PATCH 01/13] Add missing type hints to test_mau and test_rust. --- mypy.ini | 6 ------ tests/test_mau.py | 35 ++++++++++++++++++++++------------- tests/test_rust.py | 2 +- 3 files changed, 23 insertions(+), 20 deletions(-) diff --git a/mypy.ini b/mypy.ini index 1bdeb18d9496..6c9f659e36c1 100644 --- a/mypy.ini +++ b/mypy.ini @@ -78,12 +78,6 @@ disallow_untyped_defs = False [mypy-tests.test_federation] disallow_untyped_defs = False -[mypy-tests.test_mau] -disallow_untyped_defs = False - -[mypy-tests.test_rust] -disallow_untyped_defs = False - [mypy-tests.test_test_utils] disallow_untyped_defs = False diff --git a/tests/test_mau.py b/tests/test_mau.py index f14fcb7db9ba..4e7665a22b93 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -14,12 +14,17 @@ """Tests REST events for /rooms paths.""" -from typing import List +from typing import List, Optional + +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.appservice import ApplicationService from synapse.rest.client import register, sync +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest from tests.unittest import override_config @@ -30,7 +35,7 @@ class TestMauLimit(unittest.HomeserverTestCase): servlets = [register.register_servlets, sync.register_servlets] - def default_config(self): + def default_config(self) -> JsonDict: config = default_config("test") config.update( @@ -53,10 +58,12 @@ def default_config(self): return config - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.store = homeserver.get_datastores().main - def test_simple_deny_mau(self): + def test_simple_deny_mau(self) -> None: # Create and sync so that the MAU counts get updated token1 = self.create_user("kermit1") self.do_sync_for_user(token1) @@ -75,7 +82,7 @@ def test_simple_deny_mau(self): self.assertEqual(e.code, 403) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - def test_as_ignores_mau(self): + def test_as_ignores_mau(self) -> None: """Test that application services can still create users when the MAU limit has been reached. This only works when application service user ip tracking is disabled. @@ -113,7 +120,7 @@ def test_as_ignores_mau(self): self.create_user("as_kermit4", token=as_token, appservice=True) - def test_allowed_after_a_month_mau(self): + def test_allowed_after_a_month_mau(self) -> None: # Create and sync so that the MAU counts get updated token1 = self.create_user("kermit1") self.do_sync_for_user(token1) @@ -132,7 +139,7 @@ def test_allowed_after_a_month_mau(self): self.do_sync_for_user(token3) @override_config({"mau_trial_days": 1}) - def test_trial_delay(self): + def test_trial_delay(self) -> None: # We should be able to register more than the limit initially token1 = self.create_user("kermit1") self.do_sync_for_user(token1) @@ -165,7 +172,7 @@ def test_trial_delay(self): self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) @override_config({"mau_trial_days": 1}) - def test_trial_users_cant_come_back(self): + def test_trial_users_cant_come_back(self) -> None: self.hs.config.server.mau_trial_days = 1 # We should be able to register more than the limit initially @@ -216,7 +223,7 @@ def test_trial_users_cant_come_back(self): # max_mau_value should not matter {"max_mau_value": 1, "limit_usage_by_mau": False, "mau_stats_only": True} ) - def test_tracked_but_not_limited(self): + def test_tracked_but_not_limited(self) -> None: # Simply being able to create 2 users indicates that the # limit was not reached. token1 = self.create_user("kermit1") @@ -236,10 +243,10 @@ def test_tracked_but_not_limited(self): "mau_appservice_trial_days": {"SomeASID": 1, "AnotherASID": 2}, } ) - def test_as_trial_days(self): + def test_as_trial_days(self) -> None: user_tokens: List[str] = [] - def advance_time_and_sync(): + def advance_time_and_sync() -> None: self.reactor.advance(24 * 60 * 61) for token in user_tokens: self.do_sync_for_user(token) @@ -300,7 +307,9 @@ def advance_time_and_sync(): }, ) - def create_user(self, localpart, token=None, appservice=False): + def create_user( + self, localpart: str, token: Optional[str] = None, appservice: bool = False + ) -> str: request_data = { "username": localpart, "password": "monkey", @@ -326,7 +335,7 @@ def create_user(self, localpart, token=None, appservice=False): return access_token - def do_sync_for_user(self, token): + def do_sync_for_user(self, token: str) -> None: channel = self.make_request("GET", "/sync", access_token=token) if channel.code != 200: diff --git a/tests/test_rust.py b/tests/test_rust.py index 55d8b6b28cb4..67443b628042 100644 --- a/tests/test_rust.py +++ b/tests/test_rust.py @@ -6,6 +6,6 @@ class RustTestCase(unittest.TestCase): """Basic tests to ensure that we can call into Rust code.""" - def test_basic(self): + def test_basic(self) -> None: result = sum_as_string(1, 2) self.assertEqual("3", result) From 5bad3c72bfe1c53b1d05fbd3672cd85315e3ca05 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 8 Feb 2023 13:19:56 -0500 Subject: [PATCH 02/13] Add missing type hints to test_test_utils and test_types.. --- mypy.ini | 6 ------ tests/test_test_utils.py | 16 ++++++++-------- tests/test_types.py | 30 +++++++++++++++--------------- 3 files changed, 23 insertions(+), 29 deletions(-) diff --git a/mypy.ini b/mypy.ini index 6c9f659e36c1..679d4c07ce57 100644 --- a/mypy.ini +++ b/mypy.ini @@ -78,12 +78,6 @@ disallow_untyped_defs = False [mypy-tests.test_federation] disallow_untyped_defs = False -[mypy-tests.test_test_utils] -disallow_untyped_defs = False - -[mypy-tests.test_types] -disallow_untyped_defs = False - [mypy-tests.test_utils.*] disallow_untyped_defs = False diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index d04bcae0fabe..5cd698147e12 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -17,25 +17,25 @@ class MockClockTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.clock = MockClock() - def test_advance_time(self): + def test_advance_time(self) -> None: start_time = self.clock.time() self.clock.advance_time(20) self.assertEqual(20, self.clock.time() - start_time) - def test_later(self): + def test_later(self) -> None: invoked = [0, 0] - def _cb0(): + def _cb0() -> None: invoked[0] = 1 self.clock.call_later(10, _cb0) - def _cb1(): + def _cb1() -> None: invoked[1] = 1 self.clock.call_later(20, _cb1) @@ -51,15 +51,15 @@ def _cb1(): self.assertTrue(invoked[1]) - def test_cancel_later(self): + def test_cancel_later(self) -> None: invoked = [0, 0] - def _cb0(): + def _cb0() -> None: invoked[0] = 1 t0 = self.clock.call_later(10, _cb0) - def _cb1(): + def _cb1() -> None: invoked[1] = 1 self.clock.call_later(20, _cb1) diff --git a/tests/test_types.py b/tests/test_types.py index 111116938423..c491cc9a9661 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -43,34 +43,34 @@ def test_two_colons(self) -> None: class UserIDTestCase(unittest.HomeserverTestCase): - def test_parse(self): + def test_parse(self) -> None: user = UserID.from_string("@1234abcd:test") self.assertEqual("1234abcd", user.localpart) self.assertEqual("test", user.domain) self.assertEqual(True, self.hs.is_mine(user)) - def test_parse_rejects_empty_id(self): + def test_parse_rejects_empty_id(self) -> None: with self.assertRaises(SynapseError): UserID.from_string("") - def test_parse_rejects_missing_sigil(self): + def test_parse_rejects_missing_sigil(self) -> None: with self.assertRaises(SynapseError): UserID.from_string("alice:example.com") - def test_parse_rejects_missing_separator(self): + def test_parse_rejects_missing_separator(self) -> None: with self.assertRaises(SynapseError): UserID.from_string("@alice.example.com") - def test_validation_rejects_missing_domain(self): + def test_validation_rejects_missing_domain(self) -> None: self.assertFalse(UserID.is_valid("@alice:")) - def test_build(self): + def test_build(self) -> None: user = UserID("5678efgh", "my.domain") self.assertEqual(user.to_string(), "@5678efgh:my.domain") - def test_compare(self): + def test_compare(self) -> None: userA = UserID.from_string("@userA:my.domain") userAagain = UserID.from_string("@userA:my.domain") userB = UserID.from_string("@userB:my.domain") @@ -80,43 +80,43 @@ def test_compare(self): class RoomAliasTestCase(unittest.HomeserverTestCase): - def test_parse(self): + def test_parse(self) -> None: room = RoomAlias.from_string("#channel:test") self.assertEqual("channel", room.localpart) self.assertEqual("test", room.domain) self.assertEqual(True, self.hs.is_mine(room)) - def test_build(self): + def test_build(self) -> None: room = RoomAlias("channel", "my.domain") self.assertEqual(room.to_string(), "#channel:my.domain") - def test_validate(self): + def test_validate(self) -> None: id_string = "#test:domain,test" self.assertFalse(RoomAlias.is_valid(id_string)) class MapUsernameTestCase(unittest.TestCase): - def testPassThrough(self): + def test_pass_througuh(self) -> None: self.assertEqual(map_username_to_mxid_localpart("test1234"), "test1234") - def testUpperCase(self): + def test_upper_case(self) -> None: self.assertEqual(map_username_to_mxid_localpart("tEST_1234"), "test_1234") self.assertEqual( map_username_to_mxid_localpart("tEST_1234", case_sensitive=True), "t_e_s_t__1234", ) - def testSymbols(self): + def test_symbols(self) -> None: self.assertEqual( map_username_to_mxid_localpart("test=$?_1234"), "test=3d=24=3f_1234" ) - def testLeadingUnderscore(self): + def test_leading_underscore(self) -> None: self.assertEqual(map_username_to_mxid_localpart("_test_1234"), "=5ftest_1234") - def testNonAscii(self): + def test_non_ascii(self) -> None: # this should work with either a unicode or a bytes self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast") self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast") From 6f49d7af6cadc0e8d16512da72a9bc9cb25cc8c4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 8 Feb 2023 13:21:38 -0500 Subject: [PATCH 03/13] Add type hints to test_distributor and test_event_auth. --- mypy.ini | 6 ------ tests/test_distributor.py | 12 ++++++------ tests/test_event_auth.py | 32 +++++++++++++++++--------------- 3 files changed, 23 insertions(+), 27 deletions(-) diff --git a/mypy.ini b/mypy.ini index 679d4c07ce57..70c106c668af 100644 --- a/mypy.ini +++ b/mypy.ini @@ -69,12 +69,6 @@ disallow_untyped_defs = False [mypy-tests.server_notices.test_resource_limits_server_notices] disallow_untyped_defs = False -[mypy-tests.test_distributor] -disallow_untyped_defs = False - -[mypy-tests.test_event_auth] -disallow_untyped_defs = False - [mypy-tests.test_federation] disallow_untyped_defs = False diff --git a/tests/test_distributor.py b/tests/test_distributor.py index 31546ea52bb0..a248f1d27711 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py @@ -21,10 +21,10 @@ class DistributorTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.dist = Distributor() - def test_signal_dispatch(self): + def test_signal_dispatch(self) -> None: self.dist.declare("alert") observer = Mock() @@ -33,7 +33,7 @@ def test_signal_dispatch(self): self.dist.fire("alert", 1, 2, 3) observer.assert_called_with(1, 2, 3) - def test_signal_catch(self): + def test_signal_catch(self) -> None: self.dist.declare("alarm") observers = [Mock() for i in (1, 2)] @@ -51,7 +51,7 @@ def test_signal_catch(self): self.assertEqual(mock_logger.warning.call_count, 1) self.assertIsInstance(mock_logger.warning.call_args[0][0], str) - def test_signal_prereg(self): + def test_signal_prereg(self) -> None: observer = Mock() self.dist.observe("flare", observer) @@ -60,8 +60,8 @@ def test_signal_prereg(self): observer.assert_called_with(4, 5) - def test_signal_undeclared(self): - def code(): + def test_signal_undeclared(self) -> None: + def code() -> None: self.dist.fire("notification") self.assertRaises(KeyError, code) diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 0a7937f1cc72..2860564afc45 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -31,13 +31,13 @@ class _StubEventSourceStore: """A stub implementation of the EventSourceStore""" - def __init__(self): + def __init__(self) -> None: self._store: Dict[str, EventBase] = {} - def add_event(self, event: EventBase): + def add_event(self, event: EventBase) -> None: self._store[event.event_id] = event - def add_events(self, events: Iterable[EventBase]): + def add_events(self, events: Iterable[EventBase]) -> None: for event in events: self._store[event.event_id] = event @@ -59,7 +59,7 @@ async def get_events( class EventAuthTestCase(unittest.TestCase): - def test_rejected_auth_events(self): + def test_rejected_auth_events(self) -> None: """ Events that refer to rejected events in their auth events are rejected """ @@ -109,7 +109,7 @@ def test_rejected_auth_events(self): ) ) - def test_create_event_with_prev_events(self): + def test_create_event_with_prev_events(self) -> None: """A create event with prev_events should be rejected https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules @@ -150,7 +150,7 @@ def test_create_event_with_prev_events(self): event_auth.check_state_independent_auth_rules(event_store, bad_event) ) - def test_duplicate_auth_events(self): + def test_duplicate_auth_events(self) -> None: """Events with duplicate auth_events should be rejected https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules @@ -196,7 +196,7 @@ def test_duplicate_auth_events(self): event_auth.check_state_independent_auth_rules(event_store, bad_event2) ) - def test_unexpected_auth_events(self): + def test_unexpected_auth_events(self) -> None: """Events with excess auth_events should be rejected https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules @@ -236,7 +236,7 @@ def test_unexpected_auth_events(self): event_auth.check_state_independent_auth_rules(event_store, bad_event) ) - def test_random_users_cannot_send_state_before_first_pl(self): + def test_random_users_cannot_send_state_before_first_pl(self) -> None: """ Check that, before the first PL lands, the creator is the only user that can send a state event. @@ -263,7 +263,7 @@ def test_random_users_cannot_send_state_before_first_pl(self): auth_events, ) - def test_state_default_level(self): + def test_state_default_level(self) -> None: """ Check that users above the state_default level can send state and those below cannot @@ -298,7 +298,7 @@ def test_state_default_level(self): auth_events, ) - def test_alias_event(self): + def test_alias_event(self) -> None: """Alias events have special behavior up through room version 6.""" creator = "@creator:example.com" other = "@other:example.com" @@ -333,7 +333,7 @@ def test_alias_event(self): auth_events, ) - def test_msc2432_alias_event(self): + def test_msc2432_alias_event(self) -> None: """After MSC2432, alias events have no special behavior.""" creator = "@creator:example.com" other = "@other:example.com" @@ -366,7 +366,9 @@ def test_msc2432_alias_event(self): ) @parameterized.expand([(RoomVersions.V1, True), (RoomVersions.V6, False)]) - def test_notifications(self, room_version: RoomVersion, allow_modification: bool): + def test_notifications( + self, room_version: RoomVersion, allow_modification: bool + ) -> None: """ Notifications power levels get checked due to MSC2209. """ @@ -395,7 +397,7 @@ def test_notifications(self, room_version: RoomVersion, allow_modification: bool with self.assertRaises(AuthError): event_auth.check_state_dependent_auth_rules(pl_event, auth_events) - def test_join_rules_public(self): + def test_join_rules_public(self) -> None: """ Test joining a public room. """ @@ -460,7 +462,7 @@ def test_join_rules_public(self): auth_events.values(), ) - def test_join_rules_invite(self): + def test_join_rules_invite(self) -> None: """ Test joining an invite only room. """ @@ -835,7 +837,7 @@ def _power_levels_event( ) -def _alias_event(room_version: RoomVersion, sender: str, **kwargs) -> EventBase: +def _alias_event(room_version: RoomVersion, sender: str, **kwargs: Any) -> EventBase: data = { "room_id": TEST_ROOM_ID, **_maybe_get_event_id_dict_for_room_version(room_version), From 871105089a72ea5e022859c53cb959694ad86cb1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 8 Feb 2023 13:26:51 -0500 Subject: [PATCH 04/13] Newsfragment --- changelog.d/15027.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/15027.misc diff --git a/changelog.d/15027.misc b/changelog.d/15027.misc new file mode 100644 index 000000000000..93ceaeafc9b9 --- /dev/null +++ b/changelog.d/15027.misc @@ -0,0 +1 @@ +Improve type hints. From ad38a8742168da02df0efd5aa2727c015d131117 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 8 Feb 2023 13:35:01 -0500 Subject: [PATCH 05/13] Fix-up test_Federation. --- mypy.ini | 3 -- tests/test_federation.py | 80 ++++++++++++++++++++++------------------ 2 files changed, 44 insertions(+), 39 deletions(-) diff --git a/mypy.ini b/mypy.ini index 70c106c668af..941f26689010 100644 --- a/mypy.ini +++ b/mypy.ini @@ -69,9 +69,6 @@ disallow_untyped_defs = False [mypy-tests.server_notices.test_resource_limits_server_notices] disallow_untyped_defs = False -[mypy-tests.test_federation] -disallow_untyped_defs = False - [mypy-tests.test_utils.*] disallow_untyped_defs = False diff --git a/tests/test_federation.py b/tests/test_federation.py index 80e5c590d836..ddb43c8c981a 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -12,53 +12,48 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Union from unittest.mock import Mock from twisted.internet.defer import succeed +from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import FederationError from synapse.api.room_versions import RoomVersions -from synapse.events import make_event_from_dict +from synapse.events import EventBase, make_event_from_dict +from synapse.events.snapshot import EventContext from synapse.federation.federation_base import event_from_pdu_json +from synapse.http.types import QueryParams from synapse.logging.context import LoggingContext -from synapse.types import UserID, create_requester +from synapse.server import HomeServer +from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock from synapse.util.retryutils import NotRetryingDestination from tests import unittest -from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver from tests.test_utils import make_awaitable class MessageAcceptTests(unittest.HomeserverTestCase): - def setUp(self): - + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.http_client = Mock() - self.reactor = ThreadedMemoryReactorClock() - self.hs_clock = Clock(self.reactor) - self.homeserver = setup_test_homeserver( - self.addCleanup, - federation_http_client=self.http_client, - clock=self.hs_clock, - reactor=self.reactor, - ) + return self.setup_test_homeserver(federation_http_client=self.http_client) + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: user_id = UserID("us", "test") our_user = create_requester(user_id) - room_creator = self.homeserver.get_room_creation_handler() + room_creator = self.hs.get_room_creation_handler() self.room_id = self.get_success( room_creator.create_room( our_user, room_creator._presets_dict["public_chat"], ratelimit=False ) )[0]["room_id"] - self.store = self.homeserver.get_datastores().main + self.store = self.hs.get_datastores().main # Figure out what the most recent event is most_recent = self.get_success( - self.homeserver.get_datastores().main.get_latest_event_ids_in_room( - self.room_id - ) + self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) )[0] join_event = make_event_from_dict( @@ -78,14 +73,16 @@ def setUp(self): } ) - self.handler = self.homeserver.get_federation_handler() - federation_event_handler = self.homeserver.get_federation_event_handler() + self.handler = self.hs.get_federation_handler() + federation_event_handler = self.hs.get_federation_event_handler() - async def _check_event_auth(origin, event, context): + async def _check_event_auth( + origin: Optional[str], event: EventBase, context: EventContext + ) -> None: pass federation_event_handler._check_event_auth = _check_event_auth - self.client = self.homeserver.get_federation_client() + self.client = self.hs.get_federation_client() self.client._check_sigs_and_hash_for_pulled_events_and_fetch = ( lambda dest, pdus, **k: succeed(pdus) ) @@ -104,16 +101,25 @@ async def _check_event_auth(origin, event, context): "$join:test.serv", ) - def test_cant_hide_direct_ancestors(self): + def test_cant_hide_direct_ancestors(self) -> None: """ If you send a message, you must be able to provide the direct prev_events that said event references. """ - async def post_json(destination, path, data, headers=None, timeout=0): + async def post_json( + destination: str, + path: str, + data: Optional[JsonDict] = None, + long_retries: bool = False, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + args: Optional[QueryParams] = None, + ) -> Union[JsonDict, list]: # If it asks us for new missing events, give them NOTHING if path.startswith("/_matrix/federation/v1/get_missing_events/"): return {"events": []} + return {} self.http_client.post_json = post_json @@ -138,7 +144,7 @@ async def post_json(destination, path, data, headers=None, timeout=0): } ) - federation_event_handler = self.homeserver.get_federation_event_handler() + federation_event_handler = self.hs.get_federation_event_handler() with LoggingContext("test-context"): failure = self.get_failure( federation_event_handler.on_receive_pdu("test.serv", lying_event), @@ -158,7 +164,7 @@ async def post_json(destination, path, data, headers=None, timeout=0): extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)) self.assertEqual(extrem[0], "$join:test.serv") - def test_retry_device_list_resync(self): + def test_retry_device_list_resync(self) -> None: """Tests that device lists are marked as stale if they couldn't be synced, and that stale device lists are retried periodically. """ @@ -171,24 +177,26 @@ def test_retry_device_list_resync(self): # When this function is called, increment the number of resync attempts (only if # we're querying devices for the right user ID), then raise a # NotRetryingDestination error to fail the resync gracefully. - def query_user_devices(destination, user_id): + def query_user_devices( + destination: str, user_id: str, timeout: int = 30000 + ) -> JsonDict: if user_id == remote_user_id: self.resync_attempts += 1 raise NotRetryingDestination(0, 0, destination) # Register the mock on the federation client. - federation_client = self.homeserver.get_federation_client() + federation_client = self.hs.get_federation_client() federation_client.query_user_devices = Mock(side_effect=query_user_devices) # Register a mock on the store so that the incoming update doesn't fail because # we don't share a room with the user. - store = self.homeserver.get_datastores().main + store = self.hs.get_datastores().main store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"])) # Manually inject a fake device list update. We need this update to include at # least one prev_id so that the user's device list will need to be retried. - device_list_updater = self.homeserver.get_device_handler().device_list_updater + device_list_updater = self.hs.get_device_handler().device_list_updater self.get_success( device_list_updater.incoming_device_list_update( origin=remote_origin, @@ -218,7 +226,7 @@ def query_user_devices(destination, user_id): self.reactor.advance(30) self.assertEqual(self.resync_attempts, 2) - def test_cross_signing_keys_retry(self): + def test_cross_signing_keys_retry(self) -> None: """Tests that resyncing a device list correctly processes cross-signing keys from the remote server. """ @@ -227,7 +235,7 @@ def test_cross_signing_keys_retry(self): remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" # Register mock device list retrieval on the federation client. - federation_client = self.homeserver.get_federation_client() + federation_client = self.hs.get_federation_client() federation_client.query_user_devices = Mock( return_value=make_awaitable( { @@ -252,7 +260,7 @@ def test_cross_signing_keys_retry(self): ) # Resync the device list. - device_handler = self.homeserver.get_device_handler() + device_handler = self.hs.get_device_handler() self.get_success( device_handler.device_list_updater.user_device_resync(remote_user_id), ) @@ -279,7 +287,7 @@ def test_cross_signing_keys_retry(self): class StripUnsignedFromEventsTestCase(unittest.TestCase): - def test_strip_unauthorized_unsigned_values(self): + def test_strip_unauthorized_unsigned_values(self) -> None: event1 = { "sender": "@baduser:test.serv", "state_key": "@baduser:test.serv", @@ -296,7 +304,7 @@ def test_strip_unauthorized_unsigned_values(self): # Make sure unauthorized fields are stripped from unsigned self.assertNotIn("more warez", filtered_event.unsigned) - def test_strip_event_maintains_allowed_fields(self): + def test_strip_event_maintains_allowed_fields(self) -> None: event2 = { "sender": "@baduser:test.serv", "state_key": "@baduser:test.serv", @@ -323,7 +331,7 @@ def test_strip_event_maintains_allowed_fields(self): self.assertIn("invite_room_state", filtered_event2.unsigned) self.assertEqual([], filtered_event2.unsigned["invite_room_state"]) - def test_strip_event_removes_fields_based_on_event_type(self): + def test_strip_event_removes_fields_based_on_event_type(self) -> None: event3 = { "sender": "@baduser:test.serv", "state_key": "@baduser:test.serv", From 75cb9deb44fa1c84de93d3d14d4385b236f318b9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 8 Feb 2023 13:35:58 -0500 Subject: [PATCH 06/13] Fix-up test_visibility. --- mypy.ini | 3 --- tests/test_visibility.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/mypy.ini b/mypy.ini index 941f26689010..b260706e7ad2 100644 --- a/mypy.ini +++ b/mypy.ini @@ -72,9 +72,6 @@ disallow_untyped_defs = False [mypy-tests.test_utils.*] disallow_untyped_defs = False -[mypy-tests.test_visibility] -disallow_untyped_defs = False - [mypy-tests.unittest] disallow_untyped_defs = False diff --git a/tests/test_visibility.py b/tests/test_visibility.py index d0b9ad54540d..875e37988f5d 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -258,7 +258,7 @@ def _inject_outlier(self) -> EventBase: class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): - def test_out_of_band_invite_rejection(self): + def test_out_of_band_invite_rejection(self) -> None: # this is where we have received an invite event over federation, and then # rejected it. invite_pdu = { From 4908f0a97fe863b8d39a547dfebd1eef23d16d07 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 8 Feb 2023 13:37:44 -0500 Subject: [PATCH 07/13] Partially fix tests.unittest. --- tests/unittest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittest.py b/tests/unittest.py index fa92dd94eb6a..68e59a88dc0f 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -315,7 +315,7 @@ def setUp(self) -> None: # This has to be a function and not just a Mock, because # `self.helper.auth_user_id` is temporarily reassigned in some tests - async def get_requester(*args, **kwargs) -> Requester: + async def get_requester(*args: Any, **kwargs: Any) -> Requester: assert self.helper.auth_user_id is not None return create_requester( user_id=UserID.from_string(self.helper.auth_user_id), From 84defca466aff84f072bb5d1665f0e1ef9ddb544 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 8 Feb 2023 13:41:07 -0500 Subject: [PATCH 08/13] Handle tests.server_notices. --- mypy.ini | 6 ---- tests/server_notices/test_consent.py | 14 ++++---- .../test_resource_limits_server_notices.py | 35 +++++++++++-------- 3 files changed, 28 insertions(+), 27 deletions(-) diff --git a/mypy.ini b/mypy.ini index b260706e7ad2..52d0107366ef 100644 --- a/mypy.ini +++ b/mypy.ini @@ -63,12 +63,6 @@ disallow_untyped_defs = False [mypy-tests.scripts.test_new_matrix_user] disallow_untyped_defs = False -[mypy-tests.server_notices.test_consent] -disallow_untyped_defs = False - -[mypy-tests.server_notices.test_resource_limits_server_notices] -disallow_untyped_defs = False - [mypy-tests.test_utils.*] disallow_untyped_defs = False diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py index 58b399a04377..6540ed53f173 100644 --- a/tests/server_notices/test_consent.py +++ b/tests/server_notices/test_consent.py @@ -14,8 +14,12 @@ import os +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.rest.client import login, room, sync +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -29,7 +33,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: tmpdir = self.mktemp() os.mkdir(tmpdir) @@ -53,15 +57,13 @@ def make_homeserver(self, reactor, clock): "room_name": "Server Notices", } - hs = self.setup_test_homeserver(config=config) - - return hs + return self.setup_test_homeserver(config=config) - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("bob", "abc123") self.access_token = self.login("bob", "abc123") - def test_get_sync_message(self): + def test_get_sync_message(self) -> None: """ When user consent server notices are enabled, a sync will cause a notice to fire (in a room which the user is invited to). The notice contains diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index dadc6efcbf75..5b76383d760a 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -24,6 +24,7 @@ from synapse.server_notices.resource_limits_server_notices import ( ResourceLimitsServerNotices, ) +from synapse.types import JsonDict from synapse.util import Clock from tests import unittest @@ -33,7 +34,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): - def default_config(self): + def default_config(self) -> JsonDict: config = default_config("test") config.update( @@ -86,18 +87,18 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) # type: ignore[assignment] @override_config({"hs_disabled": True}) - def test_maybe_send_server_notice_disabled_hs(self): + def test_maybe_send_server_notice_disabled_hs(self) -> None: """If the HS is disabled, we should not send notices""" self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() @override_config({"limit_usage_by_mau": False}) - def test_maybe_send_server_notice_to_user_flag_off(self): + def test_maybe_send_server_notice_to_user_flag_off(self) -> None: """If mau limiting is disabled, we should not send notices""" self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() - def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): + def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None: """Test when user has blocked notice, but should have it removed""" self._rlsn._auth_blocking.check_auth_blocking = Mock( @@ -114,7 +115,7 @@ def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): self._rlsn._server_notices_manager.maybe_get_notice_room_for_user.assert_called_once() self._send_notice.assert_called_once() - def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self): + def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self) -> None: """ Test when user has blocked notice, but notice ought to be there (NOOP) """ @@ -134,7 +135,7 @@ def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self): self._send_notice.assert_not_called() - def test_maybe_send_server_notice_to_user_add_blocked_notice(self): + def test_maybe_send_server_notice_to_user_add_blocked_notice(self) -> None: """ Test when user does not have blocked notice, but should have one """ @@ -147,7 +148,7 @@ def test_maybe_send_server_notice_to_user_add_blocked_notice(self): # Would be better to check contents, but 2 calls == set blocking event self.assertEqual(self._send_notice.call_count, 2) - def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self): + def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self) -> None: """ Test when user does not have blocked notice, nor should they (NOOP) """ @@ -159,7 +160,7 @@ def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self): self._send_notice.assert_not_called() - def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self): + def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self) -> None: """ Test when user is not part of the MAU cohort - this should not ever happen - but ... @@ -175,7 +176,9 @@ def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self): self._send_notice.assert_not_called() @override_config({"mau_limit_alerting": False}) - def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(self): + def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked( + self, + ) -> None: """ Test that when server is over MAU limit and alerting is suppressed, then an alert message is not sent into the room @@ -191,7 +194,7 @@ def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(self): self.assertEqual(self._send_notice.call_count, 0) @override_config({"mau_limit_alerting": False}) - def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self): + def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self) -> None: """ Test that when a server is disabled, that MAU limit alerting is ignored. """ @@ -207,7 +210,9 @@ def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self): self.assertEqual(self._send_notice.call_count, 2) @override_config({"mau_limit_alerting": False}) - def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(self): + def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked( + self, + ) -> None: """ When the room is already in a blocked state, test that when alerting is suppressed that the room is returned to an unblocked state. @@ -242,7 +247,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): sync.register_servlets, ] - def default_config(self): + def default_config(self) -> JsonDict: c = super().default_config() c["server_notices"] = { "system_mxid_localpart": "server", @@ -270,7 +275,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = "@user_id:test" - def test_server_notice_only_sent_once(self): + def test_server_notice_only_sent_once(self) -> None: self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000)) self.store.user_last_seen_monthly_active = Mock( @@ -306,7 +311,7 @@ def test_server_notice_only_sent_once(self): self.assertEqual(count, 1) - def test_no_invite_without_notice(self): + def test_no_invite_without_notice(self) -> None: """Tests that a user doesn't get invited to a server notices room without a server notice being sent. @@ -328,7 +333,7 @@ def test_no_invite_without_notice(self): m.assert_called_once_with(user_id) - def test_invite_with_notice(self): + def test_invite_with_notice(self) -> None: """Tests that, if the MAU limit is hit, the server notices user invites each user to a room in which it has sent a notice. """ From b3773b97dffd87a2ffc28c698262dec87bef6322 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 8 Feb 2023 14:05:08 -0500 Subject: [PATCH 09/13] Fix-up test_utils. --- mypy.ini | 3 --- tests/handlers/test_oidc.py | 4 ++-- tests/test_utils/__init__.py | 23 ++++++++++++++--------- tests/test_utils/event_injection.py | 8 ++++---- tests/test_utils/html_parsers.py | 6 +++--- tests/test_utils/logging_setup.py | 4 ++-- tests/test_utils/oidc.py | 10 +++++----- 7 files changed, 30 insertions(+), 28 deletions(-) diff --git a/mypy.ini b/mypy.ini index 52d0107366ef..dc6cfdb1810d 100644 --- a/mypy.ini +++ b/mypy.ini @@ -63,9 +63,6 @@ disallow_untyped_defs = False [mypy-tests.scripts.test_new_matrix_user] disallow_untyped_defs = False -[mypy-tests.test_utils.*] -disallow_untyped_defs = False - [mypy-tests.unittest] disallow_untyped_defs = False diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index adddbd002f50..951caaa6b3ba 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -150,7 +150,7 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver() self.hs_patcher = self.fake_server.patch_homeserver(hs=hs) - self.hs_patcher.start() + self.hs_patcher.start() # type: ignore[attr-defined] self.handler = hs.get_oidc_handler() self.provider = self.handler._providers["oidc"] @@ -170,7 +170,7 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: return hs def tearDown(self) -> None: - self.hs_patcher.stop() + self.hs_patcher.stop() # type: ignore[attr-defined] return super().tearDown() def reset_mocks(self) -> None: diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index e62ebcc6a5a3..e7b02ebd56bc 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -20,12 +20,13 @@ import warnings from asyncio import Future from binascii import unhexlify -from typing import Awaitable, Callable, Tuple, TypeVar +from typing import Any, Awaitable, Callable, List, Optional, Tuple, TypeVar from unittest.mock import Mock import attr import zope.interface +from twisted.internet.interfaces import IProtocol from twisted.python.failure import Failure from twisted.web.client import ResponseDone from twisted.web.http import RESPONSES @@ -78,25 +79,29 @@ def setup_awaitable_errors() -> Callable[[], None]: unraisable_exceptions = [] orig_unraisablehook = sys.unraisablehook - def unraisablehook(unraisable): + def unraisablehook(unraisable: sys.UnraisableHookArgs) -> None: unraisable_exceptions.append(unraisable.exc_value) - def cleanup(): + def cleanup() -> None: """ A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions. """ sys.unraisablehook = orig_unraisablehook if unraisable_exceptions: - raise unraisable_exceptions.pop() + exc = unraisable_exceptions.pop() + assert exc is not None + raise exc sys.unraisablehook = unraisablehook return cleanup -def simple_async_mock(return_value=None, raises=None) -> Mock: +def simple_async_mock( + return_value: Optional[TV] = None, raises: Optional[Exception] = None +) -> Mock: # AsyncMock is not available in python3.5, this mimics part of its behaviour - async def cb(*args, **kwargs): + async def cb(*args: Any, **kwargs: Any) -> Optional[TV]: if raises: raise raises return return_value @@ -125,14 +130,14 @@ class FakeResponse: # type: ignore[misc] headers: Headers = attr.Factory(Headers) @property - def phrase(self): + def phrase(self) -> bytes: return RESPONSES.get(self.code, b"Unknown Status") @property - def length(self): + def length(self) -> int: return len(self.body) - def deliverBody(self, protocol): + def deliverBody(self, protocol: IProtocol) -> None: protocol.dataReceived(self.body) protocol.connectionLost(Failure(ResponseDone())) diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index 8027c7a856e2..1a50c2acf12a 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple +from typing import Any, List, Optional, Tuple import synapse.server from synapse.api.constants import EventTypes @@ -32,7 +32,7 @@ async def inject_member_event( membership: str, target: Optional[str] = None, extra_content: Optional[dict] = None, - **kwargs, + **kwargs: Any, ) -> EventBase: """Inject a membership event into a room.""" if target is None: @@ -57,7 +57,7 @@ async def inject_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, prev_event_ids: Optional[List[str]] = None, - **kwargs, + **kwargs: Any, ) -> EventBase: """Inject a generic event into a room @@ -82,7 +82,7 @@ async def create_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, prev_event_ids: Optional[List[str]] = None, - **kwargs, + **kwargs: Any, ) -> Tuple[EventBase, EventContext]: if room_version is None: room_version = await hs.get_datastores().main.get_room_version_id( diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py index e878af5f12e7..189c697efbee 100644 --- a/tests/test_utils/html_parsers.py +++ b/tests/test_utils/html_parsers.py @@ -13,13 +13,13 @@ # limitations under the License. from html.parser import HTMLParser -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Dict, Iterable, List, NoReturn, Optional, Tuple class TestHtmlParser(HTMLParser): """A generic HTML page parser which extracts useful things from the HTML""" - def __init__(self): + def __init__(self) -> None: super().__init__() # a list of links found in the doc @@ -48,5 +48,5 @@ def handle_starttag( assert input_name self.hiddens[input_name] = attr_dict["value"] - def error(_, message): + def error(self, message: str) -> NoReturn: raise AssertionError(message) diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py index 304c7b98c5c9..b522163a3444 100644 --- a/tests/test_utils/logging_setup.py +++ b/tests/test_utils/logging_setup.py @@ -25,7 +25,7 @@ class ToTwistedHandler(logging.Handler): tx_log = twisted.logger.Logger() - def emit(self, record): + def emit(self, record: logging.LogRecord) -> None: log_entry = self.format(record) log_level = record.levelname.lower().replace("warning", "warn") self.tx_log.emit( @@ -33,7 +33,7 @@ def emit(self, record): ) -def setup_logging(): +def setup_logging() -> None: """Configure the python logging appropriately for the tests. (Logs will end up in _trial_temp.) diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py index 1461d23ee823..d555b242555d 100644 --- a/tests/test_utils/oidc.py +++ b/tests/test_utils/oidc.py @@ -14,7 +14,7 @@ import json -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, ContextManager, Dict, List, Optional, Tuple from unittest.mock import Mock, patch from urllib.parse import parse_qs @@ -77,14 +77,14 @@ def __init__(self, clock: Clock, issuer: str): self._id_token_overrides: Dict[str, Any] = {} - def reset_mocks(self): + def reset_mocks(self) -> None: self.request.reset_mock() self.get_jwks_handler.reset_mock() self.get_metadata_handler.reset_mock() self.get_userinfo_handler.reset_mock() self.post_token_handler.reset_mock() - def patch_homeserver(self, hs: HomeServer): + def patch_homeserver(self, hs: HomeServer) -> ContextManager[Mock]: """Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``. This patch should be used whenever the HS is expected to perform request to the @@ -188,7 +188,7 @@ def generate_logout_token(self, grant: FakeAuthorizationGrant) -> str: return self._sign(logout_token) - def id_token_override(self, overrides: dict): + def id_token_override(self, overrides: dict) -> ContextManager[dict]: """Temporarily patch the ID token generated by the token endpoint.""" return patch.object(self, "_id_token_overrides", overrides) @@ -247,7 +247,7 @@ def buggy_endpoint( metadata: bool = False, token: bool = False, userinfo: bool = False, - ): + ) -> ContextManager[Dict[str, Mock]]: """A context which makes a set of endpoints return a 500 error. Args: From 15f410a039a8355c76650f63458ba45456ecae49 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 8 Feb 2023 14:13:35 -0500 Subject: [PATCH 10/13] Fix test_new_matrix_user. --- mypy.ini | 3 --- tests/scripts/test_new_matrix_user.py | 25 ++++++++++++++++--------- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/mypy.ini b/mypy.ini index dc6cfdb1810d..0e5c6ccf6121 100644 --- a/mypy.ini +++ b/mypy.ini @@ -60,9 +60,6 @@ disallow_untyped_defs = False [mypy-synapse.storage.database] disallow_untyped_defs = False -[mypy-tests.scripts.test_new_matrix_user] -disallow_untyped_defs = False - [mypy-tests.unittest] disallow_untyped_defs = False diff --git a/tests/scripts/test_new_matrix_user.py b/tests/scripts/test_new_matrix_user.py index 22f99c6ab1ce..3285f2433ccf 100644 --- a/tests/scripts/test_new_matrix_user.py +++ b/tests/scripts/test_new_matrix_user.py @@ -12,29 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Optional from unittest.mock import Mock, patch from synapse._scripts.register_new_matrix_user import request_registration +from synapse.types import JsonDict from tests.unittest import TestCase class RegisterTestCase(TestCase): - def test_success(self): + def test_success(self) -> None: """ The script will fetch a nonce, and then generate a MAC with it, and then post that MAC. """ - def get(url, verify=None): + def get(url: str, verify: Optional[bool] = None) -> Mock: r = Mock() r.status_code = 200 r.json = lambda: {"nonce": "a"} return r - def post(url, json=None, verify=None): + def post( + url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None + ) -> Mock: # Make sure we are sent the correct info + assert json is not None self.assertEqual(json["username"], "user") self.assertEqual(json["password"], "pass") self.assertEqual(json["nonce"], "a") @@ -70,12 +74,12 @@ def post(url, json=None, verify=None): # sys.exit shouldn't have been called. self.assertEqual(err_code, []) - def test_failure_nonce(self): + def test_failure_nonce(self) -> None: """ If the script fails to fetch a nonce, it throws an error and quits. """ - def get(url, verify=None): + def get(url: str, verify: Optional[bool] = None) -> Mock: r = Mock() r.status_code = 404 r.reason = "Not Found" @@ -107,20 +111,23 @@ def get(url, verify=None): self.assertIn("ERROR! Received 404 Not Found", out) self.assertNotIn("Success!", out) - def test_failure_post(self): + def test_failure_post(self) -> None: """ The script will fetch a nonce, and then if the final POST fails, will report an error and quit. """ - def get(url, verify=None): + def get(url: str, verify: Optional[bool] = None) -> Mock: r = Mock() r.status_code = 200 r.json = lambda: {"nonce": "a"} return r - def post(url, json=None, verify=None): + def post( + url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None + ) -> Mock: # Make sure we are sent the correct info + assert json is not None self.assertEqual(json["username"], "user") self.assertEqual(json["password"], "pass") self.assertEqual(json["nonce"], "a") From 65906dd07a359bbd3250dd6b793f4ea624b396cf Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 8 Feb 2023 14:15:51 -0500 Subject: [PATCH 11/13] Newsfragment --- changelog.d/15028.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/15028.misc diff --git a/changelog.d/15028.misc b/changelog.d/15028.misc new file mode 100644 index 000000000000..93ceaeafc9b9 --- /dev/null +++ b/changelog.d/15028.misc @@ -0,0 +1 @@ +Improve type hints. From 9a848cff8c0dec97ac67bb5bb53e34513925b9dc Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 8 Feb 2023 14:24:26 -0500 Subject: [PATCH 12/13] Unused import. --- tests/test_utils/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index e7b02ebd56bc..d3fcb7144797 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -20,7 +20,7 @@ import warnings from asyncio import Future from binascii import unhexlify -from typing import Any, Awaitable, Callable, List, Optional, Tuple, TypeVar +from typing import Any, Awaitable, Callable, Optional, Tuple, TypeVar from unittest.mock import Mock import attr From ea95563c289bfe80402f574f061e5bb5110ddd3f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 8 Feb 2023 14:57:23 -0500 Subject: [PATCH 13/13] Fix-up import. --- tests/test_utils/__init__.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index d3fcb7144797..e5dae670a70e 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -20,7 +20,7 @@ import warnings from asyncio import Future from binascii import unhexlify -from typing import Any, Awaitable, Callable, Optional, Tuple, TypeVar +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar from unittest.mock import Mock import attr @@ -35,6 +35,9 @@ from synapse.types import JsonDict +if TYPE_CHECKING: + from sys import UnraisableHookArgs + TV = TypeVar("TV") @@ -79,7 +82,7 @@ def setup_awaitable_errors() -> Callable[[], None]: unraisable_exceptions = [] orig_unraisablehook = sys.unraisablehook - def unraisablehook(unraisable: sys.UnraisableHookArgs) -> None: + def unraisablehook(unraisable: "UnraisableHookArgs") -> None: unraisable_exceptions.append(unraisable.exc_value) def cleanup() -> None: