diff --git a/lib/srv/mcp/audit.go b/lib/srv/mcp/audit.go index 292ef1fbe8b6e..57db00db2a6b6 100644 --- a/lib/srv/mcp/audit.go +++ b/lib/srv/mcp/audit.go @@ -188,7 +188,6 @@ func (a *sessionAuditor) emitRequestEvent(ctx context.Context, msg *mcputils.JSO a.emitEvent(ctx, event) } -//nolint:unused //TODO(greedy52) remove nolint func (a *sessionAuditor) emitListenSSEStreamEvent(ctx context.Context, err error) { event := &apievents.MCPSessionListenSSEStream{ Metadata: a.makeEventMetadata( @@ -210,7 +209,6 @@ func (a *sessionAuditor) emitListenSSEStreamEvent(ctx context.Context, err error a.emitEvent(ctx, event) } -//nolint:unused //TODO(greedy52) remove nolint func (a *sessionAuditor) emitInvalidHTTPRequest(ctx context.Context, r *http.Request) { body, _ := utils.GetAndReplaceRequestBody(r) event := &apievents.MCPSessionInvalidHTTPRequest{ diff --git a/lib/srv/mcp/http.go b/lib/srv/mcp/http.go index c181df4881262..551276cd488db 100644 --- a/lib/srv/mcp/http.go +++ b/lib/srv/mcp/http.go @@ -242,7 +242,7 @@ func (t *streamableHTTPTransport) handleMCPMessage(r *http.Request) (*http.Respo case baseMessage.IsRequest(): mcpRequest := baseMessage.MakeRequest() // Only emit session start if "initialize" succeeded. - if mcpRequest.Method == "initialize" && respErrForAudit == nil { + if mcpRequest.Method == mcp.MethodInitialize && respErrForAudit == nil { t.emitStartEvent(t.parentCtx) } t.emitRequestEvent(t.parentCtx, mcpRequest, respErrForAudit) diff --git a/lib/srv/mcp/reporting.go b/lib/srv/mcp/reporting.go new file mode 100644 index 0000000000000..49569bcd8b4a1 --- /dev/null +++ b/lib/srv/mcp/reporting.go @@ -0,0 +1,139 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package mcp + +import ( + "slices" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/prometheus/client_golang/prometheus" + + "github.com/gravitational/teleport" +) + +var ( + setupErrors = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: teleport.MetricNamespace, + Name: "setup_errors_total", + Subsystem: "mcp", + Help: "Number of errors encountered when setting up MCP sessions", + }, + []string{"transport"}, + ) + + accumulatedSessions = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: teleport.MetricNamespace, + Name: "sessions_total", + Subsystem: "mcp", + Help: "Number of accumulated MCP sessions", + }, + []string{"transport"}, + ) + + activeSessions = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: teleport.MetricNamespace, + Name: "active_sessions_total", + Subsystem: "mcp", + Help: "Number of active MCP sessions", + }, + []string{"transport"}, + ) + + messagesFromClient = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: teleport.MetricNamespace, + Name: "messages_from_client_total", + Subsystem: "mcp", + Help: "Number of messages received from the MCP client", + }, + []string{"transport", "type", "method"}, + ) + + messagesFromServer = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: teleport.MetricNamespace, + Name: "messages_from_server_total", + Subsystem: "mcp", + Help: "Number of messages received from the MCP server", + }, + []string{"transport", "type", "method"}, + ) + + allPrometheusCollectors = []prometheus.Collector{ + setupErrors, + accumulatedSessions, activeSessions, + messagesFromClient, messagesFromServer, + } + + // knownNotificationMethods is a list of known method names for notifications. + // + // The list is obtained by searching these in addition to mcp-go: + // - https://github.com/modelcontextprotocol/modelcontextprotocol + // - https://github.com/modelcontextprotocol/typescript-sdk/blob/main/src/server/index.ts + knownNotificationMethods = []mcp.MCPMethod{ + //nolint:misspell // "cancelled" is "UK" spelling but our linter is set to use US locale + "notifications/cancelled", + "notifications/initialized", + "notifications/message", + "notifications/progress", + mcp.MethodNotificationPromptsListChanged, // notifications/prompts/list_changed + mcp.MethodNotificationResourcesListChanged, // notifications/resources/list_changed + mcp.MethodNotificationResourceUpdated, // notifications/resources/updated + mcp.MethodNotificationToolsListChanged, // notifications/tools/list_changed + "notifications/roots/list_changed", + } + + // knownRequestMethods is a list of known method names for requests. + // + // The list is obtained by searching these in addition to mcp-go: + // - https://github.com/modelcontextprotocol/modelcontextprotocol + // - https://github.com/modelcontextprotocol/typescript-sdk/blob/main/src/server/index.ts + knownRequestMethods = []mcp.MCPMethod{ + mcp.MethodInitialize, // initialize + mcp.MethodPing, // ping + mcp.MethodResourcesList, // resources/list + mcp.MethodResourcesTemplatesList, // resources/templates/list + mcp.MethodResourcesRead, // resources/read + mcp.MethodPromptsList, // prompts/list + mcp.MethodPromptsGet, // prompts/get + mcp.MethodToolsList, // tools/list + mcp.MethodToolsCall, // tools/call + mcp.MethodSetLogLevel, // logging/setLevel + mcp.MethodElicitationCreate, // elicitation/create + "roots/list", + "sampling/createMessage", + } +) + +func reportNotificationMethod(method mcp.MCPMethod) string { + if slices.Contains(knownNotificationMethods, method) { + return string(method) + } + return "unknown" +} + +func reportRequestMethod(method mcp.MCPMethod) string { + if slices.Contains(knownRequestMethods, method) { + return string(method) + } + return "unknown" +} diff --git a/lib/srv/mcp/server.go b/lib/srv/mcp/server.go index 04ae5fde991a8..73046fcae086a 100644 --- a/lib/srv/mcp/server.go +++ b/lib/srv/mcp/server.go @@ -32,6 +32,7 @@ import ( "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/observability/metrics" "github.com/gravitational/teleport/lib/services" appcommon "github.com/gravitational/teleport/lib/srv/app/common" "github.com/gravitational/teleport/lib/utils" @@ -103,7 +104,6 @@ func (c *ServerConfig) CheckAndSetDefaults() error { } // Server handles forwarding client connections to MCP servers. -// TODO(greedy52) add server metrics. type Server struct { cfg ServerConfig @@ -116,6 +116,10 @@ func NewServer(cfg ServerConfig) (*Server, error) { return nil, trace.Wrap(err) } + if err := metrics.RegisterPrometheusCollectors(allPrometheusCollectors...); err != nil { + return nil, trace.Wrap(err) + } + cache, err := utils.NewFnCache(utils.FnCacheConfig{ TTL: 10 * time.Minute, Context: cfg.ParentContext, @@ -137,11 +141,16 @@ func (s *Server) HandleSession(ctx context.Context, sessionCtx *SessionCtx) erro if err := sessionCtx.checkAndSetDefaults(); err != nil { return trace.Wrap(err) } + + // Metrics. + accumulatedSessions.WithLabelValues(sessionCtx.transport).Inc() + activeSessions.WithLabelValues(sessionCtx.transport).Inc() + defer activeSessions.WithLabelValues(sessionCtx.transport).Dec() + if s.cfg.EnableDemoServer && isDemoServerApp(sessionCtx.App) { return trace.Wrap(s.handleStdio(ctx, sessionCtx, makeDemoServerRunner)) } - transportType := types.GetMCPServerTransportType(sessionCtx.App.GetURI()) - switch transportType { + switch sessionCtx.transport { case types.MCPTransportStdio: return trace.Wrap(s.handleStdio(ctx, sessionCtx, makeExecServerRunner)) case types.MCPTransportSSE: @@ -149,7 +158,7 @@ func (s *Server) HandleSession(ctx context.Context, sessionCtx *SessionCtx) erro case types.MCPTransportHTTP: return trace.Wrap(s.handleStreamableHTTP(ctx, sessionCtx)) default: - return trace.BadParameter("unknown transport type: %v", transportType) + return trace.BadParameter("unknown transport type: %v", sessionCtx.transport) } } diff --git a/lib/srv/mcp/session.go b/lib/srv/mcp/session.go index 07fff770e914b..bd9dac8b7f50b 100644 --- a/lib/srv/mcp/session.go +++ b/lib/srv/mcp/session.go @@ -67,6 +67,9 @@ type SessionCtx struct { // traitsForRewriteHeaders are user traits used for rewriting headers. traitsForRewriteHeaders wrappers.Traits + + // transport is the transport type of the MCP server. + transport string } func (c *SessionCtx) checkAndSetDefaults() error { @@ -82,8 +85,11 @@ func (c *SessionCtx) checkAndSetDefaults() error { if c.Identity.Username == "" { c.Identity = c.AuthCtx.Identity.GetIdentity() } + if c.transport == "" { + c.transport = types.GetMCPServerTransportType(c.App.GetURI()) + } if c.sessionID == "" { - if types.MCPTransportHTTP == types.GetMCPServerTransportType(c.App.GetURI()) { + if types.MCPTransportHTTP == c.transport { // A single HTTP request is handled at a time so take session ID // from cert. c.sessionID = session.ID(c.Identity.RouteToApp.SessionID) @@ -197,6 +203,7 @@ func (s *sessionHandler) checkAccessToTool(ctx context.Context, toolName string) func (s *sessionHandler) processClientNotification(ctx context.Context, notification *mcputils.JSONRPCNotification) { s.emitNotificationEvent(ctx, notification, nil) + messagesFromClient.WithLabelValues(s.transport, "notification", reportNotificationMethod(notification.Method)).Inc() } func (s *sessionHandler) onClientNotification(serverRequestWriter mcputils.MessageWriter) mcputils.HandleNotificationFunc { @@ -250,6 +257,8 @@ func (s *sessionHandler) processClientRequest(ctx context.Context, req *mcputils } func (s *sessionHandler) processClientRequestNoAudit(ctx context.Context, req *mcputils.JSONRPCRequest) (mcp.JSONRPCMessage, error) { + messagesFromClient.WithLabelValues(s.transport, "request", reportRequestMethod(req.Method)).Inc() + s.idTracker.PushRequest(req) switch req.Method { case mcp.MethodToolsCall: @@ -263,6 +272,8 @@ func (s *sessionHandler) processClientRequestNoAudit(ctx context.Context, req *m func (s *sessionHandler) processServerResponse(ctx context.Context, response *mcputils.JSONRPCResponse) mcp.JSONRPCMessage { method, _ := s.idTracker.PopByID(response.ID) + messagesFromServer.WithLabelValues(s.transport, "response", reportRequestMethod(method)).Inc() + switch method { case mcp.MethodToolsList: return s.makeToolsCallResponse(ctx, response) @@ -272,6 +283,7 @@ func (s *sessionHandler) processServerResponse(ctx context.Context, response *mc func (s *sessionHandler) processServerNotification(ctx context.Context, notification *mcputils.JSONRPCNotification) { s.logger.DebugContext(ctx, "Received server notification.", "method", notification.Method) + messagesFromServer.WithLabelValues(s.transport, "notification", reportNotificationMethod(notification.Method)).Inc() } func (s *sessionHandler) makeToolsCallResponse(ctx context.Context, resp *mcputils.JSONRPCResponse) mcp.JSONRPCMessage { diff --git a/lib/srv/mcp/sse.go b/lib/srv/mcp/sse.go index 63fc39bd6f8b0..ccdfd996dcd53 100644 --- a/lib/srv/mcp/sse.go +++ b/lib/srv/mcp/sse.go @@ -64,6 +64,7 @@ func (s *Server) handleStdioToSSE(ctx context.Context, sessionCtx *SessionCtx) e // Initialize SSE stream. sseResponseReader, sseRequestWriter, err := mcputils.ConnectSSEServer(ctx, baseURL, httpTransport) if err != nil { + setupErrors.WithLabelValues(sessionCtx.transport).Inc() return trace.Wrap(err) } session.logger.DebugContext(ctx, "Received SSE endpoint", "endpoint_url", sseRequestWriter.GetEndpointURL()) diff --git a/lib/srv/mcp/stdio.go b/lib/srv/mcp/stdio.go index d48727ebe8d7f..35a1a581016c3 100644 --- a/lib/srv/mcp/stdio.go +++ b/lib/srv/mcp/stdio.go @@ -176,6 +176,7 @@ func (s *execServer) getStdinPipe() (io.WriteCloser, error) { func (s *execServer) run(context.Context) error { if err := s.cmd.Start(); err != nil { + setupErrors.WithLabelValues(s.session.transport).Inc() return trace.Wrap(err) } diff --git a/lib/srv/mcp/stdio_test.go b/lib/srv/mcp/stdio_test.go index bb175dcf9b2ed..65a1ebd54ef80 100644 --- a/lib/srv/mcp/stdio_test.go +++ b/lib/srv/mcp/stdio_test.go @@ -27,6 +27,7 @@ import ( "github.com/gravitational/trace" "github.com/mark3labs/mcp-go/mcp" + "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -179,6 +180,14 @@ func TestHandleSession_execMCPServer(t *testing.T) { // Check container is running. require.NotEmpty(t, findDockerContainerID(t.Context(), dockerClient, containerName)) + + // Note that some metrics may be incremented by other tests too so here + // just checking they are non-zero. + require.Positive(t, testutil.ToFloat64(accumulatedSessions.WithLabelValues(types.MCPTransportStdio))) + require.Positive(t, testutil.ToFloat64(activeSessions.WithLabelValues(types.MCPTransportStdio))) + require.Positive(t, testutil.ToFloat64(messagesFromClient.WithLabelValues(types.MCPTransportStdio, "request", "initialize"))) + require.Positive(t, testutil.ToFloat64(messagesFromClient.WithLabelValues(types.MCPTransportStdio, "notification", "notifications/initialized"))) + require.Positive(t, testutil.ToFloat64(messagesFromServer.WithLabelValues(types.MCPTransportStdio, "response", "initialize"))) } tests := []struct { @@ -219,6 +228,9 @@ func TestHandleSession_execMCPServer(t *testing.T) { cmd: "fail-to-start", checkHandlerError: require.Error, waitForHandlerExit: time.Second * 5, + afterHandlerStop: func(t *testing.T, _ *testContext, _ string) { + require.Positive(t, testutil.ToFloat64(setupErrors.WithLabelValues(types.MCPTransportStdio))) + }, }, { // Make sure handler is not blocked when command starts then fails