Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions src/azure/cli/_azure_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,23 @@
class ENDPOINT_URLS: #pylint: disable=too-few-public-methods,old-style-class,no-init
MANAGEMENT = 'management'
ACTIVE_DIRECTORY_AUTHORITY = 'active_directory_authority'
ACTIVE_DIRECTORY_GRAPH_RESOURCE_ID = 'active_directory_graph_resource_id'

_environments = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

black forest?

Copy link
Contributor Author

@yugangw-msft yugangw-msft Jul 7, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right on this, we are missing usgovenment as well. Since we have not exposed the --env flag in login, this can wait till we add that flag.

ENV_DEFAULT: {
ENDPOINT_URLS.MANAGEMENT: 'https://management.core.windows.net/',
ENDPOINT_URLS.ACTIVE_DIRECTORY_AUTHORITY : 'https://login.microsoftonline.com'
ENDPOINT_URLS.ACTIVE_DIRECTORY_AUTHORITY : 'https://login.microsoftonline.com',
ENDPOINT_URLS.ACTIVE_DIRECTORY_GRAPH_RESOURCE_ID: 'https://graph.windows.net/'
},
ENV_CHINA: {
ENDPOINT_URLS.MANAGEMENT: 'https://management.core.chinacloudapi.cn/',
ENDPOINT_URLS.ACTIVE_DIRECTORY_AUTHORITY: 'https://login.chinacloudapi.cn'
ENDPOINT_URLS.ACTIVE_DIRECTORY_AUTHORITY: 'https://login.chinacloudapi.cn',
ENDPOINT_URLS.ACTIVE_DIRECTORY_GRAPH_RESOURCE_ID: 'https://graph.chinacloudapi.cn/'
},
ENV_US_GOVERNMENT: {
ENDPOINT_URLS.MANAGEMENT: 'https://management.core.usgovcloudapi.net/',
ENDPOINT_URLS.ACTIVE_DIRECTORY_AUTHORITY: 'https://login.microsoftonline.com'
ENDPOINT_URLS.ACTIVE_DIRECTORY_AUTHORITY: 'https://login.microsoftonline.com',
ENDPOINT_URLS.ACTIVE_DIRECTORY_GRAPH_RESOURCE_ID: 'https://graph.windows.net/'
}
}

Expand All @@ -36,8 +40,3 @@ def get_env(env_name=None):
def get_authority_url(tenant=None, env_name=None):
env = get_env(env_name)
return env[ENDPOINT_URLS.ACTIVE_DIRECTORY_AUTHORITY] + '/' + (tenant or COMMON_TENANT)

def get_management_endpoint_url(env_name=None):
env = get_env(env_name)
return env[ENDPOINT_URLS.MANAGEMENT]

61 changes: 33 additions & 28 deletions src/azure/cli/_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from azure.mgmt.resource.subscriptions import SubscriptionClient
from .main import ACCOUNT
from ._util import CLIError
from ._azure_env import (get_authority_url, CLIENT_ID, get_management_endpoint_url,
ENV_DEFAULT, COMMON_TENANT)
from ._azure_env import (get_authority_url, get_env, ENDPOINT_URLS,
CLIENT_ID, ENV_DEFAULT, COMMON_TENANT)
from .adal_authentication import AdalAuthentication
import azure.cli._logging as _logging
logger = _logging.get_az_logger(__name__)
Expand Down Expand Up @@ -67,6 +67,9 @@ def __init__(self, storage=None, auth_ctx_factory=None):
factory = auth_ctx_factory or _AUTH_CTX_FACTORY
self._creds_cache = CredsCache(factory)
self._subscription_finder = SubscriptionFinder(factory, self._creds_cache.adal_token_cache)
env = get_env()
self._management_resource_uri = env[ENDPOINT_URLS.MANAGEMENT]
self._graph_resource_uri = env[ENDPOINT_URLS.ACTIVE_DIRECTORY_GRAPH_RESOURCE_ID]

