diff --git a/apisix/plugins/grpc-transcode/proto.lua b/apisix/plugins/grpc-transcode/proto.lua index 347ec39eae17..26c072aa59c3 100644 --- a/apisix/plugins/grpc-transcode/proto.lua +++ b/apisix/plugins/grpc-transcode/proto.lua @@ -20,6 +20,7 @@ local pb = require("pb") local protoc = require("protoc") local pcall = pcall local ipairs = ipairs +local pairs = pairs local decode_base64 = ngx.decode_base64 @@ -29,6 +30,94 @@ local lrucache_proto = core.lrucache.new({ }) local proto_fake_file = "filename for loaded" +local FIELD_TYPE_MESSAGE = 11 + +local function ensure_package_prefix(package) + if package and package ~= "" then + return "." .. package .. "." + end + + return "." +end + +local function create_field_descriptor(field) + return { + name = field.name, + label = field.label, + type = field.type, + type_name = field.type_name, + } +end + +local function register_message_descriptor(index, message, prefix) + local full_name = prefix .. message.name + local descriptor = { + full_name = full_name, + fields = {}, + map_entry = message.options and message.options.map_entry or false, + } + + for _, field in ipairs(message.field or {}) do + descriptor.fields[field.name] = create_field_descriptor(field) + end + + index[full_name] = descriptor + + local nested_prefix = full_name .. "." + for _, nested in ipairs(message.nested_type or {}) do + register_message_descriptor(index, nested, nested_prefix) + end +end + +local function build_message_index_from_file(index, file) + if not file then + return + end + + local prefix = ensure_package_prefix(file.package) + for _, message in ipairs(file.message_type or {}) do + register_message_descriptor(index, message, prefix) + end +end + +local function mark_map_fields(index) + for _, descriptor in pairs(index) do + if descriptor.map_entry and descriptor.fields then + descriptor.map_value_field = descriptor.fields.value + end + end + + for _, descriptor in pairs(index) do + for _, field in pairs(descriptor.fields) do + if field.type == FIELD_TYPE_MESSAGE and field.type_name then + local target = index[field.type_name] + if target and target.map_entry then + field.is_map = true + field.map_entry_descriptor = target + end + end + end + end +end + +local function build_message_index(files) + local index = {} + if not files then + return index + end + + if files.message_type or files.package then + build_message_index_from_file(index, files) + else + for _, file in ipairs(files) do + build_message_index_from_file(index, file) + end + end + + mark_map_fields(index) + + return index +end local function compile_proto_text(content) protoc.reload() @@ -60,6 +149,7 @@ local function compile_proto_text(content) end compiled[proto_fake_file].index = index + compiled.message_index = build_message_index(compiled[proto_fake_file]) return compiled end @@ -93,6 +183,7 @@ local function compile_proto_bin(content) local compiled = {} compiled[proto_fake_file] = {} compiled[proto_fake_file].index = index + compiled.message_index = build_message_index(files) return compiled end diff --git a/apisix/plugins/grpc-transcode/response.lua b/apisix/plugins/grpc-transcode/response.lua index 9dd6780f049d..ea258bc88f1f 100644 --- a/apisix/plugins/grpc-transcode/response.lua +++ b/apisix/plugins/grpc-transcode/response.lua @@ -23,6 +23,59 @@ local string = string local ngx_decode_base64 = ngx.decode_base64 local ipairs = ipairs local pcall = pcall +local type = type +local pairs = pairs +local setmetatable = setmetatable + +local _M = {} + +-- Protobuf repeated field label value +local PROTOBUF_REPEATED_LABEL = 3 +local repeated_label = PROTOBUF_REPEATED_LABEL +local FIELD_TYPE_MESSAGE = 11 + +local function set_default_array(tab, descriptor, message_index) + if type(tab) ~= "table" or not descriptor or not descriptor.fields then + return + end + + for field_name, field_info in pairs(descriptor.fields) do + local value = tab[field_name] + if value ~= nil and type(value) == "table" then + if field_info.label == repeated_label and not field_info.is_map then + setmetatable(value, core.json.array_mt) + end + + if field_info.type == FIELD_TYPE_MESSAGE then + if field_info.is_map then + local map_entry = field_info.map_entry_descriptor + local map_value_field = map_entry and map_entry.map_value_field + if map_value_field and map_value_field.type == FIELD_TYPE_MESSAGE then + local nested_desc = message_index and + message_index[map_value_field.type_name] + if nested_desc then + for _, map_val in pairs(value) do + set_default_array(map_val, nested_desc, message_index) + end + end + end + else + local nested_desc = message_index and + message_index[field_info.type_name] + if nested_desc then + if field_info.label == repeated_label then + for _, item in ipairs(value) do + set_default_array(item, nested_desc, message_index) + end + else + set_default_array(value, nested_desc, message_index) + end + end + end + end + end + end +end local function handle_error_response(status_detail_type, proto) @@ -93,7 +146,8 @@ local function handle_error_response(status_detail_type, proto) end -return function(ctx, proto, service, method, pb_option, show_status_in_body, status_detail_type) +local function transform_response(ctx, proto, service, method, pb_option, + show_status_in_body, status_detail_type) local buffer = core.response.hold_body_chunk(ctx) if not buffer then return nil @@ -132,6 +186,14 @@ return function(ctx, proto, service, method, pb_option, show_status_in_body, sta return err_msg end + local message_index = proto and proto.message_index + if message_index then + local output_descriptor = message_index[m.output_type] + if output_descriptor then + set_default_array(decoded, output_descriptor, message_index) + end + end + local response, err = core.json.encode(decoded) if not response then err_msg = "failed to json_encode response body" @@ -142,3 +204,13 @@ return function(ctx, proto, service, method, pb_option, show_status_in_body, sta ngx.arg[1] = response return nil end + +_M._TEST = { + set_default_array = set_default_array, +} + +return setmetatable(_M, { + __call = function(_, ...) + return transform_response(...) + end +}) diff --git a/t/plugin/grpc-transcode-arrays.t b/t/plugin/grpc-transcode-arrays.t new file mode 100644 index 000000000000..70455c76713b --- /dev/null +++ b/t/plugin/grpc-transcode-arrays.t @@ -0,0 +1,157 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +use t::APISIX 'no_plan'; + +no_long_string(); +no_shuffle(); +no_root_location(); + +add_block_preprocessor(sub { + my ($block) = @_; + + if (!$block->request) { + $block->set_value("request", "GET /t"); + } +}); + +run_tests; + +__DATA__ + +=== TEST 1: repeated field name scoped to its message +--- config + location /t { + content_by_lua_block { + local response = require("apisix.plugins.grpc-transcode.response") + local helpers = response._TEST + + local message_index = { + [".demo.Company"] = { + fields = { + name = {label = 3, type = 9}, + }, + }, + [".demo.InnerName"] = { + fields = { + part = {label = 1, type = 9}, + }, + }, + [".demo.HelloRequest"] = { + fields = { + company = {label = 1, type = 11, type_name = ".demo.Company"}, + name = {label = 1, type = 11, type_name = ".demo.InnerName"}, + }, + } + } + + local data = { + company = { name = {"foo", "bar"} }, + name = { part = "ceo" }, + } + + helpers.set_default_array(data, + message_index[".demo.HelloRequest"], message_index) + + local array_mt = require("apisix.core").json.array_mt + + if getmetatable(data.company.name) ~= array_mt then + ngx.status = 500 + ngx.say("company.name isn't treated as an array") + return + end + + if getmetatable(data.name) == array_mt then + ngx.status = 500 + ngx.say("nested message incorrectly converted to array") + return + end + + ngx.say("passed") + } + } +--- response_body +passed + + +=== TEST 2: map values keep object semantics while nested arrays are applied +--- config + location /t { + content_by_lua_block { + local response = require("apisix.plugins.grpc-transcode.response") + local helpers = response._TEST + + local member_descriptor = { + fields = { + alias = {label = 3, type = 9}, + } + } + + local map_entry_descriptor = { + map_entry = true, + fields = { + key = {label = 1, type = 9}, + value = {label = 1, type = 11, type_name = ".demo.Member"}, + }, + } + map_entry_descriptor.map_value_field = map_entry_descriptor.fields.value + + local team_descriptor = { + fields = { + members = { + label = 3, + type = 11, + type_name = ".demo.Team.MemberEntry", + is_map = true, + map_entry_descriptor = map_entry_descriptor, + } + } + } + + local message_index = { + [".demo.Member"] = member_descriptor, + [".demo.Team.MemberEntry"] = map_entry_descriptor, + [".demo.Team"] = team_descriptor, + } + + local data = { + members = { + alice = { alias = {"aa", "ab"} }, + bob = { alias = {"ba"} }, + } + } + + helpers.set_default_array(data, message_index[".demo.Team"], message_index) + + local array_mt = require("apisix.core").json.array_mt + + if getmetatable(data.members) == array_mt then + ngx.status = 500 + ngx.say("map field should not become an array") + return + end + + if getmetatable(data.members.alice.alias) ~= array_mt then + ngx.status = 500 + ngx.say("map values should still apply nested arrays") + return + end + + ngx.say("passed") + } + } +--- response_body +passed diff --git a/t/plugin/grpc-transcode.t b/t/plugin/grpc-transcode.t index e261bf7bd554..c7c4d0d71ef6 100644 --- a/t/plugin/grpc-transcode.t +++ b/t/plugin/grpc-transcode.t @@ -761,3 +761,54 @@ POST /grpctest Content-Type: application/json --- response_body eval qr/"gender":2/ + + + +=== TEST 30: set route (return empty array from grpc server) +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "methods": ["GET", "POST"], + "uri": "/grpctest", + "plugins": { + "grpc-transcode": { + "proto_id": "1", + "service": "helloworld.Greeter", + "method": "SayHello" + } + }, + "upstream": { + "scheme": "grpc", + "type": "roundrobin", + "nodes": { + "127.0.0.1:10051": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- request +GET /t +--- response_body +passed + + + +=== TEST 31: hit route, response keeps empty array +--- request +POST /grpctest +{"name":"world","items":[]} +--- more_headers +Content-Type: application/json +--- response_body eval +qr/"items":\[\]/