Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IDP token introspection #229

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 213 additions & 0 deletions stytch/b2b/api/idp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
from __future__ import annotations

from typing import Any, Dict, Optional

import jwt

from stytch.b2b.models.sessions import AuthorizationCheck
from stytch.consumer.models.idp import IDPTokenClaims, IDPTokenResponse
from stytch.core.api_base import ApiBase
from stytch.core.http.client import AsyncClient, SyncClient
from stytch.shared import jwt_helpers, rbac_local
from stytch.shared.policy_cache import PolicyCache


class IDP:
def __init__(
self,
api_base: ApiBase,
sync_client: SyncClient,
async_client: AsyncClient,
jwks_client: jwt.PyJWKClient,
project_id: str,
policy_cache: PolicyCache,
) -> None:
self.api_base = api_base
self.sync_client = sync_client
self.async_client = async_client
self.jwks_client = jwks_client
self.project_id = project_id
self.non_custom_claim_keys = [
"aud",
"exp",
"iat",
"iss",
"jti",
"nbf",
"sub",
"active",
"client_id",
"request_id",
"scope",
"status_code",
"token_type",
]
self.policy_cache = policy_cache

def introspect_token_network(
self,
token: str,
client_id: str,
client_secret: Optional[str] = None,
token_type_hint: str = "access_token",
authorization_check: Optional[AuthorizationCheck] = None,
) -> Optional[IDPTokenClaims]:
"""Introspects a token JWT from an authorization code response.
Access tokens are JWTs signed with the project's JWKs. Refresh tokens are opaque tokens.
Access tokens contain a standard set of claims as well as any custom claims generated from templates.

Fields:
- token: The access token (or refresh token) to introspect.
- client_id: The ID of the client.
- client_secret: The secret of the client.
- token_type_hint: A hint on what the token contains. Valid fields are 'access_token' and 'refresh_token'.
"""
headers: Dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"}
data: Dict[str, Any] = {
"token": token,
"client_id": client_id,
"token_type_hint": token_type_hint,
}
if client_secret is not None:
data["client_secret"] = client_secret

url = self.api_base.url_for(
f"/v1/public/{self.project_id}/oauth2/introspect", data
)
res = self.sync_client.post_form(url, data, headers)
jwtResponse = IDPTokenResponse.from_json(res.response.status_code, res.json)
custom_claims = {
k: v for k, v in res.json.items() if k not in self.non_custom_claim_keys
}
if not jwtResponse.active:
return None

scope = jwtResponse.scope

if authorization_check is not None:
rbac_local.perform_scope_authorization_check(
policy=self.policy_cache.get(),
token_scopes=scope.split(),
authorization_check=authorization_check,
)

return IDPTokenClaims(
subject=jwtResponse.sub,
scope=jwtResponse.scope,
audience=jwtResponse.aud,
expires_at=jwtResponse.exp,
issued_at=jwtResponse.iat,
issuer=jwtResponse.iss,
not_before=jwtResponse.nbf,
token_type=jwtResponse.token_type,
custom_claims=custom_claims,
)

async def introspect_token_network_async(
self,
token: str,
client_id: str,
client_secret: Optional[str] = None,
token_type_hint: str = "access_token",
authorization_check: Optional[AuthorizationCheck] = None,
) -> Optional[IDPTokenClaims]:
"""Introspects a token JWT from an authorization code response.
Access tokens are JWTs signed with the project's JWKs. Refresh tokens are opaque tokens.
Access tokens contain a standard set of claims as well as any custom claims generated from templates.

Fields:
- token: The access token (or refresh token) to introspect.
- client_id: The ID of the client.
- client_secret: The secret of the client.
- token_type_hint: A hint on what the token contains. Valid fields are 'access_token' and 'refresh_token'.
"""
headers: Dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"}
data: Dict[str, Any] = {
"token": token,
"client_id": client_id,
"token_type_hint": token_type_hint,
}
if client_secret is not None:
data["client_secret"] = client_secret

