Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 99 additions & 32 deletions google/auth/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@
_AWS_SECURITY_TOKEN_HEADER = "x-amz-security-token"
# The AWS authorization header name for the auto-generated date.
_AWS_DATE_HEADER = "x-amz-date"
# The default AWS regional credential verification URL.
_DEFAULT_AWS_REGIONAL_CREDENTIAL_VERIFICATION_URL = (
"https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15"
)


class RequestSigner(object):
Expand Down Expand Up @@ -360,8 +364,9 @@ class AwsSecurityCredentialsSupplier(metaclass=abc.ABCMeta):

@abc.abstractmethod
def get_aws_security_credentials(self, context, request):
"""Returns the AWS security credentials for the requested context. This is not cached by the calling
Google credential, so caching logic should be implemented in the supplier.
"""Returns the AWS security credentials for the requested context.

.. warning: This is not cached by the calling Google credential, so caching logic should be implemented in the supplier.

Args:
context (google.auth.externalaccount.SupplierContext): The context object
Expand Down Expand Up @@ -603,20 +608,40 @@ def __init__(
self,
audience,
subject_token_type,
token_url,
token_url=external_account._DEFAULT_TOKEN_URL,
credential_source=None,
aws_security_credentials_supplier=None,
*args,
**kwargs
):
"""Instantiates an AWS workload external account credentials object.

Args:
audience (str): The STS audience field.
subject_token_type (str): The subject token type.
token_url (str): The STS endpoint URL.
credential_source (Mapping): The credential source dictionary used
to provide instructions on how to retrieve external credential
to be exchanged for Google access tokens.
subject_token_type (str): The subject token type based on the Oauth2.0 token exchange spec.
Expected values include::

“urn:ietf:params:aws:token-type:aws4_request”

token_url (Optional [str]): The STS endpoint URL. If not provided, will default to "https://sts.googleapis.com/v1/token".
credential_source (Optional [Mapping]): The credential source dictionary used
to provide instructions on how to retrieve external credential to be exchanged for Google access tokens.
Either a credential source or an AWS security credentials supplier must be provided.

Example credential_source for AWS credential::

{
"environment_id": "aws1",
"regional_cred_verification_url": "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"region_url": "http://169.254.169.254/latest/meta-data/placement/availability-zone",
"url": "http://169.254.169.254/latest/meta-data/iam/security-credentials",
imdsv2_session_token_url": "http://169.254.169.254/latest/api/token"
}

aws_security_credentials_supplier (Optional [AwsSecurityCredentialsSupplier]): Optional AWS security credentials supplier.
This will be called to supply valid AWS security credentails which will then
be exchanged for Google access tokens. Either an AWS security credentials supplier
or a credential source must be provided.
args (List): Optional positional arguments passed into the underlying :meth:`~external_account.Credentials.__init__` method.
kwargs (Mapping): Optional keyword arguments passed into the underlying :meth:`~external_account.Credentials.__init__` method.

Expand All @@ -637,33 +662,52 @@ def __init__(
*args,
**kwargs
)
credential_source = credential_source or {}
self._target_resource = audience
environment_id = credential_source.get("environment_id") or ""
self._aws_security_credentials_supplier = _DefaultAwsSecurityCredentialsSupplier(
credential_source
)
self._cred_verification_url = credential_source.get(
"regional_cred_verification_url"
)
if credential_source is None and aws_security_credentials_supplier is None:
raise exceptions.InvalidValue(
"A valid credential source or AWS security credentials supplier must be provided."
)
if (
credential_source is not None
and aws_security_credentials_supplier is not None
):
raise exceptions.InvalidValue(
"AWS credential cannot have both a credential source and an AWS security credentials supplier."
)

# Get the environment ID. Currently, only one version supported (v1).
matches = re.match(r"^(aws)([\d]+)$", environment_id)
if matches:
env_id, env_version = matches.groups()
if aws_security_credentials_supplier:
self._aws_security_credentials_supplier = aws_security_credentials_supplier
# The regional cred verification URL would normally be provided through the credential source. So set it to the default one here.
self._cred_verification_url = (
_DEFAULT_AWS_REGIONAL_CREDENTIAL_VERIFICATION_URL
)
else:
env_id, env_version = (None, None)

if env_id != "aws" or self._cred_verification_url is None:
raise exceptions.InvalidResource(
"No valid AWS 'credential_source' provided"
environment_id = credential_source.get("environment_id") or ""
self._aws_security_credentials_supplier = _DefaultAwsSecurityCredentialsSupplier(
credential_source
)
elif int(env_version or "") != 1:
raise exceptions.InvalidValue(
"aws version '{}' is not supported in the current build.".format(
env_version
)
self._cred_verification_url = credential_source.get(
"regional_cred_verification_url"
)

# Get the environment ID. Currently, only one version supported (v1).
matches = re.match(r"^(aws)([\d]+)$", environment_id)
if matches:
env_id, env_version = matches.groups()
else:
env_id, env_version = (None, None)

if env_id != "aws" or self._cred_verification_url is None:
raise exceptions.InvalidResource(
"No valid AWS 'credential_source' provided"
)
elif int(env_version or "") != 1:
raise exceptions.InvalidValue(
"aws version '{}' is not supported in the current build.".format(
env_version
)
)

self._target_resource = audience
self._request_signer = None

def retrieve_subject_token(self, request):
Expand Down Expand Up @@ -758,9 +802,26 @@ def retrieve_subject_token(self, request):

def _create_default_metrics_options(self):
metrics_options = super(Credentials, self)._create_default_metrics_options()
metrics_options["source"] = "aws"
if self._has_custom_supplier():
metrics_options["source"] = "programmatic"
else:
metrics_options["source"] = "aws"
return metrics_options

def _has_custom_supplier(self):
return self._credential_source is None

def _constructor_args(self):
args = super(Credentials, self)._constructor_args()
# If a custom supplier was used, append it to the args dict.
if self._has_custom_supplier():
args.update(
{
"aws_security_credentials_supplier": self._aws_security_credentials_supplier
}
)
return args

@classmethod
def from_info(cls, info, **kwargs):
"""Creates an AWS Credentials instance from parsed external account info.
Expand All @@ -776,6 +837,12 @@ def from_info(cls, info, **kwargs):
Raises:
ValueError: For invalid parameters.
"""
aws_security_credentials_supplier = info.get(
"aws_security_credentials_supplier"
)
kwargs.update(
{"aws_security_credentials_supplier": aws_security_credentials_supplier}
)
return super(Credentials, cls).from_info(info, **kwargs)

@classmethod
Expand Down
38 changes: 28 additions & 10 deletions google/auth/external_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
_STS_REQUESTED_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token"
# Cloud resource manager URL used to retrieve project information.
_CLOUD_RESOURCE_MANAGER = "https://cloudresourcemanager.googleapis.com/v1/projects/"
# Default Google sts token url.
_DEFAULT_TOKEN_URL = "https://sts.googleapis.com/v1/token"


@dataclass
Expand All @@ -60,11 +62,13 @@ class SupplierContext:

Attributes:
subject_token_type (str): The requested subject token type based on the Oauth2.0 token exchange spec.
Expected values include:
“urn:ietf:params:oauth:token-type:jwt”
“urn:ietf:params:oauth:token-type:id-token”
“urn:ietf:params:oauth:token-type:saml2”
“urn:ietf:params:aws:token-type:aws4_request”
Expected values include::

“urn:ietf:params:oauth:token-type:jwt”
“urn:ietf:params:oauth:token-type:id-token”
“urn:ietf:params:oauth:token-type:saml2”
“urn:ietf:params:aws:token-type:aws4_request”

audience (str): The requested audience for the subject token.
"""

