Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add support for stable MSC2858 API #9617

Merged
merged 4 commits into from
Mar 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9617.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Finalise support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858)).
8 changes: 4 additions & 4 deletions docs/openid.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ Synapse config:
oidc_providers:
- idp_id: github
idp_name: Github
idp_brand: "org.matrix.github" # optional: styling hint for clients
idp_brand: "github" # optional: styling hint for clients
discover: false
issuer: "https://github.com/"
client_id: "your-client-id" # TO BE FILLED
Expand All @@ -252,7 +252,7 @@ oidc_providers:
oidc_providers:
- idp_id: google
idp_name: Google
idp_brand: "org.matrix.google" # optional: styling hint for clients
idp_brand: "google" # optional: styling hint for clients
issuer: "https://accounts.google.com/"
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
Expand Down Expand Up @@ -299,7 +299,7 @@ Synapse config:
oidc_providers:
- idp_id: gitlab
idp_name: Gitlab
idp_brand: "org.matrix.gitlab" # optional: styling hint for clients
idp_brand: "gitlab" # optional: styling hint for clients
issuer: "https://gitlab.com/"
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
Expand Down Expand Up @@ -334,7 +334,7 @@ Synapse config:
```yaml
- idp_id: facebook
idp_name: Facebook
idp_brand: "org.matrix.facebook" # optional: styling hint for clients
idp_brand: "facebook" # optional: styling hint for clients
discover: false
issuer: "https://facebook.com"
client_id: "your-client-id" # TO BE FILLED
Expand Down
2 changes: 1 addition & 1 deletion docs/sample_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1919,7 +1919,7 @@ oidc_providers:
#
#- idp_id: github
# idp_name: Github
# idp_brand: org.matrix.github
# idp_brand: github
# discover: false
# issuer: "https://github.com/"
# client_id: "your-client-id" # TO BE FILLED
Expand Down
13 changes: 11 additions & 2 deletions synapse/config/oidc_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs):
#
#- idp_id: github
# idp_name: Github
# idp_brand: org.matrix.github
# idp_brand: github
# discover: false
# issuer: "https://github.com/"
# client_id: "your-client-id" # TO BE FILLED
Expand Down Expand Up @@ -272,7 +272,12 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs):
"idp_icon": {"type": "string"},
"idp_brand": {
"type": "string",
# MSC2758-style namespaced identifier
"minLength": 1,
"maxLength": 255,
"pattern": "^[a-z][a-z0-9_.-]*$",
},
"idp_unstable_brand": {
"type": "string",
"minLength": 1,
"maxLength": 255,
"pattern": "^[a-z][a-z0-9_.-]*$",
Expand Down Expand Up @@ -466,6 +471,7 @@ def _parse_oidc_config_dict(
idp_name=oidc_config.get("idp_name", "OIDC"),
idp_icon=idp_icon,
idp_brand=oidc_config.get("idp_brand"),
unstable_idp_brand=oidc_config.get("unstable_idp_brand"),
discover=oidc_config.get("discover", True),
issuer=oidc_config["issuer"],
client_id=oidc_config["client_id"],
Expand Down Expand Up @@ -512,6 +518,9 @@ class OidcProviderConfig:
# Optional brand identifier for this IdP.
idp_brand = attr.ib(type=Optional[str])

# Optional brand identifier for the unstable API (see MSC2858).
unstable_idp_brand = attr.ib(type=Optional[str])

# whether the OIDC discovery mechanism is used to discover endpoints
discover = attr.ib(type=bool)

Expand Down
1 change: 1 addition & 0 deletions synapse/handlers/cas_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(self, hs: "HomeServer"):
# the SsoIdentityProvider protocol type.
self.idp_icon = None
self.idp_brand = None
self.unstable_idp_brand = None

self._sso_handler = hs.get_sso_handler()

Expand Down
3 changes: 3 additions & 0 deletions synapse/handlers/oidc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,9 @@ def __init__(
# optional brand identifier for this auth provider
self.idp_brand = provider.idp_brand

# Optional brand identifier for the unstable API (see MSC2858).
self.unstable_idp_brand = provider.unstable_idp_brand

self._sso_handler = hs.get_sso_handler()

self._sso_handler.register_identity_provider(self)
Expand Down
1 change: 1 addition & 0 deletions synapse/handlers/saml_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(self, hs: "HomeServer"):
# the SsoIdentityProvider protocol type.
self.idp_icon = None
self.idp_brand = None
self.unstable_idp_brand = None

# a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
Expand Down
5 changes: 5 additions & 0 deletions synapse/handlers/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ def idp_brand(self) -> Optional[str]:
"""Optional branding identifier"""
return None

@property
def unstable_idp_brand(self) -> Optional[str]:
"""Optional brand identifier for the unstable API (see MSC2858)."""
return None

@abc.abstractmethod
async def handle_redirect_request(
self,
Expand Down
39 changes: 34 additions & 5 deletions synapse/rest/client/v1/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
# limitations under the License.

import logging
import re
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional

from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.api.urls import CLIENT_API_PREFIX
from synapse.appservice import ApplicationService
from synapse.handlers.sso import SsoIdentityProvider
from synapse.http import get_request_uri
Expand Down Expand Up @@ -94,11 +96,21 @@ def on_GET(self, request: SynapseRequest):
flows.append({"type": LoginRestServlet.CAS_TYPE})

if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict
sso_flow = {
"type": LoginRestServlet.SSO_TYPE,
"identity_providers": [
_get_auth_flow_dict_for_idp(
idp,
)
for idp in self._sso_handler.get_identity_providers().values()
],
} # type: JsonDict

if self._msc2858_enabled:
# backwards-compatibility support for clients which don't
# support the stable API yet
sso_flow["org.matrix.msc2858.identity_providers"] = [
_get_auth_flow_dict_for_idp(idp)
_get_auth_flow_dict_for_idp(idp, use_unstable_brands=True)
for idp in self._sso_handler.get_identity_providers().values()
]

Expand Down Expand Up @@ -331,22 +343,38 @@ async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
return result


def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
def _get_auth_flow_dict_for_idp(
idp: SsoIdentityProvider, use_unstable_brands: bool = False
) -> JsonDict:
"""Return an entry for the login flow dict

Returns an entry suitable for inclusion in "identity_providers" in the
response to GET /_matrix/client/r0/login

Args:
idp: the identity provider to describe
use_unstable_brands: whether we should use brand identifiers suitable
for the unstable API
"""
e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
if idp.idp_icon:
e["icon"] = idp.idp_icon
if idp.idp_brand:
e["brand"] = idp.idp_brand
# use the stable brand identifier if the unstable identifier isn't defined.
if use_unstable_brands and idp.unstable_idp_brand:
e["brand"] = idp.unstable_idp_brand
return e


class SsoRedirectServlet(RestServlet):
PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True)
PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to already be a list, we could probably update the type hint of client_patterns instead of casting here. 🤷

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might not be a list in some tests?

re.compile(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use client_patterns here too?

Copy link
Member Author

@richvdh richvdh Mar 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, good question. I think it's slightly more verbose to call client_patterns (especially with the cast to list), but I could do so.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only argument for changing it is that if we ever added another path to all our stable endpoints this would happen automatically. (Also client_patterns automatically registers under unstable/ too, so might make sense for consistency -- although I suspect that's less of a feature and more of a bug that it does that.)

I think it is OK to leave it if we don't find any of the above compelling.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah. I think most of client_patterns's behaviour is more bug than feature tbh. I feel like we should be explicit when we add new paths.

"^"
+ CLIENT_API_PREFIX
+ "/r0/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$"
)
]

def __init__(self, hs: "HomeServer"):
# make sure that the relevant handlers are instantiated, so that they
Expand All @@ -364,7 +392,8 @@ def __init__(self, hs: "HomeServer"):
def register(self, http_server: HttpServer) -> None:
super().register(http_server)
if self._msc2858_enabled:
# expose additional endpoint for MSC2858 support
# expose additional endpoint for MSC2858 support: backwards-compat support
# for clients which don't yet support the stable endpoints.
http_server.register_paths(
"GET",
client_patterns(
Expand Down
43 changes: 27 additions & 16 deletions tests/rest/client/v1/test_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,14 +437,16 @@ def test_get_login_flows(self):
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)

expected_flows = [
{"type": "m.login.cas"},
{"type": "m.login.sso"},
{"type": "m.login.token"},
{"type": "m.login.password"},
] + ADDITIONAL_LOGIN_FLOWS
expected_flow_types = [
"m.login.cas",
"m.login.sso",
"m.login.token",
"m.login.password",
] + [f["type"] for f in ADDITIONAL_LOGIN_FLOWS]

self.assertCountEqual(channel.json_body["flows"], expected_flows)
self.assertCountEqual(
[f["type"] for f in channel.json_body["flows"]], expected_flow_types
)

@override_config({"experimental_features": {"msc2858_enabled": True}})
def test_get_msc2858_login_flows(self):
Expand Down Expand Up @@ -636,22 +638,25 @@ def test_multi_sso_redirect_to_unknown(self):
)
self.assertEqual(channel.code, 400, channel.result)

def test_client_idp_redirect_msc2858_disabled(self):
"""If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
channel = self._make_sso_redirect_request(True, "oidc")
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")

@override_config({"experimental_features": {"msc2858_enabled": True}})
def test_client_idp_redirect_to_unknown(self):
"""If the client tries to pick an unknown IdP, return a 404"""
channel = self._make_sso_redirect_request(True, "xxx")
channel = self._make_sso_redirect_request(False, "xxx")
self.assertEqual(channel.code, 404, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")

@override_config({"experimental_features": {"msc2858_enabled": True}})
def test_client_idp_redirect_to_oidc(self):
"""If the client pick a known IdP, redirect to it"""
channel = self._make_sso_redirect_request(False, "oidc")
self.assertEqual(channel.code, 302, channel.result)
oidc_uri = channel.headers.getRawHeaders("Location")[0]
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)

# it should redirect us to the auth page of the OIDC server
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)

@override_config({"experimental_features": {"msc2858_enabled": True}})
def test_client_msc2858_redirect_to_oidc(self):
"""Test the unstable API"""
channel = self._make_sso_redirect_request(True, "oidc")
self.assertEqual(channel.code, 302, channel.result)
oidc_uri = channel.headers.getRawHeaders("Location")[0]
Expand All @@ -660,6 +665,12 @@ def test_client_idp_redirect_to_oidc(self):
# it should redirect us to the auth page of the OIDC server
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)

def test_client_idp_redirect_msc2858_disabled(self):
"""If the client tries to use the MSC2858 endpoint but MSC2858 is disabled, return a 400"""
channel = self._make_sso_redirect_request(True, "oidc")
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")

def _make_sso_redirect_request(
self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None
):
Expand Down