Skip to content

Commit

Permalink
Merge pull request #7445 from kyleknap/v2-update-sso-cred-provider
Browse files Browse the repository at this point in the history
[v2] Port SSO credential provider updates
kyleknap authored Nov 16, 2022
2 parents 85ea90c + 0ca979c commit 6a7280e
Showing 5 changed files with 306 additions and 78 deletions.
5 changes: 5 additions & 0 deletions .changes/next-release/enhancement-sso-81091.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"type": "enhancement",
"category": "sso",
"description": "Add support for loading sso-session profiles for SSO credential provider"
}
152 changes: 77 additions & 75 deletions awscli/botocore/credentials.py
Original file line number Diff line number Diff line change
@@ -39,10 +39,12 @@
UnauthorizedSSOTokenError,
UnknownCredentialError,
)
from botocore.tokens import SSOTokenProvider
from botocore.utils import (
ContainerMetadataFetcher,
FileWebIdentityTokenLoader,
InstanceMetadataFetcher,
JSONFileCache,
SSOTokenLoader,
original_ld_library_path,
parse_key_val_file,
@@ -211,6 +213,7 @@ def _create_sso_provider(self, profile_name):
profile_name=profile_name,
cache=self._cache,
token_cache=self._sso_token_cache,
token_provider=SSOTokenProvider(self._session),
)


@@ -282,57 +285,6 @@ def __call__(self):
return _Refresher(actual_refresh)


class JSONFileCache(object):
"""JSON file cache.
This provides a dict like interface that stores JSON serializable
objects.
The objects are serialized to JSON and stored in a file. These
values can be retrieved at a later time.
"""

CACHE_DIR = os.path.expanduser(os.path.join('~', '.aws', 'boto', 'cache'))

def __init__(self, working_dir=CACHE_DIR, dumps_func=None):
self._working_dir = working_dir
if dumps_func is None:
dumps_func = self._default_dumps
self._dumps = dumps_func

def _default_dumps(self, obj):
return json.dumps(obj, default=_serialize_if_needed)

def __contains__(self, cache_key):
actual_key = self._convert_cache_key(cache_key)
return os.path.isfile(actual_key)

def __getitem__(self, cache_key):
"""Retrieve value from a cache key."""
actual_key = self._convert_cache_key(cache_key)
try:
with open(actual_key) as f:
return json.load(f)
except (OSError, ValueError, IOError):
raise KeyError(cache_key)

def __setitem__(self, cache_key, value):
full_key = self._convert_cache_key(cache_key)
try:
file_content = self._dumps(value)
except (TypeError, ValueError):
raise ValueError("Value cannot be cached, must be "
"JSON serializable: %s" % value)
if not os.path.isdir(self._working_dir):
os.makedirs(self._working_dir)
with os.fdopen(os.open(full_key,
os.O_WRONLY | os.O_CREAT, 0o600), 'w') as f:
f.truncate()
f.write(file_content)

def _convert_cache_key(self, cache_key):
full_path = os.path.join(self._working_dir, cache_key + '.json')
return full_path


class Credentials(object):
"""
Holds the credentials needed to authenticate requests.
@@ -2002,13 +1954,16 @@ class SSOCredentialFetcher(CachedCredentialFetcher):

def __init__(self, start_url, sso_region, role_name, account_id,
client_creator, token_loader=None, cache=None,
expiry_window_seconds=None):
expiry_window_seconds=None, token_provider=None,
sso_session_name=None):
self._client_creator = client_creator
self._sso_region = sso_region
self._role_name = role_name
self._account_id = account_id
self._start_url = start_url
self._token_loader = token_loader
self._token_provider = token_provider
self._sso_session_name = sso_session_name
super(SSOCredentialFetcher, self).__init__(
cache, expiry_window_seconds
)
@@ -2019,10 +1974,13 @@ def _create_cache_key(self):
The cache key is intended to be compatible with file names.
"""
args = {
'startUrl': self._start_url,
'roleName': self._role_name,
'accountId': self._account_id,
}
if self._sso_session_name:
args['sessionName'] = self._sso_session_name
else:
args['startUrl'] = self._start_url
# NOTE: It would be good to hoist this cache key construction logic
# into the CachedCredentialFetcher class as we should be consistent.
# Unfortunately, the current assume role fetchers that sub class don't
@@ -2045,12 +2003,16 @@ def _get_credentials(self):
region_name=self._sso_region,
)
client = self._client_creator('sso', config=config)
if self._token_provider:
initial_token_data = self._token_provider.load_token()
token = initial_token_data.get_frozen_token().token
else:
token = self._token_loader(self._start_url)['accessToken']

