Skip to content

Commit af86417

Browse files
joguSDnateprewitt
authored andcommitted
Update SSO token provider configuration and caching:
* Migrate SSO token provider configuration to dedicated section * Cache SSO tokens based on the sso-session name
1 parent e5e848a commit af86417

File tree

7 files changed

+150
-36
lines changed

7 files changed

+150
-36
lines changed

botocore/configloader.py

+9
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ def build_profile_map(parsed_ini_config):
253253
"""
254254
parsed_config = copy.deepcopy(parsed_ini_config)
255255
profiles = {}
256+
sso_sessions = {}
256257
final_config = {}
257258
for key, values in parsed_config.items():
258259
if key.startswith("profile"):
@@ -262,6 +263,13 @@ def build_profile_map(parsed_ini_config):
262263
continue
263264
if len(parts) == 2:
264265
profiles[parts[1]] = values
266+
elif key.startswith("sso-session"):
267+
try:
268+
parts = shlex.split(key)
269+
except ValueError:
270+
continue
271+
if len(parts) == 2:
272+
sso_sessions[parts[1]] = values
265273
elif key == 'default':
266274
# default section is special and is considered a profile
267275
# name but we don't require you use 'profile "default"'
@@ -270,4 +278,5 @@ def build_profile_map(parsed_ini_config):
270278
else:
271279
final_config[key] = values
272280
final_config['profiles'] = profiles
281+
final_config['sso_sessions'] = sso_sessions
273282
return final_config

botocore/tokens.py

+31-12
Original file line numberDiff line numberDiff line change
@@ -198,24 +198,43 @@ def __init__(self, session, cache=None, time_fetcher=_utc_now):
198198
def _load_sso_config(self):
199199
loaded_config = self._session.full_config
200200
profiles = loaded_config.get("profiles", {})
201+
sso_sessions = loaded_config.get("sso_sessions", {})
201202
profile_name = self._session.get_config_variable("profile")
202203
if not profile_name:
203204
profile_name = "default"
204205
profile_config = profiles.get(profile_name, {})
205206

206-
if "sso_start_url" not in profile_config:
207-
return None
207+
if "sso_session" not in profile_config:
208+
return
209+
210+
sso_session_name = profile_config["sso_session"]
211+
sso_config = sso_sessions.get(sso_session_name, None)
212+
213+
if not sso_config:
214+
error_msg = (
215+
f'The profile "{profile_name}" is configured to use the SSO '
216+
f'token provider but the "{sso_session_name}" sso_session '
217+
f"configuration does not exist."
218+
)
219+
raise InvalidConfigError(error_msg=error_msg)
220+
221+
missing_configs = []
222+
for var in self._SSO_CONFIG_VARS:
223+
if var not in sso_config:
224+
missing_configs.append(var)
208225

209-
if "sso_region" not in profile_config:
226+
if missing_configs:
210227
error_msg = (
211228
f'The profile "{profile_name}" is configured to use the SSO '
212-
f"token provider but is missing the sso_region configuration"
229+
f"token provider but is missing the following configuration: "
230+
f"{missing_configs}."
213231
)
214232
raise InvalidConfigError(error_msg=error_msg)
215233

216234
return {
217-
"sso_region": profile_config["sso_region"],
218-
"sso_start_url": profile_config["sso_start_url"],
235+
"session_name": sso_session_name,
236+
"sso_region": sso_config["sso_region"],
237+
"sso_start_url": sso_config["sso_start_url"],
219238
}
220239

221240
@CachedProperty
@@ -250,9 +269,6 @@ def _attempt_create_token(self, token):
250269
}
251270
if "refreshToken" in response:
252271
new_token["refreshToken"] = response["refreshToken"]
253-
elif "refreshToken" in token:
254-
# TODO: Verify if we should preserve the old refresh token
255-
new_token["refreshToken"] = token["refreshToken"]
256272
logger.info("SSO Token refresh succeeded")
257273
return new_token
258274

@@ -282,8 +298,9 @@ def _refresh_access_token(self, token):
282298

283299
def _refresher(self):
284300
start_url = self._sso_config["sso_start_url"]
285-
logger.info(f"Loading cached SSO token for {start_url}")
286-
token_dict = self._token_loader(start_url)
301+
session_name = self._sso_config["session_name"]
302+
logger.info(f"Loading cached SSO token for {session_name}")
303+
token_dict = self._token_loader(start_url, session_name=session_name)
287304
expiration = dateutil.parser.parse(token_dict["expiresAt"])
288305
logger.debug(f"Cached SSO token expires at {expiration}")
289306

@@ -293,7 +310,9 @@ def _refresher(self):
293310
if new_token_dict is not None:
294311
token_dict = new_token_dict
295312
expiration = token_dict["expiresAt"]
296-
self._token_loader.save_token(start_url, token_dict)
313+
self._token_loader.save_token(
314+
start_url, token_dict, session_name=session_name
315+
)
297316

298317
return FrozenAuthToken(
299318
token_dict["accessToken"], expiration=expiration

botocore/utils.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -2764,17 +2764,24 @@ def __init__(self, cache=None):
27642764
cache = {}
27652765
self._cache = cache
27662766

2767-
def _generate_cache_key(self, start_url):
2768-
return hashlib.sha1(start_url.encode('utf-8')).hexdigest()
2769-
2770-
def save_token(self, start_url, token):
2771-
cache_key = self._generate_cache_key(start_url)
2767+
def _generate_cache_key(self, start_url, session_name):
2768+
input_str = start_url
2769+
if session_name is not None:
2770+
input_str = session_name
2771+
return hashlib.sha1(input_str.encode('utf-8')).hexdigest()
2772+
2773+
def save_token(self, start_url, token, session_name=None):
2774+
cache_key = self._generate_cache_key(start_url, session_name)
27722775
self._cache[cache_key] = token
27732776

2774-
def __call__(self, start_url):
2775-
cache_key = self._generate_cache_key(start_url)
2777+
def __call__(self, start_url, session_name=None):
2778+
cache_key = self._generate_cache_key(start_url, session_name)
2779+
logger.debug(f'Checking for cached token at: {cache_key}')
27762780
if cache_key not in self._cache:
2777-
error_msg = f'Token for {start_url} does not exist'
2781+
name = start_url
2782+
if session_name is not None:
2783+
name = session_name
2784+
error_msg = f'Token for {name} does not exist'
27782785
raise SSOTokenLoadError(error_msg=error_msg)
27792786

27802787
token = self._cache[cache_key]

tests/unit/cfg/aws_sso_session_config

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[default]
2+
sso_session = sso
3+
4+
[sso-session sso]
5+
sso_start_url = https://example.com
6+
sso_region = us-east-1

tests/unit/test_configloader.py

+10
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,16 @@ def test_unicode_bytes_path(self):
178178
self.assertIn('default', loaded_config['profiles'])
179179
self.assertIn('personal', loaded_config['profiles'])
180180

181+
def test_sso_session_config(self):
182+
filename = path('aws_sso_session_config')
183+
loaded_config = load_config(filename)
184+
self.assertIn('profiles', loaded_config)
185+
self.assertIn('default', loaded_config['profiles'])
186+
self.assertIn('sso_sessions', loaded_config)
187+
self.assertIn('sso', loaded_config['sso_sessions'])
188+
sso_config = loaded_config['sso_sessions']['sso']
189+
self.assertEqual(sso_config['sso_region'], 'us-east-1')
190+
self.assertEqual(sso_config['sso_start_url'], 'https://example.com')
181191

182192
if __name__ == "__main__":
183193
unittest.main()

tests/unit/test_tokens.py

+56-16
Original file line numberDiff line numberDiff line change
@@ -33,42 +33,79 @@ def parametrize(cases):
3333
sso_provider_resolution_cases = [
3434
{
3535
"documentation": "Full valid profile",
36-
"profile": {
37-
"sso_region": "us-east-1",
38-
"sso_start_url": "https://d-abc123.awsapps.com/start",
36+
"config": {
37+
"profiles": {"test": {"sso_session": "admin"}},
38+
"sso_sessions": {
39+
"admin": {
40+
"sso_region": "us-east-1",
41+
"sso_start_url": "https://d-abc123.awsapps.com/start",
42+
}
43+
},
3944
},
4045
"resolves": True,
4146
},
4247
{
4348
"documentation": "Non-SSO profiles are skipped",
44-
"profile": {"region": "us-west-2"},
49+
"config": {"profiles": {"test": {"region": "us-west-2"}}},
4550
"resolves": False,
4651
},
4752
{
4853
"documentation": "Only start URL is invalid",
49-
"profile": {"sso_start_url": "https://d-abc123.awsapps.com/start"},
54+
"config": {
55+
"profiles": {"test": {"sso_session": "admin"}},
56+
"sso_sessions": {
57+
"admin": {
58+
"sso_start_url": "https://d-abc123.awsapps.com/start"
59+
}
60+
},
61+
},
62+
"resolves": False,
63+
"expectedException": InvalidConfigError,
64+
},
65+
{
66+
"documentation": "Only sso_region is invalid",
67+
"config": {
68+
"profiles": {"test": {"sso_session": "admin"}},
69+
"sso_sessions": {"admin": {"sso_region": "us-east-1"}},
70+
},
5071
"resolves": False,
5172
"expectedException": InvalidConfigError,
5273
},
5374
{
54-
"documentation": "SSO Region only is skipped",
55-
"profile": {"sso_region": "us-east-1"},
75+
"documentation": "Specified sso-session must exist",
76+
"config": {
77+
"profiles": {"test": {"sso_session": "dev"}},
78+
"sso_sessions": {"admin": {"sso_region": "us-east-1"}},
79+
},
80+
"resolves": False,
81+
"expectedException": InvalidConfigError,
82+
},
83+
{
84+
"documentation": "The sso_session must be specified",
85+
"config": {
86+
"profiles": {"test": {"region": "us-west-2"}},
87+
"sso_sessions": {
88+
"admin": {
89+
"sso_region": "us-east-1",
90+
"sso_start_url": "https://d-abc123.awsapps.com/start",
91+
}
92+
},
93+
},
5694
"resolves": False,
5795
},
5896
]
5997

6098

6199
def _create_mock_session(config):
62100
mock_session = mock.Mock(spec=Session)
63-
mock_session.get_config_variable.return_value = "default"
64-
mock_session.full_config = {"profiles": {"default": config}}
101+
mock_session.get_config_variable.return_value = "test"
102+
mock_session.full_config = config
65103
return mock_session
66104

67105

68106
@parametrize(sso_provider_resolution_cases)
69107
def test_sso_token_provider_resolution(test_case):
70-
config = test_case["profile"]
71-
mock_session = _create_mock_session(config)
108+
mock_session = _create_mock_session(test_case["config"])
72109
resolver = SSOTokenProvider(mock_session)
73110

74111
expected_exception = test_case.get("expectedException")
@@ -196,8 +233,6 @@ def test_sso_token_provider_resolution(test_case):
196233
"clientId": "clientid",
197234
"clientSecret": "YSBzZWNyZXQ=",
198235
"registrationExpiresAt": "2022-12-25T13:30:00Z",
199-
# TODO: Verify if we should preserve old refresh token
200-
"refreshToken": "cachedrefreshtoken",
201236
},
202237
"expectedToken": {
203238
"token": "newtoken",
@@ -225,10 +260,15 @@ def test_sso_token_provider_resolution(test_case):
225260
@parametrize(sso_provider_refresh_cases)
226261
def test_sso_token_provider_refresh(test_case):
227262
config = {
228-
"sso_region": "us-west-2",
229-
"sso_start_url": "https://d-123.awsapps.com/start",
263+
"profiles": {"test": {"sso_session": "admin"}},
264+
"sso_sessions": {
265+
"admin": {
266+
"sso_region": "us-west-2",
267+
"sso_start_url": "https://d-123.awsapps.com/start",
268+
}
269+
},
230270
}
231-
cache_key = "2b829a45f04c9828cb45b7d092d8e4aa30818393"
271+
cache_key = "d033e22ae348aeb5660fc2140aec35850c4da997"
232272
token_cache = {}
233273

234274
# Prepopulate the token cache

tests/unit/test_utils.py

+23
Original file line numberDiff line numberDiff line change
@@ -3148,8 +3148,10 @@ def test_imds_service_endpoint_overrides_ipv6_endpoint(self, send):
31483148
class TestSSOTokenLoader(unittest.TestCase):
31493149
def setUp(self):
31503150
super().setUp()
3151+
self.session_name = 'admin'
31513152
self.start_url = 'https://d-abc123.awsapps.com/start'
31523153
self.cache_key = '40a89917e3175433e361b710a9d43528d7f1890a'
3154+
self.session_cache_key = 'd033e22ae348aeb5660fc2140aec35850c4da997'
31533155
self.access_token = 'totally.a.token'
31543156
self.cached_token = {
31553157
'accessToken': self.access_token,
@@ -3172,6 +3174,27 @@ def test_can_handle_invalid_cache(self):
31723174
with self.assertRaises(SSOTokenLoadError):
31733175
self.loader(self.start_url)
31743176

3177+
def test_can_save_token(self):
3178+
self.loader.save_token(self.start_url, self.cached_token)
3179+
access_token = self.loader(self.start_url)
3180+
self.assertEqual(self.cached_token, access_token)
3181+
3182+
def test_can_save_token_sso_session(self):
3183+
self.loader.save_token(
3184+
self.start_url, self.cached_token, session_name=self.session_name,
3185+
)
3186+
access_token = self.loader(
3187+
self.start_url, session_name=self.session_name,
3188+
)
3189+
self.assertEqual(self.cached_token, access_token)
3190+
3191+
def test_can_load_token_exists_sso_session_name(self):
3192+
self.cache[self.session_cache_key] = self.cached_token
3193+
access_token = self.loader(
3194+
self.start_url, session_name=self.session_name,
3195+
)
3196+
self.assertEqual(self.cached_token, access_token)
3197+
31753198

31763199
@pytest.mark.parametrize(
31773200
'header_name, headers, expected',

0 commit comments

Comments
 (0)