Skip to content
112 changes: 112 additions & 0 deletions router-tests/error_handling_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1820,6 +1820,118 @@ func TestErrorLocations(t *testing.T) {
})
}
})

t.Run("OmitLocations strips locations while preserving extensions", func(t *testing.T) {
t.Parallel()

t.Run("passthrough mode", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
ModifySubgraphErrorPropagation: func(cfg *config.SubgraphErrorPropagationConfiguration) {
cfg.Enabled = true
cfg.Mode = config.SubgraphErrorPropagationModePassthrough
cfg.OmitLocations = true
cfg.OmitExtensions = false
},
Subgraphs: testenv.SubgraphsConfig{
Employees: testenv.SubgraphConfig{
Middleware: func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"errors":[{"message":"Unauthorized","locations":[{"line":1,"column":1}],"extensions":{"code":"UNAUTHORIZED"}}]}`))
})
},
},
},
}, func(t *testing.T, xEnv *testenv.Environment) {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `{ employees { id details { forename surname } notes } }`,
})
require.JSONEq(t, `{"errors":[{"message":"Unauthorized","extensions":{"code":"UNAUTHORIZED","statusCode":200}}],"data":{"employees":null}}`, res.Body)
})
})

t.Run("wrapped mode", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
ModifySubgraphErrorPropagation: func(cfg *config.SubgraphErrorPropagationConfiguration) {
cfg.Enabled = true
cfg.Mode = config.SubgraphErrorPropagationModeWrapped
cfg.OmitLocations = true
cfg.OmitExtensions = false
},
Subgraphs: testenv.SubgraphsConfig{
Employees: testenv.SubgraphConfig{
Middleware: func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"errors":[{"message":"Unauthorized","locations":[{"line":1,"column":1}],"extensions":{"code":"UNAUTHORIZED"}}]}`))
})
},
},
},
}, func(t *testing.T, xEnv *testenv.Environment) {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `{ employees { id details { forename surname } notes } }`,
})
require.JSONEq(t, `{"errors":[{"message":"Failed to fetch from Subgraph 'employees'.","extensions":{"errors":[{"message":"Unauthorized","extensions":{"code":"UNAUTHORIZED"}}],"statusCode":200}}],"data":{"employees":null}}`, res.Body)
})
})

t.Run("passthrough mode with multiple errors", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
ModifySubgraphErrorPropagation: func(cfg *config.SubgraphErrorPropagationConfiguration) {
cfg.Enabled = true
cfg.Mode = config.SubgraphErrorPropagationModePassthrough
cfg.OmitLocations = true
cfg.OmitExtensions = false
},
Subgraphs: testenv.SubgraphsConfig{
Employees: testenv.SubgraphConfig{
Middleware: func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"errors":[{"message":"Error 1","locations":[{"line":1,"column":5}],"extensions":{"code":"ERR1"}},{"message":"Error 2","extensions":{"code":"ERR2"}},{"message":"Error 3","locations":[{"line":3,"column":10}],"extensions":{"code":"ERR3"}}]}`))
})
},
},
},
}, func(t *testing.T, xEnv *testenv.Environment) {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `{ employees { id details { forename surname } notes } }`,
})
require.JSONEq(t, `{"errors":[{"message":"Error 1","extensions":{"code":"ERR1","statusCode":200}},{"message":"Error 2","extensions":{"code":"ERR2","statusCode":200}},{"message":"Error 3","extensions":{"code":"ERR3","statusCode":200}}],"data":{"employees":null}}`, res.Body)
})
})

t.Run("wrapped mode with multiple errors", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
ModifySubgraphErrorPropagation: func(cfg *config.SubgraphErrorPropagationConfiguration) {
cfg.Enabled = true
cfg.Mode = config.SubgraphErrorPropagationModeWrapped
cfg.OmitLocations = true
cfg.OmitExtensions = false
},
Subgraphs: testenv.SubgraphsConfig{
Employees: testenv.SubgraphConfig{
Middleware: func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"errors":[{"message":"Error 1","locations":[{"line":1,"column":5}],"extensions":{"code":"ERR1"}},{"message":"Error 2","extensions":{"code":"ERR2"}},{"message":"Error 3","locations":[{"line":3,"column":10}],"extensions":{"code":"ERR3"}}]}`))
})
},
},
},
}, func(t *testing.T, xEnv *testenv.Environment) {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `{ employees { id details { forename surname } notes } }`,
})
require.JSONEq(t, `{"errors":[{"message":"Failed to fetch from Subgraph 'employees'.","extensions":{"errors":[{"message":"Error 1","extensions":{"code":"ERR1"}},{"message":"Error 2","extensions":{"code":"ERR2"}},{"message":"Error 3","extensions":{"code":"ERR3"}}],"statusCode":200}}],"data":{"employees":null}}`, res.Body)
})
})
})
}

