diff --git a/docs/_docs/customizingyourgateway.md b/docs/_docs/customizingyourgateway.md index 86e0f8d06fa..11190d832f0 100644 --- a/docs/_docs/customizingyourgateway.md +++ b/docs/_docs/customizingyourgateway.md @@ -234,24 +234,14 @@ if err := pb.RegisterMyServiceHandlerFromEndpoint(ctx, mux, serviceEndpoint, opt ``` ## Error handler -The gateway uses two different error handlers for non-streaming requests: - - * `runtime.HTTPError` is called for errors from backend calls - * `runtime.OtherErrorHandler` is called for errors from parsing and routing client requests - -To override all error handling for a `*runtime.ServeMux`, use the -`runtime.WithProtoErrorHandler` serve option. - -Alternatively, you can override the global default `HTTPError` handling by -setting `runtime.GlobalHTTPErrorHandler` to a custom function, and override -the global default `OtherErrorHandler` by setting `runtime.OtherErrorHandler` -to a custom function. - -You should not set `runtime.HTTPError` directly, because that might break -any `ServeMux` set up with the `WithProtoErrorHandler` option. +To override error handling for a `*runtime.ServeMux`, use the +`runtime.WithErrorHandler` option. This will configure all unary error +responses to pass through this error handler. See https://mycodesmells.com/post/grpc-gateway-error-handler for an example -of writing a custom error handler function. +of writing a custom error handler function. Note that this post targets +the v1 release of the gateway, and you no longer assign to `HTTPError` to +configure an error handler. ## Stream Error Handler The error handler described in the previous section applies only @@ -285,40 +275,33 @@ Here's an example custom handler: // handleStreamError overrides default behavior for computing an error // message for a server stream. // -// It uses a default "502 Bad Gateway" HTTP code; only emits "safe" -// messages; and does not set gRPC code or details fields (so they will +// It uses a default "502 Bad Gateway" HTTP code, only emits "safe" +// messages and does not set the details field (so it will // be omitted from the resulting JSON object that is sent to client). -func handleStreamError(ctx context.Context, err error) *runtime.StreamError { - code := http.StatusBadGateway +func handleStreamError(ctx context.Context, err error) *status.Status { + code := codes.Internal msg := "unexpected error" if s, ok := status.FromError(err); ok { - code = runtime.HTTPStatusFromCode(s.Code()) - // default message, based on the name of the gRPC code - msg = code.String() + code = s.Code() + // default message, based on the gRPC status + msg = s.Message() // see if error details include "safe" message to send // to external callers - for _, msg := s.Details() { + for _, msg := range s.Details() { if safe, ok := msg.(*SafeMessage); ok { msg = safe.Text break } } } - return &runtime.StreamError{ - HttpCode: int32(code), - HttpStatus: http.StatusText(code), - Message: msg, - } + return status.Errorf(code, msg) } ``` If no custom handler is provided, the default stream error handler will include any gRPC error attributes (code, message, detail messages), if the error being reported includes them. If the error does not have -these attributes, a gRPC code of `Unknown` (2) is reported. The default -handler will also include an HTTP code and status, which is derived -from the gRPC code (or set to `"500 Internal Server Error"` when -the source error has no gRPC attributes). +these attributes, a gRPC code of `Unknown` (2) is reported. ## Replace a response forwarder per method You might want to keep the behavior of the current marshaler but change only a message forwarding of a certain API method. diff --git a/examples/internal/integration/BUILD.bazel b/examples/internal/integration/BUILD.bazel index bcc837b56f1..9755e1b7f7c 100644 --- a/examples/internal/integration/BUILD.bazel +++ b/examples/internal/integration/BUILD.bazel @@ -7,7 +7,6 @@ go_test( "fieldmask_test.go", "integration_test.go", "main_test.go", - "proto_error_test.go", ], deps = [ "//examples/internal/clients/abe:go_default_library", @@ -23,9 +22,7 @@ go_test( "@com_github_golang_protobuf//descriptor:go_default_library_gen", "@com_github_golang_protobuf//jsonpb:go_default_library_gen", "@com_github_golang_protobuf//proto:go_default_library", - "@com_github_golang_protobuf//ptypes:go_default_library_gen", "@com_github_google_go_cmp//cmp:go_default_library", - "@go_googleapis//google/rpc:errdetails_go_proto", "@go_googleapis//google/rpc:status_go_proto", "@io_bazel_rules_go//proto/wkt:empty_go_proto", "@io_bazel_rules_go//proto/wkt:field_mask_go_proto", diff --git a/examples/internal/integration/integration_test.go b/examples/internal/integration/integration_test.go index 075f6286e9f..c34062886b9 100644 --- a/examples/internal/integration/integration_test.go +++ b/examples/internal/integration/integration_test.go @@ -1250,8 +1250,8 @@ func testABERepeated(t *testing.T, port int) { "bar", }, PathRepeatedBytesValue: [][]byte{ - []byte{0x00}, - []byte{0xFF}, + {0x00}, + {0xFF}, }, PathRepeatedUint32Value: []uint32{ 0, @@ -1378,7 +1378,7 @@ func TestUnknownPath(t *testing.T) { } } -func TestMethodNotAllowed(t *testing.T) { +func TestNotImplemented(t *testing.T) { if testing.Short() { t.Skip() return @@ -1397,7 +1397,7 @@ func TestMethodNotAllowed(t *testing.T) { return } - if got, want := resp.StatusCode, http.StatusMethodNotAllowed; got != want { + if got, want := resp.StatusCode, http.StatusNotImplemented; got != want { t.Errorf("resp.StatusCode = %d; want %d", got, want) t.Logf("%s", buf) } diff --git a/examples/internal/integration/proto_error_test.go b/examples/internal/integration/proto_error_test.go deleted file mode 100644 index 12a636c1a62..00000000000 --- a/examples/internal/integration/proto_error_test.go +++ /dev/null @@ -1,281 +0,0 @@ -package integration_test - -import ( - "context" - "fmt" - "io/ioutil" - "net/http" - "strings" - "testing" - "time" - - "github.com/golang/protobuf/jsonpb" - "github.com/golang/protobuf/ptypes" - "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" - "google.golang.org/genproto/googleapis/rpc/errdetails" - spb "google.golang.org/genproto/googleapis/rpc/status" - "google.golang.org/grpc/codes" -) - -func runServer(ctx context.Context, t *testing.T, port uint16) { - opt := runtime.WithProtoErrorHandler(runtime.DefaultHTTPProtoErrorHandler) - if err := runGateway(ctx, fmt.Sprintf(":%d", port), opt); err != nil { - t.Errorf("runGateway() failed with %v; want success", err) - } -} - -func TestWithProtoErrorHandler(t *testing.T) { - if testing.Short() { - t.Skip() - return - } - - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - const port = 8082 - go runServer(ctx, t, port) - if err := waitForGateway(ctx, 8082); err != nil { - t.Errorf("waitForGateway(ctx, 8082) failed with %v; want success", err) - } - testEcho(t, port, "v1", "application/json") - testEchoBody(t, port, "v1") -} - -func TestABEWithProtoErrorHandler(t *testing.T) { - if testing.Short() { - t.Skip() - return - } - - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - const port = 8083 - go runServer(ctx, t, port) - if err := waitForGateway(ctx, 8083); err != nil { - t.Errorf("waitForGateway(ctx, 8083) failed with %v; want success", err) - } - - testABECreate(t, port) - testABECreateBody(t, port) - testABEBulkCreate(t, port) - testABELookup(t, port) - testABELookupNotFoundWithProtoError(t, port) - testABELookupNotFoundWithProtoErrorIncludingDetails(t, port) - testABEList(t, port) - testABEBulkEcho(t, port) - testABEBulkEchoZeroLength(t, port) - testAdditionalBindings(t, port) -} - -func testABELookupNotFoundWithProtoError(t *testing.T, port uint16) { - url := fmt.Sprintf("http://localhost:%d/v1/example/a_bit_of_everything", port) - uuid := "not_exist" - url = fmt.Sprintf("%s/%s", url, uuid) - resp, err := http.Get(url) - if err != nil { - t.Errorf("http.Get(%q) failed with %v; want success", url, err) - return - } - defer resp.Body.Close() - - buf, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Errorf("ioutil.ReadAll(resp.Body) failed with %v; want success", err) - return - } - - if got, want := resp.StatusCode, http.StatusNotFound; got != want { - t.Errorf("resp.StatusCode = %d; want %d", got, want) - t.Logf("%s", buf) - return - } - - var msg spb.Status - if err := jsonpb.UnmarshalString(string(buf), &msg); err != nil { - t.Errorf("jsonpb.UnmarshalString(%s, &msg) failed with %v; want success", buf, err) - return - } - - if got, want := msg.Code, int32(codes.NotFound); got != want { - t.Errorf("msg.Code = %d; want %d", got, want) - return - } - - if got, want := msg.Message, "not found"; got != want { - t.Errorf("msg.Message = %s; want %s", got, want) - return - } - - if got, want := resp.Header.Get("Grpc-Metadata-Uuid"), uuid; got != want { - t.Errorf("Grpc-Metadata-Uuid was %s, wanted %s", got, want) - } - if got, want := resp.Trailer.Get("Grpc-Trailer-Foo"), "foo2"; got != want { - t.Errorf("Grpc-Trailer-Foo was %q, wanted %q", got, want) - } - if got, want := resp.Trailer.Get("Grpc-Trailer-Bar"), "bar2"; got != want { - t.Errorf("Grpc-Trailer-Bar was %q, wanted %q", got, want) - } -} - -func testABELookupNotFoundWithProtoErrorIncludingDetails(t *testing.T, port uint16) { - uuid := "errorwithdetails" - url := fmt.Sprintf("http://localhost:%d/v2/example/%s", port, uuid) - resp, err := http.Get(url) - if err != nil { - t.Errorf("http.Get(%q) failed with %v; want success", url, err) - return - } - defer resp.Body.Close() - - buf, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Errorf("ioutil.ReadAll(resp.Body) failed with %v; want success", err) - return - } - - if got, want := resp.StatusCode, http.StatusInternalServerError; got != want { - t.Errorf("resp.StatusCode = %d; want %d", got, want) - t.Logf("%s", buf) - return - } - - var msg spb.Status - if err := jsonpb.UnmarshalString(string(buf), &msg); err != nil { - t.Errorf("jsonpb.UnmarshalString(%s, &msg) failed with %v; want success", buf, err) - return - } - - if got, want := msg.Code, int32(codes.Unknown); got != want { - t.Errorf("msg.Code = %d; want %d", got, want) - return - } - - if got, want := msg.Message, "with details"; got != want { - t.Errorf("msg.Message = %s; want %s", got, want) - return - } - - details := msg.Details - if got, want := len(details), 1; got != want { - t.Fatalf("got %q details, wanted %q", got, want) - } - - detail := errdetails.DebugInfo{} - if got, want := ptypes.UnmarshalAny(msg.Details[0], &detail), error(nil); got != want { - t.Errorf("unmarshaling any: got %q, wanted %q", got, want) - } - - if got, want := len(detail.StackEntries), 1; got != want { - t.Fatalf("got %d stack entries, expected %d", got, want) - } - if got, want := detail.StackEntries[0], "foo:1"; got != want { - t.Errorf("StackEntries[0]: got %q; want %q", got, want) - } -} - -func TestUnknownPathWithProtoError(t *testing.T) { - if testing.Short() { - t.Skip() - return - } - - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - const port = 8084 - go runServer(ctx, t, port) - if err := waitForGateway(ctx, 8084); err != nil { - t.Errorf("waitForGateway(ctx, 8084) failed with %v; want success", err) - } - - url := fmt.Sprintf("http://localhost:%d", port) - resp, err := http.Post(url, "application/json", strings.NewReader("{}")) - if err != nil { - t.Errorf("http.Post(%q) failed with %v; want success", url, err) - return - } - defer resp.Body.Close() - buf, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Errorf("ioutil.ReadAll(resp.Body) failed with %v; want success", err) - return - } - - if got, want := resp.StatusCode, http.StatusNotImplemented; got != want { - t.Errorf("resp.StatusCode = %d; want %d", got, want) - t.Logf("%s", buf) - } - - var msg spb.Status - if err := jsonpb.UnmarshalString(string(buf), &msg); err != nil { - t.Errorf("jsonpb.UnmarshalString(%s, &msg) failed with %v; want success", buf, err) - return - } - - if got, want := msg.Code, int32(codes.Unimplemented); got != want { - t.Errorf("msg.Code = %d; want %d", got, want) - return - } - - if msg.Message == "" { - t.Errorf("msg.Message should not be empty") - return - } -} - -func TestMethodNotAllowedWithProtoError(t *testing.T) { - if testing.Short() { - t.Skip() - return - } - - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - const port = 8085 - go runServer(ctx, t, port) - - // Waiting for the server's getting available. - // TODO(yugui) find a better way to wait - time.Sleep(100 * time.Millisecond) - - url := fmt.Sprintf("http://localhost:%d/v1/example/echo/myid", port) - resp, err := http.Get(url) - if err != nil { - t.Errorf("http.Post(%q) failed with %v; want success", url, err) - return - } - defer resp.Body.Close() - buf, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Errorf("ioutil.ReadAll(resp.Body) failed with %v; want success", err) - return - } - - if got, want := resp.StatusCode, http.StatusNotImplemented; got != want { - t.Errorf("resp.StatusCode = %d; want %d", got, want) - t.Logf("%s", buf) - } - - var msg spb.Status - if err := jsonpb.UnmarshalString(string(buf), &msg); err != nil { - t.Errorf("jsonpb.UnmarshalString(%s, &msg) failed with %v; want success", buf, err) - return - } - - if got, want := msg.Code, int32(codes.Unimplemented); got != want { - t.Errorf("msg.Code = %d; want %d", got, want) - return - } - - if msg.Message == "" { - t.Errorf("msg.Message should not be empty") - return - } -} diff --git a/repositories.bzl b/repositories.bzl index a2070e0051c..fa1a255418d 100644 --- a/repositories.bzl +++ b/repositories.bzl @@ -84,8 +84,8 @@ def go_repositories(): go_repository( name = "org_golang_google_grpc", importpath = "google.golang.org/grpc", - sum = "h1:zvIju4sqAGvwKspUQOhwnpcqSbzi7/H6QomNNjTL4sk=", - version = "v1.27.1", + sum = "h1:EC2SB8S04d2r73uptxphDSUG+kTKVgjRPF+N3xpxRB4=", + version = "v1.29.1", ) go_repository( name = "org_golang_x_lint", @@ -163,8 +163,8 @@ def go_repositories(): go_repository( name = "com_github_envoyproxy_go_control_plane", importpath = "github.com/envoyproxy/go-control-plane", - sum = "h1:4cmBvAEBNJaGARUEs3/suWRyfyBfhf7I60WBZq+bv2w=", - version = "v0.9.1-0.20191026205805-5f8ba28d4473", + sum = "h1:rEvIZUSZ3fx39WIi3JkQqQBitGwpELBIYWeBVh6wn+E=", + version = "v0.9.4", ) go_repository( name = "com_github_envoyproxy_protoc_gen_validate", @@ -190,3 +190,9 @@ def go_repositories(): sum = "h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=", version = "v0.0.0-20191204190536-9bdfabe68543", ) + go_repository( + name = "com_github_cncf_udpa_go", + importpath = "github.com/cncf/udpa/go", + sum = "h1:WBZRG4aNOuI15bLRrCgN8fCq8E5Xuty6jGbmSNEvSsU=", + version = "v0.0.0-20191209042840-269d4d468f6f", + ) diff --git a/runtime/BUILD.bazel b/runtime/BUILD.bazel index 745fa692547..b5295bd120a 100644 --- a/runtime/BUILD.bazel +++ b/runtime/BUILD.bazel @@ -20,7 +20,6 @@ go_library( "mux.go", "pattern.go", "proto2_convert.go", - "proto_errors.go", "query.go", ], importpath = "github.com/grpc-ecosystem/grpc-gateway/v2/runtime", diff --git a/runtime/errors.go b/runtime/errors.go index 58c80eec857..1a47702d5ef 100644 --- a/runtime/errors.go +++ b/runtime/errors.go @@ -10,6 +10,12 @@ import ( "google.golang.org/grpc/status" ) +// ErrorHandlerFunc is the signature used to configure error handling. +type ErrorHandlerFunc func(context.Context, *ServeMux, Marshaler, http.ResponseWriter, *http.Request, error) + +// StreamErrorHandlerFunc is the signature used to configure stream error handling. +type StreamErrorHandlerFunc func(context.Context, error) *status.Status + // HTTPStatusFromCode converts a gRPC error code into the corresponding HTTP response status. // See: https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto func HTTPStatusFromCode(code codes.Code) int { @@ -55,61 +61,19 @@ func HTTPStatusFromCode(code codes.Code) int { return http.StatusInternalServerError } -var ( - // HTTPError replies to the request with an error. - // - // HTTPError is called: - // - From generated per-endpoint gateway handler code, when calling the backend results in an error. - // - From gateway runtime code, when forwarding the response message results in an error. - // - // The default value for HTTPError calls the custom error handler configured on the ServeMux via the - // WithProtoErrorHandler serve option if that option was used, calling GlobalHTTPErrorHandler otherwise. - // - // To customize the error handling of a particular ServeMux instance, use the WithProtoErrorHandler - // serve option. - // - // To customize the error format for all ServeMux instances not using the WithProtoErrorHandler serve - // option, set GlobalHTTPErrorHandler to a custom function. - // - // Setting this variable directly to customize error format is deprecated. - HTTPError = MuxOrGlobalHTTPError - - // GlobalHTTPErrorHandler is the HTTPError handler for all ServeMux instances not using the - // WithProtoErrorHandler serve option. - // - // You can set a custom function to this variable to customize error format. - GlobalHTTPErrorHandler = DefaultHTTPError - - // OtherErrorHandler handles gateway errors from parsing and routing client requests for all - // ServeMux instances not using the WithProtoErrorHandler serve option. - // - // It returns the following error codes: StatusMethodNotAllowed StatusNotFound StatusBadRequest - // - // To customize parsing and routing error handling of a particular ServeMux instance, use the - // WithProtoErrorHandler serve option. - // - // To customize parsing and routing error handling of all ServeMux instances not using the - // WithProtoErrorHandler serve option, set a custom function to this variable. - OtherErrorHandler = DefaultOtherErrorHandler -) - -// MuxOrGlobalHTTPError uses the mux-configured error handler, falling back to GlobalErrorHandler. -func MuxOrGlobalHTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, err error) { - if mux.protoErrorHandler != nil { - mux.protoErrorHandler(ctx, mux, marshaler, w, r, err) - } else { - GlobalHTTPErrorHandler(ctx, mux, marshaler, w, r, err) - } +// HTTPError uses the mux-configured error handler. +func HTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, err error) { + mux.errorHandler(ctx, mux, marshaler, w, r, err) } -// DefaultHTTPError is the default implementation of HTTPError. -// If "err" is an error from gRPC system, the function replies with the status code mapped by HTTPStatusFromCode. +// defaultHTTPErrorHandler is the default error handler. +// If "err" is a gRPC Status, the function replies with the status code mapped by HTTPStatusFromCode. // If otherwise, it replies with http.StatusInternalServerError. // -// The response body returned by this function is a JSON object, -// which contains a member whose key is "message" and whose value is err.Error(). -func DefaultHTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, _ *http.Request, err error) { - const fallback = `{"error": "failed to marshal error message"}` +// The response body written by this function is a Status message marshaled by the Marshaler. +func defaultHTTPErrorHandler(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, _ *http.Request, err error) { + // return Internal when Marshal failed + const fallback = `{"code": 13, "message": "failed to marshal error message"}` s := status.Convert(err) pb := s.Proto() @@ -117,7 +81,7 @@ func DefaultHTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w w.Header().Del("Trailer") contentType := marshaler.ContentType() - // Check marshaler on run time in order to keep backwards compatability + // Check marshaler at runtime in order to keep backwards compatibility. // An interface param needs to be added to the ContentType() function on // the Marshal interface to be able to remove this check if typeMarshaler, ok := marshaler.(contentTypeMarshaler); ok { @@ -151,8 +115,6 @@ func DefaultHTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w handleForwardResponseTrailer(w, md) } -// DefaultOtherErrorHandler is the default implementation of OtherErrorHandler. -// It simply writes a string representation of the given error into "w". -func DefaultOtherErrorHandler(w http.ResponseWriter, _ *http.Request, msg string, code int) { - http.Error(w, msg, code) +func defaultStreamErrorHandler(_ context.Context, err error) *status.Status { + return status.Convert(err) } diff --git a/runtime/errors_test.go b/runtime/errors_test.go index 83ec1666a03..5ba54eb97db 100644 --- a/runtime/errors_test.go +++ b/runtime/errors_test.go @@ -62,7 +62,8 @@ func TestDefaultHTTPError(t *testing.T) { } { w := httptest.NewRecorder() req, _ := http.NewRequest("", "", nil) // Pass in an empty request to match the signature - runtime.DefaultHTTPError(ctx, &runtime.ServeMux{}, &runtime.JSONPb{}, w, req, spec.err) + mux := runtime.NewServeMux() + runtime.HTTPError(ctx, mux, &runtime.JSONPb{}, w, req, spec.err) if got, want := w.Header().Get("Content-Type"), "application/json"; got != want { t.Errorf(`w.Header().Get("Content-Type") = %q; want %q; on spec.err=%v`, got, want, spec.err) diff --git a/runtime/handler.go b/runtime/handler.go index c519d5d098d..4cbf8f46259 100644 --- a/runtime/handler.go +++ b/runtime/handler.go @@ -186,7 +186,7 @@ func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, re } func handleForwardResponseStreamError(ctx context.Context, wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, req *http.Request, mux *ServeMux, err error) { - st := status.Convert(err) + st := mux.streamErrorHandler(ctx, err) if !wroteHeader { w.WriteHeader(HTTPStatusFromCode(st.Code())) } diff --git a/runtime/mux.go b/runtime/mux.go index eb7e435365e..70caae1dade 100644 --- a/runtime/mux.go +++ b/runtime/mux.go @@ -16,14 +16,6 @@ import ( // A HandlerFunc handles a specific pair of path pattern and HTTP method. type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) -// ErrUnknownURI is the error supplied to a custom ProtoErrorHandlerFunc when -// a request is received with a URI path that does not match any registered -// service method. -// -// Since gRPC servers return an "Unimplemented" code for requests with an -// unrecognized URI path, this error also has a gRPC "Unimplemented" code. -var ErrUnknownURI = status.Error(codes.Unimplemented, http.StatusText(http.StatusNotImplemented)) - // ServeMux is a request multiplexer for grpc-gateway. // It matches http requests to patterns and invokes the corresponding handler. type ServeMux struct { @@ -34,7 +26,8 @@ type ServeMux struct { incomingHeaderMatcher HeaderMatcherFunc outgoingHeaderMatcher HeaderMatcherFunc metadataAnnotators []func(context.Context, *http.Request) metadata.MD - protoErrorHandler ProtoErrorHandlerFunc + errorHandler ErrorHandlerFunc + streamErrorHandler StreamErrorHandlerFunc disablePathLengthFallback bool lastMatchWins bool } @@ -110,14 +103,26 @@ func WithMetadata(annotator func(context.Context, *http.Request) metadata.MD) Se } } -// WithProtoErrorHandler returns a ServeMuxOption for configuring a custom error handler. +// WithErrorHandler returns a ServeMuxOption for configuring a custom error handler. // -// This can be used to handle an error as general proto message defined by gRPC. -// When this option is used, the mux uses the configured error handler instead of HTTPError and -// OtherErrorHandler. -func WithProtoErrorHandler(fn ProtoErrorHandlerFunc) ServeMuxOption { +// This can be used to configure a custom error response. +func WithErrorHandler(fn ErrorHandlerFunc) ServeMuxOption { return func(serveMux *ServeMux) { - serveMux.protoErrorHandler = fn + serveMux.errorHandler = fn + } +} + +// WithStreamErrorHandler returns a ServeMuxOption that will use the given custom stream +// error handler, which allows for customizing the error trailer for server-streaming +// calls. +// +// For stream errors that occur before any response has been written, the mux's +// ErrorHandler will be invoked. However, once data has been written, the errors must +// be handled differently: they must be included in the response body. The response body's +// final message will include the error details returned by the stream error handler. +func WithStreamErrorHandler(fn StreamErrorHandlerFunc) ServeMuxOption { + return func(serveMux *ServeMux) { + serveMux.streamErrorHandler = fn } } @@ -143,6 +148,8 @@ func NewServeMux(opts ...ServeMuxOption) *ServeMux { handlers: make(map[string][]handler), forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0), marshalers: makeMarshalerMIMERegistry(), + errorHandler: defaultHTTPErrorHandler, + streamErrorHandler: defaultStreamErrorHandler, } for _, opt := range opts { @@ -177,28 +184,23 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { path := r.URL.Path if !strings.HasPrefix(path, "/") { - if s.protoErrorHandler != nil { - _, outboundMarshaler := MarshalerForRequest(s, r) - sterr := status.Error(codes.InvalidArgument, http.StatusText(http.StatusBadRequest)) - s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr) - } else { - OtherErrorHandler(w, r, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) - } + _, outboundMarshaler := MarshalerForRequest(s, r) + sterr := status.Error(codes.InvalidArgument, http.StatusText(http.StatusBadRequest)) + s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr) return } components := strings.Split(path[1:], "/") l := len(components) var verb string - if idx := strings.LastIndex(components[l-1], ":"); idx == 0 { - if s.protoErrorHandler != nil { - _, outboundMarshaler := MarshalerForRequest(s, r) - s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI) - } else { - OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound) - } + idx := strings.LastIndex(components[l-1], ":") + if idx == 0 { + _, outboundMarshaler := MarshalerForRequest(s, r) + sterr := status.Error(codes.NotFound, http.StatusText(http.StatusNotFound)) + s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr) return - } else if idx > 0 { + } + if idx > 0 { c := components[l-1] components[l-1], verb = c[:idx], c[idx+1:] } @@ -206,13 +208,9 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && s.isPathLengthFallback(r) { r.Method = strings.ToUpper(override) if err := r.ParseForm(); err != nil { - if s.protoErrorHandler != nil { - _, outboundMarshaler := MarshalerForRequest(s, r) - sterr := status.Error(codes.InvalidArgument, err.Error()) - s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr) - } else { - OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest) - } + _, outboundMarshaler := MarshalerForRequest(s, r) + sterr := status.Error(codes.InvalidArgument, err.Error()) + s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr) return } } @@ -226,7 +224,7 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // lookup other methods to handle fallback from GET to POST and - // to determine if it is MethodNotAllowed or NotFound. + // to determine if it is NotImplemented or NotFound. for m, handlers := range s.handlers { if m == r.Method { continue @@ -239,34 +237,25 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { // X-HTTP-Method-Override is optional. Always allow fallback to POST. if s.isPathLengthFallback(r) { if err := r.ParseForm(); err != nil { - if s.protoErrorHandler != nil { - _, outboundMarshaler := MarshalerForRequest(s, r) - sterr := status.Error(codes.InvalidArgument, err.Error()) - s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr) - } else { - OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest) - } + _, outboundMarshaler := MarshalerForRequest(s, r) + sterr := status.Error(codes.InvalidArgument, err.Error()) + s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr) return } h.h(w, r, pathParams) return } - if s.protoErrorHandler != nil { - _, outboundMarshaler := MarshalerForRequest(s, r) - s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI) - } else { - OtherErrorHandler(w, r, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) - } + _, outboundMarshaler := MarshalerForRequest(s, r) + // codes.Unimplemented is the closes we have to MethodNotAllowed + sterr := status.Error(codes.Unimplemented, http.StatusText(http.StatusNotImplemented)) + s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr) return } } - if s.protoErrorHandler != nil { - _, outboundMarshaler := MarshalerForRequest(s, r) - s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI) - } else { - OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound) - } + _, outboundMarshaler := MarshalerForRequest(s, r) + sterr := status.Error(codes.NotFound, http.StatusText(http.StatusNotFound)) + s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr) } // GetForwardResponseOptions returns the ForwardResponseOptions associated with this ServeMux. diff --git a/runtime/mux_test.go b/runtime/mux_test.go index 21ea5d143f5..da74f1410ee 100644 --- a/runtime/mux_test.go +++ b/runtime/mux_test.go @@ -2,16 +2,14 @@ package runtime_test import ( "bytes" - "context" "fmt" "net/http" "net/http/httptest" + "strconv" "testing" "github.com/grpc-ecosystem/grpc-gateway/v2/internal/utilities" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" ) func TestMuxServeHTTP(t *testing.T) { @@ -21,7 +19,7 @@ func TestMuxServeHTTP(t *testing.T) { pool []string verb string } - for _, spec := range []struct { + for i, spec := range []struct { patterns []stubPattern patternOpts []runtime.PatternOpt @@ -33,7 +31,6 @@ func TestMuxServeHTTP(t *testing.T) { respContent string disablePathLengthFallback bool - errHandler runtime.ProtoErrorHandlerFunc muxOpts []runtime.ServeMuxOption }{ { @@ -112,7 +109,7 @@ func TestMuxServeHTTP(t *testing.T) { }, reqMethod: "DELETE", reqPath: "/foo", - respStatus: http.StatusMethodNotAllowed, + respStatus: http.StatusNotImplemented, }, { patterns: []stubPattern{ @@ -143,8 +140,7 @@ func TestMuxServeHTTP(t *testing.T) { headers: map[string]string{ "Content-Type": "application/x-www-form-urlencoded", }, - respStatus: http.StatusMethodNotAllowed, - respContent: "Method Not Allowed\n", + respStatus: http.StatusNotImplemented, disablePathLengthFallback: true, }, { @@ -204,7 +200,7 @@ func TestMuxServeHTTP(t *testing.T) { headers: map[string]string{ "Content-Type": "application/json", }, - respStatus: http.StatusMethodNotAllowed, + respStatus: http.StatusNotImplemented, }, { patterns: []stubPattern{ @@ -245,38 +241,6 @@ func TestMuxServeHTTP(t *testing.T) { respStatus: http.StatusOK, respContent: "GET /foo/{id=*}:verb", }, - { - // mux identifying invalid path results in 'Not Found' status - // (with custom handler looking for ErrUnknownURI) - patterns: []stubPattern{ - { - method: "GET", - ops: []int{int(utilities.OpLitPush), 0}, - pool: []string{"unimplemented"}, - }, - }, - reqMethod: "GET", - reqPath: "/foobar", - respStatus: http.StatusNotFound, - respContent: "GET /foobar", - errHandler: unknownPathIs404, - }, - { - // server returning unimplemented results in 'Not Implemented' code - // even when using custom error handler - patterns: []stubPattern{ - { - method: "GET", - ops: []int{int(utilities.OpLitPush), 0}, - pool: []string{"unimplemented"}, - }, - }, - reqMethod: "GET", - reqPath: "/unimplemented", - respStatus: http.StatusNotImplemented, - respContent: `GET /unimplemented`, - errHandler: unknownPathIs404, - }, { patterns: []stubPattern{ { @@ -336,64 +300,45 @@ func TestMuxServeHTTP(t *testing.T) { muxOpts: []runtime.ServeMuxOption{runtime.WithLastMatchWins()}, }, } { - opts := spec.muxOpts - if spec.disablePathLengthFallback { - opts = append(opts, runtime.WithDisablePathLengthFallback()) - } - if spec.errHandler != nil { - opts = append(opts, runtime.WithProtoErrorHandler(spec.errHandler)) - } - mux := runtime.NewServeMux(opts...) - for _, p := range spec.patterns { - func(p stubPattern) { - pat, err := runtime.NewPattern(1, p.ops, p.pool, p.verb, spec.patternOpts...) - if err != nil { - t.Fatalf("runtime.NewPattern(1, %#v, %#v, %q) failed with %v; want success", p.ops, p.pool, p.verb, err) - } - mux.Handle(p.method, pat, func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) { - if r.URL.Path == "/unimplemented" { - // simulate method returning "unimplemented" error - _, m := runtime.MarshalerForRequest(mux, r) - runtime.HTTPError(r.Context(), mux, m, w, r, status.Error(codes.Unimplemented, http.StatusText(http.StatusNotImplemented))) - w.WriteHeader(http.StatusNotImplemented) - return + t.Run(strconv.Itoa(i), func(t *testing.T) { + opts := spec.muxOpts + if spec.disablePathLengthFallback { + opts = append(opts, runtime.WithDisablePathLengthFallback()) + } + mux := runtime.NewServeMux(opts...) + for _, p := range spec.patterns { + func(p stubPattern) { + pat, err := runtime.NewPattern(1, p.ops, p.pool, p.verb, spec.patternOpts...) + if err != nil { + t.Fatalf("runtime.NewPattern(1, %#v, %#v, %q) failed with %v; want success", p.ops, p.pool, p.verb, err) } - fmt.Fprintf(w, "%s %s", p.method, pat.String()) - }) - }(p) - } - - url := fmt.Sprintf("http://host.example%s", spec.reqPath) - r, err := http.NewRequest(spec.reqMethod, url, bytes.NewReader(nil)) - if err != nil { - t.Fatalf("http.NewRequest(%q, %q, nil) failed with %v; want success", spec.reqMethod, url, err) - } - for name, value := range spec.headers { - r.Header.Set(name, value) - } - w := httptest.NewRecorder() - mux.ServeHTTP(w, r) + mux.Handle(p.method, pat, func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) { + fmt.Fprintf(w, "%s %s", p.method, pat.String()) + }) + }(p) + } - if got, want := w.Code, spec.respStatus; got != want { - t.Errorf("w.Code = %d; want %d; patterns=%v; req=%v", got, want, spec.patterns, r) - } - if spec.respContent != "" { - if got, want := w.Body.String(), spec.respContent; got != want { - t.Errorf("w.Body = %q; want %q; patterns=%v; req=%v", got, want, spec.patterns, r) + url := fmt.Sprintf("http://host.example%s", spec.reqPath) + r, err := http.NewRequest(spec.reqMethod, url, bytes.NewReader(nil)) + if err != nil { + t.Fatalf("http.NewRequest(%q, %q, nil) failed with %v; want success", spec.reqMethod, url, err) } - } - } -} + for name, value := range spec.headers { + r.Header.Set(name, value) + } + w := httptest.NewRecorder() + mux.ServeHTTP(w, r) -func unknownPathIs404(ctx context.Context, mux *runtime.ServeMux, m runtime.Marshaler, w http.ResponseWriter, r *http.Request, err error) { - if err == runtime.ErrUnknownURI { - w.WriteHeader(http.StatusNotFound) - } else { - c := status.Convert(err).Code() - w.WriteHeader(runtime.HTTPStatusFromCode(c)) + if got, want := w.Code, spec.respStatus; got != want { + t.Errorf("w.Code = %d; want %d; patterns=%v; req=%v", got, want, spec.patterns, r) + } + if spec.respContent != "" { + if got, want := w.Body.String(), spec.respContent; got != want { + t.Errorf("w.Body = %q; want %q; patterns=%v; req=%v", got, want, spec.patterns, r) + } + } + }) } - - fmt.Fprintf(w, "%s %s", r.Method, r.URL.Path) } var defaultHeaderMatcherTests = []struct { diff --git a/runtime/pattern.go b/runtime/pattern.go index c2e4bf956b0..a1bd2496fc9 100644 --- a/runtime/pattern.go +++ b/runtime/pattern.go @@ -21,7 +21,8 @@ type op struct { operand int } -// Pattern is a template pattern of http request paths defined in github.com/googleapis/googleapis/google/api/http.proto. +// Pattern is a template pattern of http request paths defined in +// https://github.com/googleapis/googleapis/blob/master/google/api/http.proto type Pattern struct { // ops is a list of operations ops []op diff --git a/runtime/proto_errors.go b/runtime/proto_errors.go deleted file mode 100644 index b0cf0d0bb3f..00000000000 --- a/runtime/proto_errors.go +++ /dev/null @@ -1,70 +0,0 @@ -package runtime - -import ( - "context" - "io" - "net/http" - - "google.golang.org/grpc/codes" - "google.golang.org/grpc/grpclog" - "google.golang.org/grpc/status" -) - -// ProtoErrorHandlerFunc handles the error as a gRPC error generated via status package and replies to the request. -type ProtoErrorHandlerFunc func(context.Context, *ServeMux, Marshaler, http.ResponseWriter, *http.Request, error) - -var _ ProtoErrorHandlerFunc = DefaultHTTPProtoErrorHandler - -// DefaultHTTPProtoErrorHandler is an implementation of HTTPError. -// If "err" is an error from gRPC system, the function replies with the status code mapped by HTTPStatusFromCode. -// If otherwise, it replies with http.StatusInternalServerError. -// -// The response body returned by this function is a Status message marshaled by a Marshaler. -// -// Do not set this function to HTTPError variable directly, use WithProtoErrorHandler option instead. -func DefaultHTTPProtoErrorHandler(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, _ *http.Request, err error) { - // return Internal when Marshal failed - const fallback = `{"code": 13, "message": "failed to marshal error message"}` - - s, ok := status.FromError(err) - if !ok { - s = status.New(codes.Unknown, err.Error()) - } - - w.Header().Del("Trailer") - - contentType := marshaler.ContentType() - // Check marshaler on run time in order to keep backwards compatability - // An interface param needs to be added to the ContentType() function on - // the Marshal interface to be able to remove this check - if typeMarshaler, ok := marshaler.(contentTypeMarshaler); ok { - pb := s.Proto() - contentType = typeMarshaler.ContentTypeFromMessage(pb) - } - w.Header().Set("Content-Type", contentType) - - buf, merr := marshaler.Marshal(s.Proto()) - if merr != nil { - grpclog.Infof("Failed to marshal error message %q: %v", s.Proto(), merr) - w.WriteHeader(http.StatusInternalServerError) - if _, err := io.WriteString(w, fallback); err != nil { - grpclog.Infof("Failed to write response: %v", err) - } - return - } - - md, ok := ServerMetadataFromContext(ctx) - if !ok { - grpclog.Infof("Failed to extract ServerMetadata from context") - } - - handleForwardResponseServerMetadata(w, mux, md) - handleForwardResponseTrailerHeader(w, md) - st := HTTPStatusFromCode(s.Code()) - w.WriteHeader(st) - if _, err := w.Write(buf); err != nil { - grpclog.Infof("Failed to write response: %v", err) - } - - handleForwardResponseTrailer(w, md) -}