diff --git a/internal/pkg/opentelemetry/controller.go b/internal/pkg/opentelemetry/controller.go index ef0d62474..e5954cad5 100644 --- a/internal/pkg/opentelemetry/controller.go +++ b/internal/pkg/opentelemetry/controller.go @@ -19,19 +19,6 @@ import ( type Controller struct { logger *slog.Logger tracerProvider trace.TracerProvider - tracersMap map[tracerID]trace.Tracer -} - -type tracerID struct{ name, version, schema string } - -func (c *Controller) getTracer(name, version, schema string) trace.Tracer { - tID := tracerID{name: name, version: version, schema: schema} - t, exists := c.tracersMap[tID] - if !exists { - t = c.tracerProvider.Tracer(name, trace.WithInstrumentationVersion(version), trace.WithSchemaURL(schema)) - c.tracersMap[tID] = t - } - return t } // Trace creates a trace span for event. @@ -43,7 +30,12 @@ func (c *Controller) Trace(ss ptrace.ScopeSpans) { kvs []attribute.KeyValue ) - t := c.getTracer(ss.Scope().Name(), ss.Scope().Version(), ss.SchemaUrl()) + tracer := c.tracerProvider.Tracer( + ss.Scope().Name(), + trace.WithInstrumentationVersion(ss.Scope().Version()), + trace.WithInstrumentationAttributes(attrs(ss.Scope().Attributes())...), + trace.WithSchemaURL(ss.SchemaUrl()), + ) for k := 0; k < ss.Spans().Len(); k++ { pSpan := ss.Spans().At(k) @@ -51,7 +43,7 @@ func (c *Controller) Trace(ss ptrace.ScopeSpans) { c.logger.Debug("dropping invalid span", "name", pSpan.Name()) continue } - c.logger.Debug("handling span", "tracer", t, "span", pSpan) + c.logger.Debug("handling span", "tracer", tracer, "span", pSpan) ctx := context.Background() if !pSpan.ParentSpanID().IsEmpty() { @@ -71,7 +63,7 @@ func (c *Controller) Trace(ss ptrace.ScopeSpans) { trace.WithTimestamp(pSpan.StartTimestamp().AsTime()), trace.WithLinks(c.links(pSpan.Links())...), ) - _, span := t.Start(ctx, pSpan.Name(), startOpts...) + _, span := tracer.Start(ctx, pSpan.Name(), startOpts...) startOpts = startOpts[:0] kvs = kvs[:0] @@ -96,7 +88,6 @@ func NewController(logger *slog.Logger, tracerProvider trace.TracerProvider) (*C return &Controller{ logger: logger, tracerProvider: tracerProvider, - tracersMap: make(map[tracerID]trace.Tracer), }, nil } diff --git a/internal/pkg/opentelemetry/controller_test.go b/internal/pkg/opentelemetry/controller_test.go index 9a3a7bd3c..60566f439 100644 --- a/internal/pkg/opentelemetry/controller_test.go +++ b/internal/pkg/opentelemetry/controller_test.go @@ -10,6 +10,7 @@ import ( "runtime" "strconv" "strings" + "sync" "sync/atomic" "testing" "time" @@ -314,35 +315,6 @@ func TestTrace(t *testing.T) { } } -func TestGetTracer(t *testing.T) { - exporter := tracetest.NewInMemoryExporter() - tp := sdktrace.NewTracerProvider( - sdktrace.WithSampler(sdktrace.AlwaysSample()), - sdktrace.WithBatcher(exporter), - sdktrace.WithResource(instResource()), - ) - defer func() { - err := tp.Shutdown(context.Background()) - assert.NoError(t, err) - }() - - ctrl, err := NewController(slog.Default(), tp) - assert.NoError(t, err) - - t1 := ctrl.getTracer("test", "v1", "schema") - assert.Equal(t, t1, ctrl.tracersMap[tracerID{name: "test", version: "v1", schema: "schema"}]) - - t2 := ctrl.getTracer("net/http", "", "") - assert.Equal(t, t2, ctrl.tracersMap[tracerID{name: "net/http", version: "", schema: ""}]) - - t3 := ctrl.getTracer("test", "v1", "schema") - assert.Same(t, t1, t3) - - t4 := ctrl.getTracer("net/http", "", "") - assert.Same(t, t2, t4) - assert.Equal(t, len(ctrl.tracersMap), 2) -} - type shutdownExporter struct { sdktrace.SpanExporter @@ -390,3 +362,38 @@ func TestShutdown(t *testing.T) { assert.True(t, exporter.called, "Exporter not shutdown") assert.Equal(t, uint32(nSpan), exporter.exported.Load(), "Pending spans not flushed") } + +func TestControllerTraceConcurrentSafe(t *testing.T) { + exporter := tracetest.NewInMemoryExporter() + tp := sdktrace.NewTracerProvider( + sdktrace.WithSampler(sdktrace.AlwaysSample()), + sdktrace.WithBatcher(exporter), + sdktrace.WithResource(instResource()), + ) + defer func() { + err := tp.Shutdown(context.Background()) + assert.NoError(t, err) + }() + + ctrl, err := NewController(slog.Default(), tp) + assert.NoError(t, err) + + const goroutines = 10 + + var wg sync.WaitGroup + for n := 0; n < goroutines; n++ { + wg.Add(1) + go func() { + defer wg.Done() + + data := ptrace.NewScopeSpans() + data.Scope().SetName(fmt.Sprintf("tracer-%d", n%(goroutines/2))) + data.Scope().SetVersion("v1") + data.SetSchemaUrl("url") + data.Spans().AppendEmpty().SetName("test") + ctrl.Trace(data) + }() + } + + wg.Wait() +}