func TestSSEErrorResponseWriteFailures(t *testing.T) {
Expand Down
169 changes: 119 additions & 50 deletions router-tests/modules/custom-forbidden-handler/module.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
// Package custom_forbidden_handler standardizes 403 auth failure responses from subgraphs.
//
// Requirements addressed:
// 1. Replaces the default subgraph 403 response with a uniform format:
// {"errors":[{"message":"...","extensions":{"code":"FORBIDDEN"}}]}
// 2. If any subgraph returns 403, the entire response is replaced — no partial data.
// When a subgraph returns a 403 (either via HTTP status code or a GraphQL error with
// a forbidden indicator in extensions), the module:
// 1. Rewrites the subgraph response body to a uniform error format so the router's
// error pipeline (ALLOWED_EXTENSION_FIELDS, etc.) processes a clean input.
// 2. Short-circuits subsequent subgraph calls via OnOriginRequest once a 403 is detected.
// 3. Replaces the entire router response (via middleware) with a standardized error
// and data:null — no partial data is returned.
//
// How it works:
// - OnOriginResponse (EnginePostOriginHandler) runs per-subgraph and flags 403s on the context.
// It detects 403 in two ways:
// a. HTTP status code 403 on the subgraph response.
// b. GraphQL-level error with extensions.code == 403 (number or string) in the response body,
// even when the HTTP status is 200. The body is read, inspected, and restored so
// downstream handlers can still consume it.
// - Middleware buffers the response. After the engine finishes, if a 403 was flagged,
// the buffered response is discarded and replaced with the standardized error.
// # Acceptance Criteria
//
// - Subgraph returns HTTP 403 → single standardized forbidden error.
// - Subgraph returns 200 with GraphQL error code 403 → detected as forbidden.
// - Subgraph returns 200 with GraphQL errorCode "FORBIDDEN" → detected as forbidden.
// - One subgraph forbidden, others succeed → no partial data, single error.
// - All subgraphs forbidden → single error, not one per subgraph.
// - Sequential subgraph fetches after a 403 → subsequent calls are skipped.
// - Parallel subgraph fetches both forbidden → single error, no duplicates.
// - Non-403 errors → pass through normally, not intercepted.
// - Non-forbidden extension field filtering → unaffected by the module.
// - Streaming requests → not affected by the module.
package custom_forbidden_handler

import (
Expand All @@ -30,6 +36,10 @@ import (

const myModuleID = "forbiddenHandlerModule"

// forbiddenErrorBody is the standardised GraphQL error body written by the module
// when any subgraph returns a 403.
var forbiddenErrorBody = []byte(`{"errors":[{"message":"Insufficient permissions to fulfill the request.","extensions":{"errorCode":"FORBIDDEN"}}],"data":null}`)

type ForbiddenHandlerModule struct {
Logger *zap.Logger
}
Expand All @@ -43,99 +53,157 @@ func (m *ForbiddenHandlerModule) Cleanup() error {
return nil
}

// OnOriginResponse detects 403 from any subgraph and flags it on the context.
// It checks both the HTTP status code and the GraphQL response body for
// errors with extensions.code == 403.
// OnOriginRequest short-circuits subgraph calls when a 403 has already been
// detected from a previous subgraph. This avoids unnecessary network round-trips.
func (m *ForbiddenHandlerModule) OnOriginRequest(req *http.Request, ctx core.RequestContext) (*http.Request, *http.Response) {
if ctx.GetBool("streaming_request") {
return req, nil
}
if ctx.GetBool("forbidden_encountered") {
// Return an empty data response. The middleware will replace the entire
// response anyway, so the body content here does not matter much — it
// just needs to be valid JSON so the resolver does not panic.
return req, &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"data":null}`)),
ContentLength: 13,
}
}
return req, nil
}

// OnOriginResponse detects 403 from any subgraph and rewrites the response body
// to a standardised GraphQL error so the router's error pipeline processes clean
// input. It also flags the request context so the middleware and OnOriginRequest
// can act on it.
func (m *ForbiddenHandlerModule) OnOriginResponse(resp *http.Response, ctx core.RequestContext) *http.Response {
if resp == nil {
if resp == nil || ctx.GetBool("streaming_request") {
return nil
}

if resp.StatusCode == http.StatusForbidden {
ctx.Set("forbidden_encountered", true)
// If already flagged, still rewrite the body so the resolver does not
// process stale subgraph errors.
if ctx.GetBool("forbidden_encountered") {
resp.Body = io.NopCloser(bytes.NewReader(forbiddenErrorBody))
resp.ContentLength = int64(len(forbiddenErrorBody))
resp.StatusCode = http.StatusOK
return nil
}

// Also check if the response body contains a GraphQL error with code 403
if resp.Body != nil {
isForbidden := resp.StatusCode == http.StatusForbidden

if !isForbidden && resp.Body != nil {
body, err := io.ReadAll(resp.Body)
_ = resp.Body.Close()
// Always restore the body so downstream can still read it
resp.Body = io.NopCloser(bytes.NewReader(body))
if err == nil && hasForbiddenGraphQLError(body) {
ctx.Set("forbidden_encountered", true)
isForbidden = true
}
}

if isForbidden {
ctx.Set("forbidden_encountered", true)
resp.Body = io.NopCloser(bytes.NewReader(forbiddenErrorBody))
resp.ContentLength = int64(len(forbiddenErrorBody))
// Normalise to 200 so the resolver reads the body as a regular GraphQL
// response and processes the error through its pipeline.
resp.StatusCode = http.StatusOK
}

return nil
}

// hasForbiddenGraphQLError checks if a GraphQL response body contains an error
// with extensions.code == 403 (as a number).
// whose extensions indicate a 403/FORBIDDEN status. It inspects both "code" and
// "errorCode" extension fields and accepts the numeric value 403 as well as the
// strings "403" and "FORBIDDEN" (case-insensitive).
func hasForbiddenGraphQLError(body []byte) bool {
var result struct {
Errors []struct {
Extensions struct {
Code json.RawMessage `json:"code"`
} `json:"extensions"`
Extensions map[string]json.RawMessage `json:"extensions"`
} `json:"errors"`
}
if err := json.Unmarshal(body, &result); err != nil {
return false
}
for _, e := range result.Errors {
if len(e.Extensions.Code) == 0 {
continue
}
// Check as number
var code float64
if json.Unmarshal(e.Extensions.Code, &code) == nil && code == 403 {
return true
}
// Check as string
var codeStr string
if json.Unmarshal(e.Extensions.Code, &codeStr) == nil && codeStr == "403" {
if isForbiddenCode(e.Extensions["code"]) || isForbiddenCode(e.Extensions["errorCode"]) {
return true
}
}
return false
}

// isForbiddenCode returns true when raw represents the numeric value 403 or the
// strings "403" / "FORBIDDEN" (case-insensitive).
func isForbiddenCode(raw json.RawMessage) bool {
if len(raw) == 0 {
return false
}
var code float64
if json.Unmarshal(raw, &code) == nil && code == 403 {
return true
}
var codeStr string
if json.Unmarshal(raw, &codeStr) == nil {
return codeStr == "403" || strings.EqualFold(codeStr, "FORBIDDEN")
}
return false
}

// Middleware buffers the engine's response. After all subgraph calls complete,
// it checks the forbidden flag: on 403, the buffer is discarded and replaced
// with the standardized error; otherwise the buffered response is flushed.
// it checks the forbidden flag: on 403, the buffered response is discarded and
// replaced with the standardised forbiddenErrorBody (no partial data, no
// pipeline-decorated extensions — always the same clean error).
func (m *ForbiddenHandlerModule) Middleware(ctx core.RequestContext, next http.Handler) {
// Skip for streaming subscriptions (SSE/multipart) — they require
// http.Flusher and deliver data incrementally, so buffering does not apply.
// The flag is set before calling next so that OnOriginRequest and
// OnOriginResponse also skip their forbidden-handling logic.
if isStreamingRequest(ctx.Request()) {
ctx.Set("streaming_request", true)
next.ServeHTTP(ctx.ResponseWriter(), ctx.Request())
Comment thread
StarpTech marked this conversation as resolved.
return
}

// Save the real writer before calling next — the wrapper propagates the
// writer passed to ServeHTTP into reqContext.responseWriter, so after
// next.ServeHTTP(bw, ...) returns, ctx.ResponseWriter() may point to bw.
w := ctx.ResponseWriter()

bw := &bufferedWriter{
header: make(http.Header),
}

next.ServeHTTP(bw, ctx.Request())

if ctx.GetBool("forbidden_encountered") {
core.WriteResponseError(ctx, core.NewHttpGraphqlError(
"Insufficient permissions to fulfill the request.",
"FORBIDDEN",
http.StatusForbidden,
))
// A subgraph returned 403 — discard whatever the engine produced
// (which may contain pipeline-decorated errors with serviceName,
// statusCode, DOWNSTREAM_SERVICE_ERROR, etc.) and write the
// standardised forbidden response directly.
maps.Copy(w.Header(), bw.header)
w.Header().Del("Content-Length")
w.Header().Del("Content-Encoding")
w.Header().Del("Transfer-Encoding")
w.Header().Del("ETag")
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(forbiddenErrorBody)
Comment thread
StarpTech marked this conversation as resolved.
return
}

// Flush buffered response to the real writer
real := ctx.ResponseWriter()
maps.Copy(real.Header(), bw.header)
real.WriteHeader(bw.code)
_, _ = real.Write(bw.body.Bytes())
maps.Copy(w.Header(), bw.header)
if bw.code != 0 {
w.WriteHeader(bw.code)
}
_, _ = w.Write(bw.body.Bytes())
}

func isStreamingRequest(r *http.Request) bool {
accept := r.Header.Get("Accept")
accept := strings.ToLower(strings.TrimSpace(r.Header.Get("Accept")))
return strings.Contains(accept, "text/event-stream") ||
strings.Contains(accept, "multipart/mixed")
}
Expand Down Expand Up @@ -174,6 +242,7 @@ func (m *ForbiddenHandlerModule) Module() core.ModuleInfo {
// Interface guards
var (
_ core.RouterMiddlewareHandler = (*ForbiddenHandlerModule)(nil)
_ core.EnginePreOriginHandler = (*ForbiddenHandlerModule)(nil)
_ core.EnginePostOriginHandler = (*ForbiddenHandlerModule)(nil)
_ core.Provisioner = (*ForbiddenHandlerModule)(nil)
_ core.Cleaner = (*ForbiddenHandlerModule)(nil)
Expand Down
Loading
Loading