From 0510a618c2fae0bcec83d719fcbc0fb449b2423d Mon Sep 17 00:00:00 2001 From: Ignasi Barrera Date: Sun, 12 Oct 2025 17:42:53 +0200 Subject: [PATCH 1/7] mpc: add configured header attribtues to metrics and spans Signed-off-by: Ignasi Barrera --- cmd/extproc/mainlib/main.go | 8 +-- internal/mcpproxy/handlers.go | 39 ++++++++------ internal/mcpproxy/handlers_test.go | 6 +-- internal/mcpproxy/mcpproxy.go | 29 +++++++--- internal/mcpproxy/mcpproxy_test.go | 11 ++-- internal/mcpproxy/session_test.go | 3 +- internal/metrics/mcp_metrics.go | 81 ++++++++++++++++++++-------- internal/metrics/mcp_metrics_test.go | 38 ++++++++++--- internal/tracing/api/mcp.go | 5 +- internal/tracing/mcp.go | 22 ++++++-- internal/tracing/mcp_test.go | 18 +++++-- internal/tracing/tracing.go | 2 +- 12 files changed, 181 insertions(+), 81 deletions(-) diff --git a/cmd/extproc/mainlib/main.go b/cmd/extproc/mainlib/main.go index 9453b09db6..d9e493d0b6 100644 --- a/cmd/extproc/mainlib/main.go +++ b/cmd/extproc/mainlib/main.go @@ -233,7 +233,7 @@ func Main(ctx context.Context, args []string, stderr io.Writer) (err error) { chatCompletionMetrics := metrics.NewChatCompletion(meter, metricsRequestHeaderAttributes) completionMetrics := metrics.NewCompletion(meter, metricsRequestHeaderAttributes) embeddingsMetrics := metrics.NewEmbeddings(meter, metricsRequestHeaderAttributes) - mcpMetrics := metrics.NewMCP(meter) + mcpMetrics := metrics.NewMCP(meter, metricsRequestHeaderAttributes) tracing, err := tracing.NewTracingFromEnv(ctx, os.Stdout, spanRequestHeaderAttributes) if err != nil { @@ -264,13 +264,13 @@ func Main(ctx context.Context, args []string, stderr io.Writer) (err error) { seed, fallbackSeed, _ := strings.Cut(flags.mcpSessionEncryptionSeed, ",") mcpSessionCrypto := mcpproxy.DefaultSessionCrypto(seed, fallbackSeed) var mcpProxyMux *http.ServeMux - var mcpProxy *mcpproxy.MCPProxy - mcpProxy, mcpProxyMux, err = mcpproxy.NewMCPProxy(l.With("component", "mcp-proxy"), mcpMetrics, + var mcpProxyConfig *mcpproxy.ProxyConfig + mcpProxyConfig, mcpProxyMux, err = mcpproxy.NewMCPProxy(l.With("component", "mcp-proxy"), mcpMetrics, tracing.MCPTracer(), mcpSessionCrypto) if err != nil { return fmt.Errorf("failed to create MCP proxy: %w", err) } - if err = extproc.StartConfigWatcher(ctx, flags.configPath, mcpProxy, l, time.Second*5); err != nil { + if err = extproc.StartConfigWatcher(ctx, flags.configPath, mcpProxyConfig, l, time.Second*5); err != nil { return fmt.Errorf("failed to start config watcher: %w", err) } diff --git a/internal/mcpproxy/handlers.go b/internal/mcpproxy/handlers.go index d9e9b44eb7..69874c9176 100644 --- a/internal/mcpproxy/handlers.go +++ b/internal/mcpproxy/handlers.go @@ -190,7 +190,7 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { switch msg.Method { case "notifications/roots/list_changed": p := &mcp.RootsListChangedParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") @@ -199,7 +199,7 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { err = m.handleNotificationsRootsListChanged(ctx, s, w, msg, span) case "completion/complete": p := &mcp.CompleteParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") @@ -209,7 +209,7 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { case "notifications/progress": m.metrics.RecordProgress(ctx) p := &mcp.ProgressNotificationParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") @@ -219,7 +219,7 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { case "initialize": // The very first request from the client to establish a session. p := &mcp.InitializeParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam m.l.Error("Failed to unmarshal initialize params", slog.String("error", err.Error())) @@ -243,7 +243,7 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusAccepted) case "logging/setLevel": p := &mcp.SetLoggingLevelParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam m.l.Error("Failed to unmarshal set logging level params", slog.String("error", err.Error())) @@ -256,7 +256,7 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { err = m.handlePing(ctx, w, msg) case "prompts/list": p := &mcp.ListPromptsParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") @@ -265,7 +265,7 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { err = m.handlePromptListRequest(ctx, s, w, msg, p, span) case "prompts/get": p := &mcp.GetPromptParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") @@ -274,7 +274,7 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { err = m.handlePromptGetRequest(ctx, s, w, msg, p) case "tools/call": p := &mcp.CallToolParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam m.l.Error("Failed to unmarshal params", slog.String("method", msg.Method), slog.String("error", err.Error())) @@ -284,7 +284,7 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { err = m.handleToolCallRequest(ctx, s, w, msg, p, span) case "tools/list": p := &mcp.ListToolsParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") @@ -293,7 +293,7 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { err = m.handleToolsListRequest(ctx, s, w, msg, p, span) case "resources/list": p := &mcp.ListResourcesParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") @@ -302,7 +302,7 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { err = m.handleResourceListRequest(ctx, s, w, msg, p, span) case "resources/read": p := &mcp.ReadResourceParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") @@ -311,7 +311,7 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { err = m.handleResourceReadRequest(ctx, s, w, msg, p) case "resources/templates/list": p := &mcp.ListResourceTemplatesParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") @@ -320,7 +320,7 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { err = m.handleResourcesTemplatesListRequest(ctx, s, w, msg, p, span) case "resources/subscribe": p := &mcp.SubscribeParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") @@ -329,7 +329,7 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { err = m.handleResourcesSubscribeRequest(ctx, s, w, msg, p, span) case "resources/unsubscribe": p := &mcp.UnsubscribeParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") @@ -372,7 +372,6 @@ func errorType(err error) metrics.MCPErrorType { // handleInitializeRequest handles the "initialize" JSON-RPC method. func (m *MCPProxy) handleInitializeRequest(ctx context.Context, w http.ResponseWriter, req *jsonrpc.Request, p *mcp.InitializeParams, route, subject string, span tracing.MCPSpan) error { m.metrics.RecordClientCapabilities(ctx, p.Capabilities) - s, err := m.newSession(ctx, p, route, subject, span) if err != nil { m.l.Error("failed to create new session", slog.String("error", err.Error())) @@ -1223,7 +1222,7 @@ func sendToAllBackendsAndAggregateResponsesImpl[responseType any](ctx context.Co } // parseParamsAndMaybeStartSpan parses the params from the JSON-RPC request and starts a tracing span if params is non-nil. -func parseParamsAndMaybeStartSpan[paramType mcp.Params](ctx context.Context, m *MCPProxy, req *jsonrpc.Request, p paramType) (tracing.MCPSpan, error) { +func parseParamsAndMaybeStartSpan[paramType mcp.Params](ctx context.Context, m *MCPProxy, req *jsonrpc.Request, p paramType, headers http.Header) (tracing.MCPSpan, error) { if req.Params == nil { return nil, nil } @@ -1233,7 +1232,13 @@ func parseParamsAndMaybeStartSpan[paramType mcp.Params](ctx context.Context, m * return nil, err } - span := m.tracer.StartSpanAndInjectMeta(ctx, req, p) + // TODO: headers with multiple values + headerMap := make(map[string]string, len(headers)) + for header := range headers { + headerMap[header] = headers.Get(header) + } + + span := m.tracer.StartSpanAndInjectMeta(ctx, req, p, headerMap) return span, nil } diff --git a/internal/mcpproxy/handlers_test.go b/internal/mcpproxy/handlers_test.go index 747293f1b3..e6fec797c0 100644 --- a/internal/mcpproxy/handlers_test.go +++ b/internal/mcpproxy/handlers_test.go @@ -71,7 +71,7 @@ func newTestMCPProxyWithTracer(t tracingapi.MCPTracer) *MCPProxy { func newTestMCPProxyWithOTEL(mr *sdkmetric.ManualReader, tracer tracingapi.MCPTracer) *MCPProxy { mcpProxy := newTestMCPProxyWithTracer(tracer) meter := sdkmetric.NewMeterProvider(sdkmetric.WithReader(mr)).Meter("test") - mcpProxy.metrics = metrics.NewMCP(meter) + mcpProxy.metrics = metrics.NewMCP(meter, nil) return mcpProxy } @@ -1598,7 +1598,7 @@ func Test_parseParamsAndMaybeStartSpan(t *testing.T) { trace, err := tracing.NewTracingFromEnv(t.Context(), t.Output(), nil) require.NoError(t, err) m.tracer = trace.MCPTracer() - s, err := parseParamsAndMaybeStartSpan(t.Context(), m, req, p) + s, err := parseParamsAndMaybeStartSpan(t.Context(), m, req, p, nil) require.NoError(t, err) require.NotNil(t, s) // Make sure that traceparent is not empty, that's span started. @@ -1611,7 +1611,7 @@ func Test_parseParamsAndMaybeStartSpan_NilParam(t *testing.T) { } p := &mcp.GetPromptParams{} m := newTestMCPProxy() - s, err := parseParamsAndMaybeStartSpan(t.Context(), m, req, p) + s, err := parseParamsAndMaybeStartSpan(t.Context(), m, req, p, nil) require.NoError(t, err) require.Nil(t, s) } diff --git a/internal/mcpproxy/mcpproxy.go b/internal/mcpproxy/mcpproxy.go index edf639aa0b..2ef7f96081 100644 --- a/internal/mcpproxy/mcpproxy.go +++ b/internal/mcpproxy/mcpproxy.go @@ -29,6 +29,11 @@ import ( ) type ( + // ProxyConfig holds the main MCP proxy configuration. + ProxyConfig struct { + *mcpProxyConfig + } + // MCPProxy serves /mcp endpoint. // // This implements [extproc.ConfigReceiver] to gets the up-to-date configuration. @@ -77,8 +82,8 @@ func (f *toolSelector) allows(tool string) bool { } // NewMCPProxy creates a new MCPProxy instance. -func NewMCPProxy(l *slog.Logger, mcpMetrics metrics.MCPMetrics, tracer tracing.MCPTracer, sessionCrypto SessionCrypto) (*MCPProxy, *http.ServeMux, error) { - p := &MCPProxy{l: l, metrics: mcpMetrics, tracer: tracer, sessionCrypto: sessionCrypto} +func NewMCPProxy(l *slog.Logger, mcpMetrics metrics.MCPMetrics, tracer tracing.MCPTracer, sessionCrypto SessionCrypto) (*ProxyConfig, *http.ServeMux, error) { + cfg := &ProxyConfig{} mux := http.NewServeMux() mux.HandleFunc( // Must match all paths since the route selection happens at Envoy level and the "route" header is already @@ -87,23 +92,31 @@ func NewMCPProxy(l *slog.Logger, mcpMetrics metrics.MCPMetrics, tracer tracing.M // For example, if we mistakenly set /mcp here, only the route with prefix /mcp will be matched, and other routes // with different prefixes will not be matched, which is not what we want. "/", func(w http.ResponseWriter, r *http.Request) { + proxy := &MCPProxy{ + mcpProxyConfig: cfg.mcpProxyConfig, + l: l, + metrics: mcpMetrics.WithRequestAttributes(r), + tracer: tracer, + sessionCrypto: sessionCrypto, + } + switch r.Method { case http.MethodGet: - p.serveGET(w, r) + proxy.serveGET(w, r) case http.MethodPost: - p.servePOST(w, r) + proxy.servePOST(w, r) case http.MethodDelete: - p.serverDELETE(w, r) + proxy.serverDELETE(w, r) default: http.Error(w, "method not allowed", http.StatusMethodNotAllowed) } }) - return p, mux, nil + return cfg, mux, nil } // LoadConfig implements [extproc.ConfigReceiver.LoadConfig] which will be called // when the configuration is updated on the file system. -func (m *MCPProxy) LoadConfig(_ context.Context, config *filterapi.Config) error { +func (p *ProxyConfig) LoadConfig(_ context.Context, config *filterapi.Config) error { newConfig := &mcpProxyConfig{} mcpConfig := config.MCPConfig if config.MCPConfig == nil { @@ -145,7 +158,7 @@ func (m *MCPProxy) LoadConfig(_ context.Context, config *filterapi.Config) error newConfig.routes[route.Name] = r } - m.mcpProxyConfig = newConfig // This is racy, but we don't care. + p.mcpProxyConfig = newConfig // This is racy, but we don't care. return nil } diff --git a/internal/mcpproxy/mcpproxy_test.go b/internal/mcpproxy/mcpproxy_test.go index 1a78b2c8e9..e377bf13e2 100644 --- a/internal/mcpproxy/mcpproxy_test.go +++ b/internal/mcpproxy/mcpproxy_test.go @@ -50,7 +50,7 @@ type fakeTracer struct { span *fakeSpan } -func (f *fakeTracer) StartSpanAndInjectMeta(_ context.Context, _ *jsonrpc.Request, _ mcp.Params) tracing.MCPSpan { +func (f *fakeTracer) StartSpanAndInjectMeta(context.Context, *jsonrpc.Request, mcp.Params, map[string]string) tracing.MCPSpan { if f.span == nil { f.span = &fakeSpan{} } @@ -66,8 +66,6 @@ func TestNewMCPProxy(t *testing.T) { require.NoError(t, err) require.NotNil(t, proxy) require.NotNil(t, mux) - require.Equal(t, l, proxy.l) - require.NotNil(t, proxy.metrics) } func TestMCPProxy_HTTPMethods(t *testing.T) { @@ -86,11 +84,12 @@ func TestMCPProxy_HTTPMethods(t *testing.T) { } func TestLoadConfig_NilMCPConfig(t *testing.T) { - proxy := newTestMCPProxy() - config := &filterapi.Config{MCPConfig: nil} + proxy, _, err := NewMCPProxy(slog.Default(), stubMetrics{}, noopTracer, DefaultSessionCrypto("test")) + require.NoError(t, err) - err := proxy.LoadConfig(t.Context(), config) + config := &filterapi.Config{MCPConfig: nil} + err = proxy.LoadConfig(t.Context(), config) require.NoError(t, err) } diff --git a/internal/mcpproxy/session_test.go b/internal/mcpproxy/session_test.go index e34f503d85..d7f07d320b 100644 --- a/internal/mcpproxy/session_test.go +++ b/internal/mcpproxy/session_test.go @@ -30,7 +30,8 @@ import ( // stubMetrics implements metrics.MCPMetrics with no-ops. type stubMetrics struct{} -func (stubMetrics) RecordRequestDuration(_ context.Context, _ *time.Time) {} +func (s stubMetrics) WithRequestAttributes(_ *http.Request) metrics.MCPMetrics { return s } +func (stubMetrics) RecordRequestDuration(_ context.Context, _ *time.Time) {} func (stubMetrics) RecordRequestErrorDuration(_ context.Context, _ *time.Time, _ metrics.MCPErrorType) { } func (stubMetrics) RecordMethodCount(_ context.Context, _ string) {} diff --git a/internal/metrics/mcp_metrics.go b/internal/metrics/mcp_metrics.go index 86eb1cff64..ce9534faac 100644 --- a/internal/metrics/mcp_metrics.go +++ b/internal/metrics/mcp_metrics.go @@ -7,6 +7,7 @@ package metrics import ( "context" + "net/http" "time" mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp" @@ -102,6 +103,8 @@ const ( // MCPMetrics holds metrics for MCP. type MCPMetrics interface { + // WithRequestAttributes returns a new MCPMetrics instance with default attributes extracted from the HTTP request. + WithRequestAttributes(req *http.Request) MCPMetrics // RecordRequestDuration records the duration of a success MCP request. RecordRequestDuration(ctx context.Context, startAt *time.Time) // RecordRequestErrorDuration records the duration of an MCP request that resulted in an error. @@ -121,16 +124,19 @@ type MCPMetrics interface { } type mcp struct { - requestDuration metric.Float64Histogram - methodCount metric.Float64Counter - initializationDuration metric.Float64Histogram - capabilitiesNegotiated metric.Float64Counter - progressNotifications metric.Float64Counter + requestDuration metric.Float64Histogram + methodCount metric.Float64Counter + initializationDuration metric.Float64Histogram + capabilitiesNegotiated metric.Float64Counter + progressNotifications metric.Float64Counter + requestHeaderLabelMapping map[string]string // maps HTTP headers to metric label names. + defaultAttributes []attribute.KeyValue } // NewMCP creates a new mcp metrics instance. -func NewMCP(meter metric.Meter) MCPMetrics { +func NewMCP(meter metric.Meter, requestHeaderLabelMapping map[string]string) MCPMetrics { return &mcp{ + requestHeaderLabelMapping: requestHeaderLabelMapping, requestDuration: mustRegisterHistogram(meter, mcpRequestDuration, metric.WithDescription("Duration of MCP requests"), @@ -159,13 +165,38 @@ func NewMCP(meter metric.Meter) MCPMetrics { } } +// WithRequestAttributes returns a new MCPMetrics instance with default attributes extracted from +// the HTTP request headers. +func (m *mcp) WithRequestAttributes(req *http.Request) MCPMetrics { + withAttrs := &mcp{ + requestDuration: m.requestDuration, + methodCount: m.methodCount, + initializationDuration: m.initializationDuration, + capabilitiesNegotiated: m.capabilitiesNegotiated, + progressNotifications: m.progressNotifications, + requestHeaderLabelMapping: m.requestHeaderLabelMapping, + } + + // Apply header-to-attribute mapping if configured. + for headerName, attrName := range m.requestHeaderLabelMapping { + if headerValue := req.Header.Get(headerName); headerValue != "" { + withAttrs.defaultAttributes = append( + withAttrs.defaultAttributes, + attribute.String(attrName, headerValue), + ) + } + } + + return withAttrs +} + // RecordMethodCount implements [MCPMetrics.RecordMethodCount]. func (m *mcp) RecordMethodCount(ctx context.Context, methodName string) { if methodName == "" { return } m.methodCount.Add(ctx, 1, - metric.WithAttributes( + m.withDefaultAttributes( attribute.Key(mcpAttributeMethodName).String(methodName), attribute.String(mcpAttributeStatusName, string(mcpStatusSuccess)), )) @@ -174,7 +205,7 @@ func (m *mcp) RecordMethodCount(ctx context.Context, methodName string) { // RecordMethodErrorCount implements [MCPMetrics.RecordMethodErrorCount]. func (m *mcp) RecordMethodErrorCount(ctx context.Context) { m.methodCount.Add(ctx, 1, - metric.WithAttributes( + m.withDefaultAttributes( attribute.String(mcpAttributeStatusName, string(mcpStatusError)), )) } @@ -184,9 +215,8 @@ func (m *mcp) RecordRequestDuration(ctx context.Context, startAt *time.Time) { if startAt == nil { return } - duration := time.Since(*startAt).Seconds() - m.requestDuration.Record(ctx, duration) + m.requestDuration.Record(ctx, duration, m.withDefaultAttributes()) } // RecordRequestErrorDuration implements [MCPMetrics.RecordRequestErrorDuration]. @@ -196,7 +226,7 @@ func (m *mcp) RecordRequestErrorDuration(ctx context.Context, startAt *time.Time } duration := time.Since(*startAt).Seconds() - m.requestDuration.Record(ctx, duration, metric.WithAttributes( + m.requestDuration.Record(ctx, duration, m.withDefaultAttributes( attribute.Key(mcpAttributeErrorType).String(string(errType)), )) } @@ -207,12 +237,12 @@ func (m *mcp) RecordInitializationDuration(ctx context.Context, startAt *time.Ti return } duration := time.Since(*startAt).Seconds() - m.initializationDuration.Record(ctx, duration) + m.initializationDuration.Record(ctx, duration, m.withDefaultAttributes()) } // RecordProgress implements [MCPMetrics.RecordProgress]. func (m *mcp) RecordProgress(ctx context.Context) { - m.progressNotifications.Add(ctx, 1) + m.progressNotifications.Add(ctx, 1, m.withDefaultAttributes()) } // RecordClientCapabilities implements [MCPMetrics.RecordClientCapabilities]. @@ -223,28 +253,28 @@ func (m *mcp) RecordClientCapabilities(ctx context.Context, capabilities *mcpsdk side := string(mcpCapabilitySideClient) if l := len(capabilities.Experimental); l > 0 { - m.capabilitiesNegotiated.Add(ctx, float64(l), metric.WithAttributes( + m.capabilitiesNegotiated.Add(ctx, float64(l), m.withDefaultAttributes( attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeExperimental)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } if capabilities.Roots.ListChanged { - m.capabilitiesNegotiated.Add(ctx, 1, metric.WithAttributes( + m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes( attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeRoots)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } if capabilities.Sampling != nil { - m.capabilitiesNegotiated.Add(ctx, 1, metric.WithAttributes( + m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes( attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeSampling)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } if capabilities.Elicitation != nil { - m.capabilitiesNegotiated.Add(ctx, 1, metric.WithAttributes( + m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes( attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeElicitation)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) @@ -259,44 +289,49 @@ func (m *mcp) RecordServerCapabilities(ctx context.Context, serverCapa *mcpsdk.S side := string(mcpCapabilitySideServer) if serverCapa.Completions != nil { - m.capabilitiesNegotiated.Add(ctx, 1, metric.WithAttributes( + m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes( attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeCompletions)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } if l := len(serverCapa.Experimental); l > 0 { - m.capabilitiesNegotiated.Add(ctx, float64(l), metric.WithAttributes( + m.capabilitiesNegotiated.Add(ctx, float64(l), m.withDefaultAttributes( attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeExperimental)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } if serverCapa.Logging != nil { - m.capabilitiesNegotiated.Add(ctx, 1, metric.WithAttributes( + m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes( attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeLogging)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } if serverCapa.Prompts != nil { - m.capabilitiesNegotiated.Add(ctx, 1, metric.WithAttributes( + m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes( attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypePrompts)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } if serverCapa.Resources != nil { - m.capabilitiesNegotiated.Add(ctx, 1, metric.WithAttributes( + m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes( attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeResources)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } if serverCapa.Tools != nil { - m.capabilitiesNegotiated.Add(ctx, 1, metric.WithAttributes( + m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes( attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeTools)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } } + +// withDefaultAttributes appends default attributes to the provided attributes. +func (m *mcp) withDefaultAttributes(attrs ...attribute.KeyValue) metric.MeasurementOption { + return metric.WithAttributes(append(m.defaultAttributes, attrs...)...) +} diff --git a/internal/metrics/mcp_metrics_test.go b/internal/metrics/mcp_metrics_test.go index 82e8dda0b7..50a3873e51 100644 --- a/internal/metrics/mcp_metrics_test.go +++ b/internal/metrics/mcp_metrics_test.go @@ -6,6 +6,7 @@ package metrics import ( + "net/http" "testing" "time" @@ -21,15 +22,38 @@ func TestNewMCP(t *testing.T) { mr := metric.NewManualReader() meter := metric.NewMeterProvider(metric.WithReader(mr)).Meter("test") - m := NewMCP(meter) + m := NewMCP(meter, nil) require.NotNil(t, m) } +func TestRecordMetricWithCustomAttributes(t *testing.T) { + mr := metric.NewManualReader() + meter := metric.NewMeterProvider(metric.WithReader(mr)).Meter("test") + + m := NewMCP(meter, map[string]string{"x-custom-attr": "attr.custom"}) + require.NotNil(t, m) + + req, err := http.NewRequest("GET", "https://example.com", nil) + require.NoError(t, err) + req.Header.Set("x-custom-attr", "test") // should be included in metrics + req.Header.Set("x-other-attr", "other") // should be ignored + + m = m.WithRequestAttributes(req) + + startAt := time.Now().Add(-1 * time.Minute) + m.RecordRequestDuration(t.Context(), &startAt) + + count, sum := testotel.GetHistogramValues(t, mr, mcpRequestDuration, + attribute.NewSet(attribute.String("attr.custom", "test"))) + require.Equal(t, uint64(1), count) + require.Equal(t, 60, int(sum)) +} + func TestRecordRequestDuration(t *testing.T) { mr := metric.NewManualReader() meter := metric.NewMeterProvider(metric.WithReader(mr)).Meter("test") - m := NewMCP(meter) + m := NewMCP(meter, nil) require.NotNil(t, m) startAt := time.Now().Add(-1 * time.Minute) m.RecordRequestDuration(t.Context(), &startAt) @@ -43,7 +67,7 @@ func TestRecordRequestErrorDuration(t *testing.T) { mr := metric.NewManualReader() meter := metric.NewMeterProvider(metric.WithReader(mr)).Meter("test") - m := NewMCP(meter) + m := NewMCP(meter, nil) require.NotNil(t, m) startAt := time.Now().Add(-30 * time.Second) m.RecordRequestErrorDuration(t.Context(), &startAt, MCPErrorUnsupportedProtocolVersion) @@ -59,7 +83,7 @@ func TestRecordMethodCount(t *testing.T) { mr := metric.NewManualReader() meter := metric.NewMeterProvider(metric.WithReader(mr)).Meter("test") - m := NewMCP(meter) + m := NewMCP(meter, nil) require.NotNil(t, m) m.RecordMethodCount(t.Context(), "test_method_name") @@ -82,7 +106,7 @@ func TestRecordInitializationDuration(t *testing.T) { mr := metric.NewManualReader() meter := metric.NewMeterProvider(metric.WithReader(mr)).Meter("test") - m := NewMCP(meter) + m := NewMCP(meter, nil) require.NotNil(t, m) startAt := time.Now().Add(-45 * time.Second) @@ -97,7 +121,7 @@ func TestRecordCapabilitiesNegotiated(t *testing.T) { mr := metric.NewManualReader() meter := metric.NewMeterProvider(metric.WithReader(mr)).Meter("test") - m := NewMCP(meter) + m := NewMCP(meter, nil) require.NotNil(t, m) m.RecordClientCapabilities(t.Context(), &mcpsdk.ClientCapabilities{ @@ -152,7 +176,7 @@ func TestRecordProgressNotifications(t *testing.T) { mr := metric.NewManualReader() meter := metric.NewMeterProvider(metric.WithReader(mr)).Meter("test") - m := NewMCP(meter) + m := NewMCP(meter, nil) require.NotNil(t, m) m.RecordProgress(t.Context()) diff --git a/internal/tracing/api/mcp.go b/internal/tracing/api/mcp.go index dbcf5716a4..9082c6f774 100644 --- a/internal/tracing/api/mcp.go +++ b/internal/tracing/api/mcp.go @@ -23,9 +23,10 @@ type MCPTracer interface { // - ctx: might include a parent span context. // - req: Incoming MCP request message. // - param: Incoming MCP parameter used to extract parent trace context. + // - headers: Request HTTP request headers. // // Returns nil unless the span is sampled. - StartSpanAndInjectMeta(ctx context.Context, req *jsonrpc.Request, param mcp.Params) MCPSpan + StartSpanAndInjectMeta(ctx context.Context, req *jsonrpc.Request, param mcp.Params, headers map[string]string) MCPSpan } // MCPSpan represents an MCP span. @@ -45,6 +46,6 @@ var _ MCPTracer = NoopMCPTracer{} type NoopMCPTracer struct{} // StartSpanAndInjectMeta implements [MCPTracer.StartSpanAndInjectMeta]. -func (NoopMCPTracer) StartSpanAndInjectMeta(_ context.Context, _ *jsonrpc.Request, _ mcp.Params) MCPSpan { +func (NoopMCPTracer) StartSpanAndInjectMeta(context.Context, *jsonrpc.Request, mcp.Params, map[string]string) MCPSpan { return nil } diff --git a/internal/tracing/mcp.go b/internal/tracing/mcp.go index 3e66ec0a99..eade1489cf 100644 --- a/internal/tracing/mcp.go +++ b/internal/tracing/mcp.go @@ -57,16 +57,21 @@ func (s mcpSpan) EndSpan() { // mcpTracer is an implementation of [tracing.MCPTracer]. type mcpTracer struct { - tracer trace.Tracer - propagator propagation.TextMapPropagator + tracer trace.Tracer + propagator propagation.TextMapPropagator + headerAttributes map[string]string } -func newMCPTracer(tracer trace.Tracer, propagator propagation.TextMapPropagator) tracing.MCPTracer { - return mcpTracer{tracer: tracer, propagator: propagator} +func newMCPTracer(tracer trace.Tracer, propagator propagation.TextMapPropagator, headerAttributes map[string]string) tracing.MCPTracer { + return mcpTracer{ + tracer: tracer, + propagator: propagator, + headerAttributes: headerAttributes, + } } // StartSpanAndInjectMeta implements [tracing.MCPTracer.StartSpanAndInjectMeta]. -func (m mcpTracer) StartSpanAndInjectMeta(ctx context.Context, req *jsonrpc.Request, param mcp.Params) tracing.MCPSpan { +func (m mcpTracer) StartSpanAndInjectMeta(ctx context.Context, req *jsonrpc.Request, param mcp.Params, headers map[string]string) tracing.MCPSpan { attrs := []attribute.KeyValue{ attribute.String("mcp.protocol.version", "2025-06-18"), attribute.String("mcp.transport", "http"), @@ -75,6 +80,13 @@ func (m mcpTracer) StartSpanAndInjectMeta(ctx context.Context, req *jsonrpc.Requ } attrs = append(attrs, getMCPParamsAsAttributes(param)...) + // Apply header-to-attribute mapping if configured. + for headerName, attrName := range m.headerAttributes { + if headerValue, ok := headers[headerName]; ok { + attrs = append(attrs, attribute.String(attrName, headerValue)) + } + } + // Extract trace context from incoming meta. mutableMeta := param.GetMeta() if mutableMeta == nil { diff --git a/internal/tracing/mcp_test.go b/internal/tracing/mcp_test.go index d12a9a546e..c5a51c2b97 100644 --- a/internal/tracing/mcp_test.go +++ b/internal/tracing/mcp_test.go @@ -23,17 +23,27 @@ func TestTracer_StartSpanAndInjectMeta(t *testing.T) { exporter := tracetest.NewInMemoryExporter() tp := trace.NewTracerProvider(trace.WithSyncer(exporter)) - tracer := newMCPTracer(tp.Tracer("test"), autoprop.NewTextMapPropagator()) + tracer := newMCPTracer(tp.Tracer("test"), autoprop.NewTextMapPropagator(), + map[string]string{"x-custom-attr": "custom.attr"}) reqID, _ := jsonrpc.MakeID("id") r := &jsonrpc.Request{ID: reqID, Method: "initialize"} p := &mcp.InitializeParams{} - span := tracer.StartSpanAndInjectMeta(t.Context(), r, p) + span := tracer.StartSpanAndInjectMeta(t.Context(), r, p, map[string]string{ + "x-custom-attr": "custom-value", + }) require.NotNil(t, span) meta := p.GetMeta() require.NotNil(t, meta) require.NotNil(t, meta["traceparent"]) + + // End the span to export it + span.EndSpan() + spans := exporter.GetSpans() + require.Len(t, spans, 1) + actualSpan := spans[0] + require.Contains(t, actualSpan.Attributes, attribute.String("custom.attr", "custom-value")) } func Test_getMCPAttributes(t *testing.T) { @@ -221,12 +231,12 @@ func TestMCPTracer_SpanName(t *testing.T) { exporter := tracetest.NewInMemoryExporter() tp := trace.NewTracerProvider(trace.WithSyncer(exporter)) - tracer := newMCPTracer(tp.Tracer("test"), autoprop.NewTextMapPropagator()) + tracer := newMCPTracer(tp.Tracer("test"), autoprop.NewTextMapPropagator(), nil) reqID, _ := jsonrpc.MakeID("test-id") req := &jsonrpc.Request{ID: reqID, Method: tt.method} - span := tracer.StartSpanAndInjectMeta(context.Background(), req, tt.params) + span := tracer.StartSpanAndInjectMeta(context.Background(), req, tt.params, nil) require.NotNil(t, span) span.EndSpan() diff --git a/internal/tracing/tracing.go b/internal/tracing/tracing.go index acc123f4c3..c2810ff563 100644 --- a/internal/tracing/tracing.go +++ b/internal/tracing/tracing.go @@ -179,7 +179,7 @@ func NewTracingFromEnv(ctx context.Context, stdout io.Writer, headerAttributeMap embeddingsRecorder, headerAttrs, ), - mcpTracer: newMCPTracer(tracer, propagator), + mcpTracer: newMCPTracer(tracer, propagator, headerAttrs), shutdown: tp.Shutdown, // we have to shut down what we create. }, nil } From d870aaaee75c98d6072da40e8a8921ad5eab7d58 Mon Sep 17 00:00:00 2001 From: Ignasi Barrera Date: Mon, 13 Oct 2025 11:52:49 +0200 Subject: [PATCH 2/7] rename variables to attribute Signed-off-by: Ignasi Barrera --- internal/metrics/base_metrics.go | 26 +++++++++---------- .../metrics/chat_completion_metrics_test.go | 2 +- internal/metrics/embeddings_metrics.go | 4 +-- internal/metrics/embeddings_metrics_test.go | 2 +- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/internal/metrics/base_metrics.go b/internal/metrics/base_metrics.go index d76bde3f97..67e18ecee9 100644 --- a/internal/metrics/base_metrics.go +++ b/internal/metrics/base_metrics.go @@ -26,21 +26,21 @@ type baseMetrics struct { // requestModel is the original model from the request body. requestModel string // responseModel is the model that ultimately generated the response (may differ due to backend override). - responseModel string - backend string - requestHeaderLabelMapping map[string]string // maps HTTP headers to metric label names. + responseModel string + backend string + requestHeaderAttributeMapping map[string]string // maps HTTP headers to metric attribute names. } // newBaseMetrics creates a new baseMetrics instance with the specified operation. -func newBaseMetrics(meter metric.Meter, operation string, requestHeaderLabelMapping map[string]string) baseMetrics { +func newBaseMetrics(meter metric.Meter, operation string, requestHeaderAttributeMapping map[string]string) baseMetrics { return baseMetrics{ - metrics: newGenAI(meter), - operation: operation, - originalModel: "unknown", - requestModel: "unknown", - responseModel: "unknown", - backend: "unknown", - requestHeaderLabelMapping: requestHeaderLabelMapping, + metrics: newGenAI(meter), + operation: operation, + originalModel: "unknown", + requestModel: "unknown", + responseModel: "unknown", + backend: "unknown", + requestHeaderAttributeMapping: requestHeaderAttributeMapping, } } @@ -85,13 +85,13 @@ func (b *baseMetrics) buildBaseAttributes(headers map[string]string) attribute.S origModel := attribute.Key(genaiAttributeOriginalModel).String(b.originalModel) reqModel := attribute.Key(genaiAttributeRequestModel).String(b.requestModel) respModel := attribute.Key(genaiAttributeResponseModel).String(b.responseModel) - if len(b.requestHeaderLabelMapping) == 0 { + if len(b.requestHeaderAttributeMapping) == 0 { return attribute.NewSet(opt, provider, origModel, reqModel, respModel) } // Add header values as attributes based on the header mapping if headers are provided. attrs := []attribute.KeyValue{opt, provider, origModel, reqModel, respModel} - for headerName, labelName := range b.requestHeaderLabelMapping { + for headerName, labelName := range b.requestHeaderAttributeMapping { if headerValue, exists := headers[headerName]; exists { attrs = append(attrs, attribute.Key(labelName).String(headerValue)) } diff --git a/internal/metrics/chat_completion_metrics_test.go b/internal/metrics/chat_completion_metrics_test.go index c5bd515d5f..04a72fffd3 100644 --- a/internal/metrics/chat_completion_metrics_test.go +++ b/internal/metrics/chat_completion_metrics_test.go @@ -213,7 +213,7 @@ func TestHeaderLabelMapping(t *testing.T) { pm.RecordTokenUsage(t.Context(), 10, 5, requestHeaders) // Verify that the header mapping is set correctly. - assert.Equal(t, headerMapping, pm.requestHeaderLabelMapping) + assert.Equal(t, headerMapping, pm.requestHeaderAttributeMapping) // Verify that the metrics are recorded with the mapped header attributes. attrs := attribute.NewSet( diff --git a/internal/metrics/embeddings_metrics.go b/internal/metrics/embeddings_metrics.go index c0571ff332..f42114210e 100644 --- a/internal/metrics/embeddings_metrics.go +++ b/internal/metrics/embeddings_metrics.go @@ -44,9 +44,9 @@ type EmbeddingsMetrics interface { } // NewEmbeddings creates a new Embeddings instance. -func NewEmbeddings(meter metric.Meter, requestHeaderLabelMapping map[string]string) EmbeddingsMetrics { +func NewEmbeddings(meter metric.Meter, requestHeaderAttributeMapping map[string]string) EmbeddingsMetrics { return &embeddings{ - baseMetrics: newBaseMetrics(meter, genaiOperationEmbedding, requestHeaderLabelMapping), + baseMetrics: newBaseMetrics(meter, genaiOperationEmbedding, requestHeaderAttributeMapping), } } diff --git a/internal/metrics/embeddings_metrics_test.go b/internal/metrics/embeddings_metrics_test.go index a1df67b39b..0159a14eb8 100644 --- a/internal/metrics/embeddings_metrics_test.go +++ b/internal/metrics/embeddings_metrics_test.go @@ -101,7 +101,7 @@ func TestEmbeddings_HeaderLabelMapping(t *testing.T) { em.RecordTokenUsage(t.Context(), 10, requestHeaders) // Verify that the header mapping is set correctly. - assert.Equal(t, headerMapping, em.requestHeaderLabelMapping) + assert.Equal(t, headerMapping, em.requestHeaderAttributeMapping) // Verify that the metrics are recorded with the mapped header attributes. attrs := attribute.NewSet( From 5c53798025f58b817b46fa41b13e2d6905442e29 Mon Sep 17 00:00:00 2001 From: Ignasi Barrera Date: Mon, 13 Oct 2025 11:53:06 +0200 Subject: [PATCH 3/7] propagate custom metric and trace attributes from _meta Signed-off-by: Ignasi Barrera --- internal/mcpproxy/handlers.go | 105 +++++++++++++------------ internal/mcpproxy/mcpproxy.go | 4 +- internal/mcpproxy/session_test.go | 21 ++--- internal/metrics/mcp_metrics.go | 111 +++++++++++++++------------ internal/metrics/mcp_metrics_test.go | 36 +++++---- internal/tracing/mcp.go | 25 +++--- internal/tracing/mcp_test.go | 6 +- 7 files changed, 170 insertions(+), 138 deletions(-) diff --git a/internal/mcpproxy/handlers.go b/internal/mcpproxy/handlers.go index 69874c9176..503c965717 100644 --- a/internal/mcpproxy/handlers.go +++ b/internal/mcpproxy/handlers.go @@ -107,6 +107,7 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { errType metrics.MCPErrorType requestMethod string span tracing.MCPSpan + params mcp.Params ) defer func() { if m.l.Enabled(ctx, slog.LevelDebug) { @@ -119,17 +120,17 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { if span != nil { span.EndSpanOnError(string(errType), err) } - m.metrics.RecordMethodErrorCount(ctx) - m.metrics.RecordRequestErrorDuration(ctx, &startAt, errType) + m.metrics.RecordMethodErrorCount(ctx, params) + m.metrics.RecordRequestErrorDuration(ctx, &startAt, errType, params) return } if span != nil { span.EndSpan() } - m.metrics.RecordRequestDuration(ctx, &startAt) + m.metrics.RecordRequestDuration(ctx, &startAt, params) // TODO: should we special case when this request is "Response" where method is empty? - m.metrics.RecordMethodCount(ctx, requestMethod) + m.metrics.RecordMethodCount(ctx, requestMethod, params) }() if sessionID := r.Header.Get(sessionIDHeader); sessionID != "" { s, err = m.sessionFromID(secureClientToGatewaySessionID(sessionID), secureClientToGatewayEventID(r.Header.Get(lastEventIDHeader))) @@ -189,8 +190,8 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { switch msg.Method { case "notifications/roots/list_changed": - p := &mcp.RootsListChangedParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) + params = &mcp.RootsListChangedParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") @@ -198,28 +199,28 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { } err = m.handleNotificationsRootsListChanged(ctx, s, w, msg, span) case "completion/complete": - p := &mcp.CompleteParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) + params = &mcp.CompleteParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") return } - err = m.handleCompletionComplete(ctx, s, w, msg, p, span) + err = m.handleCompletionComplete(ctx, s, w, msg, params.(*mcp.CompleteParams), span) case "notifications/progress": - m.metrics.RecordProgress(ctx) - p := &mcp.ProgressNotificationParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) + params = &mcp.ProgressNotificationParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) + m.metrics.RecordProgress(ctx, params) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") return } - err = m.handleClientToServerNotificationsProgress(ctx, s, w, msg, p, span) + err = m.handleClientToServerNotificationsProgress(ctx, s, w, msg, params.(*mcp.ProgressNotificationParams), span) case "initialize": // The very first request from the client to establish a session. - p := &mcp.InitializeParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) + params = &mcp.InitializeParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam m.l.Error("Failed to unmarshal initialize params", slog.String("error", err.Error())) @@ -235,107 +236,107 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { onErrorResponse(w, http.StatusInternalServerError, "missing route header") return } - err = m.handleInitializeRequest(ctx, w, msg, p, route, extractSubject(r), span) + err = m.handleInitializeRequest(ctx, w, msg, params.(*mcp.InitializeParams), route, extractSubject(r), span) case "notifications/initialized": // According to the MCP spec, when the server receives a JSON-RPC response or notification from the client // and accepts it, the server MUST return HTTP 202 Accepted with an empty body. // https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server w.WriteHeader(http.StatusAccepted) case "logging/setLevel": - p := &mcp.SetLoggingLevelParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) + params = &mcp.SetLoggingLevelParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam m.l.Error("Failed to unmarshal set logging level params", slog.String("error", err.Error())) onErrorResponse(w, http.StatusBadRequest, "invalid set logging level params") return } - err = m.handleSetLoggingLevel(ctx, s, w, msg, p, span) + err = m.handleSetLoggingLevel(ctx, s, w, msg, params.(*mcp.SetLoggingLevelParams), span) case "ping": // Ping is intentionally not traced as it's a lightweight health check. err = m.handlePing(ctx, w, msg) case "prompts/list": - p := &mcp.ListPromptsParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) + params = &mcp.ListPromptsParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") return } - err = m.handlePromptListRequest(ctx, s, w, msg, p, span) + err = m.handlePromptListRequest(ctx, s, w, msg, params.(*mcp.ListPromptsParams), span) case "prompts/get": - p := &mcp.GetPromptParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) + params = &mcp.GetPromptParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") return } - err = m.handlePromptGetRequest(ctx, s, w, msg, p) + err = m.handlePromptGetRequest(ctx, s, w, msg, params.(*mcp.GetPromptParams)) case "tools/call": - p := &mcp.CallToolParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) + params = &mcp.CallToolParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam m.l.Error("Failed to unmarshal params", slog.String("method", msg.Method), slog.String("error", err.Error())) onErrorResponse(w, http.StatusBadRequest, "invalid params") return } - err = m.handleToolCallRequest(ctx, s, w, msg, p, span) + err = m.handleToolCallRequest(ctx, s, w, msg, params.(*mcp.CallToolParams), span) case "tools/list": - p := &mcp.ListToolsParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) + params = &mcp.ListToolsParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") return } - err = m.handleToolsListRequest(ctx, s, w, msg, p, span) + err = m.handleToolsListRequest(ctx, s, w, msg, params.(*mcp.ListToolsParams), span) case "resources/list": - p := &mcp.ListResourcesParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) + params = &mcp.ListResourcesParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") return } - err = m.handleResourceListRequest(ctx, s, w, msg, p, span) + err = m.handleResourceListRequest(ctx, s, w, msg, params.(*mcp.ListResourcesParams), span) case "resources/read": - p := &mcp.ReadResourceParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) + params = &mcp.ReadResourceParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") return } - err = m.handleResourceReadRequest(ctx, s, w, msg, p) + err = m.handleResourceReadRequest(ctx, s, w, msg, params.(*mcp.ReadResourceParams)) case "resources/templates/list": - p := &mcp.ListResourceTemplatesParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) + params = &mcp.ListResourceTemplatesParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") return } - err = m.handleResourcesTemplatesListRequest(ctx, s, w, msg, p, span) + err = m.handleResourcesTemplatesListRequest(ctx, s, w, msg, params.(*mcp.ListResourceTemplatesParams), span) case "resources/subscribe": - p := &mcp.SubscribeParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) + params = &mcp.SubscribeParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") return } - err = m.handleResourcesSubscribeRequest(ctx, s, w, msg, p, span) + err = m.handleResourcesSubscribeRequest(ctx, s, w, msg, params.(*mcp.SubscribeParams), span) case "resources/unsubscribe": - p := &mcp.UnsubscribeParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p, r.Header) + params = &mcp.UnsubscribeParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") return } - err = m.handleResourcesUnsubscribeRequest(ctx, s, w, msg, p, span) + err = m.handleResourcesUnsubscribeRequest(ctx, s, w, msg, params.(*mcp.UnsubscribeParams), span) case "notifications/cancelled": // The responsibility of cancelling the operation on server side is optional, so we just ignore it for now. // https://modelcontextprotocol.io/specification/2025-06-18/basic/utilities/cancellation#behavior-requirements @@ -371,7 +372,7 @@ func errorType(err error) metrics.MCPErrorType { // handleInitializeRequest handles the "initialize" JSON-RPC method. func (m *MCPProxy) handleInitializeRequest(ctx context.Context, w http.ResponseWriter, req *jsonrpc.Request, p *mcp.InitializeParams, route, subject string, span tracing.MCPSpan) error { - m.metrics.RecordClientCapabilities(ctx, p.Capabilities) + m.metrics.RecordClientCapabilities(ctx, p.Capabilities, p) s, err := m.newSession(ctx, p, route, subject, span) if err != nil { m.l.Error("failed to create new session", slog.String("error", err.Error())) @@ -788,7 +789,11 @@ func (m *MCPProxy) recordResponse(ctx context.Context, rawMsg jsonrpc.Message) { case "notifications/resources/list_changed": case "notifications/resources/updated": case "notifications/progress": - m.metrics.RecordProgress(ctx) + params := &mcp.ProgressNotificationParams{} + if err := json.Unmarshal(msg.Params, ¶ms); err != nil { + m.l.Error("Failed to unmarshal params", slog.String("method", msg.Method), slog.String("error", err.Error())) + } + m.metrics.RecordProgress(ctx, params) case "notifications/message": case "notifications/tools/list_changed": case "roots/list": @@ -796,11 +801,11 @@ func (m *MCPProxy) recordResponse(ctx context.Context, rawMsg jsonrpc.Message) { case "elicitation/create": default: knownMethod = false - m.metrics.RecordMethodErrorCount(ctx) + m.metrics.RecordMethodErrorCount(ctx, nil) m.l.Warn("Unsupported MCP request method from server", slog.String("method", msg.Method)) } if knownMethod { - m.metrics.RecordMethodCount(ctx, msg.Method) + m.metrics.RecordMethodCount(ctx, msg.Method, nil) } default: m.l.Warn("unexpected message type in MCP response", slog.Any("message", msg)) diff --git a/internal/mcpproxy/mcpproxy.go b/internal/mcpproxy/mcpproxy.go index 2ef7f96081..05689a51a9 100644 --- a/internal/mcpproxy/mcpproxy.go +++ b/internal/mcpproxy/mcpproxy.go @@ -199,7 +199,7 @@ func (m *MCPProxy) newSession(ctx context.Context, p *mcp.InitializeParams, rout // TODO: should we record a metric for this? return } - m.metrics.RecordInitializationDuration(ctx, &startAt) + m.metrics.RecordInitializationDuration(ctx, &startAt, p) if m.l.Enabled(ctx, slog.LevelDebug) { m.l.Debug("created MCP session", slog.String("backend", backend.Name), slog.String("session_id", string(initResult.sessionID))) } @@ -352,7 +352,7 @@ func (m *MCPProxy) initializeSession(ctx context.Context, routeName filterapi.MC if m.l.Enabled(ctx, slog.LevelDebug) { m.l.Debug("MCP session initialized", slog.Any("capabilities", initResult.Capabilities)) } - m.metrics.RecordServerCapabilities(ctx, initResult.Capabilities) + m.metrics.RecordServerCapabilities(ctx, initResult.Capabilities, p) } // Need to invoke "notifications/initialized" to complete the initialization. diff --git a/internal/mcpproxy/session_test.go b/internal/mcpproxy/session_test.go index d7f07d320b..b63714e79c 100644 --- a/internal/mcpproxy/session_test.go +++ b/internal/mcpproxy/session_test.go @@ -30,16 +30,19 @@ import ( // stubMetrics implements metrics.MCPMetrics with no-ops. type stubMetrics struct{} -func (s stubMetrics) WithRequestAttributes(_ *http.Request) metrics.MCPMetrics { return s } -func (stubMetrics) RecordRequestDuration(_ context.Context, _ *time.Time) {} -func (stubMetrics) RecordRequestErrorDuration(_ context.Context, _ *time.Time, _ metrics.MCPErrorType) { +func (s stubMetrics) WithRequestAttributes(_ *http.Request) metrics.MCPMetrics { return s } +func (stubMetrics) RecordRequestDuration(_ context.Context, _ *time.Time, _ mcpsdk.Params) {} +func (stubMetrics) RecordRequestErrorDuration(_ context.Context, _ *time.Time, _ metrics.MCPErrorType, _ mcpsdk.Params) { } -func (stubMetrics) RecordMethodCount(_ context.Context, _ string) {} -func (stubMetrics) RecordMethodErrorCount(_ context.Context) {} -func (stubMetrics) RecordInitializationDuration(_ context.Context, _ *time.Time) {} -func (stubMetrics) RecordClientCapabilities(_ context.Context, _ *mcpsdk.ClientCapabilities) {} -func (stubMetrics) RecordServerCapabilities(_ context.Context, _ *mcpsdk.ServerCapabilities) {} -func (stubMetrics) RecordProgress(_ context.Context) {} +func (stubMetrics) RecordMethodCount(_ context.Context, _ string, _ mcpsdk.Params) {} +func (stubMetrics) RecordMethodErrorCount(_ context.Context, _ mcpsdk.Params) {} +func (stubMetrics) RecordInitializationDuration(_ context.Context, _ *time.Time, _ mcpsdk.Params) {} +func (stubMetrics) RecordClientCapabilities(_ context.Context, _ *mcpsdk.ClientCapabilities, _ mcpsdk.Params) { +} + +func (stubMetrics) RecordServerCapabilities(_ context.Context, _ *mcpsdk.ServerCapabilities, _ mcpsdk.Params) { +} +func (stubMetrics) RecordProgress(_ context.Context, _ mcpsdk.Params) {} func TestBackendSessionIDs_Success(t *testing.T) { backendA := "backendA" diff --git a/internal/metrics/mcp_metrics.go b/internal/metrics/mcp_metrics.go index ce9534faac..56e4956ff8 100644 --- a/internal/metrics/mcp_metrics.go +++ b/internal/metrics/mcp_metrics.go @@ -7,6 +7,7 @@ package metrics import ( "context" + "fmt" "net/http" "time" @@ -106,37 +107,37 @@ type MCPMetrics interface { // WithRequestAttributes returns a new MCPMetrics instance with default attributes extracted from the HTTP request. WithRequestAttributes(req *http.Request) MCPMetrics // RecordRequestDuration records the duration of a success MCP request. - RecordRequestDuration(ctx context.Context, startAt *time.Time) + RecordRequestDuration(ctx context.Context, startAt *time.Time, meta mcpsdk.Params) // RecordRequestErrorDuration records the duration of an MCP request that resulted in an error. - RecordRequestErrorDuration(ctx context.Context, startAt *time.Time, errType MCPErrorType) + RecordRequestErrorDuration(ctx context.Context, startAt *time.Time, errType MCPErrorType, meta mcpsdk.Params) // RecordMethodCount records the count of method invocations. - RecordMethodCount(ctx context.Context, methodName string) + RecordMethodCount(ctx context.Context, methodName string, meta mcpsdk.Params) // RecordMethodErrorCount records the count of method invocations with error status. - RecordMethodErrorCount(ctx context.Context) + RecordMethodErrorCount(ctx context.Context, meta mcpsdk.Params) // RecordInitializationDuration records the duration of MCP initialization. - RecordInitializationDuration(ctx context.Context, startAt *time.Time) + RecordInitializationDuration(ctx context.Context, startAt *time.Time, meta mcpsdk.Params) // RecordClientCapabilities records the negotiated client capabilities. - RecordClientCapabilities(ctx context.Context, capabilities *mcpsdk.ClientCapabilities) + RecordClientCapabilities(ctx context.Context, capabilities *mcpsdk.ClientCapabilities, meta mcpsdk.Params) // RecordServerCapabilities records the negotiated server capabilities. - RecordServerCapabilities(ctx context.Context, capabilities *mcpsdk.ServerCapabilities) + RecordServerCapabilities(ctx context.Context, capabilities *mcpsdk.ServerCapabilities, meta mcpsdk.Params) // RecordProgress records a progress notification sent/received. - RecordProgress(ctx context.Context) + RecordProgress(ctx context.Context, meta mcpsdk.Params) } type mcp struct { - requestDuration metric.Float64Histogram - methodCount metric.Float64Counter - initializationDuration metric.Float64Histogram - capabilitiesNegotiated metric.Float64Counter - progressNotifications metric.Float64Counter - requestHeaderLabelMapping map[string]string // maps HTTP headers to metric label names. - defaultAttributes []attribute.KeyValue + requestDuration metric.Float64Histogram + methodCount metric.Float64Counter + initializationDuration metric.Float64Histogram + capabilitiesNegotiated metric.Float64Counter + progressNotifications metric.Float64Counter + requestHeaderAttributeMapping map[string]string // maps HTTP headers to metric attribute names. + defaultAttributes []attribute.KeyValue } // NewMCP creates a new mcp metrics instance. -func NewMCP(meter metric.Meter, requestHeaderLabelMapping map[string]string) MCPMetrics { +func NewMCP(meter metric.Meter, requestHeaderAttributeMapping map[string]string) MCPMetrics { return &mcp{ - requestHeaderLabelMapping: requestHeaderLabelMapping, + requestHeaderAttributeMapping: requestHeaderAttributeMapping, requestDuration: mustRegisterHistogram(meter, mcpRequestDuration, metric.WithDescription("Duration of MCP requests"), @@ -169,16 +170,16 @@ func NewMCP(meter metric.Meter, requestHeaderLabelMapping map[string]string) MCP // the HTTP request headers. func (m *mcp) WithRequestAttributes(req *http.Request) MCPMetrics { withAttrs := &mcp{ - requestDuration: m.requestDuration, - methodCount: m.methodCount, - initializationDuration: m.initializationDuration, - capabilitiesNegotiated: m.capabilitiesNegotiated, - progressNotifications: m.progressNotifications, - requestHeaderLabelMapping: m.requestHeaderLabelMapping, + requestDuration: m.requestDuration, + methodCount: m.methodCount, + initializationDuration: m.initializationDuration, + capabilitiesNegotiated: m.capabilitiesNegotiated, + progressNotifications: m.progressNotifications, + requestHeaderAttributeMapping: m.requestHeaderAttributeMapping, } // Apply header-to-attribute mapping if configured. - for headerName, attrName := range m.requestHeaderLabelMapping { + for headerName, attrName := range m.requestHeaderAttributeMapping { if headerValue := req.Header.Get(headerName); headerValue != "" { withAttrs.defaultAttributes = append( withAttrs.defaultAttributes, @@ -191,90 +192,90 @@ func (m *mcp) WithRequestAttributes(req *http.Request) MCPMetrics { } // RecordMethodCount implements [MCPMetrics.RecordMethodCount]. -func (m *mcp) RecordMethodCount(ctx context.Context, methodName string) { +func (m *mcp) RecordMethodCount(ctx context.Context, methodName string, params mcpsdk.Params) { if methodName == "" { return } m.methodCount.Add(ctx, 1, - m.withDefaultAttributes( + m.withDefaultAttributes(params, attribute.Key(mcpAttributeMethodName).String(methodName), attribute.String(mcpAttributeStatusName, string(mcpStatusSuccess)), )) } // RecordMethodErrorCount implements [MCPMetrics.RecordMethodErrorCount]. -func (m *mcp) RecordMethodErrorCount(ctx context.Context) { +func (m *mcp) RecordMethodErrorCount(ctx context.Context, params mcpsdk.Params) { m.methodCount.Add(ctx, 1, - m.withDefaultAttributes( + m.withDefaultAttributes(params, attribute.String(mcpAttributeStatusName, string(mcpStatusError)), )) } // RecordRequestDuration implements [MCPMetrics.RecordRequestDuration]. -func (m *mcp) RecordRequestDuration(ctx context.Context, startAt *time.Time) { +func (m *mcp) RecordRequestDuration(ctx context.Context, startAt *time.Time, params mcpsdk.Params) { if startAt == nil { return } duration := time.Since(*startAt).Seconds() - m.requestDuration.Record(ctx, duration, m.withDefaultAttributes()) + m.requestDuration.Record(ctx, duration, m.withDefaultAttributes(params)) } // RecordRequestErrorDuration implements [MCPMetrics.RecordRequestErrorDuration]. -func (m *mcp) RecordRequestErrorDuration(ctx context.Context, startAt *time.Time, errType MCPErrorType) { +func (m *mcp) RecordRequestErrorDuration(ctx context.Context, startAt *time.Time, errType MCPErrorType, params mcpsdk.Params) { if startAt == nil { return } duration := time.Since(*startAt).Seconds() - m.requestDuration.Record(ctx, duration, m.withDefaultAttributes( + m.requestDuration.Record(ctx, duration, m.withDefaultAttributes(params, attribute.Key(mcpAttributeErrorType).String(string(errType)), )) } // RecordInitializationDuration implements [MCPMetrics.RecordInitializationDuration]. -func (m *mcp) RecordInitializationDuration(ctx context.Context, startAt *time.Time) { +func (m *mcp) RecordInitializationDuration(ctx context.Context, startAt *time.Time, params mcpsdk.Params) { if startAt == nil { return } duration := time.Since(*startAt).Seconds() - m.initializationDuration.Record(ctx, duration, m.withDefaultAttributes()) + m.initializationDuration.Record(ctx, duration, m.withDefaultAttributes(params)) } // RecordProgress implements [MCPMetrics.RecordProgress]. -func (m *mcp) RecordProgress(ctx context.Context) { - m.progressNotifications.Add(ctx, 1, m.withDefaultAttributes()) +func (m *mcp) RecordProgress(ctx context.Context, params mcpsdk.Params) { + m.progressNotifications.Add(ctx, 1, m.withDefaultAttributes(params)) } // RecordClientCapabilities implements [MCPMetrics.RecordClientCapabilities]. -func (m *mcp) RecordClientCapabilities(ctx context.Context, capabilities *mcpsdk.ClientCapabilities) { +func (m *mcp) RecordClientCapabilities(ctx context.Context, capabilities *mcpsdk.ClientCapabilities, params mcpsdk.Params) { if capabilities == nil { return } side := string(mcpCapabilitySideClient) if l := len(capabilities.Experimental); l > 0 { - m.capabilitiesNegotiated.Add(ctx, float64(l), m.withDefaultAttributes( + m.capabilitiesNegotiated.Add(ctx, float64(l), m.withDefaultAttributes(params, attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeExperimental)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } if capabilities.Roots.ListChanged { - m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes( + m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes(params, attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeRoots)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } if capabilities.Sampling != nil { - m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes( + m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes(params, attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeSampling)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } if capabilities.Elicitation != nil { - m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes( + m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes(params, attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeElicitation)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) @@ -282,49 +283,49 @@ func (m *mcp) RecordClientCapabilities(ctx context.Context, capabilities *mcpsdk } // RecordServerCapabilities implements [MCPMetrics.RecordServerCapabilities]. -func (m *mcp) RecordServerCapabilities(ctx context.Context, serverCapa *mcpsdk.ServerCapabilities) { +func (m *mcp) RecordServerCapabilities(ctx context.Context, serverCapa *mcpsdk.ServerCapabilities, params mcpsdk.Params) { if serverCapa == nil { return } side := string(mcpCapabilitySideServer) if serverCapa.Completions != nil { - m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes( + m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes(params, attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeCompletions)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } if l := len(serverCapa.Experimental); l > 0 { - m.capabilitiesNegotiated.Add(ctx, float64(l), m.withDefaultAttributes( + m.capabilitiesNegotiated.Add(ctx, float64(l), m.withDefaultAttributes(params, attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeExperimental)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } if serverCapa.Logging != nil { - m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes( + m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes(params, attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeLogging)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } if serverCapa.Prompts != nil { - m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes( + m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes(params, attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypePrompts)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } if serverCapa.Resources != nil { - m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes( + m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes(params, attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeResources)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } if serverCapa.Tools != nil { - m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes( + m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes(params, attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeTools)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) @@ -332,6 +333,16 @@ func (m *mcp) RecordServerCapabilities(ctx context.Context, serverCapa *mcpsdk.S } // withDefaultAttributes appends default attributes to the provided attributes. -func (m *mcp) withDefaultAttributes(attrs ...attribute.KeyValue) metric.MeasurementOption { - return metric.WithAttributes(append(m.defaultAttributes, attrs...)...) +func (m *mcp) withDefaultAttributes(params mcpsdk.Params, attrs ...attribute.KeyValue) metric.MeasurementOption { + all := make([]attribute.KeyValue, 0, len(m.defaultAttributes)+len(m.requestHeaderAttributeMapping)+len(attrs)) + all = append(all, m.defaultAttributes...) + if params != nil { + for k, v := range params.GetMeta() { + if target, ok := m.requestHeaderAttributeMapping[k]; ok { + all = append(all, attribute.String(target, fmt.Sprintf("%v", v))) + } + } + } + all = append(all, attrs...) + return metric.WithAttributes(all...) } diff --git a/internal/metrics/mcp_metrics_test.go b/internal/metrics/mcp_metrics_test.go index 50a3873e51..15454d6af2 100644 --- a/internal/metrics/mcp_metrics_test.go +++ b/internal/metrics/mcp_metrics_test.go @@ -30,21 +30,29 @@ func TestRecordMetricWithCustomAttributes(t *testing.T) { mr := metric.NewManualReader() meter := metric.NewMeterProvider(metric.WithReader(mr)).Meter("test") - m := NewMCP(meter, map[string]string{"x-custom-attr": "attr.custom"}) + m := NewMCP(meter, map[string]string{ + "x-tracing-enrichment-user-region": "user.region", + "x-session-id": "session.id", + }) require.NotNil(t, m) req, err := http.NewRequest("GET", "https://example.com", nil) require.NoError(t, err) - req.Header.Set("x-custom-attr", "test") // should be included in metrics - req.Header.Set("x-other-attr", "other") // should be ignored + req.Header.Set("x-tracing-enrichment-user-region", "us-east-1") // should be included in metrics + req.Header.Set("x-other-attr", "other") // should be ignored m = m.WithRequestAttributes(req) startAt := time.Now().Add(-1 * time.Minute) - m.RecordRequestDuration(t.Context(), &startAt) + m.RecordRequestDuration(t.Context(), &startAt, &mcpsdk.InitializeParams{ + Meta: map[string]any{"x-session-id": "sess-1234"}, + }) count, sum := testotel.GetHistogramValues(t, mr, mcpRequestDuration, - attribute.NewSet(attribute.String("attr.custom", "test"))) + attribute.NewSet( + attribute.String("user.region", "us-east-1"), + attribute.String("session.id", "sess-1234"), + )) require.Equal(t, uint64(1), count) require.Equal(t, 60, int(sum)) } @@ -56,7 +64,7 @@ func TestRecordRequestDuration(t *testing.T) { m := NewMCP(meter, nil) require.NotNil(t, m) startAt := time.Now().Add(-1 * time.Minute) - m.RecordRequestDuration(t.Context(), &startAt) + m.RecordRequestDuration(t.Context(), &startAt, nil) count, sum := testotel.GetHistogramValues(t, mr, mcpRequestDuration, attribute.NewSet()) require.Equal(t, uint64(1), count) @@ -70,7 +78,7 @@ func TestRecordRequestErrorDuration(t *testing.T) { m := NewMCP(meter, nil) require.NotNil(t, m) startAt := time.Now().Add(-30 * time.Second) - m.RecordRequestErrorDuration(t.Context(), &startAt, MCPErrorUnsupportedProtocolVersion) + m.RecordRequestErrorDuration(t.Context(), &startAt, MCPErrorUnsupportedProtocolVersion, nil) count, sum := testotel.GetHistogramValues(t, mr, mcpRequestDuration, attribute.NewSet( attribute.Key(mcpAttributeErrorType).String(string(MCPErrorUnsupportedProtocolVersion)), @@ -86,7 +94,7 @@ func TestRecordMethodCount(t *testing.T) { m := NewMCP(meter, nil) require.NotNil(t, m) - m.RecordMethodCount(t.Context(), "test_method_name") + m.RecordMethodCount(t.Context(), "test_method_name", nil) attrs := attribute.NewSet( attribute.Key(mcpAttributeMethodName).String("test_method_name"), attribute.Key(mcpAttributeStatusName).String(string(mcpStatusSuccess)), @@ -94,7 +102,7 @@ func TestRecordMethodCount(t *testing.T) { val := testotel.GetCounterValue(t, mr, mcpMethodCount, attrs) require.Equal(t, float64(1), val) - m.RecordMethodErrorCount(t.Context()) + m.RecordMethodErrorCount(t.Context(), nil) attrs = attribute.NewSet( attribute.Key(mcpAttributeStatusName).String(string(mcpStatusError)), ) @@ -110,7 +118,7 @@ func TestRecordInitializationDuration(t *testing.T) { require.NotNil(t, m) startAt := time.Now().Add(-45 * time.Second) - m.RecordInitializationDuration(t.Context(), &startAt) + m.RecordInitializationDuration(t.Context(), &startAt, nil) count, sum := testotel.GetHistogramValues(t, mr, mcpInitializationDuration, attribute.NewSet()) require.Equal(t, uint64(1), count) @@ -133,7 +141,7 @@ func TestRecordCapabilitiesNegotiated(t *testing.T) { ListChanged bool "json:\"listChanged,omitempty\"" }{ListChanged: true}, Sampling: &mcpsdk.SamplingCapabilities{}, - }) + }, nil) m.RecordServerCapabilities(t.Context(), &mcpsdk.ServerCapabilities{ Experimental: map[string]any{ "exp1": struct{}{}, @@ -143,7 +151,7 @@ func TestRecordCapabilitiesNegotiated(t *testing.T) { Prompts: &mcpsdk.PromptCapabilities{ListChanged: true}, Resources: &mcpsdk.ResourceCapabilities{ListChanged: true, Subscribe: true}, Tools: &mcpsdk.ToolCapabilities{ListChanged: true}, - }) + }, nil) require.Equal(t, float64(2), testotel.GetCounterValue(t, mr, mcpCapabilitiesNegotiated, attribute.NewSet( attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeExperimental)), attribute.Key(mcpAttributeCapabilitySide).String(string(mcpCapabilitySideClient)), @@ -179,11 +187,11 @@ func TestRecordProgressNotifications(t *testing.T) { m := NewMCP(meter, nil) require.NotNil(t, m) - m.RecordProgress(t.Context()) + m.RecordProgress(t.Context(), nil) val := testotel.GetCounterValue(t, mr, mpcProgressNotifications, attribute.NewSet()) require.Equal(t, float64(1), val) - m.RecordProgress(t.Context()) + m.RecordProgress(t.Context(), nil) val = testotel.GetCounterValue(t, mr, mpcProgressNotifications, attribute.NewSet()) require.Equal(t, float64(2), val) } diff --git a/internal/tracing/mcp.go b/internal/tracing/mcp.go index eade1489cf..e4c5a22bcc 100644 --- a/internal/tracing/mcp.go +++ b/internal/tracing/mcp.go @@ -57,16 +57,16 @@ func (s mcpSpan) EndSpan() { // mcpTracer is an implementation of [tracing.MCPTracer]. type mcpTracer struct { - tracer trace.Tracer - propagator propagation.TextMapPropagator - headerAttributes map[string]string + tracer trace.Tracer + propagator propagation.TextMapPropagator + attributeMappings map[string]string } -func newMCPTracer(tracer trace.Tracer, propagator propagation.TextMapPropagator, headerAttributes map[string]string) tracing.MCPTracer { +func newMCPTracer(tracer trace.Tracer, propagator propagation.TextMapPropagator, attributeMappings map[string]string) tracing.MCPTracer { return mcpTracer{ - tracer: tracer, - propagator: propagator, - headerAttributes: headerAttributes, + tracer: tracer, + propagator: propagator, + attributeMappings: attributeMappings, } } @@ -81,9 +81,14 @@ func (m mcpTracer) StartSpanAndInjectMeta(ctx context.Context, req *jsonrpc.Requ attrs = append(attrs, getMCPParamsAsAttributes(param)...) // Apply header-to-attribute mapping if configured. - for headerName, attrName := range m.headerAttributes { - if headerValue, ok := headers[headerName]; ok { - attrs = append(attrs, attribute.String(attrName, headerValue)) + for srcName, targetName := range m.attributeMappings { + // Check if the attribute is present in the metadata first, as this is the common place to add custom attributes + // in MCP requests. Fall back to headers if not found in metadata. + // If the attribute is not found there, check if there is any custom header to map. + if metaValue, ok := param.GetMeta()[srcName]; ok { + attrs = append(attrs, attribute.String(targetName, fmt.Sprintf("%v", metaValue))) + } else if headerValue, ok := headers[srcName]; ok { + attrs = append(attrs, attribute.String(targetName, headerValue)) } } diff --git a/internal/tracing/mcp_test.go b/internal/tracing/mcp_test.go index c5a51c2b97..ba64b762c3 100644 --- a/internal/tracing/mcp_test.go +++ b/internal/tracing/mcp_test.go @@ -24,13 +24,13 @@ func TestTracer_StartSpanAndInjectMeta(t *testing.T) { tp := trace.NewTracerProvider(trace.WithSyncer(exporter)) tracer := newMCPTracer(tp.Tracer("test"), autoprop.NewTextMapPropagator(), - map[string]string{"x-custom-attr": "custom.attr"}) + map[string]string{"x-tracing-enrichment-user-region": "user.region"}) reqID, _ := jsonrpc.MakeID("id") r := &jsonrpc.Request{ID: reqID, Method: "initialize"} p := &mcp.InitializeParams{} span := tracer.StartSpanAndInjectMeta(t.Context(), r, p, map[string]string{ - "x-custom-attr": "custom-value", + "x-tracing-enrichment-user-region": "us-east-1", }) require.NotNil(t, span) @@ -43,7 +43,7 @@ func TestTracer_StartSpanAndInjectMeta(t *testing.T) { spans := exporter.GetSpans() require.Len(t, spans, 1) actualSpan := spans[0] - require.Contains(t, actualSpan.Attributes, attribute.String("custom.attr", "custom-value")) + require.Contains(t, actualSpan.Attributes, attribute.String("user.region", "us-east-1")) } func Test_getMCPAttributes(t *testing.T) { From 7f35ed41da3452f0ee01034ac162a80d82aaba93 Mon Sep 17 00:00:00 2001 From: Ignasi Barrera Date: Mon, 13 Oct 2025 15:43:14 +0200 Subject: [PATCH 4/7] add test sfor attribute precedence Signed-off-by: Ignasi Barrera --- internal/metrics/mcp_metrics_test.go | 1 + internal/tracing/mcp_test.go | 10 ++++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/internal/metrics/mcp_metrics_test.go b/internal/metrics/mcp_metrics_test.go index 15454d6af2..42aedb4aa0 100644 --- a/internal/metrics/mcp_metrics_test.go +++ b/internal/metrics/mcp_metrics_test.go @@ -40,6 +40,7 @@ func TestRecordMetricWithCustomAttributes(t *testing.T) { require.NoError(t, err) req.Header.Set("x-tracing-enrichment-user-region", "us-east-1") // should be included in metrics req.Header.Set("x-other-attr", "other") // should be ignored + req.Header.Set("x-session-id", "123") // should be ignored as the value in the metadata takes precedence m = m.WithRequestAttributes(req) diff --git a/internal/tracing/mcp_test.go b/internal/tracing/mcp_test.go index ba64b762c3..4cc67555b1 100644 --- a/internal/tracing/mcp_test.go +++ b/internal/tracing/mcp_test.go @@ -24,13 +24,17 @@ func TestTracer_StartSpanAndInjectMeta(t *testing.T) { tp := trace.NewTracerProvider(trace.WithSyncer(exporter)) tracer := newMCPTracer(tp.Tracer("test"), autoprop.NewTextMapPropagator(), - map[string]string{"x-tracing-enrichment-user-region": "user.region"}) + map[string]string{ + "x-tracing-enrichment-user-region": "user.region", + "x-session-id": "session.id", + }) reqID, _ := jsonrpc.MakeID("id") r := &jsonrpc.Request{ID: reqID, Method: "initialize"} - p := &mcp.InitializeParams{} + p := &mcp.InitializeParams{Meta: map[string]any{"x-session-id": "sess-1234"}} span := tracer.StartSpanAndInjectMeta(t.Context(), r, p, map[string]string{ "x-tracing-enrichment-user-region": "us-east-1", + "x-session-id": "123", // should be ignored as the value in the metadata takes precedence }) require.NotNil(t, span) @@ -44,6 +48,8 @@ func TestTracer_StartSpanAndInjectMeta(t *testing.T) { require.Len(t, spans, 1) actualSpan := spans[0] require.Contains(t, actualSpan.Attributes, attribute.String("user.region", "us-east-1")) + require.Contains(t, actualSpan.Attributes, attribute.String("session.id", "sess-1234")) + require.NotContains(t, actualSpan.Attributes, attribute.String("session.id", "123")) } func Test_getMCPAttributes(t *testing.T) { From 8ae4f39ac94e842db0cad3b9cdf3f8a38cb57687 Mon Sep 17 00:00:00 2001 From: Takeshi Yoneda Date: Tue, 14 Oct 2025 09:07:56 -0700 Subject: [PATCH 5/7] drift Signed-off-by: Takeshi Yoneda --- internal/metrics/completion_metrics_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/metrics/completion_metrics_test.go b/internal/metrics/completion_metrics_test.go index dc4758e52b..f167d06367 100644 --- a/internal/metrics/completion_metrics_test.go +++ b/internal/metrics/completion_metrics_test.go @@ -216,7 +216,7 @@ func TestCompletion_HeaderLabelMapping(t *testing.T) { pm.RecordTokenUsage(t.Context(), 10, 5, requestHeaders) // Verify that the header mapping is set correctly. - assert.Equal(t, headerMapping, pm.requestHeaderLabelMapping) + assert.Equal(t, headerMapping, pm.requestHeaderAttributeMapping) // Verify that the metrics are recorded with the mapped header attributes. attrs := attribute.NewSet( From 3b448be50642f925e282dd037133739965271a9c Mon Sep 17 00:00:00 2001 From: Ignasi Barrera Date: Tue, 14 Oct 2025 23:11:39 +0100 Subject: [PATCH 6/7] address review comments Signed-off-by: Ignasi Barrera --- internal/mcpproxy/handlers.go | 8 +------ internal/mcpproxy/mcpproxy_test.go | 4 ++-- internal/tracing/api/mcp.go | 5 ++-- internal/tracing/mcp.go | 38 ++++++++++++++++++++++++++---- internal/tracing/mcp_test.go | 20 ++++++++++++---- 5 files changed, 55 insertions(+), 20 deletions(-) diff --git a/internal/mcpproxy/handlers.go b/internal/mcpproxy/handlers.go index 503c965717..cf6fe1e5e8 100644 --- a/internal/mcpproxy/handlers.go +++ b/internal/mcpproxy/handlers.go @@ -1237,13 +1237,7 @@ func parseParamsAndMaybeStartSpan[paramType mcp.Params](ctx context.Context, m * return nil, err } - // TODO: headers with multiple values - headerMap := make(map[string]string, len(headers)) - for header := range headers { - headerMap[header] = headers.Get(header) - } - - span := m.tracer.StartSpanAndInjectMeta(ctx, req, p, headerMap) + span := m.tracer.StartSpanAndInjectMeta(ctx, req, p, headers) return span, nil } diff --git a/internal/mcpproxy/mcpproxy_test.go b/internal/mcpproxy/mcpproxy_test.go index e377bf13e2..f15f042a51 100644 --- a/internal/mcpproxy/mcpproxy_test.go +++ b/internal/mcpproxy/mcpproxy_test.go @@ -50,7 +50,7 @@ type fakeTracer struct { span *fakeSpan } -func (f *fakeTracer) StartSpanAndInjectMeta(context.Context, *jsonrpc.Request, mcp.Params, map[string]string) tracing.MCPSpan { +func (f *fakeTracer) StartSpanAndInjectMeta(context.Context, *jsonrpc.Request, mcp.Params, http.Header) tracing.MCPSpan { if f.span == nil { f.span = &fakeSpan{} } @@ -84,7 +84,7 @@ func TestMCPProxy_HTTPMethods(t *testing.T) { } func TestLoadConfig_NilMCPConfig(t *testing.T) { - proxy, _, err := NewMCPProxy(slog.Default(), stubMetrics{}, noopTracer, DefaultSessionCrypto("test")) + proxy, _, err := NewMCPProxy(slog.Default(), stubMetrics{}, noopTracer, DefaultSessionCrypto("test", "")) require.NoError(t, err) config := &filterapi.Config{MCPConfig: nil} diff --git a/internal/tracing/api/mcp.go b/internal/tracing/api/mcp.go index 9082c6f774..913610fa1b 100644 --- a/internal/tracing/api/mcp.go +++ b/internal/tracing/api/mcp.go @@ -9,6 +9,7 @@ package api import ( "context" + "net/http" "github.com/modelcontextprotocol/go-sdk/jsonrpc" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -26,7 +27,7 @@ type MCPTracer interface { // - headers: Request HTTP request headers. // // Returns nil unless the span is sampled. - StartSpanAndInjectMeta(ctx context.Context, req *jsonrpc.Request, param mcp.Params, headers map[string]string) MCPSpan + StartSpanAndInjectMeta(ctx context.Context, req *jsonrpc.Request, param mcp.Params, headers http.Header) MCPSpan } // MCPSpan represents an MCP span. @@ -46,6 +47,6 @@ var _ MCPTracer = NoopMCPTracer{} type NoopMCPTracer struct{} // StartSpanAndInjectMeta implements [MCPTracer.StartSpanAndInjectMeta]. -func (NoopMCPTracer) StartSpanAndInjectMeta(context.Context, *jsonrpc.Request, mcp.Params, map[string]string) MCPSpan { +func (NoopMCPTracer) StartSpanAndInjectMeta(context.Context, *jsonrpc.Request, mcp.Params, http.Header) MCPSpan { return nil } diff --git a/internal/tracing/mcp.go b/internal/tracing/mcp.go index e4c5a22bcc..b524677791 100644 --- a/internal/tracing/mcp.go +++ b/internal/tracing/mcp.go @@ -8,6 +8,11 @@ package tracing import ( "context" "fmt" + "maps" + "net/http" + "slices" + "sort" + "strings" "github.com/modelcontextprotocol/go-sdk/jsonrpc" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -71,7 +76,7 @@ func newMCPTracer(tracer trace.Tracer, propagator propagation.TextMapPropagator, } // StartSpanAndInjectMeta implements [tracing.MCPTracer.StartSpanAndInjectMeta]. -func (m mcpTracer) StartSpanAndInjectMeta(ctx context.Context, req *jsonrpc.Request, param mcp.Params, headers map[string]string) tracing.MCPSpan { +func (m mcpTracer) StartSpanAndInjectMeta(ctx context.Context, req *jsonrpc.Request, param mcp.Params, headers http.Header) tracing.MCPSpan { attrs := []attribute.KeyValue{ attribute.String("mcp.protocol.version", "2025-06-18"), attribute.String("mcp.transport", "http"), @@ -85,9 +90,9 @@ func (m mcpTracer) StartSpanAndInjectMeta(ctx context.Context, req *jsonrpc.Requ // Check if the attribute is present in the metadata first, as this is the common place to add custom attributes // in MCP requests. Fall back to headers if not found in metadata. // If the attribute is not found there, check if there is any custom header to map. - if metaValue, ok := param.GetMeta()[srcName]; ok { - attrs = append(attrs, attribute.String(targetName, fmt.Sprintf("%v", metaValue))) - } else if headerValue, ok := headers[srcName]; ok { + if metaValue := caseInsensitiveValue(param.GetMeta(), srcName); metaValue != "" { + attrs = append(attrs, attribute.String(targetName, metaValue)) + } else if headerValue := headers.Get(srcName); headerValue != "" { // this is case-insensitive attrs = append(attrs, attribute.String(targetName, headerValue)) } } @@ -121,6 +126,31 @@ func (m mcpTracer) StartSpanAndInjectMeta(ctx context.Context, req *jsonrpc.Requ return nil } +// caseInsensitiveValue retrieves a value from the meta map in a case-insensitive manner. +// If the same key is present in different cases, the first one in alphabetical order +// that matches is returned. +// If the key is not found, it returns an empty string. +func caseInsensitiveValue(meta map[string]any, key string) string { + if meta == nil { + return "" + } + + if v, ok := meta[key]; ok { + return fmt.Sprintf("%v", v) + } + + keys := slices.Collect(maps.Keys(meta)) + sort.Strings(keys) + + for _, k := range keys { + if strings.EqualFold(k, key) { + return fmt.Sprintf("%v", meta[k]) + } + } + + return "" +} + func getMCPParamsAsAttributes(p mcp.Params) []attribute.KeyValue { var attrs []attribute.KeyValue switch params := p.(type) { diff --git a/internal/tracing/mcp_test.go b/internal/tracing/mcp_test.go index 4cc67555b1..bc5c8793f9 100644 --- a/internal/tracing/mcp_test.go +++ b/internal/tracing/mcp_test.go @@ -7,6 +7,7 @@ package tracing import ( "context" + "net/http" "testing" "github.com/modelcontextprotocol/go-sdk/jsonrpc" @@ -27,15 +28,22 @@ func TestTracer_StartSpanAndInjectMeta(t *testing.T) { map[string]string{ "x-tracing-enrichment-user-region": "user.region", "x-session-id": "session.id", + "CustomAttr": "custom.attr", }) + headers := make(http.Header) + headers.Add("X-Tracing-Enrichment-User-Region", "us-east-1") + headers.Add("X-Session-Id", "123") // should be ignored as the value in the metadata takes precedence + reqID, _ := jsonrpc.MakeID("id") r := &jsonrpc.Request{ID: reqID, Method: "initialize"} - p := &mcp.InitializeParams{Meta: map[string]any{"x-session-id": "sess-1234"}} - span := tracer.StartSpanAndInjectMeta(t.Context(), r, p, map[string]string{ - "x-tracing-enrichment-user-region": "us-east-1", - "x-session-id": "123", // should be ignored as the value in the metadata takes precedence - }) + p := &mcp.InitializeParams{Meta: map[string]any{ + "x-session-id": "sess-1234", // alphabetical order wins when multiple values match case-insensitively + "X-SESSION-ID": "sess-4567", + "customattr": "custom-value1", // exact match should win over case-insensitive match + "CustomAttr": "custom-value2", + }} + span := tracer.StartSpanAndInjectMeta(t.Context(), r, p, headers) require.NotNil(t, span) meta := p.GetMeta() @@ -49,7 +57,9 @@ func TestTracer_StartSpanAndInjectMeta(t *testing.T) { actualSpan := spans[0] require.Contains(t, actualSpan.Attributes, attribute.String("user.region", "us-east-1")) require.Contains(t, actualSpan.Attributes, attribute.String("session.id", "sess-1234")) + require.Contains(t, actualSpan.Attributes, attribute.String("custom.attr", "custom-value2")) require.NotContains(t, actualSpan.Attributes, attribute.String("session.id", "123")) + require.NotContains(t, actualSpan.Attributes, attribute.String("custom.attr", "custom-value1")) } func Test_getMCPAttributes(t *testing.T) { From 296fa4024d5dd042f6cd5103b96c9135fdf90d7b Mon Sep 17 00:00:00 2001 From: Ignasi Barrera Date: Tue, 14 Oct 2025 23:55:33 +0100 Subject: [PATCH 7/7] case-insensitive params in metrics as well Signed-off-by: Ignasi Barrera --- internal/lang/maps.go | 36 ++++++++++++++++++ internal/lang/maps_test.go | 57 ++++++++++++++++++++++++++++ internal/metrics/mcp_metrics.go | 6 ++- internal/metrics/mcp_metrics_test.go | 19 +++++++--- internal/tracing/mcp.go | 32 +--------------- internal/tracing/mcp_test.go | 4 +- 6 files changed, 114 insertions(+), 40 deletions(-) create mode 100644 internal/lang/maps.go create mode 100644 internal/lang/maps_test.go diff --git a/internal/lang/maps.go b/internal/lang/maps.go new file mode 100644 index 0000000000..eff005d870 --- /dev/null +++ b/internal/lang/maps.go @@ -0,0 +1,36 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package lang + +import ( + "fmt" + "maps" + "slices" + "strings" +) + +// CaseInsensitiveValue retrieves a value from the meta map in a case-insensitive manner. +// If the same key is present in different cases, the first one in alphabetical order +// that matches is returned. +// If the key is not found, it returns an empty string. +func CaseInsensitiveValue(m map[string]any, key string) string { + if m == nil { + return "" + } + + if v, ok := m[key]; ok { + return fmt.Sprintf("%v", v) + } + + keys := slices.Sorted(maps.Keys(m)) + for _, k := range keys { + if strings.EqualFold(k, key) { + return fmt.Sprintf("%v", m[k]) + } + } + + return "" +} diff --git a/internal/lang/maps_test.go b/internal/lang/maps_test.go new file mode 100644 index 0000000000..c9666920ac --- /dev/null +++ b/internal/lang/maps_test.go @@ -0,0 +1,57 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package lang + +import "testing" + +func TestCaseInsensitiveValue(t *testing.T) { + tests := []struct { + name string + m map[string]any + key string + want string + }{ + { + name: "nil map", + m: nil, + key: "anything", + want: "", + }, + { + name: "exact match returns value", + m: map[string]any{"Foo": "bar", "foo": "should-not-be-used"}, + key: "Foo", + want: "bar", + }, + { + name: "case-insensitive match when exact not present", + m: map[string]any{"FOO": "baz"}, + key: "foo", + want: "baz", + }, + { + name: "multiple case variants - alphabetical first chosen", + m: map[string]any{"ALPHA": 2, "Alpha": 1}, + key: "alpha", + want: "2", // ALPHA is alphabetically first + }, + { + name: "nil value formatted", + m: map[string]any{"key": nil}, + key: "key", + want: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := CaseInsensitiveValue(tc.m, tc.key) + if got != tc.want { + t.Fatalf("CaseInsensitiveValue(%v, %q) = %q; want %q", tc.m, tc.key, got, tc.want) + } + }) + } +} diff --git a/internal/metrics/mcp_metrics.go b/internal/metrics/mcp_metrics.go index 56e4956ff8..f4a4801c85 100644 --- a/internal/metrics/mcp_metrics.go +++ b/internal/metrics/mcp_metrics.go @@ -14,6 +14,8 @@ import ( mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" + + "github.com/envoyproxy/ai-gateway/internal/lang" ) // nolint: godot @@ -337,8 +339,8 @@ func (m *mcp) withDefaultAttributes(params mcpsdk.Params, attrs ...attribute.Key all := make([]attribute.KeyValue, 0, len(m.defaultAttributes)+len(m.requestHeaderAttributeMapping)+len(attrs)) all = append(all, m.defaultAttributes...) if params != nil { - for k, v := range params.GetMeta() { - if target, ok := m.requestHeaderAttributeMapping[k]; ok { + for src, target := range m.requestHeaderAttributeMapping { + if v := lang.CaseInsensitiveValue(params.GetMeta(), src); v != "" { all = append(all, attribute.String(target, fmt.Sprintf("%v", v))) } } diff --git a/internal/metrics/mcp_metrics_test.go b/internal/metrics/mcp_metrics_test.go index 42aedb4aa0..4fcd471f5c 100644 --- a/internal/metrics/mcp_metrics_test.go +++ b/internal/metrics/mcp_metrics_test.go @@ -32,27 +32,34 @@ func TestRecordMetricWithCustomAttributes(t *testing.T) { m := NewMCP(meter, map[string]string{ "x-tracing-enrichment-user-region": "user.region", - "x-session-id": "session.id", + "X-Session-Id": "session.id", + "CustomAttr": "custom.attr", }) require.NotNil(t, m) req, err := http.NewRequest("GET", "https://example.com", nil) require.NoError(t, err) - req.Header.Set("x-tracing-enrichment-user-region", "us-east-1") // should be included in metrics - req.Header.Set("x-other-attr", "other") // should be ignored - req.Header.Set("x-session-id", "123") // should be ignored as the value in the metadata takes precedence + req.Header.Set("X-Tracing-Enrichment-User-Region", "us-east-1") // should be included in metrics + req.Header.Set("X-Other-Attr", "other") // should be ignored + req.Header.Set("X-Session-Id", "123") // should be ignored as the value in the metadata takes precedence m = m.WithRequestAttributes(req) startAt := time.Now().Add(-1 * time.Minute) m.RecordRequestDuration(t.Context(), &startAt, &mcpsdk.InitializeParams{ - Meta: map[string]any{"x-session-id": "sess-1234"}, + Meta: map[string]any{ + "x-session-id": "sess-1234", // alphabetical order wins when multiple values match case-insensitively + "X-SESSION-ID": "sess-4567", + "customattr": "custom-value1", // exact match should win over case-insensitive match + "CustomAttr": "custom-value2", + }, }) count, sum := testotel.GetHistogramValues(t, mr, mcpRequestDuration, attribute.NewSet( attribute.String("user.region", "us-east-1"), - attribute.String("session.id", "sess-1234"), + attribute.String("session.id", "sess-4567"), + attribute.String("custom.attr", "custom-value2"), )) require.Equal(t, uint64(1), count) require.Equal(t, 60, int(sum)) diff --git a/internal/tracing/mcp.go b/internal/tracing/mcp.go index b524677791..c0910adb70 100644 --- a/internal/tracing/mcp.go +++ b/internal/tracing/mcp.go @@ -8,11 +8,7 @@ package tracing import ( "context" "fmt" - "maps" "net/http" - "slices" - "sort" - "strings" "github.com/modelcontextprotocol/go-sdk/jsonrpc" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -21,6 +17,7 @@ import ( "go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/trace" + "github.com/envoyproxy/ai-gateway/internal/lang" tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" ) @@ -90,7 +87,7 @@ func (m mcpTracer) StartSpanAndInjectMeta(ctx context.Context, req *jsonrpc.Requ // Check if the attribute is present in the metadata first, as this is the common place to add custom attributes // in MCP requests. Fall back to headers if not found in metadata. // If the attribute is not found there, check if there is any custom header to map. - if metaValue := caseInsensitiveValue(param.GetMeta(), srcName); metaValue != "" { + if metaValue := lang.CaseInsensitiveValue(param.GetMeta(), srcName); metaValue != "" { attrs = append(attrs, attribute.String(targetName, metaValue)) } else if headerValue := headers.Get(srcName); headerValue != "" { // this is case-insensitive attrs = append(attrs, attribute.String(targetName, headerValue)) @@ -126,31 +123,6 @@ func (m mcpTracer) StartSpanAndInjectMeta(ctx context.Context, req *jsonrpc.Requ return nil } -// caseInsensitiveValue retrieves a value from the meta map in a case-insensitive manner. -// If the same key is present in different cases, the first one in alphabetical order -// that matches is returned. -// If the key is not found, it returns an empty string. -func caseInsensitiveValue(meta map[string]any, key string) string { - if meta == nil { - return "" - } - - if v, ok := meta[key]; ok { - return fmt.Sprintf("%v", v) - } - - keys := slices.Collect(maps.Keys(meta)) - sort.Strings(keys) - - for _, k := range keys { - if strings.EqualFold(k, key) { - return fmt.Sprintf("%v", meta[k]) - } - } - - return "" -} - func getMCPParamsAsAttributes(p mcp.Params) []attribute.KeyValue { var attrs []attribute.KeyValue switch params := p.(type) { diff --git a/internal/tracing/mcp_test.go b/internal/tracing/mcp_test.go index bc5c8793f9..c11c16571f 100644 --- a/internal/tracing/mcp_test.go +++ b/internal/tracing/mcp_test.go @@ -27,7 +27,7 @@ func TestTracer_StartSpanAndInjectMeta(t *testing.T) { tracer := newMCPTracer(tp.Tracer("test"), autoprop.NewTextMapPropagator(), map[string]string{ "x-tracing-enrichment-user-region": "user.region", - "x-session-id": "session.id", + "X-Session-Id": "session.id", "CustomAttr": "custom.attr", }) @@ -56,7 +56,7 @@ func TestTracer_StartSpanAndInjectMeta(t *testing.T) { require.Len(t, spans, 1) actualSpan := spans[0] require.Contains(t, actualSpan.Attributes, attribute.String("user.region", "us-east-1")) - require.Contains(t, actualSpan.Attributes, attribute.String("session.id", "sess-1234")) + require.Contains(t, actualSpan.Attributes, attribute.String("session.id", "sess-4567")) require.Contains(t, actualSpan.Attributes, attribute.String("custom.attr", "custom-value2")) require.NotContains(t, actualSpan.Attributes, attribute.String("session.id", "123")) require.NotContains(t, actualSpan.Attributes, attribute.String("custom.attr", "custom-value1"))