From 002a5e3498b7ed4aea8288fbbf32bd02565e305a Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Wed, 22 Jan 2025 18:43:20 +0100 Subject: [PATCH 1/2] feat(router): add text map propagator interface to module system --- .../modules/custom-trace-propagator/module.go | 101 ++++++++++++++++++ .../modules/custom_trace_propagator_test.go | 54 ++++++++++ router/core/graph_server.go | 6 +- router/core/modules.go | 9 ++ router/core/router.go | 28 +++-- router/pkg/trace/propagation.go | 4 +- 6 files changed, 189 insertions(+), 13 deletions(-) create mode 100644 router-tests/modules/custom-trace-propagator/module.go create mode 100644 router-tests/modules/custom_trace_propagator_test.go diff --git a/router-tests/modules/custom-trace-propagator/module.go b/router-tests/modules/custom-trace-propagator/module.go new file mode 100644 index 0000000000..f9ad3a1507 --- /dev/null +++ b/router-tests/modules/custom-trace-propagator/module.go @@ -0,0 +1,101 @@ +package custom_trace_propagator + +import ( + "context" + "fmt" + "github.com/wundergraph/cosmo/router/core" + "go.opentelemetry.io/otel/propagation" + "go.uber.org/zap" +) + +func init() { + // Register your module here + core.RegisterModule(&CustomTracePropagatorModule{}) +} + +const myModuleID = "tracePropagatorModule" + +// CustomTracePropagatorModule is a simple module that provides a custom trace propagator for the router +type CustomTracePropagatorModule struct { + Value uint64 `mapstructure:"value"` + Propagator *customPropagator + Logger *zap.Logger +} + +func (m *CustomTracePropagatorModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + // This is the ID of your module, it must be unique + ID: myModuleID, + // The priority of your module, lower numbers are executed first + Priority: 1, + New: func() core.Module { + return &CustomTracePropagatorModule{ + Propagator: &customPropagator{}, + } + }, + } +} + +const customPropagatorKey = "customPropagator" + +func (m *CustomTracePropagatorModule) TracePropagators() []propagation.TextMapPropagator { + return []propagation.TextMapPropagator{m.Propagator} +} + +type customPropagator struct { + InjectCalled int + ExtractCalled int +} + +type info struct { + injectCalled int + extractCalled int +} + +func parse(s string) *info { + var i info + + _, err := fmt.Sscanf(s, "injectCalled:%d, extractCalled:%d", &i.injectCalled, &i.extractCalled) + if err != nil { + return nil + } + return &i +} + +func (i *info) String() string { + return fmt.Sprintf("injectCalled:%d, extractCalled:%d", i.injectCalled, i.extractCalled) +} + +func (c *customPropagator) Inject(ctx context.Context, carrier propagation.TextMapCarrier) { + c.InjectCalled++ + var i info + + switch v := ctx.Value(customPropagatorKey).(type) { + case *info: + i = *v + default: + } + + i.injectCalled = c.InjectCalled + carrier.Set(customPropagatorKey, i.String()) +} + +func (c *customPropagator) Extract(ctx context.Context, carrier propagation.TextMapCarrier) context.Context { + c.ExtractCalled++ + + cStr := carrier.Get(customPropagatorKey) + + i := parse(cStr) + if i == nil { + return ctx + } + + i.extractCalled = c.ExtractCalled + return context.WithValue(ctx, customPropagatorKey, i) +} + +func (c *customPropagator) Fields() []string { + return []string{customPropagatorKey} +} + +var _ propagation.TextMapPropagator = (*customPropagator)(nil) diff --git a/router-tests/modules/custom_trace_propagator_test.go b/router-tests/modules/custom_trace_propagator_test.go new file mode 100644 index 0000000000..9300022f3f --- /dev/null +++ b/router-tests/modules/custom_trace_propagator_test.go @@ -0,0 +1,54 @@ +package module + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + custom_trace_propagator "github.com/wundergraph/cosmo/router-tests/modules/custom-trace-propagator" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/cmd/custom/module" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/config" + rtrace "github.com/wundergraph/cosmo/router/pkg/trace" +) + +func TestModuleCustomPropagator(t *testing.T) { + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "myModule": module.MyModule{Value: 1}, + "custom_trace": custom_trace_propagator.CustomTracePropagatorModule{Value: 2}, + }, + } + + testenv.Run(t, &testenv.Config{ + Subgraphs: testenv.SubgraphsConfig{ + GlobalMiddleware: func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + require.Equal(t, "injectCalled:1, extractCalled:1", request.Header.Get("CustomPropagator")) + handler.ServeHTTP(writer, request) + }) + }, + }, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithTracing(&rtrace.Config{ + Enabled: true, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query MyQuery { employees { id } }`, + OperationName: json.RawMessage(`"MyQuery"`), + Header: map[string][]string{ + "CustomPropagator": {"injectCalled:0, extractCalled:0"}, + }, + }) + require.NoError(t, err) + assert.Equal(t, 200, res.Response.StatusCode) + assert.JSONEq(t, res.Body, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`) + }) +} diff --git a/router/core/graph_server.go b/router/core/graph_server.go index fa30d261c2..256bf054d7 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -769,8 +769,8 @@ func (s *graphServer) buildGraphMux(ctx context.Context, otelhttp.WithTracerProvider(s.tracerProvider), } - if s.tracePropagators != nil { - middlewareOptions = append(middlewareOptions, otelhttp.WithPropagators(s.tracePropagators)) + if s.compositePropagator != nil { + middlewareOptions = append(middlewareOptions, otelhttp.WithPropagators(s.compositePropagator)) } traceHandler := rtrace.NewMiddleware( @@ -867,7 +867,7 @@ func (s *graphServer) buildGraphMux(ctx context.Context, }, }, TracerProvider: s.tracerProvider, - TracePropagators: s.tracePropagators, + TracePropagators: s.compositePropagator, LocalhostFallbackInsideDocker: s.localhostFallbackInsideDocker, Logger: s.logger, }, diff --git a/router/core/modules.go b/router/core/modules.go index 91428e8a2d..434344bd9e 100644 --- a/router/core/modules.go +++ b/router/core/modules.go @@ -3,6 +3,7 @@ package core import ( stdContext "context" "fmt" + "go.opentelemetry.io/otel/propagation" "math" "net/http" "sort" @@ -123,6 +124,14 @@ type EnginePostOriginHandler interface { OnOriginResponse(resp *http.Response, ctx RequestContext) *http.Response } +// TracePropagationProvider is an interface that allows you to provide custom trace propagators. +// The trace propagators are used to inject and extract trace information from the request. +// The provided propagators will be used in addition to the configured propagators. +type TracePropagationProvider interface { + // TracePropagators returns the custom trace propagators which should be used by the router. + TracePropagators() []propagation.TextMapPropagator +} + // Provisioner is called before the server starts // It allows you to initialize your module e.g. create a database connection // or load a configuration file diff --git a/router/core/router.go b/router/core/router.go index ac37f2fef7..e231bf26d8 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -209,7 +209,8 @@ type ( tlsServerConfig *tls.Config tlsConfig *TlsConfig telemetryAttributes []config.CustomAttribute - tracePropagators propagation.TextMapPropagator + tracePropagators []propagation.TextMapPropagator + compositePropagator propagation.TextMapPropagator // Poller configPoller configpoller.ConfigPoller selfRegister selfregister.SelfRegister @@ -435,18 +436,12 @@ func NewRouter(opts ...Option) (*Router, error) { if r.traceConfig.Enabled { if len(r.traceConfig.Propagators) > 0 { - propagators, err := rtrace.NewCompositePropagator(r.traceConfig.Propagators...) + propagators, err := rtrace.BuildPropagators(r.traceConfig.Propagators...) if err != nil { r.logger.Error("creating propagators", zap.Error(err)) return nil, err } - // Don't set it globally when we use the router in tests. - // In practice, setting it globally only makes sense for module development. - if r.traceConfig.TestMemoryExporter == nil { - otel.SetTextMapPropagator(propagators) - } - r.tracePropagators = propagators } @@ -677,6 +672,13 @@ func (r *Router) initModules(ctx context.Context) error { r.postOriginHandlers = append(r.postOriginHandlers, handler.OnOriginResponse) } + if handler, ok := moduleInstance.(TracePropagationProvider); ok { + modulePropagators := handler.TracePropagators() + if len(modulePropagators) > 0 { + r.tracePropagators = append(r.tracePropagators, modulePropagators...) + } + } + r.modules = append(r.modules, moduleInstance) r.logger.Info("Module registered", @@ -883,6 +885,16 @@ func (r *Router) bootstrap(ctx context.Context) error { return fmt.Errorf("failed to init user modules: %w", err) } + if r.traceConfig.Enabled && len(r.tracePropagators) > 0 { + r.compositePropagator = propagation.NewCompositeTextMapPropagator(r.tracePropagators...) + + // Don't set it globally when we use the router in tests. + // In practice, setting it globally only makes sense for module development. + if r.traceConfig.TestMemoryExporter == nil { + otel.SetTextMapPropagator(r.compositePropagator) + } + } + return nil } diff --git a/router/pkg/trace/propagation.go b/router/pkg/trace/propagation.go index dfaf317c8b..d62a92c65a 100644 --- a/router/pkg/trace/propagation.go +++ b/router/pkg/trace/propagation.go @@ -8,7 +8,7 @@ import ( "go.opentelemetry.io/otel/propagation" ) -func NewCompositePropagator(propagators ...Propagator) (propagation.TextMapPropagator, error) { +func BuildPropagators(propagators ...Propagator) ([]propagation.TextMapPropagator, error) { var allPropagators []propagation.TextMapPropagator for _, p := range propagators { switch p { @@ -26,5 +26,5 @@ func NewCompositePropagator(propagators ...Propagator) (propagation.TextMapPropa return nil, fmt.Errorf("unknown trace propagator: %s", p) } } - return propagation.NewCompositeTextMapPropagator(allPropagators...), nil + return allPropagators, nil } From 332dfe118e63b5bfd6221d0edaef382583a17a75 Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Thu, 23 Jan 2025 10:33:42 +0100 Subject: [PATCH 2/2] test: include custom trace ID in tests --- .../modules/custom-trace-propagator/module.go | 39 +++++++-- .../modules/custom_trace_propagator_test.go | 79 +++++++++++-------- 2 files changed, 80 insertions(+), 38 deletions(-) diff --git a/router-tests/modules/custom-trace-propagator/module.go b/router-tests/modules/custom-trace-propagator/module.go index f9ad3a1507..6a7b10a6c3 100644 --- a/router-tests/modules/custom-trace-propagator/module.go +++ b/router-tests/modules/custom-trace-propagator/module.go @@ -3,9 +3,12 @@ package custom_trace_propagator import ( "context" "fmt" - "github.com/wundergraph/cosmo/router/core" + "go.opentelemetry.io/otel/propagation" + "go.opentelemetry.io/otel/trace" "go.uber.org/zap" + + "github.com/wundergraph/cosmo/router/core" ) func init() { @@ -36,7 +39,9 @@ func (m *CustomTracePropagatorModule) Module() core.ModuleInfo { } } -const customPropagatorKey = "customPropagator" +type ctxKeyCustomPropagator string + +const ctxKey = "CustomPropagator" func (m *CustomTracePropagatorModule) TracePropagators() []propagation.TextMapPropagator { return []propagation.TextMapPropagator{m.Propagator} @@ -70,32 +75,52 @@ func (c *customPropagator) Inject(ctx context.Context, carrier propagation.TextM c.InjectCalled++ var i info - switch v := ctx.Value(customPropagatorKey).(type) { + switch v := ctx.Value(ctxKeyCustomPropagator(ctxKey)).(type) { case *info: i = *v default: } i.injectCalled = c.InjectCalled - carrier.Set(customPropagatorKey, i.String()) + carrier.Set(ctxKey, i.String()) } func (c *customPropagator) Extract(ctx context.Context, carrier propagation.TextMapCarrier) context.Context { c.ExtractCalled++ - cStr := carrier.Get(customPropagatorKey) + cStr := carrier.Get(ctxKey) i := parse(cStr) if i == nil { return ctx } + // create a fantasy trace ID for testing purposes + sID := "acde00000000000000000000eeeeffff" + + tid, err := trace.TraceIDFromHex(sID) + if err != nil { + return ctx + } + + sc := trace.SpanFromContext(ctx).SpanContext() + + ssc := trace.NewSpanContext(trace.SpanContextConfig{ + TraceID: tid, + SpanID: sc.SpanID(), + TraceFlags: sc.TraceFlags(), + TraceState: sc.TraceState(), + Remote: sc.IsRemote(), + }) + i.extractCalled = c.ExtractCalled - return context.WithValue(ctx, customPropagatorKey, i) + ctx = context.WithValue(ctx, ctxKeyCustomPropagator(ctxKey), i) + + return trace.ContextWithSpanContext(ctx, ssc) } func (c *customPropagator) Fields() []string { - return []string{customPropagatorKey} + return []string{ctxKey} } var _ propagation.TextMapPropagator = (*customPropagator)(nil) diff --git a/router-tests/modules/custom_trace_propagator_test.go b/router-tests/modules/custom_trace_propagator_test.go index 9300022f3f..0b2e51f0b9 100644 --- a/router-tests/modules/custom_trace_propagator_test.go +++ b/router-tests/modules/custom_trace_propagator_test.go @@ -7,48 +7,65 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + custom_trace_propagator "github.com/wundergraph/cosmo/router-tests/modules/custom-trace-propagator" "github.com/wundergraph/cosmo/router-tests/testenv" "github.com/wundergraph/cosmo/router/cmd/custom/module" "github.com/wundergraph/cosmo/router/core" "github.com/wundergraph/cosmo/router/pkg/config" rtrace "github.com/wundergraph/cosmo/router/pkg/trace" + "github.com/wundergraph/cosmo/router/pkg/trace/tracetest" ) func TestModuleCustomPropagator(t *testing.T) { - cfg := config.Config{ - Graph: config.Graph{}, - Modules: map[string]interface{}{ - "myModule": module.MyModule{Value: 1}, - "custom_trace": custom_trace_propagator.CustomTracePropagatorModule{Value: 2}, - }, - } - - testenv.Run(t, &testenv.Config{ - Subgraphs: testenv.SubgraphsConfig{ - GlobalMiddleware: func(handler http.Handler) http.Handler { - return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - require.Equal(t, "injectCalled:1, extractCalled:1", request.Header.Get("CustomPropagator")) - handler.ServeHTTP(writer, request) - }) + t.Run("Should set custom trace propagator with a custom trace ID", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "myModule": module.MyModule{Value: 1}, + "custom_trace": custom_trace_propagator.CustomTracePropagatorModule{Value: 2}, }, - }, - RouterOptions: []core.Option{ - core.WithModulesConfig(cfg.Modules), - core.WithTracing(&rtrace.Config{ - Enabled: true, - }), - }, - }, func(t *testing.T, xEnv *testenv.Environment) { - res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ - Query: `query MyQuery { employees { id } }`, - OperationName: json.RawMessage(`"MyQuery"`), - Header: map[string][]string{ - "CustomPropagator": {"injectCalled:0, extractCalled:0"}, + } + + exporter := tracetest.NewInMemoryExporter(t) + + testenv.Run(t, &testenv.Config{ + Subgraphs: testenv.SubgraphsConfig{ + GlobalMiddleware: func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + require.Equal(t, "injectCalled:1, extractCalled:1", request.Header.Get("CustomPropagator")) + handler.ServeHTTP(writer, request) + }) + }, }, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithTracing(&rtrace.Config{ + Enabled: true, + }), + }, + TraceExporter: exporter, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query MyQuery { employees { id } }`, + OperationName: json.RawMessage(`"MyQuery"`), + Header: map[string][]string{ + "CustomPropagator": {"injectCalled:0, extractCalled:0"}, + }, + }) + require.NoError(t, err) + assert.Equal(t, 200, res.Response.StatusCode) + assert.JSONEq(t, res.Body, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`) + + spans := exporter.GetSpans().Snapshots() + + // check that our fantasy trace ID is set + for _, s := range spans { + traceIDStr := s.SpanContext().TraceID().String() + require.Equal(t, "acde00000000000000000000eeeeffff", traceIDStr) + } }) - require.NoError(t, err) - assert.Equal(t, 200, res.Response.StatusCode) - assert.JSONEq(t, res.Body, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`) }) }