Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Update the process for mapping SAML2 users to matrix IDs #6037

Merged
merged 10 commits into from
Sep 24, 2019
1 change: 1 addition & 0 deletions changelog.d/6037.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make the process for mapping SAML2 users to matrix IDs more flexible.
26 changes: 26 additions & 0 deletions docs/sample_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1174,6 +1174,32 @@ saml2_config:
#
#saml_session_lifetime: 5m

# The SAML attribute (after mapping via the attribute maps) to use to derive
# the Matrix ID from. 'uid' by default.
#
#mxid_source_attribute: displayName

# The mapping system to use for mapping the saml attribute onto a matrix ID.
# Options include:
# * 'hexencode' (which maps unpermitted characters to '=xx')
# * 'dotreplace' (which replaces unpermitted characters with '.').
# The default is 'hexencode'.
#
#mxid_mapping: dotreplace

# In previous versions of synapse, the mapping from SAML attribute to MXID was
# always calculated dynamically rather than stored in a table. For backwards-
# compatibility, we will look for user_ids matching such a pattern before
# creating a new account.
#
# This setting controls the SAML attribute which will be used for this
# backwards-compatibility lookup. Typically it should be 'uid', but if the
# attribute maps are changed, it may be necessary to change it.
#
# The default is 'uid'.
#
#grandfathered_mxid_source_attribute: upn



# Enable CAS for registration and login.
Expand Down
78 changes: 76 additions & 2 deletions synapse/config/saml2_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re

from synapse.python_dependencies import DependencyException, check_requirements
from synapse.types import (
map_username_to_mxid_localpart,
mxid_localpart_allowed_characters,
)

from ._base import Config, ConfigError

Expand All @@ -36,6 +42,14 @@ def read_config(self, config, **kwargs):

self.saml2_enabled = True

self.saml2_mxid_source_attribute = saml2_config.get(
"mxid_source_attribute", "uid"
)

self.saml2_grandfathered_mxid_source_attribute = saml2_config.get(
"grandfathered_mxid_source_attribute", "uid"
)

import saml2.config

self.saml2_sp_config = saml2.config.SPConfig()
Expand All @@ -51,13 +65,26 @@ def read_config(self, config, **kwargs):
saml2_config.get("saml_session_lifetime", "5m")
)

mapping = saml2_config.get("mxid_mapping", "hexencode")
try:
self.saml2_mxid_mapper = MXID_MAPPER_MAP[mapping]
except KeyError:
raise ConfigError("%s is not a known mxid_mapping" % (mapping,))

def _default_saml_config_dict(self):
import saml2

public_baseurl = self.public_baseurl
if public_baseurl is None:
raise ConfigError("saml2_config requires a public_baseurl to be set")

required_attributes = {"uid", self.saml2_mxid_source_attribute}

optional_attributes = {"displayName"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason this doesn't match what was there before?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, it was broken before.

if self.saml2_grandfathered_mxid_source_attribute:
optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
optional_attributes -= required_attributes

metadata_url = public_baseurl + "_matrix/saml2/metadata.xml"
response_url = public_baseurl + "_matrix/saml2/authn_response"
return {
Expand All @@ -69,8 +96,9 @@ def _default_saml_config_dict(self):
(response_url, saml2.BINDING_HTTP_POST)
]
},
"required_attributes": ["uid"],
"optional_attributes": ["mail", "surname", "givenname"],
"required_attributes": list(required_attributes),
"optional_attributes": list(optional_attributes),
# "name_id_format": saml2.saml.NAMEID_FORMAT_PERSISTENT,
}
},
}
Expand Down Expand Up @@ -146,6 +174,52 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs):
# The default is 5 minutes.
#
#saml_session_lifetime: 5m

# The SAML attribute (after mapping via the attribute maps) to use to derive
# the Matrix ID from. 'uid' by default.
#
#mxid_source_attribute: displayName

