Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions tool/tsh/common/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -515,14 +515,8 @@ func onProxyCommandApp(cf *CLIConf) error {
return trace.Wrap(err)
}

if app.IsMCP() {
// TODO(greedy52) refactor and implement "tsh proxy mcp".
switch types.GetMCPServerTransportType(app.GetURI()) {
case types.MCPTransportHTTP:
// continue
default:
return trace.BadParameter("MCP applications are not supported. Please see 'tsh mcp config --help' for more details.")
}
if err := checkProxyMCPCompatibility(cf.command, app); err != nil {
return trace.Wrap(err)
}

proxyApp, err := newLocalProxyAppWithPortMapping(cf.Context, tc, profile, appInfo.RouteToApp, app, portMapping, cf.InsecureSkipVerify)
Expand Down Expand Up @@ -554,6 +548,26 @@ func onProxyCommandApp(cf *CLIConf) error {
return nil
}

func checkProxyMCPCompatibility(command string, app types.Application) error {
if !app.IsMCP() {
switch command {
case "proxy mcp":
Comment thread
Tener marked this conversation as resolved.
return trace.BadParameter("%q is not an MCP application", app.GetName())
default:
// tsh proxy app
return nil
}
}

mcpTransport := types.GetMCPServerTransportType(app.GetURI())
switch mcpTransport {
case types.MCPTransportHTTP:
return nil
default:
return trace.BadParameter("MCP applications with %s transport are not supported. Please see 'tsh mcp config --help' for more details.", mcpTransport)
}
}

// onProxyCommandAWS creates local proxes for AWS apps.
func onProxyCommandAWS(cf *CLIConf) error {
if err := checkProxyAWSFormatCompatibility(cf); err != nil {
Expand Down
58 changes: 58 additions & 0 deletions tool/tsh/common/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1754,3 +1754,61 @@ func mustDialLocalAppProxy(t *testing.T, port string, expectedName string) {
require.Equal(t, expectedName, r.Header.Get("Server"), "the response header \"Server\" does not have the expected value")
}, 5*time.Second, 50*time.Millisecond)
}

func Test_checkProxyMCPCompatibility(t *testing.T) {
tests := []struct {
name string
command string
appURI string
checkResult require.ErrorAssertionFunc
}{
{
name: "streamable HTTP allowed for tsh proxy app",
command: "proxy app",
appURI: "mcp+http://example.com/mcp",
checkResult: require.NoError,
},
{
name: "streamable HTTP allowed for tsh proxy mcp",
command: "proxy mcp",
appURI: "mcp+http://example.com/mcp",
checkResult: require.NoError,
},
{
name: "unsupported transport fails for tsh proxy app",
command: "proxy app",
appURI: "mcp+sse+http://example.com/sse",
checkResult: require.Error,
},
{
name: "unsupported transport fails for tsh proxy mcp",
command: "proxy mcp",
appURI: "mcp+sse+http://example.com/sse",
checkResult: require.Error,
},
{
name: "regular app allowed for tsh proxy app",
command: "proxy app",
appURI: "http://example.com",
checkResult: require.NoError,
},
{
name: "regular app fails for tsh proxy mcp",
command: "proxy mcp",
appURI: "http://example.com",
checkResult: require.Error,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
app, err := types.NewAppV3(types.Metadata{
Name: t.Name(),
}, types.AppSpecV3{
URI: tt.appURI,
})
require.NoError(t, err)
tt.checkResult(t, checkProxyMCPCompatibility(tt.command, app))
})
}
}
7 changes: 7 additions & 0 deletions tool/tsh/common/tsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,11 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error {
proxyApp.Flag("port", "Specifies the listening port used by by the proxy app listener. Accepts an optional target port of a multi-port TCP app after a colon, e.g. \"1234:5678\".").Short('p').StringVar(&cf.LocalProxyPortMapping)
proxyApp.Flag("cluster", clusterHelp).Short('c').StringVar(&cf.SiteName)

proxyMCP := proxy.Command("mcp", "Start local proxy for MCP access.")
proxyMCP.Arg("app", "The name of the MCP application to start local proxy for.").Required().StringVar(&cf.AppName)
proxyMCP.Flag("port", "Specifies the listening port used by by the proxy app listener.").Short('p').StringVar(&cf.LocalProxyPortMapping)
proxyMCP.Flag("cluster", clusterHelp).Short('c').StringVar(&cf.SiteName)

proxyAWS := proxy.Command("aws", "Start local proxy for AWS access.")
proxyAWS.Flag("app", "Optional Name of the AWS application to use if logged into multiple.").StringVar(&cf.AppName)
proxyAWS.Flag("port", "Specifies the source port used by the proxy listener.").Short('p').StringVar(&cf.LocalProxyPort)
Expand Down Expand Up @@ -1781,6 +1786,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error {
err = onProxyCommandDB(&cf)
case proxyApp.FullCommand():
err = onProxyCommandApp(&cf)
case proxyMCP.FullCommand():
err = onProxyCommandApp(&cf)
case proxyAWS.FullCommand():
err = onProxyCommandAWS(&cf)
case proxyAzure.FullCommand():
Expand Down
Loading