Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/18635.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support arbitrary profile fields.
24 changes: 11 additions & 13 deletions synapse/handlers/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ async def get_profile(self, user_id: str, ignore_backoff: bool = True) -> JsonDi

if self.hs.is_mine(target_user):
profileinfo = await self.store.get_profileinfo(target_user)
extra_fields = {}
if self.hs.config.experimental.msc4133_enabled:
extra_fields = await self.store.get_profile_fields(target_user)
extra_fields = await self.store.get_profile_fields(target_user)

if (
profileinfo.display_name is None
Expand Down Expand Up @@ -550,16 +548,16 @@ async def on_profile_query(self, args: JsonDict) -> JsonDict:
# since then we send a null in the JSON response
if avatar_url is not None:
response["avatar_url"] = avatar_url
if self.hs.config.experimental.msc4133_enabled:
if just_field is None:
response.update(await self.store.get_profile_fields(user))
elif just_field not in (
ProfileFields.DISPLAYNAME,
ProfileFields.AVATAR_URL,
):
response[just_field] = await self.store.get_profile_field(
user, just_field
)

if just_field is None:
response.update(await self.store.get_profile_fields(user))
elif just_field not in (
ProfileFields.DISPLAYNAME,
ProfileFields.AVATAR_URL,
):
response[just_field] = await self.store.get_profile_field(
user, just_field
)
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
Expand Down
30 changes: 15 additions & 15 deletions synapse/rest/client/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,22 +92,22 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"enabled": self.config.experimental.msc3664_enabled,
}

disallowed_profile_fields = []
response["capabilities"]["m.profile_fields"] = {"enabled": True}
if not self.config.registration.enable_set_displayname:
disallowed_profile_fields.append("displayname")
if not self.config.registration.enable_set_avatar_url:
disallowed_profile_fields.append("avatar_url")
if disallowed_profile_fields:
response["capabilities"]["m.profile_fields"]["disallowed"] = (
disallowed_profile_fields
)

# For transition from unstable to stable identifiers.
if self.config.experimental.msc4133_enabled:
response["capabilities"]["uk.tcpip.msc4133.profile_fields"] = {
"enabled": True,
}

# Ensure this is consistent with the legacy m.set_displayname and
# m.set_avatar_url.
disallowed = []
if not self.config.registration.enable_set_displayname:
disallowed.append("displayname")
if not self.config.registration.enable_set_avatar_url:
disallowed.append("avatar_url")
if disallowed:
response["capabilities"]["uk.tcpip.msc4133.profile_fields"][
"disallowed"
] = disallowed
response["capabilities"]["uk.tcpip.msc4133.profile_fields"] = response[
"capabilities"
]["m.profile_fields"]

if self.config.experimental.msc4267_enabled:
response["capabilities"]["org.matrix.msc4267.forget_forced_upon_leave"] = {
Expand Down
190 changes: 26 additions & 164 deletions synapse/rest/client/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,161 +57,6 @@ def _read_propagate(hs: "HomeServer", request: SynapseRequest) -> bool:
return propagate


class ProfileDisplaynameRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True)
CATEGORY = "Event sending requests"

def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()

async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester_user = None

if self.hs.config.server.require_auth_for_profile_requests:
requester = await self.auth.get_user_by_req(request)
requester_user = requester.user

if not UserID.is_valid(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
)

user = UserID.from_string(user_id)
await self.profile_handler.check_profile_query_allowed(user, requester_user)

displayname = await self.profile_handler.get_displayname(user)

ret = {}
if displayname is not None:
ret["displayname"] = displayname

return 200, ret

