diff --git a/google/auth/_default.py b/google/auth/_default.py index 894c5c222..acf76e924 100644 --- a/google/auth/_default.py +++ b/google/auth/_default.py @@ -311,16 +311,16 @@ def _get_external_account_credentials(info, filename, scopes=None, request=None) """ # There are currently 2 types of external_account credentials. try: - # Check if configuration corresponds to an Identity Pool credentials. - from google.auth import identity_pool + # Check if configuration corresponds to an AWS credentials. + from google.auth import aws - credentials = identity_pool.Credentials.from_info(info, scopes=scopes) + credentials = aws.Credentials.from_info(info, scopes=scopes) except ValueError: try: - # Check if configuration corresponds to an AWS credentials. - from google.auth import aws + # Check if configuration corresponds to an Identity Pool credentials. + from google.auth import identity_pool - credentials = aws.Credentials.from_info(info, scopes=scopes) + credentials = identity_pool.Credentials.from_info(info, scopes=scopes) except ValueError: # If the configuration is invalid or does not correspond to any # supported external_account credentials, raise an error. diff --git a/google/auth/identity_pool.py b/google/auth/identity_pool.py index 43df96273..5eed7a77c 100644 --- a/google/auth/identity_pool.py +++ b/google/auth/identity_pool.py @@ -15,17 +15,24 @@ """Identity Pool Credentials. This module provides credentials that are initialized using external_account -arguments which are typically loaded from the external credentials file. -Unlike other Credentials that can be initialized with a list of explicit -arguments, secrets or credentials, external account clients use the -environment and hints/guidelines provided by the external_account JSON -file to retrieve credentials and exchange them for Google access tokens. +arguments which are typically loaded from an external credentials file or +an external credentials url. Unlike other Credentials that can be initialized +with a list of explicit arguments, secrets or credentials, external account +clients use the environment and hints/guidelines provided by the +external_account JSON file to retrieve credentials and exchange them for Google +access tokens. Identity Pool Credentials are used with external credentials (eg. OIDC ID tokens) retrieved from a file location, typical for K8s workloads -registered with Hub with Hub workload identity enabled. +registered with Hub with Hub workload identity enabled, or retrieved from an +url, typical for Azure based workflows. """ +try: + from collections.abc import Mapping +# Python 2.7 compatibility +except ImportError: # pragma: NO COVER + from collections import Mapping import io import json import os @@ -36,10 +43,7 @@ class Credentials(external_account.Credentials): - """File-sourced external account credentials. - This is typically used to exchange OIDC ID tokens in K8s (file-sourced - credentials) for Google access tokens. - """ + """External account credentials sourced from files and urls.""" def __init__( self, @@ -53,7 +57,7 @@ def __init__( quota_project_id=None, scopes=None, ): - """Instantiates a file-sourced external account credentials object. + """Instantiates an external account credentials object from a file/url. Args: audience (str): The STS audience field. @@ -61,7 +65,19 @@ def __init__( token_url (str): The STS endpoint URL. credential_source (Mapping): The credential source dictionary used to provide instructions on how to retrieve external credential to be - exchanged for Google access tokens.. + exchanged for Google access tokens + Example credential_source's: + { + "url": "http://www.example.com", + "format": { + "type": "json", + "subject_token_field_name": "access_token", + }, + "headers": {"foo": "bar"}, + } + { + "file": "/path/to/token/file.txt" + } service_account_impersonation_url (Optional[str]): The optional service account impersonation getAccessToken URL. client_id (Optional[str]): The optional client ID. @@ -91,15 +107,25 @@ def __init__( quota_project_id=quota_project_id, scopes=scopes, ) - if isinstance(credential_source, dict): + if not isinstance(credential_source, Mapping): + self._credential_source_file = None + self._credential_source_url = None + else: self._credential_source_file = credential_source.get("file") + self._credential_source_url = credential_source.get("url") self._credential_source_headers = credential_source.get("headers") - credential_source_format = credential_source.get("format") or {} + credential_source_format = credential_source.get("format", {}) # Get credential_source format type. When not provided, this # defaults to text. self._credential_source_format_type = ( credential_source_format.get("type") or "text" ) + # environment_id is only supported in AWS or dedicated future external + # account credentials. + if "environment_id" in credential_source: + raise ValueError( + "Invalid Identity Pool credential_source field 'environment_id'" + ) if self._credential_source_format_type not in ["text", "json"]: raise ValueError( "Invalid credential_source format '{}'".format( @@ -117,28 +143,60 @@ def __init__( ) else: self._credential_source_field_name = None - else: - self._credential_source_file = None - if not self._credential_source_file: - raise ValueError("Missing credential_source file") + + if self._credential_source_file and self._credential_source_url: + raise ValueError( + "Ambiguous credential_source. 'file' is mutually exclusive with 'url'." + ) + if not self._credential_source_file and not self._credential_source_url: + raise ValueError( + "Missing credential_source. A 'file' or 'url' must be provided." + ) @_helpers.copy_docstring(external_account.Credentials) def retrieve_subject_token(self, request): - return self._get_token_file( - self._credential_source_file, + return self._parse_token_data( + self._get_token_data(request), self._credential_source_format_type, self._credential_source_field_name, ) - def _get_token_file( - self, filename, format_type="text", subject_token_field_name=None - ): + def _get_token_data(self, request): + if self._credential_source_file: + return self._get_file_data(self._credential_source_file) + else: + return self._get_url_data( + request, self._credential_source_url, self._credential_source_headers + ) + + def _get_file_data(self, filename): if not os.path.exists(filename): raise exceptions.RefreshError("File '{}' was not found.".format(filename)) with io.open(filename, "r", encoding="utf-8") as file_obj: - content = file_obj.read() + return file_obj.read(), filename + + def _get_url_data(self, request, url, headers): + response = request(url=url, method="GET", headers=headers) + + # support both string and bytes type response.data + response_body = ( + response.data.decode("utf-8") + if hasattr(response.data, "decode") + else response.data + ) + if response.status != 200: + raise exceptions.RefreshError( + "Unable to retrieve Identity Pool subject token", response_body + ) + + return response_body, url + + def _parse_token_data( + self, token_content, format_type="text", subject_token_field_name=None + ): + content, filename = token_content if format_type == "text": token = content else: diff --git a/tests/test_identity_pool.py b/tests/test_identity_pool.py index 47a6fbdb3..fc8ad8de9 100644 --- a/tests/test_identity_pool.py +++ b/tests/test_identity_pool.py @@ -47,8 +47,8 @@ TEXT_FILE_SUBJECT_TOKEN = fh.read() with open(SUBJECT_TOKEN_JSON_FILE) as fh: - content = json.load(fh) - JSON_FILE_SUBJECT_TOKEN = content.get(SUBJECT_TOKEN_FIELD_NAME) + JSON_FILE_CONTENT = json.load(fh) + JSON_FILE_SUBJECT_TOKEN = JSON_FILE_CONTENT.get(SUBJECT_TOKEN_FIELD_NAME) TOKEN_URL = "https://sts.googleapis.com/v1/token" SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt" @@ -61,6 +61,12 @@ class TestCredentials(object): "file": SUBJECT_TOKEN_JSON_FILE, "format": {"type": "json", "subject_token_field_name": "access_token"}, } + CREDENTIAL_URL = "http://fakeurl.com" + CREDENTIAL_SOURCE_TEXT_URL = {"url": CREDENTIAL_URL} + CREDENTIAL_SOURCE_JSON_URL = { + "url": CREDENTIAL_URL, + "format": {"type": "json", "subject_token_field_name": "access_token"}, + } SUCCESS_RESPONSE = { "access_token": "ACCESS_TOKEN", "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", @@ -69,35 +75,46 @@ class TestCredentials(object): "scope": " ".join(SCOPES), } + @classmethod + def make_mock_response(cls, status, data): + response = mock.create_autospec(transport.Response, instance=True) + response.status = status + if isinstance(data, dict): + response.data = json.dumps(data).encode("utf-8") + else: + response.data = data + return response + @classmethod def make_mock_request( - cls, - token_status=http_client.OK, - token_data=None, - impersonation_status=None, - impersonation_data=None, + cls, token_status=http_client.OK, token_data=None, *extra_requests ): responses = [] - # STS token exchange request. - token_response = mock.create_autospec(transport.Response, instance=True) - token_response.status = token_status - token_response.data = json.dumps(token_data).encode("utf-8") - responses.append(token_response) - - # If service account impersonation is requested, mock the expected response. - if impersonation_status: - impersonation_response = mock.create_autospec( - transport.Response, instance=True + responses.append(cls.make_mock_response(token_status, token_data)) + + while len(extra_requests) > 0: + # If service account impersonation is requested, mock the expected response. + status, data, extra_requests = ( + extra_requests[0], + extra_requests[1], + extra_requests[2:], ) - impersonation_response.status = impersonation_status - impersonation_response.data = json.dumps(impersonation_data).encode("utf-8") - responses.append(impersonation_response) + responses.append(cls.make_mock_response(status, data)) request = mock.create_autospec(transport.Request) request.side_effect = responses return request + @classmethod + def assert_credential_request_kwargs( + cls, request_kwargs, headers, url=CREDENTIAL_URL + ): + assert request_kwargs["url"] == url + assert request_kwargs["method"] == "GET" + assert request_kwargs["headers"] == headers + assert request_kwargs.get("body", None) is None + @classmethod def assert_token_request_kwargs( cls, request_kwargs, headers, request_data, token_url=TOKEN_URL @@ -138,30 +155,24 @@ def assert_underlying_credentials_refresh( basic_auth_encoding=None, quota_project_id=None, scopes=None, + credential_data=None, ): """Utility to assert that a credentials are initialized with the expected attributes by calling refresh functionality and confirming response matches expected one and that the underlying requests were populated with the expected parameters. """ - - expire_time = ( - _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600) - ).isoformat("T") + "Z" # STS token exchange request/response. token_response = cls.SUCCESS_RESPONSE.copy() - token_headers = {"Content-Type": "application/x-www-form-urlencoded"} if basic_auth_encoding: token_headers["Authorization"] = "Basic " + basic_auth_encoding + if service_account_impersonation_url: token_scopes = "https://www.googleapis.com/auth/iam" - impersonation_status = http_client.OK - total_requests = 2 else: token_scopes = " ".join(scopes or []) - impersonation_status = None - total_requests = 1 + token_request_data = { "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", "audience": audience, @@ -170,44 +181,59 @@ def assert_underlying_credentials_refresh( "subject_token": subject_token, "subject_token_type": subject_token_type, } - # Service account impersonation request/response. - impersonation_response = { - "accessToken": "SA_ACCESS_TOKEN", - "expireTime": expire_time, - } - impersonation_headers = { - "Content-Type": "application/json", - "authorization": "Bearer {}".format(token_response["access_token"]), - } - impersonation_request_data = { - "delegates": None, - "scope": scopes, - "lifetime": "3600s", - } - # Initialize mock request to handle token exchange and service account - # impersonation request. - request = cls.make_mock_request( - token_status=http_client.OK, - token_data=token_response, - impersonation_status=impersonation_status, - impersonation_data=impersonation_response, - ) + + if service_account_impersonation_url: + # Service account impersonation request/response. + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + + datetime.timedelta(seconds=3600) + ).isoformat("T") + "Z" + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + impersonation_headers = { + "Content-Type": "application/json", + "authorization": "Bearer {}".format(token_response["access_token"]), + } + impersonation_request_data = { + "delegates": None, + "scope": scopes, + "lifetime": "3600s", + } + + # Initialize mock request to handle token retrieval, token exchange and + # service account impersonation request. + requests = [] + if credential_data: + requests.append((http_client.OK, credential_data)) + + token_request_index = len(requests) + requests.append((http_client.OK, token_response)) + + if service_account_impersonation_url: + impersonation_request_index = len(requests) + requests.append((http_client.OK, impersonation_response)) + + request = cls.make_mock_request(*[el for req in requests for el in req]) credentials.refresh(request) - assert len(request.call_args_list) == total_requests + assert len(request.call_args_list) == len(requests) + if credential_data: + cls.assert_credential_request_kwargs(request.call_args_list[0].kwargs, None) # Verify token exchange request parameters. cls.assert_token_request_kwargs( - request.call_args_list[0].kwargs, + request.call_args_list[token_request_index].kwargs, token_headers, token_request_data, token_url, ) # Verify service account impersonation request parameters if the request # is processed. - if impersonation_status: + if service_account_impersonation_url: cls.assert_impersonation_request_kwargs( - request.call_args_list[1].kwargs, + request.call_args_list[impersonation_request_index].kwargs, impersonation_headers, impersonation_request_data, service_account_impersonation_url, @@ -352,13 +378,34 @@ def test_constructor_invalid_options(self): with pytest.raises(ValueError) as excinfo: self.make_credentials(credential_source=credential_source) - assert excinfo.match(r"Missing credential_source file") + assert excinfo.match(r"Missing credential_source") + + def test_constructor_invalid_options_url_and_file(self): + credential_source = { + "url": self.CREDENTIAL_URL, + "file": SUBJECT_TOKEN_TEXT_FILE, + } + + with pytest.raises(ValueError) as excinfo: + self.make_credentials(credential_source=credential_source) + + assert excinfo.match(r"Ambiguous credential_source") + + def test_constructor_invalid_options_environment_id(self): + credential_source = {"url": self.CREDENTIAL_URL, "environment_id": "aws1"} + + with pytest.raises(ValueError) as excinfo: + self.make_credentials(credential_source=credential_source) + + assert excinfo.match( + r"Invalid Identity Pool credential_source field 'environment_id'" + ) def test_constructor_invalid_credential_source(self): with pytest.raises(ValueError) as excinfo: self.make_credentials(credential_source="non-dict") - assert excinfo.match(r"Missing credential_source file") + assert excinfo.match(r"Missing credential_source") def test_constructor_invalid_credential_source_format_type(self): credential_source = {"format": {"type": "xml"}} @@ -551,3 +598,198 @@ def test_refresh_with_retrieve_subject_token_error(self): SUBJECT_TOKEN_JSON_FILE, "not_found" ) ) + + def test_retrieve_subject_token_from_url(self): + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_TEXT_URL + ) + request = self.make_mock_request(token_data=TEXT_FILE_SUBJECT_TOKEN) + subject_token = credentials.retrieve_subject_token(request) + + assert subject_token == TEXT_FILE_SUBJECT_TOKEN + self.assert_credential_request_kwargs(request.call_args_list[0].kwargs, None) + + def test_retrieve_subject_token_from_url_with_headers(self): + credentials = self.make_credentials( + credential_source={"url": self.CREDENTIAL_URL, "headers": {"foo": "bar"}} + ) + request = self.make_mock_request(token_data=TEXT_FILE_SUBJECT_TOKEN) + subject_token = credentials.retrieve_subject_token(request) + + assert subject_token == TEXT_FILE_SUBJECT_TOKEN + self.assert_credential_request_kwargs( + request.call_args_list[0].kwargs, {"foo": "bar"} + ) + + def test_retrieve_subject_token_from_url_json(self): + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_JSON_URL + ) + request = self.make_mock_request(token_data=JSON_FILE_CONTENT) + subject_token = credentials.retrieve_subject_token(request) + + assert subject_token == JSON_FILE_SUBJECT_TOKEN + self.assert_credential_request_kwargs(request.call_args_list[0].kwargs, None) + + def test_retrieve_subject_token_from_url_json_with_headers(self): + credentials = self.make_credentials( + credential_source={ + "url": self.CREDENTIAL_URL, + "format": {"type": "json", "subject_token_field_name": "access_token"}, + "headers": {"foo": "bar"}, + } + ) + request = self.make_mock_request(token_data=JSON_FILE_CONTENT) + subject_token = credentials.retrieve_subject_token(request) + + assert subject_token == JSON_FILE_SUBJECT_TOKEN + self.assert_credential_request_kwargs( + request.call_args_list[0].kwargs, {"foo": "bar"} + ) + + def test_retrieve_subject_token_from_url_not_found(self): + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_TEXT_URL + ) + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.retrieve_subject_token( + self.make_mock_request(token_status=404, token_data=JSON_FILE_CONTENT) + ) + + assert excinfo.match("Unable to retrieve Identity Pool subject token") + + def test_retrieve_subject_token_from_url_json_invalid_field(self): + credential_source = { + "url": self.CREDENTIAL_URL, + "format": {"type": "json", "subject_token_field_name": "not_found"}, + } + credentials = self.make_credentials(credential_source=credential_source) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.retrieve_subject_token( + self.make_mock_request(token_data=JSON_FILE_CONTENT) + ) + + assert excinfo.match( + "Unable to parse subject_token from JSON file '{}' using key '{}'".format( + self.CREDENTIAL_URL, "not_found" + ) + ) + + def test_retrieve_subject_token_from_url_json_invalid_format(self): + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_JSON_URL + ) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.retrieve_subject_token(self.make_mock_request(token_data="{")) + + assert excinfo.match( + "Unable to parse subject_token from JSON file '{}' using key '{}'".format( + self.CREDENTIAL_URL, "access_token" + ) + ) + + def test_refresh_text_file_success_without_impersonation_url(self): + credentials = self.make_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + # Test with text format type. + credential_source=self.CREDENTIAL_SOURCE_TEXT_URL, + scopes=SCOPES, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=AUDIENCE, + subject_token=TEXT_FILE_SUBJECT_TOKEN, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=None, + basic_auth_encoding=BASIC_AUTH_ENCODING, + quota_project_id=None, + scopes=SCOPES, + credential_data=TEXT_FILE_SUBJECT_TOKEN, + ) + + def test_refresh_text_file_success_with_impersonation_url(self): + # Initialize credentials with service account impersonation and basic auth. + credentials = self.make_credentials( + # Test with text format type. + credential_source=self.CREDENTIAL_SOURCE_TEXT_URL, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=SCOPES, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=AUDIENCE, + subject_token=TEXT_FILE_SUBJECT_TOKEN, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + basic_auth_encoding=None, + quota_project_id=None, + scopes=SCOPES, + credential_data=TEXT_FILE_SUBJECT_TOKEN, + ) + + def test_refresh_json_file_success_without_impersonation_url(self): + credentials = self.make_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + # Test with JSON format type. + credential_source=self.CREDENTIAL_SOURCE_JSON_URL, + scopes=SCOPES, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=AUDIENCE, + subject_token=JSON_FILE_SUBJECT_TOKEN, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=None, + basic_auth_encoding=BASIC_AUTH_ENCODING, + quota_project_id=None, + scopes=SCOPES, + credential_data=JSON_FILE_CONTENT, + ) + + def test_refresh_json_file_success_with_impersonation_url(self): + # Initialize credentials with service account impersonation and basic auth. + credentials = self.make_credentials( + # Test with JSON format type. + credential_source=self.CREDENTIAL_SOURCE_JSON_URL, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=SCOPES, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=AUDIENCE, + subject_token=JSON_FILE_SUBJECT_TOKEN, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + basic_auth_encoding=None, + quota_project_id=None, + scopes=SCOPES, + credential_data=JSON_FILE_CONTENT, + ) + + def test_refresh_with_retrieve_subject_token_error_url(self): + credential_source = { + "url": self.CREDENTIAL_URL, + "format": {"type": "json", "subject_token_field_name": "not_found"}, + } + credentials = self.make_credentials(credential_source=credential_source) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.refresh(self.make_mock_request(token_data=JSON_FILE_CONTENT)) + + assert excinfo.match( + "Unable to parse subject_token from JSON file '{}' using key '{}'".format( + self.CREDENTIAL_URL, "not_found" + ) + )