Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OBACK-290: Fix issue with python Authenticate JWT locally logic #209

Merged
merged 5 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ This client library supports all Stytch's live products:
- [x] [Crypto wallets](https://stytch.com/docs/guides/web3/api)
- [x] [Passwords](https://stytch.com/docs/guides/passwords/api)

**B2B**
## B2B

- [x] [Organizations](https://stytch.com/docs/b2b/api/organization-object)
- [x] [Members](https://stytch.com/docs/b2b/api/member-object)
Expand Down
163 changes: 99 additions & 64 deletions stytch/consumer/api/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
GetJWKSResponse,
GetResponse,
RevokeResponse,
Session,
Session, AuthenticateJWTLocalResponse,
)
from stytch.core.api_base import ApiBase
from stytch.core.http.client import AsyncClient, SyncClient
Expand All @@ -24,12 +24,12 @@

class Sessions:
def __init__(
self,
api_base: ApiBase,
sync_client: SyncClient,
async_client: AsyncClient,
jwks_client: jwt.PyJWKClient,
project_id: str,
self,
api_base: ApiBase,
sync_client: SyncClient,
async_client: AsyncClient,
jwks_client: jwt.PyJWKClient,
project_id: str,
) -> None:
self.api_base = api_base
self.sync_client = sync_client
Expand All @@ -38,8 +38,8 @@ def __init__(
self.project_id = project_id

def get(
self,
user_id: str,
self,
user_id: str,
) -> GetResponse:
"""List all active Sessions for a given `user_id`. All timestamps are formatted according to the RFC 3339 standard and are expressed in UTC, e.g. `2021-12-29T12:33:09Z`.

Expand All @@ -56,8 +56,8 @@ def get(
return GetResponse.from_json(res.response.status_code, res.json)

async def get_async(
self,
user_id: str,
self,
user_id: str,
) -> GetResponse:
"""List all active Sessions for a given `user_id`. All timestamps are formatted according to the RFC 3339 standard and are expressed in UTC, e.g. `2021-12-29T12:33:09Z`.

Expand All @@ -74,11 +74,11 @@ async def get_async(
return GetResponse.from_json(res.response.status, res.json)

def authenticate(
self,
session_token: Optional[str] = None,
session_duration_minutes: Optional[int] = None,
session_jwt: Optional[str] = None,
session_custom_claims: Optional[Dict[str, Any]] = None,
self,
session_token: Optional[str] = None,
session_duration_minutes: Optional[int] = None,
session_jwt: Optional[str] = None,
session_custom_claims: Optional[Dict[str, Any]] = None,
) -> AuthenticateResponse:
"""Authenticate a session token or session JWT and retrieve associated session data. If `session_duration_minutes` is included, update the lifetime of the session to be that many minutes from now. All timestamps are formatted according to the RFC 3339 standard and are expressed in UTC, e.g. `2021-12-29T12:33:09Z`. This endpoint requires exactly one `session_jwt` or `session_token` as part of the request. If both are included, you will receive a `too_many_session_arguments` error.

Expand Down Expand Up @@ -108,11 +108,11 @@ def authenticate(
return AuthenticateResponse.from_json(res.response.status_code, res.json)

async def authenticate_async(
self,
session_token: Optional[str] = None,
session_duration_minutes: Optional[int] = None,
session_jwt: Optional[str] = None,
session_custom_claims: Optional[Dict[str, Any]] = None,
self,
session_token: Optional[str] = None,
session_duration_minutes: Optional[int] = None,
session_jwt: Optional[str] = None,
session_custom_claims: Optional[Dict[str, Any]] = None,
) -> AuthenticateResponse:
"""Authenticate a session token or session JWT and retrieve associated session data. If `session_duration_minutes` is included, update the lifetime of the session to be that many minutes from now. All timestamps are formatted according to the RFC 3339 standard and are expressed in UTC, e.g. `2021-12-29T12:33:09Z`. This endpoint requires exactly one `session_jwt` or `session_token` as part of the request. If both are included, you will receive a `too_many_session_arguments` error.

Expand Down Expand Up @@ -142,10 +142,10 @@ async def authenticate_async(
return AuthenticateResponse.from_json(res.response.status, res.json)

def revoke(
self,
session_id: Optional[str] = None,
session_token: Optional[str] = None,
session_jwt: Optional[str] = None,
self,
session_id: Optional[str] = None,
session_token: Optional[str] = None,
session_jwt: Optional[str] = None,
) -> RevokeResponse:
"""Revoke a Session, immediately invalidating all of its session tokens. You can revoke a session in three ways: using its ID, or using one of its session tokens, or one of its JWTs. This endpoint requires exactly one of those to be included in the request. It will return an error if multiple are present.

Expand All @@ -168,10 +168,10 @@ def revoke(
return RevokeResponse.from_json(res.response.status_code, res.json)

async def revoke_async(
self,
session_id: Optional[str] = None,
session_token: Optional[str] = None,
session_jwt: Optional[str] = None,
self,
session_id: Optional[str] = None,
session_token: Optional[str] = None,
session_jwt: Optional[str] = None,
) -> RevokeResponse:
"""Revoke a Session, immediately invalidating all of its session tokens. You can revoke a session in three ways: using its ID, or using one of its session tokens, or one of its JWTs. This endpoint requires exactly one of those to be included in the request. It will return an error if multiple are present.

Expand All @@ -194,8 +194,8 @@ async def revoke_async(
return RevokeResponse.from_json(res.response.status, res.json)

def get_jwks(
self,
project_id: str,
self,
project_id: str,
) -> GetJWKSResponse:
"""Get the JSON Web Key Set (JWKS) for a project.

Expand All @@ -222,8 +222,8 @@ def get_jwks(
return GetJWKSResponse.from_json(res.response.status_code, res.json)

async def get_jwks_async(
self,
project_id: str,
self,
project_id: str,
) -> GetJWKSResponse:
"""Get the JSON Web Key Set (JWKS) for a project.

Expand Down Expand Up @@ -254,11 +254,11 @@ async def get_jwks_async(
# ADDIMPORT: import jwt
# ADDIMPORT: import time
def authenticate_jwt(
self,
session_jwt: str,
max_token_age_seconds: Optional[int] = None,
session_custom_claims: Optional[Dict[str, Any]] = None,
) -> Optional[Session]:
self,
session_jwt: str,
max_token_age_seconds: Optional[int] = None,
session_custom_claims: Optional[Dict[str, Any]] = None,
) -> Optional[AuthenticateJWTLocalResponse]:
"""Parse a JWT and verify the signature, preferring local verification
over remote.

Expand All @@ -269,22 +269,41 @@ def authenticate_jwt(
zero or use the authenticate method instead.
"""
# Return the local_result if available, otherwise call the Stytch API
return (
self.authenticate_jwt_local(
session_jwt=session_jwt,
max_token_age_seconds=max_token_age_seconds,
# Return the local_result if available, otherwise call the Stytch API
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicated comment?

local_token = self.authenticate_jwt_local(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[dust] local_resp or local_res feels like a better name than local_token

session_jwt=session_jwt,
max_token_age_seconds=max_token_age_seconds,
)
if local_token is not None:
return AuthenticateJWTLocalResponse.from_json(
status_code=200,
json={
"session": local_token,
"session_jwt": session_jwt,
"status_code": 200,
"request_id": "",
},
)
or self.authenticate(
else:
authenticate_response = self.authenticate(
session_custom_claims=session_custom_claims, session_jwt=session_jwt
).session
)
)
return AuthenticateJWTLocalResponse.from_json(
status_code=authenticate_response.status_code,
json={
"session": authenticate_response.session,
"session_jwt": authenticate_response.session_jwt,
"status_code": authenticate_response.status_code,
"request_id": authenticate_response.request_id,
}
)

async def authenticate_jwt_async(
self,
session_jwt: str,
max_token_age_seconds: Optional[int] = None,
session_custom_claims: Optional[Dict[str, Any]] = None,
) -> Optional[Session]:
self,
session_jwt: str,
max_token_age_seconds: Optional[int] = None,
session_custom_claims: Optional[Dict[str, Any]] = None,
) -> Optional[AuthenticateJWTLocalResponse]:
"""Parse a JWT and verify the signature, preferring local verification
over remote.

Expand All @@ -295,28 +314,44 @@ async def authenticate_jwt_async(
zero or use the authenticate method instead.
"""
# Return the local_result if available, otherwise call the Stytch API
return (
self.authenticate_jwt_local(
session_jwt=session_jwt,
max_token_age_seconds=max_token_age_seconds,
)
or (
await self.authenticate_async(
session_custom_claims=session_custom_claims, session_jwt=session_jwt
)
).session
local_token = self.authenticate_jwt_local(
session_jwt=session_jwt,
max_token_age_seconds=max_token_age_seconds,
)
if local_token is not None:
return AuthenticateJWTLocalResponse.from_json(
status_code=200,
json={
"session": local_token,
"session_jwt": session_jwt,
"status_code": 200,
"request_id": "",
},
)
else:
authenticate_response = await self.authenticate_async(
session_custom_claims=session_custom_claims, session_jwt=session_jwt
)
return AuthenticateJWTLocalResponse.from_json(
status_code=authenticate_response.status_code,
json={
"session": authenticate_response.session,
"session_jwt": authenticate_response.session_jwt,
"status_code": authenticate_response.status_code,
"request_id": authenticate_response.request_id,
}
)

# ENDMANUAL(authenticate_jwt)

# MANUAL(authenticate_jwt_local)(SERVICE_METHOD)
# ADDIMPORT: from stytch.consumer.models.sessions import Session
# ADDIMPORT: from stytch.shared import jwt_helpers
def authenticate_jwt_local(
self,
session_jwt: str,
max_token_age_seconds: Optional[int] = None,
leeway: int = 0,
self,
session_jwt: str,
max_token_age_seconds: Optional[int] = None,
leeway: int = 0,
) -> Optional[Session]:
_session_claim = "https://stytch.com/session"
generic_claims = jwt_helpers.authenticate_jwt_local(
Expand Down
6 changes: 6 additions & 0 deletions stytch/consumer/models/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,3 +496,9 @@ class RevokeResponse(ResponseBase):
"""Response type for `Sessions.revoke`.
Fields:
""" # noqa

# MANUAL(AuthenticateJWTLocalResponse)(Types)
class AuthenticateJWTLocalResponse(ResponseBase):
session: Session
session_jwt: str
# ENDMANUAL(AuthenticateJWTLocalResponse)
40 changes: 22 additions & 18 deletions stytch/shared/jwt_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,28 @@ def authenticate_jwt_local(

signing_key = jwks_client.get_signing_key_from_jwt(jwt)

# NOTE: The max_token_age_seconds value is applied after decoding.
payload = pyjwt.decode(
jwt,
signing_key.key,
algorithms=["RS256"],
options={
"require": ["aud", "iss", "exp", "iat", "nbf"],
"verify_signature": True,
"verify_aud": True,
"verify_iss": True,
"verify_exp": True,
"verify_iat": True,
"verify_nbf": True,
},
audience=jwt_audience,
issuer=jwt_issuer,
leeway=leeway,
)
try:
# NOTE: The max_token_age_seconds value is applied after decoding.
payload = pyjwt.decode(
jwt,
signing_key.key,
algorithms=["RS256"],
options={
"require": ["aud", "iss", "exp", "iat", "nbf"],
"verify_signature": True,
"verify_aud": True,
"verify_iss": True,
"verify_exp": True,
"verify_iat": True,
"verify_nbf": True,
},
audience=jwt_audience,
issuer=jwt_issuer,
leeway=leeway,
)
except Exception:
# In the event of a failure to decode, such as an expired token, we should return None
return None

if max_token_age_seconds is not None:
iat = payload["iat"]
Expand Down
2 changes: 1 addition & 1 deletion stytch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "10.1.0"
__version__ = "11.0.0"
1 change: 1 addition & 0 deletions test/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
TEST_PW_HASH_TYPE = MigrateRequestHashType.BCRYPT
# Sessions test constants
TEST_SESSION_TOKEN = "WJtR5BCy38Szd5AfoDpf0iqFKEt4EE5JhjlWUY7l3FtY"
TEST_EXPIRED_JWT = "eyJhbGciOiJSUzI1NiIsImtpZCI6Imp3ay10ZXN0LWU5NzQ1ZmJmLTNiNDQtNDkxYi1iNDAyLWFmZjhmMTNlZWM2OSIsInR5cCI6IkpXVCJ9.eyJhdWQiOlsicHJvamVjdC10ZXN0LTFiZWZkODEyLWU5ZTYtNDlmZS1iZGMxLWM4ZGYyYmZhNzAyYSJdLCJleHAiOjE3MjExNjI2MzMsImh0dHBzOi8vc3R5dGNoLmNvbS9zZXNzaW9uIjp7ImlkIjoic2Vzc2lvbi10ZXN0LWViOTQyMzNmLTg4MDAtNGViZC04NjQ1LTUxZGMxNWY5ZDAyOCIsInN0YXJ0ZWRfYXQiOiIyMDIxLTA4LTI4VDAwOjQxOjU4WiIsImxhc3RfYWNjZXNzZWRfYXQiOiIyMDI0LTA3LTE2VDIwOjM4OjUzWiIsImV4cGlyZXNfYXQiOiIyMDI0LTA3LTE2VDIxOjM4OjUzWiIsImF0dHJpYnV0ZXMiOnsidXNlcl9hZ2VudCI6IiIsImlwX2FkZHJlc3MiOiIifSwiYXV0aGVudGljYXRpb25fZmFjdG9ycyI6W3sidHlwZSI6Im1hZ2ljX2xpbmsiLCJkZWxpdmVyeV9tZXRob2QiOiJlbWFpbCIsImxhc3RfYXV0aGVudGljYXRlZF9hdCI6IjIwMjQtMDctMTZUMjA6Mzg6NTNaIiwiZW1haWxfZmFjdG9yIjp7ImVtYWlsX2lkIjoiZW1haWwtdGVzdC0yMzg3M2U4OS1kNGVkLTRlOTItYjNiOS1lNWM3MTk4ZmEyODYiLCJlbWFpbF9hZGRyZXNzIjoic2FuZGJveEBzdHl0Y2guY29tIn19XX0sImlhdCI6MTcyMTE2MjMzMywiaXNzIjoic3R5dGNoLmNvbS9wcm9qZWN0LXRlc3QtMWJlZmQ4MTItZTllNi00OWZlLWJkYzEtYzhkZjJiZmE3MDJhIiwibmJmIjoxNzIxMTYyMzMzLCJzdWIiOiJ1c2VyLXRlc3QtZTM3OTVjODEtZjg0OS00MTY3LWJmZGEtZTRhNmU5YzI4MGZkIn0.hm-vJeVBCun4I7Vkhj10VQevlQCdeyY2OWqFJ3B0tSSjE9s_8uAS8ag41hCyEzgVENyjIicU6r2Ow_dLol4dwXZrqU3cwNJsH6T62hxSzpG8cTVQJzQgirf7yjqMC-TzTrxKuPkta5sULMV6GBtW9-HDT5HlWQK5hp3BSxwq2Qc_WfmKv8ghV5YcBHRVCmu8IDnG0D19FrPOhLvLR6NPbNpidSA6gYP9OMl45f65hZHOQF178F6HtErZPpkDX2g3n6iymxznj5b5g2mEKhxjVYUxAWsFPE1FZXurc5ui47n3PaUu9J7zJmaUSi80lFQ-YKCZVSWAQ_DUOhwSJSEsLw"
# TOTP test constants
TEST_TOTP_USER_ID = "user-test-e3795c81-f849-4167-bfda-e4a6e9c280fd"
TEST_TOTP_CODE = "000000"
Expand Down
19 changes: 18 additions & 1 deletion test/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#!/usr/bin/env python3

import unittest
from time import sleep

from test.constants import (
TEST_CRYPTO_SIGNATURE,
TEST_CRYPTO_WALLET_ADDRESS,
Expand All @@ -17,7 +19,7 @@
TEST_TOTP_CODE,
TEST_TOTP_RECOVERY_CODE,
TEST_TOTP_USER_ID,
TEST_USERS_NAME,
TEST_USERS_NAME, TEST_EXPIRED_JWT,
)
from test.integration_base import CreatedTestUser, IntegrationTestBase

Expand Down Expand Up @@ -232,6 +234,21 @@ def test_webauthn(self) -> None:
# TODO: No test public key credential (see skipTest above)
self.assertTrue(api.authenticate(public_key_credential="").is_success)

def test_authenticate(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sweet! Thanks for adding a test!

api = self.b2c_client.sessions

with self._get_temporary_user() as user:
assert isinstance(user, CreatedTestUser)
self.assertTrue(api.get(user_id=user.user_id).is_success)
# Grab a recent JWT token and verify it's valid
auth_response = api.authenticate(session_token=TEST_SESSION_TOKEN)
response = self.b2c_client.sessions.authenticate_jwt(session_jwt=auth_response.session_jwt)
self.assertEquals(auth_response.session_jwt, response.session_jwt)

def test_authenticate_jwt_local_returns_none_for_expired_token(self) -> None:
api = self.b2c_client.sessions
self.assertIsNone(api.authenticate_jwt_local(session_jwt=TEST_EXPIRED_JWT))


if __name__ == "__main__":
unittest.main()
11 changes: 11 additions & 0 deletions test/test_integration_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,14 @@ async def test_webauthn_async(self) -> None:
self.assertTrue(
(await api.authenticate_async(public_key_credential="")).is_success
)

async def test_authenticate(self) -> None:
api = self.b2c_client.sessions

async with self._get_temporary_user_async() as user:
assert isinstance(user, CreatedTestUser)
self.assertTrue(api.get(user_id=user.user_id).is_success)
# Grab a recent JWT token and verify it's valid
auth_response = api.authenticate(session_token=TEST_SESSION_TOKEN)
response = await self.b2c_client.sessions.authenticate_jwt_async(session_jwt=auth_response.session_jwt)
self.assertEquals(auth_response.session_jwt, response.session_jwt)
Loading