From 7b8ad663425336609bbfb900a3f12843a444b5de Mon Sep 17 00:00:00 2001 From: Peter Edge Date: Mon, 14 Sep 2015 18:11:02 -0700 Subject: [PATCH 1/3] add runtime.WithForwardResponseOption --- .../examplepb/a_bit_of_everything.pb.gw.go | 24 ++++----- examples/examplepb/echo_service.pb.gw.go | 6 +-- examples/examplepb/flow_combination.pb.gw.go | 54 +++++++++---------- examples/integration_test.go | 33 ++++++++++-- examples/main.go | 8 +-- .../gengateway/template.go | 6 +-- runtime/handler.go | 51 ++++++++++++++---- runtime/mux.go | 37 +++++++++++-- 8 files changed, 151 insertions(+), 68 deletions(-) diff --git a/examples/examplepb/a_bit_of_everything.pb.gw.go b/examples/examplepb/a_bit_of_everything.pb.gw.go index d79901fa843..1383694f037 100644 --- a/examples/examplepb/a_bit_of_everything.pb.gw.go +++ b/examples/examplepb/a_bit_of_everything.pb.gw.go @@ -406,7 +406,7 @@ func request_ABitOfEverythingService_BulkEcho_0(ctx context.Context, client ABit // RegisterABitOfEverythingServiceHandlerFromEndpoint is same as RegisterABitOfEverythingServiceHandler but // automatically dials to "endpoint" and closes the connection when "ctx" gets done. func RegisterABitOfEverythingServiceHandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string) (err error) { - conn, err := grpc.Dial(endpoint) + conn, err := grpc.Dial(endpoint, grpc.WithInsecure()) if err != nil { return err } @@ -440,7 +440,7 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se return } - forward_ABitOfEverythingService_Create_0(ctx, w, req, resp) + forward_ABitOfEverythingService_Create_0(ctx, w, req, resp, mux.GetForwardResponseOptions()...) }) @@ -451,7 +451,7 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se return } - forward_ABitOfEverythingService_CreateBody_0(ctx, w, req, resp) + forward_ABitOfEverythingService_CreateBody_0(ctx, w, req, resp, mux.GetForwardResponseOptions()...) }) @@ -462,7 +462,7 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se return } - forward_ABitOfEverythingService_BulkCreate_0(ctx, w, req, resp) + forward_ABitOfEverythingService_BulkCreate_0(ctx, w, req, resp, mux.GetForwardResponseOptions()...) }) @@ -473,7 +473,7 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se return } - forward_ABitOfEverythingService_Lookup_0(ctx, w, req, resp) + forward_ABitOfEverythingService_Lookup_0(ctx, w, req, resp, mux.GetForwardResponseOptions()...) }) @@ -484,7 +484,7 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se return } - forward_ABitOfEverythingService_List_0(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }) + forward_ABitOfEverythingService_List_0(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) }) @@ -495,7 +495,7 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se return } - forward_ABitOfEverythingService_Update_0(ctx, w, req, resp) + forward_ABitOfEverythingService_Update_0(ctx, w, req, resp, mux.GetForwardResponseOptions()...) }) @@ -506,7 +506,7 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se return } - forward_ABitOfEverythingService_Delete_0(ctx, w, req, resp) + forward_ABitOfEverythingService_Delete_0(ctx, w, req, resp, mux.GetForwardResponseOptions()...) }) @@ -517,7 +517,7 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se return } - forward_ABitOfEverythingService_Echo_0(ctx, w, req, resp) + forward_ABitOfEverythingService_Echo_0(ctx, w, req, resp, mux.GetForwardResponseOptions()...) }) @@ -528,7 +528,7 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se return } - forward_ABitOfEverythingService_Echo_1(ctx, w, req, resp) + forward_ABitOfEverythingService_Echo_1(ctx, w, req, resp, mux.GetForwardResponseOptions()...) }) @@ -539,7 +539,7 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se return } - forward_ABitOfEverythingService_Echo_2(ctx, w, req, resp) + forward_ABitOfEverythingService_Echo_2(ctx, w, req, resp, mux.GetForwardResponseOptions()...) }) @@ -550,7 +550,7 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se return } - forward_ABitOfEverythingService_BulkEcho_0(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }) + forward_ABitOfEverythingService_BulkEcho_0(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) }) diff --git a/examples/examplepb/echo_service.pb.gw.go b/examples/examplepb/echo_service.pb.gw.go index d9a2a9d02bf..13fa7a94d82 100644 --- a/examples/examplepb/echo_service.pb.gw.go +++ b/examples/examplepb/echo_service.pb.gw.go @@ -66,7 +66,7 @@ func request_EchoService_EchoBody_0(ctx context.Context, client EchoServiceClien // RegisterEchoServiceHandlerFromEndpoint is same as RegisterEchoServiceHandler but // automatically dials to "endpoint" and closes the connection when "ctx" gets done. func RegisterEchoServiceHandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string) (err error) { - conn, err := grpc.Dial(endpoint) + conn, err := grpc.Dial(endpoint, grpc.WithInsecure()) if err != nil { return err } @@ -100,7 +100,7 @@ func RegisterEchoServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn return } - forward_EchoService_Echo_0(ctx, w, req, resp) + forward_EchoService_Echo_0(ctx, w, req, resp, mux.GetForwardResponseOptions()...) }) @@ -111,7 +111,7 @@ func RegisterEchoServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn return } - forward_EchoService_EchoBody_0(ctx, w, req, resp) + forward_EchoService_EchoBody_0(ctx, w, req, resp, mux.GetForwardResponseOptions()...) }) diff --git a/examples/examplepb/flow_combination.pb.gw.go b/examples/examplepb/flow_combination.pb.gw.go index da61a9d350e..d33becb0eba 100644 --- a/examples/examplepb/flow_combination.pb.gw.go +++ b/examples/examplepb/flow_combination.pb.gw.go @@ -786,7 +786,7 @@ func request_FlowCombination_RpcPathNestedStream_2(ctx context.Context, client F // RegisterFlowCombinationHandlerFromEndpoint is same as RegisterFlowCombinationHandler but // automatically dials to "endpoint" and closes the connection when "ctx" gets done. func RegisterFlowCombinationHandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string) (err error) { - conn, err := grpc.Dial(endpoint) + conn, err := grpc.Dial(endpoint, grpc.WithInsecure()) if err != nil { return err } @@ -820,7 +820,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_RpcEmptyRpc_0(ctx, w, req, resp) + forward_FlowCombination_RpcEmptyRpc_0(ctx, w, req, resp, mux.GetForwardResponseOptions()...) }) @@ -831,7 +831,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_RpcEmptyStream_0(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }) + forward_FlowCombination_RpcEmptyStream_0(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) }) @@ -842,7 +842,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_StreamEmptyRpc_0(ctx, w, req, resp) + forward_FlowCombination_StreamEmptyRpc_0(ctx, w, req, resp, mux.GetForwardResponseOptions()...) }) @@ -853,7 +853,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_StreamEmptyStream_0(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }) + forward_FlowCombination_StreamEmptyStream_0(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) }) @@ -864,7 +864,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_RpcBodyRpc_0(ctx, w, req, resp) + forward_FlowCombination_RpcBodyRpc_0(ctx, w, req, resp, mux.GetForwardResponseOptions()...) }) @@ -875,7 +875,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_RpcBodyRpc_1(ctx, w, req, resp) + forward_FlowCombination_RpcBodyRpc_1(ctx, w, req, resp, mux.GetForwardResponseOptions()...) }) @@ -886,7 +886,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_RpcBodyRpc_2(ctx, w, req, resp) + forward_FlowCombination_RpcBodyRpc_2(ctx, w, req, resp, mux.GetForwardResponseOptions()...) }) @@ -897,7 +897,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_RpcBodyRpc_3(ctx, w, req, resp) + forward_FlowCombination_RpcBodyRpc_3(ctx, w, req, resp, mux.GetForwardResponseOptions()...) }) @@ -908,7 +908,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_RpcBodyRpc_4(ctx, w, req, resp) + forward_FlowCombination_RpcBodyRpc_4(ctx, w, req, resp, mux.GetForwardResponseOptions()...) }) @@ -919,7 +919,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_RpcBodyRpc_5(ctx, w, req, resp) + forward_FlowCombination_RpcBodyRpc_5(ctx, w, req, resp, mux.GetForwardResponseOptions()...) }) @@ -930,7 +930,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_RpcBodyRpc_6(ctx, w, req, resp) + forward_FlowCombination_RpcBodyRpc_6(ctx, w, req, resp, mux.GetForwardResponseOptions()...) }) @@ -941,7 +941,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_RpcPathSingleNestedRpc_0(ctx, w, req, resp) + forward_FlowCombination_RpcPathSingleNestedRpc_0(ctx, w, req, resp, mux.GetForwardResponseOptions()...) }) @@ -952,7 +952,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_RpcPathNestedRpc_0(ctx, w, req, resp) + forward_FlowCombination_RpcPathNestedRpc_0(ctx, w, req, resp, mux.GetForwardResponseOptions()...) }) @@ -963,7 +963,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_RpcPathNestedRpc_1(ctx, w, req, resp) + forward_FlowCombination_RpcPathNestedRpc_1(ctx, w, req, resp, mux.GetForwardResponseOptions()...) }) @@ -974,7 +974,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_RpcPathNestedRpc_2(ctx, w, req, resp) + forward_FlowCombination_RpcPathNestedRpc_2(ctx, w, req, resp, mux.GetForwardResponseOptions()...) }) @@ -985,7 +985,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_RpcBodyStream_0(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }) + forward_FlowCombination_RpcBodyStream_0(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) }) @@ -996,7 +996,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_RpcBodyStream_1(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }) + forward_FlowCombination_RpcBodyStream_1(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) }) @@ -1007,7 +1007,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_RpcBodyStream_2(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }) + forward_FlowCombination_RpcBodyStream_2(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) }) @@ -1018,7 +1018,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_RpcBodyStream_3(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }) + forward_FlowCombination_RpcBodyStream_3(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) }) @@ -1029,7 +1029,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_RpcBodyStream_4(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }) + forward_FlowCombination_RpcBodyStream_4(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) }) @@ -1040,7 +1040,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_RpcBodyStream_5(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }) + forward_FlowCombination_RpcBodyStream_5(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) }) @@ -1051,7 +1051,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_RpcBodyStream_6(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }) + forward_FlowCombination_RpcBodyStream_6(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) }) @@ -1062,7 +1062,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_RpcPathSingleNestedStream_0(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }) + forward_FlowCombination_RpcPathSingleNestedStream_0(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) }) @@ -1073,7 +1073,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_RpcPathNestedStream_0(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }) + forward_FlowCombination_RpcPathNestedStream_0(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) }) @@ -1084,7 +1084,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_RpcPathNestedStream_1(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }) + forward_FlowCombination_RpcPathNestedStream_1(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) }) @@ -1095,7 +1095,7 @@ func RegisterFlowCombinationHandler(ctx context.Context, mux *runtime.ServeMux, return } - forward_FlowCombination_RpcPathNestedStream_2(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }) + forward_FlowCombination_RpcPathNestedStream_2(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) }) diff --git a/examples/integration_test.go b/examples/integration_test.go index 4382351ab27..22129ec3be7 100644 --- a/examples/integration_test.go +++ b/examples/integration_test.go @@ -12,9 +12,13 @@ import ( "testing" "time" + "golang.org/x/net/context" + gw "github.com/gengo/grpc-gateway/examples/examplepb" server "github.com/gengo/grpc-gateway/examples/server" sub "github.com/gengo/grpc-gateway/examples/sub" + "github.com/gengo/grpc-gateway/runtime" + "github.com/golang/protobuf/proto" ) func TestIntegration(t *testing.T) { @@ -30,14 +34,14 @@ func TestIntegration(t *testing.T) { } }() go func() { - if err := Run(); err != nil { + if err := Run(":8080"); err != nil { t.Errorf("gw.Run() failed with %v; want success", err) return } }() time.Sleep(100 * time.Millisecond) - testEcho(t) + testEcho(t, 8080, "application/json") testEchoBody(t) testABECreate(t) testABECreateBody(t) @@ -45,10 +49,27 @@ func TestIntegration(t *testing.T) { testABELookup(t) testABEList(t) testAdditionalBindings(t) + + go func() { + if err := Run( + ":8081", + runtime.WithForwardResponseOption( + func(_ context.Context, w http.ResponseWriter, _ proto.Message) error { + w.Header().Set("Content-Type", "application/vnd.docker.plugins.v1.1+json") + return nil + }, + ), + ); err != nil { + t.Errorf("gw.Run() failed with %v; want success", err) + return + } + }() + + testEcho(t, 8081, "application/vnd.docker.plugins.v1.1+json") } -func testEcho(t *testing.T) { - url := "http://localhost:8080/v1/example/echo/myid" +func testEcho(t *testing.T, port int, contentType string) { + url := fmt.Sprintf("http://localhost:%d/v1/example/echo/myid", port) resp, err := http.Post(url, "application/json", strings.NewReader("{}")) if err != nil { t.Errorf("http.Post(%q) failed with %v; want success", url, err) @@ -74,6 +95,10 @@ func testEcho(t *testing.T) { if got, want := msg.Id, "myid"; got != want { t.Errorf("msg.Id = %q; want %q", got, want) } + + if value := resp.Header.Get("Content-Type"); value != contentType { + t.Errorf("Content-Type was %s, wanted %s", value, contentType) + } } func testEchoBody(t *testing.T) { diff --git a/examples/main.go b/examples/main.go index facfd9168a9..10c528a35bc 100644 --- a/examples/main.go +++ b/examples/main.go @@ -16,12 +16,12 @@ var ( flowEndpoint = flag.String("flow_endpoint", "localhost:9090", "endpoint of FlowCombination") ) -func Run() error { +func Run(address string, opts ...runtime.ServeMuxOption) error { ctx := context.Background() ctx, cancel := context.WithCancel(ctx) defer cancel() - mux := runtime.NewServeMux() + mux := runtime.NewServeMux(opts...) err := examplepb.RegisterEchoServiceHandlerFromEndpoint(ctx, mux, *echoEndpoint) if err != nil { return err @@ -35,7 +35,7 @@ func Run() error { return err } - http.ListenAndServe(":8080", mux) + http.ListenAndServe(address, mux) return nil } @@ -43,7 +43,7 @@ func main() { flag.Parse() defer glog.Flush() - if err := Run(); err != nil { + if err := Run(":8080"); err != nil { glog.Fatal(err) } } diff --git a/protoc-gen-grpc-gateway/gengateway/template.go b/protoc-gen-grpc-gateway/gengateway/template.go index 66ed99a3153..cefb0253a8d 100644 --- a/protoc-gen-grpc-gateway/gengateway/template.go +++ b/protoc-gen-grpc-gateway/gengateway/template.go @@ -216,7 +216,7 @@ var ( // Register{{$svc.GetName}}HandlerFromEndpoint is same as Register{{$svc.GetName}}Handler but // automatically dials to "endpoint" and closes the connection when "ctx" gets done. func Register{{$svc.GetName}}HandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string) (err error) { - conn, err := grpc.Dial(endpoint) + conn, err := grpc.Dial(endpoint, grpc.WithInsecure()) if err != nil { return err } @@ -251,9 +251,9 @@ func Register{{$svc.GetName}}Handler(ctx context.Context, mux *runtime.ServeMux, return } {{if $m.GetServerStreaming}} - forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }) + forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(ctx, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) {{else}} - forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(ctx, w, req, resp) + forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(ctx, w, req, resp, mux.GetForwardResponseOptions()...) {{end}} }) {{end}} diff --git a/runtime/handler.go b/runtime/handler.go index 577a060f74d..658a9edeb5b 100644 --- a/runtime/handler.go +++ b/runtime/handler.go @@ -17,7 +17,7 @@ type responseStreamChunk struct { } // ForwardResponseStream forwards the stream from gRPC server to REST client. -func ForwardResponseStream(ctx context.Context, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error)) { +func ForwardResponseStream(ctx context.Context, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) { f, ok := w.(http.Flusher) if !ok { glog.Errorf("Flush not supported in %T", w) @@ -27,6 +27,10 @@ func ForwardResponseStream(ctx context.Context, w http.ResponseWriter, req *http w.Header().Set("Transfer-Encoding", "chunked") w.Header().Set("Content-Type", "application/json") + if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } w.WriteHeader(http.StatusOK) f.Flush() for { @@ -35,15 +39,11 @@ func ForwardResponseStream(ctx context.Context, w http.ResponseWriter, req *http return } if err != nil { - buf, merr := json.Marshal(responseStreamChunk{Error: err.Error()}) - if merr != nil { - glog.Errorf("Failed to marshal an error: %v", merr) - return - } - if _, werr := fmt.Fprintf(w, "%s\n", buf); werr != nil { - glog.Errorf("Failed to notify error to client: %v", werr) - return - } + handleForwardResponseStreamError(w, err) + return + } + if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil { + handleForwardResponseStreamError(w, err) return } buf, err := json.Marshal(responseStreamChunk{Result: resp}) @@ -60,7 +60,7 @@ func ForwardResponseStream(ctx context.Context, w http.ResponseWriter, req *http } // ForwardResponseMessage forwards the message "resp" from gRPC server to REST client. -func ForwardResponseMessage(ctx context.Context, w http.ResponseWriter, req *http.Request, resp proto.Message) { +func ForwardResponseMessage(ctx context.Context, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) { buf, err := json.Marshal(resp) if err != nil { glog.Errorf("Marshal error: %v", err) @@ -69,7 +69,36 @@ func ForwardResponseMessage(ctx context.Context, w http.ResponseWriter, req *htt } w.Header().Set("Content-Type", "application/json") + if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil { + HTTPError(ctx, w, err) + return + } if _, err = w.Write(buf); err != nil { glog.Errorf("Failed to write response: %v", err) } } + +func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error { + if opts == nil || len(opts) == 0 { + return nil + } + for _, opt := range opts { + if err := opt(ctx, w, resp); err != nil { + glog.Errorf("Error handling ForwardResponseOptions: %v", err) + return err + } + } + return nil +} + +func handleForwardResponseStreamError(w http.ResponseWriter, err error) { + buf, merr := json.Marshal(responseStreamChunk{Error: err.Error()}) + if merr != nil { + glog.Errorf("Failed to marshal an error: %v", merr) + return + } + if _, werr := fmt.Fprintf(w, "%s\n", buf); werr != nil { + glog.Errorf("Failed to notify error to client: %v", werr) + return + } +} diff --git a/runtime/mux.go b/runtime/mux.go index 3538fc3820f..da99a6b073c 100644 --- a/runtime/mux.go +++ b/runtime/mux.go @@ -4,7 +4,10 @@ import ( "net/http" "strings" + "golang.org/x/net/context" + "github.com/golang/glog" + "github.com/golang/protobuf/proto" ) // A HandlerFunc handles a specific pair of path pattern and HTTP method. @@ -14,14 +17,35 @@ type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[str // It matches http requests to patterns and invokes the corresponding handler. type ServeMux struct { // handlers maps HTTP method to a list of handlers. - handlers map[string][]handler + handlers map[string][]handler + forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error +} + +// ServeMuxOption is an option that can be given to a ServeMux on construction. +type ServeMuxOption func(*ServeMux) + +// WithForwardResponseOption returns a ServeMuxOption representing the forwardResponseOption. +// +// forwardResponseOption is an option that will be called on the relevant context.Context, +// http.ResponseWriter, and proto.Message before every forwarded response. +// +// The message may be nil in the case where just a header is being sent. +func WithForwardResponseOption(forwardResponseOption func(context.Context, http.ResponseWriter, proto.Message) error) ServeMuxOption { + return func(serveMux *ServeMux) { + serveMux.forwardResponseOptions = append(serveMux.forwardResponseOptions, forwardResponseOption) + } } // NewServeMux returns a new MuxHandler whose internal mapping is empty. -func NewServeMux() *ServeMux { - return &ServeMux{ - handlers: make(map[string][]handler), +func NewServeMux(opts ...ServeMuxOption) *ServeMux { + serveMux := &ServeMux{ + handlers: make(map[string][]handler), + forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0), + } + for _, opt := range opts { + opt(serveMux) } + return serveMux } // Handle associates "h" to the pair of HTTP method and path pattern. @@ -92,6 +116,11 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) } +// GetForwardResponseOptions returns the ForwardResponseOptions associated with this ServeMux. +func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.ResponseWriter, proto.Message) error { + return s.forwardResponseOptions +} + func isPathLengthFallback(r *http.Request) bool { return r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded" } From 5931ef5eaed53277b40bc055adfa626a379cb9ca Mon Sep 17 00:00:00 2001 From: Peter Edge Date: Mon, 14 Sep 2015 18:13:09 -0700 Subject: [PATCH 2/3] move handleForwardResponseOptions before json.Marshal in ForwardResponseMessage --- runtime/handler.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/runtime/handler.go b/runtime/handler.go index 658a9edeb5b..5a5e69bb78c 100644 --- a/runtime/handler.go +++ b/runtime/handler.go @@ -61,18 +61,19 @@ func ForwardResponseStream(ctx context.Context, w http.ResponseWriter, req *http // ForwardResponseMessage forwards the message "resp" from gRPC server to REST client. func ForwardResponseMessage(ctx context.Context, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) { - buf, err := json.Marshal(resp) - if err != nil { - glog.Errorf("Marshal error: %v", err) + w.Header().Set("Content-Type", "application/json") + if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil { HTTPError(ctx, w, err) return } - w.Header().Set("Content-Type", "application/json") - if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil { + buf, err := json.Marshal(resp) + if err != nil { + glog.Errorf("Marshal error: %v", err) HTTPError(ctx, w, err) return } + if _, err = w.Write(buf); err != nil { glog.Errorf("Failed to write response: %v", err) } From 68e036171ac2e43bbbe08c7d37383e757887c26d Mon Sep 17 00:00:00 2001 From: Peter Edge Date: Mon, 28 Sep 2015 14:45:07 -0700 Subject: [PATCH 3/3] remove redundant nil check --- runtime/handler.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtime/handler.go b/runtime/handler.go index 6fe0370c0bd..b1242e20ee7 100644 --- a/runtime/handler.go +++ b/runtime/handler.go @@ -88,7 +88,7 @@ func ForwardResponseMessage(ctx context.Context, w http.ResponseWriter, req *htt } func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error { - if opts == nil || len(opts) == 0 { + if len(opts) == 0 { return nil } for _, opt := range opts {