token_dict = self._token_loader(self._start_url)
kwargs = {
'roleName': self._role_name,
'accountId': self._account_id,
'accessToken': token_dict['accessToken'],
'accessToken': token,
}
try:
response = client.get_role_credentials(**kwargs)
@@ -2076,18 +2038,24 @@ class SSOProvider(CredentialProvider):
_SSO_TOKEN_CACHE_DIR = os.path.expanduser(
os.path.join('~', '.aws', 'sso', 'cache')
)
_SSO_CONFIG_VARS = [
'sso_start_url',
'sso_region',
_PROFILE_REQUIRED_CONFIG_VARS = (
'sso_role_name',
'sso_account_id',
]
)
_SSO_REQUIRED_CONFIG_VARS = (
'sso_start_url',
'sso_region',
)
_ALL_REQUIRED_CONFIG_VARS = (
_PROFILE_REQUIRED_CONFIG_VARS + _SSO_REQUIRED_CONFIG_VARS
)

def __init__(self, load_config, client_creator, profile_name,
cache=None, token_cache=None):
cache=None, token_cache=None, token_provider=None):
if token_cache is None:
token_cache = JSONFileCache(self._SSO_TOKEN_CACHE_DIR)
self._token_cache = token_cache
self._token_provider = token_provider
if cache is None:
cache = {}
self.cache = cache
@@ -2100,17 +2068,24 @@ def _load_sso_config(self):
profiles = loaded_config.get('profiles', {})
profile_name = self._profile_name
profile_config = profiles.get(self._profile_name, {})
sso_sessions = loaded_config.get('sso_sessions', {})

# Role name & Account ID indicate the cred provider should be used
sso_cred_vars = ('sso_role_name', 'sso_account_id')
if all(c not in profile_config for c in sso_cred_vars):
if all(
c not in profile_config for c in self._PROFILE_REQUIRED_CONFIG_VARS
):
return None

resolved_config, extra_reqs = self._resolve_sso_session_reference(
profile_config, sso_sessions
)

config = {}
missing_config_vars = []
for config_var in self._SSO_CONFIG_VARS:
if config_var in profile_config:
config[config_var] = profile_config[config_var]
all_required_configs = self._ALL_REQUIRED_CONFIG_VARS + extra_reqs
for config_var in all_required_configs:
if config_var in resolved_config:
config[config_var] = resolved_config[config_var]
else:
missing_config_vars.append(config_var)

@@ -2122,23 +2097,50 @@ def _load_sso_config(self):
'required configuration: %s' % (profile_name, missing)
)
)

return config

def _resolve_sso_session_reference(self, profile_config, sso_sessions):
sso_session_name = profile_config.get('sso_session')
if sso_session_name is None:
# No reference to resolve, proceed with legacy flow
return profile_config, ()

if sso_session_name not in sso_sessions:
error_msg = f'The specified sso-session does not exist: "{sso_session_name}"'
raise InvalidConfigError(error_msg=error_msg)

config = profile_config.copy()
session = sso_sessions[sso_session_name]
for config_var, val in session.items():
# Validate any keys referenced in both profile and sso_session match
if config.get(config_var, val) != val:
error_msg = (
f"The value for {config_var} is inconsistent between "
f"profile ({config[config_var]}) and sso-session ({val})."
)
raise InvalidConfigError(error_msg=error_msg)
config[config_var] = val
return config, ('sso_session',)

def load(self):
sso_config = self._load_sso_config()
if not sso_config:
return None

sso_fetcher = SSOCredentialFetcher(
sso_config['sso_start_url'],
sso_config['sso_region'],
sso_config['sso_role_name'],
sso_config['sso_account_id'],
self._client_creator,
token_loader=SSOTokenLoader(cache=self._token_cache),
cache=self.cache,
)
fetcher_kwargs = {
'start_url': sso_config['sso_start_url'],
'sso_region': sso_config['sso_region'],
'role_name': sso_config['sso_role_name'],
'account_id': sso_config['sso_account_id'],
'client_creator': self._client_creator,
'token_loader': SSOTokenLoader(cache=self._token_cache),
'cache': self.cache,
}
if 'sso_session' in sso_config:
fetcher_kwargs['sso_session_name'] = sso_config['sso_session']
fetcher_kwargs['token_provider'] = self._token_provider

sso_fetcher = SSOCredentialFetcher(**fetcher_kwargs)

