Skip to content

Commit

Permalink
Filter out refresh_in from auth responses
Browse files Browse the repository at this point in the history
  • Loading branch information
rayluo committed Mar 4, 2021
1 parent 36365ac commit 76348db
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 32 deletions.
44 changes: 25 additions & 19 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ def _str2bytes(raw):
return raw


def _clean_up(result):
if isinstance(result, dict):
result.pop("refresh_in", None) # MSAL handled refresh_in, customers need not
return result


class ClientApplication(object):

ACQUIRE_TOKEN_SILENT_ID = "84"
Expand Down Expand Up @@ -507,7 +513,7 @@ def authorize(): # A controller in a web app
return redirect(url_for("index"))
"""
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
return self.client.obtain_token_by_auth_code_flow(
return _clean_up(self.client.obtain_token_by_auth_code_flow(
auth_code_flow,
auth_response,
scope=decorate_scope(scopes, self.client_id) if scopes else None,
Expand All @@ -521,7 +527,7 @@ def authorize(): # A controller in a web app
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities,
auth_code_flow.pop("claims_challenge", None))),
**kwargs)
**kwargs))

def acquire_token_by_authorization_code(
self,
Expand Down Expand Up @@ -580,7 +586,7 @@ def acquire_token_by_authorization_code(
"Change your acquire_token_by_authorization_code() "
"to acquire_token_by_auth_code_flow()", DeprecationWarning)
with warnings.catch_warnings(record=True):
return self.client.obtain_token_by_authorization_code(
return _clean_up(self.client.obtain_token_by_authorization_code(
code, redirect_uri=redirect_uri,
scope=decorate_scope(scopes, self.client_id),
headers={
Expand All @@ -593,7 +599,7 @@ def acquire_token_by_authorization_code(
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge)),
nonce=nonce,
**kwargs)
**kwargs))

def get_accounts(self, username=None):
"""Get a list of accounts which previously signed in, i.e. exists in cache.
Expand Down Expand Up @@ -855,13 +861,13 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
result = self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
authority, decorate_scope(scopes, self.client_id), account,
force_refresh=force_refresh, claims_challenge=claims_challenge, **kwargs)
result = _clean_up(result)
if (result and "error" not in result) or (not access_token_from_cache):
return result
except: # The exact HTTP exception is transportation-layer dependent
logger.exception("Refresh token failed") # Potential AAD outage?
return access_token_from_cache


def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
self, authority, scopes, account, **kwargs):
query = {
Expand Down Expand Up @@ -987,7 +993,7 @@ def acquire_token_by_refresh_token(self, refresh_token, scopes, **kwargs):
* A dict contains no "error" key means migration was successful.
"""
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
return self.client.obtain_token_by_refresh_token(
return _clean_up(self.client.obtain_token_by_refresh_token(
refresh_token,
scope=decorate_scope(scopes, self.client_id),
headers={
Expand All @@ -998,7 +1004,7 @@ def acquire_token_by_refresh_token(self, refresh_token, scopes, **kwargs):
rt_getter=lambda rt: rt,
on_updating_rt=False,
on_removing_rt=lambda rt_item: None, # No OP
**kwargs)
**kwargs))


class PublicClientApplication(ClientApplication): # browser app or mobile app
Expand Down Expand Up @@ -1072,7 +1078,7 @@ def acquire_token_interactive(
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
claims = _merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge)
return self.client.obtain_token_by_browser(
return _clean_up(self.client.obtain_token_by_browser(
scope=decorate_scope(scopes, self.client_id) if scopes else None,
extra_scope_to_consent=extra_scopes_to_consent,
redirect_uri="http://localhost:{port}".format(
Expand All @@ -1091,7 +1097,7 @@ def acquire_token_interactive(
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_INTERACTIVE),
},
**kwargs)
**kwargs))

def initiate_device_flow(self, scopes=None, **kwargs):
"""Initiate a Device Flow instance,
Expand Down Expand Up @@ -1134,7 +1140,7 @@ def acquire_token_by_device_flow(self, flow, claims_challenge=None, **kwargs):
- A successful response would contain "access_token" key,
- an error response would contain "error" and usually "error_description".
"""
return self.client.obtain_token_by_device_flow(
return _clean_up(self.client.obtain_token_by_device_flow(
flow,
data=dict(
kwargs.pop("data", {}),
Expand All @@ -1150,7 +1156,7 @@ def acquire_token_by_device_flow(self, flow, claims_challenge=None, **kwargs):
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_BY_DEVICE_FLOW_ID),
},
**kwargs)
**kwargs))

def acquire_token_by_username_password(
self, username, password, scopes, claims_challenge=None, **kwargs):
Expand Down Expand Up @@ -1188,15 +1194,15 @@ def acquire_token_by_username_password(
user_realm_result = self.authority.user_realm_discovery(
username, correlation_id=headers[CLIENT_REQUEST_ID])
if user_realm_result.get("account_type") == "Federated":
return self._acquire_token_by_username_password_federated(
return _clean_up(self._acquire_token_by_username_password_federated(
user_realm_result, username, password, scopes=scopes,
data=data,
headers=headers, **kwargs)
return self.client.obtain_token_by_username_password(
headers=headers, **kwargs))
return _clean_up(self.client.obtain_token_by_username_password(
username, password, scope=scopes,
headers=headers,
data=data,
**kwargs)
**kwargs))

def _acquire_token_by_username_password_federated(
self, user_realm_result, username, password, scopes=None, **kwargs):
Expand Down Expand Up @@ -1256,7 +1262,7 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
"""
# TBD: force_refresh behavior
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
return self.client.obtain_token_for_client(
return _clean_up(self.client.obtain_token_for_client(
scope=scopes, # This grant flow requires no scope decoration
headers={
CLIENT_REQUEST_ID: _get_new_correlation_id(),
Expand All @@ -1267,7 +1273,7 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
kwargs.pop("data", {}),
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge)),
**kwargs)
**kwargs))

