diff --git a/msal/application.py b/msal/application.py index 7ca62d7c..812abbfb 100644 --- a/msal/application.py +++ b/msal/application.py @@ -21,13 +21,14 @@ import msal.telemetry from .region import _detect_region from .throttled_http_client import ThrottledHttpClient +from .cloudshell import _is_running_in_cloud_shell # The __init__.py will import this. Not the other way around. __version__ = "1.17.0" # When releasing, also check and bump our dependencies's versions if needed logger = logging.getLogger(__name__) - +_AUTHORITY_TYPE_CLOUDSHELL = "CLOUDSHELL" def extract_certs(public_cert_content): # Parses raw public certificate file contents and returns a list of strings @@ -986,6 +987,10 @@ def get_accounts(self, username=None): return accounts def _find_msal_accounts(self, environment): + interested_authority_types = [ + TokenCache.AuthorityType.ADFS, TokenCache.AuthorityType.MSSTS] + if _is_running_in_cloud_shell(): + interested_authority_types.append(_AUTHORITY_TYPE_CLOUDSHELL) grouped_accounts = { a.get("home_account_id"): # Grouped by home tenant's id { # These are minimal amount of non-tenant-specific account info @@ -1001,8 +1006,7 @@ def _find_msal_accounts(self, environment): for a in self.token_cache.find( TokenCache.CredentialType.ACCOUNT, query={"environment": environment}) - if a["authority_type"] in ( - TokenCache.AuthorityType.ADFS, TokenCache.AuthorityType.MSSTS) + if a["authority_type"] in interested_authority_types } return list(grouped_accounts.values()) @@ -1062,6 +1066,21 @@ def _forget_me(self, home_account): TokenCache.CredentialType.ACCOUNT, query=owned_by_home_account): self.token_cache.remove_account(a) + def _acquire_token_by_cloud_shell(self, scopes, data=None): + from .cloudshell import _obtain_token + response = _obtain_token( + self.http_client, scopes, client_id=self.client_id, data=data) + if "error" not in response: + self.token_cache.add(dict( + client_id=self.client_id, + scope=response["scope"].split() if "scope" in response else scopes, + token_endpoint=self.authority.token_endpoint, + response=response.copy(), + data=data or {}, + authority_type=_AUTHORITY_TYPE_CLOUDSHELL, + )) + return response + def acquire_token_silent( self, scopes, # type: List[str] @@ -1195,6 +1214,7 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it( authority, # This can be different than self.authority force_refresh=False, # type: Optional[boolean] claims_challenge=None, + correlation_id=None, **kwargs): access_token_from_cache = None if not (force_refresh or claims_challenge): # Bypass AT when desired or using claims @@ -1233,9 +1253,13 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it( refresh_reason = msal.telemetry.FORCE_REFRESH # TODO: It could also mean claims_challenge assert refresh_reason, "It should have been established at this point" try: + if account and account.get("authority_type") == _AUTHORITY_TYPE_CLOUDSHELL: + return self._acquire_token_by_cloud_shell( + scopes, data=kwargs.get("data")) result = _clean_up(self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( authority, self._decorate_scope(scopes), account, refresh_reason=refresh_reason, claims_challenge=claims_challenge, + correlation_id=correlation_id, **kwargs)) if (result and "error" not in result) or (not access_token_from_cache): return result @@ -1574,6 +1598,9 @@ def acquire_token_interactive( - A dict containing an "error" key, when token refresh failed. """ self._validate_ssh_cert_input_data(kwargs.get("data", {})) + if _is_running_in_cloud_shell() and prompt == "none": + return self._acquire_token_by_cloud_shell( + scopes, data=kwargs.pop("data", {})) claims = _merge_claims_challenge_and_capabilities( self._client_capabilities, claims_challenge) telemetry_context = self._build_telemetry_context( diff --git a/msal/cloudshell.py b/msal/cloudshell.py new file mode 100644 index 00000000..f4feaf44 --- /dev/null +++ b/msal/cloudshell.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. + +"""This module wraps Cloud Shell's IMDS-like interface inside an OAuth2-like helper""" +import base64 +import json +import logging +import os +import time +try: # Python 2 + from urlparse import urlparse +except: # Python 3 + from urllib.parse import urlparse +from .oauth2cli.oidc import decode_part + + +logger = logging.getLogger(__name__) + + +def _is_running_in_cloud_shell(): + return os.environ.get("AZUREPS_HOST_ENVIRONMENT", "").startswith("cloud-shell") + + +def _scope_to_resource(scope): # This is an experimental reasonable-effort approach + cloud_shell_supported_audiences = [ + "https://analysis.windows.net/powerbi/api", # Came from https://msazure.visualstudio.com/One/_git/compute-CloudShell?path=/src/images/agent/env/envconfig.PROD.json + "https://pas.windows.net/CheckMyAccess/Linux/.default", # Cloud Shell accepts it as-is + ] + for a in cloud_shell_supported_audiences: + if scope.startswith(a): + return a + u = urlparse(scope) + if u.scheme: + return "{}://{}".format(u.scheme, u.netloc) + return scope # There is no much else we can do here + + +def _obtain_token(http_client, scopes, client_id=None, data=None): + resp = http_client.post( + "http://localhost:50342/oauth2/token", + data=dict( + data or {}, + resource=" ".join(map(_scope_to_resource, scopes))), + headers={"Metadata": "true"}, + ) + if resp.status_code >= 300: + logger.debug("Cloud Shell IMDS error: %s", resp.text) + cs_error = json.loads(resp.text).get("error", {}) + return {k: v for k, v in { + "error": cs_error.get("code"), + "error_description": cs_error.get("message"), + }.items() if v} + imds_payload = json.loads(resp.text) + BEARER = "Bearer" + oauth2_response = { + "access_token": imds_payload["access_token"], + "expires_in": int(imds_payload["expires_in"]), + "token_type": imds_payload.get("token_type", BEARER), + } + expected_token_type = (data or {}).get("token_type", BEARER) + if oauth2_response["token_type"] != expected_token_type: + return { # Generate a normal error (rather than an intrusive exception) + "error": "broker_error", + "error_description": "token_type {} is not supported by this version of Azure Portal".format( + expected_token_type), + } + parts = imds_payload["access_token"].split(".") + + # The following default values are useful in SSH Cert scenario + client_info = { # Default value, in case the real value will be unavailable + "uid": "user", + "utid": "cloudshell", + } + now = time.time() + preferred_username = "currentuser@cloudshell" + oauth2_response["id_token_claims"] = { # First 5 claims are required per OIDC + "iss": "cloudshell", + "sub": "user", + "aud": client_id, + "exp": now + 3600, + "iat": now, + "preferred_username": preferred_username, # Useful as MSAL account's username + } + + if len(parts) == 3: # Probably a JWT. Use it to derive client_info and id token. + try: + # Data defined in https://docs.microsoft.com/en-us/azure/active-directory/develop/access-tokens#payload-claims + jwt_payload = json.loads(decode_part(parts[1])) + client_info = { + # Mimic a real home_account_id, + # so that this pseudo account and a real account would interop. + "uid": jwt_payload.get("oid", "user"), + "utid": jwt_payload.get("tid", "cloudshell"), + } + oauth2_response["id_token_claims"] = { + "iss": jwt_payload["iss"], + "sub": jwt_payload["sub"], # Could use oid instead + "aud": client_id, + "exp": jwt_payload["exp"], + "iat": jwt_payload["iat"], + "preferred_username": jwt_payload.get("preferred_username") # V2 + or jwt_payload.get("unique_name") # V1 + or preferred_username, + } + except ValueError: + logger.debug("Unable to decode jwt payload: %s", parts[1]) + oauth2_response["client_info"] = base64.b64encode( + # Mimic a client_info, so that MSAL would create an account + json.dumps(client_info).encode("utf-8")).decode("utf-8") + oauth2_response["id_token_claims"]["tid"] = client_info["utid"] # TBD + + ## Note: Decided to not surface resource back as scope, + ## because they would cause the downstream OAuth2 code path to + ## cache the token with a different scope and won't hit them later. + #if imds_payload.get("resource"): + # oauth2_response["scope"] = imds_payload["resource"] + if imds_payload.get("refresh_token"): + oauth2_response["refresh_token"] = imds_payload["refresh_token"] + return oauth2_response + diff --git a/msal/token_cache.py b/msal/token_cache.py index 2ed819d7..f7d9f955 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -113,6 +113,7 @@ def wipe(dictionary, sensitive_fields): # Masks sensitive info return self.__add(event, now=now) finally: wipe(event.get("response", {}), ( # These claims were useful during __add() + "id_token_claims", # Provided by broker "access_token", "refresh_token", "id_token", "username")) wipe(event, ["username"]) # Needed for federated ROPC logger.debug("event=%s", json.dumps( @@ -150,7 +151,8 @@ def __add(self, event, now=None): id_token = response.get("id_token") id_token_claims = ( decode_id_token(id_token, client_id=event["client_id"]) - if id_token else {}) + if id_token + else response.get("id_token_claims", {})) # Broker would provide id_token_claims client_info, home_account_id = self.__parse_account(response, id_token_claims) target = ' '.join(event.get("scope") or []) # Per schema, we don't sort it @@ -195,9 +197,10 @@ def __add(self, event, now=None): or data.get("username") # Falls back to ROPC username or event.get("username") # Falls back to Federated ROPC username or "", # The schema does not like null - "authority_type": + "authority_type": event.get( + "authority_type", # Honor caller's choice of authority_type self.AuthorityType.ADFS if realm == "adfs" - else self.AuthorityType.MSSTS, + else self.AuthorityType.MSSTS), # "client_info": response.get("client_info"), # Optional } self.modify(self.CredentialType.ACCOUNT, account, account) diff --git a/tests/test_e2e.py b/tests/test_e2e.py index f74c0767..9a971f46 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -185,12 +185,14 @@ def _test_acquire_token_interactive( self, client_id=None, authority=None, scope=None, port=None, username_uri="", # But you would want to provide one data=None, # Needed by ssh-cert feature + prompt=None, **ignored): assert client_id and authority and scope self.app = msal.PublicClientApplication( client_id, authority=authority, http_client=MinimalHttpClient()) result = self.app.acquire_token_interactive( scope, + prompt=prompt, timeout=120, port=port, welcome_template= # This is an undocumented feature for testing @@ -237,6 +239,7 @@ def test_ssh_cert_for_user(self): scope=self.SCOPE, data=self.DATA1, username_uri="https://msidlab.com/api/user?usertype=cloud", + prompt="none" if msal.application._is_running_in_cloud_shell() else None, ) # It already tests reading AT from cache, and using RT to refresh # acquire_token_silent() would work because we pass in the same key self.assertIsNotNone(result.get("access_token"), "Encountered {}: {}".format( @@ -254,6 +257,20 @@ def test_ssh_cert_for_user(self): self.assertNotEqual(result["access_token"], refreshed_ssh_cert['access_token']) +@unittest.skipUnless( + msal.application._is_running_in_cloud_shell(), + "Manually run this test case from inside Cloud Shell") +class CloudShellTestCase(E2eTestCase): + app = msal.PublicClientApplication("client_id") + scope_that_requires_no_managed_device = "https://management.core.windows.net/" # Scopes came from https://msazure.visualstudio.com/One/_git/compute-CloudShell?path=/src/images/agent/env/envconfig.PROD.json&version=GBmaster&_a=contents + def test_access_token_should_be_obtained_for_a_supported_scope(self): + result = self.app.acquire_token_interactive( + [self.scope_that_requires_no_managed_device], prompt="none") + self.assertEqual( + "Bearer", result.get("token_type"), "Unexpected result: %s" % result) + self.assertIsNotNone(result.get("access_token")) + + THIS_FOLDER = os.path.dirname(__file__) CONFIG = os.path.join(THIS_FOLDER, "config.json") @unittest.skipUnless(os.path.exists(CONFIG), "Optional %s not found" % CONFIG)