diff --git a/examples/integration/integration_test.go b/examples/integration/integration_test.go index 62db1063785..03aed1a0886 100644 --- a/examples/integration/integration_test.go +++ b/examples/integration/integration_test.go @@ -15,6 +15,7 @@ import ( "strings" "sync" "testing" + "time" "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/proto" @@ -263,6 +264,7 @@ func TestABE(t *testing.T) { testABECreate(t, 8080) testABECreateBody(t, 8080) testABEBulkCreate(t, 8080) + testABEBulkCreateWithError(t, 8080) testABELookup(t, 8080) testABELookupNotFound(t, 8080) testABEList(t, 8080) @@ -549,6 +551,65 @@ func testABEBulkCreate(t *testing.T, port int) { } } +func testABEBulkCreateWithError(t *testing.T, port int) { + count := 0 + r, w := io.Pipe() + go func(w io.WriteCloser) { + defer func() { + if cerr := w.Close(); cerr != nil { + t.Errorf("w.Close() failed with %v; want success", cerr) + } + }() + for _, val := range []string{ + "foo", "bar", "baz", "qux", "quux", + } { + time.Sleep(1 * time.Millisecond) + + want := gw.ABitOfEverything{ + StringValue: fmt.Sprintf("strprefix/%s", val), + } + var m jsonpb.Marshaler + if err := m.Marshal(w, &want); err != nil { + t.Fatalf("m.Marshal(%#v, w) failed with %v; want success", want, err) + } + if _, err := io.WriteString(w, "\n"); err != nil { + t.Errorf("w.Write(%q) failed with %v; want success", "\n", err) + return + } + count++ + } + }(w) + + apiURL := fmt.Sprintf("http://localhost:%d/v1/example/a_bit_of_everything/bulk", port) + request, err := http.NewRequest("POST", apiURL, r) + if err != nil { + t.Fatalf("http.NewRequest(%q, %q, nil) failed with %v; want success", "POST", apiURL, err) + } + request.Header.Add("Grpc-Metadata-error", "some error") + + resp, err := http.DefaultClient.Do(request) + if err != nil { + t.Errorf("http.Post(%q) failed with %v; want success", apiURL, err) + return + } + defer resp.Body.Close() + buf, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("ioutil.ReadAll(resp.Body) failed with %v; want success", err) + return + } + + if got, want := resp.StatusCode, http.StatusBadRequest; got != want { + t.Errorf("resp.StatusCode = %d; want %d", got, want) + t.Logf("%s", buf) + } + + var msg errorBody + if err := json.Unmarshal(buf, &msg); err != nil { + t.Fatalf("json.Unmarshal(%s, &msg) failed with %v; want success", buf, err) + } +} + func testABELookup(t *testing.T, port int) { apiURL := fmt.Sprintf("http://localhost:%d/v1/example/a_bit_of_everything", port) cresp, err := http.Post(apiURL, "application/json", strings.NewReader(` diff --git a/examples/proto/examplepb/flow_combination.pb.gw.go b/examples/proto/examplepb/flow_combination.pb.gw.go index e5de847c464..b29a6f3bbf9 100644 --- a/examples/proto/examplepb/flow_combination.pb.gw.go +++ b/examples/proto/examplepb/flow_combination.pb.gw.go @@ -73,6 +73,9 @@ func request_FlowCombination_StreamEmptyRpc_0(ctx context.Context, marshaler run return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } if err = stream.Send(&protoReq); err != nil { + if err == io.EOF { + break + } grpclog.Infof("Failed to send request: %v", err) return nil, metadata, err } diff --git a/examples/proto/examplepb/stream.pb.gw.go b/examples/proto/examplepb/stream.pb.gw.go index 3b01d805071..ec7cb9217b6 100644 --- a/examples/proto/examplepb/stream.pb.gw.go +++ b/examples/proto/examplepb/stream.pb.gw.go @@ -49,6 +49,9 @@ func request_StreamService_BulkCreate_0(ctx context.Context, marshaler runtime.M return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } if err = stream.Send(&protoReq); err != nil { + if err == io.EOF { + break + } grpclog.Infof("Failed to send request: %v", err) return nil, metadata, err } diff --git a/examples/server/a_bit_of_everything.go b/examples/server/a_bit_of_everything.go index fc3acca0c1e..b2bcf41994d 100644 --- a/examples/server/a_bit_of_everything.go +++ b/examples/server/a_bit_of_everything.go @@ -65,8 +65,15 @@ func (s *_ABitOfEverythingServer) CreateBody(ctx context.Context, msg *examples. } func (s *_ABitOfEverythingServer) BulkCreate(stream examples.StreamService_BulkCreateServer) error { - count := 0 ctx := stream.Context() + + if header, ok := metadata.FromIncomingContext(ctx); ok { + if v, ok := header["error"]; ok { + return status.Errorf(codes.InvalidArgument, "error metadata: %v", v) + } + } + + count := 0 for { msg, err := stream.Recv() if err == io.EOF { diff --git a/protoc-gen-grpc-gateway/gengateway/template.go b/protoc-gen-grpc-gateway/gengateway/template.go index d5a4980d65c..97770d93bf0 100644 --- a/protoc-gen-grpc-gateway/gengateway/template.go +++ b/protoc-gen-grpc-gateway/gengateway/template.go @@ -254,6 +254,9 @@ func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx cont return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } if err = stream.Send(&protoReq); err != nil { + if err == io.EOF { + break + } grpclog.Infof("Failed to send request: %v", err) return nil, metadata, err }