Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Remove not needed database updates in modify user admin API #10627

Merged
merged 5 commits into from
Aug 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/10627.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Remove not needed database updates in modify user admin API.
8 changes: 6 additions & 2 deletions docs/admin_api/user_admin_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@ It returns a JSON body like the following:
"threepids": [
{
"medium": "email",
"address": "<user_mail_1>"
"address": "<user_mail_1>",
"added_at": 1586458409743,
"validated_at": 1586458409743
},
{
"medium": "email",
"address": "<user_mail_2>"
"address": "<user_mail_2>",
"added_at": 1586458409743,
"validated_at": 1586458409743
}
],
"avatar_url": "<avatar_url>",
Expand Down
55 changes: 35 additions & 20 deletions synapse/rest/admin/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,18 @@ async def on_PUT(
if not isinstance(deactivate, bool):
raise SynapseError(400, "'deactivated' parameter is not of type boolean")

# convert into List[Tuple[str, str]]
# convert List[Dict[str, str]] into Set[Tuple[str, str]]
if external_ids is not None:
new_external_ids = []
for external_id in external_ids:
new_external_ids.append(
(external_id["auth_provider"], external_id["external_id"])
)
new_external_ids = {
(external_id["auth_provider"], external_id["external_id"])
for external_id in external_ids
}

# convert List[Dict[str, str]] into Set[Tuple[str, str]]
if threepids is not None:
new_threepids = {
(threepid["medium"], threepid["address"]) for threepid in threepids
}

if user: # modify user
if "displayname" in body:
Expand All @@ -244,29 +249,39 @@ async def on_PUT(
)

if threepids is not None:
# remove old threepids from user
old_threepids = await self.store.user_get_threepids(user_id)
for threepid in old_threepids:
# get changed threepids (added and removed)
# convert List[Dict[str, Any]] into Set[Tuple[str, str]]
cur_threepids = {
(threepid["medium"], threepid["address"])
for threepid in await self.store.user_get_threepids(user_id)
}
add_threepids = new_threepids - cur_threepids
del_threepids = cur_threepids - new_threepids

# remove old threepids
for medium, address in del_threepids:
try:
await self.auth_handler.delete_threepid(
user_id, threepid["medium"], threepid["address"], None
user_id, medium, address, None
)
except Exception:
logger.exception("Failed to remove threepids")
raise SynapseError(500, "Failed to remove threepids")

# add new threepids to user
# add new threepids
current_time = self.hs.get_clock().time_msec()
for threepid in threepids:
for medium, address in add_threepids:
await self.auth_handler.add_threepid(
user_id, threepid["medium"], threepid["address"], current_time
user_id, medium, address, current_time
)

if external_ids is not None:
# get changed external_ids (added and removed)
cur_external_ids = await self.store.get_external_ids_by_user(user_id)
add_external_ids = set(new_external_ids) - set(cur_external_ids)
del_external_ids = set(cur_external_ids) - set(new_external_ids)
cur_external_ids = set(
await self.store.get_external_ids_by_user(user_id)
)
add_external_ids = new_external_ids - cur_external_ids
del_external_ids = cur_external_ids - new_external_ids

# remove old external_ids
for auth_provider, external_id in del_external_ids:
Expand Down Expand Up @@ -349,9 +364,9 @@ async def on_PUT(

if threepids is not None:
current_time = self.hs.get_clock().time_msec()
for threepid in threepids:
for medium, address in new_threepids:
await self.auth_handler.add_threepid(
user_id, threepid["medium"], threepid["address"], current_time
user_id, medium, address, current_time
)
if (
self.hs.config.email_enable_notifs
Expand All @@ -363,8 +378,8 @@ async def on_PUT(
kind="email",
app_id="m.email",
app_display_name="Email Notifications",
device_display_name=threepid["address"],
pushkey=threepid["address"],
device_display_name=address,
pushkey=address,
lang=None, # We don't know a user's language here
data={},
)
Expand Down
25 changes: 18 additions & 7 deletions synapse/storage/databases/main/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,16 +754,18 @@ async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[s
)
return user_id

def get_user_id_by_threepid_txn(self, txn, medium, address):
def get_user_id_by_threepid_txn(
self, txn, medium: str, address: str
) -> Optional[str]:
"""Returns user id from threepid

Args:
txn (cursor):
medium (str): threepid medium e.g. email
address (str): threepid address e.g. [email protected]
medium: threepid medium e.g. email
address: threepid address e.g. [email protected]

Returns:
str|None: user id or None if no user id/threepid mapping exists
user id, or None if no user id/threepid mapping exists
"""
ret = self.db_pool.simple_select_one_txn(
txn,
Expand All @@ -776,22 +778,31 @@ def get_user_id_by_threepid_txn(self, txn, medium, address):
return ret["user_id"]
return None

async def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
async def user_add_threepid(
self,
user_id: str,
medium: str,
address: str,
validated_at: int,
added_at: int,
) -> None:
await self.db_pool.simple_upsert(
"user_threepids",
{"medium": medium, "address": address},
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
)

async def user_get_threepids(self, user_id):
async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]:
return await self.db_pool.simple_select_list(
"user_threepids",
{"user_id": user_id},
["medium", "address", "validated_at", "added_at"],
"user_get_threepids",
)

async def user_delete_threepid(self, user_id, medium, address) -> None:
async def user_delete_threepid(
self, user_id: str, medium: str, address: str
) -> None:
await self.db_pool.simple_delete(
"user_threepids",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
Expand Down
62 changes: 58 additions & 4 deletions tests/rest/admin/test_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,12 +1431,14 @@ def test_create_user(self):
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("[email protected]", channel.json_body["threepids"][0]["address"])
self.assertEqual(1, len(channel.json_body["threepids"]))
self.assertEqual(
"external_id1", channel.json_body["external_ids"][0]["external_id"]
)
self.assertEqual(
"auth_provider1", channel.json_body["external_ids"][0]["auth_provider"]
)
self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertFalse(channel.json_body["admin"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
self._check_fields(channel.json_body)
Expand Down Expand Up @@ -1676,18 +1678,53 @@ def test_set_threepid(self):
Test setting threepid for an other user.
"""

# Delete old and add new threepid to user
# Add two threepids to user
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content={"threepids": [{"medium": "email", "address": "[email protected]"}]},
content={
"threepids": [
{"medium": "email", "address": "[email protected]"},
{"medium": "email", "address": "[email protected]"},
],
},
)

self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
# result does not always have the same sort order, therefore it becomes sorted
sorted_result = sorted(
channel.json_body["threepids"], key=lambda k: k["address"]
)
self.assertEqual("email", sorted_result[0]["medium"])
self.assertEqual("[email protected]", sorted_result[0]["address"])
self.assertEqual("email", sorted_result[1]["medium"])
self.assertEqual("[email protected]", sorted_result[1]["address"])
self._check_fields(channel.json_body)

# Set a new and remove a threepid
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content={
"threepids": [
{"medium": "email", "address": "[email protected]"},
{"medium": "email", "address": "[email protected]"},
],
},
)

self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("[email protected]", channel.json_body["threepids"][0]["address"])
self.assertEqual("[email protected]", channel.json_body["threepids"][0]["address"])
self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
self.assertEqual("[email protected]", channel.json_body["threepids"][1]["address"])
self._check_fields(channel.json_body)

# Get user
channel = self.make_request(
Expand All @@ -1698,8 +1735,24 @@ def test_set_threepid(self):

self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("[email protected]", channel.json_body["threepids"][0]["address"])
self.assertEqual("[email protected]", channel.json_body["threepids"][0]["address"])
self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
self.assertEqual("[email protected]", channel.json_body["threepids"][1]["address"])
self._check_fields(channel.json_body)

# Remove threepids
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content={"threepids": []},
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self._check_fields(channel.json_body)

def test_set_external_id(self):
"""
Expand Down Expand Up @@ -1778,6 +1831,7 @@ def test_set_external_id(self):

self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["external_ids"]))
self.assertEqual(
channel.json_body["external_ids"],
[
Expand Down