Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Require type hints in the handlers module #10831

Merged
merged 14 commits into from
Sep 20, 2021
1 change: 1 addition & 0 deletions changelog.d/10831.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to handlers.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ files =
tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py

[mypy-synapse.handlers.*]
disallow_untyped_defs = True

[mypy-synapse.rest.*]
disallow_untyped_defs = True

Expand Down
4 changes: 2 additions & 2 deletions synapse/config/password_auth_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List
from typing import Any, List, Tuple, Type

from synapse.util.module_loader import load_module

Expand All @@ -25,7 +25,7 @@ class PasswordAuthProviderConfig(Config):
section = "authproviders"

def read_config(self, config, **kwargs):
self.password_providers: List[Any] = []
self.password_providers: List[Tuple[Type, Any]] = []
providers = []

# We want to be backwards compatible with the old `ldap_config`
Expand Down
14 changes: 10 additions & 4 deletions synapse/handlers/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import TYPE_CHECKING, Optional

from synapse.api.ratelimiting import Ratelimiter
from synapse.types import Requester

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -63,16 +64,21 @@ def __init__(self, hs: "HomeServer"):

self.event_builder_factory = hs.get_event_builder_factory()

async def ratelimit(self, requester, update=True, is_admin_redaction=False):
async def ratelimit(
self,
requester: Requester,
update: bool = True,
is_admin_redaction: bool = False,
) -> None:
"""Ratelimits requests.

Args:
requester (Requester)
update (bool): Whether to record that a request is being processed.
requester
update: Whether to record that a request is being processed.
Set to False when doing multiple checks for one request (e.g.
to check up front if we would reject the request), and set to
True for the last call for a given request.
is_admin_redaction (bool): Whether this is a room admin/moderator
is_admin_redaction: Whether this is a room admin/moderator
redacting an event. If so then we may apply different
ratelimits depending on config.

Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from typing import TYPE_CHECKING, List, Tuple
from typing import TYPE_CHECKING, Any, List, Tuple

from synapse.replication.http.account_data import (
ReplicationAddTagRestServlet,
Expand Down Expand Up @@ -171,7 +171,7 @@ def get_current_key(self, direction: str = "f") -> int:
return self.store.get_max_account_data_stream_id()

async def get_new_events(
self, user: UserID, from_key: int, **kwargs
self, user: UserID, from_key: int, **kwargs: Any
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this kwargs for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that get_new_events generally takes some random arguments which most things don't care about. I did verify that there are other arguments going into this call, I'll double check if we can tighten this up.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be cleaned up, but I'd prefer to do it in a separate PR if that's OK!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I resolved the other threads about this to keep the conversation in one spot.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, take this to another PR if you prefer.

) -> Tuple[List[JsonDict], int]:
user_id = user.to_string()
last_stream_id = from_key
Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/account_validity.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def register_account_validity_callbacks(
on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
):
) -> None:
"""Register callbacks from module for each hook."""
if is_user_expired is not None:
self._is_user_expired_callbacks.append(is_user_expired)
Expand Down Expand Up @@ -165,7 +165,7 @@ async def is_user_expired(self, user_id: str) -> bool:

return False

async def on_user_registration(self, user_id: str):
async def on_user_registration(self, user_id: str) -> None:
"""Tell third-party modules about a user's registration.

Args:
Expand Down
18 changes: 9 additions & 9 deletions synapse/handlers/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Union

from prometheus_client import Counter

Expand Down Expand Up @@ -58,7 +58,7 @@ def __init__(self, hs: "HomeServer"):
self.current_max = 0
self.is_processing = False

