Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add missing type hints for tests.unittest. #13397

Merged
merged 11 commits into from
Jul 27, 2022
1 change: 1 addition & 0 deletions changelog.d/13397.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Adding missing type hints to tests.
12 changes: 2 additions & 10 deletions tests/handlers/test_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,17 +481,13 @@ def default_config(self) -> Dict[str, Any]:

return config

def prepare(
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
) -> HomeServer:
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
self.allowed_access_token = self.login(self.allowed_localpart, "pass")

self.denied_user_id = self.register_user("denied", "pass")
self.denied_access_token = self.login("denied", "pass")

return hs

def test_denied_without_publication_permission(self) -> None:
"""
Try to create a room, register an alias for it, and publish it,
Expand Down Expand Up @@ -575,9 +571,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):

servlets = [directory.register_servlets, room.register_servlets]

def prepare(
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
) -> HomeServer:
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
room_id = self.helper.create_room_as(self.user_id)

channel = self.make_request(
Expand All @@ -588,8 +582,6 @@ def prepare(
self.room_list_handler = hs.get_room_list_handler()
self.directory_handler = hs.get_directory_handler()

return hs

def test_disabling_room_list(self) -> None:
self.room_list_handler.enable_room_list_search = True
self.directory_handler.enable_room_list_search = True
Expand Down
6 changes: 4 additions & 2 deletions tests/rest/client/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,7 @@ def assert_thread(bundled_aggregations: JsonDict) -> None:
participated, bundled_aggregations.get("current_user_participated")
)
# The latest thread event has some fields that don't matter.
self.assertIn("latest_event", bundled_aggregations)
self.assert_dict(
{
"content": {
Expand All @@ -1072,7 +1073,7 @@ def assert_thread(bundled_aggregations: JsonDict) -> None:
"sender": self.user2_id,
"type": "m.room.test",
},
bundled_aggregations.get("latest_event"),
bundled_aggregations["latest_event"],
)

return assert_thread
Expand Down Expand Up @@ -1112,6 +1113,7 @@ def assert_thread(bundled_aggregations: JsonDict) -> None:
self.assertEqual(2, bundled_aggregations.get("count"))
self.assertTrue(bundled_aggregations.get("current_user_participated"))
# The latest thread event has some fields that don't matter.
self.assertIn("latest_event", bundled_aggregations)
self.assert_dict(
{
"content": {
Expand All @@ -1124,7 +1126,7 @@ def assert_thread(bundled_aggregations: JsonDict) -> None:
"sender": self.user_id,
"type": "m.room.test",
},
bundled_aggregations.get("latest_event"),
bundled_aggregations["latest_event"],
)
# Check the unsigned field on the latest event.
self.assert_dict(
Expand Down
4 changes: 3 additions & 1 deletion tests/rest/client/test_rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,8 +495,10 @@ def test_get_state_cancellation(self) -> None:
)

self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# json_body is defined as JsonDict, but it can be any valid JSON.
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
json_body: List[JsonDict] = channel.json_body # type: ignore[assignment]
self.assertCountEqual(
[state_event["type"] for state_event in channel.json_body],
[state_event["type"] for state_event in json_body],
{
"m.room.create",
"m.room.power_levels",
Expand Down
85 changes: 47 additions & 38 deletions tests/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Generic,
Iterable,
List,
NoReturn,
Optional,
Tuple,
Type,
Expand All @@ -39,7 +40,7 @@
import canonicaljson
import signedjson.key
import unpaddedbase64
from typing_extensions import Protocol
from typing_extensions import ParamSpec, Protocol

from twisted.internet.defer import Deferred, ensureDeferred
from twisted.python.failure import Failure
Expand Down Expand Up @@ -67,7 +68,7 @@
from synapse.rest import RegisterServletsFunc
from synapse.server import HomeServer
from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict, UserID, create_requester
from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util import Clock
from synapse.util.httpresourcetree import create_resource_tree

Expand All @@ -88,6 +89,9 @@
TV = TypeVar("TV")
_ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True)

P = ParamSpec("P")
R = TypeVar("R")


class _TypedFailure(Generic[_ExcType], Protocol):
"""Extension to twisted.Failure, where the 'value' has a certain type."""
Expand All @@ -97,7 +101,7 @@ def value(self) -> _ExcType:
...


def around(target):
def around(target: TV) -> Callable[[Callable[P, R]], None]:
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
"""A CLOS-style 'around' modifier, which wraps the original method of the
given instance with another piece of code.

