Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IDP token introspection #229

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions stytch/b2b/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from stytch.b2b.api.sso import SSO
from stytch.b2b.api.totps import TOTPs
from stytch.consumer.api.fraud import Fraud
from stytch.consumer.api.idp import IDP
from stytch.consumer.api.m2m import M2M
from stytch.consumer.api.project import Project
from stytch.core.client_base import ClientBase
Expand Down Expand Up @@ -148,6 +149,13 @@ def __init__(
sync_client=self.sync_client,
async_client=self.async_client,
)
self.idp = IDP(
api_base=self.api_base,
sync_client=self.sync_client,
async_client=self.async_client,
jwks_client=self.jwks_client,
project_id=project_id,
)

def get_jwks_client(self, project_id: str) -> jwt.PyJWKClient:
data = {"project_id": project_id}
Expand Down
204 changes: 204 additions & 0 deletions stytch/consumer/api/idp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
from __future__ import annotations

from typing import Any, Dict, Optional

import jwt

from stytch.consumer.models.idp import AccessTokenJWTClaims, AccessTokenJWTResponse
from stytch.core.api_base import ApiBase
from stytch.core.http.client import AsyncClient, SyncClient
from stytch.shared import jwt_helpers


class IDP:
def __init__(
self,
api_base: ApiBase,
sync_client: SyncClient,
async_client: AsyncClient,
jwks_client: jwt.PyJWKClient,
project_id: str,
) -> None:
self.api_base = api_base
self.sync_client = sync_client
self.async_client = async_client
self.jwks_client = jwks_client
self.project_id = project_id

def introspect_idp_access_token(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need introspect_idp_access_token since unlike Session JWTs there is no case where an access token can fail local validation and pass remote validation. If the access token is expired locally, it is also guaranteed to be expired serverside.

self,
access_token: str,
client_id: str,
client_secret: Optional[str] = None,
token_type_hint: str = "access_token",
) -> Optional[AccessTokenJWTClaims]:
"""Introspects a token JWT from an authorization code response.
Access tokens and refresh tokens are JWTs signed with the project's JWKs.
Access tokens contain a standard set of claims as well as any custom claims generated from templates.

Fields:
- access_token: The access token (or refresh token) to introspect.
- client_id: The ID of the client.
- client_secret: The secret of the client.
- token_type_hint: A hint on what the token contains. Valid fields are 'access_token' and 'refresh_token'.
"""
return self.introspect_idp_access_token_local(
access_token, client_id
) or self.introspect_idp_access_token_network(
access_token, client_id, client_secret, token_type_hint
)

async def introspect_idp_access_token_async(
self,
access_token: str,
client_id: str,
client_secret: Optional[str] = None,
token_type_hint: str = "access_token",
) -> Optional[AccessTokenJWTClaims]:
"""Introspects a token JWT from an authorization code response.
Access tokens and refresh tokens are JWTs signed with the project's JWKs.
Access tokens contain a standard set of claims as well as any custom claims generated from templates.

Fields:
- access_token: The access token (or refresh token) to introspect.
- client_id: The ID of the client.
- client_secret: The secret of the client.
- token_type_hint: A hint on what the token contains. Valid fields are 'access_token' and 'refresh_token'.
"""
local_introspection_response = self.introspect_idp_access_token_local(access_token, client_id)
if local_introspection_response is not None:
return local_introspection_response
return await self.introspect_idp_access_token_network_async(
access_token, client_id, client_secret, token_type_hint
)

def introspect_idp_access_token_network(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming nits:

  • Since the classname is already idp - do we need the second idp in the name? idp.introspect_idp_ could be idp.introspect_
  • Since this works for both access and refresh tokens, suggest removing access_ from the name

I think introspect_token_network works better.

self,
access_token: str,
client_id: str,
client_secret: Optional[str] = None,
token_type_hint: str = "access_token",
) -> Optional[AccessTokenJWTClaims]:
"""Introspects a token JWT from an authorization code response.
Access tokens and refresh tokens are JWTs signed with the project's JWKs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refresh Tokens are not JWTs - this endpoint supports both Access Token JWTs and Refresh Token opaque tokens

Access tokens contain a standard set of claims as well as any custom claims generated from templates.

Fields:
- access_token: The access token (or refresh token) to introspect.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this field takes in both access tokens and refresh tokens, we should name it token

- client_id: The ID of the client.
- client_secret: The secret of the client.
- token_type_hint: A hint on what the token contains. Valid fields are 'access_token' and 'refresh_token'.
"""
headers: Dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"}
data: Dict[str, Any] = {
"token": access_token,
"client_id": client_id,
"token_type_hint": token_type_hint,
}
if client_secret is not None:
data["client_secret"] = client_secret