def notify_interested_services(self, max_token: RoomStreamToken):
def notify_interested_services(self, max_token: RoomStreamToken) -> None:
"""Notifies (pushes) all application services interested in this event.

Pushing is done asynchronously, so this method won't block for any
Expand All @@ -82,7 +82,7 @@ def notify_interested_services(self, max_token: RoomStreamToken):
self._notify_interested_services(max_token)

@wrap_as_background_process("notify_interested_services")
async def _notify_interested_services(self, max_token: RoomStreamToken):
async def _notify_interested_services(self, max_token: RoomStreamToken) -> None:
with Measure(self.clock, "notify_interested_services"):
self.is_processing = True
try:
Expand All @@ -100,7 +100,7 @@ async def _notify_interested_services(self, max_token: RoomStreamToken):
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)

async def handle_event(event):
async def handle_event(event: EventBase) -> None:
# Gather interested services
services = await self._get_services_for_event(event)
if len(services) == 0:
Expand All @@ -116,9 +116,9 @@ async def handle_event(event):

if not self.started_scheduler:

async def start_scheduler():
async def start_scheduler() -> None:
try:
return await self.scheduler.start()
await self.scheduler.start()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scheduler.start doesn't return anything so no need to pass that back.

except Exception:
logger.error("Application Services Failure")

Expand All @@ -137,7 +137,7 @@ async def start_scheduler():
"appservice_sender"
).observe((now - ts) / 1000)

async def handle_room_events(events):
async def handle_room_events(events: Iterable[EventBase]) -> None:
for event in events:
await handle_event(event)

Expand Down Expand Up @@ -184,7 +184,7 @@ def notify_interested_services_ephemeral(
stream_key: str,
new_token: Optional[int],
users: Optional[Collection[Union[str, UserID]]] = None,
):
) -> None:
"""This is called by the notifier in the background
when a ephemeral event handled by the homeserver.

Expand Down Expand Up @@ -226,7 +226,7 @@ async def _notify_interested_services_ephemeral(
stream_key: str,
new_token: Optional[int],
users: Collection[Union[str, UserID]],
):
) -> None:
logger.debug("Checking interested services for %s" % (stream_key))
with Measure(self.clock, "notify_interested_services_ephemeral"):
for service in services:
Expand Down
45 changes: 24 additions & 21 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Mapping,
Optional,
Tuple,
Type,
Union,
cast,
)
Expand Down Expand Up @@ -439,7 +440,7 @@ async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:

return ui_auth_types

def get_enabled_auth_types(self):
def get_enabled_auth_types(self) -> Iterable[str]:
"""Return the enabled user-interactive authentication types

Returns the UI-Auth types which are supported by the homeserver's current
Expand Down Expand Up @@ -702,7 +703,7 @@ async def get_session_data(
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))

async def _expire_old_sessions(self):
async def _expire_old_sessions(self) -> None:
"""
Invalidate any user interactive authentication sessions that have expired.
"""
Expand Down Expand Up @@ -1352,7 +1353,7 @@ async def validate_short_term_login_token(
await self.auth.check_auth_blocking(res.user_id)
return res

async def delete_access_token(self, access_token: str):
async def delete_access_token(self, access_token: str) -> None:
"""Invalidate a single access token

Args:
Expand Down Expand Up @@ -1381,7 +1382,7 @@ async def delete_access_tokens_for_user(
user_id: str,
except_token_id: Optional[int] = None,
device_id: Optional[str] = None,
):
) -> None:
"""Invalidate access tokens belonging to a user

Args:
Expand Down Expand Up @@ -1409,7 +1410,7 @@ async def delete_access_tokens_for_user(

async def add_threepid(
self, user_id: str, medium: str, address: str, validated_at: int
):
) -> None:
# check if medium has a valid value
if medium not in ["email", "msisdn"]:
raise SynapseError(
Expand Down Expand Up @@ -1480,7 +1481,7 @@ async def hash(self, password: str) -> str:
Hashed password.
"""

def _do_hash():
def _do_hash() -> str:
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)