return DeferredRefreshableCredentials(
method=self.METHOD,
6 changes: 3 additions & 3 deletions awscli/botocore/tokens.py
Original file line number Diff line number Diff line change
@@ -24,13 +24,12 @@
from botocore import UNSIGNED
from botocore.compat import total_seconds
from botocore.config import Config
from botocore.credentials import JSONFileCache
from botocore.exceptions import (
ClientError,
InvalidConfigError,
TokenRetrievalError,
)
from botocore.utils import CachedProperty, SSOTokenLoader
from botocore.utils import CachedProperty, JSONFileCache, SSOTokenLoader

logger = logging.getLogger(__name__)

@@ -184,11 +183,12 @@ class SSOTokenProvider:
"sso_region",
]
_GRANT_TYPE = "refresh_token"
DEFAULT_CACHE_CLS = JSONFileCache

def __init__(self, session, cache=None, time_fetcher=_utc_now):
self._session = session
if cache is None:
cache = JSONFileCache(
cache = self.DEFAULT_CACHE_CLS(
self._SSO_TOKEN_CACHE_DIR,
dumps_func=_sso_json_dumps,
)
70 changes: 70 additions & 0 deletions awscli/botocore/utils.py
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@
import time
import warnings
import weakref
from pathlib import Path

import botocore
import botocore.awsrequest
@@ -2887,3 +2888,72 @@ def _get_global_endpoint(self, endpoint, endpoint_variant_tags=None):
dns_suffix = self._DEFAULT_DNS_SUFFIX

return f"https://{endpoint}.endpoint.events.{dns_suffix}/"


class JSONFileCache:
"""JSON file cache.
This provides a dict like interface that stores JSON serializable
objects.
The objects are serialized to JSON and stored in a file. These
values can be retrieved at a later time.
"""

CACHE_DIR = os.path.expanduser(os.path.join('~', '.aws', 'boto', 'cache'))

def __init__(self, working_dir=CACHE_DIR, dumps_func=None):
self._working_dir = working_dir
if dumps_func is None:
dumps_func = self._default_dumps
self._dumps = dumps_func

def _default_dumps(self, obj):
return json.dumps(obj, default=self._serialize_if_needed)

def __contains__(self, cache_key):
actual_key = self._convert_cache_key(cache_key)
return os.path.isfile(actual_key)

def __getitem__(self, cache_key):
"""Retrieve value from a cache key."""
actual_key = self._convert_cache_key(cache_key)
try:
with open(actual_key) as f:
return json.load(f)
except (OSError, ValueError):
raise KeyError(cache_key)

def __delitem__(self, cache_key):
actual_key = self._convert_cache_key(cache_key)
try:
key_path = Path(actual_key)
key_path.unlink()
except FileNotFoundError:
raise KeyError(cache_key)

def __setitem__(self, cache_key, value):
full_key = self._convert_cache_key(cache_key)
try:
file_content = self._dumps(value)
except (TypeError, ValueError):
raise ValueError(
f"Value cannot be cached, must be "
f"JSON serializable: {value}"
)
if not os.path.isdir(self._working_dir):
os.makedirs(self._working_dir)
with os.fdopen(
os.open(full_key, os.O_WRONLY | os.O_CREAT, 0o600), 'w'
) as f:
f.truncate()
f.write(file_content)

def _convert_cache_key(self, cache_key):
full_path = os.path.join(self._working_dir, cache_key + '.json')
return full_path

def _serialize_if_needed(self, value, iso=False):
if isinstance(value, datetime.datetime):
if iso:
return value.isoformat()
return value.strftime('%Y-%m-%dT%H:%M:%S%Z')
return value
151 changes: 151 additions & 0 deletions tests/functional/botocore/test_credentials.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import json
import uuid
import threading
import os
@@ -21,6 +22,7 @@
from datetime import datetime, timedelta
import sys

import pytest
from dateutil.tz import tzlocal
from botocore.exceptions import CredentialRetrievalError

@@ -42,8 +44,12 @@
from botocore.session import Session
from botocore.exceptions import InvalidConfigError, InfiniteLoopConfigError
from botocore.stub import Stubber
from botocore.tokens import SSOTokenProvider
from botocore.utils import datetime2timestamp

TIME_IN_ONE_HOUR = datetime.utcnow() + timedelta(hours=1)
TIME_IN_SIX_MONTHS = datetime.utcnow() + timedelta(hours=4320)


class TestCredentialRefreshRaces(unittest.TestCase):
def assert_consistent_credentials_seen(self, creds, func):
@@ -842,3 +848,148 @@ def test_imds_use_truncated_user_agent(self, send):
provider.load()
args, _ = send.call_args
self.assertEqual(args[0].headers['User-Agent'], 'Botocore/24.0')


class MockCache:
"""Mock for JSONFileCache to avoid touching files on disk"""

def __init__(self, working_dir=None, dumps_func=None):
self.working_dir = working_dir
self.dumps_func = dumps_func

def __contains__(self, cache_key):
return True

def __getitem__(self, cache_key):
return {
"startUrl": "https://test.awsapps.com/start",
"region": "us-east-1",
"accessToken": "access-token",
"expiresAt": TIME_IN_ONE_HOUR.strftime('%Y-%m-%dT%H:%M:%SZ'),
"expiresIn": 3600,
"clientId": "client-12345",
"clientSecret": "client-secret",
"registrationExpiresAt": TIME_IN_SIX_MONTHS.strftime(
'%Y-%m-%dT%H:%M:%SZ'
),
"refreshToken": "refresh-here",
}

def __delitem__(self, cache_key):
pass


class SSOSessionTest(BaseEnvVar):
def setUp(self):
super().setUp()
self.tempdir = tempfile.mkdtemp()
self.config_file = os.path.join(self.tempdir, 'config')
self.environ['AWS_CONFIG_FILE'] = self.config_file
self.access_key_id = 'ASIA123456ABCDEFG'
self.secret_access_key = 'secret-key'
self.session_token = 'session-token'

def tearDown(self):
shutil.rmtree(self.tempdir)
super().tearDown()

def write_config(self, config):
with open(self.config_file, 'w') as f:
f.write(config)

def test_token_chosen_from_provider(self):
profile = (
'[profile sso-test]\n'
'region = us-east-1\n'
'sso_session = sso-test-session\n'
'sso_account_id = 12345678901234\n'
'sso_role_name = ViewOnlyAccess\n'
'\n'
'[sso-session sso-test-session]\n'
'sso_region = us-east-1\n'
'sso_start_url = https://test.awsapps.com/start\n'
'sso_registration_scopes = sso:account:access\n'
)
self.write_config(profile)

session = Session(profile='sso-test')
with SessionHTTPStubber(session) as stubber:
self.add_credential_response(stubber)
stubber.add_response()
with mock.patch.object(
SSOTokenProvider, 'DEFAULT_CACHE_CLS', MockCache
):
c = session.create_client('s3')
c.list_buckets()

self.assert_valid_sso_call(
stubber.requests[0],
(
'https://portal.sso.us-east-1.amazonaws.com/federation/credentials'
'?role_name=ViewOnlyAccess&account_id=12345678901234'
),
b'access-token',
)
self.assert_credentials_used(
stubber.requests[1],
self.access_key_id.encode('utf-8'),
self.session_token.encode('utf-8'),
)

def test_mismatched_session_values(self):
profile = (
'[profile sso-test]\n'
'region = us-east-1\n'
'sso_session = sso-test-session\n'
'sso_start_url = https://test2.awsapps.com/start\n'
'sso_account_id = 12345678901234\n'
'sso_role_name = ViewOnlyAccess\n'
'\n'
'[sso-session sso-test-session]\n'
'sso_region = us-east-1\n'
'sso_start_url = https://test.awsapps.com/start\n'
'sso_registration_scopes = sso:account:access\n'
)
self.write_config(profile)

session = Session(profile='sso-test')
with pytest.raises(InvalidConfigError):
c = session.create_client('s3')
c.list_buckets()

def test_missing_sso_session(self):
profile = (
'[profile sso-test]\n'
'region = us-east-1\n'
'sso_session = sso-test-session\n'
'sso_start_url = https://test2.awsapps.com/start\n'
'sso_account_id = 12345678901234\n'
'sso_role_name = ViewOnlyAccess\n'
'\n'
)
self.write_config(profile)

session = Session(profile='sso-test')
with pytest.raises(InvalidConfigError):
c = session.create_client('s3')
c.list_buckets()

def assert_valid_sso_call(self, request, url, access_token):
assert request.url == url
assert 'x-amz-sso_bearer_token' in request.headers
assert request.headers['x-amz-sso_bearer_token'] == access_token

def assert_credentials_used(self, request, access_key, session_token):
assert access_key in request.headers.get('Authorization')
assert request.headers.get('X-Amz-Security-Token') == session_token

def add_credential_response(self, stubber):
response = {
'roleCredentials': {
'accessKeyId': self.access_key_id,
'secretAccessKey': self.secret_access_key,
'sessionToken': self.session_token,
'expiration': TIME_IN_ONE_HOUR.timestamp() * 1000,
}
}
stubber.add_response(body=json.dumps(response).encode('utf-8'))

0 comments on commit 6a7280e

Please sign in to comment.