diff --git a/google/auth/_default.py b/google/auth/_default.py index dc54c44b6..15799cea7 100644 --- a/google/auth/_default.py +++ b/google/auth/_default.py @@ -36,11 +36,13 @@ _SERVICE_ACCOUNT_TYPE = "service_account" _EXTERNAL_ACCOUNT_TYPE = "external_account" _IMPERSONATED_SERVICE_ACCOUNT_TYPE = "impersonated_service_account" +_GDCH_SERVICE_ACCOUNT_TYPE = "gdch_service_account" _VALID_TYPES = ( _AUTHORIZED_USER_TYPE, _SERVICE_ACCOUNT_TYPE, _EXTERNAL_ACCOUNT_TYPE, _IMPERSONATED_SERVICE_ACCOUNT_TYPE, + _GDCH_SERVICE_ACCOUNT_TYPE, ) # Help message when no credentials can be found. @@ -134,6 +136,8 @@ def load_credentials_from_file( def _load_credentials_from_info( filename, info, scopes, default_scopes, quota_project_id, request ): + from google.auth.credentials import CredentialsWithQuotaProject + credential_type = info.get("type") if credential_type == _AUTHORIZED_USER_TYPE: @@ -158,6 +162,8 @@ def _load_credentials_from_info( credentials, project_id = _get_impersonated_service_account_credentials( filename, info, scopes ) + elif credential_type == _GDCH_SERVICE_ACCOUNT_TYPE: + credentials, project_id = _get_gdch_service_account_credentials(info) else: raise exceptions.DefaultCredentialsError( "The file {file} does not have a valid type. " @@ -165,7 +171,8 @@ def _load_credentials_from_info( file=filename, type=credential_type, valid_types=_VALID_TYPES ) ) - credentials = _apply_quota_project_id(credentials, quota_project_id) + if isinstance(credentials, CredentialsWithQuotaProject): + credentials = _apply_quota_project_id(credentials, quota_project_id) return credentials, project_id @@ -430,6 +437,36 @@ def _get_impersonated_service_account_credentials(filename, info, scopes): return credentials, None +def _get_gdch_service_account_credentials(info): + from google.oauth2 import gdch_credentials + + k8s_ca_cert_path = info.get("k8s_ca_cert_path") + k8s_cert_path = info.get("k8s_cert_path") + k8s_key_path = info.get("k8s_key_path") + k8s_token_endpoint = info.get("k8s_token_endpoint") + ais_ca_cert_path = info.get("ais_ca_cert_path") + ais_token_endpoint = info.get("ais_token_endpoint") + + format_version = info.get("format_version") + if format_version != "v1": + raise exceptions.DefaultCredentialsError( + "format_version is not provided or unsupported. Supported version is: v1" + ) + + return ( + gdch_credentials.ServiceAccountCredentials( + k8s_ca_cert_path, + k8s_cert_path, + k8s_key_path, + k8s_token_endpoint, + ais_ca_cert_path, + ais_token_endpoint, + None, + ), + None, + ) + + def _apply_quota_project_id(credentials, quota_project_id): if quota_project_id: credentials = credentials.with_quota_project(quota_project_id) @@ -465,6 +502,11 @@ def default(scopes=None, request=None, quota_project_id=None, default_scopes=Non endpoint. The project ID returned in this case is the one corresponding to the underlying workload identity pool resource if determinable. + + If the environment variable is set to the path of a valid GDCH service + account JSON file (`Google Distributed Cloud Hosted`_), then a GDCH + credential will be returned. The project ID returned is None unless it + is set via `GOOGLE_CLOUD_PROJECT` environment variable. 2. If the `Google Cloud SDK`_ is installed and has application default credentials set they are loaded and returned. @@ -499,6 +541,8 @@ def default(scopes=None, request=None, quota_project_id=None, default_scopes=Non .. _Metadata Service: https://cloud.google.com/compute/docs\ /storing-retrieving-metadata .. _Cloud Run: https://cloud.google.com/run + .. _Google Distributed Cloud Hosted: https://cloud.google.com/blog/topics\ + /hybrid-cloud/announcing-google-distributed-cloud-edge-and-hosted Example:: diff --git a/google/oauth2/_client.py b/google/oauth2/_client.py index 2f4e8474b..8831baf27 100644 --- a/google/oauth2/_client.py +++ b/google/oauth2/_client.py @@ -44,11 +44,13 @@ def _handle_error_response(response_data): """Translates an error response into an exception. Args: - response_data (Mapping): The decoded response data. + response_data (Mapping | str): The decoded response data. Raises: google.auth.exceptions.RefreshError: The errors contained in response_data. """ + if isinstance(response_data, six.string_types): + raise exceptions.RefreshError(response_data) try: error_details = "{}: {}".format( response_data["error"], response_data.get("error_description") @@ -79,7 +81,13 @@ def _parse_expiry(response_data): def _token_endpoint_request_no_throw( - request, token_uri, body, access_token=None, use_json=False + request, + token_uri, + body, + access_token=None, + use_json=False, + expected_status_code=http_client.OK, + **kwargs ): """Makes a request to the OAuth 2.0 authorization server's token endpoint. This function doesn't throw on response errors. @@ -93,6 +101,16 @@ def _token_endpoint_request_no_throw( access_token (Optional(str)): The access token needed to make the request. use_json (Optional(bool)): Use urlencoded format or json format for the content type. The default value is False. + expected_status_code (Optional(int)): The expected the status code of + the token response. The default value is 200. We may expect other + status code like 201 for GDCH credentials. + kwargs: Additional arguments passed on to the request method. The + kwargs will be passed to `requests.request` method, see: + https://docs.python-requests.org/en/latest/api/#requests.request. + For example, you can use `cert=("cert_pem_path", "key_pem_path")` + to set up client side SSL certificate, and use + `verify="ca_bundle_path"` to set up the CA certificates for sever + side SSL certificate verification. Returns: Tuple(bool, Mapping[str, str]): A boolean indicating if the request is @@ -112,32 +130,46 @@ def _token_endpoint_request_no_throw( # retry to fetch token for maximum of two times if any internal failure # occurs. while True: - response = request(method="POST", url=token_uri, headers=headers, body=body) + response = request( + method="POST", url=token_uri, headers=headers, body=body, **kwargs + ) response_body = ( response.data.decode("utf-8") if hasattr(response.data, "decode") else response.data ) - response_data = json.loads(response_body) - if response.status == http_client.OK: + if response.status == expected_status_code: + # response_body should be a JSON + response_data = json.loads(response_body) break else: - error_desc = response_data.get("error_description") or "" - error_code = response_data.get("error") or "" - if ( - any(e == "internal_failure" for e in (error_code, error_desc)) - and retry < 1 - ): - retry += 1 - continue - return response.status == http_client.OK, response_data - - return response.status == http_client.OK, response_data + # For a failed response, response_body could be a string + try: + response_data = json.loads(response_body) + error_desc = response_data.get("error_description") or "" + error_code = response_data.get("error") or "" + if ( + any(e == "internal_failure" for e in (error_code, error_desc)) + and retry < 1 + ): + retry += 1 + continue + except ValueError: + response_data = response_body + return False, response_data + + return response.status == expected_status_code, response_data def _token_endpoint_request( - request, token_uri, body, access_token=None, use_json=False + request, + token_uri, + body, + access_token=None, + use_json=False, + expected_status_code=http_client.OK, + **kwargs ): """Makes a request to the OAuth 2.0 authorization server's token endpoint. @@ -150,6 +182,16 @@ def _token_endpoint_request( access_token (Optional(str)): The access token needed to make the request. use_json (Optional(bool)): Use urlencoded format or json format for the content type. The default value is False. + expected_status_code (Optional(int)): The expected the status code of + the token response. The default value is 200. We may expect other + status code like 201 for GDCH credentials. + kwargs: Additional arguments passed on to the request method. The + kwargs will be passed to `requests.request` method, see: + https://docs.python-requests.org/en/latest/api/#requests.request. + For example, you can use `cert=("cert_pem_path", "key_pem_path")` + to set up client side SSL certificate, and use + `verify="ca_bundle_path"` to set up the CA certificates for sever + side SSL certificate verification. Returns: Mapping[str, str]: The JSON-decoded response data. @@ -159,7 +201,13 @@ def _token_endpoint_request( an error. """ response_status_ok, response_data = _token_endpoint_request_no_throw( - request, token_uri, body, access_token=access_token, use_json=use_json + request, + token_uri, + body, + access_token=access_token, + use_json=use_json, + expected_status_code=expected_status_code, + **kwargs ) if not response_status_ok: _handle_error_response(response_data) diff --git a/google/oauth2/gdch_credentials.py b/google/oauth2/gdch_credentials.py new file mode 100644 index 000000000..e0edbf039 --- /dev/null +++ b/google/oauth2/gdch_credentials.py @@ -0,0 +1,194 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Experimental GDCH credentials support. +""" + +import six +from six.moves import http_client + +from google.auth import _helpers +from google.auth import credentials +from google.auth import exceptions +from google.oauth2 import _client + + +TOKEN_EXCHANGE_TYPE = "urn:ietf:params:oauth:token-type:token-exchange" +ACCESS_TOKEN_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token" +JWT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt" +SERVICE_ACCOUNT_TOKEN_TYPE = "urn:k8s:params:oauth:token-type:serviceaccount" + + +class ServiceAccountCredentials(credentials.Credentials): + """Credentials for GDCH (`Google Distributed Cloud Hosted`_) for service + account users. + + .. _Google Distributed Cloud Hosted: + https://cloud.google.com/blog/topics/hybrid-cloud/\ + announcing-google-distributed-cloud-edge-and-hosted + + Besides the constructor, a GDCH credential can be created via application + default credentials. + + To do so, user first creates a JSON file of the + following format:: + + { + "type":"gdch_service_account", + "format_version":"v1", + "k8s_ca_cert_path":"", + "k8s_cert_path":"", + "k8s_key_path":"", + "k8s_token_endpoint":"", + "ais_ca_cert_path":"", + "ais_token_endpoint":"" + } + + Here "k8s_*" files are used to request a k8s token from k8s token endpoint + using mutual TLS connection. The k8s token is then sent to AIS token endpoint + to exchange for an AIS token. The AIS token will be used to talk to Google + API services. + + "k8s_ca_cert_path" field is not needed if the k8s server uses well known CA. + "ais_ca_cert_path" field is not needed if the AIS server uses well known CA. + These two fields can be used for testing environments. + + The "format_version" field stands for the format of the JSON file. For now + it is always "v1". + + After the JSON file is created, set `GOOGLE_APPLICATION_CREDENTIALS` environment + variable to the JSON file path, then use the following code to create the + credential:: + + import google.auth + + credential, _ = google.auth.default() + credential = credential.with_audience("") + + The audience denotes the scope the AIS token is requested, for example, it + could be either a k8s cluster or API service. + """ + + def __init__( + self, + k8s_ca_cert_path, + k8s_cert_path, + k8s_key_path, + k8s_token_endpoint, + ais_ca_cert_path, + ais_token_endpoint, + audience, + ): + """ + Args: + k8s_ca_cert_path (str): CA cert path for k8s calls. This field is + useful if the specific k8s server doesn't use well known CA, + for instance, a testing k8s server. If the CA is well known, + you can pass `None` for this parameter. + k8s_cert_path (str): Certificate path for k8s calls + k8s_key_path (str): Key path for k8s calls + k8s_token_endpoint (str): k8s token endpoint url + ais_ca_cert_path (str): CA cert path for AIS token endpoint calls. + This field is useful if the specific AIS token server doesn't + uses well known CA, for instance, a testing AIS server. If the + CA is well known, you can pass `None` for this parameter. + ais_token_endpoint (str): AIS token endpoint url + audience (str): The audience for the requested AIS token. For + example, it could be a k8s cluster or API service. + """ + super(ServiceAccountCredentials, self).__init__() + self._k8s_ca_cert_path = k8s_ca_cert_path + self._k8s_cert_path = k8s_cert_path + self._k8s_key_path = k8s_key_path + self._k8s_token_endpoint = k8s_token_endpoint + self._ais_ca_cert_path = ais_ca_cert_path + self._ais_token_endpoint = ais_token_endpoint + self._audience = audience + + def _make_k8s_token_request(self, request): + k8s_request_body = { + "kind": "TokenRequest", + "apiVersion": "authentication.k8s.io/v1", + "spec": {"audiences": [self._ais_token_endpoint]}, + } + # mTLS connection to k8s token endpoint to get a k8s token. + k8s_response_data = _client._token_endpoint_request( + request, + self._k8s_token_endpoint, + k8s_request_body, + access_token=None, + use_json=True, + expected_status_code=http_client.CREATED, + cert=(self._k8s_cert_path, self._k8s_key_path), + verify=self._k8s_ca_cert_path, + ) + + try: + k8s_token = k8s_response_data["status"]["token"] + return k8s_token + except KeyError as caught_exc: + new_exc = exceptions.RefreshError( + "No access token in k8s token response.", k8s_response_data + ) + six.raise_from(new_exc, caught_exc) + + def _make_ais_token_request(self, k8s_token, request): + # send a request to AIS token point with the k8s token + ais_request_body = { + "grant_type": TOKEN_EXCHANGE_TYPE, + "audience": self._audience, + "requested_token_type": ACCESS_TOKEN_TOKEN_TYPE, + "subject_token": k8s_token, + "subject_token_type": SERVICE_ACCOUNT_TOKEN_TYPE, + } + ais_response_data = _client._token_endpoint_request( + request, + self._ais_token_endpoint, + ais_request_body, + access_token=None, + use_json=True, + verify=self._ais_ca_cert_path, + ) + ais_token, _, ais_expiry, _ = _client._handle_refresh_grant_response( + ais_response_data, None + ) + return ais_token, ais_expiry + + @_helpers.copy_docstring(credentials.Credentials) + def refresh(self, request): + import google.auth.transport.requests + + if not isinstance(request, google.auth.transport.requests.Request): + raise exceptions.RefreshError( + "For GDCH service account credentials, request must be a google.auth.transport.requests.Request object" + ) + + k8s_token = self._make_k8s_token_request(request) + self.token, self.expiry = self._make_ais_token_request(k8s_token, request) + + def with_audience(self, audience): + """Create a copy of GDCH credentials with the specified audience. + + Args: + audience (str): The intended audience for GDCH credentials. + """ + return self.__class__( + self._k8s_ca_cert_path, + self._k8s_cert_path, + self._k8s_key_path, + self._k8s_token_endpoint, + self._ais_ca_cert_path, + self._ais_token_endpoint, + audience, + ) diff --git a/tests/data/gdch_service_account.json b/tests/data/gdch_service_account.json new file mode 100644 index 000000000..c6c441bfd --- /dev/null +++ b/tests/data/gdch_service_account.json @@ -0,0 +1,10 @@ +{ + "type":"gdch_service_account", + "format_version": "v1", + "k8s_ca_cert_path":"./k8s_ca_cert.pem", + "k8s_cert_path":"./k8s_cert.pem", + "k8s_key_path":"./k8s_key.pem", + "k8s_token_endpoint":"https://k8s_endpoint/api/v1/namespaces/sa-token-test/serviceaccounts/sa-token-user/token", + "ais_ca_cert_path":"./ais_ca_cert.pem", + "ais_token_endpoint":"https://ais_endpoint/sts/v1beta/token" +} diff --git a/tests/oauth2/test__client.py b/tests/oauth2/test__client.py index 5485bed84..400582fc3 100644 --- a/tests/oauth2/test__client.py +++ b/tests/oauth2/test__client.py @@ -56,7 +56,7 @@ def test__handle_error_response(): assert excinfo.match(r"help: I\'m alive") -def test__handle_error_response_non_json(): +def test__handle_error_response_no_error(): response_data = {"foo": "bar"} with pytest.raises(exceptions.RefreshError) as excinfo: @@ -65,6 +65,15 @@ def test__handle_error_response_non_json(): assert excinfo.match(r"{\"foo\": \"bar\"}") +def test__handle_error_response_not_json(): + response_data = "this is an error message" + + with pytest.raises(exceptions.RefreshError) as excinfo: + _client._handle_error_response(response_data) + + assert excinfo.match(response_data) + + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) def test__parse_expiry(unused_utcnow): result = _client._parse_expiry({"expires_in": 500}) @@ -145,6 +154,8 @@ def test__token_endpoint_request_internal_failure_error(): _client._token_endpoint_request( request, "http://example.com", {"error_description": "internal_failure"} ) + # request should be called twice due to the retry + assert request.call_count == 2 request = make_request( {"error": "internal_failure"}, status=http_client.BAD_REQUEST @@ -154,6 +165,33 @@ def test__token_endpoint_request_internal_failure_error(): _client._token_endpoint_request( request, "http://example.com", {"error": "internal_failure"} ) + # request should be called twice due to the retry + assert request.call_count == 2 + + +def test__token_endpoint_request_string_error(): + response = mock.create_autospec(transport.Response, instance=True) + response.status = http_client.BAD_REQUEST + response.data = "this is an error message" + request = mock.create_autospec(transport.Request) + request.return_value = response + + with pytest.raises(exceptions.RefreshError) as excinfo: + _client._token_endpoint_request(request, "http://example.com", {}) + assert excinfo.match("this is an error message") + + +def test__token_endpoint_request_expected_status_code(): + request = make_request({}, status=http_client.CREATED) + + # It doesn't throw if the response code is the expected one. + _client._token_endpoint_request( + request, "http://example.com", {}, expected_status_code=http_client.CREATED + ) + + # It throws since the default status code is 200 OK, but we are expecting 201 CREATED. + with pytest.raises(exceptions.RefreshError): + _client._token_endpoint_request(request, "http://example.com", {}) def verify_request_params(request, params): diff --git a/tests/oauth2/test_gdch_credentials.py b/tests/oauth2/test_gdch_credentials.py new file mode 100644 index 000000000..41aa399af --- /dev/null +++ b/tests/oauth2/test_gdch_credentials.py @@ -0,0 +1,180 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime + +import mock +import pytest # type: ignore +from six.moves import http_client + +from google.auth import exceptions +from google.auth.transport import requests +from google.oauth2 import gdch_credentials + + +class TestCredentials(object): + K8S_CA_CERT_PATH = "./k8s_ca_cert.pem" + K8S_CERT_PATH = "./k8s_cert.pem" + K8S_KEY_PATH = "./k8s_key.pem" + K8S_TOKEN = "k8s_token" + K8S_TOKEN_ENDPOINT = "https://k8s_endpoint/v1/token" + AIS_CA_CERT_PATH = "./ais_ca_cert.pem" + AIS_TOKEN_ENDPOINT = "https://k8s_endpoint/v1/token" + AUDIENCE = "audience_foo" + + @classmethod + def make_credentials(cls): + return gdch_credentials.ServiceAccountCredentials( + cls.K8S_CA_CERT_PATH, + cls.K8S_CERT_PATH, + cls.K8S_KEY_PATH, + cls.K8S_TOKEN_ENDPOINT, + cls.AIS_CA_CERT_PATH, + cls.AIS_TOKEN_ENDPOINT, + cls.AUDIENCE, + ) + + def test_with_audience(self): + creds = self.make_credentials() + assert creds._audience == self.AUDIENCE + + new_creds = creds.with_audience("bar") + assert new_creds._audience == "bar" + + @mock.patch("google.oauth2._client._token_endpoint_request", autospec=True) + def test__make_k8s_token_request(self, token_endpoint_request): + creds = self.make_credentials() + req = requests.Request() + + token_endpoint_request.return_value = { + "status": { + "token": self.K8S_TOKEN, + "expirationTimestamp": "2022-02-22T06:51:46Z", + } + } + assert creds._make_k8s_token_request(req) == self.K8S_TOKEN + token_endpoint_request.assert_called_with( + req, + creds._k8s_token_endpoint, + { + "kind": "TokenRequest", + "apiVersion": "authentication.k8s.io/v1", + "spec": {"audiences": [creds._ais_token_endpoint]}, + }, + None, + True, + http_client.CREATED, + cert=(creds._k8s_cert_path, creds._k8s_key_path), + verify=creds._k8s_ca_cert_path, + ) + + @mock.patch("google.oauth2._client._token_endpoint_request", autospec=True) + def test__make_k8s_token_request_no_token(self, token_endpoint_request): + creds = self.make_credentials() + req = requests.Request() + + token_endpoint_request.return_value = { + "status": {"expirationTimestamp": "2022-02-22T06:51:46Z"} + } + + with pytest.raises(exceptions.RefreshError) as excinfo: + creds._make_k8s_token_request(req) + assert excinfo.match("No access token in k8s token response") + + @mock.patch("google.oauth2._client._token_endpoint_request", autospec=True) + @mock.patch("google.auth._helpers.utcnow", autospec=True) + def test__make_ais_token_request(self, utcnow, token_endpoint_request): + creds = self.make_credentials() + req = requests.Request() + + issue_time = datetime.datetime(2022, 1, 1, 0, 0, 0) + utcnow.return_value = issue_time + expires_in_seconds = 3599 + + token_endpoint_request.return_value = { + "access_token": "ais_token", + "expires_in": expires_in_seconds, + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + "token_type": "Bearer", + } + + k8s_token = self.K8S_TOKEN + ais_token, ais_expiry = creds._make_ais_token_request(k8s_token, req) + assert ais_token == "ais_token" + assert ais_expiry == issue_time + datetime.timedelta(seconds=expires_in_seconds) + token_endpoint_request.assert_called_with( + req, + creds._ais_token_endpoint, + { + "grant_type": gdch_credentials.TOKEN_EXCHANGE_TYPE, + "audience": creds._audience, + "requested_token_type": gdch_credentials.ACCESS_TOKEN_TOKEN_TYPE, + "subject_token": k8s_token, + "subject_token_type": gdch_credentials.SERVICE_ACCOUNT_TOKEN_TYPE, + }, + None, + True, + verify=creds._ais_ca_cert_path, + ) + + @mock.patch( + "google.oauth2.gdch_credentials.ServiceAccountCredentials._make_k8s_token_request", + autospec=True, + ) + @mock.patch( + "google.oauth2.gdch_credentials.ServiceAccountCredentials._make_ais_token_request", + autospec=True, + ) + def test_refresh(self, ais_token_request, k8s_token_request): + k8s_token_request.return_value = self.K8S_TOKEN + mock_expiry = mock.Mock() + ais_token_request.return_value = ("ais_token", mock_expiry) + + creds = self.make_credentials() + req = requests.Request() + creds.refresh(req) + + k8s_token_request.assert_called_with(creds, req) + ais_token_request.assert_called_with(creds, self.K8S_TOKEN, req) + assert creds.token == "ais_token" + assert creds.expiry == mock_expiry + + def test_refresh_request_not_requests_type(self): + creds = self.make_credentials() + req = mock.Mock() + + with pytest.raises(exceptions.RefreshError) as excinfo: + creds.refresh(req) + assert excinfo.match( + "request must be a google.auth.transport.requests.Request object" + ) + + @mock.patch( + "google.oauth2.gdch_credentials.ServiceAccountCredentials._make_k8s_token_request", + autospec=True, + ) + @mock.patch( + "google.oauth2.gdch_credentials.ServiceAccountCredentials._make_ais_token_request", + autospec=True, + ) + def test_before_request(self, ais_token_request, k8s_token_request): + ais_token_request.return_value = ("ais_token", mock.Mock()) + + cred = self.make_credentials() + headers = {} + + cred.before_request(requests.Request(), "GET", "https://example.com", headers) + k8s_token_request.assert_called() + ais_token_request.assert_called() + assert headers["authorization"] == "Bearer ais_token" diff --git a/tests/test__default.py b/tests/test__default.py index 4e7eeb84e..e92166811 100644 --- a/tests/test__default.py +++ b/tests/test__default.py @@ -29,6 +29,7 @@ from google.auth import identity_pool from google.auth import impersonated_credentials from google.auth import pluggable +from google.oauth2 import gdch_credentials from google.oauth2 import service_account import google.oauth2.credentials @@ -51,6 +52,8 @@ CLIENT_SECRETS_FILE = os.path.join(DATA_DIR, "client_secrets.json") +GDCH_SERVICE_ACCOUNT_FILE = os.path.join(DATA_DIR, "gdch_service_account.json") + with open(SERVICE_ACCOUNT_FILE) as fh: SERVICE_ACCOUNT_FILE_DATA = json.load(fh) @@ -645,6 +648,22 @@ def test__get_gcloud_sdk_credentials_no_project_id(load, unused_isfile, get_proj assert get_project_id.called +def test__get_gdch_service_account_credentials_no_format_version(): + with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: + _default._get_gdch_service_account_credentials({}) + assert excinfo.match( + "format_version is not provided or unsupported. Supported version is: v1" + ) + + +def test__get_gdch_service_account_credentials_invalid_format_version(): + with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: + _default._get_gdch_service_account_credentials({"format_version": "v2"}) + assert excinfo.match( + "format_version is not provided or unsupported. Supported version is: v1" + ) + + class _AppIdentityModule(object): """The interface of the App Idenity app engine module. See https://cloud.google.com/appengine/docs/standard/python/refdocs\ @@ -1150,6 +1169,31 @@ def test_default_impersonated_service_account_set_both_scopes_and_default_scopes assert credentials._target_scopes == scopes +@mock.patch( + "google.auth._cloud_sdk.get_application_default_credentials_path", autospec=True +) +@mock.patch("google.auth._default._apply_quota_project_id", autospec=True) +def test_default_gdch_service_account_credentials(apply_quota_project_id, get_adc_path): + get_adc_path.return_value = GDCH_SERVICE_ACCOUNT_FILE + + credentials, _ = _default.default(quota_project_id="project-foo") + + # make sure _apply_quota_project_id is not called since GDCH service account + # credential doesn't inheirt from CredentialsWithQuotaProject. + apply_quota_project_id.assert_not_called() + + assert isinstance(credentials, gdch_credentials.ServiceAccountCredentials) + assert credentials._k8s_ca_cert_path == "./k8s_ca_cert.pem" + assert credentials._k8s_cert_path == "./k8s_cert.pem" + assert credentials._k8s_key_path == "./k8s_key.pem" + assert ( + credentials._k8s_token_endpoint + == "https://k8s_endpoint/api/v1/namespaces/sa-token-test/serviceaccounts/sa-token-user/token" + ) + assert credentials._ais_ca_cert_path == "./ais_ca_cert.pem" + assert credentials._ais_token_endpoint == "https://ais_endpoint/sts/v1beta/token" + + @EXTERNAL_ACCOUNT_GET_PROJECT_ID_PATCH def test_load_credentials_from_external_account_pluggable(get_project_id, tmpdir): config_file = tmpdir.join("config.json")