Skip to content
This repository was archived by the owner on Mar 26, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,20 @@ class {{ async_method_name_prefix }}{{ service.name }}RestInterceptor:
"""
return response

{% if not method.server_streaming %}
{{ async_prefix }}def post_{{ method.name|snake_case }}_with_metadata(self, response: {{method.output.ident}}, {{ client_method_metadata_argument() }}) -> Tuple[{{method.output.ident}}, {{ client_method_metadata_type() }}]:
{% else %}
{{ async_prefix }}def post_{{ method.name|snake_case }}_with_metadata(self, response: rest_streaming{{ async_suffix }}.{{ async_method_name_prefix }}ResponseIterator, {{ client_method_metadata_argument() }}) -> Tuple[rest_streaming{{ async_suffix }}.{{ async_method_name_prefix }}ResponseIterator, {{ client_method_metadata_type() }}]:
{% endif %}
"""Post-rpc interceptor for {{ method.name|snake_case }}

Override in a subclass to either manipulate or read, either the response
or metadata after it is returned by the {{ service.name }} server but before
it is returned to user code.
"""
return response, metadata

{% endif %}{# not method.void #}
{% endfor %}

{% for name, signature in api.mixin_api_signatures.items() %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ class {{service.name}}RestTransport(_Base{{ service.name }}RestTransport):
{% endif %}{# method.lro #}
{#- TODO(https://github.com/googleapis/gapic-generator-python/issues/2274): Add debug log before intercepting a request #}
resp = self._interceptor.post_{{ method.name|snake_case }}(resp)
response_metadata = [(k, str(v)) for k, v in response.headers.items()]
resp, _ = self._interceptor.post_{{ method.name|snake_case }}_with_metadata(resp, response_metadata)
{# TODO(https://github.com/googleapis/gapic-generator-python/issues/2279): Add logging support for rest streaming. #}
{% if not method.server_streaming %}
if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor(logging.DEBUG): # pragma: NO COVER
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ class Async{{service.name}}RestTransport(_Base{{ service.name }}RestTransport):
json_format.Parse(content, pb_resp, ignore_unknown_fields=True)
{% endif %}{# if method.server_streaming #}
resp = await self._interceptor.post_{{ method.name|snake_case }}(resp)
response_metadata = [(k, str(v)) for k, v in response.headers.items()]
resp, _ = await self._interceptor.post_{{ method.name|snake_case }}_with_metadata(resp, response_metadata)
{# TODO(https://github.com/googleapis/gapic-generator-python/issues/2279): Add logging support for rest streaming. #}
{% if not method.server_streaming %}
if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor(logging.DEBUG): # pragma: NO COVER
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2218,11 +2218,13 @@ def test_initialize_client_w_{{transport_name}}():
{% endif %}
{% if not method.void %}
mock.patch.object(transports.{{async_method_prefix}}{{ service.name }}RestInterceptor, "post_{{method.name|snake_case}}") as post, \
mock.patch.object(transports.{{async_method_prefix}}{{ service.name }}RestInterceptor, "post_{{method.name|snake_case}}_with_metadata") as post_with_metadata, \
{% endif %}
mock.patch.object(transports.{{async_method_prefix}}{{ service.name }}RestInterceptor, "pre_{{ method.name|snake_case }}") as pre:
pre.assert_not_called()
{% if not method.void %}
post.assert_not_called()
post_with_metadata.assert_not_called()
{% endif %}
{% if method.input.ident.is_proto_plus_type %}
pb_message = {{ method.input.ident }}.pb({{ method.input.ident }}())
Expand Down Expand Up @@ -2265,13 +2267,15 @@ def test_initialize_client_w_{{transport_name}}():
pre.return_value = request, metadata
{% if not method.void %}
post.return_value = {{ method.output.ident }}()
post_with_metadata.return_value = {{ method.output.ident }}(), metadata
{% endif %}

{{await_prefix}}client.{{ method_name }}(request, metadata=[("key", "val"), ("cephalopod", "squid"),])

pre.assert_called_once()
{% if not method.void %}
post.assert_called_once()
post_with_metadata.assert_called_once()
{% endif %}
{% endif %}{# end 'grpc' in transport #}
{% endmacro%}{# inteceptor_class_test #}

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,15 @@ def post_generate_access_token(self, response: common.GenerateAccessTokenRespons
"""
return response

def post_generate_access_token_with_metadata(self, response: common.GenerateAccessTokenResponse, metadata: Sequence[Tuple[str, Union[str, bytes]]]) -> Tuple[common.GenerateAccessTokenResponse, Sequence[Tuple[str, Union[str, bytes]]]]:
"""Post-rpc interceptor for generate_access_token

Override in a subclass to either manipulate or read, either the response
or metadata after it is returned by the IAMCredentials server but before
it is returned to user code.
"""
return response, metadata

