Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjusting token cache format #19

Merged
merged 8 commits into from
Mar 4, 2019
7 changes: 5 additions & 2 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,12 +292,14 @@ def acquire_token_silent(
})
now = time.time()
for entry in matches:
if entry["expires_on"] - now < 5*60:
expires_in = int(entry["expires_on"]) - now
if expires_in < 5*60:
continue # Removal is not necessary, it will be overwritten
logger.debug("Cache hit an AT")
return { # Mimic a real response
"access_token": entry["secret"],
"token_type": "Bearer",
"expires_in": entry["expires_on"] - now,
"expires_in": int(expires_in), # OAuth2 specs defines it as int

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int [](start = 34, length = 3)

it seems like expires_in is already an int? Did -now change it to something else? Should that then have been converted to an int?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! It is true that the expires_in in the original server-side response is an integer. But here in this code path we hit a token in the cache, and then re-calculate a new expires_in based on its original value and the current time(), which happens to be a float in Python. So we do a Forceful Unobvious Conversion to Keep the Integer Time, at the exact place where it is needed.

}

matches = self.token_cache.find(
Expand All @@ -311,6 +313,7 @@ def acquire_token_silent(
})
client = self._build_client(self.client_credential, the_authority)
for entry in matches:
logger.debug("Cache hit an RT")
response = client.obtain_token_by_refresh_token(
entry, rt_getter=lambda token_item: token_item["secret"],
scope=decorate_scope(scopes, self.client_id))
Expand Down
54 changes: 33 additions & 21 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,19 @@ def __init__(self):
def find(self, credential_type, target=None, query=None):
target = target or []
assert isinstance(target, list), "Invalid parameter type"
target_set = set(target)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is target scopes? if yes, do you have a place where you remove dupes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes target is exactly scope in Unified Schema's dialect.

The set(...) method creates a native set object in Python. It automatically deduplicates the scopes.

with self._lock:
# Since the target inside token cache key is (per schema) unsorted,
# there is no point to attempt an O(1) key-value search here.
# So we always do an O(n) in-memory search.
return [entry
for entry in self._cache.get(credential_type, {}).values()
if is_subdict_of(query or {}, entry)
and set(target) <= set(entry.get("target", []))]
and (target_set <= set(entry.get("target", "").split())
if target else True)
]

def add(self, event):
def add(self, event, now=None):
# type: (dict) -> None
# event typically contains: client_id, scope, token_endpoint,
# resposne, params, data, grant_type
Expand All @@ -56,9 +62,9 @@ def add(self, event):
default=str, # A workaround when assertion is in bytes in Python 3
))
response = event.get("response", {})
access_token = response.get("access_token", {})
refresh_token = response.get("refresh_token", {})
id_token = response.get("id_token", {})
access_token = response.get("access_token")
refresh_token = response.get("refresh_token")
id_token = response.get("id_token")
client_info = {}
home_account_id = None
if "client_info" in response:
Expand All @@ -67,6 +73,7 @@ def add(self, event):
environment = realm = None
if "token_endpoint" in event:
_, environment, realm = canonicalize(event["token_endpoint"])
target = ' '.join(event.get("scope", [])) # Per schema, we don't sort it

with self._lock:

Expand All @@ -77,20 +84,22 @@ def add(self, event):
self.CredentialType.ACCESS_TOKEN,
event.get("client_id", ""),
realm or "",
' '.join(sorted(event.get("scope", []))),
target,
]).lower()
now = time.time()
now = time.time() if now is None else now
expires_in = response.get("expires_in", 3599)
self._cache.setdefault(self.CredentialType.ACCESS_TOKEN, {})[key] = {
"credential_type": self.CredentialType.ACCESS_TOKEN,
"secret": access_token,
"home_account_id": home_account_id,
"environment": environment,
"client_id": event.get("client_id"),
"target": event.get("scope"),
"target": target,
"realm": realm,
"cached_at": now,
"expires_on": now + response.get("expires_in", 3599),
"extended_expires_on": now + response.get("ext_expires_in", 0),
"cached_at": str(int(now)), # Schema defines it as a string
"expires_on": str(int(now + expires_in)), # Same here
"extended_expires_on": str(int( # Same here
now + response.get("ext_expires_in", expires_in))),
}

if client_info:
Expand All @@ -108,7 +117,10 @@ def add(self, event):
"local_account_id": decoded_id_token.get(
"oid", decoded_id_token.get("sub")),
"username": decoded_id_token.get("preferred_username"),
"authority_type": "AAD", # Always AAD?
"authority_type":
"ADFS" if realm == "adfs"
else "MSSTS", # MSSTS means AAD v2 for both AAD & MSA
# "client_info": response.get("client_info"), # Optional
}

