Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Improve type hints for cached decorator #15658

Merged
merged 6 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/15658.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints.
28 changes: 27 additions & 1 deletion scripts-dev/mypy_synapse_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@

from typing import Callable, Optional, Type

from mypy.erasetype import remove_instance_last_known_values
from mypy.nodes import ARG_NAMED_OPT
from mypy.plugin import MethodSigContext, Plugin
from mypy.typeops import bind_self
from mypy.types import CallableType, NoneType, UnionType
from mypy.types import CallableType, Instance, NoneType, UnionType


class SynapsePlugin(Plugin):
Expand Down Expand Up @@ -92,10 +93,35 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
arg_names.append("on_invalidate")
arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg.

# Finally we ensure the return type is a Deferred.
if (
isinstance(signature.ret_type, Instance)
and signature.ret_type.type.fullname == "twisted.internet.defer.Deferred"
):
# If it is already a Deferred, nothing to do.
ret_type = signature.ret_type
else:
# If a coroutine, wrap the coroutine's return type in a Deferred.
if (
isinstance(signature.ret_type, Instance)
and signature.ret_type.type.fullname == "typing.Coroutine"
clokep marked this conversation as resolved.
Show resolved Hide resolved
):
ret_arg = signature.ret_type.args[2]

# Otherwise, wrap the return type in a Deferred.
else:
ret_arg = signature.ret_type

# This should be able to use ctx.api.lookup_typeinfo, but that doesn't seem
# to find the correct symbol.
sym = ctx.api.modules["twisted.internet.defer"].names.get("Deferred") # type: ignore[attr-defined]
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
ret_type = Instance(sym.node, [remove_instance_last_known_values(ret_arg)])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be 100% wrong, but mypy's documentation on plugins is lacking.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you leave a comment explaining what remove_instance_last_known_values is doing (or at least, what breaks without it)?

Copy link
Member Author

@clokep clokep May 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just copied out of lookup_typeinfo (or things that calls), I'm not really sure what it is doing TBH. I guess my comment wasn't clear enough where this code comes from.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expanded the comment to say where this comes from.


signature = signature.copy_modified(
arg_types=arg_types,
arg_names=arg_names,
arg_kinds=arg_kinds,
ret_type=ret_type,
)

