diff --git a/api/types/app.go b/api/types/app.go
index da514c22abb84..38dd917840c17 100644
--- a/api/types/app.go
+++ b/api/types/app.go
@@ -526,7 +526,7 @@ func (a *AppV3) checkMCP() error {
switch GetMCPServerTransportType(a.Spec.URI) {
case MCPTransportStdio:
return trace.Wrap(a.checkMCPStdio())
- case MCPTransportSSE:
+ case MCPTransportSSE, MCPTransportHTTP:
_, err := url.Parse(a.Spec.URI)
return trace.Wrap(err)
default:
@@ -690,6 +690,8 @@ func GetMCPServerTransportType(uri string) string {
return MCPTransportStdio
case SchemeMCPSSEHTTP, SchemeMCPSSEHTTPS:
return MCPTransportSSE
+ case SchemeMCPHTTP, SchemeMCPHTTPS:
+ return MCPTransportHTTP
default:
return ""
}
diff --git a/api/types/app_test.go b/api/types/app_test.go
index 861f3f7000b59..12d3110fb00ac 100644
--- a/api/types/app_test.go
+++ b/api/types/app_test.go
@@ -689,7 +689,7 @@ func TestNewAppV3(t *testing.T) {
wantErr: require.NoError,
},
{
- name: "mcp with uri",
+ name: "mcp with SSE transport",
meta: Metadata{
Name: "mcp-everything",
},
@@ -711,6 +711,29 @@ func TestNewAppV3(t *testing.T) {
},
wantErr: require.NoError,
},
+ {
+ name: "mcp with streamable HTTP transport",
+ meta: Metadata{
+ Name: "mcp-everything",
+ },
+ spec: AppSpecV3{
+ URI: "mcp+http://localhost:12345/mcp",
+ },
+ want: &AppV3{
+ Kind: "app",
+ SubKind: "mcp",
+ Version: "v3",
+ Metadata: Metadata{
+ Name: "mcp-everything",
+ Namespace: "default",
+ Labels: map[string]string{AppSubKindLabel: "mcp"},
+ },
+ Spec: AppSpecV3{
+ URI: "mcp+http://localhost:12345/mcp",
+ },
+ },
+ wantErr: require.NoError,
+ },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
diff --git a/api/types/constants.go b/api/types/constants.go
index 4d63544a6d086..9adb0156d0a99 100644
--- a/api/types/constants.go
+++ b/api/types/constants.go
@@ -951,6 +951,14 @@ const (
SchemeMCPSSEHTTPS = "mcp+sse+https"
// MCPTransportSSE indicates the MCP server uses SSE transport.
MCPTransportSSE = "SSE"
+ // SchemeMCPHTTP is a URI scheme for MCP servers using HTTP with streamable
+ // HTTP transport.
+ SchemeMCPHTTP = "mcp+http"
+ // SchemeMCPHTTPS is a URI scheme for MCP servers using HTTPS with
+ // streamable HTTP transport.
+ SchemeMCPHTTPS = "mcp+https"
+ // MCPTransportHTTP indicates the MCP server uses SSE transport.
+ MCPTransportHTTP = "Streamable HTTP"
// DiscoveredResourceNode identifies a discovered SSH node.
DiscoveredResourceNode = "node"
diff --git a/integration/appaccess/appaccess_test.go b/integration/appaccess/appaccess_test.go
index 9c1eb79b8d30c..6122b9b840abd 100644
--- a/integration/appaccess/appaccess_test.go
+++ b/integration/appaccess/appaccess_test.go
@@ -23,6 +23,7 @@ import (
"context"
"crypto/tls"
"errors"
+ "fmt"
"io"
"net"
"net/http"
@@ -35,6 +36,7 @@ import (
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
+ mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -62,10 +64,20 @@ import (
func TestAppAccess(t *testing.T) {
// Enable MCP test servers.
sseServerURL := mcptest.MustStartSSETestServer(t)
- extraApps := []servicecfg.App{{
- Name: "test-sse",
- URI: "mcp+sse+" + sseServerURL,
- }}
+ streamableHTTPServer := mcpserver.NewTestStreamableHTTPServer(mcptest.NewServer())
+ streamableHTTPServerURL := fmt.Sprintf("mcp+%s/mcp", streamableHTTPServer.URL)
+ t.Cleanup(streamableHTTPServer.Close)
+
+ extraApps := []servicecfg.App{
+ {
+ Name: "test-sse",
+ URI: "mcp+sse+" + sseServerURL,
+ },
+ {
+ Name: "test-http",
+ URI: streamableHTTPServerURL,
+ },
+ }
// Reusing the pack as much as we can.
pack := SetupWithOptions(t, AppTestOptions{
diff --git a/integration/appaccess/mcp_test.go b/integration/appaccess/mcp_test.go
index c6dd65964362f..cf513e33acf60 100644
--- a/integration/appaccess/mcp_test.go
+++ b/integration/appaccess/mcp_test.go
@@ -20,11 +20,19 @@ package appaccess
import (
"context"
+ "crypto/tls"
+ "fmt"
+ "net/http"
"testing"
+ mcpclient "github.com/mark3labs/mcp-go/client"
+ mcpclienttransport "github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/require"
+ "github.com/gravitational/teleport/api/client/proto"
+ "github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/lib/client"
libmcp "github.com/gravitational/teleport/lib/srv/mcp"
"github.com/gravitational/teleport/lib/utils/mcptest"
)
@@ -41,6 +49,10 @@ func testMCP(pack *Pack, t *testing.T) {
t.Run("DialMCPServer stdio to sse success", func(t *testing.T) {
testMCPDialStdioToSSE(t, pack, "test-sse")
})
+
+ t.Run("proxy streamable HTTP requests with TLS cert", func(t *testing.T) {
+ testMCPProxyStreamableHTTP(t, pack, "test-http")
+ })
}
func testMCPDialStdioNoServerFound(t *testing.T, pack *Pack) {
@@ -81,3 +93,50 @@ func testMCPDialStdioToSSE(t *testing.T, pack *Pack, appName string) {
mcptest.MustCallServerTool(t, ctx, stdioClient)
}
+
+func testMCPProxyStreamableHTTP(t *testing.T, pack *Pack, appName string) {
+ require.NoError(t, pack.tc.SaveProfile(false))
+
+ // Find the MCP server.
+ filter := pack.tc.ResourceFilter(types.KindAppServer)
+ filter.PredicateExpression = fmt.Sprintf(`name == "%s"`, appName)
+ apps, err := pack.tc.ListApps(t.Context(), filter)
+ require.NoError(t, err)
+ require.Len(t, apps, 1)
+
+ // Issue a TLS cert with app route.
+ keyRing, err := pack.tc.IssueUserCertsWithMFA(t.Context(), client.ReissueParams{
+ RouteToCluster: pack.rootCluster.Secrets.SiteName,
+ RouteToApp: proto.RouteToApp{
+ ClusterName: pack.rootCluster.Secrets.SiteName,
+ Name: apps[0].GetName(),
+ PublicAddr: apps[0].GetPublicAddr(),
+ },
+ })
+ require.NoError(t, err)
+ appCert, err := keyRing.AppTLSCert(appName)
+ require.NoError(t, err)
+
+ // Create an MCP client with app cert.
+ ctx := t.Context()
+ mcpClientTransport, err := mcpclienttransport.NewStreamableHTTP(
+ "https://"+pack.rootCluster.Web,
+ mcpclienttransport.WithHTTPBasicClient(&http.Client{
+ Transport: &http.Transport{
+ TLSClientConfig: &tls.Config{
+ Certificates: []tls.Certificate{appCert},
+ InsecureSkipVerify: true,
+ },
+ },
+ }),
+ )
+ require.NoError(t, err)
+ client := mcpclient.NewClient(mcpClientTransport)
+ require.NoError(t, client.Start(ctx))
+ defer client.Close()
+
+ // Initialize client and call a tool.
+ _, err = mcptest.InitializeClient(ctx, client)
+ require.NoError(t, err)
+ mcptest.MustCallServerTool(t, ctx, client)
+}
diff --git a/lib/client/api.go b/lib/client/api.go
index 65abe6fb5cb2d..4b3f5c24d81be 100644
--- a/lib/client/api.go
+++ b/lib/client/api.go
@@ -5567,6 +5567,12 @@ func (tc *TeleportClient) DialMCPServer(ctx context.Context, appName string) (ne
return nil, trace.BadParameter("app %q is not a MCP server", appName)
}
+ // TODO(greedy52) support streamable HTTP for "tsh mcp connect" before
+ // release.
+ if transport := types.GetMCPServerTransportType(apps[0].GetURI()); transport == types.MCPTransportHTTP {
+ return nil, trace.NotImplemented("MCP support for %s is not yet implemented", transport)
+ }
+
cert, err := tc.issueMCPCertWithMFA(ctx, apps[0])
if err != nil {
return nil, trace.Wrap(err)
diff --git a/lib/srv/app/connections_handler.go b/lib/srv/app/connections_handler.go
index 8c5f13c1a1666..3953ce05fef2d 100644
--- a/lib/srv/app/connections_handler.go
+++ b/lib/srv/app/connections_handler.go
@@ -285,6 +285,7 @@ func NewConnectionsHandler(closeContext context.Context, cfg *ConnectionsHandler
AccessPoint: c.cfg.AccessPoint,
EnableDemoServer: c.cfg.MCPDemoServer,
CipherSuites: c.cfg.CipherSuites,
+ AuthClient: c.cfg.AuthClient,
})
if err != nil {
return nil, trace.Wrap(err)
@@ -614,7 +615,7 @@ func (c *ConnectionsHandler) handleConnection(conn net.Conn) (func(), error) {
case app.IsTCP():
return nil, trace.Wrap(err)
case app.IsMCP():
- return nil, trace.Wrap(c.mcpServer.HandleUnauthorizedConnection(ctx, conn, err))
+ return nil, trace.Wrap(c.mcpServer.HandleUnauthorizedConnection(ctx, conn, app, err))
default:
c.setConnAuth(tlsConn, err)
}
diff --git a/lib/srv/mcp/helpers_test.go b/lib/srv/mcp/helpers_test.go
index ca2df2f9f9c51..fae369d42b259 100644
--- a/lib/srv/mcp/helpers_test.go
+++ b/lib/srv/mcp/helpers_test.go
@@ -51,8 +51,9 @@ func TestMain(m *testing.M) {
}
type setupTestContextOptions struct {
- roleSet services.RoleSet
- app types.Application
+ roleSet services.RoleSet
+ app types.Application
+ clientConn net.Conn
}
type setupTestContextOptionFunc func(*setupTestContextOptions)
@@ -69,6 +70,12 @@ func withRole(role types.Role) setupTestContextOptionFunc {
}
}
+func withClientConn(conn net.Conn) setupTestContextOptionFunc {
+ return func(opts *setupTestContextOptions) {
+ opts.clientConn = conn
+ }
+}
+
// withAdminRole assigns to ai_user a role that allows all MCP servers and their
// tools.
func withAdminRole(t *testing.T) setupTestContextOptionFunc {
@@ -139,8 +146,13 @@ func setupTestContext(t *testing.T, applyOpts ...setupTestContextOptionFunc) tes
applyOpt(&opts)
}
- // Fake connection.
- clientSourceConn, clientDestConn := makeDualPipeNetConn(t)
+ // Fake connection if not passed in.
+ var clientSourceConn, clientDestConn net.Conn
+ if opts.clientConn != nil {
+ clientDestConn = opts.clientConn
+ } else {
+ clientSourceConn, clientDestConn = makeDualPipeNetConn(t)
+ }
// App.
if opts.app == nil {
@@ -164,7 +176,7 @@ func setupTestContext(t *testing.T, applyOpts ...setupTestContextOptionFunc) tes
sessionCtx := &SessionCtx{
ClientConn: clientDestConn,
App: opts.app,
- AuthCtx: makeTestAuthContext(t, opts.roleSet),
+ AuthCtx: makeTestAuthContext(t, opts.roleSet, opts.app),
}
require.NoError(t, sessionCtx.checkAndSetDefaults())
@@ -174,7 +186,7 @@ func setupTestContext(t *testing.T, applyOpts ...setupTestContextOptionFunc) tes
}
}
-func makeTestAuthContext(t *testing.T, roleSet services.RoleSet) *authz.Context {
+func makeTestAuthContext(t *testing.T, roleSet services.RoleSet, app types.Application) *authz.Context {
t.Helper()
user, err := types.NewUser("ai")
@@ -189,6 +201,11 @@ func makeTestAuthContext(t *testing.T, roleSet services.RoleSet) *authz.Context
Principals: user.GetLogins(),
},
}
+ if app != nil {
+ identity.Identity.RouteToApp.Name = app.GetName()
+ identity.Identity.RouteToApp.SessionID = "session-id-for+" + app.GetName()
+ }
+
accessInfo, err := services.AccessInfoFromLocalTLSIdentity(identity.Identity)
require.NoError(t, err)
checker := services.NewAccessCheckerWithRoleSet(accessInfo, "my-cluster", roleSet)
@@ -327,3 +344,10 @@ func forceRemoveContainer(t *testing.T, dockerClient *docker.Client, containerNa
}
}
}
+
+type mockAuthClient struct {
+}
+
+func (m mockAuthClient) GenerateAppToken(_ context.Context, req types.GenerateAppTokenRequest) (string, error) {
+ return "app-token-for-" + req.Username, nil
+}
diff --git a/lib/srv/mcp/http.go b/lib/srv/mcp/http.go
new file mode 100644
index 0000000000000..c181df4881262
--- /dev/null
+++ b/lib/srv/mcp/http.go
@@ -0,0 +1,309 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package mcp
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "io"
+ "net"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/gravitational/trace"
+ "github.com/mark3labs/mcp-go/mcp"
+
+ "github.com/gravitational/teleport"
+ "github.com/gravitational/teleport/lib/httplib/reverseproxy"
+ "github.com/gravitational/teleport/lib/services"
+ appcommon "github.com/gravitational/teleport/lib/srv/app/common"
+ "github.com/gravitational/teleport/lib/utils"
+ listenerutils "github.com/gravitational/teleport/lib/utils/listener"
+ "github.com/gravitational/teleport/lib/utils/mcputils"
+)
+
+const (
+ mcpSessionIDHeader = "Mcp-Session-Id"
+)
+
+func (s *Server) serveHTTPConn(ctx context.Context, conn net.Conn, handler http.Handler) error {
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ waitConn := utils.NewCloserConn(conn)
+ listener := listenerutils.NewSingleUseListener(waitConn)
+ go func() {
+ // Make sure connection is closed when ctx is canceled.
+ <-ctx.Done()
+ waitConn.Close()
+ }()
+
+ httpServer := &http.Server{
+ Handler: handler,
+ BaseContext: func(net.Listener) context.Context {
+ return ctx
+ },
+ }
+ if err := httpServer.Serve(listener); err != nil && !utils.IsOKNetworkError(err) {
+ return trace.Wrap(err)
+ }
+ waitConn.Wait()
+ return nil
+}
+
+func (s *Server) handleAuthErrHTTP(ctx context.Context, clientConn net.Conn, authErr error) error {
+ return trace.Wrap(s.serveHTTPConn(ctx, clientConn, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+ trace.WriteError(w, authErr)
+ })))
+}
+
+func (s *Server) handleStreamableHTTP(ctx context.Context, sessionCtx *SessionCtx) error {
+ session, err := s.getSessionHandlerWithJWT(ctx, sessionCtx)
+ if err != nil {
+ return trace.Wrap(err, "setting up session handler")
+ }
+
+ transport, err := s.makeStreamableHTTPTransport(session)
+ if err != nil {
+ return trace.Wrap(err, "setting up streamable http transport")
+ }
+
+ session.logger.DebugContext(ctx, "Started handling HTTP request")
+ defer session.logger.DebugContext(ctx, "Completed handling HTTP request")
+
+ delegate := reverseproxy.NewHeaderRewriter()
+ reverseProxy, err := reverseproxy.New(
+ reverseproxy.WithFlushInterval(100*time.Millisecond),
+ reverseproxy.WithRoundTripper(transport),
+ reverseproxy.WithLogger(session.logger),
+ reverseproxy.WithRewriter(appcommon.NewHeaderRewriter(delegate)),
+ reverseproxy.WithResponseModifier(func(resp *http.Response) error {
+ if resp.Request != nil && resp.Request.Method == http.MethodDelete {
+ // Nothing to modify here.
+ return nil
+ }
+ return trace.Wrap(mcputils.ReplaceHTTPResponse(ctx, resp, newHTTPResponseReplacer(session)))
+ }),
+ )
+ if err != nil {
+ return trace.Wrap(err, "creating reverse proxy")
+ }
+
+ return trace.Wrap(s.serveHTTPConn(ctx, sessionCtx.ClientConn, reverseProxy))
+}
+
+func (s *Server) makeStreamableHTTPTransport(session *sessionHandler) (http.RoundTripper, error) {
+ targetURI, err := url.Parse(session.App.GetURI())
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ targetURI.Scheme = strings.TrimPrefix(targetURI.Scheme, "mcp+")
+
+ targetTransport, err := s.makeHTTPTransport(session.App)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return &streamableHTTPTransport{
+ sessionHandler: session,
+ targetURI: targetURI,
+ targetTransport: targetTransport,
+ }, nil
+}
+
+type streamableHTTPTransport struct {
+ *sessionHandler
+ targetURI *url.URL
+ targetTransport http.RoundTripper
+}
+
+func (t *streamableHTTPTransport) RoundTrip(r *http.Request) (*http.Response, error) {
+ t.setExternalSessionID(r.Header)
+
+ switch r.Method {
+ case http.MethodDelete:
+ return t.handleSessionEndRequest(r)
+ case http.MethodGet:
+ return t.handleListenSSEStreamRequest(r)
+ case http.MethodPost:
+ return t.handleMCPMessage(r)
+
+ default:
+ t.emitInvalidHTTPRequest(t.parentCtx, r)
+ return &http.Response{
+ Request: r,
+ StatusCode: http.StatusMethodNotAllowed,
+ }, nil
+ }
+}
+
+func (t *streamableHTTPTransport) setExternalSessionID(header http.Header) {
+ if id := header.Get(mcpSessionIDHeader); id != "" {
+ t.mcpSessionID.Store(&id)
+ }
+}
+
+func (t *streamableHTTPTransport) rewriteRequest(r *http.Request) *http.Request {
+ r = r.Clone(r.Context())
+ r.URL.Scheme = t.targetURI.Scheme
+ r.URL.Host = t.targetURI.Host
+
+ // Defaults to the endpoint defined in the app if client is not providing it.
+ // By spec, streamable HTTP should use a single endpoint except the
+ // ".well-known" used for OAuth.
+ // https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#streamable-http
+ if t.targetURI.Path != "" && (r.URL.Path == "" || r.URL.Path == "/") {
+ r.URL.Path = t.targetURI.Path
+ }
+
+ // Add in JWT headers. By default, JWT is not put into "Authorization"
+ // headers since the auth token can also come from the client and Teleport
+ // just pass it through. If the remote MCP server does verify the auth token
+ // signed by Teleport, the server can take the token from the
+ // "teleport-jwt-assertion" header or use a rewrite setting to set the JWT
+ // as "Bearer" in "Authorization".
+ r.Header.Set(teleport.AppJWTHeader, t.jwt)
+ // Add headers from rewrite configuration.
+ rewriteHeaders := appcommon.AppRewriteHeaders(r.Context(), t.App.GetRewrite(), t.logger)
+ services.RewriteHeadersAndApplyValueTraits(r, rewriteHeaders, t.traitsForRewriteHeaders, t.logger)
+ return r
+}
+
+func (t *streamableHTTPTransport) rewriteAndSendRequest(r *http.Request) (*http.Response, error) {
+ rCopy := t.rewriteRequest(r)
+ return t.targetTransport.RoundTrip(rCopy)
+}
+
+func (t *streamableHTTPTransport) handleSessionEndRequest(r *http.Request) (*http.Response, error) {
+ resp, err := t.rewriteAndSendRequest(r)
+ t.emitEndEvent(t.parentCtx, convertHTTPResponseErrorForAudit(resp, err))
+ return resp, trace.Wrap(err)
+}
+
+func (t *streamableHTTPTransport) handleListenSSEStreamRequest(r *http.Request) (*http.Response, error) {
+ resp, err := t.rewriteAndSendRequest(r)
+ t.emitListenSSEStreamEvent(t.parentCtx, convertHTTPResponseErrorForAudit(resp, err))
+ return resp, trace.Wrap(err)
+}
+
+func (t *streamableHTTPTransport) handleMCPMessage(r *http.Request) (*http.Response, error) {
+ var baseMessage mcputils.BaseJSONRPCMessage
+ if reqBody, err := utils.GetAndReplaceRequestBody(r); err != nil {
+ t.emitInvalidHTTPRequest(t.parentCtx, r)
+ return nil, trace.BadParameter("invalid request body %v", err)
+ } else if err := json.Unmarshal(reqBody, &baseMessage); err != nil {
+ t.emitInvalidHTTPRequest(t.parentCtx, r)
+ return nil, trace.BadParameter("invalid request body %v", err)
+ }
+
+ switch {
+ case baseMessage.IsRequest():
+ mcpRequest := baseMessage.MakeRequest()
+ if errResp, authErr := t.sessionHandler.processClientRequestNoAudit(r.Context(), mcpRequest); authErr != nil {
+ return t.handleRequestAuthError(r, mcpRequest, errResp, authErr)
+ }
+ case baseMessage.IsNotification():
+ // nothing to do, yet.
+ default:
+ // Not sending it to the server if we don't understand it.
+ t.emitInvalidHTTPRequest(t.parentCtx, r)
+ return nil, trace.BadParameter("not a MCP request or notification")
+ }
+
+ resp, err := t.rewriteAndSendRequest(r)
+ // Prefer session ID from server response if present. For example,
+ // "initialize" request does not have an ID but the server response may have
+ // it.
+ if resp != nil {
+ t.setExternalSessionID(resp.Header)
+ }
+
+ // Take care of audit events after round trip.
+ respErrForAudit := convertHTTPResponseErrorForAudit(resp, err)
+ switch {
+ case baseMessage.IsRequest():
+ mcpRequest := baseMessage.MakeRequest()
+ // Only emit session start if "initialize" succeeded.
+ if mcpRequest.Method == "initialize" && respErrForAudit == nil {
+ t.emitStartEvent(t.parentCtx)
+ }
+ t.emitRequestEvent(t.parentCtx, mcpRequest, respErrForAudit)
+ case baseMessage.IsNotification():
+ t.emitNotificationEvent(t.parentCtx, baseMessage.MakeNotification(), respErrForAudit)
+ }
+ return resp, trace.Wrap(err)
+}
+
+func (t *streamableHTTPTransport) handleRequestAuthError(r *http.Request, mcpRequest *mcputils.JSONRPCRequest, errResp mcp.JSONRPCMessage, authErr error) (*http.Response, error) {
+ t.emitRequestEvent(t.parentCtx, mcpRequest, authErr)
+
+ errRespAsBody, err := json.Marshal(errResp)
+ if err != nil {
+ // Should not happen. If it does, we are failing the request either way.
+ return nil, trace.Wrap(err)
+ }
+
+ httpResp := &http.Response{
+ Request: r,
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(bytes.NewReader(errRespAsBody)),
+ Header: make(http.Header),
+ }
+ httpResp.Header.Set("Content-Type", "application/json")
+ return httpResp, nil
+}
+
+func convertHTTPResponseErrorForAudit(resp *http.Response, err error) error {
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ if resp == nil {
+ // Should not happen.
+ return trace.BadParameter("missing response")
+ }
+ if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusBadRequest {
+ return nil
+ }
+ if resp.Status == "" {
+ return trace.Errorf("HTTP %d %s", resp.StatusCode, http.StatusText(resp.StatusCode))
+ }
+ return trace.Errorf("HTTP %s", resp.Status)
+}
+
+// streamableHTTPResponseReplacer is a wrapper of sessionHandler to satisfy
+// mcputils.ServerMessageProcessor.
+type streamableHTTPResponseReplacer struct {
+ *sessionHandler
+}
+
+func newHTTPResponseReplacer(sessionHandler *sessionHandler) *streamableHTTPResponseReplacer {
+ return &streamableHTTPResponseReplacer{
+ sessionHandler: sessionHandler,
+ }
+}
+
+func (p *streamableHTTPResponseReplacer) ProcessResponse(ctx context.Context, resp *mcputils.JSONRPCResponse) mcp.JSONRPCMessage {
+ return p.processServerResponse(ctx, resp)
+}
+func (p *streamableHTTPResponseReplacer) ProcessNotification(ctx context.Context, notification *mcputils.JSONRPCNotification) mcp.JSONRPCMessage {
+ p.processServerNotification(ctx, notification)
+ return notification
+}
diff --git a/lib/srv/mcp/http_test.go b/lib/srv/mcp/http_test.go
new file mode 100644
index 0000000000000..89455e4325246
--- /dev/null
+++ b/lib/srv/mcp/http_test.go
@@ -0,0 +1,205 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package mcp
+
+import (
+ "fmt"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/gravitational/trace"
+ mcpclient "github.com/mark3labs/mcp-go/client"
+ mcpclienttransport "github.com/mark3labs/mcp-go/client/transport"
+ mcpserver "github.com/mark3labs/mcp-go/server"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/gravitational/teleport/api/types"
+ apievents "github.com/gravitational/teleport/api/types/events"
+ libevents "github.com/gravitational/teleport/lib/events"
+ "github.com/gravitational/teleport/lib/events/eventstest"
+ "github.com/gravitational/teleport/lib/utils"
+ listenerutils "github.com/gravitational/teleport/lib/utils/listener"
+ "github.com/gravitational/teleport/lib/utils/mcptest"
+ sliceutils "github.com/gravitational/teleport/lib/utils/slices"
+)
+
+func Test_handleStreamableHTTP(t *testing.T) {
+ t.Parallel()
+
+ remoteMCPServer := mcpserver.NewStreamableHTTPServer(mcptest.NewServer())
+ remoteMCPHTTPServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case r.URL.Path != "/mcp":
+ // Unhappy scenario.
+ w.WriteHeader(http.StatusNotFound)
+ case r.Header.Get("Authorization") != "Bearer app-token-for-ai":
+ // Verify rewrite headers.
+ w.WriteHeader(http.StatusUnauthorized)
+ default:
+ remoteMCPServer.ServeHTTP(w, r)
+ }
+ }))
+ t.Cleanup(remoteMCPHTTPServer.Close)
+
+ app, err := types.NewAppV3(types.Metadata{
+ Name: "test-http",
+ }, types.AppSpecV3{
+ URI: fmt.Sprintf("mcp+%s/mcp", remoteMCPHTTPServer.URL),
+ Rewrite: &types.Rewrite{
+ Headers: []*types.Header{{
+ Name: "Authorization",
+ Value: "Bearer {{internal.jwt}}",
+ }},
+ },
+ })
+ require.NoError(t, err)
+
+ emitter := eventstest.MockRecorderEmitter{}
+ s, err := NewServer(ServerConfig{
+ Emitter: &emitter,
+ ParentContext: t.Context(),
+ HostID: "my-host-id",
+ AccessPoint: fakeAccessPoint{},
+ CipherSuites: utils.DefaultCipherSuites(),
+ AuthClient: mockAuthClient{},
+ })
+ require.NoError(t, err)
+
+ // Run MCP handler behind a listener.
+ var wg sync.WaitGroup
+ t.Cleanup(wg.Wait)
+ listener := listenerutils.NewInMemoryListener()
+ require.NoError(t, err)
+ defer listener.Close()
+ go func() {
+ for {
+ conn, err := listener.Accept()
+ if err != nil {
+ assert.True(t, utils.IsOKNetworkError(err))
+ return
+ }
+ wg.Go(func() {
+ defer conn.Close()
+ testCtx := setupTestContext(t, withAdminRole(t), withApp(app), withClientConn(conn))
+ assert.NoError(t, s.HandleSession(t.Context(), testCtx.SessionCtx))
+ })
+ }
+ }()
+
+ t.Run("success", func(t *testing.T) {
+ ctx := t.Context()
+ emitter.Reset()
+ mcpClientTransport, err := mcpclienttransport.NewStreamableHTTP(
+ "http://memory",
+ mcpclienttransport.WithHTTPBasicClient(listener.MakeHTTPClient()),
+ mcpclienttransport.WithContinuousListening(),
+ )
+ require.NoError(t, err)
+ client := mcpclient.NewClient(mcpClientTransport)
+ require.NoError(t, client.Start(ctx))
+
+ // Initialize client, then call a tool. Note that the order can be
+ // undeterministic as the listen request is sent from a go-routine by
+ // mcp-go client.
+ getEventCode := func(e apievents.AuditEvent) string {
+ return e.GetCode()
+ }
+ _, err = mcptest.InitializeClient(ctx, client)
+ require.NoError(t, err)
+ mcptest.MustCallServerTool(t, ctx, client)
+ require.EventuallyWithT(t, func(t *assert.CollectT) {
+ require.ElementsMatch(t, []string{
+ libevents.MCPSessionStartCode,
+ libevents.MCPSessionRequestCode, // "initialize"
+ libevents.MCPSessionNotificationCode,
+ libevents.MCPSessionListenSSEStreamCode,
+ libevents.MCPSessionRequestCode, // "tools/call"
+ }, sliceutils.Map(emitter.Events(), getEventCode))
+ }, 2*time.Second, time.Millisecond*100, "waiting for events")
+
+ // Close client and wait for end event.
+ require.NoError(t, client.Close())
+ require.EventuallyWithT(t, func(t *assert.CollectT) {
+ require.Equal(t, libevents.MCPSessionEndEvent, emitter.LastEvent().GetType())
+ }, 2*time.Second, time.Millisecond*100, "waiting for end event")
+ })
+
+ t.Run("endpoint not found", func(t *testing.T) {
+ ctx := t.Context()
+ emitter.Reset()
+ mcpClientTransport, err := mcpclienttransport.NewStreamableHTTP(
+ "http://memory/notfound",
+ mcpclienttransport.WithHTTPBasicClient(listener.MakeHTTPClient()),
+ )
+ require.NoError(t, err)
+
+ // Initialize client should fail.
+ client := mcpclient.NewClient(mcpClientTransport)
+ _, err = mcptest.InitializeClient(ctx, client)
+ require.Error(t, err)
+
+ // Close client and verify failure event.
+ events := emitter.Events()
+ require.Len(t, events, 1)
+ lastEvent, ok := events[0].(*apievents.MCPSessionRequest)
+ require.True(t, ok)
+ require.Equal(t, libevents.MCPSessionRequestEvent, lastEvent.GetType())
+ require.Equal(t, libevents.MCPSessionRequestFailureCode, lastEvent.GetCode())
+ require.False(t, lastEvent.Success)
+ require.Equal(t, "HTTP 404 Not Found", lastEvent.Error)
+ })
+}
+
+func Test_handleAuthErrHTTP(t *testing.T) {
+ s, err := NewServer(ServerConfig{
+ Emitter: &libevents.DiscardEmitter{},
+ ParentContext: t.Context(),
+ HostID: "my-host-id",
+ AccessPoint: fakeAccessPoint{},
+ CipherSuites: utils.DefaultCipherSuites(),
+ AuthClient: mockAuthClient{},
+ })
+
+ require.NoError(t, err)
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+ defer listener.Close()
+ go func() {
+ conn, err := listener.Accept()
+ if err != nil {
+ assert.True(t, utils.IsOKNetworkError(err))
+ return
+ }
+ defer conn.Close()
+ s.handleAuthErrHTTP(t.Context(), conn, trace.AccessDenied("access denied"))
+ }()
+
+ mcpClientTransport, err := mcpclienttransport.NewStreamableHTTP(
+ fmt.Sprintf("http://%s", listener.Addr().String()),
+ )
+ require.NoError(t, err)
+ client := mcpclient.NewClient(mcpClientTransport)
+ _, err = mcptest.InitializeClient(t.Context(), client)
+ require.ErrorContains(t, err, "access denied")
+}
diff --git a/lib/srv/mcp/server.go b/lib/srv/mcp/server.go
index 9bdc0de102e8e..04ae5fde991a8 100644
--- a/lib/srv/mcp/server.go
+++ b/lib/srv/mcp/server.go
@@ -33,6 +33,8 @@ import (
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/services"
+ appcommon "github.com/gravitational/teleport/lib/srv/app/common"
+ "github.com/gravitational/teleport/lib/utils"
)
// AccessPoint defines functions that the MCP server requires from the caching
@@ -42,6 +44,12 @@ type AccessPoint interface {
services.ClusterNameGetter
}
+// AuthClient defines functions that the MCP server requires from the auth
+// client.
+type AuthClient interface {
+ appcommon.AppTokenGenerator
+}
+
// ServerConfig is the config for the MCP forward server.
type ServerConfig struct {
// Emitter is used for emitting audit events.
@@ -54,6 +62,8 @@ type ServerConfig struct {
HostID string
// AccessPoint is a caching client connected to the Auth Server.
AccessPoint AccessPoint
+ // AuthClient is a client directly connected to the Auth server.
+ AuthClient AuthClient
// EnableDemoServer enables the "Teleport Demo" MCP server.
EnableDemoServer bool
// CipherSuites is the list of TLS cipher suites that have been configured
@@ -74,6 +84,9 @@ func (c *ServerConfig) CheckAndSetDefaults() error {
if c.HostID == "" {
return trace.BadParameter("missing HostID")
}
+ if c.AuthClient == nil {
+ return trace.BadParameter("missing AuthClient")
+ }
if c.AccessPoint == nil {
return trace.BadParameter("missing AccessPoint")
}
@@ -93,6 +106,8 @@ func (c *ServerConfig) CheckAndSetDefaults() error {
// TODO(greedy52) add server metrics.
type Server struct {
cfg ServerConfig
+
+ sessionCache *utils.FnCache
}
// NewServer creates a new Server.
@@ -100,8 +115,20 @@ func NewServer(cfg ServerConfig) (*Server, error) {
if err := cfg.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
+
+ cache, err := utils.NewFnCache(utils.FnCacheConfig{
+ TTL: 10 * time.Minute,
+ Context: cfg.ParentContext,
+ Clock: cfg.clock,
+ ReloadOnErr: true,
+ })
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
return &Server{
- cfg: cfg,
+ cfg: cfg,
+ sessionCache: cache,
}, nil
}
@@ -119,6 +146,8 @@ func (s *Server) HandleSession(ctx context.Context, sessionCtx *SessionCtx) erro
return trace.Wrap(s.handleStdio(ctx, sessionCtx, makeExecServerRunner))
case types.MCPTransportSSE:
return trace.Wrap(s.handleStdioToSSE(ctx, sessionCtx))
+ case types.MCPTransportHTTP:
+ return trace.Wrap(s.handleStreamableHTTP(ctx, sessionCtx))
default:
return trace.BadParameter("unknown transport type: %v", transportType)
}
@@ -127,10 +156,16 @@ func (s *Server) HandleSession(ctx context.Context, sessionCtx *SessionCtx) erro
// HandleUnauthorizedConnection handles an unauthorized client connection.
// This function has a hardcoded 30 seconds timeout in case the proper error
// message cannot be delivered to the client.
-func (s *Server) HandleUnauthorizedConnection(ctx context.Context, clientConn net.Conn, authErr error) error {
+func (s *Server) HandleUnauthorizedConnection(ctx context.Context, clientConn net.Conn, app types.Application, authErr error) error {
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel()
- return trace.Wrap(s.handleAuthErrStdio(ctx, clientConn, authErr))
+ transportType := types.GetMCPServerTransportType(app.GetURI())
+ switch transportType {
+ case types.MCPTransportHTTP:
+ return trace.Wrap(s.handleAuthErrHTTP(ctx, clientConn, authErr))
+ default:
+ return trace.Wrap(s.handleAuthErrStdio(ctx, clientConn, authErr))
+ }
}
func (s *Server) makeSessionAuditor(ctx context.Context, sessionCtx *SessionCtx, logger *slog.Logger) (*sessionAuditor, error) {
@@ -167,6 +202,7 @@ func (s *Server) makeSessionHandler(ctx context.Context, sessionCtx *SessionCtx)
logger := s.cfg.Log.With(
"client_ip", sessionCtx.ClientConn.RemoteAddr(),
"app", sessionCtx.App.GetName(),
+ "app_uri", sessionCtx.App.GetURI(),
"user", sessionCtx.AuthCtx.User.GetName(),
"session_id", sessionCtx.sessionID,
)
@@ -185,3 +221,17 @@ func (s *Server) makeSessionHandler(ctx context.Context, sessionCtx *SessionCtx)
clock: s.cfg.clock,
})
}
+
+func (s *Server) makeSessionHandlerWithJWT(ctx context.Context, sessionCtx *SessionCtx) (*sessionHandler, error) {
+ if err := sessionCtx.generateJWTAndTraits(ctx, s.cfg.AuthClient); err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return s.makeSessionHandler(ctx, sessionCtx)
+}
+
+func (s *Server) getSessionHandlerWithJWT(ctx context.Context, sessionCtx *SessionCtx) (*sessionHandler, error) {
+ ttl := min(sessionCtx.Identity.Expires.Sub(s.cfg.clock.Now()), 10*time.Minute)
+ return utils.FnCacheGetWithTTL(ctx, s.sessionCache, sessionCtx.sessionID, ttl, func(ctx context.Context) (*sessionHandler, error) {
+ return s.makeSessionHandlerWithJWT(ctx, sessionCtx)
+ })
+}
diff --git a/lib/srv/mcp/session.go b/lib/srv/mcp/session.go
index 29be7cb08150f..07fff770e914b 100644
--- a/lib/srv/mcp/session.go
+++ b/lib/srv/mcp/session.go
@@ -31,10 +31,12 @@ import (
"github.com/mark3labs/mcp-go/mcp"
"github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/api/types/wrappers"
"github.com/gravitational/teleport/lib/authz"
dtauthz "github.com/gravitational/teleport/lib/devicetrust/authz"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/session"
+ appcommon "github.com/gravitational/teleport/lib/srv/app/common"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/mcputils"
@@ -59,6 +61,12 @@ type SessionCtx struct {
// mcpSessionID is the MCP session ID tracked by remote MCP server.
mcpSessionID atomicString
+
+ // jwt is the jwt token signed for this identity by Auth server.
+ jwt string
+
+ // traitsForRewriteHeaders are user traits used for rewriting headers.
+ traitsForRewriteHeaders wrappers.Traits
}
func (c *SessionCtx) checkAndSetDefaults() error {
@@ -75,7 +83,14 @@ func (c *SessionCtx) checkAndSetDefaults() error {
c.Identity = c.AuthCtx.Identity.GetIdentity()
}
if c.sessionID == "" {
- c.sessionID = session.NewID()
+ if types.MCPTransportHTTP == types.GetMCPServerTransportType(c.App.GetURI()) {
+ // A single HTTP request is handled at a time so take session ID
+ // from cert.
+ c.sessionID = session.ID(c.Identity.RouteToApp.SessionID)
+ }
+ if c.sessionID == "" {
+ c.sessionID = session.NewID()
+ }
}
return nil
}
@@ -89,6 +104,11 @@ func (c *SessionCtx) getAccessState(authPref types.AuthPreference) services.Acce
return state
}
+func (c *SessionCtx) generateJWTAndTraits(ctx context.Context, auth AuthClient) (err error) {
+ c.jwt, c.traitsForRewriteHeaders, err = appcommon.GenerateJWTAndTraits(ctx, &c.Identity, c.App, auth)
+ return trace.Wrap(err)
+}
+
type sessionHandlerConfig struct {
*SessionCtx
*sessionAuditor
@@ -198,6 +218,7 @@ func (s *sessionHandler) onClientRequest(clientResponseWriter, serverRequestWrit
func (s *sessionHandler) onServerNotification(clientResponseWriter mcputils.MessageWriter) mcputils.HandleNotificationFunc {
return func(ctx context.Context, notification *mcputils.JSONRPCNotification) error {
+ s.processServerNotification(ctx, notification)
return trace.Wrap(clientResponseWriter.WriteMessage(ctx, notification))
}
}
@@ -217,17 +238,27 @@ const (
)
func (s *sessionHandler) processClientRequest(ctx context.Context, req *mcputils.JSONRPCRequest) (mcp.JSONRPCMessage, replyDirection) {
+ s.idTracker.PushRequest(req)
+ reply, authErr := s.processClientRequestNoAudit(ctx, req)
+ s.emitRequestEvent(ctx, req, authErr)
+
+ // Not forwarding to server. Just send the auth error to client.
+ if authErr != nil {
+ return reply, replyToClient
+ }
+ return reply, replyToServer
+}
+
+func (s *sessionHandler) processClientRequestNoAudit(ctx context.Context, req *mcputils.JSONRPCRequest) (mcp.JSONRPCMessage, error) {
s.idTracker.PushRequest(req)
switch req.Method {
case mcp.MethodToolsCall:
methodName, _ := req.Params.GetName()
if authErr := s.checkAccessToTool(ctx, methodName); authErr != nil {
- s.emitRequestEvent(ctx, req, authErr)
- return makeToolAccessDeniedResponse(req, authErr), replyToClient
+ return makeToolAccessDeniedResponse(req, authErr), trace.Wrap(authErr)
}
}
- s.emitRequestEvent(ctx, req, nil)
- return req, replyToServer
+ return req, nil
}
func (s *sessionHandler) processServerResponse(ctx context.Context, response *mcputils.JSONRPCResponse) mcp.JSONRPCMessage {
@@ -239,6 +270,10 @@ func (s *sessionHandler) processServerResponse(ctx context.Context, response *mc
return response
}
+func (s *sessionHandler) processServerNotification(ctx context.Context, notification *mcputils.JSONRPCNotification) {
+ s.logger.DebugContext(ctx, "Received server notification.", "method", notification.Method)
+}
+
func (s *sessionHandler) makeToolsCallResponse(ctx context.Context, resp *mcputils.JSONRPCResponse) mcp.JSONRPCMessage {
// Nothing to do, likely an error response.
if resp.Result == nil {
diff --git a/lib/srv/mcp/sse.go b/lib/srv/mcp/sse.go
index 2384fa7b17347..63fc39bd6f8b0 100644
--- a/lib/srv/mcp/sse.go
+++ b/lib/srv/mcp/sse.go
@@ -52,6 +52,7 @@ func (s *Server) handleStdioToSSE(ctx context.Context, sessionCtx *SessionCtx) e
if err != nil {
return trace.Wrap(err, "creating HTTP transport")
}
+ // TODO(greedy52) support JWT for SSE transport.
session, err := s.makeSessionHandler(ctx, sessionCtx)
if err != nil {
return trace.Wrap(err, "setting up session handler")
diff --git a/lib/srv/mcp/sse_test.go b/lib/srv/mcp/sse_test.go
index 855675d4b9d6e..8ec15e972f685 100644
--- a/lib/srv/mcp/sse_test.go
+++ b/lib/srv/mcp/sse_test.go
@@ -51,6 +51,7 @@ func Test_handleStdioToSSE(t *testing.T) {
HostID: "my-host-id",
AccessPoint: fakeAccessPoint{},
CipherSuites: utils.DefaultCipherSuites(),
+ AuthClient: mockAuthClient{},
})
require.NoError(t, err)
diff --git a/lib/srv/mcp/stdio_test.go b/lib/srv/mcp/stdio_test.go
index 044ebea39428e..6a90af1e09a03 100644
--- a/lib/srv/mcp/stdio_test.go
+++ b/lib/srv/mcp/stdio_test.go
@@ -46,20 +46,21 @@ func Test_handleAuthErrStdio(t *testing.T) {
HostID: "my-host-id",
AccessPoint: fakeAccessPoint{},
CipherSuites: utils.DefaultCipherSuites(),
+ AuthClient: mockAuthClient{},
})
require.NoError(t, err)
- clientSourceConn, clientDestConn := makeDualPipeNetConn(t)
+ testCtx := setupTestContext(t, withAdminRole(t))
originalAuthErr := trace.AccessDenied("test access denied")
handlerDoneCh := make(chan struct{}, 1)
go func() {
- handlerErr := s.HandleUnauthorizedConnection(ctx, clientDestConn, originalAuthErr)
+ handlerErr := s.HandleUnauthorizedConnection(ctx, testCtx.SessionCtx.ClientConn, testCtx.SessionCtx.App, originalAuthErr)
handlerDoneCh <- struct{}{}
require.ErrorIs(t, handlerErr, originalAuthErr)
}()
- stdioClient := mcptest.NewStdioClientFromConn(t, clientSourceConn)
+ stdioClient := mcptest.NewStdioClientFromConn(t, testCtx.clientSourceConn)
_, err = mcptest.InitializeClient(ctx, stdioClient)
require.EqualError(t, err, originalAuthErr.Error())
@@ -80,6 +81,7 @@ func Test_handleStdio(t *testing.T) {
HostID: "my-host-id",
AccessPoint: fakeAccessPoint{},
CipherSuites: utils.DefaultCipherSuites(),
+ AuthClient: mockAuthClient{},
})
require.NoError(t, err)
@@ -154,6 +156,7 @@ func TestHandleSession_execMCPServer(t *testing.T) {
HostID: "my-host-id",
AccessPoint: fakeAccessPoint{},
CipherSuites: utils.DefaultCipherSuites(),
+ AuthClient: mockAuthClient{},
})
require.NoError(t, err)
diff --git a/lib/utils/listener/memory.go b/lib/utils/listener/memory.go
index 5f7795a69952d..107a56ed18edb 100644
--- a/lib/utils/listener/memory.go
+++ b/lib/utils/listener/memory.go
@@ -21,6 +21,7 @@ import (
"errors"
"io"
"net"
+ "net/http"
"sync"
)
@@ -86,6 +87,15 @@ func (m *InMemoryListener) DialContext(ctx context.Context, _ string, _ string)
return clientConn, nil
}
+// MakeHTTPClient is a helper to generate an HTTP client that dials this listener.
+func (m *InMemoryListener) MakeHTTPClient() *http.Client {
+ return &http.Client{
+ Transport: &http.Transport{
+ DialContext: m.DialContext,
+ },
+ }
+}
+
// ErrListenerClosed is the error returned by dial when the listener is closed.
var ErrListenerClosed = errors.New("in-memory listener closed")
diff --git a/tool/tsh/common/proxy.go b/tool/tsh/common/proxy.go
index 3bb3421ebd0a4..86605d3c54ae1 100644
--- a/tool/tsh/common/proxy.go
+++ b/tool/tsh/common/proxy.go
@@ -516,7 +516,13 @@ func onProxyCommandApp(cf *CLIConf) error {
}
if app.IsMCP() {
- return trace.BadParameter("MCP applications are not supported. Please see 'tsh mcp config --help' for more details.")
+ // TODO(greedy52) refactor and implement "tsh proxy mcp".
+ switch types.GetMCPServerTransportType(app.GetURI()) {
+ case types.MCPTransportHTTP:
+ // continue
+ default:
+ return trace.BadParameter("MCP applications are not supported. Please see 'tsh mcp config --help' for more details.")
+ }
}
proxyApp, err := newLocalProxyAppWithPortMapping(cf.Context, tc, profile, appInfo.RouteToApp, app, portMapping, cf.InsecureSkipVerify)