Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Use inline type hints in tests/ (#10350)
Browse files Browse the repository at this point in the history
This PR is tantamount to running:

    python3.8 -m com2ann -v 6 tests/

(com2ann requires python 3.8 to run)
  • Loading branch information
ShadowJonathan authored Jul 13, 2021
1 parent 89cfc3d commit 9372971
Show file tree
Hide file tree
Showing 18 changed files with 62 additions and 63 deletions.
1 change: 1 addition & 0 deletions changelog.d/10350.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert internal type variable syntax to reflect wider ecosystem use.
6 changes: 3 additions & 3 deletions tests/events/test_presence_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def test_receiving_all_presence(self):
)
self.assertEqual(len(presence_updates), 1)

presence_update = presence_updates[0] # type: UserPresenceState
presence_update: UserPresenceState = presence_updates[0]
self.assertEqual(presence_update.user_id, self.other_user_one_id)
self.assertEqual(presence_update.state, "online")
self.assertEqual(presence_update.status_msg, "boop")
Expand Down Expand Up @@ -274,7 +274,7 @@ def test_send_local_online_presence_to_with_module(self):
presence_updates, _ = sync_presence(self, self.other_user_id)
self.assertEqual(len(presence_updates), 1)

presence_update = presence_updates[0] # type: UserPresenceState
presence_update: UserPresenceState = presence_updates[0]
self.assertEqual(presence_update.user_id, self.other_user_id)
self.assertEqual(presence_update.state, "online")
self.assertEqual(presence_update.status_msg, "I'm online!")
Expand Down Expand Up @@ -320,7 +320,7 @@ def test_send_local_online_presence_to_with_module(self):
)
for call in calls:
call_args = call[0]
federation_transaction = call_args[0] # type: Transaction
federation_transaction: Transaction = call_args[0]

# Get the sent EDUs in this transaction
edus = federation_transaction.get_dict()["edus"]
Expand Down
16 changes: 8 additions & 8 deletions tests/module_api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ def test_sending_events_into_room(self):
"content": content,
"sender": user_id,
}
event = self.get_success(
event: EventBase = self.get_success(
self.module_api.create_and_send_event_into_room(event_dict)
) # type: EventBase
)
self.assertEqual(event.sender, user_id)
self.assertEqual(event.type, "m.room.message")
self.assertEqual(event.room_id, room_id)
Expand Down Expand Up @@ -136,9 +136,9 @@ def test_sending_events_into_room(self):
"sender": user_id,
"state_key": "",
}
event = self.get_success(
event: EventBase = self.get_success(
self.module_api.create_and_send_event_into_room(event_dict)
) # type: EventBase
)
self.assertEqual(event.sender, user_id)
self.assertEqual(event.type, "m.room.power_levels")
self.assertEqual(event.room_id, room_id)
Expand Down Expand Up @@ -281,7 +281,7 @@ def test_send_local_online_presence_to_federation(self):
)
for call in calls:
call_args = call[0]
federation_transaction = call_args[0] # type: Transaction
federation_transaction: Transaction = call_args[0]

# Get the sent EDUs in this transaction
edus = federation_transaction.get_dict()["edus"]
Expand Down Expand Up @@ -390,7 +390,7 @@ def _test_sending_local_online_presence_to_local_user(
)
test_case.assertEqual(len(presence_updates), 1)

presence_update = presence_updates[0] # type: UserPresenceState
presence_update: UserPresenceState = presence_updates[0]
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
test_case.assertEqual(presence_update.state, "online")

Expand Down Expand Up @@ -443,7 +443,7 @@ def _test_sending_local_online_presence_to_local_user(
)
test_case.assertEqual(len(presence_updates), 1)

presence_update = presence_updates[0] # type: UserPresenceState
presence_update: UserPresenceState = presence_updates[0]
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
test_case.assertEqual(presence_update.state, "online")

Expand All @@ -454,7 +454,7 @@ def _test_sending_local_online_presence_to_local_user(
)
test_case.assertEqual(len(presence_updates), 1)

presence_update = presence_updates[0] # type: UserPresenceState
presence_update: UserPresenceState = presence_updates[0]
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
test_case.assertEqual(presence_update.state, "online")

Expand Down
12 changes: 6 additions & 6 deletions tests/replication/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
self.server = server_factory.buildProtocol(
self.server: ServerReplicationStreamProtocol = server_factory.buildProtocol(
None
) # type: ServerReplicationStreamProtocol
)