if id_token:
Expand All @@ -118,6 +130,7 @@ def add(self, event):
self.CredentialType.ID_TOKEN,
event.get("client_id", ""),
realm or "",
"" # Albeit irrelevant, schema requires an empty scope here
]).lower()
self._cache.setdefault(self.CredentialType.ID_TOKEN, {})[key] = {
"credential_type": self.CredentialType.ID_TOKEN,
Expand All @@ -132,16 +145,14 @@ def add(self, event):
if refresh_token:
key = self._build_rt_key(
home_account_id, environment,
event.get("client_id", ""), event.get("scope", []))
event.get("client_id", ""), target)
rt = {
"credential_type": self.CredentialType.REFRESH_TOKEN,
"secret": refresh_token,
"home_account_id": home_account_id,
"environment": environment,
"client_id": event.get("client_id"),
# Fields below are considered optional
"target": event.get("scope"),
"client_info": response.get("client_info"),
"target": target, # Optional per schema though
}
if "foci" in response:
rt["family_id"] = response["foci"]
Expand All @@ -158,7 +169,7 @@ def _build_rt_key(
cls.CredentialType.REFRESH_TOKEN,
client_id or "",
"", # RT is cross-tenant in AAD
' '.join(sorted(target or [])),
target or "", # raw value could be None if deserialized from other SDK
]).lower()

def remove_rt(self, rt_item):
Expand All @@ -169,7 +180,8 @@ def remove_rt(self, rt_item):
def update_rt(self, rt_item, new_rt):
key = self._build_rt_key(**rt_item)
with self._lock:
rt = self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {})[key]
RTs = self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {})
rt = RTs.get(key, {}) # key usually exists, but we'll survive its absence
rt["secret"] = new_rt