Expand All @@ -106,11 +110,11 @@ def method_name(orig, *args, **kwargs):
return orig(*args, **kwargs)
"""

def _around(code):
def _around(code: Callable[P, R]) -> None:
name = code.__name__
orig = getattr(target, name)

clokep marked this conversation as resolved.
Show resolved Hide resolved
def new(*args, **kwargs):
def new(*args: P.args, **kwargs: P.kwargs) -> R:
return code(orig, *args, **kwargs)

setattr(target, name, new)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A minor nitpick: Is this strictly correct? From the perspective of type annotations only:

  • code is the wrapper function which is being decorated by @around(target).
  • The first argument to code is orig.
  • Therefore the first argument of new is also orig.
  • Therefore the call to code(...) on line +117 passes in orig twice. But that's not what the source code does.

I think code: Callable[Concatenate[object, P], R] might be more accurate (where object represents orig).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That might be more accurate, yes. I'll check that!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Therefore the first argument of new is also orig.

I don't think this is correct -- the code would fail in this case because setUp doesn't accept *args.

new needs to match the signature of orig, not of code. I believe this is what the current code does.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're agreeing with each other, but I didn't express myself well. Perhaps instead of

Therefore the first argument of new is also orig.

I should have said

mypy can deduce that the first argument of new has type type(orig).


Putting that aside for the moment, however, I entirely agree with your summary here:

new needs to match the signature of orig, not of code

but I don't think this is what the current set of annotations mean. At present, both new and code are of type Callable[P, R]. But this doesn't make sense, because code takes orig as an extra argument!

I still think that code: Callable[Concatenate[object, P], R] is correct here. Have I managed to convince you?

(You could try writing/casting orig: Callable[P, R] = ...; I think mypy would detect the inconsistency I've tried to describe above.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe 603d70a fixes this?

Expand All @@ -131,7 +135,7 @@ def __init__(self, methodName: str):
level = getattr(method, "loglevel", getattr(self, "loglevel", None))

@around(self)
def setUp(orig):
def setUp(orig: Callable[[], R]) -> R:
# if we're not starting in the sentinel logcontext, then to be honest
# all future bets are off.
if current_context():
Expand All @@ -144,7 +148,7 @@ def setUp(orig):
if level is not None and old_level != level:

@around(self)
def tearDown(orig):
def tearDown(orig: Callable[[], R]) -> R:
ret = orig()
logging.getLogger().setLevel(old_level)
return ret
Expand All @@ -158,7 +162,7 @@ def tearDown(orig):
return orig()

@around(self)
def tearDown(orig):
def tearDown(orig: Callable[[], R]) -> R:
ret = orig()
# force a GC to workaround problems with deferreds leaking logcontexts when
# they are GCed (see the logcontext docs)
Expand All @@ -167,7 +171,7 @@ def tearDown(orig):

return ret

def assertObjectHasAttributes(self, attrs, obj):
def assertObjectHasAttributes(self, attrs: Dict[str, object], obj: object) -> None:
"""Asserts that the given object has each of the attributes given, and
that the value of each matches according to assertEqual."""
for key in attrs.keys():
Expand All @@ -178,44 +182,44 @@ def assertObjectHasAttributes(self, attrs, obj):
except AssertionError as e:
raise (type(e))(f"Assert error for '.{key}':") from e

def assert_dict(self, required, actual):
def assert_dict(self, required: dict, actual: dict) -> None:
"""Does a partial assert of a dict.

Args:
required (dict): The keys and value which MUST be in 'actual'.
actual (dict): The test result. Extra keys will not be checked.
required: The keys and value which MUST be in 'actual'.
actual: The test result. Extra keys will not be checked.
"""
for key in required:
self.assertEqual(
required[key], actual[key], msg="%s mismatch. %s" % (key, actual)
)


def DEBUG(target):
def DEBUG(target: TV) -> TV:
"""A decorator to set the .loglevel attribute to logging.DEBUG.
Can apply to either a TestCase or an individual test method."""
target.loglevel = logging.DEBUG
target.loglevel = logging.DEBUG # type: ignore[attr-defined]
return target


def INFO(target):
def INFO(target: TV) -> TV:
"""A decorator to set the .loglevel attribute to logging.INFO.
Can apply to either a TestCase or an individual test method."""
target.loglevel = logging.INFO
target.loglevel = logging.INFO # type: ignore[attr-defined]
return target


def logcontext_clean(target):
def logcontext_clean(target: TV) -> TV:
"""A decorator which marks the TestCase or method as 'logcontext_clean'

