Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions router-tests/prometheus_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4238,6 +4238,50 @@ func TestPrometheus(t *testing.T) {
})
})

t.Run("Authentication failure records correct HTTP status code in metrics", func(t *testing.T) {
t.Parallel()

authenticators, _ := ConfigureAuth(t)
accessController, err := core.NewAccessController(core.AccessControllerOptions{
Authenticators: authenticators,
AuthenticationRequired: true,
})
require.NoError(t, err)

metricReader := metric.NewManualReader()
promRegistry := prometheus.NewRegistry()

testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithAccessController(accessController),
},
MetricReader: metricReader,
PrometheusRegistry: promRegistry,
}, func(t *testing.T, xEnv *testenv.Environment) {
// Make unauthenticated request - should get 401
res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", nil,
strings.NewReader(`{"query":"{ employees { id } }"}`))
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)

// Verify metrics record 401, not 200
mf, err := promRegistry.Gather()
require.NoError(t, err)

requestTotal := findMetricFamilyByName(mf, "router_http_requests_total")
require.NotNil(t, requestTotal, "expected router_http_requests_total metric to exist")

// Find metrics with http_status_code=401
metricsWithStatus401 := findMetricsByLabel(requestTotal, "http_status_code", "401")
require.Len(t, metricsWithStatus401, 1, "expected exactly one metric with http_status_code=401")

// Ensure there are no metrics with http_status_code=200
metricsWithStatus200 := findMetricsByLabel(requestTotal, "http_status_code", "200")
require.Len(t, metricsWithStatus200, 0, "expected no metrics with http_status_code=200 for auth failure")
})
})

}

func getPort(connectionTotal *io_prometheus_client.Metric) string {
Expand Down
46 changes: 46 additions & 0 deletions router-tests/telemetry/telemetry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7853,6 +7853,52 @@ func TestFlakyTelemetry(t *testing.T) {
})
})

t.Run("Authentication failure records correct HTTP status code in metrics", func(t *testing.T) {
t.Parallel()

metricReader := metric.NewManualReader()
authenticators, _ := integration.ConfigureAuth(t)
accessController, err := core.NewAccessController(core.AccessControllerOptions{
Authenticators: authenticators,
AuthenticationRequired: true,
})
require.NoError(t, err)

testenv.Run(t, &testenv.Config{
MetricReader: metricReader,
RouterOptions: []core.Option{
core.WithAccessController(accessController),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
// Make unauthenticated request - should get 401
res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", nil,
strings.NewReader(`{"query":"{ employees { id } }"}`))
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)

rm := metricdata.ResourceMetrics{}
err = metricReader.Collect(context.Background(), &rm)
require.NoError(t, err)

scopeMetric := *integration.GetMetricScopeByName(rm.ScopeMetrics, "cosmo.router")

Comment thread
SkArchon marked this conversation as resolved.
statusCode401 := semconv.HTTPStatusCode(http.StatusUnauthorized)

// Verify http_status_code=401 on router.http.requests
requestsMetric := integration.GetMetricByName(&scopeMetric, "router.http.requests")
require.NotNil(t, requestsMetric)
requestsData := requestsMetric.Data.(metricdata.Sum[int64])
require.True(t, integration.HasDataPointWithAttribute(requestsData.DataPoints, statusCode401))

// Verify http_status_code=401 on router.http.request.duration_milliseconds
durationMetric := integration.GetMetricByName(&scopeMetric, "router.http.request.duration_milliseconds")
require.NotNil(t, durationMetric)
durationData := durationMetric.Data.(metricdata.Histogram[float64])
require.True(t, integration.HasHistogramDataPointWithAttribute(durationData.DataPoints, statusCode401))
})
})

