diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index fd75e8fd5e..473b56e7eb 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -30,6 +30,7 @@ import collections import copy import dataclasses +import functools import json import keyword import re @@ -1035,10 +1036,20 @@ def _to_regex(self, path_template: str) -> Pattern: """ return re.compile(f"^{self._convert_to_regex(path_template)}$") + # Use caching to avoid repeated computation + # TODO(https://github.com/googleapis/gapic-generator-python/issues/2161): + # Use `@functools.cache` instead of `@functools.lru_cache` once python 3.8 is dropped. + # https://docs.python.org/3/library/functools.html#functools.cache + @functools.lru_cache(maxsize=None) def to_regex(self) -> Pattern: return self._to_regex(self.path_template) @property + # Use caching to avoid repeated computation + # TODO(https://github.com/googleapis/gapic-generator-python/issues/2161): + # Use `@functools.cache` instead of `@functools.lru_cache` once python 3.8 is dropped. + # https://docs.python.org/3/library/functools.html#functools.cache + @functools.lru_cache(maxsize=None) def key(self) -> Union[str, None]: if self.path_template == "": return self.field @@ -1067,6 +1078,69 @@ def try_parse_routing_rule(cls, routing_rule: routing_pb2.RoutingRule) -> Option params = [RoutingParameter(x.field, x.path_template) for x in params] return cls(params) + @classmethod + def resolve(cls, routing_rule: routing_pb2.RoutingRule, request: Union[dict, str]) -> dict: + """Resolves the routing header which should be sent along with the request. + The routing header is determined based on the given routing rule and request. + See the following link for more information on explicit routing headers: + https://google.aip.dev/client-libraries/4222#explicit-routing-headers-googleapirouting + + Args: + routing_rule(routing_pb2.RoutingRule): A collection of Routing Parameter specifications + defined by `routing_pb2.RoutingRule`. + See https://github.com/googleapis/googleapis/blob/cb39bdd75da491466f6c92bc73cd46b0fbd6ba9a/google/api/routing.proto#L391 + request(Union[dict, str]): The request for which the routine rule should be resolved. + The format can be either a dictionary or json string representing the request. + + Returns(dict): + A dictionary containing the resolved routing header to the sent along with the given request. + """ + + def _get_field(request, field_path: str): + segments = field_path.split(".") + + # Either json string or dictionary is supported + if isinstance(request, str): + current = json.loads(request) + else: + current = request + + # This is to cater for the case where the `field_path` contains a + # dot-separated path of field names leading to a field in a sub-message. + for x in segments: + current = current.get(x, None) + # Break if the sub-message does not exist + if current is None: + break + return current + + header_params = {} + # TODO(https://github.com/googleapis/gapic-generator-python/issues/2160): Move this logic to + # `google-api-core` so that the shared code can be used in both `wrappers.py` and GAPIC clients + # via Jinja templates. + for routing_param in routing_rule.routing_parameters: + request_field_value = _get_field(request, routing_param.field) + # Only resolve the header for routing parameter fields which are populated in the request + if request_field_value is not None: + # If there is a path_template for a given routing parameter field, the value of the field must match + # If multiple `routing_param`s describe the same key + # (via the `path_template` field or via the `field` field when + # `path_template` is not provided), the "last one wins" rule + # determines which parameter gets used. See https://google.aip.dev/client-libraries/4222. + routing_parameter_key = routing_param.key + if routing_param.path_template: + routing_param_regex = routing_param.to_regex() + regex_match = routing_param_regex.match( + request_field_value + ) + if regex_match: + header_params[routing_parameter_key] = regex_match.group( + routing_parameter_key + ) + else: # No need to match + header_params[routing_parameter_key] = request_field_value + return header_params + @dataclasses.dataclass(frozen=True) class HttpRule: diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_macros.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_macros.j2 index d0c1c5c055..ea146e0f42 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_macros.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_macros.j2 @@ -414,12 +414,11 @@ def test_{{ method.name|snake_case }}_routing_parameters(): # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] + _, args, kw = call.mock_calls[0] assert args[0] == request - _, _, kw = call.mock_calls[0] - # This test doesn't assert anything useful. - assert kw['metadata'] + expected_headers = {{ method.routing_rule.resolve(method.routing_rule, routing_param.sample_request) }} + assert gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw['metadata'] {% endfor %} {% endif %} diff --git a/tests/unit/schema/wrappers/test_routing.py b/tests/unit/schema/wrappers/test_routing.py index f93d6680a0..f4e6215bca 100644 --- a/tests/unit/schema/wrappers/test_routing.py +++ b/tests/unit/schema/wrappers/test_routing.py @@ -14,6 +14,7 @@ from gapic.schema import wrappers +import json import proto import pytest @@ -23,31 +24,6 @@ class RoutingTestRequest(proto.Message): app_profile_id = proto.Field(proto.STRING, number=2) -def resolve(rule, request): - """This function performs dynamic header resolution, identical to what's in client.py.j2.""" - - def _get_field(request, field_path: str): - segments = field_path.split(".") - cur = request - for x in segments: - cur = getattr(cur, x) - return cur - - header_params = {} - for routing_param in rule.routing_parameters: - # This may raise exception (which we show to clients). - request_field_value = _get_field(request, routing_param.field) - if routing_param.path_template: - routing_param_regex = routing_param.to_regex() - regex_match = routing_param_regex.match(request_field_value) - if regex_match: - header_params[routing_param.key] = regex_match.group( - routing_param.key) - else: # No need to match - header_params[routing_param.key] = request_field_value - return header_params - - @pytest.mark.parametrize( "req, expected", [ @@ -63,7 +39,10 @@ def _get_field(request, field_path: str): def test_routing_rule_resolve_simple_extraction(req, expected): rule = wrappers.RoutingRule( [wrappers.RoutingParameter("app_profile_id", "")]) - assert resolve(rule, req) == expected + assert wrappers.RoutingRule.resolve( + rule, + RoutingTestRequest.to_dict(req) + ) == expected @pytest.mark.parametrize( @@ -82,7 +61,10 @@ def test_routing_rule_resolve_rename_extraction(req, expected): rule = wrappers.RoutingRule( [wrappers.RoutingParameter("app_profile_id", "{routing_id=**}")] ) - assert resolve(rule, req) == expected + assert wrappers.RoutingRule.resolve( + rule, + RoutingTestRequest.to_dict(req) + ) == expected @pytest.mark.parametrize( @@ -111,7 +93,10 @@ def test_routing_rule_resolve_field_match(req, expected): ), ] ) - assert resolve(rule, req) == expected + assert wrappers.RoutingRule.resolve( + rule, + RoutingTestRequest.to_dict(req) + ) == expected @pytest.mark.parametrize( @@ -135,6 +120,9 @@ def test_routing_rule_resolve_field_match(req, expected): wrappers.RoutingParameter( "table_name", "projects/*/{instance_id=instances/*}/**" ), + wrappers.RoutingParameter( + "doesnotexist", "projects/*/{instance_id=instances/*}/**" + ), ], RoutingTestRequest( table_name="projects/100/instances/200/tables/300"), @@ -144,7 +132,15 @@ def test_routing_rule_resolve_field_match(req, expected): ) def test_routing_rule_resolve(routing_parameters, req, expected): rule = wrappers.RoutingRule(routing_parameters) - got = resolve(rule, req) + got = wrappers.RoutingRule.resolve( + rule, RoutingTestRequest.to_dict(req) + ) + assert got == expected + + rule = wrappers.RoutingRule(routing_parameters) + got = wrappers.RoutingRule.resolve( + rule, json.dumps(RoutingTestRequest.to_dict(req)) + ) assert got == expected