From 0b0a583fd2fd0c447e81aa249924d8894fcef297 Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Thu, 7 Feb 2019 10:51:19 -0800 Subject: [PATCH 1/8] Handle potential race-condition in RT updating --- msal/token_cache.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/msal/token_cache.py b/msal/token_cache.py index 556ecd60..2318bc6f 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -56,9 +56,9 @@ def add(self, event): default=str, # A workaround when assertion is in bytes in Python 3 )) response = event.get("response", {}) - access_token = response.get("access_token", {}) - refresh_token = response.get("refresh_token", {}) - id_token = response.get("id_token", {}) + access_token = response.get("access_token") + refresh_token = response.get("refresh_token") + id_token = response.get("id_token") client_info = {} home_account_id = None if "client_info" in response: @@ -169,7 +169,8 @@ def remove_rt(self, rt_item): def update_rt(self, rt_item, new_rt): key = self._build_rt_key(**rt_item) with self._lock: - rt = self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {})[key] + RTs = self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {}) + rt = RTs.get(key, {}) # key usually exists, but we'll survive its absence rt["secret"] = new_rt From e8e35baa574f7fa70a8c49bd467a25d615628f88 Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Fri, 8 Feb 2019 19:18:52 -0800 Subject: [PATCH 2/8] Per Unified Schema, change target to be an unsorted string --- msal/token_cache.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/msal/token_cache.py b/msal/token_cache.py index 2318bc6f..67d30fad 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -38,11 +38,17 @@ def __init__(self): def find(self, credential_type, target=None, query=None): target = target or [] assert isinstance(target, list), "Invalid parameter type" + target_set = set(target) with self._lock: + # Since the target inside token cache key is (per schema) unsorted, + # there is no point to attempt an O(1) key-value search here. + # So we always do an O(n) in-memory search. return [entry for entry in self._cache.get(credential_type, {}).values() if is_subdict_of(query or {}, entry) - and set(target) <= set(entry.get("target", []))] + and (target_set <= set(entry.get("target", "").split()) + if target else True) + ] def add(self, event): # type: (dict) -> None @@ -67,6 +73,7 @@ def add(self, event): environment = realm = None if "token_endpoint" in event: _, environment, realm = canonicalize(event["token_endpoint"]) + target = ' '.join(event.get("scope", [])) # Per schema, we don't sort it with self._lock: @@ -77,7 +84,7 @@ def add(self, event): self.CredentialType.ACCESS_TOKEN, event.get("client_id", ""), realm or "", - ' '.join(sorted(event.get("scope", []))), + target, ]).lower() now = time.time() self._cache.setdefault(self.CredentialType.ACCESS_TOKEN, {})[key] = { @@ -86,7 +93,7 @@ def add(self, event): "home_account_id": home_account_id, "environment": environment, "client_id": event.get("client_id"), - "target": event.get("scope"), + "target": target, "realm": realm, "cached_at": now, "expires_on": now + response.get("expires_in", 3599), @@ -132,7 +139,7 @@ def add(self, event): if refresh_token: key = self._build_rt_key( home_account_id, environment, - event.get("client_id", ""), event.get("scope", [])) + event.get("client_id", ""), target) rt = { "credential_type": self.CredentialType.REFRESH_TOKEN, "secret": refresh_token, @@ -140,7 +147,7 @@ def add(self, event): "environment": environment, "client_id": event.get("client_id"), # Fields below are considered optional - "target": event.get("scope"), + "target": target, "client_info": response.get("client_info"), } if "foci" in response: @@ -158,7 +165,7 @@ def _build_rt_key( cls.CredentialType.REFRESH_TOKEN, client_id or "", "", # RT is cross-tenant in AAD - ' '.join(sorted(target or [])), + target, ]).lower() def remove_rt(self, rt_item): From 4a5255ca66345577ddc206e337373796ce464112 Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Thu, 14 Feb 2019 17:20:51 -0800 Subject: [PATCH 3/8] Remove optional field client_info --- msal/token_cache.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/msal/token_cache.py b/msal/token_cache.py index 67d30fad..47392014 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -116,6 +116,7 @@ def add(self, event): "oid", decoded_id_token.get("sub")), "username": decoded_id_token.get("preferred_username"), "authority_type": "AAD", # Always AAD? + # "client_info": response.get("client_info"), # Optional } if id_token: @@ -146,9 +147,7 @@ def add(self, event): "home_account_id": home_account_id, "environment": environment, "client_id": event.get("client_id"), - # Fields below are considered optional - "target": target, - "client_info": response.get("client_info"), + "target": target, # Optional per schema though } if "foci" in response: rt["family_id"] = response["foci"] From 63bf224f3b0187cf51e583a9f2e20a827511dd7d Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Sat, 9 Feb 2019 01:10:50 -0800 Subject: [PATCH 4/8] Add test cases for TokenCache and SerializableTokenCache --- msal/token_cache.py | 8 +-- tests/test_token_cache.py | 119 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+), 4 deletions(-) create mode 100644 tests/test_token_cache.py diff --git a/msal/token_cache.py b/msal/token_cache.py index 47392014..9bdb115d 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -50,7 +50,7 @@ def find(self, credential_type, target=None, query=None): if target else True) ] - def add(self, event): + def add(self, event, now=None): # type: (dict) -> None # event typically contains: client_id, scope, token_endpoint, # resposne, params, data, grant_type @@ -86,7 +86,7 @@ def add(self, event): realm or "", target, ]).lower() - now = time.time() + now = time.time() if now is None else now self._cache.setdefault(self.CredentialType.ACCESS_TOKEN, {})[key] = { "credential_type": self.CredentialType.ACCESS_TOKEN, "secret": access_token, @@ -202,8 +202,8 @@ class SerializableTokenCache(TokenCache): Indicates whether the cache state has changed since last :func:`~serialize` or :func:`~deserialize` call. """ - def add(self, event): - super(SerializableTokenCache, self).add(event) + def add(self, event, **kwargs): + super(SerializableTokenCache, self).add(event, **kwargs) self.has_state_changed = True def remove_rt(self, rt_item): diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py new file mode 100644 index 00000000..79fa4ab8 --- /dev/null +++ b/tests/test_token_cache.py @@ -0,0 +1,119 @@ +import logging +import base64 +import json + +from msal.token_cache import * +from tests import unittest + + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.DEBUG) + + +class TokenCacheTestCase(unittest.TestCase): + + def setUp(self): + self.cache = TokenCache() + + def testAdd(self): + client_info = base64.b64encode(b''' + {"uid": "uid", "utid": "utid"} + ''').decode('utf-8') + id_token = "header.%s.signature" % base64.b64encode(b'''{ + "sub": "subject", + "oid": "object1234", + "preferred_username": "John Doe" + }''').decode('utf-8') + self.cache.add({ + "client_id": "my_client_id", + "scope": ["s2", "s1", "s3"], # Not in particular order + "token_endpoint": "https://login.example.com/contoso/v2/token", + "response": { + "access_token": "an access token", + "token_type": "some type", + "expires_in": 3600, + "refresh_token": "a refresh token", + "client_info": client_info, + "id_token": id_token, + }, + }, now=1000) + self.assertEqual( + { + 'cached_at': 1000, + 'client_id': 'my_client_id', + 'credential_type': 'AccessToken', + 'environment': 'login.example.com', + 'expires_on': 4600, + 'extended_expires_on': 1000, + 'home_account_id': "uid.utid", + 'realm': 'contoso', + 'secret': 'an access token', + 'target': 's2 s1 s3', + }, + self.cache._cache["AccessToken"].get( + 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s2 s1 s3') + ) + self.assertEqual( + { + 'client_id': 'my_client_id', + 'credential_type': 'RefreshToken', + 'environment': 'login.example.com', + 'home_account_id': "uid.utid", + 'secret': 'a refresh token', + 'target': 's2 s1 s3', + }, + self.cache._cache["RefreshToken"].get( + 'uid.utid-login.example.com-refreshtoken-my_client_id--s2 s1 s3') + ) + self.assertEqual( + { + 'home_account_id': "uid.utid", + 'environment': 'login.example.com', + 'realm': 'contoso', + 'local_account_id': "object1234", + 'username': "John Doe", + 'authority_type': "AAD", + }, + self.cache._cache["Account"].get('uid.utid-login.example.com-contoso') + ) + self.assertEqual( + { + 'credential_type': 'IdToken', + 'secret': id_token, + 'home_account_id': "uid.utid", + 'environment': 'login.example.com', + 'realm': 'contoso', + 'client_id': 'my_client_id', + }, + self.cache._cache["IdToken"].get( + 'uid.utid-login.example.com-idtoken-my_client_id-contoso') + ) + + +class SerializableTokenCacheTestCase(TokenCacheTestCase): + # Run all inherited test methods, and have extra check in tearDown() + + def setUp(self): + self.cache = SerializableTokenCache() + self.cache.deserialize(""" + { + "AccessToken": { + "an-entry": { + "foo": "bar" + } + }, + "customized": "whatever" + } + """) + + def tearDown(self): + state = self.cache.serialize() + logger.debug("serialize() = %s", state) + # Now assert all extended content are kept intact + output = json.loads(state) + self.assertEqual(output.get("customized"), "whatever", + "Undefined cache keys and their values should be intact") + self.assertEqual( + output.get("AccessToken", {}).get("an-entry"), {"foo": "bar"}, + "Undefined token keys and their values should be intact") + From a0eab4e3e9f69546d164edb4c1433630103faeb3 Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Tue, 26 Feb 2019 14:50:46 -0800 Subject: [PATCH 5/8] Adjusting IdToken key, RT target behavior, and authority_type value --- msal/token_cache.py | 7 +++++-- tests/test_token_cache.py | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/msal/token_cache.py b/msal/token_cache.py index 9bdb115d..2cdedcb3 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -115,7 +115,9 @@ def add(self, event, now=None): "local_account_id": decoded_id_token.get( "oid", decoded_id_token.get("sub")), "username": decoded_id_token.get("preferred_username"), - "authority_type": "AAD", # Always AAD? + "authority_type": + "ADFS" if realm == "adfs" + else "MSSTS", # MSSTS means AAD v2 for both AAD & MSA # "client_info": response.get("client_info"), # Optional } @@ -126,6 +128,7 @@ def add(self, event, now=None): self.CredentialType.ID_TOKEN, event.get("client_id", ""), realm or "", + "" # Albeit irrelevant, schema requires an empty scope here ]).lower() self._cache.setdefault(self.CredentialType.ID_TOKEN, {})[key] = { "credential_type": self.CredentialType.ID_TOKEN, @@ -164,7 +167,7 @@ def _build_rt_key( cls.CredentialType.REFRESH_TOKEN, client_id or "", "", # RT is cross-tenant in AAD - target, + target or "", # raw value could be None if deserialized from other SDK ]).lower() def remove_rt(self, rt_item): diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py index 79fa4ab8..40a86c07 100644 --- a/tests/test_token_cache.py +++ b/tests/test_token_cache.py @@ -72,7 +72,7 @@ def testAdd(self): 'realm': 'contoso', 'local_account_id': "object1234", 'username': "John Doe", - 'authority_type': "AAD", + 'authority_type': "MSSTS", }, self.cache._cache["Account"].get('uid.utid-login.example.com-contoso') ) @@ -86,7 +86,7 @@ def testAdd(self): 'client_id': 'my_client_id', }, self.cache._cache["IdToken"].get( - 'uid.utid-login.example.com-idtoken-my_client_id-contoso') + 'uid.utid-login.example.com-idtoken-my_client_id-contoso-') ) From 31bc31e7c73181942681ecac6ef5f078c1b8c87f Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Tue, 26 Feb 2019 14:55:32 -0800 Subject: [PATCH 6/8] Schema defines cached_at, expires_on, ext_expires_on as string --- msal/application.py | 7 +++++-- msal/token_cache.py | 8 +++++--- tests/test_token_cache.py | 6 +++--- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/msal/application.py b/msal/application.py index 113fe51b..79547c16 100644 --- a/msal/application.py +++ b/msal/application.py @@ -292,12 +292,14 @@ def acquire_token_silent( }) now = time.time() for entry in matches: - if entry["expires_on"] - now < 5*60: + expires_in = int(entry["expires_on"]) - now + if expires_in < 5*60: continue # Removal is not necessary, it will be overwritten + logger.debug("Cache hit an AT") return { # Mimic a real response "access_token": entry["secret"], "token_type": "Bearer", - "expires_in": entry["expires_on"] - now, + "expires_in": int(expires_in), # OAuth2 specs defines it as int } matches = self.token_cache.find( @@ -311,6 +313,7 @@ def acquire_token_silent( }) client = self._build_client(self.client_credential, the_authority) for entry in matches: + logger.debug("Cache hit an RT") response = client.obtain_token_by_refresh_token( entry, rt_getter=lambda token_item: token_item["secret"], scope=decorate_scope(scopes, self.client_id)) diff --git a/msal/token_cache.py b/msal/token_cache.py index 2cdedcb3..dc649919 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -87,6 +87,7 @@ def add(self, event, now=None): target, ]).lower() now = time.time() if now is None else now + expires_in = response.get("expires_in", 3599) self._cache.setdefault(self.CredentialType.ACCESS_TOKEN, {})[key] = { "credential_type": self.CredentialType.ACCESS_TOKEN, "secret": access_token, @@ -95,9 +96,10 @@ def add(self, event, now=None): "client_id": event.get("client_id"), "target": target, "realm": realm, - "cached_at": now, - "expires_on": now + response.get("expires_in", 3599), - "extended_expires_on": now + response.get("ext_expires_in", 0), + "cached_at": str(int(now)), # Schema defines it as a string + "expires_on": str(int(now + expires_in)), # Same here + "extended_expires_on": str(int( # Same here + now + response.get("ext_expires_in", expires_in))), } if client_info: diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py index 40a86c07..eebd751d 100644 --- a/tests/test_token_cache.py +++ b/tests/test_token_cache.py @@ -39,12 +39,12 @@ def testAdd(self): }, now=1000) self.assertEqual( { - 'cached_at': 1000, + 'cached_at': "1000", 'client_id': 'my_client_id', 'credential_type': 'AccessToken', 'environment': 'login.example.com', - 'expires_on': 4600, - 'extended_expires_on': 1000, + 'expires_on': "4600", + 'extended_expires_on': "4600", 'home_account_id': "uid.utid", 'realm': 'contoso', 'secret': 'an access token', From adcb6372a48d759c37a7a4d89d9fa55c435da8f3 Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Tue, 26 Feb 2019 14:57:25 -0800 Subject: [PATCH 7/8] Indentation in serialization for easier debugging --- msal/token_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/msal/token_cache.py b/msal/token_cache.py index dc649919..116be878 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -231,5 +231,5 @@ def serialize(self): """Serialize the current cache state into a string.""" with self._lock: self.has_state_changed = False - return json.dumps(self._cache) + return json.dumps(self._cache, indent=4) From 9110eca2ce0dc850437b19c8acddac98364a2c9c Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Tue, 26 Feb 2019 19:10:35 -0800 Subject: [PATCH 8/8] Use case-sensitive scope, reference SerializableTokenCache, and log behaviors for debugging --- sample/client_credential_sample.py | 5 +++-- sample/device_flow_sample.py | 9 +++++---- sample/username_password_sample.py | 9 +++++---- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/sample/client_credential_sample.py b/sample/client_credential_sample.py index cb5ccc26..5f539465 100644 --- a/sample/client_credential_sample.py +++ b/sample/client_credential_sample.py @@ -30,7 +30,8 @@ config["client_id"], authority=config["authority"], client_credential=config["secret"], # token_cache=... # Default cache is in memory only. - # See SerializableTokenCache for more details. + # You can learn how to use SerializableTokenCache from + # https://msal-python.rtfd.io/en/latest/#msal.SerializableTokenCache ) # The pattern to acquire a token looks like this. @@ -42,7 +43,7 @@ result = app.acquire_token_silent(config["scope"], account=None) if not result: - # So no suitable token exists in cache. Let's get a new one from AAD. + logging.info("No suitable token exists in cache. Let's get a new one from AAD.") result = app.acquire_token_for_client(scopes=config["scope"]) if "access_token" in result: diff --git a/sample/device_flow_sample.py b/sample/device_flow_sample.py index 182fcce9..8c46c6b0 100644 --- a/sample/device_flow_sample.py +++ b/sample/device_flow_sample.py @@ -4,7 +4,7 @@ { "authority": "https://login.microsoftonline.com/organizations", "client_id": "your_client_id", - "scope": ["user.read"] + "scope": ["User.Read"] } You can then run this sample with a JSON configuration file: @@ -28,7 +28,8 @@ app = msal.PublicClientApplication( config["client_id"], authority=config["authority"], # token_cache=... # Default cache is in memory only. - # See SerializableTokenCache for more details. + # You can learn how to use SerializableTokenCache from + # https://msal-python.rtfd.io/en/latest/#msal.SerializableTokenCache ) # The pattern to acquire a token looks like this. @@ -39,7 +40,7 @@ # We now check the cache to see if we have some end users signed in before. accounts = app.get_accounts() if accounts: - # If so, you could then somehow display these accounts and let end user choose + logging.info("Account(s) exists in cache, probably with token too. Let's try.") print("Pick the account you want to use to proceed:") for a in accounts: print(a["username"]) @@ -49,7 +50,7 @@ result = app.acquire_token_silent(config["scope"], account=chosen) if not result: - # So no suitable token exists in cache. Let's get a new one from AAD. + logging.info("No suitable token exists in cache. Let's get a new one from AAD.") flow = app.initiate_device_flow(scopes=config["scope"]) print(flow["message"]) # Ideally you should wait here, in order to save some unnecessary polling diff --git a/sample/username_password_sample.py b/sample/username_password_sample.py index a34acaee..0137ae6e 100644 --- a/sample/username_password_sample.py +++ b/sample/username_password_sample.py @@ -5,7 +5,7 @@ "authority": "https://login.microsoftonline.com/organizations", "client_id": "your_client_id", "username": "your_username@your_tenant.com", - "scope": ["user.read"], + "scope": ["User.Read"], "password": "This is a sample only. You better NOT persist your password." } @@ -30,7 +30,8 @@ app = msal.PublicClientApplication( config["client_id"], authority=config["authority"], # token_cache=... # Default cache is in memory only. - # See SerializableTokenCache for more details. + # You can learn how to use SerializableTokenCache from + # https://msal-python.rtfd.io/en/latest/#msal.SerializableTokenCache ) # The pattern to acquire a token looks like this. @@ -39,11 +40,11 @@ # Firstly, check the cache to see if this end user has signed in before accounts = app.get_accounts(username=config["username"]) if accounts: - # It means the account(s) exists in cache, probably with token too. Let's try. + logging.info("Account(s) exists in cache, probably with token too. Let's try.") result = app.acquire_token_silent(config["scope"], account=accounts[0]) if not result: - # So no suitable token exists in cache. Let's get a new one from AAD. + logging.info("No suitable token exists in cache. Let's get a new one from AAD.") result = app.acquire_token_by_username_password( config["username"], config["password"], scopes=config["scope"])