Expand All @@ -195,8 +207,8 @@ class SerializableTokenCache(TokenCache):
Indicates whether the cache state has changed since last
:func:`~serialize` or :func:`~deserialize` call.
"""
def add(self, event):
super(SerializableTokenCache, self).add(event)
def add(self, event, **kwargs):
super(SerializableTokenCache, self).add(event, **kwargs)
self.has_state_changed = True

def remove_rt(self, rt_item):
Expand All @@ -219,5 +231,5 @@ def serialize(self):
"""Serialize the current cache state into a string."""
with self._lock:
self.has_state_changed = False
return json.dumps(self._cache)
return json.dumps(self._cache, indent=4)

5 changes: 3 additions & 2 deletions sample/client_credential_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
config["client_id"], authority=config["authority"],
client_credential=config["secret"],
# token_cache=... # Default cache is in memory only.
# See SerializableTokenCache for more details.
# You can learn how to use SerializableTokenCache from
# https://msal-python.rtfd.io/en/latest/#msal.SerializableTokenCache
)

# The pattern to acquire a token looks like this.
Expand All @@ -42,7 +43,7 @@
result = app.acquire_token_silent(config["scope"], account=None)

if not result:
# So no suitable token exists in cache. Let's get a new one from AAD.
logging.info("No suitable token exists in cache. Let's get a new one from AAD.")
result = app.acquire_token_for_client(scopes=config["scope"])

if "access_token" in result:
Expand Down
9 changes: 5 additions & 4 deletions sample/device_flow_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
{
"authority": "https://login.microsoftonline.com/organizations",
"client_id": "your_client_id",
"scope": ["user.read"]
"scope": ["User.Read"]
}

You can then run this sample with a JSON configuration file:
Expand All @@ -28,7 +28,8 @@
app = msal.PublicClientApplication(
config["client_id"], authority=config["authority"],
# token_cache=... # Default cache is in memory only.
# See SerializableTokenCache for more details.
# You can learn how to use SerializableTokenCache from
# https://msal-python.rtfd.io/en/latest/#msal.SerializableTokenCache
)

# The pattern to acquire a token looks like this.
Expand All @@ -39,7 +40,7 @@
# We now check the cache to see if we have some end users signed in before.
accounts = app.get_accounts()
if accounts:
# If so, you could then somehow display these accounts and let end user choose
logging.info("Account(s) exists in cache, probably with token too. Let's try.")
print("Pick the account you want to use to proceed:")
for a in accounts:
print(a["username"])
Expand All @@ -49,7 +50,7 @@
result = app.acquire_token_silent(config["scope"], account=chosen)

if not result:
# So no suitable token exists in cache. Let's get a new one from AAD.
logging.info("No suitable token exists in cache. Let's get a new one from AAD.")
flow = app.initiate_device_flow(scopes=config["scope"])
print(flow["message"])
# Ideally you should wait here, in order to save some unnecessary polling
Expand Down
9 changes: 5 additions & 4 deletions sample/username_password_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"authority": "https://login.microsoftonline.com/organizations",
"client_id": "your_client_id",
"username": "your_username@your_tenant.com",
"scope": ["user.read"],
"scope": ["User.Read"],
"password": "This is a sample only. You better NOT persist your password."
}

Expand All @@ -30,7 +30,8 @@
app = msal.PublicClientApplication(
config["client_id"], authority=config["authority"],
# token_cache=... # Default cache is in memory only.
# See SerializableTokenCache for more details.
# You can learn how to use SerializableTokenCache from
# https://msal-python.rtfd.io/en/latest/#msal.SerializableTokenCache
)

# The pattern to acquire a token looks like this.
Expand All @@ -39,11 +40,11 @@
# Firstly, check the cache to see if this end user has signed in before
accounts = app.get_accounts(username=config["username"])
if accounts:
# It means the account(s) exists in cache, probably with token too. Let's try.
logging.info("Account(s) exists in cache, probably with token too. Let's try.")
result = app.acquire_token_silent(config["scope"], account=accounts[0])

if not result:
# So no suitable token exists in cache. Let's get a new one from AAD.
logging.info("No suitable token exists in cache. Let's get a new one from AAD.")
result = app.acquire_token_by_username_password(
config["username"], config["password"], scopes=config["scope"])

Expand Down
119 changes: 119 additions & 0 deletions tests/test_token_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import logging
import base64
import json

from msal.token_cache import *
from tests import unittest


logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)


class TokenCacheTestCase(unittest.TestCase):

def setUp(self):
self.cache = TokenCache()

def testAdd(self):
client_info = base64.b64encode(b'''
{"uid": "uid", "utid": "utid"}
''').decode('utf-8')
id_token = "header.%s.signature" % base64.b64encode(b'''{
"sub": "subject",
"oid": "object1234",
"preferred_username": "John Doe"
}''').decode('utf-8')
self.cache.add({
"client_id": "my_client_id",
"scope": ["s2", "s1", "s3"], # Not in particular order
"token_endpoint": "https://login.example.com/contoso/v2/token",
"response": {
"access_token": "an access token",
"token_type": "some type",
"expires_in": 3600,
"refresh_token": "a refresh token",
"client_info": client_info,
"id_token": id_token,
},
}, now=1000)
self.assertEqual(
{
'cached_at': "1000",
'client_id': 'my_client_id',
'credential_type': 'AccessToken',
'environment': 'login.example.com',
'expires_on': "4600",
'extended_expires_on': "4600",
'home_account_id': "uid.utid",
'realm': 'contoso',
'secret': 'an access token',
'target': 's2 s1 s3',
},
self.cache._cache["AccessToken"].get(
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s2 s1 s3')
)
self.assertEqual(
{
'client_id': 'my_client_id',
'credential_type': 'RefreshToken',
'environment': 'login.example.com',
'home_account_id': "uid.utid",
'secret': 'a refresh token',
'target': 's2 s1 s3',
},
self.cache._cache["RefreshToken"].get(
'uid.utid-login.example.com-refreshtoken-my_client_id--s2 s1 s3')
)
self.assertEqual(
{
'home_account_id': "uid.utid",
'environment': 'login.example.com',
'realm': 'contoso',
'local_account_id': "object1234",
'username': "John Doe",
'authority_type': "MSSTS",
},
self.cache._cache["Account"].get('uid.utid-login.example.com-contoso')
)
self.assertEqual(
{
'credential_type': 'IdToken',
'secret': id_token,
'home_account_id': "uid.utid",
'environment': 'login.example.com',
'realm': 'contoso',
'client_id': 'my_client_id',
},
self.cache._cache["IdToken"].get(
'uid.utid-login.example.com-idtoken-my_client_id-contoso-')
)


class SerializableTokenCacheTestCase(TokenCacheTestCase):
# Run all inherited test methods, and have extra check in tearDown()

def setUp(self):
self.cache = SerializableTokenCache()
self.cache.deserialize("""
{
"AccessToken": {
"an-entry": {
"foo": "bar"
}
},
"customized": "whatever"
}
""")

def tearDown(self):
state = self.cache.serialize()
logger.debug("serialize() = %s", state)
# Now assert all extended content are kept intact
output = json.loads(state)
self.assertEqual(output.get("customized"), "whatever",
"Undefined cache keys and their values should be intact")
self.assertEqual(
output.get("AccessToken", {}).get("an-entry"), {"foo": "bar"},
"Undefined token keys and their values should be intact")