From bb00f271ae03a4954c14dd663c425f593df372f0 Mon Sep 17 00:00:00 2001 From: Thomas Jackson Date: Wed, 18 Apr 2018 16:26:26 -0700 Subject: [PATCH] Support cases where the request is done with transfer-encoding chunked (#589) * Support cases where the request is done with transfer-encoding chunked PR #527 was put in place to fix an issue where an empty request body would cause an empty message to be sent down to the GRPC service (instead of failing). The fix at the time was to check that the ContentLength was >0, but this doesn't take into consideration transfer-encoding chunked POSTs. Since this patch all chunked POSTs no longer unmarshal the message (as the content-length was 0). My proposed fix is instead to always call Decode and simply ignore EOF errors (as we still want to pass the un-filled struct down). I have tested that things such as partial json blobs (something like '{') don' t return EOF (they have their own json error). --- .../examplepb/a_bit_of_everything.pb.gw.go | 36 ++++------ examples/examplepb/echo_service.pb.gw.go | 6 +- examples/examplepb/flow_combination.pb.gw.go | 72 +++++++------------ examples/examplepb/wrappers.pb.gw.go | 6 +- examples/integration_test.go | 15 ++++ .../gengateway/template.go | 6 +- 6 files changed, 57 insertions(+), 84 deletions(-) diff --git a/examples/examplepb/a_bit_of_everything.pb.gw.go b/examples/examplepb/a_bit_of_everything.pb.gw.go index 5fc7bbdb8c0..8c9d4fe98d8 100644 --- a/examples/examplepb/a_bit_of_everything.pb.gw.go +++ b/examples/examplepb/a_bit_of_everything.pb.gw.go @@ -224,10 +224,8 @@ func request_ABitOfEverythingService_CreateBody_0(ctx context.Context, marshaler var protoReq ABitOfEverything var metadata runtime.ServerMetadata - if req.ContentLength > 0 { - if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } msg, err := client.CreateBody(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) @@ -266,10 +264,8 @@ func request_ABitOfEverythingService_Update_0(ctx context.Context, marshaler run var protoReq ABitOfEverything var metadata runtime.ServerMetadata - if req.ContentLength > 0 { - if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } var ( @@ -388,10 +384,8 @@ func request_ABitOfEverythingService_Echo_1(ctx context.Context, marshaler runti var protoReq sub.StringMessage var metadata runtime.ServerMetadata - if req.ContentLength > 0 { - if err := marshaler.NewDecoder(req.Body).Decode(&protoReq.Value); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq.Value); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } msg, err := client.Echo(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) @@ -420,10 +414,8 @@ func request_ABitOfEverythingService_DeepPathEcho_0(ctx context.Context, marshal var protoReq ABitOfEverything var metadata runtime.ServerMetadata - if req.ContentLength > 0 { - if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } var ( @@ -471,10 +463,8 @@ func request_ABitOfEverythingService_GetMessageWithBody_0(ctx context.Context, m var protoReq MessageWithBody var metadata runtime.ServerMetadata - if req.ContentLength > 0 { - if err := marshaler.NewDecoder(req.Body).Decode(&protoReq.Data); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq.Data); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } var ( @@ -504,10 +494,8 @@ func request_ABitOfEverythingService_PostWithEmptyBody_0(ctx context.Context, ma var protoReq Body var metadata runtime.ServerMetadata - if req.ContentLength > 0 { - if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } var ( diff --git a/examples/examplepb/echo_service.pb.gw.go b/examples/examplepb/echo_service.pb.gw.go index bc766f7b808..878808b266d 100644 --- a/examples/examplepb/echo_service.pb.gw.go +++ b/examples/examplepb/echo_service.pb.gw.go @@ -105,10 +105,8 @@ func request_EchoService_EchoBody_0(ctx context.Context, marshaler runtime.Marsh var protoReq SimpleMessage var metadata runtime.ServerMetadata - if req.ContentLength > 0 { - if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } msg, err := client.EchoBody(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) diff --git a/examples/examplepb/flow_combination.pb.gw.go b/examples/examplepb/flow_combination.pb.gw.go index adca0fd727d..6f743dd9427 100644 --- a/examples/examplepb/flow_combination.pb.gw.go +++ b/examples/examplepb/flow_combination.pb.gw.go @@ -151,10 +151,8 @@ func request_FlowCombination_RpcBodyRpc_0(ctx context.Context, marshaler runtime var protoReq NonEmptyProto var metadata runtime.ServerMetadata - if req.ContentLength > 0 { - if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } msg, err := client.RpcBodyRpc(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) @@ -232,10 +230,8 @@ func request_FlowCombination_RpcBodyRpc_3(ctx context.Context, marshaler runtime var protoReq NonEmptyProto var metadata runtime.ServerMetadata - if req.ContentLength > 0 { - if err := marshaler.NewDecoder(req.Body).Decode(&protoReq.C); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq.C); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } var ( @@ -280,10 +276,8 @@ func request_FlowCombination_RpcBodyRpc_4(ctx context.Context, marshaler runtime var protoReq NonEmptyProto var metadata runtime.ServerMetadata - if req.ContentLength > 0 { - if err := marshaler.NewDecoder(req.Body).Decode(&protoReq.C); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq.C); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } if err := runtime.PopulateQueryParameters(&protoReq, req.URL.Query(), filter_FlowCombination_RpcBodyRpc_4); err != nil { @@ -303,10 +297,8 @@ func request_FlowCombination_RpcBodyRpc_5(ctx context.Context, marshaler runtime var protoReq NonEmptyProto var metadata runtime.ServerMetadata - if req.ContentLength > 0 { - if err := marshaler.NewDecoder(req.Body).Decode(&protoReq.C); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq.C); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } var ( @@ -414,10 +406,8 @@ func request_FlowCombination_RpcPathNestedRpc_0(ctx context.Context, marshaler r var protoReq NestedProto var metadata runtime.ServerMetadata - if req.ContentLength > 0 { - if err := marshaler.NewDecoder(req.Body).Decode(&protoReq.C); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq.C); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } var ( @@ -501,10 +491,8 @@ func request_FlowCombination_RpcPathNestedRpc_2(ctx context.Context, marshaler r var protoReq NestedProto var metadata runtime.ServerMetadata - if req.ContentLength > 0 { - if err := marshaler.NewDecoder(req.Body).Decode(&protoReq.C); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq.C); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } var ( @@ -538,10 +526,8 @@ func request_FlowCombination_RpcBodyStream_0(ctx context.Context, marshaler runt var protoReq NonEmptyProto var metadata runtime.ServerMetadata - if req.ContentLength > 0 { - if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } stream, err := client.RpcBodyStream(ctx, &protoReq) @@ -643,10 +629,8 @@ func request_FlowCombination_RpcBodyStream_3(ctx context.Context, marshaler runt var protoReq NonEmptyProto var metadata runtime.ServerMetadata - if req.ContentLength > 0 { - if err := marshaler.NewDecoder(req.Body).Decode(&protoReq.C); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq.C); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } var ( @@ -699,10 +683,8 @@ func request_FlowCombination_RpcBodyStream_4(ctx context.Context, marshaler runt var protoReq NonEmptyProto var metadata runtime.ServerMetadata - if req.ContentLength > 0 { - if err := marshaler.NewDecoder(req.Body).Decode(&protoReq.C); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq.C); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } if err := runtime.PopulateQueryParameters(&protoReq, req.URL.Query(), filter_FlowCombination_RpcBodyStream_4); err != nil { @@ -730,10 +712,8 @@ func request_FlowCombination_RpcBodyStream_5(ctx context.Context, marshaler runt var protoReq NonEmptyProto var metadata runtime.ServerMetadata - if req.ContentLength > 0 { - if err := marshaler.NewDecoder(req.Body).Decode(&protoReq.C); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq.C); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } var ( @@ -865,10 +845,8 @@ func request_FlowCombination_RpcPathNestedStream_0(ctx context.Context, marshale var protoReq NestedProto var metadata runtime.ServerMetadata - if req.ContentLength > 0 { - if err := marshaler.NewDecoder(req.Body).Decode(&protoReq.C); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq.C); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } var ( @@ -968,10 +946,8 @@ func request_FlowCombination_RpcPathNestedStream_2(ctx context.Context, marshale var protoReq NestedProto var metadata runtime.ServerMetadata - if req.ContentLength > 0 { - if err := marshaler.NewDecoder(req.Body).Decode(&protoReq.C); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq.C); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } var ( diff --git a/examples/examplepb/wrappers.pb.gw.go b/examples/examplepb/wrappers.pb.gw.go index bd57b237e34..a940b090174 100644 --- a/examples/examplepb/wrappers.pb.gw.go +++ b/examples/examplepb/wrappers.pb.gw.go @@ -32,10 +32,8 @@ func request_WrappersService_Create_0(ctx context.Context, marshaler runtime.Mar var protoReq Wrappers var metadata runtime.ServerMetadata - if req.ContentLength > 0 { - if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } msg, err := client.Create(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) diff --git a/examples/integration_test.go b/examples/integration_test.go index 405726258fc..5cc2ed70e50 100644 --- a/examples/integration_test.go +++ b/examples/integration_test.go @@ -15,6 +15,7 @@ import ( "time" "context" + "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes/empty" @@ -667,6 +668,20 @@ func testAdditionalBindings(t *testing.T, port int) { } return resp }, + func() *http.Response { + r, w := io.Pipe() + go func() { + defer w.Close() + w.Write([]byte(`"hello"`)) + }() + url := fmt.Sprintf("http://localhost:%d/v2/example/echo", port) + resp, err := http.Post(url, "application/json", r) + if err != nil { + t.Errorf("http.Post(%q, %q, %q) failed with %v; want success", url, "application/json", `"hello"`, err) + return nil + } + return resp + }, func() *http.Response { url := fmt.Sprintf("http://localhost:%d/v2/example/echo?value=hello", port) resp, err := http.Get(url) diff --git a/protoc-gen-grpc-gateway/gengateway/template.go b/protoc-gen-grpc-gateway/gengateway/template.go index c662e7b29fc..435a1405100 100644 --- a/protoc-gen-grpc-gateway/gengateway/template.go +++ b/protoc-gen-grpc-gateway/gengateway/template.go @@ -207,10 +207,8 @@ var ( var protoReq {{.Method.RequestType.GoType .Method.Service.File.GoPkg.Path}} var metadata runtime.ServerMetadata {{if .Body}} - if req.ContentLength > 0 { - if err := marshaler.NewDecoder(req.Body).Decode(&{{.Body.AssignableExpr "protoReq"}}); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } + if err := marshaler.NewDecoder(req.Body).Decode(&{{.Body.AssignableExpr "protoReq"}}); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } {{end}} {{if .PathParams}}