def pre_generate_id_token(self, request: common.GenerateIdTokenRequest, metadata: Sequence[Tuple[str, Union[str, bytes]]]) -> Tuple[common.GenerateIdTokenRequest, Sequence[Tuple[str, Union[str, bytes]]]]:
"""Pre-rpc interceptor for generate_id_token

Expand All @@ -144,6 +153,15 @@ def post_generate_id_token(self, response: common.GenerateIdTokenResponse) -> co
"""
return response

def post_generate_id_token_with_metadata(self, response: common.GenerateIdTokenResponse, metadata: Sequence[Tuple[str, Union[str, bytes]]]) -> Tuple[common.GenerateIdTokenResponse, Sequence[Tuple[str, Union[str, bytes]]]]:
"""Post-rpc interceptor for generate_id_token

Override in a subclass to either manipulate or read, either the response
or metadata after it is returned by the IAMCredentials server but before
it is returned to user code.
"""
return response, metadata

def pre_sign_blob(self, request: common.SignBlobRequest, metadata: Sequence[Tuple[str, Union[str, bytes]]]) -> Tuple[common.SignBlobRequest, Sequence[Tuple[str, Union[str, bytes]]]]:
"""Pre-rpc interceptor for sign_blob

Expand All @@ -161,6 +179,15 @@ def post_sign_blob(self, response: common.SignBlobResponse) -> common.SignBlobRe
"""
return response

def post_sign_blob_with_metadata(self, response: common.SignBlobResponse, metadata: Sequence[Tuple[str, Union[str, bytes]]]) -> Tuple[common.SignBlobResponse, Sequence[Tuple[str, Union[str, bytes]]]]:
"""Post-rpc interceptor for sign_blob

Override in a subclass to either manipulate or read, either the response
or metadata after it is returned by the IAMCredentials server but before
it is returned to user code.
"""
return response, metadata

def pre_sign_jwt(self, request: common.SignJwtRequest, metadata: Sequence[Tuple[str, Union[str, bytes]]]) -> Tuple[common.SignJwtRequest, Sequence[Tuple[str, Union[str, bytes]]]]:
"""Pre-rpc interceptor for sign_jwt

Expand All @@ -178,6 +205,15 @@ def post_sign_jwt(self, response: common.SignJwtResponse) -> common.SignJwtRespo
"""
return response

def post_sign_jwt_with_metadata(self, response: common.SignJwtResponse, metadata: Sequence[Tuple[str, Union[str, bytes]]]) -> Tuple[common.SignJwtResponse, Sequence[Tuple[str, Union[str, bytes]]]]:
"""Post-rpc interceptor for sign_jwt

Override in a subclass to either manipulate or read, either the response
or metadata after it is returned by the IAMCredentials server but before
it is returned to user code.
"""
return response, metadata


@dataclasses.dataclass
class IAMCredentialsRestStub:
Expand Down Expand Up @@ -375,6 +411,8 @@ def __call__(self,
json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True)

resp = self._interceptor.post_generate_access_token(resp)
response_metadata = [(k, str(v)) for k, v in response.headers.items()]
resp, _ = self._interceptor.post_generate_access_token_with_metadata(resp, response_metadata)
if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor(logging.DEBUG): # pragma: NO COVER
try:
response_payload = common.GenerateAccessTokenResponse.to_json(response)
Expand Down Expand Up @@ -495,6 +533,8 @@ def __call__(self,
json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True)

resp = self._interceptor.post_generate_id_token(resp)
response_metadata = [(k, str(v)) for k, v in response.headers.items()]
resp, _ = self._interceptor.post_generate_id_token_with_metadata(resp, response_metadata)
if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor(logging.DEBUG): # pragma: NO COVER
try:
response_payload = common.GenerateIdTokenResponse.to_json(response)
Expand Down Expand Up @@ -615,6 +655,8 @@ def __call__(self,
json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True)

resp = self._interceptor.post_sign_blob(resp)
response_metadata = [(k, str(v)) for k, v in response.headers.items()]
resp, _ = self._interceptor.post_sign_blob_with_metadata(resp, response_metadata)
if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor(logging.DEBUG): # pragma: NO COVER
try:
response_payload = common.SignBlobResponse.to_json(response)
Expand Down Expand Up @@ -735,6 +777,8 @@ def __call__(self,
json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True)

resp = self._interceptor.post_sign_jwt(resp)
response_metadata = [(k, str(v)) for k, v in response.headers.items()]
resp, _ = self._interceptor.post_sign_jwt_with_metadata(resp, response_metadata)
if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor(logging.DEBUG): # pragma: NO COVER
try:
response_payload = common.SignJwtResponse.to_json(response)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3162,9 +3162,11 @@ def test_generate_access_token_rest_interceptors(null_interceptor):
with mock.patch.object(type(client.transport._session), "request") as req, \
mock.patch.object(path_template, "transcode") as transcode, \
mock.patch.object(transports.IAMCredentialsRestInterceptor, "post_generate_access_token") as post, \
mock.patch.object(transports.IAMCredentialsRestInterceptor, "post_generate_access_token_with_metadata") as post_with_metadata, \
mock.patch.object(transports.IAMCredentialsRestInterceptor, "pre_generate_access_token") as pre:
pre.assert_not_called()
post.assert_not_called()
post_with_metadata.assert_not_called()
pb_message = common.GenerateAccessTokenRequest.pb(common.GenerateAccessTokenRequest())
transcode.return_value = {
"method": "post",
Expand All @@ -3186,11 +3188,13 @@ def test_generate_access_token_rest_interceptors(null_interceptor):
]
pre.return_value = request, metadata
post.return_value = common.GenerateAccessTokenResponse()
post_with_metadata.return_value = common.GenerateAccessTokenResponse(), metadata