Expand Down Expand Up @@ -108,7 +112,14 @@ def __init__(

Args:
audience (str): The STS audience field.
subject_token_type (str): The subject token type.
subject_token_type (str): The subject token type based on the Oauth2.0 token exchange spec.
Expected values include::

“urn:ietf:params:oauth:token-type:jwt”
“urn:ietf:params:oauth:token-type:id-token”
“urn:ietf:params:oauth:token-type:saml2”
“urn:ietf:params:aws:token-type:aws4_request”

token_url (str): The STS endpoint URL.
credential_source (Mapping): The credential source dictionary.
service_account_impersonation_url (Optional[str]): The optional service account
Expand Down Expand Up @@ -165,10 +176,7 @@ def __init__(

self._metrics_options = self._create_default_metrics_options()

if self._service_account_impersonation_url:
self._impersonated_credentials = self._initialize_impersonated_credentials()
else:
self._impersonated_credentials = None
self._impersonated_credentials = None
self._project_id = None
self._supplier_context = SupplierContext(
self._subject_token_type, self._audience
Expand Down Expand Up @@ -381,6 +389,10 @@ def get_project_id(self, request):
@_helpers.copy_docstring(credentials.Credentials)
def refresh(self, request):
scopes = self._scopes if self._scopes is not None else self._default_scopes

if self._should_initialize_impersonated_credentials():
self._impersonated_credentials = self._initialize_impersonated_credentials()

if self._impersonated_credentials:
self._impersonated_credentials.refresh(request)
self.token = self._impersonated_credentials.token
Expand Down Expand Up @@ -444,6 +456,12 @@ def with_universe_domain(self, universe_domain):
new_cred._metrics_options = self._metrics_options
return new_cred

def _should_initialize_impersonated_credentials(self):
return (
self._service_account_impersonation_url is not None
and self._impersonated_credentials is None
)

def _initialize_impersonated_credentials(self):
"""Generates an impersonated credentials.

Expand Down
Loading