# Make a new HomeServer object for the worker
self.reactor.lookups["testserv"] = "1.2.3.4"
Expand Down Expand Up @@ -195,7 +195,7 @@ def assert_request_is_get_repl_stream_updates(
fetching updates for given stream.
"""

path = request.path # type: bytes # type: ignore
path: bytes = request.path # type: ignore
self.assertRegex(
path,
br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
Expand All @@ -212,7 +212,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
unlike `BaseStreamTestCase`.
"""

servlets = [] # type: List[Callable[[HomeServer, JsonResource], None]]
servlets: List[Callable[[HomeServer, JsonResource], None]] = []

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -448,7 +448,7 @@ def __init__(self, hs: HomeServer):
super().__init__(hs)

# list of received (stream_name, token, row) tuples
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
self.received_rdata_rows: List[Tuple[str, int, Any]] = []

async def on_rdata(self, stream_name, instance_name, token, rows):
await super().on_rdata(stream_name, instance_name, token, rows)
Expand Down Expand Up @@ -484,7 +484,7 @@ def buildProtocol(self, addr):
class FakeRedisPubSubProtocol(Protocol):
"""A connection from a client talking to the fake Redis server."""

transport = None # type: Optional[FakeTransport]
transport: Optional[FakeTransport] = None

def __init__(self, server: FakeRedisPubSubServer):
self._server = server
Expand Down
14 changes: 7 additions & 7 deletions tests/replication/tcp/streams/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def test_update_function_huge_state_change(self):
)

# this is the point in the DAG where we make a fork
fork_point = self.get_success(
fork_point: List[str] = self.get_success(
self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
) # type: List[str]
)

events = [
self._inject_state_event(sender=OTHER_USER)
Expand Down Expand Up @@ -238,7 +238,7 @@ def test_update_function_huge_state_change(self):
self.assertEqual(row.data.event_id, pl_event.event_id)

# the state rows are unsorted
state_rows = [] # type: List[EventsStreamCurrentStateRow]
state_rows: List[EventsStreamCurrentStateRow] = []
for stream_name, _, row in received_rows:
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
Expand Down Expand Up @@ -290,11 +290,11 @@ def test_update_function_state_row_limit(self):
)

# this is the point in the DAG where we make a fork
fork_point = self.get_success(
fork_point: List[str] = self.get_success(
self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
) # type: List[str]
)

events = [] # type: List[EventBase]
events: List[EventBase] = []
for user in user_ids:
events.extend(
self._inject_state_event(sender=user) for _ in range(STATES_PER_USER)
Expand Down Expand Up @@ -355,7 +355,7 @@ def test_update_function_state_row_limit(self):
self.assertEqual(row.data.event_id, pl_events[i].event_id)

# the state rows are unsorted
state_rows = [] # type: List[EventsStreamCurrentStateRow]
state_rows: List[EventsStreamCurrentStateRow] = []
for _ in range(STATES_PER_USER + 1):
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
Expand Down
4 changes: 2 additions & 2 deletions tests/replication/tcp/streams/test_receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_receipt(self):
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "receipts")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0]
self.assertEqual("!room:blue", row.room_id)
self.assertEqual("m.read", row.receipt_type)
self.assertEqual(USER_ID, row.user_id)
Expand Down Expand Up @@ -75,7 +75,7 @@ def test_receipt(self):
self.assertEqual(token, 3)
self.assertEqual(1, len(rdata_rows))

row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0]
self.assertEqual("!room2:blue", row.room_id)
self.assertEqual("m.read", row.receipt_type)
self.assertEqual(USER_ID, row.user_id)
Expand Down
4 changes: 2 additions & 2 deletions tests/replication/tcp/streams/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_typing(self):
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: TypingStream.TypingStreamRow
row: TypingStream.TypingStreamRow = rdata_rows[0]
self.assertEqual(ROOM_ID, row.room_id)
self.assertEqual([USER_ID], row.user_ids)

Expand Down Expand Up @@ -102,7 +102,7 @@ def test_reset(self):
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: TypingStream.TypingStreamRow
row: TypingStream.TypingStreamRow = rdata_rows[0]
self.assertEqual(ROOM_ID, row.room_id)
self.assertEqual([USER_ID], row.user_ids)

Expand Down
2 changes: 1 addition & 1 deletion tests/replication/test_multi_media_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

logger = logging.getLogger(__name__)

test_server_connection_factory = None # type: Optional[TestServerTLSConnectionFactory]
test_server_connection_factory: Optional[TestServerTLSConnectionFactory] = None


class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
Expand Down
4 changes: 2 additions & 2 deletions tests/rest/client/test_third_party_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,11 @@ def test_send_event(self):
"content": content,
"sender": self.user_id,
}
event = self.get_success(
event: EventBase = self.get_success(
current_rules_module().module_api.create_and_send_event_into_room(
event_dict
)
) # type: EventBase
)

self.assertEquals(event.sender, self.user_id)
self.assertEquals(event.room_id, self.room_id)
Expand Down
14 changes: 6 additions & 8 deletions tests/rest/client/v1/test_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def test_get_msc2858_login_flows(self):
self.assertEqual(channel.code, 200, channel.result)

