Skip to content

Commit

Permalink
Add POC commit of enabling SSO session for cred provider
Browse files Browse the repository at this point in the history
This commit is only a prototype to get a fully working
end-to-end example of a fully configure SSO session environment.

REMOVE THIS COMMIT PRIOR TO MERGING
  • Loading branch information
kyleknap committed Oct 26, 2022
1 parent 1c270ba commit a3cf58e
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 81 deletions.
183 changes: 106 additions & 77 deletions awscli/botocore/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@
UnauthorizedSSOTokenError,
UnknownCredentialError,
)
from botocore.tokens import SSOTokenProvider
from botocore.utils import (
ContainerMetadataFetcher,
FileWebIdentityTokenLoader,
InstanceMetadataFetcher,
JSONFileCache,
SSOTokenLoader,
original_ld_library_path,
parse_key_val_file,
Expand Down Expand Up @@ -211,6 +213,7 @@ def _create_sso_provider(self, profile_name):
profile_name=profile_name,
cache=self._cache,
token_cache=self._sso_token_cache,
token_provider=SSOTokenProvider(self._session)
)


Expand Down Expand Up @@ -282,57 +285,6 @@ def __call__(self):
return _Refresher(actual_refresh)


class JSONFileCache(object):
"""JSON file cache.
This provides a dict like interface that stores JSON serializable
objects.
The objects are serialized to JSON and stored in a file. These
values can be retrieved at a later time.
"""

CACHE_DIR = os.path.expanduser(os.path.join('~', '.aws', 'boto', 'cache'))

def __init__(self, working_dir=CACHE_DIR, dumps_func=None):
self._working_dir = working_dir
if dumps_func is None:
dumps_func = self._default_dumps
self._dumps = dumps_func

def _default_dumps(self, obj):
return json.dumps(obj, default=_serialize_if_needed)

def __contains__(self, cache_key):
actual_key = self._convert_cache_key(cache_key)
return os.path.isfile(actual_key)

def __getitem__(self, cache_key):
"""Retrieve value from a cache key."""
actual_key = self._convert_cache_key(cache_key)
try:
with open(actual_key) as f:
return json.load(f)
except (OSError, ValueError, IOError):
raise KeyError(cache_key)

def __setitem__(self, cache_key, value):
full_key = self._convert_cache_key(cache_key)
try:
file_content = self._dumps(value)
except (TypeError, ValueError):
raise ValueError("Value cannot be cached, must be "
"JSON serializable: %s" % value)
if not os.path.isdir(self._working_dir):
os.makedirs(self._working_dir)
with os.fdopen(os.open(full_key,
os.O_WRONLY | os.O_CREAT, 0o600), 'w') as f:
f.truncate()
f.write(file_content)

def _convert_cache_key(self, cache_key):
full_path = os.path.join(self._working_dir, cache_key + '.json')
return full_path


class Credentials(object):
"""
Holds the credentials needed to authenticate requests.
Expand Down Expand Up @@ -2000,15 +1952,18 @@ def load_credentials(self):
class SSOCredentialFetcher(CachedCredentialFetcher):
_UTC_DATE_FORMAT = '%Y-%m-%dT%H:%M:%SZ'

def __init__(self, start_url, sso_region, role_name, account_id,
client_creator, token_loader=None, cache=None,
expiry_window_seconds=None):
def __init__(self, sso_region, role_name, account_id, client_creator,
start_url=None, token_loader=None,
cache=None, expiry_window_seconds=None,
token_provider=None, sso_session_name=None):
self._client_creator = client_creator
self._sso_region = sso_region
self._role_name = role_name
self._account_id = account_id
self._start_url = start_url
self._token_loader = token_loader
self._token_provider = token_provider
self._sso_session_name = sso_session_name
super(SSOCredentialFetcher, self).__init__(
cache, expiry_window_seconds
)
Expand All @@ -2019,10 +1974,13 @@ def _create_cache_key(self):
The cache key is intended to be compatible with file names.
"""
args = {
'startUrl': self._start_url,
'roleName': self._role_name,
'accountId': self._account_id,
}
if self._sso_session_name:
args['sessionName'] = self._sso_session_name
else:
args['startUrl'] = self._start_url
# NOTE: It would be good to hoist this cache key construction logic
# into the CachedCredentialFetcher class as we should be consistent.
# Unfortunately, the current assume role fetchers that sub class don't
Expand All @@ -2046,11 +2004,15 @@ def _get_credentials(self):
)
client = self._client_creator('sso', config=config)