url = self.api_base.url_for(
f"/v1/public/{self.project_id}/oauth2/introspect", data
)
res = self.sync_client.postForm(url, data, headers)
jwtResponse = AccessTokenJWTResponse.from_json(
res.response.status_code, res.json
)
if not jwtResponse.active:
return None
return AccessTokenJWTClaims(
subject=jwtResponse.sub,
scope=jwtResponse.scope,
custom_claims={},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should return custom claims that we retrieve from the introspection response here

audience=jwtResponse.aud,
expires_at=jwtResponse.exp,
issued_at=jwtResponse.iat,
issuer=jwtResponse.iss,
not_before=jwtResponse.nbf,
)

async def introspect_idp_access_token_network_async(
self,
access_token: str,
client_id: str,
client_secret: Optional[str] = None,
token_type_hint: str = "access_token",
) -> Optional[AccessTokenJWTClaims]:
"""Introspects a token JWT from an authorization code response.
Access tokens and refresh tokens are JWTs signed with the project's JWKs.
Access tokens contain a standard set of claims as well as any custom claims generated from templates.

Fields:
- access_token: The access token (or refresh token) to introspect.
- client_id: The ID of the client.
- client_secret: The secret of the client.
- token_type_hint: A hint on what the token contains. Valid fields are 'access_token' and 'refresh_token'.
"""
headers: Dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"}
data: Dict[str, Any] = {
"token": access_token,
"client_id": client_id,
"token_type_hint": token_type_hint,
}
if client_secret is not None:
data["client_secret"] = client_secret

url = self.api_base.url_for(
f"/v1/public/{self.project_id}/oauth2/introspect", data
)
res = await self.async_client.postForm(url, data, headers)
jwtResponse = AccessTokenJWTResponse.from_json(
res.response.status, res.json
)
if not jwtResponse.active:
return None
return AccessTokenJWTClaims(
subject=jwtResponse.sub,
scope=jwtResponse.scope,
custom_claims={},
audience=jwtResponse.aud,
expires_at=jwtResponse.exp,
issued_at=jwtResponse.iat,
issuer=jwtResponse.iss,
not_before=jwtResponse.nbf,
)

def introspect_idp_access_token_local(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in contrast this method is only valid for access tokens

self,
access_token: str,
client_id: str,
) -> Optional[AccessTokenJWTClaims]:
"""Introspects a token JWT from an authorization code response.
Access tokens and refresh tokens are JWTs signed with the project's JWKs.
Access tokens contain a standard set of claims as well as any custom claims generated from templates.

Fields:
- access_token: The access token (or refresh token) to introspect.
- client_id: The ID of the client.
"""
_scope_claim = "scope"
generic_claims = jwt_helpers.authenticate_jwt_local(
project_id=self.project_id,
jwks_client=self.jwks_client,
jwt=access_token,
custom_audience=client_id,
custom_issuer=f"https://stytch.com/{self.project_id}",
)
if generic_claims is None:
return None

custom_claims = {
k: v for k, v in generic_claims.untyped_claims.items() if k != _scope_claim
}

return AccessTokenJWTClaims(
subject=generic_claims.reserved_claims["sub"],
scope=generic_claims.untyped_claims[_scope_claim],
custom_claims=custom_claims,
audience=generic_claims.reserved_claims["aud"],
expires_at=generic_claims.reserved_claims["exp"],
issued_at=generic_claims.reserved_claims["iat"],
issuer=generic_claims.reserved_claims["iss"],
not_before=generic_claims.reserved_claims["nbf"],
)
8 changes: 8 additions & 0 deletions stytch/consumer/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from stytch.consumer.api.crypto_wallets import CryptoWallets
from stytch.consumer.api.fraud import Fraud
from stytch.consumer.api.idp import IDP
from stytch.consumer.api.m2m import M2M
from stytch.consumer.api.magic_links import MagicLinks
from stytch.consumer.api.oauth import OAuth
Expand Down Expand Up @@ -114,6 +115,13 @@ def __init__(
sync_client=self.sync_client,
async_client=self.async_client,
)
self.idp = IDP(
api_base=self.api_base,
sync_client=self.sync_client,
async_client=self.async_client,
jwks_client=self.jwks_client,
project_id=project_id,
)

def get_jwks_client(self, project_id: str) -> jwt.PyJWKClient:
data = {"project_id": project_id}
Expand Down
51 changes: 51 additions & 0 deletions stytch/consumer/models/idp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Any, Dict, List, Optional

