diff --git a/router-tests/modules/custom-trace-propagator/module.go b/router-tests/modules/custom-trace-propagator/module.go new file mode 100644 index 0000000000..6a7b10a6c3 --- /dev/null +++ b/router-tests/modules/custom-trace-propagator/module.go @@ -0,0 +1,126 @@ +package custom_trace_propagator + +import ( + "context" + "fmt" + + "go.opentelemetry.io/otel/propagation" + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" + + "github.com/wundergraph/cosmo/router/core" +) + +func init() { + // Register your module here + core.RegisterModule(&CustomTracePropagatorModule{}) +} + +const myModuleID = "tracePropagatorModule" + +// CustomTracePropagatorModule is a simple module that provides a custom trace propagator for the router +type CustomTracePropagatorModule struct { + Value uint64 `mapstructure:"value"` + Propagator *customPropagator + Logger *zap.Logger +} + +func (m *CustomTracePropagatorModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + // This is the ID of your module, it must be unique + ID: myModuleID, + // The priority of your module, lower numbers are executed first + Priority: 1, + New: func() core.Module { + return &CustomTracePropagatorModule{ + Propagator: &customPropagator{}, + } + }, + } +} + +type ctxKeyCustomPropagator string + +const ctxKey = "CustomPropagator" + +func (m *CustomTracePropagatorModule) TracePropagators() []propagation.TextMapPropagator { + return []propagation.TextMapPropagator{m.Propagator} +} + +type customPropagator struct { + InjectCalled int + ExtractCalled int +} + +type info struct { + injectCalled int + extractCalled int +} + +func parse(s string) *info { + var i info + + _, err := fmt.Sscanf(s, "injectCalled:%d, extractCalled:%d", &i.injectCalled, &i.extractCalled) + if err != nil { + return nil + } + return &i +} + +func (i *info) String() string { + return fmt.Sprintf("injectCalled:%d, extractCalled:%d", i.injectCalled, i.extractCalled) +} + +func (c *customPropagator) Inject(ctx context.Context, carrier propagation.TextMapCarrier) { + c.InjectCalled++ + var i info + + switch v := ctx.Value(ctxKeyCustomPropagator(ctxKey)).(type) { + case *info: + i = *v + default: + } + + i.injectCalled = c.InjectCalled + carrier.Set(ctxKey, i.String()) +} + +func (c *customPropagator) Extract(ctx context.Context, carrier propagation.TextMapCarrier) context.Context { + c.ExtractCalled++ + + cStr := carrier.Get(ctxKey) + + i := parse(cStr) + if i == nil { + return ctx + } + + // create a fantasy trace ID for testing purposes + sID := "acde00000000000000000000eeeeffff" + + tid, err := trace.TraceIDFromHex(sID) + if err != nil { + return ctx + } + + sc := trace.SpanFromContext(ctx).SpanContext() + + ssc := trace.NewSpanContext(trace.SpanContextConfig{ + TraceID: tid, + SpanID: sc.SpanID(), + TraceFlags: sc.TraceFlags(), + TraceState: sc.TraceState(), + Remote: sc.IsRemote(), + }) + + i.extractCalled = c.ExtractCalled + ctx = context.WithValue(ctx, ctxKeyCustomPropagator(ctxKey), i) + + return trace.ContextWithSpanContext(ctx, ssc) +} + +func (c *customPropagator) Fields() []string { + return []string{ctxKey} +} + +var _ propagation.TextMapPropagator = (*customPropagator)(nil) diff --git a/router-tests/modules/custom_trace_propagator_test.go b/router-tests/modules/custom_trace_propagator_test.go new file mode 100644 index 0000000000..0b2e51f0b9 --- /dev/null +++ b/router-tests/modules/custom_trace_propagator_test.go @@ -0,0 +1,71 @@ +package module + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + custom_trace_propagator "github.com/wundergraph/cosmo/router-tests/modules/custom-trace-propagator" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/cmd/custom/module" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/config" + rtrace "github.com/wundergraph/cosmo/router/pkg/trace" + "github.com/wundergraph/cosmo/router/pkg/trace/tracetest" +) + +func TestModuleCustomPropagator(t *testing.T) { + t.Run("Should set custom trace propagator with a custom trace ID", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "myModule": module.MyModule{Value: 1}, + "custom_trace": custom_trace_propagator.CustomTracePropagatorModule{Value: 2}, + }, + } + + exporter := tracetest.NewInMemoryExporter(t) + + testenv.Run(t, &testenv.Config{ + Subgraphs: testenv.SubgraphsConfig{ + GlobalMiddleware: func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + require.Equal(t, "injectCalled:1, extractCalled:1", request.Header.Get("CustomPropagator")) + handler.ServeHTTP(writer, request) + }) + }, + }, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithTracing(&rtrace.Config{ + Enabled: true, + }), + }, + TraceExporter: exporter, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query MyQuery { employees { id } }`, + OperationName: json.RawMessage(`"MyQuery"`), + Header: map[string][]string{ + "CustomPropagator": {"injectCalled:0, extractCalled:0"}, + }, + }) + require.NoError(t, err) + assert.Equal(t, 200, res.Response.StatusCode) + assert.JSONEq(t, res.Body, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`) + + spans := exporter.GetSpans().Snapshots() + + // check that our fantasy trace ID is set + for _, s := range spans { + traceIDStr := s.SpanContext().TraceID().String() + require.Equal(t, "acde00000000000000000000eeeeffff", traceIDStr) + } + }) + }) +} diff --git a/router/core/graph_server.go b/router/core/graph_server.go index e06dec4c42..6920395ad1 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -781,8 +781,8 @@ func (s *graphServer) buildGraphMux(ctx context.Context, otelhttp.WithTracerProvider(s.tracerProvider), } - if s.tracePropagators != nil { - middlewareOptions = append(middlewareOptions, otelhttp.WithPropagators(s.tracePropagators)) + if s.compositePropagator != nil { + middlewareOptions = append(middlewareOptions, otelhttp.WithPropagators(s.compositePropagator)) } traceHandler := rtrace.NewMiddleware( @@ -879,7 +879,7 @@ func (s *graphServer) buildGraphMux(ctx context.Context, }, }, TracerProvider: s.tracerProvider, - TracePropagators: s.tracePropagators, + TracePropagators: s.compositePropagator, LocalhostFallbackInsideDocker: s.localhostFallbackInsideDocker, Logger: s.logger, }, diff --git a/router/core/modules.go b/router/core/modules.go index 91428e8a2d..434344bd9e 100644 --- a/router/core/modules.go +++ b/router/core/modules.go @@ -3,6 +3,7 @@ package core import ( stdContext "context" "fmt" + "go.opentelemetry.io/otel/propagation" "math" "net/http" "sort" @@ -123,6 +124,14 @@ type EnginePostOriginHandler interface { OnOriginResponse(resp *http.Response, ctx RequestContext) *http.Response } +// TracePropagationProvider is an interface that allows you to provide custom trace propagators. +// The trace propagators are used to inject and extract trace information from the request. +// The provided propagators will be used in addition to the configured propagators. +type TracePropagationProvider interface { + // TracePropagators returns the custom trace propagators which should be used by the router. + TracePropagators() []propagation.TextMapPropagator +} + // Provisioner is called before the server starts // It allows you to initialize your module e.g. create a database connection // or load a configuration file diff --git a/router/core/router.go b/router/core/router.go index f3e1a8ad9f..4971893f05 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -206,7 +206,8 @@ type ( tlsServerConfig *tls.Config tlsConfig *TlsConfig telemetryAttributes []config.CustomAttribute - tracePropagators propagation.TextMapPropagator + tracePropagators []propagation.TextMapPropagator + compositePropagator propagation.TextMapPropagator // Poller configPoller configpoller.ConfigPoller selfRegister selfregister.SelfRegister @@ -442,18 +443,12 @@ func NewRouter(opts ...Option) (*Router, error) { if r.traceConfig.Enabled { if len(r.traceConfig.Propagators) > 0 { - propagators, err := rtrace.NewCompositePropagator(r.traceConfig.Propagators...) + propagators, err := rtrace.BuildPropagators(r.traceConfig.Propagators...) if err != nil { r.logger.Error("creating propagators", zap.Error(err)) return nil, err } - // Don't set it globally when we use the router in tests. - // In practice, setting it globally only makes sense for module development. - if r.traceConfig.TestMemoryExporter == nil { - otel.SetTextMapPropagator(propagators) - } - r.tracePropagators = propagators } @@ -674,6 +669,13 @@ func (r *Router) initModules(ctx context.Context) error { r.postOriginHandlers = append(r.postOriginHandlers, handler.OnOriginResponse) } + if handler, ok := moduleInstance.(TracePropagationProvider); ok { + modulePropagators := handler.TracePropagators() + if len(modulePropagators) > 0 { + r.tracePropagators = append(r.tracePropagators, modulePropagators...) + } + } + r.modules = append(r.modules, moduleInstance) r.logger.Info("Module registered", @@ -884,6 +886,16 @@ func (r *Router) bootstrap(ctx context.Context) error { return fmt.Errorf("failed to init user modules: %w", err) } + if r.traceConfig.Enabled && len(r.tracePropagators) > 0 { + r.compositePropagator = propagation.NewCompositeTextMapPropagator(r.tracePropagators...) + + // Don't set it globally when we use the router in tests. + // In practice, setting it globally only makes sense for module development. + if r.traceConfig.TestMemoryExporter == nil { + otel.SetTextMapPropagator(r.compositePropagator) + } + } + return nil } diff --git a/router/pkg/trace/propagation.go b/router/pkg/trace/propagation.go index dfaf317c8b..d62a92c65a 100644 --- a/router/pkg/trace/propagation.go +++ b/router/pkg/trace/propagation.go @@ -8,7 +8,7 @@ import ( "go.opentelemetry.io/otel/propagation" ) -func NewCompositePropagator(propagators ...Propagator) (propagation.TextMapPropagator, error) { +func BuildPropagators(propagators ...Propagator) ([]propagation.TextMapPropagator, error) { var allPropagators []propagation.TextMapPropagator for _, p := range propagators { switch p { @@ -26,5 +26,5 @@ func NewCompositePropagator(propagators ...Propagator) (propagation.TextMapPropa return nil, fmt.Errorf("unknown trace propagator: %s", p) } } - return propagation.NewCompositeTextMapPropagator(allPropagators...), nil + return allPropagators, nil }