t.Run("Operation parsing errors are tracked", func(t *testing.T) {
t.Parallel()

Expand Down
22 changes: 22 additions & 0 deletions router-tests/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,25 @@ func GetMetricScopeByName(metrics []metricdata.ScopeMetrics, name string) *metri
func ToPtr[T any](v T) *T {
return &v
}

// HasDataPointWithAttribute checks if any data point in a Sum or Gauge metric has the given attribute.
func HasDataPointWithAttribute[N int64 | float64](dataPoints []metricdata.DataPoint[N], attr attribute.KeyValue) bool {
for _, dp := range dataPoints {
val, ok := dp.Attributes.Value(attr.Key)
if ok && val == attr.Value {
return true
}
}
return false
}

// HasHistogramDataPointWithAttribute checks if any data point in a Histogram metric has the given attribute.
func HasHistogramDataPointWithAttribute[N int64 | float64](dataPoints []metricdata.HistogramDataPoint[N], attr attribute.KeyValue) bool {
for _, dp := range dataPoints {
val, ok := dp.Attributes.Value(attr.Key)
if ok && val == attr.Value {
return true
}
}
return false
}
34 changes: 17 additions & 17 deletions router/core/graphql_prehandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,13 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

var (
// In GraphQL the statusCode does not always express the error state of the request
// we use this flag to determine if we have an error for the request metrics
writtenBytes int
statusCode = http.StatusOK
traceTimings *art.TraceTimings
)

// Wrap the response w early so that all paths (including early returns
// for auth failures, bad requests, etc.) have the actual HTTP status code
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)

requestContext := getRequestContext(r.Context())
requestLogger := requestContext.logger

Expand Down Expand Up @@ -249,6 +249,11 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler {
requestContext.telemetry.AddCustomMetricStringSliceAttr(ContextFieldOperationServices, requestContext.dataSourceNames)
requestContext.telemetry.AddCustomMetricStringSliceAttr(ContextFieldGraphQLErrorCodes, requestContext.graphQLErrorCodes)

// Read the actual status code from the wrapped response w.
// This captures the correct status code for all paths, including early returns.
statusCode := ww.Status()
writtenBytes := ww.BytesWritten()

metrics.Finish(
requestContext,
statusCode,
Expand All @@ -266,7 +271,7 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler {
requestContext.SetError(err)
writeRequestErrors(writeRequestErrorsParams{
request: r,
writer: w,
writer: ww,
statusCode: http.StatusBadRequest,
requestErrors: graphqlerrors.RequestErrorsFromError(err),
logger: requestLogger,
Expand All @@ -293,7 +298,7 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler {
message: "file upload disabled",
statusCode: http.StatusOK,
})
writeOperationError(r, w, requestLogger, requestContext.error, h.headerPropagation)
writeOperationError(r, ww, requestLogger, requestContext.error, h.headerPropagation)
return
}

Expand All @@ -312,7 +317,7 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler {
}
if err != nil {
requestContext.SetError(err)
writeOperationError(r, w, requestLogger, requestContext.error, h.headerPropagation)
writeOperationError(r, ww, requestLogger, requestContext.error, h.headerPropagation)
readMultiPartSpan.End()
return
}
Expand Down Expand Up @@ -349,7 +354,7 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler {
// e.g. too large body, slow client, aborted connection etc.
// The error is logged as debug log in the writeOperationError function

writeOperationError(r, w, requestLogger, err, h.headerPropagation)
writeOperationError(r, ww, requestLogger, err, h.headerPropagation)
readOperationBodySpan.End()
return
}
Expand All @@ -365,7 +370,7 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler {
trace.WithAttributes(requestContext.telemetry.traceAttrs...),
)

validatedReq, err := h.accessController.Access(w, r)
validatedReq, err := h.accessController.Access(ww, r)
if err != nil {
// Auth failed but introspection queries might be allowed to skip auth.
// At this early stage we don't know wether this query is an introspection query or not.
Expand All @@ -377,14 +382,14 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler {
// Reject the request since auth has failed
// and skipping auth for introspection queries is not allowed,
// so it does not matter wether this is an introspection query or not.
h.handleAuthenticationFailure(requestContext, requestLogger, err, routerSpan, authenticateSpan, r, w)
h.handleAuthenticationFailure(requestContext, requestLogger, err, routerSpan, authenticateSpan, r, ww)
authenticateSpan.End()
return
}

if h.accessController.IntrospectionSecretConfigured() {
if !h.accessController.IntrospectionAccess(r, body) {
h.handleAuthenticationFailure(requestContext, requestLogger, err, routerSpan, authenticateSpan, r, w)
h.handleAuthenticationFailure(requestContext, requestLogger, err, routerSpan, authenticateSpan, r, ww)
authenticateSpan.End()
return
}
Expand Down Expand Up @@ -418,7 +423,7 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler {
// Mark the root span of the router as failed, so we can easily identify failed requests
rtrace.AttachErrToSpan(routerSpan, err)

writeOperationError(r, w, requestLogger, err, h.headerPropagation)
writeOperationError(r, ww, requestLogger, err, h.headerPropagation)
return
}

Expand All @@ -438,8 +443,6 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler {
r = r.WithContext(resolve.SetRequest(r.Context(), reqData))
}

ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)

// The request context needs to be updated with the latest request to ensure that the context is up to date
requestContext.request = r
requestContext.responseWriter = ww
Expand All @@ -448,9 +451,6 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler {
// and enrich the context to make it available in the request context as well for metrics etc.
next.ServeHTTP(ww, r)

statusCode = ww.Status()
writtenBytes = ww.BytesWritten()

// Mark the root span of the router as failed, so we can easily identify failed requests
if requestContext.error != nil {
rtrace.AttachErrToSpan(trace.SpanFromContext(r.Context()), requestContext.error)
Expand Down
Loading