diff --git a/stytch/b2b/api/idp.py b/stytch/b2b/api/idp.py new file mode 100644 index 0000000..f620aac --- /dev/null +++ b/stytch/b2b/api/idp.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +from typing import Any, Dict, Optional + +import jwt + +from stytch.b2b.models.sessions import AuthorizationCheck +from stytch.consumer.models.idp import IDPTokenClaims, IDPTokenResponse +from stytch.core.api_base import ApiBase +from stytch.core.http.client import AsyncClient, SyncClient +from stytch.shared import jwt_helpers, rbac_local +from stytch.shared.policy_cache import PolicyCache + + +class IDP: + def __init__( + self, + api_base: ApiBase, + sync_client: SyncClient, + async_client: AsyncClient, + jwks_client: jwt.PyJWKClient, + project_id: str, + policy_cache: PolicyCache, + ) -> None: + self.api_base = api_base + self.sync_client = sync_client + self.async_client = async_client + self.jwks_client = jwks_client + self.project_id = project_id + self.non_custom_claim_keys = [ + "aud", + "exp", + "iat", + "iss", + "jti", + "nbf", + "sub", + "active", + "client_id", + "request_id", + "scope", + "status_code", + "token_type", + ] + self.policy_cache = policy_cache + + def introspect_token_network( + self, + token: str, + client_id: str, + client_secret: Optional[str] = None, + token_type_hint: str = "access_token", + authorization_check: Optional[AuthorizationCheck] = None, + ) -> Optional[IDPTokenClaims]: + """Introspects a token JWT from an authorization code response. + Access tokens are JWTs signed with the project's JWKs. Refresh tokens are opaque tokens. + Access tokens contain a standard set of claims as well as any custom claims generated from templates. + + Fields: + - token: The access token (or refresh token) to introspect. + - client_id: The ID of the client. + - client_secret: The secret of the client. + - token_type_hint: A hint on what the token contains. Valid fields are 'access_token' and 'refresh_token'. + """ + headers: Dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"} + data: Dict[str, Any] = { + "token": token, + "client_id": client_id, + "token_type_hint": token_type_hint, + } + if client_secret is not None: + data["client_secret"] = client_secret + + url = self.api_base.url_for( + f"/v1/public/{self.project_id}/oauth2/introspect", data + ) + res = self.sync_client.post_form(url, data, headers) + jwtResponse = IDPTokenResponse.from_json(res.response.status_code, res.json) + custom_claims = { + k: v for k, v in res.json.items() if k not in self.non_custom_claim_keys + } + if not jwtResponse.active: + return None + + scope = jwtResponse.scope + + if authorization_check is not None: + rbac_local.perform_scope_authorization_check( + policy=self.policy_cache.get(), + token_scopes=scope.split(), + authorization_check=authorization_check, + ) + + return IDPTokenClaims( + subject=jwtResponse.sub, + scope=jwtResponse.scope, + audience=jwtResponse.aud, + expires_at=jwtResponse.exp, + issued_at=jwtResponse.iat, + issuer=jwtResponse.iss, + not_before=jwtResponse.nbf, + token_type=jwtResponse.token_type, + custom_claims=custom_claims, + ) + + async def introspect_token_network_async( + self, + token: str, + client_id: str, + client_secret: Optional[str] = None, + token_type_hint: str = "access_token", + authorization_check: Optional[AuthorizationCheck] = None, + ) -> Optional[IDPTokenClaims]: + """Introspects a token JWT from an authorization code response. + Access tokens are JWTs signed with the project's JWKs. Refresh tokens are opaque tokens. + Access tokens contain a standard set of claims as well as any custom claims generated from templates. + + Fields: + - token: The access token (or refresh token) to introspect. + - client_id: The ID of the client. + - client_secret: The secret of the client. + - token_type_hint: A hint on what the token contains. Valid fields are 'access_token' and 'refresh_token'. + """ + headers: Dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"} + data: Dict[str, Any] = { + "token": token, + "client_id": client_id, + "token_type_hint": token_type_hint, + } + if client_secret is not None: + data["client_secret"] = client_secret + + url = self.api_base.url_for( + f"/v1/public/{self.project_id}/oauth2/introspect", data + ) + res = await self.async_client.post_form(url, data, headers) + jwtResponse = IDPTokenResponse.from_json(res.response.status, res.json) + custom_claims = { + k: v for k, v in res.json.items() if k not in self.non_custom_claim_keys + } + if not jwtResponse.active: + return None + + scope = jwtResponse.scope + + if authorization_check is not None: + rbac_local.perform_scope_authorization_check( + policy=self.policy_cache.get(), + token_scopes=scope.split(), + authorization_check=authorization_check, + ) + + return IDPTokenClaims( + subject=jwtResponse.sub, + scope=jwtResponse.scope, + audience=jwtResponse.aud, + expires_at=jwtResponse.exp, + issued_at=jwtResponse.iat, + issuer=jwtResponse.iss, + not_before=jwtResponse.nbf, + token_type=jwtResponse.token_type, + custom_claims=custom_claims, + ) + + def introspect_access_token_local( + self, + access_token: str, + client_id: str, + authorization_check: Optional[AuthorizationCheck] = None, + ) -> Optional[IDPTokenClaims]: + """Introspects a token JWT from an authorization code response. + Access tokens are JWTs signed with the project's JWKs. Refresh tokens are opaque tokens. + Access tokens contain a standard set of claims as well as any custom claims generated from templates. + + Fields: + - access_token: The access token (or refresh token) to introspect. + - client_id: The ID of the client. + """ + _scope_claim = "scope" + generic_claims = jwt_helpers.authenticate_jwt_local( + project_id=self.project_id, + jwks_client=self.jwks_client, + jwt=access_token, + custom_audience=client_id, + custom_issuer=f"https://stytch.com/{self.project_id}", + ) + if generic_claims is None: + return None + + custom_claims = { + k: v for k, v in generic_claims.untyped_claims.items() if k != _scope_claim + } + + scope = generic_claims.untyped_claims[_scope_claim] + + if authorization_check is not None: + rbac_local.perform_scope_authorization_check( + policy=self.policy_cache.get(), + token_scopes=scope.split(), + authorization_check=authorization_check, + ) + + return IDPTokenClaims( + subject=generic_claims.reserved_claims["sub"], + scope=scope, + custom_claims=custom_claims, + audience=generic_claims.reserved_claims["aud"], + expires_at=generic_claims.reserved_claims["exp"], + issued_at=generic_claims.reserved_claims["iat"], + issuer=generic_claims.reserved_claims["iss"], + not_before=generic_claims.reserved_claims["nbf"], + token_type="access_token", + ) diff --git a/stytch/b2b/client.py b/stytch/b2b/client.py index 4ce5f81..be88ee9 100644 --- a/stytch/b2b/client.py +++ b/stytch/b2b/client.py @@ -11,6 +11,7 @@ import jwt from stytch.b2b.api.discovery import Discovery +from stytch.b2b.api.idp import IDP from stytch.b2b.api.impersonation import Impersonation from stytch.b2b.api.magic_links import MagicLinks from stytch.b2b.api.oauth import OAuth @@ -148,6 +149,14 @@ def __init__( sync_client=self.sync_client, async_client=self.async_client, ) + self.idp = IDP( + api_base=self.api_base, + sync_client=self.sync_client, + async_client=self.async_client, + jwks_client=self.jwks_client, + project_id=project_id, + policy_cache=policy_cache, + ) def get_jwks_client(self, project_id: str) -> jwt.PyJWKClient: data = {"project_id": project_id} diff --git a/stytch/b2b/models/rbac.py b/stytch/b2b/models/rbac.py index 6923e4e..fe37e1c 100644 --- a/stytch/b2b/models/rbac.py +++ b/stytch/b2b/models/rbac.py @@ -122,15 +122,28 @@ class PolicyRole(pydantic.BaseModel): permissions: List[PolicyRolePermission] +class PolicyScopePermission(pydantic.BaseModel): + resource_id: str + actions: List[str] + + +class PolicyScope(pydantic.BaseModel): + scope: str + description: str + permissions: List[PolicyScopePermission] + + class Policy(pydantic.BaseModel): """ Fields: - roles: An array of [Role objects](https://stytch.com/docs/b2b/api/rbac-role-object). - resources: An array of [Resource objects](https://stytch.com/docs/b2b/api/rbac-resource-object). + - scopes: (no documentation yet) """ # noqa roles: List[PolicyRole] resources: List[PolicyResource] + scopes: List[PolicyScope] class PolicyResponse(ResponseBase): diff --git a/stytch/consumer/api/idp.py b/stytch/consumer/api/idp.py new file mode 100644 index 0000000..db50923 --- /dev/null +++ b/stytch/consumer/api/idp.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +from typing import Any, Dict, Optional + +import jwt + +from stytch.consumer.models.idp import IDPTokenClaims, IDPTokenResponse +from stytch.core.api_base import ApiBase +from stytch.core.http.client import AsyncClient, SyncClient +from stytch.shared import jwt_helpers + + +class IDP: + def __init__( + self, + api_base: ApiBase, + sync_client: SyncClient, + async_client: AsyncClient, + jwks_client: jwt.PyJWKClient, + project_id: str, + ) -> None: + self.api_base = api_base + self.sync_client = sync_client + self.async_client = async_client + self.jwks_client = jwks_client + self.project_id = project_id + self.non_custom_claim_keys = [ + "aud", + "exp", + "iat", + "iss", + "jti", + "nbf", + "sub", + "active", + "client_id", + "request_id", + "scope", + "status_code", + "token_type", + ] + + def introspect_token_network( + self, + token: str, + client_id: str, + client_secret: Optional[str] = None, + token_type_hint: str = "access_token", + ) -> Optional[IDPTokenClaims]: + """Introspects a token JWT from an authorization code response. + Access tokens are JWTs signed with the project's JWKs. Refresh tokens are opaque tokens. + Access tokens contain a standard set of claims as well as any custom claims generated from templates. + + Fields: + - token: The access token (or refresh token) to introspect. + - client_id: The ID of the client. + - client_secret: The secret of the client. + - token_type_hint: A hint on what the token contains. Valid fields are 'access_token' and 'refresh_token'. + """ + headers: Dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"} + data: Dict[str, Any] = { + "token": token, + "client_id": client_id, + "token_type_hint": token_type_hint, + } + if client_secret is not None: + data["client_secret"] = client_secret + + url = self.api_base.url_for( + f"/v1/public/{self.project_id}/oauth2/introspect", data + ) + res = self.sync_client.post_form(url, data, headers) + jwtResponse = IDPTokenResponse.from_json(res.response.status_code, res.json) + custom_claims = { + k: v for k, v in res.json.items() if k not in self.non_custom_claim_keys + } + if not jwtResponse.active: + return None + + return IDPTokenClaims( + subject=jwtResponse.sub, + scope=jwtResponse.scope, + audience=jwtResponse.aud, + expires_at=jwtResponse.exp, + issued_at=jwtResponse.iat, + issuer=jwtResponse.iss, + not_before=jwtResponse.nbf, + token_type=jwtResponse.token_type, + custom_claims=custom_claims, + ) + + async def introspect_token_network_async( + self, + token: str, + client_id: str, + client_secret: Optional[str] = None, + token_type_hint: str = "access_token", + ) -> Optional[IDPTokenClaims]: + """Introspects a token JWT from an authorization code response. + Access tokens are JWTs signed with the project's JWKs. Refresh tokens are opaque tokens. + Access tokens contain a standard set of claims as well as any custom claims generated from templates. + + Fields: + - token: The access token (or refresh token) to introspect. + - client_id: The ID of the client. + - client_secret: The secret of the client. + - token_type_hint: A hint on what the token contains. Valid fields are 'access_token' and 'refresh_token'. + """ + headers: Dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"} + data: Dict[str, Any] = { + "token": token, + "client_id": client_id, + "token_type_hint": token_type_hint, + } + if client_secret is not None: + data["client_secret"] = client_secret + + url = self.api_base.url_for( + f"/v1/public/{self.project_id}/oauth2/introspect", data + ) + res = await self.async_client.post_form(url, data, headers) + jwtResponse = IDPTokenResponse.from_json(res.response.status, res.json) + custom_claims = { + k: v for k, v in res.json.items() if k not in self.non_custom_claim_keys + } + if not jwtResponse.active: + return None + + return IDPTokenClaims( + subject=jwtResponse.sub, + scope=jwtResponse.scope, + audience=jwtResponse.aud, + expires_at=jwtResponse.exp, + issued_at=jwtResponse.iat, + issuer=jwtResponse.iss, + not_before=jwtResponse.nbf, + token_type=jwtResponse.token_type, + custom_claims=custom_claims, + ) + + def introspect_access_token_local( + self, + access_token: str, + client_id: str, + ) -> Optional[IDPTokenClaims]: + """Introspects a token JWT from an authorization code response. + Access tokens are JWTs signed with the project's JWKs. Refresh tokens are opaque tokens. + Access tokens contain a standard set of claims as well as any custom claims generated from templates. + + Fields: + - access_token: The access token (or refresh token) to introspect. + - client_id: The ID of the client. + """ + _scope_claim = "scope" + generic_claims = jwt_helpers.authenticate_jwt_local( + project_id=self.project_id, + jwks_client=self.jwks_client, + jwt=access_token, + custom_audience=client_id, + custom_issuer=f"https://stytch.com/{self.project_id}", + ) + if generic_claims is None: + return None + + custom_claims = { + k: v for k, v in generic_claims.untyped_claims.items() if k != _scope_claim + } + + return IDPTokenClaims( + subject=generic_claims.reserved_claims["sub"], + scope=generic_claims.untyped_claims[_scope_claim], + custom_claims=custom_claims, + audience=generic_claims.reserved_claims["aud"], + expires_at=generic_claims.reserved_claims["exp"], + issued_at=generic_claims.reserved_claims["iat"], + issuer=generic_claims.reserved_claims["iss"], + not_before=generic_claims.reserved_claims["nbf"], + token_type="access_token", + ) diff --git a/stytch/consumer/client.py b/stytch/consumer/client.py index f1f6aa1..10eea41 100644 --- a/stytch/consumer/client.py +++ b/stytch/consumer/client.py @@ -12,6 +12,7 @@ from stytch.consumer.api.crypto_wallets import CryptoWallets from stytch.consumer.api.fraud import Fraud +from stytch.consumer.api.idp import IDP from stytch.consumer.api.m2m import M2M from stytch.consumer.api.magic_links import MagicLinks from stytch.consumer.api.oauth import OAuth @@ -114,6 +115,13 @@ def __init__( sync_client=self.sync_client, async_client=self.async_client, ) + self.idp = IDP( + api_base=self.api_base, + sync_client=self.sync_client, + async_client=self.async_client, + jwks_client=self.jwks_client, + project_id=project_id, + ) def get_jwks_client(self, project_id: str) -> jwt.PyJWKClient: data = {"project_id": project_id} diff --git a/stytch/consumer/models/idp.py b/stytch/consumer/models/idp.py new file mode 100644 index 0000000..c31aa5b --- /dev/null +++ b/stytch/consumer/models/idp.py @@ -0,0 +1,54 @@ +from typing import Any, Dict, List, Optional + +import pydantic + +from stytch.core.response_base import ResponseBase + + +class IDPTokenResponse(ResponseBase): + """Response type for `IDP.introspect_token_network`. + Fields: + - active: Whether or not this token is active. + - sub: Subject of this token. + - scope: A space-delimited string of scopes this token is granted. + - aud: Audience of this token. Usually the user or member ID, and any custom audience, if present. + - exp: Expiration of this access token, in Unix time. + - iat: The time this access token was issued. + - iss: The issuer of this access token. + - nbf: The time before which the token must not be accepted for processing. + """ # noqa + + active: bool + sub: Optional[str] = None + scope: Optional[str] = None + aud: Optional[List[str]] = [] + exp: Optional[int] = None + iat: Optional[int] = None + iss: Optional[str] = None + nbf: Optional[int] = None + token_type: Optional[str] = None + + +class IDPTokenClaims(pydantic.BaseModel): + """Response type for `IDP.introspect_token_network`. + Fields: + - subject: The subject (either user_id or member_id) that the token is intended for. + - scope: A space-delimited string of scopes this token is granted. + - custom_claims: A dict of custom claims of the token. + - audience: Audience of this token. Usually the user or member ID, and any custom audience, if present. + - expires_at: Expiration of this access token, in Unix time. + - issued_at: The time this access token was issued. + - issuer: The issuer of this access token. + - not_before: The time before which the token must not be accepted for processing. + - token_type: The type of this token - e.g. 'access_token' or 'refresh_token'. + """ # noqa + + subject: str + scope: Optional[str] + custom_claims: Optional[Dict[str, Any]] = None + audience: Optional[List[str]] + expires_at: Optional[int] + issued_at: Optional[int] + issuer: Optional[str] + not_before: Optional[int] + token_type: Optional[str] diff --git a/stytch/core/http/client.py b/stytch/core/http/client.py index 53e14aa..b98278a 100644 --- a/stytch/core/http/client.py +++ b/stytch/core/http/client.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import asyncio +import json from dataclasses import dataclass from typing import Any, Dict, Generic, Optional, TypeVar @@ -66,6 +67,17 @@ def post( resp = requests.post(url, json=json, headers=final_headers, auth=self.auth) return self._response_from_request(resp) + def post_form( + self, + url: str, + form: Optional[Dict[str, Any]], + headers: Optional[Dict[str, str]] = None, + ) -> ResponseWithJson: + final_headers = self.headers.copy() + final_headers.update(headers or {}) + resp = requests.post(url, data=form, headers=final_headers, auth=self.auth) + return self._response_from_request(resp) + def put( self, url: str, @@ -128,6 +140,16 @@ async def _response_from_request( resp_json = {} return ResponseWithJson(response=r, json=resp_json) + async def _response_from_post_form_request( + cls, r: aiohttp.ClientResponse + ) -> ResponseWithJson: + try: + content = await r.content.read() + resp_json = json.loads(content.decode()) + except Exception as e: + resp_json = {} + return ResponseWithJson(response=r, json=resp_json) + async def get( self, url: str, @@ -154,6 +176,19 @@ async def post( ) return await self._response_from_request(resp) + async def post_form( + self, + url: str, + form: Optional[Dict[str, Any]], + headers: Optional[Dict[str, str]] = None, + ) -> ResponseWithJson: + final_headers = self.headers.copy() + final_headers.update(headers or {}) + resp = await self._session.post( + url, data=form, headers=final_headers, auth=self.auth + ) + return await self._response_from_post_form_request(resp) + async def put( self, url: str, diff --git a/stytch/shared/jwt_helpers.py b/stytch/shared/jwt_helpers.py index 010791e..b8c7a87 100644 --- a/stytch/shared/jwt_helpers.py +++ b/stytch/shared/jwt_helpers.py @@ -19,6 +19,8 @@ def authenticate_jwt_local( jwt: str, max_token_age_seconds: Optional[int] = None, leeway: int = 0, + custom_audience: Optional[str] = None, + custom_issuer: Optional[str] = None, ) -> Optional[GenericClaims]: """Parse a JWT and verify the signature locally (without calling /authenticate in the API). @@ -32,8 +34,8 @@ def authenticate_jwt_local( The value for leeway is the maximum allowable difference in seconds when comparing timestamps. It defaults to zero. """ - jwt_audience = project_id - jwt_issuer = f"stytch.com/{project_id}" + jwt_audience = custom_audience if custom_audience else project_id + jwt_issuer = custom_issuer if custom_issuer else f"stytch.com/{project_id}" now = time.time() diff --git a/stytch/shared/rbac_local.py b/stytch/shared/rbac_local.py index 0ee0cf8..6c5bf50 100644 --- a/stytch/shared/rbac_local.py +++ b/stytch/shared/rbac_local.py @@ -56,3 +56,30 @@ def perform_authorization_check( # If we made it here, we didn't find a matching permission raise RBACPermissionError(authorization_check) + + +def perform_scope_authorization_check( + policy: Policy, + token_scopes: List[str], + authorization_check: AuthorizationCheck, +) -> None: + """Performs an authorization check against a policy and a set of scopes. If the check + succeeds, this method will return. If the check fails, a PermissionError will be + raised. + """ + for scope in policy.scopes: + if scope.scope in token_scopes: + for permission in scope.permissions: + has_matching_action = ( + "*" in permission.actions + or authorization_check.action in permission.actions + ) + has_matching_resource = ( + authorization_check.resource_id == permission.resource_id + ) + if has_matching_action and has_matching_resource: + # All good, we found a matching permission + return + + # If we made it here, we didn't find a matching permission + raise RBACPermissionError(authorization_check) diff --git a/stytch/shared/tests/test_rbac_local.py b/stytch/shared/tests/test_rbac_local.py index ec43ff1..fe08832 100644 --- a/stytch/shared/tests/test_rbac_local.py +++ b/stytch/shared/tests/test_rbac_local.py @@ -1,11 +1,12 @@ import unittest -from stytch.b2b.models.rbac import Policy, PolicyRole, PolicyRolePermission +from stytch.b2b.models.rbac import Policy, PolicyRole, PolicyRolePermission, PolicyScope, PolicyScopePermission from stytch.b2b.models.sessions import AuthorizationCheck from stytch.shared.rbac_local import ( RBACPermissionError, TenancyError, perform_authorization_check, + perform_scope_authorization_check, ) @@ -42,9 +43,34 @@ def setUp(self) -> None: PolicyRolePermission(actions=["write", "read"], resource_id="bar") ], ) + self.read_scope = PolicyScope( + scope="read:documents", + description="Read documents", + permissions=[ + PolicyScopePermission(actions=["read"], resource_id="foo"), + PolicyScopePermission(actions=["read"], resource_id="bar"), + ], + ) + self.write_scope = PolicyScope( + scope="write:documents", + description="Write documents", + permissions=[ + PolicyScopePermission(actions=["write", "read"], resource_id="foo"), + PolicyScopePermission(actions=["write", "read"], resource_id="bar"), + ], + ) + self.wildcard_scope = PolicyScope( + scope="wildcard:documents", + description="Wildcard documents", + permissions=[ + PolicyScopePermission(actions=["*"], resource_id="foo"), + PolicyScopePermission(actions=["*"], resource_id="bar"), + ], + ) self.policy = Policy( resources=[], roles=[self.admin, self.global_writer, self.global_reader, self.bar_writer], + scopes=[self.read_scope, self.write_scope, self.wildcard_scope], ) def test_perform_authorization_check(self) -> None: @@ -112,3 +138,56 @@ def test_perform_authorization_check(self) -> None: # Act perform_authorization_check(self.policy, roles, org_id, req) # Assertion is that no exception is raised + + def test_perform_scope_authorization_check(self) -> None: + with self.subTest("has matching action but not resource"): + with self.assertRaises(RBACPermissionError): + # Arrange + scopes = [self.write_scope.scope] + org_id = "my_org" + req = AuthorizationCheck( + organization_id=org_id, + resource_id="baz", + action="write", + ) + # Act + perform_scope_authorization_check(self.policy, scopes, req) + + with self.subTest("has matching resource but not action"): + with self.assertRaises(RBACPermissionError): + # Arrange + scopes = [self.read_scope.scope] + org_id = "my_org" + req = AuthorizationCheck( + organization_id=org_id, + resource_id="foo", + action="write", + ) + # Act + perform_scope_authorization_check(self.policy, scopes, req) + + with self.subTest("has matching resource and specific action"): + # Arrange + scopes = [self.write_scope.scope] + org_id = "my_org" + req = AuthorizationCheck( + organization_id=org_id, + resource_id="foo", + action="write", + ) + # Act + perform_scope_authorization_check(self.policy, scopes, req) + # Assertion is that no exception is raised + + with self.subTest("has matching resource and star action"): + # Arrange + scopes = [self.wildcard_scope.scope] + org_id = "my_org" + req = AuthorizationCheck( + organization_id=org_id, + resource_id="foo", + action="write", + ) + # Act + perform_scope_authorization_check(self.policy, scopes, req) + # Assertion is that no exception is raised