From f73d84f13153acc06ee657c1f49be64a653d3475 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 23 May 2023 15:32:24 -0400 Subject: [PATCH 1/6] Improve type hints on cached descriptor --- synapse/util/caches/descriptors.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 81df71a0c514..8514a75a1c2f 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -220,7 +220,9 @@ def __init__( self.iterable = iterable self.prune_unread_entries = prune_unread_entries - def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: + def __get__( + self, obj: Optional[Any], owner: Optional[Type] + ) -> Callable[..., "defer.Deferred[Any]"]: cache: DeferredCache[CacheKey, Any] = DeferredCache( name=self.name, max_entries=self.max_entries, @@ -232,7 +234,7 @@ def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., An get_cache_key = self.cache_key_builder @functools.wraps(self.orig) - def _wrapped(*args: Any, **kwargs: Any) -> Any: + def _wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Any]": # If we're passed a cache_context then we'll want to call its invalidate() # whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) From ea3071c72a2878313f7994c4e11fb870a31f71ca Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 23 May 2023 15:32:39 -0400 Subject: [PATCH 2/6] Improve the mypy plugin to return deferreds for caches. --- scripts-dev/mypy_synapse_plugin.py | 29 ++++++- synapse/storage/databases/main/roommember.py | 2 +- tests/appservice/test_appservice.py | 82 +++++++------------- tests/storage/test_transactions.py | 11 ++- 4 files changed, 63 insertions(+), 61 deletions(-) diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index 2c377533c0fd..da19e0250cdb 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -18,10 +18,12 @@ from typing import Callable, Optional, Type +from mypy.erasetype import remove_instance_last_known_values from mypy.nodes import ARG_NAMED_OPT from mypy.plugin import MethodSigContext, Plugin from mypy.typeops import bind_self -from mypy.types import CallableType, NoneType, UnionType +from mypy.types import CallableType, Instance, NoneType, UnionType + class SynapsePlugin(Plugin): @@ -92,10 +94,35 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: arg_names.append("on_invalidate") arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg. + # Finally we ensure the return type is a Deferred. + if ( + isinstance(signature.ret_type, Instance) + and signature.ret_type.type.fullname == "twisted.internet.defer.Deferred" + ): + # If it is already a Deferred, nothing to do. + ret_type = signature.ret_type + else: + # If a coroutine, wrap the coroutine's return type in a Deferred. + if ( + isinstance(signature.ret_type, Instance) + and signature.ret_type.type.fullname == "typing.Coroutine" + ): + ret_arg = signature.ret_type.args[2] + + # Otherwise, wrap the return type in a Deferred. + else: + ret_arg = signature.ret_type + + # This should be able to use ctx.api.lookup_typeinfo, but that doesn't seem + # to find the correct symbol. + sym = ctx.api.modules["twisted.internet.defer"].names.get("Deferred") # type: ignore[attr-defined] + ret_type = Instance(sym.node, [remove_instance_last_known_values(ret_arg)]) + signature = signature.copy_modified( arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds, + ret_type=ret_type, ) return signature diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index e068f27a1079..ae9c201b87e8 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -1099,7 +1099,7 @@ async def _get_joined_hosts( # `get_joined_hosts` is called with the "current" state group for the # room, and so consecutive calls will be for consecutive state groups # which point to the previous state group. - cache = await self._get_joined_hosts_cache(room_id) # type: ignore[misc] + cache = await self._get_joined_hosts_cache(room_id) # If the state group in the cache matches, we already have the data we need. if state_entry.state_group == cache.state_group: diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index dee976356faa..66753c60c4b1 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import re -from typing import Generator +from typing import Any, Generator from unittest.mock import Mock from twisted.internet import defer @@ -49,15 +49,13 @@ def setUp(self) -> None: @defer.inlineCallbacks def test_regex_user_id_prefix_match( self, - ) -> Generator["defer.Deferred[object]", object, None]: + ) -> Generator["defer.Deferred[Any]", object, None]: self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" self.assertTrue( ( - yield defer.ensureDeferred( - self.service.is_interested_in_event( - self.event.event_id, self.event, self.store - ) + yield self.service.is_interested_in_event( + self.event.event_id, self.event, self.store ) ) ) @@ -65,15 +63,13 @@ def test_regex_user_id_prefix_match( @defer.inlineCallbacks def test_regex_user_id_prefix_no_match( self, - ) -> Generator["defer.Deferred[object]", object, None]: + ) -> Generator["defer.Deferred[Any]", object, None]: self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@someone_else:matrix.org" self.assertFalse( ( - yield defer.ensureDeferred( - self.service.is_interested_in_event( - self.event.event_id, self.event, self.store - ) + yield self.service.is_interested_in_event( + self.event.event_id, self.event, self.store ) ) ) @@ -81,17 +77,15 @@ def test_regex_user_id_prefix_no_match( @defer.inlineCallbacks def test_regex_room_member_is_checked( self, - ) -> Generator["defer.Deferred[object]", object, None]: + ) -> Generator["defer.Deferred[Any]", object, None]: self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@someone_else:matrix.org" self.event.type = "m.room.member" self.event.state_key = "@irc_foobar:matrix.org" self.assertTrue( ( - yield defer.ensureDeferred( - self.service.is_interested_in_event( - self.event.event_id, self.event, self.store - ) + yield self.service.is_interested_in_event( + self.event.event_id, self.event, self.store ) ) ) @@ -99,17 +93,15 @@ def test_regex_room_member_is_checked( @defer.inlineCallbacks def test_regex_room_id_match( self, - ) -> Generator["defer.Deferred[object]", object, None]: + ) -> Generator["defer.Deferred[Any]", object, None]: self.service.namespaces[ApplicationService.NS_ROOMS].append( _regex("!some_prefix.*some_suffix:matrix.org") ) self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org" self.assertTrue( ( - yield defer.ensureDeferred( - self.service.is_interested_in_event( - self.event.event_id, self.event, self.store - ) + yield self.service.is_interested_in_event( + self.event.event_id, self.event, self.store ) ) ) @@ -117,25 +109,21 @@ def test_regex_room_id_match( @defer.inlineCallbacks def test_regex_room_id_no_match( self, - ) -> Generator["defer.Deferred[object]", object, None]: + ) -> Generator["defer.Deferred[Any]", object, None]: self.service.namespaces[ApplicationService.NS_ROOMS].append( _regex("!some_prefix.*some_suffix:matrix.org") ) self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org" self.assertFalse( ( - yield defer.ensureDeferred( - self.service.is_interested_in_event( - self.event.event_id, self.event, self.store - ) + yield self.service.is_interested_in_event( + self.event.event_id, self.event, self.store ) ) ) @defer.inlineCallbacks - def test_regex_alias_match( - self, - ) -> Generator["defer.Deferred[object]", object, None]: + def test_regex_alias_match(self) -> Generator["defer.Deferred[Any]", object, None]: self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) @@ -145,10 +133,8 @@ def test_regex_alias_match( self.store.get_local_users_in_room = simple_async_mock([]) self.assertTrue( ( - yield defer.ensureDeferred( - self.service.is_interested_in_event( - self.event.event_id, self.event, self.store - ) + yield self.service.is_interested_in_event( + self.event.event_id, self.event, self.store ) ) ) @@ -192,7 +178,7 @@ def test_exclusive_room(self) -> None: @defer.inlineCallbacks def test_regex_alias_no_match( self, - ) -> Generator["defer.Deferred[object]", object, None]: + ) -> Generator["defer.Deferred[Any]", object, None]: self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) @@ -213,7 +199,7 @@ def test_regex_alias_no_match( @defer.inlineCallbacks def test_regex_multiple_matches( self, - ) -> Generator["defer.Deferred[object]", object, None]: + ) -> Generator["defer.Deferred[Any]", object, None]: self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) @@ -223,18 +209,14 @@ def test_regex_multiple_matches( self.store.get_local_users_in_room = simple_async_mock([]) self.assertTrue( ( - yield defer.ensureDeferred( - self.service.is_interested_in_event( - self.event.event_id, self.event, self.store - ) + yield self.service.is_interested_in_event( + self.event.event_id, self.event, self.store ) ) ) @defer.inlineCallbacks - def test_interested_in_self( - self, - ) -> Generator["defer.Deferred[object]", object, None]: + def test_interested_in_self(self) -> Generator["defer.Deferred[Any]", object, None]: # make sure invites get through self.service.sender = "@appservice:name" self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) @@ -243,18 +225,14 @@ def test_interested_in_self( self.event.state_key = self.service.sender self.assertTrue( ( - yield defer.ensureDeferred( - self.service.is_interested_in_event( - self.event.event_id, self.event, self.store - ) + yield self.service.is_interested_in_event( + self.event.event_id, self.event, self.store ) ) ) @defer.inlineCallbacks - def test_member_list_match( - self, - ) -> Generator["defer.Deferred[object]", object, None]: + def test_member_list_match(self) -> Generator["defer.Deferred[Any]", object, None]: self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) # Note that @irc_fo:here is the AS user. self.store.get_local_users_in_room = simple_async_mock( @@ -265,10 +243,8 @@ def test_member_list_match( self.event.sender = "@xmpp_foobar:matrix.org" self.assertTrue( ( - yield defer.ensureDeferred( - self.service.is_interested_in_event( - self.event.event_id, self.event, self.store - ) + yield self.service.is_interested_in_event( + self.event.event_id, self.event, self.store ) ) ) diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py index db9ee9955e93..2fab84a52939 100644 --- a/tests/storage/test_transactions.py +++ b/tests/storage/test_transactions.py @@ -33,15 +33,14 @@ def test_get_set_transactions(self) -> None: destination retries, as well as testing tht we can set and get correctly. """ - d = self.store.get_destination_retry_timings("example.com") - r = self.get_success(d) + r = self.get_success(self.store.get_destination_retry_timings("example.com")) self.assertIsNone(r) - d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100) - self.get_success(d) + self.get_success( + self.store.set_destination_retry_timings("example.com", 1000, 50, 100) + ) - d = self.store.get_destination_retry_timings("example.com") - r = self.get_success(d) + r = self.get_success(self.store.get_destination_retry_timings("example.com")) self.assertEqual( DestinationRetryTimings( From 6310a361f13781333c316ddd763b59d747b34232 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 23 May 2023 15:34:00 -0400 Subject: [PATCH 3/6] Newsfragment --- changelog.d/15658.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/15658.misc diff --git a/changelog.d/15658.misc b/changelog.d/15658.misc new file mode 100644 index 000000000000..93ceaeafc9b9 --- /dev/null +++ b/changelog.d/15658.misc @@ -0,0 +1 @@ +Improve type hints. From 4b3308ce974eaa0c3aa568da0fd3e4050d597c0e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 23 May 2023 16:07:10 -0400 Subject: [PATCH 4/6] Lint --- scripts-dev/mypy_synapse_plugin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index da19e0250cdb..019eb8f8a8a3 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -25,7 +25,6 @@ from mypy.types import CallableType, Instance, NoneType, UnionType - class SynapsePlugin(Plugin): def get_method_signature_hook( self, fullname: str From e9d02f472d128b845906858dfb0b7b053acd7cff Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 24 May 2023 08:07:47 -0400 Subject: [PATCH 5/6] Handle Awaitable in addition to Coroutine. --- scripts-dev/mypy_synapse_plugin.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index 019eb8f8a8a3..eaa5528d8680 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -101,15 +101,18 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: # If it is already a Deferred, nothing to do. ret_type = signature.ret_type else: - # If a coroutine, wrap the coroutine's return type in a Deferred. - if ( - isinstance(signature.ret_type, Instance) - and signature.ret_type.type.fullname == "typing.Coroutine" - ): - ret_arg = signature.ret_type.args[2] - - # Otherwise, wrap the return type in a Deferred. - else: + ret_arg = None + if isinstance(signature.ret_type, Instance): + # If a coroutine, wrap the coroutine's return type in a Deferred. + if signature.ret_type.type.fullname == "typing.Coroutine": + ret_arg = signature.ret_type.args[2] + + # If an awaitable, wrap the awaitable's final value in a Deferred. + elif signature.ret_type.type.fullname == "typing.Awaitable": + ret_arg = signature.ret_type.args[0] + + # Otherwise, wrap the return value in a Deferred. + if ret_arg is None: ret_arg = signature.ret_type # This should be able to use ctx.api.lookup_typeinfo, but that doesn't seem From 38ce596f5b828c0b89919dd97b772314b085ecf2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 24 May 2023 08:11:17 -0400 Subject: [PATCH 6/6] Expand comments. --- scripts-dev/mypy_synapse_plugin.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index eaa5528d8680..8058e9c993b1 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -115,8 +115,11 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: if ret_arg is None: ret_arg = signature.ret_type - # This should be able to use ctx.api.lookup_typeinfo, but that doesn't seem - # to find the correct symbol. + # This should be able to use ctx.api.named_generic_type, but that doesn't seem + # to find the correct symbol for anything more than 1 module deep. + # + # modules is not part of CheckerPluginInterface. The following is a combination + # of TypeChecker.named_generic_type and TypeChecker.lookup_typeinfo. sym = ctx.api.modules["twisted.internet.defer"].names.get("Deferred") # type: ignore[attr-defined] ret_type = Instance(sym.node, [remove_instance_last_known_values(ret_arg)])