From 6d9efce3761e7a529d564e3162414e2d1029e977 Mon Sep 17 00:00:00 2001 From: Ryan Kohler Date: Wed, 18 Nov 2020 11:39:13 -0800 Subject: [PATCH 1/6] feat: Adding support for url-based credential files --- google/auth/identity_pool.py | 48 +++++++++++++++++++++++++-------- tests/test_identity_pool.py | 51 +++++++++++++++++++++++++++++++++--- 2 files changed, 84 insertions(+), 15 deletions(-) diff --git a/google/auth/identity_pool.py b/google/auth/identity_pool.py index 43df96273..28287b40b 100644 --- a/google/auth/identity_pool.py +++ b/google/auth/identity_pool.py @@ -33,6 +33,8 @@ from google.auth import _helpers from google.auth import exceptions from google.auth import external_account +from six.moves import http_client +from six.moves import urllib class Credentials(external_account.Credentials): @@ -52,6 +54,7 @@ def __init__( client_secret=None, quota_project_id=None, scopes=None, + success_codes=(http_client.OK,), ): """Instantiates a file-sourced external account credentials object. @@ -91,9 +94,14 @@ def __init__( quota_project_id=quota_project_id, scopes=scopes, ) - if isinstance(credential_source, dict): + if not isinstance(credential_source, dict): + self._credential_source_file = None + self._credential_source_url = None + else: self._credential_source_file = credential_source.get("file") + self._credential_source_url = credential_source.get("url") self._credential_source_headers = credential_source.get("headers") + self._success_codes = success_codes credential_source_format = credential_source.get("format") or {} # Get credential_source format type. When not provided, this # defaults to text. @@ -117,28 +125,46 @@ def __init__( ) else: self._credential_source_field_name = None - else: - self._credential_source_file = None - if not self._credential_source_file: - raise ValueError("Missing credential_source file") + + if self._credential_source_file and self._credential_source_url: + raise ValueError("Ambiguous credential_source") + if not self._credential_source_file and not self._credential_source_url: + raise ValueError("Missing credential_source") @_helpers.copy_docstring(external_account.Credentials) def retrieve_subject_token(self, request): - return self._get_token_file( - self._credential_source_file, + return self._parse_token_data( + self._get_token_data(), self._credential_source_format_type, self._credential_source_field_name, ) - def _get_token_file( - self, filename, format_type="text", subject_token_field_name=None - ): + def _get_token_data(self): + if self._credential_source_file: + return self._get_file_data(self._credential_source_file) + if self._credential_source_url: + return self._get_url_data(self._credential_source_url) + + def _get_file_data(self, filename): if not os.path.exists(filename): raise exceptions.RefreshError("File '{}' was not found.".format(filename)) with io.open(filename, "r", encoding="utf-8") as file_obj: - content = file_obj.read() + return file_obj.read(), filename + + def _get_url_data(self, url): + response = urllib.request.urlopen(url) + if response.status not in self._success_codes: + raise exceptions.RefreshError("Url '{}' was not found.".format(url)) + return response.read(), url + def _parse_token_data( + self, + token_content, + format_type="text", + subject_token_field_name=None + ): + content, filename = token_content if format_type == "text": token = content else: diff --git a/tests/test_identity_pool.py b/tests/test_identity_pool.py index 47a6fbdb3..a2dd0d287 100644 --- a/tests/test_identity_pool.py +++ b/tests/test_identity_pool.py @@ -47,8 +47,8 @@ TEXT_FILE_SUBJECT_TOKEN = fh.read() with open(SUBJECT_TOKEN_JSON_FILE) as fh: - content = json.load(fh) - JSON_FILE_SUBJECT_TOKEN = content.get(SUBJECT_TOKEN_FIELD_NAME) + JSON_FILE_CONTENT = json.load(fh) + JSON_FILE_SUBJECT_TOKEN = JSON_FILE_CONTENT.get(SUBJECT_TOKEN_FIELD_NAME) TOKEN_URL = "https://sts.googleapis.com/v1/token" SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt" @@ -69,6 +69,16 @@ class TestCredentials(object): "scope": " ".join(SCOPES), } + class FakeResponse: + def __init__(self, data, status=http_client.OK): + self.status = status + self.data = data + if isinstance(data, dict): + self.data = json.dumps(data) + + def read(self): + return self.data + @classmethod def make_mock_request( cls, @@ -352,13 +362,13 @@ def test_constructor_invalid_options(self): with pytest.raises(ValueError) as excinfo: self.make_credentials(credential_source=credential_source) - assert excinfo.match(r"Missing credential_source file") + assert excinfo.match(r"Missing credential_source") def test_constructor_invalid_credential_source(self): with pytest.raises(ValueError) as excinfo: self.make_credentials(credential_source="non-dict") - assert excinfo.match(r"Missing credential_source file") + assert excinfo.match(r"Missing credential_source") def test_constructor_invalid_credential_source_format_type(self): credential_source = {"format": {"type": "xml"}} @@ -551,3 +561,36 @@ def test_refresh_with_retrieve_subject_token_error(self): SUBJECT_TOKEN_JSON_FILE, "not_found" ) ) + + @mock.patch.object(urllib.request, "urlopen", return_value=FakeResponse( + TEXT_FILE_SUBJECT_TOKEN)) + def test_retrieve_subject_token_from_url(self, mock_urlopen): + credential_source = { + "url": "http://fakeurl.com", + } + credentials = self.make_credentials(credential_source=credential_source) + subject_token = credentials.retrieve_subject_token(None) + + assert subject_token == TEXT_FILE_SUBJECT_TOKEN + + @mock.patch.object(urllib.request, "urlopen", return_value=FakeResponse( + JSON_FILE_CONTENT)) + def test_retrieve_subject_token_from_url_json(self, mock_urlopen): + credential_source = { + "url": "http://fakeurl.com", + "format": {"type": "json", "subject_token_field_name": "access_token"}, + } + credentials = self.make_credentials(credential_source=credential_source) + subject_token = credentials.retrieve_subject_token(None) + + assert subject_token == JSON_FILE_SUBJECT_TOKEN + + @mock.patch.object(urllib.request, "urlopen", return_value=FakeResponse( + TEXT_FILE_SUBJECT_TOKEN, status=http_client.NOT_FOUND)) + def test_retrieve_subject_token_from_url_not_found(self, mock_urlopen): + credential_source = { + "url": "http://fakeurl.com", + } + credentials = self.make_credentials(credential_source=credential_source) + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.retrieve_subject_token(None) From 96d61db33b30cbf6cae4e40c03ac3178f3eadd5f Mon Sep 17 00:00:00 2001 From: Ryan Kohler Date: Wed, 2 Dec 2020 09:35:21 -0800 Subject: [PATCH 2/6] Making changes requested by Bassam --- google/auth/identity_pool.py | 60 +++++--- tests/test_identity_pool.py | 286 ++++++++++++++++++++++++++--------- 2 files changed, 252 insertions(+), 94 deletions(-) diff --git a/google/auth/identity_pool.py b/google/auth/identity_pool.py index 28287b40b..8e9e06783 100644 --- a/google/auth/identity_pool.py +++ b/google/auth/identity_pool.py @@ -15,15 +15,17 @@ """Identity Pool Credentials. This module provides credentials that are initialized using external_account -arguments which are typically loaded from the external credentials file. -Unlike other Credentials that can be initialized with a list of explicit -arguments, secrets or credentials, external account clients use the -environment and hints/guidelines provided by the external_account JSON -file to retrieve credentials and exchange them for Google access tokens. +arguments which are typically loaded from an external credentials file or +an external credentials url. Unlike other Credentials that can be initialized +with a list of explicit arguments, secrets or credentials, external account +clients use the environment and hints/guidelines provided by the +external_account JSON file to retrieve credentials and exchange them for Google +access tokens. Identity Pool Credentials are used with external credentials (eg. OIDC ID tokens) retrieved from a file location, typical for K8s workloads -registered with Hub with Hub workload identity enabled. +registered with Hub with Hub workload identity enabled, or retrieved from an +url, typical for AWS and Azure based workflows. """ import io @@ -33,15 +35,10 @@ from google.auth import _helpers from google.auth import exceptions from google.auth import external_account -from six.moves import http_client -from six.moves import urllib class Credentials(external_account.Credentials): - """File-sourced external account credentials. - This is typically used to exchange OIDC ID tokens in K8s (file-sourced - credentials) for Google access tokens. - """ + """External account credentials sourced from files and urls.""" def __init__( self, @@ -54,9 +51,8 @@ def __init__( client_secret=None, quota_project_id=None, scopes=None, - success_codes=(http_client.OK,), ): - """Instantiates a file-sourced external account credentials object. + """Instantiates an external account credentials object from a file/url. Args: audience (str): The STS audience field. @@ -101,7 +97,6 @@ def __init__( self._credential_source_file = credential_source.get("file") self._credential_source_url = credential_source.get("url") self._credential_source_headers = credential_source.get("headers") - self._success_codes = success_codes credential_source_format = credential_source.get("format") or {} # Get credential_source format type. When not provided, this # defaults to text. @@ -134,16 +129,19 @@ def __init__( @_helpers.copy_docstring(external_account.Credentials) def retrieve_subject_token(self, request): return self._parse_token_data( - self._get_token_data(), + self._get_token_data(request), self._credential_source_format_type, self._credential_source_field_name, ) - def _get_token_data(self): + def _get_token_data(self, request): if self._credential_source_file: return self._get_file_data(self._credential_source_file) if self._credential_source_url: - return self._get_url_data(self._credential_source_url) + return self._get_url_data( + request, + self._credential_source_url, + self._credential_source_headers) def _get_file_data(self, filename): if not os.path.exists(filename): @@ -152,11 +150,27 @@ def _get_file_data(self, filename): with io.open(filename, "r", encoding="utf-8") as file_obj: return file_obj.read(), filename - def _get_url_data(self, url): - response = urllib.request.urlopen(url) - if response.status not in self._success_codes: - raise exceptions.RefreshError("Url '{}' was not found.".format(url)) - return response.read(), url + def _get_url_data(self, request, url, headers): + response = request( + url=url, + method="GET", + headers=headers + ) + + # support both string and bytes type response.data + response_body = ( + response.data.decode("utf-8") + if hasattr(response.data, "decode") + else response.data + ) + + if response.status != 200: + raise exceptions.RefreshError( + "Unable to retrieve Identity Pool subject token", + response_body + ) + + return response_body, url def _parse_token_data( self, diff --git a/tests/test_identity_pool.py b/tests/test_identity_pool.py index a2dd0d287..88841b53c 100644 --- a/tests/test_identity_pool.py +++ b/tests/test_identity_pool.py @@ -61,6 +61,11 @@ class TestCredentials(object): "file": SUBJECT_TOKEN_JSON_FILE, "format": {"type": "json", "subject_token_field_name": "access_token"}, } + CREDENTIAL_SOURCE_TEXT_URL = {"url": "http://fakeurl.com"} + CREDENTIAL_SOURCE_JSON_URL = { + "url": "http://fakeurl.com", + "format": {"type": "json", "subject_token_field_name": "access_token"}, + } SUCCESS_RESPONSE = { "access_token": "ACCESS_TOKEN", "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", @@ -69,39 +74,30 @@ class TestCredentials(object): "scope": " ".join(SCOPES), } - class FakeResponse: - def __init__(self, data, status=http_client.OK): - self.status = status - self.data = data - if isinstance(data, dict): - self.data = json.dumps(data) - - def read(self): - return self.data + @classmethod + def make_mock_response(cls, status, data): + response = mock.create_autospec(transport.Response, instance=True) + response.status = status + if isinstance(data, dict): + response.data = json.dumps(data).encode("utf-8") + else: + response.data = data + return response @classmethod def make_mock_request( cls, token_status=http_client.OK, token_data=None, - impersonation_status=None, - impersonation_data=None, + *extra_requests, ): responses = [] - # STS token exchange request. - token_response = mock.create_autospec(transport.Response, instance=True) - token_response.status = token_status - token_response.data = json.dumps(token_data).encode("utf-8") - responses.append(token_response) - - # If service account impersonation is requested, mock the expected response. - if impersonation_status: - impersonation_response = mock.create_autospec( - transport.Response, instance=True - ) - impersonation_response.status = impersonation_status - impersonation_response.data = json.dumps(impersonation_data).encode("utf-8") - responses.append(impersonation_response) + responses.append(cls.make_mock_response(token_status, token_data)) + + while len(extra_requests) > 0: + # If service account impersonation is requested, mock the expected response. + status, data, extra_requests = extra_requests[0], extra_requests[1], extra_requests[2:] + responses.append(cls.make_mock_response(status, data)) request = mock.create_autospec(transport.Request) request.side_effect = responses @@ -148,30 +144,24 @@ def assert_underlying_credentials_refresh( basic_auth_encoding=None, quota_project_id=None, scopes=None, + credential_data=None, ): """Utility to assert that a credentials are initialized with the expected attributes by calling refresh functionality and confirming response matches expected one and that the underlying requests were populated with the expected parameters. """ - - expire_time = ( - _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600) - ).isoformat("T") + "Z" # STS token exchange request/response. token_response = cls.SUCCESS_RESPONSE.copy() - token_headers = {"Content-Type": "application/x-www-form-urlencoded"} if basic_auth_encoding: token_headers["Authorization"] = "Basic " + basic_auth_encoding + if service_account_impersonation_url: token_scopes = "https://www.googleapis.com/auth/iam" - impersonation_status = http_client.OK - total_requests = 2 else: token_scopes = " ".join(scopes or []) - impersonation_status = None - total_requests = 1 + token_request_data = { "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", "audience": audience, @@ -180,44 +170,57 @@ def assert_underlying_credentials_refresh( "subject_token": subject_token, "subject_token_type": subject_token_type, } - # Service account impersonation request/response. - impersonation_response = { - "accessToken": "SA_ACCESS_TOKEN", - "expireTime": expire_time, - } - impersonation_headers = { - "Content-Type": "application/json", - "authorization": "Bearer {}".format(token_response["access_token"]), - } - impersonation_request_data = { - "delegates": None, - "scope": scopes, - "lifetime": "3600s", - } - # Initialize mock request to handle token exchange and service account - # impersonation request. + + if service_account_impersonation_url: + # Service account impersonation request/response. + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600) + ).isoformat("T") + "Z" + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + impersonation_headers = { + "Content-Type": "application/json", + "authorization": "Bearer {}".format(token_response["access_token"]), + } + impersonation_request_data = { + "delegates": None, + "scope": scopes, + "lifetime": "3600s", + } + + # Initialize mock request to handle token retrieval, token exchange and + # service account impersonation request. + requests = [] + if credential_data: + requests.append((http_client.OK, credential_data)) + + token_request_index = len(requests) + requests.append((http_client.OK, token_response)) + + if service_account_impersonation_url: + impersonation_request_index = len(requests) + requests.append((http_client.OK, impersonation_response)) + request = cls.make_mock_request( - token_status=http_client.OK, - token_data=token_response, - impersonation_status=impersonation_status, - impersonation_data=impersonation_response, - ) + *[el for req in requests for el in req]) credentials.refresh(request) - assert len(request.call_args_list) == total_requests + assert len(request.call_args_list) == len(requests) # Verify token exchange request parameters. cls.assert_token_request_kwargs( - request.call_args_list[0].kwargs, + request.call_args_list[token_request_index].kwargs, token_headers, token_request_data, token_url, ) # Verify service account impersonation request parameters if the request # is processed. - if impersonation_status: + if service_account_impersonation_url: cls.assert_impersonation_request_kwargs( - request.call_args_list[1].kwargs, + request.call_args_list[impersonation_request_index].kwargs, impersonation_headers, impersonation_request_data, service_account_impersonation_url, @@ -562,35 +565,176 @@ def test_refresh_with_retrieve_subject_token_error(self): ) ) - @mock.patch.object(urllib.request, "urlopen", return_value=FakeResponse( - TEXT_FILE_SUBJECT_TOKEN)) - def test_retrieve_subject_token_from_url(self, mock_urlopen): + def test_retrieve_subject_token_from_url(self): credential_source = { "url": "http://fakeurl.com", } credentials = self.make_credentials(credential_source=credential_source) - subject_token = credentials.retrieve_subject_token(None) + subject_token = credentials.retrieve_subject_token( + self.make_mock_request(token_data=TEXT_FILE_SUBJECT_TOKEN)) assert subject_token == TEXT_FILE_SUBJECT_TOKEN - @mock.patch.object(urllib.request, "urlopen", return_value=FakeResponse( - JSON_FILE_CONTENT)) - def test_retrieve_subject_token_from_url_json(self, mock_urlopen): + def test_retrieve_subject_token_from_url_json(self): credential_source = { "url": "http://fakeurl.com", "format": {"type": "json", "subject_token_field_name": "access_token"}, } credentials = self.make_credentials(credential_source=credential_source) - subject_token = credentials.retrieve_subject_token(None) + subject_token = credentials.retrieve_subject_token( + self.make_mock_request(token_data=JSON_FILE_CONTENT)) assert subject_token == JSON_FILE_SUBJECT_TOKEN - @mock.patch.object(urllib.request, "urlopen", return_value=FakeResponse( - TEXT_FILE_SUBJECT_TOKEN, status=http_client.NOT_FOUND)) - def test_retrieve_subject_token_from_url_not_found(self, mock_urlopen): + def test_retrieve_subject_token_from_url_not_found(self): credential_source = { "url": "http://fakeurl.com", } credentials = self.make_credentials(credential_source=credential_source) with pytest.raises(exceptions.RefreshError) as excinfo: - credentials.retrieve_subject_token(None) + credentials.retrieve_subject_token( + self.make_mock_request( + token_status=500, + token_data=JSON_FILE_CONTENT)) + + def test_retrieve_subject_token_from_url_json_invalid_field(self): + url = "http://fakeurl.com" + credential_source = { + "url": url, + "format": {"type": "json", "subject_token_field_name": "not_found"}, + } + credentials = self.make_credentials(credential_source=credential_source) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.retrieve_subject_token( + self.make_mock_request(token_data=JSON_FILE_CONTENT)) + + assert excinfo.match( + "Unable to parse subject_token from JSON file '{}' using key '{}'".format( + url, "not_found" + ) + ) + + def test_retrieve_subject_token_from_url_json_invalid_format(self): + url = "http://fakeurl.com" + credential_source = { + "url": url, + "format": {"type": "json", "subject_token_field_name": "access_token"}, + } + credentials = self.make_credentials(credential_source=credential_source) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.retrieve_subject_token( + self.make_mock_request(token_data="{")) + + assert excinfo.match( + "Unable to parse subject_token from JSON file '{}' using key '{}'".format( + url, "access_token" + ) + ) + + def test_refresh_text_file_success_without_impersonation_url(self): + credentials = self.make_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + # Test with text format type. + credential_source=self.CREDENTIAL_SOURCE_TEXT_URL, + scopes=SCOPES, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=AUDIENCE, + subject_token=TEXT_FILE_SUBJECT_TOKEN, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=None, + basic_auth_encoding=BASIC_AUTH_ENCODING, + quota_project_id=None, + scopes=SCOPES, + credential_data=TEXT_FILE_SUBJECT_TOKEN, + ) + + def test_refresh_text_file_success_with_impersonation_url(self): + # Initialize credentials with service account impersonation and basic auth. + credentials = self.make_credentials( + # Test with text format type. + credential_source=self.CREDENTIAL_SOURCE_TEXT_URL, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=SCOPES, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=AUDIENCE, + subject_token=TEXT_FILE_SUBJECT_TOKEN, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + basic_auth_encoding=None, + quota_project_id=None, + scopes=SCOPES, + credential_data=TEXT_FILE_SUBJECT_TOKEN, + ) + + def test_refresh_json_file_success_without_impersonation_url(self): + credentials = self.make_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + # Test with JSON format type. + credential_source=self.CREDENTIAL_SOURCE_JSON_URL, + scopes=SCOPES, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=AUDIENCE, + subject_token=JSON_FILE_SUBJECT_TOKEN, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=None, + basic_auth_encoding=BASIC_AUTH_ENCODING, + quota_project_id=None, + scopes=SCOPES, + credential_data=JSON_FILE_CONTENT, + ) + + def test_refresh_json_file_success_with_impersonation_url(self): + # Initialize credentials with service account impersonation and basic auth. + credentials = self.make_credentials( + # Test with JSON format type. + credential_source=self.CREDENTIAL_SOURCE_JSON_URL, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=SCOPES, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=AUDIENCE, + subject_token=JSON_FILE_SUBJECT_TOKEN, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + basic_auth_encoding=None, + quota_project_id=None, + scopes=SCOPES, + credential_data=JSON_FILE_CONTENT, + ) + + def test_refresh_with_retrieve_subject_token_error_url(self): + url = "http://fakeurl.com" + credential_source = { + "url": url, + "format": {"type": "json", "subject_token_field_name": "not_found"}, + } + credentials = self.make_credentials(credential_source=credential_source) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.refresh( + self.make_mock_request(token_data=JSON_FILE_CONTENT)) + + assert excinfo.match( + "Unable to parse subject_token from JSON file '{}' using key '{}'".format( + url, "not_found" + ) + ) From 333899bbf2bbc3cb1bd85c7e43d08e9372e7a492 Mon Sep 17 00:00:00 2001 From: Ryan Kohler Date: Fri, 4 Dec 2020 12:18:32 -0800 Subject: [PATCH 3/6] More changes requested by bojeil-google --- google/auth/identity_pool.py | 2 +- tests/test_identity_pool.py | 57 ++++++++++++++++++------------------ 2 files changed, 29 insertions(+), 30 deletions(-) diff --git a/google/auth/identity_pool.py b/google/auth/identity_pool.py index 8e9e06783..5b50a3a09 100644 --- a/google/auth/identity_pool.py +++ b/google/auth/identity_pool.py @@ -25,7 +25,7 @@ Identity Pool Credentials are used with external credentials (eg. OIDC ID tokens) retrieved from a file location, typical for K8s workloads registered with Hub with Hub workload identity enabled, or retrieved from an -url, typical for AWS and Azure based workflows. +url, typical for Azure based workflows. """ import io diff --git a/tests/test_identity_pool.py b/tests/test_identity_pool.py index 88841b53c..1f36beab0 100644 --- a/tests/test_identity_pool.py +++ b/tests/test_identity_pool.py @@ -61,9 +61,10 @@ class TestCredentials(object): "file": SUBJECT_TOKEN_JSON_FILE, "format": {"type": "json", "subject_token_field_name": "access_token"}, } - CREDENTIAL_SOURCE_TEXT_URL = {"url": "http://fakeurl.com"} + CREDENTIAL_URL = "http://fakeurl.com" + CREDENTIAL_SOURCE_TEXT_URL = {"url": CREDENTIAL_URL} CREDENTIAL_SOURCE_JSON_URL = { - "url": "http://fakeurl.com", + "url": CREDENTIAL_URL, "format": {"type": "json", "subject_token_field_name": "access_token"}, } SUCCESS_RESPONSE = { @@ -104,6 +105,15 @@ def make_mock_request( return request + @classmethod + def assert_credential_request_kwargs( + cls, request_kwargs, url=CREDENTIAL_URL + ): + assert request_kwargs["url"] == url + assert request_kwargs["method"] == "GET" + assert request_kwargs["headers"] is None + assert request_kwargs.get("body", None) is None + @classmethod def assert_token_request_kwargs( cls, request_kwargs, headers, request_data, token_url=TOKEN_URL @@ -209,6 +219,10 @@ def assert_underlying_credentials_refresh( credentials.refresh(request) assert len(request.call_args_list) == len(requests) + if credential_data: + cls.assert_credential_request_kwargs( + request.call_args_list[0].kwargs, + ) # Verify token exchange request parameters. cls.assert_token_request_kwargs( request.call_args_list[token_request_index].kwargs, @@ -566,41 +580,32 @@ def test_refresh_with_retrieve_subject_token_error(self): ) def test_retrieve_subject_token_from_url(self): - credential_source = { - "url": "http://fakeurl.com", - } - credentials = self.make_credentials(credential_source=credential_source) + credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE_TEXT_URL) subject_token = credentials.retrieve_subject_token( self.make_mock_request(token_data=TEXT_FILE_SUBJECT_TOKEN)) assert subject_token == TEXT_FILE_SUBJECT_TOKEN def test_retrieve_subject_token_from_url_json(self): - credential_source = { - "url": "http://fakeurl.com", - "format": {"type": "json", "subject_token_field_name": "access_token"}, - } - credentials = self.make_credentials(credential_source=credential_source) + credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE_JSON_URL) subject_token = credentials.retrieve_subject_token( self.make_mock_request(token_data=JSON_FILE_CONTENT)) assert subject_token == JSON_FILE_SUBJECT_TOKEN def test_retrieve_subject_token_from_url_not_found(self): - credential_source = { - "url": "http://fakeurl.com", - } - credentials = self.make_credentials(credential_source=credential_source) + credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE_TEXT_URL) with pytest.raises(exceptions.RefreshError) as excinfo: credentials.retrieve_subject_token( self.make_mock_request( - token_status=500, + token_status=404, token_data=JSON_FILE_CONTENT)) + assert excinfo.match("Unable to retrieve Identity Pool subject token") + def test_retrieve_subject_token_from_url_json_invalid_field(self): - url = "http://fakeurl.com" credential_source = { - "url": url, + "url": self.CREDENTIAL_URL, "format": {"type": "json", "subject_token_field_name": "not_found"}, } credentials = self.make_credentials(credential_source=credential_source) @@ -611,17 +616,12 @@ def test_retrieve_subject_token_from_url_json_invalid_field(self): assert excinfo.match( "Unable to parse subject_token from JSON file '{}' using key '{}'".format( - url, "not_found" + self.CREDENTIAL_URL, "not_found" ) ) def test_retrieve_subject_token_from_url_json_invalid_format(self): - url = "http://fakeurl.com" - credential_source = { - "url": url, - "format": {"type": "json", "subject_token_field_name": "access_token"}, - } - credentials = self.make_credentials(credential_source=credential_source) + credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE_JSON_URL) with pytest.raises(exceptions.RefreshError) as excinfo: credentials.retrieve_subject_token( @@ -629,7 +629,7 @@ def test_retrieve_subject_token_from_url_json_invalid_format(self): assert excinfo.match( "Unable to parse subject_token from JSON file '{}' using key '{}'".format( - url, "access_token" + self.CREDENTIAL_URL, "access_token" ) ) @@ -722,9 +722,8 @@ def test_refresh_json_file_success_with_impersonation_url(self): ) def test_refresh_with_retrieve_subject_token_error_url(self): - url = "http://fakeurl.com" credential_source = { - "url": url, + "url": self.CREDENTIAL_URL, "format": {"type": "json", "subject_token_field_name": "not_found"}, } credentials = self.make_credentials(credential_source=credential_source) @@ -735,6 +734,6 @@ def test_refresh_with_retrieve_subject_token_error_url(self): assert excinfo.match( "Unable to parse subject_token from JSON file '{}' using key '{}'".format( - url, "not_found" + self.CREDENTIAL_URL, "not_found" ) ) From 12feca32ef42024df7e9a2b138caec1925021755 Mon Sep 17 00:00:00 2001 From: Ryan Kohler Date: Fri, 4 Dec 2020 16:23:14 -0800 Subject: [PATCH 4/6] Adding tests that include headers --- tests/test_identity_pool.py | 60 ++++++++++++++++++++++++++++++------- 1 file changed, 50 insertions(+), 10 deletions(-) diff --git a/tests/test_identity_pool.py b/tests/test_identity_pool.py index 1f36beab0..2b9147798 100644 --- a/tests/test_identity_pool.py +++ b/tests/test_identity_pool.py @@ -107,11 +107,11 @@ def make_mock_request( @classmethod def assert_credential_request_kwargs( - cls, request_kwargs, url=CREDENTIAL_URL + cls, request_kwargs, headers, url=CREDENTIAL_URL ): assert request_kwargs["url"] == url assert request_kwargs["method"] == "GET" - assert request_kwargs["headers"] is None + assert request_kwargs["headers"] == headers assert request_kwargs.get("body", None) is None @classmethod @@ -222,6 +222,7 @@ def assert_underlying_credentials_refresh( if credential_data: cls.assert_credential_request_kwargs( request.call_args_list[0].kwargs, + None, ) # Verify token exchange request parameters. cls.assert_token_request_kwargs( @@ -580,21 +581,59 @@ def test_refresh_with_retrieve_subject_token_error(self): ) def test_retrieve_subject_token_from_url(self): - credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE_TEXT_URL) - subject_token = credentials.retrieve_subject_token( - self.make_mock_request(token_data=TEXT_FILE_SUBJECT_TOKEN)) + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_TEXT_URL) + request = self.make_mock_request(token_data=TEXT_FILE_SUBJECT_TOKEN) + subject_token = credentials.retrieve_subject_token(request) assert subject_token == TEXT_FILE_SUBJECT_TOKEN + self.assert_credential_request_kwargs( + request.call_args_list[0].kwargs, + None) + + def test_retrieve_subject_token_from_url_with_headers(self): + credentials = self.make_credentials( + credential_source={ + "url": self.CREDENTIAL_URL, + "headers": {"foo": "bar"}}) + request = self.make_mock_request(token_data=TEXT_FILE_SUBJECT_TOKEN) + subject_token = credentials.retrieve_subject_token(request) + + assert subject_token == TEXT_FILE_SUBJECT_TOKEN + self.assert_credential_request_kwargs( + request.call_args_list[0].kwargs, + {"foo": "bar"}) def test_retrieve_subject_token_from_url_json(self): - credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE_JSON_URL) - subject_token = credentials.retrieve_subject_token( - self.make_mock_request(token_data=JSON_FILE_CONTENT)) + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_JSON_URL) + request = self.make_mock_request(token_data=JSON_FILE_CONTENT) + subject_token = credentials.retrieve_subject_token(request) assert subject_token == JSON_FILE_SUBJECT_TOKEN + self.assert_credential_request_kwargs( + request.call_args_list[0].kwargs, + None) + + def test_retrieve_subject_token_from_url_json_with_headers(self): + credentials = self.make_credentials( + credential_source={ + "url": self.CREDENTIAL_URL, + "format": { + "type": "json", + "subject_token_field_name": "access_token"}, + "headers": {"foo": "bar"}}) + request = self.make_mock_request(token_data=JSON_FILE_CONTENT) + subject_token = credentials.retrieve_subject_token(request) + + assert subject_token == JSON_FILE_SUBJECT_TOKEN + self.assert_credential_request_kwargs( + request.call_args_list[0].kwargs, + {"foo": "bar"}) def test_retrieve_subject_token_from_url_not_found(self): - credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE_TEXT_URL) + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_TEXT_URL) with pytest.raises(exceptions.RefreshError) as excinfo: credentials.retrieve_subject_token( self.make_mock_request( @@ -621,7 +660,8 @@ def test_retrieve_subject_token_from_url_json_invalid_field(self): ) def test_retrieve_subject_token_from_url_json_invalid_format(self): - credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE_JSON_URL) + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_JSON_URL) with pytest.raises(exceptions.RefreshError) as excinfo: credentials.retrieve_subject_token( From a76c698295ff7de00dc013f1822f302ec92d876c Mon Sep 17 00:00:00 2001 From: Ryan Kohler Date: Thu, 10 Dec 2020 16:08:14 -0800 Subject: [PATCH 5/6] changes requested by @busunkim96 --- google/auth/identity_pool.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/google/auth/identity_pool.py b/google/auth/identity_pool.py index 5b50a3a09..12de30424 100644 --- a/google/auth/identity_pool.py +++ b/google/auth/identity_pool.py @@ -32,6 +32,8 @@ import json import os +from collections.abc import Mapping + from google.auth import _helpers from google.auth import exceptions from google.auth import external_account @@ -60,7 +62,19 @@ def __init__( token_url (str): The STS endpoint URL. credential_source (Mapping): The credential source dictionary used to provide instructions on how to retrieve external credential to be - exchanged for Google access tokens.. + exchanged for Google access tokens + Example credential_source's: + { + "url": "http://www.example.com", + "format": { + "type": "json", + "subject_token_field_name": "access_token", + }, + "headers": {"foo": "bar"}, + } + { + "file": "/path/to/token/file.txt" + } service_account_impersonation_url (Optional[str]): The optional service account impersonation getAccessToken URL. client_id (Optional[str]): The optional client ID. @@ -90,14 +104,14 @@ def __init__( quota_project_id=quota_project_id, scopes=scopes, ) - if not isinstance(credential_source, dict): + if not isinstance(credential_source, Mapping): self._credential_source_file = None self._credential_source_url = None else: self._credential_source_file = credential_source.get("file") self._credential_source_url = credential_source.get("url") self._credential_source_headers = credential_source.get("headers") - credential_source_format = credential_source.get("format") or {} + credential_source_format = credential_source.get("format", {}) # Get credential_source format type. When not provided, this # defaults to text. self._credential_source_format_type = ( @@ -122,9 +136,11 @@ def __init__( self._credential_source_field_name = None if self._credential_source_file and self._credential_source_url: - raise ValueError("Ambiguous credential_source") + raise ValueError( + "Ambiguous credential_source. 'file' is mutually exclusive with 'url'.") if not self._credential_source_file and not self._credential_source_url: - raise ValueError("Missing credential_source") + raise ValueError( + "Missing credential_source. A 'file' or 'url' must be provided.") @_helpers.copy_docstring(external_account.Credentials) def retrieve_subject_token(self, request): From b1eb9ea4b73340d86bf0f2555c4c5bd2b02fffa6 Mon Sep 17 00:00:00 2001 From: Ryan Kohler Date: Fri, 11 Dec 2020 11:56:38 -0800 Subject: [PATCH 6/6] Making changes suggested by @bojeil-google to pass Kokoro Run --- google/auth/_default.py | 12 ++--- google/auth/identity_pool.py | 40 ++++++++------- tests/test_identity_pool.py | 98 +++++++++++++++++++++--------------- 3 files changed, 84 insertions(+), 66 deletions(-) diff --git a/google/auth/_default.py b/google/auth/_default.py index 894c5c222..acf76e924 100644 --- a/google/auth/_default.py +++ b/google/auth/_default.py @@ -311,16 +311,16 @@ def _get_external_account_credentials(info, filename, scopes=None, request=None) """ # There are currently 2 types of external_account credentials. try: - # Check if configuration corresponds to an Identity Pool credentials. - from google.auth import identity_pool + # Check if configuration corresponds to an AWS credentials. + from google.auth import aws - credentials = identity_pool.Credentials.from_info(info, scopes=scopes) + credentials = aws.Credentials.from_info(info, scopes=scopes) except ValueError: try: - # Check if configuration corresponds to an AWS credentials. - from google.auth import aws + # Check if configuration corresponds to an Identity Pool credentials. + from google.auth import identity_pool - credentials = aws.Credentials.from_info(info, scopes=scopes) + credentials = identity_pool.Credentials.from_info(info, scopes=scopes) except ValueError: # If the configuration is invalid or does not correspond to any # supported external_account credentials, raise an error. diff --git a/google/auth/identity_pool.py b/google/auth/identity_pool.py index 12de30424..5eed7a77c 100644 --- a/google/auth/identity_pool.py +++ b/google/auth/identity_pool.py @@ -28,12 +28,15 @@ url, typical for Azure based workflows. """ +try: + from collections.abc import Mapping +# Python 2.7 compatibility +except ImportError: # pragma: NO COVER + from collections import Mapping import io import json import os -from collections.abc import Mapping - from google.auth import _helpers from google.auth import exceptions from google.auth import external_account @@ -117,6 +120,12 @@ def __init__( self._credential_source_format_type = ( credential_source_format.get("type") or "text" ) + # environment_id is only supported in AWS or dedicated future external + # account credentials. + if "environment_id" in credential_source: + raise ValueError( + "Invalid Identity Pool credential_source field 'environment_id'" + ) if self._credential_source_format_type not in ["text", "json"]: raise ValueError( "Invalid credential_source format '{}'".format( @@ -137,10 +146,12 @@ def __init__( if self._credential_source_file and self._credential_source_url: raise ValueError( - "Ambiguous credential_source. 'file' is mutually exclusive with 'url'.") + "Ambiguous credential_source. 'file' is mutually exclusive with 'url'." + ) if not self._credential_source_file and not self._credential_source_url: raise ValueError( - "Missing credential_source. A 'file' or 'url' must be provided.") + "Missing credential_source. A 'file' or 'url' must be provided." + ) @_helpers.copy_docstring(external_account.Credentials) def retrieve_subject_token(self, request): @@ -153,11 +164,10 @@ def retrieve_subject_token(self, request): def _get_token_data(self, request): if self._credential_source_file: return self._get_file_data(self._credential_source_file) - if self._credential_source_url: + else: return self._get_url_data( - request, - self._credential_source_url, - self._credential_source_headers) + request, self._credential_source_url, self._credential_source_headers + ) def _get_file_data(self, filename): if not os.path.exists(filename): @@ -167,11 +177,7 @@ def _get_file_data(self, filename): return file_obj.read(), filename def _get_url_data(self, request, url, headers): - response = request( - url=url, - method="GET", - headers=headers - ) + response = request(url=url, method="GET", headers=headers) # support both string and bytes type response.data response_body = ( @@ -182,17 +188,13 @@ def _get_url_data(self, request, url, headers): if response.status != 200: raise exceptions.RefreshError( - "Unable to retrieve Identity Pool subject token", - response_body + "Unable to retrieve Identity Pool subject token", response_body ) return response_body, url def _parse_token_data( - self, - token_content, - format_type="text", - subject_token_field_name=None + self, token_content, format_type="text", subject_token_field_name=None ): content, filename = token_content if format_type == "text": diff --git a/tests/test_identity_pool.py b/tests/test_identity_pool.py index 2b9147798..fc8ad8de9 100644 --- a/tests/test_identity_pool.py +++ b/tests/test_identity_pool.py @@ -87,17 +87,18 @@ def make_mock_response(cls, status, data): @classmethod def make_mock_request( - cls, - token_status=http_client.OK, - token_data=None, - *extra_requests, + cls, token_status=http_client.OK, token_data=None, *extra_requests ): responses = [] responses.append(cls.make_mock_response(token_status, token_data)) while len(extra_requests) > 0: # If service account impersonation is requested, mock the expected response. - status, data, extra_requests = extra_requests[0], extra_requests[1], extra_requests[2:] + status, data, extra_requests = ( + extra_requests[0], + extra_requests[1], + extra_requests[2:], + ) responses.append(cls.make_mock_response(status, data)) request = mock.create_autospec(transport.Request) @@ -184,7 +185,8 @@ def assert_underlying_credentials_refresh( if service_account_impersonation_url: # Service account impersonation request/response. expire_time = ( - _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600) + _helpers.utcnow().replace(microsecond=0) + + datetime.timedelta(seconds=3600) ).isoformat("T") + "Z" impersonation_response = { "accessToken": "SA_ACCESS_TOKEN", @@ -213,17 +215,13 @@ def assert_underlying_credentials_refresh( impersonation_request_index = len(requests) requests.append((http_client.OK, impersonation_response)) - request = cls.make_mock_request( - *[el for req in requests for el in req]) + request = cls.make_mock_request(*[el for req in requests for el in req]) credentials.refresh(request) assert len(request.call_args_list) == len(requests) if credential_data: - cls.assert_credential_request_kwargs( - request.call_args_list[0].kwargs, - None, - ) + cls.assert_credential_request_kwargs(request.call_args_list[0].kwargs, None) # Verify token exchange request parameters. cls.assert_token_request_kwargs( request.call_args_list[token_request_index].kwargs, @@ -382,6 +380,27 @@ def test_constructor_invalid_options(self): assert excinfo.match(r"Missing credential_source") + def test_constructor_invalid_options_url_and_file(self): + credential_source = { + "url": self.CREDENTIAL_URL, + "file": SUBJECT_TOKEN_TEXT_FILE, + } + + with pytest.raises(ValueError) as excinfo: + self.make_credentials(credential_source=credential_source) + + assert excinfo.match(r"Ambiguous credential_source") + + def test_constructor_invalid_options_environment_id(self): + credential_source = {"url": self.CREDENTIAL_URL, "environment_id": "aws1"} + + with pytest.raises(ValueError) as excinfo: + self.make_credentials(credential_source=credential_source) + + assert excinfo.match( + r"Invalid Identity Pool credential_source field 'environment_id'" + ) + def test_constructor_invalid_credential_source(self): with pytest.raises(ValueError) as excinfo: self.make_credentials(credential_source="non-dict") @@ -582,63 +601,60 @@ def test_refresh_with_retrieve_subject_token_error(self): def test_retrieve_subject_token_from_url(self): credentials = self.make_credentials( - credential_source=self.CREDENTIAL_SOURCE_TEXT_URL) + credential_source=self.CREDENTIAL_SOURCE_TEXT_URL + ) request = self.make_mock_request(token_data=TEXT_FILE_SUBJECT_TOKEN) subject_token = credentials.retrieve_subject_token(request) assert subject_token == TEXT_FILE_SUBJECT_TOKEN - self.assert_credential_request_kwargs( - request.call_args_list[0].kwargs, - None) + self.assert_credential_request_kwargs(request.call_args_list[0].kwargs, None) def test_retrieve_subject_token_from_url_with_headers(self): credentials = self.make_credentials( - credential_source={ - "url": self.CREDENTIAL_URL, - "headers": {"foo": "bar"}}) + credential_source={"url": self.CREDENTIAL_URL, "headers": {"foo": "bar"}} + ) request = self.make_mock_request(token_data=TEXT_FILE_SUBJECT_TOKEN) subject_token = credentials.retrieve_subject_token(request) assert subject_token == TEXT_FILE_SUBJECT_TOKEN self.assert_credential_request_kwargs( - request.call_args_list[0].kwargs, - {"foo": "bar"}) + request.call_args_list[0].kwargs, {"foo": "bar"} + ) def test_retrieve_subject_token_from_url_json(self): credentials = self.make_credentials( - credential_source=self.CREDENTIAL_SOURCE_JSON_URL) + credential_source=self.CREDENTIAL_SOURCE_JSON_URL + ) request = self.make_mock_request(token_data=JSON_FILE_CONTENT) subject_token = credentials.retrieve_subject_token(request) assert subject_token == JSON_FILE_SUBJECT_TOKEN - self.assert_credential_request_kwargs( - request.call_args_list[0].kwargs, - None) + self.assert_credential_request_kwargs(request.call_args_list[0].kwargs, None) def test_retrieve_subject_token_from_url_json_with_headers(self): credentials = self.make_credentials( credential_source={ "url": self.CREDENTIAL_URL, - "format": { - "type": "json", - "subject_token_field_name": "access_token"}, - "headers": {"foo": "bar"}}) + "format": {"type": "json", "subject_token_field_name": "access_token"}, + "headers": {"foo": "bar"}, + } + ) request = self.make_mock_request(token_data=JSON_FILE_CONTENT) subject_token = credentials.retrieve_subject_token(request) assert subject_token == JSON_FILE_SUBJECT_TOKEN self.assert_credential_request_kwargs( - request.call_args_list[0].kwargs, - {"foo": "bar"}) + request.call_args_list[0].kwargs, {"foo": "bar"} + ) def test_retrieve_subject_token_from_url_not_found(self): credentials = self.make_credentials( - credential_source=self.CREDENTIAL_SOURCE_TEXT_URL) + credential_source=self.CREDENTIAL_SOURCE_TEXT_URL + ) with pytest.raises(exceptions.RefreshError) as excinfo: credentials.retrieve_subject_token( - self.make_mock_request( - token_status=404, - token_data=JSON_FILE_CONTENT)) + self.make_mock_request(token_status=404, token_data=JSON_FILE_CONTENT) + ) assert excinfo.match("Unable to retrieve Identity Pool subject token") @@ -651,7 +667,8 @@ def test_retrieve_subject_token_from_url_json_invalid_field(self): with pytest.raises(exceptions.RefreshError) as excinfo: credentials.retrieve_subject_token( - self.make_mock_request(token_data=JSON_FILE_CONTENT)) + self.make_mock_request(token_data=JSON_FILE_CONTENT) + ) assert excinfo.match( "Unable to parse subject_token from JSON file '{}' using key '{}'".format( @@ -661,11 +678,11 @@ def test_retrieve_subject_token_from_url_json_invalid_field(self): def test_retrieve_subject_token_from_url_json_invalid_format(self): credentials = self.make_credentials( - credential_source=self.CREDENTIAL_SOURCE_JSON_URL) + credential_source=self.CREDENTIAL_SOURCE_JSON_URL + ) with pytest.raises(exceptions.RefreshError) as excinfo: - credentials.retrieve_subject_token( - self.make_mock_request(token_data="{")) + credentials.retrieve_subject_token(self.make_mock_request(token_data="{")) assert excinfo.match( "Unable to parse subject_token from JSON file '{}' using key '{}'".format( @@ -769,8 +786,7 @@ def test_refresh_with_retrieve_subject_token_error_url(self): credentials = self.make_credentials(credential_source=credential_source) with pytest.raises(exceptions.RefreshError) as excinfo: - credentials.refresh( - self.make_mock_request(token_data=JSON_FILE_CONTENT)) + credentials.refresh(self.make_mock_request(token_data=JSON_FILE_CONTENT)) assert excinfo.match( "Unable to parse subject_token from JSON file '{}' using key '{}'".format(