Skip to content

Commit

Permalink
OBACK-290: Fix issue with python Authenticate JWT locally logic (#209)
Browse files Browse the repository at this point in the history
* Fix issue with python Authenticate JWT locally logic

* Cleanups

* More cleans

* Try adding an actual check

* Minor nits
  • Loading branch information
bgier-stytch authored Jul 17, 2024
1 parent 689d25b commit a584134
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 39 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ This client library supports all Stytch's live products:
- [x] [Crypto wallets](https://stytch.com/docs/guides/web3/api)
- [x] [Passwords](https://stytch.com/docs/guides/passwords/api)

**B2B**
## B2B

- [x] [Organizations](https://stytch.com/docs/b2b/api/organization-object)
- [x] [Members](https://stytch.com/docs/b2b/api/member-object)
Expand Down
74 changes: 55 additions & 19 deletions stytch/consumer/api/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
GetResponse,
RevokeResponse,
Session,
AuthenticateJWTLocalResponse,
)
from stytch.core.api_base import ApiBase
from stytch.core.http.client import AsyncClient, SyncClient
Expand Down Expand Up @@ -253,12 +254,13 @@ async def get_jwks_async(
# ADDIMPORT: from typing import Any, Dict, Optional
# ADDIMPORT: import jwt
# ADDIMPORT: import time
# ADDIMPORT: from stytch.consumer.models.sessions import AuthenticateJWTLocalResponse
def authenticate_jwt(
self,
session_jwt: str,
max_token_age_seconds: Optional[int] = None,
session_custom_claims: Optional[Dict[str, Any]] = None,
) -> Optional[Session]:
) -> Optional[AuthenticateJWTLocalResponse]:

This comment has been minimized.

Copy link
@dvf

dvf Jul 23, 2024

It looks like this is always going to return AuthenticateJWTLocalResponse (not optional)?

"""Parse a JWT and verify the signature, preferring local verification
over remote.
Expand All @@ -269,22 +271,40 @@ def authenticate_jwt(
zero or use the authenticate method instead.
"""
# Return the local_result if available, otherwise call the Stytch API
return (
self.authenticate_jwt_local(
session_jwt=session_jwt,
max_token_age_seconds=max_token_age_seconds,
local_resp = self.authenticate_jwt_local(
session_jwt=session_jwt,
max_token_age_seconds=max_token_age_seconds,
)
if local_resp is not None:
return AuthenticateJWTLocalResponse.from_json(
status_code=200,
json={
"session": local_resp,
"session_jwt": session_jwt,
"status_code": 200,
"request_id": "",
},
)
or self.authenticate(
else:
authenticate_response = self.authenticate(
session_custom_claims=session_custom_claims, session_jwt=session_jwt
).session
)
)
return AuthenticateJWTLocalResponse.from_json(
status_code=authenticate_response.status_code,
json={
"session": authenticate_response.session,
"session_jwt": authenticate_response.session_jwt,
"status_code": authenticate_response.status_code,
"request_id": authenticate_response.request_id,
}
)

async def authenticate_jwt_async(
self,
session_jwt: str,
max_token_age_seconds: Optional[int] = None,
session_custom_claims: Optional[Dict[str, Any]] = None,
) -> Optional[Session]:
) -> Optional[AuthenticateJWTLocalResponse]:
"""Parse a JWT and verify the signature, preferring local verification
over remote.
Expand All @@ -295,17 +315,33 @@ async def authenticate_jwt_async(
zero or use the authenticate method instead.
"""
# Return the local_result if available, otherwise call the Stytch API
return (
self.authenticate_jwt_local(
session_jwt=session_jwt,
max_token_age_seconds=max_token_age_seconds,
)
or (
await self.authenticate_async(
session_custom_claims=session_custom_claims, session_jwt=session_jwt
)
).session
local_token = self.authenticate_jwt_local(
session_jwt=session_jwt,
max_token_age_seconds=max_token_age_seconds,
)
if local_token is not None:
return AuthenticateJWTLocalResponse.from_json(
status_code=200,
json={
"session": local_token,
"session_jwt": session_jwt,
"status_code": 200,
"request_id": "",
},
)
else:
authenticate_response = await self.authenticate_async(
session_custom_claims=session_custom_claims, session_jwt=session_jwt
)
return AuthenticateJWTLocalResponse.from_json(
status_code=authenticate_response.status_code,
json={
"session": authenticate_response.session,
"session_jwt": authenticate_response.session_jwt,
"status_code": authenticate_response.status_code,
"request_id": authenticate_response.request_id,
}
)

# ENDMANUAL(authenticate_jwt)

Expand Down
6 changes: 6 additions & 0 deletions stytch/consumer/models/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,3 +496,9 @@ class RevokeResponse(ResponseBase):
"""Response type for `Sessions.revoke`.
Fields:
""" # noqa

# MANUAL(AuthenticateJWTLocalResponse)(Types)
class AuthenticateJWTLocalResponse(ResponseBase):
session: Session
session_jwt: str
# ENDMANUAL(AuthenticateJWTLocalResponse)
40 changes: 22 additions & 18 deletions stytch/shared/jwt_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,28 @@ def authenticate_jwt_local(

signing_key = jwks_client.get_signing_key_from_jwt(jwt)

# NOTE: The max_token_age_seconds value is applied after decoding.
payload = pyjwt.decode(
jwt,
signing_key.key,
algorithms=["RS256"],
options={
"require": ["aud", "iss", "exp", "iat", "nbf"],
"verify_signature": True,
"verify_aud": True,
"verify_iss": True,
"verify_exp": True,
"verify_iat": True,
"verify_nbf": True,
},
audience=jwt_audience,
issuer=jwt_issuer,
leeway=leeway,
)
try:
# NOTE: The max_token_age_seconds value is applied after decoding.
payload = pyjwt.decode(
jwt,
signing_key.key,
algorithms=["RS256"],
options={
"require": ["aud", "iss", "exp", "iat", "nbf"],
"verify_signature": True,
"verify_aud": True,
"verify_iss": True,
"verify_exp": True,
"verify_iat": True,
"verify_nbf": True,
},
audience=jwt_audience,
issuer=jwt_issuer,
leeway=leeway,
)
except Exception:
# In the event of a failure to decode, such as an expired token, we should return None
return None

if max_token_age_seconds is not None:
iat = payload["iat"]
Expand Down
2 changes: 1 addition & 1 deletion stytch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "10.1.0"
__version__ = "11.0.0"
1 change: 1 addition & 0 deletions test/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
TEST_PW_HASH_TYPE = MigrateRequestHashType.BCRYPT
# Sessions test constants
TEST_SESSION_TOKEN = "WJtR5BCy38Szd5AfoDpf0iqFKEt4EE5JhjlWUY7l3FtY"
TEST_EXPIRED_JWT = "eyJhbGciOiJSUzI1NiIsImtpZCI6Imp3ay10ZXN0LWU5NzQ1ZmJmLTNiNDQtNDkxYi1iNDAyLWFmZjhmMTNlZWM2OSIsInR5cCI6IkpXVCJ9.eyJhdWQiOlsicHJvamVjdC10ZXN0LTFiZWZkODEyLWU5ZTYtNDlmZS1iZGMxLWM4ZGYyYmZhNzAyYSJdLCJleHAiOjE3MjExNjI2MzMsImh0dHBzOi8vc3R5dGNoLmNvbS9zZXNzaW9uIjp7ImlkIjoic2Vzc2lvbi10ZXN0LWViOTQyMzNmLTg4MDAtNGViZC04NjQ1LTUxZGMxNWY5ZDAyOCIsInN0YXJ0ZWRfYXQiOiIyMDIxLTA4LTI4VDAwOjQxOjU4WiIsImxhc3RfYWNjZXNzZWRfYXQiOiIyMDI0LTA3LTE2VDIwOjM4OjUzWiIsImV4cGlyZXNfYXQiOiIyMDI0LTA3LTE2VDIxOjM4OjUzWiIsImF0dHJpYnV0ZXMiOnsidXNlcl9hZ2VudCI6IiIsImlwX2FkZHJlc3MiOiIifSwiYXV0aGVudGljYXRpb25fZmFjdG9ycyI6W3sidHlwZSI6Im1hZ2ljX2xpbmsiLCJkZWxpdmVyeV9tZXRob2QiOiJlbWFpbCIsImxhc3RfYXV0aGVudGljYXRlZF9hdCI6IjIwMjQtMDctMTZUMjA6Mzg6NTNaIiwiZW1haWxfZmFjdG9yIjp7ImVtYWlsX2lkIjoiZW1haWwtdGVzdC0yMzg3M2U4OS1kNGVkLTRlOTItYjNiOS1lNWM3MTk4ZmEyODYiLCJlbWFpbF9hZGRyZXNzIjoic2FuZGJveEBzdHl0Y2guY29tIn19XX0sImlhdCI6MTcyMTE2MjMzMywiaXNzIjoic3R5dGNoLmNvbS9wcm9qZWN0LXRlc3QtMWJlZmQ4MTItZTllNi00OWZlLWJkYzEtYzhkZjJiZmE3MDJhIiwibmJmIjoxNzIxMTYyMzMzLCJzdWIiOiJ1c2VyLXRlc3QtZTM3OTVjODEtZjg0OS00MTY3LWJmZGEtZTRhNmU5YzI4MGZkIn0.hm-vJeVBCun4I7Vkhj10VQevlQCdeyY2OWqFJ3B0tSSjE9s_8uAS8ag41hCyEzgVENyjIicU6r2Ow_dLol4dwXZrqU3cwNJsH6T62hxSzpG8cTVQJzQgirf7yjqMC-TzTrxKuPkta5sULMV6GBtW9-HDT5HlWQK5hp3BSxwq2Qc_WfmKv8ghV5YcBHRVCmu8IDnG0D19FrPOhLvLR6NPbNpidSA6gYP9OMl45f65hZHOQF178F6HtErZPpkDX2g3n6iymxznj5b5g2mEKhxjVYUxAWsFPE1FZXurc5ui47n3PaUu9J7zJmaUSi80lFQ-YKCZVSWAQ_DUOhwSJSEsLw"
# TOTP test constants
TEST_TOTP_USER_ID = "user-test-e3795c81-f849-4167-bfda-e4a6e9c280fd"
TEST_TOTP_CODE = "000000"
Expand Down
19 changes: 19 additions & 0 deletions test/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3

import unittest

from test.constants import (
TEST_CRYPTO_SIGNATURE,
TEST_CRYPTO_WALLET_ADDRESS,
Expand All @@ -18,6 +19,7 @@
TEST_TOTP_RECOVERY_CODE,
TEST_TOTP_USER_ID,
TEST_USERS_NAME,
TEST_EXPIRED_JWT,
)
from test.integration_base import CreatedTestUser, IntegrationTestBase

Expand Down Expand Up @@ -232,6 +234,23 @@ def test_webauthn(self) -> None:
# TODO: No test public key credential (see skipTest above)
self.assertTrue(api.authenticate(public_key_credential="").is_success)

def test_authenticate(self) -> None:
api = self.b2c_client.sessions

with self._get_temporary_user() as user:
assert isinstance(user, CreatedTestUser)
self.assertTrue(api.get(user_id=user.user_id).is_success)
# Grab a recent JWT token and verify it's valid
auth_response = api.authenticate(session_token=TEST_SESSION_TOKEN)
response = self.b2c_client.sessions.authenticate_jwt(session_jwt=auth_response.session_jwt)
self.assertIsNotNone(response)
if response is not None:
self.assertEquals(auth_response.session_jwt, response.session_jwt)

def test_authenticate_jwt_local_returns_none_for_expired_token(self) -> None:
api = self.b2c_client.sessions
self.assertIsNone(api.authenticate_jwt_local(session_jwt=TEST_EXPIRED_JWT))


if __name__ == "__main__":
unittest.main()
13 changes: 13 additions & 0 deletions test/test_integration_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,16 @@ async def test_webauthn_async(self) -> None:
self.assertTrue(
(await api.authenticate_async(public_key_credential="")).is_success
)

async def test_authenticate(self) -> None:
api = self.b2c_client.sessions

async with self._get_temporary_user_async() as user:
assert isinstance(user, CreatedTestUser)
self.assertTrue(api.get(user_id=user.user_id).is_success)
# Grab a recent JWT token and verify it's valid
auth_response = api.authenticate(session_token=TEST_SESSION_TOKEN)
response = await self.b2c_client.sessions.authenticate_jwt_async(session_jwt=auth_response.session_jwt)
self.assertIsNotNone(response)
if response is not None:
self.assertEquals(auth_response.session_jwt, response.session_jwt)

0 comments on commit a584134

Please sign in to comment.