Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions lib/srv/mcp/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ func (a *sessionAuditor) emitRequestEvent(ctx context.Context, msg *mcputils.JSO
a.emitEvent(ctx, event)
}

//nolint:unused //TODO(greedy52) remove nolint
func (a *sessionAuditor) emitListenSSEStreamEvent(ctx context.Context, err error) {
event := &apievents.MCPSessionListenSSEStream{
Metadata: a.makeEventMetadata(
Expand All @@ -210,7 +209,6 @@ func (a *sessionAuditor) emitListenSSEStreamEvent(ctx context.Context, err error
a.emitEvent(ctx, event)
}

//nolint:unused //TODO(greedy52) remove nolint
func (a *sessionAuditor) emitInvalidHTTPRequest(ctx context.Context, r *http.Request) {
body, _ := utils.GetAndReplaceRequestBody(r)
event := &apievents.MCPSessionInvalidHTTPRequest{
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/mcp/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ func (t *streamableHTTPTransport) handleMCPMessage(r *http.Request) (*http.Respo
case baseMessage.IsRequest():
mcpRequest := baseMessage.MakeRequest()
// Only emit session start if "initialize" succeeded.
if mcpRequest.Method == "initialize" && respErrForAudit == nil {
if mcpRequest.Method == mcp.MethodInitialize && respErrForAudit == nil {
t.emitStartEvent(t.parentCtx)
}
t.emitRequestEvent(t.parentCtx, mcpRequest, respErrForAudit)
Expand Down
139 changes: 139 additions & 0 deletions lib/srv/mcp/reporting.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Teleport
* Copyright (C) 2025 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package mcp

import (
"slices"

"github.com/mark3labs/mcp-go/mcp"
"github.com/prometheus/client_golang/prometheus"

"github.com/gravitational/teleport"
)

var (
setupErrors = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: teleport.MetricNamespace,
Name: "setup_errors_total",
Subsystem: "mcp",
Help: "Number of errors encountered when setting up MCP sessions",
},
[]string{"transport"},
Comment thread
greedy52 marked this conversation as resolved.
)

accumulatedSessions = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: teleport.MetricNamespace,
Name: "sessions_total",
Subsystem: "mcp",
Help: "Number of accumulated MCP sessions",
},
[]string{"transport"},
)

activeSessions = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: teleport.MetricNamespace,
Name: "active_sessions_total",
Subsystem: "mcp",
Help: "Number of active MCP sessions",
},
[]string{"transport"},
)

messagesFromClient = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: teleport.MetricNamespace,
Name: "messages_from_client_total",
Subsystem: "mcp",
Help: "Number of messages received from the MCP client",
},
[]string{"transport", "type", "method"},
Comment thread
greedy52 marked this conversation as resolved.
)

messagesFromServer = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: teleport.MetricNamespace,
Name: "messages_from_server_total",
Subsystem: "mcp",
Help: "Number of messages received from the MCP server",
},
[]string{"transport", "type", "method"},
)

allPrometheusCollectors = []prometheus.Collector{
setupErrors,
accumulatedSessions, activeSessions,
messagesFromClient, messagesFromServer,
}

// knownNotificationMethods is a list of known method names for notifications.
//
// The list is obtained by searching these in addition to mcp-go:
// - https://github.com/modelcontextprotocol/modelcontextprotocol
// - https://github.com/modelcontextprotocol/typescript-sdk/blob/main/src/server/index.ts
knownNotificationMethods = []mcp.MCPMethod{
//nolint:misspell // "cancelled" is "UK" spelling but our linter is set to use US locale
"notifications/cancelled",
"notifications/initialized",
"notifications/message",
"notifications/progress",
mcp.MethodNotificationPromptsListChanged, // notifications/prompts/list_changed
mcp.MethodNotificationResourcesListChanged, // notifications/resources/list_changed
mcp.MethodNotificationResourceUpdated, // notifications/resources/updated
mcp.MethodNotificationToolsListChanged, // notifications/tools/list_changed
"notifications/roots/list_changed",
}

// knownRequestMethods is a list of known method names for requests.
//
// The list is obtained by searching these in addition to mcp-go:
// - https://github.com/modelcontextprotocol/modelcontextprotocol
// - https://github.com/modelcontextprotocol/typescript-sdk/blob/main/src/server/index.ts
knownRequestMethods = []mcp.MCPMethod{
mcp.MethodInitialize, // initialize
mcp.MethodPing, // ping
mcp.MethodResourcesList, // resources/list
mcp.MethodResourcesTemplatesList, // resources/templates/list
mcp.MethodResourcesRead, // resources/read
mcp.MethodPromptsList, // prompts/list
mcp.MethodPromptsGet, // prompts/get
mcp.MethodToolsList, // tools/list
mcp.MethodToolsCall, // tools/call
mcp.MethodSetLogLevel, // logging/setLevel
mcp.MethodElicitationCreate, // elicitation/create
"roots/list",
"sampling/createMessage",
}
)

