Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions router-tests/modules/context-error/module.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package context_error

import (
"net/http"

"github.com/wundergraph/cosmo/router/core"
)

const myModuleID = "contextErrorModule"

type ContextErrorModule struct {
ErrorValue error
}

type headerCapturingWriter struct {
http.ResponseWriter
ctx core.RequestContext
statusCode int
moduleReference *ContextErrorModule
hasError bool
headerWritten bool
}

func (w *headerCapturingWriter) checkAndSetError() {
if !w.hasError {
if err := w.ctx.Error(); err != nil {
w.moduleReference.ErrorValue = err
w.hasError = true
w.Header().Set("X-Has-Error", "true")
}
}
}

func (w *headerCapturingWriter) WriteHeader(statusCode int) {
if !w.headerWritten {
w.statusCode = statusCode
w.checkAndSetError()
w.headerWritten = true
w.ResponseWriter.WriteHeader(statusCode)
}
}

func (w *headerCapturingWriter) Write(b []byte) (int, error) {
if !w.headerWritten {
w.checkAndSetError()
w.headerWritten = true
}

return w.ResponseWriter.Write(b)
}

// Flush implements http.Flusher to support streaming responses
func (w *headerCapturingWriter) Flush() {
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}

func (m *ContextErrorModule) RouterOnRequest(ctx core.RequestContext, next http.Handler) {
// Wrap the response writer to intercept writes
wrappedWriter := &headerCapturingWriter{
ResponseWriter: ctx.ResponseWriter(),
ctx: ctx,
statusCode: 0,
moduleReference: m,
}

// Call the next handler with the wrapped writer
// This wrapped writer will be passed through to all subsequent handlers,
// including the pre-handler where authentication happens
next.ServeHTTP(wrappedWriter, ctx.Request())
}

func (m *ContextErrorModule) Module() core.ModuleInfo {
return core.ModuleInfo{
// This is the ID of your module, it must be unique
ID: myModuleID,
// The priority of your module, lower numbers are executed first
Priority: 1,
New: func() core.Module {
return &ContextErrorModule{}
},
}
}

// Interface guard
var (
_ core.RouterOnRequestHandler = (*ContextErrorModule)(nil)
)
125 changes: 125 additions & 0 deletions router-tests/modules/context_error_field_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package module_test

import (
"io"
"net/http"
"strings"
"testing"

"github.com/stretchr/testify/require"
integration "github.com/wundergraph/cosmo/router-tests"
contexterror "github.com/wundergraph/cosmo/router-tests/modules/context-error"
"github.com/wundergraph/cosmo/router-tests/testenv"
"github.com/wundergraph/cosmo/router/core"
"github.com/wundergraph/cosmo/router/pkg/config"
)

func TestContextErrorModule(t *testing.T) {
t.Parallel()

t.Run("error is captured in context when authentication fails", func(t *testing.T) {
t.Parallel()

authenticators, _ := integration.ConfigureAuth(t)
accessController, err := core.NewAccessController(core.AccessControllerOptions{
Authenticators: authenticators,
AuthenticationRequired: true,
SkipIntrospectionQueries: false,
IntrospectionSkipSecret: "",
})
require.NoError(t, err)

cfg := config.Config{
Modules: map[string]interface{}{
"contextErrorModule": contexterror.ContextErrorModule{},
},
}

testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithAccessController(accessController),
core.WithModulesConfig(cfg.Modules),
core.WithCustomModules(&contexterror.ContextErrorModule{}),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
// Operations with an invalid token should fail
header := http.Header{
"Authorization": []string{"Bearer invalid"},
}
res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(`{"query":"{ employees { id } }"}`))
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
data, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Contains(t, string(data), "unauthorized")

// Verify the X-Has-Error header is set when authentication fails
require.Equal(t, "true", res.Header.Get("X-Has-Error"))
})
})

t.Run("error is captured in context when subgraph fails", func(t *testing.T) {
t.Parallel()

cfg := config.Config{
Modules: map[string]interface{}{
"contextErrorModule": contexterror.ContextErrorModule{},
},
}

testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithModulesConfig(cfg.Modules),
core.WithCustomModules(&contexterror.ContextErrorModule{}),
},
Subgraphs: testenv.SubgraphsConfig{
Products: testenv.SubgraphConfig{
Middleware: func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(`{"errors":[{"message":"Internal server error","extensions":{"code":"INTERNAL_SERVER_ERROR"}}]}`))
})
},
},
},
}, func(t *testing.T, xEnv *testenv.Environment) {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `{ employees { id details { forename surname } notes } }`,
})

// Verify the response contains errors from the subgraph failure
require.Contains(t, res.Body, "errors")
require.Contains(t, res.Body, "Failed to fetch from Subgraph")

// Verify the X-Has-Error header is set when subgraph fails
require.Equal(t, "true", res.Response.Header.Get("X-Has-Error"))
})
})

t.Run("no error in context when request succeeds", func(t *testing.T) {
t.Parallel()

cfg := config.Config{
Modules: map[string]interface{}{
"contextErrorModule": contexterror.ContextErrorModule{},
},
}

testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithModulesConfig(cfg.Modules),
core.WithCustomModules(&contexterror.ContextErrorModule{}),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `query MyQuery { employee(id: 1) { id } }`,
})

require.Equal(t, `{"data":{"employee":{"id":1}}}`, res.Body)

// Verify the X-Has-Error header is NOT set when request succeeds
require.Empty(t, res.Response.Header.Get("X-Has-Error"))
})
})
}
8 changes: 8 additions & 0 deletions router/core/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ type RequestContext interface {
// SetForceSha256Compute forces the computation of the Sha256Hash of the operation
// This is useful if the Sha256Hash is needed in custom modules but not used anywhere else
SetForceSha256Compute()

// Error returns the error associated with the request, if any
Error() error
Comment thread
SkArchon marked this conversation as resolved.
}

var metricAttrsPool = sync.Pool{
Expand Down Expand Up @@ -473,6 +476,11 @@ func (c *requestContext) SetForceSha256Compute() {
c.forceSha256Compute = true
}

// Error returns the error associated with the request, if any
func (c *requestContext) Error() error {
Comment thread
SkArchon marked this conversation as resolved.
return c.error
}

type OperationContext interface {
// Name is the name of the operation
Name() string
Expand Down
4 changes: 1 addition & 3 deletions router/core/graphql_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,20 +175,18 @@ func (h *GraphQLHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
resolveCtx = WithResponseHeaderPropagation(resolveCtx)
}

defer propagateSubgraphErrors(resolveCtx)

respBuf := bytes.Buffer{}

resp, err := h.executor.Resolver.ResolveGraphQLResponse(resolveCtx, p.Response, nil, &respBuf)
reqCtx.dataSourceNames = getSubgraphNames(p.Response.DataSources)

if err != nil {
trackFinalResponseError(resolveCtx.Context(), err)
h.WriteError(resolveCtx, err, p.Response, w)
return
}

if errs := resolveCtx.SubgraphErrors(); errs != nil {
trackFinalResponseError(resolveCtx.Context(), errs)
w.Header().Set("Cache-Control", "no-store, no-cache, must-revalidate")
}

Expand Down
Loading