Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ from google.auth import credentials as ga_credentials # type: ignore
from google.api_core import exceptions as core_exceptions
from google.api_core import retry as retries
from google.api_core import rest_helpers
from google.api_core import rest_streaming
from google.api_core import path_template
from google.api_core import gapic_v1

{% if service.has_lro %}
from google.api_core import operations_v1
from google.protobuf import json_format
Expand Down Expand Up @@ -66,7 +68,7 @@ class {{ service.name }}RestInterceptor:

.. code-block:
class MyCustom{{ service.name }}Interceptor({{ service.name }}RestInterceptor):
{% for _, method in service.methods|dictsort if not (method.server_streaming or method.client_streaming) %}
{% for _, method in service.methods|dictsort if not method.client_streaming %}
def pre_{{ method.name|snake_case }}(request, metadata):
logging.log(f"Received request: {request}")
return request, metadata
Expand All @@ -82,7 +84,7 @@ class {{ service.name }}RestInterceptor:


"""
{% for method in service.methods.values()|sort(attribute="name") if not (method.server_streaming or method.client_streaming) %}
{% for method in service.methods.values()|sort(attribute="name") if not method.client_streaming %}
def pre_{{ method.name|snake_case }}(self, request: {{method.input.ident}}, metadata: Sequence[Tuple[str, str]]) -> Tuple[{{method.input.ident}}, Sequence[Tuple[str, str]]]:
"""Pre-rpc interceptor for {{ method.name|snake_case }}

Expand All @@ -92,7 +94,11 @@ class {{ service.name }}RestInterceptor:
return request, metadata

