Skip to content
Merged
126 changes: 126 additions & 0 deletions router-tests/modules/custom-trace-propagator/module.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package custom_trace_propagator

import (
"context"
"fmt"

"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"

"github.com/wundergraph/cosmo/router/core"
)

func init() {
// Register your module here
core.RegisterModule(&CustomTracePropagatorModule{})
}

const myModuleID = "tracePropagatorModule"

// CustomTracePropagatorModule is a simple module that provides a custom trace propagator for the router
type CustomTracePropagatorModule struct {
Value uint64 `mapstructure:"value"`
Propagator *customPropagator
Logger *zap.Logger
}

func (m *CustomTracePropagatorModule) Module() core.ModuleInfo {
return core.ModuleInfo{
// This is the ID of your module, it must be unique
ID: myModuleID,
// The priority of your module, lower numbers are executed first
Priority: 1,
New: func() core.Module {
return &CustomTracePropagatorModule{
Propagator: &customPropagator{},
}
},
}
}

type ctxKeyCustomPropagator string

const ctxKey = "CustomPropagator"

func (m *CustomTracePropagatorModule) TracePropagators() []propagation.TextMapPropagator {
return []propagation.TextMapPropagator{m.Propagator}
}

type customPropagator struct {
InjectCalled int
ExtractCalled int
}

type info struct {
injectCalled int
extractCalled int
}

func parse(s string) *info {
var i info

_, err := fmt.Sscanf(s, "injectCalled:%d, extractCalled:%d", &i.injectCalled, &i.extractCalled)
if err != nil {
return nil
}
return &i
}

func (i *info) String() string {
return fmt.Sprintf("injectCalled:%d, extractCalled:%d", i.injectCalled, i.extractCalled)
}

func (c *customPropagator) Inject(ctx context.Context, carrier propagation.TextMapCarrier) {
c.InjectCalled++
var i info

switch v := ctx.Value(ctxKeyCustomPropagator(ctxKey)).(type) {
case *info:
i = *v
default:
}

i.injectCalled = c.InjectCalled
carrier.Set(ctxKey, i.String())
}

func (c *customPropagator) Extract(ctx context.Context, carrier propagation.TextMapCarrier) context.Context {
c.ExtractCalled++

cStr := carrier.Get(ctxKey)

i := parse(cStr)
if i == nil {
return ctx
}

// create a fantasy trace ID for testing purposes
sID := "acde00000000000000000000eeeeffff"

tid, err := trace.TraceIDFromHex(sID)
if err != nil {
return ctx
}

sc := trace.SpanFromContext(ctx).SpanContext()

ssc := trace.NewSpanContext(trace.SpanContextConfig{
TraceID: tid,
SpanID: sc.SpanID(),
TraceFlags: sc.TraceFlags(),
TraceState: sc.TraceState(),
Remote: sc.IsRemote(),
})

i.extractCalled = c.ExtractCalled
ctx = context.WithValue(ctx, ctxKeyCustomPropagator(ctxKey), i)

return trace.ContextWithSpanContext(ctx, ssc)
}

func (c *customPropagator) Fields() []string {
return []string{ctxKey}
}

var _ propagation.TextMapPropagator = (*customPropagator)(nil)
71 changes: 71 additions & 0 deletions router-tests/modules/custom_trace_propagator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package module

import (
"encoding/json"
"net/http"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

custom_trace_propagator "github.com/wundergraph/cosmo/router-tests/modules/custom-trace-propagator"
"github.com/wundergraph/cosmo/router-tests/testenv"
"github.com/wundergraph/cosmo/router/cmd/custom/module"
"github.com/wundergraph/cosmo/router/core"
"github.com/wundergraph/cosmo/router/pkg/config"
rtrace "github.com/wundergraph/cosmo/router/pkg/trace"
"github.com/wundergraph/cosmo/router/pkg/trace/tracetest"
)