def find_subscriptions_on_login(self, #pylint: disable=too-many-arguments
interactive,
Expand All @@ -77,17 +80,18 @@ def find_subscriptions_on_login(self, #pylint: disable=too-many-arguments
self._creds_cache.remove_cached_creds(username)
subscriptions = []
if interactive:
subscriptions = self._subscription_finder.find_through_interactive_flow()
subscriptions = self._subscription_finder.find_through_interactive_flow(
self._management_resource_uri)
else:
if is_service_principal:
if not tenant:
raise CLIError('Please supply tenant using "--tenant"')

subscriptions = self._subscription_finder.find_from_service_principal_id(username,
password,
tenant)
subscriptions = self._subscription_finder.find_from_service_principal_id(
username, password, tenant, self._management_resource_uri)
else:
subscriptions = self._subscription_finder.find_from_user_account(username, password)
subscriptions = self._subscription_finder.find_from_user_account(
username, password, self._management_resource_uri)

if not subscriptions:
raise CLIError('No subscriptions found for this account.')
Expand Down Expand Up @@ -192,7 +196,7 @@ def load_cached_subscriptions(self):
def _cache_subscriptions_to_local_storage(self, subscriptions):
self._storage[_SUBSCRIPTIONS] = subscriptions

def get_login_credentials(self):
def get_login_credentials(self, for_graph_client=False):
subscriptions = self.load_cached_subscriptions()
if not subscriptions:
raise CLIError('Please run login to setup account.')
Expand All @@ -204,52 +208,54 @@ def get_login_credentials(self):

user_type = active_account[_USER_ENTITY][_USER_TYPE]
username_or_sp_id = active_account[_USER_ENTITY][_USER_NAME]
resource = self._graph_resource_uri if for_graph_client else self._management_resource_uri
if user_type == _USER:
token_retriever = lambda: self._creds_cache.retrieve_token_for_user(
username_or_sp_id, active_account[_TENANT_ID])
username_or_sp_id, active_account[_TENANT_ID], resource)
auth_object = AdalAuthentication(token_retriever)
else:
token_retriever = lambda: self._creds_cache.retrieve_token_for_service_principal(
username_or_sp_id)
username_or_sp_id, resource)
auth_object = AdalAuthentication(token_retriever)

return auth_object, str(active_account[_SUBSCRIPTION_ID])
return (auth_object,
str(active_account[_SUBSCRIPTION_ID]),
str(active_account[_TENANT_ID]))


class SubscriptionFinder(object):
'''finds all subscriptions for a user or service principal'''
def __init__(self, auth_context_factory, adal_token_cache, arm_client_factory=None):
self._adal_token_cache = adal_token_cache
self._auth_context_factory = auth_context_factory
self._resource = get_management_endpoint_url(ENV_DEFAULT)
self.user_id = None # will figure out after log user in
self._arm_client_factory = arm_client_factory or \
(lambda config: SubscriptionClient(config)) #pylint: disable=unnecessary-lambda

def find_from_user_account(self, username, password):
def find_from_user_account(self, username, password, resource):
context = self._create_auth_context(COMMON_TENANT)
token_entry = context.acquire_token_with_username_password(
self._resource,
resource,
username,
password,
CLIENT_ID)
self.user_id = token_entry[_TOKEN_ENTRY_USER_ID]
result = self._find_using_common_tenant(token_entry[_ACCESS_TOKEN])
result = self._find_using_common_tenant(token_entry[_ACCESS_TOKEN], resource)
return result

def find_through_interactive_flow(self):
def find_through_interactive_flow(self, resource):
context = self._create_auth_context(COMMON_TENANT)
code = context.acquire_user_code(self._resource, CLIENT_ID)
code = context.acquire_user_code(resource, CLIENT_ID)
logger.warning(code['message'])
token_entry = context.acquire_token_with_device_code(self._resource, code, CLIENT_ID)
token_entry = context.acquire_token_with_device_code(resource, code, CLIENT_ID)
self.user_id = token_entry[_TOKEN_ENTRY_USER_ID]
result = self._find_using_common_tenant(token_entry[_ACCESS_TOKEN])
result = self._find_using_common_tenant(token_entry[_ACCESS_TOKEN], resource)
return result

