diff --git a/router-tests/telemetry/span_error_status_test.go b/router-tests/telemetry/span_error_status_test.go index 88b20a0875..2a02f2a607 100644 --- a/router-tests/telemetry/span_error_status_test.go +++ b/router-tests/telemetry/span_error_status_test.go @@ -3,17 +3,21 @@ package telemetry import ( "context" "net/http" + "net/http/httptest" "strings" "testing" "time" "github.com/stretchr/testify/require" "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/pkg/config" "github.com/wundergraph/cosmo/router/pkg/trace/tracetest" "go.opentelemetry.io/otel/codes" sdkmetric "go.opentelemetry.io/otel/sdk/metric" "go.opentelemetry.io/otel/sdk/metric/metricdata" sdktrace "go.opentelemetry.io/otel/sdk/trace" + otelsdktracetest "go.opentelemetry.io/otel/sdk/trace/tracetest" + "go.uber.org/zap" "go.uber.org/zap/zapcore" ) @@ -26,10 +30,32 @@ func rootSpan(spans []sdktrace.ReadOnlySpan) sdktrace.ReadOnlySpan { return spans[len(spans)-1] } +// waitForStableSpans polls exporter.GetSpans() until the snapshot count stops +// changing for at least 200ms, guaranteeing the trace tree has finished exporting +// before assertions run across all spans. +func waitForStableSpans(t *testing.T, exporter *otelsdktracetest.InMemoryExporter) []sdktrace.ReadOnlySpan { + t.Helper() + var ( + spans []sdktrace.ReadOnlySpan + lastCount = -1 + stableSince time.Time + ) + require.Eventually(t, func() bool { + spans = exporter.GetSpans().Snapshots() + if len(spans) != lastCount { + lastCount = len(spans) + stableSince = time.Now() + return false + } + return len(spans) > 0 && time.Since(stableSince) >= 200*time.Millisecond + }, 5*time.Second, 50*time.Millisecond, "expected span snapshot to stabilize") + return spans +} + func TestClientDisconnectionBehavior(t *testing.T) { t.Parallel() - t.Run("span status is not error but exception event is recorded", func(t *testing.T) { + t.Run("root span is not marked as error but records exception event on client disconnect", func(t *testing.T) { t.Parallel() exporter := tracetest.NewInMemoryExporter(t) @@ -74,7 +100,7 @@ func TestClientDisconnectionBehavior(t *testing.T) { }) }) - t.Run("error metrics are not inflated but request count is recorded", func(t *testing.T) { + t.Run("error metrics are not inflated but request count is recorded on client disconnect", func(t *testing.T) { t.Parallel() metricReader := sdkmetric.NewManualReader() @@ -116,7 +142,7 @@ func TestClientDisconnectionBehavior(t *testing.T) { }) }) - t.Run("log level is info not error", func(t *testing.T) { + t.Run("context canceled is not logged at error level on client disconnect", func(t *testing.T) { t.Parallel() testenv.Run(t, &testenv.Config{ @@ -151,7 +177,208 @@ func TestClientDisconnectionBehavior(t *testing.T) { }) }) - t.Run("other errors still mark span as error", func(t *testing.T) { + t.Run("subgraph fetch span is not marked as error on client disconnect", func(t *testing.T) { + t.Parallel() + + exporter := tracetest.NewInMemoryExporter(t) + + testenv.Run(t, &testenv.Config{ + TraceExporter: exporter, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.DebugLevel, + }, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Delay: 2 * time.Second, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Since the subgraph takes 2 seconds, and this takes 200 milliseconds which is less than the subgraph + // this will ensure that the request is cancelled due to a timeout + ctx, cancel := context.WithTimeout(t.Context(), 200*time.Millisecond) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, xEnv.GraphQLRequestURL(), + strings.NewReader(`{"query":"{ employees { id } }"}`)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req) + require.Error(t, err) + require.Nil(t, resp, "client should not receive any response when it disconnects") + + spans := waitForStableSpans(t, exporter) + + // No span in the entire trace should be marked as ERROR for client disconnections + for _, s := range spans { + require.NotEqual(t, codes.Error, s.Status().Code, + "span %q should not be marked as error when client disconnects", s.Name()) + } + + // Find the "Engine - Fetch" span and verify the cancellation is still recorded as an event + var fetchSpan sdktrace.ReadOnlySpan + for _, s := range spans { + if s.Name() == "Engine - Fetch" { + fetchSpan = s + break + } + } + require.NotNil(t, fetchSpan, "expected Engine - Fetch span to be exported") + + hasExceptionEvent := false + for _, event := range fetchSpan.Events() { + if event.Name == "exception" { + hasExceptionEvent = true + break + } + } + require.True(t, hasExceptionEvent, + "subgraph fetch span should have an exception event recorded for client disconnections") + + // The wg.request.error attribute should not be set on the fetch span + for _, attr := range fetchSpan.Attributes() { + if attr.Key == "wg.request.error" { + require.False(t, attr.Value.AsBool(), + "wg.request.error should not be true on fetch span for client disconnects") + } + } + + // Verify no 500 status code was written — the server should not produce + // an error response when the client has disconnected + requestLogs := xEnv.Observer().FilterField(zap.Int("status", 500)).All() + require.Empty(t, requestLogs, + "server should not write a 500 response for client disconnections") + }) + }) + + t.Run("persisted operation fetch span is not marked as error on client disconnect", func(t *testing.T) { + t.Parallel() + + exporter := tracetest.NewInMemoryExporter(t) + + // Create a slow CDN server that delays persisted operation responses long enough + // for the client's request context to be canceled mid-fetch. + cdnServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/" { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`[]`)) + return + } + time.Sleep(2 * time.Second) + w.WriteHeader(http.StatusNotFound) + })) + defer cdnServer.Close() + + testenv.Run(t, &testenv.Config{ + TraceExporter: exporter, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.DebugLevel, + }, + CdnSever: cdnServer, + ModifyCDNConfig: func(cfg *config.CDNConfiguration) { + cfg.CacheSize = 0 // Disable cache so every request hits the CDN + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + ctx, cancel := context.WithTimeout(t.Context(), 200*time.Millisecond) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, xEnv.GraphQLRequestURL(), + strings.NewReader(`{"operationName":"Employees","extensions":{"persistedQuery":{"version":1,"sha256Hash":"dc67510fb4289672bea757e862d6b00e83db5d3cbbcfb15260601b6f29bb2b8f"}}}`)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("graphql-client-name", "my-client") + + client := &http.Client{} + resp, err := client.Do(req) + require.Error(t, err) + require.Nil(t, resp, "client should not receive any response when it disconnects") + + spans := waitForStableSpans(t, exporter) + + // No span should be marked as ERROR for client disconnections + for _, s := range spans { + require.NotEqual(t, codes.Error, s.Status().Code, + "span %q should not be marked as error when client disconnects during persisted op fetch", s.Name()) + } + + // Verify the "Load Persisted Operation" span exists and has the exception event + var poSpan sdktrace.ReadOnlySpan + for _, s := range spans { + if s.Name() == "Load Persisted Operation" { + poSpan = s + break + } + } + require.NotNil(t, poSpan, "expected Load Persisted Operation span to be exported") + + hasExceptionEvent := false + for _, event := range poSpan.Events() { + if event.Name == "exception" { + hasExceptionEvent = true + break + } + } + require.True(t, hasExceptionEvent, + "Load Persisted Operation span should have an exception event for client disconnections") + + // Verify no 500 status code was written + requestLogs := xEnv.Observer().FilterField(zap.Int("status", 500)).All() + require.Empty(t, requestLogs, + "server should not write a 500 response for client disconnections during persisted op fetch") + }) + }) + + t.Run("batched request spans are not marked as error on client disconnect", func(t *testing.T) { + t.Parallel() + + exporter := tracetest.NewInMemoryExporter(t) + + testenv.Run(t, &testenv.Config{ + TraceExporter: exporter, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.DebugLevel, + }, + BatchingConfig: config.BatchingConfig{ + Enabled: true, + MaxConcurrency: 10, + MaxEntriesPerBatch: 100, + }, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Delay: 2 * time.Second, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + ctx, cancel := context.WithTimeout(t.Context(), 200*time.Millisecond) + defer cancel() + + res, err := xEnv.MakeGraphQLBatchedRequestRequestWithContext(ctx, []testenv.GraphQLRequest{ + {Query: `query employees { employees { id } }`}, + {Query: `query employee { employees { isAvailable } }`}, + }, nil) + require.Error(t, err) + require.Nil(t, res, "client should not receive any response when it disconnects") + + spans := waitForStableSpans(t, exporter) + + // No span should be marked as ERROR for client disconnections + for _, s := range spans { + require.NotEqual(t, codes.Error, s.Status().Code, + "span %q should not be marked as error when client disconnects during batch request", s.Name()) + } + + // Verify no 500 status code was written + requestLogs := xEnv.Observer().FilterField(zap.Int("status", 500)).All() + require.Empty(t, requestLogs, + "server should not write a 500 response for client disconnections during batch request") + }) + }) + + t.Run("root span is marked as error on subgraph failure", func(t *testing.T) { t.Parallel() exporter := tracetest.NewInMemoryExporter(t) diff --git a/router-tests/telemetry/telemetry_test.go b/router-tests/telemetry/telemetry_test.go index 0ac327e1cf..20416beca8 100644 --- a/router-tests/telemetry/telemetry_test.go +++ b/router-tests/telemetry/telemetry_test.go @@ -7737,7 +7737,7 @@ func TestFlakyTelemetry(t *testing.T) { require.Equal(t, "Engine - Fetch", sn[8].Name()) require.Equal(t, trace.SpanKindInternal, sn[8].SpanKind()) require.Equal(t, codes.Error, sn[8].Status().Code) - require.Lenf(t, sn[8].Attributes(), 14, "expected 14 attributes, got %d", len(sn[8].Attributes())) + require.Lenf(t, sn[8].Attributes(), 15, "expected 15 attributes, got %d", len(sn[8].Attributes())) require.Contains(t, sn[8].Status().Description, "connect: connection refused\nFailed to fetch from Subgraph 'products' at Path: 'employees'.") events := sn[8].Events() @@ -7809,7 +7809,7 @@ func TestFlakyTelemetry(t *testing.T) { require.Equal(t, "Engine - Fetch", sn[8].Name()) require.Equal(t, trace.SpanKindInternal, sn[8].SpanKind()) - require.Lenf(t, sn[8].Attributes(), 14, "expected 14 attributes, got %d", len(sn[6].Attributes())) + require.Lenf(t, sn[8].Attributes(), 15, "expected 15 attributes, got %d", len(sn[8].Attributes())) given = attribute.NewSet(sn[8].Attributes()...) want = attribute.NewSet([]attribute.KeyValue{ @@ -7827,6 +7827,7 @@ func TestFlakyTelemetry(t *testing.T) { otel.WgOperationType.String("query"), otel.WgOperationProtocol.String("http"), otel.WgOperationHash.String("13939103824696605913"), + otel.WgRequestError.Bool(true), }...) require.True(t, given.Equals(&want)) diff --git a/router-tests/testenv/testenv.go b/router-tests/testenv/testenv.go index 117c16b22f..0702ec6243 100644 --- a/router-tests/testenv/testenv.go +++ b/router-tests/testenv/testenv.go @@ -1804,23 +1804,33 @@ func SetupCDNServer(t testing.TB) (cdnServer *httptest.Server, port int) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/" { requestLog, err := json.Marshal(cdnRequestLog) - require.NoError(t, err) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } w.Header().Set("Content-Type", "application/json") _, err = w.Write(requestLog) - require.NoError(t, err) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } return } + cdnRequestLog = append(cdnRequestLog, r.Method+" "+r.URL.Path) // Ensure we have an authorization header with a valid token authorization := r.Header.Get("Authorization") - if authorization == "" { - require.NotEmpty(t, authorization, "missing authorization header") + token, ok := strings.CutPrefix(authorization, "Bearer ") + if !ok { + http.Error(w, "missing or malformed Bearer token", http.StatusUnauthorized) + return } - token := authorization[len("Bearer "):] parsedClaims := make(jwt.MapClaims) jwtParser := new(jwt.Parser) - _, _, err := jwtParser.ParseUnverified(token, parsedClaims) - require.NoError(t, err) + if _, _, err := jwtParser.ParseUnverified(token, parsedClaims); err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } cdnFileServer.ServeHTTP(w, r) }) cdnServer = httptest.NewServer(handler) diff --git a/router/core/batch.go b/router/core/batch.go index 4df68be802..cd8670aafc 100644 --- a/router/core/batch.go +++ b/router/core/batch.go @@ -181,7 +181,12 @@ func processBatchedRequest(w http.ResponseWriter, r *http.Request, handlerOpts H } func processBatchError(w http.ResponseWriter, r *http.Request, err error, requestLogger *zap.Logger) { - ctrace.AttachErrToSpanFromContext(r.Context(), err) + if errors.Is(err, context.Canceled) { + span := trace.SpanFromContext(r.Context()) + span.RecordError(err) + } else { + ctrace.AttachErrToSpanFromContext(r.Context(), err) + } requestError := graphqlerrors.RequestError{ Message: err.Error(), diff --git a/router/core/engine_loader_hooks.go b/router/core/engine_loader_hooks.go index 1d284f06ed..13c4a6da1a 100644 --- a/router/core/engine_loader_hooks.go +++ b/router/core/engine_loader_hooks.go @@ -238,72 +238,96 @@ func (f *engineLoaderHooks) OnFinished(ctx context.Context, ds resolve.DataSourc } } - if responseInfo.Err != nil { + if responseInfo.Err != nil && !errors.Is(responseInfo.Err, context.Canceled) { f.accessLogger.Error(path, fields) } else { f.accessLogger.Info(path, fields) } } + measureSliceAttrs := reqContext.telemetry.metricSliceAttrs + if responseInfo.Err != nil { - // Set error status. This is the fetch error from the engine - // Downstream errors are extracted from the subgraph response - rtrace.SetSanitizedSpanStatus(span, codes.Error, responseInfo.Err.Error()) - span.RecordError(responseInfo.Err) - - var errorCodesAttr []string - - if unwrapped, ok := responseInfo.Err.(multiError); ok { - errs := unwrapped.Unwrap() - for _, e := range errs { - var subgraphError *resolve.SubgraphError - if errors.As(e, &subgraphError) { - for i, downstreamError := range subgraphError.DownstreamErrors { - var errorCode string - if downstreamError.Extensions != nil { - if value := downstreamError.Extensions.Get("code"); value != nil { - errorCode = string(value.GetStringBytes()) - } - } - - if errorCode != "" { - errorCodesAttr = append(errorCodesAttr, errorCode) - span.AddEvent(fmt.Sprintf("Downstream error %d", i+1), - trace.WithAttributes( - rotel.WgSubgraphErrorExtendedCode.String(errorCode), - rotel.WgSubgraphErrorMessage.String(downstreamError.Message), - ), - ) - } + // Client disconnections (context.Canceled) are not server-side errors. + // Record the error for observability but don't set the span status to ERROR + // and don't count it as a request error in metrics. + if errors.Is(responseInfo.Err, context.Canceled) { + span.RecordError(responseInfo.Err) + } else { + errorSliceAttrs := *reqContext.telemetry.AcquireAttributes() + defer reqContext.telemetry.ReleaseAttributes(&errorSliceAttrs) + errorSliceAttrs = append(errorSliceAttrs, reqContext.telemetry.metricSliceAttrs...) + + measureSliceAttrs, metricAddOpt = f.recordFetchError(ctx, span, responseInfo.Err, reqContext, metricAttrs, metricAddOpt, errorSliceAttrs) + } + } + + f.metricStore.MeasureRequestCount(ctx, measureSliceAttrs, metricAddOpt) + f.metricStore.MeasureLatency(ctx, latency, measureSliceAttrs, metricAddOpt) + + span.SetAttributes(traceAttrs...) +} + +// recordFetchError sets the span status to ERROR, extracts downstream error codes, +// records the request error metric, and returns the enriched slice attributes and +// measurement option for the caller to use in MeasureRequestCount/MeasureLatency. +func (f *engineLoaderHooks) recordFetchError( + ctx context.Context, + span trace.Span, + fetchErr error, + reqContext *requestContext, + metricAttrs []attribute.KeyValue, + metricAddOpt otelmetric.AddOption, + metricSliceAttrs []attribute.KeyValue, +) ([]attribute.KeyValue, otelmetric.MeasurementOption) { + rtrace.SetSanitizedSpanStatus(span, codes.Error, fetchErr.Error()) + span.SetAttributes(rotel.WgRequestError.Bool(true)) + span.RecordError(fetchErr) + + // Extract downstream error codes from subgraph errors + var errorCodesAttr []string + + if unwrapped, ok := fetchErr.(multiError); ok { + for _, e := range unwrapped.Unwrap() { + var subgraphError *resolve.SubgraphError + if !errors.As(e, &subgraphError) { + continue + } + + for i, downstreamError := range subgraphError.DownstreamErrors { + var errorCode string + if downstreamError.Extensions != nil { + if value := downstreamError.Extensions.Get("code"); value != nil { + errorCode = string(value.GetStringBytes()) } } + + if errorCode == "" { + continue + } + + errorCodesAttr = append(errorCodesAttr, errorCode) + span.AddEvent(fmt.Sprintf("Downstream error %d", i+1), + trace.WithAttributes( + rotel.WgSubgraphErrorExtendedCode.String(errorCode), + rotel.WgSubgraphErrorMessage.String(downstreamError.Message), + ), + ) } } errorCodesAttr = unique.SliceElements(errorCodesAttr) - // Reduce cardinality of error codes slices.Sort(errorCodesAttr) + } - metricSliceAttrs := *reqContext.telemetry.AcquireAttributes() - defer reqContext.telemetry.ReleaseAttributes(&metricSliceAttrs) - metricSliceAttrs = append(metricSliceAttrs, reqContext.telemetry.metricSliceAttrs...) - - // We can't add this earlier because this is done per subgraph response - if v, ok := reqContext.telemetry.metricSetAttrs[ContextFieldGraphQLErrorCodes]; ok && len(errorCodesAttr) > 0 { - metricSliceAttrs = append(metricSliceAttrs, attribute.StringSlice(v, errorCodesAttr)) - } - - f.metricStore.MeasureRequestError(ctx, metricSliceAttrs, metricAddOpt) + if v, ok := reqContext.telemetry.metricSetAttrs[ContextFieldGraphQLErrorCodes]; ok && len(errorCodesAttr) > 0 { + metricSliceAttrs = append(metricSliceAttrs, attribute.StringSlice(v, errorCodesAttr)) + } - metricAttrs = append(metricAttrs, rotel.WgRequestError.Bool(true)) + f.metricStore.MeasureRequestError(ctx, metricSliceAttrs, metricAddOpt) - attrOpt := otelmetric.WithAttributeSet(attribute.NewSet(metricAttrs...)) - f.metricStore.MeasureRequestCount(ctx, metricSliceAttrs, attrOpt) - f.metricStore.MeasureLatency(ctx, latency, metricSliceAttrs, attrOpt) - } else { - f.metricStore.MeasureRequestCount(ctx, reqContext.telemetry.metricSliceAttrs, metricAddOpt) - f.metricStore.MeasureLatency(ctx, latency, reqContext.telemetry.metricSliceAttrs, metricAddOpt) - } + metricAttrs = append(metricAttrs, rotel.WgRequestError.Bool(true)) + attrOpt := otelmetric.WithAttributeSet(attribute.NewSet(metricAttrs...)) - span.SetAttributes(traceAttrs...) + return metricSliceAttrs, attrOpt } diff --git a/router/core/engine_loader_hooks_test.go b/router/core/engine_loader_hooks_test.go new file mode 100644 index 0000000000..c81de6a0ef --- /dev/null +++ b/router/core/engine_loader_hooks_test.go @@ -0,0 +1,426 @@ +package core + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/wundergraph/astjson" + rcontext "github.com/wundergraph/cosmo/router/internal/context" + rotel "github.com/wundergraph/cosmo/router/pkg/otel" + "github.com/wundergraph/cosmo/router/pkg/trace/tracetest" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + otelmetric "go.opentelemetry.io/otel/metric" + sdktrace "go.opentelemetry.io/otel/sdk/trace" +) + +func setupTestContext(t *testing.T, tp *sdktrace.TracerProvider) (context.Context, *requestContext) { + t.Helper() + + req := httptest.NewRequest(http.MethodPost, "/graphql", nil) + rc := buildRequestContext(requestContextOptions{r: req}) + rc.operation = &operationContext{} + + ctx := context.WithValue(req.Context(), rcontext.RequestContextKey, rc) + + tracer := tp.Tracer("test") + ctx, _ = tracer.Start(ctx, "Engine - Fetch") + ctx = context.WithValue(ctx, rcontext.EngineLoaderHooksContextKey, &engineLoaderHooksRequestContext{ + startTime: time.Now(), + }) + + return ctx, rc +} + +func TestOnFinished_ClientDisconnect(t *testing.T) { + t.Parallel() + + ds := resolve.DataSourceInfo{ + ID: "subgraph-1", + Name: "products", + } + + t.Run("context.Canceled does not set span ERROR status", func(t *testing.T) { + t.Parallel() + + exporter := tracetest.NewInMemoryExporter(t) + tp := sdktrace.NewTracerProvider(sdktrace.WithSyncer(exporter)) + + store := &spyMetricStore{} + hooks := NewEngineRequestHooks(store, nil, tp, nil, nil, nil, false, nil) + + ctx, _ := setupTestContext(t, tp) + + hooks.OnFinished(ctx, ds, &resolve.ResponseInfo{ + Err: context.Canceled, + }) + + spans := exporter.GetSpans().Snapshots() + require.Len(t, spans, 1) + + // Span status should NOT be Error for client disconnects + require.NotEqual(t, codes.Error, spans[0].Status().Code, + "client disconnect should not set span status to Error") + + // The error should still be recorded as an event for observability + require.Len(t, spans[0].Events(), 1, "context.Canceled should be recorded as a span event") + + // MeasureRequestError should NOT be called + require.False(t, store.requestErrorCalled, + "MeasureRequestError should not be called for client disconnects") + }) + + t.Run("real error sets span ERROR status", func(t *testing.T) { + t.Parallel() + + exporter := tracetest.NewInMemoryExporter(t) + tp := sdktrace.NewTracerProvider(sdktrace.WithSyncer(exporter)) + + store := &spyMetricStore{} + hooks := NewEngineRequestHooks(store, nil, tp, nil, nil, nil, false, nil) + + ctx, _ := setupTestContext(t, tp) + + hooks.OnFinished(ctx, ds, &resolve.ResponseInfo{ + Err: errors.New("connection refused"), + }) + + spans := exporter.GetSpans().Snapshots() + require.Len(t, spans, 1) + + // Span status should be Error for real errors + require.Equal(t, codes.Error, spans[0].Status().Code, + "real errors should set span status to Error") + + // MeasureRequestError should be called + require.True(t, store.requestErrorCalled, + "MeasureRequestError should be called for real errors") + }) + + t.Run("wrapped context.Canceled does not set span ERROR status", func(t *testing.T) { + t.Parallel() + + exporter := tracetest.NewInMemoryExporter(t) + tp := sdktrace.NewTracerProvider(sdktrace.WithSyncer(exporter)) + + store := &spyMetricStore{} + hooks := NewEngineRequestHooks(store, nil, tp, nil, nil, nil, false, nil) + + ctx, _ := setupTestContext(t, tp) + + // Simulate a wrapped context.Canceled error (as would happen through net/http) + wrappedErr := fmt.Errorf("fetch failed: %w", context.Canceled) + hooks.OnFinished(ctx, ds, &resolve.ResponseInfo{ + Err: wrappedErr, + }) + + spans := exporter.GetSpans().Snapshots() + require.Len(t, spans, 1) + + require.NotEqual(t, codes.Error, spans[0].Status().Code, + "wrapped context.Canceled should not set span status to Error") + require.False(t, store.requestErrorCalled, + "MeasureRequestError should not be called for wrapped context.Canceled") + }) +} + +func TestRecordFetchError(t *testing.T) { + t.Parallel() + + t.Run("sets span status to error and records error event", func(t *testing.T) { + t.Parallel() + + exporter := tracetest.NewInMemoryExporter(t) + tp := sdktrace.NewTracerProvider(sdktrace.WithSyncer(exporter)) + tracer := tp.Tracer("test") + + ctx, span := tracer.Start(context.Background(), "test-span") + + store := &spyMetricStore{} + hooks := &engineLoaderHooks{metricStore: store} + + rc := buildRequestContext(requestContextOptions{ + r: httptest.NewRequest(http.MethodPost, "/graphql", nil), + }) + rc.operation = &operationContext{} + + fetchErr := errors.New("connection refused") + metricAddOpt := otelmetric.WithAttributeSet(attribute.NewSet()) + + hooks.recordFetchError(ctx, span, fetchErr, rc, nil, metricAddOpt, nil) + span.End() + + spans := exporter.GetSpans().Snapshots() + require.Len(t, spans, 1) + + require.Equal(t, codes.Error, spans[0].Status().Code) + require.Equal(t, "connection refused", spans[0].Status().Description) + + // Should have an exception event + require.Len(t, spans[0].Events(), 1) + require.Equal(t, "exception", spans[0].Events()[0].Name) + + require.True(t, store.requestErrorCalled) + }) + + t.Run("calls MeasureRequestError", func(t *testing.T) { + t.Parallel() + + exporter := tracetest.NewInMemoryExporter(t) + tp := sdktrace.NewTracerProvider(sdktrace.WithSyncer(exporter)) + tracer := tp.Tracer("test") + + _, span := tracer.Start(context.Background(), "test-span") + defer span.End() + + store := &spyMetricStore{} + hooks := &engineLoaderHooks{metricStore: store} + + rc := buildRequestContext(requestContextOptions{ + r: httptest.NewRequest(http.MethodPost, "/graphql", nil), + }) + rc.operation = &operationContext{} + + metricAddOpt := otelmetric.WithAttributeSet(attribute.NewSet()) + hooks.recordFetchError(context.Background(), span, errors.New("fail"), rc, nil, metricAddOpt, nil) + + require.True(t, store.requestErrorCalled, "should call MeasureRequestError") + }) + + t.Run("extracts downstream error codes from subgraph errors", func(t *testing.T) { + t.Parallel() + + exporter := tracetest.NewInMemoryExporter(t) + tp := sdktrace.NewTracerProvider(sdktrace.WithSyncer(exporter)) + tracer := tp.Tracer("test") + + ctx, span := tracer.Start(context.Background(), "test-span") + + store := &spyMetricStore{} + hooks := &engineLoaderHooks{metricStore: store} + + rc := buildRequestContext(requestContextOptions{ + r: httptest.NewRequest(http.MethodPost, "/graphql", nil), + }) + rc.operation = &operationContext{} + + // Build a SubgraphError with downstream errors containing extension codes + subErr := resolve.NewSubgraphError(resolve.DataSourceInfo{Name: "products"}, "query.products", "upstream error", 500) + + parser := astjson.Parser{} + ext, _ := parser.Parse(`{"code":"PRODUCT_NOT_FOUND"}`) + subErr.AppendDownstreamError(&resolve.GraphQLError{ + Message: "product not found", + Extensions: ext, + }) + + ext2, _ := parser.Parse(`{"code":"INVALID_INPUT"}`) + subErr.AppendDownstreamError(&resolve.GraphQLError{ + Message: "invalid input", + Extensions: ext2, + }) + + // Wrap as a multi-error (how the engine returns subgraph errors) + fetchErr := errors.Join(subErr) + + metricAddOpt := otelmetric.WithAttributeSet(attribute.NewSet()) + hooks.recordFetchError(ctx, span, fetchErr, rc, nil, metricAddOpt, nil) + span.End() + + spans := exporter.GetSpans().Snapshots() + require.Len(t, spans, 1) + + // Should have: 1 exception event + 2 downstream error events + require.Len(t, spans[0].Events(), 3) + require.Equal(t, "exception", spans[0].Events()[0].Name) + require.Equal(t, "Downstream error 1", spans[0].Events()[1].Name) + require.Equal(t, "Downstream error 2", spans[0].Events()[2].Name) + + // Verify downstream error attributes + event1Attrs := spans[0].Events()[1].Attributes + require.Contains(t, event1Attrs, rotel.WgSubgraphErrorExtendedCode.String("PRODUCT_NOT_FOUND")) + require.Contains(t, event1Attrs, rotel.WgSubgraphErrorMessage.String("product not found")) + + event2Attrs := spans[0].Events()[2].Attributes + require.Contains(t, event2Attrs, rotel.WgSubgraphErrorExtendedCode.String("INVALID_INPUT")) + require.Contains(t, event2Attrs, rotel.WgSubgraphErrorMessage.String("invalid input")) + }) + + t.Run("handles errors without downstream codes", func(t *testing.T) { + t.Parallel() + + exporter := tracetest.NewInMemoryExporter(t) + tp := sdktrace.NewTracerProvider(sdktrace.WithSyncer(exporter)) + tracer := tp.Tracer("test") + + ctx, span := tracer.Start(context.Background(), "test-span") + + store := &spyMetricStore{} + hooks := &engineLoaderHooks{metricStore: store} + + rc := buildRequestContext(requestContextOptions{ + r: httptest.NewRequest(http.MethodPost, "/graphql", nil), + }) + rc.operation = &operationContext{} + + // SubgraphError with a downstream error that has no extension code + subErr := resolve.NewSubgraphError(resolve.DataSourceInfo{Name: "products"}, "query.products", "upstream error", 500) + subErr.AppendDownstreamError(&resolve.GraphQLError{ + Message: "something went wrong", + // No Extensions + }) + + fetchErr := errors.Join(subErr) + metricAddOpt := otelmetric.WithAttributeSet(attribute.NewSet()) + hooks.recordFetchError(ctx, span, fetchErr, rc, nil, metricAddOpt, nil) + span.End() + + spans := exporter.GetSpans().Snapshots() + require.Len(t, spans, 1) + + // Only the exception event, no downstream error events (no codes to report) + require.Len(t, spans[0].Events(), 1) + require.Equal(t, "exception", spans[0].Events()[0].Name) + + require.True(t, store.requestErrorCalled) + }) + + t.Run("deduplicates and sorts error codes", func(t *testing.T) { + t.Parallel() + + exporter := tracetest.NewInMemoryExporter(t) + tp := sdktrace.NewTracerProvider(sdktrace.WithSyncer(exporter)) + tracer := tp.Tracer("test") + + ctx, span := tracer.Start(context.Background(), "test-span") + + store := &spyMetricStore{} + hooks := &engineLoaderHooks{metricStore: store} + + rc := buildRequestContext(requestContextOptions{ + r: httptest.NewRequest(http.MethodPost, "/graphql", nil), + }) + rc.operation = &operationContext{} + rc.telemetry.metricSetAttrs = map[string]string{ + ContextFieldGraphQLErrorCodes: "graphql.error.codes", + } + + parser := astjson.Parser{} + + // Two subgraph errors with duplicate and unsorted codes + subErr1 := resolve.NewSubgraphError(resolve.DataSourceInfo{Name: "products"}, "query.products", "err1", 500) + ext1, _ := parser.Parse(`{"code":"ZEBRA_ERROR"}`) + subErr1.AppendDownstreamError(&resolve.GraphQLError{Message: "z", Extensions: ext1}) + ext2, _ := parser.Parse(`{"code":"ALPHA_ERROR"}`) + subErr1.AppendDownstreamError(&resolve.GraphQLError{Message: "a", Extensions: ext2}) + + subErr2 := resolve.NewSubgraphError(resolve.DataSourceInfo{Name: "users"}, "query.users", "err2", 500) + ext3, _ := parser.Parse(`{"code":"ALPHA_ERROR"}`) // duplicate + subErr2.AppendDownstreamError(&resolve.GraphQLError{Message: "a2", Extensions: ext3}) + + fetchErr := errors.Join(subErr1, subErr2) + metricAddOpt := otelmetric.WithAttributeSet(attribute.NewSet()) + hooks.recordFetchError(ctx, span, fetchErr, rc, nil, metricAddOpt, nil) + span.End() + + // Find the error codes attribute captured by the spy metric store + var foundCodes []string + for _, attr := range store.requestErrorSliceAttr { + if string(attr.Key) == "graphql.error.codes" { + foundCodes = attr.Value.AsStringSlice() + } + } + + require.Equal(t, []string{"ALPHA_ERROR", "ZEBRA_ERROR"}, foundCodes, + "error codes should be deduplicated and sorted") + }) + + t.Run("preserves pre-populated slice attrs alongside error codes", func(t *testing.T) { + t.Parallel() + + exporter := tracetest.NewInMemoryExporter(t) + tp := sdktrace.NewTracerProvider(sdktrace.WithSyncer(exporter)) + tracer := tp.Tracer("test") + + ctx, span := tracer.Start(context.Background(), "test-span") + + store := &spyMetricStore{} + hooks := &engineLoaderHooks{metricStore: store} + + rc := buildRequestContext(requestContextOptions{ + r: httptest.NewRequest(http.MethodPost, "/graphql", nil), + }) + rc.operation = &operationContext{} + rc.telemetry.metricSetAttrs = map[string]string{ + ContextFieldGraphQLErrorCodes: "graphql.error.codes", + } + + parser := astjson.Parser{} + subErr := resolve.NewSubgraphError(resolve.DataSourceInfo{Name: "products"}, "query.products", "err", 500) + ext, _ := parser.Parse(`{"code":"SOME_CODE"}`) + subErr.AppendDownstreamError(&resolve.GraphQLError{Message: "m", Extensions: ext}) + + fetchErr := errors.Join(subErr) + metricAddOpt := otelmetric.WithAttributeSet(attribute.NewSet()) + + // Simulate the caller pattern: pass a pre-populated slice (like AcquireAttributes + append base attrs) + prePopulated := []attribute.KeyValue{ + attribute.String("existing.attr", "value"), + } + + resultSlice, _ := hooks.recordFetchError(ctx, span, fetchErr, rc, nil, metricAddOpt, prePopulated) + span.End() + + // The returned slice should contain both the pre-existing attr and the error codes + var hasExisting, hasErrorCodes bool + for _, attr := range resultSlice { + if string(attr.Key) == "existing.attr" { + hasExisting = true + } + if string(attr.Key) == "graphql.error.codes" { + hasErrorCodes = true + } + } + require.True(t, hasExisting, "pre-populated attrs should be preserved") + require.True(t, hasErrorCodes, "error codes should be appended") + }) + + t.Run("plain error without multi-error wrapper", func(t *testing.T) { + t.Parallel() + + exporter := tracetest.NewInMemoryExporter(t) + tp := sdktrace.NewTracerProvider(sdktrace.WithSyncer(exporter)) + tracer := tp.Tracer("test") + + ctx, span := tracer.Start(context.Background(), "test-span") + + store := &spyMetricStore{} + hooks := &engineLoaderHooks{metricStore: store} + + rc := buildRequestContext(requestContextOptions{ + r: httptest.NewRequest(http.MethodPost, "/graphql", nil), + }) + rc.operation = &operationContext{} + + // A plain error (not a multi-error, not a SubgraphError) + fetchErr := errors.New("dial tcp: connection refused") + metricAddOpt := otelmetric.WithAttributeSet(attribute.NewSet()) + hooks.recordFetchError(ctx, span, fetchErr, rc, nil, metricAddOpt, nil) + span.End() + + spans := exporter.GetSpans().Snapshots() + require.Len(t, spans, 1) + + require.Equal(t, codes.Error, spans[0].Status().Code) + // Only the exception event, no downstream error events + require.Len(t, spans[0].Events(), 1) + require.True(t, store.requestErrorCalled) + }) +} diff --git a/router/core/errors.go b/router/core/errors.go index 7c5de91f65..c287eb2ea9 100644 --- a/router/core/errors.go +++ b/router/core/errors.go @@ -342,6 +342,16 @@ func writeOperationError(r *http.Request, w http.ResponseWriter, requestLogger * var httpErr HttpError var poNotFoundErr *persistedoperation.PersistentOperationNotFoundError switch { + case errors.Is(err, context.Canceled): + newErr := NewHttpGraphqlError("request canceled", "REQUEST_CANCELED", http.StatusOK) + writeRequestErrors(writeRequestErrorsParams{ + request: r, + writer: w, + statusCode: http.StatusOK, + requestErrors: requestErrorsFromHttpError(newErr), + logger: requestLogger, + headerPropagation: propagation, + }) case errors.As(err, &httpErr): writeRequestErrors(writeRequestErrorsParams{ request: r, diff --git a/router/core/graphql_prehandler.go b/router/core/graphql_prehandler.go index 8c6bbb0af6..1363d637bc 100644 --- a/router/core/graphql_prehandler.go +++ b/router/core/graphql_prehandler.go @@ -419,14 +419,19 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler { }) if err != nil { requestContext.SetError(err) - // Mark the root span of the router as failed, so we can easily identify failed requests - rtrace.AttachErrToSpan(routerSpan, err) - - if h.operationProcessor.costControl != nil && h.operationProcessor.costControl.ExposeHeaders && - // Report the estimated cost in case of errors. - // The actual cost is only available for successful requests. - requestContext.operation != nil && requestContext.operation.costEstimatedSet { - ww.Header().Set(CostEstimatedHeader, strconv.Itoa(requestContext.operation.costEstimated)) + if errors.Is(err, context.Canceled) { + // Client disconnections are not server-side errors and should not mark the root span as ERROR + // or write error responses (which would produce a 500 status code visible to otelhttp). + routerSpan.RecordError(err) + } else { + // Mark the root span of the router as failed, so we can easily identify failed requests. + rtrace.AttachErrToSpan(routerSpan, err) + if h.operationProcessor.costControl != nil && h.operationProcessor.costControl.ExposeHeaders && + // Report the estimated cost in case of errors. + // The actual cost is only available for successful requests. + requestContext.operation != nil && requestContext.operation.costEstimatedSet { + ww.Header().Set(CostEstimatedHeader, strconv.Itoa(requestContext.operation.costEstimated)) + } } writeOperationError(r, ww, requestLogger, err, h.headerPropagation) @@ -624,7 +629,9 @@ func (h *PreHandler) handleOperation(req *http.Request, httpOperation *httpOpera span.SetAttributes(otel.WgEnginePersistedOperationCacheHit.Bool(operationKit.parsedOperation.PersistedOperationCacheHit)) if err != nil { span.RecordError(err) - rtrace.SetSanitizedSpanStatus(span, codes.Error, err.Error()) + if !errors.Is(err, context.Canceled) { + rtrace.SetSanitizedSpanStatus(span, codes.Error, err.Error()) + } var poNotFoundErr *persistedoperation.PersistentOperationNotFoundError if h.operationBlocker.logUnknownOperationsEnabled && errors.As(err, &poNotFoundErr) { diff --git a/router/core/operation_metrics_test.go b/router/core/operation_metrics_test.go index 771b61ab37..ea3ff1ae6a 100644 --- a/router/core/operation_metrics_test.go +++ b/router/core/operation_metrics_test.go @@ -130,11 +130,13 @@ func (m *spyRouterMetrics) MetricStore() metric.Store { type spyMetricStore struct { metric.NoopMetrics - requestErrorCalled bool + requestErrorCalled bool + requestErrorSliceAttr []attribute.KeyValue } -func (m *spyMetricStore) MeasureRequestError(_ context.Context, _ []attribute.KeyValue, _ otelmetric.AddOption) { +func (m *spyMetricStore) MeasureRequestError(_ context.Context, sliceAttr []attribute.KeyValue, _ otelmetric.AddOption) { m.requestErrorCalled = true + m.requestErrorSliceAttr = append(m.requestErrorSliceAttr, sliceAttr...) } func newTestRequestContext(t *testing.T) *requestContext { diff --git a/router/pkg/trace/transport.go b/router/pkg/trace/transport.go index cd2726ca0b..a15c34f735 100644 --- a/router/pkg/trace/transport.go +++ b/router/pkg/trace/transport.go @@ -1,12 +1,16 @@ package trace import ( + gocontext "context" + "errors" "net/http" "sync/atomic" "time" "github.com/wundergraph/cosmo/router/internal/context" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" ) type TransportOption func(svr *transport) @@ -50,7 +54,13 @@ func (t *transport) RoundTrip(r *http.Request) (*http.Response, error) { } // In case of a roundtrip error the span status is set to error by the otelhttp.RoundTrip function. - // Also, status code >= 500 is considered an error + // Also, status code >= 500 is considered an error. + // Client disconnections (context.Canceled) are not server-side errors. Pre-set the span + // status to Ok so that otelhttp cannot override it with Error (per OTel spec, Ok is final). + if err != nil && errors.Is(err, gocontext.Canceled) { + span := trace.SpanFromContext(r.Context()) + span.SetStatus(codes.Ok, "client disconnected") + } return res, err } diff --git a/router/pkg/trace/transport_test.go b/router/pkg/trace/transport_test.go index 1d4e582db1..4d73f8ac89 100644 --- a/router/pkg/trace/transport_test.go +++ b/router/pkg/trace/transport_test.go @@ -3,14 +3,16 @@ package trace import ( "bytes" "context" - "github.com/wundergraph/cosmo/router/pkg/otel" - "github.com/wundergraph/cosmo/router/pkg/trace/tracetest" "io" "net/http" "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/pkg/otel" + "github.com/wundergraph/cosmo/router/pkg/trace/tracetest" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/otel/codes" sdktrace "go.opentelemetry.io/otel/sdk/trace" @@ -124,4 +126,47 @@ func TestTransport(t *testing.T) { assert.Contains(t, sn[0].Attributes(), semconv.HTTPStatusCode(http.StatusInternalServerError)) assert.Contains(t, sn[0].Attributes(), otel.WgComponentName.String("test")) }) + + t.Run("context canceled does not set span status to error", func(t *testing.T) { + exporter := tracetest.NewInMemoryExporter(t) + + // Slow server that takes longer than the client will wait + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(2 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + // Use WithCancel to simulate a client disconnect (produces context.Canceled) + ctx, cancel := context.WithCancel(context.Background()) + + r, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL+"/test", nil) + require.NoError(t, err) + + tr := NewTransport(http.DefaultTransport, []otelhttp.Option{ + otelhttp.WithSpanOptions(trace.WithAttributes(otel.WgComponentName.String("test"))), + otelhttp.WithTracerProvider(sdktrace.NewTracerProvider(sdktrace.WithSyncer(exporter))), + }) + + // Cancel the context after a short delay to simulate client disconnect + time.AfterFunc(50*time.Millisecond, cancel) + + c := http.Client{Transport: tr} + _, err = c.Do(r) + require.Error(t, err) + + sn := exporter.GetSpans().Snapshots() + require.Len(t, sn, 1) + + span := sn[0] + + require.Equal(t, "HTTP GET", span.Name()) + + // The span should NOT be marked as Error for client disconnections. + // Our transport pre-sets Ok to prevent otelhttp from overriding with Error. + require.NotEqual(t, codes.Error, span.Status().Code, + "context.Canceled should not produce an Error span status") + require.Equal(t, codes.Ok, span.Status().Code, + "context.Canceled should produce an Ok span status (prevents otelhttp override)") + }) }