Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion gapic/schema/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __str__(self) -> str:
# This module is from a different proto package
# Most commonly happens for a common proto
# https://pypi.org/project/googleapis-common-protos/
if not self.proto_package.startswith(self.api_naming.proto_package):
if self.is_external_type:
module_name = f'{self.module}_pb2'

# Return the dot-separated Python identifier.
Expand All @@ -102,6 +102,10 @@ def __str__(self) -> str:
# Return the Python identifier.
return '.'.join(self.parent + (self.name,))

@property
def is_external_type(self):
return not self.proto_package.startswith(self.api_naming.proto_package)

@cached_property
def __cached_string_repr(self):
return "({})".format(
Expand Down
28 changes: 22 additions & 6 deletions gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,19 @@ def recursive_mock_original_type(field):
# Not worth the hassle, just return an empty map.
return {}

msg_dict = {
f.name: recursive_mock_original_type(f)
for f in field.message.fields.values()
}
adr = field.type.meta.address
if adr.name == "Any" and adr.package == ("google", "protobuf"):
# If it is Any type pack a random but validly encoded type,
# Duration in this specific case.
msg_dict = {
"type_url": "type.googleapis.com/google.protobuf.Duration",
"value": b'\x08\x0c\x10\xdb\x07',
}
else:
msg_dict = {
f.name: recursive_mock_original_type(f)
for f in field.message.fields.values()
}

return [msg_dict] if field.repeated else msg_dict

Expand Down Expand Up @@ -237,9 +246,16 @@ def primitive_mock(self, suffix: int = 0) -> Union[bool, str, bytes, int, float,
if self.type.python_type == bool:
answer = True
elif self.type.python_type == str:
answer = f"{self.name}_value_{suffix}" if suffix else f"{self.name}_value"
if self.name == "type_url":
# It is most likely a mock for Any type. We don't really care
# which mock value to put, so lets put a value which makes
# Any deserializer happy, which will wtill work even if it
# is not Any.
answer = "type.googleapis.com/google.protobuf.Empty"
else:
answer = f"{self.name}_value{suffix}" if suffix else f"{self.name}_value"
elif self.type.python_type == bytes:
answer_str = f"{self.name}_blob_{suffix}" if suffix else f"{self.name}_blob"
answer_str = f"{self.name}_blob{suffix}" if suffix else f"{self.name}_blob"
answer = bytes(answer_str, encoding="utf-8")
elif self.type.python_type == int:
answer = sum([ord(i) for i in self.name]) + suffix
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ from google.api_core import rest_streaming
from google.api_core import path_template
from google.api_core import gapic_v1

from google.protobuf import json_format
{% if service.has_lro %}
from google.api_core import operations_v1
from google.protobuf import json_format
{% endif %}
from requests import __version__ as requests_version
import dataclasses
Expand Down Expand Up @@ -328,20 +328,19 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% endfor %}{# rule in method.http_options #}
]
request, metadata = self._interceptor.pre_{{ method.name|snake_case }}(request, metadata)
request_kwargs = {{method.input.ident}}.to_dict(request)
transcoded_request = path_template.transcode(
http_options, **request_kwargs)
{% if method.input.ident.is_external_type %}
pb_request = request
{% else %}
pb_request = {{method.input.ident}}.pb(request)
{% endif %}
transcoded_request = path_template.transcode(http_options, pb_request)

{% set body_spec = method.http_options[0].body %}
{%- if body_spec %}
# Jsonify the request body
body = {% if body_spec == '*' -%}
{{method.input.ident}}.to_json(
{{method.input.ident}}(transcoded_request['body']),
{% else -%}
{{method.input.fields[body_spec].type.ident}}.to_json(
{{method.input.fields[body_spec].type.ident}}(transcoded_request['body']),
{% endif %}{# body_spec == "*" #}

body = json_format.MessageToJson(
transcoded_request['body'],
including_default_value_fields=False,
use_integers_for_enums={{ opts.rest_numeric_enums }}
)
Expand All @@ -351,12 +350,11 @@ class {{service.name}}RestTransport({{service.name}}Transport):
method = transcoded_request['method']

# Jsonify the query params
query_params = json.loads({{method.input.ident}}.to_json(
{{method.input.ident}}(transcoded_request['query_params']),
query_params = json.loads(json_format.MessageToJson(
transcoded_request['query_params'],
including_default_value_fields=False,
use_integers_for_enums={{ opts.rest_numeric_enums }}
use_integers_for_enums={{ opts.rest_numeric_enums }},
))

{% if method.input.required_fields %}
query_params.update(self._get_unset_required_fields(query_params))
{% endif %}{# required fields #}
Expand Down Expand Up @@ -391,10 +389,14 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% elif method.server_streaming %}
resp = rest_streaming.ResponseIterator(response, {{method.output.ident}})
{% else %}
resp = {{method.output.ident}}.from_json(
response.content,
ignore_unknown_fields=True
)
resp = {{method.output.ident}}()
{% if method.output.ident.is_external_type %}
pb_resp = resp
{% else %}
pb_resp = {{method.output.ident}}.pb(resp)
{% endif %}

json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True)
{% endif %}{# method.lro #}
resp = self._interceptor.post_{{ method.name|snake_case }}(resp)
return resp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import grpc
from grpc.experimental import aio
{% if "rest" in opts.transport %}
from collections.abc import Iterable
from google.protobuf import json_format
import json
{% endif %}
import math
Expand Down Expand Up @@ -51,9 +52,6 @@ from google.api_core import future
from google.api_core import operation
from google.api_core import operations_v1
from google.longrunning import operations_pb2
{% if "rest" in opts.transport %}
from google.protobuf import json_format
{% endif %}{# rest transport #}
{% endif %}{# lro #}
{% if api.has_location_mixin %}
from google.cloud.location import locations_pb2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,7 @@ def test_{{ method_name }}_rest(request_type):
request_init["{{ field.name }}"] = {{ field.merged_mock_value(method.http_options[0].sample_request(method).get(field.name)) }}
{% endif %}
{% endfor %}
request = request_type(request_init)
request = request_type(**request_init)
{% if method.client_streaming %}
requests = [request]
{% endif %}
Expand Down Expand Up @@ -902,11 +902,19 @@ def test_{{ method_name }}_rest(request_type):
json_return_value = ''
{% elif method.lro %}
json_return_value = json_format.MessageToJson(return_value)
{% elif method.server_streaming %}
json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
{% else %}
json_return_value = {{ method.output.ident }}.to_json(return_value)
{% if method.output.ident.is_external_type %}
pb_return_value = return_value
{% else %}
pb_return_value = {{ method.output.ident }}.pb(return_value)
{% endif %}
json_return_value = json_format.MessageToJson(pb_return_value)
{% endif %}

{% if method.server_streaming %}
json_return_value = "[{}]".format(json_return_value)
{% endif %}

response_value._content = json_return_value.encode('UTF-8')
req.return_value = response_value
{% if method.client_streaming %}
Expand Down Expand Up @@ -965,20 +973,25 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
request_init["{{ req_field.name }}"] = {{ req_field.type.python_type(req_field.field_pb.default_value or 0) }}
{% endif %}{# default is str #}
{% endfor %}
request = request_type(request_init)
jsonified_request = json.loads(request_type.to_json(
request,
request = request_type(**request_init)
{% if method.input.ident.is_external_type %}
pb_request = request
{% else %}
pb_request = request_type.pb(request)
{% endif %}
jsonified_request = json.loads(json_format.MessageToJson(
pb_request,
including_default_value_fields=False,
use_integers_for_enums=False
))
))

# verify fields with default values are dropped
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
{% set field_name = req_field.name | camel_case %}
assert "{{ field_name }}" not in jsonified_request
{% endfor %}

unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).{{ method.name | snake_case }}._get_unset_required_fields(jsonified_request)
unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).{{ method.transport_safe_name | snake_case }}._get_unset_required_fields(jsonified_request)
jsonified_request.update(unset_fields)

# verify required fields with default values are now present
Expand All @@ -994,7 +1007,7 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
jsonified_request["{{ field_name }}"] = {{ mock_value }}
{% endfor %}

unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).{{ method.name | snake_case }}._get_unset_required_fields(jsonified_request)
unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).{{ method.transport_safe_name | snake_case }}._get_unset_required_fields(jsonified_request)
{% if method.query_params %}
# Check that path parameters and body parameters are not mixing in.
assert not set(unset_fields) - set(({% for param in method.query_params|sort %}"{{param}}", {% endfor %}))
Expand All @@ -1014,7 +1027,7 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
credentials=ga_credentials.AnonymousCredentials(),
transport='rest',
)
request = request_type(request_init)
request = request_type(**request_init)

# Designate an appropriate value for the returned response.
{% if method.void %}
Expand All @@ -1032,13 +1045,18 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
with mock.patch.object(path_template, 'transcode') as transcode:
# A uri without fields and an empty body will force all the
# request fields to show up in the query_params.
{% if method.input.ident.is_external_type %}
pb_request = request
{% else %}
pb_request = request_type.pb(request)
{% endif %}
transcode_result = {
'uri': 'v1/sample_method',
'method': "{{ method.http_options[0].method }}",
'query_params': request_init,
'query_params': pb_request,
}
{% if method.http_options[0].body %}
transcode_result['body'] = {}
transcode_result['body'] = pb_request
{% endif %}
transcode.return_value = transcode_result

Expand All @@ -1048,11 +1066,19 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
json_return_value = ''
{% elif method.lro %}
json_return_value = json_format.MessageToJson(return_value)
{% elif method.server_streaming %}
json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
{% else %}
json_return_value = {{ method.output.ident }}.to_json(return_value)

{% if method.output.ident.is_external_type %}
pb_return_value = return_value
{% else %}
pb_return_value = {{ method.output.ident }}.pb(return_value)
{% endif %}
json_return_value = json_format.MessageToJson(pb_return_value)
{% endif %}
{% if method.server_streaming %}
json_return_value = "[{}]".format(json_return_value)
{% endif %}

response_value._content = json_return_value.encode('UTF-8')
req.return_value = response_value

Expand Down Expand Up @@ -1115,8 +1141,17 @@ def test_{{ method_name }}_rest_interceptors(null_interceptor):
{% if not method.void %}
post.assert_not_called()
{% endif %}

transcode.return_value = {"method": "post", "uri": "my_uri", "body": None, "query_params": {},}
{% if method.input.ident.is_external_type %}
pb_message = {{ method.input.ident }}()
{% else %}
pb_message = {{ method.input.ident }}.pb({{ method.input.ident }}())
{% endif %}
transcode.return_value = {
"method": "post",
"uri": "my_uri",
"body": pb_message,
"query_params": pb_message,
}

req.return_value = Response()
req.return_value.status_code = 200
Expand Down Expand Up @@ -1164,7 +1199,7 @@ def test_{{ method_name }}_rest_bad_request(transport: str = 'rest', request_typ
request_init["{{ field.name }}"] = {{ field.merged_mock_value(method.http_options[0].sample_request(method).get(field.name)) }}
{% endif %}
{% endfor %}
request = request_type(request_init)
request = request_type(**request_init)
{% if method.client_streaming %}
requests = [request]
{% endif %}
Expand Down Expand Up @@ -1222,12 +1257,17 @@ def test_{{ method_name }}_rest_flattened():
json_return_value = ''
{% elif method.lro %}
json_return_value = json_format.MessageToJson(return_value)
{% elif method.server_streaming %}
json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
{% else %}
json_return_value = {{ method.output.ident }}.to_json(return_value)
{% if method.output.ident.is_external_type %}
pb_return_value = return_value
{% else %}
pb_return_value = {{ method.output.ident }}.pb(return_value)
{% endif %}
json_return_value = json_format.MessageToJson(pb_return_value)
{% endif %}
{% if method.server_streaming %}
json_return_value = "[{}]".format(json_return_value)
{% endif %}

response_value._content = json_return_value.encode('UTF-8')
req.return_value = response_value

Expand Down Expand Up @@ -1342,7 +1382,7 @@ def test_{{ method_name }}_rest_pager(transport: str = 'rest'):
return_values = tuple(Response() for i in response)
for return_val, response_val in zip(return_values, response):
{% if method.server_streaming %}
response_val = "[{}]".format({{ method.output.ident }}.to_json(response_val))
response_val = "[{}]".format(response_val)
{% endif %}
return_val._content = response_val.encode('UTF-8')
return_val.status_code = 200
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ async def sample_generate_access_token():
# Initialize request argument(s)
request = credentials_v1.GenerateAccessTokenRequest(
name="name_value",
scope=['scope_value_1', 'scope_value_2'],
scope=['scope_value1', 'scope_value2'],
)

# Make the request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def sample_generate_access_token():
# Initialize request argument(s)
request = credentials_v1.GenerateAccessTokenRequest(
name="name_value",
scope=['scope_value_1', 'scope_value_2'],
scope=['scope_value1', 'scope_value2'],
)

# Make the request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async def sample_generate_access_token():
# Initialize request argument(s)
request = credentials_v1.GenerateAccessTokenRequest(
name="name_value",
scope=['scope_value_1', 'scope_value_2'],
scope=['scope_value1', 'scope_value2'],
)

# Make the request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def sample_generate_access_token():
# Initialize request argument(s)
request = credentials_v1.GenerateAccessTokenRequest(
name="name_value",
scope=['scope_value_1', 'scope_value_2'],
scope=['scope_value1', 'scope_value2'],
)

# Make the request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ async def sample_list_log_entries():

# Initialize request argument(s)
request = logging_v2.ListLogEntriesRequest(
resource_names=['resource_names_value_1', 'resource_names_value_2'],
resource_names=['resource_names_value1', 'resource_names_value2'],
)

# Make the request
Expand Down Expand Up @@ -867,7 +867,7 @@ async def sample_tail_log_entries():

# Initialize request argument(s)
request = logging_v2.TailLogEntriesRequest(
resource_names=['resource_names_value_1', 'resource_names_value_2'],
resource_names=['resource_names_value1', 'resource_names_value2'],
)

# This method expects an iterator which contains
Expand Down
Loading