Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion api/types/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ func (a *AppV3) checkMCP() error {
switch GetMCPServerTransportType(a.Spec.URI) {
case MCPTransportStdio:
return trace.Wrap(a.checkMCPStdio())
case MCPTransportSSE:
case MCPTransportSSE, MCPTransportHTTP:
_, err := url.Parse(a.Spec.URI)
return trace.Wrap(err)
default:
Expand Down Expand Up @@ -690,6 +690,8 @@ func GetMCPServerTransportType(uri string) string {
return MCPTransportStdio
case SchemeMCPSSEHTTP, SchemeMCPSSEHTTPS:
return MCPTransportSSE
case SchemeMCPHTTP, SchemeMCPHTTPS:
return MCPTransportHTTP
default:
return ""
}
Expand Down
25 changes: 24 additions & 1 deletion api/types/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ func TestNewAppV3(t *testing.T) {
wantErr: require.NoError,
},
{
name: "mcp with uri",
name: "mcp with SSE transport",
meta: Metadata{
Name: "mcp-everything",
},
Expand All @@ -711,6 +711,29 @@ func TestNewAppV3(t *testing.T) {
},
wantErr: require.NoError,
},
{
name: "mcp with streamable HTTP transport",
meta: Metadata{
Name: "mcp-everything",
},
spec: AppSpecV3{
URI: "mcp+http://localhost:12345/mcp",
},
want: &AppV3{
Kind: "app",
SubKind: "mcp",
Version: "v3",
Metadata: Metadata{
Name: "mcp-everything",
Namespace: "default",
Labels: map[string]string{AppSubKindLabel: "mcp"},
},
Spec: AppSpecV3{
URI: "mcp+http://localhost:12345/mcp",
},
},
wantErr: require.NoError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
8 changes: 8 additions & 0 deletions api/types/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,14 @@ const (
SchemeMCPSSEHTTPS = "mcp+sse+https"
// MCPTransportSSE indicates the MCP server uses SSE transport.
MCPTransportSSE = "SSE"
// SchemeMCPHTTP is a URI scheme for MCP servers using HTTP with streamable
// HTTP transport.
SchemeMCPHTTP = "mcp+http"
// SchemeMCPHTTPS is a URI scheme for MCP servers using HTTPS with
// streamable HTTP transport.
SchemeMCPHTTPS = "mcp+https"
// MCPTransportHTTP indicates the MCP server uses SSE transport.
MCPTransportHTTP = "Streamable HTTP"

// DiscoveredResourceNode identifies a discovered SSH node.
DiscoveredResourceNode = "node"
Expand Down
20 changes: 16 additions & 4 deletions integration/appaccess/appaccess_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
Expand All @@ -35,6 +36,7 @@ import (
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -62,10 +64,20 @@ import (
func TestAppAccess(t *testing.T) {
// Enable MCP test servers.
sseServerURL := mcptest.MustStartSSETestServer(t)
extraApps := []servicecfg.App{{
Name: "test-sse",
URI: "mcp+sse+" + sseServerURL,
}}
streamableHTTPServer := mcpserver.NewTestStreamableHTTPServer(mcptest.NewServer())
streamableHTTPServerURL := fmt.Sprintf("mcp+%s/mcp", streamableHTTPServer.URL)
t.Cleanup(streamableHTTPServer.Close)

extraApps := []servicecfg.App{
{
Name: "test-sse",
URI: "mcp+sse+" + sseServerURL,
},
{
Name: "test-http",
URI: streamableHTTPServerURL,
},
}

// Reusing the pack as much as we can.
pack := SetupWithOptions(t, AppTestOptions{
Expand Down
59 changes: 59 additions & 0 deletions integration/appaccess/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,19 @@ package appaccess

import (
"context"
"crypto/tls"
"fmt"
"net/http"
"testing"

mcpclient "github.com/mark3labs/mcp-go/client"
mcpclienttransport "github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/client"
libmcp "github.com/gravitational/teleport/lib/srv/mcp"
"github.com/gravitational/teleport/lib/utils/mcptest"
)
Expand All @@ -41,6 +49,10 @@ func testMCP(pack *Pack, t *testing.T) {
t.Run("DialMCPServer stdio to sse success", func(t *testing.T) {
testMCPDialStdioToSSE(t, pack, "test-sse")
})

t.Run("proxy streamable HTTP requests with TLS cert", func(t *testing.T) {
testMCPProxyStreamableHTTP(t, pack, "test-http")
})
}

func testMCPDialStdioNoServerFound(t *testing.T, pack *Pack) {
Expand Down Expand Up @@ -81,3 +93,50 @@ func testMCPDialStdioToSSE(t *testing.T, pack *Pack, appName string) {

mcptest.MustCallServerTool(t, ctx, stdioClient)
}

func testMCPProxyStreamableHTTP(t *testing.T, pack *Pack, appName string) {
require.NoError(t, pack.tc.SaveProfile(false))

// Find the MCP server.
filter := pack.tc.ResourceFilter(types.KindAppServer)
filter.PredicateExpression = fmt.Sprintf(`name == "%s"`, appName)
apps, err := pack.tc.ListApps(t.Context(), filter)
require.NoError(t, err)
require.Len(t, apps, 1)

// Issue a TLS cert with app route.
keyRing, err := pack.tc.IssueUserCertsWithMFA(t.Context(), client.ReissueParams{
RouteToCluster: pack.rootCluster.Secrets.SiteName,
RouteToApp: proto.RouteToApp{
ClusterName: pack.rootCluster.Secrets.SiteName,
Name: apps[0].GetName(),
PublicAddr: apps[0].GetPublicAddr(),
},
})
require.NoError(t, err)
appCert, err := keyRing.AppTLSCert(appName)
require.NoError(t, err)

// Create an MCP client with app cert.
ctx := t.Context()
mcpClientTransport, err := mcpclienttransport.NewStreamableHTTP(
"https://"+pack.rootCluster.Web,
mcpclienttransport.WithHTTPBasicClient(&http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
Certificates: []tls.Certificate{appCert},
InsecureSkipVerify: true,
},
},
}),
)
require.NoError(t, err)
client := mcpclient.NewClient(mcpClientTransport)
require.NoError(t, client.Start(ctx))
defer client.Close()

// Initialize client and call a tool.
_, err = mcptest.InitializeClient(ctx, client)
require.NoError(t, err)
mcptest.MustCallServerTool(t, ctx, client)
}
6 changes: 6 additions & 0 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -5567,6 +5567,12 @@ func (tc *TeleportClient) DialMCPServer(ctx context.Context, appName string) (ne
return nil, trace.BadParameter("app %q is not a MCP server", appName)
}

// TODO(greedy52) support streamable HTTP for "tsh mcp connect" before
// release.
if transport := types.GetMCPServerTransportType(apps[0].GetURI()); transport == types.MCPTransportHTTP {
return nil, trace.NotImplemented("MCP support for %s is not yet implemented", transport)
}

cert, err := tc.issueMCPCertWithMFA(ctx, apps[0])
if err != nil {
return nil, trace.Wrap(err)
Expand Down
3 changes: 2 additions & 1 deletion lib/srv/app/connections_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ func NewConnectionsHandler(closeContext context.Context, cfg *ConnectionsHandler
AccessPoint: c.cfg.AccessPoint,
EnableDemoServer: c.cfg.MCPDemoServer,
CipherSuites: c.cfg.CipherSuites,
AuthClient: c.cfg.AuthClient,
})
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -614,7 +615,7 @@ func (c *ConnectionsHandler) handleConnection(conn net.Conn) (func(), error) {
case app.IsTCP():
return nil, trace.Wrap(err)
case app.IsMCP():
return nil, trace.Wrap(c.mcpServer.HandleUnauthorizedConnection(ctx, conn, err))
return nil, trace.Wrap(c.mcpServer.HandleUnauthorizedConnection(ctx, conn, app, err))
default:
c.setConnAuth(tlsConn, err)
}
Expand Down
36 changes: 30 additions & 6 deletions lib/srv/mcp/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ func TestMain(m *testing.M) {
}

type setupTestContextOptions struct {
roleSet services.RoleSet
app types.Application
roleSet services.RoleSet
app types.Application
clientConn net.Conn
}

type setupTestContextOptionFunc func(*setupTestContextOptions)
Expand All @@ -69,6 +70,12 @@ func withRole(role types.Role) setupTestContextOptionFunc {
}
}

func withClientConn(conn net.Conn) setupTestContextOptionFunc {
return func(opts *setupTestContextOptions) {
opts.clientConn = conn
}
}

// withAdminRole assigns to ai_user a role that allows all MCP servers and their
// tools.
func withAdminRole(t *testing.T) setupTestContextOptionFunc {
Expand Down Expand Up @@ -139,8 +146,13 @@ func setupTestContext(t *testing.T, applyOpts ...setupTestContextOptionFunc) tes
applyOpt(&opts)
}

// Fake connection.
clientSourceConn, clientDestConn := makeDualPipeNetConn(t)
// Fake connection if not passed in.
var clientSourceConn, clientDestConn net.Conn
if opts.clientConn != nil {
clientDestConn = opts.clientConn
} else {
clientSourceConn, clientDestConn = makeDualPipeNetConn(t)
}

// App.
if opts.app == nil {
Expand All @@ -164,7 +176,7 @@ func setupTestContext(t *testing.T, applyOpts ...setupTestContextOptionFunc) tes
sessionCtx := &SessionCtx{
ClientConn: clientDestConn,
App: opts.app,
AuthCtx: makeTestAuthContext(t, opts.roleSet),
AuthCtx: makeTestAuthContext(t, opts.roleSet, opts.app),
}
require.NoError(t, sessionCtx.checkAndSetDefaults())

Expand All @@ -174,7 +186,7 @@ func setupTestContext(t *testing.T, applyOpts ...setupTestContextOptionFunc) tes
}
}

func makeTestAuthContext(t *testing.T, roleSet services.RoleSet) *authz.Context {
func makeTestAuthContext(t *testing.T, roleSet services.RoleSet, app types.Application) *authz.Context {
t.Helper()

user, err := types.NewUser("ai")
Expand All @@ -189,6 +201,11 @@ func makeTestAuthContext(t *testing.T, roleSet services.RoleSet) *authz.Context
Principals: user.GetLogins(),
},
}
if app != nil {
identity.Identity.RouteToApp.Name = app.GetName()
identity.Identity.RouteToApp.SessionID = "session-id-for+" + app.GetName()
}

accessInfo, err := services.AccessInfoFromLocalTLSIdentity(identity.Identity)
require.NoError(t, err)
checker := services.NewAccessCheckerWithRoleSet(accessInfo, "my-cluster", roleSet)
Expand Down Expand Up @@ -327,3 +344,10 @@ func forceRemoveContainer(t *testing.T, dockerClient *docker.Client, containerNa
}
}
}

type mockAuthClient struct {
}

func (m mockAuthClient) GenerateAppToken(_ context.Context, req types.GenerateAppTokenRequest) (string, error) {
return "app-token-for-" + req.Username, nil
}
Loading
Loading