token_dict = self._token_loader(self._start_url)
if self._token_provider:
token = self._token_provider.load_token().get_frozen_token().token
else:
token = self._token_loader(self._start_url)['accessToken']

kwargs = {
'roleName': self._role_name,
'accountId': self._account_id,
'accessToken': token_dict['accessToken'],
'accessToken': token,
}
try:
response = client.get_role_credentials(**kwargs)
Expand All @@ -2076,18 +2038,21 @@ class SSOProvider(CredentialProvider):
_SSO_TOKEN_CACHE_DIR = os.path.expanduser(
os.path.join('~', '.aws', 'sso', 'cache')
)
_SSO_CONFIG_VARS = [
'sso_start_url',
'sso_region',
_PROFILE_ONLY_REQUIRED_CONFIG_VARS = [
'sso_role_name',
'sso_account_id',
]
_OTHER_REQUIRED_CONFIG_VARS = [
'sso_start_url',
'sso_region',
]

def __init__(self, load_config, client_creator, profile_name,
cache=None, token_cache=None):
cache=None, token_cache=None, token_provider=None):
if token_cache is None:
token_cache = JSONFileCache(self._SSO_TOKEN_CACHE_DIR)
self._token_cache = token_cache
self._token_provider = token_provider
if cache is None:
cache = {}
self.cache = cache
Expand All @@ -2102,17 +2067,29 @@ def _load_sso_config(self):
profile_config = profiles.get(self._profile_name, {})

# Role name & Account ID indicate the cred provider should be used
sso_cred_vars = ('sso_role_name', 'sso_account_id')
if all(c not in profile_config for c in sso_cred_vars):
if all(c not in profile_config for c in self._PROFILE_ONLY_REQUIRED_CONFIG_VARS):
return None

config = {}
missing_config_vars = []
for config_var in self._SSO_CONFIG_VARS:
if config_var in profile_config:
config[config_var] = profile_config[config_var]
else:
missing_config_vars.append(config_var)
if 'sso_session' in profile_config:
self._collect_sso_session_config_vars(
full_config=loaded_config,
sso_session_name=profile_config['sso_session'],
sso_config=config,
missing=missing_config_vars
)
else:
self._collect_legacy_profile_config_vars(
profile_config=profile_config,
sso_config=config,
missing=missing_config_vars,
)
self._collect_account_id_and_role_name(
profile_config=profile_config,
sso_config=config,
missing=missing_config_vars,
)

if missing_config_vars:
missing = ', '.join(missing_config_vars)
Expand All @@ -2125,21 +2102,73 @@ def _load_sso_config(self):

return config

def _collect_sso_session_config_vars(
self, full_config, sso_session_name, sso_config, missing
):
sso_sessions = full_config.get('sso_sessions', {})
if sso_session_name not in sso_sessions:
raise InvalidConfigError(
error_msg=(
f'The specified sso-session does not exist: '
f'"{sso_session_name}"'
)
)
sso_config['sso_session'] = sso_session_name
self._collect_config_vars(
source_config=sso_sessions[sso_session_name],
required=self._OTHER_REQUIRED_CONFIG_VARS,
config=sso_config,
missing=missing,
)

def _collect_legacy_profile_config_vars(
self, profile_config, sso_config, missing
):
self._collect_config_vars(
source_config=profile_config,
required=self._OTHER_REQUIRED_CONFIG_VARS,
config=sso_config,
missing=missing,
)

def _collect_account_id_and_role_name(
self, profile_config, sso_config, missing
):
self._collect_config_vars(
source_config=profile_config,
required=self._PROFILE_ONLY_REQUIRED_CONFIG_VARS,
config=sso_config,
missing=missing,
)

def _collect_config_vars(
self, source_config, required, config, missing
):
for config_var in required:
if config_var in source_config:
config[config_var] = source_config[config_var]
else:
missing.append(config_var)

def load(self):
sso_config = self._load_sso_config()
if not sso_config:
return None