def find_from_service_principal_id(self, client_id, secret, tenant):
def find_from_service_principal_id(self, client_id, secret, tenant, resource):
context = self._create_auth_context(tenant, False)
token_entry = context.acquire_token_with_client_credentials(
self._resource,
resource,
client_id,
secret)
self.user_id = client_id
Expand All @@ -261,15 +267,15 @@ def _create_auth_context(self, tenant, use_token_cache=True):
authority = get_authority_url(tenant, ENV_DEFAULT)
return self._auth_context_factory(authority, token_cache)

def _find_using_common_tenant(self, access_token):
def _find_using_common_tenant(self, access_token, resource):
all_subscriptions = []
token_credential = BasicTokenAuthentication({'access_token': access_token})
client = self._arm_client_factory(token_credential)
tenants = client.tenants.list()
for t in tenants:
tenant_id = t.tenant_id
temp_context = self._create_auth_context(tenant_id)
temp_credentials = temp_context.acquire_token(self._resource, self.user_id, CLIENT_ID)
temp_credentials = temp_context.acquire_token(resource, self.user_id, CLIENT_ID)
subscriptions = self._find_using_specific_tenant(
tenant_id,
temp_credentials[_ACCESS_TOKEN])
Expand Down Expand Up @@ -297,7 +303,6 @@ def __init__(self, auth_ctx_factory=None):
self._auth_ctx_factory = auth_ctx_factory or _AUTH_CTX_FACTORY
self.adal_token_cache = None
self._load_creds()
self._resource = get_management_endpoint_url(ENV_DEFAULT)

def persist_cached_creds(self):
#be compatible with azure-xplat-cli, use 'ascii' so to save w/o a BOM
Expand All @@ -314,25 +319,25 @@ def persist_cached_creds(self):
cred_file.write(json.dumps(all_creds))
self.adal_token_cache.has_state_changed = False

def retrieve_token_for_user(self, username, tenant):
def retrieve_token_for_user(self, username, tenant, resource):
authority = get_authority_url(tenant, ENV_DEFAULT)
context = self._auth_ctx_factory(authority, cache=self.adal_token_cache)
token_entry = context.acquire_token(self._resource, username, CLIENT_ID)
token_entry = context.acquire_token(resource, username, CLIENT_ID)
if not token_entry:
raise CLIError('Could not retrieve token from local cache, please run \'login\'.')

if self.adal_token_cache.has_state_changed:
self.persist_cached_creds()
return (token_entry[_TOKEN_ENTRY_TOKEN_TYPE], token_entry[_ACCESS_TOKEN])

def retrieve_token_for_service_principal(self, sp_id):
def retrieve_token_for_service_principal(self, sp_id, resource):
matched = [x for x in self._service_principal_creds if sp_id == x[_SERVICE_PRINCIPAL_ID]]
if not matched:
raise CLIError('Please run "account set" to select active account.')
cred = matched[0]
authority_url = get_authority_url(cred[_SERVICE_PRINCIPAL_TENANT], ENV_DEFAULT)
context = self._auth_ctx_factory(authority_url, None)
token_entry = context.acquire_token_with_client_credentials(self._resource,
token_entry = context.acquire_token_with_client_credentials(resource,
sp_id,
cred[_ACCESS_TOKEN])
return (token_entry[_TOKEN_ENTRY_TOKEN_TYPE], token_entry[_ACCESS_TOKEN])
Expand Down
2 changes: 1 addition & 1 deletion src/azure/cli/commands/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def get_subscription_service_client(client_type):
def _get_mgmt_service_client(client_type, subscription_bound=True):
logger.info('Getting management service client client_type=%s', client_type.__name__)
profile = Profile()
cred, subscription_id = profile.get_login_credentials()
cred, subscription_id, _ = profile.get_login_credentials()
if subscription_bound:
client = client_type(cred, subscription_id)
else:
Expand Down
52 changes: 39 additions & 13 deletions src/azure/cli/tests/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def test_get_login_credentials(self, mock_get_token, mock_read_cred_file):
ENV_DEFAULT)
profile._set_subscriptions(consolidated)
#action
cred, subscription_id = profile.get_login_credentials()
cred, subscription_id, _ = profile.get_login_credentials()

