diff --git a/tool/tsh/common/proxy.go b/tool/tsh/common/proxy.go index 86605d3c54ae1..34ac1d6c3b81f 100644 --- a/tool/tsh/common/proxy.go +++ b/tool/tsh/common/proxy.go @@ -515,14 +515,8 @@ func onProxyCommandApp(cf *CLIConf) error { return trace.Wrap(err) } - if app.IsMCP() { - // TODO(greedy52) refactor and implement "tsh proxy mcp". - switch types.GetMCPServerTransportType(app.GetURI()) { - case types.MCPTransportHTTP: - // continue - default: - return trace.BadParameter("MCP applications are not supported. Please see 'tsh mcp config --help' for more details.") - } + if err := checkProxyMCPCompatibility(cf.command, app); err != nil { + return trace.Wrap(err) } proxyApp, err := newLocalProxyAppWithPortMapping(cf.Context, tc, profile, appInfo.RouteToApp, app, portMapping, cf.InsecureSkipVerify) @@ -554,6 +548,26 @@ func onProxyCommandApp(cf *CLIConf) error { return nil } +func checkProxyMCPCompatibility(command string, app types.Application) error { + if !app.IsMCP() { + switch command { + case "proxy mcp": + return trace.BadParameter("%q is not an MCP application", app.GetName()) + default: + // tsh proxy app + return nil + } + } + + mcpTransport := types.GetMCPServerTransportType(app.GetURI()) + switch mcpTransport { + case types.MCPTransportHTTP: + return nil + default: + return trace.BadParameter("MCP applications with %s transport are not supported. Please see 'tsh mcp config --help' for more details.", mcpTransport) + } +} + // onProxyCommandAWS creates local proxes for AWS apps. func onProxyCommandAWS(cf *CLIConf) error { if err := checkProxyAWSFormatCompatibility(cf); err != nil { diff --git a/tool/tsh/common/proxy_test.go b/tool/tsh/common/proxy_test.go index d0039df0a8533..424d7c30ef6cb 100644 --- a/tool/tsh/common/proxy_test.go +++ b/tool/tsh/common/proxy_test.go @@ -1754,3 +1754,61 @@ func mustDialLocalAppProxy(t *testing.T, port string, expectedName string) { require.Equal(t, expectedName, r.Header.Get("Server"), "the response header \"Server\" does not have the expected value") }, 5*time.Second, 50*time.Millisecond) } + +func Test_checkProxyMCPCompatibility(t *testing.T) { + tests := []struct { + name string + command string + appURI string + checkResult require.ErrorAssertionFunc + }{ + { + name: "streamable HTTP allowed for tsh proxy app", + command: "proxy app", + appURI: "mcp+http://example.com/mcp", + checkResult: require.NoError, + }, + { + name: "streamable HTTP allowed for tsh proxy mcp", + command: "proxy mcp", + appURI: "mcp+http://example.com/mcp", + checkResult: require.NoError, + }, + { + name: "unsupported transport fails for tsh proxy app", + command: "proxy app", + appURI: "mcp+sse+http://example.com/sse", + checkResult: require.Error, + }, + { + name: "unsupported transport fails for tsh proxy mcp", + command: "proxy mcp", + appURI: "mcp+sse+http://example.com/sse", + checkResult: require.Error, + }, + { + name: "regular app allowed for tsh proxy app", + command: "proxy app", + appURI: "http://example.com", + checkResult: require.NoError, + }, + { + name: "regular app fails for tsh proxy mcp", + command: "proxy mcp", + appURI: "http://example.com", + checkResult: require.Error, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + app, err := types.NewAppV3(types.Metadata{ + Name: t.Name(), + }, types.AppSpecV3{ + URI: tt.appURI, + }) + require.NoError(t, err) + tt.checkResult(t, checkProxyMCPCompatibility(tt.command, app)) + }) + } +} diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index ef408f3ea1722..8b85f60a25b0a 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -1084,6 +1084,11 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { proxyApp.Flag("port", "Specifies the listening port used by by the proxy app listener. Accepts an optional target port of a multi-port TCP app after a colon, e.g. \"1234:5678\".").Short('p').StringVar(&cf.LocalProxyPortMapping) proxyApp.Flag("cluster", clusterHelp).Short('c').StringVar(&cf.SiteName) + proxyMCP := proxy.Command("mcp", "Start local proxy for MCP access.") + proxyMCP.Arg("app", "The name of the MCP application to start local proxy for.").Required().StringVar(&cf.AppName) + proxyMCP.Flag("port", "Specifies the listening port used by by the proxy app listener.").Short('p').StringVar(&cf.LocalProxyPortMapping) + proxyMCP.Flag("cluster", clusterHelp).Short('c').StringVar(&cf.SiteName) + proxyAWS := proxy.Command("aws", "Start local proxy for AWS access.") proxyAWS.Flag("app", "Optional Name of the AWS application to use if logged into multiple.").StringVar(&cf.AppName) proxyAWS.Flag("port", "Specifies the source port used by the proxy listener.").Short('p').StringVar(&cf.LocalProxyPort) @@ -1781,6 +1786,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { err = onProxyCommandDB(&cf) case proxyApp.FullCommand(): err = onProxyCommandApp(&cf) + case proxyMCP.FullCommand(): + err = onProxyCommandApp(&cf) case proxyAWS.FullCommand(): err = onProxyCommandAWS(&cf) case proxyAzure.FullCommand():