... ie, any logcontext errors should cause a test failure
"""

def logcontext_error(msg):
def logcontext_error(msg: str) -> NoReturn:
raise AssertionError("logcontext error: %s" % (msg))

patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error)
return patcher(target)
return patcher(target) # type: ignore[call-overload]


class HomeserverTestCase(TestCase):
Expand Down Expand Up @@ -255,7 +259,7 @@ def __init__(self, methodName: str):
method = getattr(self, methodName)
self._extra_config = getattr(method, "_extra_config", None)

def setUp(self):
def setUp(self) -> None:
"""
Set up the TestCase by calling the homeserver constructor, optionally
hijacking the authentication system to return a fixed user, and then
Expand Down Expand Up @@ -306,15 +310,21 @@ def setUp(self):
)
)

async def get_user_by_access_token(token=None, allow_guest=False):
async def get_user_by_access_token(
token: Optional[str] = None, allow_guest: bool = False
) -> JsonDict:
assert self.helper.auth_user_id is not None
return {
"user": UserID.from_string(self.helper.auth_user_id),
"token_id": token_id,
"is_guest": False,
}

async def get_user_by_req(request, allow_guest=False):
async def get_user_by_req(
request: SynapseRequest,
allow_guest: bool = False,
allow_expired: bool = False,
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
) -> Requester:
assert self.helper.auth_user_id is not None
return create_requester(
UserID.from_string(self.helper.auth_user_id),
Expand All @@ -339,11 +349,11 @@ async def get_user_by_req(request, allow_guest=False):
if hasattr(self, "prepare"):
self.prepare(self.reactor, self.clock, self.hs)

def tearDown(self):
def tearDown(self) -> None:
# Reset to not use frozen dicts.
events.USE_FROZEN_DICTS = False

def wait_on_thread(self, deferred, timeout=10):
def wait_on_thread(self, deferred: Deferred, timeout: int = 10) -> None:
"""
Wait until a Deferred is done, where it's waiting on a real thread.
"""
Expand Down Expand Up @@ -374,7 +384,7 @@ def make_homeserver(self, reactor, clock):
clock (synapse.util.Clock): The Clock, associated with the reactor.

Returns:
A homeserver (synapse.server.HomeServer) suitable for testing.
A homeserver suitable for testing.

Function to be overridden in subclasses.
"""
Expand Down Expand Up @@ -408,7 +418,7 @@ def create_resource_dict(self) -> Dict[str, Resource]:
"/_synapse/admin": servlet_resource,
}

def default_config(self):
def default_config(self) -> JsonDict:
"""
Get a default HomeServer config dict.
"""
Expand All @@ -421,7 +431,9 @@ def default_config(self):

return config

def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
"""
Prepare for the test. This involves things like mocking out parts of
the homeserver, or building test data common across the whole test
Expand Down Expand Up @@ -519,7 +531,7 @@ def setup_test_homeserver(self, *args: Any, **kwargs: Any) -> HomeServer:
config_obj.parse_config_dict(config, "", "")
kwargs["config"] = config_obj

async def run_bg_updates():
async def run_bg_updates() -> None:
with LoggingContext("run_bg_updates"):
self.get_success(stor.db_pool.updates.run_background_updates(False))

Expand All @@ -538,11 +550,7 @@ def pump(self, by: float = 0.0) -> None:
"""
self.reactor.pump([by] * 100)

def get_success(
self,
d: Awaitable[TV],
by: float = 0.0,
) -> TV:
def get_success(self, d: Awaitable[TV], by: float = 0.0) -> TV:
deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type]
self.pump(by=by)
return self.successResultOf(deferred)
Expand Down Expand Up @@ -755,7 +763,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
OTHER_SERVER_NAME = "other.example.com"
OTHER_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test")

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)

# poke the other server's signing key into the key store, so that we don't
Expand Down Expand Up @@ -879,7 +887,7 @@ def _auth_header_for_request(
)


def override_config(extra_config):
def override_config(extra_config: JsonDict) -> Callable[[TV], TV]:
"""A decorator which can be applied to test functions to give additional HS config

For use
Expand All @@ -892,12 +900,13 @@ def test_foo(self):
...

Args:
extra_config(dict): Additional config settings to be merged into the default
extra_config: Additional config settings to be merged into the default
config dict before instantiating the test homeserver.
"""

def decorator(func):
func._extra_config = extra_config
def decorator(func: TV) -> TV:
# This attribute is being defined.
func._extra_config = extra_config # type: ignore[attr-defined]
return func

return decorator
Expand Down