From bdf7ed45884e3403a6996798ffa1c574cc1e578d Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 22 Sep 2022 18:35:22 +0200 Subject: [PATCH] Refactor OIDC tests to better mimic an actual OIDC provider Instead of constantly mocking the internal methods of the OIDC handler, it now mocks HTTP responses Signed-off-by: Quentin Gliech --- synapse/handlers/oidc.py | 19 +- tests/handlers/test_oidc.py | 812 +++++++++++++++++++++++------------- 2 files changed, 529 insertions(+), 302 deletions(-) diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 0262eda3d36b..22d866c976e7 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -44,7 +44,7 @@ MacaroonInitException, MacaroonInvalidSignatureException, ) -from typing_extensions import TypedDict +from typing_extensions import NotRequired, TypedDict from twisted.web.client import readBody from twisted.web.http_headers import Headers @@ -95,10 +95,10 @@ class Token(TypedDict): access_token: str token_type: str - id_token: Optional[str] - refresh_token: Optional[str] + id_token: NotRequired[str] + refresh_token: NotRequired[str] expires_in: int - scope: Optional[str] + scope: NotRequired[str] #: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but @@ -367,6 +367,7 @@ def __init__( provider: OidcProviderConfig, ): self._store = hs.get_datastores().main + self._clock = hs.get_clock() self._macaroon_generaton = macaroon_generator @@ -847,7 +848,9 @@ async def _verify_jwt( logger.debug("Decoded JWT (%s) %r; validating", claims_cls.__name__, claims) - claims.validate(leeway=120) # allows 2 min of clock skew + claims.validate( + now=self._clock.time(), leeway=120 + ) # allows 2 min of clock skew return claims async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken: @@ -862,7 +865,7 @@ async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken: Returns: The decoded claims in the ID token. """ - id_token = token["id_token"] + id_token = token.get("id_token") # That has been theoritically been checked by the caller, so even though # assertion are not enabled in production, it is mainly here to appease mypy @@ -1294,8 +1297,8 @@ async def handle_backchannel_logout( # `user_id`. Hence, we have to iterate over the list of devices and log them out # one by one. for device in devices: - user_id = device["user_id"] - device_id = device["device_id"] + user_id: str = device["user_id"] + device_id: str = device["device_id"] # If the user_id associated with that device/session is not the one we got # out of the `sub` claim, skip that device and show log an error. diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index e6cd3af7b756..8dd9ac1a2934 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -13,21 +13,33 @@ # limitations under the License. import json import os -from typing import Any, Dict +from typing import Any, Dict, List, Optional +from unittest import mock from unittest.mock import ANY, Mock, patch from urllib.parse import parse_qs, urlparse import pymacaroons +from typing_extensions import TypedDict +from twisted.internet.interfaces import IProtocol +from twisted.python.failure import Failure from twisted.test.proto_helpers import MemoryReactor +from twisted.web.client import ResponseDone +from twisted.web.http import RESPONSES +from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse +from synapse.handlers.oidc import Token from synapse.handlers.sso import MappingException +from synapse.http.client import SimpleHttpClient +from synapse.http.site import SynapseRequest from synapse.server import HomeServer from synapse.types import JsonDict, UserID from synapse.util import Clock -from synapse.util.macaroons import OidcSessionData, get_value_from_macaroon +from synapse.util.macaroons import get_value_from_macaroon +from synapse.util.stringutils import random_string -from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock +from tests.test_utils import get_awaitable_result, simple_async_mock from tests.unittest import HomeserverTestCase, override_config try: @@ -46,12 +58,6 @@ CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback" SCOPES = ["openid"] -AUTHORIZATION_ENDPOINT = ISSUER + "authorize" -TOKEN_ENDPOINT = ISSUER + "token" -USERINFO_ENDPOINT = ISSUER + "userinfo" -WELL_KNOWN = ISSUER + ".well-known/openid-configuration" -JWKS_URI = ISSUER + ".well-known/jwks.json" - # config for common cases DEFAULT_CONFIG = { "enabled": True, @@ -66,9 +72,9 @@ EXPLICIT_ENDPOINT_CONFIG = { **DEFAULT_CONFIG, "discover": False, - "authorization_endpoint": AUTHORIZATION_ENDPOINT, - "token_endpoint": TOKEN_ENDPOINT, - "jwks_uri": JWKS_URI, + "authorization_endpoint": ISSUER + "authorize", + "token_endpoint": ISSUER + "token", + "jwks_uri": ISSUER + "jwks", } @@ -102,25 +108,275 @@ async def map_user_attributes(self, userinfo, token, failures): } -async def get_json(url: str) -> JsonDict: - # Mock get_json calls to handle jwks & oidc discovery endpoints - if url == WELL_KNOWN: - # Minimal discovery document, as defined in OpenID.Discovery - # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata +class AuthorizationGrant(TypedDict, total=True): + userinfo: dict + client_id: str + redirect_uri: str + scope: str + nonce: Optional[str] + + +class FakeProvider: + """A fake OpenID Connect Provider.""" + + def __init__(self, clock: Clock, issuer: str): + from authlib.jose import ECKey, KeySet + + self.clock = clock + self.issuer = issuer + + # A code -> grant mapping + self.authorization_grants: Dict[str, AuthorizationGrant] = {} + # An access token -> grant mapping + self.sessions: Dict[str, AuthorizationGrant] = {} + + # We generate here an ECDSA key with the P-256 curve (ES256 algorithm) used for + # signing JWTs. ECDSA keys are really quick to generate compared to RSA. + self.key = ECKey.generate_key(crv="P-256", is_private=True) + self.jwks = KeySet([ECKey.import_key(self.key.raw_key.public_key())]) + + self._id_token_overrides: Dict[str, Any] = {} + + @property + def authorization_endpoint(self) -> str: + return self.issuer + "authorize" + + @property + def token_endpoint(self) -> str: + return self.issuer + "token" + + @property + def userinfo_endpoint(self) -> str: + return self.issuer + "userinfo" + + @property + def metadata_endpoint(self) -> str: + return self.issuer + ".well-known/openid-configuration" + + @property + def jwks_uri(self) -> str: + return self.issuer + "jwks" + + def get_metadata(self) -> dict: return { - "issuer": ISSUER, - "authorization_endpoint": AUTHORIZATION_ENDPOINT, - "token_endpoint": TOKEN_ENDPOINT, - "jwks_uri": JWKS_URI, - "userinfo_endpoint": USERINFO_ENDPOINT, + "issuer": self.issuer, + "authorization_endpoint": self.authorization_endpoint, + "token_endpoint": self.token_endpoint, + "jwks_uri": self.jwks_uri, + "userinfo_endpoint": self.userinfo_endpoint, "response_types_supported": ["code"], "subject_types_supported": ["public"], - "id_token_signing_alg_values_supported": ["RS256"], + "id_token_signing_alg_values_supported": ["ES256"], + } + + def get_jwks(self) -> dict: + return self.jwks.as_dict() + + def get_userinfo(self, access_token: str) -> Optional[dict]: + """Given an access token, get the userinfo of the associated session.""" + session = self.sessions.get(access_token, None) + if session is None: + return None + return session["userinfo"] + + def _sign(self, payload: dict) -> str: + from authlib.jose import JsonWebSignature + + jws = JsonWebSignature() + kid = self.get_jwks()["keys"][0]["kid"] + protected = {"alg": "ES256", "kid": kid} + json_payload = json.dumps(payload) + return jws.serialize_compact(protected, json_payload, self.key).decode("utf-8") + + def generate_id_token(self, grant: AuthorizationGrant) -> str: + now = self.clock.time() + id_token = { + **grant["userinfo"], + "iss": self.issuer, + "aud": grant["client_id"], + "iat": now, + "nbf": now, + "exp": now + 600, } - elif url == JWKS_URI: - return {"keys": []} - return {} + nonce = grant.get("nonce", None) + if nonce is not None: + id_token["nonce"] = nonce + + id_token.update(self._id_token_overrides) + + return self._sign(id_token) + + def id_token_override(self, overrides: dict): + """Temporarily patch the ID token generated by the token endpoint.""" + return patch.object(self, "_id_token_overrides", overrides) + + def start_authorization( + self, + client_id: str, + scope: str, + redirect_uri: str, + userinfo: dict, + nonce: Optional[str] = None, + ) -> str: + """Start an authorization request, and get back the code to use on the authorization endpoint.""" + code = random_string(10) + self.authorization_grants[code] = AuthorizationGrant( + userinfo=userinfo, + scope=scope, + redirect_uri=redirect_uri, + nonce=nonce, + client_id=client_id, + ) + return code + + def exchange_code(self, code: str) -> Optional[Token]: + grant = self.authorization_grants.pop(code, None) + if grant is None: + return None + + access_token = random_string(10) + self.sessions[access_token] = grant + + token = Token( + token_type="Bearer", + access_token=access_token, + expires_in=3600, + scope=grant["scope"], + ) + + if "openid" in grant["scope"]: + token["id_token"] = self.generate_id_token(grant) + + return token + + +class FakeProviderHttpClient(SimpleHttpClient): + """A fake HTTP client, to handle OAuth request through the FakeProvider.""" + + # All methods here are mocks, so we can track when they are called, and override + # their values + request: Mock + get_jwks: Mock + get_metadata: Mock + get_userinfo: Mock + post_token: Mock + + def __init__(self, hs: "HomeServer", provider: FakeProvider): + super().__init__(hs) + self._provider = provider + + self.request = Mock(side_effect=self._request) + self.get_jwks = Mock(side_effect=self._get_jwks) + self.get_metadata = Mock(side_effect=self._get_metadata) + self.get_userinfo = Mock(side_effect=self._get_userinfo) + self.post_token = Mock(side_effect=self._post_token) + + def buggy_endpoint( + self, + *, + jwks: bool = False, + metadata: bool = False, + token: bool = False, + userinfo: bool = False, + ): + """A context which makes a set of endpoints return a 500 error. + + Args: + jwks: If True, makes the JWKS endpoint return a 500 error. + metadata: If True, makes the OIDC Discovery endpoint return a 500 error. + token: If True, makes the token endpoint return a 500 error. + userinfo: If True, makes the userinfo endpoint return a 500 error. + """ + buggy = _mock_response(code=500, body=b"Internal server error") + + patches = {} + if jwks: + patches["get_jwks"] = Mock(return_value=buggy) + if metadata: + patches["get_metadata"] = Mock(return_value=buggy) + if token: + patches["post_token"] = Mock(return_value=buggy) + if userinfo: + patches["get_userinfo"] = Mock(return_value=buggy) + + return patch.multiple(self, **patches) + + async def _request( + self, + method: str, + uri: str, + data: Optional[bytes] = None, + headers: Optional[Headers] = None, + ) -> IResponse: + """The override of the SimpleHttpClient#request() method""" + access_token: Optional[str] = None + + if headers is None: + headers = Headers() + + # Try to find the access token in the headers if any + auth_headers = headers.getRawHeaders(b"Authorization") + if auth_headers: + parts = auth_headers[0].split(b" ") + if parts[0] == b"Bearer" and len(parts) == 2: + access_token = parts[1].decode("ascii") + + if method == "POST": + # If the method is POST, assume it has an url-encoded body + if data is None or headers.getRawHeaders(b"Content-Type") != [ + b"application/x-www-form-urlencoded" + ]: + return _mock_json_response( + code=400, payload={"error": "invalid_request"} + ) + + params = parse_qs(data.decode("utf-8")) + + if uri == self._provider.token_endpoint: + return self.post_token(params) + + elif method == "GET": + if uri == self._provider.jwks_uri: + return self.get_jwks() + elif uri == self._provider.metadata_endpoint: + return self.get_metadata() + elif uri == self._provider.userinfo_endpoint: + return self.get_userinfo(access_token=access_token) + + return _mock_response(code=404, body=b"404 not found") + + # Request handlers + def _get_jwks(self) -> IResponse: + """Handles requests to the JWKS URI.""" + return _mock_json_response(payload=self._provider.get_jwks()) + + def _get_metadata(self) -> IResponse: + """Handles requests to the OIDC well-known document.""" + return _mock_json_response(payload=self._provider.get_metadata()) + + def _get_userinfo(self, access_token: Optional[str]) -> IResponse: + """Handles requests to the userinfo endpoint.""" + if access_token is None: + return _mock_response(code=401) + user_info = self._provider.get_userinfo(access_token) + if user_info is None: + return _mock_response(code=401) + + return _mock_json_response(payload=user_info) + + def _post_token(self, params: Dict[str, List[str]]) -> IResponse: + """Handles requests to the token endpoint.""" + code = params.get("code", []) + + if len(code) != 1: + return _mock_json_response(code=400, payload={"error": "invalid_request"}) + + grant = self._provider.exchange_code(code=code[0]) + if grant is None: + return _mock_json_response(code=400, payload={"error": "invalid_grant"}) + + return _mock_json_response(payload=dict(grant)) def _key_file_path() -> str: @@ -159,11 +415,11 @@ def default_config(self) -> Dict[str, Any]: return config def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.http_client = Mock(spec=["get_json"]) - self.http_client.get_json.side_effect = get_json - self.http_client.user_agent = b"Synapse Test" + self.fake_provider = FakeProvider(clock=clock, issuer=ISSUER) - hs = self.setup_test_homeserver(proxied_http_client=self.http_client) + hs = self.setup_test_homeserver() + self.http_client = FakeProviderHttpClient(hs=hs, provider=self.fake_provider) + hs._proxied_http_client = self.http_client # type: ignore[attr-defined] self.handler = hs.get_oidc_handler() self.provider = self.handler._providers["oidc"] @@ -175,18 +431,49 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # Reduce the number of attempts when generating MXIDs. sso_handler._MAP_USERNAME_RETRIES = 3 + auth_handler = hs.get_auth_handler() + # Mock the complete SSO login method. + self.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[assignment] + return hs + def reset_mocks(self): + """Reset all the Mocks.""" + self.http_client.request.reset_mock() + self.http_client.get_jwks.reset_mock() + self.http_client.get_metadata.reset_mock() + self.http_client.get_userinfo.reset_mock() + self.http_client.post_token.reset_mock() + self.render_error.reset_mock() + self.complete_sso_login.reset_mock() + def metadata_edit(self, values): """Modify the result that will be returned by the well-known query""" - async def patched_get_json(uri): - res = await get_json(uri) - if uri == WELL_KNOWN: - res.update(values) - return res + metadata = self.fake_provider.get_metadata() + metadata.update(values) + return patch.object(self.fake_provider, "get_metadata", return_value=metadata) - return patch.object(self.http_client, "get_json", patched_get_json) + def start_authorization( + self, + userinfo: dict, + client_redirect_url: str = "http://client/redirect", + scope: str = "openid", + ) -> SynapseRequest: + """Start an authorization request, and get the callback request back.""" + nonce = random_string(10) + state = random_string(10) + + code = self.fake_provider.start_authorization( + userinfo=userinfo, + scope=scope, + client_id=CLIENT_ID, + redirect_uri=CALLBACK_URL, + nonce=nonce, + ) + session = self._generate_oidc_session_token(state, nonce, client_redirect_url) + return _build_callback_request(code, state, session) def assertRenderedError(self, error, error_description=None): self.render_error.assert_called_once() @@ -210,52 +497,54 @@ def test_discovery(self) -> None: """The handler should discover the endpoints from OIDC discovery document.""" # This would throw if some metadata were invalid metadata = self.get_success(self.provider.load_metadata()) - self.http_client.get_json.assert_called_once_with(WELL_KNOWN) + self.http_client.get_metadata.assert_called_once() - self.assertEqual(metadata.issuer, ISSUER) - self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT) - self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT) - self.assertEqual(metadata.jwks_uri, JWKS_URI) - # FIXME: it seems like authlib does not have that defined in its metadata models - # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT) + self.assertEqual(metadata.issuer, self.fake_provider.issuer) + self.assertEqual( + metadata.authorization_endpoint, + self.fake_provider.authorization_endpoint, + ) + self.assertEqual(metadata.token_endpoint, self.fake_provider.token_endpoint) + self.assertEqual(metadata.jwks_uri, self.fake_provider.jwks_uri) + # It seems like authlib does not have that defined in its metadata models + self.assertEqual( + metadata.get("userinfo_endpoint"), + self.fake_provider.userinfo_endpoint, + ) # subsequent calls should be cached - self.http_client.reset_mock() + self.reset_mocks() self.get_success(self.provider.load_metadata()) - self.http_client.get_json.assert_not_called() + self.http_client.get_metadata.assert_not_called() @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) def test_no_discovery(self) -> None: """When discovery is disabled, it should not try to load from discovery document.""" self.get_success(self.provider.load_metadata()) - self.http_client.get_json.assert_not_called() + self.http_client.get_metadata.assert_not_called() - @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_load_jwks(self) -> None: """JWKS loading is done once (then cached) if used.""" jwks = self.get_success(self.provider.load_jwks()) - self.http_client.get_json.assert_called_once_with(JWKS_URI) - self.assertEqual(jwks, {"keys": []}) + self.http_client.get_jwks.assert_called_once() + self.assertEqual(jwks, self.fake_provider.get_jwks()) # subsequent calls should be cached… - self.http_client.reset_mock() + self.reset_mocks() self.get_success(self.provider.load_jwks()) - self.http_client.get_json.assert_not_called() + self.http_client.get_jwks.assert_not_called() # …unless forced - self.http_client.reset_mock() + self.reset_mocks() self.get_success(self.provider.load_jwks(force=True)) - self.http_client.get_json.assert_called_once_with(JWKS_URI) - - # Throw if the JWKS uri is missing - original = self.provider.load_metadata + self.http_client.get_jwks.assert_called_once() - async def patched_load_metadata(): - m = (await original()).copy() - m.update({"jwks_uri": None}) - return m - - with patch.object(self.provider, "load_metadata", patched_load_metadata): + with self.metadata_edit({"jwks_uri": None}): + # If we don't do this, the load_metadata call will throw because of the + # missing jwks_uri + self.provider._user_profile_method = "userinfo_endpoint" + self.get_success(self.provider.load_metadata(force=True)) self.get_failure(self.provider.load_jwks(force=True), RuntimeError) @override_config({"oidc_config": DEFAULT_CONFIG}) @@ -359,7 +648,7 @@ def test_redirect_request(self) -> None: self.provider.handle_redirect_request(req, b"http://client/redirect") ) ) - auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT) + auth_endpoint = urlparse(self.fake_provider.authorization_endpoint) self.assertEqual(url.scheme, auth_endpoint.scheme) self.assertEqual(url.netloc, auth_endpoint.netloc) @@ -424,34 +713,20 @@ def test_callback(self) -> None: with self.assertRaises(AttributeError): _ = mapping_provider.get_extra_attributes - token = { - "type": "bearer", - "id_token": "id_token", - "access_token": "access_token", - } username = "bar" userinfo = { "sub": "foo", "username": username, } expected_user_id = "@%s:%s" % (username, self.hs.hostname) - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] - self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - code = "code" - state = "state" - nonce = "nonce" client_redirect_url = "http://client/redirect" - ip_address = "10.0.0.1" - session = self._generate_oidc_session_token(state, nonce, client_redirect_url) - request = _build_callback_request(code, state, session, ip_address=ip_address) - + request = self.start_authorization( + userinfo, client_redirect_url=client_redirect_url + ) self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( expected_user_id, "oidc", request, @@ -460,12 +735,12 @@ def test_callback(self) -> None: new_user=True, auth_provider_session_id=None, ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.provider._fetch_userinfo.assert_not_called() + self.http_client.post_token.assert_called_once() + self.http_client.get_userinfo.assert_not_called() self.render_error.assert_not_called() # Handle mapping errors + request = self.start_authorization(userinfo) with patch.object( self.provider, "_remote_id_from_userinfo", @@ -475,81 +750,64 @@ def test_callback(self) -> None: self.assertRenderedError("mapping_error") # Handle ID token errors - self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment] - self.get_success(self.handler.handle_oidc_callback(request)) + request = self.start_authorization(userinfo) + with self.fake_provider.id_token_override({"iss": "https://bad.issuer/"}): + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_token") - auth_handler.complete_sso_login.reset_mock() - self.provider._exchange_code.reset_mock() - self.provider._parse_id_token.reset_mock() - self.provider._fetch_userinfo.reset_mock() + self.reset_mocks() # With userinfo fetching self.provider._user_profile_method = "userinfo_endpoint" - token = { - "type": "bearer", - "access_token": "access_token", - } - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] + # Without the "openid" scope, the FakeProvider does not generate an id_token + request = self.start_authorization(userinfo, scope="") self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( expected_user_id, "oidc", request, - client_redirect_url, + ANY, None, new_user=False, auth_provider_session_id=None, ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_not_called() - self.provider._fetch_userinfo.assert_called_once_with(token) + self.http_client.post_token.assert_called_once() + self.http_client.get_userinfo.assert_called_once() self.render_error.assert_not_called() + self.reset_mocks() + # With an ID token, userinfo fetching and sid in the ID token self.provider._user_profile_method = "userinfo_endpoint" - token = { - "type": "bearer", - "access_token": "access_token", - "id_token": "id_token", - } - id_token = { - "sid": "abcdefgh", - } - self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment] - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] - auth_handler.complete_sso_login.reset_mock() - self.provider._fetch_userinfo.reset_mock() - self.get_success(self.handler.handle_oidc_callback(request)) + sid = "abcdefgh" + request = self.start_authorization(userinfo) + with self.fake_provider.id_token_override({"sid": sid}): + self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( expected_user_id, "oidc", request, - client_redirect_url, + ANY, None, new_user=False, - auth_provider_session_id=id_token["sid"], + auth_provider_session_id=sid, ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.provider._fetch_userinfo.assert_called_once_with(token) + self.http_client.post_token.assert_called_once() + self.http_client.get_userinfo.assert_called_once() self.render_error.assert_not_called() # Handle userinfo fetching error - self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment] - self.get_success(self.handler.handle_oidc_callback(request)) + request = self.start_authorization(userinfo) + with self.http_client.buggy_endpoint(userinfo=True): + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("fetch_error") - # Handle code exchange failure - from synapse.handlers.oidc import OidcError - - self.provider._exchange_code = simple_async_mock( # type: ignore[assignment] - raises=OidcError("invalid_request") - ) - self.get_success(self.handler.handle_oidc_callback(request)) - self.assertRenderedError("invalid_request") + request = self.start_authorization(userinfo) + with self.http_client.buggy_endpoint(token=True): + self.get_success(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("server_error") @override_config({"oidc_config": DEFAULT_CONFIG}) def test_callback_session(self) -> None: @@ -599,18 +857,20 @@ def test_callback_session(self) -> None: ) def test_exchange_code(self) -> None: """Code exchange behaves correctly and handles various error scenarios.""" - token = {"type": "bearer"} - token_json = json.dumps(token).encode("utf-8") - self.http_client.request = simple_async_mock( - return_value=FakeResponse(code=200, phrase=b"OK", body=token_json) - ) + token = { + "type": "Bearer", + "access_token": "aabbcc", + } + + self.http_client.post_token.side_effect = None + self.http_client.post_token.return_value = _mock_json_response(payload=token) code = "code" ret = self.get_success(self.provider._exchange_code(code)) kwargs = self.http_client.request.call_args[1] self.assertEqual(ret, token) self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + self.assertEqual(kwargs["uri"], self.fake_provider.token_endpoint) args = parse_qs(kwargs["data"].decode("utf-8")) self.assertEqual(args["grant_type"], ["authorization_code"]) @@ -620,12 +880,8 @@ def test_exchange_code(self) -> None: self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) # Test error handling - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=400, - phrase=b"Bad Request", - body=b'{"error": "foo", "error_description": "bar"}', - ) + self.http_client.post_token.return_value = _mock_json_response( + code=400, payload={"error": "foo", "error_description": "bar"} ) from synapse.handlers.oidc import OidcError @@ -634,46 +890,30 @@ def test_exchange_code(self) -> None: self.assertEqual(exc.value.error_description, "bar") # Internal server error with no JSON body - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=500, - phrase=b"Internal Server Error", - body=b"Not JSON", - ) + self.http_client.post_token.return_value = _mock_response( + code=500, body=b"Not JSON" ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # Internal server error with JSON body - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=500, - phrase=b"Internal Server Error", - body=b'{"error": "internal_server_error"}', - ) + self.http_client.post_token.return_value = _mock_json_response( + code=500, payload={"error": "internal_server_error"} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "internal_server_error") # 4xx error without "error" field - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=400, - phrase=b"Bad request", - body=b"{}", - ) + self.http_client.post_token.return_value = _mock_json_response( + code=400, payload={} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # 2xx error with "error" field - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=200, - phrase=b"OK", - body=b'{"error": "some_error"}', - ) + self.http_client.post_token.return_value = _mock_json_response( + code=200, payload={"error": "some_error"} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "some_error") @@ -697,12 +937,13 @@ def test_exchange_code_jwt_key(self) -> None: """Test that code exchange works with a JWK client secret.""" from authlib.jose import jwt - token = {"type": "bearer"} - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8") - ) - ) + token = { + "type": "Bearer", + "access_token": "aabbcc", + } + + self.http_client.post_token.side_effect = None + self.http_client.post_token.return_value = _mock_json_response(payload=token) code = "code" # advance the clock a bit before we start, so we aren't working with zero @@ -716,7 +957,7 @@ def test_exchange_code_jwt_key(self) -> None: # the request should have hit the token endpoint kwargs = self.http_client.request.call_args[1] self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + self.assertEqual(kwargs["uri"], self.fake_provider.token_endpoint) # the client secret provided to the should be a jwt which can be checked with # the public key @@ -750,12 +991,13 @@ def test_exchange_code_jwt_key(self) -> None: ) def test_exchange_code_no_auth(self) -> None: """Test that code exchange works with no client secret.""" - token = {"type": "bearer"} - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8") - ) - ) + token = { + "type": "Bearer", + "access_token": "aabbcc", + } + + self.http_client.post_token.side_effect = None + self.http_client.post_token.return_value = _mock_json_response(payload=token) code = "code" ret = self.get_success(self.provider._exchange_code(code)) @@ -764,7 +1006,7 @@ def test_exchange_code_no_auth(self) -> None: # the request should have hit the token endpoint kwargs = self.http_client.request.call_args[1] self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + self.assertEqual(kwargs["uri"], self.fake_provider.token_endpoint) # check the POSTed data args = parse_qs(kwargs["data"].decode("utf-8")) @@ -787,37 +1029,19 @@ def test_extra_attributes(self) -> None: """ Login while using a mapping provider that implements get_extra_attributes. """ - token = { - "type": "bearer", - "id_token": "id_token", - "access_token": "access_token", - } userinfo = { "sub": "foo", "username": "foo", "phone": "1234567", } - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] - self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - - state = "state" - client_redirect_url = "http://client/redirect" - session = self._generate_oidc_session_token( - state=state, - nonce="nonce", - client_redirect_url=client_redirect_url, - ) - request = _build_callback_request("code", state, session) - + request = self.start_authorization(userinfo) self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@foo:test", "oidc", request, - client_redirect_url, + ANY, {"phone": "1234567"}, new_user=True, auth_provider_session_id=None, @@ -826,15 +1050,13 @@ def test_extra_attributes(self) -> None: @override_config({"oidc_config": DEFAULT_CONFIG}) def test_map_userinfo_to_user(self) -> None: """Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - userinfo: dict = { "sub": "test_user", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( "@test_user:test", "oidc", ANY, @@ -843,15 +1065,16 @@ def test_map_userinfo_to_user(self) -> None: new_user=True, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Some providers return an integer ID. userinfo = { "sub": 1234, "username": "test_user_2", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( "@test_user_2:test", "oidc", ANY, @@ -860,7 +1083,7 @@ def test_map_userinfo_to_user(self) -> None: new_user=True, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Test if the mxid is already taken store = self.hs.get_datastores().main @@ -869,8 +1092,9 @@ def test_map_userinfo_to_user(self) -> None: store.register_user(user_id=user3.to_string(), password_hash=None) ) userinfo = {"sub": "test3", "username": "test_user_3"} - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() self.assertRenderedError( "mapping_error", "Mapping provider does not support de-duplicating Matrix IDs", @@ -885,16 +1109,14 @@ def test_map_userinfo_to_existing_user(self) -> None: store.register_user(user_id=user.to_string(), password_hash=None) ) - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - # Map a user via SSO. userinfo = { "sub": "test", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( user.to_string(), "oidc", ANY, @@ -903,11 +1125,12 @@ def test_map_userinfo_to_existing_user(self) -> None: new_user=False, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Subsequent calls should map to the same mxid. - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( user.to_string(), "oidc", ANY, @@ -916,7 +1139,7 @@ def test_map_userinfo_to_existing_user(self) -> None: new_user=False, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Note that a second SSO user can be mapped to the same Matrix ID. (This # requires a unique sub, but something that maps to the same matrix ID, @@ -927,8 +1150,9 @@ def test_map_userinfo_to_existing_user(self) -> None: "sub": "test1", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( user.to_string(), "oidc", ANY, @@ -937,7 +1161,7 @@ def test_map_userinfo_to_existing_user(self) -> None: new_user=False, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Register some non-exact matching cases. user2 = UserID.from_string("@TEST_user_2:test") @@ -954,8 +1178,9 @@ def test_map_userinfo_to_existing_user(self) -> None: "sub": "test2", "username": "TEST_USER_2", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() args = self.assertRenderedError("mapping_error") self.assertTrue( args[2].startswith( @@ -969,8 +1194,9 @@ def test_map_userinfo_to_existing_user(self) -> None: store.register_user(user_id=user2.to_string(), password_hash=None) ) - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( "@TEST_USER_2:test", "oidc", ANY, @@ -983,9 +1209,9 @@ def test_map_userinfo_to_existing_user(self) -> None: @override_config({"oidc_config": DEFAULT_CONFIG}) def test_map_userinfo_to_invalid_localpart(self) -> None: """If the mapping provider generates an invalid localpart it should be rejected.""" - self.get_success( - _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"}) - ) + userinfo = {"sub": "test2", "username": "föö"} + request = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error", "localpart is invalid: föö") @override_config( @@ -1000,9 +1226,6 @@ def test_map_userinfo_to_invalid_localpart(self) -> None: ) def test_map_userinfo_to_user_retries(self) -> None: """The mapping provider can retry generating an MXID if the MXID is already in use.""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) @@ -1011,10 +1234,11 @@ def test_map_userinfo_to_user_retries(self) -> None: "sub": "test", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) # test_user is already taken, so test_user1 gets registered instead. - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@test_user1:test", "oidc", ANY, @@ -1023,7 +1247,7 @@ def test_map_userinfo_to_user_retries(self) -> None: new_user=True, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Register all of the potential mxids for a particular OIDC username. self.get_success( @@ -1039,8 +1263,9 @@ def test_map_userinfo_to_user_retries(self) -> None: "sub": "tester", "username": "tester", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() self.assertRenderedError( "mapping_error", "Unable to generate a Matrix ID from the SSO response" ) @@ -1052,7 +1277,8 @@ def test_empty_localpart(self) -> None: "sub": "tester", "username": "", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error", "localpart is invalid: ") @override_config( @@ -1071,7 +1297,8 @@ def test_null_localpart(self) -> None: "sub": "tester", "username": None, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error", "localpart is invalid: ") @override_config( @@ -1084,16 +1311,14 @@ def test_null_localpart(self) -> None: ) def test_attribute_requirements(self) -> None: """The required attributes must be met from the OIDC userinfo response.""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - # userinfo lacking "test": "foobar" attribute should fail. userinfo = { "sub": "tester", "username": "tester", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": "foobar" attribute should succeed. userinfo = { @@ -1101,10 +1326,11 @@ def test_attribute_requirements(self) -> None: "username": "tester", "test": "foobar", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) # check that the auth handler got called as expected - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@tester:test", "oidc", ANY, @@ -1124,18 +1350,17 @@ def test_attribute_requirements(self) -> None: ) def test_attribute_requirements_contains(self) -> None: """Test that auth succeeds if userinfo attribute CONTAINS required value""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # userinfo with "test": ["foobar", "foo", "bar"] attribute should succeed. userinfo = { "sub": "tester", "username": "tester", "test": ["foobar", "foo", "bar"], } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) # check that the auth handler got called as expected - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@tester:test", "oidc", ANY, @@ -1158,16 +1383,15 @@ def test_attribute_requirements_mismatch(self) -> None: Test that auth fails if attributes exist but don't match, or are non-string values. """ - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # userinfo with "test": "not_foobar" attribute should fail userinfo: dict = { "sub": "tester", "username": "tester", "test": "not_foobar", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": ["foo", "bar"] attribute should fail userinfo = { @@ -1175,8 +1399,9 @@ def test_attribute_requirements_mismatch(self) -> None: "username": "tester", "test": ["foo", "bar"], } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": False attribute should fail # this is largely just to ensure we don't crash here @@ -1185,8 +1410,9 @@ def test_attribute_requirements_mismatch(self) -> None: "username": "tester", "test": False, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": None attribute should fail # a value of None breaks the OIDC spec, but it's important to not crash here @@ -1195,8 +1421,9 @@ def test_attribute_requirements_mismatch(self) -> None: "username": "tester", "test": None, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": 1 attribute should fail # this is largely just to ensure we don't crash here @@ -1205,8 +1432,9 @@ def test_attribute_requirements_mismatch(self) -> None: "username": "tester", "test": 1, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": 3.14 attribute should fail # this is largely just to ensure we don't crash here @@ -1215,8 +1443,9 @@ def test_attribute_requirements_mismatch(self) -> None: "username": "tester", "test": 3.14, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() def _generate_oidc_session_token( self, @@ -1238,39 +1467,34 @@ def _generate_oidc_session_token( ) -async def _make_callback_with_userinfo( - hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect" -) -> None: - """Mock up an OIDC callback with the given userinfo dict +def _mock_response( + *, code: int = 200, headers: Optional[Headers] = None, body: bytes = b"" +) -> IResponse: + if headers is None: + headers = Headers() - We'll pull out the OIDC handler from the homeserver, stub out a couple of methods, - and poke in the userinfo dict as if it were the response to an OIDC userinfo call. + def deliver_body(p: IProtocol): + p.dataReceived(body) + p.connectionLost(Failure(ResponseDone())) - Args: - hs: the HomeServer impl to send the callback to. - userinfo: the OIDC userinfo dict - client_redirect_url: the URL to redirect to on success. - """ - - handler = hs.get_oidc_handler() - provider = handler._providers["oidc"] - provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment] - provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - - state = "state" - session = handler._macaroon_generator.generate_oidc_session_token( - state=state, - session_data=OidcSessionData( - idp_id="oidc", - nonce="nonce", - client_redirect_url=client_redirect_url, - ui_auth_session_id="", - ), + response = mock.Mock( + code=code, + phrase=RESPONSES.get(code, b"Unknown Status"), + headers=headers, + length=len(body), + deliverBody=deliver_body, ) - request = _build_callback_request("code", state, session) + mock.seal(response) + return response - await handler.handle_oidc_callback(request) + +def _mock_json_response(*, payload: JsonDict, code: int = 200) -> IResponse: + body = json.dumps(payload).encode("utf-8") + return _mock_response( + code=code, + headers=Headers({"content-Type": ["application/json"]}), + body=body, + ) def _build_callback_request(