Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"errors"
"fmt"
"slices"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -89,6 +90,10 @@ func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return next(ctx, request)
}
}
labeler, found := LabelerFromContext(ctx)
if !found {
ctx = ContextWithLabeler(ctx, labeler)
}
attributeFilter := i.config.filterAttribute.filter
isClient := request.Spec().IsClient
name := strings.TrimLeft(request.Spec().Procedure, "/")
Expand Down Expand Up @@ -175,7 +180,12 @@ func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
span.SetStatus(serverSpanStatus(protocol, err))
}
span.SetAttributes(attributes...)
attributesSet := attribute.NewSet(attributes...)
var attributesSet attribute.Set
if labelerAttrs := labeler.Get(); len(labelerAttrs) > 0 {
attributesSet = attribute.NewSet(slices.Concat(attributes, labelerAttrs)...)
} else {
attributesSet = attribute.NewSet(attributes...)
}
instrumentation.duration.Record(ctx, i.config.now().Sub(requestStartTime).Milliseconds(), metric.WithAttributeSet(attributesSet))
instrumentation.requestSize.Record(ctx, int64(requestSize), metric.WithAttributeSet(attributesSet))
instrumentation.requestsPerRPC.Record(ctx, 1, metric.WithAttributeSet(attributesSet))
Expand All @@ -193,6 +203,10 @@ func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) conn
return next(ctx, spec)
}
}
labeler, found := LabelerFromContext(ctx)
if !found {
ctx = ContextWithLabeler(ctx, labeler)
}
requestStartTime := i.config.now()
name := strings.TrimLeft(spec.Procedure, "/")
// Span is closed on context cancelation or when the stream is closed.
Expand All @@ -215,6 +229,7 @@ func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) conn
i.config.omitTraceEvents,
instrumentation.responseSize,
instrumentation.requestSize,
labeler,
)
var requestOnce sync.Once
setRequestAttributes := func() {
Expand Down Expand Up @@ -246,7 +261,7 @@ func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) conn
}
span.SetStatus(clientSpanStatus(protocol, state.error))
span.End()
attributeSet := attribute.NewSet(state.attributes...)
attributeSet := attribute.NewSet(state.metricAttributes()...)
instrumentation.requestsPerRPC.Record(ctx, state.sentCounter, metric.WithAttributeSet(attributeSet))
instrumentation.responsesPerRPC.Record(ctx, state.receivedCounter, metric.WithAttributeSet(attributeSet))
duration := i.config.now().Sub(requestStartTime).Milliseconds()
Expand Down Expand Up @@ -282,6 +297,10 @@ func (i *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) co
return next(ctx, conn)
}
}
labeler, found := LabelerFromContext(ctx)
if !found {
ctx = ContextWithLabeler(ctx, labeler)
}
Comment thread
emcfarlane marked this conversation as resolved.
name := strings.TrimLeft(conn.Spec().Procedure, "/")
protocol := protocolToSemConv(conn.Peer().Protocol, i.config.rpcSystem)
state := newStreamingState(
Expand All @@ -292,6 +311,7 @@ func (i *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) co
i.config.omitTraceEvents,
instrumentation.requestSize,
instrumentation.responseSize,
labeler,
)
// extract any request headers into the context
carrier := propagation.HeaderCarrier(conn.RequestHeader())
Expand Down Expand Up @@ -342,7 +362,7 @@ func (i *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) co
span.SetAttributes(headerAttributes(protocol, responseKey, conn.ResponseHeader(), i.config.responseHeaderKeys)...)
}
span.SetStatus(serverSpanStatus(protocol, err))
attributeSet := attribute.NewSet(state.attributes...)
attributeSet := attribute.NewSet(state.metricAttributes()...)
instrumentation.requestsPerRPC.Record(ctx, state.receivedCounter, metric.WithAttributeSet(attributeSet))
instrumentation.responsesPerRPC.Record(ctx, state.sentCounter, metric.WithAttributeSet(attributeSet))
duration := i.config.now().Sub(requestStartTime).Milliseconds()
Expand Down
202 changes: 202 additions & 0 deletions interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2608,3 +2608,205 @@ func assertUsableTraceparent(t *testing.T, header http.Header) {
assert.NotEmpty(t, spanContext.SpanID(), "span ID should not be empty")
assert.NotEmpty(t, spanContext.TraceID(), "trace ID should not be empty")
}

