Skip to content

Commit

Permalink
Add concurrent safety to Controller.Trace
Browse files Browse the repository at this point in the history
Use TracerProvider.Tracer directly instead of caching Tracers in the
Controller. The API ensures that TracerProvider.Tracer is concurrent
safe and the default SDK already handles the caching of Tracers (no need
to duplicate that logic here).
  • Loading branch information
MrAlias committed Nov 6, 2024
1 parent 658f3f4 commit fed3782
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 46 deletions.
25 changes: 8 additions & 17 deletions internal/pkg/opentelemetry/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,6 @@ import (
type Controller struct {
logger *slog.Logger
tracerProvider trace.TracerProvider
tracersMap map[tracerID]trace.Tracer
}

type tracerID struct{ name, version, schema string }

func (c *Controller) getTracer(name, version, schema string) trace.Tracer {
tID := tracerID{name: name, version: version, schema: schema}
t, exists := c.tracersMap[tID]
if !exists {
t = c.tracerProvider.Tracer(name, trace.WithInstrumentationVersion(version), trace.WithSchemaURL(schema))
c.tracersMap[tID] = t
}
return t
}

// Trace creates a trace span for event.
Expand All @@ -43,15 +30,20 @@ func (c *Controller) Trace(ss ptrace.ScopeSpans) {
kvs []attribute.KeyValue
)

t := c.getTracer(ss.Scope().Name(), ss.Scope().Version(), ss.SchemaUrl())
tracer := c.tracerProvider.Tracer(
ss.Scope().Name(),
trace.WithInstrumentationVersion(ss.Scope().Version()),
trace.WithInstrumentationAttributes(attrs(ss.Scope().Attributes())...),
trace.WithSchemaURL(ss.SchemaUrl()),
)
for k := 0; k < ss.Spans().Len(); k++ {
pSpan := ss.Spans().At(k)

if pSpan.TraceID().IsEmpty() || pSpan.SpanID().IsEmpty() {
c.logger.Debug("dropping invalid span", "name", pSpan.Name())
continue
}
c.logger.Debug("handling span", "tracer", t, "span", pSpan)
c.logger.Debug("handling span", "tracer", tracer, "span", pSpan)

ctx := context.Background()
if !pSpan.ParentSpanID().IsEmpty() {
Expand All @@ -71,7 +63,7 @@ func (c *Controller) Trace(ss ptrace.ScopeSpans) {
trace.WithTimestamp(pSpan.StartTimestamp().AsTime()),
trace.WithLinks(c.links(pSpan.Links())...),
)
_, span := t.Start(ctx, pSpan.Name(), startOpts...)
_, span := tracer.Start(ctx, pSpan.Name(), startOpts...)
startOpts = startOpts[:0]
kvs = kvs[:0]

Expand All @@ -96,7 +88,6 @@ func NewController(logger *slog.Logger, tracerProvider trace.TracerProvider) (*C
return &Controller{
logger: logger,
tracerProvider: tracerProvider,
tracersMap: make(map[tracerID]trace.Tracer),
}, nil
}

Expand Down
65 changes: 36 additions & 29 deletions internal/pkg/opentelemetry/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -314,35 +315,6 @@ func TestTrace(t *testing.T) {
}
}

func TestGetTracer(t *testing.T) {
exporter := tracetest.NewInMemoryExporter()
tp := sdktrace.NewTracerProvider(
sdktrace.WithSampler(sdktrace.AlwaysSample()),
sdktrace.WithBatcher(exporter),
sdktrace.WithResource(instResource()),
)
defer func() {
err := tp.Shutdown(context.Background())
assert.NoError(t, err)
}()

ctrl, err := NewController(slog.Default(), tp)
assert.NoError(t, err)

t1 := ctrl.getTracer("test", "v1", "schema")
assert.Equal(t, t1, ctrl.tracersMap[tracerID{name: "test", version: "v1", schema: "schema"}])

t2 := ctrl.getTracer("net/http", "", "")
assert.Equal(t, t2, ctrl.tracersMap[tracerID{name: "net/http", version: "", schema: ""}])

t3 := ctrl.getTracer("test", "v1", "schema")
assert.Same(t, t1, t3)

t4 := ctrl.getTracer("net/http", "", "")
assert.Same(t, t2, t4)
assert.Equal(t, len(ctrl.tracersMap), 2)
}

type shutdownExporter struct {
sdktrace.SpanExporter

Expand Down Expand Up @@ -390,3 +362,38 @@ func TestShutdown(t *testing.T) {
assert.True(t, exporter.called, "Exporter not shutdown")
assert.Equal(t, uint32(nSpan), exporter.exported.Load(), "Pending spans not flushed")
}

func TestControllerTraceConcurrentSafe(t *testing.T) {
exporter := tracetest.NewInMemoryExporter()
tp := sdktrace.NewTracerProvider(
sdktrace.WithSampler(sdktrace.AlwaysSample()),
sdktrace.WithBatcher(exporter),
sdktrace.WithResource(instResource()),
)
defer func() {
err := tp.Shutdown(context.Background())
assert.NoError(t, err)
}()

ctrl, err := NewController(slog.Default(), tp)
assert.NoError(t, err)

const goroutines = 10

var wg sync.WaitGroup
for n := 0; n < goroutines; n++ {
wg.Add(1)
go func() {
defer wg.Done()

data := ptrace.NewScopeSpans()
data.Scope().SetName(fmt.Sprintf("tracer-%d", n%(goroutines/2)))
data.Scope().SetVersion("v1")
data.SetSchemaUrl("url")
data.Spans().AppendEmpty().SetName("test")
ctrl.Trace(data)
}()
}

wg.Wait()
}

0 comments on commit fed3782

Please sign in to comment.