Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Record the SSO Auth Provider in the login token #9510

Merged
merged 14 commits into from
Mar 4, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,6 +1406,7 @@ async def start_sso_ui_auth(self, request: SynapseRequest, session_id: str) -> s
async def complete_sso_login(
self,
registered_user_id: str,
auth_provider_id: str,
request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
Expand All @@ -1415,6 +1416,9 @@ async def complete_sso_login(

Args:
registered_user_id: The registered user ID to complete SSO login for.
auth_provider_id: The id of the SSO Identity provider that was used for
login. This will be stored in the login token for future tracking in
prometheus metrics.
request: The request to complete.
client_redirect_url: The URL to which to redirect the user at the end of the
process.
Expand All @@ -1436,6 +1440,7 @@ async def complete_sso_login(

self._complete_sso_login(
registered_user_id,
auth_provider_id,
request,
client_redirect_url,
extra_attributes,
Expand All @@ -1446,6 +1451,7 @@ async def complete_sso_login(
def _complete_sso_login(
self,
registered_user_id: str,
auth_provider_id: str,
request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
Expand All @@ -1472,7 +1478,7 @@ def _complete_sso_login(

# Create a login token
login_token = self.macaroon_gen.generate_short_term_login_token(
registered_user_id
registered_user_id, auth_provider_id=auth_provider_id
)

# Append the login token to the original redirect URL (i.e. with its query
Expand Down Expand Up @@ -1578,13 +1584,20 @@ def generate_access_token(
return macaroon.serialize()

def generate_short_term_login_token(
self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
self,
user_id: str,
duration_in_ms: int = (2 * 60 * 1000),
auth_provider_id: Optional[str] = None,
) -> str:
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
if auth_provider_id is not None:
macaroon.add_first_party_caveat(
"auth_provider_id = %s" % (auth_provider_id,)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A macaroon that doesn't have an auth_provider_id caveat allows clients to add any auth_provider_id themselves. I think we want to add a auth_provider_id = None (or equivalent) when auth_provider_id is None?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the advantage of storing this information in a macaroon instead of recording it internally in the database?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A macaroon that doesn't have an auth_provider_id caveat allows clients to add any auth_provider_id themselves. I think we want to add a auth_provider_id = None (or equivalent) when auth_provider_id is None?

we were actually setting it to something not-None on all code paths. I've updated the method signature to enforce this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the advantage of storing this information in a macaroon instead of recording it internally in the database?

it's just easier! To store it in the database, we'd have to add a table, do a bunch of sql to insert and select rows, and do a bunch more sql to delete rows once they expire. Not having to do the database queries seems more efficient too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's just easier!

/me looks at the size of this PR :p

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:-p

This PR is mostly moving existing stuff about. Storing it in the DB would be a bunch of new code.

return macaroon.serialize()

def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
Expand Down
2 changes: 2 additions & 0 deletions synapse/handlers/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ async def complete_sso_login_request(

await self._auth_handler.complete_sso_login(
user_id,
auth_provider_id,
request,
client_redirect_url,
extra_login_attributes,
Expand Down Expand Up @@ -886,6 +887,7 @@ async def register_sso_user(self, request: Request, session_id: str) -> None:

await self._auth_handler.complete_sso_login(
user_id,
session.auth_provider_id,
request,
session.client_redirect_url,
session.extra_login_attributes,
Expand Down
31 changes: 27 additions & 4 deletions synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,26 @@ def record_user_external_id(
)

def generate_short_term_login_token(
self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
self,
user_id: str,
duration_in_ms: int = (2 * 60 * 1000),
auth_provider_id: Optional[str] = None,
) -> str:
"""Generate a login token suitable for m.login.token authentication"""
"""Generate a login token suitable for m.login.token authentication

Args:
user_id: gives the ID of the user that the token is for

duration_in_ms: the time that the token will be valid for

auth_provider_id: the ID of the SSO IdP that the user used to authenticate
to get this token, if any. This is encoded in the token so that
/login can report stats on number of successful logins by IdP.
"""
return self._hs.get_macaroon_generator().generate_short_term_login_token(
user_id, duration_in_ms
user_id,
duration_in_ms,
auth_provider_id=auth_provider_id,
)

@defer.inlineCallbacks
Expand Down Expand Up @@ -276,6 +291,7 @@ def complete_sso_login(
"""
self._auth_handler._complete_sso_login(
registered_user_id,
"<unknown>",
request,
client_redirect_url,
)
Expand All @@ -286,6 +302,7 @@ async def complete_sso_login_async(
request: SynapseRequest,
client_redirect_url: str,
new_user: bool = False,
auth_provider_id: str = "<unknown>",
):
"""Complete a SSO login by redirecting the user to a page to confirm whether they
want their access token sent to `client_redirect_url`, or redirect them to that
Expand All @@ -299,9 +316,15 @@ async def complete_sso_login_async(
redirect them directly if whitelisted).
new_user: set to true to use wording for the consent appropriate to a user
who has just registered.
auth_provider_id: the ID of the SSO IdP which was used to log in. This
is used to track counts of sucessful logins by IdP.
"""
await self._auth_handler.complete_sso_login(
registered_user_id, request, client_redirect_url, new_user=new_user
registered_user_id,
auth_provider_id,
request,
client_redirect_url,
new_user=new_user,
)

@defer.inlineCallbacks
Expand Down
8 changes: 8 additions & 0 deletions tests/handlers/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ def test_short_term_login_token_gives_user_id(self):
AuthError,
)

def test_short_term_login_token_gives_auth_provider(self):
token = self.macaroon_generator.generate_short_term_login_token(
"a_user", auth_provider_id="my_idp"
)
res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
self.assertEqual("a_user", res.user_id)
self.assertEqual("my_idp", res.auth_provider_id)

def test_short_term_login_token_cannot_replace_user_id(self):
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
macaroon = pymacaroons.Macaroon.deserialize(token)
Expand Down