# The mapping system to use for mapping the saml attribute onto a matrix ID.
# Options include:
# * 'hexencode' (which maps unpermitted characters to '=xx')
# * 'dotreplace' (which replaces unpermitted characters with '.').
# The default is 'hexencode'.
#
#mxid_mapping: dotreplace

# In previous versions of synapse, the mapping from SAML attribute to MXID was
# always calculated dynamically rather than stored in a table. For backwards-
# compatibility, we will look for user_ids matching such a pattern before
# creating a new account.
#
# This setting controls the SAML attribute which will be used for this
# backwards-compatibility lookup. Typically it should be 'uid', but if the
# attribute maps are changed, it may be necessary to change it.
#
# The default is 'uid'.
#
#grandfathered_mxid_source_attribute: upn
""" % {
"config_dir_path": config_dir_path
}


DOT_REPLACE_PATTERN = re.compile(
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
)


def dot_replace_for_mxid(username: str) -> str:
username = username.lower()
username = DOT_REPLACE_PATTERN.sub(".", username)

# regular mxids aren't allowed to start with an underscore either
username = re.sub("^_", "", username)
return username


MXID_MAPPER_MAP = {
"hexencode": map_username_to_mxid_localpart,
"dotreplace": dot_replace_for_mxid,
}
106 changes: 98 additions & 8 deletions synapse/handlers/saml_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_string
from synapse.rest.client.v1.login import SSOAuthHandler
from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.util.async_helpers import Linearizer

logger = logging.getLogger(__name__)

Expand All @@ -29,12 +31,26 @@ class SamlHandler:
def __init__(self, hs):
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._sso_auth_handler = SSOAuthHandler(hs)
self._registration_handler = hs.get_registration_handler()

self._clock = hs.get_clock()
self._datastore = hs.get_datastore()
self._hostname = hs.hostname
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
self._grandfathered_mxid_source_attribute = (
hs.config.saml2_grandfathered_mxid_source_attribute
)
self._mxid_mapper = hs.config.saml2_mxid_mapper

# identifier for the external_ids table
self._auth_provider_id = "saml"

# a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {}

self._clock = hs.get_clock()
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
# a lock on the mappings
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)

def handle_redirect_request(self, client_redirect_url):
"""Handle an incoming request to /login/sso/redirect
Expand All @@ -60,7 +76,7 @@ def handle_redirect_request(self, client_redirect_url):
# this shouldn't happen!
raise Exception("prepare_for_authenticate didn't return a Location header")

