Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 211 additions & 0 deletions gapic/templates/tests/unit/gapic/%name_%version/%sub/test_macros.j2
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,11 @@ def test_{{ method_name }}_raw_page_lro():
{% with method_name = method.safe_name|snake_case + "_unary" if method.extended_lro and not full_extended_lro else method.name|snake_case, method_output = method.extended_lro.operation_type if method.extended_lro and not full_extended_lro else method.output %}{% if method.http_options %}
{# TODO(kbandes): remove this if condition when lro and client streaming are supported. #}
{% if not method.client_streaming %}
{# NOTE: This guard is added to avoid generating duplicate tests for methods which are tested elsewhere. As we implement each of the api methods
# in the `macro::call_success_test`, the case will be removed from this condition below.
# TODO(https://github.com/googleapis/gapic-generator-python/issues/2143): Remove the test `test_{{ method_name }}_rest` from here once the linked issue is resolved.
#}
{% if method.server_streaming or method.lro or method.extended_lro or method.paged_result_field %}
@pytest.mark.parametrize("request_type", [
{{ method.input.ident }},
dict,
Expand Down Expand Up @@ -1185,6 +1190,7 @@ def test_{{ method_name }}_rest(request_type):
json_return_value = json_format.MessageToJson(return_value)
{% else %}
{% if method.output.ident.is_proto_plus_type %}

# Convert return value to protobuf type
return_value = {{ method.output.ident }}.pb(return_value)
{% endif %}
Expand Down Expand Up @@ -1249,6 +1255,7 @@ def test_{{ method_name }}_rest(request_type):
{% endfor %}
{% endif %}

{% endif %}{# if method.server_streaming or method.lro or method.extended_lro or method.paged_result_field #}
def test_{{ method_name }}_rest_use_cached_wrapped_rpc():
# Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
# instead of constructing them on each call
Expand Down Expand Up @@ -1953,6 +1960,7 @@ def test_unsupported_parameter_rest_asyncio():
{{ rest_method_not_implemented_error(service, method, transport, is_async) }}
{% else %}
{{ bad_request_test(service, method, transport, is_async) }}
{{ call_success_test(service, method, transport, is_async) }}
{% endif %}{# is_rest_unsupported_method(method, is_async) == 'False' and method.http_options #}
{% endfor %}
{{ initialize_client_with_transport_test(service, transport, is_async) }}
Expand Down Expand Up @@ -2056,3 +2064,206 @@ def test_initialize_client_w_{{transport_name}}():

{% endif %}{# if 'rest' in transport #}
{% endmacro %}

{# call_success_test generates tests for rest methods
# when they make a successful request.
# NOTE: Currently, this macro does not support the following method
# types: [method.server_streaming, method.lro, method.extended_lro, method.paged_result_field].
# As support is added for the above methods, the relevant guard can be removed from within the macro
# TODO(https://github.com/googleapis/gapic-generator-python/issues/2142): Clean up `rest_required_tests` as we add support for each of the method types metioned above.
#}
{% macro call_success_test(service, method, transport, is_async) %}
{% if 'rest' in transport %}
{% set await_prefix = get_await_prefix(is_async) %}
{% set async_prefix = get_async_prefix(is_async) %}
{% set async_decorator = get_async_decorator(is_async) %}
{% set transport_name = get_transport_name(transport, is_async) %}
{% set method_name = method.name|snake_case %}
{# NOTE: set method_output to method.extended_lro.operation_type for the following method types:
# (method.extended_lro and not full_extended_lro)
#}
{% set method_output = method.output %}
{# TODO(https://github.com/googleapis/gapic-generator-python/issues/2143): Update the guard below as we add support for each method, and keep it in sync with the guard in
# `rest_required_tests`, which should be the exact opposite. Remove it once we have all the methods supported in async rest transport that are supported in sync rest transport.
#}
{% if not (method.server_streaming or method.lro or method.extended_lro or method.paged_result_field)%}
{{async_decorator}}
@pytest.mark.parametrize("request_type", [
{{ method.input.ident }},
dict,
])
{{async_prefix}}def test_{{method_name}}_{{transport_name}}_call_success(request_type):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the method_call_test from https://github.com/googleapis/gapic-generator-python/pull/2126/files be used here to reduce code duplication?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it can exactly be used as is but as I'm working through refactoring the test cases, I see that we can reduce code duplication by pulling out the logic to generate mock response according to the type of api method being tested into a separate macro.

If there's obvious code duplication that we're introducing or if sth could be easily resolved then we shall address that. Otherwise, I think it's alright to do this as a follow up. WDYT?

{% if transport_name == 'rest_asyncio' %}
if not HAS_GOOGLE_AUTH_AIO:
{# TODO(https://github.com/googleapis/google-auth-library-python/pull/1577): Update the version of google-auth once the linked PR is merged. #}
pytest.skip("google-auth > 2.x.x is required for async rest transport.")

{% endif %}
client = {{ get_client(service, is_async) }}(
credentials={{get_credentials(is_async)}},
transport="{{transport_name}}"
)

# send a request that will satisfy transcoding
request_init = {{ method.http_options[0].sample_request(method) }}
{% for field in method.body_fields.values() %}
{% if not field.oneof or field.proto3_optional %}
{# ignore oneof fields that might conflict with sample_request #}
request_init["{{ field.name }}"] = {{ field.merged_mock_value(method.http_options[0].sample_request(method).get(field.name)) }}
# The version of a generated dependency at test runtime may differ from the version used during generation.
# Delete any fields which are not present in the current runtime dependency
# See https://github.com/googleapis/gapic-generator-python/issues/1748

# Determine if the message type is proto-plus or protobuf
test_field = {{ method.input.ident }}.meta.fields["{{ field.name }}"]

def get_message_fields(field):
# Given a field which is a message (composite type), return a list with
# all the fields of the message.
# If the field is not a composite type, return an empty list.
message_fields = []

if hasattr(field, "message") and field.message:
is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR")

if is_field_type_proto_plus_type:
message_fields = field.message.meta.fields.values()
# Add `# pragma: NO COVER` because there may not be any `*_pb2` field types
else: # pragma: NO COVER
message_fields = field.message.DESCRIPTOR.fields
return message_fields

runtime_nested_fields = [
(field.name, nested_field.name)
for field in get_message_fields(test_field)
for nested_field in get_message_fields(field)
]

subfields_not_in_runtime = []

# For each item in the sample request, create a list of sub fields which are not present at runtime
# Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime
for field, value in request_init["{{ field.name }}"].items(): # pragma: NO COVER
result = None
is_repeated = False
# For repeated fields
if isinstance(value, list) and len(value):
is_repeated = True
result = value[0]
# For fields where the type is another message
if isinstance(value, dict):
result = value

if result and hasattr(result, "keys"):
for subfield in result.keys():
if (field, subfield) not in runtime_nested_fields:
subfields_not_in_runtime.append(
{"field": field, "subfield": subfield, "is_repeated": is_repeated}
)

# Remove fields from the sample request which are not present in the runtime version of the dependency
# Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime
for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER
field = subfield_to_delete.get("field")
field_repeated = subfield_to_delete.get("is_repeated")
subfield = subfield_to_delete.get("subfield")
if subfield:
if field_repeated:
for i in range(0, len(request_init["{{ field.name }}"][field])):
del request_init["{{ field.name }}"][field][i][subfield]
else:
del request_init["{{ field.name }}"][field][subfield]
{% endif %}
{% endfor %}
request = request_type(**request_init)

# Mock the http request call within the method and fake a response.
with mock.patch.object(type(client.transport._session), 'request') as req:
# Designate an appropriate value for the returned response.
{% if method.void %}
return_value = None
{% elif method.lro %}
return_value = operations_pb2.Operation(name='operations/spam')
{% elif method.extended_lro %}
return_value = {{ method.extended_lro.operation_type.ident }}(
{% for field in method.extended_lro.operation_type.fields.values() | rejectattr('message')%}
{% if not field.oneof or field.proto3_optional %}
{{ field.name }}={{ field.mock_value }},
{% endif %}{% endfor %}
{# This is a hack to only pick one field #}
{% for oneof_fields in method.output.oneof_fields().values() %}
{% if (oneof_fields | rejectattr('message') | list) %}
{% with field = (oneof_fields | rejectattr('message') | first) %}
{{ field.name }}={{ field.mock_value }},
{% endwith %}
{% endif %}
{% endfor %}
)
{% else %}
return_value = {{ method.output.ident }}(
{% for field in method.output.fields.values() | rejectattr('message')%}
{% if not field.oneof or field.proto3_optional %}
{{ field.name }}={{ field.mock_value }},
{% endif %}{% endfor %}
{# This is a hack to only pick one field #}
{% for oneof_fields in method.output.oneof_fields().values() %}
{% if (oneof_fields | rejectattr('message') | list) %}
{% with field = (oneof_fields | rejectattr('message') | first) %}
{{ field.name }}={{ field.mock_value }},
{% endwith %}
{% endif %}
{% endfor %}
)
{% endif %}{# method.void #}

# Wrap the value into a proper Response obj
response_value = mock.Mock()
response_value.status_code = 200
{% if method.void %}
json_return_value = ''
{% else %}{# method.void #}
{% if method.output.ident.is_proto_plus_type %}

# Convert return value to protobuf type
return_value = {{ method.output.ident }}.pb(return_value)
{% endif %}{# method.output.ident.is_proto_plus_type #}
json_return_value = json_format.MessageToJson(return_value)
{% endif %}{# method.void #}
{% if is_async %}
response_value.read = mock.AsyncMock(return_value=b'{}')
{% else %}{# is_async #}
response_value.content = json_return_value.encode('UTF-8')
{% endif %}
req.return_value = response_value
response = {{ await_prefix }}client.{{ method_name }}(request)

# Establish that the response is the type that we expect.
{% if method.void %}
assert response is None
{% else %}
assert isinstance(response, {{ method.client_output.ident }})
{% for field in method_output.fields.values() | rejectattr('message') %}
{% if not field.oneof or field.proto3_optional %}
{% if field.field_pb.type in [1, 2] %}{# Use approx eq for floats #}
{% if field.repeated %}
for index in range(len(response.{{ field.name }})):
assert math.isclose(
response.{{ field.name }}[index],
{{ field.mock_value }}[index],
rel_tol=1e-6,
)
{% else %}{# field.repeated #}
assert math.isclose(response.{{ field.name }}, {{ field.mock_value }}, rel_tol=1e-6)
{% endif %}{# field.repeated #}
{% elif field.field_pb.type == 8 %}{# Use 'is' for bools #}
assert response.{{ field.name }} is {{ field.mock_value }}
{% else %}
assert response.{{ field.name }} == {{ field.mock_value }}
{% endif %}{# field.field_pb.type in [1, 2] #}
{% endif %}{# not field.oneof or field.proto3_optional #}
{% endfor %}{# field in method_output.fields.values() | rejectattr('message') #}
{% endif %}{# method.void #}

{% endif %}{# if not (method.server_streaming or method.lro or method.extended_lro or method.paged_result_field) #}
{% endif %}{# if 'rest' in transport #}
{% endmacro %}
Loading