#verify
self.assertEqual(subscription_id, '1')
Expand All @@ -207,8 +207,30 @@ def test_get_login_credentials(self, mock_get_token, mock_read_cred_file):
self.assertEqual(token, self.raw_token1)
self.assertEqual(some_token_type, token_type)
self.assertEqual(mock_read_cred_file.call_count, 1)
mock_get_token.assert_called_once_with(mock.ANY, self.user1, self.tenant_id,
'https://management.core.windows.net/')
self.assertEqual(mock_get_token.call_count, 1)

@mock.patch('azure.cli._profile._read_file_content', autospec=True)
@mock.patch('azure.cli._profile.CredsCache.retrieve_token_for_user', autospec=True)
def test_get_login_credentials_for_graph_client(self, mock_get_token, mock_read_cred_file):
some_token_type = 'Bearer'
mock_read_cred_file.return_value = json.dumps([Test_Profile.token_entry1])
mock_get_token.return_value = (some_token_type, Test_Profile.raw_token1)
#setup
storage_mock = {'subscriptions': None}
profile = Profile(storage_mock)
consolidated = Profile._normalize_properties(self.user1, [self.subscription1],
False, ENV_DEFAULT)
profile._set_subscriptions(consolidated)
#action
cred, _, tenant_id = profile.get_login_credentials(for_graph_client=True)
_, _ = cred._token_retriever()
#verify
mock_get_token.assert_called_once_with(mock.ANY, self.user1, self.tenant_id,
'https://graph.windows.net/')
self.assertEqual(tenant_id, self.tenant_id)

@mock.patch('azure.cli._profile._read_file_content', autospec=True)
@mock.patch('azure.cli._profile.CredsCache.persist_cached_creds', autospec=True)
def test_logout(self, mock_persist_creds, mock_read_cred_file):
Expand Down Expand Up @@ -264,16 +286,16 @@ def test_find_subscriptions_thru_username_password(self, mock_auth_context):
finder = SubscriptionFinder(lambda _, _2: mock_auth_context,
None,
lambda _: mock_arm_client)

mgmt_resource = 'https://management.core.windows.net/'
#action
subs = finder.find_from_user_account(self.user1, 'bar')
subs = finder.find_from_user_account(self.user1, 'bar', mgmt_resource)

#assert
self.assertEqual([self.subscription1], subs)
mock_auth_context.acquire_token_with_username_password.assert_called_once_with(
'https://management.core.windows.net/', self.user1, 'bar', mock.ANY)
mgmt_resource, self.user1, 'bar', mock.ANY)
mock_auth_context.acquire_token.assert_called_once_with(
'https://management.core.windows.net/', self.user1, mock.ANY)
mgmt_resource, self.user1, mock.ANY)

@mock.patch('adal.AuthenticationContext', autospec=True)
def test_find_subscriptions_through_interactive_flow(self, mock_auth_context):
Expand All @@ -286,18 +308,18 @@ def test_find_subscriptions_through_interactive_flow(self, mock_auth_context):
finder = SubscriptionFinder(lambda _, _2: mock_auth_context,
None,
lambda _: mock_arm_client)

mgmt_resource = 'https://management.core.windows.net/'
#action
subs = finder.find_through_interactive_flow()
subs = finder.find_through_interactive_flow(mgmt_resource)

#assert
self.assertEqual([self.subscription1], subs)
mock_auth_context.acquire_user_code.assert_called_once_with(
'https://management.core.windows.net/', mock.ANY)
mgmt_resource, mock.ANY)
mock_auth_context.acquire_token_with_device_code.assert_called_once_with(
'https://management.core.windows.net/', test_nonsense_code, mock.ANY)
mgmt_resource, test_nonsense_code, mock.ANY)
mock_auth_context.acquire_token.assert_called_once_with(
'https://management.core.windows.net/', self.user1, mock.ANY)
mgmt_resource, self.user1, mock.ANY)

