Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Prefer make_awaitable over defer.succeed in tests #12505

Merged
merged 9 commits into from
Apr 27, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12505.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use `make_awaitable` instead of `defer.succeed` for return values of mocks in tests.
21 changes: 16 additions & 5 deletions synapse/logging/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,11 +808,22 @@ def run_in_background( # type: ignore[misc]
# At this point we should have a Deferred, if not then f was a synchronous
# function, wrap it in a Deferred for consistency.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment (apart from being horrifying grammar) seems to be a bit outdated?

(I also wonder if we really need to support synchronous functions here these days. if not we can simplify all this with an assert isinstance(res, Awaitable))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll come up with some new words here.

Ditching support for synchronous functions is a good idea, but:

  • run_in_background is part of the module API so it's a little naughty to change it. However, I unknowingly did something similar to make_deferred_yieldable in Add missing type hints to synapse.logging.context #11556 by a while back and nobody's complained to my knowledge. So maybe that's fine?
  • run_in_background is called by preserve_fn, which also accepts synchronous functions. preserve_fn's used by @cached and @cachedList to wrap methods, some of which are still synchronous. I think restricting those decorators to async functions is going to turn out to be a rabbit hole.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, fair. In general, @cached and co seem to be overkill for a synchronous function and we should just use @lru_cache or something instead - but agreed this is a rabbithole we shouldn't get into right now.

if not isinstance(res, defer.Deferred):
# `res` is not a `Deferred` and not a `Coroutine`.
# There are no other types of `Awaitable`s we expect to encounter in Synapse.
assert not isinstance(res, Awaitable)

return defer.succeed(res)
if isinstance(res, Awaitable):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we flip this condition to reduce nesting?

if not isinstance(res, Awaitable):
    # `f` returned a plain value.
    return defer.succeed(res)

# now handle the completed awaitable case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

# `f` returned some kind of awaitable that is not a coroutine or `Deferred`.
# We assume that it is a completed awaitable, such as a `DoneAwaitable` or
# `Future` from `make_awaitable`, and await it manually.
iterator = res.__await__() # `__await__` returns an iterator...
try:
next(iterator)
raise ValueError(
f"Function {f} returned an unresolved awaitable: {res}"
)
except StopIteration as e:
# ...which raises a `StopIteration` once the awaitable is complete.
return defer.succeed(e.value)
else:
# `f` returned a plain value.
return defer.succeed(res)

if res.called and not res.paused:
# The function should have maintained the logcontext, so we can
Expand Down
2 changes: 1 addition & 1 deletion tests/federation/test_federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_get_room_state(self):
)