func reportNotificationMethod(method mcp.MCPMethod) string {
if slices.Contains(knownNotificationMethods, method) {
return string(method)
}
return "unknown"
}

func reportRequestMethod(method mcp.MCPMethod) string {
if slices.Contains(knownRequestMethods, method) {
return string(method)
}
return "unknown"
}
17 changes: 13 additions & 4 deletions lib/srv/mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/gravitational/teleport/api/types"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/observability/metrics"
"github.com/gravitational/teleport/lib/services"
appcommon "github.com/gravitational/teleport/lib/srv/app/common"
"github.com/gravitational/teleport/lib/utils"
Expand Down Expand Up @@ -103,7 +104,6 @@ func (c *ServerConfig) CheckAndSetDefaults() error {
}

// Server handles forwarding client connections to MCP servers.
// TODO(greedy52) add server metrics.
type Server struct {
cfg ServerConfig

Expand All @@ -116,6 +116,10 @@ func NewServer(cfg ServerConfig) (*Server, error) {
return nil, trace.Wrap(err)
}

if err := metrics.RegisterPrometheusCollectors(allPrometheusCollectors...); err != nil {
return nil, trace.Wrap(err)
}

cache, err := utils.NewFnCache(utils.FnCacheConfig{
TTL: 10 * time.Minute,
Context: cfg.ParentContext,
Expand All @@ -137,19 +141,24 @@ func (s *Server) HandleSession(ctx context.Context, sessionCtx *SessionCtx) erro
if err := sessionCtx.checkAndSetDefaults(); err != nil {
return trace.Wrap(err)
}

// Metrics.
accumulatedSessions.WithLabelValues(sessionCtx.transport).Inc()
activeSessions.WithLabelValues(sessionCtx.transport).Inc()
defer activeSessions.WithLabelValues(sessionCtx.transport).Dec()

if s.cfg.EnableDemoServer && isDemoServerApp(sessionCtx.App) {
return trace.Wrap(s.handleStdio(ctx, sessionCtx, makeDemoServerRunner))
}
transportType := types.GetMCPServerTransportType(sessionCtx.App.GetURI())
switch transportType {
switch sessionCtx.transport {
case types.MCPTransportStdio:
return trace.Wrap(s.handleStdio(ctx, sessionCtx, makeExecServerRunner))
case types.MCPTransportSSE:
return trace.Wrap(s.handleStdioToSSE(ctx, sessionCtx))
case types.MCPTransportHTTP:
return trace.Wrap(s.handleStreamableHTTP(ctx, sessionCtx))
default:
return trace.BadParameter("unknown transport type: %v", transportType)
return trace.BadParameter("unknown transport type: %v", sessionCtx.transport)
}
}

Expand Down
14 changes: 13 additions & 1 deletion lib/srv/mcp/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ type SessionCtx struct {

// traitsForRewriteHeaders are user traits used for rewriting headers.
traitsForRewriteHeaders wrappers.Traits

// transport is the transport type of the MCP server.
transport string
}

func (c *SessionCtx) checkAndSetDefaults() error {
Expand All @@ -82,8 +85,11 @@ func (c *SessionCtx) checkAndSetDefaults() error {
if c.Identity.Username == "" {
c.Identity = c.AuthCtx.Identity.GetIdentity()
}
if c.transport == "" {
c.transport = types.GetMCPServerTransportType(c.App.GetURI())
}
if c.sessionID == "" {
if types.MCPTransportHTTP == types.GetMCPServerTransportType(c.App.GetURI()) {
if types.MCPTransportHTTP == c.transport {
// A single HTTP request is handled at a time so take session ID
// from cert.
c.sessionID = session.ID(c.Identity.RouteToApp.SessionID)
Expand Down Expand Up @@ -197,6 +203,7 @@ func (s *sessionHandler) checkAccessToTool(ctx context.Context, toolName string)

func (s *sessionHandler) processClientNotification(ctx context.Context, notification *mcputils.JSONRPCNotification) {
s.emitNotificationEvent(ctx, notification, nil)
messagesFromClient.WithLabelValues(s.transport, "notification", reportNotificationMethod(notification.Method)).Inc()
}

func (s *sessionHandler) onClientNotification(serverRequestWriter mcputils.MessageWriter) mcputils.HandleNotificationFunc {
Expand Down Expand Up @@ -250,6 +257,8 @@ func (s *sessionHandler) processClientRequest(ctx context.Context, req *mcputils
}

func (s *sessionHandler) processClientRequestNoAudit(ctx context.Context, req *mcputils.JSONRPCRequest) (mcp.JSONRPCMessage, error) {
messagesFromClient.WithLabelValues(s.transport, "request", reportRequestMethod(req.Method)).Inc()

s.idTracker.PushRequest(req)
switch req.Method {
case mcp.MethodToolsCall:
Expand All @@ -263,6 +272,8 @@ func (s *sessionHandler) processClientRequestNoAudit(ctx context.Context, req *m

func (s *sessionHandler) processServerResponse(ctx context.Context, response *mcputils.JSONRPCResponse) mcp.JSONRPCMessage {
method, _ := s.idTracker.PopByID(response.ID)
messagesFromServer.WithLabelValues(s.transport, "response", reportRequestMethod(method)).Inc()

switch method {
case mcp.MethodToolsList:
return s.makeToolsCallResponse(ctx, response)
Expand All @@ -272,6 +283,7 @@ func (s *sessionHandler) processServerResponse(ctx context.Context, response *mc

func (s *sessionHandler) processServerNotification(ctx context.Context, notification *mcputils.JSONRPCNotification) {
s.logger.DebugContext(ctx, "Received server notification.", "method", notification.Method)
messagesFromServer.WithLabelValues(s.transport, "notification", reportNotificationMethod(notification.Method)).Inc()
}

func (s *sessionHandler) makeToolsCallResponse(ctx context.Context, resp *mcputils.JSONRPCResponse) mcp.JSONRPCMessage {
Expand Down
1 change: 1 addition & 0 deletions lib/srv/mcp/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ func (s *Server) handleStdioToSSE(ctx context.Context, sessionCtx *SessionCtx) e
// Initialize SSE stream.
sseResponseReader, sseRequestWriter, err := mcputils.ConnectSSEServer(ctx, baseURL, httpTransport)
if err != nil {
setupErrors.WithLabelValues(sessionCtx.transport).Inc()
return trace.Wrap(err)
}
session.logger.DebugContext(ctx, "Received SSE endpoint", "endpoint_url", sseRequestWriter.GetEndpointURL())
Expand Down
1 change: 1 addition & 0 deletions lib/srv/mcp/stdio.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ func (s *execServer) getStdinPipe() (io.WriteCloser, error) {

func (s *execServer) run(context.Context) error {
if err := s.cmd.Start(); err != nil {
setupErrors.WithLabelValues(s.session.transport).Inc()
return trace.Wrap(err)
}

Expand Down
12 changes: 12 additions & 0 deletions lib/srv/mcp/stdio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (

"github.com/gravitational/trace"
"github.com/mark3labs/mcp-go/mcp"
"github.com/prometheus/client_golang/prometheus/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -179,6 +180,14 @@ func TestHandleSession_execMCPServer(t *testing.T) {

// Check container is running.
require.NotEmpty(t, findDockerContainerID(t.Context(), dockerClient, containerName))

// Note that some metrics may be incremented by other tests too so here
// just checking they are non-zero.
require.Positive(t, testutil.ToFloat64(accumulatedSessions.WithLabelValues(types.MCPTransportStdio)))
require.Positive(t, testutil.ToFloat64(activeSessions.WithLabelValues(types.MCPTransportStdio)))
require.Positive(t, testutil.ToFloat64(messagesFromClient.WithLabelValues(types.MCPTransportStdio, "request", "initialize")))
require.Positive(t, testutil.ToFloat64(messagesFromClient.WithLabelValues(types.MCPTransportStdio, "notification", "notifications/initialized")))
require.Positive(t, testutil.ToFloat64(messagesFromServer.WithLabelValues(types.MCPTransportStdio, "response", "initialize")))
}

tests := []struct {
Expand Down Expand Up @@ -219,6 +228,9 @@ func TestHandleSession_execMCPServer(t *testing.T) {
cmd: "fail-to-start",
checkHandlerError: require.Error,
waitForHandlerExit: time.Second * 5,
afterHandlerStop: func(t *testing.T, _ *testContext, _ string) {
require.Positive(t, testutil.ToFloat64(setupErrors.WithLabelValues(types.MCPTransportStdio)))
},
},
{
// Make sure handler is not blocked when command starts then fails
Expand Down
Loading