Skip to content

Commit 36365ac

Browse files
committed
Implement refresh_in behavior, and some test cases
1 parent 34e0b82 commit 36365ac

File tree

4 files changed

+115
-15
lines changed

4 files changed

+115
-15
lines changed

msal/application.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,7 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
822822
force_refresh=False, # type: Optional[boolean]
823823
claims_challenge=None,
824824
**kwargs):
825+
access_token_from_cache = None
825826
if not (force_refresh or claims_challenge): # Bypass AT when desired or using claims
826827
query={
827828
"client_id": self.client_id,
@@ -839,17 +840,27 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
839840
now = time.time()
840841
for entry in matches:
841842
expires_in = int(entry["expires_on"]) - now
842-
if expires_in < 5*60:
843+
if expires_in < 5*60: # Then consider it expired
843844
continue # Removal is not necessary, it will be overwritten
844845
logger.debug("Cache hit an AT")
845-
return { # Mimic a real response
846+
access_token_from_cache = { # Mimic a real response
846847
"access_token": entry["secret"],
847848
"token_type": entry.get("token_type", "Bearer"),
848849
"expires_in": int(expires_in), # OAuth2 specs defines it as int
849850
}
850-
return self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
851+
if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging
852+
break # With a fallback in hand, we break here to go refresh
853+
return access_token_from_cache # It is still good as new
854+
try:
855+
result = self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
851856
authority, decorate_scope(scopes, self.client_id), account,
852857
force_refresh=force_refresh, claims_challenge=claims_challenge, **kwargs)
858+
if (result and "error" not in result) or (not access_token_from_cache):
859+
return result
860+
except: # The exact HTTP exception is transportation-layer dependent
861+
logger.exception("Refresh token failed") # Potential AAD outage?
862+
return access_token_from_cache
863+
853864

854865
def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
855866
self, authority, scopes, account, **kwargs):

msal/token_cache.py

+3
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,9 @@ def __add(self, event, now=None):
170170
}
171171
if data.get("key_id"): # It happens in SSH-cert or POP scenario
172172
at["key_id"] = data.get("key_id")
173+
if "refresh_in" in response:
174+
refresh_in = response["refresh_in"] # It is an integer
175+
at["refresh_on"] = str(now + refresh_in) # Schema wants a string
173176
self.modify(self.CredentialType.ACCESS_TOKEN, at, at)
174177

175178
if client_info and not event.get("skip_account_creation"):

tests/test_application.py

+80
Original file line numberDiff line numberDiff line change
@@ -319,3 +319,83 @@ def test_only_client_capabilities_no_claims_merge(self):
319319

320320
def test_both_claims_and_capabilities_none(self):
321321
self.assertEqual(_merge_claims_challenge_and_capabilities(None, None), None)
322+
323+
324+
class TestApplicationForRefreshInBehaviors(unittest.TestCase):
325+
"""The following test cases were based on design doc here
326+
https://identitydivision.visualstudio.com/DevEx/_git/AuthLibrariesApiReview?path=%2FRefreshAtExpirationPercentage%2Foverview.md&version=GBdev&_a=preview&anchor=scenarios
327+
"""
328+
def setUp(self):
329+
self.authority_url = "https://login.microsoftonline.com/common"
330+
self.authority = msal.authority.Authority(
331+
self.authority_url, MinimalHttpClient())
332+
self.scopes = ["s1", "s2"]
333+
self.uid = "my_uid"
334+
self.utid = "my_utid"
335+
self.account = {"home_account_id": "{}.{}".format(self.uid, self.utid)}
336+
self.rt = "this is a rt"
337+
self.cache = msal.SerializableTokenCache()
338+
self.client_id = "my_app"
339+
self.app = ClientApplication(
340+
self.client_id, authority=self.authority_url, token_cache=self.cache)
341+
342+
def populate_cache(self, access_token="at", expires_in=86400, refresh_in=43200):
343+
self.cache.add({
344+
"client_id": self.client_id,
345+
"scope": self.scopes,
346+
"token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url),
347+
"response": TokenCacheTestCase.build_response(
348+
access_token=access_token,
349+
expires_in=expires_in, refresh_in=refresh_in,
350+
uid=self.uid, utid=self.utid, refresh_token=self.rt),
351+
})
352+
353+
def test_fresh_token_should_be_returned_from_cache(self):
354+
# a.k.a. Return unexpired token that is not above token refresh expiration threshold
355+
access_token = "An access token prepopulated into cache"
356+
self.populate_cache(access_token=access_token, expires_in=900, refresh_in=450)
357+
self.assertEqual(
358+
access_token,
359+
self.app.acquire_token_silent(['s1'], self.account).get("access_token"))
360+
361+
def test_aging_token_and_available_aad_should_return_new_token(self):
362+
# a.k.a. Attempt to refresh unexpired token when AAD available
363+
self.populate_cache(access_token="old AT", expires_in=3599, refresh_in=-1)
364+
new_access_token = "new AT"
365+
self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = (
366+
lambda *args, **kwargs: {"access_token": new_access_token})
367+
self.assertEqual(
368+
new_access_token,
369+
self.app.acquire_token_silent(['s1'], self.account).get("access_token"))
370+
371+
def test_aging_token_and_unavailable_aad_should_return_old_token(self):
372+
# a.k.a. Attempt refresh unexpired token when AAD unavailable
373+
old_at = "old AT"
374+
self.populate_cache(access_token=old_at, expires_in=3599, refresh_in=-1)
375+
self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = (
376+
lambda *args, **kwargs: {"error": "sth went wrong"})
377+
self.assertEqual(
378+
old_at,
379+
self.app.acquire_token_silent(['s1'], self.account).get("access_token"))
380+
381+
def test_expired_token_and_unavailable_aad_should_return_error(self):
382+
# a.k.a. Attempt refresh expired token when AAD unavailable
383+
self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900)
384+
error = "something went wrong"
385+
self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = (
386+
lambda *args, **kwargs: {"error": error})
387+
self.assertEqual(
388+
error,
389+
self.app.acquire_token_silent_with_error( # This variant preserves error
390+
['s1'], self.account).get("error"))
391+
392+
def test_expired_token_and_available_aad_should_return_new_token(self):
393+
# a.k.a. Attempt refresh expired token when AAD available
394+
self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900)
395+
new_access_token = "new AT"
396+
self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = (
397+
lambda *args, **kwargs: {"access_token": new_access_token})
398+
self.assertEqual(
399+
new_access_token,
400+
self.app.acquire_token_silent(['s1'], self.account).get("access_token"))
401+

