Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,116 @@ from google.api_core import retry as retries
from google.api_core import rest_helpers
from google.api_core import path_template
from google.api_core import gapic_v1

# TODO: Remove once my PR gets merged and released.
# Begin of ResponseIterator depedencies.
from collections import deque
import string
from typing import Deque
import requests

class ResponseIterator:
"""Iterator over REST API responses.

Args:
response (requests.Response): An API response object.
response_message_cls (Callable[proto.Message]): A proto
class expected to be returned from an API.
"""

def __init__(self, response: requests.Response, response_message_cls):
self._response = response
self._response_message_cls = response_message_cls
# Inner iterator over HTTP response's content.
self._response_itr = self._response.iter_content(decode_unicode=True)
# Contains a list of JSON responses ready to be sent to user.
self._ready_objs: Deque[str] = deque()
# Current JSON response being built.
self._obj = ""
# Keeps track of the nesting level within a JSON object.
self._level = 0
# Keeps track whether HTTP response is currently sending values
# inside of a string value.
self._in_string = False
# Whether an escape symbol "\" was encountered.
self._next_should_be_escaped = False

def cancel(self):
"""Cancel existing streaming operation.
"""
self._response.close()

def _process_chunk(self, chunk: str):
if self._level == 0:
if chunk[0] != "[":
raise ValueError(
"Can only parse array of JSON objects, instead got %s" % chunk
)
for char in chunk:
if char == "{":
if self._level == 1:
# Level 1 corresponds to the outermost JSON object
# (i.e. the one we care about).
self._obj = ""
if not self._in_string:
self._level += 1
self._obj += char
elif char == "}":
self._obj += char
if not self._in_string:
self._level -= 1
if not self._in_string and self._level == 1:
self._ready_objs.append(self._obj)
elif char == '"':
# Helps to deal with an escaped quotes inside of a string.
if not self._next_should_be_escaped:
self._in_string = not self._in_string
self._obj += char
elif char in string.whitespace:
if self._in_string:
self._obj += char
elif char == "[":
if self._level == 0:
self._level += 1
else:
self._obj += char
elif char == "]":
if self._level == 1:
self._level -= 1
else:
self._obj += char
else:
self._obj += char

if char == "\\":
# Escaping the "\".
if self._next_should_be_escaped:
self._next_should_be_escaped = False
else:
self._next_should_be_escaped = True
else:
self._next_should_be_escaped = False

def __next__(self):
while not self._ready_objs:
try:
chunk = next(self._response_itr)
self._process_chunk(chunk)
except StopIteration as e:
if self._level > 0:
raise ValueError("Unfinished stream: %s" % self._obj)
raise e
return self._grab()

def _grab(self):
# Add extra quotes to make json.loads happy.
return self._response_message_cls.from_json(self._ready_objs.popleft())

def __iter__(self):
return self

# End of ResponseIterator dependencies.

{% if service.has_lro %}
from google.api_core import operations_v1
from google.protobuf import json_format
Expand Down Expand Up @@ -179,7 +289,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
def __hash__(self):
return hash("{{method.name}}")

