Skip to content
148 changes: 145 additions & 3 deletions router-tests/prometheus_improved_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package integration

import (
rmetric "github.com/wundergraph/cosmo/router/pkg/metric"
"regexp"
"testing"

rmetric "github.com/wundergraph/cosmo/router/pkg/metric"

"github.com/prometheus/client_golang/prometheus"
io_prometheus_client "github.com/prometheus/client_model/go"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -36,7 +37,8 @@ func TestPrometheusSchemaUsage(t *testing.T) {
PrometheusRegistry: promRegistry,
MetricOptions: testenv.MetricOptions{
PrometheusSchemaFieldUsage: testenv.PrometheusSchemaFieldUsage{
Enabled: true,
Enabled: true,
SampleRate: 1.0,
},
},
}, func(t *testing.T, xEnv *testenv.Environment) {
Expand Down Expand Up @@ -128,7 +130,8 @@ query myQuery {
PrometheusRegistry: promRegistry,
MetricOptions: testenv.MetricOptions{
PrometheusSchemaFieldUsage: testenv.PrometheusSchemaFieldUsage{
Enabled: true,
Enabled: true,
SampleRate: 1.0,
},
},
}, func(t *testing.T, xEnv *testenv.Environment) {
Expand Down Expand Up @@ -203,6 +206,7 @@ query myQuery {
PrometheusSchemaFieldUsage: testenv.PrometheusSchemaFieldUsage{
Enabled: true,
IncludeOperationSha: false,
SampleRate: 1.0,
},
},
}, func(t *testing.T, xEnv *testenv.Environment) {
Expand Down Expand Up @@ -240,6 +244,7 @@ query myQuery {
PrometheusSchemaFieldUsage: testenv.PrometheusSchemaFieldUsage{
Enabled: true,
IncludeOperationSha: true,
SampleRate: 1.0,
},
},
}, func(t *testing.T, xEnv *testenv.Environment) {
Expand Down Expand Up @@ -283,6 +288,7 @@ query myQuery {
PrometheusSchemaFieldUsage: testenv.PrometheusSchemaFieldUsage{
Enabled: true,
IncludeOperationSha: false,
SampleRate: 1.0,
},
},
}, func(t *testing.T, xEnv *testenv.Environment) {
Expand All @@ -308,6 +314,142 @@ query myQuery {
}
})
})

t.Run("sampling reduces tracked requests", func(t *testing.T) {
t.Parallel()

metricReader := metric.NewManualReader()
promRegistry := prometheus.NewRegistry()

testenv.Run(t, &testenv.Config{
MetricReader: metricReader,
PrometheusRegistry: promRegistry,
MetricOptions: testenv.MetricOptions{
PrometheusSchemaFieldUsage: testenv.PrometheusSchemaFieldUsage{
Enabled: true,
SampleRate: 0.1, // 10% sampling
},
},
}, func(t *testing.T, xEnv *testenv.Environment) {
// Make 100 requests
for i := 0; i < 100; i++ {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `query myQuery { employee(id: 1) { id } }`,
})
require.JSONEq(t, `{"data":{"employee":{"id":1}}}`, res.Body)
}

mf, err := promRegistry.Gather()
require.NoError(t, err)

schemaUsage := findMetricFamilyByName(mf, SchemaFieldUsageMetricName)
assert.NotNil(t, schemaUsage)

schemaUsageMetrics := schemaUsage.GetMetric()

require.Greater(t, len(schemaUsageMetrics), 0, "At least 1 request should be sampled")

// With 10% sampling and 100 requests, each sampled request increments two field counters (`employee` and `id`).
// 100% sampling would produce 200 total field counts (100 requests * 2 fields), so a reduced total confirms sampling worked.
totalFieldCounts := 0.0
for _, m := range schemaUsageMetrics {
counter := m.GetCounter()
require.NotNil(t, counter)
totalFieldCounts += counter.GetValue()
}

require.Greater(t, totalFieldCounts, 0.0, "At least one sampled field is expected with a 10% sample rate")
require.Less(t, totalFieldCounts, 200.0, "Sampling should record fewer than 100% of requests (200 total field counts)")

// Verify that the sampled metrics have correct structure
for _, m := range schemaUsageMetrics {
assertLabelValue(t, m.Label, otel.WgOperationName, "myQuery")
assertLabelValue(t, m.Label, otel.WgOperationType, "query")
}
})
})
Comment thread
StarpTech marked this conversation as resolved.