async def on_PUT(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
if not UserID.is_valid(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
)

requester = await self.auth.get_user_by_req(request, allow_guest=True)
user = UserID.from_string(user_id)
is_admin = await self.auth.is_server_admin(requester)

content = parse_json_object_from_request(request)

try:
new_name = content["displayname"]
except Exception:
raise SynapseError(
400, "Missing key 'displayname'", errcode=Codes.MISSING_PARAM
)

propagate = _read_propagate(self.hs, request)

requester_suspended = (
await self.hs.get_datastores().main.get_user_suspended_status(
requester.user.to_string()
)
)

if requester_suspended:
raise SynapseError(
403,
"Updating displayname while account is suspended is not allowed.",
Codes.USER_ACCOUNT_SUSPENDED,
)

await self.profile_handler.set_displayname(
user, requester, new_name, is_admin, propagate=propagate
)

return 200, {}


class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
CATEGORY = "Event sending requests"

def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()

async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester_user = None

if self.hs.config.server.require_auth_for_profile_requests:
requester = await self.auth.get_user_by_req(request)
requester_user = requester.user

if not UserID.is_valid(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
)

user = UserID.from_string(user_id)
await self.profile_handler.check_profile_query_allowed(user, requester_user)

avatar_url = await self.profile_handler.get_avatar_url(user)

ret = {}
if avatar_url is not None:
ret["avatar_url"] = avatar_url

return 200, ret

async def on_PUT(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
if not UserID.is_valid(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
)

requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
is_admin = await self.auth.is_server_admin(requester)

content = parse_json_object_from_request(request)
try:
new_avatar_url = content["avatar_url"]
except KeyError:
raise SynapseError(
400, "Missing key 'avatar_url'", errcode=Codes.MISSING_PARAM
)

propagate = _read_propagate(self.hs, request)

requester_suspended = (
await self.hs.get_datastores().main.get_user_suspended_status(
requester.user.to_string()
)
)

if requester_suspended:
raise SynapseError(
403,
"Updating avatar URL while account is suspended is not allowed.",
Codes.USER_ACCOUNT_SUSPENDED,
)

await self.profile_handler.set_avatar_url(
user, requester, new_avatar_url, is_admin, propagate=propagate
)

return 200, {}


class ProfileRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
CATEGORY = "Event sending requests"
Expand Down Expand Up @@ -244,12 +89,19 @@ async def on_GET(
return 200, ret


class UnstableProfileFieldRestServlet(RestServlet):
class ProfileFieldRestServlet(RestServlet):
PATTERNS = [
*client_patterns(
"/profile/(?P<user_id>[^/]*)/(?P<field_name>displayname)", v1=True
),
*client_patterns(
"/profile/(?P<user_id>[^/]*)/(?P<field_name>avatar_url)", v1=True
),
re.compile(
r"^/_matrix/client/unstable/uk\.tcpip\.msc4133/profile/(?P<user_id>[^/]*)/(?P<field_name>[^/]*)"
)
r"^/_matrix/client/v3/profile/(?P<user_id>[^/]*)/(?P<field_name>[^/]*)"
),
]

CATEGORY = "Event sending requests"

def __init__(self, hs: "HomeServer"):
Expand Down Expand Up @@ -304,7 +156,10 @@ async def on_PUT(
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
)

requester = await self.auth.get_user_by_req(request)
# Guest users are able to set their own displayname.
requester = await self.auth.get_user_by_req(
request, allow_guest=field_name == ProfileFields.DISPLAYNAME
)
user = UserID.from_string(user_id)
is_admin = await self.auth.is_server_admin(requester)

Expand Down Expand Up @@ -366,7 +221,10 @@ async def on_DELETE(
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
)

requester = await self.auth.get_user_by_req(request)
# Guest users are able to set their own displayname.
requester = await self.auth.get_user_by_req(
request, allow_guest=field_name == ProfileFields.DISPLAYNAME
)
user = UserID.from_string(user_id)
is_admin = await self.auth.is_server_admin(requester)

Expand Down Expand Up @@ -413,11 +271,15 @@ async def on_DELETE(
return 200, {}


class UnstableProfileFieldRestServlet(ProfileFieldRestServlet):
re.compile(
r"^/_matrix/client/unstable/uk\.tcpip\.msc4133/profile/(?P<user_id>[^/]*)/(?P<field_name>[^/]*)"
)


def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
# The specific displayname / avatar URL / custom field endpoints *must* appear
# before their corresponding generic profile endpoint.
ProfileDisplaynameRestServlet(hs).register(http_server)
ProfileAvatarURLRestServlet(hs).register(http_server)
# The specific field endpoint *must* appear before the generic profile endpoint.
ProfileFieldRestServlet(hs).register(http_server)
ProfileRestServlet(hs).register(http_server)
if hs.config.experimental.msc4133_enabled:
UnstableProfileFieldRestServlet(hs).register(http_server)
2 changes: 0 additions & 2 deletions synapse/rest/client/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,6 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"org.matrix.msc4140": bool(self.config.server.max_event_delay_ms),
# Simplified sliding sync
"org.matrix.simplified_msc3575": msc3575_enabled,
# Arbitrary key-value profile fields.
"uk.tcpip.msc4133": self.config.experimental.msc4133_enabled,
# MSC4155: Invite filtering
"org.matrix.msc4155": self.config.experimental.msc4155_enabled,
},
Expand Down
12 changes: 12 additions & 0 deletions tests/rest/client/test_capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ def test_get_set_displayname_capabilities_displayname_disabled(self) -> None:

self.assertEqual(channel.code, HTTPStatus.OK)
self.assertFalse(capabilities["m.set_displayname"]["enabled"])
self.assertTrue(capabilities["m.profile_fields"]["enabled"])
self.assertEqual(
capabilities["m.profile_fields"]["disallowed"], ["displayname"]
)

@override_config({"enable_set_avatar_url": False})
def test_get_set_avatar_url_capabilities_avatar_url_disabled(self) -> None:
Expand All @@ -141,6 +145,8 @@ def test_get_set_avatar_url_capabilities_avatar_url_disabled(self) -> None:

self.assertEqual(channel.code, HTTPStatus.OK)
self.assertFalse(capabilities["m.set_avatar_url"]["enabled"])
self.assertTrue(capabilities["m.profile_fields"]["enabled"])
self.assertEqual(capabilities["m.profile_fields"]["disallowed"], ["avatar_url"])

@override_config(
{
Expand All @@ -159,6 +165,10 @@ def test_get_set_displayname_capabilities_displayname_disabled_msc4133(

self.assertEqual(channel.code, HTTPStatus.OK)
self.assertFalse(capabilities["m.set_displayname"]["enabled"])
self.assertTrue(capabilities["m.profile_fields"]["enabled"])
self.assertEqual(
capabilities["m.profile_fields"]["disallowed"], ["displayname"]
)
self.assertTrue(capabilities["uk.tcpip.msc4133.profile_fields"]["enabled"])
self.assertEqual(
capabilities["uk.tcpip.msc4133.profile_fields"]["disallowed"],
Expand All @@ -180,6 +190,8 @@ def test_get_set_avatar_url_capabilities_avatar_url_disabled_msc4133(self) -> No

self.assertEqual(channel.code, HTTPStatus.OK)
self.assertFalse(capabilities["m.set_avatar_url"]["enabled"])
self.assertTrue(capabilities["m.profile_fields"]["enabled"])
self.assertEqual(capabilities["m.profile_fields"]["disallowed"], ["avatar_url"])
self.assertTrue(capabilities["uk.tcpip.msc4133.profile_fields"]["enabled"])
self.assertEqual(
capabilities["uk.tcpip.msc4133.profile_fields"]["disallowed"],
Expand Down
Loading
Loading