func TestModuleCustomPropagator(t *testing.T) {
t.Run("Should set custom trace propagator with a custom trace ID", func(t *testing.T) {
t.Parallel()

cfg := config.Config{
Graph: config.Graph{},
Modules: map[string]interface{}{
"myModule": module.MyModule{Value: 1},
"custom_trace": custom_trace_propagator.CustomTracePropagatorModule{Value: 2},
},
}

exporter := tracetest.NewInMemoryExporter(t)

testenv.Run(t, &testenv.Config{
Subgraphs: testenv.SubgraphsConfig{
GlobalMiddleware: func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
require.Equal(t, "injectCalled:1, extractCalled:1", request.Header.Get("CustomPropagator"))
handler.ServeHTTP(writer, request)
})
},
},
RouterOptions: []core.Option{
core.WithModulesConfig(cfg.Modules),
core.WithTracing(&rtrace.Config{
Enabled: true,
}),
},
TraceExporter: exporter,
}, func(t *testing.T, xEnv *testenv.Environment) {
res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
Query: `query MyQuery { employees { id } }`,
OperationName: json.RawMessage(`"MyQuery"`),
Header: map[string][]string{
"CustomPropagator": {"injectCalled:0, extractCalled:0"},
},
})
require.NoError(t, err)
assert.Equal(t, 200, res.Response.StatusCode)
assert.JSONEq(t, res.Body, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`)

spans := exporter.GetSpans().Snapshots()

// check that our fantasy trace ID is set
for _, s := range spans {
traceIDStr := s.SpanContext().TraceID().String()
require.Equal(t, "acde00000000000000000000eeeeffff", traceIDStr)
}
})
})
}
6 changes: 3 additions & 3 deletions router/core/graph_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -781,8 +781,8 @@ func (s *graphServer) buildGraphMux(ctx context.Context,
otelhttp.WithTracerProvider(s.tracerProvider),
}

if s.tracePropagators != nil {
middlewareOptions = append(middlewareOptions, otelhttp.WithPropagators(s.tracePropagators))
if s.compositePropagator != nil {
middlewareOptions = append(middlewareOptions, otelhttp.WithPropagators(s.compositePropagator))
}

traceHandler := rtrace.NewMiddleware(
Expand Down Expand Up @@ -879,7 +879,7 @@ func (s *graphServer) buildGraphMux(ctx context.Context,
},
},
TracerProvider: s.tracerProvider,
TracePropagators: s.tracePropagators,
TracePropagators: s.compositePropagator,
LocalhostFallbackInsideDocker: s.localhostFallbackInsideDocker,
Logger: s.logger,
},
Expand Down
9 changes: 9 additions & 0 deletions router/core/modules.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package core
import (
stdContext "context"
"fmt"
"go.opentelemetry.io/otel/propagation"
"math"
"net/http"
"sort"
Expand Down Expand Up @@ -123,6 +124,14 @@ type EnginePostOriginHandler interface {
OnOriginResponse(resp *http.Response, ctx RequestContext) *http.Response
}

// TracePropagationProvider is an interface that allows you to provide custom trace propagators.
// The trace propagators are used to inject and extract trace information from the request.
// The provided propagators will be used in addition to the configured propagators.
type TracePropagationProvider interface {
// TracePropagators returns the custom trace propagators which should be used by the router.
TracePropagators() []propagation.TextMapPropagator
}

// Provisioner is called before the server starts
// It allows you to initialize your module e.g. create a database connection
// or load a configuration file
Expand Down
28 changes: 20 additions & 8 deletions router/core/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ type (
tlsServerConfig *tls.Config
tlsConfig *TlsConfig
telemetryAttributes []config.CustomAttribute
tracePropagators propagation.TextMapPropagator
tracePropagators []propagation.TextMapPropagator
compositePropagator propagation.TextMapPropagator
// Poller
configPoller configpoller.ConfigPoller
selfRegister selfregister.SelfRegister
Expand Down Expand Up @@ -442,18 +443,12 @@ func NewRouter(opts ...Option) (*Router, error) {

if r.traceConfig.Enabled {
if len(r.traceConfig.Propagators) > 0 {
propagators, err := rtrace.NewCompositePropagator(r.traceConfig.Propagators...)
propagators, err := rtrace.BuildPropagators(r.traceConfig.Propagators...)
if err != nil {
r.logger.Error("creating propagators", zap.Error(err))
return nil, err
}

// Don't set it globally when we use the router in tests.
// In practice, setting it globally only makes sense for module development.
if r.traceConfig.TestMemoryExporter == nil {
otel.SetTextMapPropagator(propagators)
}

r.tracePropagators = propagators
}

Expand Down Expand Up @@ -674,6 +669,13 @@ func (r *Router) initModules(ctx context.Context) error {
r.postOriginHandlers = append(r.postOriginHandlers, handler.OnOriginResponse)
}

if handler, ok := moduleInstance.(TracePropagationProvider); ok {
modulePropagators := handler.TracePropagators()
if len(modulePropagators) > 0 {
r.tracePropagators = append(r.tracePropagators, modulePropagators...)
}
}

r.modules = append(r.modules, moduleInstance)

r.logger.Info("Module registered",
Expand Down Expand Up @@ -884,6 +886,16 @@ func (r *Router) bootstrap(ctx context.Context) error {
return fmt.Errorf("failed to init user modules: %w", err)
}

if r.traceConfig.Enabled && len(r.tracePropagators) > 0 {
r.compositePropagator = propagation.NewCompositeTextMapPropagator(r.tracePropagators...)

// Don't set it globally when we use the router in tests.
// In practice, setting it globally only makes sense for module development.
if r.traceConfig.TestMemoryExporter == nil {
otel.SetTextMapPropagator(r.compositePropagator)
}
}

return nil
}

Expand Down
4 changes: 2 additions & 2 deletions router/pkg/trace/propagation.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"go.opentelemetry.io/otel/propagation"
)

func NewCompositePropagator(propagators ...Propagator) (propagation.TextMapPropagator, error) {
func BuildPropagators(propagators ...Propagator) ([]propagation.TextMapPropagator, error) {
var allPropagators []propagation.TextMapPropagator
for _, p := range propagators {
switch p {
Expand All @@ -26,5 +26,5 @@ func NewCompositePropagator(propagators ...Propagator) (propagation.TextMapPropa
return nil, fmt.Errorf("unknown trace propagator: %s", p)
}
}
return propagation.NewCompositeTextMapPropagator(allPropagators...), nil
return allPropagators, nil
}