# stick the flows results in a dict by type
flow_results = {} # type: Dict[str, Any]
flow_results: Dict[str, Any] = {}
for f in channel.json_body["flows"]:
flow_type = f["type"]
self.assertNotIn(
Expand Down Expand Up @@ -501,7 +501,7 @@ def test_multi_sso_redirect(self):
p.close()

# there should be a link for each href
returned_idps = [] # type: List[str]
returned_idps: List[str] = []
for link in p.links:
path, query = link.split("?", 1)
self.assertEqual(path, "pick_idp")
Expand Down Expand Up @@ -582,7 +582,7 @@ def test_login_via_oidc(self):
# ... and should have set a cookie including the redirect url
cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
assert cookie_headers
cookies = {} # type: Dict[str, str]
cookies: Dict[str, str] = {}
for h in cookie_headers:
key, value = h.split(";")[0].split("=", maxsplit=1)
cookies[key] = value
Expand Down Expand Up @@ -874,9 +874,7 @@ def make_homeserver(self, reactor, clock):

def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
result = jwt.encode(
payload, secret, self.jwt_algorithm
) # type: Union[str, bytes]
result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm)
if isinstance(result, bytes):
return result.decode("ascii")
return result
Expand Down Expand Up @@ -1084,7 +1082,7 @@ def make_homeserver(self, reactor, clock):

def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
result = jwt.encode(payload, secret, "RS256") # type: Union[bytes,str]
result: Union[bytes, str] = jwt.encode(payload, secret, "RS256")
if isinstance(result, bytes):
return result.decode("ascii")
return result
Expand Down Expand Up @@ -1272,7 +1270,7 @@ def test_username_picker(self):
self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")

# ... with a username_mapping_session cookie
cookies = {} # type: Dict[str,str]
cookies: Dict[str, str] = {}
channel.extract_cookies(cookies)
self.assertIn("username_mapping_session", cookies)
session_id = cookies["username_mapping_session"]
Expand Down
8 changes: 5 additions & 3 deletions tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class FakeChannel:
_reactor = attr.ib()
result = attr.ib(type=dict, default=attr.Factory(dict))
_ip = attr.ib(type=str, default="127.0.0.1")
_producer = None # type: Optional[Union[IPullProducer, IPushProducer]]
_producer: Optional[Union[IPullProducer, IPushProducer]] = None

@property
def json_body(self):
Expand Down Expand Up @@ -316,8 +316,10 @@ def __init__(self):

self._tcp_callbacks = {}
self._udp = []
lookups = self.lookups = {} # type: Dict[str, str]
self._thread_callbacks = deque() # type: Deque[Callable[[], None]]
self.lookups: Dict[str, str] = {}
self._thread_callbacks: Deque[Callable[[], None]] = deque()

lookups = self.lookups

@implementer(IResolverSimple)
class FakeResolver:
Expand Down
4 changes: 1 addition & 3 deletions tests/storage/test_background_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
self.updates = (
self.hs.get_datastore().db_pool.updates
) # type: BackgroundUpdater
self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates
# the base test class should have run the real bg updates for us
self.assertTrue(
self.get_success(self.updates.has_completed_background_updates())
Expand Down
6 changes: 3 additions & 3 deletions tests/storage/test_id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):

def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.db_pool = self.store.db_pool # type: DatabasePool
self.db_pool: DatabasePool = self.store.db_pool

self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))

Expand Down Expand Up @@ -460,7 +460,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):

def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.db_pool = self.store.db_pool # type: DatabasePool
self.db_pool: DatabasePool = self.store.db_pool

self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))

Expand Down Expand Up @@ -586,7 +586,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):

def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.db_pool = self.store.db_pool # type: DatabasePool
self.db_pool: DatabasePool = self.store.db_pool

self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))

Expand Down
2 changes: 1 addition & 1 deletion tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def test_branch_no_conflict(self):

self.store.register_events(graph.walk())

context_store = {} # type: dict[str, EventContext]
context_store: dict[str, EventContext] = {}

for event in graph.walk():
context = yield defer.ensureDeferred(
Expand Down
6 changes: 3 additions & 3 deletions tests/test_utils/html_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ def __init__(self):
super().__init__()

# a list of links found in the doc
self.links = [] # type: List[str]
self.links: List[str] = []

# the values of any hidden <input>s: map from name to value
self.hiddens = {} # type: Dict[str, Optional[str]]
self.hiddens: Dict[str, Optional[str]] = {}

# the values of any radio buttons: map from name to list of values
self.radios = {} # type: Dict[str, List[Optional[str]]]
self.radios: Dict[str, List[Optional[str]]] = {}

def handle_starttag(
self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
Expand Down
2 changes: 1 addition & 1 deletion tests/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def get_success_or_raise(self, d, by=0.0):
if not isinstance(deferred, Deferred):
return d

results = [] # type: list
results: list = []
deferred.addBoth(results.append)

self.pump(by=by)
Expand Down
Loading

0 comments on commit 9372971

Please sign in to comment.