tests/test_token_cache.py

+18-12
Original file line numberDiff line numberDiff line change
@@ -29,30 +29,20 @@ def build_id_token(
2929
def build_response( # simulate a response from AAD
3030
uid=None, utid=None, # If present, they will form client_info
3131
access_token=None, expires_in=3600, token_type="some type",
32-
refresh_token=None,
33-
foci=None,
34-
id_token=None, # or something generated by build_id_token()
35-
error=None,
32+
**kwargs # Pass-through: refresh_token, foci, id_token, error, refresh_in, ...
3633
):
3734
response = {}
3835
if uid and utid: # Mimic the AAD behavior for "client_info=1" request
3936
response["client_info"] = base64.b64encode(json.dumps({
4037
"uid": uid, "utid": utid,
4138
}).encode()).decode('utf-8')
42-
if error:
43-
response["error"] = error
4439
if access_token:
4540
response.update({
4641
"access_token": access_token,
4742
"expires_in": expires_in,
4843
"token_type": token_type,
4944
})
50-
if refresh_token:
51-
response["refresh_token"] = refresh_token
52-
if id_token:
53-
response["id_token"] = id_token
54-
if foci:
55-
response["foci"] = foci
45+
response.update(kwargs) # Pass-through key-value pairs as top-level fields
5646
return response
5747

5848
def setUp(self):
@@ -222,6 +212,21 @@ def test_key_id_is_also_recorded(self):
222212
{}).get("key_id")
223213
self.assertEqual(my_key_id, cached_key_id, "AT should be bound to the key")
224214

215+
def test_refresh_in_should_be_recorded_as_refresh_on(self): # Sounds weird. Yep.
216+
self.cache.add({
217+
"client_id": "my_client_id",
218+
"scope": ["s2", "s1", "s3"], # Not in particular order
219+
"token_endpoint": "https://login.example.com/contoso/v2/token",
220+
"response": self.build_response(
221+
uid="uid", utid="utid", # client_info
222+
expires_in=3600, refresh_in=1800, access_token="an access token",
223+
), #refresh_token="a refresh token"),
224+
}, now=1000)
225+
refresh_on = self.cache._cache["AccessToken"].get(
226+
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s2 s1 s3',
227+
{}).get("refresh_on")
228+
self.assertEqual("2800", refresh_on, "Should save refresh_on")
229+
225230
def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt(self):
226231
sample = {
227232
'client_id': 'my_client_id',
@@ -241,6 +246,7 @@ def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt(self):
241246
'uid.utid-login.example.com-refreshtoken-my_client_id--s2 s1 s3')
242247
)
243248

249+
244250
class SerializableTokenCacheTestCase(TokenCacheTestCase):
245251
# Run all inherited test methods, and have extra check in tearDown()
246252

0 commit comments

Comments
 (0)