diff --git a/cmd/extproc/mainlib/main.go b/cmd/extproc/mainlib/main.go index 9453b09db6..d9e493d0b6 100644 --- a/cmd/extproc/mainlib/main.go +++ b/cmd/extproc/mainlib/main.go @@ -233,7 +233,7 @@ func Main(ctx context.Context, args []string, stderr io.Writer) (err error) { chatCompletionMetrics := metrics.NewChatCompletion(meter, metricsRequestHeaderAttributes) completionMetrics := metrics.NewCompletion(meter, metricsRequestHeaderAttributes) embeddingsMetrics := metrics.NewEmbeddings(meter, metricsRequestHeaderAttributes) - mcpMetrics := metrics.NewMCP(meter) + mcpMetrics := metrics.NewMCP(meter, metricsRequestHeaderAttributes) tracing, err := tracing.NewTracingFromEnv(ctx, os.Stdout, spanRequestHeaderAttributes) if err != nil { @@ -264,13 +264,13 @@ func Main(ctx context.Context, args []string, stderr io.Writer) (err error) { seed, fallbackSeed, _ := strings.Cut(flags.mcpSessionEncryptionSeed, ",") mcpSessionCrypto := mcpproxy.DefaultSessionCrypto(seed, fallbackSeed) var mcpProxyMux *http.ServeMux - var mcpProxy *mcpproxy.MCPProxy - mcpProxy, mcpProxyMux, err = mcpproxy.NewMCPProxy(l.With("component", "mcp-proxy"), mcpMetrics, + var mcpProxyConfig *mcpproxy.ProxyConfig + mcpProxyConfig, mcpProxyMux, err = mcpproxy.NewMCPProxy(l.With("component", "mcp-proxy"), mcpMetrics, tracing.MCPTracer(), mcpSessionCrypto) if err != nil { return fmt.Errorf("failed to create MCP proxy: %w", err) } - if err = extproc.StartConfigWatcher(ctx, flags.configPath, mcpProxy, l, time.Second*5); err != nil { + if err = extproc.StartConfigWatcher(ctx, flags.configPath, mcpProxyConfig, l, time.Second*5); err != nil { return fmt.Errorf("failed to start config watcher: %w", err) } diff --git a/internal/lang/maps.go b/internal/lang/maps.go new file mode 100644 index 0000000000..eff005d870 --- /dev/null +++ b/internal/lang/maps.go @@ -0,0 +1,36 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package lang + +import ( + "fmt" + "maps" + "slices" + "strings" +) + +// CaseInsensitiveValue retrieves a value from the meta map in a case-insensitive manner. +// If the same key is present in different cases, the first one in alphabetical order +// that matches is returned. +// If the key is not found, it returns an empty string. +func CaseInsensitiveValue(m map[string]any, key string) string { + if m == nil { + return "" + } + + if v, ok := m[key]; ok { + return fmt.Sprintf("%v", v) + } + + keys := slices.Sorted(maps.Keys(m)) + for _, k := range keys { + if strings.EqualFold(k, key) { + return fmt.Sprintf("%v", m[k]) + } + } + + return "" +} diff --git a/internal/lang/maps_test.go b/internal/lang/maps_test.go new file mode 100644 index 0000000000..c9666920ac --- /dev/null +++ b/internal/lang/maps_test.go @@ -0,0 +1,57 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package lang + +import "testing" + +func TestCaseInsensitiveValue(t *testing.T) { + tests := []struct { + name string + m map[string]any + key string + want string + }{ + { + name: "nil map", + m: nil, + key: "anything", + want: "", + }, + { + name: "exact match returns value", + m: map[string]any{"Foo": "bar", "foo": "should-not-be-used"}, + key: "Foo", + want: "bar", + }, + { + name: "case-insensitive match when exact not present", + m: map[string]any{"FOO": "baz"}, + key: "foo", + want: "baz", + }, + { + name: "multiple case variants - alphabetical first chosen", + m: map[string]any{"ALPHA": 2, "Alpha": 1}, + key: "alpha", + want: "2", // ALPHA is alphabetically first + }, + { + name: "nil value formatted", + m: map[string]any{"key": nil}, + key: "key", + want: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := CaseInsensitiveValue(tc.m, tc.key) + if got != tc.want { + t.Fatalf("CaseInsensitiveValue(%v, %q) = %q; want %q", tc.m, tc.key, got, tc.want) + } + }) + } +} diff --git a/internal/mcpproxy/handlers.go b/internal/mcpproxy/handlers.go index d9e9b44eb7..cf6fe1e5e8 100644 --- a/internal/mcpproxy/handlers.go +++ b/internal/mcpproxy/handlers.go @@ -107,6 +107,7 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { errType metrics.MCPErrorType requestMethod string span tracing.MCPSpan + params mcp.Params ) defer func() { if m.l.Enabled(ctx, slog.LevelDebug) { @@ -119,17 +120,17 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { if span != nil { span.EndSpanOnError(string(errType), err) } - m.metrics.RecordMethodErrorCount(ctx) - m.metrics.RecordRequestErrorDuration(ctx, &startAt, errType) + m.metrics.RecordMethodErrorCount(ctx, params) + m.metrics.RecordRequestErrorDuration(ctx, &startAt, errType, params) return } if span != nil { span.EndSpan() } - m.metrics.RecordRequestDuration(ctx, &startAt) + m.metrics.RecordRequestDuration(ctx, &startAt, params) // TODO: should we special case when this request is "Response" where method is empty? - m.metrics.RecordMethodCount(ctx, requestMethod) + m.metrics.RecordMethodCount(ctx, requestMethod, params) }() if sessionID := r.Header.Get(sessionIDHeader); sessionID != "" { s, err = m.sessionFromID(secureClientToGatewaySessionID(sessionID), secureClientToGatewayEventID(r.Header.Get(lastEventIDHeader))) @@ -189,8 +190,8 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { switch msg.Method { case "notifications/roots/list_changed": - p := &mcp.RootsListChangedParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + params = &mcp.RootsListChangedParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") @@ -198,28 +199,28 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { } err = m.handleNotificationsRootsListChanged(ctx, s, w, msg, span) case "completion/complete": - p := &mcp.CompleteParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + params = &mcp.CompleteParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") return } - err = m.handleCompletionComplete(ctx, s, w, msg, p, span) + err = m.handleCompletionComplete(ctx, s, w, msg, params.(*mcp.CompleteParams), span) case "notifications/progress": - m.metrics.RecordProgress(ctx) - p := &mcp.ProgressNotificationParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + params = &mcp.ProgressNotificationParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) + m.metrics.RecordProgress(ctx, params) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") return } - err = m.handleClientToServerNotificationsProgress(ctx, s, w, msg, p, span) + err = m.handleClientToServerNotificationsProgress(ctx, s, w, msg, params.(*mcp.ProgressNotificationParams), span) case "initialize": // The very first request from the client to establish a session. - p := &mcp.InitializeParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + params = &mcp.InitializeParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam m.l.Error("Failed to unmarshal initialize params", slog.String("error", err.Error())) @@ -235,107 +236,107 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) { onErrorResponse(w, http.StatusInternalServerError, "missing route header") return } - err = m.handleInitializeRequest(ctx, w, msg, p, route, extractSubject(r), span) + err = m.handleInitializeRequest(ctx, w, msg, params.(*mcp.InitializeParams), route, extractSubject(r), span) case "notifications/initialized": // According to the MCP spec, when the server receives a JSON-RPC response or notification from the client // and accepts it, the server MUST return HTTP 202 Accepted with an empty body. // https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server w.WriteHeader(http.StatusAccepted) case "logging/setLevel": - p := &mcp.SetLoggingLevelParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + params = &mcp.SetLoggingLevelParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam m.l.Error("Failed to unmarshal set logging level params", slog.String("error", err.Error())) onErrorResponse(w, http.StatusBadRequest, "invalid set logging level params") return } - err = m.handleSetLoggingLevel(ctx, s, w, msg, p, span) + err = m.handleSetLoggingLevel(ctx, s, w, msg, params.(*mcp.SetLoggingLevelParams), span) case "ping": // Ping is intentionally not traced as it's a lightweight health check. err = m.handlePing(ctx, w, msg) case "prompts/list": - p := &mcp.ListPromptsParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + params = &mcp.ListPromptsParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") return } - err = m.handlePromptListRequest(ctx, s, w, msg, p, span) + err = m.handlePromptListRequest(ctx, s, w, msg, params.(*mcp.ListPromptsParams), span) case "prompts/get": - p := &mcp.GetPromptParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + params = &mcp.GetPromptParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") return } - err = m.handlePromptGetRequest(ctx, s, w, msg, p) + err = m.handlePromptGetRequest(ctx, s, w, msg, params.(*mcp.GetPromptParams)) case "tools/call": - p := &mcp.CallToolParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + params = &mcp.CallToolParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam m.l.Error("Failed to unmarshal params", slog.String("method", msg.Method), slog.String("error", err.Error())) onErrorResponse(w, http.StatusBadRequest, "invalid params") return } - err = m.handleToolCallRequest(ctx, s, w, msg, p, span) + err = m.handleToolCallRequest(ctx, s, w, msg, params.(*mcp.CallToolParams), span) case "tools/list": - p := &mcp.ListToolsParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + params = &mcp.ListToolsParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") return } - err = m.handleToolsListRequest(ctx, s, w, msg, p, span) + err = m.handleToolsListRequest(ctx, s, w, msg, params.(*mcp.ListToolsParams), span) case "resources/list": - p := &mcp.ListResourcesParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + params = &mcp.ListResourcesParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") return } - err = m.handleResourceListRequest(ctx, s, w, msg, p, span) + err = m.handleResourceListRequest(ctx, s, w, msg, params.(*mcp.ListResourcesParams), span) case "resources/read": - p := &mcp.ReadResourceParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + params = &mcp.ReadResourceParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") return } - err = m.handleResourceReadRequest(ctx, s, w, msg, p) + err = m.handleResourceReadRequest(ctx, s, w, msg, params.(*mcp.ReadResourceParams)) case "resources/templates/list": - p := &mcp.ListResourceTemplatesParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + params = &mcp.ListResourceTemplatesParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") return } - err = m.handleResourcesTemplatesListRequest(ctx, s, w, msg, p, span) + err = m.handleResourcesTemplatesListRequest(ctx, s, w, msg, params.(*mcp.ListResourceTemplatesParams), span) case "resources/subscribe": - p := &mcp.SubscribeParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + params = &mcp.SubscribeParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") return } - err = m.handleResourcesSubscribeRequest(ctx, s, w, msg, p, span) + err = m.handleResourcesSubscribeRequest(ctx, s, w, msg, params.(*mcp.SubscribeParams), span) case "resources/unsubscribe": - p := &mcp.UnsubscribeParams{} - span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p) + params = &mcp.UnsubscribeParams{} + span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header) if err != nil { errType = metrics.MCPErrorInvalidParam onErrorResponse(w, http.StatusBadRequest, "invalid params") return } - err = m.handleResourcesUnsubscribeRequest(ctx, s, w, msg, p, span) + err = m.handleResourcesUnsubscribeRequest(ctx, s, w, msg, params.(*mcp.UnsubscribeParams), span) case "notifications/cancelled": // The responsibility of cancelling the operation on server side is optional, so we just ignore it for now. // https://modelcontextprotocol.io/specification/2025-06-18/basic/utilities/cancellation#behavior-requirements @@ -371,8 +372,7 @@ func errorType(err error) metrics.MCPErrorType { // handleInitializeRequest handles the "initialize" JSON-RPC method. func (m *MCPProxy) handleInitializeRequest(ctx context.Context, w http.ResponseWriter, req *jsonrpc.Request, p *mcp.InitializeParams, route, subject string, span tracing.MCPSpan) error { - m.metrics.RecordClientCapabilities(ctx, p.Capabilities) - + m.metrics.RecordClientCapabilities(ctx, p.Capabilities, p) s, err := m.newSession(ctx, p, route, subject, span) if err != nil { m.l.Error("failed to create new session", slog.String("error", err.Error())) @@ -789,7 +789,11 @@ func (m *MCPProxy) recordResponse(ctx context.Context, rawMsg jsonrpc.Message) { case "notifications/resources/list_changed": case "notifications/resources/updated": case "notifications/progress": - m.metrics.RecordProgress(ctx) + params := &mcp.ProgressNotificationParams{} + if err := json.Unmarshal(msg.Params, ¶ms); err != nil { + m.l.Error("Failed to unmarshal params", slog.String("method", msg.Method), slog.String("error", err.Error())) + } + m.metrics.RecordProgress(ctx, params) case "notifications/message": case "notifications/tools/list_changed": case "roots/list": @@ -797,11 +801,11 @@ func (m *MCPProxy) recordResponse(ctx context.Context, rawMsg jsonrpc.Message) { case "elicitation/create": default: knownMethod = false - m.metrics.RecordMethodErrorCount(ctx) + m.metrics.RecordMethodErrorCount(ctx, nil) m.l.Warn("Unsupported MCP request method from server", slog.String("method", msg.Method)) } if knownMethod { - m.metrics.RecordMethodCount(ctx, msg.Method) + m.metrics.RecordMethodCount(ctx, msg.Method, nil) } default: m.l.Warn("unexpected message type in MCP response", slog.Any("message", msg)) @@ -1223,7 +1227,7 @@ func sendToAllBackendsAndAggregateResponsesImpl[responseType any](ctx context.Co } // parseParamsAndMaybeStartSpan parses the params from the JSON-RPC request and starts a tracing span if params is non-nil. -func parseParamsAndMaybeStartSpan[paramType mcp.Params](ctx context.Context, m *MCPProxy, req *jsonrpc.Request, p paramType) (tracing.MCPSpan, error) { +func parseParamsAndMaybeStartSpan[paramType mcp.Params](ctx context.Context, m *MCPProxy, req *jsonrpc.Request, p paramType, headers http.Header) (tracing.MCPSpan, error) { if req.Params == nil { return nil, nil } @@ -1233,7 +1237,7 @@ func parseParamsAndMaybeStartSpan[paramType mcp.Params](ctx context.Context, m * return nil, err } - span := m.tracer.StartSpanAndInjectMeta(ctx, req, p) + span := m.tracer.StartSpanAndInjectMeta(ctx, req, p, headers) return span, nil } diff --git a/internal/mcpproxy/handlers_test.go b/internal/mcpproxy/handlers_test.go index 747293f1b3..e6fec797c0 100644 --- a/internal/mcpproxy/handlers_test.go +++ b/internal/mcpproxy/handlers_test.go @@ -71,7 +71,7 @@ func newTestMCPProxyWithTracer(t tracingapi.MCPTracer) *MCPProxy { func newTestMCPProxyWithOTEL(mr *sdkmetric.ManualReader, tracer tracingapi.MCPTracer) *MCPProxy { mcpProxy := newTestMCPProxyWithTracer(tracer) meter := sdkmetric.NewMeterProvider(sdkmetric.WithReader(mr)).Meter("test") - mcpProxy.metrics = metrics.NewMCP(meter) + mcpProxy.metrics = metrics.NewMCP(meter, nil) return mcpProxy } @@ -1598,7 +1598,7 @@ func Test_parseParamsAndMaybeStartSpan(t *testing.T) { trace, err := tracing.NewTracingFromEnv(t.Context(), t.Output(), nil) require.NoError(t, err) m.tracer = trace.MCPTracer() - s, err := parseParamsAndMaybeStartSpan(t.Context(), m, req, p) + s, err := parseParamsAndMaybeStartSpan(t.Context(), m, req, p, nil) require.NoError(t, err) require.NotNil(t, s) // Make sure that traceparent is not empty, that's span started. @@ -1611,7 +1611,7 @@ func Test_parseParamsAndMaybeStartSpan_NilParam(t *testing.T) { } p := &mcp.GetPromptParams{} m := newTestMCPProxy() - s, err := parseParamsAndMaybeStartSpan(t.Context(), m, req, p) + s, err := parseParamsAndMaybeStartSpan(t.Context(), m, req, p, nil) require.NoError(t, err) require.Nil(t, s) } diff --git a/internal/mcpproxy/mcpproxy.go b/internal/mcpproxy/mcpproxy.go index edf639aa0b..05689a51a9 100644 --- a/internal/mcpproxy/mcpproxy.go +++ b/internal/mcpproxy/mcpproxy.go @@ -29,6 +29,11 @@ import ( ) type ( + // ProxyConfig holds the main MCP proxy configuration. + ProxyConfig struct { + *mcpProxyConfig + } + // MCPProxy serves /mcp endpoint. // // This implements [extproc.ConfigReceiver] to gets the up-to-date configuration. @@ -77,8 +82,8 @@ func (f *toolSelector) allows(tool string) bool { } // NewMCPProxy creates a new MCPProxy instance. -func NewMCPProxy(l *slog.Logger, mcpMetrics metrics.MCPMetrics, tracer tracing.MCPTracer, sessionCrypto SessionCrypto) (*MCPProxy, *http.ServeMux, error) { - p := &MCPProxy{l: l, metrics: mcpMetrics, tracer: tracer, sessionCrypto: sessionCrypto} +func NewMCPProxy(l *slog.Logger, mcpMetrics metrics.MCPMetrics, tracer tracing.MCPTracer, sessionCrypto SessionCrypto) (*ProxyConfig, *http.ServeMux, error) { + cfg := &ProxyConfig{} mux := http.NewServeMux() mux.HandleFunc( // Must match all paths since the route selection happens at Envoy level and the "route" header is already @@ -87,23 +92,31 @@ func NewMCPProxy(l *slog.Logger, mcpMetrics metrics.MCPMetrics, tracer tracing.M // For example, if we mistakenly set /mcp here, only the route with prefix /mcp will be matched, and other routes // with different prefixes will not be matched, which is not what we want. "/", func(w http.ResponseWriter, r *http.Request) { + proxy := &MCPProxy{ + mcpProxyConfig: cfg.mcpProxyConfig, + l: l, + metrics: mcpMetrics.WithRequestAttributes(r), + tracer: tracer, + sessionCrypto: sessionCrypto, + } + switch r.Method { case http.MethodGet: - p.serveGET(w, r) + proxy.serveGET(w, r) case http.MethodPost: - p.servePOST(w, r) + proxy.servePOST(w, r) case http.MethodDelete: - p.serverDELETE(w, r) + proxy.serverDELETE(w, r) default: http.Error(w, "method not allowed", http.StatusMethodNotAllowed) } }) - return p, mux, nil + return cfg, mux, nil } // LoadConfig implements [extproc.ConfigReceiver.LoadConfig] which will be called // when the configuration is updated on the file system. -func (m *MCPProxy) LoadConfig(_ context.Context, config *filterapi.Config) error { +func (p *ProxyConfig) LoadConfig(_ context.Context, config *filterapi.Config) error { newConfig := &mcpProxyConfig{} mcpConfig := config.MCPConfig if config.MCPConfig == nil { @@ -145,7 +158,7 @@ func (m *MCPProxy) LoadConfig(_ context.Context, config *filterapi.Config) error newConfig.routes[route.Name] = r } - m.mcpProxyConfig = newConfig // This is racy, but we don't care. + p.mcpProxyConfig = newConfig // This is racy, but we don't care. return nil } @@ -186,7 +199,7 @@ func (m *MCPProxy) newSession(ctx context.Context, p *mcp.InitializeParams, rout // TODO: should we record a metric for this? return } - m.metrics.RecordInitializationDuration(ctx, &startAt) + m.metrics.RecordInitializationDuration(ctx, &startAt, p) if m.l.Enabled(ctx, slog.LevelDebug) { m.l.Debug("created MCP session", slog.String("backend", backend.Name), slog.String("session_id", string(initResult.sessionID))) } @@ -339,7 +352,7 @@ func (m *MCPProxy) initializeSession(ctx context.Context, routeName filterapi.MC if m.l.Enabled(ctx, slog.LevelDebug) { m.l.Debug("MCP session initialized", slog.Any("capabilities", initResult.Capabilities)) } - m.metrics.RecordServerCapabilities(ctx, initResult.Capabilities) + m.metrics.RecordServerCapabilities(ctx, initResult.Capabilities, p) } // Need to invoke "notifications/initialized" to complete the initialization. diff --git a/internal/mcpproxy/mcpproxy_test.go b/internal/mcpproxy/mcpproxy_test.go index 1a78b2c8e9..f15f042a51 100644 --- a/internal/mcpproxy/mcpproxy_test.go +++ b/internal/mcpproxy/mcpproxy_test.go @@ -50,7 +50,7 @@ type fakeTracer struct { span *fakeSpan } -func (f *fakeTracer) StartSpanAndInjectMeta(_ context.Context, _ *jsonrpc.Request, _ mcp.Params) tracing.MCPSpan { +func (f *fakeTracer) StartSpanAndInjectMeta(context.Context, *jsonrpc.Request, mcp.Params, http.Header) tracing.MCPSpan { if f.span == nil { f.span = &fakeSpan{} } @@ -66,8 +66,6 @@ func TestNewMCPProxy(t *testing.T) { require.NoError(t, err) require.NotNil(t, proxy) require.NotNil(t, mux) - require.Equal(t, l, proxy.l) - require.NotNil(t, proxy.metrics) } func TestMCPProxy_HTTPMethods(t *testing.T) { @@ -86,11 +84,12 @@ func TestMCPProxy_HTTPMethods(t *testing.T) { } func TestLoadConfig_NilMCPConfig(t *testing.T) { - proxy := newTestMCPProxy() - config := &filterapi.Config{MCPConfig: nil} + proxy, _, err := NewMCPProxy(slog.Default(), stubMetrics{}, noopTracer, DefaultSessionCrypto("test", "")) + require.NoError(t, err) - err := proxy.LoadConfig(t.Context(), config) + config := &filterapi.Config{MCPConfig: nil} + err = proxy.LoadConfig(t.Context(), config) require.NoError(t, err) } diff --git a/internal/mcpproxy/session_test.go b/internal/mcpproxy/session_test.go index e34f503d85..b63714e79c 100644 --- a/internal/mcpproxy/session_test.go +++ b/internal/mcpproxy/session_test.go @@ -30,15 +30,19 @@ import ( // stubMetrics implements metrics.MCPMetrics with no-ops. type stubMetrics struct{} -func (stubMetrics) RecordRequestDuration(_ context.Context, _ *time.Time) {} -func (stubMetrics) RecordRequestErrorDuration(_ context.Context, _ *time.Time, _ metrics.MCPErrorType) { +func (s stubMetrics) WithRequestAttributes(_ *http.Request) metrics.MCPMetrics { return s } +func (stubMetrics) RecordRequestDuration(_ context.Context, _ *time.Time, _ mcpsdk.Params) {} +func (stubMetrics) RecordRequestErrorDuration(_ context.Context, _ *time.Time, _ metrics.MCPErrorType, _ mcpsdk.Params) { } -func (stubMetrics) RecordMethodCount(_ context.Context, _ string) {} -func (stubMetrics) RecordMethodErrorCount(_ context.Context) {} -func (stubMetrics) RecordInitializationDuration(_ context.Context, _ *time.Time) {} -func (stubMetrics) RecordClientCapabilities(_ context.Context, _ *mcpsdk.ClientCapabilities) {} -func (stubMetrics) RecordServerCapabilities(_ context.Context, _ *mcpsdk.ServerCapabilities) {} -func (stubMetrics) RecordProgress(_ context.Context) {} +func (stubMetrics) RecordMethodCount(_ context.Context, _ string, _ mcpsdk.Params) {} +func (stubMetrics) RecordMethodErrorCount(_ context.Context, _ mcpsdk.Params) {} +func (stubMetrics) RecordInitializationDuration(_ context.Context, _ *time.Time, _ mcpsdk.Params) {} +func (stubMetrics) RecordClientCapabilities(_ context.Context, _ *mcpsdk.ClientCapabilities, _ mcpsdk.Params) { +} + +func (stubMetrics) RecordServerCapabilities(_ context.Context, _ *mcpsdk.ServerCapabilities, _ mcpsdk.Params) { +} +func (stubMetrics) RecordProgress(_ context.Context, _ mcpsdk.Params) {} func TestBackendSessionIDs_Success(t *testing.T) { backendA := "backendA" diff --git a/internal/metrics/base_metrics.go b/internal/metrics/base_metrics.go index d76bde3f97..67e18ecee9 100644 --- a/internal/metrics/base_metrics.go +++ b/internal/metrics/base_metrics.go @@ -26,21 +26,21 @@ type baseMetrics struct { // requestModel is the original model from the request body. requestModel string // responseModel is the model that ultimately generated the response (may differ due to backend override). - responseModel string - backend string - requestHeaderLabelMapping map[string]string // maps HTTP headers to metric label names. + responseModel string + backend string + requestHeaderAttributeMapping map[string]string // maps HTTP headers to metric attribute names. } // newBaseMetrics creates a new baseMetrics instance with the specified operation. -func newBaseMetrics(meter metric.Meter, operation string, requestHeaderLabelMapping map[string]string) baseMetrics { +func newBaseMetrics(meter metric.Meter, operation string, requestHeaderAttributeMapping map[string]string) baseMetrics { return baseMetrics{ - metrics: newGenAI(meter), - operation: operation, - originalModel: "unknown", - requestModel: "unknown", - responseModel: "unknown", - backend: "unknown", - requestHeaderLabelMapping: requestHeaderLabelMapping, + metrics: newGenAI(meter), + operation: operation, + originalModel: "unknown", + requestModel: "unknown", + responseModel: "unknown", + backend: "unknown", + requestHeaderAttributeMapping: requestHeaderAttributeMapping, } } @@ -85,13 +85,13 @@ func (b *baseMetrics) buildBaseAttributes(headers map[string]string) attribute.S origModel := attribute.Key(genaiAttributeOriginalModel).String(b.originalModel) reqModel := attribute.Key(genaiAttributeRequestModel).String(b.requestModel) respModel := attribute.Key(genaiAttributeResponseModel).String(b.responseModel) - if len(b.requestHeaderLabelMapping) == 0 { + if len(b.requestHeaderAttributeMapping) == 0 { return attribute.NewSet(opt, provider, origModel, reqModel, respModel) } // Add header values as attributes based on the header mapping if headers are provided. attrs := []attribute.KeyValue{opt, provider, origModel, reqModel, respModel} - for headerName, labelName := range b.requestHeaderLabelMapping { + for headerName, labelName := range b.requestHeaderAttributeMapping { if headerValue, exists := headers[headerName]; exists { attrs = append(attrs, attribute.Key(labelName).String(headerValue)) } diff --git a/internal/metrics/chat_completion_metrics_test.go b/internal/metrics/chat_completion_metrics_test.go index c5bd515d5f..04a72fffd3 100644 --- a/internal/metrics/chat_completion_metrics_test.go +++ b/internal/metrics/chat_completion_metrics_test.go @@ -213,7 +213,7 @@ func TestHeaderLabelMapping(t *testing.T) { pm.RecordTokenUsage(t.Context(), 10, 5, requestHeaders) // Verify that the header mapping is set correctly. - assert.Equal(t, headerMapping, pm.requestHeaderLabelMapping) + assert.Equal(t, headerMapping, pm.requestHeaderAttributeMapping) // Verify that the metrics are recorded with the mapped header attributes. attrs := attribute.NewSet( diff --git a/internal/metrics/completion_metrics_test.go b/internal/metrics/completion_metrics_test.go index dc4758e52b..f167d06367 100644 --- a/internal/metrics/completion_metrics_test.go +++ b/internal/metrics/completion_metrics_test.go @@ -216,7 +216,7 @@ func TestCompletion_HeaderLabelMapping(t *testing.T) { pm.RecordTokenUsage(t.Context(), 10, 5, requestHeaders) // Verify that the header mapping is set correctly. - assert.Equal(t, headerMapping, pm.requestHeaderLabelMapping) + assert.Equal(t, headerMapping, pm.requestHeaderAttributeMapping) // Verify that the metrics are recorded with the mapped header attributes. attrs := attribute.NewSet( diff --git a/internal/metrics/embeddings_metrics.go b/internal/metrics/embeddings_metrics.go index c0571ff332..f42114210e 100644 --- a/internal/metrics/embeddings_metrics.go +++ b/internal/metrics/embeddings_metrics.go @@ -44,9 +44,9 @@ type EmbeddingsMetrics interface { } // NewEmbeddings creates a new Embeddings instance. -func NewEmbeddings(meter metric.Meter, requestHeaderLabelMapping map[string]string) EmbeddingsMetrics { +func NewEmbeddings(meter metric.Meter, requestHeaderAttributeMapping map[string]string) EmbeddingsMetrics { return &embeddings{ - baseMetrics: newBaseMetrics(meter, genaiOperationEmbedding, requestHeaderLabelMapping), + baseMetrics: newBaseMetrics(meter, genaiOperationEmbedding, requestHeaderAttributeMapping), } } diff --git a/internal/metrics/embeddings_metrics_test.go b/internal/metrics/embeddings_metrics_test.go index a1df67b39b..0159a14eb8 100644 --- a/internal/metrics/embeddings_metrics_test.go +++ b/internal/metrics/embeddings_metrics_test.go @@ -101,7 +101,7 @@ func TestEmbeddings_HeaderLabelMapping(t *testing.T) { em.RecordTokenUsage(t.Context(), 10, requestHeaders) // Verify that the header mapping is set correctly. - assert.Equal(t, headerMapping, em.requestHeaderLabelMapping) + assert.Equal(t, headerMapping, em.requestHeaderAttributeMapping) // Verify that the metrics are recorded with the mapped header attributes. attrs := attribute.NewSet( diff --git a/internal/metrics/mcp_metrics.go b/internal/metrics/mcp_metrics.go index 86eb1cff64..f4a4801c85 100644 --- a/internal/metrics/mcp_metrics.go +++ b/internal/metrics/mcp_metrics.go @@ -7,11 +7,15 @@ package metrics import ( "context" + "fmt" + "net/http" "time" mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" + + "github.com/envoyproxy/ai-gateway/internal/lang" ) // nolint: godot @@ -102,35 +106,40 @@ const ( // MCPMetrics holds metrics for MCP. type MCPMetrics interface { + // WithRequestAttributes returns a new MCPMetrics instance with default attributes extracted from the HTTP request. + WithRequestAttributes(req *http.Request) MCPMetrics // RecordRequestDuration records the duration of a success MCP request. - RecordRequestDuration(ctx context.Context, startAt *time.Time) + RecordRequestDuration(ctx context.Context, startAt *time.Time, meta mcpsdk.Params) // RecordRequestErrorDuration records the duration of an MCP request that resulted in an error. - RecordRequestErrorDuration(ctx context.Context, startAt *time.Time, errType MCPErrorType) + RecordRequestErrorDuration(ctx context.Context, startAt *time.Time, errType MCPErrorType, meta mcpsdk.Params) // RecordMethodCount records the count of method invocations. - RecordMethodCount(ctx context.Context, methodName string) + RecordMethodCount(ctx context.Context, methodName string, meta mcpsdk.Params) // RecordMethodErrorCount records the count of method invocations with error status. - RecordMethodErrorCount(ctx context.Context) + RecordMethodErrorCount(ctx context.Context, meta mcpsdk.Params) // RecordInitializationDuration records the duration of MCP initialization. - RecordInitializationDuration(ctx context.Context, startAt *time.Time) + RecordInitializationDuration(ctx context.Context, startAt *time.Time, meta mcpsdk.Params) // RecordClientCapabilities records the negotiated client capabilities. - RecordClientCapabilities(ctx context.Context, capabilities *mcpsdk.ClientCapabilities) + RecordClientCapabilities(ctx context.Context, capabilities *mcpsdk.ClientCapabilities, meta mcpsdk.Params) // RecordServerCapabilities records the negotiated server capabilities. - RecordServerCapabilities(ctx context.Context, capabilities *mcpsdk.ServerCapabilities) + RecordServerCapabilities(ctx context.Context, capabilities *mcpsdk.ServerCapabilities, meta mcpsdk.Params) // RecordProgress records a progress notification sent/received. - RecordProgress(ctx context.Context) + RecordProgress(ctx context.Context, meta mcpsdk.Params) } type mcp struct { - requestDuration metric.Float64Histogram - methodCount metric.Float64Counter - initializationDuration metric.Float64Histogram - capabilitiesNegotiated metric.Float64Counter - progressNotifications metric.Float64Counter + requestDuration metric.Float64Histogram + methodCount metric.Float64Counter + initializationDuration metric.Float64Histogram + capabilitiesNegotiated metric.Float64Counter + progressNotifications metric.Float64Counter + requestHeaderAttributeMapping map[string]string // maps HTTP headers to metric attribute names. + defaultAttributes []attribute.KeyValue } // NewMCP creates a new mcp metrics instance. -func NewMCP(meter metric.Meter) MCPMetrics { +func NewMCP(meter metric.Meter, requestHeaderAttributeMapping map[string]string) MCPMetrics { return &mcp{ + requestHeaderAttributeMapping: requestHeaderAttributeMapping, requestDuration: mustRegisterHistogram(meter, mcpRequestDuration, metric.WithDescription("Duration of MCP requests"), @@ -159,92 +168,116 @@ func NewMCP(meter metric.Meter) MCPMetrics { } } +// WithRequestAttributes returns a new MCPMetrics instance with default attributes extracted from +// the HTTP request headers. +func (m *mcp) WithRequestAttributes(req *http.Request) MCPMetrics { + withAttrs := &mcp{ + requestDuration: m.requestDuration, + methodCount: m.methodCount, + initializationDuration: m.initializationDuration, + capabilitiesNegotiated: m.capabilitiesNegotiated, + progressNotifications: m.progressNotifications, + requestHeaderAttributeMapping: m.requestHeaderAttributeMapping, + } + + // Apply header-to-attribute mapping if configured. + for headerName, attrName := range m.requestHeaderAttributeMapping { + if headerValue := req.Header.Get(headerName); headerValue != "" { + withAttrs.defaultAttributes = append( + withAttrs.defaultAttributes, + attribute.String(attrName, headerValue), + ) + } + } + + return withAttrs +} + // RecordMethodCount implements [MCPMetrics.RecordMethodCount]. -func (m *mcp) RecordMethodCount(ctx context.Context, methodName string) { +func (m *mcp) RecordMethodCount(ctx context.Context, methodName string, params mcpsdk.Params) { if methodName == "" { return } m.methodCount.Add(ctx, 1, - metric.WithAttributes( + m.withDefaultAttributes(params, attribute.Key(mcpAttributeMethodName).String(methodName), attribute.String(mcpAttributeStatusName, string(mcpStatusSuccess)), )) } // RecordMethodErrorCount implements [MCPMetrics.RecordMethodErrorCount]. -func (m *mcp) RecordMethodErrorCount(ctx context.Context) { +func (m *mcp) RecordMethodErrorCount(ctx context.Context, params mcpsdk.Params) { m.methodCount.Add(ctx, 1, - metric.WithAttributes( + m.withDefaultAttributes(params, attribute.String(mcpAttributeStatusName, string(mcpStatusError)), )) } // RecordRequestDuration implements [MCPMetrics.RecordRequestDuration]. -func (m *mcp) RecordRequestDuration(ctx context.Context, startAt *time.Time) { +func (m *mcp) RecordRequestDuration(ctx context.Context, startAt *time.Time, params mcpsdk.Params) { if startAt == nil { return } - duration := time.Since(*startAt).Seconds() - m.requestDuration.Record(ctx, duration) + m.requestDuration.Record(ctx, duration, m.withDefaultAttributes(params)) } // RecordRequestErrorDuration implements [MCPMetrics.RecordRequestErrorDuration]. -func (m *mcp) RecordRequestErrorDuration(ctx context.Context, startAt *time.Time, errType MCPErrorType) { +func (m *mcp) RecordRequestErrorDuration(ctx context.Context, startAt *time.Time, errType MCPErrorType, params mcpsdk.Params) { if startAt == nil { return } duration := time.Since(*startAt).Seconds() - m.requestDuration.Record(ctx, duration, metric.WithAttributes( + m.requestDuration.Record(ctx, duration, m.withDefaultAttributes(params, attribute.Key(mcpAttributeErrorType).String(string(errType)), )) } // RecordInitializationDuration implements [MCPMetrics.RecordInitializationDuration]. -func (m *mcp) RecordInitializationDuration(ctx context.Context, startAt *time.Time) { +func (m *mcp) RecordInitializationDuration(ctx context.Context, startAt *time.Time, params mcpsdk.Params) { if startAt == nil { return } duration := time.Since(*startAt).Seconds() - m.initializationDuration.Record(ctx, duration) + m.initializationDuration.Record(ctx, duration, m.withDefaultAttributes(params)) } // RecordProgress implements [MCPMetrics.RecordProgress]. -func (m *mcp) RecordProgress(ctx context.Context) { - m.progressNotifications.Add(ctx, 1) +func (m *mcp) RecordProgress(ctx context.Context, params mcpsdk.Params) { + m.progressNotifications.Add(ctx, 1, m.withDefaultAttributes(params)) } // RecordClientCapabilities implements [MCPMetrics.RecordClientCapabilities]. -func (m *mcp) RecordClientCapabilities(ctx context.Context, capabilities *mcpsdk.ClientCapabilities) { +func (m *mcp) RecordClientCapabilities(ctx context.Context, capabilities *mcpsdk.ClientCapabilities, params mcpsdk.Params) { if capabilities == nil { return } side := string(mcpCapabilitySideClient) if l := len(capabilities.Experimental); l > 0 { - m.capabilitiesNegotiated.Add(ctx, float64(l), metric.WithAttributes( + m.capabilitiesNegotiated.Add(ctx, float64(l), m.withDefaultAttributes(params, attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeExperimental)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } if capabilities.Roots.ListChanged { - m.capabilitiesNegotiated.Add(ctx, 1, metric.WithAttributes( + m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes(params, attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeRoots)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } if capabilities.Sampling != nil { - m.capabilitiesNegotiated.Add(ctx, 1, metric.WithAttributes( + m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes(params, attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeSampling)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } if capabilities.Elicitation != nil { - m.capabilitiesNegotiated.Add(ctx, 1, metric.WithAttributes( + m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes(params, attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeElicitation)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) @@ -252,51 +285,66 @@ func (m *mcp) RecordClientCapabilities(ctx context.Context, capabilities *mcpsdk } // RecordServerCapabilities implements [MCPMetrics.RecordServerCapabilities]. -func (m *mcp) RecordServerCapabilities(ctx context.Context, serverCapa *mcpsdk.ServerCapabilities) { +func (m *mcp) RecordServerCapabilities(ctx context.Context, serverCapa *mcpsdk.ServerCapabilities, params mcpsdk.Params) { if serverCapa == nil { return } side := string(mcpCapabilitySideServer) if serverCapa.Completions != nil { - m.capabilitiesNegotiated.Add(ctx, 1, metric.WithAttributes( + m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes(params, attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeCompletions)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } if l := len(serverCapa.Experimental); l > 0 { - m.capabilitiesNegotiated.Add(ctx, float64(l), metric.WithAttributes( + m.capabilitiesNegotiated.Add(ctx, float64(l), m.withDefaultAttributes(params, attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeExperimental)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } if serverCapa.Logging != nil { - m.capabilitiesNegotiated.Add(ctx, 1, metric.WithAttributes( + m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes(params, attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeLogging)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } if serverCapa.Prompts != nil { - m.capabilitiesNegotiated.Add(ctx, 1, metric.WithAttributes( + m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes(params, attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypePrompts)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } if serverCapa.Resources != nil { - m.capabilitiesNegotiated.Add(ctx, 1, metric.WithAttributes( + m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes(params, attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeResources)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } if serverCapa.Tools != nil { - m.capabilitiesNegotiated.Add(ctx, 1, metric.WithAttributes( + m.capabilitiesNegotiated.Add(ctx, 1, m.withDefaultAttributes(params, attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeTools)), attribute.Key(mcpAttributeCapabilitySide).String(side), )) } } + +// withDefaultAttributes appends default attributes to the provided attributes. +func (m *mcp) withDefaultAttributes(params mcpsdk.Params, attrs ...attribute.KeyValue) metric.MeasurementOption { + all := make([]attribute.KeyValue, 0, len(m.defaultAttributes)+len(m.requestHeaderAttributeMapping)+len(attrs)) + all = append(all, m.defaultAttributes...) + if params != nil { + for src, target := range m.requestHeaderAttributeMapping { + if v := lang.CaseInsensitiveValue(params.GetMeta(), src); v != "" { + all = append(all, attribute.String(target, fmt.Sprintf("%v", v))) + } + } + } + all = append(all, attrs...) + return metric.WithAttributes(all...) +} diff --git a/internal/metrics/mcp_metrics_test.go b/internal/metrics/mcp_metrics_test.go index 82e8dda0b7..4fcd471f5c 100644 --- a/internal/metrics/mcp_metrics_test.go +++ b/internal/metrics/mcp_metrics_test.go @@ -6,6 +6,7 @@ package metrics import ( + "net/http" "testing" "time" @@ -21,18 +22,57 @@ func TestNewMCP(t *testing.T) { mr := metric.NewManualReader() meter := metric.NewMeterProvider(metric.WithReader(mr)).Meter("test") - m := NewMCP(meter) + m := NewMCP(meter, nil) require.NotNil(t, m) } +func TestRecordMetricWithCustomAttributes(t *testing.T) { + mr := metric.NewManualReader() + meter := metric.NewMeterProvider(metric.WithReader(mr)).Meter("test") + + m := NewMCP(meter, map[string]string{ + "x-tracing-enrichment-user-region": "user.region", + "X-Session-Id": "session.id", + "CustomAttr": "custom.attr", + }) + require.NotNil(t, m) + + req, err := http.NewRequest("GET", "https://example.com", nil) + require.NoError(t, err) + req.Header.Set("X-Tracing-Enrichment-User-Region", "us-east-1") // should be included in metrics + req.Header.Set("X-Other-Attr", "other") // should be ignored + req.Header.Set("X-Session-Id", "123") // should be ignored as the value in the metadata takes precedence + + m = m.WithRequestAttributes(req) + + startAt := time.Now().Add(-1 * time.Minute) + m.RecordRequestDuration(t.Context(), &startAt, &mcpsdk.InitializeParams{ + Meta: map[string]any{ + "x-session-id": "sess-1234", // alphabetical order wins when multiple values match case-insensitively + "X-SESSION-ID": "sess-4567", + "customattr": "custom-value1", // exact match should win over case-insensitive match + "CustomAttr": "custom-value2", + }, + }) + + count, sum := testotel.GetHistogramValues(t, mr, mcpRequestDuration, + attribute.NewSet( + attribute.String("user.region", "us-east-1"), + attribute.String("session.id", "sess-4567"), + attribute.String("custom.attr", "custom-value2"), + )) + require.Equal(t, uint64(1), count) + require.Equal(t, 60, int(sum)) +} + func TestRecordRequestDuration(t *testing.T) { mr := metric.NewManualReader() meter := metric.NewMeterProvider(metric.WithReader(mr)).Meter("test") - m := NewMCP(meter) + m := NewMCP(meter, nil) require.NotNil(t, m) startAt := time.Now().Add(-1 * time.Minute) - m.RecordRequestDuration(t.Context(), &startAt) + m.RecordRequestDuration(t.Context(), &startAt, nil) count, sum := testotel.GetHistogramValues(t, mr, mcpRequestDuration, attribute.NewSet()) require.Equal(t, uint64(1), count) @@ -43,10 +83,10 @@ func TestRecordRequestErrorDuration(t *testing.T) { mr := metric.NewManualReader() meter := metric.NewMeterProvider(metric.WithReader(mr)).Meter("test") - m := NewMCP(meter) + m := NewMCP(meter, nil) require.NotNil(t, m) startAt := time.Now().Add(-30 * time.Second) - m.RecordRequestErrorDuration(t.Context(), &startAt, MCPErrorUnsupportedProtocolVersion) + m.RecordRequestErrorDuration(t.Context(), &startAt, MCPErrorUnsupportedProtocolVersion, nil) count, sum := testotel.GetHistogramValues(t, mr, mcpRequestDuration, attribute.NewSet( attribute.Key(mcpAttributeErrorType).String(string(MCPErrorUnsupportedProtocolVersion)), @@ -59,10 +99,10 @@ func TestRecordMethodCount(t *testing.T) { mr := metric.NewManualReader() meter := metric.NewMeterProvider(metric.WithReader(mr)).Meter("test") - m := NewMCP(meter) + m := NewMCP(meter, nil) require.NotNil(t, m) - m.RecordMethodCount(t.Context(), "test_method_name") + m.RecordMethodCount(t.Context(), "test_method_name", nil) attrs := attribute.NewSet( attribute.Key(mcpAttributeMethodName).String("test_method_name"), attribute.Key(mcpAttributeStatusName).String(string(mcpStatusSuccess)), @@ -70,7 +110,7 @@ func TestRecordMethodCount(t *testing.T) { val := testotel.GetCounterValue(t, mr, mcpMethodCount, attrs) require.Equal(t, float64(1), val) - m.RecordMethodErrorCount(t.Context()) + m.RecordMethodErrorCount(t.Context(), nil) attrs = attribute.NewSet( attribute.Key(mcpAttributeStatusName).String(string(mcpStatusError)), ) @@ -82,11 +122,11 @@ func TestRecordInitializationDuration(t *testing.T) { mr := metric.NewManualReader() meter := metric.NewMeterProvider(metric.WithReader(mr)).Meter("test") - m := NewMCP(meter) + m := NewMCP(meter, nil) require.NotNil(t, m) startAt := time.Now().Add(-45 * time.Second) - m.RecordInitializationDuration(t.Context(), &startAt) + m.RecordInitializationDuration(t.Context(), &startAt, nil) count, sum := testotel.GetHistogramValues(t, mr, mcpInitializationDuration, attribute.NewSet()) require.Equal(t, uint64(1), count) @@ -97,7 +137,7 @@ func TestRecordCapabilitiesNegotiated(t *testing.T) { mr := metric.NewManualReader() meter := metric.NewMeterProvider(metric.WithReader(mr)).Meter("test") - m := NewMCP(meter) + m := NewMCP(meter, nil) require.NotNil(t, m) m.RecordClientCapabilities(t.Context(), &mcpsdk.ClientCapabilities{ @@ -109,7 +149,7 @@ func TestRecordCapabilitiesNegotiated(t *testing.T) { ListChanged bool "json:\"listChanged,omitempty\"" }{ListChanged: true}, Sampling: &mcpsdk.SamplingCapabilities{}, - }) + }, nil) m.RecordServerCapabilities(t.Context(), &mcpsdk.ServerCapabilities{ Experimental: map[string]any{ "exp1": struct{}{}, @@ -119,7 +159,7 @@ func TestRecordCapabilitiesNegotiated(t *testing.T) { Prompts: &mcpsdk.PromptCapabilities{ListChanged: true}, Resources: &mcpsdk.ResourceCapabilities{ListChanged: true, Subscribe: true}, Tools: &mcpsdk.ToolCapabilities{ListChanged: true}, - }) + }, nil) require.Equal(t, float64(2), testotel.GetCounterValue(t, mr, mcpCapabilitiesNegotiated, attribute.NewSet( attribute.Key(mcpAttributeCapabilityType).String(string(mcpCapabilityTypeExperimental)), attribute.Key(mcpAttributeCapabilitySide).String(string(mcpCapabilitySideClient)), @@ -152,14 +192,14 @@ func TestRecordProgressNotifications(t *testing.T) { mr := metric.NewManualReader() meter := metric.NewMeterProvider(metric.WithReader(mr)).Meter("test") - m := NewMCP(meter) + m := NewMCP(meter, nil) require.NotNil(t, m) - m.RecordProgress(t.Context()) + m.RecordProgress(t.Context(), nil) val := testotel.GetCounterValue(t, mr, mpcProgressNotifications, attribute.NewSet()) require.Equal(t, float64(1), val) - m.RecordProgress(t.Context()) + m.RecordProgress(t.Context(), nil) val = testotel.GetCounterValue(t, mr, mpcProgressNotifications, attribute.NewSet()) require.Equal(t, float64(2), val) } diff --git a/internal/tracing/api/mcp.go b/internal/tracing/api/mcp.go index dbcf5716a4..913610fa1b 100644 --- a/internal/tracing/api/mcp.go +++ b/internal/tracing/api/mcp.go @@ -9,6 +9,7 @@ package api import ( "context" + "net/http" "github.com/modelcontextprotocol/go-sdk/jsonrpc" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -23,9 +24,10 @@ type MCPTracer interface { // - ctx: might include a parent span context. // - req: Incoming MCP request message. // - param: Incoming MCP parameter used to extract parent trace context. + // - headers: Request HTTP request headers. // // Returns nil unless the span is sampled. - StartSpanAndInjectMeta(ctx context.Context, req *jsonrpc.Request, param mcp.Params) MCPSpan + StartSpanAndInjectMeta(ctx context.Context, req *jsonrpc.Request, param mcp.Params, headers http.Header) MCPSpan } // MCPSpan represents an MCP span. @@ -45,6 +47,6 @@ var _ MCPTracer = NoopMCPTracer{} type NoopMCPTracer struct{} // StartSpanAndInjectMeta implements [MCPTracer.StartSpanAndInjectMeta]. -func (NoopMCPTracer) StartSpanAndInjectMeta(_ context.Context, _ *jsonrpc.Request, _ mcp.Params) MCPSpan { +func (NoopMCPTracer) StartSpanAndInjectMeta(context.Context, *jsonrpc.Request, mcp.Params, http.Header) MCPSpan { return nil } diff --git a/internal/tracing/mcp.go b/internal/tracing/mcp.go index 3e66ec0a99..c0910adb70 100644 --- a/internal/tracing/mcp.go +++ b/internal/tracing/mcp.go @@ -8,6 +8,7 @@ package tracing import ( "context" "fmt" + "net/http" "github.com/modelcontextprotocol/go-sdk/jsonrpc" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -16,6 +17,7 @@ import ( "go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/trace" + "github.com/envoyproxy/ai-gateway/internal/lang" tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" ) @@ -57,16 +59,21 @@ func (s mcpSpan) EndSpan() { // mcpTracer is an implementation of [tracing.MCPTracer]. type mcpTracer struct { - tracer trace.Tracer - propagator propagation.TextMapPropagator + tracer trace.Tracer + propagator propagation.TextMapPropagator + attributeMappings map[string]string } -func newMCPTracer(tracer trace.Tracer, propagator propagation.TextMapPropagator) tracing.MCPTracer { - return mcpTracer{tracer: tracer, propagator: propagator} +func newMCPTracer(tracer trace.Tracer, propagator propagation.TextMapPropagator, attributeMappings map[string]string) tracing.MCPTracer { + return mcpTracer{ + tracer: tracer, + propagator: propagator, + attributeMappings: attributeMappings, + } } // StartSpanAndInjectMeta implements [tracing.MCPTracer.StartSpanAndInjectMeta]. -func (m mcpTracer) StartSpanAndInjectMeta(ctx context.Context, req *jsonrpc.Request, param mcp.Params) tracing.MCPSpan { +func (m mcpTracer) StartSpanAndInjectMeta(ctx context.Context, req *jsonrpc.Request, param mcp.Params, headers http.Header) tracing.MCPSpan { attrs := []attribute.KeyValue{ attribute.String("mcp.protocol.version", "2025-06-18"), attribute.String("mcp.transport", "http"), @@ -75,6 +82,18 @@ func (m mcpTracer) StartSpanAndInjectMeta(ctx context.Context, req *jsonrpc.Requ } attrs = append(attrs, getMCPParamsAsAttributes(param)...) + // Apply header-to-attribute mapping if configured. + for srcName, targetName := range m.attributeMappings { + // Check if the attribute is present in the metadata first, as this is the common place to add custom attributes + // in MCP requests. Fall back to headers if not found in metadata. + // If the attribute is not found there, check if there is any custom header to map. + if metaValue := lang.CaseInsensitiveValue(param.GetMeta(), srcName); metaValue != "" { + attrs = append(attrs, attribute.String(targetName, metaValue)) + } else if headerValue := headers.Get(srcName); headerValue != "" { // this is case-insensitive + attrs = append(attrs, attribute.String(targetName, headerValue)) + } + } + // Extract trace context from incoming meta. mutableMeta := param.GetMeta() if mutableMeta == nil { diff --git a/internal/tracing/mcp_test.go b/internal/tracing/mcp_test.go index d12a9a546e..c11c16571f 100644 --- a/internal/tracing/mcp_test.go +++ b/internal/tracing/mcp_test.go @@ -7,6 +7,7 @@ package tracing import ( "context" + "net/http" "testing" "github.com/modelcontextprotocol/go-sdk/jsonrpc" @@ -23,17 +24,42 @@ func TestTracer_StartSpanAndInjectMeta(t *testing.T) { exporter := tracetest.NewInMemoryExporter() tp := trace.NewTracerProvider(trace.WithSyncer(exporter)) - tracer := newMCPTracer(tp.Tracer("test"), autoprop.NewTextMapPropagator()) + tracer := newMCPTracer(tp.Tracer("test"), autoprop.NewTextMapPropagator(), + map[string]string{ + "x-tracing-enrichment-user-region": "user.region", + "X-Session-Id": "session.id", + "CustomAttr": "custom.attr", + }) + + headers := make(http.Header) + headers.Add("X-Tracing-Enrichment-User-Region", "us-east-1") + headers.Add("X-Session-Id", "123") // should be ignored as the value in the metadata takes precedence reqID, _ := jsonrpc.MakeID("id") r := &jsonrpc.Request{ID: reqID, Method: "initialize"} - p := &mcp.InitializeParams{} - span := tracer.StartSpanAndInjectMeta(t.Context(), r, p) + p := &mcp.InitializeParams{Meta: map[string]any{ + "x-session-id": "sess-1234", // alphabetical order wins when multiple values match case-insensitively + "X-SESSION-ID": "sess-4567", + "customattr": "custom-value1", // exact match should win over case-insensitive match + "CustomAttr": "custom-value2", + }} + span := tracer.StartSpanAndInjectMeta(t.Context(), r, p, headers) require.NotNil(t, span) meta := p.GetMeta() require.NotNil(t, meta) require.NotNil(t, meta["traceparent"]) + + // End the span to export it + span.EndSpan() + spans := exporter.GetSpans() + require.Len(t, spans, 1) + actualSpan := spans[0] + require.Contains(t, actualSpan.Attributes, attribute.String("user.region", "us-east-1")) + require.Contains(t, actualSpan.Attributes, attribute.String("session.id", "sess-4567")) + require.Contains(t, actualSpan.Attributes, attribute.String("custom.attr", "custom-value2")) + require.NotContains(t, actualSpan.Attributes, attribute.String("session.id", "123")) + require.NotContains(t, actualSpan.Attributes, attribute.String("custom.attr", "custom-value1")) } func Test_getMCPAttributes(t *testing.T) { @@ -221,12 +247,12 @@ func TestMCPTracer_SpanName(t *testing.T) { exporter := tracetest.NewInMemoryExporter() tp := trace.NewTracerProvider(trace.WithSyncer(exporter)) - tracer := newMCPTracer(tp.Tracer("test"), autoprop.NewTextMapPropagator()) + tracer := newMCPTracer(tp.Tracer("test"), autoprop.NewTextMapPropagator(), nil) reqID, _ := jsonrpc.MakeID("test-id") req := &jsonrpc.Request{ID: reqID, Method: tt.method} - span := tracer.StartSpanAndInjectMeta(context.Background(), req, tt.params) + span := tracer.StartSpanAndInjectMeta(context.Background(), req, tt.params, nil) require.NotNil(t, span) span.EndSpan() diff --git a/internal/tracing/tracing.go b/internal/tracing/tracing.go index acc123f4c3..c2810ff563 100644 --- a/internal/tracing/tracing.go +++ b/internal/tracing/tracing.go @@ -179,7 +179,7 @@ func NewTracingFromEnv(ctx context.Context, stdout io.Writer, headerAttributeMap embeddingsRecorder, headerAttrs, ), - mcpTracer: newMCPTracer(tracer, propagator), + mcpTracer: newMCPTracer(tracer, propagator, headerAttrs), shutdown: tp.Shutdown, // we have to shut down what we create. }, nil }