diff --git a/msal/application.py b/msal/application.py index 72bbecf3..4e1fba84 100644 --- a/msal/application.py +++ b/msal/application.py @@ -100,6 +100,12 @@ def _str2bytes(raw): return raw +def _clean_up(result): + if isinstance(result, dict): + result.pop("refresh_in", None) # MSAL handled refresh_in, customers need not + return result + + class ClientApplication(object): ACQUIRE_TOKEN_SILENT_ID = "84" @@ -507,7 +513,7 @@ def authorize(): # A controller in a web app return redirect(url_for("index")) """ self._validate_ssh_cert_input_data(kwargs.get("data", {})) - return self.client.obtain_token_by_auth_code_flow( + return _clean_up(self.client.obtain_token_by_auth_code_flow( auth_code_flow, auth_response, scope=decorate_scope(scopes, self.client_id) if scopes else None, @@ -521,7 +527,7 @@ def authorize(): # A controller in a web app claims=_merge_claims_challenge_and_capabilities( self._client_capabilities, auth_code_flow.pop("claims_challenge", None))), - **kwargs) + **kwargs)) def acquire_token_by_authorization_code( self, @@ -580,7 +586,7 @@ def acquire_token_by_authorization_code( "Change your acquire_token_by_authorization_code() " "to acquire_token_by_auth_code_flow()", DeprecationWarning) with warnings.catch_warnings(record=True): - return self.client.obtain_token_by_authorization_code( + return _clean_up(self.client.obtain_token_by_authorization_code( code, redirect_uri=redirect_uri, scope=decorate_scope(scopes, self.client_id), headers={ @@ -593,7 +599,7 @@ def acquire_token_by_authorization_code( claims=_merge_claims_challenge_and_capabilities( self._client_capabilities, claims_challenge)), nonce=nonce, - **kwargs) + **kwargs)) def get_accounts(self, username=None): """Get a list of accounts which previously signed in, i.e. exists in cache. @@ -855,13 +861,13 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it( result = self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( authority, decorate_scope(scopes, self.client_id), account, force_refresh=force_refresh, claims_challenge=claims_challenge, **kwargs) + result = _clean_up(result) if (result and "error" not in result) or (not access_token_from_cache): return result except: # The exact HTTP exception is transportation-layer dependent logger.exception("Refresh token failed") # Potential AAD outage? return access_token_from_cache - def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( self, authority, scopes, account, **kwargs): query = { @@ -987,7 +993,7 @@ def acquire_token_by_refresh_token(self, refresh_token, scopes, **kwargs): * A dict contains no "error" key means migration was successful. """ self._validate_ssh_cert_input_data(kwargs.get("data", {})) - return self.client.obtain_token_by_refresh_token( + return _clean_up(self.client.obtain_token_by_refresh_token( refresh_token, scope=decorate_scope(scopes, self.client_id), headers={ @@ -998,7 +1004,7 @@ def acquire_token_by_refresh_token(self, refresh_token, scopes, **kwargs): rt_getter=lambda rt: rt, on_updating_rt=False, on_removing_rt=lambda rt_item: None, # No OP - **kwargs) + **kwargs)) class PublicClientApplication(ClientApplication): # browser app or mobile app @@ -1072,7 +1078,7 @@ def acquire_token_interactive( self._validate_ssh_cert_input_data(kwargs.get("data", {})) claims = _merge_claims_challenge_and_capabilities( self._client_capabilities, claims_challenge) - return self.client.obtain_token_by_browser( + return _clean_up(self.client.obtain_token_by_browser( scope=decorate_scope(scopes, self.client_id) if scopes else None, extra_scope_to_consent=extra_scopes_to_consent, redirect_uri="http://localhost:{port}".format( @@ -1091,7 +1097,7 @@ def acquire_token_interactive( CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header( self.ACQUIRE_TOKEN_INTERACTIVE), }, - **kwargs) + **kwargs)) def initiate_device_flow(self, scopes=None, **kwargs): """Initiate a Device Flow instance, @@ -1134,7 +1140,7 @@ def acquire_token_by_device_flow(self, flow, claims_challenge=None, **kwargs): - A successful response would contain "access_token" key, - an error response would contain "error" and usually "error_description". """ - return self.client.obtain_token_by_device_flow( + return _clean_up(self.client.obtain_token_by_device_flow( flow, data=dict( kwargs.pop("data", {}), @@ -1150,7 +1156,7 @@ def acquire_token_by_device_flow(self, flow, claims_challenge=None, **kwargs): CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header( self.ACQUIRE_TOKEN_BY_DEVICE_FLOW_ID), }, - **kwargs) + **kwargs)) def acquire_token_by_username_password( self, username, password, scopes, claims_challenge=None, **kwargs): @@ -1188,15 +1194,15 @@ def acquire_token_by_username_password( user_realm_result = self.authority.user_realm_discovery( username, correlation_id=headers[CLIENT_REQUEST_ID]) if user_realm_result.get("account_type") == "Federated": - return self._acquire_token_by_username_password_federated( + return _clean_up(self._acquire_token_by_username_password_federated( user_realm_result, username, password, scopes=scopes, data=data, - headers=headers, **kwargs) - return self.client.obtain_token_by_username_password( + headers=headers, **kwargs)) + return _clean_up(self.client.obtain_token_by_username_password( username, password, scope=scopes, headers=headers, data=data, - **kwargs) + **kwargs)) def _acquire_token_by_username_password_federated( self, user_realm_result, username, password, scopes=None, **kwargs): @@ -1256,7 +1262,7 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs): """ # TBD: force_refresh behavior self._validate_ssh_cert_input_data(kwargs.get("data", {})) - return self.client.obtain_token_for_client( + return _clean_up(self.client.obtain_token_for_client( scope=scopes, # This grant flow requires no scope decoration headers={ CLIENT_REQUEST_ID: _get_new_correlation_id(), @@ -1267,7 +1273,7 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs): kwargs.pop("data", {}), claims=_merge_claims_challenge_and_capabilities( self._client_capabilities, claims_challenge)), - **kwargs) + **kwargs)) def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=None, **kwargs): """Acquires token using on-behalf-of (OBO) flow. @@ -1297,7 +1303,7 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No """ # The implementation is NOT based on Token Exchange # https://tools.ietf.org/html/draft-ietf-oauth-token-exchange-16 - return self.client.obtain_token_by_assertion( # bases on assertion RFC 7521 + return _clean_up(self.client.obtain_token_by_assertion( # bases on assertion RFC 7521 user_assertion, self.client.GRANT_TYPE_JWT, # IDTs and AAD ATs are all JWTs scope=decorate_scope(scopes, self.client_id), # Decoration is used for: @@ -1316,4 +1322,4 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header( self.ACQUIRE_TOKEN_ON_BEHALF_OF_ID), }, - **kwargs) + **kwargs)) diff --git a/tests/test_application.py b/tests/test_application.py index 3c3b4644..2ba66a8b 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -354,19 +354,23 @@ def test_fresh_token_should_be_returned_from_cache(self): # a.k.a. Return unexpired token that is not above token refresh expiration threshold access_token = "An access token prepopulated into cache" self.populate_cache(access_token=access_token, expires_in=900, refresh_in=450) - self.assertEqual( - access_token, - self.app.acquire_token_silent(['s1'], self.account).get("access_token")) + result = self.app.acquire_token_silent(['s1'], self.account) + self.assertEqual(access_token, result.get("access_token")) + self.assertNotIn("refresh_in", result, "Customers need not know refresh_in") def test_aging_token_and_available_aad_should_return_new_token(self): # a.k.a. Attempt to refresh unexpired token when AAD available self.populate_cache(access_token="old AT", expires_in=3599, refresh_in=-1) new_access_token = "new AT" - self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = ( - lambda *args, **kwargs: {"access_token": new_access_token}) - self.assertEqual( - new_access_token, - self.app.acquire_token_silent(['s1'], self.account).get("access_token")) + def mock_post(*args, **kwargs): + return MinimalResponse(status_code=200, text=json.dumps({ + "access_token": new_access_token, + "refresh_in": 123, + })) + self.app.http_client.post = mock_post + result = self.app.acquire_token_silent(['s1'], self.account) + self.assertEqual(new_access_token, result.get("access_token")) + self.assertNotIn("refresh_in", result, "Customers need not know refresh_in") def test_aging_token_and_unavailable_aad_should_return_old_token(self): # a.k.a. Attempt refresh unexpired token when AAD unavailable @@ -393,9 +397,13 @@ def test_expired_token_and_available_aad_should_return_new_token(self): # a.k.a. Attempt refresh expired token when AAD available self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900) new_access_token = "new AT" - self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = ( - lambda *args, **kwargs: {"access_token": new_access_token}) - self.assertEqual( - new_access_token, - self.app.acquire_token_silent(['s1'], self.account).get("access_token")) + def mock_post(*args, **kwargs): + return MinimalResponse(status_code=200, text=json.dumps({ + "access_token": new_access_token, + "refresh_in": 123, + })) + self.app.http_client.post = mock_post + result = self.app.acquire_token_silent(['s1'], self.account) + self.assertEqual(new_access_token, result.get("access_token")) + self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")