diff --git a/google/auth/aws.py b/google/auth/aws.py index fc644095f..a3f497442 100644 --- a/google/auth/aws.py +++ b/google/auth/aws.py @@ -64,6 +64,10 @@ _AWS_SECURITY_TOKEN_HEADER = "x-amz-security-token" # The AWS authorization header name for the auto-generated date. _AWS_DATE_HEADER = "x-amz-date" +# The default AWS regional credential verification URL. +_DEFAULT_AWS_REGIONAL_CREDENTIAL_VERIFICATION_URL = ( + "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15" +) class RequestSigner(object): @@ -360,8 +364,9 @@ class AwsSecurityCredentialsSupplier(metaclass=abc.ABCMeta): @abc.abstractmethod def get_aws_security_credentials(self, context, request): - """Returns the AWS security credentials for the requested context. This is not cached by the calling - Google credential, so caching logic should be implemented in the supplier. + """Returns the AWS security credentials for the requested context. + + .. warning: This is not cached by the calling Google credential, so caching logic should be implemented in the supplier. Args: context (google.auth.externalaccount.SupplierContext): The context object @@ -603,8 +608,9 @@ def __init__( self, audience, subject_token_type, - token_url, + token_url=external_account._DEFAULT_TOKEN_URL, credential_source=None, + aws_security_credentials_supplier=None, *args, **kwargs ): @@ -612,11 +618,30 @@ def __init__( Args: audience (str): The STS audience field. - subject_token_type (str): The subject token type. - token_url (str): The STS endpoint URL. - credential_source (Mapping): The credential source dictionary used - to provide instructions on how to retrieve external credential - to be exchanged for Google access tokens. + subject_token_type (str): The subject token type based on the Oauth2.0 token exchange spec. + Expected values include:: + + “urn:ietf:params:aws:token-type:aws4_request” + + token_url (Optional [str]): The STS endpoint URL. If not provided, will default to "https://sts.googleapis.com/v1/token". + credential_source (Optional [Mapping]): The credential source dictionary used + to provide instructions on how to retrieve external credential to be exchanged for Google access tokens. + Either a credential source or an AWS security credentials supplier must be provided. + + Example credential_source for AWS credential:: + + { + "environment_id": "aws1", + "regional_cred_verification_url": "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "region_url": "http://169.254.169.254/latest/meta-data/placement/availability-zone", + "url": "http://169.254.169.254/latest/meta-data/iam/security-credentials", + imdsv2_session_token_url": "http://169.254.169.254/latest/api/token" + } + + aws_security_credentials_supplier (Optional [AwsSecurityCredentialsSupplier]): Optional AWS security credentials supplier. + This will be called to supply valid AWS security credentails which will then + be exchanged for Google access tokens. Either an AWS security credentials supplier + or a credential source must be provided. args (List): Optional positional arguments passed into the underlying :meth:`~external_account.Credentials.__init__` method. kwargs (Mapping): Optional keyword arguments passed into the underlying :meth:`~external_account.Credentials.__init__` method. @@ -637,33 +662,52 @@ def __init__( *args, **kwargs ) - credential_source = credential_source or {} - self._target_resource = audience - environment_id = credential_source.get("environment_id") or "" - self._aws_security_credentials_supplier = _DefaultAwsSecurityCredentialsSupplier( - credential_source - ) - self._cred_verification_url = credential_source.get( - "regional_cred_verification_url" - ) + if credential_source is None and aws_security_credentials_supplier is None: + raise exceptions.InvalidValue( + "A valid credential source or AWS security credentials supplier must be provided." + ) + if ( + credential_source is not None + and aws_security_credentials_supplier is not None + ): + raise exceptions.InvalidValue( + "AWS credential cannot have both a credential source and an AWS security credentials supplier." + ) - # Get the environment ID. Currently, only one version supported (v1). - matches = re.match(r"^(aws)([\d]+)$", environment_id) - if matches: - env_id, env_version = matches.groups() + if aws_security_credentials_supplier: + self._aws_security_credentials_supplier = aws_security_credentials_supplier + # The regional cred verification URL would normally be provided through the credential source. So set it to the default one here. + self._cred_verification_url = ( + _DEFAULT_AWS_REGIONAL_CREDENTIAL_VERIFICATION_URL + ) else: - env_id, env_version = (None, None) - - if env_id != "aws" or self._cred_verification_url is None: - raise exceptions.InvalidResource( - "No valid AWS 'credential_source' provided" + environment_id = credential_source.get("environment_id") or "" + self._aws_security_credentials_supplier = _DefaultAwsSecurityCredentialsSupplier( + credential_source ) - elif int(env_version or "") != 1: - raise exceptions.InvalidValue( - "aws version '{}' is not supported in the current build.".format( - env_version - ) + self._cred_verification_url = credential_source.get( + "regional_cred_verification_url" ) + + # Get the environment ID. Currently, only one version supported (v1). + matches = re.match(r"^(aws)([\d]+)$", environment_id) + if matches: + env_id, env_version = matches.groups() + else: + env_id, env_version = (None, None) + + if env_id != "aws" or self._cred_verification_url is None: + raise exceptions.InvalidResource( + "No valid AWS 'credential_source' provided" + ) + elif int(env_version or "") != 1: + raise exceptions.InvalidValue( + "aws version '{}' is not supported in the current build.".format( + env_version + ) + ) + + self._target_resource = audience self._request_signer = None def retrieve_subject_token(self, request): @@ -758,9 +802,26 @@ def retrieve_subject_token(self, request): def _create_default_metrics_options(self): metrics_options = super(Credentials, self)._create_default_metrics_options() - metrics_options["source"] = "aws" + if self._has_custom_supplier(): + metrics_options["source"] = "programmatic" + else: + metrics_options["source"] = "aws" return metrics_options + def _has_custom_supplier(self): + return self._credential_source is None + + def _constructor_args(self): + args = super(Credentials, self)._constructor_args() + # If a custom supplier was used, append it to the args dict. + if self._has_custom_supplier(): + args.update( + { + "aws_security_credentials_supplier": self._aws_security_credentials_supplier + } + ) + return args + @classmethod def from_info(cls, info, **kwargs): """Creates an AWS Credentials instance from parsed external account info. @@ -776,6 +837,12 @@ def from_info(cls, info, **kwargs): Raises: ValueError: For invalid parameters. """ + aws_security_credentials_supplier = info.get( + "aws_security_credentials_supplier" + ) + kwargs.update( + {"aws_security_credentials_supplier": aws_security_credentials_supplier} + ) return super(Credentials, cls).from_info(info, **kwargs) @classmethod diff --git a/google/auth/external_account.py b/google/auth/external_account.py index fbd542843..c14001bc2 100644 --- a/google/auth/external_account.py +++ b/google/auth/external_account.py @@ -51,6 +51,8 @@ _STS_REQUESTED_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token" # Cloud resource manager URL used to retrieve project information. _CLOUD_RESOURCE_MANAGER = "https://cloudresourcemanager.googleapis.com/v1/projects/" +# Default Google sts token url. +_DEFAULT_TOKEN_URL = "https://sts.googleapis.com/v1/token" @dataclass @@ -60,11 +62,13 @@ class SupplierContext: Attributes: subject_token_type (str): The requested subject token type based on the Oauth2.0 token exchange spec. - Expected values include: - “urn:ietf:params:oauth:token-type:jwt” - “urn:ietf:params:oauth:token-type:id-token” - “urn:ietf:params:oauth:token-type:saml2” - “urn:ietf:params:aws:token-type:aws4_request” + Expected values include:: + + “urn:ietf:params:oauth:token-type:jwt” + “urn:ietf:params:oauth:token-type:id-token” + “urn:ietf:params:oauth:token-type:saml2” + “urn:ietf:params:aws:token-type:aws4_request” + audience (str): The requested audience for the subject token. """ @@ -108,7 +112,14 @@ def __init__( Args: audience (str): The STS audience field. - subject_token_type (str): The subject token type. + subject_token_type (str): The subject token type based on the Oauth2.0 token exchange spec. + Expected values include:: + + “urn:ietf:params:oauth:token-type:jwt” + “urn:ietf:params:oauth:token-type:id-token” + “urn:ietf:params:oauth:token-type:saml2” + “urn:ietf:params:aws:token-type:aws4_request” + token_url (str): The STS endpoint URL. credential_source (Mapping): The credential source dictionary. service_account_impersonation_url (Optional[str]): The optional service account @@ -165,10 +176,7 @@ def __init__( self._metrics_options = self._create_default_metrics_options() - if self._service_account_impersonation_url: - self._impersonated_credentials = self._initialize_impersonated_credentials() - else: - self._impersonated_credentials = None + self._impersonated_credentials = None self._project_id = None self._supplier_context = SupplierContext( self._subject_token_type, self._audience @@ -381,6 +389,10 @@ def get_project_id(self, request): @_helpers.copy_docstring(credentials.Credentials) def refresh(self, request): scopes = self._scopes if self._scopes is not None else self._default_scopes + + if self._should_initialize_impersonated_credentials(): + self._impersonated_credentials = self._initialize_impersonated_credentials() + if self._impersonated_credentials: self._impersonated_credentials.refresh(request) self.token = self._impersonated_credentials.token @@ -444,6 +456,12 @@ def with_universe_domain(self, universe_domain): new_cred._metrics_options = self._metrics_options return new_cred + def _should_initialize_impersonated_credentials(self): + return ( + self._service_account_impersonation_url is not None + and self._impersonated_credentials is None + ) + def _initialize_impersonated_credentials(self): """Generates an impersonated credentials. diff --git a/google/auth/identity_pool.py b/google/auth/identity_pool.py index 5350c35dc..77bc7735b 100644 --- a/google/auth/identity_pool.py +++ b/google/auth/identity_pool.py @@ -57,8 +57,9 @@ class SubjectTokenSupplier(metaclass=abc.ABCMeta): @abc.abstractmethod def get_subject_token(self, context, request): - """Returns the requested subject token. The subject token must be valid. This is not cached by the calling - Google credential, so caching logic should be implemented in the supplier. + """Returns the requested subject token. The subject token must be valid. + + .. warning: This is not cached by the calling Google credential, so caching logic should be implemented in the supplier. Args: context (google.auth.externalaccount.SupplierContext): The context object @@ -167,8 +168,9 @@ def __init__( self, audience, subject_token_type, - token_url, - credential_source, + token_url=external_account._DEFAULT_TOKEN_URL, + credential_source=None, + subject_token_supplier=None, *args, **kwargs ): @@ -176,11 +178,18 @@ def __init__( Args: audience (str): The STS audience field. - subject_token_type (str): The subject token type. - token_url (str): The STS endpoint URL. - credential_source (Mapping): The credential source dictionary used to + subject_token_type (str): The subject token type based on the Oauth2.0 token exchange spec. + Expected values include:: + + “urn:ietf:params:oauth:token-type:jwt” + “urn:ietf:params:oauth:token-type:id-token” + “urn:ietf:params:oauth:token-type:saml2” + + token_url (Optional [str]): The STS endpoint URL. If not provided, will default to "https://sts.googleapis.com/v1/token". + credential_source (Optional [Mapping]): The credential source dictionary used to provide instructions on how to retrieve external credential to be - exchanged for Google access tokens. + exchanged for Google access tokens. Either a credential source or + a subject token supplier must be provided. Example credential_source for url-sourced credential:: @@ -198,6 +207,10 @@ def __init__( { "file": "/path/to/token/file.txt" } + subject_token_supplier (Optional [SubjectTokenSupplier]): Optional subject token supplier. + This will be called to supply a valid subject token which will then + be exchanged for Google access tokens. Either a subject token supplier + or a credential source must be provided. args (List): Optional positional arguments passed into the underlying :meth:`~external_account.Credentials.__init__` method. kwargs (Mapping): Optional keyword arguments passed into the underlying :meth:`~external_account.Credentials.__init__` method. @@ -219,10 +232,25 @@ def __init__( *args, **kwargs ) - if not isinstance(credential_source, Mapping): + if credential_source is None and subject_token_supplier is None: + raise exceptions.InvalidValue( + "A valid credential source or a subject token supplier must be provided." + ) + if credential_source is not None and subject_token_supplier is not None: + raise exceptions.InvalidValue( + "Identity pool credential cannot have both a credential source and a subject token supplier." + ) + + if subject_token_supplier is not None: + self._subject_token_supplier = subject_token_supplier self._credential_source_file = None self._credential_source_url = None else: + if not isinstance(credential_source, Mapping): + self._credential_source_executable = None + raise exceptions.MalformedError( + "Invalid credential_source. The credential_source is not a dict." + ) self._credential_source_file = credential_source.get("file") self._credential_source_url = credential_source.get("url") self._credential_source_headers = credential_source.get("headers") @@ -256,28 +284,28 @@ def __init__( else: self._credential_source_field_name = None - if self._credential_source_file and self._credential_source_url: - raise exceptions.MalformedError( - "Ambiguous credential_source. 'file' is mutually exclusive with 'url'." - ) - if not self._credential_source_file and not self._credential_source_url: - raise exceptions.MalformedError( - "Missing credential_source. A 'file' or 'url' must be provided." - ) + if self._credential_source_file and self._credential_source_url: + raise exceptions.MalformedError( + "Ambiguous credential_source. 'file' is mutually exclusive with 'url'." + ) + if not self._credential_source_file and not self._credential_source_url: + raise exceptions.MalformedError( + "Missing credential_source. A 'file' or 'url' must be provided." + ) - if self._credential_source_file: - self._subject_token_supplier = _FileSupplier( - self._credential_source_file, - self._credential_source_format_type, - self._credential_source_field_name, - ) - else: - self._subject_token_supplier = _UrlSupplier( - self._credential_source_url, - self._credential_source_format_type, - self._credential_source_field_name, - self._credential_source_headers, - ) + if self._credential_source_file: + self._subject_token_supplier = _FileSupplier( + self._credential_source_file, + self._credential_source_format_type, + self._credential_source_field_name, + ) + else: + self._subject_token_supplier = _UrlSupplier( + self._credential_source_url, + self._credential_source_format_type, + self._credential_source_field_name, + self._credential_source_headers, + ) @_helpers.copy_docstring(external_account.Credentials) def retrieve_subject_token(self, request): @@ -295,8 +323,20 @@ def _create_default_metrics_options(self): metrics_options["source"] = "file" else: metrics_options["source"] = "url" + else: + metrics_options["source"] = "programmatic" return metrics_options + def _has_custom_supplier(self): + return self._credential_source is None + + def _constructor_args(self): + args = super(Credentials, self)._constructor_args() + # If a custom supplier was used, append it to the args dict. + if self._has_custom_supplier(): + args.update({"subject_token_supplier": self._subject_token_supplier}) + return args + @classmethod def from_info(cls, info, **kwargs): """Creates an Identity Pool Credentials instance from parsed external account info. @@ -313,6 +353,8 @@ def from_info(cls, info, **kwargs): Raises: ValueError: For invalid parameters. """ + subject_token_supplier = info.get("subject_token_supplier") + kwargs.update({"subject_token_supplier": subject_token_supplier}) return super(Credentials, cls).from_info(info, **kwargs) @classmethod diff --git a/tests/test_aws.py b/tests/test_aws.py index 4d6268d37..561482031 100644 --- a/tests/test_aws.py +++ b/tests/test_aws.py @@ -21,7 +21,7 @@ import mock import pytest # type: ignore -from google.auth import _helpers +from google.auth import _helpers, external_account from google.auth import aws from google.auth import environment_vars from google.auth import exceptions @@ -668,6 +668,36 @@ def test_get_request_options_with_missing_hostname_url(self): assert excinfo.match(r"Invalid AWS service URL") +class TestAwsSecurityCredentialsSupplier(aws.AwsSecurityCredentialsSupplier): + def __init__( + self, + security_credentials=None, + region=None, + credentials_exception=None, + region_exception=None, + expected_context=None, + ): + self._security_credentials = security_credentials + self._region = region + self._credentials_exception = credentials_exception + self._region_exception = region_exception + self._expected_context = expected_context + + def get_aws_security_credentials(self, context, request): + if self._expected_context is not None: + assert self._expected_context == context + if self._credentials_exception is not None: + raise self._credentials_exception + return self._security_credentials + + def get_aws_region(self, context, request): + if self._expected_context is not None: + assert self._expected_context == context + if self._region_exception is not None: + raise self._region_exception + return self._region + + class TestCredentials(object): AWS_REGION = "us-east-2" AWS_ROLE = "gcp-aws-role" @@ -840,7 +870,8 @@ def make_mock_request( @classmethod def make_credentials( cls, - credential_source, + credential_source=None, + aws_security_credentials_supplier=None, token_url=TOKEN_URL, token_info_url=TOKEN_INFO_URL, client_id=None, @@ -857,6 +888,7 @@ def make_credentials( token_info_url=token_info_url, service_account_impersonation_url=service_account_impersonation_url, credential_source=credential_source, + aws_security_credentials_supplier=aws_security_credentials_supplier, client_id=client_id, client_secret=client_secret, quota_project_id=quota_project_id, @@ -935,6 +967,7 @@ def test_from_info_full_options(self, mock_init): client_id=CLIENT_ID, client_secret=CLIENT_SECRET, credential_source=self.CREDENTIAL_SOURCE, + aws_security_credentials_supplier=None, quota_project_id=QUOTA_PROJECT_ID, workforce_pool_user_project=None, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -963,6 +996,38 @@ def test_from_info_required_options_only(self, mock_init): client_id=None, client_secret=None, credential_source=self.CREDENTIAL_SOURCE, + aws_security_credentials_supplier=None, + quota_project_id=None, + workforce_pool_user_project=None, + universe_domain=DEFAULT_UNIVERSE_DOMAIN, + ) + + @mock.patch.object(aws.Credentials, "__init__", return_value=None) + def test_from_info_supplier(self, mock_init): + supplier = TestAwsSecurityCredentialsSupplier() + + credentials = aws.Credentials.from_info( + { + "audience": AUDIENCE, + "subject_token_type": SUBJECT_TOKEN_TYPE, + "token_url": TOKEN_URL, + "aws_security_credentials_supplier": supplier, + } + ) + + # Confirm aws.Credentials instance initialized with the expected parameters. + assert isinstance(credentials, aws.Credentials) + mock_init.assert_called_once_with( + audience=AUDIENCE, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + token_info_url=None, + service_account_impersonation_url=None, + service_account_impersonation_options={}, + client_id=None, + client_secret=None, + credential_source=None, + aws_security_credentials_supplier=supplier, quota_project_id=None, workforce_pool_user_project=None, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -999,6 +1064,7 @@ def test_from_file_full_options(self, mock_init, tmpdir): client_id=CLIENT_ID, client_secret=CLIENT_SECRET, credential_source=self.CREDENTIAL_SOURCE, + aws_security_credentials_supplier=None, quota_project_id=QUOTA_PROJECT_ID, workforce_pool_user_project=None, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -1028,6 +1094,7 @@ def test_from_file_required_options_only(self, mock_init, tmpdir): client_id=None, client_secret=None, credential_source=self.CREDENTIAL_SOURCE, + aws_security_credentials_supplier=None, quota_project_id=None, workforce_pool_user_project=None, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -1042,6 +1109,27 @@ def test_constructor_invalid_credential_source(self): assert excinfo.match(r"No valid AWS 'credential_source' provided") + def test_constructor_invalid_credential_source_and_supplier(self): + # Provide both a credential source and supplier. + with pytest.raises(ValueError) as excinfo: + self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE, + aws_security_credentials_supplier="test", + ) + + assert excinfo.match( + r"AWS credential cannot have both a credential source and an AWS security credentials supplier." + ) + + def test_constructor_invalid_no_credential_source_or_supplier(self): + # Provide no credential source or supplier. + with pytest.raises(ValueError) as excinfo: + self.make_credentials() + + assert excinfo.match( + r"A valid credential source or AWS security credentials supplier must be provided." + ) + def test_constructor_invalid_environment_id(self): # Provide invalid environment_id. credential_source = self.CREDENTIAL_SOURCE.copy() @@ -2086,3 +2174,249 @@ def test_refresh_with_retrieve_subject_token_error(self): credentials.refresh(request) assert excinfo.match(r"Unable to retrieve AWS region") + + @mock.patch("google.auth._helpers.utcnow") + def test_retrieve_subject_token_success_with_supplier(self, utcnow): + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + request = self.make_mock_request() + + security_credentials = aws.AwsSecurityCredentials( + ACCESS_KEY_ID, SECRET_ACCESS_KEY + ) + supplier = TestAwsSecurityCredentialsSupplier( + security_credentials=security_credentials, region=self.AWS_REGION + ) + + credentials = self.make_credentials(aws_security_credentials_supplier=supplier) + + subject_token = credentials.retrieve_subject_token(request) + assert subject_token == self.make_serialized_aws_signed_request( + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY) + ) + + @mock.patch("google.auth._helpers.utcnow") + def test_retrieve_subject_token_success_with_supplier_session_token(self, utcnow): + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + request = self.make_mock_request() + + security_credentials = aws.AwsSecurityCredentials( + ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN + ) + supplier = TestAwsSecurityCredentialsSupplier( + security_credentials=security_credentials, region=self.AWS_REGION + ) + + credentials = self.make_credentials(aws_security_credentials_supplier=supplier) + + subject_token = credentials.retrieve_subject_token(request) + assert subject_token == self.make_serialized_aws_signed_request( + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) + ) + + @mock.patch("google.auth._helpers.utcnow") + def test_retrieve_subject_token_success_with_supplier_correct_context(self, utcnow): + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + request = self.make_mock_request() + expected_context = external_account.SupplierContext( + SUBJECT_TOKEN_TYPE, AUDIENCE + ) + + security_credentials = aws.AwsSecurityCredentials( + ACCESS_KEY_ID, SECRET_ACCESS_KEY + ) + supplier = TestAwsSecurityCredentialsSupplier( + security_credentials=security_credentials, + region=self.AWS_REGION, + expected_context=expected_context, + ) + + credentials = self.make_credentials(aws_security_credentials_supplier=supplier) + + credentials.retrieve_subject_token(request) + + def test_retrieve_subject_token_error_with_supplier(self): + request = self.make_mock_request() + expected_exception = exceptions.RefreshError("Test error") + supplier = TestAwsSecurityCredentialsSupplier( + region=self.AWS_REGION, credentials_exception=expected_exception + ) + + credentials = self.make_credentials(aws_security_credentials_supplier=supplier) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.refresh(request) + + assert excinfo.match(r"Test error") + + def test_retrieve_subject_token_error_with_supplier_region(self): + request = self.make_mock_request() + expected_exception = exceptions.RefreshError("Test error") + security_credentials = aws.AwsSecurityCredentials( + ACCESS_KEY_ID, SECRET_ACCESS_KEY + ) + supplier = TestAwsSecurityCredentialsSupplier( + security_credentials=security_credentials, + region_exception=expected_exception, + ) + + credentials = self.make_credentials(aws_security_credentials_supplier=supplier) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.refresh(request) + + assert excinfo.match(r"Test error") + + @mock.patch( + "google.auth.metrics.python_and_auth_lib_version", + return_value=LANG_LIBRARY_METRICS_HEADER_VALUE, + ) + @mock.patch("google.auth._helpers.utcnow") + def test_refresh_success_with_supplier_with_impersonation( + self, utcnow, mock_auth_lib_value + ): + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600) + ).isoformat("T") + "Z" + expected_subject_token = self.make_serialized_aws_signed_request( + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) + ) + token_headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": "Basic " + BASIC_AUTH_ENCODING, + "x-goog-api-client": "gl-python/3.7 auth/1.1 google-byoid-sdk sa-impersonation/true config-lifetime/false source/programmatic", + } + token_request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "scope": "https://www.googleapis.com/auth/iam", + "subject_token": expected_subject_token, + "subject_token_type": SUBJECT_TOKEN_TYPE, + } + # Service account impersonation request/response. + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + impersonation_headers = { + "Content-Type": "application/json", + "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]), + "x-goog-user-project": QUOTA_PROJECT_ID, + "x-goog-api-client": IMPERSONATE_ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, + "x-allowed-locations": "0x0", + } + impersonation_request_data = { + "delegates": None, + "scope": SCOPES, + "lifetime": "3600s", + } + request = self.make_mock_request( + token_status=http_client.OK, + token_data=self.SUCCESS_RESPONSE, + impersonation_status=http_client.OK, + impersonation_data=impersonation_response, + ) + + supplier = TestAwsSecurityCredentialsSupplier( + security_credentials=aws.AwsSecurityCredentials( + ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN + ), + region=self.AWS_REGION, + ) + + credentials = self.make_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + aws_security_credentials_supplier=supplier, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + quota_project_id=QUOTA_PROJECT_ID, + scopes=SCOPES, + # Default scopes should be ignored. + default_scopes=["ignored"], + ) + + credentials.refresh(request) + + assert len(request.call_args_list) == 2 + # First request should be sent to GCP STS endpoint. + self.assert_token_request_kwargs( + request.call_args_list[0][1], token_headers, token_request_data + ) + # Second request should be sent to iamcredentials endpoint for service + # account impersonation. + self.assert_impersonation_request_kwargs( + request.call_args_list[1][1], + impersonation_headers, + impersonation_request_data, + ) + assert credentials.token == impersonation_response["accessToken"] + assert credentials.quota_project_id == QUOTA_PROJECT_ID + assert credentials.scopes == SCOPES + assert credentials.default_scopes == ["ignored"] + + @mock.patch( + "google.auth.metrics.python_and_auth_lib_version", + return_value=LANG_LIBRARY_METRICS_HEADER_VALUE, + ) + @mock.patch("google.auth._helpers.utcnow") + def test_refresh_success_with_supplier(self, utcnow, mock_auth_lib_value): + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + expected_subject_token = self.make_serialized_aws_signed_request( + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) + ) + token_headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": "Basic " + BASIC_AUTH_ENCODING, + "x-goog-api-client": "gl-python/3.7 auth/1.1 google-byoid-sdk sa-impersonation/false config-lifetime/false source/programmatic", + } + token_request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "scope": " ".join(SCOPES), + "subject_token": expected_subject_token, + "subject_token_type": SUBJECT_TOKEN_TYPE, + } + request = self.make_mock_request( + token_status=http_client.OK, token_data=self.SUCCESS_RESPONSE + ) + + supplier = TestAwsSecurityCredentialsSupplier( + security_credentials=aws.AwsSecurityCredentials( + ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN + ), + region=self.AWS_REGION, + ) + + credentials = self.make_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + aws_security_credentials_supplier=supplier, + quota_project_id=QUOTA_PROJECT_ID, + scopes=SCOPES, + # Default scopes should be ignored. + default_scopes=["ignored"], + ) + + credentials.refresh(request) + + assert len(request.call_args_list) == 1 + # First request should be sent to GCP STS endpoint. + self.assert_token_request_kwargs( + request.call_args_list[0][1], token_headers, token_request_data + ) + assert credentials.token == self.SUCCESS_RESPONSE["access_token"] + assert credentials.quota_project_id == QUOTA_PROJECT_ID + assert credentials.scopes == SCOPES + assert credentials.default_scopes == ["ignored"] diff --git a/tests/test_external_account.py b/tests/test_external_account.py index 586d06efb..c458b21b6 100644 --- a/tests/test_external_account.py +++ b/tests/test_external_account.py @@ -477,16 +477,6 @@ def test_with_quota_project_full_options_propagated(self): universe_domain=DEFAULT_UNIVERSE_DOMAIN, ) - def test_with_invalid_impersonation_target_principal(self): - invalid_url = "https://iamcredentials.googleapis.com/v1/invalid" - - with pytest.raises(exceptions.RefreshError) as excinfo: - self.make_credentials(service_account_impersonation_url=invalid_url) - - assert excinfo.match( - r"Unable to determine target principal from service account impersonation URL." - ) - def test_info(self): credentials = self.make_credentials(universe_domain="dummy_universe.com") @@ -1069,6 +1059,21 @@ def test_refresh_impersonation_without_client_auth_error(self): assert not credentials.expired assert credentials.token is None + def test_refresh_impersonation_invalid_impersonated_url_error(self): + credentials = self.make_credentials( + service_account_impersonation_url="https://iamcredentials.googleapis.com/v1/invalid", + scopes=self.SCOPES, + ) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.refresh(None) + + assert excinfo.match( + r"Unable to determine target principal from service account impersonation URL." + ) + assert not credentials.expired + assert credentials.token is None + @mock.patch( "google.auth.metrics.python_and_auth_lib_version", return_value=LANG_LIBRARY_METRICS_HEADER_VALUE, diff --git a/tests/test_identity_pool.py b/tests/test_identity_pool.py index be30c4e9b..ac1d9a0bb 100644 --- a/tests/test_identity_pool.py +++ b/tests/test_identity_pool.py @@ -21,7 +21,7 @@ import mock import pytest # type: ignore -from google.auth import _helpers +from google.auth import _helpers, external_account from google.auth import exceptions from google.auth import identity_pool from google.auth import metrics @@ -151,6 +151,22 @@ ] +class TestSubjectTokenSupplier(identity_pool.SubjectTokenSupplier): + def __init__( + self, subject_token=None, subject_token_exception=None, expected_context=None + ): + self._subject_token = subject_token + self._subject_token_exception = subject_token_exception + self._expected_context = expected_context + + def get_subject_token(self, context, request): + if self._expected_context is not None: + assert self._expected_context == context + if self._subject_token_exception is not None: + raise self._subject_token_exception + return self._subject_token + + class TestCredentials(object): CREDENTIAL_SOURCE_TEXT = {"file": SUBJECT_TOKEN_TEXT_FILE} CREDENTIAL_SOURCE_JSON = { @@ -273,10 +289,13 @@ def assert_underlying_credentials_refresh( else: metrics_options["sa-impersonation"] = "false" metrics_options["config-lifetime"] = "false" - if credentials._credential_source_file: - metrics_options["source"] = "file" + if credentials._credential_source: + if credentials._credential_source_file: + metrics_options["source"] = "file" + else: + metrics_options["source"] = "url" else: - metrics_options["source"] = "url" + metrics_options["source"] = "programmatic" token_headers["x-goog-api-client"] = metrics.byoid_metrics_header( metrics_options @@ -386,6 +405,7 @@ def make_credentials( default_scopes=None, service_account_impersonation_url=None, credential_source=None, + subject_token_supplier=None, workforce_pool_user_project=None, ): return identity_pool.Credentials( @@ -395,6 +415,7 @@ def make_credentials( token_info_url=token_info_url, service_account_impersonation_url=service_account_impersonation_url, credential_source=credential_source, + subject_token_supplier=subject_token_supplier, client_id=client_id, client_secret=client_secret, quota_project_id=quota_project_id, @@ -432,6 +453,7 @@ def test_from_info_full_options(self, mock_init): client_id=CLIENT_ID, client_secret=CLIENT_SECRET, credential_source=self.CREDENTIAL_SOURCE_TEXT, + subject_token_supplier=None, quota_project_id=QUOTA_PROJECT_ID, workforce_pool_user_project=None, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -460,6 +482,38 @@ def test_from_info_required_options_only(self, mock_init): client_id=None, client_secret=None, credential_source=self.CREDENTIAL_SOURCE_TEXT, + subject_token_supplier=None, + quota_project_id=None, + workforce_pool_user_project=None, + universe_domain=DEFAULT_UNIVERSE_DOMAIN, + ) + + @mock.patch.object(identity_pool.Credentials, "__init__", return_value=None) + def test_from_info_supplier(self, mock_init): + supplier = TestSubjectTokenSupplier() + + credentials = identity_pool.Credentials.from_info( + { + "audience": AUDIENCE, + "subject_token_type": SUBJECT_TOKEN_TYPE, + "token_url": TOKEN_URL, + "subject_token_supplier": supplier, + } + ) + + # Confirm identity_pool.Credentials instantiated with expected attributes. + assert isinstance(credentials, identity_pool.Credentials) + mock_init.assert_called_once_with( + audience=AUDIENCE, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + token_info_url=None, + service_account_impersonation_url=None, + service_account_impersonation_options={}, + client_id=None, + client_secret=None, + credential_source=None, + subject_token_supplier=supplier, quota_project_id=None, workforce_pool_user_project=None, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -489,6 +543,7 @@ def test_from_info_workforce_pool(self, mock_init): client_id=None, client_secret=None, credential_source=self.CREDENTIAL_SOURCE_TEXT, + subject_token_supplier=None, quota_project_id=None, workforce_pool_user_project=WORKFORCE_POOL_USER_PROJECT, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -524,6 +579,7 @@ def test_from_file_full_options(self, mock_init, tmpdir): client_id=CLIENT_ID, client_secret=CLIENT_SECRET, credential_source=self.CREDENTIAL_SOURCE_TEXT, + subject_token_supplier=None, quota_project_id=QUOTA_PROJECT_ID, workforce_pool_user_project=None, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -553,6 +609,7 @@ def test_from_file_required_options_only(self, mock_init, tmpdir): client_id=None, client_secret=None, credential_source=self.CREDENTIAL_SOURCE_TEXT, + subject_token_supplier=None, quota_project_id=None, workforce_pool_user_project=None, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -583,6 +640,7 @@ def test_from_file_workforce_pool(self, mock_init, tmpdir): client_id=None, client_secret=None, credential_source=self.CREDENTIAL_SOURCE_TEXT, + subject_token_supplier=None, quota_project_id=None, workforce_pool_user_project=WORKFORCE_POOL_USER_PROJECT, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -633,7 +691,29 @@ def test_constructor_invalid_credential_source(self): with pytest.raises(ValueError) as excinfo: self.make_credentials(credential_source="non-dict") - assert excinfo.match(r"Missing credential_source") + assert excinfo.match( + r"Invalid credential_source. The credential_source is not a dict." + ) + + def test_constructor_invalid_no_credential_source_or_supplier(self): + with pytest.raises(ValueError) as excinfo: + self.make_credentials() + + assert excinfo.match( + r"A valid credential source or a subject token supplier must be provided." + ) + + def test_constructor_invalid_both_credential_source_and_supplier(self): + supplier = TestSubjectTokenSupplier() + with pytest.raises(ValueError) as excinfo: + self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_TEXT, + subject_token_supplier=supplier, + ) + + assert excinfo.match( + r"Identity pool credential cannot have both a credential source and a subject token supplier." + ) def test_constructor_invalid_credential_source_format_type(self): credential_source = {"format": {"type": "xml"}} @@ -1297,3 +1377,78 @@ def test_refresh_with_retrieve_subject_token_error_url(self): self.CREDENTIAL_URL, "not_found" ) ) + + def test_retrieve_subject_token_supplier(self): + supplier = TestSubjectTokenSupplier(subject_token=JSON_FILE_SUBJECT_TOKEN) + + credentials = self.make_credentials(subject_token_supplier=supplier) + + subject_token = credentials.retrieve_subject_token(None) + + assert subject_token == JSON_FILE_SUBJECT_TOKEN + + def test_retrieve_subject_token_supplier_correct_context(self): + supplier = TestSubjectTokenSupplier( + subject_token=JSON_FILE_SUBJECT_TOKEN, + expected_context=external_account.SupplierContext( + SUBJECT_TOKEN_TYPE, AUDIENCE + ), + ) + + credentials = self.make_credentials(subject_token_supplier=supplier) + + credentials.retrieve_subject_token(None) + + def test_retrieve_subject_token_supplier_error(self): + expected_exception = exceptions.RefreshError("test error") + supplier = TestSubjectTokenSupplier(subject_token_exception=expected_exception) + + credentials = self.make_credentials(subject_token_supplier=supplier) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.refresh(self.make_mock_request(token_data=JSON_FILE_CONTENT)) + + assert excinfo.match("test error") + + def test_refresh_success_supplier_with_impersonation_url(self): + # Initialize credentials with service account impersonation and a supplier. + supplier = TestSubjectTokenSupplier(subject_token=JSON_FILE_SUBJECT_TOKEN) + credentials = self.make_credentials( + subject_token_supplier=supplier, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=SCOPES, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=AUDIENCE, + subject_token=TEXT_FILE_SUBJECT_TOKEN, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + basic_auth_encoding=None, + quota_project_id=None, + used_scopes=SCOPES, + scopes=SCOPES, + default_scopes=None, + ) + + def test_refresh_success_supplier_without_impersonation_url(self): + # Initialize supplier credentials without service account impersonation. + supplier = TestSubjectTokenSupplier(subject_token=JSON_FILE_SUBJECT_TOKEN) + credentials = self.make_credentials( + subject_token_supplier=supplier, scopes=SCOPES + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=AUDIENCE, + subject_token=TEXT_FILE_SUBJECT_TOKEN, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + basic_auth_encoding=None, + quota_project_id=None, + used_scopes=SCOPES, + scopes=SCOPES, + default_scopes=None, + )