@mock.patch('adal.AuthenticationContext', autospec=True)
def test_find_subscriptions_from_service_principal_id(self, mock_auth_context):
Expand All @@ -307,15 +329,17 @@ def test_find_subscriptions_from_service_principal_id(self, mock_auth_context):
finder = SubscriptionFinder(lambda _, _2: mock_auth_context,
None,
lambda _: mock_arm_client)
mgmt_resource = 'https://management.core.windows.net/'
#action
subs = finder.find_from_service_principal_id('my app', 'my secret', self.tenant_id)
subs = finder.find_from_service_principal_id('my app', 'my secret',
self.tenant_id, mgmt_resource)

#assert
self.assertEqual([self.subscription1], subs)
mock_arm_client.tenants.list.assert_not_called()
mock_auth_context.acquire_token.assert_not_called()
mock_auth_context.acquire_token_with_client_credentials.assert_called_once_with(
'https://management.core.windows.net/', 'my app', 'my secret')
mgmt_resource, 'my app', 'my secret')

@mock.patch('azure.cli._profile._read_file_content', autospec=True)
def test_credscache_load_tokens_and_sp_creds(self, mock_read_file):
Expand Down Expand Up @@ -413,7 +437,9 @@ def get_auth_context(authority, **kwargs): # pylint: disable=unused-argument
creds_cache = CredsCache(auth_ctx_factory=get_auth_context)

#action
token_type, token = creds_cache.retrieve_token_for_user(self.user1, self.tenant_id)
mgmt_resource = 'https://management.core.windows.net/'
token_type, token = creds_cache.retrieve_token_for_user(self.user1, self.tenant_id,
mgmt_resource)
mock_adal_auth_context.acquire_token.assert_called_once_with(
'https://management.core.windows.net/',
self.user1,
Expand Down
4 changes: 2 additions & 2 deletions src/azure/cli/utils/vcr_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
def _mock_get_mgmt_service_client(client_type, subscription_bound=True):
# version of _get_mgmt_service_client to use when recording or playing tests
profile = Profile()
cred, subscription_id = profile.get_login_credentials()
cred, subscription_id, _ = profile.get_login_credentials()
if subscription_bound:
client = client_type(cred, subscription_id)
else:
Expand Down Expand Up @@ -60,7 +60,7 @@ def _mock_subscriptions(self): #pylint: disable=unused-argument
"tenantId": "123",
"isDefault": True}]

def _mock_user_access_token(_, _1, _2): #pylint: disable=unused-argument
def _mock_user_access_token(_, _1, _2, _3): #pylint: disable=unused-argument
return ('Bearer', 'top-secret-token-for-you')

def _mock_operation_delay(_):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,5 @@ def handle_folding(namespace):
def get_subscription_id():
from azure.cli.commands.client_factory import Profile
profile = Profile()
_, subscription_id = profile.get_login_credentials()
_, subscription_id, _ = profile.get_login_credentials()
return subscription_id
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def get_subscription_id_list(prefix, **kwargs):#pylint: disable=unused-argument
help='Organization id or service principal'
)

sp_name_type = CliArgumentType(
options_list=('--name', '-n')
)

register_cli_argument('login', 'password', password_type)
register_cli_argument('login', 'service_principal', service_principal_type)
register_cli_argument('login', 'username', username_type)
Expand All @@ -45,3 +49,6 @@ def get_subscription_id_list(prefix, **kwargs):#pylint: disable=unused-argument
register_cli_argument('logout', 'username', username_type)

register_cli_argument('account', 'subscription_name_or_id', subscription_name_or_id_type)

register_cli_argument('account create-sp', 'name', sp_name_type)
register_cli_argument('account reset-sp-credentials', 'name', sp_name_type)
Loading