diff --git a/router-tests/modules/context-error/module.go b/router-tests/modules/context-error/module.go new file mode 100644 index 0000000000..6b1337e60e --- /dev/null +++ b/router-tests/modules/context-error/module.go @@ -0,0 +1,89 @@ +package context_error + +import ( + "net/http" + + "github.com/wundergraph/cosmo/router/core" +) + +const myModuleID = "contextErrorModule" + +type ContextErrorModule struct { + ErrorValue error +} + +type headerCapturingWriter struct { + http.ResponseWriter + ctx core.RequestContext + statusCode int + moduleReference *ContextErrorModule + hasError bool + headerWritten bool +} + +func (w *headerCapturingWriter) checkAndSetError() { + if !w.hasError { + if err := w.ctx.Error(); err != nil { + w.moduleReference.ErrorValue = err + w.hasError = true + w.Header().Set("X-Has-Error", "true") + } + } +} + +func (w *headerCapturingWriter) WriteHeader(statusCode int) { + if !w.headerWritten { + w.statusCode = statusCode + w.checkAndSetError() + w.headerWritten = true + w.ResponseWriter.WriteHeader(statusCode) + } +} + +func (w *headerCapturingWriter) Write(b []byte) (int, error) { + if !w.headerWritten { + w.checkAndSetError() + w.headerWritten = true + } + + return w.ResponseWriter.Write(b) +} + +// Flush implements http.Flusher to support streaming responses +func (w *headerCapturingWriter) Flush() { + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +func (m *ContextErrorModule) RouterOnRequest(ctx core.RequestContext, next http.Handler) { + // Wrap the response writer to intercept writes + wrappedWriter := &headerCapturingWriter{ + ResponseWriter: ctx.ResponseWriter(), + ctx: ctx, + statusCode: 0, + moduleReference: m, + } + + // Call the next handler with the wrapped writer + // This wrapped writer will be passed through to all subsequent handlers, + // including the pre-handler where authentication happens + next.ServeHTTP(wrappedWriter, ctx.Request()) +} + +func (m *ContextErrorModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + // This is the ID of your module, it must be unique + ID: myModuleID, + // The priority of your module, lower numbers are executed first + Priority: 1, + New: func() core.Module { + return &ContextErrorModule{} + }, + } +} + +// Interface guard +var ( + _ core.RouterOnRequestHandler = (*ContextErrorModule)(nil) +) diff --git a/router-tests/modules/context_error_field_test.go b/router-tests/modules/context_error_field_test.go new file mode 100644 index 0000000000..ea9ebf2b39 --- /dev/null +++ b/router-tests/modules/context_error_field_test.go @@ -0,0 +1,125 @@ +package module_test + +import ( + "io" + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/require" + integration "github.com/wundergraph/cosmo/router-tests" + contexterror "github.com/wundergraph/cosmo/router-tests/modules/context-error" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/config" +) + +func TestContextErrorModule(t *testing.T) { + t.Parallel() + + t.Run("error is captured in context when authentication fails", func(t *testing.T) { + t.Parallel() + + authenticators, _ := integration.ConfigureAuth(t) + accessController, err := core.NewAccessController(core.AccessControllerOptions{ + Authenticators: authenticators, + AuthenticationRequired: true, + SkipIntrospectionQueries: false, + IntrospectionSkipSecret: "", + }) + require.NoError(t, err) + + cfg := config.Config{ + Modules: map[string]interface{}{ + "contextErrorModule": contexterror.ContextErrorModule{}, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(accessController), + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&contexterror.ContextErrorModule{}), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Operations with an invalid token should fail + header := http.Header{ + "Authorization": []string{"Bearer invalid"}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(`{"query":"{ employees { id } }"}`)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Contains(t, string(data), "unauthorized") + + // Verify the X-Has-Error header is set when authentication fails + require.Equal(t, "true", res.Header.Get("X-Has-Error")) + }) + }) + + t.Run("error is captured in context when subgraph fails", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Modules: map[string]interface{}{ + "contextErrorModule": contexterror.ContextErrorModule{}, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&contexterror.ContextErrorModule{}), + }, + Subgraphs: testenv.SubgraphsConfig{ + Products: testenv.SubgraphConfig{ + Middleware: func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"errors":[{"message":"Internal server error","extensions":{"code":"INTERNAL_SERVER_ERROR"}}]}`)) + }) + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `{ employees { id details { forename surname } notes } }`, + }) + + // Verify the response contains errors from the subgraph failure + require.Contains(t, res.Body, "errors") + require.Contains(t, res.Body, "Failed to fetch from Subgraph") + + // Verify the X-Has-Error header is set when subgraph fails + require.Equal(t, "true", res.Response.Header.Get("X-Has-Error")) + }) + }) + + t.Run("no error in context when request succeeds", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Modules: map[string]interface{}{ + "contextErrorModule": contexterror.ContextErrorModule{}, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&contexterror.ContextErrorModule{}), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query MyQuery { employee(id: 1) { id } }`, + }) + + require.Equal(t, `{"data":{"employee":{"id":1}}}`, res.Body) + + // Verify the X-Has-Error header is NOT set when request succeeds + require.Empty(t, res.Response.Header.Get("X-Has-Error")) + }) + }) +} diff --git a/router/core/context.go b/router/core/context.go index 8f567df0d1..264e166c32 100644 --- a/router/core/context.go +++ b/router/core/context.go @@ -139,6 +139,9 @@ type RequestContext interface { // SetForceSha256Compute forces the computation of the Sha256Hash of the operation // This is useful if the Sha256Hash is needed in custom modules but not used anywhere else SetForceSha256Compute() + + // Error returns the error associated with the request, if any + Error() error } var metricAttrsPool = sync.Pool{ @@ -473,6 +476,11 @@ func (c *requestContext) SetForceSha256Compute() { c.forceSha256Compute = true } +// Error returns the error associated with the request, if any +func (c *requestContext) Error() error { + return c.error +} + type OperationContext interface { // Name is the name of the operation Name() string diff --git a/router/core/graphql_handler.go b/router/core/graphql_handler.go index e53489c059..cad6df1466 100644 --- a/router/core/graphql_handler.go +++ b/router/core/graphql_handler.go @@ -175,13 +175,10 @@ func (h *GraphQLHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { resolveCtx = WithResponseHeaderPropagation(resolveCtx) } - defer propagateSubgraphErrors(resolveCtx) - respBuf := bytes.Buffer{} resp, err := h.executor.Resolver.ResolveGraphQLResponse(resolveCtx, p.Response, nil, &respBuf) reqCtx.dataSourceNames = getSubgraphNames(p.Response.DataSources) - if err != nil { trackFinalResponseError(resolveCtx.Context(), err) h.WriteError(resolveCtx, err, p.Response, w) @@ -189,6 +186,7 @@ func (h *GraphQLHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if errs := resolveCtx.SubgraphErrors(); errs != nil { + trackFinalResponseError(resolveCtx.Context(), errs) w.Header().Set("Cache-Control", "no-store, no-cache, must-revalidate") }