// labelerInterceptor is a test interceptor that retrieves the Labeler from
// context and adds custom attributes. Used to test that labeler attributes
// appear in metrics but not in spans.
type labelerInterceptor struct {
attrs []attribute.KeyValue
}

func (l labelerInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
labeler, _ := LabelerFromContext(ctx)
labeler.Add(l.attrs...)
return next(ctx, req)
}
}

func (l labelerInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
labeler, _ := LabelerFromContext(ctx)
labeler.Add(l.attrs...)
return next(ctx, spec)
}
}

func (l labelerInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
labeler, _ := LabelerFromContext(ctx)
labeler.Add(l.attrs...)
return next(ctx, conn)
}
}

func TestLabelerFromContext(t *testing.T) {
t.Parallel()
// LabelerFromContext on empty context returns a new Labeler and false.
labeler, ok := LabelerFromContext(context.Background())
assert.False(t, ok)
require.NotNil(t, labeler)
// Add and Get should not panic even though the labeler is not in a context.
labeler.Add(attribute.String("key", "value"))
got := labeler.Get()
assert.Len(t, got, 1)
assert.Equal(t, attribute.String("key", "value"), got[0])
}

func TestLabelerUnary(t *testing.T) {
t.Parallel()
metricReader, meterProvider := setupMetrics()
spanRecorder := tracetest.NewSpanRecorder()
traceProvider := trace.NewTracerProvider(trace.WithSpanProcessor(spanRecorder))
customAttrs := []attribute.KeyValue{
attribute.String("custom.label", "test-value"),
}
interceptor, err := NewInterceptor(
WithMeterProvider(meterProvider),
WithTracerProvider(traceProvider),
)
require.NoError(t, err)
client, _, _ := startServer(t,
[]connect.HandlerOption{
connect.WithInterceptors(interceptor, labelerInterceptor{attrs: customAttrs}),
},
nil,
okayPingServer(),
)
_, err = client.Ping(context.Background(), requestOfSize(1, 12))
require.NoError(t, err)
// Verify custom attributes appear in metrics.
assertMetrics(t, metricReader, expectedMetrics{
ServerDuration: true,
ServerRequestSize: true,
RequiredAttrs: map[string]attribute.Value{
"custom.label": attribute.StringValue("test-value"),
},
})
// Verify custom attributes do NOT appear in spans.
require.Len(t, spanRecorder.Ended(), 1)
for _, attr := range spanRecorder.Ended()[0].Attributes() {
assert.NotEqual(t, attribute.Key("custom.label"), attr.Key,
"span should not contain labeler attributes")
}
}

func TestLabelerStreaming(t *testing.T) {
t.Parallel()
metricReader, meterProvider := setupMetrics()
spanRecorder := tracetest.NewSpanRecorder()
traceProvider := trace.NewTracerProvider(trace.WithSpanProcessor(spanRecorder))
customAttrs := []attribute.KeyValue{
attribute.String("custom.label", "stream-value"),
}
interceptor, err := NewInterceptor(
WithMeterProvider(meterProvider),
WithTracerProvider(traceProvider),
)
require.NoError(t, err)
client, _, _ := startServer(t,
[]connect.HandlerOption{
connect.WithInterceptors(interceptor, labelerInterceptor{attrs: customAttrs}),
},
nil,
okayPingServer(),
)
stream := client.PingStream(context.Background())
require.NoError(t, stream.Send(&pingv1.PingStreamRequest{
Data: []byte("Hello, otel!"),
}))
_, err = stream.Receive()
require.NoError(t, err)
require.NoError(t, stream.CloseRequest())
require.NoError(t, stream.CloseResponse())
// Verify custom attributes appear in metrics (including per-message and final).
assertMetrics(t, metricReader, expectedMetrics{
ServerDuration: true,
ServerRequestSize: true,
RequiredAttrs: map[string]attribute.Value{
"custom.label": attribute.StringValue("stream-value"),
},
})
// Verify custom attributes do NOT appear in spans.
require.Len(t, spanRecorder.Ended(), 1)
for _, attr := range spanRecorder.Ended()[0].Attributes() {
assert.NotEqual(t, attribute.Key("custom.label"), attr.Key,
"span should not contain labeler attributes")
}
}