t.Run("100% sample rate tracks all requests", func(t *testing.T) {
t.Parallel()

metricReader := metric.NewManualReader()
promRegistry := prometheus.NewRegistry()

testenv.Run(t, &testenv.Config{
MetricReader: metricReader,
PrometheusRegistry: promRegistry,
MetricOptions: testenv.MetricOptions{
PrometheusSchemaFieldUsage: testenv.PrometheusSchemaFieldUsage{
Enabled: true,
SampleRate: 1.0, // 100% sampling (default)
},
},
}, func(t *testing.T, xEnv *testenv.Environment) {
// Make 10 requests
for range 10 {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `query myQuery { employee(id: 1) { id } }`,
})
require.JSONEq(t, `{"data":{"employee":{"id":1}}}`, res.Body)
}

mf, err := promRegistry.Gather()
require.NoError(t, err)

schemaUsage := findMetricFamilyByName(mf, SchemaFieldUsageMetricName)
assert.NotNil(t, schemaUsage)

schemaUsageMetrics := schemaUsage.GetMetric()

// With 100% sampling and 10 requests, we expect 2 metrics (employee, id)
// The counter values should be 10 for each field
require.Len(t, schemaUsageMetrics, 2)

for _, metric := range schemaUsageMetrics {
assertLabelValue(t, metric.Label, otel.WgOperationName, "myQuery")
assertLabelValue(t, metric.Label, otel.WgOperationType, "query")

// Each field should have been counted 10 times (once per request)
assert.InEpsilon(t, 10.0, *metric.Counter.Value, 0.0001)
}
})
})

t.Run("0% sample rate tracks no requests", func(t *testing.T) {
t.Parallel()

metricReader := metric.NewManualReader()
promRegistry := prometheus.NewRegistry()

testenv.Run(t, &testenv.Config{
MetricReader: metricReader,
PrometheusRegistry: promRegistry,
MetricOptions: testenv.MetricOptions{
PrometheusSchemaFieldUsage: testenv.PrometheusSchemaFieldUsage{
Enabled: true,
SampleRate: 0.0, // 0% sampling
},
},
}, func(t *testing.T, xEnv *testenv.Environment) {
// Make 10 requests
for range 10 {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `query myQuery { employee(id: 1) { id } }`,
})
require.JSONEq(t, `{"data":{"employee":{"id":1}}}`, res.Body)
}

mf, err := promRegistry.Gather()
require.NoError(t, err)

schemaUsage := findMetricFamilyByName(mf, SchemaFieldUsageMetricName)

// With 0% sampling, no metrics should be recorded
if schemaUsage != nil {
require.Len(t, schemaUsage.GetMetric(), 0, "No metrics should be recorded with 0% sampling")
}
})
})
}

