From 7983956e39bd3340117e1a63d2967a78fa92c915 Mon Sep 17 00:00:00 2001 From: Guillaume Fillon Date: Fri, 4 Nov 2016 17:39:58 +0100 Subject: [PATCH 1/2] Add usage of http.Request's Context --- README.md | 4 + .../gengateway/generator.go | 11 ++- .../gengateway/template.go | 98 ++++++++++++++++++- protoc-gen-grpc-gateway/main.go | 5 +- 4 files changed, 108 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 718047f3234..33f1bdddbde 100644 --- a/README.md +++ b/README.md @@ -186,6 +186,10 @@ Make sure that your `$GOPATH/bin` is in your `$PATH`. `protoc-gen-grpc-gateway` supports custom mapping from Protobuf `import` to Golang import path. They are compatible to [the parameters with same names in `protoc-gen-go`](https://github.com/golang/protobuf#parameters). +In addition we also support the `request_context` parameter in order to use the `http.Request`'s Context (only for Go 1.7 and above). +This parameter can be useful to pass request scoped context between the gateway and the gRPC service. +**WARNING**: using `request_context` has breaking API: `context.Context` is removed from all `Register${SvcName}Handler` functions. + `protoc-gen-grpc-gateway` also supports some more command line flags to control logging. You can give these flags together with parameters above. Run `protoc-gen-grpc-gateway --help` for more details about the flags. ## More Examples diff --git a/protoc-gen-grpc-gateway/gengateway/generator.go b/protoc-gen-grpc-gateway/gengateway/generator.go index b4cce8695eb..44d01187207 100644 --- a/protoc-gen-grpc-gateway/gengateway/generator.go +++ b/protoc-gen-grpc-gateway/gengateway/generator.go @@ -20,12 +20,13 @@ var ( ) type generator struct { - reg *descriptor.Registry - baseImports []descriptor.GoPackage + reg *descriptor.Registry + baseImports []descriptor.GoPackage + useRequestContext bool } // New returns a new generator which generates grpc gateway files. -func New(reg *descriptor.Registry) gen.Generator { +func New(reg *descriptor.Registry, useRequestContext bool) gen.Generator { var imports []descriptor.GoPackage for _, pkgpath := range []string{ "io", @@ -54,7 +55,7 @@ func New(reg *descriptor.Registry) gen.Generator { } imports = append(imports, pkg) } - return &generator{reg: reg, baseImports: imports} + return &generator{reg: reg, baseImports: imports, useRequestContext: useRequestContext} } func (g *generator) Generate(targets []*descriptor.File) ([]*plugin.CodeGeneratorResponse_File, error) { @@ -107,5 +108,5 @@ func (g *generator) generate(file *descriptor.File) (string, error) { imports = append(imports, pkg) } } - return applyTemplate(param{File: file, Imports: imports}) + return applyTemplate(param{File: file, Imports: imports, UseRequestContext: g.useRequestContext}) } diff --git a/protoc-gen-grpc-gateway/gengateway/template.go b/protoc-gen-grpc-gateway/gengateway/template.go index 3c3da539b95..5fec3501fa3 100644 --- a/protoc-gen-grpc-gateway/gengateway/template.go +++ b/protoc-gen-grpc-gateway/gengateway/template.go @@ -13,7 +13,8 @@ import ( type param struct { *descriptor.File - Imports []descriptor.GoPackage + Imports []descriptor.GoPackage + UseRequestContext bool } type binding struct { @@ -86,8 +87,14 @@ func applyTemplate(p param) (string, error) { if !methodSeen { return "", errNoTargetService } - if err := trailerTemplate.Execute(w, p.Services); err != nil { - return "", err + if p.UseRequestContext { + if err := trailerTemplate17.Execute(w, p.Services); err != nil { + return "", err + } + } else { + if err := trailerTemplate.Execute(w, p.Services); err != nil { + return "", err + } } return w.String(), nil } @@ -365,6 +372,91 @@ var ( {{end}} ) +var ( + {{range $m := $svc.Methods}} + {{range $b := $m.Bindings}} + forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}} = {{if $m.GetServerStreaming}}runtime.ForwardResponseStream{{else}}runtime.ForwardResponseMessage{{end}} + {{end}} + {{end}} +) +{{end}}`)) + + // trailerTemplate17 is the Go1.7 (and above) version of trailerTemplate. + trailerTemplate17 = template.Must(template.New("trailer-1.7").Parse(` +{{range $svc := .}} +// Register{{$svc.GetName}}HandlerFromEndpoint is same as Register{{$svc.GetName}}Handler but +// automatically dials to "endpoint" and closes the connection when "ctx" gets done. +func Register{{$svc.GetName}}HandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) { + conn, err := grpc.Dial(endpoint, opts...) + if err != nil { + return err + } + defer func() { + if err != nil { + if cerr := conn.Close(); cerr != nil { + grpclog.Printf("Failed to close conn to %s: %v", endpoint, cerr) + } + return + } + go func() { + <-ctx.Done() + if cerr := conn.Close(); cerr != nil { + grpclog.Printf("Failed to close conn to %s: %v", endpoint, cerr) + } + }() + }() + + return Register{{$svc.GetName}}Handler(ctx, mux, conn) +} + +// Register{{$svc.GetName}}Handler registers the http handlers for service {{$svc.GetName}} to "mux". +// The handlers forward requests to the grpc endpoint over "conn". +func Register{{$svc.GetName}}Handler(_ context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { + client := New{{$svc.GetName}}Client(conn) + {{range $m := $svc.Methods}} + {{range $b := $m.Bindings}} + mux.Handle({{$b.HTTPMethod | printf "%q"}}, pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + if cn, ok := w.(http.CloseNotifier); ok { + go func(done <-chan struct{}, closed <-chan bool) { + select { + case <-done: + case <-closed: + cancel() + } + }(ctx.Done(), cn.CloseNotify()) + } + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateContext(ctx, req) + if err != nil { + runtime.HTTPError(ctx, outboundMarshaler, w, req, err) + } + resp, md, err := request_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(rctx, inboundMarshaler, client, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, outboundMarshaler, w, req, err) + return + } + {{if $m.GetServerStreaming}} + forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(ctx, outboundMarshaler, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) + {{else}} + forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(ctx, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + {{end}} + }) + {{end}} + {{end}} + return nil +} + +var ( + {{range $m := $svc.Methods}} + {{range $b := $m.Bindings}} + pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}} = runtime.MustPattern(runtime.NewPattern({{$b.PathTmpl.Version}}, {{$b.PathTmpl.OpCodes | printf "%#v"}}, {{$b.PathTmpl.Pool | printf "%#v"}}, {{$b.PathTmpl.Verb | printf "%q"}})) + {{end}} + {{end}} +) + var ( {{range $m := $svc.Methods}} {{range $b := $m.Bindings}} diff --git a/protoc-gen-grpc-gateway/main.go b/protoc-gen-grpc-gateway/main.go index 1c01d8756ce..0e2c54f8b26 100644 --- a/protoc-gen-grpc-gateway/main.go +++ b/protoc-gen-grpc-gateway/main.go @@ -23,7 +23,8 @@ import ( ) var ( - importPrefix = flag.String("import_prefix", "", "prefix to be added to go package paths for imported proto files") + importPrefix = flag.String("import_prefix", "", "prefix to be added to go package paths for imported proto files") + useRequestContext = flag.Bool("request_context", false, "determine whether to use http.Request's context or not") ) func parseReq(r io.Reader) (*plugin.CodeGeneratorRequest, error) { @@ -73,7 +74,7 @@ func main() { } } - g := gengateway.New(reg) + g := gengateway.New(reg, *useRequestContext) reg.SetPrefix(*importPrefix) if err := reg.Load(req); err != nil { From 71ed796e0001a609c4ed9c38bcaf20145f2db4dd Mon Sep 17 00:00:00 2001 From: Travis Cline Date: Sat, 19 Nov 2016 12:08:05 -0800 Subject: [PATCH 2/2] add request_context flag --- .travis.yml | 5 +- Makefile | 7 +- README.md | 1 - .../gengateway/template.go | 102 ++---------------- 4 files changed, 18 insertions(+), 97 deletions(-) diff --git a/.travis.yml b/.travis.yml index 61ec5207f5d..daf891232c1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -23,10 +23,13 @@ before_script: - sh -c 'cd examples/browser && npm install' script: - make realclean && make examples SWAGGER_CODEGEN="java -jar $HOME/local/swagger-codegen-cli.jar" -- if ! go version | grep devel; then test -z "$(git status --porcelain)" || (git status; git diff; exit 1); fi +- if (go version | grep -qv devel) && [ -z "${GATEWAY_PLUGIN_FLAGS}" ]; then test -z "$(git status --porcelain)" || (git status; git diff; exit 1); fi - env GLOG_logtostderr=1 go test -race -v github.com/grpc-ecosystem/grpc-gateway/... - make lint - sh -c 'cd examples/browser && gulp' env: global: - "PATH=$PATH:$HOME/local/bin" + matrix: + - GATEWAY_PLUGIN_FLAGS= + - GATEWAY_PLUGIN_FLAGS=request_context=true diff --git a/Makefile b/Makefile index 756f678aeaa..01a05855453 100644 --- a/Makefile +++ b/Makefile @@ -35,6 +35,7 @@ GATEWAY_PLUGIN_SRC= utilities/doc.go \ protoc-gen-grpc-gateway/httprule/parse.go \ protoc-gen-grpc-gateway/httprule/types.go \ protoc-gen-grpc-gateway/main.go +GATEWAY_PLUGIN_FLAGS?= GOOGLEAPIS_DIR=third_party/googleapis OPTIONS_PROTO=$(GOOGLEAPIS_DIR)/google/api/annotations.proto $(GOOGLEAPIS_DIR)/google/api/http.proto @@ -45,6 +46,10 @@ RUNTIME_PROTO=runtime/internal/stream_chunk.proto RUNTIME_GO=$(RUNTIME_PROTO:.proto=.pb.go) PKGMAP=Mgoogle/protobuf/descriptor.proto=$(GO_PLUGIN_PKG)/descriptor,Mgoogle/api/annotations.proto=$(PKG)/$(GOOGLEAPIS_DIR)/google/api,Mexamples/sub/message.proto=$(PKG)/examples/sub +ADDITIONAL_FLAGS= +ifneq "$(GATEWAY_PLUGIN_FLAGS)" "" + ADDITIONAL_FLAGS=,$(GATEWAY_PLUGIN_FLAGS) +endif SWAGGER_EXAMPLES=examples/examplepb/echo_service.proto \ examples/examplepb/a_bit_of_everything.proto EXAMPLES=examples/examplepb/echo_service.proto \ @@ -102,7 +107,7 @@ $(EXAMPLE_DEPSRCS): $(GO_PLUGIN) $(EXAMPLE_DEPS) protoc -I $(PROTOC_INC_PATH) -I. --plugin=$(GO_PLUGIN) --go_out=$(PKGMAP),plugins=grpc:$(OUTPUT_DIR) $(@:.pb.go=.proto) cp $(OUTPUT_DIR)/$(PKG)/$@ $@ || cp $(OUTPUT_DIR)/$@ $@ $(EXAMPLE_GWSRCS): $(GATEWAY_PLUGIN) $(EXAMPLES) - protoc -I $(PROTOC_INC_PATH) -I. -I$(GOOGLEAPIS_DIR) --plugin=$(GATEWAY_PLUGIN) --grpc-gateway_out=logtostderr=true,$(PKGMAP):. $(EXAMPLES) + protoc -I $(PROTOC_INC_PATH) -I. -I$(GOOGLEAPIS_DIR) --plugin=$(GATEWAY_PLUGIN) --grpc-gateway_out=logtostderr=true,$(PKGMAP)$(ADDITIONAL_FLAGS):. $(EXAMPLES) $(EXAMPLE_SWAGGERSRCS): $(SWAGGER_PLUGIN) $(SWAGGER_EXAMPLES) protoc -I $(PROTOC_INC_PATH) -I. -I$(GOOGLEAPIS_DIR) --plugin=$(SWAGGER_PLUGIN) --swagger_out=logtostderr=true,$(PKGMAP):. $(SWAGGER_EXAMPLES) diff --git a/README.md b/README.md index 33f1bdddbde..bed084be863 100644 --- a/README.md +++ b/README.md @@ -188,7 +188,6 @@ They are compatible to [the parameters with same names in `protoc-gen-go`](https In addition we also support the `request_context` parameter in order to use the `http.Request`'s Context (only for Go 1.7 and above). This parameter can be useful to pass request scoped context between the gateway and the gRPC service. -**WARNING**: using `request_context` has breaking API: `context.Context` is removed from all `Register${SvcName}Handler` functions. `protoc-gen-grpc-gateway` also supports some more command line flags to control logging. You can give these flags together with parameters above. Run `protoc-gen-grpc-gateway --help` for more details about the flags. diff --git a/protoc-gen-grpc-gateway/gengateway/template.go b/protoc-gen-grpc-gateway/gengateway/template.go index 5fec3501fa3..f3c0b60bdb7 100644 --- a/protoc-gen-grpc-gateway/gengateway/template.go +++ b/protoc-gen-grpc-gateway/gengateway/template.go @@ -87,14 +87,8 @@ func applyTemplate(p param) (string, error) { if !methodSeen { return "", errNoTargetService } - if p.UseRequestContext { - if err := trailerTemplate17.Execute(w, p.Services); err != nil { - return "", err - } - } else { - if err := trailerTemplate.Execute(w, p.Services); err != nil { - return "", err - } + if err := trailerTemplate.Execute(w, p); err != nil { + return "", err } return w.String(), nil } @@ -298,7 +292,8 @@ var ( `)) trailerTemplate = template.Must(template.New("trailer").Parse(` -{{range $svc := .}} +{{$UseRequestContext := .UseRequestContext}} +{{range $svc := .Services}} // Register{{$svc.GetName}}HandlerFromEndpoint is same as Register{{$svc.GetName}}Handler but // automatically dials to "endpoint" and closes the connection when "ctx" gets done. func Register{{$svc.GetName}}HandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) { @@ -331,92 +326,11 @@ func Register{{$svc.GetName}}Handler(ctx context.Context, mux *runtime.ServeMux, {{range $m := $svc.Methods}} {{range $b := $m.Bindings}} mux.Handle({{$b.HTTPMethod | printf "%q"}}, pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - if cn, ok := w.(http.CloseNotifier); ok { - go func(done <-chan struct{}, closed <-chan bool) { - select { - case <-done: - case <-closed: - cancel() - } - }(ctx.Done(), cn.CloseNotify()) - } - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - rctx, err := runtime.AnnotateContext(ctx, req) - if err != nil { - runtime.HTTPError(ctx, outboundMarshaler, w, req, err) - } - resp, md, err := request_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(rctx, inboundMarshaler, client, req, pathParams) - ctx = runtime.NewServerMetadataContext(ctx, md) - if err != nil { - runtime.HTTPError(ctx, outboundMarshaler, w, req, err) - return - } - {{if $m.GetServerStreaming}} - forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(ctx, outboundMarshaler, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) - {{else}} - forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(ctx, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - {{end}} - }) - {{end}} - {{end}} - return nil -} - -var ( - {{range $m := $svc.Methods}} - {{range $b := $m.Bindings}} - pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}} = runtime.MustPattern(runtime.NewPattern({{$b.PathTmpl.Version}}, {{$b.PathTmpl.OpCodes | printf "%#v"}}, {{$b.PathTmpl.Pool | printf "%#v"}}, {{$b.PathTmpl.Verb | printf "%q"}})) - {{end}} - {{end}} -) - -var ( - {{range $m := $svc.Methods}} - {{range $b := $m.Bindings}} - forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}} = {{if $m.GetServerStreaming}}runtime.ForwardResponseStream{{else}}runtime.ForwardResponseMessage{{end}} - {{end}} - {{end}} -) -{{end}}`)) - - // trailerTemplate17 is the Go1.7 (and above) version of trailerTemplate. - trailerTemplate17 = template.Must(template.New("trailer-1.7").Parse(` -{{range $svc := .}} -// Register{{$svc.GetName}}HandlerFromEndpoint is same as Register{{$svc.GetName}}Handler but -// automatically dials to "endpoint" and closes the connection when "ctx" gets done. -func Register{{$svc.GetName}}HandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) { - conn, err := grpc.Dial(endpoint, opts...) - if err != nil { - return err - } - defer func() { - if err != nil { - if cerr := conn.Close(); cerr != nil { - grpclog.Printf("Failed to close conn to %s: %v", endpoint, cerr) - } - return - } - go func() { - <-ctx.Done() - if cerr := conn.Close(); cerr != nil { - grpclog.Printf("Failed to close conn to %s: %v", endpoint, cerr) - } - }() - }() - - return Register{{$svc.GetName}}Handler(ctx, mux, conn) -} - -// Register{{$svc.GetName}}Handler registers the http handlers for service {{$svc.GetName}} to "mux". -// The handlers forward requests to the grpc endpoint over "conn". -func Register{{$svc.GetName}}Handler(_ context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { - client := New{{$svc.GetName}}Client(conn) - {{range $m := $svc.Methods}} - {{range $b := $m.Bindings}} - mux.Handle({{$b.HTTPMethod | printf "%q"}}, pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + {{- if $UseRequestContext }} ctx, cancel := context.WithCancel(req.Context()) + {{- else -}} + ctx, cancel := context.WithCancel(ctx) + {{- end }} defer cancel() if cn, ok := w.(http.CloseNotifier); ok { go func(done <-chan struct{}, closed <-chan bool) {