import pydantic

from stytch.core.response_base import ResponseBase


class AccessTokenJWTResponse(ResponseBase):
"""Response type for `IDP.introspect_idp_access_token`.
Fields:
- active: Whether or not this token is active.
- sub: Subject of this JWT.
- scope: A space-delimited string of scopes this JWT is granted.
- aud: Audience of this JWT. Usually the user or member ID, and any custom audience, if present.
- exp: Expiration of this access token, in Unix time.
- iat: The time this access token was issued.
- iss: The issuer of this access token.
- nbf: The time before which the JWT must not be accepted for processing.
""" # noqa

active: bool
sub: Optional[str] = None
scope: Optional[str] = None
aud: Optional[List[str]] = []
exp: Optional[int] = None
iat: Optional[int] = None
iss: Optional[str] = None
nbf: Optional[int] = None


class AccessTokenJWTClaims(pydantic.BaseModel):
"""Response type for `IDP.introspect_idp_access_token`.
Fields:
- subject: The subject (either user_id or member_id) that the JWT is intended for.
- scope: A space-delimited string of scopes this JWT is granted.
- custom_claims: A dict of custom claims of the JWT.
- audience: Audience of this JWT. Usually the user or member ID, and any custom audience, if present.
- expires_at: Expiration of this access token, in Unix time.
- issued_at: The time this access token was issued.
- issuer: The issuer of this access token.
- not_before: The time before which the JWT must not be accepted for processing.
""" # noqa

subject: str
scope: Optional[str]
custom_claims: Optional[Dict[str, Any]] = None
audience: Optional[List[str]]
expires_at: Optional[int]
issued_at: Optional[int]
issuer: Optional[str]
not_before: Optional[int]
36 changes: 36 additions & 0 deletions stytch/core/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import requests
import requests.auth

import json

from stytch.version import __version__

HEADERS = {
Expand Down Expand Up @@ -66,6 +68,17 @@ def post(
resp = requests.post(url, json=json, headers=final_headers, auth=self.auth)
return self._response_from_request(resp)

def postForm(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we name this post_form instead? We use snake case everywhere else in this package

self,
url: str,
form: Optional[Dict[str, Any]],
headers: Optional[Dict[str, str]] = None,
) -> ResponseWithJson:
final_headers = self.headers.copy()
final_headers.update(headers or {})
resp = requests.post(url, data=form, headers=final_headers, auth=self.auth)
return self._response_from_request(resp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does one flavor of post_form use _response_from_post_form_request and the other not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the sync flavor uses the requests library which plays well with _response_from_request, but async's aiohttp doesn't play well here


def put(
self,
url: str,
Expand Down Expand Up @@ -127,6 +140,16 @@ async def _response_from_request(
except Exception:
resp_json = {}
return ResponseWithJson(response=r, json=resp_json)

async def _response_from_post_form_request(
cls, r: aiohttp.ClientResponse
) -> ResponseWithJson:
try:
content = await r.content.read()
resp_json = json.loads(content.decode())
except Exception as e:
resp_json = {}
return ResponseWithJson(response=r, json=resp_json)

async def get(
self,
Expand All @@ -153,6 +176,19 @@ async def post(
url, json=json, headers=final_headers, auth=self.auth
)
return await self._response_from_request(resp)

async def postForm(
self,
url: str,
form: Optional[Dict[str, Any]],
headers: Optional[Dict[str, str]] = None,
) -> ResponseWithJson:
final_headers = self.headers.copy()
final_headers.update(headers or {})
resp = await self._session.post(
url, data=form, headers=final_headers, auth=self.auth
)
return await self._response_from_post_form_request(resp)

async def put(
self,
Expand Down
6 changes: 4 additions & 2 deletions stytch/shared/jwt_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def authenticate_jwt_local(
jwt: str,
max_token_age_seconds: Optional[int] = None,
leeway: int = 0,
custom_audience: Optional[str] = None,
custom_issuer: Optional[str] = None,
) -> Optional[GenericClaims]:
"""Parse a JWT and verify the signature locally
(without calling /authenticate in the API).
Expand All @@ -32,8 +34,8 @@ def authenticate_jwt_local(
The value for leeway is the maximum allowable difference in seconds when
comparing timestamps. It defaults to zero.
"""
jwt_audience = project_id
jwt_issuer = f"stytch.com/{project_id}"
jwt_audience = custom_audience if custom_audience else project_id
jwt_issuer = custom_issuer if custom_issuer else f"stytch.com/{project_id}"

now = time.time()

Expand Down
Loading