return signature
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/roommember.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,7 +1099,7 @@ async def _get_joined_hosts(
# `get_joined_hosts` is called with the "current" state group for the
# room, and so consecutive calls will be for consecutive state groups
# which point to the previous state group.
cache = await self._get_joined_hosts_cache(room_id) # type: ignore[misc]
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
cache = await self._get_joined_hosts_cache(room_id)

# If the state group in the cache matches, we already have the data we need.
if state_entry.state_group == cache.state_group:
Expand Down
6 changes: 4 additions & 2 deletions synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,9 @@ def __init__(
self.iterable = iterable
self.prune_unread_entries = prune_unread_entries

def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
def __get__(
self, obj: Optional[Any], owner: Optional[Type]
) -> Callable[..., "defer.Deferred[Any]"]:
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
cache: DeferredCache[CacheKey, Any] = DeferredCache(
name=self.name,
max_entries=self.max_entries,
Expand All @@ -232,7 +234,7 @@ def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., An
get_cache_key = self.cache_key_builder

@functools.wraps(self.orig)
def _wrapped(*args: Any, **kwargs: Any) -> Any:
def _wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Any]":
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
Expand Down
82 changes: 29 additions & 53 deletions tests/appservice/test_appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Generator
from typing import Any, Generator
from unittest.mock import Mock

from twisted.internet import defer
Expand Down Expand Up @@ -49,93 +49,81 @@ def setUp(self) -> None:
@defer.inlineCallbacks
def test_regex_user_id_prefix_match(
self,
) -> Generator["defer.Deferred[object]", object, None]:
) -> Generator["defer.Deferred[Any]", object, None]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deferred[object] can't yield Deferred[bool], but using Deferred[bool] doesn't work with inlineCallbacks.

self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@irc_foobar:matrix.org"
self.assertTrue(
(
yield defer.ensureDeferred(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
yield self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is already a Deferred, so we don't need to wrap it in ensureDeferred here.

)
)

@defer.inlineCallbacks
def test_regex_user_id_prefix_no_match(
self,
) -> Generator["defer.Deferred[object]", object, None]:
) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@someone_else:matrix.org"
self.assertFalse(
(
yield defer.ensureDeferred(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
yield self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)

@defer.inlineCallbacks
def test_regex_room_member_is_checked(
self,
) -> Generator["defer.Deferred[object]", object, None]:
) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@someone_else:matrix.org"
self.event.type = "m.room.member"
self.event.state_key = "@irc_foobar:matrix.org"
self.assertTrue(
(
yield defer.ensureDeferred(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
yield self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)

@defer.inlineCallbacks
def test_regex_room_id_match(
self,
) -> Generator["defer.Deferred[object]", object, None]:
) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!some_prefix.*some_suffix:matrix.org")
)
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
self.assertTrue(
(
yield defer.ensureDeferred(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
yield self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)

@defer.inlineCallbacks
def test_regex_room_id_no_match(
self,
) -> Generator["defer.Deferred[object]", object, None]:
) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!some_prefix.*some_suffix:matrix.org")
)
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
self.assertFalse(
(
yield defer.ensureDeferred(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
yield self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)

@defer.inlineCallbacks
def test_regex_alias_match(
self,
) -> Generator["defer.Deferred[object]", object, None]:
def test_regex_alias_match(self) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
Expand All @@ -145,10 +133,8 @@ def test_regex_alias_match(
self.store.get_local_users_in_room = simple_async_mock([])
self.assertTrue(
(
yield defer.ensureDeferred(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
yield self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)
Expand Down Expand Up @@ -192,7 +178,7 @@ def test_exclusive_room(self) -> None:
@defer.inlineCallbacks
def test_regex_alias_no_match(
self,
) -> Generator["defer.Deferred[object]", object, None]:
) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
Expand All @@ -213,7 +199,7 @@ def test_regex_alias_no_match(
@defer.inlineCallbacks
def test_regex_multiple_matches(
self,
) -> Generator["defer.Deferred[object]", object, None]:
) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
Expand All @@ -223,18 +209,14 @@ def test_regex_multiple_matches(
self.store.get_local_users_in_room = simple_async_mock([])
self.assertTrue(
(
yield defer.ensureDeferred(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
yield self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)

@defer.inlineCallbacks
def test_interested_in_self(
self,
) -> Generator["defer.Deferred[object]", object, None]:
def test_interested_in_self(self) -> Generator["defer.Deferred[Any]", object, None]:
# make sure invites get through
self.service.sender = "@appservice:name"
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
Expand All @@ -243,18 +225,14 @@ def test_interested_in_self(
self.event.state_key = self.service.sender
self.assertTrue(
(
yield defer.ensureDeferred(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
yield self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)

@defer.inlineCallbacks
def test_member_list_match(
self,
) -> Generator["defer.Deferred[object]", object, None]:
def test_member_list_match(self) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
# Note that @irc_fo:here is the AS user.
self.store.get_local_users_in_room = simple_async_mock(
Expand All @@ -265,10 +243,8 @@ def test_member_list_match(
self.event.sender = "@xmpp_foobar:matrix.org"
self.assertTrue(
(
yield defer.ensureDeferred(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
yield self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)
11 changes: 5 additions & 6 deletions tests/storage/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,14 @@ def test_get_set_transactions(self) -> None:
destination retries, as well as testing tht we can set and get
correctly.
"""
d = self.store.get_destination_retry_timings("example.com")
r = self.get_success(d)
r = self.get_success(self.store.get_destination_retry_timings("example.com"))
self.assertIsNone(r)

d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
self.get_success(d)
self.get_success(
self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
)

d = self.store.get_destination_retry_timings("example.com")
r = self.get_success(d)
r = self.get_success(self.store.get_destination_retry_timings("example.com"))
Comment on lines -37 to +43
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The get and set versions now properly have different return types, I just inlined them instead of having a d and a d2.


self.assertEqual(
DestinationRetryTimings(
Expand Down