From fed378223425185714c041004701762b73f6bdf9 Mon Sep 17 00:00:00 2001 From: Tyler Yahn Date: Wed, 6 Nov 2024 13:12:07 -0800 Subject: [PATCH] Add concurrent safety to Controller.Trace Use TracerProvider.Tracer directly instead of caching Tracers in the Controller. The API ensures that TracerProvider.Tracer is concurrent safe and the default SDK already handles the caching of Tracers (no need to duplicate that logic here). --- internal/pkg/opentelemetry/controller.go | 25 +++---- internal/pkg/opentelemetry/controller_test.go | 65 ++++++++++--------- 2 files changed, 44 insertions(+), 46 deletions(-) diff --git a/internal/pkg/opentelemetry/controller.go b/internal/pkg/opentelemetry/controller.go index ef0d62474..e5954cad5 100644 --- a/internal/pkg/opentelemetry/controller.go +++ b/internal/pkg/opentelemetry/controller.go @@ -19,19 +19,6 @@ import ( type Controller struct { logger *slog.Logger tracerProvider trace.TracerProvider - tracersMap map[tracerID]trace.Tracer -} - -type tracerID struct{ name, version, schema string } - -func (c *Controller) getTracer(name, version, schema string) trace.Tracer { - tID := tracerID{name: name, version: version, schema: schema} - t, exists := c.tracersMap[tID] - if !exists { - t = c.tracerProvider.Tracer(name, trace.WithInstrumentationVersion(version), trace.WithSchemaURL(schema)) - c.tracersMap[tID] = t - } - return t } // Trace creates a trace span for event. @@ -43,7 +30,12 @@ func (c *Controller) Trace(ss ptrace.ScopeSpans) { kvs []attribute.KeyValue ) - t := c.getTracer(ss.Scope().Name(), ss.Scope().Version(), ss.SchemaUrl()) + tracer := c.tracerProvider.Tracer( + ss.Scope().Name(), + trace.WithInstrumentationVersion(ss.Scope().Version()), + trace.WithInstrumentationAttributes(attrs(ss.Scope().Attributes())...), + trace.WithSchemaURL(ss.SchemaUrl()), + ) for k := 0; k < ss.Spans().Len(); k++ { pSpan := ss.Spans().At(k) @@ -51,7 +43,7 @@ func (c *Controller) Trace(ss ptrace.ScopeSpans) { c.logger.Debug("dropping invalid span", "name", pSpan.Name()) continue } - c.logger.Debug("handling span", "tracer", t, "span", pSpan) + c.logger.Debug("handling span", "tracer", tracer, "span", pSpan) ctx := context.Background() if !pSpan.ParentSpanID().IsEmpty() { @@ -71,7 +63,7 @@ func (c *Controller) Trace(ss ptrace.ScopeSpans) { trace.WithTimestamp(pSpan.StartTimestamp().AsTime()), trace.WithLinks(c.links(pSpan.Links())...), ) - _, span := t.Start(ctx, pSpan.Name(), startOpts...) + _, span := tracer.Start(ctx, pSpan.Name(), startOpts...) startOpts = startOpts[:0] kvs = kvs[:0] @@ -96,7 +88,6 @@ func NewController(logger *slog.Logger, tracerProvider trace.TracerProvider) (*C return &Controller{ logger: logger, tracerProvider: tracerProvider, - tracersMap: make(map[tracerID]trace.Tracer), }, nil } diff --git a/internal/pkg/opentelemetry/controller_test.go b/internal/pkg/opentelemetry/controller_test.go index 9a3a7bd3c..60566f439 100644 --- a/internal/pkg/opentelemetry/controller_test.go +++ b/internal/pkg/opentelemetry/controller_test.go @@ -10,6 +10,7 @@ import ( "runtime" "strconv" "strings" + "sync" "sync/atomic" "testing" "time" @@ -314,35 +315,6 @@ func TestTrace(t *testing.T) { } } -func TestGetTracer(t *testing.T) { - exporter := tracetest.NewInMemoryExporter() - tp := sdktrace.NewTracerProvider( - sdktrace.WithSampler(sdktrace.AlwaysSample()), - sdktrace.WithBatcher(exporter), - sdktrace.WithResource(instResource()), - ) - defer func() { - err := tp.Shutdown(context.Background()) - assert.NoError(t, err) - }() - - ctrl, err := NewController(slog.Default(), tp) - assert.NoError(t, err) - - t1 := ctrl.getTracer("test", "v1", "schema") - assert.Equal(t, t1, ctrl.tracersMap[tracerID{name: "test", version: "v1", schema: "schema"}]) - - t2 := ctrl.getTracer("net/http", "", "") - assert.Equal(t, t2, ctrl.tracersMap[tracerID{name: "net/http", version: "", schema: ""}]) - - t3 := ctrl.getTracer("test", "v1", "schema") - assert.Same(t, t1, t3) - - t4 := ctrl.getTracer("net/http", "", "") - assert.Same(t, t2, t4) - assert.Equal(t, len(ctrl.tracersMap), 2) -} - type shutdownExporter struct { sdktrace.SpanExporter @@ -390,3 +362,38 @@ func TestShutdown(t *testing.T) { assert.True(t, exporter.called, "Exporter not shutdown") assert.Equal(t, uint32(nSpan), exporter.exported.Load(), "Pending spans not flushed") } + +func TestControllerTraceConcurrentSafe(t *testing.T) { + exporter := tracetest.NewInMemoryExporter() + tp := sdktrace.NewTracerProvider( + sdktrace.WithSampler(sdktrace.AlwaysSample()), + sdktrace.WithBatcher(exporter), + sdktrace.WithResource(instResource()), + ) + defer func() { + err := tp.Shutdown(context.Background()) + assert.NoError(t, err) + }() + + ctrl, err := NewController(slog.Default(), tp) + assert.NoError(t, err) + + const goroutines = 10 + + var wg sync.WaitGroup + for n := 0; n < goroutines; n++ { + wg.Add(1) + go func() { + defer wg.Done() + + data := ptrace.NewScopeSpans() + data.Scope().SetName(fmt.Sprintf("tracer-%d", n%(goroutines/2))) + data.Scope().SetVersion("v1") + data.SetSchemaUrl("url") + data.Spans().AppendEmpty().SetName("test") + ctrl.Trace(data) + }() + } + + wg.Wait() +}