diff --git a/.changes/next-release/enhancement-ssologin-96466.json b/.changes/next-release/enhancement-ssologin-96466.json new file mode 100644 index 000000000000..18539a0cae21 --- /dev/null +++ b/.changes/next-release/enhancement-ssologin-96466.json @@ -0,0 +1,5 @@ +{ + "type": "enhancement", + "category": "``sso login``", + "description": "Add ``--sso-session`` argument to enable direct SSO login with a ``sso-session``" +} diff --git a/.changes/next-release/feature-configuresso-52515.json b/.changes/next-release/feature-configuresso-52515.json new file mode 100644 index 000000000000..8e92e962b726 --- /dev/null +++ b/.changes/next-release/feature-configuresso-52515.json @@ -0,0 +1,5 @@ +{ + "type": "feature", + "category": "``configure sso``", + "description": "Add support for configuring ``sso-session`` as part of configuring SSO-enabled profile" +} diff --git a/.changes/next-release/feature-configuressosession-45599.json b/.changes/next-release/feature-configuressosession-45599.json new file mode 100644 index 000000000000..5a0035f5451f --- /dev/null +++ b/.changes/next-release/feature-configuressosession-45599.json @@ -0,0 +1,5 @@ +{ + "type": "feature", + "category": "``configure sso-session``", + "description": "Add new ``configure sso-session`` command for creating and updating ``sso-session`` configurations" +} diff --git a/awscli/customizations/configure/__init__.py b/awscli/customizations/configure/__init__.py index 6055529251d4..ab0630533545 100644 --- a/awscli/customizations/configure/__init__.py +++ b/awscli/customizations/configure/__init__.py @@ -46,6 +46,10 @@ def profile_to_section(profile_name): """Converts a profile name to a section header to be used in the config.""" if profile_name == 'default': return profile_name - if any(c in _WHITESPACE for c in profile_name): - profile_name = shlex_quote(profile_name) - return 'profile %s' % profile_name + return get_section_header('profile', profile_name) + + +def get_section_header(section_type, section_name): + if any(c in _WHITESPACE for c in section_name): + section_name = shlex_quote(section_name) + return f'{section_type} {section_name}' diff --git a/awscli/customizations/configure/configure.py b/awscli/customizations/configure/configure.py index 51d3b82fc037..e1365d1b9c75 100644 --- a/awscli/customizations/configure/configure.py +++ b/awscli/customizations/configure/configure.py @@ -25,6 +25,7 @@ from awscli.customizations.configure.importer import ConfigureImportCommand from awscli.customizations.configure.listprofiles import ListProfilesCommand from awscli.customizations.configure.sso import ConfigureSSOCommand +from awscli.customizations.configure.sso import ConfigureSSOSessionCommand from awscli.customizations.configure.exportcreds import \ ConfigureExportCredentialsCommand @@ -82,6 +83,7 @@ class ConfigureCommand(BasicCommand): {'name': 'import', 'command_class': ConfigureImportCommand}, {'name': 'list-profiles', 'command_class': ListProfilesCommand}, {'name': 'sso', 'command_class': ConfigureSSOCommand}, + {'name': 'sso-session', 'command_class': ConfigureSSOSessionCommand}, {'name': 'export-credentials', 'command_class': ConfigureExportCredentialsCommand}, ] diff --git a/awscli/customizations/configure/sso.py b/awscli/customizations/configure/sso.py index 7c9c73231304..759d2cc832e2 100644 --- a/awscli/customizations/configure/sso.py +++ b/awscli/customizations/configure/sso.py @@ -10,9 +10,14 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +import collections +import itertools +import json import os import logging +import re +import colorama from botocore import UNSIGNED from botocore.config import Config from botocore.configprovider import ConstantProvider @@ -20,39 +25,89 @@ from botocore.utils import is_valid_endpoint_url from prompt_toolkit import prompt as ptk_prompt +from prompt_toolkit.application import get_app from prompt_toolkit.completion import WordCompleter +from prompt_toolkit.formatted_text import FormattedText +from prompt_toolkit.styles import Style from prompt_toolkit.validation import Validator from prompt_toolkit.validation import ValidationError from awscli.customizations.utils import uni_print -from awscli.customizations.commands import BasicCommand -from awscli.customizations.configure import profile_to_section +from awscli.customizations.configure import ( + profile_to_section, get_section_header, +) from awscli.customizations.configure.writer import ConfigFileWriter from awscli.customizations.wizard.ui.selectmenu import select_menu from awscli.customizations.sso.utils import ( - do_sso_login, PrintOnlyHandler, LOGIN_ARGS, + do_sso_login, parse_sso_registration_scopes, PrintOnlyHandler, LOGIN_ARGS, + BaseSSOCommand, ) from awscli.formatter import CLI_OUTPUT_FORMATS logger = logging.getLogger(__name__) +_CMD_PROMPT_USAGE = ( + 'To keep an existing value, hit enter when prompted for the value. When ' + 'you are prompted for information, the current value will be displayed in ' + '[brackets]. If the config item has no value, it is displayed as ' + '[None] or omitted entirely.\n\n' +) +_CONFIG_EXTRA_INFO = ( + 'Note: The configuration is saved in the shared configuration file. ' + 'By default, ``~/.aws/config``. For more information, see the ' + '"Configuring the AWS CLI to use AWS Single Sign-On" section in the AWS ' + 'CLI User Guide:' + '\n\nhttps://docs.aws.amazon.com/cli/latest/userguide/cli-configure-sso.html' +) + -class StartUrlValidator(Validator): +class ValidatorWithDefault(Validator): def __init__(self, default=None): - super(StartUrlValidator, self).__init__() + super(ValidatorWithDefault, self).__init__() self._default = default + def _raise_validation_error(self, document, message): + index = len(document.text) + raise ValidationError(index, message) + + +class StartUrlValidator(ValidatorWithDefault): def validate(self, document): # If there's a default, allow an empty prompt if not document.text and self._default: return if not is_valid_endpoint_url(document.text): - index = len(document.text) - raise ValidationError(index, 'Not a valid Start URL') + self._raise_validation_error(document, 'Not a valid Start URL') + + +class RequiredInputValidator(ValidatorWithDefault): + def validate(self, document): + if document.text or self._default: + return + self._raise_validation_error(document, 'A value is required') + + +class ScopesValidator(ValidatorWithDefault): + def validate(self, document): + # If there's a default, allow an empty prompt + if not document.text and self._default: + return + if not self._is_comma_separated_list(document.text): + self._raise_validation_error( + document, 'Scope values must be separated by commas') + + def _is_comma_separated_list(self, value): + scopes = value.split(',') + for scope in scopes: + if re.findall(r'\s', scope.strip()): + return False + return True class PTKPrompt(object): + _DEFAULT_PROMPT_FORMAT = '{prompt_text} [{current_value}]: ' + def __init__(self, prompter=None): if prompter is None: prompter = ptk_prompt @@ -61,29 +116,35 @@ def __init__(self, prompter=None): def _create_completer(self, completions): if completions is None: completions = [] + completer_kwargs = { + 'words': completions, + 'pattern': re.compile(r'\S+') + } if isinstance(completions, dict): - meta_dict = completions - completions = list(meta_dict.keys()) - completer = WordCompleter( - completions, - sentence=True, - meta_dict=meta_dict, - ) - else: - completer = WordCompleter(completions, sentence=True) - return completer + completer_kwargs['meta_dict'] = completions + completer_kwargs['words'] = list(completions.keys()) + return WordCompleter(**completer_kwargs) def get_value(self, current_value, prompt_text='', - completions=None, validator=None): - completer = self._create_completer(completions) - prompt_string = u'{} [{}]: '.format(prompt_text, current_value) - response = self._prompter( - prompt_string, - validator=validator, - validate_while_typing=False, - completer=completer, - complete_while_typing=True, + completions=None, validator=None, toolbar=None, + prompt_fmt=None): + if prompt_fmt is None: + prompt_fmt = self._DEFAULT_PROMPT_FORMAT + prompt_string = prompt_fmt.format( + prompt_text=prompt_text, + current_value=current_value ) + prompter_kwargs = { + 'validator': validator, + 'validate_while_typing': False, + 'completer': self._create_completer(completions), + 'complete_while_typing': True, + 'style': self._get_prompt_style(), + } + if toolbar: + prompter_kwargs['bottom_toolbar'] = toolbar + prompter_kwargs['refresh_interval'] = 0.2 + response = self._prompter(prompt_string, **prompter_kwargs) # Strip any extra white space response = response.strip() if not response: @@ -91,6 +152,13 @@ def get_value(self, current_value, prompt_text='', response = current_value return response + def _get_prompt_style(self): + return Style.from_dict( + { + 'bottom-toolbar': 'noreverse', + } + ) + def display_account(account): """Converts an SSO account response into a display string. @@ -111,67 +179,219 @@ def display_account(account): return account_template.format(**account) -class ConfigureSSOCommand(BasicCommand): +class SSOSessionConfigurationPrompter: + _DEFAULT_SSO_SCOPE = 'sso:account:access' + _KNOWN_SSO_SCOPES = { + 'sso:account:access': ( + 'Grants access to AWS IAM Identity Center accounts and permission ' + 'sets' + ) + } + + def __init__(self, botocore_session, prompter): + self._botocore_session = botocore_session + self._prompter = prompter + self._sso_sessions = self._botocore_session.full_config.get( + 'sso_sessions', {}) + self._sso_session = None + self.sso_session_config = {} + + @property + def sso_session(self): + return self._sso_session + + @sso_session.setter + def sso_session(self, value): + self._sso_session = value + self.sso_session_config = self._sso_sessions.get( + self._sso_session, {}).copy() + + def prompt_for_sso_session(self, required=True): + prompt_text = 'SSO session name' + prompt_fmt = None + validator_cls = None + if required: + validator_cls = RequiredInputValidator + if not self.sso_session: + prompt_fmt = f'{prompt_text}: ' + if not required: + prompt_fmt = f'{prompt_text} (Recommended): ' + sso_session = self._prompt_for( + 'sso_session', prompt_text, + completions=sorted(self._sso_sessions), + toolbar=self._get_sso_session_toolbar, + validator_cls=validator_cls, + prompt_fmt=prompt_fmt, + current_value=self.sso_session, + ) + self.sso_session = sso_session + return sso_session + + def prompt_for_sso_start_url(self): + return self._prompt_for( + 'sso_start_url', 'SSO start URL', + completions=self._get_potential_start_urls(), + validator_cls=StartUrlValidator, + ) + + def prompt_for_sso_region(self): + return self._prompt_for( + 'sso_region', 'SSO region', + completions=self._get_potential_sso_regions(), + validator_cls=RequiredInputValidator, + ) + + def prompt_for_sso_registration_scopes(self): + if 'sso_registration_scopes' not in self.sso_session_config: + self.sso_session_config['sso_registration_scopes'] = \ + self._DEFAULT_SSO_SCOPE + raw_scopes = self._prompt_for( + 'sso_registration_scopes', 'SSO registration scopes', + completions=self._get_potential_sso_registrations_scopes(), + validator_cls=ScopesValidator, + ) + return parse_sso_registration_scopes(raw_scopes) + + def _prompt_for(self, config_name, text, + completions=None, validator_cls=None, + toolbar=None, prompt_fmt=None, current_value=None): + if current_value is None: + current_value = self.sso_session_config.get(config_name) + validator = None + if validator_cls: + validator = validator_cls(current_value) + value = self._prompter.get_value( + current_value, text, + completions=completions, + validator=validator, + toolbar=toolbar, + prompt_fmt=prompt_fmt + ) + if value: + self.sso_session_config[config_name] = value + return value + + def _get_sso_session_toolbar(self): + current_input = get_app().current_buffer.document.text + if current_input in self._sso_sessions: + selected_sso_config = self._sso_sessions[current_input] + return FormattedText([ + ('', self._get_toolbar_border()), + ('', '\n'), + ('bold', f'Configuration for SSO session: {current_input}\n\n'), + ('', json.dumps(selected_sso_config, indent=2)), + ]) + + def _get_toolbar_border(self): + horizontal_line_char = '\u2500' + return horizontal_line_char * get_app().output.get_size().columns + + def _get_potential_start_urls(self): + profiles = self._botocore_session.full_config.get('profiles', {}) + configs_to_search = itertools.chain( + profiles.values(), + self._sso_sessions.values() + ) + potential_start_urls = set() + for config_to_search in configs_to_search: + if 'sso_start_url' in config_to_search: + start_url = config_to_search['sso_start_url'] + potential_start_urls.add(start_url) + return list(potential_start_urls) + + def _get_potential_sso_regions(self): + return self._botocore_session.get_available_regions('sso-oidc') + + def _get_potential_sso_registrations_scopes(self): + potential_scopes = self._KNOWN_SSO_SCOPES.copy() + scopes_to_sessions = self._get_previously_used_scopes_to_sso_sessions() + for scope, sso_sessions in scopes_to_sessions.items(): + if scope not in potential_scopes: + potential_scopes[scope] = ( + f'Used in SSO sessions: {", ".join(sso_sessions)}' + ) + return potential_scopes + + def _get_previously_used_scopes_to_sso_sessions(self): + scopes_to_sessions = collections.defaultdict(list) + for sso_session, sso_session_config in self._sso_sessions.items(): + if 'sso_registration_scopes' in sso_session_config: + parsed_scopes = parse_sso_registration_scopes( + sso_session_config['sso_registration_scopes'] + ) + for parsed_scope in parsed_scopes: + scopes_to_sessions[parsed_scope].append(sso_session) + return scopes_to_sessions + + +class BaseSSOConfigurationCommand(BaseSSOCommand): + def __init__(self, session, prompter=None, config_writer=None): + super(BaseSSOConfigurationCommand, self).__init__(session) + if prompter is None: + prompter = PTKPrompt() + self._prompter = prompter + if config_writer is None: + config_writer = ConfigFileWriter() + self._config_writer = config_writer + self._sso_sessions = self._session.full_config.get('sso_sessions', {}) + self._sso_session_prompter = SSOSessionConfigurationPrompter( + botocore_session=session, prompter=self._prompter, + ) + + def _write_sso_configuration(self): + self._update_section( + section_header=get_section_header( + 'sso-session', self._sso_session_prompter.sso_session), + new_values=self._sso_session_prompter.sso_session_config + ) + + def _update_section(self, section_header, new_values): + config_path = self._session.get_config_variable('config_file') + config_path = os.path.expanduser(config_path) + new_values['__section__'] = section_header + self._config_writer.update_config(new_values, config_path) + + +class ConfigureSSOCommand(BaseSSOConfigurationCommand): NAME = 'sso' SYNOPSIS = ('aws configure sso [--profile profile-name]') DESCRIPTION = ( 'The ``aws configure sso`` command interactively prompts for the ' 'configuration values required to create a profile that sources ' - 'temporary AWS credentials from AWS Single Sign-On. To keep an ' - 'existing value, hit enter when prompted for the value. When you ' - 'are prompted for information, the current value will be displayed in ' - '[brackets]. If the config item has no value, it is displayed as ' - '[None]. When providing the ``--profile`` parameter the named profile ' + 'temporary AWS credentials from AWS Single Sign-On.\n\n' + f'{_CMD_PROMPT_USAGE}' + 'When providing the ``--profile`` parameter the named profile ' 'will be created or updated. When a profile is not explicitly set ' - 'the profile name will be prompted for.' - '\n\nNote: The configuration is saved in the shared configuration ' - 'file. By default, ``~/.aws/config``.' - 'For more information, see the "Configuring the AWS CLI to use AWS ' - 'Single Sign-On" section in the AWS CLI User Guide:' - '\n\nhttps://docs.aws.amazon.com/cli/latest/userguide/cli-configure-sso.html' + 'the profile name will be prompted for.\n\n' + f'{_CONFIG_EXTRA_INFO}' ) # TODO: Add CLI parameters to skip prompted values, --start-url, etc. ARG_TABLE = LOGIN_ARGS def __init__(self, session, prompter=None, selector=None, config_writer=None, sso_token_cache=None, sso_login=None): - super(ConfigureSSOCommand, self).__init__(session) - if prompter is None: - prompter = PTKPrompt() - self._prompter = prompter + super(ConfigureSSOCommand, self).__init__( + session, prompter=prompter, config_writer=config_writer) if selector is None: selector = select_menu self._selector = selector - if config_writer is None: - config_writer = ConfigFileWriter() if sso_login is None: sso_login = do_sso_login self._sso_login = sso_login - self._config_writer = config_writer self._sso_token_cache = sso_token_cache - self._new_values = {} + self._new_profile_config_values = {} self._original_profile_name = self._session.profile try: - self._config = self._session.get_scoped_config() + self._profile_config = self._session.get_scoped_config() except ProfileNotFound: - self._config = {} + self._profile_config = {} + self._set_sso_session_if_configured_in_profile() - def _prompt_for(self, config_name, text, - completions=None, validator_cls=None): - current_value = self._config.get(config_name) - if validator_cls is None: - validator = None - else: - validator = validator_cls(current_value) - new_value = self._prompter.get_value( - current_value, text, - completions=completions, - validator=validator, - ) - if new_value: - self._new_values[config_name] = new_value - return new_value + def _set_sso_session_if_configured_in_profile(self): + if 'sso_session' in self._profile_config: + self._sso_session_prompter.sso_session = \ + self._profile_config['sso_session'] def _handle_single_account(self, accounts): sso_account_id = accounts[0]['accountId'] @@ -186,7 +406,8 @@ def _handle_multiple_accounts(self, accounts): 'There are {} AWS accounts available to you.\n' ) uni_print(available_accounts_msg.format(len(accounts))) - selected_account = self._selector(accounts, display_account) + selected_account = self._selector( + accounts, display_format=display_account) sso_account_id = selected_account['accountId'] return sso_account_id @@ -204,7 +425,7 @@ def _prompt_for_account(self, sso, sso_token): else: sso_account_id = self._handle_multiple_accounts(accounts) uni_print('Using the account ID {}\n'.format(sso_account_id)) - self._new_values['sso_account_id'] = sso_account_id + self._new_profile_config_values['sso_account_id'] = sso_account_id return sso_account_id def _handle_single_role(self, roles): @@ -238,57 +459,43 @@ def _prompt_for_role(self, sso, sso_token, sso_account_id): else: sso_role_name = self._handle_multiple_roles(roles) uni_print('Using the role name "{}"\n'.format(sso_role_name)) - self._new_values['sso_role_name'] = sso_role_name + self._new_profile_config_values['sso_role_name'] = sso_role_name return sso_role_name - def _prompt_for_profile(self, sso_account_id, sso_role_name): + def _prompt_for_profile(self, sso_account_id=None, sso_role_name=None): if self._original_profile_name: profile_name = self._original_profile_name else: - default_profile = '{}-{}'.format(sso_role_name, sso_account_id) text = 'CLI profile name' - profile_name = self._prompter.get_value(default_profile, text) + default_profile = None + if sso_account_id and sso_role_name: + default_profile = f'{sso_role_name}-{sso_account_id}' + validator = RequiredInputValidator(default_profile) + profile_name = self._prompter.get_value( + default_profile, text, validator=validator) return profile_name - def _get_potential_start_urls(self): - profiles = self._session.full_config.get('profiles', []) - potential_start_urls = set() - for profile, config in profiles.items(): - if 'sso_start_url' in config: - start_url = config['sso_start_url'] - potential_start_urls.add(start_url) - return list(potential_start_urls) - - def _prompt_for_start_url(self): - potential_start_urls = self._get_potential_start_urls() - start_url = self._prompt_for( - 'sso_start_url', 'SSO start URL', - completions=potential_start_urls, - validator_cls=StartUrlValidator, - ) - return start_url - - def _get_potential_sso_regions(self): - return self._session.get_available_regions('sso-oidc') - - def _prompt_for_sso_region(self): - potential_sso_regions = self._get_potential_sso_regions() - sso_region = self._prompt_for( - 'sso_region', 'SSO Region', - completions=potential_sso_regions, - ) - return sso_region - def _prompt_for_cli_default_region(self): # TODO: figure out a way to get a list of reasonable client regions - return self._prompt_for('region', 'CLI default client Region') + return self._prompt_for_profile_config( + 'region', 'CLI default client Region') def _prompt_for_cli_output_format(self): - return self._prompt_for( + return self._prompt_for_profile_config( 'output', 'CLI default output format', completions=list(CLI_OUTPUT_FORMATS.keys()), ) + def _prompt_for_profile_config(self, config_name, text, completions=None): + current_value = self._profile_config.get(config_name) + new_value = self._prompter.get_value( + current_value, text, + completions=completions, + ) + if new_value: + self._new_profile_config_values[config_name] = new_value + return new_value + def _unset_session_profile(self): # The profile provided to the CLI as --profile may not exist. # This means we cannot use the session as is to create clients. @@ -302,28 +509,28 @@ def _unset_session_profile(self): def _run_main(self, parsed_args, parsed_globals): self._unset_session_profile() - start_url = self._prompt_for_start_url() - sso_region = self._prompt_for_sso_region() on_pending_authorization = None if parsed_args.no_browser: on_pending_authorization = PrintOnlyHandler() + sso_registration_args = self._prompt_for_sso_registration_args() sso_token = self._sso_login( self._session, - sso_region, - start_url, token_cache=self._sso_token_cache, on_pending_authorization=on_pending_authorization, + **sso_registration_args ) # Construct an SSO client to explore the accounts / roles client_config = Config( signature_version=UNSIGNED, - region_name=sso_region, + region_name=sso_registration_args['sso_region'], ) sso = self._session.create_client('sso', config=client_config) - sso_account_id = self._prompt_for_account(sso, sso_token) - sso_role_name = self._prompt_for_role(sso, sso_token, sso_account_id) + sso_account_id, sso_role_name = self._prompt_for_sso_account_and_role( + sso, sso_token + ) + configured_for_aws_credentials = all((sso_account_id, sso_role_name)) # General CLI configuration self._prompt_for_cli_default_region() @@ -331,20 +538,153 @@ def _run_main(self, parsed_args, parsed_globals): profile_name = self._prompt_for_profile(sso_account_id, sso_role_name) - usage_msg = ( - '\nTo use this profile, specify the profile name using ' - '--profile, as shown:\n\n' - 'aws s3 ls --profile {}\n' - ) - uni_print(usage_msg.format(profile_name)) - self._write_new_config(profile_name) + self._print_conclusion(configured_for_aws_credentials, profile_name) return 0 + def _prompt_for_sso_registration_args(self): + sso_session = self._sso_session_prompter.prompt_for_sso_session( + required=False) + if sso_session is None: + self._warn_configuring_using_legacy_format() + return self._prompt_for_registration_args_with_legacy_format() + else: + self._set_sso_session_in_profile_config(sso_session) + if sso_session in self._sso_sessions: + return self._get_sso_registration_args_from_sso_config( + sso_session) + else: + return self._prompt_for_registration_args_for_new_sso_session( + sso_session=sso_session + ) + + def _prompt_for_registration_args_with_legacy_format(self): + self._store_sso_session_prompter_answers_to_profile_config() + self._set_sso_session_defaults_from_profile_config() + start_url, sso_region = self._prompt_for_sso_start_url_and_sso_region() + return { + 'start_url': start_url, + 'sso_region': sso_region + } + + def _get_sso_registration_args_from_sso_config(self, sso_session): + sso_config = self._get_sso_session_config(sso_session) + return { + 'session_name': sso_session, + 'start_url': sso_config['sso_start_url'], + 'sso_region': sso_config['sso_region'], + 'registration_scopes': sso_config.get('registration_scopes') + } + + def _prompt_for_registration_args_for_new_sso_session(self, sso_session): + self._set_sso_session_defaults_from_profile_config() + start_url, sso_region = self._prompt_for_sso_start_url_and_sso_region() + scopes = self._sso_session_prompter.prompt_for_sso_registration_scopes() + return { + 'session_name': sso_session, + 'start_url': start_url, + 'sso_region': sso_region, + 'registration_scopes': scopes, + # We force refresh for any new SSO sessions to ensure we are not + # using any cached tokens from any previous of attempts to + # create/authenticate a new SSO session as part of the configure + # sso flow. + 'force_refresh': True + } + + def _store_sso_session_prompter_answers_to_profile_config(self): + # Wire the SSO session prompter to set config values to the + # dictionary used for writing to the profile section + self._sso_session_prompter.sso_session_config = \ + self._new_profile_config_values + + def _set_sso_session_in_profile_config(self, sso_session): + self._new_profile_config_values['sso_session'] = sso_session + + def _set_sso_session_defaults_from_profile_config(self): + # This is to ensure the SSO session prompter pulls in existing + # SSO configuration as part of the prompt if a profile was explicitly + # provided that already had SSO configuration + if 'sso_start_url' in self._profile_config: + self._sso_session_prompter.sso_session_config['sso_start_url'] = \ + self._profile_config['sso_start_url'] + if 'sso_region' in self._profile_config: + self._sso_session_prompter.sso_session_config['sso_region'] = \ + self._profile_config['sso_region'] + + def _prompt_for_sso_start_url_and_sso_region(self): + start_url = self._sso_session_prompter.prompt_for_sso_start_url() + sso_region = self._sso_session_prompter.prompt_for_sso_region() + return start_url, sso_region + + def _warn_configuring_using_legacy_format(self): + uni_print( + f'{colorama.Style.BRIGHT}WARNING: Configuring using legacy format ' + f'(e.g. without an SSO session).\n' + f'Consider re-running "configure sso" command and providing ' + f'a session name.\n{colorama.Style.RESET_ALL}' + ) + + def _prompt_for_sso_account_and_role(self, sso, sso_token): + sso_account_id = None + sso_role_name = None + try: + sso_account_id = self._prompt_for_account(sso, sso_token) + sso_role_name = self._prompt_for_role( + sso, sso_token, sso_account_id) + except sso.exceptions.UnauthorizedException as e: + uni_print( + 'Unable to list AWS accounts and/or roles. ' + 'Skipping configuring AWS credential provider for profile.\n' + ) + return sso_account_id, sso_role_name + def _write_new_config(self, profile): - config_path = self._session.get_config_variable('config_file') - config_path = os.path.expanduser(config_path) - if self._new_values: - section = profile_to_section(profile) - self._new_values['__section__'] = section - self._config_writer.update_config(self._new_values, config_path) + if self._new_profile_config_values: + profile_section = profile_to_section(profile) + self._update_section( + profile_section, self._new_profile_config_values) + if self._sso_session_prompter.sso_session: + self._write_sso_configuration() + + def _print_conclusion(self, configured_for_aws_credentials, profile_name): + if configured_for_aws_credentials: + msg = ( + '\nTo use this profile, specify the profile name using ' + '--profile, as shown:\n\n' + 'aws s3 ls --profile {}\n' + ) + else: + msg = 'Successfully configured SSO for profile: {}\n' + uni_print(msg.format(profile_name)) + + +class ConfigureSSOSessionCommand(BaseSSOConfigurationCommand): + NAME = 'sso-session' + SYNOPSIS = ('aws configure sso-session') + DESCRIPTION = ( + 'The ``aws configure sso-session`` command interactively prompts for ' + 'the configuration values required to create a SSO session. ' + 'The SSO session can then be associated to a profile to retrieve ' + 'SSO access tokens and AWS credentials.\n\n' + f'{_CMD_PROMPT_USAGE}' + f'{_CONFIG_EXTRA_INFO}' + ) + + def _run_main(self, parsed_args, parsed_globals): + self._sso_session_prompter.prompt_for_sso_session() + self._sso_session_prompter.prompt_for_sso_start_url() + self._sso_session_prompter.prompt_for_sso_region() + self._sso_session_prompter.prompt_for_sso_registration_scopes() + self._write_sso_configuration() + self._print_configuration_success() + return 0 + + def _print_configuration_success(self): + sso_session = self._sso_session_prompter.sso_session + uni_print( + f'\nCompleted configuring SSO session: {sso_session}\n' + f'Run the following to login and refresh access token for ' + f'this session:\n\n' + f'aws sso login --sso-session {sso_session}\n' + ) diff --git a/awscli/customizations/sso/login.py b/awscli/customizations/sso/login.py index c921d06a5373..bf7f5489726d 100644 --- a/awscli/customizations/sso/login.py +++ b/awscli/customizations/sso/login.py @@ -23,14 +23,24 @@ class LoginCommand(BaseSSOCommand): 'credentials. To login, the requested profile must have first been ' 'setup using ``aws configure sso``. Each time the ``login`` command ' 'is called, a new SSO access token will be retrieved. Please note ' - 'that only one login session can be active for a given SSO Start URL ' + 'that only one login session can be active for a given SSO Session ' 'and creating multiple profiles does not allow for multiple users to ' - 'be authenticated against the same SSO Start URL.' + 'be authenticated against the same SSO Session.' ) - ARG_TABLE = LOGIN_ARGS + ARG_TABLE = LOGIN_ARGS + [ + { + 'name': 'sso-session', + 'help_text': ( + 'An explicit SSO session to use to login. By default, this ' + 'command will login using the SSO session configured as part ' + 'of the requested profile and generally does not require this ' + 'argument to be set.' + ) + } + ] def _run_main(self, parsed_args, parsed_globals): - sso_config = self._get_sso_config() + sso_config = self._get_sso_config(sso_session=parsed_args.sso_session) on_pending_authorization = None if parsed_args.no_browser: on_pending_authorization = PrintOnlyHandler() diff --git a/awscli/customizations/sso/utils.py b/awscli/customizations/sso/utils.py index ae809d8031a5..ae9a83e9e8b8 100644 --- a/awscli/customizations/sso/utils.py +++ b/awscli/customizations/sso/utils.py @@ -77,6 +77,14 @@ def do_sso_login(session, sso_region, start_url, token_cache=None, ) +def parse_sso_registration_scopes(raw_scopes): + parsed_scopes = [] + for scope in raw_scopes.split(','): + if scope := scope.strip(): + parsed_scopes.append(scope) + return parsed_scopes + + def open_browser_with_original_ld_path(url): with original_ld_library_path(): webbrowser.open_new_tab(url) @@ -148,25 +156,16 @@ class BaseSSOCommand(BasicCommand): 'sso_region', ] - def _get_sso_config(self): + def _get_sso_config(self, sso_session=None): scoped_config = self._session.get_scoped_config() - sso_session_config = self._get_sso_session_config(scoped_config) - if sso_session_config: - return sso_session_config - return self._get_legacy_sso_config(scoped_config) - - def _get_sso_session_config(self, scoped_config): - if 'sso_session' not in scoped_config: - return None - - for config_var in self._REQUIRED_SSO_CONFIG_VARS: - if config_var in scoped_config: - raise InvalidSSOConfigError( - 'Inline SSO configuration and sso_session cannot be ' - 'configured on the same profile.' - ) - - session_name = scoped_config['sso_session'] + if sso_session is None: + sso_session = scoped_config.get('sso_session') + if sso_session: + return self._get_sso_session_config(sso_session) + else: + return self._get_legacy_sso_config(scoped_config) + + def _get_sso_session_config(self, session_name): full_config = self._session.full_config if session_name not in full_config.get('sso_sessions', {}): raise InvalidSSOConfigError( @@ -179,7 +178,7 @@ def _get_sso_session_config(self, scoped_config): scopes_var = 'sso_registration_scopes' if scopes_var in session_config: raw_scopes = session_config[scopes_var] - parsed_scopes = self._parse_registration_scopes(raw_scopes) + parsed_scopes = parse_sso_registration_scopes(raw_scopes) sso_config['registration_scopes'] = parsed_scopes if missing: @@ -190,14 +189,6 @@ def _get_sso_session_config(self, scoped_config): return sso_config - def _parse_registration_scopes(self, raw_scopes): - parsed_scopes = [] - for scope in raw_scopes.split(','): - scope = scope.strip() - if scope: - parsed_scopes.append(scope) - return parsed_scopes - def _get_legacy_sso_config(self, scoped_config): sso_config, missing = self._get_required_config_vars(scoped_config) if missing: diff --git a/tests/functional/sso/__init__.py b/tests/functional/sso/__init__.py index 5dd426dfc826..98036f9173b8 100644 --- a/tests/functional/sso/__init__.py +++ b/tests/functional/sso/__init__.py @@ -67,12 +67,16 @@ def get_legacy_config(self): ) return content - def get_sso_session_config(self, session_name): - content = ( - f'[default]\n' - f'sso_session={session_name}\n' - f'sso_role_name={self.role_name}\n' - f'sso_account_id={self.account}\n' + def get_sso_session_config(self, session_name, include_profile=True): + content = '' + if include_profile: + content += ( + f'[default]\n' + f'sso_session={session_name}\n' + f'sso_role_name={self.role_name}\n' + f'sso_account_id={self.account}\n' + ) + content += ( f'[sso-session {session_name}]\n' f'sso_start_url={self.start_url}\n' f'sso_region={self.sso_region}\n' diff --git a/tests/functional/sso/test_login.py b/tests/functional/sso/test_login.py index dd8af1126f4f..e24c47591716 100644 --- a/tests/functional/sso/test_login.py +++ b/tests/functional/sso/test_login.py @@ -195,6 +195,19 @@ def test_login_sso_session(self): expected_token=self.access_token, ) + def test_login_sso_with_explicit_sso_session_arg(self): + content = self.get_sso_session_config( + 'test-session', include_profile=False) + self.set_config_file_content(content=content) + self.add_oidc_workflow_responses(self.access_token) + self.run_cmd('sso login --sso-session test-session') + self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_token( + start_url=self.start_url, + session_name='test-session', + expected_token=self.access_token, + ) + def test_login_sso_session_with_scopes(self): self.registration_scopes = ['sso:foo', 'sso:bar'] content = self.get_sso_session_config('test-session') @@ -211,21 +224,6 @@ def test_login_sso_session_with_scopes(self): self.assertEqual(operation.name, 'RegisterClient') self.assertEqual(params.get('scopes'), self.registration_scopes) - def test_login_sso_session_and_legacy_config_errors(self): - content = self.get_legacy_config() - content += ( - f'sso_session=test\n' - f'[sso-session test]\n' - f'sso_start_url={self.start_url}\n' - f'sso_region={self.sso_region}\n' - ) - self.set_config_file_content(content=content) - _, stderr, _ = self.run_cmd('sso login', expected_rc=253) - self.assertIn( - 'cannot be configured on the same profile', - stderr - ) - def test_login_sso_session_missing_config(self): content = ( f'[default]\n' diff --git a/tests/unit/customizations/configure/__init__.py b/tests/unit/customizations/configure/__init__.py index 6cfcd74ccf8b..60cfce39f4ad 100644 --- a/tests/unit/customizations/configure/__init__.py +++ b/tests/unit/customizations/configure/__init__.py @@ -26,6 +26,7 @@ def __init__(self, all_variables, profile_does_not_exist=False, self.variables = all_variables self.profile_does_not_exist = profile_does_not_exist self.config = {} + self.full_config = {} if config_file_vars is None: config_file_vars = {} self.config_file_vars = config_file_vars diff --git a/tests/unit/customizations/configure/test_sso.py b/tests/unit/customizations/configure/test_sso.py index 0ef24dbd8fe8..b515e763726b 100644 --- a/tests/unit/customizations/configure/test_sso.py +++ b/tests/unit/customizations/configure/test_sso.py @@ -10,29 +10,1575 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +import argparse +import dataclasses +import json +import typing + import mock from datetime import datetime, timedelta + +import prompt_toolkit +import pytest from dateutil.tz import tzlocal from prompt_toolkit import prompt as ptk_prompt from prompt_toolkit.document import Document +from prompt_toolkit.validation import Validator from prompt_toolkit.validation import DummyValidator from prompt_toolkit.validation import ValidationError -from botocore.session import Session from botocore.stub import Stubber -from botocore.exceptions import ProfileNotFound from awscli.testutils import unittest from awscli.customizations.configure.sso import display_account -from awscli.customizations.configure.sso import select_menu from awscli.customizations.configure.sso import PTKPrompt +from awscli.customizations.configure.sso import SSOSessionConfigurationPrompter from awscli.customizations.configure.sso import ConfigureSSOCommand +from awscli.customizations.configure.sso import ConfigureSSOSessionCommand from awscli.customizations.configure.sso import StartUrlValidator -from awscli.customizations.configure.writer import ConfigFileWriter +from awscli.customizations.configure.sso import RequiredInputValidator +from awscli.customizations.configure.sso import ScopesValidator +from awscli.customizations.sso.utils import parse_sso_registration_scopes from awscli.customizations.sso.utils import do_sso_login, PrintOnlyHandler from awscli.formatter import CLI_OUTPUT_FORMATS +from tests import StubbedSession + + +@pytest.fixture +def aws_config(tmp_path): + return tmp_path / "config" + + +@pytest.fixture +def env(aws_config): + env_vars = { + "AWS_DEFAULT_REGION": "us-west-2", + "AWS_ACCESS_KEY_ID": "access_key", + "AWS_SECRET_ACCESS_KEY": "secret_key", + "AWS_CONFIG_FILE": aws_config, + "AWS_SHARED_CREDENTIALS_FILE": "", + } + with mock.patch("os.environ", env_vars): + yield env_vars + + +@pytest.fixture +def access_token(): + return "access.token.string" + + +@pytest.fixture +def account_id(): + return "0123456789" + + +@pytest.fixture +def role_name(): + return "roleA" + + +@pytest.fixture +def sso_session_name(): + return "dev" + + +@pytest.fixture +def scopes(): + return "scope-1, scope-2" + + +@pytest.fixture +def default_sso_scope(): + return "sso:account:access" + + +@pytest.fixture +def existing_profile_name(): + return "existing-profile" + + +@pytest.fixture +def existing_sso_session(): + return "existing-sso-session" + + +@pytest.fixture +def existing_start_url(): + return "https://existing-start-url" + + +@pytest.fixture +def existing_sso_region(): + return "existing-sso-region" + + +@pytest.fixture +def existing_scopes(): + return "existing-scope-1, existing-scope-2" + + +@pytest.fixture +def existing_region(): + return "existing-region" + + +@pytest.fixture +def existing_output(): + return "existing-output" + + +@pytest.fixture +def botocore_session(env): + return StubbedSession() + + +@pytest.fixture +def all_sso_oidc_regions(botocore_session): + return botocore_session.get_available_regions("sso-oidc") + + +@pytest.fixture +def sso_stubber_factory(env, botocore_session): + def create_sso_stubber(session=None): + if session is None: + session = botocore_session + sso_client = session.create_client("sso") + stubber = Stubber(sso_client) + stubber.activate() + return stubber + + return create_sso_stubber + + +@pytest.fixture +def sso_stubber(sso_stubber_factory): + return sso_stubber_factory() + + +@pytest.fixture +def stub_sso_list_accounts(sso_stubber, access_token): + def _do_stub_list_accounts(accounts, override_sso_stubber=None): + stubber = sso_stubber + if override_sso_stubber is not None: + stubber = override_sso_stubber + stubber.add_response( + "list_accounts", + service_response={ + "accountList": accounts, + }, + expected_params={"accessToken": access_token}, + ) + + return _do_stub_list_accounts + + +@pytest.fixture +def stub_sso_list_roles(sso_stubber, access_token): + def _do_stub_list_accounts( + role_names, expected_account_id, override_sso_stubber=None + ): + stubber = sso_stubber + if override_sso_stubber is not None: + stubber = override_sso_stubber + stubber.add_response( + "list_account_roles", + service_response={ + "roleList": [ + {"roleName": role_name} for role_name in role_names + ], + }, + expected_params={ + "accountId": expected_account_id, + "accessToken": access_token, + }, + ) + + return _do_stub_list_accounts + + +@pytest.fixture +def stub_simple_single_item_sso_responses( + sso_stubber, access_token, stub_sso_list_accounts, stub_sso_list_roles +): + def _do_stub_simple_single_item_sso_responses( + account_id, role_name, override_sso_stubber=None + ): + stub_sso_list_accounts( + accounts=[ + { + "accountId": account_id, + "emailAddress": "account@site.com", + } + ], + override_sso_stubber=override_sso_stubber, + ) + stub_sso_list_roles( + role_names=[role_name], + expected_account_id=account_id, + override_sso_stubber=override_sso_stubber, + ) + + return _do_stub_simple_single_item_sso_responses + + +@pytest.fixture +def stub_sso_authorization_error(sso_stubber): + def _do_stub_authorization_error(override_sso_stubber=None): + stubber = sso_stubber + if override_sso_stubber is not None: + stubber = override_sso_stubber + stubber.add_client_error( + "list_accounts", service_error_code="UnauthorizedException" + ) + + return _do_stub_authorization_error + + +@pytest.fixture() +def ptk_stubber(): + return PTKStubber() + + +@pytest.fixture +def prompter(ptk_stubber): + return PTKPrompt(prompter=ptk_stubber.prompt) + + +@pytest.fixture +def sso_config_prompter_factory(env, botocore_session, prompter): + def create_sso_config_prompter(session=None, prompt=None): + if session is None: + session = botocore_session + if prompt is None: + prompt = prompter + return SSOSessionConfigurationPrompter( + botocore_session=session, prompter=prompt + ) + + return create_sso_config_prompter + + +@pytest.fixture +def sso_config_prompter(sso_config_prompter_factory): + return sso_config_prompter_factory() + + +@pytest.fixture +def selector(ptk_stubber): + return ptk_stubber.select_menu + + +@pytest.fixture +def mock_ptk_app(): + mock_app = mock.Mock(spec=prompt_toolkit.application.DummyApplication()) + with prompt_toolkit.application.current.set_app(mock_app): + yield mock_app + + +@pytest.fixture +def mock_do_sso_login(): + login_mock = mock.Mock(spec=do_sso_login) + login_mock.return_value = { + "accessToken": "access.token.string", + "expiresAt": datetime.now(tzlocal()) + timedelta(hours=24), + } + return login_mock + + +@pytest.fixture +def sso_cmd_factory( + env, botocore_session, prompter, mock_do_sso_login, selector +): + def create_sso_cmd(**override_kwargs): + kwargs = { + "session": botocore_session, + "prompter": prompter, + "sso_login": mock_do_sso_login, + "selector": selector, + } + kwargs.update(**override_kwargs) + return ConfigureSSOCommand(**kwargs) + + return create_sso_cmd + + +@pytest.fixture +def sso_cmd(sso_cmd_factory): + return sso_cmd_factory() + + +@pytest.fixture +def sso_session_cmd_factory(env, botocore_session, prompter): + def create_sso_session_cmd(**override_kwargs): + kwargs = {"session": botocore_session, "prompter": prompter} + kwargs.update(**override_kwargs) + return ConfigureSSOSessionCommand(**kwargs) + + return create_sso_session_cmd + + +@pytest.fixture +def sso_session_cmd(sso_session_cmd_factory): + return sso_session_cmd_factory() + + +@pytest.fixture +def args(): + return [] + + +@pytest.fixture +def parsed_globals(): + return argparse.Namespace() + + +@pytest.fixture +def start_url_prompt(): + return StartUrlPrompt(answer="https://starturl", expected_default=None) + + +@pytest.fixture +def sso_region_prompt(): + return SSORegionPrompt(answer="us-west-2", expected_default=None) + + +@pytest.fixture +def scopes_prompt(scopes, default_sso_scope): + return ScopesPrompt(answer=scopes, expected_default=default_sso_scope) + + +@pytest.fixture +def account_id_select(account_id): + selected_account = { + "accountId": account_id, + "emailAddress": "account@site.com", + } + return SelectMenu( + answer=selected_account, + expected_choices=[ + selected_account, + {"accountId": "1234567890", "emailAddress": "account2@site.com"}, + ], + ) + + +@pytest.fixture +def role_name_select(role_name): + return SelectMenu(answer=role_name, expected_choices=[role_name, "roleB"]) + + +@pytest.fixture +def region_prompt(): + return RegionPrompt(answer="us-west-2", expected_default=None) + + +@pytest.fixture +def output_prompt(): + return OutputPrompt(answer="json", expected_default=None) + + +@pytest.fixture +def profile_prompt(role_name, account_id): + return ProfilePrompt( + answer="dev", expected_default=f"{role_name}-{account_id}" + ) + + +@pytest.fixture +def configure_sso_legacy_inputs( + start_url_prompt, + sso_region_prompt, + account_id_select, + role_name_select, + region_prompt, + output_prompt, + profile_prompt, +): + return UserInputs( + session_prompt=RecommendedSessionPrompt(answer=""), + start_url_prompt=start_url_prompt, + sso_region_prompt=sso_region_prompt, + account_id_select=account_id_select, + role_name_select=role_name_select, + region_prompt=region_prompt, + output_prompt=output_prompt, + profile_prompt=profile_prompt, + ) + + +@pytest.fixture +def configure_sso_legacy_with_existing_defaults_inputs( + configure_sso_legacy_inputs, + existing_start_url, + existing_sso_region, + existing_region, + existing_output, +): + inputs = configure_sso_legacy_inputs + inputs.start_url_prompt.expected_default = existing_start_url + inputs.sso_region_prompt.expected_default = existing_sso_region + inputs.region_prompt.expected_default = existing_region + inputs.output_prompt.expected_default = existing_output + return inputs + + +@pytest.fixture +def configure_sso_using_new_session_inputs( + start_url_prompt, + sso_region_prompt, + scopes_prompt, + account_id_select, + role_name_select, + region_prompt, + output_prompt, + profile_prompt, + sso_session_name, +): + return UserInputs( + session_prompt=RecommendedSessionPrompt(answer=sso_session_name), + start_url_prompt=start_url_prompt, + sso_region_prompt=sso_region_prompt, + scopes_prompt=scopes_prompt, + account_id_select=account_id_select, + role_name_select=role_name_select, + region_prompt=region_prompt, + output_prompt=output_prompt, + profile_prompt=profile_prompt, + ) + + +@pytest.fixture() +def configure_sso_using_existing_session_inputs( + account_id_select, + role_name_select, + region_prompt, + output_prompt, + profile_prompt, + existing_sso_session, +): + return UserInputs( + session_prompt=RecommendedSessionPrompt(answer=existing_sso_session), + account_id_select=account_id_select, + role_name_select=role_name_select, + region_prompt=region_prompt, + output_prompt=output_prompt, + profile_prompt=profile_prompt, + ) + + +@pytest.fixture +def configure_sso_with_existing_defaults_inputs( + configure_sso_using_existing_session_inputs, + existing_sso_session, + existing_region, + existing_output, + sso_session_name, +): + inputs = configure_sso_using_existing_session_inputs + inputs.session_prompt = SessionWithDefaultPrompt( + answer=sso_session_name, expected_default=existing_sso_session + ) + inputs.region_prompt.expected_default = existing_region + inputs.output_prompt.expected_default = existing_output + return inputs + + +@pytest.fixture +def configure_sso_using_new_session_from_legacy_profile_inputs( + configure_sso_using_new_session_inputs, + sso_session_name, + existing_start_url, + existing_sso_region, + existing_region, + existing_output, +): + inputs = configure_sso_using_new_session_inputs + inputs.clear_answers() + inputs.session_prompt.answer = sso_session_name + inputs.start_url_prompt.expected_default = existing_start_url + inputs.sso_region_prompt.expected_default = existing_sso_region + inputs.region_prompt.expected_default = existing_region + inputs.output_prompt.expected_default = existing_output + return inputs + + +@pytest.fixture() +def configure_sso_session_inputs( + sso_session_name, start_url_prompt, sso_region_prompt, scopes_prompt +): + return UserInputs( + session_prompt=RequiredSessionPrompt(answer=sso_session_name), + start_url_prompt=start_url_prompt, + sso_region_prompt=sso_region_prompt, + scopes_prompt=scopes_prompt, + ) + + +@pytest.fixture +def configure_sso_session_with_existing_defaults_inputs( + configure_sso_session_inputs, + existing_start_url, + existing_sso_region, + existing_scopes, +): + inputs = configure_sso_session_inputs + inputs.start_url_prompt.expected_default = existing_start_url + inputs.sso_region_prompt.expected_default = existing_sso_region + inputs.scopes_prompt.expected_default = existing_scopes + return inputs + + +@pytest.fixture +def aws_config_lines_for_existing_legacy_profile( + existing_profile_name, + existing_start_url, + existing_sso_region, + existing_region, + existing_output, + account_id, + role_name, +): + return [ + f"[profile {existing_profile_name}]", + f"sso_start_url = {existing_start_url}", + f"sso_region = {existing_sso_region}", + f"sso_account_id = {account_id}", + f"sso_role_name = {role_name}", + f"region = {existing_region}", + f"output = {existing_output}", + ] + + +@pytest.fixture +def aws_config_lines_for_existing_sso_session( + existing_sso_session, + existing_start_url, + existing_sso_region, + existing_scopes, +): + return [ + f"[sso-session {existing_sso_session}]", + f"sso_start_url = {existing_start_url}", + f"sso_region = {existing_sso_region}", + f"sso_registration_scopes = {existing_scopes}", + ] + + +@pytest.fixture +def aws_config_lines_for_existing_profile_and_session( + existing_profile_name, + existing_sso_session, + existing_region, + existing_output, + account_id, + role_name, + aws_config_lines_for_existing_sso_session, +): + return [ + f"[profile {existing_profile_name}]", + f"sso_session = {existing_sso_session}", + f"sso_account_id = {account_id}", + f"sso_role_name = {role_name}", + f"region = {existing_region}", + f"output = {existing_output}", + ] + aws_config_lines_for_existing_sso_session + + +@dataclasses.dataclass +class UserInput: + answer: typing.Any + + +@dataclasses.dataclass +class Prompt(UserInput): + expected_validator_cls: typing.Optional[Validator] = None + expected_completions: typing.Optional[typing.List[str]] = None + _expected_message: typing.Optional[str] = dataclasses.field( + init=False, repr=False, default=None + ) + + @property + def expected_message(self): + return self._expected_message + + @expected_message.setter + def expected_message(self, value): + self._expected_message = value + + +@dataclasses.dataclass +class PromptWithDefault(Prompt): + expected_default: typing.Any = None + msg_format: str = dataclasses.field(init=False) + + @property + def expected_message(self): + if self._expected_message is None: + self._expected_message = self.msg_format.format( + default=self.expected_default + ) + return self._expected_message + + +@dataclasses.dataclass +class StartUrlPrompt(PromptWithDefault): + msg_format: str = dataclasses.field( + init=False, default="SSO start URL [{default}]: " + ) + expected_validator_cls: typing.Optional[Validator] = StartUrlValidator + + +@dataclasses.dataclass +class SSORegionPrompt(PromptWithDefault): + msg_format: str = dataclasses.field( + init=False, default="SSO region [{default}]: " + ) + expected_validator_cls: typing.Optional[Validator] = RequiredInputValidator + + +@dataclasses.dataclass +class ScopesPrompt(PromptWithDefault): + msg_format: str = dataclasses.field( + init=False, default="SSO registration scopes [{default}]: " + ) + expected_validator_cls: typing.Optional[Validator] = ScopesValidator + + +@dataclasses.dataclass +class RequiredSessionPrompt(Prompt): + expected_validator_cls: typing.Optional[Validator] = RequiredInputValidator + + def __post_init__(self): + super().__init__( + answer=self.answer, + expected_validator_cls=self.expected_validator_cls, + ) + self.expected_message = "SSO session name: " + + +@dataclasses.dataclass +class RecommendedSessionPrompt(Prompt): + def __post_init__(self): + super().__init__(answer=self.answer) + self.expected_message = "SSO session name (Recommended): " + + +@dataclasses.dataclass +class SessionWithDefaultPrompt(PromptWithDefault): + msg_format: str = dataclasses.field( + init=False, default="SSO session name [{default}]: " + ) + + +@dataclasses.dataclass +class RegionPrompt(PromptWithDefault): + msg_format: str = dataclasses.field( + init=False, default="CLI default client Region [{default}]: " + ) + + +@dataclasses.dataclass +class OutputPrompt(PromptWithDefault): + msg_format: str = dataclasses.field( + init=False, default="CLI default output format [{default}]: " + ) + + +@dataclasses.dataclass +class ProfilePrompt(PromptWithDefault): + msg_format: str = dataclasses.field( + init=False, default="CLI profile name [{default}]: " + ) + expected_validator_cls: typing.Optional[Validator] = RequiredInputValidator + + +@dataclasses.dataclass +class SelectMenu(UserInput): + expected_choices: typing.Optional[typing.List[typing.Any]] = None + + +@dataclasses.dataclass +class UserInputs: + session_prompt: typing.Optional[Prompt] = None + start_url_prompt: typing.Optional[StartUrlPrompt] = None + sso_region_prompt: typing.Optional[SSORegionPrompt] = None + scopes_prompt: typing.Optional[ScopesPrompt] = None + account_id_select: typing.Optional[SelectMenu] = None + role_name_select: typing.Optional[SelectMenu] = None + region_prompt: typing.Optional[RegionPrompt] = None + output_prompt: typing.Optional[OutputPrompt] = None + profile_prompt: typing.Optional[ProfilePrompt] = None + + def get_expected_inputs(self): + expected_inputs = [] + for possible_input_field in dataclasses.fields(self): + possible_input = getattr(self, possible_input_field.name) + if possible_input is not None: + expected_inputs.append(possible_input) + return expected_inputs + + def clear_answers(self): + for user_input in self.get_expected_inputs(): + user_input.answer = "" + + def skip_account_and_role_selection(self): + self.account_id_select = None + self.role_name_select = None + + def skip_profile_prompt(self): + self.profile_prompt = None + + +class PTKStubber: + _ALLOWED_PROMPT_KWARGS = { + "validator", + "validate_while_typing", + "completer", + "complete_while_typing", + "bottom_toolbar", + "refresh_interval", + "style", + } + _ALLOWED_SELECT_MENU_KWARGS = { + "display_format", + "max_height", + } + + def __init__(self, user_inputs=None): + if user_inputs is None: + user_inputs = UserInputs() + self.user_inputs = user_inputs + self._expected_inputs = None + + def prompt(self, message, **kwargs): + self._initialize_expected_inputs_if_needed() + self._validate_kwargs(kwargs, self._ALLOWED_PROMPT_KWARGS) + if not self._expected_inputs: + raise AssertionError( + f'Received prompt with no stubbed answer: "{message}"' + ) + prompt = self._expected_inputs.pop(0) + assert isinstance( + prompt, Prompt + ), f'Did not receive user input of type Prompt for: "{message}"' + if prompt.expected_message is not None: + assert message == prompt.expected_message, ( + f"Prompt does not match expected " + f'prompt for answer: "{prompt}"' + ) + if prompt.expected_validator_cls: + assert isinstance( + kwargs.get("validator"), prompt.expected_validator_cls + ) + if prompt.expected_completions is not None: + provided_completer = kwargs.get("completer") + assert provided_completer is not None, ( + f"Expected completions but no completer was provided for " + f"prompt: {prompt}" + ) + assert provided_completer.words == prompt.expected_completions + return prompt.answer + + def select_menu(self, items, **kwargs): + self._initialize_expected_inputs_if_needed() + self._validate_kwargs(kwargs, self._ALLOWED_SELECT_MENU_KWARGS) + if not self._expected_inputs: + raise AssertionError( + f'Received select_menu with no stubbed answer: "{items}"' + ) + select_menu = self._expected_inputs.pop(0) + assert isinstance( + select_menu, SelectMenu + ), f'Did not receive user input of type SelectMenu for: "{items}"' + if select_menu.expected_choices is not None: + assert items == select_menu.expected_choices, ( + f"Choices does not match expected select_menu choices " + f'for answer: "{select_menu.answer}"' + ) + return select_menu.answer + + def _initialize_expected_inputs_if_needed(self): + if self._expected_inputs is None: + self._expected_inputs = self.user_inputs.get_expected_inputs() + + def _validate_kwargs(self, provided_kwargs, allowed_kwargs): + assert set(provided_kwargs).issubset( + allowed_kwargs + ), "Arguments provided does not matched allowed keyword arguments" + + +def write_aws_config(aws_config, lines): + with open(aws_config, "w") as f: + content = "\n".join(lines) + f.write(content + "\n") + + +def assert_aws_config(aws_config, expected_lines): + with open(aws_config, "r") as f: + assert f.read().splitlines() == expected_lines + + +class TestConfigureSSOCommand: + def assert_do_sso_login_call( + self, + mock_do_sso_login, + botocore_session, + expected_sso_region, + expected_start_url, + expected_session_name=None, + expected_scopes=None, + expected_auth_handler_cls=None, + expected_force_refresh=None, + ): + expected_kwargs = { + "sso_region": expected_sso_region, + "start_url": expected_start_url, + "on_pending_authorization": None, + "token_cache": None, + } + if expected_session_name is not None: + expected_kwargs["session_name"] = expected_session_name + if expected_scopes is not None: + expected_kwargs["registration_scopes"] = expected_scopes + if expected_auth_handler_cls: + expected_kwargs["on_pending_authorization"] = mock.ANY + if expected_force_refresh is not None: + expected_kwargs["force_refresh"] = expected_force_refresh + + mock_do_sso_login.assert_called_with( + botocore_session, **expected_kwargs + ) + + if expected_auth_handler_cls: + _, _, login_kwargs = mock_do_sso_login.mock_calls[0] + auth_handler = login_kwargs["on_pending_authorization"] + assert isinstance(auth_handler, expected_auth_handler_cls) + + def test_legacy_configure_sso_flow( + self, + sso_cmd, + ptk_stubber, + aws_config, + stub_sso_list_roles, + stub_sso_list_accounts, + mock_do_sso_login, + botocore_session, + args, + parsed_globals, + configure_sso_legacy_inputs, + capsys, + ): + inputs = configure_sso_legacy_inputs + selected_account_id = inputs.account_id_select.answer["accountId"] + ptk_stubber.user_inputs = inputs + stub_sso_list_accounts(inputs.account_id_select.expected_choices) + stub_sso_list_roles( + inputs.role_name_select.expected_choices, + expected_account_id=selected_account_id, + ) + + sso_cmd(args, parsed_globals) + self.assert_do_sso_login_call( + mock_do_sso_login, + botocore_session, + expected_sso_region=inputs.sso_region_prompt.answer, + expected_start_url=inputs.start_url_prompt.answer, + ) + assert_aws_config( + aws_config, + expected_lines=[ + f"[profile {inputs.profile_prompt.answer}]", + f"sso_start_url = {inputs.start_url_prompt.answer}", + f"sso_region = {inputs.sso_region_prompt.answer}", + f"sso_account_id = {selected_account_id}", + f"sso_role_name = {inputs.role_name_select.answer}", + f"region = {inputs.region_prompt.answer}", + f"output = {inputs.output_prompt.answer}", + ], + ) + stdout = capsys.readouterr().out + assert "WARNING: Configuring using legacy format" in stdout + assert f"aws s3 ls --profile {inputs.profile_prompt.answer}" in stdout + + def test_single_account_single_role_flow_no_browser( + self, + sso_cmd, + ptk_stubber, + aws_config, + stub_simple_single_item_sso_responses, + mock_do_sso_login, + botocore_session, + parsed_globals, + configure_sso_legacy_inputs, + account_id, + role_name, + ): + inputs = configure_sso_legacy_inputs + inputs.skip_account_and_role_selection() + ptk_stubber.user_inputs = inputs + stub_simple_single_item_sso_responses(account_id, role_name) + + sso_cmd(["--no-browser"], parsed_globals) + self.assert_do_sso_login_call( + mock_do_sso_login, + botocore_session, + expected_sso_region=inputs.sso_region_prompt.answer, + expected_start_url=inputs.start_url_prompt.answer, + expected_auth_handler_cls=PrintOnlyHandler, + ) + assert_aws_config( + aws_config, + expected_lines=[ + f"[profile {inputs.profile_prompt.answer}]", + f"sso_start_url = {inputs.start_url_prompt.answer}", + f"sso_region = {inputs.sso_region_prompt.answer}", + f"sso_account_id = {account_id}", + f"sso_role_name = {role_name}", + f"region = {inputs.region_prompt.answer}", + f"output = {inputs.output_prompt.answer}", + ], + ) + + def test_single_account_single_role_flow( + self, + sso_cmd, + ptk_stubber, + aws_config, + stub_simple_single_item_sso_responses, + mock_do_sso_login, + botocore_session, + args, + parsed_globals, + configure_sso_legacy_inputs, + account_id, + role_name, + ): + inputs = configure_sso_legacy_inputs + inputs.skip_account_and_role_selection() + ptk_stubber.user_inputs = inputs + stub_simple_single_item_sso_responses(account_id, role_name) + + sso_cmd(args, parsed_globals) + self.assert_do_sso_login_call( + mock_do_sso_login, + botocore_session, + expected_sso_region=inputs.sso_region_prompt.answer, + expected_start_url=inputs.start_url_prompt.answer, + ) + assert_aws_config( + aws_config, + expected_lines=[ + f"[profile {inputs.profile_prompt.answer}]", + f"sso_start_url = {inputs.start_url_prompt.answer}", + f"sso_region = {inputs.sso_region_prompt.answer}", + f"sso_account_id = {account_id}", + f"sso_role_name = {role_name}", + f"region = {inputs.region_prompt.answer}", + f"output = {inputs.output_prompt.answer}", + ], + ) + + def test_no_accounts_flow_raises_error( + self, + sso_cmd, + ptk_stubber, + sso_stubber, + stub_sso_list_accounts, + args, + parsed_globals, + configure_sso_legacy_inputs, + ): + ptk_stubber.user_inputs = configure_sso_legacy_inputs + stub_sso_list_accounts([]) + with pytest.raises(RuntimeError): + sso_cmd(args, parsed_globals) + sso_stubber.assert_no_pending_responses() + + def test_no_roles_flow_raises_error( + self, + sso_cmd, + ptk_stubber, + sso_stubber, + stub_sso_list_accounts, + stub_sso_list_roles, + args, + parsed_globals, + configure_sso_legacy_inputs, + ): + only_account = configure_sso_legacy_inputs.account_id_select.answer + configure_sso_legacy_inputs.account_id_select = None + ptk_stubber.user_inputs = configure_sso_legacy_inputs + stub_sso_list_accounts([only_account]) + stub_sso_list_roles([], expected_account_id=only_account["accountId"]) + with pytest.raises(RuntimeError): + sso_cmd(args, parsed_globals) + sso_stubber.assert_no_pending_responses() + + def test_defaults_to_scoped_config( + self, + sso_cmd_factory, + ptk_stubber, + aws_config, + sso_stubber_factory, + stub_simple_single_item_sso_responses, + mock_do_sso_login, + args, + parsed_globals, + configure_sso_legacy_with_existing_defaults_inputs, + aws_config_lines_for_existing_legacy_profile, + account_id, + role_name, + existing_profile_name, + ): + write_aws_config( + aws_config, lines=aws_config_lines_for_existing_legacy_profile + ) + session = StubbedSession(profile=existing_profile_name) + + inputs = configure_sso_legacy_with_existing_defaults_inputs + inputs.skip_account_and_role_selection() + inputs.skip_profile_prompt() + inputs.clear_answers() + ptk_stubber.user_inputs = inputs + + sso_stubber = sso_stubber_factory(session) + stub_simple_single_item_sso_responses( + account_id, role_name, sso_stubber + ) + + sso_cmd = sso_cmd_factory(session=session) + sso_cmd(args, parsed_globals) + self.assert_do_sso_login_call( + mock_do_sso_login, + session, + expected_sso_region=inputs.sso_region_prompt.expected_default, + expected_start_url=inputs.start_url_prompt.expected_default, + ) + assert_aws_config( + aws_config, + expected_lines=aws_config_lines_for_existing_legacy_profile, + ) + + def test_handles_non_existent_profile( + self, + sso_cmd_factory, + ptk_stubber, + aws_config, + botocore_session, + stub_simple_single_item_sso_responses, + mock_do_sso_login, + args, + parsed_globals, + configure_sso_legacy_inputs, + account_id, + role_name, + ): + inputs = configure_sso_legacy_inputs + inputs.skip_account_and_role_selection() + inputs.skip_profile_prompt() + ptk_stubber.user_inputs = inputs + stub_simple_single_item_sso_responses(account_id, role_name) + + new_session = StubbedSession(profile="new-profile") + # We use the default session to create the stubbed clients because + # if we create the stubbed clients with a non-existent profile, we will + # get a ProfileNotFound error. So after the clients' creation we + # assign them to be used in the session using the new profile. + new_session.cached_clients.update(botocore_session.cached_clients) + new_session.client_stubs.update(botocore_session.client_stubs) + + sso_cmd = sso_cmd_factory(session=new_session) + sso_cmd(args, parsed_globals) + self.assert_do_sso_login_call( + mock_do_sso_login, + new_session, + expected_sso_region=inputs.sso_region_prompt.answer, + expected_start_url=inputs.start_url_prompt.answer, + ) + assert_aws_config( + aws_config, + expected_lines=[ + f"[profile new-profile]", + f"sso_start_url = {inputs.start_url_prompt.answer}", + f"sso_region = {inputs.sso_region_prompt.answer}", + f"sso_account_id = {account_id}", + f"sso_role_name = {role_name}", + f"region = {inputs.region_prompt.answer}", + f"output = {inputs.output_prompt.answer}", + ], + ) + + def test_cli_config_is_none_not_written( + self, + sso_cmd, + ptk_stubber, + aws_config, + botocore_session, + stub_simple_single_item_sso_responses, + args, + parsed_globals, + configure_sso_legacy_inputs, + account_id, + role_name, + ): + inputs = configure_sso_legacy_inputs + inputs.skip_account_and_role_selection() + inputs.region_prompt.answer = "" + inputs.output_prompt.answer = "" + ptk_stubber.user_inputs = inputs + stub_simple_single_item_sso_responses(account_id, role_name) + + sso_cmd(args, parsed_globals) + + assert_aws_config( + aws_config, + expected_lines=[ + f"[profile {inputs.profile_prompt.answer}]", + f"sso_start_url = {inputs.start_url_prompt.answer}", + f"sso_region = {inputs.sso_region_prompt.answer}", + f"sso_account_id = {account_id}", + f"sso_role_name = {role_name}", + ], + ) + + def test_prompts_suggest_values_from_profiles( + self, + sso_cmd_factory, + ptk_stubber, + aws_config, + sso_stubber_factory, + stub_simple_single_item_sso_responses, + aws_config_lines_for_existing_legacy_profile, + existing_start_url, + args, + parsed_globals, + configure_sso_legacy_inputs, + account_id, + role_name, + all_sso_oidc_regions, + ): + write_aws_config( + aws_config, lines=aws_config_lines_for_existing_legacy_profile + ) + session = StubbedSession() + + inputs = configure_sso_legacy_inputs + inputs.skip_account_and_role_selection() + inputs.start_url_prompt.expected_completions = [existing_start_url] + inputs.sso_region_prompt.expected_completions = all_sso_oidc_regions + inputs.output_prompt.expected_completions = list( + CLI_OUTPUT_FORMATS.keys() + ) + ptk_stubber.user_inputs = inputs + + sso_stubber = sso_stubber_factory(session) + stub_simple_single_item_sso_responses( + account_id, role_name, sso_stubber + ) + sso_cmd = sso_cmd_factory(session=session) + assert sso_cmd(args, parsed_globals) == 0 + + def test_configure_sso_with_new_sso_session( + self, + sso_cmd, + ptk_stubber, + aws_config, + stub_sso_list_roles, + stub_sso_list_accounts, + mock_do_sso_login, + botocore_session, + args, + parsed_globals, + configure_sso_using_new_session_inputs, + capsys, + ): + inputs = configure_sso_using_new_session_inputs + selected_account_id = inputs.account_id_select.answer["accountId"] + ptk_stubber.user_inputs = inputs + + stub_sso_list_accounts(inputs.account_id_select.expected_choices) + stub_sso_list_roles( + inputs.role_name_select.expected_choices, + expected_account_id=selected_account_id, + ) + + sso_cmd(args, parsed_globals) + self.assert_do_sso_login_call( + mock_do_sso_login, + botocore_session, + expected_session_name=inputs.session_prompt.answer, + expected_sso_region=inputs.sso_region_prompt.answer, + expected_start_url=inputs.start_url_prompt.answer, + expected_scopes=parse_sso_registration_scopes( + inputs.scopes_prompt.answer + ), + expected_force_refresh=True, + ) + assert_aws_config( + aws_config, + expected_lines=[ + f"[profile {inputs.profile_prompt.answer}]", + f"sso_session = {inputs.session_prompt.answer}", + f"sso_account_id = {selected_account_id}", + f"sso_role_name = {inputs.role_name_select.answer}", + f"region = {inputs.region_prompt.answer}", + f"output = {inputs.output_prompt.answer}", + f"[sso-session {inputs.session_prompt.answer}]", + f"sso_start_url = {inputs.start_url_prompt.answer}", + f"sso_region = {inputs.sso_region_prompt.answer}", + f"sso_registration_scopes = {inputs.scopes_prompt.answer}", + ], + ) + stdout = capsys.readouterr().out + assert "WARNING: Configuring using legacy format" not in stdout + assert f"aws s3 ls --profile {inputs.profile_prompt.answer}" in stdout + + def test_configure_sso_with_existing_sso_session( + self, + sso_cmd_factory, + ptk_stubber, + aws_config, + sso_stubber_factory, + stub_simple_single_item_sso_responses, + mock_do_sso_login, + args, + parsed_globals, + configure_sso_using_existing_session_inputs, + aws_config_lines_for_existing_sso_session, + account_id, + role_name, + existing_start_url, + existing_sso_region, + existing_scopes, + ): + write_aws_config( + aws_config, lines=aws_config_lines_for_existing_sso_session + ) + session = StubbedSession() + + inputs = configure_sso_using_existing_session_inputs + inputs.skip_account_and_role_selection() + ptk_stubber.user_inputs = inputs + + sso_stubber = sso_stubber_factory(session) + stub_simple_single_item_sso_responses( + account_id, role_name, sso_stubber + ) + sso_cmd = sso_cmd_factory(session=session) + sso_cmd(args, parsed_globals) + + self.assert_do_sso_login_call( + mock_do_sso_login, + session, + expected_session_name=inputs.session_prompt.answer, + expected_sso_region=existing_sso_region, + expected_start_url=existing_start_url, + expected_scopes=parse_sso_registration_scopes(existing_scopes), + ) + assert_aws_config( + aws_config, + expected_lines=aws_config_lines_for_existing_sso_session + + [ + f"[profile {inputs.profile_prompt.answer}]", + f"sso_session = {inputs.session_prompt.answer}", + f"sso_account_id = {account_id}", + f"sso_role_name = {role_name}", + f"region = {inputs.region_prompt.answer}", + f"output = {inputs.output_prompt.answer}", + ], + ) + + def test_configure_sso_reusing_existing_configuration( + self, + sso_cmd_factory, + ptk_stubber, + aws_config, + sso_stubber_factory, + stub_simple_single_item_sso_responses, + mock_do_sso_login, + args, + parsed_globals, + configure_sso_with_existing_defaults_inputs, + aws_config_lines_for_existing_profile_and_session, + account_id, + role_name, + existing_profile_name, + existing_start_url, + existing_sso_region, + existing_scopes, + ): + write_aws_config( + aws_config, lines=aws_config_lines_for_existing_profile_and_session + ) + session = StubbedSession(profile=existing_profile_name) + + inputs = configure_sso_with_existing_defaults_inputs + inputs.skip_account_and_role_selection() + inputs.clear_answers() + ptk_stubber.user_inputs = inputs + + sso_stubber = sso_stubber_factory(session) + stub_simple_single_item_sso_responses( + account_id, role_name, sso_stubber + ) + sso_cmd = sso_cmd_factory(session=session) + sso_cmd(args, parsed_globals) + + self.assert_do_sso_login_call( + mock_do_sso_login, + session, + expected_session_name=inputs.session_prompt.expected_default, + expected_sso_region=existing_sso_region, + expected_start_url=existing_start_url, + expected_scopes=parse_sso_registration_scopes(existing_scopes), + ) + assert_aws_config( + aws_config, + expected_lines=aws_config_lines_for_existing_profile_and_session, + ) + + def test_configure_sso_skips_account_role_config_when_no_access( + self, + sso_cmd, + ptk_stubber, + aws_config, + stub_sso_authorization_error, + mock_do_sso_login, + botocore_session, + args, + parsed_globals, + configure_sso_using_new_session_inputs, + capsys, + ): + inputs = configure_sso_using_new_session_inputs + inputs.skip_account_and_role_selection() + inputs.profile_prompt.expected_default = None + ptk_stubber.user_inputs = inputs + + stub_sso_authorization_error() + + sso_cmd(args, parsed_globals) + assert_aws_config( + aws_config, + expected_lines=[ + f"[profile {inputs.profile_prompt.answer}]", + f"sso_session = {inputs.session_prompt.answer}", + f"region = {inputs.region_prompt.answer}", + f"output = {inputs.output_prompt.answer}", + f"[sso-session {inputs.session_prompt.answer}]", + f"sso_start_url = {inputs.start_url_prompt.answer}", + f"sso_region = {inputs.sso_region_prompt.answer}", + f"sso_registration_scopes = {inputs.scopes_prompt.answer}", + ], + ) + stdout = capsys.readouterr().out + profile_answer = inputs.profile_prompt.answer + assert "Unable to list AWS accounts" in stdout + assert f"configured SSO for profile: {profile_answer}" in stdout + + def test_configure_sso_uses_profile_values_when_making_new_session( + self, + sso_cmd_factory, + ptk_stubber, + aws_config, + sso_stubber_factory, + stub_simple_single_item_sso_responses, + mock_do_sso_login, + args, + parsed_globals, + configure_sso_using_new_session_from_legacy_profile_inputs, + aws_config_lines_for_existing_legacy_profile, + account_id, + role_name, + existing_profile_name, + default_sso_scope, + ): + write_aws_config( + aws_config, lines=aws_config_lines_for_existing_legacy_profile + ) + session = StubbedSession(profile=existing_profile_name) + + inputs = configure_sso_using_new_session_from_legacy_profile_inputs + inputs.skip_account_and_role_selection() + ptk_stubber.user_inputs = inputs + + sso_stubber = sso_stubber_factory(session) + stub_simple_single_item_sso_responses( + account_id, role_name, sso_stubber + ) + sso_cmd = sso_cmd_factory(session=session) + sso_cmd(args, parsed_globals) + + self.assert_do_sso_login_call( + mock_do_sso_login, + session, + expected_session_name=inputs.session_prompt.answer, + expected_sso_region=inputs.sso_region_prompt.expected_default, + expected_start_url=inputs.start_url_prompt.expected_default, + expected_scopes=[default_sso_scope], + expected_force_refresh=True, + ) + assert_aws_config( + aws_config, + expected_lines=aws_config_lines_for_existing_legacy_profile + + [ + f"sso_session = {inputs.session_prompt.answer}", + f"[sso-session {inputs.session_prompt.answer}]", + f"sso_start_url = {inputs.start_url_prompt.expected_default}", + f"sso_region = {inputs.sso_region_prompt.expected_default}", + f"sso_registration_scopes = {default_sso_scope}", + ], + ) + + def test_configure_sso_suggests_values_from_sessions( + self, + sso_cmd_factory, + ptk_stubber, + aws_config, + sso_stubber_factory, + stub_simple_single_item_sso_responses, + aws_config_lines_for_existing_sso_session, + existing_sso_session, + existing_start_url, + args, + parsed_globals, + configure_sso_using_new_session_inputs, + account_id, + role_name, + all_sso_oidc_regions, + ): + write_aws_config( + aws_config, lines=aws_config_lines_for_existing_sso_session + ) + session = StubbedSession() + + inputs = configure_sso_using_new_session_inputs + inputs.skip_account_and_role_selection() + inputs.session_prompt.expected_completions = [existing_sso_session] + inputs.start_url_prompt.expected_completions = [existing_start_url] + inputs.sso_region_prompt.expected_completions = all_sso_oidc_regions + inputs.output_prompt.expected_completions = list( + CLI_OUTPUT_FORMATS.keys() + ) + ptk_stubber.user_inputs = inputs + + sso_stubber = sso_stubber_factory(session) + stub_simple_single_item_sso_responses( + account_id, role_name, sso_stubber + ) + sso_cmd = sso_cmd_factory(session=session) + assert sso_cmd(args, parsed_globals) == 0 + + +class TestConfigureSSOSessionCommand: + def test_new_sso_session( + self, + sso_session_cmd, + ptk_stubber, + aws_config, + configure_sso_session_inputs, + args, + parsed_globals, + capsys, + ): + inputs = configure_sso_session_inputs + ptk_stubber.user_inputs = inputs + sso_session_cmd(args, parsed_globals) + assert_aws_config( + aws_config, + expected_lines=[ + f"[sso-session {inputs.session_prompt.answer}]", + f"sso_start_url = {inputs.start_url_prompt.answer}", + f"sso_region = {inputs.sso_region_prompt.answer}", + f"sso_registration_scopes = {inputs.scopes_prompt.answer}", + ], + ) + expected_login = ( + f"aws sso login --sso-session {inputs.session_prompt.answer}" + ) + assert expected_login in capsys.readouterr().out + + def test_can_used_default_scope_for_new_session( + self, + sso_session_cmd, + ptk_stubber, + aws_config, + configure_sso_session_inputs, + args, + parsed_globals, + default_sso_scope, + ): + inputs = configure_sso_session_inputs + inputs.scopes_prompt.answer = "" + ptk_stubber.user_inputs = inputs + sso_session_cmd(args, parsed_globals) + assert_aws_config( + aws_config, + expected_lines=[ + f"[sso-session {inputs.session_prompt.answer}]", + f"sso_start_url = {inputs.start_url_prompt.answer}", + f"sso_region = {inputs.sso_region_prompt.answer}", + f"sso_registration_scopes = {default_sso_scope}", + ], + ) + + def test_reuse_existing_sso_session_configurations( + self, + sso_session_cmd_factory, + ptk_stubber, + aws_config, + aws_config_lines_for_existing_sso_session, + configure_sso_session_with_existing_defaults_inputs, + args, + parsed_globals, + existing_sso_session, + ): + write_aws_config( + aws_config, lines=aws_config_lines_for_existing_sso_session + ) + inputs = configure_sso_session_with_existing_defaults_inputs + inputs.clear_answers() + inputs.session_prompt.answer = existing_sso_session + ptk_stubber.user_inputs = inputs + + sso_session_cmd = sso_session_cmd_factory(session=StubbedSession()) + sso_session_cmd(args, parsed_globals) + assert_aws_config( + aws_config, expected_lines=aws_config_lines_for_existing_sso_session + ) + + def test_override_existing_sso_session_configurations( + self, + sso_session_cmd_factory, + ptk_stubber, + aws_config, + aws_config_lines_for_existing_sso_session, + configure_sso_session_with_existing_defaults_inputs, + args, + parsed_globals, + existing_sso_session, + ): + write_aws_config( + aws_config, lines=aws_config_lines_for_existing_sso_session + ) + inputs = configure_sso_session_with_existing_defaults_inputs + inputs.session_prompt.answer = existing_sso_session + ptk_stubber.user_inputs = inputs + + sso_session_cmd = sso_session_cmd_factory(session=StubbedSession()) + sso_session_cmd(args, parsed_globals) + assert_aws_config( + aws_config, + expected_lines=[ + f"[sso-session {existing_sso_session}]", + f"sso_start_url = {inputs.start_url_prompt.answer}", + f"sso_region = {inputs.sso_region_prompt.answer}", + f"sso_registration_scopes = {inputs.scopes_prompt.answer}", + ], + ) + class TestPTKPrompt(unittest.TestCase): def setUp(self): @@ -40,403 +1586,389 @@ def setUp(self): self.prompter = PTKPrompt(prompter=self.mock_prompter) def test_returns_input(self): - self.mock_prompter.return_value = 'new_value' - response = self.prompter.get_value('default_value', 'Prompt Text') - self.assertEqual(response, 'new_value') + self.mock_prompter.return_value = "new_value" + response = self.prompter.get_value("default_value", "Prompt Text") + self.assertEqual(response, "new_value") def test_user_hits_enter_returns_current(self): - self.mock_prompter.return_value = '' - response = self.prompter.get_value('default_value', 'Prompt Text') + self.mock_prompter.return_value = "" + response = self.prompter.get_value("default_value", "Prompt Text") # We convert the empty string to the default value - self.assertEqual(response, 'default_value') + self.assertEqual(response, "default_value") def assert_expected_completions(self, completions): # The order of the completion list can vary becuase it comes from the # dict's keys. Asserting that each expected completion is in the list _, kwargs = self.mock_prompter.call_args_list[0] - completer = kwargs['completer'] + completer = kwargs["completer"] self.assertEqual(len(completions), len(completer.words)) for completion in completions: self.assertIn(completion, completer.words) def assert_expected_meta_dict(self, meta_dict): _, kwargs = self.mock_prompter.call_args_list[0] - self.assertEqual(kwargs['completer'].meta_dict, meta_dict) + self.assertEqual(kwargs["completer"].meta_dict, meta_dict) def assert_expected_validator(self, validator): _, kwargs = self.mock_prompter.call_args_list[0] - self.assertEqual(kwargs['validator'], validator) + self.assertEqual(kwargs["validator"], validator) + + def assert_expected_toolbar(self, expected_toolbar): + _, kwargs = self.mock_prompter.call_args_list[0] + self.assertEqual(kwargs["bottom_toolbar"], expected_toolbar) + + def assert_expected_prompt_message(self, expected_message): + args, _ = self.mock_prompter.call_args_list[0] + self.assertEqual(args[0], expected_message) def test_handles_list_completions(self): - completions = ['a', 'b'] - self.prompter.get_value('', '', completions=completions) + completions = ["a", "b"] + self.prompter.get_value("", "", completions=completions) self.assert_expected_completions(completions) def test_handles_dict_completions(self): descriptions = { - 'a': 'the letter a', - 'b': 'the letter b', + "a": "the letter a", + "b": "the letter b", } - expected_completions = ['a', 'b'] - self.prompter.get_value('', '', completions=descriptions) + expected_completions = ["a", "b"] + self.prompter.get_value("", "", completions=descriptions) self.assert_expected_completions(expected_completions) self.assert_expected_meta_dict(descriptions) def test_passes_validator(self): validator = DummyValidator() - self.prompter.get_value('', '', validator=validator) + self.prompter.get_value("", "", validator=validator) self.assert_expected_validator(validator) def test_strips_extra_whitespace(self): - self.mock_prompter.return_value = ' no_whitespace \t ' - response = self.prompter.get_value('default_value', 'Prompt Text') - self.assertEqual(response, 'no_whitespace') + self.mock_prompter.return_value = " no_whitespace \t " + response = self.prompter.get_value("default_value", "Prompt Text") + self.assertEqual(response, "no_whitespace") + def test_can_provide_toolbar(self): + toolbar = "Toolbar content" + self.prompter.get_value("default_value", "Prompt Text", toolbar=toolbar) + self.assert_expected_toolbar(toolbar) -class TestStartUrlValidator(unittest.TestCase): - def setUp(self): - self.document = mock.Mock(spec=Document) - self.validator = StartUrlValidator() - - def _validate_text(self, text): - self.document.text = text - self.validator.validate(self.document) - - def assert_text_not_allowed(self, text): - with self.assertRaises(ValidationError): - self._validate_text(text) - - def test_disallowed_text(self): - not_start_urls = [ - '', - 'd-abc123', - 'foo bar baz', - ] - for text in not_start_urls: - self.assert_text_not_allowed(text) - - def test_allowed_text(self): - valid_start_urls = [ - 'https://d-abc123.awsapps.com/start', - 'https://d-abc123.awsapps.com/start#', - 'https://d-abc123.awsapps.com/start/', - 'https://d-abc123.awsapps.com/start-beta', - 'https://start.url', - ] - for text in valid_start_urls: - self._validate_text(text) - - def test_allows_empty_string_if_default(self): - default = 'https://some.default' - self.validator = StartUrlValidator(default) - self._validate_text('') - - -class TestConfigureSSOCommand(unittest.TestCase): - def setUp(self): - self.global_args = mock.Mock() - self._session = Session() - self.sso_client = self._session.create_client( - 'sso', - region_name='us-west-2', - ) - self.sso_stub = Stubber(self.sso_client) - self.profile = 'a-profile' - self.scoped_config = {} - self.full_config = { - 'profiles': { - self.profile: self.scoped_config - } - } - self.mock_session = mock.Mock(spec=Session) - self.mock_session.get_scoped_config.return_value = self.scoped_config - self.mock_session.emit_first_non_none_response.return_value = None - self.mock_session.full_config = self.full_config - self.mock_session.create_client.return_value = self.sso_client - self.mock_session.profile = self.profile - self.config_path = '/some/path' - self.session_config = { - 'config_file': self.config_path, - } - self.mock_session.get_config_variable = self.session_config.get - self.mock_session.get_available_regions.return_value = ['us-east-1'] - self.token_cache = {} - self.writer = mock.Mock(spec=ConfigFileWriter) - self.prompter = mock.Mock(spec=PTKPrompt) - self.selector = mock.Mock(spec=select_menu) - self.region = 'us-west-2' - self.output = 'json' - self.sso_region = 'us-east-1' - self.start_url = 'https://d-92671207e4.awsapps.com/start' - self.account_id = '0123456789' - self.role_name = 'roleA' - self.expires_at = datetime.now(tzlocal()) + timedelta(hours=24) - self.access_token = { - 'accessToken': 'access.token.string', - 'expiresAt': self.expires_at, - } - self.do_sso_login_mock = mock.Mock(spec=do_sso_login) - self.do_sso_login_mock.return_value = self.access_token - self.configure_sso = ConfigureSSOCommand( - self.mock_session, - prompter=self.prompter, - selector=self.selector, - config_writer=self.writer, - sso_token_cache=self.token_cache, - sso_login=self.do_sso_login_mock, - ) - - def _add_list_accounts_response(self, accounts): - params = { - 'accessToken': self.access_token['accessToken'], - } - response = { - 'accountList': accounts, - } - self.sso_stub.add_response('list_accounts', response, params) + def test_can_provide_prompt_format(self): + self.prompter.get_value( + "default_value", + "Prompt Text", + prompt_fmt="{prompt_text} [default: {current_value}]: ", + ) + self.assert_expected_prompt_message( + "Prompt Text [default: default_value]: " + ) + + +class TestSSOSessionConfigurationPrompter: + def get_toolbar_content(self, toolbar_render): + formatted_text = toolbar_render() + content_lines = [line for _, line in formatted_text] + return "".join(content_lines) - def _add_list_account_roles_response(self, roles): - params = { - 'accountId': self.account_id, - 'accessToken': self.access_token['accessToken'], + def test_prompt_for_session_name(self, sso_config_prompter, ptk_stubber): + ptk_stubber.user_inputs = UserInputs( + session_prompt=RequiredSessionPrompt("dev") + ) + assert sso_config_prompter.prompt_for_sso_session() == "dev" + assert sso_config_prompter.sso_session == "dev" + + def test_prompt_for_session_name_opt_out_of_required( + self, sso_config_prompter, ptk_stubber + ): + ptk_stubber.user_inputs = UserInputs( + session_prompt=RecommendedSessionPrompt("") + ) + answer = sso_config_prompter.prompt_for_sso_session(required=False) + assert answer is None + assert sso_config_prompter.sso_session is None + + def test_manually_set_session_name(self, sso_config_prompter): + sso_config_prompter.sso_session = "override" + assert sso_config_prompter.sso_session == "override" + + def test_setting_session_name_updates_sso_config( + self, + sso_config_prompter_factory, + aws_config, + aws_config_lines_for_existing_sso_session, + existing_sso_session, + existing_sso_region, + existing_start_url, + existing_scopes, + ): + write_aws_config(aws_config, aws_config_lines_for_existing_sso_session) + session = StubbedSession() + sso_config_prompter = sso_config_prompter_factory(session=session) + sso_config_prompter.sso_session = existing_sso_session + assert sso_config_prompter.sso_session_config == { + "sso_region": existing_sso_region, + "sso_start_url": existing_start_url, + "sso_registration_scopes": existing_scopes, } - response = { - 'roleList': roles, + + def test_prompt_for_session_suggests_existing_sessions( + self, + sso_config_prompter_factory, + ptk_stubber, + aws_config, + aws_config_lines_for_existing_sso_session, + existing_sso_session, + ): + write_aws_config(aws_config, aws_config_lines_for_existing_sso_session) + session = StubbedSession() + sso_config_prompter = sso_config_prompter_factory(session=session) + + ptk_stubber.user_inputs = UserInputs( + session_prompt=RequiredSessionPrompt( + "dev", expected_completions=[existing_sso_session] + ), + ) + assert sso_config_prompter.prompt_for_sso_session() == "dev" + + def test_prompt_for_session_name_shows_session_config_in_toolbar( + self, + sso_config_prompter_factory, + ptk_stubber, + mock_ptk_app, + aws_config, + aws_config_lines_for_existing_sso_session, + existing_sso_session, + existing_start_url, + existing_sso_region, + existing_scopes, + ): + write_aws_config(aws_config, aws_config_lines_for_existing_sso_session) + session = StubbedSession() + mock_ptk_prompt = mock.Mock(ptk_prompt) + prompter = PTKPrompt(mock_ptk_prompt) + sso_config_prompter = sso_config_prompter_factory( + session=session, + prompt=prompter, + ) + sso_config_prompter.prompt_for_sso_session() + toolbar_render = mock_ptk_prompt.call_args_list[0][1]["bottom_toolbar"] + mock_ptk_app.current_buffer.document.text = existing_sso_session + mock_ptk_app.output.get_size.return_value.columns = 1 + actual_toolbar_content = self.get_toolbar_content(toolbar_render) + expected_sso_config_in_toolbar = json.dumps( + { + "sso_start_url": existing_start_url, + "sso_region": existing_sso_region, + "sso_registration_scopes": existing_scopes, + }, + indent=2, + ) + assert expected_sso_config_in_toolbar in actual_toolbar_content + + def test_prompt_for_start_url(self, sso_config_prompter, ptk_stubber): + url = "https://start.here" + ptk_stubber.user_inputs = UserInputs( + start_url_prompt=StartUrlPrompt(url) + ) + assert sso_config_prompter.prompt_for_sso_start_url() == url + assert sso_config_prompter.sso_session_config == {"sso_start_url": url} + + def test_prompt_for_start_url_reuse_existing_configuration( + self, sso_config_prompter, ptk_stubber, existing_start_url + ): + sso_config_prompter.sso_session_config[ + "sso_start_url" + ] = existing_start_url + ptk_stubber.user_inputs = UserInputs( + start_url_prompt=StartUrlPrompt( + "", expected_default=existing_start_url + ) + ) + answer = sso_config_prompter.prompt_for_sso_start_url() + assert answer == existing_start_url + assert sso_config_prompter.sso_session_config == { + "sso_start_url": existing_start_url } - self.sso_stub.add_response('list_account_roles', response, params) - - def _add_prompt_responses(self): - self.prompter.get_value.side_effect = [ - self.start_url, - self.sso_region, - self.region, - self.output, - ] - - def _add_simple_single_item_responses(self): - selected_account = { - 'accountId': self.account_id, - 'emailAddress': 'account@site.com', + + def test_prompt_for_start_url_suggests_previously_used_start_urls( + self, + sso_config_prompter_factory, + ptk_stubber, + aws_config, + aws_config_lines_for_existing_sso_session, + existing_start_url, + ): + write_aws_config(aws_config, aws_config_lines_for_existing_sso_session) + session = StubbedSession() + url = "https://start.here" + ptk_stubber.user_inputs = UserInputs( + start_url_prompt=StartUrlPrompt( + answer=url, expected_completions=[existing_start_url] + ) + ) + sso_config_prompter = sso_config_prompter_factory(session=session) + answer = sso_config_prompter.prompt_for_sso_start_url() + assert answer == url + + def test_prompt_for_sso_region(self, sso_config_prompter, ptk_stubber): + sso_region = "us-west-2" + ptk_stubber.user_inputs = UserInputs( + sso_region_prompt=SSORegionPrompt(sso_region) + ) + assert sso_config_prompter.prompt_for_sso_region() == sso_region + assert sso_config_prompter.sso_session_config == { + "sso_region": sso_region } - self._add_list_accounts_response([selected_account]) - self._add_list_account_roles_response([{'roleName': self.role_name}]) - - def assert_config_updates(self, config=None): - if config is None: - config = { - '__section__': 'profile %s' % self.profile, - 'sso_start_url': self.start_url, - 'sso_region': self.sso_region, - 'sso_account_id': self.account_id, - 'sso_role_name': self.role_name, - 'region': self.region, - 'output': self.output, - } - self.writer.update_config.assert_called_with(config, self.config_path) - - def test_basic_configure_sso_flow(self): - self._add_prompt_responses() - selected_account = { - 'accountId': self.account_id, - 'emailAddress': 'account@site.com', + + def test_prompt_for_sso_region_reuse_existing_configuration( + self, sso_config_prompter, ptk_stubber, existing_sso_region + ): + sso_config_prompter.sso_session_config[ + "sso_region" + ] = existing_sso_region + ptk_stubber.user_inputs = UserInputs( + sso_region_prompt=SSORegionPrompt( + "", expected_default=existing_sso_region + ) + ) + answer = sso_config_prompter.prompt_for_sso_region() + assert answer == existing_sso_region + assert sso_config_prompter.sso_session_config == { + "sso_region": existing_sso_region } - self.selector.side_effect = [ - selected_account, - self.role_name, - ] - accounts = [ - selected_account, - {'accountId': '1234567890', 'emailAddress': 'account2@site.com'}, - ] - self._add_list_accounts_response(accounts) - roles = [ - {'roleName': self.role_name}, - {'roleName': 'roleB'}, - ] - self._add_list_account_roles_response(roles) - with self.sso_stub: - self.configure_sso(args=[], parsed_globals=self.global_args) - self.sso_stub.assert_no_pending_responses() - self.assert_config_updates() - - def test_single_account_single_role_flow_no_browser(self): - self._add_prompt_responses() - self._add_simple_single_item_responses() - with self.sso_stub: - self.configure_sso( - args=['--no-browser'], - parsed_globals=self.global_args, + + def test_prompt_for_sso_region_suggests_all_valid_sso_oidc_regions( + self, sso_config_prompter, ptk_stubber, all_sso_oidc_regions + ): + sso_region = "us-west-2" + ptk_stubber.user_inputs = UserInputs( + sso_region_prompt=SSORegionPrompt( + sso_region, expected_completions=all_sso_oidc_regions + ), + ) + assert sso_config_prompter.prompt_for_sso_region() == sso_region + + def test_prompt_for_scopes( + self, sso_config_prompter, ptk_stubber, default_sso_scope + ): + scopes = "scope-1, scope-2" + parsed_scopes = ["scope-1", "scope-2"] + ptk_stubber.user_inputs = UserInputs( + scopes_prompt=ScopesPrompt( + scopes, expected_default=default_sso_scope ) - self.sso_stub.assert_no_pending_responses() - self.assert_config_updates() - _, _, login_kwargs = self.do_sso_login_mock.mock_calls[0] - auth_handler = login_kwargs['on_pending_authorization'] - self.assertIsInstance(auth_handler, PrintOnlyHandler) - # Account / Role should be auto selected if only one is returned - self.assertEqual(self.selector.call_count, 0) - - def test_single_account_single_role_flow(self): - self._add_prompt_responses() - self._add_simple_single_item_responses() - with self.sso_stub: - self.configure_sso(args=[], parsed_globals=self.global_args) - self.sso_stub.assert_no_pending_responses() - self.assert_config_updates() - # Account / Role should be auto selected if only one is returned - self.assertEqual(self.selector.call_count, 0) - - def test_no_accounts_flow_raises_error(self): - self.prompter.get_value.side_effect = [self.start_url, self.sso_region] - self._add_list_accounts_response([]) - with self.assertRaises(RuntimeError): - with self.sso_stub: - self.configure_sso(args=[], parsed_globals=self.global_args) - self.sso_stub.assert_no_pending_responses() - - def test_no_roles_flow_raises_error(self): - self._add_prompt_responses() - selected_account = { - 'accountId': self.account_id, - 'emailAddress': 'account@site.com', + ) + answer = sso_config_prompter.prompt_for_sso_registration_scopes() + assert answer == parsed_scopes + assert sso_config_prompter.sso_session_config == { + "sso_registration_scopes": scopes } - self._add_list_accounts_response([selected_account]) - self._add_list_account_roles_response([]) - with self.assertRaises(RuntimeError): - with self.sso_stub: - self.configure_sso(args=[], parsed_globals=self.global_args) - self.sso_stub.assert_no_pending_responses() - - def assert_default_prompt_args(self, defaults): - calls = self.prompter.get_value.call_args_list - self.assertEqual(len(calls), len(defaults)) - for call, default in zip(calls, defaults): - # The default to the prompt call is the first positional param - self.assertEqual(call[0][0], default) - - def assert_prompt_completions(self, completions): - calls = self.prompter.get_value.call_args_list - self.assertEqual(len(calls), len(completions)) - for call, completions in zip(calls, completions): - _, kwargs = call - self.assertEqual(kwargs['completions'], completions) - - def test_defaults_to_scoped_config(self): - self.scoped_config['sso_start_url'] = 'default-url' - self.scoped_config['sso_region'] = 'default-sso-region' - self.scoped_config['region'] = 'default-region' - self.scoped_config['output'] = 'default-output' - self._add_prompt_responses() - self._add_simple_single_item_responses() - with self.sso_stub: - self.configure_sso(args=[], parsed_globals=self.global_args) - self.sso_stub.assert_no_pending_responses() - self.assert_config_updates() - expected_defaults = [ - 'default-url', - 'default-sso-region', - 'default-region', - 'default-output', - ] - self.assert_default_prompt_args(expected_defaults) - - def test_handles_no_profile(self): - expected_profile = 'profile-a' - self.profile = None - self.mock_session.profile = None - self.configure_sso = ConfigureSSOCommand( - self.mock_session, - prompter=self.prompter, - selector=self.selector, - config_writer=self.writer, - sso_token_cache=self.token_cache, - sso_login=self.do_sso_login_mock, - ) - # If there is no profile, it will be prompted for as the last value - self.prompter.get_value.side_effect = [ - self.start_url, - self.sso_region, - self.region, - self.output, - expected_profile, - ] - self._add_simple_single_item_responses() - with self.sso_stub: - self.configure_sso(args=[], parsed_globals=self.global_args) - self.sso_stub.assert_no_pending_responses() - self.profile = expected_profile - self.assert_config_updates() - - def test_handles_non_existant_profile(self): - not_found_exception = ProfileNotFound(profile=self.profile) - self.mock_session.get_scoped_config.side_effect = not_found_exception - self.configure_sso = ConfigureSSOCommand( - self.mock_session, - prompter=self.prompter, - selector=self.selector, - config_writer=self.writer, - sso_token_cache=self.token_cache, - sso_login=self.do_sso_login_mock, - ) - self._add_prompt_responses() - self._add_simple_single_item_responses() - with self.sso_stub: - self.configure_sso(args=[], parsed_globals=self.global_args) - self.sso_stub.assert_no_pending_responses() - self.assert_config_updates() - - def test_cli_config_is_none_not_written(self): - self.prompter.get_value.side_effect = [ - self.start_url, - self.sso_region, - # The CLI region and output format shouldn't be written - # to the config as they are None - None, - None - ] - self._add_simple_single_item_responses() - with self.sso_stub: - self.configure_sso(args=[], parsed_globals=self.global_args) - self.sso_stub.assert_no_pending_responses() - expected_config = { - '__section__': 'profile %s' % self.profile, - 'sso_start_url': self.start_url, - 'sso_region': self.sso_region, - 'sso_account_id': self.account_id, - 'sso_role_name': self.role_name, + + def test_prompt_for_scopes_reuse_existing_configuration( + self, sso_config_prompter, ptk_stubber, existing_scopes + ): + sso_config_prompter.sso_session_config[ + "sso_registration_scopes" + ] = existing_scopes + ptk_stubber.user_inputs = UserInputs( + scopes_prompt=ScopesPrompt("", expected_default=existing_scopes) + ) + answer = sso_config_prompter.prompt_for_sso_registration_scopes() + assert answer == parse_sso_registration_scopes(existing_scopes) + assert sso_config_prompter.sso_session_config == { + "sso_registration_scopes": existing_scopes } - self.assert_config_updates(config=expected_config) - def test_prompts_suggest_values(self): - self.full_config['profiles']['another_profile'] = { - 'sso_start_url': self.start_url, + def test_prompt_for_scopes_used_defaults_account_scope( + self, sso_config_prompter, ptk_stubber, default_sso_scope + ): + ptk_stubber.user_inputs = UserInputs( + scopes_prompt=ScopesPrompt("", expected_default=default_sso_scope) + ) + answer = sso_config_prompter.prompt_for_sso_registration_scopes() + assert answer == [default_sso_scope] + assert sso_config_prompter.sso_session_config == { + "sso_registration_scopes": default_sso_scope } - self._add_prompt_responses() - self._add_simple_single_item_responses() - with self.sso_stub: - self.configure_sso(args=[], parsed_globals=self.global_args) - self.sso_stub.assert_no_pending_responses() - expected_start_urls = [self.start_url] - expected_sso_regions = ['us-east-1'] - expected_cli_regions = None - expected_cli_outputs = list(CLI_OUTPUT_FORMATS.keys()) - expected_completions = [ - expected_start_urls, - expected_sso_regions, - expected_cli_regions, - expected_cli_outputs, - ] - self.assert_prompt_completions(expected_completions) + + def test_prompt_for_scopes_suggest_known_and_previously_used_scopes( + self, + sso_config_prompter_factory, + ptk_stubber, + aws_config, + aws_config_lines_for_existing_sso_session, + default_sso_scope, + existing_scopes, + ): + write_aws_config(aws_config, aws_config_lines_for_existing_sso_session) + session = StubbedSession() + ptk_stubber.user_inputs = UserInputs( + scopes_prompt=ScopesPrompt( + "", + expected_default=default_sso_scope, + expected_completions=[default_sso_scope] + + parse_sso_registration_scopes(existing_scopes), + ) + ) + sso_config_prompter = sso_config_prompter_factory(session=session) + answer = sso_config_prompter.prompt_for_sso_registration_scopes() + assert answer == [default_sso_scope] + + +def passes_validator(validator, text): + document = mock.Mock(spec=Document) + document.text = text + try: + validator.validate(document) + except ValidationError: + return False + return True + + +@pytest.mark.parametrize( + "validator_cls,input_value,default,is_valid", + [ + # StartUrlValidator cases + (StartUrlValidator, "https://d-abc123.awsapps.com/start", None, True), + (StartUrlValidator, "https://d-abc123.awsapps.com/start#", None, True), + (StartUrlValidator, "https://d-abc123.awsapps.com/start/", None, True), + ( + StartUrlValidator, + "https://d-abc123.awsapps.com/start-beta", + None, + True, + ), + (StartUrlValidator, "https://start.url", None, True), + (StartUrlValidator, "", "https://some.default", True), + (StartUrlValidator, "", None, False), + (StartUrlValidator, "d-abc123", None, False), + (StartUrlValidator, "foo bar baz", None, False), + # RequiredInputValidator cases + (RequiredInputValidator, "input-value", "default-value", True), + (RequiredInputValidator, "input-value", None, True), + (RequiredInputValidator, "", "default-value", True), + (RequiredInputValidator, "", None, False), + # ScopesValidator cases + (ScopesValidator, "sso:account:access", "sso:account:access", True), + (ScopesValidator, "", "sso:account:access", True), + (ScopesValidator, "value-1, value-2", None, True), + (ScopesValidator, " value-1, value-2 ", None, True), + (ScopesValidator, "value-1 value-2", None, False), + (ScopesValidator, "value-1, value-2 value3", None, False), + ], +) +def test_validators(validator_cls, input_value, default, is_valid): + validator = validator_cls(default) + assert passes_validator(validator, input_value) == is_valid class TestDisplayAccount(unittest.TestCase): def setUp(self): - self.account_id = '1234' - self.email_address = 'test@test.com' - self.account_name = 'FooBar' + self.account_id = "1234" + self.email_address = "test@test.com" + self.account_name = "FooBar" self.account = { - 'accountId': self.account_id, - 'emailAddress': self.email_address, - 'accountName': self.account_name, + "accountId": self.account_id, + "emailAddress": self.email_address, + "accountName": self.account_name, } def test_display_account_all_fields(self): @@ -446,22 +1978,22 @@ def test_display_account_all_fields(self): self.assertIn(self.account_id, account_str) def test_display_account_missing_email(self): - del self.account['emailAddress'] + del self.account["emailAddress"] account_str = display_account(self.account) self.assertIn(self.account_name, account_str) self.assertNotIn(self.email_address, account_str) self.assertIn(self.account_id, account_str) def test_display_account_missing_name(self): - del self.account['accountName'] + del self.account["accountName"] account_str = display_account(self.account) self.assertNotIn(self.account_name, account_str) self.assertIn(self.email_address, account_str) self.assertIn(self.account_id, account_str) def test_display_account_missing_name_and_email(self): - del self.account['accountName'] - del self.account['emailAddress'] + del self.account["accountName"] + del self.account["emailAddress"] account_str = display_account(self.account) self.assertNotIn(self.account_name, account_str) self.assertNotIn(self.email_address, account_str) diff --git a/tests/unit/customizations/sso/test_utils.py b/tests/unit/customizations/sso/test_utils.py index 0bc85a28f40d..47f5762b3d8c 100644 --- a/tests/unit/customizations/sso/test_utils.py +++ b/tests/unit/customizations/sso/test_utils.py @@ -12,6 +12,9 @@ # language governing permissions and limitations under the License. import os import webbrowser + +import pytest + from awscli.testutils import mock from awscli.testutils import unittest @@ -19,12 +22,30 @@ from botocore.exceptions import ClientError from awscli.compat import StringIO +from awscli.customizations.sso.utils import parse_sso_registration_scopes from awscli.customizations.sso.utils import do_sso_login from awscli.customizations.sso.utils import OpenBrowserHandler from awscli.customizations.sso.utils import PrintOnlyHandler from awscli.customizations.sso.utils import open_browser_with_original_ld_path +@pytest.mark.parametrize( + 'raw_scopes, parsed_scopes', + [ + ('scope', ['scope']), + (' scope ', ['scope']), + ('', []), + ('scope, ', ['scope']), + ('scope-1,scope-2', ['scope-1', 'scope-2']), + ('scope-1, scope-2', ['scope-1', 'scope-2']), + (' scope-1, scope-2 ', ['scope-1', 'scope-2']), + ('scope-1,scope-2,scope-3', ['scope-1', 'scope-2', 'scope-3']) + ] +) +def test_parse_registration_scopes(raw_scopes, parsed_scopes): + assert parse_sso_registration_scopes(raw_scopes) == parsed_scopes + + class TestDoSSOLogin(unittest.TestCase): def setUp(self): self.region = 'us-west-2' diff --git a/tests/utils/botocore/__init__.py b/tests/utils/botocore/__init__.py index 88cdb0635188..964ec88232d7 100644 --- a/tests/utils/botocore/__init__.py +++ b/tests/utils/botocore/__init__.py @@ -506,6 +506,14 @@ def __init__(self, *args, **kwargs): self._cached_clients = {} self._client_stubs = {} + @property + def cached_clients(self): + return self._cached_clients + + @property + def client_stubs(self): + return self._client_stubs + def create_client(self, service_name, *args, **kwargs): if service_name not in self._cached_clients: client = self._create_stubbed_client(service_name, *args, **kwargs)