Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert the registration handler to async/await. #7649

Merged
merged 1 commit into from
Jun 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7649.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert registration handler to async/await.
107 changes: 41 additions & 66 deletions synapse/handlers/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
"""Contains functions for registering clients."""
import logging

from twisted.internet import defer

from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, LoginType
from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
Expand Down Expand Up @@ -75,8 +73,9 @@ def __init__(self, hs):

self.session_lifetime = hs.config.session_lifetime

@defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None, assigned_user_id=None):
async def check_username(
self, localpart, guest_access_token=None, assigned_user_id=None
):
if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
400,
Expand Down Expand Up @@ -113,13 +112,13 @@ def check_username(self, localpart, guest_access_token=None, assigned_user_id=No
Codes.INVALID_USERNAME,
)

users = yield self.store.get_users_by_id_case_insensitive(user_id)
users = await self.store.get_users_by_id_case_insensitive(user_id)
if users:
if not guest_access_token:
raise SynapseError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE
)
user_data = yield self.auth.get_user_by_access_token(guest_access_token)
user_data = await self.auth.get_user_by_access_token(guest_access_token)
if not user_data["is_guest"] or user_data["user"].localpart != localpart:
raise AuthError(
403,
Expand All @@ -137,8 +136,7 @@ def check_username(self, localpart, guest_access_token=None, assigned_user_id=No
except ValueError:
pass

@defer.inlineCallbacks
def register_user(
async def register_user(
self,
localpart=None,
password_hash=None,
Expand Down Expand Up @@ -169,18 +167,18 @@ def register_user(
by_admin (bool): True if this registration is being made via the
admin api, otherwise False.
Returns:
Deferred[str]: user_id
str: user_id
Raises:
SynapseError if there was a problem registering.
"""
yield self.check_registration_ratelimit(address)
self.check_registration_ratelimit(address)

# do not check_auth_blocking if the call is coming through the Admin API
if not by_admin:
yield self.auth.check_auth_blocking(threepid=threepid)
await self.auth.check_auth_blocking(threepid=threepid)

if localpart is not None:
yield self.check_username(localpart, guest_access_token=guest_access_token)
await self.check_username(localpart, guest_access_token=guest_access_token)

was_guest = guest_access_token is not None

Expand All @@ -194,7 +192,7 @@ def register_user(
elif default_display_name is None:
default_display_name = localpart

yield self.register_with_store(
await self.register_with_store(
user_id=user_id,
password_hash=password_hash,
was_guest=was_guest,
Expand All @@ -206,11 +204,9 @@ def register_user(
)

if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(localpart)
yield defer.ensureDeferred(
self.user_directory_handler.handle_local_profile_change(
user_id, profile
)
profile = await self.store.get_profileinfo(localpart)
await self.user_directory_handler.handle_local_profile_change(
user_id, profile
)

else:
Expand All @@ -222,14 +218,14 @@ def register_user(
if fail_count > 10:
raise SynapseError(500, "Unable to find a suitable guest user ID")

localpart = yield self._generate_user_id()
localpart = await self._generate_user_id()
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
yield self.check_user_id_not_appservice_exclusive(user_id)
self.check_user_id_not_appservice_exclusive(user_id)
if default_display_name is None:
default_display_name = localpart
try:
yield self.register_with_store(
await self.register_with_store(
user_id=user_id,
password_hash=password_hash,
make_guest=make_guest,
Expand All @@ -252,7 +248,7 @@ def register_user(
user_id,
)
else:
yield defer.ensureDeferred(self._auto_join_rooms(user_id))
await self._auto_join_rooms(user_id)
else:
logger.info(
"Skipping auto-join for %s because consent is required at registration",
Expand All @@ -270,7 +266,7 @@ def register_user(
}

# Bind email to new account
yield self._register_email_threepid(user_id, threepid_dict, None)
await self._register_email_threepid(user_id, threepid_dict, None)

return user_id

Expand Down Expand Up @@ -335,8 +331,7 @@ async def post_consent_actions(self, user_id):
"""
await self._auto_join_rooms(user_id)

@defer.inlineCallbacks
def appservice_register(self, user_localpart, as_token):
async def appservice_register(self, user_localpart, as_token):
user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string()
service = self.store.get_app_service_by_token(as_token)
Expand All @@ -351,11 +346,9 @@ def appservice_register(self, user_localpart, as_token):

service_id = service.id if service.is_exclusive_user(user_id) else None

yield self.check_user_id_not_appservice_exclusive(
user_id, allowed_appservice=service
)
self.check_user_id_not_appservice_exclusive(user_id, allowed_appservice=service)

yield self.register_with_store(
await self.register_with_store(
user_id=user_id,
password_hash="",
appservice_id=service_id,
Expand Down Expand Up @@ -387,13 +380,12 @@ def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=Non
errcode=Codes.EXCLUSIVE,
)

@defer.inlineCallbacks
def _generate_user_id(self):
async def _generate_user_id(self):
if self._next_generated_user_id is None:
with (yield self._generate_user_id_linearizer.queue(())):
with await self._generate_user_id_linearizer.queue(()):
if self._next_generated_user_id is None:
self._next_generated_user_id = (
yield self.store.find_next_generated_user_id_localpart()
await self.store.find_next_generated_user_id_localpart()
)

id = self._next_generated_user_id
Expand Down Expand Up @@ -496,8 +488,9 @@ def register_with_store(
user_type=user_type,
)

@defer.inlineCallbacks
def register_device(self, user_id, device_id, initial_display_name, is_guest=False):
async def register_device(
self, user_id, device_id, initial_display_name, is_guest=False
):
"""Register a device for a user and generate an access token.

The access token will be limited by the homeserver's session_lifetime config.
Expand All @@ -511,11 +504,11 @@ def register_device(self, user_id, device_id, initial_display_name, is_guest=Fal
is_guest (bool): Whether this is a guest account

Returns:
defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
tuple[str, str]: Tuple of device ID and access token
"""

if self.hs.config.worker_app:
r = yield self._register_device_client(
r = await self._register_device_client(
user_id=user_id,
device_id=device_id,
initial_display_name=initial_display_name,
Expand All @@ -531,7 +524,7 @@ def register_device(self, user_id, device_id, initial_display_name, is_guest=Fal
)
valid_until_ms = self.clock.time_msec() + self.session_lifetime

device_id = yield self.device_handler.check_device_registered(
device_id = await self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
if is_guest:
Expand All @@ -540,10 +533,8 @@ def register_device(self, user_id, device_id, initial_display_name, is_guest=Fal
user_id, ["guest = true"]
)
else:
access_token = yield defer.ensureDeferred(
self._auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id, valid_until_ms=valid_until_ms
)
access_token = await self._auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id, valid_until_ms=valid_until_ms
)

return (device_id, access_token)
Expand Down Expand Up @@ -594,8 +585,7 @@ async def _on_user_consented(self, user_id, consent_version):
await self.store.user_set_consent_version(user_id, consent_version)
await self.post_consent_actions(user_id)

@defer.inlineCallbacks
def _register_email_threepid(self, user_id, threepid, token):
async def _register_email_threepid(self, user_id, threepid, token):
"""Add an email address as a 3pid identifier

Also adds an email pusher for the email address, if configured in the
Expand All @@ -608,22 +598,15 @@ def _register_email_threepid(self, user_id, threepid, token):
threepid (object): m.login.email.identity auth response
token (str|None): access_token for the user, or None if not logged
in.
Returns:
defer.Deferred:
"""
reqd = ("medium", "address", "validated_at")
if any(x not in threepid for x in reqd):
# This will only happen if the ID server returns a malformed response
logger.info("Can't add incomplete 3pid")
return

yield defer.ensureDeferred(
self._auth_handler.add_threepid(
user_id,
threepid["medium"],
threepid["address"],
threepid["validated_at"],
)
await self._auth_handler.add_threepid(
user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
)

# And we add an email pusher for them by default, but only
Expand All @@ -639,10 +622,10 @@ def _register_email_threepid(self, user_id, threepid, token):
# It would really make more sense for this to be passed
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
user_tuple = yield self.store.get_user_by_access_token(token)
user_tuple = await self.store.get_user_by_access_token(token)
token_id = user_tuple["token_id"]

yield self.pusher_pool.add_pusher(
await self.pusher_pool.add_pusher(
user_id=user_id,
access_token=token_id,
kind="email",
Expand All @@ -654,17 +637,14 @@ def _register_email_threepid(self, user_id, threepid, token):
data={},
)

@defer.inlineCallbacks
def _register_msisdn_threepid(self, user_id, threepid):
async def _register_msisdn_threepid(self, user_id, threepid):
"""Add a phone number as a 3pid identifier

Must be called on master.

Args:
user_id (str): id of user
threepid (object): m.login.msisdn auth response
Returns:
defer.Deferred:
"""
try:
assert_params_in_dict(threepid, ["medium", "address", "validated_at"])
Expand All @@ -675,11 +655,6 @@ def _register_msisdn_threepid(self, user_id, threepid):
return None
raise

yield defer.ensureDeferred(
self._auth_handler.add_threepid(
user_id,
threepid["medium"],
threepid["address"],
threepid["validated_at"],
)
await self._auth_handler.add_threepid(
user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
)
8 changes: 6 additions & 2 deletions synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,12 @@ def register_user(self, localpart, displayname=None, emails=[]):
Returns:
Deferred[str]: user_id
"""
return self._hs.get_registration_handler().register_user(
localpart=localpart, default_display_name=displayname, bind_emails=emails
return defer.ensureDeferred(
self._hs.get_registration_handler().register_user(
localpart=localpart,
default_display_name=displayname,
bind_emails=emails,
)
)

def register_device(self, user_id, device_id=None, initial_display_name=None):
Expand Down