diff --git a/interceptor.go b/interceptor.go index 911767d..46ded28 100644 --- a/interceptor.go +++ b/interceptor.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "slices" "strings" "sync" "time" @@ -89,6 +90,10 @@ func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { return next(ctx, request) } } + labeler, found := LabelerFromContext(ctx) + if !found { + ctx = ContextWithLabeler(ctx, labeler) + } attributeFilter := i.config.filterAttribute.filter isClient := request.Spec().IsClient name := strings.TrimLeft(request.Spec().Procedure, "/") @@ -175,7 +180,12 @@ func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { span.SetStatus(serverSpanStatus(protocol, err)) } span.SetAttributes(attributes...) - attributesSet := attribute.NewSet(attributes...) + var attributesSet attribute.Set + if labelerAttrs := labeler.Get(); len(labelerAttrs) > 0 { + attributesSet = attribute.NewSet(slices.Concat(attributes, labelerAttrs)...) + } else { + attributesSet = attribute.NewSet(attributes...) + } instrumentation.duration.Record(ctx, i.config.now().Sub(requestStartTime).Milliseconds(), metric.WithAttributeSet(attributesSet)) instrumentation.requestSize.Record(ctx, int64(requestSize), metric.WithAttributeSet(attributesSet)) instrumentation.requestsPerRPC.Record(ctx, 1, metric.WithAttributeSet(attributesSet)) @@ -193,6 +203,10 @@ func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) conn return next(ctx, spec) } } + labeler, found := LabelerFromContext(ctx) + if !found { + ctx = ContextWithLabeler(ctx, labeler) + } requestStartTime := i.config.now() name := strings.TrimLeft(spec.Procedure, "/") // Span is closed on context cancelation or when the stream is closed. @@ -215,6 +229,7 @@ func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) conn i.config.omitTraceEvents, instrumentation.responseSize, instrumentation.requestSize, + labeler, ) var requestOnce sync.Once setRequestAttributes := func() { @@ -246,7 +261,7 @@ func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) conn } span.SetStatus(clientSpanStatus(protocol, state.error)) span.End() - attributeSet := attribute.NewSet(state.attributes...) + attributeSet := attribute.NewSet(state.metricAttributes()...) instrumentation.requestsPerRPC.Record(ctx, state.sentCounter, metric.WithAttributeSet(attributeSet)) instrumentation.responsesPerRPC.Record(ctx, state.receivedCounter, metric.WithAttributeSet(attributeSet)) duration := i.config.now().Sub(requestStartTime).Milliseconds() @@ -282,6 +297,10 @@ func (i *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) co return next(ctx, conn) } } + labeler, found := LabelerFromContext(ctx) + if !found { + ctx = ContextWithLabeler(ctx, labeler) + } name := strings.TrimLeft(conn.Spec().Procedure, "/") protocol := protocolToSemConv(conn.Peer().Protocol, i.config.rpcSystem) state := newStreamingState( @@ -292,6 +311,7 @@ func (i *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) co i.config.omitTraceEvents, instrumentation.requestSize, instrumentation.responseSize, + labeler, ) // extract any request headers into the context carrier := propagation.HeaderCarrier(conn.RequestHeader()) @@ -342,7 +362,7 @@ func (i *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) co span.SetAttributes(headerAttributes(protocol, responseKey, conn.ResponseHeader(), i.config.responseHeaderKeys)...) } span.SetStatus(serverSpanStatus(protocol, err)) - attributeSet := attribute.NewSet(state.attributes...) + attributeSet := attribute.NewSet(state.metricAttributes()...) instrumentation.requestsPerRPC.Record(ctx, state.receivedCounter, metric.WithAttributeSet(attributeSet)) instrumentation.responsesPerRPC.Record(ctx, state.sentCounter, metric.WithAttributeSet(attributeSet)) duration := i.config.now().Sub(requestStartTime).Milliseconds() diff --git a/interceptor_test.go b/interceptor_test.go index 27b3439..cb26a0a 100644 --- a/interceptor_test.go +++ b/interceptor_test.go @@ -2608,3 +2608,205 @@ func assertUsableTraceparent(t *testing.T, header http.Header) { assert.NotEmpty(t, spanContext.SpanID(), "span ID should not be empty") assert.NotEmpty(t, spanContext.TraceID(), "trace ID should not be empty") } + +// labelerInterceptor is a test interceptor that retrieves the Labeler from +// context and adds custom attributes. Used to test that labeler attributes +// appear in metrics but not in spans. +type labelerInterceptor struct { + attrs []attribute.KeyValue +} + +func (l labelerInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + labeler, _ := LabelerFromContext(ctx) + labeler.Add(l.attrs...) + return next(ctx, req) + } +} + +func (l labelerInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { + return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { + labeler, _ := LabelerFromContext(ctx) + labeler.Add(l.attrs...) + return next(ctx, spec) + } +} + +func (l labelerInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { + return func(ctx context.Context, conn connect.StreamingHandlerConn) error { + labeler, _ := LabelerFromContext(ctx) + labeler.Add(l.attrs...) + return next(ctx, conn) + } +} + +func TestLabelerFromContext(t *testing.T) { + t.Parallel() + // LabelerFromContext on empty context returns a new Labeler and false. + labeler, ok := LabelerFromContext(context.Background()) + assert.False(t, ok) + require.NotNil(t, labeler) + // Add and Get should not panic even though the labeler is not in a context. + labeler.Add(attribute.String("key", "value")) + got := labeler.Get() + assert.Len(t, got, 1) + assert.Equal(t, attribute.String("key", "value"), got[0]) +} + +func TestLabelerUnary(t *testing.T) { + t.Parallel() + metricReader, meterProvider := setupMetrics() + spanRecorder := tracetest.NewSpanRecorder() + traceProvider := trace.NewTracerProvider(trace.WithSpanProcessor(spanRecorder)) + customAttrs := []attribute.KeyValue{ + attribute.String("custom.label", "test-value"), + } + interceptor, err := NewInterceptor( + WithMeterProvider(meterProvider), + WithTracerProvider(traceProvider), + ) + require.NoError(t, err) + client, _, _ := startServer(t, + []connect.HandlerOption{ + connect.WithInterceptors(interceptor, labelerInterceptor{attrs: customAttrs}), + }, + nil, + okayPingServer(), + ) + _, err = client.Ping(context.Background(), requestOfSize(1, 12)) + require.NoError(t, err) + // Verify custom attributes appear in metrics. + assertMetrics(t, metricReader, expectedMetrics{ + ServerDuration: true, + ServerRequestSize: true, + RequiredAttrs: map[string]attribute.Value{ + "custom.label": attribute.StringValue("test-value"), + }, + }) + // Verify custom attributes do NOT appear in spans. + require.Len(t, spanRecorder.Ended(), 1) + for _, attr := range spanRecorder.Ended()[0].Attributes() { + assert.NotEqual(t, attribute.Key("custom.label"), attr.Key, + "span should not contain labeler attributes") + } +} + +func TestLabelerStreaming(t *testing.T) { + t.Parallel() + metricReader, meterProvider := setupMetrics() + spanRecorder := tracetest.NewSpanRecorder() + traceProvider := trace.NewTracerProvider(trace.WithSpanProcessor(spanRecorder)) + customAttrs := []attribute.KeyValue{ + attribute.String("custom.label", "stream-value"), + } + interceptor, err := NewInterceptor( + WithMeterProvider(meterProvider), + WithTracerProvider(traceProvider), + ) + require.NoError(t, err) + client, _, _ := startServer(t, + []connect.HandlerOption{ + connect.WithInterceptors(interceptor, labelerInterceptor{attrs: customAttrs}), + }, + nil, + okayPingServer(), + ) + stream := client.PingStream(context.Background()) + require.NoError(t, stream.Send(&pingv1.PingStreamRequest{ + Data: []byte("Hello, otel!"), + })) + _, err = stream.Receive() + require.NoError(t, err) + require.NoError(t, stream.CloseRequest()) + require.NoError(t, stream.CloseResponse()) + // Verify custom attributes appear in metrics (including per-message and final). + assertMetrics(t, metricReader, expectedMetrics{ + ServerDuration: true, + ServerRequestSize: true, + RequiredAttrs: map[string]attribute.Value{ + "custom.label": attribute.StringValue("stream-value"), + }, + }) + // Verify custom attributes do NOT appear in spans. + require.Len(t, spanRecorder.Ended(), 1) + for _, attr := range spanRecorder.Ended()[0].Attributes() { + assert.NotEqual(t, attribute.Key("custom.label"), attr.Key, + "span should not contain labeler attributes") + } +} + +func TestLabelerUnaryClient(t *testing.T) { + t.Parallel() + metricReader, meterProvider := setupMetrics() + spanRecorder := tracetest.NewSpanRecorder() + traceProvider := trace.NewTracerProvider(trace.WithSpanProcessor(spanRecorder)) + customAttrs := []attribute.KeyValue{ + attribute.String("custom.label", "client-value"), + } + interceptor, err := NewInterceptor( + WithMeterProvider(meterProvider), + WithTracerProvider(traceProvider), + ) + require.NoError(t, err) + client, _, _ := startServer(t, + nil, + []connect.ClientOption{ + connect.WithInterceptors(interceptor, labelerInterceptor{attrs: customAttrs}), + }, + okayPingServer(), + ) + _, err = client.Ping(context.Background(), requestOfSize(1, 12)) + require.NoError(t, err) + assertMetrics(t, metricReader, expectedMetrics{ + ClientDuration: true, + RequiredAttrs: map[string]attribute.Value{ + "custom.label": attribute.StringValue("client-value"), + }, + }) + require.Len(t, spanRecorder.Ended(), 1) + for _, attr := range spanRecorder.Ended()[0].Attributes() { + assert.NotEqual(t, attribute.Key("custom.label"), attr.Key, + "span should not contain labeler attributes") + } +} + +func TestLabelerStreamingClient(t *testing.T) { + t.Parallel() + metricReader, meterProvider := setupMetrics() + spanRecorder := tracetest.NewSpanRecorder() + traceProvider := trace.NewTracerProvider(trace.WithSpanProcessor(spanRecorder)) + customAttrs := []attribute.KeyValue{ + attribute.String("custom.label", "client-stream-value"), + } + interceptor, err := NewInterceptor( + WithMeterProvider(meterProvider), + WithTracerProvider(traceProvider), + ) + require.NoError(t, err) + client, _, _ := startServer(t, + nil, + []connect.ClientOption{ + connect.WithInterceptors(interceptor, labelerInterceptor{attrs: customAttrs}), + }, + okayPingServer(), + ) + stream := client.PingStream(context.Background()) + require.NoError(t, stream.Send(&pingv1.PingStreamRequest{ + Data: []byte("Hello, otel!"), + })) + _, err = stream.Receive() + require.NoError(t, err) + require.NoError(t, stream.CloseRequest()) + require.NoError(t, stream.CloseResponse()) + assertMetrics(t, metricReader, expectedMetrics{ + ClientDuration: true, + RequiredAttrs: map[string]attribute.Value{ + "custom.label": attribute.StringValue("client-stream-value"), + }, + }) + require.Len(t, spanRecorder.Ended(), 1) + for _, attr := range spanRecorder.Ended()[0].Attributes() { + assert.NotEqual(t, attribute.Key("custom.label"), attr.Key, + "span should not contain labeler attributes") + } +} diff --git a/labeler.go b/labeler.go new file mode 100644 index 0000000..396b2b7 --- /dev/null +++ b/labeler.go @@ -0,0 +1,66 @@ +// Copyright 2022-2025 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package otelconnect + +import ( + "context" + "slices" + "sync" + + "go.opentelemetry.io/otel/attribute" +) + +// Labeler is used to allow instrumented ConnectRPC handlers to add custom +// attributes to the metrics recorded by the instrumentation. +type Labeler struct { + mu sync.Mutex + attributes []attribute.KeyValue +} + +// Add attributes to a Labeler. +func (l *Labeler) Add(ls ...attribute.KeyValue) { + l.mu.Lock() + defer l.mu.Unlock() + l.attributes = append(l.attributes, ls...) +} + +// Get returns a copy of the attributes added to the Labeler. +func (l *Labeler) Get() []attribute.KeyValue { + l.mu.Lock() + defer l.mu.Unlock() + return slices.Clone(l.attributes) +} + +type labelerContextKey struct{} + +// ContextWithLabeler returns a new context with the provided Labeler instance. +// Attributes added to the specified labeler will be injected into metrics +// emitted by the instrumentation. Only one labeler can be injected into the +// context. Injecting it multiple times will override the previous calls. +func ContextWithLabeler(parent context.Context, l *Labeler) context.Context { + return context.WithValue(parent, labelerContextKey{}, l) +} + +// LabelerFromContext retrieves a Labeler instance from the provided context if +// one is available. If no Labeler was found in the provided context a new, empty +// Labeler is returned and the second return value is false. In this case it is +// safe to use the Labeler but any attributes added to it will not be used. +func LabelerFromContext(ctx context.Context) (*Labeler, bool) { + l, ok := ctx.Value(labelerContextKey{}).(*Labeler) + if !ok { + l = &Labeler{} + } + return l, ok +} diff --git a/streaming.go b/streaming.go index f1a9d01..992bf32 100644 --- a/streaming.go +++ b/streaming.go @@ -18,6 +18,7 @@ import ( "context" "errors" "io" + "slices" "sync" "connectrpc.com/connect" @@ -40,6 +41,7 @@ type streamingState struct { receivedCounter int64 receiveSize metric.Int64Histogram sendSize metric.Int64Histogram + labeler *Labeler } func newStreamingState( @@ -49,6 +51,7 @@ func newStreamingState( attributeFilter AttributeFilter, omitTraceEvents bool, receiveSize, sendSize metric.Int64Histogram, + labeler *Labeler, ) *streamingState { attributes := make([]attribute.KeyValue, 0, 6) // 5 max request attrs + status code attr attributes = attributeFilter.filter(spec, @@ -62,6 +65,7 @@ func newStreamingState( attributes: attributes, receiveSize: receiveSize, sendSize: sendSize, + labeler: labeler, } } @@ -74,6 +78,17 @@ func (s *streamingState) addAttributes(attributes ...attribute.KeyValue) { s.attributes = append(s.attributes, s.attributeFilter.filter(s.spec, attributes...)...) } +func (s *streamingState) metricAttributes() []attribute.KeyValue { + if s.labeler == nil { + return s.attributes + } + labelerAttrs := s.labeler.Get() + if len(labelerAttrs) == 0 { + return s.attributes + } + return slices.Concat(s.attributes, labelerAttrs) +} + func (s *streamingState) receive(ctx context.Context, msg any, conn sendReceiver) error { err := conn.Receive(msg) s.mu.Lock() @@ -95,7 +110,7 @@ func (s *streamingState) receive(ctx context.Context, msg any, conn sendReceiver if !s.omitTraceEvents { s.emitEvent(ctx, semconv.MessageTypeReceived, s.receivedCounter, size, ok) } - s.receiveSize.Record(ctx, int64(size), metric.WithAttributes(s.attributes...)) + s.receiveSize.Record(ctx, int64(size), metric.WithAttributes(s.metricAttributes()...)) return err } @@ -120,7 +135,7 @@ func (s *streamingState) send(ctx context.Context, msg any, conn sendReceiver) e if !s.omitTraceEvents { s.emitEvent(ctx, semconv.MessageTypeSent, s.sentCounter, size, ok) } - s.sendSize.Record(ctx, int64(size), metric.WithAttributes(s.attributes...)) + s.sendSize.Record(ctx, int64(size), metric.WithAttributes(s.metricAttributes()...)) return err }