{% if not method.void %}
{% if not method.server_streaming %}
def post_{{ method.name|snake_case }}(self, response: {{method.output.ident}}) -> {{method.output.ident}}:
{% else %}
def post_{{ method.name|snake_case }}(self, response: rest_streaming.ResponseIterator) -> rest_streaming.ResponseIterator:
{% endif %}
"""Post-rpc interceptor for {{ method.name|snake_case }}

Override in a subclass to manipulate the response
Expand Down Expand Up @@ -248,8 +254,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
def __hash__(self):
return hash("{{method.name}}")


{% if not (method.server_streaming or method.client_streaming) %}
{% if not method.client_streaming %}
{% if method.input.required_fields %}
__REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
Expand All @@ -262,15 +267,15 @@ class {{service.name}}RestTransport({{service.name}}Transport):
def _get_unset_required_fields(cls, message_dict):
return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict}
{% endif %}{# required fields #}
{% endif %}{# not (method.server_streaming or method.client_streaming) #}
{% endif %}{# not method.client_streaming #}

def __call__(self,
request: {{method.input.ident}}, *,
retry: OptionalRetry=gapic_v1.method.DEFAULT,
timeout: float=None,
metadata: Sequence[Tuple[str, str]]=(),
){% if not method.void %} -> {{method.output.ident}}{% endif %}:
{% if method.http_options and not (method.server_streaming or method.client_streaming) %}
){% if not method.void %} -> {% if not method.server_streaming %}{{method.output.ident}}{% else %}rest_streaming.ResponseIterator{% endif %}{% endif %}:
{% if method.http_options and not method.client_streaming %}
r"""Call the {{- ' ' -}}
{{ (method.name|snake_case).replace('_',' ')|wrap(
width=70, offset=45, indent=8) }}
Expand Down Expand Up @@ -360,6 +365,8 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% if method.lro %}
resp = operations_pb2.Operation()
json_format.Parse(response.content, resp, ignore_unknown_fields=True)
{% elif method.server_streaming %}
resp = rest_streaming.ResponseIterator(response, {{method.output.ident}})
{% else %}
resp = {{method.output.ident}}.from_json(
response.content,
Expand All @@ -370,14 +377,14 @@ class {{service.name}}RestTransport({{service.name}}Transport):
return resp

{% endif %}{# method.void #}
{% else %}{# method.http_options and not (method.server_streaming or method.client_streaming) #}
{% else %}{# method.http_options and not method.client_streaming #}
{% if not method.http_options %}
raise RuntimeError(
"Cannot define a method without a valid 'google.api.http' annotation.")

{% elif method.server_streaming or method.client_streaming %}
{% elif method.client_streaming %}
raise NotImplementedError(
"Streaming over REST is not yet defined for python client")
"Client streaming over REST is not yet defined for python client")

{% else %}
raise NotImplementedError()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import mock
import grpc
from grpc.experimental import aio
{% if "rest" in opts.transport %}
from collections.abc import Iterable
import json
{% endif %}
import math
Expand Down Expand Up @@ -861,8 +862,8 @@ def test_{{ method_name }}_raw_page_lro():
{% endfor %} {# method in methods for grpc #}

{% for method in service.methods.values() if 'rest' in opts.transport %}{% with method_name = method.name|snake_case + "_unary" if method.operation_service else method.name|snake_case %}{% if method.http_options %}
{# TODO(kbandes): remove this if condition when streaming are supported. #}
{% if not (method.server_streaming or method.client_streaming) %}
{# TODO(kbandes): remove this if condition when client streaming are supported. #}
{% if not method.client_streaming %}
@pytest.mark.parametrize("request_type", [
{{ method.input.ident }},
dict,
Expand All @@ -884,8 +885,6 @@ def test_{{ method_name }}_rest(request_type, transport: str = 'rest'):
return_value = None
{% elif method.lro %}
return_value = operations_pb2.Operation(name='operations/spam')
{% elif method.server_streaming %}
return_value = iter([{{ method.output.ident }}()])
{% else %}
return_value = {{ method.output.ident }}(
{% for field in method.output.fields.values() | rejectattr('message')%}
Expand All @@ -905,6 +904,8 @@ def test_{{ method_name }}_rest(request_type, transport: str = 'rest'):
req.return_value.request = PreparedRequest()
{% if method.void %}
json_return_value = ''
{% elif method.server_streaming %}
json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
{% else %}
json_return_value = {{ method.output.ident }}.to_json(return_value)
{% endif %}
Expand All @@ -914,6 +915,10 @@ def test_{{ method_name }}_rest(request_type, transport: str = 'rest'):
# the request over the wire, so an empty request is fine.
{% if method.client_streaming %}
client.{{ method_name }}(iter([requests]))
{% elif method.server_streaming %}
with mock.patch.object(response_value, 'iter_content') as iter_content:
iter_content.return_value = iter(json_return_value)
response = client.{{ method_name }}(request)
{% else %}
client.{{ method_name }}(request)
{% endif %}
Expand Down Expand Up @@ -950,8 +955,6 @@ def test_{{ method.name|snake_case }}_rest(request_type):
return_value = None
{% elif method.lro %}
return_value = operations_pb2.Operation(name='operations/spam')
{% elif method.server_streaming %}
return_value = iter([{{ method.output.ident }}()])
{% else %}
return_value = {{ method.output.ident }}(
{% for field in method.output.fields.values() | rejectattr('message')%}
Expand All @@ -974,13 +977,19 @@ def test_{{ method.name|snake_case }}_rest(request_type):
json_return_value = ''
{% elif method.lro %}
json_return_value = json_format.MessageToJson(return_value)
{% elif method.server_streaming %}
json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
{% else %}
json_return_value = {{ method.output.ident }}.to_json(return_value)
{% endif %}
response_value._content = json_return_value.encode('UTF-8')
req.return_value = response_value
{% if method.client_streaming %}
response = client.{{ method.name|snake_case }}(iter(requests))
{% elif method.server_streaming %}
with mock.patch.object(response_value, 'iter_content') as iter_content:
iter_content.return_value = iter(json_return_value)
response = client.{{ method_name }}(request)
{% else %}
response = client.{{ method_name }}(request)
{% endif %}
Expand All @@ -991,6 +1000,11 @@ def test_{{ method.name|snake_case }}_rest(request_type):

{% endif %}

{% if method.server_streaming %}
assert isinstance(response, Iterable)
response = next(response)
{% endif %}

# Establish that the response is the type that we expect.
{% if method.void %}
assert response is None
Expand Down Expand Up @@ -1085,8 +1099,6 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
return_value = None
{% elif method.lro %}
return_value = operations_pb2.Operation(name='operations/spam')
{% elif method.server_streaming %}
return_value = iter([{{ method.output.ident }}()])
{% else %}
return_value = {{ method.output.ident }}()
{% endif %}
Expand Down Expand Up @@ -1114,6 +1126,8 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
json_return_value = ''
{% elif method.lro %}
json_return_value = json_format.MessageToJson(return_value)
{% elif method.server_streaming %}
json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
{% else %}
json_return_value = {{ method.output.ident }}.to_json(return_value)
{% endif %}
Expand All @@ -1122,6 +1136,10 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide

{% if method.client_streaming %}
response = client.{{ method.name|snake_case }}(iter(requests))
{% elif method.server_streaming %}
with mock.patch.object(response_value, 'iter_content') as iter_content:
iter_content.return_value = iter(json_return_value)
response = client.{{ method_name }}(request)
{% else %}
response = client.{{ method_name }}(request)
{% endif %}
Expand Down Expand Up @@ -1248,8 +1266,6 @@ def test_{{ method.name|snake_case }}_rest_flattened():
return_value = None
{% elif method.lro %}
return_value = operations_pb2.Operation(name='operations/spam')
{% elif method.server_streaming %}
return_value = iter([{{ method.output.ident }}()])
{% else %}
return_value = {{ method.output.ident }}()
{% endif %}
Expand All @@ -1261,6 +1277,8 @@ def test_{{ method.name|snake_case }}_rest_flattened():
json_return_value = ''
{% elif method.lro %}
json_return_value = json_format.MessageToJson(return_value)
{% elif method.server_streaming %}
json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
{% else %}
json_return_value = {{ method.output.ident }}.to_json(return_value)
{% endif %}
Expand All @@ -1281,7 +1299,14 @@ def test_{{ method.name|snake_case }}_rest_flattened():
{% endfor %}
)
mock_args.update(sample_request)

{% if method.server_streaming %}
with mock.patch.object(response_value, 'iter_content') as iter_content:
iter_content.return_value = iter(json_return_value)
client.{{ method_name }}(**mock_args)
{% else %}
client.{{ method_name }}(**mock_args)
{% endif %}

# Establish that the underlying call was made with the expected
# request object values.
Expand Down Expand Up @@ -1385,6 +1410,9 @@ def test_{{ method_name }}_rest_pager(transport: str = 'rest'):
response = tuple({{ method.output.ident }}.to_json(x) for x in response)
return_values = tuple(Response() for i in response)
for return_val, response_val in zip(return_values, response):
{% if method.server_streaming %}
response_val = "[{}]".format({{ method.output.ident }}.to_json(response_val))
{% endif %}
return_val._content = response_val.encode('UTF-8')
return_val.status_code = 200
req.side_effect = return_values
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ from google.auth import credentials as ga_credentials # type: ignore
from google.api_core import exceptions as core_exceptions
from google.api_core import retry as retries
from google.api_core import rest_helpers
from google.api_core import rest_streaming
from google.api_core import path_template
from google.api_core import gapic_v1

{% if service.has_lro %}
from google.api_core import operations_v1
from google.protobuf import json_format
Expand Down Expand Up @@ -66,7 +68,7 @@ class {{ service.name }}RestInterceptor:

.. code-block:
class MyCustom{{ service.name }}Interceptor({{ service.name }}RestInterceptor):
{% for _, method in service.methods|dictsort if not (method.server_streaming or method.client_streaming) %}
{% for _, method in service.methods|dictsort if not method.client_streaming %}
def pre_{{ method.name|snake_case }}(request, metadata):
logging.log(f"Received request: {request}")
return request, metadata
Expand All @@ -82,7 +84,7 @@ class {{ service.name }}RestInterceptor:


"""
{% for method in service.methods.values()|sort(attribute="name") if not (method.server_streaming or method.client_streaming) %}
{% for method in service.methods.values()|sort(attribute="name") if not method.client_streaming %}
def pre_{{ method.name|snake_case }}(self, request: {{method.input.ident}}, metadata: Sequence[Tuple[str, str]]) -> Tuple[{{method.input.ident}}, Sequence[Tuple[str, str]]]:
"""Pre-rpc interceptor for {{ method.name|snake_case }}

Expand All @@ -92,7 +94,11 @@ class {{ service.name }}RestInterceptor:
return request, metadata

{% if not method.void %}
{% if not method.server_streaming %}
def post_{{ method.name|snake_case }}(self, response: {{method.output.ident}}) -> {{method.output.ident}}:
{% else %}
def post_{{ method.name|snake_case }}(self, response: rest_streaming.ResponseIterator) -> rest_streaming.ResponseIterator:
{% endif %}
"""Post-rpc interceptor for {{ method.name|snake_case }}

Override in a subclass to manipulate the response
Expand Down Expand Up @@ -248,8 +254,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
def __hash__(self):
return hash("{{method.name}}")


{% if not (method.server_streaming or method.client_streaming) %}
{% if not method.client_streaming %}
{% if method.input.required_fields %}
__REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
Expand All @@ -262,15 +267,15 @@ class {{service.name}}RestTransport({{service.name}}Transport):
def _get_unset_required_fields(cls, message_dict):
return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict}
{% endif %}{# required fields #}
{% endif %}{# not (method.server_streaming or method.client_streaming) #}
{% endif %}{# not method.client_streaming #}

def __call__(self,
request: {{method.input.ident}}, *,
retry: OptionalRetry=gapic_v1.method.DEFAULT,
timeout: float=None,
metadata: Sequence[Tuple[str, str]]=(),
){% if not method.void %} -> {{method.output.ident}}{% endif %}:
{% if method.http_options and not (method.server_streaming or method.client_streaming) %}
){% if not method.void %} -> {% if not method.server_streaming %}{{method.output.ident}}{% else %}rest_streaming.ResponseIterator{% endif %}{% endif %}:
{% if method.http_options and not method.client_streaming %}
r"""Call the {{- ' ' -}}
{{ (method.name|snake_case).replace('_',' ')|wrap(
width=70, offset=45, indent=8) }}
Expand Down Expand Up @@ -360,6 +365,8 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% if method.lro %}
resp = operations_pb2.Operation()
json_format.Parse(response.content, resp, ignore_unknown_fields=True)
{% elif method.server_streaming %}
resp = rest_streaming.ResponseIterator(response, {{method.output.ident}})
{% else %}
resp = {{method.output.ident}}.from_json(
response.content,
Expand All @@ -370,14 +377,14 @@ class {{service.name}}RestTransport({{service.name}}Transport):
return resp

{% endif %}{# method.void #}
{% else %}{# method.http_options and not (method.server_streaming or method.client_streaming) #}
{% else %}{# method.http_options and not method.client_streaming #}
{% if not method.http_options %}
raise RuntimeError(
"Cannot define a method without a valid 'google.api.http' annotation.")

{% elif method.server_streaming or method.client_streaming %}
{% elif method.client_streaming %}
raise NotImplementedError(
"Streaming over REST is not yet defined for python client")
"Client streaming over REST is not yet defined for python client")

{% else %}
raise NotImplementedError()
Expand Down
2 changes: 1 addition & 1 deletion gapic/templates/setup.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ setuptools.setup(
install_requires=(
{# TODO(dovs): remove when 1.x deprecation is complete #}
{% if 'rest' in opts.transport %}
'google-api-core[grpc] >= 2.3.0, < 3.0.0dev',
'google-api-core[grpc] >= 2.4.0, < 3.0.0dev',
{% else %}
'google-api-core[grpc] >= 1.28.0, < 3.0.0dev',
{% endif %}
Expand Down
Loading