Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file removed .DS_Store
Binary file not shown.
97 changes: 83 additions & 14 deletions google/auth/external_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

from google.auth import _helpers
from google.auth import credentials
from google.auth import exceptions
from google.auth import impersonated_credentials
from google.oauth2 import sts
from google.oauth2 import utils

Expand All @@ -58,6 +60,7 @@ def __init__(
subject_token_type,
token_url,
credential_source,
service_account_impersonation_url=None,
client_id=None,
client_secret=None,
quota_project_id=None,
Expand All @@ -70,17 +73,23 @@ def __init__(
subject_token_type (str): The subject token type.
token_url (str): The STS endpoint URL.
credential_source (Mapping): The credential source dictionary.
service_account_impersonation_url (Optional[str]): The optional service account
impersonation generateAccessToken URL.
client_id (Optional[str]): The optional client ID.
client_secret (Optional[str]): The optional client secret.
quota_project_id (Optional[str]): The optional quota project ID.
scopes (Optional[Sequence[str]]): Optional scopes to request during the
authorization grant.
Raises:
google.auth.exceptions.RefreshError: If the generateAccessToken
endpoint returned an error.
"""
super(Credentials, self).__init__()
self._audience = audience
self._subject_token_type = subject_token_type
self._token_url = token_url
self._credential_source = credential_source
self._service_account_impersonation_url = service_account_impersonation_url
self._client_id = client_id
self._client_secret = client_secret
self._quota_project_id = quota_project_id
Expand All @@ -94,6 +103,11 @@ def __init__(
self._client_auth = None
self._sts_client = sts.Client(self._token_url, self._client_auth)

if self._service_account_impersonation_url:
self._impersonated_credentials = self._initialize_impersonated_credentials()
else:
self._impersonated_credentials = None

@property
def requires_scopes(self):
"""Checks if the credentials requires scopes.
Expand Down Expand Up @@ -132,20 +146,24 @@ def retrieve_subject_token(self, request):

@_helpers.copy_docstring(credentials.Credentials)
def refresh(self, request):
now = _helpers.utcnow()
response_data = self._sts_client.exchange_token(
request=request,
grant_type=_STS_GRANT_TYPE,
subject_token=self.retrieve_subject_token(request),
subject_token_type=self._subject_token_type,
audience=self._audience,
scopes=self._scopes,
requested_token_type=_STS_REQUESTED_TOKEN_TYPE,
)

self.token = response_data.get("access_token")
lifetime = datetime.timedelta(seconds=response_data.get("expires_in"))
self.expiry = now + lifetime
if self._impersonated_credentials:
self._impersonated_credentials.refresh(request)
self.token = self._impersonated_credentials.token
self.expiry = self._impersonated_credentials.expiry
else:
now = _helpers.utcnow()
response_data = self._sts_client.exchange_token(
request=request,
grant_type=_STS_GRANT_TYPE,
subject_token=self.retrieve_subject_token(request),
subject_token_type=self._subject_token_type,
audience=self._audience,
scopes=self._scopes,
requested_token_type=_STS_REQUESTED_TOKEN_TYPE,
)
self.token = response_data.get("access_token")
lifetime = datetime.timedelta(seconds=response_data.get("expires_in"))
self.expiry = now + lifetime

@_helpers.copy_docstring(credentials.CredentialsWithQuotaProject)
def with_quota_project(self, quota_project_id):
Expand All @@ -155,8 +173,59 @@ def with_quota_project(self, quota_project_id):
subject_token_type=self._subject_token_type,
token_url=self._token_url,
credential_source=self._credential_source,
service_account_impersonation_url=self._service_account_impersonation_url,
client_id=self._client_id,
client_secret=self._client_secret,
quota_project_id=quota_project_id,
scopes=self._scopes,
)

def _initialize_impersonated_credentials(self):
"""Generates an impersonated credentials.

For more details, see `projects.serviceAccounts.generateAccessToken`_.

.. _projects.serviceAccounts.generateAccessToken: https://cloud.google.com/iam/docs/reference/credentials/rest/v1/projects.serviceAccounts/generateAccessToken

Returns:
impersonated_credentials.Credential: The impersonated credentials
object.

Raises:
google.auth.exceptions.RefreshError: If the generateAccessToken
endpoint returned an error.
"""
# Return copy of instance with no service account impersonation.
source_credentials = self.__class__(
audience=self._audience,
subject_token_type=self._subject_token_type,
token_url=self._token_url,
credential_source=self._credential_source,
service_account_impersonation_url=None,
client_id=self._client_id,
client_secret=self._client_secret,
quota_project_id=self._quota_project_id,
scopes=self._scopes,
)

# Determine target_principal.
start_index = self._service_account_impersonation_url.rfind("/")
end_index = self._service_account_impersonation_url.find(":generateAccessToken")
if start_index != -1 and end_index != -1 and start_index < end_index:
start_index = start_index + 1
target_principal = self._service_account_impersonation_url[
start_index:end_index
]
else:
raise exceptions.RefreshError(
"Unable to determine target principal from service account impersonation URL."
)

# Initialize and return impersonated credentials.
return impersonated_credentials.Credentials(
source_credentials=source_credentials,
target_principal=target_principal,
target_scopes=self._scopes,
quota_project_id=self._quota_project_id,
iam_endpoint_override=self._service_account_impersonation_url,
)
16 changes: 14 additions & 2 deletions google/auth/impersonated_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,19 @@
_DEFAULT_TOKEN_URI = "https://oauth2.googleapis.com/token"


def _make_iam_token_request(request, principal, headers, body):
def _make_iam_token_request(
request, principal, headers, body, iam_endpoint_override=None
):
"""Makes a request to the Google Cloud IAM service for an access token.
Args:
request (Request): The Request object to use.
principal (str): The principal to request an access token for.
headers (Mapping[str, str]): Map of headers to transmit.
body (Mapping[str, str]): JSON Payload body for the iamcredentials
API call.
iam_endpoint_override (Optiona[str]): The full IAM endpoint override
with the target_principal embedded. This is useful when supporting
impersonation with regional endpoints.

Raises:
google.auth.exceptions.TransportError: Raised if there is an underlying
Expand All @@ -82,7 +87,7 @@ def _make_iam_token_request(request, principal, headers, body):
`iamcredentials.googleapis.com` is not enabled or the
`Service Account Token Creator` is not assigned
"""
iam_endpoint = _IAM_ENDPOINT.format(principal)
iam_endpoint = iam_endpoint_override or _IAM_ENDPOINT.format(principal)

body = json.dumps(body).encode("utf-8")

Expand Down Expand Up @@ -185,6 +190,7 @@ def __init__(
delegates=None,
lifetime=_DEFAULT_TOKEN_LIFETIME_SECS,
quota_project_id=None,
iam_endpoint_override=None,
):
"""
Args:
Expand All @@ -209,6 +215,9 @@ def __init__(
quota_project_id (Optional[str]): The project ID used for quota and billing.
This project may be different from the project used to
create the credentials.
iam_endpoint_override (Optiona[str]): The full IAM endpoint override
with the target_principal embedded. This is useful when supporting
impersonation with regional endpoints.
"""

super(Credentials, self).__init__()
Expand All @@ -226,6 +235,7 @@ def __init__(
self.token = None
self.expiry = _helpers.utcnow()
self._quota_project_id = quota_project_id
self._iam_endpoint_override = iam_endpoint_override

@_helpers.copy_docstring(credentials.Credentials)
def refresh(self, request):
Expand Down Expand Up @@ -260,6 +270,7 @@ def _update_token(self, request):
principal=self._target_principal,
headers=headers,
body=body,
iam_endpoint_override=self._iam_endpoint_override,
)

def sign_bytes(self, message):
Expand Down Expand Up @@ -302,6 +313,7 @@ def with_quota_project(self, quota_project_id):
delegates=self._delegates,
lifetime=self._lifetime,
quota_project_id=quota_project_id,
iam_endpoint_override=self._iam_endpoint_override,
)


Expand Down
Loading