func TestLabelerUnaryClient(t *testing.T) {
t.Parallel()
metricReader, meterProvider := setupMetrics()
spanRecorder := tracetest.NewSpanRecorder()
traceProvider := trace.NewTracerProvider(trace.WithSpanProcessor(spanRecorder))
customAttrs := []attribute.KeyValue{
attribute.String("custom.label", "client-value"),
}
interceptor, err := NewInterceptor(
WithMeterProvider(meterProvider),
WithTracerProvider(traceProvider),
)
require.NoError(t, err)
client, _, _ := startServer(t,
nil,
[]connect.ClientOption{
connect.WithInterceptors(interceptor, labelerInterceptor{attrs: customAttrs}),
},
okayPingServer(),
)
_, err = client.Ping(context.Background(), requestOfSize(1, 12))
require.NoError(t, err)
assertMetrics(t, metricReader, expectedMetrics{
ClientDuration: true,
RequiredAttrs: map[string]attribute.Value{
"custom.label": attribute.StringValue("client-value"),
},
})
require.Len(t, spanRecorder.Ended(), 1)
for _, attr := range spanRecorder.Ended()[0].Attributes() {
assert.NotEqual(t, attribute.Key("custom.label"), attr.Key,
"span should not contain labeler attributes")
}
}

func TestLabelerStreamingClient(t *testing.T) {
t.Parallel()
metricReader, meterProvider := setupMetrics()
spanRecorder := tracetest.NewSpanRecorder()
traceProvider := trace.NewTracerProvider(trace.WithSpanProcessor(spanRecorder))
customAttrs := []attribute.KeyValue{
attribute.String("custom.label", "client-stream-value"),
}
interceptor, err := NewInterceptor(
WithMeterProvider(meterProvider),
WithTracerProvider(traceProvider),
)
require.NoError(t, err)
client, _, _ := startServer(t,
nil,
[]connect.ClientOption{
connect.WithInterceptors(interceptor, labelerInterceptor{attrs: customAttrs}),
},
okayPingServer(),
)
stream := client.PingStream(context.Background())
require.NoError(t, stream.Send(&pingv1.PingStreamRequest{
Data: []byte("Hello, otel!"),
}))
_, err = stream.Receive()
require.NoError(t, err)
require.NoError(t, stream.CloseRequest())
require.NoError(t, stream.CloseResponse())
assertMetrics(t, metricReader, expectedMetrics{
ClientDuration: true,
RequiredAttrs: map[string]attribute.Value{
"custom.label": attribute.StringValue("client-stream-value"),
},
})
require.Len(t, spanRecorder.Ended(), 1)
for _, attr := range spanRecorder.Ended()[0].Attributes() {
assert.NotEqual(t, attribute.Key("custom.label"), attr.Key,
"span should not contain labeler attributes")
}
}
66 changes: 66 additions & 0 deletions labeler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright 2022-2025 The Connect Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package otelconnect

import (
"context"
"slices"
"sync"

"go.opentelemetry.io/otel/attribute"
)

// Labeler is used to allow instrumented ConnectRPC handlers to add custom
// attributes to the metrics recorded by the instrumentation.
type Labeler struct {
mu sync.Mutex
attributes []attribute.KeyValue
}

// Add attributes to a Labeler.
func (l *Labeler) Add(ls ...attribute.KeyValue) {
l.mu.Lock()
defer l.mu.Unlock()
l.attributes = append(l.attributes, ls...)
}

// Get returns a copy of the attributes added to the Labeler.
func (l *Labeler) Get() []attribute.KeyValue {
l.mu.Lock()
defer l.mu.Unlock()
return slices.Clone(l.attributes)
}

type labelerContextKey struct{}

// ContextWithLabeler returns a new context with the provided Labeler instance.
// Attributes added to the specified labeler will be injected into metrics
// emitted by the instrumentation. Only one labeler can be injected into the
// context. Injecting it multiple times will override the previous calls.
func ContextWithLabeler(parent context.Context, l *Labeler) context.Context {
return context.WithValue(parent, labelerContextKey{}, l)
}

// LabelerFromContext retrieves a Labeler instance from the provided context if
// one is available. If no Labeler was found in the provided context a new, empty
// Labeler is returned and the second return value is false. In this case it is
// safe to use the Labeler but any attributes added to it will not be used.
func LabelerFromContext(ctx context.Context) (*Labeler, bool) {
l, ok := ctx.Value(labelerContextKey{}).(*Labeler)
if !ok {
l = &Labeler{}
}
return l, ok
}
Loading