Skip to content

Commit

Permalink
feat(oauth): claims verification
Browse files Browse the repository at this point in the history
  • Loading branch information
thekaveman committed Jan 28, 2025
1 parent 9027e86 commit 538e108
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 57 deletions.
37 changes: 37 additions & 0 deletions web/oauth/claims.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import logging

logger = logging.getLogger(__name__)


def process(userinfo: dict, expected_claims: list[str]) -> tuple[list[str], dict[str, str]]:
"""Process expected claims from the userinfo dict.
- Boolean claims comes back in userinfo like `{ "claim": "1" | "0" }` or `{ "claim": "true" }`
- Other claims come back in userinfo like `{ "claim": "value" }`
Returns a tuple `(claims: list[str], errors: dict[str, int])`
"""
claims = []
errors = {}

for claim in expected_claims:
claim_value = userinfo.get(claim)
if not claim_value:
logger.warning(f"userinfo did not contain: {claim}")
try:
claim_value = int(claim_value)
except (TypeError, ValueError):
pass
if isinstance(claim_value, int):
if claim_value == 1:
# if userinfo contains our claim and the flag is 1 (true), store the *claim*
claims.append(claim)
elif claim_value >= 10:
errors[claim] = claim_value
elif isinstance(claim_value, str):
if claim_value.lower() == "true":
claims.append(claim)
elif claim_value.lower() != "false":
claims.append(f"{claim}:{claim_value}")

return (claims, errors)
98 changes: 41 additions & 57 deletions web/oauth/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,46 +4,41 @@
from django.shortcuts import redirect
from django.urls import reverse

from . import redirects, session
from . import redirects
from . import claims
from .client import create_client, oauth
from .models import ClientConfig
from .routes import Routes
from .session import Session

logger = logging.getLogger(__name__)


def _oauth_client_config():
return ClientConfig.objects.first()


def _oauth_client_or_error_redirect(request: HttpRequest, config: ClientConfig):
def _client_or_error_redirect(request: HttpRequest):
"""Calls `web.oauth.client.create_client()`.
If a client is created successfully, return it; Otherwise, return a redirect response to OAuth system error.
"""
oauth_client = None
exception = None
session = Session(request)

try:
oauth_client = create_client(oauth, config)
except Exception as ex:
exception = ex

if not oauth_client and not exception:
exception = Exception(f"oauth_client not registered: {config.client_name}")
oauth_config = session.oauth_config
if not oauth_config:
raise Exception("No oauth_config in session")

if exception:
raise exception
scopes = session.oauth_scopes
oauth_client = create_client(oauth, oauth_config, scopes=scopes)
if not oauth_client:
raise Exception(f"oauth_client not registered: {oauth_config.client_name}")

return oauth_client


def authorize(request):
def authorize(request: HttpRequest):
"""View implementing OIDC token authorization."""
logger.debug(Routes.AUTHORIZE)

oauth_client_config = _oauth_client_config()
oauth_client_result = _oauth_client_or_error_redirect(request, oauth_client_config)
session = Session(request)
oauth_client_result = _client_or_error_redirect(request)

if hasattr(oauth_client_result, "authorize_access_token"):
# this looks like an oauth_client since it has the method we need
Expand Down Expand Up @@ -71,42 +66,31 @@ def authorize(request):
logger.debug("OIDC access token authorized")

# We store the id_token in the user's session. This is the minimal amount of information needed later to log the user out.
id_token = token["id_token"]

# We store the returned claim in case it can be used later.
claims = []
session.oauth_token = token["id_token"]
# We store the returned claims in case they can be used later.
expected_claims = session.oauth_claims_check.split(" ")
stored_claims = []

error_claim = {}

if claims:
userinfo = token.get("userinfo")

if userinfo:
for claim in claims:
claim_value = userinfo.get(claim)
# the claim comes back in userinfo like { "claim": "1" | "0" }
claim_value = int(claim_value) if claim_value else None
if claim_value is None:
logger.warning(f"userinfo did not contain: {claim}")
elif claim_value == 1:
# if userinfo contains our claim and the flag is 1 (true), store the *claim*
stored_claims.append(claim)
elif claim_value >= 10:
error_claim[claim] = claim_value

session.oauth_token(request, id_token)
session.oauth_claims(request, stored_claims)

return redirect(oauth_client_config.post_authorize_access_token_redirect)


def login(request):
error_claims = {}

if expected_claims:
userinfo = token.get("userinfo", {})
stored_claims, error_claims = claims.process(userinfo, expected_claims)
# if we found the eligibility claim
if session.oauth_claims_eligibility and session.oauth_claims_eligibility in stored_claims:
# store them and redirect to success
session.oauth_claims_verified = " ".join(stored_claims)
return redirect(session.oauth_redirect_success)
# else redirect to failure
if error_claims:
logger.error(error_claims)
return redirect(session.oauth_redirect_failure)


def login(request: HttpRequest):
"""View implementing OIDC authorize_redirect."""
logger.debug(Routes.LOGIN)

oauth_client_config = _oauth_client_config()
oauth_client_result = _oauth_client_or_error_redirect(request, oauth_client_config)
oauth_client_result = _client_or_error_redirect(request)

if hasattr(oauth_client_result, "authorize_redirect"):
# this looks like an oauth_client since it has the method we need
Expand Down Expand Up @@ -139,12 +123,12 @@ def login(request):
return result


def logout(request):
def logout(request: HttpRequest):
"""View handler for OIDC sign out."""
logger.debug(Routes.LOGOUT)

oauth_client_config = _oauth_client_config()
oauth_client_result = _oauth_client_or_error_redirect(request, oauth_client_config)
session = Session(request)
oauth_client_result = _client_or_error_redirect(request)

if hasattr(oauth_client_result, "load_server_metadata"):
# this looks like an oauth_client since it has the method we need
Expand All @@ -155,8 +139,8 @@ def logout(request):
return oauth_client_result

# overwrite the oauth session token, the user is signed out of the app
token = session.oauth_token(request)
session.logout(request)
token = session.oauth_token
session.logout()

route = reverse(Routes.POST_LOGOUT)
redirect_uri = redirects.generate_redirect_uri(request, route)
Expand Down

0 comments on commit 538e108

Please sign in to comment.