From e63b987b2cf1019f5d6763ba830f83317513f800 Mon Sep 17 00:00:00 2001 From: Filipe Azevedo Date: Tue, 28 May 2019 20:12:09 +0100 Subject: [PATCH] Allow automatically resolved non root resources (#454) * Allow non root resources * code review * code review * code review * code review * add non root resource test * add non root resource test * add non root resource test * code review * code review --- go/grpcweb/helpers.go | 13 +++++++++++++ go/grpcweb/helpers_internal_test.go | 29 +++++++++++++++++++++++++++++ go/grpcweb/options.go | 15 +++++++++++++++ go/grpcweb/wrapper.go | 15 ++++++++++++++- go/grpcweb/wrapper_test.go | 21 +++++++++++++++++++++ 5 files changed, 92 insertions(+), 1 deletion(-) create mode 100644 go/grpcweb/helpers_internal_test.go diff --git a/go/grpcweb/helpers.go b/go/grpcweb/helpers.go index 86a1bf2a..d76d7ec7 100644 --- a/go/grpcweb/helpers.go +++ b/go/grpcweb/helpers.go @@ -7,10 +7,14 @@ import ( "fmt" "net/http" "net/url" + "regexp" + "strings" "google.golang.org/grpc" ) +var pathMatcher = regexp.MustCompile(`/[^/]*/[^/]*$`) + // ListGRPCResources is a helper function that lists all URLs that are registered on gRPC server. // // This makes it easy to register all the relevant routes in your HTTP router of choice. @@ -35,3 +39,12 @@ func WebsocketRequestOrigin(req *http.Request) (string, error) { } return parsed.Host, nil } + +func getGRPCEndpoint(req *http.Request) string { + endpoint := pathMatcher.FindString(strings.TrimRight(req.URL.Path, "/")) + if len(endpoint) == 0 { + return req.URL.Path + } + + return endpoint +} diff --git a/go/grpcweb/helpers_internal_test.go b/go/grpcweb/helpers_internal_test.go new file mode 100644 index 00000000..3842fa69 --- /dev/null +++ b/go/grpcweb/helpers_internal_test.go @@ -0,0 +1,29 @@ +package grpcweb + +import ( + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetGRPCEndpoint(t *testing.T) { + cases := []struct { + input string + output string + }{ + {input: "/", output: "/"}, + {input: "/resource", output: "/resource"}, + {input: "/improbable.grpcweb.test.TestService/PingEmpty", output: "/improbable.grpcweb.test.TestService/PingEmpty"}, + {input: "/improbable.grpcweb.test.TestService/PingEmpty/", output: "/improbable.grpcweb.test.TestService/PingEmpty"}, + {input: "/a/b/c/improbable.grpcweb.test.TestService/PingEmpty", output: "/improbable.grpcweb.test.TestService/PingEmpty"}, + {input: "/a/b/c/improbable.grpcweb.test.TestService/PingEmpty/", output: "/improbable.grpcweb.test.TestService/PingEmpty"}, + } + + for _, c := range cases { + req := httptest.NewRequest("GET", c.input, nil) + result := getGRPCEndpoint(req) + + assert.Equal(t, c.output, result) + } +} diff --git a/go/grpcweb/options.go b/go/grpcweb/options.go index 441f4fa5..96aea91d 100644 --- a/go/grpcweb/options.go +++ b/go/grpcweb/options.go @@ -10,6 +10,7 @@ var ( allowedRequestHeaders: []string{"*"}, corsForRegisteredEndpointsOnly: true, originFunc: func(origin string) bool { return false }, + allowNonRootResources: false, } ) @@ -19,6 +20,7 @@ type options struct { originFunc func(origin string) bool enableWebsockets bool websocketOriginFunc func(req *http.Request) bool + allowNonRootResources bool } func evaluateOptions(opts []Option) *options { @@ -99,3 +101,16 @@ func WithWebsocketOriginFunc(websocketOriginFunc func(req *http.Request) bool) O o.websocketOriginFunc = websocketOriginFunc } } + +// WithAllowNonRootResource enables the gRPC wrapper to serve requests that have a path prefix +// added to the URL, before the service name and method placeholders. +// +// This should be set to false when exposing the endpoint as the root resource, to avoid +// the performance cost of path processing for every request. +// +// The default behaviour is `false`, i.e. always serves requests assuming there is no prefix to the gRPC endpoint. +func WithAllowNonRootResource(allowNonRootResources bool) Option { + return func(o *options) { + o.allowNonRootResources = allowNonRootResources + } +} diff --git a/go/grpcweb/wrapper.go b/go/grpcweb/wrapper.go index 44ac9932..2f1606e8 100644 --- a/go/grpcweb/wrapper.go +++ b/go/grpcweb/wrapper.go @@ -35,6 +35,7 @@ type WrappedGrpcServer struct { originFunc func(origin string) bool enableWebsockets bool websocketOriginFunc func(req *http.Request) bool + endpointFunc func(req *http.Request) string } // WrapServer takes a gRPC Server in Go and returns a WrappedGrpcServer that provides gRPC-Web Compatibility. @@ -56,6 +57,15 @@ func WrapServer(server *grpc.Server, options ...Option) *WrappedGrpcServer { if websocketOriginFunc == nil { websocketOriginFunc = defaultWebsocketOriginFunc } + + endpointFunc := func(req *http.Request) string { + return req.URL.Path + } + + if opts.allowNonRootResources { + endpointFunc = getGRPCEndpoint + } + return &WrappedGrpcServer{ server: server, opts: opts, @@ -63,6 +73,7 @@ func WrapServer(server *grpc.Server, options ...Option) *WrappedGrpcServer { originFunc: opts.originFunc, enableWebsockets: opts.enableWebsockets, websocketOriginFunc: websocketOriginFunc, + endpointFunc: endpointFunc, } } @@ -105,6 +116,7 @@ func (w *WrappedGrpcServer) IsGrpcWebSocketRequest(req *http.Request) bool { func (w *WrappedGrpcServer) HandleGrpcWebRequest(resp http.ResponseWriter, req *http.Request) { intReq, isTextFormat := hackIntoNormalGrpcRequest(req) intResp := newGrpcWebResponse(resp, isTextFormat) + req.URL.Path = w.endpointFunc(req) w.server.ServeHTTP(intResp, intReq) intResp.finishRequest(req) } @@ -161,6 +173,7 @@ func (w *WrappedGrpcServer) handleWebSocket(wsConn *websocket.Conn, req *http.Re grpclog.Errorf("web socket text format requests not yet supported") return } + req.URL.Path = w.endpointFunc(req) w.server.ServeHTTP(respWriter, interceptedRequest) } @@ -187,7 +200,7 @@ func (w *WrappedGrpcServer) IsAcceptableGrpcCorsRequest(req *http.Request) bool func (w *WrappedGrpcServer) isRequestForRegisteredEndpoint(req *http.Request) bool { registeredEndpoints := ListGRPCResources(w.server) - requestedEndpoint := req.URL.Path + requestedEndpoint := w.endpointFunc(req) for _, v := range registeredEndpoints { if v == requestedEndpoint { return true diff --git a/go/grpcweb/wrapper_test.go b/go/grpcweb/wrapper_test.go index bbd8d65c..1a1ff321 100644 --- a/go/grpcweb/wrapper_test.go +++ b/go/grpcweb/wrapper_test.go @@ -15,6 +15,7 @@ import ( "log" "net" "net/http" + "net/http/httptest" "net/textproto" "os" "strconv" @@ -67,6 +68,26 @@ func TestHttp1GrpcWebWrapperTestSuite(t *testing.T) { suite.Run(t, &GrpcWebWrapperTestSuite{httpMajorVersion: 1}) } +func TestNonRootResource(t *testing.T) { + grpcServer := grpc.NewServer() + testproto.RegisterTestServiceServer(grpcServer, &testServiceImpl{}) + wrappedServer := grpcweb.WrapServer(grpcServer, + grpcweb.WithAllowNonRootResource(true), + grpcweb.WithOriginFunc(func(origin string) bool { + return true + })) + + headers := http.Header{} + headers.Add("Access-Control-Request-Method", "POST") + headers.Add("Access-Control-Request-Headers", "origin, x-something-custom, x-grpc-web, accept") + req := httptest.NewRequest("OPTIONS", "http://host/grpc/improbable.grpcweb.test.TestService/Echo", nil) + req.Header = headers + resp := httptest.NewRecorder() + wrappedServer.ServeHTTP(resp, req) + + assert.Equal(t, http.StatusOK, resp.Code) +} + func (s *GrpcWebWrapperTestSuite) SetupTest() { var err error s.grpcServer = grpc.NewServer()