Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import collections
import copy
import dataclasses
import functools
import json
import keyword
import re
Expand Down Expand Up @@ -1035,10 +1036,20 @@ def _to_regex(self, path_template: str) -> Pattern:
"""
return re.compile(f"^{self._convert_to_regex(path_template)}$")

# Use caching to avoid repeated computation
# TODO(https://github.com/googleapis/gapic-generator-python/issues/2161):
# Use `@functools.cache` instead of `@functools.lru_cache` once python 3.8 is dropped.
# https://docs.python.org/3/library/functools.html#functools.cache
@functools.lru_cache(maxsize=None)
def to_regex(self) -> Pattern:
return self._to_regex(self.path_template)

@property
# Use caching to avoid repeated computation
# TODO(https://github.com/googleapis/gapic-generator-python/issues/2161):
# Use `@functools.cache` instead of `@functools.lru_cache` once python 3.8 is dropped.
# https://docs.python.org/3/library/functools.html#functools.cache
@functools.lru_cache(maxsize=None)
def key(self) -> Union[str, None]:
if self.path_template == "":
return self.field
Expand Down Expand Up @@ -1067,6 +1078,69 @@ def try_parse_routing_rule(cls, routing_rule: routing_pb2.RoutingRule) -> Option
params = [RoutingParameter(x.field, x.path_template) for x in params]
return cls(params)

@classmethod
def resolve(cls, routing_rule: routing_pb2.RoutingRule, request: Union[dict, str]) -> dict:
"""Resolves the routing header which should be sent along with the request.
The routing header is determined based on the given routing rule and request.
See the following link for more information on explicit routing headers:
https://google.aip.dev/client-libraries/4222#explicit-routing-headers-googleapirouting

Args:
routing_rule(routing_pb2.RoutingRule): A collection of Routing Parameter specifications
defined by `routing_pb2.RoutingRule`.
See https://github.com/googleapis/googleapis/blob/cb39bdd75da491466f6c92bc73cd46b0fbd6ba9a/google/api/routing.proto#L391
request(Union[dict, str]): The request for which the routine rule should be resolved.
The format can be either a dictionary or json string representing the request.

Returns(dict):
A dictionary containing the resolved routing header to the sent along with the given request.
"""

def _get_field(request, field_path: str):
segments = field_path.split(".")

# Either json string or dictionary is supported
if isinstance(request, str):
current = json.loads(request)
else:
current = request

# This is to cater for the case where the `field_path` contains a
# dot-separated path of field names leading to a field in a sub-message.
for x in segments:
current = current.get(x, None)
# Break if the sub-message does not exist
if current is None:
break
return current

header_params = {}
# TODO(https://github.com/googleapis/gapic-generator-python/issues/2160): Move this logic to
# `google-api-core` so that the shared code can be used in both `wrappers.py` and GAPIC clients
# via Jinja templates.
for routing_param in routing_rule.routing_parameters:
request_field_value = _get_field(request, routing_param.field)
# Only resolve the header for routing parameter fields which are populated in the request
if request_field_value is not None:
# If there is a path_template for a given routing parameter field, the value of the field must match
# If multiple `routing_param`s describe the same key
# (via the `path_template` field or via the `field` field when
# `path_template` is not provided), the "last one wins" rule
# determines which parameter gets used. See https://google.aip.dev/client-libraries/4222.
routing_parameter_key = routing_param.key
if routing_param.path_template:
routing_param_regex = routing_param.to_regex()
regex_match = routing_param_regex.match(
request_field_value
)
if regex_match:
header_params[routing_parameter_key] = regex_match.group(
routing_parameter_key
)
else: # No need to match
header_params[routing_parameter_key] = request_field_value
return header_params


@dataclasses.dataclass(frozen=True)
class HttpRule:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,12 +414,11 @@ def test_{{ method.name|snake_case }}_routing_parameters():

# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls) == 1
_, args, _ = call.mock_calls[0]
_, args, kw = call.mock_calls[0]
assert args[0] == request

_, _, kw = call.mock_calls[0]
# This test doesn't assert anything useful.
assert kw['metadata']
expected_headers = {{ method.routing_rule.resolve(method.routing_rule, routing_param.sample_request) }}
assert gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw['metadata']
{% endfor %}
{% endif %}

Expand Down
54 changes: 25 additions & 29 deletions tests/unit/schema/wrappers/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from gapic.schema import wrappers

import json
import proto
import pytest

Expand All @@ -23,31 +24,6 @@ class RoutingTestRequest(proto.Message):
app_profile_id = proto.Field(proto.STRING, number=2)


def resolve(rule, request):
"""This function performs dynamic header resolution, identical to what's in client.py.j2."""

def _get_field(request, field_path: str):
segments = field_path.split(".")
cur = request
for x in segments:
cur = getattr(cur, x)
return cur

header_params = {}
for routing_param in rule.routing_parameters:
# This may raise exception (which we show to clients).
request_field_value = _get_field(request, routing_param.field)
if routing_param.path_template:
routing_param_regex = routing_param.to_regex()
regex_match = routing_param_regex.match(request_field_value)
if regex_match:
header_params[routing_param.key] = regex_match.group(
routing_param.key)
else: # No need to match
header_params[routing_param.key] = request_field_value
return header_params


@pytest.mark.parametrize(
"req, expected",
[
Expand All @@ -63,7 +39,10 @@ def _get_field(request, field_path: str):
def test_routing_rule_resolve_simple_extraction(req, expected):
rule = wrappers.RoutingRule(
[wrappers.RoutingParameter("app_profile_id", "")])
assert resolve(rule, req) == expected
assert wrappers.RoutingRule.resolve(
rule,
RoutingTestRequest.to_dict(req)
) == expected


@pytest.mark.parametrize(
Expand All @@ -82,7 +61,10 @@ def test_routing_rule_resolve_rename_extraction(req, expected):
rule = wrappers.RoutingRule(
[wrappers.RoutingParameter("app_profile_id", "{routing_id=**}")]
)
assert resolve(rule, req) == expected
assert wrappers.RoutingRule.resolve(
rule,
RoutingTestRequest.to_dict(req)
) == expected


@pytest.mark.parametrize(
Expand Down Expand Up @@ -111,7 +93,10 @@ def test_routing_rule_resolve_field_match(req, expected):
),
]
)
assert resolve(rule, req) == expected
assert wrappers.RoutingRule.resolve(
rule,
RoutingTestRequest.to_dict(req)
) == expected


@pytest.mark.parametrize(
Expand All @@ -135,6 +120,9 @@ def test_routing_rule_resolve_field_match(req, expected):
wrappers.RoutingParameter(
"table_name", "projects/*/{instance_id=instances/*}/**"
),
wrappers.RoutingParameter(
"doesnotexist", "projects/*/{instance_id=instances/*}/**"
),
],
RoutingTestRequest(
table_name="projects/100/instances/200/tables/300"),
Expand All @@ -144,7 +132,15 @@ def test_routing_rule_resolve_field_match(req, expected):
)
def test_routing_rule_resolve(routing_parameters, req, expected):
rule = wrappers.RoutingRule(routing_parameters)
got = resolve(rule, req)
got = wrappers.RoutingRule.resolve(
rule, RoutingTestRequest.to_dict(req)
)
assert got == expected

rule = wrappers.RoutingRule(routing_parameters)
got = wrappers.RoutingRule.resolve(
rule, json.dumps(RoutingTestRequest.to_dict(req))
)
assert got == expected


Expand Down