func assertLabelNotPresent(t *testing.T, labels []*io_prometheus_client.LabelPair, labelKey attribute.Key) {
Expand Down
2 changes: 2 additions & 0 deletions router-tests/testenv/testenv.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ type MetricOptions struct {
type PrometheusSchemaFieldUsage struct {
Enabled bool
IncludeOperationSha bool
SampleRate float64
}

type Config struct {
Expand Down Expand Up @@ -1516,6 +1517,7 @@ func configureRouter(listenerAddr string, testConfig *Config, routerConfig *node
PromSchemaFieldUsage: rmetric.PrometheusSchemaFieldUsage{
Enabled: testConfig.MetricOptions.PrometheusSchemaFieldUsage.Enabled,
IncludeOperationSha: testConfig.MetricOptions.PrometheusSchemaFieldUsage.IncludeOperationSha,
SampleRate: testConfig.MetricOptions.PrometheusSchemaFieldUsage.SampleRate,
},
}
}
Expand Down
5 changes: 3 additions & 2 deletions router/core/graph_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -919,8 +919,9 @@ func (s *graphServer) buildGraphMux(
routerConfigVersion: opts.RouterConfigVersion,
logger: s.logger,

promSchemaUsageEnabled: s.metricConfig.Prometheus.PromSchemaFieldUsage.Enabled,
promSchemaUsageIncludeOperationSha: s.metricConfig.Prometheus.PromSchemaFieldUsage.IncludeOperationSha,
promSchemaUsageEnabled: s.metricConfig.Prometheus.PromSchemaFieldUsage.Enabled,
promSchemaUsageIncludeOpSha: s.metricConfig.Prometheus.PromSchemaFieldUsage.IncludeOperationSha,
promSchemaUsageSampleRate: s.metricConfig.Prometheus.PromSchemaFieldUsage.SampleRate,
})

baseLogFields := []zapcore.Field{
Expand Down
73 changes: 61 additions & 12 deletions router/core/operation_metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package core

import (
"context"
"math/rand/v2"
"slices"
"time"

Expand Down Expand Up @@ -38,8 +39,14 @@ type OperationMetrics struct {
logger *zap.Logger
trackUsageInfo bool

promSchemaUsageEnabled bool
promSchemaUsageIncludeOperationSha bool
promSchemaUsageEnabled bool
promSchemaUsageIncludeOpSha bool
promSchemaUsageSampleRate float64
}

type usageKey struct {
fieldName string
parentType string
}

func (m *OperationMetrics) Finish(reqContext *requestContext, statusCode int, responseSize int, exportSynchronous bool) {
Expand Down Expand Up @@ -82,28 +89,46 @@ func (m *OperationMetrics) Finish(reqContext *requestContext, statusCode int, re
}

// Prometheus usage metrics, disabled by default
if m.promSchemaUsageEnabled && reqContext.operation != nil && !reqContext.operation.executionOptions.SkipLoader {
if m.promSchemaUsageEnabled && reqContext.operation != nil {

if !m.shouldSampleOperation() {
return
}

opAttrs := []attribute.KeyValue{
rotel.WgOperationName.String(reqContext.operation.name),
rotel.WgOperationType.String(reqContext.operation.opType),
}

if m.promSchemaUsageIncludeOperationSha && reqContext.operation.sha256Hash != "" {
// Include operation SHA256 if enabled
if m.promSchemaUsageIncludeOpSha && reqContext.operation.sha256Hash != "" {
opAttrs = append(opAttrs, rotel.WgOperationSha256.String(reqContext.operation.sha256Hash))
}

usageCounts := make(map[usageKey]int)

for _, field := range reqContext.operation.typeFieldUsageInfo {
if field.ExactParentTypeName == "" {
if field.ExactParentTypeName == "" || len(field.Path) == 0 {
continue
}

key := usageKey{
fieldName: field.Path[len(field.Path)-1],
parentType: field.ExactParentTypeName,
}

usageCounts[key]++
}

for key, count := range usageCounts {
fieldAttrs := []attribute.KeyValue{
rotel.WgGraphQLFieldName.String(field.Path[len(field.Path)-1]),
rotel.WgGraphQLParentType.String(field.ExactParentTypeName),
rotel.WgGraphQLFieldName.String(key.fieldName),
rotel.WgGraphQLParentType.String(key.parentType),
}

rm.MeasureSchemaFieldUsage(ctx, 1, []attribute.KeyValue{}, otelmetric.WithAttributeSet(attribute.NewSet(slices.Concat(opAttrs, fieldAttrs)...)))
rm.MeasureSchemaFieldUsage(ctx, int64(count), []attribute.KeyValue{}, otelmetric.WithAttributeSet(attribute.NewSet(slices.Concat(opAttrs, fieldAttrs)...)))
Comment thread
StarpTech marked this conversation as resolved.
}

}
}

Expand All @@ -116,8 +141,9 @@ type OperationMetricsOptions struct {
Logger *zap.Logger
TrackUsageInfo bool

PrometheusSchemaUsageEnabled bool
PrometheusSchemaUsageIncludeSha bool
PrometheusSchemaUsageEnabled bool
PrometheusSchemaUsageIncludeOpSha bool
PrometheusSchemaUsageSampleRate float64
}

// newOperationMetrics creates a new OperationMetrics struct and starts the operation metrics.
Expand All @@ -135,7 +161,30 @@ func newOperationMetrics(opts OperationMetricsOptions) *OperationMetrics {
logger: opts.Logger,
trackUsageInfo: opts.TrackUsageInfo,

promSchemaUsageEnabled: opts.PrometheusSchemaUsageEnabled,
promSchemaUsageIncludeOperationSha: opts.PrometheusSchemaUsageIncludeSha,
promSchemaUsageEnabled: opts.PrometheusSchemaUsageEnabled,
promSchemaUsageIncludeOpSha: opts.PrometheusSchemaUsageIncludeOpSha,
promSchemaUsageSampleRate: opts.PrometheusSchemaUsageSampleRate,
}
}

// shouldSampleOperation determines if a request should be sampled for schema field usage metrics.
// Uses probabilistic random sampling to ensure uniform distribution across all operations.
//
// This ensures:
// - All operations get statistical coverage (~X% of requests per operation)
// - Uniform distribution regardless of request ID format
// - Supports ANY sample rate (0.0 to 1.0), including arbitrary values like 0.8, 0.156, etc.
//
// Note: Uses non-deterministic random sampling rather than hash-based sampling because
// sequential request IDs produce clustered hash values that break deterministic sampling.
Comment on lines +178 to +179
Copy link
Copy Markdown
Contributor

@ysmolski ysmolski Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Kind of contradictory statement. "Uses non-deterministic" conflicts with "clustered hash values that break deterministic sampling".

Suggested change
// Note: Uses non-deterministic random sampling rather than hash-based sampling because
// sequential request IDs produce clustered hash values that break deterministic sampling.
// Note: Uses non-deterministic random sampling rather than hash-based sampling because
// sequential request IDs produce clustered hash values that are not distributed uniformly.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't seen this. Thanks for the suggestions.

func (m *OperationMetrics) shouldSampleOperation() bool {
if m.promSchemaUsageSampleRate >= 1.0 {
return true
}
if m.promSchemaUsageSampleRate <= 0.0 {
return false
}

// Probabilistic sampling: simple, reliable, and guaranteed uniform distribution
return rand.Float64() < m.promSchemaUsageSampleRate
}
1 change: 1 addition & 0 deletions router/core/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -2312,6 +2312,7 @@ func MetricConfigFromTelemetry(cfg *config.Telemetry) *rmetric.Config {
PromSchemaFieldUsage: rmetric.PrometheusSchemaFieldUsage{
Enabled: cfg.Metrics.Prometheus.SchemaFieldUsage.Enabled,
IncludeOperationSha: cfg.Metrics.Prometheus.SchemaFieldUsage.IncludeOperationSha,
SampleRate: cfg.Metrics.Prometheus.SchemaFieldUsage.SampleRate,
},
},
}
Expand Down
Loading
Loading