Expand All @@ -1504,7 +1505,7 @@ async def validate_hash(
Whether self.hash(password) == stored_hash.
"""

def _do_validate_hash(checked_hash: bytes):
def _do_validate_hash(checked_hash: bytes) -> bool:
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)

Expand Down Expand Up @@ -1581,7 +1582,7 @@ async def complete_sso_login(
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
):
) -> None:
"""Having figured out a mxid for this user, complete the HTTP request

Args:
Expand Down Expand Up @@ -1627,7 +1628,7 @@ def _complete_sso_login(
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
user_profile_data: Optional[ProfileInfo] = None,
):
) -> None:
"""
The synchronous portion of complete_sso_login.

Expand Down Expand Up @@ -1726,17 +1727,17 @@ def _expire_sso_extra_attributes(self) -> None:
del self._extra_attributes[user_id]

@staticmethod
def add_query_param_to_url(url: str, param_name: str, param: Any):
def add_query_param_to_url(url: str, param_name: str, param: Any) -> str:
url_parts = list(urllib.parse.urlparse(url))
query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
query.append((param_name, param))
url_parts[4] = urllib.parse.urlencode(query)
return urllib.parse.urlunparse(url_parts)


@attr.s(slots=True)
@attr.s(slots=True, auto_attribs=True)
class MacaroonGenerator:
hs = attr.ib()
hs: "HomeServer"

def generate_guest_access_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id)
Expand Down Expand Up @@ -1816,15 +1817,17 @@ class PasswordProvider:
"""

@classmethod
def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
def load(
cls, module: Type, config: JsonDict, module_api: ModuleApi
) -> "PasswordProvider":
try:
pp = module(config=config, account_handler=module_api)
except Exception as e:
logger.error("Error while initializing %r: %s", module, e)
raise
return cls(pp, module_api)

def __init__(self, pp, module_api: ModuleApi):
def __init__(self, pp: "PasswordProvider", module_api: ModuleApi):
self._pp = pp
self._module_api = module_api

Expand All @@ -1838,7 +1841,7 @@ def __init__(self, pp, module_api: ModuleApi):
if g:
self._supported_login_types.update(g())

def __str__(self):
def __str__(self) -> str:
return str(self._pp)

def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
Expand Down Expand Up @@ -1876,19 +1879,19 @@ async def check_auth(
"""
# first grandfather in a call to check_password
if login_type == LoginType.PASSWORD:
g = getattr(self._pp, "check_password", None)
if g:
check_password = getattr(self._pp, "check_password", None)
if check_password:
qualified_user_id = self._module_api.get_qualified_user_id(username)
is_valid = await self._pp.check_password(
is_valid = await check_password(
qualified_user_id, login_dict["password"]
)
if is_valid:
return qualified_user_id, None

g = getattr(self._pp, "check_auth", None)
if not g:
check_auth = getattr(self._pp, "check_auth", None)
if not check_auth:
return None
result = await g(username, login_type, login_dict)
result = await check_auth(username, login_type, login_dict)

# Check if the return value is a str or a tuple
if isinstance(result, str):
Expand Down
18 changes: 8 additions & 10 deletions synapse/handlers/cas.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,20 @@
class CasError(Exception):
"""Used to catch errors when validating the CAS ticket."""

def __init__(self, error, error_description=None):
def __init__(self, error: str, error_description: Optional[str] = None):
self.error = error
self.error_description = error_description

def __str__(self):
def __str__(self) -> str:
if self.error_description:
return f"{self.error}: {self.error_description}"
return self.error


@attr.s(slots=True, frozen=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class CasResponse:
username = attr.ib(type=str)
attributes = attr.ib(type=Dict[str, List[Optional[str]]])
username: str
attributes: Dict[str, List[Optional[str]]]


class CasHandler:
Expand Down Expand Up @@ -133,11 +133,9 @@ async def _validate_ticket(
body = pde.response
except HttpResponseException as e:
description = (
(
'Authorization server responded with a "{status}" error '
"while exchanging the authorization code."
).format(status=e.code),
)
'Authorization server responded with a "{status}" error '
"while exchanging the authorization code."
).format(status=e.code)
Comment on lines 135 to +138
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a real bug, it was passing a tuple instead of a string.

raise CasError("server_error", description) from e

return self._parse_cas_response(body)
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def __init__(self, hs: "HomeServer"):

hs.get_distributor().observe("user_left_room", self.user_left_room)

def _check_device_name_length(self, name: Optional[str]):
def _check_device_name_length(self, name: Optional[str]) -> None:
"""
Checks whether a device name is longer than the maximum allowed length.

Expand Down
Loading