-
Notifications
You must be signed in to change notification settings - Fork 202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adjusting token cache format #19
Changes from all commits
0b0a583
e8e35ba
4a5255c
63bf224
a0eab4e
31bc31e
adcb637
9110eca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,13 +38,19 @@ def __init__(self): | |
def find(self, credential_type, target=None, query=None): | ||
target = target or [] | ||
assert isinstance(target, list), "Invalid parameter type" | ||
target_set = set(target) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is target scopes? if yes, do you have a place where you remove dupes? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes The |
||
with self._lock: | ||
# Since the target inside token cache key is (per schema) unsorted, | ||
# there is no point to attempt an O(1) key-value search here. | ||
# So we always do an O(n) in-memory search. | ||
return [entry | ||
for entry in self._cache.get(credential_type, {}).values() | ||
if is_subdict_of(query or {}, entry) | ||
and set(target) <= set(entry.get("target", []))] | ||
and (target_set <= set(entry.get("target", "").split()) | ||
if target else True) | ||
] | ||
|
||
def add(self, event): | ||
def add(self, event, now=None): | ||
# type: (dict) -> None | ||
# event typically contains: client_id, scope, token_endpoint, | ||
# resposne, params, data, grant_type | ||
|
@@ -56,9 +62,9 @@ def add(self, event): | |
default=str, # A workaround when assertion is in bytes in Python 3 | ||
)) | ||
response = event.get("response", {}) | ||
access_token = response.get("access_token", {}) | ||
refresh_token = response.get("refresh_token", {}) | ||
id_token = response.get("id_token", {}) | ||
access_token = response.get("access_token") | ||
refresh_token = response.get("refresh_token") | ||
id_token = response.get("id_token") | ||
client_info = {} | ||
home_account_id = None | ||
if "client_info" in response: | ||
|
@@ -67,6 +73,7 @@ def add(self, event): | |
environment = realm = None | ||
if "token_endpoint" in event: | ||
_, environment, realm = canonicalize(event["token_endpoint"]) | ||
target = ' '.join(event.get("scope", [])) # Per schema, we don't sort it | ||
|
||
with self._lock: | ||
|
||
|
@@ -77,20 +84,22 @@ def add(self, event): | |
self.CredentialType.ACCESS_TOKEN, | ||
event.get("client_id", ""), | ||
realm or "", | ||
' '.join(sorted(event.get("scope", []))), | ||
target, | ||
]).lower() | ||
now = time.time() | ||
now = time.time() if now is None else now | ||
expires_in = response.get("expires_in", 3599) | ||
self._cache.setdefault(self.CredentialType.ACCESS_TOKEN, {})[key] = { | ||
"credential_type": self.CredentialType.ACCESS_TOKEN, | ||
"secret": access_token, | ||
"home_account_id": home_account_id, | ||
"environment": environment, | ||
"client_id": event.get("client_id"), | ||
"target": event.get("scope"), | ||
"target": target, | ||
"realm": realm, | ||
"cached_at": now, | ||
"expires_on": now + response.get("expires_in", 3599), | ||
"extended_expires_on": now + response.get("ext_expires_in", 0), | ||
"cached_at": str(int(now)), # Schema defines it as a string | ||
"expires_on": str(int(now + expires_in)), # Same here | ||
"extended_expires_on": str(int( # Same here | ||
now + response.get("ext_expires_in", expires_in))), | ||
} | ||
|
||
if client_info: | ||
|
@@ -108,7 +117,10 @@ def add(self, event): | |
"local_account_id": decoded_id_token.get( | ||
"oid", decoded_id_token.get("sub")), | ||
"username": decoded_id_token.get("preferred_username"), | ||
"authority_type": "AAD", # Always AAD? | ||
"authority_type": | ||
"ADFS" if realm == "adfs" | ||
else "MSSTS", # MSSTS means AAD v2 for both AAD & MSA | ||
# "client_info": response.get("client_info"), # Optional | ||
} | ||
|
||
if id_token: | ||
|
@@ -118,6 +130,7 @@ def add(self, event): | |
self.CredentialType.ID_TOKEN, | ||
event.get("client_id", ""), | ||
realm or "", | ||
"" # Albeit irrelevant, schema requires an empty scope here | ||
]).lower() | ||
self._cache.setdefault(self.CredentialType.ID_TOKEN, {})[key] = { | ||
"credential_type": self.CredentialType.ID_TOKEN, | ||
|
@@ -132,16 +145,14 @@ def add(self, event): | |
if refresh_token: | ||
key = self._build_rt_key( | ||
home_account_id, environment, | ||
event.get("client_id", ""), event.get("scope", [])) | ||
event.get("client_id", ""), target) | ||
rt = { | ||
"credential_type": self.CredentialType.REFRESH_TOKEN, | ||
"secret": refresh_token, | ||
"home_account_id": home_account_id, | ||
"environment": environment, | ||
"client_id": event.get("client_id"), | ||
# Fields below are considered optional | ||
"target": event.get("scope"), | ||
"client_info": response.get("client_info"), | ||
"target": target, # Optional per schema though | ||
} | ||
if "foci" in response: | ||
rt["family_id"] = response["foci"] | ||
|
@@ -158,7 +169,7 @@ def _build_rt_key( | |
cls.CredentialType.REFRESH_TOKEN, | ||
client_id or "", | ||
"", # RT is cross-tenant in AAD | ||
' '.join(sorted(target or [])), | ||
target or "", # raw value could be None if deserialized from other SDK | ||
]).lower() | ||
|
||
def remove_rt(self, rt_item): | ||
|
@@ -169,7 +180,8 @@ def remove_rt(self, rt_item): | |
def update_rt(self, rt_item, new_rt): | ||
key = self._build_rt_key(**rt_item) | ||
with self._lock: | ||
rt = self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {})[key] | ||
RTs = self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {}) | ||
rt = RTs.get(key, {}) # key usually exists, but we'll survive its absence | ||
rt["secret"] = new_rt | ||
|
||
|
||
|
@@ -195,8 +207,8 @@ class SerializableTokenCache(TokenCache): | |
Indicates whether the cache state has changed since last | ||
:func:`~serialize` or :func:`~deserialize` call. | ||
""" | ||
def add(self, event): | ||
super(SerializableTokenCache, self).add(event) | ||
def add(self, event, **kwargs): | ||
super(SerializableTokenCache, self).add(event, **kwargs) | ||
self.has_state_changed = True | ||
|
||
def remove_rt(self, rt_item): | ||
|
@@ -219,5 +231,5 @@ def serialize(self): | |
"""Serialize the current cache state into a string.""" | ||
with self._lock: | ||
self.has_state_changed = False | ||
return json.dumps(self._cache) | ||
return json.dumps(self._cache, indent=4) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import logging | ||
import base64 | ||
import json | ||
|
||
from msal.token_cache import * | ||
from tests import unittest | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
logging.basicConfig(level=logging.DEBUG) | ||
|
||
|
||
class TokenCacheTestCase(unittest.TestCase): | ||
|
||
def setUp(self): | ||
self.cache = TokenCache() | ||
|
||
def testAdd(self): | ||
client_info = base64.b64encode(b''' | ||
{"uid": "uid", "utid": "utid"} | ||
''').decode('utf-8') | ||
id_token = "header.%s.signature" % base64.b64encode(b'''{ | ||
"sub": "subject", | ||
"oid": "object1234", | ||
"preferred_username": "John Doe" | ||
}''').decode('utf-8') | ||
self.cache.add({ | ||
"client_id": "my_client_id", | ||
"scope": ["s2", "s1", "s3"], # Not in particular order | ||
"token_endpoint": "https://login.example.com/contoso/v2/token", | ||
"response": { | ||
"access_token": "an access token", | ||
"token_type": "some type", | ||
"expires_in": 3600, | ||
"refresh_token": "a refresh token", | ||
"client_info": client_info, | ||
"id_token": id_token, | ||
}, | ||
}, now=1000) | ||
self.assertEqual( | ||
{ | ||
'cached_at': "1000", | ||
'client_id': 'my_client_id', | ||
'credential_type': 'AccessToken', | ||
'environment': 'login.example.com', | ||
'expires_on': "4600", | ||
'extended_expires_on': "4600", | ||
'home_account_id': "uid.utid", | ||
'realm': 'contoso', | ||
'secret': 'an access token', | ||
'target': 's2 s1 s3', | ||
}, | ||
self.cache._cache["AccessToken"].get( | ||
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s2 s1 s3') | ||
) | ||
self.assertEqual( | ||
{ | ||
'client_id': 'my_client_id', | ||
'credential_type': 'RefreshToken', | ||
'environment': 'login.example.com', | ||
'home_account_id': "uid.utid", | ||
'secret': 'a refresh token', | ||
'target': 's2 s1 s3', | ||
}, | ||
self.cache._cache["RefreshToken"].get( | ||
'uid.utid-login.example.com-refreshtoken-my_client_id--s2 s1 s3') | ||
) | ||
self.assertEqual( | ||
{ | ||
'home_account_id': "uid.utid", | ||
'environment': 'login.example.com', | ||
'realm': 'contoso', | ||
'local_account_id': "object1234", | ||
'username': "John Doe", | ||
'authority_type': "MSSTS", | ||
}, | ||
self.cache._cache["Account"].get('uid.utid-login.example.com-contoso') | ||
) | ||
self.assertEqual( | ||
{ | ||
'credential_type': 'IdToken', | ||
'secret': id_token, | ||
'home_account_id': "uid.utid", | ||
'environment': 'login.example.com', | ||
'realm': 'contoso', | ||
'client_id': 'my_client_id', | ||
}, | ||
self.cache._cache["IdToken"].get( | ||
'uid.utid-login.example.com-idtoken-my_client_id-contoso-') | ||
) | ||
|
||
|
||
class SerializableTokenCacheTestCase(TokenCacheTestCase): | ||
# Run all inherited test methods, and have extra check in tearDown() | ||
|
||
def setUp(self): | ||
self.cache = SerializableTokenCache() | ||
self.cache.deserialize(""" | ||
{ | ||
"AccessToken": { | ||
"an-entry": { | ||
"foo": "bar" | ||
} | ||
}, | ||
"customized": "whatever" | ||
} | ||
""") | ||
|
||
def tearDown(self): | ||
state = self.cache.serialize() | ||
logger.debug("serialize() = %s", state) | ||
# Now assert all extended content are kept intact | ||
output = json.loads(state) | ||
self.assertEqual(output.get("customized"), "whatever", | ||
"Undefined cache keys and their values should be intact") | ||
self.assertEqual( | ||
output.get("AccessToken", {}).get("an-entry"), {"foo": "bar"}, | ||
"Undefined token keys and their values should be intact") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems like expires_in is already an int? Did -now change it to something else? Should that then have been converted to an int?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question! It is true that the
expires_in
in the original server-side response is an integer. But here in this code path we hit a token in the cache, and then re-calculate a newexpires_in
based on its original value and the current time(), which happens to be a float in Python. So we do a Forceful Unobvious Conversion to Keep the Integer Time, at the exact place where it is needed.