url = self.api_base.url_for(
f"/v1/public/{self.project_id}/oauth2/introspect", data
)
res = await self.async_client.post_form(url, data, headers)
jwtResponse = IDPTokenResponse.from_json(res.response.status, res.json)
custom_claims = {
k: v for k, v in res.json.items() if k not in self.non_custom_claim_keys
}
if not jwtResponse.active:
return None

scope = jwtResponse.scope

if authorization_check is not None:
rbac_local.perform_scope_authorization_check(
policy=self.policy_cache.get(),
token_scopes=scope.split(),
authorization_check=authorization_check,
)

return IDPTokenClaims(
subject=jwtResponse.sub,
scope=jwtResponse.scope,
audience=jwtResponse.aud,
expires_at=jwtResponse.exp,
issued_at=jwtResponse.iat,
issuer=jwtResponse.iss,
not_before=jwtResponse.nbf,
token_type=jwtResponse.token_type,
custom_claims=custom_claims,
)

def introspect_access_token_local(
self,
access_token: str,
client_id: str,
authorization_check: Optional[AuthorizationCheck] = None,
) -> Optional[IDPTokenClaims]:
"""Introspects a token JWT from an authorization code response.
Access tokens are JWTs signed with the project's JWKs. Refresh tokens are opaque tokens.
Access tokens contain a standard set of claims as well as any custom claims generated from templates.

Fields:
- access_token: The access token (or refresh token) to introspect.
- client_id: The ID of the client.
"""
_scope_claim = "scope"
generic_claims = jwt_helpers.authenticate_jwt_local(
project_id=self.project_id,
jwks_client=self.jwks_client,
jwt=access_token,
custom_audience=client_id,
custom_issuer=f"https://stytch.com/{self.project_id}",
)
if generic_claims is None:
return None

custom_claims = {
k: v for k, v in generic_claims.untyped_claims.items() if k != _scope_claim
}

scope = generic_claims.untyped_claims[_scope_claim]

if authorization_check is not None:
rbac_local.perform_scope_authorization_check(
policy=self.policy_cache.get(),
token_scopes=scope.split(),
authorization_check=authorization_check,
)

return IDPTokenClaims(
subject=generic_claims.reserved_claims["sub"],
scope=scope,
custom_claims=custom_claims,
audience=generic_claims.reserved_claims["aud"],
expires_at=generic_claims.reserved_claims["exp"],
issued_at=generic_claims.reserved_claims["iat"],
issuer=generic_claims.reserved_claims["iss"],
not_before=generic_claims.reserved_claims["nbf"],
token_type="access_token",
)
9 changes: 9 additions & 0 deletions stytch/b2b/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import jwt

from stytch.b2b.api.discovery import Discovery
from stytch.b2b.api.idp import IDP
from stytch.b2b.api.impersonation import Impersonation
from stytch.b2b.api.magic_links import MagicLinks
from stytch.b2b.api.oauth import OAuth
Expand Down Expand Up @@ -148,6 +149,14 @@ def __init__(
sync_client=self.sync_client,
async_client=self.async_client,
)
self.idp = IDP(
api_base=self.api_base,
sync_client=self.sync_client,
async_client=self.async_client,
jwks_client=self.jwks_client,
project_id=project_id,
policy_cache=policy_cache,
)

def get_jwks_client(self, project_id: str) -> jwt.PyJWKClient:
data = {"project_id": project_id}
Expand Down
13 changes: 13 additions & 0 deletions stytch/b2b/models/rbac.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,28 @@ class PolicyRole(pydantic.BaseModel):
permissions: List[PolicyRolePermission]


class PolicyScopePermission(pydantic.BaseModel):
resource_id: str
actions: List[str]


class PolicyScope(pydantic.BaseModel):
scope: str
description: str
permissions: List[PolicyScopePermission]


class Policy(pydantic.BaseModel):
"""
Fields:
- roles: An array of [Role objects](https://stytch.com/docs/b2b/api/rbac-role-object).
- resources: An array of [Resource objects](https://stytch.com/docs/b2b/api/rbac-resource-object).
- scopes: (no documentation yet)
""" # noqa

roles: List[PolicyRole]
resources: List[PolicyResource]
scopes: List[PolicyScope]


class PolicyResponse(ResponseBase):
Expand Down
Loading
Loading