diff --git a/src/azure-cli-core/azure/cli/core/_profile.py b/src/azure-cli-core/azure/cli/core/_profile.py index 9d1be08ddcc..e0d45196a51 100644 --- a/src/azure-cli-core/azure/cli/core/_profile.py +++ b/src/azure-cli-core/azure/cli/core/_profile.py @@ -95,9 +95,9 @@ class CredentialType(Enum): # pylint: disable=too-few-public-methods class Profile(object): def __init__(self, storage=None, auth_ctx_factory=None): self._storage = storage or ACCOUNT - factory = auth_ctx_factory or _AUTH_CTX_FACTORY - self._creds_cache = CredsCache(factory) - self._subscription_finder = SubscriptionFinder(factory, self._creds_cache.adal_token_cache) + self.auth_ctx_factory = auth_ctx_factory or _AUTH_CTX_FACTORY + self._creds_cache = CredsCache(self.auth_ctx_factory) + self._management_resource_uri = CLOUD.endpoints.management self._ad_resource_uri = CLOUD.endpoints.active_directory_resource_id def find_subscriptions_on_login(self, # pylint: disable=too-many-arguments @@ -105,22 +105,27 @@ def find_subscriptions_on_login(self, # pylint: disable=too-many-arguments username, password, is_service_principal, - tenant): + tenant, + subscription_finder=None): from azure.cli.core._debug import allow_debug_adal_connection allow_debug_adal_connection() subscriptions = [] + + if not subscription_finder: + subscription_finder = SubscriptionFinder(self.auth_ctx_factory, + self._creds_cache.adal_token_cache) if interactive: - subscriptions = self._subscription_finder.find_through_interactive_flow( + subscriptions = subscription_finder.find_through_interactive_flow( tenant, self._ad_resource_uri) else: if is_service_principal: if not tenant: raise CLIError('Please supply tenant using "--tenant"') sp_auth = ServicePrincipalAuth(password) - subscriptions = self._subscription_finder.find_from_service_principal_id( + subscriptions = subscription_finder.find_from_service_principal_id( username, sp_auth, tenant, self._ad_resource_uri) else: - subscriptions = self._subscription_finder.find_from_user_account( + subscriptions = subscription_finder.find_from_user_account( username, password, tenant, self._ad_resource_uri) if not subscriptions: @@ -132,7 +137,7 @@ def find_subscriptions_on_login(self, # pylint: disable=too-many-arguments if self._creds_cache.adal_token_cache.has_state_changed: self._creds_cache.persist_cached_creds() - consolidated = Profile._normalize_properties(self._subscription_finder.user_id, + consolidated = Profile._normalize_properties(subscription_finder.user_id, subscriptions, is_service_principal) self._set_subscriptions(consolidated) @@ -256,6 +261,9 @@ def get_subscription(self, subscription=None): # take id or name raise CLIError("Please run 'az account set' to select active account.") return result[0] + def get_subscription_id(self): + return self.get_subscription()[_SUBSCRIPTION_ID] + def get_login_credentials(self, resource=CLOUD.endpoints.active_directory_resource_id, subscription_id=None): account = self.get_subscription(subscription_id) @@ -428,8 +436,7 @@ def __init__(self, auth_ctx_factory=None): os.path.join(get_config_dir(), 'accessTokens.json')) self._service_principal_creds = [] self._auth_ctx_factory = auth_ctx_factory or _AUTH_CTX_FACTORY - self.adal_token_cache = None - self._load_creds() + self._adal_token_cache_attr = None def persist_cached_creds(self): with os.fdopen(os.open(self._token_file, os.O_RDWR | os.O_CREAT | os.O_TRUNC, 0o600), @@ -459,6 +466,7 @@ def retrieve_token_for_user(self, username, tenant, resource): return (token_entry[_TOKEN_ENTRY_TOKEN_TYPE], token_entry[_ACCESS_TOKEN]) def retrieve_token_for_service_principal(self, sp_id, resource): + self.load_adal_token_cache() matched = [x for x in self._service_principal_creds if sp_id == x[_SERVICE_PRINCIPAL_ID]] if not matched: raise CLIError("Please run 'az account set' to select active account.") @@ -471,23 +479,28 @@ def retrieve_token_for_service_principal(self, sp_id, resource): return (token_entry[_TOKEN_ENTRY_TOKEN_TYPE], token_entry[_ACCESS_TOKEN]) def retrieve_secret_of_service_principal(self, sp_id): + self.load_adal_token_cache() matched = [x for x in self._service_principal_creds if sp_id == x[_SERVICE_PRINCIPAL_ID]] if not matched: raise CLIError("No matched service principal found") cred = matched[0] return cred[_ACCESS_TOKEN] - def _load_creds(self): - import adal - if self.adal_token_cache is not None: - return self.adal_token_cache - all_entries = _load_tokens_from_file(self._token_file) - self._load_service_principal_creds(all_entries) - real_token = [x for x in all_entries if x not in self._service_principal_creds] - self.adal_token_cache = adal.TokenCache(json.dumps(real_token)) - return self.adal_token_cache + @property + def adal_token_cache(self): + return self.load_adal_token_cache() + + def load_adal_token_cache(self): + if self._adal_token_cache_attr is None: + import adal + all_entries = _load_tokens_from_file(self._token_file) + self._load_service_principal_creds(all_entries) + real_token = [x for x in all_entries if x not in self._service_principal_creds] + self._adal_token_cache_attr = adal.TokenCache(json.dumps(real_token)) + return self._adal_token_cache_attr def save_service_principal_cred(self, sp_entry): + self.load_adal_token_cache() matched = [x for x in self._service_principal_creds if sp_entry[_SERVICE_PRINCIPAL_ID] == x[_SERVICE_PRINCIPAL_ID] and sp_entry[_SERVICE_PRINCIPAL_TENANT] == x[_SERVICE_PRINCIPAL_TENANT]] diff --git a/src/azure-cli-core/azure/cli/core/telemetry.py b/src/azure-cli-core/azure/cli/core/telemetry.py index c5cdc53d219..f5ae41f25b9 100644 --- a/src/azure-cli-core/azure/cli/core/telemetry.py +++ b/src/azure-cli-core/azure/cli/core/telemetry.py @@ -322,7 +322,7 @@ def _get_env_string(): @decorators.suppress_all_exceptions(fallback_return=None) def _get_azure_subscription_id(): - return _get_profile().get_login_credentials()[1] + return _get_profile().get_subscription_id() def _get_shell_type(): diff --git a/src/azure-cli-core/tests/test_profile.py b/src/azure-cli-core/tests/test_profile.py index ff837886818..e344826e285 100644 --- a/src/azure-cli-core/tests/test_profile.py +++ b/src/azure-cli-core/tests/test_profile.py @@ -238,8 +238,12 @@ def test_get_expanded_subscription_info_for_logged_in_service_principal(self, storage_mock = {'subscriptions': []} profile = Profile(storage_mock) profile._management_resource_uri = 'https://management.core.windows.net/' - profile._subscription_finder = finder - profile.find_subscriptions_on_login(False, '1234', 'my-secret', True, self.tenant_id) + profile.find_subscriptions_on_login(False, + '1234', + 'my-secret', + True, + self.tenant_id, + finder) # action extended_info = profile.get_expanded_subscription_info() # assert @@ -310,7 +314,6 @@ def test_get_login_credentials(self, mock_get_token, mock_read_cred_file): token_type, token = cred._token_retriever() self.assertEqual(token, self.raw_token1) self.assertEqual(some_token_type, token_type) - self.assertEqual(mock_read_cred_file.call_count, 1) mock_get_token.assert_called_once_with(mock.ANY, self.user1, self.tenant_id, 'https://management.core.windows.net/') self.assertEqual(mock_get_token.call_count, 1) @@ -535,7 +538,7 @@ def test_credscache_load_tokens_and_sp_creds_with_secret(self, mock_read_file): creds_cache = CredsCache() # assert - token_entries = [entry for _, entry in creds_cache.adal_token_cache.read_items()] + token_entries = [entry for _, entry in creds_cache.load_adal_token_cache().read_items()] self.assertEqual(token_entries, [self.token_entry1]) self.assertEqual(creds_cache._service_principal_creds, [test_sp]) @@ -550,6 +553,7 @@ def test_credscache_load_tokens_and_sp_creds_with_cert(self, mock_read_file): # action creds_cache = CredsCache() + creds_cache.load_adal_token_cache() # assert self.assertEqual(creds_cache._service_principal_creds, [test_sp])