{% if not (method.server_streaming or method.client_streaming) %}
{% if not method.client_streaming %}
{% if method.input.required_fields %}
__REQUIRED_FIELDS_DEFAULT_VALUES = {
{% for req_field in method.input.required_fields if req_field.is_primitive %}
Expand All @@ -200,7 +310,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
timeout: float=None,
metadata: Sequence[Tuple[str, str]]=(),
) -> {{method.output.ident}}:
{% if method.http_options and not (method.server_streaming or method.client_streaming) %}
{% if method.http_options and not method.client_streaming %}
r"""Call the {{- ' ' -}}
{{ (method.name|snake_case).replace('_',' ')|wrap(
width=70, offset=45, indent=8) }}
Expand Down Expand Up @@ -291,6 +401,8 @@ class {{service.name}}RestTransport({{service.name}}Transport):
return_op = operations_pb2.Operation()
json_format.Parse(response.content, return_op, ignore_unknown_fields=True)
return return_op
{% elif method.server_streaming %}
return ResponseIterator(response, {{method.output.ident}})
{% else %}
return {{method.output.ident}}.from_json(
response.content,
Expand All @@ -299,14 +411,14 @@ class {{service.name}}RestTransport({{service.name}}Transport):

{% endif %}{# method.lro #}
{% endif %}{# method.void #}
{% else %}{# method.http_options and not (method.server_streaming or method.client_streaming) #}
{% else %}{# method.http_options and not method.client_streaming #}
{% if not method.http_options %}
raise RuntimeError(
"Cannot define a method without a valid 'google.api.http' annotation.")

{% elif method.server_streaming or method.client_streaming %}
{% elif method.client_streaming %}
raise NotImplementedError(
"Streaming over REST is not yet defined for python client")
"Client streaming over REST is not yet defined for python client")

{% else %}
raise NotImplementedError()
Expand Down
2 changes: 1 addition & 1 deletion gapic/templates/setup.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ setuptools.setup(
install_requires=(
{# TODO(dovs): remove when 1.x deprecation is complete #}
{% if 'rest' in opts.transport %}
'google-api-core[grpc] >= 2.3.0, < 3.0.0dev',
'google-api-core[grpc] >= 2.3.2, < 3.0.0dev',
{% else %}
'google-api-core[grpc] >= 1.28.0, < 3.0.0dev',
{% endif %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import mock
import grpc
from grpc.experimental import aio
{% if "rest" in opts.transport %}
from collections.abc import Iterable
import json
{% endif %}
import math
Expand Down Expand Up @@ -1229,7 +1230,7 @@ def test_{{ method_name }}_raw_page_lro():

{% for method in service.methods.values() if 'rest' in opts.transport %}{% with method_name = method.name|snake_case + "_unary" if method.operation_service else method.name|snake_case %}{% if method.http_options %}
{# TODO(kbandes): remove this if condition when streaming are supported. #}
{% if not (method.server_streaming or method.client_streaming) %}
{% if not method.client_streaming %}
@pytest.mark.parametrize("request_type", [
{{ method.input.ident }},
dict,
Expand Down Expand Up @@ -1317,8 +1318,6 @@ def test_{{ method_name }}_rest(request_type):
return_value = None
{% elif method.lro %}
return_value = operations_pb2.Operation(name='operations/spam')
{% elif method.server_streaming %}
return_value = iter([{{ method.output.ident }}()])
{% else %}
return_value = {{ method.output.ident }}(
{% for field in method.output.fields.values() | rejectattr('message')%}
Expand All @@ -1341,13 +1340,20 @@ def test_{{ method_name }}_rest(request_type):
json_return_value = ''
{% elif method.lro %}
json_return_value = json_format.MessageToJson(return_value)
{% elif method.server_streaming %}
json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
{% else %}
json_return_value = {{ method.output.ident }}.to_json(return_value)
{% endif %}
response_value._content = json_return_value.encode('UTF-8')
req.return_value = response_value

{% if method.client_streaming %}
response = client.{{ method_name }}(iter(requests))
{% elif method.server_streaming %}
with mock.patch.object(response_value, 'iter_content') as iter_content:
iter_content.return_value = iter(json_return_value)
response = client.{{ method_name }}(request)
{% else %}
response = client.{{ method_name }}(request)
{% endif %}
Expand All @@ -1358,6 +1364,11 @@ def test_{{ method_name }}_rest(request_type):

{% endif %}

{% if method.server_streaming %}
assert isinstance(response, Iterable)
response = next(response)
{% endif %}

# Establish that the response is the type that we expect.
{% if method.void %}
assert response is None
Expand Down Expand Up @@ -1443,8 +1454,6 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
return_value = None
{% elif method.lro %}
return_value = operations_pb2.Operation(name='operations/spam')
{% elif method.server_streaming %}
return_value = iter([{{ method.output.ident }}()])
{% else %}
return_value = {{ method.output.ident }}()
{% endif %}
Expand Down Expand Up @@ -1472,6 +1481,8 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
json_return_value = ''
{% elif method.lro %}
json_return_value = json_format.MessageToJson(return_value)
{% elif method.server_streaming %}
json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
{% else %}
json_return_value = {{ method.output.ident }}.to_json(return_value)
{% endif %}
Expand All @@ -1480,6 +1491,10 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide

{% if method.client_streaming %}
response = client.{{ method_name }}(iter(requests))
{% elif method.server_streaming %}
with mock.patch.object(response_value, 'iter_content') as iter_content:
iter_content.return_value = iter(json_return_value)
response = client.{{ method_name }}(request)
{% else %}
response = client.{{ method_name }}(request)
{% endif %}
Expand Down Expand Up @@ -1550,40 +1565,47 @@ def test_{{ method_name }}_rest_flattened():
return_value = None
{% elif method.lro %}
return_value = operations_pb2.Operation(name='operations/spam')
{% elif method.server_streaming %}
return_value = iter([{{ method.output.ident }}()])
{% else %}
return_value = {{ method.output.ident }}()
{% endif %}

# get arguments that satisfy an http rule for this method
sample_request = {{ method.http_options[0].sample_request(method) }}

# get truthy value for each flattened field
mock_args = dict(
{% for field in method.flattened_fields.values() %}
{% if not field.oneof or field.proto3_optional %}
{# ignore oneof fields that might conflict with sample_request #}
{{ field.name }}={{ field.mock_value }},
{% endif %}
{% endfor %}
)
mock_args.update(sample_request)

# Wrap the value into a proper Response obj
response_value = Response()
response_value.status_code = 200
{% if method.void %}
json_return_value = ''
{% elif method.lro %}
json_return_value = json_format.MessageToJson(return_value)
{% elif method.server_streaming %}
json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
{% else %}
json_return_value = {{ method.output.ident }}.to_json(return_value)
{% endif %}

response_value._content = json_return_value.encode('UTF-8')
req.return_value = response_value

# get arguments that satisfy an http rule for this method
sample_request = {{ method.http_options[0].sample_request(method) }}

# get truthy value for each flattened field
mock_args = dict(
{% for field in method.flattened_fields.values() %}
{% if not field.oneof or field.proto3_optional %}
{# ignore oneof fields that might conflict with sample_request #}
{{ field.name }}={{ field.mock_value }},
{% endif %}
{% endfor %}
)
mock_args.update(sample_request)
{% if method.server_streaming %}
with mock.patch.object(response_value, 'iter_content') as iter_content:
iter_content.return_value = iter(json_return_value)
client.{{ method_name }}(**mock_args)
{% else %}
client.{{ method_name }}(**mock_args)
{% endif %}

# Establish that the underlying call was made with the expected
# request object values.
Expand Down Expand Up @@ -1687,6 +1709,9 @@ def test_{{ method_name }}_rest_pager(transport: str = 'rest'):
response = tuple({{ method.output.ident }}.to_json(x) for x in response)
return_values = tuple(Response() for i in response)
for return_val, response_val in zip(return_values, response):
{% if method.server_streaming %}
response_val = "[{}]".format({{ method.output.ident }}.to_json(response_val))
{% endif %}
return_val._content = response_val.encode('UTF-8')
return_val.status_code = 200
req.side_effect = return_values
Expand All @@ -1699,7 +1724,6 @@ def test_{{ method_name }}_rest_pager(transport: str = 'rest'):
{% endif %}
{% endfor %}


pager = client.{{ method_name }}(request=sample_request)

{% if method.paged_result_field.map %}
Expand Down
1 change: 0 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,6 @@ def showcase_unit(
session, templates="DEFAULT", other_opts: typing.Iterable[str] = (),
):
"""Run the generated unit tests against the Showcase library."""

with showcase_library(session, templates=templates, other_opts=other_opts) as lib:
session.chdir(lib)
run_showcase_unit_tests(session)
Expand Down