diff --git a/api/types/app.go b/api/types/app.go index 146f9de88b756..ae87481431006 100644 --- a/api/types/app.go +++ b/api/types/app.go @@ -18,6 +18,7 @@ package types import ( "fmt" + "iter" "net/url" "slices" "strconv" @@ -309,6 +310,9 @@ func (a *AppV3) GetProtocol() string { if a.IsTCP() { return "TCP" } + if a.IsMCP() { + return "MCP" + } return "HTTP" } @@ -565,18 +569,27 @@ func (a *AppV3) GetMCP() *MCP { // DeduplicateApps deduplicates apps by combination of app name and public address. // Apps can have the same name but also could have different addresses. -func DeduplicateApps(apps []Application) (result []Application) { +func DeduplicateApps(apps []Application) []Application { + return slices.Collect(DeduplicatedApps(slices.Values(apps))) +} + +// DeduplicatedApps iterates deduplicated apps by combination of app name and +// public address. This is the iter.Seq version of DeduplicateApps. +func DeduplicatedApps(apps iter.Seq[Application]) iter.Seq[Application] { type key struct{ name, addr string } seen := make(map[key]struct{}) - for _, app := range apps { - key := key{app.GetName(), app.GetPublicAddr()} - if _, ok := seen[key]; ok { - continue + return func(yield func(Application) bool) { + for app := range apps { + key := key{app.GetName(), app.GetPublicAddr()} + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + if !yield(app) { + return + } } - seen[key] = struct{}{} - result = append(result, app) } - return result } // Apps is a list of app resources. diff --git a/api/types/app_test.go b/api/types/app_test.go index c9d5e3da25d20..afb494f376082 100644 --- a/api/types/app_test.go +++ b/api/types/app_test.go @@ -18,6 +18,7 @@ package types import ( "fmt" + "slices" "strconv" "testing" @@ -736,3 +737,19 @@ func TestGetMCPServerTransportType(t *testing.T) { }) } } + +func TestDeduplicateApps(t *testing.T) { + var apps []Application + for _, name := range []string{"a", "b", "c", "b", "a", "d"} { + app_, err := NewAppV3(Metadata{ + Name: name, + }, AppSpecV3{ + URI: "localhost:3080", + }) + require.NoError(t, err) + apps = append(apps, app_) + } + + deduped := DeduplicateApps(apps) + require.Equal(t, []string{"a", "b", "c", "d"}, slices.Collect(ResourceNames(deduped))) +} diff --git a/api/types/appserver.go b/api/types/appserver.go index f2094e25391af..449b7ce98610a 100644 --- a/api/types/appserver.go +++ b/api/types/appserver.go @@ -18,6 +18,8 @@ package types import ( "fmt" + "iter" + "slices" "sort" "time" @@ -26,6 +28,7 @@ import ( "github.com/gravitational/teleport/api" "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/utils" + "github.com/gravitational/teleport/api/utils/iterutils" ) // AppServer represents a single proxied web app. @@ -409,3 +412,10 @@ func (s AppServers) GetFieldVals(field string) ([]string, error) { return vals, nil } + +// Applications iterates over the applications that the AppServers proxy. +func (s AppServers) Applications() iter.Seq[Application] { + return iterutils.Map(func(appServer AppServer) Application { + return appServer.GetApp() + }, slices.Values(s)) +} diff --git a/api/types/resource.go b/api/types/resource.go index 571572ef7235a..9690a80f8a8f3 100644 --- a/api/types/resource.go +++ b/api/types/resource.go @@ -92,6 +92,11 @@ func ResourceNames[R Resource, S ~[]R](s S) iter.Seq[string] { return iterutils.Map(GetName, slices.Values(s)) } +// CompareResourceByNames compares resources by their names. +func CompareResourceByNames[R Resource](a, b R) int { + return strings.Compare(a.GetName(), b.GetName()) +} + // ResourceDetails includes details about the resource type ResourceDetails struct { Hostname string diff --git a/lib/services/role.go b/lib/services/role.go index f5c709110abf2..151d81800ec42 100644 --- a/lib/services/role.go +++ b/lib/services/role.go @@ -1029,6 +1029,22 @@ func (result *EnumerationResult) WildcardDenied() bool { return result.wildcardDenied } +// ToEntities converts result back to allowed and denied entity slices. +// +// If wildcard is denied, only "*" is returned for the denied slice. +// If wildcard is allowed, allowed entities will be appended to the allowed +// slice after the "*" as a hint for users to select. +// Denied entities is only included if the wildcard is allowed. +func (result *EnumerationResult) ToEntities() (allowed, denied []string) { + if result.wildcardDenied { + return nil, []string{types.Wildcard} + } + if result.wildcardAllowed { + return append([]string{types.Wildcard}, result.Allowed()...), result.Denied() + } + return result.Allowed(), nil +} + // NewEnumerationResult returns new EnumerationResult. func NewEnumerationResult() EnumerationResult { return EnumerationResult{ @@ -1038,6 +1054,34 @@ func NewEnumerationResult() EnumerationResult { } } +// NewEnumerationResultFromEntities creates a new EnumerationResult and +// populates the result with provided allowed and denied entries. +func NewEnumerationResultFromEntities(allowed, denied []string) EnumerationResult { + var wildcardAllowed bool + var wildcardDenied bool + allowedDeniedMap := make(map[string]bool) + for _, allow := range allowed { + if allow == types.Wildcard { + wildcardAllowed = true + } else { + allowedDeniedMap[allow] = true + } + } + for _, deny := range denied { + if deny == types.Wildcard { + wildcardDenied = true + wildcardAllowed = false + break + } + allowedDeniedMap[deny] = false + } + return EnumerationResult{ + allowedDeniedMap: allowedDeniedMap, + wildcardAllowed: wildcardAllowed, + wildcardDenied: wildcardDenied, + } +} + // MatchNamespace returns true if given list of namespace matches // target namespace, wildcard matches everything. func MatchNamespace(selectors []string, namespace string) (bool, string) { diff --git a/lib/services/role_test.go b/lib/services/role_test.go index 028410a6ba09d..1f312befc001e 100644 --- a/lib/services/role_test.go +++ b/lib/services/role_test.go @@ -10066,3 +10066,46 @@ func TestMCPToolMatcher(t *testing.T) { }) } } + +func TestNewEnumerationResultFromEntities(t *testing.T) { + tests := []struct { + name string + inputAllowed []string + inputDenied []string + wantAllowed []string + wantDenied []string + }{ + { + name: "empty", + }, + { + name: "wildcard denied", + inputAllowed: []string{"allow_entry"}, + inputDenied: []string{"deny_entry", "*"}, + wantDenied: []string{"*"}, + }, + { + name: "wildcard allowed", + inputAllowed: []string{"allow_entry", "*", "deny_overwrite"}, + inputDenied: []string{"deny_overwrite"}, + wantAllowed: []string{"*", "allow_entry"}, + wantDenied: []string{"deny_overwrite"}, + }, + { + name: "no wildcard", + inputAllowed: []string{"allow_entry_1", "deny_overwrite", "allow_entry_2"}, + inputDenied: []string{"deny_overwrite", "deny_entry"}, + wantAllowed: []string{"allow_entry_1", "allow_entry_2"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := NewEnumerationResultFromEntities(test.inputAllowed, test.inputDenied) + + actualAllowed, actualDenied := result.ToEntities() + require.Equal(t, test.wantAllowed, actualAllowed) + require.Equal(t, test.wantDenied, actualDenied) + }) + } +} diff --git a/tool/common/common.go b/tool/common/common.go index d0f9ef2ae9664..da3ec73dd83f2 100644 --- a/tool/common/common.go +++ b/tool/common/common.go @@ -27,6 +27,7 @@ import ( "sort" "strings" + "github.com/ghodss/yaml" "github.com/gravitational/trace" "github.com/gravitational/teleport" @@ -238,3 +239,35 @@ func FormatDefault[T comparable](val, defaultVal T) string { } return fmt.Sprintf("%v", val) } + +// FormatAllowedEntities returns a human-readable string describing the allowed +// entities, optionally including a list of denied entities as exceptions. +func FormatAllowedEntities(allowed []string, denied []string) string { + if len(allowed) == 0 { + return "(none)" + } + if len(denied) == 0 { + return fmt.Sprintf("%v", allowed) + } + return fmt.Sprintf("%v, except: %v", allowed, denied) +} + +// PrintJSONIndent prints provided value in JSON with default indentation. +func PrintJSONIndent(w io.Writer, v any) error { + out, err := utils.FastMarshalIndent(v, "", " ") + if err != nil { + return trace.Wrap(err) + } + _, err = fmt.Fprintln(w, string(out)) + return trace.Wrap(err) +} + +// PrintYAML prints provided value in YAML. +func PrintYAML(w io.Writer, v any) error { + out, err := yaml.Marshal(v) + if err != nil { + return trace.Wrap(err) + } + _, err = fmt.Fprintln(w, string(out)) + return trace.Wrap(err) +} diff --git a/tool/tsh/common/app.go b/tool/tsh/common/app.go index 8f5c38cd38743..408632b4e0af5 100644 --- a/tool/tsh/common/app.go +++ b/tool/tsh/common/app.go @@ -79,6 +79,10 @@ func onAppLogin(cf *CLIConf) error { } defer clusterClient.Close() + if app.IsMCP() { + return trace.BadParameter("MCP applications are not supported. Please see 'tsh mcp login --help' for more details.") + } + if err := validateTargetPort(app, int(cf.TargetPort)); err != nil { return trace.Wrap(err) } diff --git a/tool/tsh/common/db.go b/tool/tsh/common/db.go index 6dbf2fbf5bf91..8cc96273ab3ba 100644 --- a/tool/tsh/common/db.go +++ b/tool/tsh/common/db.go @@ -1005,12 +1005,16 @@ func (d *databaseInfo) getChecker(ctx context.Context, tc *client.TeleportClient } defer clusterClient.Close() + return makeAccessChecker(ctx, tc, clusterClient.AuthClient) +} + +func makeAccessChecker(ctx context.Context, tc *client.TeleportClient, auth services.CurrentUserRoleGetter) (services.AccessChecker, error) { profile, err := tc.ProfileStatus() if err != nil { return nil, trace.Wrap(err) } - checker, err := services.NewAccessCheckerForRemoteCluster(ctx, profile.AccessInfo(), tc.SiteName, clusterClient.AuthClient) + checker, err := services.NewAccessCheckerForRemoteCluster(ctx, profile.AccessInfo(), tc.SiteName, auth) return checker, trace.Wrap(err) } diff --git a/tool/tsh/common/db_exec_test.go b/tool/tsh/common/db_exec_test.go index dc99644fd1b04..27b9850364fad 100644 --- a/tool/tsh/common/db_exec_test.go +++ b/tool/tsh/common/db_exec_test.go @@ -326,6 +326,14 @@ func (c *fakeDatabaseExecClient) issueCert(context.Context, *databaseInfo) (tls. return c.cert, nil } func (c *fakeDatabaseExecClient) listDatabasesWithFilter(ctx context.Context, req *proto.ListResourcesRequest) ([]types.Database, error) { + filtered, err := matchResources(req, c.allDatabaseServers) + if err != nil { + return nil, trace.Wrap(err) + } + return types.DatabaseServers(filtered).ToDatabases(), nil +} + +func matchResources[R types.ResourceWithLabels](req *proto.ListResourcesRequest, s []R) ([]R, error) { filter := services.MatchResourceFilter{ ResourceKind: req.ResourceType, Labels: req.Labels, @@ -339,13 +347,13 @@ func (c *fakeDatabaseExecClient) listDatabasesWithFilter(ctx context.Context, re filter.PredicateExpression = expression } - var filtered []types.Database - for _, dbServer := range c.allDatabaseServers { - match, err := services.MatchResourceByFilters(dbServer, filter, nil) + var filtered []R + for _, r := range s { + match, err := services.MatchResourceByFilters(r, filter, nil) if err != nil { return nil, trace.Wrap(err) } else if match { - filtered = append(filtered, dbServer.GetDatabase()) + filtered = append(filtered, r) } } return filtered, nil diff --git a/tool/tsh/common/mcp.go b/tool/tsh/common/mcp.go index 60f08d445233d..93d716d061ef3 100644 --- a/tool/tsh/common/mcp.go +++ b/tool/tsh/common/mcp.go @@ -16,16 +16,20 @@ package common -import "github.com/alecthomas/kingpin/v2" +import ( + "github.com/alecthomas/kingpin/v2" +) type mcpCommands struct { dbStart *mcpDBStartCommand + list *mcpListCommand } -func newMCPCommands(app *kingpin.Application) *mcpCommands { +func newMCPCommands(app *kingpin.Application, cf *CLIConf) *mcpCommands { mcp := app.Command("mcp", "View and control proxied MCP servers.") db := mcp.Command("db", "Database access for MCP servers.") return &mcpCommands{ dbStart: newMCPDBCommand(db), + list: newMCPListCommand(mcp, cf), } } diff --git a/tool/tsh/common/mcp_app.go b/tool/tsh/common/mcp_app.go new file mode 100644 index 0000000000000..9a1958473f9cb --- /dev/null +++ b/tool/tsh/common/mcp_app.go @@ -0,0 +1,218 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package common + +import ( + "cmp" + "context" + "fmt" + "io" + "iter" + "slices" + "strings" + + "github.com/alecthomas/kingpin/v2" + "github.com/gravitational/trace" + oteltrace "go.opentelemetry.io/otel/trace" + + "github.com/gravitational/teleport" + apiclient "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/iterutils" + "github.com/gravitational/teleport/lib/asciitable" + "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/tool/common" +) + +func newMCPListCommand(parent *kingpin.CmdClause, cf *CLIConf) *mcpListCommand { + cmd := &mcpListCommand{ + CmdClause: parent.Command("ls", "List available MCP server applications"), + cf: cf, + } + + cmd.Flag("verbose", "Show extra MCP server fields.").Short('v').BoolVar(&cf.Verbose) + cmd.Flag("search", searchHelp).StringVar(&cf.SearchKeywords) + cmd.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) + cmd.Arg("labels", labelHelp).StringVar(&cf.Labels) + cmd.Flag("format", defaults.FormatFlagDescription(defaults.DefaultFormats...)).Short('f').Default(teleport.Text).EnumVar(&cf.Format, defaults.DefaultFormats...) + return cmd +} + +// mcpListCommand implements `tsh mcp ls` command. +type mcpListCommand struct { + *kingpin.CmdClause + cf *CLIConf + accessChecker services.AccessChecker + mcpServers []types.Application +} + +func (c *mcpListCommand) run() error { + if err := c.fetch(); err != nil { + return trace.Wrap(err) + } + return trace.Wrap(c.print()) +} + +func (c *mcpListCommand) fetch() error { + ctx := c.cf.Context + tc, err := makeClient(c.cf) + if err != nil { + return trace.Wrap(err) + } + + var clusterClient *client.ClusterClient + err = client.RetryWithRelogin(ctx, tc, func() error { + clusterClient, err = tc.ConnectToCluster(ctx) + return trace.Wrap(err) + }) + if err != nil { + return trace.Wrap(err) + } + defer clusterClient.Close() + + c.accessChecker, err = makeAccessChecker(ctx, tc, clusterClient.AuthClient) + if err != nil { + return trace.Wrap(err) + } + + c.mcpServers, err = fetchMCPServers(ctx, tc, clusterClient.AuthClient) + if err != nil { + return trace.Wrap(err) + } + return nil +} + +func (c *mcpListCommand) print() error { + mcpServers := iterutils.Map(func(app types.Application) mcpServerWithDetails { + return newMCPServerWithDetails(app, c.accessChecker) + }, slices.Values(c.mcpServers)) + + switch c.cf.Format { + case "", teleport.Text: + if c.cf.Verbose { + return trace.Wrap(printMCPServersInVerboseText(c.cf.Stdout(), mcpServers)) + } + return trace.Wrap(printMCPServersInText(c.cf.Stdout(), mcpServers)) + + case teleport.JSON: + return trace.Wrap(common.PrintJSONIndent(c.cf.Stdout(), slices.Collect(mcpServers))) + case teleport.YAML: + return trace.Wrap(common.PrintYAML(c.cf.Stdout(), slices.Collect(mcpServers))) + + default: + return trace.BadParameter("unsupported format %q", c.cf.Format) + } +} + +func fetchMCPServers(ctx context.Context, tc *client.TeleportClient, auth apiclient.GetResourcesClient) ([]types.Application, error) { + ctx, span := tc.Tracer.Start( + ctx, + "fetchMCPServers", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + ) + defer span.End() + + filter := tc.ResourceFilter(types.KindAppServer) + filter.PredicateExpression = withMCPServerAppFilter(filter.PredicateExpression) + + appServers, err := apiclient.GetAllResources[types.AppServer](ctx, auth, filter) + if err != nil { + return nil, trace.Wrap(err) + } + + return slices.SortedFunc( + types.DeduplicatedApps(types.AppServers(appServers).Applications()), + types.CompareResourceByNames, + ), nil +} + +func withMCPServerAppFilter(predicateExpression string) string { + return makePredicateConjunction( + predicateExpression, + `resource.sub_kind == "mcp"`, + ) +} + +// mcpServerWithDetails defines an MCP server application with permission +// details, for printing purpose. +type mcpServerWithDetails struct { + // Use a real type for inline. + *types.AppV3 + + Permissions struct { + MCP struct { + Tools struct { + Allowed []string `json:"allowed"` + Denied []string `json:"denied,omitempty"` + } `json:"tools"` + } `json:"mcp"` + } `json:"permissions"` +} + +func (a *mcpServerWithDetails) updateToolsPermissions(accessChecker services.AccessChecker) { + if accessChecker == nil { + return + } + + mcpTools := accessChecker.EnumerateMCPTools(a.AppV3) + a.Permissions.MCP.Tools.Allowed, a.Permissions.MCP.Tools.Denied = mcpTools.ToEntities() +} + +func newMCPServerWithDetails(app types.Application, accessChecker services.AccessChecker) mcpServerWithDetails { + a := mcpServerWithDetails{ + AppV3: app.Copy(), + } + a.updateToolsPermissions(accessChecker) + return a +} + +func printMCPServersInText(w io.Writer, mcpServers iter.Seq[mcpServerWithDetails]) error { + var rows [][]string + for mcpServer := range mcpServers { + rows = append(rows, []string{ + mcpServer.GetName(), + mcpServer.GetDescription(), + types.GetMCPServerTransportType(mcpServer.GetURI()), + common.FormatLabels(mcpServer.GetAllLabels(), false), + }) + } + t := asciitable.MakeTableWithTruncatedColumn([]string{"Name", "Description", "Type", "Labels"}, rows, "Labels") + _, err := fmt.Fprintln(w, t.AsBuffer().String()) + return trace.Wrap(err) +} + +func printMCPServersInVerboseText(w io.Writer, mcpServers iter.Seq[mcpServerWithDetails]) error { + t := asciitable.MakeTable([]string{"Name", "Description", "Type", "Labels", "Command", "Args", "Allowed Tools"}) + for mcpServer := range mcpServers { + mcpSpec := cmp.Or(mcpServer.GetMCP(), &types.MCP{}) + t.AddRow([]string{ + mcpServer.GetName(), + mcpServer.GetDescription(), + types.GetMCPServerTransportType(mcpServer.GetURI()), + common.FormatLabels(mcpServer.GetAllLabels(), true), + mcpSpec.Command, + strings.Join(mcpSpec.Args, " "), + common.FormatAllowedEntities(mcpServer.Permissions.MCP.Tools.Allowed, mcpServer.Permissions.MCP.Tools.Denied), + }) + } + _, err := fmt.Fprintln(w, t.AsBuffer().String()) + return trace.Wrap(err) +} diff --git a/tool/tsh/common/mcp_app_test.go b/tool/tsh/common/mcp_app_test.go new file mode 100644 index 0000000000000..dec2d83efe005 --- /dev/null +++ b/tool/tsh/common/mcp_app_test.go @@ -0,0 +1,364 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package common + +import ( + "bytes" + "context" + "slices" + "testing" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/observability/tracing" + "github.com/gravitational/teleport/lib/services" +) + +func Test_fetchMCPServers(t *testing.T) { + devLabels := map[string]string{"env": "dev"} + prodLabels := map[string]string{"env": "prod"} + + nonMCPAppServer := mustMakeNewAppServer(t, mustMakeNewAppV3(t, + types.Metadata{ + Name: "non-mcp-app", + Labels: devLabels, + }, + types.AppSpecV3{ + URI: "https://example.com", + }, + ), "host1") + + devMCPAppHost1 := mustMakeNewAppServer(t, mustMakeMCPAppWithNameAndLabels(t, "dev", devLabels), "host1") + devMCPAppHost2 := mustMakeNewAppServer(t, mustMakeMCPAppWithNameAndLabels(t, "dev", devLabels), "host2") + proMCPApp1 := mustMakeNewAppServer(t, mustMakeMCPAppWithNameAndLabels(t, "prod-1", prodLabels), "host1") + proMCPApp2 := mustMakeNewAppServer(t, mustMakeMCPAppWithNameAndLabels(t, "prod-2", prodLabels), "host1") + + fakeClient := &fakeResourcesClient{ + resources: []types.ResourceWithLabels{ + proMCPApp1, nonMCPAppServer, devMCPAppHost1, devMCPAppHost2, proMCPApp2, + }, + } + + tests := []struct { + name string + searchConfig client.Config + wantNames []string + }{ + { + name: "no match", + searchConfig: client.Config{ + Labels: map[string]string{"env": "not-found"}, + }, + wantNames: nil, + }, + { + name: "all", + searchConfig: client.Config{}, + wantNames: []string{"dev", "prod-1", "prod-2"}, + }, + { + name: "by label", + searchConfig: client.Config{ + Labels: map[string]string{"env": "prod"}, + }, + wantNames: []string{"prod-1", "prod-2"}, + }, + { + name: "by keywords", + searchConfig: client.Config{ + SearchKeywords: []string{"prod"}, + }, + wantNames: []string{"prod-1", "prod-2"}, + }, + { + name: "by predicate", + searchConfig: client.Config{ + PredicateExpression: "name==\"dev\"", + }, + wantNames: []string{"dev"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tc := &client.TeleportClient{ + Config: tt.searchConfig, + } + tc.Tracer = tracing.NoopTracer(teleport.ComponentTSH) + + mcpServers, err := fetchMCPServers(context.Background(), tc, fakeClient) + require.NoError(t, err) + require.Equal(t, tt.wantNames, slices.Collect(types.ResourceNames(mcpServers))) + }) + } +} + +// Test_mcpListCommand tests "tsh mcp ls". +// Note that mcpListCommand.fetch is not interesting to test and some of its +// logic is tested separately (see Test_fetchMCPServers above). Thus, this test +// mocks fetch results and tests mcpListCommand.print. +func Test_mcpListCommand(t *testing.T) { + devLabels := map[string]string{"env": "dev"} + mcpServers := []types.Application{ + mustMakeMCPAppWithNameAndLabels(t, "allow-read", devLabels), + mustMakeMCPAppWithNameAndLabels(t, "deny-write", devLabels), + } + accessChecker := fakeMCPServerAccessChecker{} + + tests := []struct { + name string + cf *CLIConf + wantOutput string + }{ + { + name: "text format", + cf: &CLIConf{}, + wantOutput: `Name Description Type Labels +---------- ----------- ----- ------- +allow-read description stdio env=dev +deny-write description stdio env=dev + +`, + }, + { + name: "text format in verbose", + cf: &CLIConf{ + Verbose: true, + }, + wantOutput: `Name Description Type Labels Command Args Allowed Tools +---------- ----------- ----- ------- ------- ---- ---------------------- +allow-read description stdio env=dev test arg [read_*] +deny-write description stdio env=dev test arg [*], except: [write_*] + +`, + }, + { + name: "JSON format", + cf: &CLIConf{ + Format: "json", + }, + wantOutput: `[ + { + "kind": "app", + "sub_kind": "mcp", + "version": "v3", + "metadata": { + "name": "allow-read", + "description": "description", + "labels": { + "env": "dev" + } + }, + "spec": { + "uri": "mcp+stdio://", + "insecure_skip_verify": false, + "mcp": { + "command": "test", + "args": [ + "arg" + ], + "run_as_host_user": "test" + } + }, + "permissions": { + "mcp": { + "tools": { + "allowed": [ + "read_*" + ] + } + } + } + }, + { + "kind": "app", + "sub_kind": "mcp", + "version": "v3", + "metadata": { + "name": "deny-write", + "description": "description", + "labels": { + "env": "dev" + } + }, + "spec": { + "uri": "mcp+stdio://", + "insecure_skip_verify": false, + "mcp": { + "command": "test", + "args": [ + "arg" + ], + "run_as_host_user": "test" + } + }, + "permissions": { + "mcp": { + "tools": { + "allowed": [ + "*" + ], + "denied": [ + "write_*" + ] + } + } + } + } +] +`, + }, + { + name: "YAML format", + cf: &CLIConf{ + Format: "yaml", + }, + wantOutput: `- kind: app + metadata: + description: description + labels: + env: dev + name: allow-read + permissions: + mcp: + tools: + allowed: + - read_* + spec: + insecure_skip_verify: false + mcp: + args: + - arg + command: test + run_as_host_user: test + uri: mcp+stdio:// + sub_kind: mcp + version: v3 +- kind: app + metadata: + description: description + labels: + env: dev + name: deny-write + permissions: + mcp: + tools: + allowed: + - '*' + denied: + - write_* + spec: + insecure_skip_verify: false + mcp: + args: + - arg + command: test + run_as_host_user: test + uri: mcp+stdio:// + sub_kind: mcp + version: v3 + +`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + cf := tt.cf + cf.OverrideStdout = &buf + cf.Context = context.Background() + + cmd := &mcpListCommand{ + cf: tt.cf, + mcpServers: mcpServers, + accessChecker: accessChecker, + } + + err := cmd.print() + require.NoError(t, err) + require.Equal(t, tt.wantOutput, buf.String()) + }) + } +} + +type fakeResourcesClient struct { + resources []types.ResourceWithLabels +} + +func (f *fakeResourcesClient) GetResources(ctx context.Context, req *proto.ListResourcesRequest) (*proto.ListResourcesResponse, error) { + filtered, err := matchResources(req, f.resources) + if err != nil { + return nil, trace.Wrap(err) + } + + paginatedResources, err := services.MakePaginatedResources(req.ResourceType, filtered, nil) + if err != nil { + return nil, trace.Wrap(err) + } + return &proto.ListResourcesResponse{ + Resources: paginatedResources, + TotalCount: int32(len(filtered)), + }, nil +} + +func mustMakeNewAppServer(t *testing.T, app *types.AppV3, host string) types.AppServer { + t.Helper() + appServer, err := types.NewAppServerV3FromApp(app, host, host) + require.NoError(t, err) + return appServer +} + +func mustMakeMCPAppWithNameAndLabels(t *testing.T, name string, labels map[string]string) *types.AppV3 { + t.Helper() + return mustMakeNewAppV3(t, + types.Metadata{ + Name: name, + Description: "description", + Labels: labels, + }, + types.AppSpecV3{ + MCP: &types.MCP{ + Command: "test", + Args: []string{"arg"}, + RunAsHostUser: "test", + }, + }, + ) +} + +type fakeMCPServerAccessChecker struct { + services.AccessChecker +} + +func (f fakeMCPServerAccessChecker) EnumerateMCPTools(app types.Application) services.EnumerationResult { + switch app.GetName() { + case "allow-read": + return services.NewEnumerationResultFromEntities([]string{"read_*"}, nil) + case "deny-write": + return services.NewEnumerationResultFromEntities([]string{"*"}, []string{"write_*"}) + default: + return services.NewEnumerationResult() + } +} diff --git a/tool/tsh/common/proxy.go b/tool/tsh/common/proxy.go index a8cc0b832a472..484507550af34 100644 --- a/tool/tsh/common/proxy.go +++ b/tool/tsh/common/proxy.go @@ -501,6 +501,10 @@ func onProxyCommandApp(cf *CLIConf) error { return trace.Wrap(err) } + if app.IsMCP() { + return trace.BadParameter("MCP applications are not supported. Please see 'tsh mcp login --help' for more details.") + } + proxyApp, err := newLocalProxyAppWithPortMapping(cf.Context, tc, profile, appInfo.RouteToApp, app, portMapping, cf.InsecureSkipVerify) if err != nil { return trace.Wrap(err) diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index c6612cd7f1ba6..99d58d823b9e0 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -1344,8 +1344,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { gitCmd := newGitCommands(app) pivCmd := newPIVCommands(app) - - mcpCmd := newMCPCommands(app) + mcpCmd := newMCPCommands(app, &cf) if runtime.GOOS == constants.WindowsOS { bench.Hidden() @@ -1754,6 +1753,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { err = pivCmd.agent.run(&cf) case mcpCmd.dbStart.FullCommand(): err = mcpCmd.dbStart.run(&cf) + case mcpCmd.list.FullCommand(): + err = mcpCmd.list.run() default: // Handle commands that might not be available. switch { @@ -3211,14 +3212,7 @@ func getDBUsers(db types.Database, accessChecker services.AccessChecker) *dbUser ) return &dbUsers{} } - var denied []string - allowed := users.Allowed() - if users.WildcardAllowed() { - // start the list with *. - allowed = append([]string{types.Wildcard}, allowed...) - // only include denied users if the wildcard is allowed. - denied = append(denied, users.Denied()...) - } + allowed, denied := users.ToEntities() return &dbUsers{ Allowed: allowed, Denied: denied, @@ -3283,9 +3277,6 @@ func formatUsersForDB(database types.Database, accessChecker services.AccessChec } dbUsers := getDBUsers(database, accessChecker) - if len(dbUsers.Allowed) == 0 { - return "(none)" - } // Add a note for auto-provisioned user. if database.IsAutoUsersEnabled() { @@ -3302,10 +3293,7 @@ func formatUsersForDB(database types.Database, accessChecker services.AccessChec } } - if len(dbUsers.Denied) == 0 { - return fmt.Sprintf("%v", dbUsers.Allowed) - } - return fmt.Sprintf("%v, except: %v", dbUsers.Allowed, dbUsers.Denied) + return common.FormatAllowedEntities(dbUsers.Allowed, dbUsers.Denied) } // TODO(greedy52) more refactoring on db printing and move them to db_print.go.