client.generate_access_token(request, metadata=[("key", "val"), ("cephalopod", "squid"),])

pre.assert_called_once()
post.assert_called_once()
post_with_metadata.assert_called_once()


def test_generate_id_token_rest_bad_request(request_type=common.GenerateIdTokenRequest):
Expand Down Expand Up @@ -3264,9 +3268,11 @@ def test_generate_id_token_rest_interceptors(null_interceptor):
with mock.patch.object(type(client.transport._session), "request") as req, \
mock.patch.object(path_template, "transcode") as transcode, \
mock.patch.object(transports.IAMCredentialsRestInterceptor, "post_generate_id_token") as post, \
mock.patch.object(transports.IAMCredentialsRestInterceptor, "post_generate_id_token_with_metadata") as post_with_metadata, \
mock.patch.object(transports.IAMCredentialsRestInterceptor, "pre_generate_id_token") as pre:
pre.assert_not_called()
post.assert_not_called()
post_with_metadata.assert_not_called()
pb_message = common.GenerateIdTokenRequest.pb(common.GenerateIdTokenRequest())
transcode.return_value = {
"method": "post",
Expand All @@ -3288,11 +3294,13 @@ def test_generate_id_token_rest_interceptors(null_interceptor):
]
pre.return_value = request, metadata
post.return_value = common.GenerateIdTokenResponse()
post_with_metadata.return_value = common.GenerateIdTokenResponse(), metadata

client.generate_id_token(request, metadata=[("key", "val"), ("cephalopod", "squid"),])

pre.assert_called_once()
post.assert_called_once()
post_with_metadata.assert_called_once()


def test_sign_blob_rest_bad_request(request_type=common.SignBlobRequest):
Expand Down Expand Up @@ -3368,9 +3376,11 @@ def test_sign_blob_rest_interceptors(null_interceptor):
with mock.patch.object(type(client.transport._session), "request") as req, \
mock.patch.object(path_template, "transcode") as transcode, \
mock.patch.object(transports.IAMCredentialsRestInterceptor, "post_sign_blob") as post, \
mock.patch.object(transports.IAMCredentialsRestInterceptor, "post_sign_blob_with_metadata") as post_with_metadata, \
mock.patch.object(transports.IAMCredentialsRestInterceptor, "pre_sign_blob") as pre:
pre.assert_not_called()
post.assert_not_called()
post_with_metadata.assert_not_called()
pb_message = common.SignBlobRequest.pb(common.SignBlobRequest())
transcode.return_value = {
"method": "post",
Expand All @@ -3392,11 +3402,13 @@ def test_sign_blob_rest_interceptors(null_interceptor):
]
pre.return_value = request, metadata
post.return_value = common.SignBlobResponse()
post_with_metadata.return_value = common.SignBlobResponse(), metadata

client.sign_blob(request, metadata=[("key", "val"), ("cephalopod", "squid"),])

pre.assert_called_once()
post.assert_called_once()
post_with_metadata.assert_called_once()


def test_sign_jwt_rest_bad_request(request_type=common.SignJwtRequest):
Expand Down Expand Up @@ -3472,9 +3484,11 @@ def test_sign_jwt_rest_interceptors(null_interceptor):
with mock.patch.object(type(client.transport._session), "request") as req, \
mock.patch.object(path_template, "transcode") as transcode, \
mock.patch.object(transports.IAMCredentialsRestInterceptor, "post_sign_jwt") as post, \
mock.patch.object(transports.IAMCredentialsRestInterceptor, "post_sign_jwt_with_metadata") as post_with_metadata, \
mock.patch.object(transports.IAMCredentialsRestInterceptor, "pre_sign_jwt") as pre:
pre.assert_not_called()
post.assert_not_called()
post_with_metadata.assert_not_called()
pb_message = common.SignJwtRequest.pb(common.SignJwtRequest())
transcode.return_value = {
"method": "post",
Expand All @@ -3496,11 +3510,13 @@ def test_sign_jwt_rest_interceptors(null_interceptor):
]
pre.return_value = request, metadata
post.return_value = common.SignJwtResponse()
post_with_metadata.return_value = common.SignJwtResponse(), metadata

client.sign_jwt(request, metadata=[("key", "val"), ("cephalopod", "squid"),])

pre.assert_called_once()
post.assert_called_once()
post_with_metadata.assert_called_once()

def test_initialize_client_w_rest():
client = IAMCredentialsClient(
Expand Down
Loading