def handle_saml_response(self, request):
async def handle_saml_response(self, request):
"""Handle an incoming request to /_matrix/saml2/authn_response

Args:
Expand All @@ -77,6 +93,10 @@ def handle_saml_response(self, request):
# the dict.
self.expire_sessions()

user_id = await self._map_saml_response_to_user(resp_bytes)
self._sso_auth_handler.complete_sso_login(user_id, request, relay_state)

async def _map_saml_response_to_user(self, resp_bytes):
try:
saml2_auth = self._saml_client.parse_authn_request_response(
resp_bytes,
Expand All @@ -91,18 +111,88 @@ def handle_saml_response(self, request):
logger.warning("SAML2 response was not signed")
raise SynapseError(400, "SAML2 response was not signed")

if "uid" not in saml2_auth.ava:
logger.info("SAML2 response: %s", saml2_auth.origxml)
logger.info("SAML2 mapped attributes: %s", saml2_auth.ava)

try:
remote_user_id = saml2_auth.ava["uid"][0]
except KeyError:
logger.warning("SAML2 response lacks a 'uid' attestation")
raise SynapseError(400, "uid not in SAML2 response")

try:
mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
except KeyError:
logger.warning(
"SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
)
raise SynapseError(
400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
)

self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)

username = saml2_auth.ava["uid"][0]
displayName = saml2_auth.ava.get("displayName", [None])[0]

return self._sso_auth_handler.on_successful_auth(
username, request, relay_state, user_display_name=displayName
)
with (await self._mapping_lock.queue(self._auth_provider_id)):
# first of all, check if we already have a mapping for this user
logger.info(
"Looking for existing mapping for user %s:%s",
self._auth_provider_id,
remote_user_id,
)
registered_user_id = await self._datastore.get_user_by_external_id(
self._auth_provider_id, remote_user_id
)
if registered_user_id is not None:
logger.info("Found existing mapping %s", registered_user_id)
return registered_user_id

# backwards-compatibility hack: see if there is an existing user with a
# suitable mapping from the uid
if (
self._grandfathered_mxid_source_attribute
and self._grandfathered_mxid_source_attribute in saml2_auth.ava
):
attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
user_id = UserID(
map_username_to_mxid_localpart(attrval), self._hostname
).to_string()
logger.info(
"Looking for existing account based on mapped %s %s",
self._grandfathered_mxid_source_attribute,
user_id,
)

users = await self._datastore.get_users_by_id_case_insensitive(user_id)
if users:
registered_user_id = list(users.keys())[0]
logger.info("Grandfathering mapping to %s", registered_user_id)
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id

# figure out a new mxid for this user
base_mxid_localpart = self._mxid_mapper(mxid_source)

suffix = 0
while True:
localpart = base_mxid_localpart + (str(suffix) if suffix else "")
if not await self._datastore.get_users_by_id_case_insensitive(
UserID(localpart, self._hostname).to_string()
):
break
suffix += 1
logger.info("Allocating mxid for new user with localpart %s", localpart)

registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=displayName
)
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id

def expire_sessions(self):
expire_before = self._clock.time_msec() - self._saml2_session_lifetime
Expand Down
14 changes: 14 additions & 0 deletions synapse/rest/client/v1/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
parse_json_object_from_request,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import UserID, map_username_to_mxid_localpart
Expand Down Expand Up @@ -507,6 +508,19 @@ def on_successful_auth(
localpart=localpart, default_display_name=user_display_name
)

self.complete_sso_login(registered_user_id, request, client_redirect_url)

def complete_sso_login(
self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
):
"""Having figured out a mxid for this user, complete the HTTP request

Args:
registered_user_id:
request:
client_redirect_url:
"""

login_token = self._macaroon_gen.generate_short_term_login_token(
registered_user_id
)
Expand Down
41 changes: 41 additions & 0 deletions synapse/storage/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from six.moves import range

from twisted.internet import defer
from twisted.internet.defer import Deferred

from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
Expand Down Expand Up @@ -384,6 +385,26 @@ def f(txn):

return self.runInteraction("get_users_by_id_case_insensitive", f)

async def get_user_by_external_id(
self, auth_provider: str, external_id: str
) -> str:
"""Look up a user by their external auth id

Args:
auth_provider: identifier for the remote auth provider
external_id: id on that system

Returns:
str|None: the mxid of the user, or None if they are not known
"""
return await self._simple_select_one_onecol(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this an async func and record_user_external_id isn't?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shrug. I can update one of them if you like.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair, was just wondering if it was a typo or something. Doesn't need to change

table="user_external_ids",
keyvalues={"auth_provider": auth_provider, "external_id": external_id},
retcol="user_id",
allow_none=True,
desc="get_user_by_external_id",
)

@defer.inlineCallbacks
def count_all_users(self):
"""Counts all users registered on the homeserver."""
Expand Down Expand Up @@ -1032,6 +1053,26 @@ def _register_user(
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))

def record_user_external_id(
self, auth_provider: str, external_id: str, user_id: str
) -> Deferred:
"""Record a mapping from an external user id to a mxid

Args:
auth_provider: identifier for the remote auth provider
external_id: id on that system
user_id: complete mxid that it is mapped to
"""
return self._simple_insert(
table="user_external_ids",
values={
"auth_provider": auth_provider,
"external_id": external_id,
"user_id": user_id,
},
desc="record_user_external_id",
)

def user_set_password_hash(self, user_id, password_hash):
"""
NB. This does *not* evict any cache because the one use for this
Expand Down
Loading