Skip to content

Commit 802dc3c

Browse files
feat: add POST support in SAML credential plugin
* Add http method config for AWS SAML plugin * Add http_method to config document of SAMLCrossAccount
1 parent 27a9e2b commit 802dc3c

File tree

3 files changed

+31
-8
lines changed

3 files changed

+31
-8
lines changed

src/awsrun/config.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def __init__(self, type_):
338338
self.type = type_
339339

340340
def type_check(self, obj):
341-
return type(obj) == self.type
341+
return type(obj) == self.type # noqa: E721
342342

343343
def __str__(self):
344344
return self.type.__name__
@@ -354,7 +354,7 @@ def __init__(self, pattern):
354354
self.pattern = pattern
355355

356356
def type_check(self, obj):
357-
if type(obj) != str:
357+
if type(obj) != str: # noqa: E721
358358
return False
359359
return bool(re.search(self.pattern, obj))
360360

@@ -366,7 +366,7 @@ class IpAddress(Type):
366366
"""Represents a string matching an IP address (v4 or v6)."""
367367

368368
def type_check(self, obj):
369-
if type(obj) != str:
369+
if type(obj) != str: # noqa: E721
370370
return False
371371
try:
372372
ipaddress.ip_address(obj)
@@ -382,7 +382,7 @@ class IpNetwork(Type):
382382
"""Represents a string matching an IP network (v4 or v6)."""
383383

384384
def type_check(self, obj):
385-
if type(obj) != str:
385+
if type(obj) != str: # noqa: E721
386386
return False
387387
try:
388388
ipaddress.ip_network(obj)
@@ -398,7 +398,7 @@ class FileType(Type):
398398
"""Represents a string pointing to an existing file."""
399399

400400
def type_check(self, obj):
401-
if type(obj) != str:
401+
if type(obj) != str: # noqa: E721
402402
return False
403403
return Path(obj).exists()
404404

@@ -462,7 +462,7 @@ def __init__(self, element_type):
462462
self.element_type = element_type
463463

464464
def type_check(self, obj):
465-
if type(obj) != list:
465+
if type(obj) != list: # noqa: E721
466466
return False
467467
return all(self.element_type.type_check(e) for e in obj)
468468

@@ -485,7 +485,7 @@ def __init__(self, key_type, value_type):
485485
self.value_type = value_type
486486

487487
def type_check(self, obj):
488-
if type(obj) != dict:
488+
if type(obj) != dict: # noqa: E721
489489
return False
490490
return all(self.key_type.type_check(k) for k in obj.keys()) and all(
491491
self.value_type.type_check(v) for v in obj.values()

src/awsrun/plugins/creds/aws.py

+13
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class SAML(Plugin):
137137
role: STRING*
138138
url: STRING*
139139
auth_type: ("basic" | "digest" | "ntlm")
140+
http_method: ("GET"| "POST")
140141
http_headers:
141142
STRING: STRING
142143
no_verify: BOOLEAN
@@ -175,6 +176,11 @@ class SAML(Plugin):
175176
specified, it must be one of `basic`, `digest`, or `ntlm`. The default value
176177
is `basic`. If using NTLM, username should be specified as `domain\\username`.
177178
179+
`http_method`
180+
: The HTTP method to use when authenticating with the IdP. If
181+
specified, it must be one of `GET`, `POST`. The default value
182+
is `GET`.
183+
178184
`http_headers`
179185
: Additional HTTP headers to send in the request to the IdP. If specified,
180186
it must be a dictionary of `key: value` pairs, where keys and values are
@@ -275,6 +281,7 @@ def instantiate(self, args):
275281
role=args.saml_role,
276282
url=cfg("url", type=URL, must_exist=True),
277283
auth=auth(args.saml_username, args.saml_password),
284+
http_method=cfg("http_method", type=Choice("GET", "POST"), default="GET"),
278285
headers=cfg("http_headers", type=Dict(Str, Str), default={}),
279286
duration=args.saml_duration,
280287
saml_duration=args.saml_assertion_duration,
@@ -458,6 +465,7 @@ class SAMLCrossAccount(AbstractCrossAccount):
458465
role: STRING*
459466
url: STRING*
460467
auth_type: ("basic" | "digest" | "ntlm")
468+
http_method: ("GET"| "POST")
461469
http_headers:
462470
STRING: STRING
463471
no_verify: BOOLEAN
@@ -503,6 +511,11 @@ class SAMLCrossAccount(AbstractCrossAccount):
503511
specified, it must be one of `basic`, `digest`, or `ntlm`. The default value
504512
is `basic`. If using NTLM, username should be specified as `domain\\username`.
505513
514+
`http_method`
515+
: The HTTP method to use when authenticating with the IdP. If
516+
specified, it must be one of `GET`, `POST`. The default value
517+
is `GET`.
518+
506519
`http_headers`
507520
: Additional HTTP headers to send in the request to the IdP. If specified,
508521
it must be a dictionary of `key: value` pairs, where keys and values are

src/awsrun/session/aws.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ def __init__(
384384
role,
385385
url,
386386
auth,
387+
http_method,
387388
headers=None,
388389
duration=3600,
389390
saml_duration=300,
@@ -392,6 +393,7 @@ def __init__(
392393
super().__init__(role, duration)
393394
self._url = url
394395
self._auth = auth
396+
self._http_method = http_method
395397
self._headers = {} if headers is None else headers
396398
self._cached_saml = ExpiringValue(self._request_assertion, saml_duration)
397399
self._no_verify = no_verify
@@ -414,7 +416,15 @@ def _request_assertion(self):
414416
with requests.Session() as s:
415417
s.auth = self._auth
416418
s.headers.update(self._headers)
417-
resp = s.get(self._url, verify=not self._no_verify)
419+
if self._http_method == "GET":
420+
resp = s.get(self._url, verify=not self._no_verify)
421+
else:
422+
authData = {
423+
"UserName": s.auth.username,
424+
"Password": s.auth.password,
425+
"AuthMethod": "FormsAuthentication",
426+
}
427+
resp = s.post(self._url, data=authData, verify=not self._no_verify)
418428

419429
if resp.status_code == 401:
420430
raise IDPAccessDeniedException("Could not authenticate")

0 commit comments

Comments
 (0)