Skip to content
Merged
51 changes: 32 additions & 19 deletions src/azure-cli-core/azure/cli/core/_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,32 +95,37 @@ class CredentialType(Enum): # pylint: disable=too-few-public-methods
class Profile(object):
def __init__(self, storage=None, auth_ctx_factory=None):
self._storage = storage or ACCOUNT
factory = auth_ctx_factory or _AUTH_CTX_FACTORY
self._creds_cache = CredsCache(factory)
self._subscription_finder = SubscriptionFinder(factory, self._creds_cache.adal_token_cache)
self.auth_ctx_factory = auth_ctx_factory or _AUTH_CTX_FACTORY
self._creds_cache = CredsCache(self.auth_ctx_factory)
self._management_resource_uri = CLOUD.endpoints.management
self._ad_resource_uri = CLOUD.endpoints.active_directory_resource_id

def find_subscriptions_on_login(self, # pylint: disable=too-many-arguments
interactive,
username,
password,
is_service_principal,
tenant):
tenant,
subscription_finder=None):
from azure.cli.core._debug import allow_debug_adal_connection
allow_debug_adal_connection()
subscriptions = []

if not subscription_finder:
subscription_finder = SubscriptionFinder(self.auth_ctx_factory,
self._creds_cache.adal_token_cache)
if interactive:
subscriptions = self._subscription_finder.find_through_interactive_flow(
subscriptions = subscription_finder.find_through_interactive_flow(
tenant, self._ad_resource_uri)
else:
if is_service_principal:
if not tenant:
raise CLIError('Please supply tenant using "--tenant"')
sp_auth = ServicePrincipalAuth(password)
subscriptions = self._subscription_finder.find_from_service_principal_id(
subscriptions = subscription_finder.find_from_service_principal_id(
username, sp_auth, tenant, self._ad_resource_uri)
else:
subscriptions = self._subscription_finder.find_from_user_account(
subscriptions = subscription_finder.find_from_user_account(
username, password, tenant, self._ad_resource_uri)

if not subscriptions:
Expand All @@ -132,7 +137,7 @@ def find_subscriptions_on_login(self, # pylint: disable=too-many-arguments

if self._creds_cache.adal_token_cache.has_state_changed:
self._creds_cache.persist_cached_creds()
consolidated = Profile._normalize_properties(self._subscription_finder.user_id,
consolidated = Profile._normalize_properties(subscription_finder.user_id,
subscriptions,
is_service_principal)
self._set_subscriptions(consolidated)
Expand Down Expand Up @@ -256,6 +261,9 @@ def get_subscription(self, subscription=None): # take id or name
raise CLIError("Please run 'az account set' to select active account.")
return result[0]

def get_subscription_id(self):
return self.get_subscription()[_SUBSCRIPTION_ID]

def get_login_credentials(self, resource=CLOUD.endpoints.active_directory_resource_id,
subscription_id=None):
account = self.get_subscription(subscription_id)
Expand Down Expand Up @@ -428,8 +436,7 @@ def __init__(self, auth_ctx_factory=None):
os.path.join(get_config_dir(), 'accessTokens.json'))
self._service_principal_creds = []
self._auth_ctx_factory = auth_ctx_factory or _AUTH_CTX_FACTORY
self.adal_token_cache = None
self._load_creds()
self._adal_token_cache_attr = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need _attr suffix and I don't see we have other code using this convention. PEP 8 calls out we only need a leading _. But I am not going to have a naming convention discussion here :), if you feel it fine, i am fine with it too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


def persist_cached_creds(self):
with os.fdopen(os.open(self._token_file, os.O_RDWR | os.O_CREAT | os.O_TRUNC, 0o600),
Expand Down Expand Up @@ -459,6 +466,7 @@ def retrieve_token_for_user(self, username, tenant, resource):
return (token_entry[_TOKEN_ENTRY_TOKEN_TYPE], token_entry[_ACCESS_TOKEN])

def retrieve_token_for_service_principal(self, sp_id, resource):
self.load_adal_token_cache()
matched = [x for x in self._service_principal_creds if sp_id == x[_SERVICE_PRINCIPAL_ID]]
if not matched:
raise CLIError("Please run 'az account set' to select active account.")
Expand All @@ -471,23 +479,28 @@ def retrieve_token_for_service_principal(self, sp_id, resource):
return (token_entry[_TOKEN_ENTRY_TOKEN_TYPE], token_entry[_ACCESS_TOKEN])

def retrieve_secret_of_service_principal(self, sp_id):
self.load_adal_token_cache()
matched = [x for x in self._service_principal_creds if sp_id == x[_SERVICE_PRINCIPAL_ID]]
if not matched:
raise CLIError("No matched service principal found")
cred = matched[0]
return cred[_ACCESS_TOKEN]

def _load_creds(self):
import adal
if self.adal_token_cache is not None:
return self.adal_token_cache
all_entries = _load_tokens_from_file(self._token_file)
self._load_service_principal_creds(all_entries)
real_token = [x for x in all_entries if x not in self._service_principal_creds]
self.adal_token_cache = adal.TokenCache(json.dumps(real_token))
return self.adal_token_cache
@property
def adal_token_cache(self):
return self.load_adal_token_cache()

def load_adal_token_cache(self):
if self._adal_token_cache_attr is None:
import adal
all_entries = _load_tokens_from_file(self._token_file)
self._load_service_principal_creds(all_entries)
real_token = [x for x in all_entries if x not in self._service_principal_creds]
self._adal_token_cache_attr = adal.TokenCache(json.dumps(real_token))
return self._adal_token_cache_attr

def save_service_principal_cred(self, sp_entry):
self.load_adal_token_cache()
matched = [x for x in self._service_principal_creds
if sp_entry[_SERVICE_PRINCIPAL_ID] == x[_SERVICE_PRINCIPAL_ID] and
sp_entry[_SERVICE_PRINCIPAL_TENANT] == x[_SERVICE_PRINCIPAL_TENANT]]
Expand Down
2 changes: 1 addition & 1 deletion src/azure-cli-core/azure/cli/core/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def _get_env_string():

@decorators.suppress_all_exceptions(fallback_return=None)
def _get_azure_subscription_id():
return _get_profile().get_login_credentials()[1]
return _get_profile().get_subscription_id()


def _get_shell_type():
Expand Down
12 changes: 8 additions & 4 deletions src/azure-cli-core/tests/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,12 @@ def test_get_expanded_subscription_info_for_logged_in_service_principal(self,
storage_mock = {'subscriptions': []}
profile = Profile(storage_mock)
profile._management_resource_uri = 'https://management.core.windows.net/'
profile._subscription_finder = finder
profile.find_subscriptions_on_login(False, '1234', 'my-secret', True, self.tenant_id)
profile.find_subscriptions_on_login(False,
'1234',
'my-secret',
True,
self.tenant_id,
finder)
# action
extended_info = profile.get_expanded_subscription_info()
# assert
Expand Down Expand Up @@ -310,7 +314,6 @@ def test_get_login_credentials(self, mock_get_token, mock_read_cred_file):
token_type, token = cred._token_retriever()
self.assertEqual(token, self.raw_token1)
self.assertEqual(some_token_type, token_type)
self.assertEqual(mock_read_cred_file.call_count, 1)
mock_get_token.assert_called_once_with(mock.ANY, self.user1, self.tenant_id,
'https://management.core.windows.net/')
self.assertEqual(mock_get_token.call_count, 1)
Expand Down Expand Up @@ -535,7 +538,7 @@ def test_credscache_load_tokens_and_sp_creds_with_secret(self, mock_read_file):
creds_cache = CredsCache()

# assert
token_entries = [entry for _, entry in creds_cache.adal_token_cache.read_items()]
token_entries = [entry for _, entry in creds_cache.load_adal_token_cache().read_items()]
self.assertEqual(token_entries, [self.token_entry1])
self.assertEqual(creds_cache._service_principal_creds, [test_sp])

Expand All @@ -550,6 +553,7 @@ def test_credscache_load_tokens_and_sp_creds_with_cert(self, mock_read_file):

# action
creds_cache = CredsCache()
creds_cache.load_adal_token_cache()

# assert
self.assertEqual(creds_cache._service_principal_creds, [test_sp])
Expand Down