def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=None, **kwargs):
"""Acquires token using on-behalf-of (OBO) flow.
Expand Down Expand Up @@ -1297,7 +1303,7 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No
"""
# The implementation is NOT based on Token Exchange
# https://tools.ietf.org/html/draft-ietf-oauth-token-exchange-16
return self.client.obtain_token_by_assertion( # bases on assertion RFC 7521
return _clean_up(self.client.obtain_token_by_assertion( # bases on assertion RFC 7521
user_assertion,
self.client.GRANT_TYPE_JWT, # IDTs and AAD ATs are all JWTs
scope=decorate_scope(scopes, self.client_id), # Decoration is used for:
Expand All @@ -1316,4 +1322,4 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_ON_BEHALF_OF_ID),
},
**kwargs)
**kwargs))
34 changes: 21 additions & 13 deletions tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,19 +354,23 @@ def test_fresh_token_should_be_returned_from_cache(self):
# a.k.a. Return unexpired token that is not above token refresh expiration threshold
access_token = "An access token prepopulated into cache"
self.populate_cache(access_token=access_token, expires_in=900, refresh_in=450)
self.assertEqual(
access_token,
self.app.acquire_token_silent(['s1'], self.account).get("access_token"))
result = self.app.acquire_token_silent(['s1'], self.account)
self.assertEqual(access_token, result.get("access_token"))
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")

def test_aging_token_and_available_aad_should_return_new_token(self):
# a.k.a. Attempt to refresh unexpired token when AAD available
self.populate_cache(access_token="old AT", expires_in=3599, refresh_in=-1)
new_access_token = "new AT"
self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = (
lambda *args, **kwargs: {"access_token": new_access_token})
self.assertEqual(
new_access_token,
self.app.acquire_token_silent(['s1'], self.account).get("access_token"))
def mock_post(*args, **kwargs):
return MinimalResponse(status_code=200, text=json.dumps({
"access_token": new_access_token,
"refresh_in": 123,
}))
self.app.http_client.post = mock_post
result = self.app.acquire_token_silent(['s1'], self.account)
self.assertEqual(new_access_token, result.get("access_token"))
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")

def test_aging_token_and_unavailable_aad_should_return_old_token(self):
# a.k.a. Attempt refresh unexpired token when AAD unavailable
Expand All @@ -393,9 +397,13 @@ def test_expired_token_and_available_aad_should_return_new_token(self):
# a.k.a. Attempt refresh expired token when AAD available
self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900)
new_access_token = "new AT"
self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = (
lambda *args, **kwargs: {"access_token": new_access_token})
self.assertEqual(
new_access_token,
self.app.acquire_token_silent(['s1'], self.account).get("access_token"))
def mock_post(*args, **kwargs):
return MinimalResponse(status_code=200, text=json.dumps({
"access_token": new_access_token,
"refresh_in": 123,
}))
self.app.http_client.post = mock_post
result = self.app.acquire_token_silent(['s1'], self.account)
self.assertEqual(new_access_token, result.get("access_token"))
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")

0 comments on commit 76348db

Please sign in to comment.