From 8e87bfb5d15f6d4cd9a38fdaad387d1a074b3b59 Mon Sep 17 00:00:00 2001 From: "STeve (Xin) Huang" Date: Mon, 22 Sep 2025 10:09:43 -0400 Subject: [PATCH 1/3] [mcp] server handler for streamable hTTP transport --- api/types/app.go | 4 +- api/types/app_test.go | 25 +- api/types/constants.go | 8 + integration/appaccess/appaccess_test.go | 20 +- integration/appaccess/mcp_test.go | 59 +++++ lib/client/api.go | 6 + lib/srv/app/connections_handler.go | 3 +- lib/srv/mcp/helpers_test.go | 36 ++- lib/srv/mcp/http.go | 297 ++++++++++++++++++++++++ lib/srv/mcp/http_test.go | 200 ++++++++++++++++ lib/srv/mcp/server.go | 56 ++++- lib/srv/mcp/session.go | 48 +++- lib/srv/mcp/sse.go | 1 + lib/srv/mcp/sse_test.go | 1 + lib/srv/mcp/stdio_test.go | 9 +- tool/tsh/common/proxy.go | 8 +- 16 files changed, 756 insertions(+), 25 deletions(-) create mode 100644 lib/srv/mcp/http.go create mode 100644 lib/srv/mcp/http_test.go diff --git a/api/types/app.go b/api/types/app.go index da514c22abb84..38dd917840c17 100644 --- a/api/types/app.go +++ b/api/types/app.go @@ -526,7 +526,7 @@ func (a *AppV3) checkMCP() error { switch GetMCPServerTransportType(a.Spec.URI) { case MCPTransportStdio: return trace.Wrap(a.checkMCPStdio()) - case MCPTransportSSE: + case MCPTransportSSE, MCPTransportHTTP: _, err := url.Parse(a.Spec.URI) return trace.Wrap(err) default: @@ -690,6 +690,8 @@ func GetMCPServerTransportType(uri string) string { return MCPTransportStdio case SchemeMCPSSEHTTP, SchemeMCPSSEHTTPS: return MCPTransportSSE + case SchemeMCPHTTP, SchemeMCPHTTPS: + return MCPTransportHTTP default: return "" } diff --git a/api/types/app_test.go b/api/types/app_test.go index 861f3f7000b59..12d3110fb00ac 100644 --- a/api/types/app_test.go +++ b/api/types/app_test.go @@ -689,7 +689,7 @@ func TestNewAppV3(t *testing.T) { wantErr: require.NoError, }, { - name: "mcp with uri", + name: "mcp with SSE transport", meta: Metadata{ Name: "mcp-everything", }, @@ -711,6 +711,29 @@ func TestNewAppV3(t *testing.T) { }, wantErr: require.NoError, }, + { + name: "mcp with streamable HTTP transport", + meta: Metadata{ + Name: "mcp-everything", + }, + spec: AppSpecV3{ + URI: "mcp+http://localhost:12345/mcp", + }, + want: &AppV3{ + Kind: "app", + SubKind: "mcp", + Version: "v3", + Metadata: Metadata{ + Name: "mcp-everything", + Namespace: "default", + Labels: map[string]string{AppSubKindLabel: "mcp"}, + }, + Spec: AppSpecV3{ + URI: "mcp+http://localhost:12345/mcp", + }, + }, + wantErr: require.NoError, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/api/types/constants.go b/api/types/constants.go index 4d63544a6d086..9adb0156d0a99 100644 --- a/api/types/constants.go +++ b/api/types/constants.go @@ -951,6 +951,14 @@ const ( SchemeMCPSSEHTTPS = "mcp+sse+https" // MCPTransportSSE indicates the MCP server uses SSE transport. MCPTransportSSE = "SSE" + // SchemeMCPHTTP is a URI scheme for MCP servers using HTTP with streamable + // HTTP transport. + SchemeMCPHTTP = "mcp+http" + // SchemeMCPHTTPS is a URI scheme for MCP servers using HTTPS with + // streamable HTTP transport. + SchemeMCPHTTPS = "mcp+https" + // MCPTransportHTTP indicates the MCP server uses SSE transport. + MCPTransportHTTP = "Streamable HTTP" // DiscoveredResourceNode identifies a discovered SSH node. DiscoveredResourceNode = "node" diff --git a/integration/appaccess/appaccess_test.go b/integration/appaccess/appaccess_test.go index 9c1eb79b8d30c..6122b9b840abd 100644 --- a/integration/appaccess/appaccess_test.go +++ b/integration/appaccess/appaccess_test.go @@ -23,6 +23,7 @@ import ( "context" "crypto/tls" "errors" + "fmt" "io" "net" "net/http" @@ -35,6 +36,7 @@ import ( "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" + mcpserver "github.com/mark3labs/mcp-go/server" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -62,10 +64,20 @@ import ( func TestAppAccess(t *testing.T) { // Enable MCP test servers. sseServerURL := mcptest.MustStartSSETestServer(t) - extraApps := []servicecfg.App{{ - Name: "test-sse", - URI: "mcp+sse+" + sseServerURL, - }} + streamableHTTPServer := mcpserver.NewTestStreamableHTTPServer(mcptest.NewServer()) + streamableHTTPServerURL := fmt.Sprintf("mcp+%s/mcp", streamableHTTPServer.URL) + t.Cleanup(streamableHTTPServer.Close) + + extraApps := []servicecfg.App{ + { + Name: "test-sse", + URI: "mcp+sse+" + sseServerURL, + }, + { + Name: "test-http", + URI: streamableHTTPServerURL, + }, + } // Reusing the pack as much as we can. pack := SetupWithOptions(t, AppTestOptions{ diff --git a/integration/appaccess/mcp_test.go b/integration/appaccess/mcp_test.go index c6dd65964362f..cf513e33acf60 100644 --- a/integration/appaccess/mcp_test.go +++ b/integration/appaccess/mcp_test.go @@ -20,11 +20,19 @@ package appaccess import ( "context" + "crypto/tls" + "fmt" + "net/http" "testing" + mcpclient "github.com/mark3labs/mcp-go/client" + mcpclienttransport "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/require" + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/client" libmcp "github.com/gravitational/teleport/lib/srv/mcp" "github.com/gravitational/teleport/lib/utils/mcptest" ) @@ -41,6 +49,10 @@ func testMCP(pack *Pack, t *testing.T) { t.Run("DialMCPServer stdio to sse success", func(t *testing.T) { testMCPDialStdioToSSE(t, pack, "test-sse") }) + + t.Run("proxy streamable HTTP requests with TLS cert", func(t *testing.T) { + testMCPProxyStreamableHTTP(t, pack, "test-http") + }) } func testMCPDialStdioNoServerFound(t *testing.T, pack *Pack) { @@ -81,3 +93,50 @@ func testMCPDialStdioToSSE(t *testing.T, pack *Pack, appName string) { mcptest.MustCallServerTool(t, ctx, stdioClient) } + +func testMCPProxyStreamableHTTP(t *testing.T, pack *Pack, appName string) { + require.NoError(t, pack.tc.SaveProfile(false)) + + // Find the MCP server. + filter := pack.tc.ResourceFilter(types.KindAppServer) + filter.PredicateExpression = fmt.Sprintf(`name == "%s"`, appName) + apps, err := pack.tc.ListApps(t.Context(), filter) + require.NoError(t, err) + require.Len(t, apps, 1) + + // Issue a TLS cert with app route. + keyRing, err := pack.tc.IssueUserCertsWithMFA(t.Context(), client.ReissueParams{ + RouteToCluster: pack.rootCluster.Secrets.SiteName, + RouteToApp: proto.RouteToApp{ + ClusterName: pack.rootCluster.Secrets.SiteName, + Name: apps[0].GetName(), + PublicAddr: apps[0].GetPublicAddr(), + }, + }) + require.NoError(t, err) + appCert, err := keyRing.AppTLSCert(appName) + require.NoError(t, err) + + // Create an MCP client with app cert. + ctx := t.Context() + mcpClientTransport, err := mcpclienttransport.NewStreamableHTTP( + "https://"+pack.rootCluster.Web, + mcpclienttransport.WithHTTPBasicClient(&http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + Certificates: []tls.Certificate{appCert}, + InsecureSkipVerify: true, + }, + }, + }), + ) + require.NoError(t, err) + client := mcpclient.NewClient(mcpClientTransport) + require.NoError(t, client.Start(ctx)) + defer client.Close() + + // Initialize client and call a tool. + _, err = mcptest.InitializeClient(ctx, client) + require.NoError(t, err) + mcptest.MustCallServerTool(t, ctx, client) +} diff --git a/lib/client/api.go b/lib/client/api.go index c25085a3f32d1..5e0fd09ea160f 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -5563,6 +5563,12 @@ func (tc *TeleportClient) DialMCPServer(ctx context.Context, appName string) (ne return nil, trace.BadParameter("app %q is not a MCP server", appName) } + // TODO(greedy52) support streamable HTTP for "tsh mcp connect" before + // release. + if transport := types.GetMCPServerTransportType(apps[0].GetURI()); transport == types.MCPTransportHTTP { + return nil, trace.NotImplemented("MCP support for %s is not yet implemented", transport) + } + cert, err := tc.issueMCPCertWithMFA(ctx, apps[0]) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/srv/app/connections_handler.go b/lib/srv/app/connections_handler.go index 8c5f13c1a1666..3953ce05fef2d 100644 --- a/lib/srv/app/connections_handler.go +++ b/lib/srv/app/connections_handler.go @@ -285,6 +285,7 @@ func NewConnectionsHandler(closeContext context.Context, cfg *ConnectionsHandler AccessPoint: c.cfg.AccessPoint, EnableDemoServer: c.cfg.MCPDemoServer, CipherSuites: c.cfg.CipherSuites, + AuthClient: c.cfg.AuthClient, }) if err != nil { return nil, trace.Wrap(err) @@ -614,7 +615,7 @@ func (c *ConnectionsHandler) handleConnection(conn net.Conn) (func(), error) { case app.IsTCP(): return nil, trace.Wrap(err) case app.IsMCP(): - return nil, trace.Wrap(c.mcpServer.HandleUnauthorizedConnection(ctx, conn, err)) + return nil, trace.Wrap(c.mcpServer.HandleUnauthorizedConnection(ctx, conn, app, err)) default: c.setConnAuth(tlsConn, err) } diff --git a/lib/srv/mcp/helpers_test.go b/lib/srv/mcp/helpers_test.go index ca2df2f9f9c51..fae369d42b259 100644 --- a/lib/srv/mcp/helpers_test.go +++ b/lib/srv/mcp/helpers_test.go @@ -51,8 +51,9 @@ func TestMain(m *testing.M) { } type setupTestContextOptions struct { - roleSet services.RoleSet - app types.Application + roleSet services.RoleSet + app types.Application + clientConn net.Conn } type setupTestContextOptionFunc func(*setupTestContextOptions) @@ -69,6 +70,12 @@ func withRole(role types.Role) setupTestContextOptionFunc { } } +func withClientConn(conn net.Conn) setupTestContextOptionFunc { + return func(opts *setupTestContextOptions) { + opts.clientConn = conn + } +} + // withAdminRole assigns to ai_user a role that allows all MCP servers and their // tools. func withAdminRole(t *testing.T) setupTestContextOptionFunc { @@ -139,8 +146,13 @@ func setupTestContext(t *testing.T, applyOpts ...setupTestContextOptionFunc) tes applyOpt(&opts) } - // Fake connection. - clientSourceConn, clientDestConn := makeDualPipeNetConn(t) + // Fake connection if not passed in. + var clientSourceConn, clientDestConn net.Conn + if opts.clientConn != nil { + clientDestConn = opts.clientConn + } else { + clientSourceConn, clientDestConn = makeDualPipeNetConn(t) + } // App. if opts.app == nil { @@ -164,7 +176,7 @@ func setupTestContext(t *testing.T, applyOpts ...setupTestContextOptionFunc) tes sessionCtx := &SessionCtx{ ClientConn: clientDestConn, App: opts.app, - AuthCtx: makeTestAuthContext(t, opts.roleSet), + AuthCtx: makeTestAuthContext(t, opts.roleSet, opts.app), } require.NoError(t, sessionCtx.checkAndSetDefaults()) @@ -174,7 +186,7 @@ func setupTestContext(t *testing.T, applyOpts ...setupTestContextOptionFunc) tes } } -func makeTestAuthContext(t *testing.T, roleSet services.RoleSet) *authz.Context { +func makeTestAuthContext(t *testing.T, roleSet services.RoleSet, app types.Application) *authz.Context { t.Helper() user, err := types.NewUser("ai") @@ -189,6 +201,11 @@ func makeTestAuthContext(t *testing.T, roleSet services.RoleSet) *authz.Context Principals: user.GetLogins(), }, } + if app != nil { + identity.Identity.RouteToApp.Name = app.GetName() + identity.Identity.RouteToApp.SessionID = "session-id-for+" + app.GetName() + } + accessInfo, err := services.AccessInfoFromLocalTLSIdentity(identity.Identity) require.NoError(t, err) checker := services.NewAccessCheckerWithRoleSet(accessInfo, "my-cluster", roleSet) @@ -327,3 +344,10 @@ func forceRemoveContainer(t *testing.T, dockerClient *docker.Client, containerNa } } } + +type mockAuthClient struct { +} + +func (m mockAuthClient) GenerateAppToken(_ context.Context, req types.GenerateAppTokenRequest) (string, error) { + return "app-token-for-" + req.Username, nil +} diff --git a/lib/srv/mcp/http.go b/lib/srv/mcp/http.go new file mode 100644 index 0000000000000..ee1da6c42f66a --- /dev/null +++ b/lib/srv/mcp/http.go @@ -0,0 +1,297 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package mcp + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net" + "net/http" + "net/url" + "strings" + "time" + + "github.com/gravitational/trace" + "github.com/mark3labs/mcp-go/mcp" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/lib/httplib/reverseproxy" + "github.com/gravitational/teleport/lib/services" + appcommon "github.com/gravitational/teleport/lib/srv/app/common" + "github.com/gravitational/teleport/lib/utils" + listenerutils "github.com/gravitational/teleport/lib/utils/listener" + "github.com/gravitational/teleport/lib/utils/mcputils" +) + +const ( + mcpSessionIDHeader = "Mcp-Session-Id" +) + +func (s *Server) serveHTTPConn(ctx context.Context, conn net.Conn, handler http.Handler) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + waitConn := utils.NewCloserConn(conn) + listener := listenerutils.NewSingleUseListener(waitConn) + if err := http.Serve(listener, handler); err != nil && !utils.IsOKNetworkError(err) { + waitConn.Close() + return trace.Wrap(err) + } + go func() { + // Make sure handler returns if ctx is canceled. + <-ctx.Done() + waitConn.Close() + }() + waitConn.Wait() + return nil +} + +func (s *Server) handleAuthErrHTTP(ctx context.Context, clientConn net.Conn, authErr error) error { + return trace.Wrap(s.serveHTTPConn(ctx, clientConn, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + trace.WriteError(w, authErr) + }))) +} + +func (s *Server) handleStreamableHTTP(ctx context.Context, sessionCtx *SessionCtx) error { + session, err := s.getSessionHandlerWithJWT(ctx, sessionCtx) + if err != nil { + return trace.Wrap(err, "setting up session handler") + } + + transport, err := s.makeStreamableHTTPTransport(session) + if err != nil { + return trace.Wrap(err, "setting up streamable http transport") + } + + session.logger.DebugContext(ctx, "Started handling HTTP request") + defer session.logger.DebugContext(ctx, "Completed handling HTTP request") + + delegate := reverseproxy.NewHeaderRewriter() + reverseProxy, err := reverseproxy.New( + reverseproxy.WithFlushInterval(100*time.Millisecond), + reverseproxy.WithRoundTripper(transport), + reverseproxy.WithLogger(session.logger), + reverseproxy.WithRewriter(appcommon.NewHeaderRewriter(delegate)), + reverseproxy.WithResponseModifier(func(resp *http.Response) error { + if resp.Request != nil && resp.Request.Method == http.MethodDelete { + // Nothing to modify here. + return nil + } + return trace.Wrap(mcputils.ReplaceHTTPResponse(ctx, resp, newHTTPResponseReplacer(session))) + }), + ) + if err != nil { + return trace.Wrap(err, "creating reverse proxy") + } + + return trace.Wrap(s.serveHTTPConn(ctx, sessionCtx.ClientConn, reverseProxy)) +} + +func (s *Server) makeStreamableHTTPTransport(session *sessionHandler) (http.RoundTripper, error) { + targetURI, err := url.Parse(session.App.GetURI()) + if err != nil { + return nil, trace.Wrap(err) + } + targetURI.Scheme = strings.TrimPrefix(targetURI.Scheme, "mcp+") + + targetTransport, err := s.makeHTTPTransport(session.App) + if err != nil { + return nil, trace.Wrap(err) + } + return &streamableHTTPTransport{ + sessionHandler: session, + targetURI: targetURI, + targetTransport: targetTransport, + }, nil +} + +type streamableHTTPTransport struct { + *sessionHandler + targetURI *url.URL + targetTransport http.RoundTripper +} + +func (t *streamableHTTPTransport) RoundTrip(r *http.Request) (*http.Response, error) { + t.setExternalSessionID(r.Header) + + switch r.Method { + case http.MethodDelete: + return t.handleSessionEndRequest(r) + case http.MethodGet: + return t.handleListenSSEStreamRequest(r) + case http.MethodPost: + return t.handleMCPMessage(r) + + default: + t.emitInvalidHTTPRequest(t.parentCtx, r) + return &http.Response{ + Request: r, + StatusCode: http.StatusMethodNotAllowed, + }, nil + } +} + +func (t *streamableHTTPTransport) setExternalSessionID(header http.Header) { + if id := header.Get(mcpSessionIDHeader); id != "" { + t.mcpSessionID.Store(&id) + } +} + +func (t *streamableHTTPTransport) rewriteRequest(r *http.Request) *http.Request { + r = r.Clone(r.Context()) + r.URL.Scheme = t.targetURI.Scheme + r.URL.Host = t.targetURI.Host + + // Defaults to the endpoint defined in the app if client is not providing it. + // By spec, streamable HTTP should use a single endpoint except the + // ".well-known" used for OAuth. + if t.targetURI.Path != "" && (r.URL.Path == "" || r.URL.Path == "/") { + r.URL.Path = t.targetURI.Path + } + + // Add in JWT headers. + r.Header.Set(teleport.AppJWTHeader, t.jwt) + // Add headers from rewrite configuration. + rewriteHeaders := appcommon.AppRewriteHeaders(r.Context(), t.App.GetRewrite(), t.logger) + services.RewriteHeadersAndApplyValueTraits(r, rewriteHeaders, t.traitsForRewriteHeaders, t.logger) + return r +} + +func (t *streamableHTTPTransport) rewriteAndSendRequest(r *http.Request) (*http.Response, error) { + rCopy := t.rewriteRequest(r) + return t.targetTransport.RoundTrip(rCopy) +} + +func (t *streamableHTTPTransport) handleSessionEndRequest(r *http.Request) (*http.Response, error) { + resp, err := t.rewriteAndSendRequest(r) + t.emitEndEvent(t.parentCtx, convertHTTPResponseErrorForAudit(resp, err)) + return resp, trace.Wrap(err) +} + +func (t *streamableHTTPTransport) handleListenSSEStreamRequest(r *http.Request) (*http.Response, error) { + resp, err := t.rewriteAndSendRequest(r) + t.emitListenSSEStreamEvent(t.parentCtx, convertHTTPResponseErrorForAudit(resp, err)) + return resp, trace.Wrap(err) +} + +func (t *streamableHTTPTransport) handleMCPMessage(r *http.Request) (*http.Response, error) { + var baseMessage mcputils.BaseJSONRPCMessage + if reqBody, err := utils.GetAndReplaceRequestBody(r); err != nil { + t.emitInvalidHTTPRequest(t.parentCtx, r) + return nil, trace.BadParameter("invalid request body %v", err) + } else if err := json.Unmarshal(reqBody, &baseMessage); err != nil { + t.emitInvalidHTTPRequest(t.parentCtx, r) + return nil, trace.BadParameter("invalid request body %v", err) + } + + switch { + case baseMessage.IsRequest(): + mcpRequest := baseMessage.MakeRequest() + if errResp, authErr := t.sessionHandler.processClientRequestNoAudit(r.Context(), mcpRequest); authErr != nil { + return t.handleRequestAuthError(r, mcpRequest, errResp, authErr) + } + case baseMessage.IsNotification(): + // nothing to do, yet. + default: + // Not sending it to the server if we don't understand it. + t.emitInvalidHTTPRequest(t.parentCtx, r) + return nil, trace.BadParameter("not a MCP request or notification") + } + + resp, err := t.rewriteAndSendRequest(r) + // Prefer session ID from server response if present. For example, + // "initialize" request does not have an ID but the server response may have + // it. + if resp != nil { + t.setExternalSessionID(resp.Header) + } + + // Take care of audit events after round trip. + respErrForAudit := convertHTTPResponseErrorForAudit(resp, err) + switch { + case baseMessage.IsRequest(): + mcpRequest := baseMessage.MakeRequest() + // Only emit session start if "initialize" succeeded. + if mcpRequest.Method == "initialize" && respErrForAudit == nil { + t.emitStartEvent(t.parentCtx) + } + t.emitRequestEvent(t.parentCtx, mcpRequest, respErrForAudit) + case baseMessage.IsNotification(): + t.emitNotificationEvent(t.parentCtx, baseMessage.MakeNotification(), respErrForAudit) + } + return resp, trace.Wrap(err) +} + +func (t *streamableHTTPTransport) handleRequestAuthError(r *http.Request, mcpRequest *mcputils.JSONRPCRequest, errResp mcp.JSONRPCMessage, authErr error) (*http.Response, error) { + t.emitRequestEvent(t.parentCtx, mcpRequest, authErr) + + errRespAsBody, err := json.Marshal(errResp) + if err != nil { + // Should not happen. If it does, we are failing the request either way. + return nil, trace.Wrap(err) + } + + httpResp := &http.Response{ + Request: r, + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(errRespAsBody)), + Header: make(http.Header), + } + httpResp.Header.Set("Content-Type", "application/json") + return httpResp, nil +} + +func convertHTTPResponseErrorForAudit(resp *http.Response, err error) error { + if err != nil { + return trace.Wrap(err) + } + if resp == nil { + // Should not happen. + return trace.BadParameter("missing response") + } + if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusBadRequest { + return nil + } + if resp.Status == "" { + return trace.Errorf("HTTP %d %s", resp.StatusCode, http.StatusText(resp.StatusCode)) + } + return trace.Errorf("HTTP %s", resp.Status) +} + +// streamableHTTPResponseReplacer is a wrapper of sessionHandler to satisfy +// mcputils.ServerMessageProcessor. +type streamableHTTPResponseReplacer struct { + *sessionHandler +} + +func newHTTPResponseReplacer(sessionHandler *sessionHandler) *streamableHTTPResponseReplacer { + return &streamableHTTPResponseReplacer{ + sessionHandler: sessionHandler, + } +} + +func (p *streamableHTTPResponseReplacer) ProcessResponse(ctx context.Context, resp *mcputils.JSONRPCResponse) mcp.JSONRPCMessage { + return p.processServerResponse(ctx, resp) +} +func (p *streamableHTTPResponseReplacer) ProcessNotification(ctx context.Context, notification *mcputils.JSONRPCNotification) mcp.JSONRPCMessage { + p.processServerNotification(ctx, notification) + return notification +} diff --git a/lib/srv/mcp/http_test.go b/lib/srv/mcp/http_test.go new file mode 100644 index 0000000000000..608d7bc032088 --- /dev/null +++ b/lib/srv/mcp/http_test.go @@ -0,0 +1,200 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package mcp + +import ( + "fmt" + "net" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/gravitational/trace" + mcpclient "github.com/mark3labs/mcp-go/client" + mcpclienttransport "github.com/mark3labs/mcp-go/client/transport" + mcpserver "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" + apievents "github.com/gravitational/teleport/api/types/events" + libevents "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/events/eventstest" + "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/utils/mcptest" + sliceutils "github.com/gravitational/teleport/lib/utils/slices" +) + +func Test_handleStreamableHTTP(t *testing.T) { + remoteMCPServer := mcpserver.NewStreamableHTTPServer(mcptest.NewServer()) + remoteMCPHTTPServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.URL.Path != "/mcp": + // Unhappy scenario. + w.WriteHeader(http.StatusNotFound) + case r.Header.Get("Authorization") != "Bearer app-token-for-ai": + // Verify rewrite headers. + w.WriteHeader(http.StatusUnauthorized) + default: + remoteMCPServer.ServeHTTP(w, r) + } + })) + t.Cleanup(remoteMCPHTTPServer.Close) + + app, err := types.NewAppV3(types.Metadata{ + Name: "test-http", + }, types.AppSpecV3{ + URI: fmt.Sprintf("mcp+%s/mcp", remoteMCPHTTPServer.URL), + Rewrite: &types.Rewrite{ + Headers: []*types.Header{{ + Name: "Authorization", + Value: "Bearer {{internal.jwt}}", + }}, + }, + }) + require.NoError(t, err) + + emitter := eventstest.MockRecorderEmitter{} + s, err := NewServer(ServerConfig{ + Emitter: &emitter, + ParentContext: t.Context(), + HostID: "my-host-id", + AccessPoint: fakeAccessPoint{}, + CipherSuites: utils.DefaultCipherSuites(), + AuthClient: mockAuthClient{}, + }) + require.NoError(t, err) + + // Run MCP handler behind a listener. + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + go func() { + for { + conn, err := listener.Accept() + if err != nil { + assert.True(t, utils.IsOKNetworkError(err)) + return + } + wg.Go(func() { + defer conn.Close() + testCtx := setupTestContext(t, withAdminRole(t), withApp(app), withClientConn(conn)) + assert.NoError(t, s.HandleSession(t.Context(), testCtx.SessionCtx)) + }) + } + }() + + t.Run("success", func(t *testing.T) { + ctx := t.Context() + emitter.Reset() + mcpClientTransport, err := mcpclienttransport.NewStreamableHTTP( + "http://"+listener.Addr().String(), + mcpclienttransport.WithContinuousListening(), + ) + require.NoError(t, err) + client := mcpclient.NewClient(mcpClientTransport) + require.NoError(t, client.Start(ctx)) + + // Initialize client, then call a tool. + _, err = mcptest.InitializeClient(ctx, client) + require.NoError(t, err) + mcptest.MustCallServerTool(t, ctx, client) + + // Close client and wait for end event. + require.NoError(t, client.Close()) + require.EventuallyWithT(t, func(t *assert.CollectT) { + require.Equal(t, libevents.MCPSessionEndEvent, emitter.LastEvent().GetType()) + }, time.Second, time.Millisecond*100, "waiting for end event") + + // Verify audit codes. Note that the order can be undeterministic as the + // listen request is sent from a go-routine by mcp-go client. + eventCodes := sliceutils.Map(emitter.Events(), func(e apievents.AuditEvent) string { + return e.GetCode() + }) + require.ElementsMatch(t, []string{ + libevents.MCPSessionStartCode, + libevents.MCPSessionRequestCode, // "initialize" + libevents.MCPSessionNotificationCode, + libevents.MCPSessionListenSSEStreamCode, + libevents.MCPSessionRequestCode, // "tools/call" + libevents.MCPSessionEndCode, + }, eventCodes) + }) + + t.Run("endpoint not found", func(t *testing.T) { + ctx := t.Context() + emitter.Reset() + mcpClientTransport, err := mcpclienttransport.NewStreamableHTTP( + fmt.Sprintf("http://%s/notfound", listener.Addr().String()), + ) + require.NoError(t, err) + + // Initialize client should fail. + client := mcpclient.NewClient(mcpClientTransport) + _, err = mcptest.InitializeClient(ctx, client) + require.Error(t, err) + + // Close client and verify failure event. + events := emitter.Events() + require.Len(t, events, 1) + lastEvent, ok := events[0].(*apievents.MCPSessionRequest) + require.True(t, ok) + require.Equal(t, libevents.MCPSessionRequestEvent, lastEvent.GetType()) + require.Equal(t, libevents.MCPSessionRequestFailureCode, lastEvent.GetCode()) + require.False(t, lastEvent.Success) + require.Equal(t, lastEvent.Error, "HTTP 404 Not Found") + }) +} + +func Test_handleAuthErrHTTP(t *testing.T) { + s, err := NewServer(ServerConfig{ + Emitter: &libevents.DiscardEmitter{}, + ParentContext: t.Context(), + HostID: "my-host-id", + AccessPoint: fakeAccessPoint{}, + CipherSuites: utils.DefaultCipherSuites(), + AuthClient: mockAuthClient{}, + }) + + require.NoError(t, err) + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + go func() { + conn, err := listener.Accept() + if err != nil { + assert.True(t, utils.IsOKNetworkError(err)) + return + } + defer conn.Close() + s.handleAuthErrHTTP(t.Context(), conn, trace.AccessDenied("access denied")) + }() + + mcpClientTransport, err := mcpclienttransport.NewStreamableHTTP( + fmt.Sprintf("http://%s", listener.Addr().String()), + ) + require.NoError(t, err) + client := mcpclient.NewClient(mcpClientTransport) + _, err = mcptest.InitializeClient(t.Context(), client) + require.ErrorContains(t, err, "access denied") +} diff --git a/lib/srv/mcp/server.go b/lib/srv/mcp/server.go index 9bdc0de102e8e..04ae5fde991a8 100644 --- a/lib/srv/mcp/server.go +++ b/lib/srv/mcp/server.go @@ -33,6 +33,8 @@ import ( apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/services" + appcommon "github.com/gravitational/teleport/lib/srv/app/common" + "github.com/gravitational/teleport/lib/utils" ) // AccessPoint defines functions that the MCP server requires from the caching @@ -42,6 +44,12 @@ type AccessPoint interface { services.ClusterNameGetter } +// AuthClient defines functions that the MCP server requires from the auth +// client. +type AuthClient interface { + appcommon.AppTokenGenerator +} + // ServerConfig is the config for the MCP forward server. type ServerConfig struct { // Emitter is used for emitting audit events. @@ -54,6 +62,8 @@ type ServerConfig struct { HostID string // AccessPoint is a caching client connected to the Auth Server. AccessPoint AccessPoint + // AuthClient is a client directly connected to the Auth server. + AuthClient AuthClient // EnableDemoServer enables the "Teleport Demo" MCP server. EnableDemoServer bool // CipherSuites is the list of TLS cipher suites that have been configured @@ -74,6 +84,9 @@ func (c *ServerConfig) CheckAndSetDefaults() error { if c.HostID == "" { return trace.BadParameter("missing HostID") } + if c.AuthClient == nil { + return trace.BadParameter("missing AuthClient") + } if c.AccessPoint == nil { return trace.BadParameter("missing AccessPoint") } @@ -93,6 +106,8 @@ func (c *ServerConfig) CheckAndSetDefaults() error { // TODO(greedy52) add server metrics. type Server struct { cfg ServerConfig + + sessionCache *utils.FnCache } // NewServer creates a new Server. @@ -100,8 +115,20 @@ func NewServer(cfg ServerConfig) (*Server, error) { if err := cfg.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } + + cache, err := utils.NewFnCache(utils.FnCacheConfig{ + TTL: 10 * time.Minute, + Context: cfg.ParentContext, + Clock: cfg.clock, + ReloadOnErr: true, + }) + if err != nil { + return nil, trace.Wrap(err) + } + return &Server{ - cfg: cfg, + cfg: cfg, + sessionCache: cache, }, nil } @@ -119,6 +146,8 @@ func (s *Server) HandleSession(ctx context.Context, sessionCtx *SessionCtx) erro return trace.Wrap(s.handleStdio(ctx, sessionCtx, makeExecServerRunner)) case types.MCPTransportSSE: return trace.Wrap(s.handleStdioToSSE(ctx, sessionCtx)) + case types.MCPTransportHTTP: + return trace.Wrap(s.handleStreamableHTTP(ctx, sessionCtx)) default: return trace.BadParameter("unknown transport type: %v", transportType) } @@ -127,10 +156,16 @@ func (s *Server) HandleSession(ctx context.Context, sessionCtx *SessionCtx) erro // HandleUnauthorizedConnection handles an unauthorized client connection. // This function has a hardcoded 30 seconds timeout in case the proper error // message cannot be delivered to the client. -func (s *Server) HandleUnauthorizedConnection(ctx context.Context, clientConn net.Conn, authErr error) error { +func (s *Server) HandleUnauthorizedConnection(ctx context.Context, clientConn net.Conn, app types.Application, authErr error) error { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() - return trace.Wrap(s.handleAuthErrStdio(ctx, clientConn, authErr)) + transportType := types.GetMCPServerTransportType(app.GetURI()) + switch transportType { + case types.MCPTransportHTTP: + return trace.Wrap(s.handleAuthErrHTTP(ctx, clientConn, authErr)) + default: + return trace.Wrap(s.handleAuthErrStdio(ctx, clientConn, authErr)) + } } func (s *Server) makeSessionAuditor(ctx context.Context, sessionCtx *SessionCtx, logger *slog.Logger) (*sessionAuditor, error) { @@ -167,6 +202,7 @@ func (s *Server) makeSessionHandler(ctx context.Context, sessionCtx *SessionCtx) logger := s.cfg.Log.With( "client_ip", sessionCtx.ClientConn.RemoteAddr(), "app", sessionCtx.App.GetName(), + "app_uri", sessionCtx.App.GetURI(), "user", sessionCtx.AuthCtx.User.GetName(), "session_id", sessionCtx.sessionID, ) @@ -185,3 +221,17 @@ func (s *Server) makeSessionHandler(ctx context.Context, sessionCtx *SessionCtx) clock: s.cfg.clock, }) } + +func (s *Server) makeSessionHandlerWithJWT(ctx context.Context, sessionCtx *SessionCtx) (*sessionHandler, error) { + if err := sessionCtx.generateJWTAndTraits(ctx, s.cfg.AuthClient); err != nil { + return nil, trace.Wrap(err) + } + return s.makeSessionHandler(ctx, sessionCtx) +} + +func (s *Server) getSessionHandlerWithJWT(ctx context.Context, sessionCtx *SessionCtx) (*sessionHandler, error) { + ttl := min(sessionCtx.Identity.Expires.Sub(s.cfg.clock.Now()), 10*time.Minute) + return utils.FnCacheGetWithTTL(ctx, s.sessionCache, sessionCtx.sessionID, ttl, func(ctx context.Context) (*sessionHandler, error) { + return s.makeSessionHandlerWithJWT(ctx, sessionCtx) + }) +} diff --git a/lib/srv/mcp/session.go b/lib/srv/mcp/session.go index 29be7cb08150f..dd51bd85990ae 100644 --- a/lib/srv/mcp/session.go +++ b/lib/srv/mcp/session.go @@ -31,10 +31,12 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/types/wrappers" "github.com/gravitational/teleport/lib/authz" dtauthz "github.com/gravitational/teleport/lib/devicetrust/authz" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/session" + appcommon "github.com/gravitational/teleport/lib/srv/app/common" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/mcputils" @@ -59,6 +61,12 @@ type SessionCtx struct { // mcpSessionID is the MCP session ID tracked by remote MCP server. mcpSessionID atomicString + + // jwt is the jwt token signed for this identity by Auth server. + jwt string + + // traitsForRewriteHeaders are user traits used for rewriting headers. + traitsForRewriteHeaders wrappers.Traits } func (c *SessionCtx) checkAndSetDefaults() error { @@ -75,7 +83,14 @@ func (c *SessionCtx) checkAndSetDefaults() error { c.Identity = c.AuthCtx.Identity.GetIdentity() } if c.sessionID == "" { - c.sessionID = session.NewID() + if types.MCPTransportHTTP == types.GetMCPServerTransportType(c.App.GetURI()) { + // A single HTTP request is handled at a time so take session ID + // from cert. + c.sessionID = session.ID(c.Identity.RouteToApp.SessionID) + } + if c.sessionID == "" { + c.sessionID = session.NewID() + } } return nil } @@ -89,6 +104,15 @@ func (c *SessionCtx) getAccessState(authPref types.AuthPreference) services.Acce return state } +func (c *SessionCtx) generateJWTAndTraits(ctx context.Context, auth AuthClient) (err error) { + if len(c.jwt) != 0 { + return nil + } + + c.jwt, c.traitsForRewriteHeaders, err = appcommon.GenerateJWTAndTraits(ctx, &c.Identity, c.App, auth) + return trace.Wrap(err) +} + type sessionHandlerConfig struct { *SessionCtx *sessionAuditor @@ -198,6 +222,7 @@ func (s *sessionHandler) onClientRequest(clientResponseWriter, serverRequestWrit func (s *sessionHandler) onServerNotification(clientResponseWriter mcputils.MessageWriter) mcputils.HandleNotificationFunc { return func(ctx context.Context, notification *mcputils.JSONRPCNotification) error { + s.processServerNotification(ctx, notification) return trace.Wrap(clientResponseWriter.WriteMessage(ctx, notification)) } } @@ -217,17 +242,26 @@ const ( ) func (s *sessionHandler) processClientRequest(ctx context.Context, req *mcputils.JSONRPCRequest) (mcp.JSONRPCMessage, replyDirection) { + s.idTracker.PushRequest(req) + reply, authErr := s.processClientRequestNoAudit(ctx, req) + if authErr != nil { + s.emitRequestEvent(ctx, req, authErr) + return reply, replyToClient + } + s.emitRequestEvent(ctx, req, nil) + return reply, replyToServer +} + +func (s *sessionHandler) processClientRequestNoAudit(ctx context.Context, req *mcputils.JSONRPCRequest) (mcp.JSONRPCMessage, error) { s.idTracker.PushRequest(req) switch req.Method { case mcp.MethodToolsCall: methodName, _ := req.Params.GetName() if authErr := s.checkAccessToTool(ctx, methodName); authErr != nil { - s.emitRequestEvent(ctx, req, authErr) - return makeToolAccessDeniedResponse(req, authErr), replyToClient + return makeToolAccessDeniedResponse(req, authErr), trace.Wrap(authErr) } } - s.emitRequestEvent(ctx, req, nil) - return req, replyToServer + return req, nil } func (s *sessionHandler) processServerResponse(ctx context.Context, response *mcputils.JSONRPCResponse) mcp.JSONRPCMessage { @@ -239,6 +273,10 @@ func (s *sessionHandler) processServerResponse(ctx context.Context, response *mc return response } +func (s *sessionHandler) processServerNotification(ctx context.Context, notification *mcputils.JSONRPCNotification) { + s.logger.DebugContext(ctx, "Received server notification.", "method", notification.Method) +} + func (s *sessionHandler) makeToolsCallResponse(ctx context.Context, resp *mcputils.JSONRPCResponse) mcp.JSONRPCMessage { // Nothing to do, likely an error response. if resp.Result == nil { diff --git a/lib/srv/mcp/sse.go b/lib/srv/mcp/sse.go index 2384fa7b17347..63fc39bd6f8b0 100644 --- a/lib/srv/mcp/sse.go +++ b/lib/srv/mcp/sse.go @@ -52,6 +52,7 @@ func (s *Server) handleStdioToSSE(ctx context.Context, sessionCtx *SessionCtx) e if err != nil { return trace.Wrap(err, "creating HTTP transport") } + // TODO(greedy52) support JWT for SSE transport. session, err := s.makeSessionHandler(ctx, sessionCtx) if err != nil { return trace.Wrap(err, "setting up session handler") diff --git a/lib/srv/mcp/sse_test.go b/lib/srv/mcp/sse_test.go index 855675d4b9d6e..8ec15e972f685 100644 --- a/lib/srv/mcp/sse_test.go +++ b/lib/srv/mcp/sse_test.go @@ -51,6 +51,7 @@ func Test_handleStdioToSSE(t *testing.T) { HostID: "my-host-id", AccessPoint: fakeAccessPoint{}, CipherSuites: utils.DefaultCipherSuites(), + AuthClient: mockAuthClient{}, }) require.NoError(t, err) diff --git a/lib/srv/mcp/stdio_test.go b/lib/srv/mcp/stdio_test.go index 044ebea39428e..6a90af1e09a03 100644 --- a/lib/srv/mcp/stdio_test.go +++ b/lib/srv/mcp/stdio_test.go @@ -46,20 +46,21 @@ func Test_handleAuthErrStdio(t *testing.T) { HostID: "my-host-id", AccessPoint: fakeAccessPoint{}, CipherSuites: utils.DefaultCipherSuites(), + AuthClient: mockAuthClient{}, }) require.NoError(t, err) - clientSourceConn, clientDestConn := makeDualPipeNetConn(t) + testCtx := setupTestContext(t, withAdminRole(t)) originalAuthErr := trace.AccessDenied("test access denied") handlerDoneCh := make(chan struct{}, 1) go func() { - handlerErr := s.HandleUnauthorizedConnection(ctx, clientDestConn, originalAuthErr) + handlerErr := s.HandleUnauthorizedConnection(ctx, testCtx.SessionCtx.ClientConn, testCtx.SessionCtx.App, originalAuthErr) handlerDoneCh <- struct{}{} require.ErrorIs(t, handlerErr, originalAuthErr) }() - stdioClient := mcptest.NewStdioClientFromConn(t, clientSourceConn) + stdioClient := mcptest.NewStdioClientFromConn(t, testCtx.clientSourceConn) _, err = mcptest.InitializeClient(ctx, stdioClient) require.EqualError(t, err, originalAuthErr.Error()) @@ -80,6 +81,7 @@ func Test_handleStdio(t *testing.T) { HostID: "my-host-id", AccessPoint: fakeAccessPoint{}, CipherSuites: utils.DefaultCipherSuites(), + AuthClient: mockAuthClient{}, }) require.NoError(t, err) @@ -154,6 +156,7 @@ func TestHandleSession_execMCPServer(t *testing.T) { HostID: "my-host-id", AccessPoint: fakeAccessPoint{}, CipherSuites: utils.DefaultCipherSuites(), + AuthClient: mockAuthClient{}, }) require.NoError(t, err) diff --git a/tool/tsh/common/proxy.go b/tool/tsh/common/proxy.go index 3bb3421ebd0a4..86605d3c54ae1 100644 --- a/tool/tsh/common/proxy.go +++ b/tool/tsh/common/proxy.go @@ -516,7 +516,13 @@ func onProxyCommandApp(cf *CLIConf) error { } if app.IsMCP() { - return trace.BadParameter("MCP applications are not supported. Please see 'tsh mcp config --help' for more details.") + // TODO(greedy52) refactor and implement "tsh proxy mcp". + switch types.GetMCPServerTransportType(app.GetURI()) { + case types.MCPTransportHTTP: + // continue + default: + return trace.BadParameter("MCP applications are not supported. Please see 'tsh mcp config --help' for more details.") + } } proxyApp, err := newLocalProxyAppWithPortMapping(cf.Context, tc, profile, appInfo.RouteToApp, app, portMapping, cf.InsecureSkipVerify) From ebdad41978f8bf448a463c8e0aabb2e657305bca Mon Sep 17 00:00:00 2001 From: "STeve (Xin) Huang" Date: Fri, 26 Sep 2025 12:52:42 -0400 Subject: [PATCH 2/3] review comments round 1 --- lib/srv/mcp/http.go | 16 +++++++++++----- lib/srv/mcp/http_test.go | 2 +- lib/srv/mcp/session.go | 9 +++------ 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/lib/srv/mcp/http.go b/lib/srv/mcp/http.go index ee1da6c42f66a..6eed87ffb84af 100644 --- a/lib/srv/mcp/http.go +++ b/lib/srv/mcp/http.go @@ -51,15 +51,21 @@ func (s *Server) serveHTTPConn(ctx context.Context, conn net.Conn, handler http. waitConn := utils.NewCloserConn(conn) listener := listenerutils.NewSingleUseListener(waitConn) - if err := http.Serve(listener, handler); err != nil && !utils.IsOKNetworkError(err) { - waitConn.Close() - return trace.Wrap(err) - } go func() { - // Make sure handler returns if ctx is canceled. + // Make sure connection is closed when ctx is canceled. <-ctx.Done() waitConn.Close() }() + + httpServer := &http.Server{ + Handler: handler, + BaseContext: func(net.Listener) context.Context { + return ctx + }, + } + if err := httpServer.Serve(listener); err != nil && !utils.IsOKNetworkError(err) { + return trace.Wrap(err) + } waitConn.Wait() return nil } diff --git a/lib/srv/mcp/http_test.go b/lib/srv/mcp/http_test.go index 608d7bc032088..e924cbe44fd83 100644 --- a/lib/srv/mcp/http_test.go +++ b/lib/srv/mcp/http_test.go @@ -162,7 +162,7 @@ func Test_handleStreamableHTTP(t *testing.T) { require.Equal(t, libevents.MCPSessionRequestEvent, lastEvent.GetType()) require.Equal(t, libevents.MCPSessionRequestFailureCode, lastEvent.GetCode()) require.False(t, lastEvent.Success) - require.Equal(t, lastEvent.Error, "HTTP 404 Not Found") + require.Equal(t, "HTTP 404 Not Found", lastEvent.Error) }) } diff --git a/lib/srv/mcp/session.go b/lib/srv/mcp/session.go index dd51bd85990ae..07fff770e914b 100644 --- a/lib/srv/mcp/session.go +++ b/lib/srv/mcp/session.go @@ -105,10 +105,6 @@ func (c *SessionCtx) getAccessState(authPref types.AuthPreference) services.Acce } func (c *SessionCtx) generateJWTAndTraits(ctx context.Context, auth AuthClient) (err error) { - if len(c.jwt) != 0 { - return nil - } - c.jwt, c.traitsForRewriteHeaders, err = appcommon.GenerateJWTAndTraits(ctx, &c.Identity, c.App, auth) return trace.Wrap(err) } @@ -244,11 +240,12 @@ const ( func (s *sessionHandler) processClientRequest(ctx context.Context, req *mcputils.JSONRPCRequest) (mcp.JSONRPCMessage, replyDirection) { s.idTracker.PushRequest(req) reply, authErr := s.processClientRequestNoAudit(ctx, req) + s.emitRequestEvent(ctx, req, authErr) + + // Not forwarding to server. Just send the auth error to client. if authErr != nil { - s.emitRequestEvent(ctx, req, authErr) return reply, replyToClient } - s.emitRequestEvent(ctx, req, nil) return reply, replyToServer } From 217c6b40fa81466bf4e41229da7bbe93cf19dd61 Mon Sep 17 00:00:00 2001 From: "STeve (Xin) Huang" Date: Mon, 29 Sep 2025 13:20:19 -0400 Subject: [PATCH 3/3] add comments and fix flaky test --- lib/srv/mcp/http.go | 8 ++++++- lib/srv/mcp/http_test.go | 43 ++++++++++++++++++++---------------- lib/utils/listener/memory.go | 10 +++++++++ 3 files changed, 41 insertions(+), 20 deletions(-) diff --git a/lib/srv/mcp/http.go b/lib/srv/mcp/http.go index 6eed87ffb84af..c181df4881262 100644 --- a/lib/srv/mcp/http.go +++ b/lib/srv/mcp/http.go @@ -169,11 +169,17 @@ func (t *streamableHTTPTransport) rewriteRequest(r *http.Request) *http.Request // Defaults to the endpoint defined in the app if client is not providing it. // By spec, streamable HTTP should use a single endpoint except the // ".well-known" used for OAuth. + // https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#streamable-http if t.targetURI.Path != "" && (r.URL.Path == "" || r.URL.Path == "/") { r.URL.Path = t.targetURI.Path } - // Add in JWT headers. + // Add in JWT headers. By default, JWT is not put into "Authorization" + // headers since the auth token can also come from the client and Teleport + // just pass it through. If the remote MCP server does verify the auth token + // signed by Teleport, the server can take the token from the + // "teleport-jwt-assertion" header or use a rewrite setting to set the JWT + // as "Bearer" in "Authorization". r.Header.Set(teleport.AppJWTHeader, t.jwt) // Add headers from rewrite configuration. rewriteHeaders := appcommon.AppRewriteHeaders(r.Context(), t.App.GetRewrite(), t.logger) diff --git a/lib/srv/mcp/http_test.go b/lib/srv/mcp/http_test.go index e924cbe44fd83..89455e4325246 100644 --- a/lib/srv/mcp/http_test.go +++ b/lib/srv/mcp/http_test.go @@ -39,11 +39,14 @@ import ( libevents "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/eventstest" "github.com/gravitational/teleport/lib/utils" + listenerutils "github.com/gravitational/teleport/lib/utils/listener" "github.com/gravitational/teleport/lib/utils/mcptest" sliceutils "github.com/gravitational/teleport/lib/utils/slices" ) func Test_handleStreamableHTTP(t *testing.T) { + t.Parallel() + remoteMCPServer := mcpserver.NewStreamableHTTPServer(mcptest.NewServer()) remoteMCPHTTPServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { @@ -86,7 +89,7 @@ func Test_handleStreamableHTTP(t *testing.T) { // Run MCP handler behind a listener. var wg sync.WaitGroup t.Cleanup(wg.Wait) - listener, err := net.Listen("tcp", "127.0.0.1:0") + listener := listenerutils.NewInMemoryListener() require.NoError(t, err) defer listener.Close() go func() { @@ -108,44 +111,46 @@ func Test_handleStreamableHTTP(t *testing.T) { ctx := t.Context() emitter.Reset() mcpClientTransport, err := mcpclienttransport.NewStreamableHTTP( - "http://"+listener.Addr().String(), + "http://memory", + mcpclienttransport.WithHTTPBasicClient(listener.MakeHTTPClient()), mcpclienttransport.WithContinuousListening(), ) require.NoError(t, err) client := mcpclient.NewClient(mcpClientTransport) require.NoError(t, client.Start(ctx)) - // Initialize client, then call a tool. + // Initialize client, then call a tool. Note that the order can be + // undeterministic as the listen request is sent from a go-routine by + // mcp-go client. + getEventCode := func(e apievents.AuditEvent) string { + return e.GetCode() + } _, err = mcptest.InitializeClient(ctx, client) require.NoError(t, err) mcptest.MustCallServerTool(t, ctx, client) + require.EventuallyWithT(t, func(t *assert.CollectT) { + require.ElementsMatch(t, []string{ + libevents.MCPSessionStartCode, + libevents.MCPSessionRequestCode, // "initialize" + libevents.MCPSessionNotificationCode, + libevents.MCPSessionListenSSEStreamCode, + libevents.MCPSessionRequestCode, // "tools/call" + }, sliceutils.Map(emitter.Events(), getEventCode)) + }, 2*time.Second, time.Millisecond*100, "waiting for events") // Close client and wait for end event. require.NoError(t, client.Close()) require.EventuallyWithT(t, func(t *assert.CollectT) { require.Equal(t, libevents.MCPSessionEndEvent, emitter.LastEvent().GetType()) - }, time.Second, time.Millisecond*100, "waiting for end event") - - // Verify audit codes. Note that the order can be undeterministic as the - // listen request is sent from a go-routine by mcp-go client. - eventCodes := sliceutils.Map(emitter.Events(), func(e apievents.AuditEvent) string { - return e.GetCode() - }) - require.ElementsMatch(t, []string{ - libevents.MCPSessionStartCode, - libevents.MCPSessionRequestCode, // "initialize" - libevents.MCPSessionNotificationCode, - libevents.MCPSessionListenSSEStreamCode, - libevents.MCPSessionRequestCode, // "tools/call" - libevents.MCPSessionEndCode, - }, eventCodes) + }, 2*time.Second, time.Millisecond*100, "waiting for end event") }) t.Run("endpoint not found", func(t *testing.T) { ctx := t.Context() emitter.Reset() mcpClientTransport, err := mcpclienttransport.NewStreamableHTTP( - fmt.Sprintf("http://%s/notfound", listener.Addr().String()), + "http://memory/notfound", + mcpclienttransport.WithHTTPBasicClient(listener.MakeHTTPClient()), ) require.NoError(t, err) diff --git a/lib/utils/listener/memory.go b/lib/utils/listener/memory.go index 5f7795a69952d..107a56ed18edb 100644 --- a/lib/utils/listener/memory.go +++ b/lib/utils/listener/memory.go @@ -21,6 +21,7 @@ import ( "errors" "io" "net" + "net/http" "sync" ) @@ -86,6 +87,15 @@ func (m *InMemoryListener) DialContext(ctx context.Context, _ string, _ string) return clientConn, nil } +// MakeHTTPClient is a helper to generate an HTTP client that dials this listener. +func (m *InMemoryListener) MakeHTTPClient() *http.Client { + return &http.Client{ + Transport: &http.Transport{ + DialContext: m.DialContext, + }, + } +} + // ErrListenerClosed is the error returned by dial when the listener is closed. var ErrListenerClosed = errors.New("in-memory listener closed")