diff --git a/handler.go b/handler.go index 8a104a5..b9a647c 100644 --- a/handler.go +++ b/handler.go @@ -29,7 +29,7 @@ type Handler struct { playground bool rootObjectFn RootObjectFn resultCallbackFn ResultCallbackFn - formatErrorFn func(err gqlerrors.FormattedError) gqlerrors.FormattedError + formatErrorFn func(err error) gqlerrors.FormattedError } type RequestOptions struct { @@ -144,7 +144,7 @@ func (h *Handler) ContextHandler(ctx context.Context, w http.ResponseWriter, r * if formatErrorFn := h.formatErrorFn; formatErrorFn != nil && len(result.Errors) > 0 { formatted := make([]gqlerrors.FormattedError, len(result.Errors)) for i, formattedError := range result.Errors { - formatted[i] = formatErrorFn(formattedError) + formatted[i] = formatErrorFn(formattedError.OriginalError()) } result.Errors = formatted } @@ -203,7 +203,7 @@ type Config struct { Playground bool RootObjectFn RootObjectFn ResultCallbackFn ResultCallbackFn - FormatErrorFn func(err gqlerrors.FormattedError) gqlerrors.FormattedError + FormatErrorFn func(err error) gqlerrors.FormattedError } func NewConfig() *Config { diff --git a/handler_test.go b/handler_test.go index 7558159..4154b1a 100644 --- a/handler_test.go +++ b/handler_test.go @@ -2,7 +2,6 @@ package handler_test import ( "encoding/json" - "errors" "fmt" "io/ioutil" "net/http" @@ -215,15 +214,15 @@ func TestHandler_BasicQuery_WithRootObjFn(t *testing.T) { } type customError struct { - error + message string } func (e customError) Error() string { - return e.error.Error() + return fmt.Sprintf("%s", e.message) } func TestHandler_BasicQuery_WithFormatErrorFn(t *testing.T) { - resolverError := customError{error: errors.New("resolver error")} + resolverError := customError{message: "resolver error"} myNameQuery := graphql.NewObject(graphql.ObjectConfig{ Name: "Query", Fields: graphql.Fields{ @@ -252,9 +251,6 @@ func TestHandler_BasicQuery_WithFormatErrorFn(t *testing.T) { }, }, Path: []interface{}{"name"}, - Extensions: map[string]interface{}{ - "fromFormatFn": "FROM_FORMAT_FN", - }, } expected := &graphql.Result{ @@ -271,22 +267,16 @@ func TestHandler_BasicQuery_WithFormatErrorFn(t *testing.T) { h := handler.New(&handler.Config{ Schema: &myNameSchema, Pretty: true, - FormatErrorFn: func(err gqlerrors.FormattedError) gqlerrors.FormattedError { + FormatErrorFn: func(err error) gqlerrors.FormattedError { formatErrorFnCalled = true - originalError := err.OriginalError() - switch errType := originalError.(type) { - case customError: + var formatted gqlerrors.FormattedError + switch err := err.(type) { + case *gqlerrors.Error: + formatted = gqlerrors.FormatError(err) default: - t.Fatalf("unexpected error type: %v", reflect.TypeOf(errType)) - } - return gqlerrors.FormattedError{ - Message: err.Message, - Locations: err.Locations, - Path: err.Path, - Extensions: map[string]interface{}{ - "fromFormatFn": "FROM_FORMAT_FN", - }, + t.Fatalf("unexpected error type: %v", reflect.TypeOf(err)) } + return formatted }, }) result, resp := executeTest(t, h, req)