sso_fetcher = SSOCredentialFetcher(
sso_config['sso_start_url'],
sso_config['sso_region'],
sso_config['sso_role_name'],
sso_config['sso_account_id'],
self._client_creator,
token_loader=SSOTokenLoader(cache=self._token_cache),
cache=self.cache,
)
fetcher_kwargs = {
'start_url': sso_config['sso_start_url'],
'sso_region': sso_config['sso_region'],
'role_name': sso_config['sso_role_name'],
'account_id': sso_config['sso_account_id'],
'client_creator': self._client_creator,
'token_loader': SSOTokenLoader(cache=self._token_cache),
'cache': self.cache,
}
if 'sso_session' in sso_config:
fetcher_kwargs['sso_session_name'] = sso_config['sso_session']
fetcher_kwargs['token_provider'] = self._token_provider

sso_fetcher = SSOCredentialFetcher(**fetcher_kwargs)
return DeferredRefreshableCredentials(
method=self.METHOD,
refresh_using=sso_fetcher.fetch_credentials,
Expand Down
3 changes: 1 addition & 2 deletions awscli/botocore/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,12 @@
from botocore import UNSIGNED
from botocore.compat import total_seconds
from botocore.config import Config
from botocore.credentials import JSONFileCache
from botocore.exceptions import (
ClientError,
InvalidConfigError,
TokenRetrievalError,
)
from botocore.utils import CachedProperty, SSOTokenLoader
from botocore.utils import CachedProperty, JSONFileCache, SSOTokenLoader

logger = logging.getLogger(__name__)

Expand Down
57 changes: 57 additions & 0 deletions awscli/botocore/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2495,6 +2495,63 @@ def __call__(self):
with self._open(self._web_identity_token_path) as token_file:
return token_file.read()

class JSONFileCache(object):
"""JSON file cache.
This provides a dict like interface that stores JSON serializable
objects.
The objects are serialized to JSON and stored in a file. These
values can be retrieved at a later time.
"""

CACHE_DIR = os.path.expanduser(os.path.join('~', '.aws', 'boto', 'cache'))

def __init__(self, working_dir=CACHE_DIR, dumps_func=None):
self._working_dir = working_dir
if dumps_func is None:
dumps_func = self._default_dumps
self._dumps = dumps_func

def _default_dumps(self, obj):
return json.dumps(obj, default=self._serialize_if_needed)

def __contains__(self, cache_key):
actual_key = self._convert_cache_key(cache_key)
return os.path.isfile(actual_key)

def __getitem__(self, cache_key):
"""Retrieve value from a cache key."""
actual_key = self._convert_cache_key(cache_key)
try:
with open(actual_key) as f:
return json.load(f)
except (OSError, ValueError, IOError):
raise KeyError(cache_key)

def __setitem__(self, cache_key, value):
full_key = self._convert_cache_key(cache_key)
try:
file_content = self._dumps(value)
except (TypeError, ValueError):
raise ValueError("Value cannot be cached, must be "
"JSON serializable: %s" % value)
if not os.path.isdir(self._working_dir):
os.makedirs(self._working_dir)
with os.fdopen(os.open(full_key,
os.O_WRONLY | os.O_CREAT, 0o600), 'w') as f:
f.truncate()
f.write(file_content)

def _convert_cache_key(self, cache_key):
full_path = os.path.join(self._working_dir, cache_key + '.json')
return full_path

def _serialize_if_needed(self, value, iso=False):
if isinstance(value, datetime.datetime):
if iso:
return value.isoformat()
return value.strftime('%Y-%m-%dT%H:%M:%S%Z')
return value


class SSOTokenFetcher(object):
# The device flow RFC defines the slow down delay to be an additional
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/botocore/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -3262,8 +3262,10 @@ def setUp(self):
self.loader = mock.Mock(spec=SSOTokenLoader)
self.loader.return_value = self.access_token
self.fetcher = SSOCredentialFetcher(
self.start_url, self.sso_region, self.role_name, self.account_id,
self.mock_session.create_client, token_loader=self.loader,
self.sso_region, self.role_name, self.account_id,
self.mock_session.create_client,
start_url=self.start_url,
token_loader=self.loader,
cache=self.cache,
)

Expand Down

0 comments on commit a3cf58e

Please sign in to comment.