# mock up the response, and have the agent return it
self._mock_agent.request.return_value = defer.succeed(
self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
clokep marked this conversation as resolved.
Show resolved Hide resolved
_mock_response(
{
"pdus": [
Expand Down
2 changes: 1 addition & 1 deletion tests/federation/test_federation_sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def test_dont_send_device_updates_for_remote_users(self):
# Send the server a device list EDU for the other user, this will cause
# it to try and resync the device lists.
self.hs.get_federation_transport_client().query_user_devices.return_value = (
defer.succeed(
make_awaitable(
{
"stream_id": "1",
"user_id": "@user2:host2",
Expand Down
7 changes: 3 additions & 4 deletions tests/handlers/test_e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from parameterized import parameterized
from signedjson import key as key, sign as sign

from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import RoomEncryptionAlgorithms
Expand Down Expand Up @@ -704,7 +703,7 @@ def test_query_devices_remote_no_sync(self) -> None:
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"

self.hs.get_federation_client().query_client_keys = mock.Mock(
return_value=defer.succeed(
return_value=make_awaitable(
{
"device_keys": {remote_user_id: {}},
"master_keys": {
Expand Down Expand Up @@ -777,14 +776,14 @@ def test_query_devices_remote_sync(self) -> None:
# Pretend we're sharing a room with the user we're querying. If not,
# `_query_devices_for_destination` will return early.
self.store.get_rooms_for_user = mock.Mock(
return_value=defer.succeed({"some_room_id"})
return_value=make_awaitable({"some_room_id"})
)

remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"

self.hs.get_federation_client().query_user_devices = mock.Mock(
return_value=defer.succeed(
return_value=make_awaitable(
{
"user_id": remote_user_id,
"stream_id": 1,
Expand Down
34 changes: 16 additions & 18 deletions tests/handlers/test_password_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
from typing import Any, Type, Union
from unittest.mock import Mock

from twisted.internet import defer

import synapse
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
Expand Down Expand Up @@ -190,7 +188,7 @@ def password_only_auth_provider_login_test_body(self):
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)

# check_password must return an awaitable
mock_password_provider.check_password.return_value = defer.succeed(True)
mock_password_provider.check_password.return_value = make_awaitable(True)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@u:test", channel.json_body["user_id"])
Expand Down Expand Up @@ -226,13 +224,13 @@ def password_only_auth_provider_ui_auth_test_body(self):
self.get_success(module_api.register_user("u"))

# log in twice, to get two devices
mock_password_provider.check_password.return_value = defer.succeed(True)
mock_password_provider.check_password.return_value = make_awaitable(True)
tok1 = self.login("u", "p")
self.login("u", "p", device_id="dev2")
mock_password_provider.reset_mock()

# have the auth provider deny the request to start with
mock_password_provider.check_password.return_value = defer.succeed(False)
mock_password_provider.check_password.return_value = make_awaitable(False)

# make the initial request which returns a 401
session = self._start_delete_device_session(tok1, "dev2")
Expand All @@ -246,7 +244,7 @@ def password_only_auth_provider_ui_auth_test_body(self):
mock_password_provider.reset_mock()

# Finally, check the request goes through when we allow it
mock_password_provider.check_password.return_value = defer.succeed(True)
mock_password_provider.check_password.return_value = make_awaitable(True)
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
self.assertEqual(channel.code, 200)
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
Expand All @@ -260,7 +258,7 @@ def local_user_fallback_login_test_body(self):
self.register_user("localuser", "localpass")

# check_password must return an awaitable
mock_password_provider.check_password.return_value = defer.succeed(False)
mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 403, channel.result)

Expand All @@ -277,7 +275,7 @@ def local_user_fallback_ui_auth_test_body(self):
self.register_user("localuser", "localpass")

# have the auth provider deny the request
mock_password_provider.check_password.return_value = defer.succeed(False)
mock_password_provider.check_password.return_value = make_awaitable(False)

# log in twice, to get two devices
tok1 = self.login("localuser", "localpass")
Expand Down Expand Up @@ -320,7 +318,7 @@ def no_local_user_fallback_login_test_body(self):
self.register_user("localuser", "localpass")

# check_password must return an awaitable
mock_password_provider.check_password.return_value = defer.succeed(False)
mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
Expand All @@ -342,7 +340,7 @@ def no_local_user_fallback_ui_auth_test_body(self):
self.register_user("localuser", "localpass")

# allow login via the auth provider
mock_password_provider.check_password.return_value = defer.succeed(True)
mock_password_provider.check_password.return_value = make_awaitable(True)

# log in twice, to get two devices
tok1 = self.login("localuser", "p")
Expand All @@ -359,7 +357,7 @@ def no_local_user_fallback_ui_auth_test_body(self):
mock_password_provider.check_password.assert_not_called()

# now try deleting with the local password
mock_password_provider.check_password.return_value = defer.succeed(False)
mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._authed_delete_device(
tok1, "dev2", session, "localuser", "localpass"
)
Expand Down Expand Up @@ -413,7 +411,7 @@ def custom_auth_provider_login_test_body(self):
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()

mock_password_provider.check_auth.return_value = defer.succeed(
mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", None)
)
channel = self._send_login("test.login_type", "u", test_field="y")
Expand All @@ -427,7 +425,7 @@ def custom_auth_provider_login_test_body(self):
# try a weird username. Again, it's unclear what we *expect* to happen
# in these cases, but at least we can guard against the API changing
# unexpectedly
mock_password_provider.check_auth.return_value = defer.succeed(
mock_password_provider.check_auth.return_value = make_awaitable(
("@ MALFORMED! :bz", None)
)
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
Expand Down Expand Up @@ -477,7 +475,7 @@ def custom_auth_provider_ui_auth_test_body(self):
mock_password_provider.reset_mock()

# right params, but authing as the wrong user
mock_password_provider.check_auth.return_value = defer.succeed(
mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", None)
)
body["auth"]["test_field"] = "foo"
Expand All @@ -490,7 +488,7 @@ def custom_auth_provider_ui_auth_test_body(self):
mock_password_provider.reset_mock()

# and finally, succeed
mock_password_provider.check_auth.return_value = defer.succeed(
mock_password_provider.check_auth.return_value = make_awaitable(
("@localuser:test", None)
)
channel = self._delete_device(tok1, "dev2", body)
Expand All @@ -508,9 +506,9 @@ def test_custom_auth_provider_callback(self):
self.custom_auth_provider_callback_test_body()

def custom_auth_provider_callback_test_body(self):
callback = Mock(return_value=defer.succeed(None))
callback = Mock(return_value=make_awaitable(None))

mock_password_provider.check_auth.return_value = defer.succeed(
mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", callback)
)
channel = self._send_login("test.login_type", "u", test_field="y")
Expand Down Expand Up @@ -646,7 +644,7 @@ def password_custom_auth_password_disabled_ui_auth_test_body(self):
login is disabled"""
# register the user and log in twice via the test login type to get two devices,
self.register_user("localuser", "localpass")
mock_password_provider.check_auth.return_value = defer.succeed(
mock_password_provider.check_auth.return_value = make_awaitable(
("@localuser:test", None)
)
channel = self._send_login("test.login_type", "localuser", test_field="")
Expand Down
6 changes: 3 additions & 3 deletions tests/handlers/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
mock_keyring = Mock(spec=["verify_json_for_server"])
mock_keyring.verify_json_for_server.return_value = defer.succeed(True)
mock_keyring.verify_json_for_server.return_value = make_awaitable(True)

# we mock out the federation client too
mock_federation_client = Mock(spec=["put_json"])
mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))

# the tests assume that we are starting at unix time 1000
reactor.pump((1000,))
Expand Down Expand Up @@ -98,7 +98,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:

self.datastore = hs.get_datastores().main
self.datastore.get_destination_retry_timings = Mock(
return_value=defer.succeed(None)
return_value=make_awaitable(None)
)

self.datastore.get_device_updates_by_remote = Mock(
Expand Down
6 changes: 3 additions & 3 deletions tests/handlers/test_user_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from unittest.mock import Mock, patch
from urllib.parse import quote

from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor

import synapse.rest.admin
Expand All @@ -30,6 +29,7 @@

from tests import unittest
from tests.storage.test_user_directory import GetUserDirectoryTables
from tests.test_utils import make_awaitable
from tests.test_utils.event_injection import inject_member_event
from tests.unittest import override_config

Expand Down Expand Up @@ -439,7 +439,7 @@ def test_handle_user_deactivated_support_user(self) -> None:
)
)

mock_remove_from_user_dir = Mock(return_value=defer.succeed(None))
mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
with patch.object(
self.store, "remove_from_user_dir", mock_remove_from_user_dir
):
Expand All @@ -454,7 +454,7 @@ def test_handle_user_deactivated_regular_user(self) -> None:
self.store.register_user(user_id=r_user_id, password_hash=None)
)

mock_remove_from_user_dir = Mock(return_value=defer.succeed(None))
mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
with patch.object(
self.store, "remove_from_user_dir", mock_remove_from_user_dir
):
Expand Down
4 changes: 2 additions & 2 deletions tests/rest/client/test_presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from http import HTTPStatus
from unittest.mock import Mock

from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor

from synapse.handlers.presence import PresenceHandler
Expand All @@ -24,6 +23,7 @@
from synapse.util import Clock

from tests import unittest
from tests.test_utils import make_awaitable


class PresenceTestCase(unittest.HomeserverTestCase):
Expand All @@ -37,7 +37,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:

presence_handler = Mock(spec=PresenceHandler)
presence_handler.set_state.return_value = defer.succeed(None)
presence_handler.set_state.return_value = make_awaitable(None)

hs = self.setup_test_homeserver(
"red",
Expand Down
7 changes: 2 additions & 5 deletions tests/rest/client/test_rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from unittest.mock import Mock, call
from urllib import parse as urlparse

from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor

import synapse.rest.admin
Expand Down Expand Up @@ -1426,9 +1425,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:

def test_simple(self) -> None:
"Simple test for searching rooms over federation"
self.federation_client.get_public_rooms.side_effect = lambda *a, **k: defer.succeed( # type: ignore[attr-defined]
{}
)
self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined]

search_filter = {"generic_search_term": "foobar"}

Expand Down Expand Up @@ -1456,7 +1453,7 @@ def test_fallback(self) -> None:
# with a 404, when using search filters.
self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined]
HttpResponseException(404, "Not Found", b""),
defer.succeed({}),
make_awaitable({}),
)

search_filter = {"generic_search_term": "foobar"}
Expand Down
7 changes: 4 additions & 3 deletions tests/rest/client/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from synapse.util import Clock

from tests import unittest
from tests.test_utils import make_awaitable
from tests.utils import MockClock


Expand All @@ -38,7 +39,7 @@ def setUp(self) -> None:

@defer.inlineCallbacks
def test_executes_given_function(self):
cb = Mock(return_value=defer.succeed(self.mock_http_response))
cb = Mock(return_value=make_awaitable(self.mock_http_response))
res = yield self.cache.fetch_or_execute(
self.mock_key, cb, "some_arg", keyword="arg"
)
Expand All @@ -47,7 +48,7 @@ def test_executes_given_function(self):

@defer.inlineCallbacks
def test_deduplicates_based_on_key(self):
cb = Mock(return_value=defer.succeed(self.mock_http_response))
cb = Mock(return_value=make_awaitable(self.mock_http_response))
for i in range(3): # invoke multiple times
res = yield self.cache.fetch_or_execute(
self.mock_key, cb, "some_arg", keyword="arg", changing_args=i
Expand Down Expand Up @@ -130,7 +131,7 @@ def cb():

@defer.inlineCallbacks
def test_cleans_up(self):
cb = Mock(return_value=defer.succeed(self.mock_http_response))
cb = Mock(return_value=make_awaitable(self.mock_http_response))
yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
# should NOT have cleaned up yet
self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)
Expand Down
Loading