From 7983956e39bd3340117e1a63d2967a78fa92c915 Mon Sep 17 00:00:00 2001 From: Guillaume Fillon Date: Fri, 4 Nov 2016 17:39:58 +0100 Subject: [PATCH] Add usage of http.Request's Context --- README.md | 4 + .../gengateway/generator.go | 11 ++- .../gengateway/template.go | 98 ++++++++++++++++++- protoc-gen-grpc-gateway/main.go | 5 +- 4 files changed, 108 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 718047f3234..33f1bdddbde 100644 --- a/README.md +++ b/README.md @@ -186,6 +186,10 @@ Make sure that your `$GOPATH/bin` is in your `$PATH`. `protoc-gen-grpc-gateway` supports custom mapping from Protobuf `import` to Golang import path. They are compatible to [the parameters with same names in `protoc-gen-go`](https://github.com/golang/protobuf#parameters). +In addition we also support the `request_context` parameter in order to use the `http.Request`'s Context (only for Go 1.7 and above). +This parameter can be useful to pass request scoped context between the gateway and the gRPC service. +**WARNING**: using `request_context` has breaking API: `context.Context` is removed from all `Register${SvcName}Handler` functions. + `protoc-gen-grpc-gateway` also supports some more command line flags to control logging. You can give these flags together with parameters above. Run `protoc-gen-grpc-gateway --help` for more details about the flags. ## More Examples diff --git a/protoc-gen-grpc-gateway/gengateway/generator.go b/protoc-gen-grpc-gateway/gengateway/generator.go index b4cce8695eb..44d01187207 100644 --- a/protoc-gen-grpc-gateway/gengateway/generator.go +++ b/protoc-gen-grpc-gateway/gengateway/generator.go @@ -20,12 +20,13 @@ var ( ) type generator struct { - reg *descriptor.Registry - baseImports []descriptor.GoPackage + reg *descriptor.Registry + baseImports []descriptor.GoPackage + useRequestContext bool } // New returns a new generator which generates grpc gateway files. -func New(reg *descriptor.Registry) gen.Generator { +func New(reg *descriptor.Registry, useRequestContext bool) gen.Generator { var imports []descriptor.GoPackage for _, pkgpath := range []string{ "io", @@ -54,7 +55,7 @@ func New(reg *descriptor.Registry) gen.Generator { } imports = append(imports, pkg) } - return &generator{reg: reg, baseImports: imports} + return &generator{reg: reg, baseImports: imports, useRequestContext: useRequestContext} } func (g *generator) Generate(targets []*descriptor.File) ([]*plugin.CodeGeneratorResponse_File, error) { @@ -107,5 +108,5 @@ func (g *generator) generate(file *descriptor.File) (string, error) { imports = append(imports, pkg) } } - return applyTemplate(param{File: file, Imports: imports}) + return applyTemplate(param{File: file, Imports: imports, UseRequestContext: g.useRequestContext}) } diff --git a/protoc-gen-grpc-gateway/gengateway/template.go b/protoc-gen-grpc-gateway/gengateway/template.go index 3c3da539b95..5fec3501fa3 100644 --- a/protoc-gen-grpc-gateway/gengateway/template.go +++ b/protoc-gen-grpc-gateway/gengateway/template.go @@ -13,7 +13,8 @@ import ( type param struct { *descriptor.File - Imports []descriptor.GoPackage + Imports []descriptor.GoPackage + UseRequestContext bool } type binding struct { @@ -86,8 +87,14 @@ func applyTemplate(p param) (string, error) { if !methodSeen { return "", errNoTargetService } - if err := trailerTemplate.Execute(w, p.Services); err != nil { - return "", err + if p.UseRequestContext { + if err := trailerTemplate17.Execute(w, p.Services); err != nil { + return "", err + } + } else { + if err := trailerTemplate.Execute(w, p.Services); err != nil { + return "", err + } } return w.String(), nil } @@ -365,6 +372,91 @@ var ( {{end}} ) +var ( + {{range $m := $svc.Methods}} + {{range $b := $m.Bindings}} + forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}} = {{if $m.GetServerStreaming}}runtime.ForwardResponseStream{{else}}runtime.ForwardResponseMessage{{end}} + {{end}} + {{end}} +) +{{end}}`)) + + // trailerTemplate17 is the Go1.7 (and above) version of trailerTemplate. + trailerTemplate17 = template.Must(template.New("trailer-1.7").Parse(` +{{range $svc := .}} +// Register{{$svc.GetName}}HandlerFromEndpoint is same as Register{{$svc.GetName}}Handler but +// automatically dials to "endpoint" and closes the connection when "ctx" gets done. +func Register{{$svc.GetName}}HandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) { + conn, err := grpc.Dial(endpoint, opts...) + if err != nil { + return err + } + defer func() { + if err != nil { + if cerr := conn.Close(); cerr != nil { + grpclog.Printf("Failed to close conn to %s: %v", endpoint, cerr) + } + return + } + go func() { + <-ctx.Done() + if cerr := conn.Close(); cerr != nil { + grpclog.Printf("Failed to close conn to %s: %v", endpoint, cerr) + } + }() + }() + + return Register{{$svc.GetName}}Handler(ctx, mux, conn) +} + +// Register{{$svc.GetName}}Handler registers the http handlers for service {{$svc.GetName}} to "mux". +// The handlers forward requests to the grpc endpoint over "conn". +func Register{{$svc.GetName}}Handler(_ context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { + client := New{{$svc.GetName}}Client(conn) + {{range $m := $svc.Methods}} + {{range $b := $m.Bindings}} + mux.Handle({{$b.HTTPMethod | printf "%q"}}, pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + if cn, ok := w.(http.CloseNotifier); ok { + go func(done <-chan struct{}, closed <-chan bool) { + select { + case <-done: + case <-closed: + cancel() + } + }(ctx.Done(), cn.CloseNotify()) + } + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateContext(ctx, req) + if err != nil { + runtime.HTTPError(ctx, outboundMarshaler, w, req, err) + } + resp, md, err := request_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(rctx, inboundMarshaler, client, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, outboundMarshaler, w, req, err) + return + } + {{if $m.GetServerStreaming}} + forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(ctx, outboundMarshaler, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) + {{else}} + forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(ctx, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + {{end}} + }) + {{end}} + {{end}} + return nil +} + +var ( + {{range $m := $svc.Methods}} + {{range $b := $m.Bindings}} + pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}} = runtime.MustPattern(runtime.NewPattern({{$b.PathTmpl.Version}}, {{$b.PathTmpl.OpCodes | printf "%#v"}}, {{$b.PathTmpl.Pool | printf "%#v"}}, {{$b.PathTmpl.Verb | printf "%q"}})) + {{end}} + {{end}} +) + var ( {{range $m := $svc.Methods}} {{range $b := $m.Bindings}} diff --git a/protoc-gen-grpc-gateway/main.go b/protoc-gen-grpc-gateway/main.go index 1c01d8756ce..0e2c54f8b26 100644 --- a/protoc-gen-grpc-gateway/main.go +++ b/protoc-gen-grpc-gateway/main.go @@ -23,7 +23,8 @@ import ( ) var ( - importPrefix = flag.String("import_prefix", "", "prefix to be added to go package paths for imported proto files") + importPrefix = flag.String("import_prefix", "", "prefix to be added to go package paths for imported proto files") + useRequestContext = flag.Bool("request_context", false, "determine whether to use http.Request's context or not") ) func parseReq(r io.Reader) (*plugin.CodeGeneratorRequest, error) { @@ -73,7 +74,7 @@ func main() { } } - g := gengateway.New(reg) + g := gengateway.New(reg, *useRequestContext) reg.SetPrefix(*importPrefix) if err := reg.Load(req); err != nil {