From 3d1256241d8e8792fa9d163568e9eaeea68244a4 Mon Sep 17 00:00:00 2001 From: "STeve (Xin) Huang" Date: Mon, 26 May 2025 15:59:26 -0400 Subject: [PATCH 1/2] MCP access part 6: "tsh mcp ls" --- api/types/app.go | 29 ++- api/types/app_test.go | 17 ++ api/types/appserver.go | 10 + api/types/resource.go | 5 + lib/services/role.go | 44 ++++ lib/services/role_test.go | 64 ++++++ lib/utils/slices/slices.go | 18 ++ lib/utils/slices/slices_test.go | 22 ++ tool/common/common.go | 33 +++ tool/tsh/common/db.go | 6 +- tool/tsh/common/db_exec_test.go | 16 +- tool/tsh/common/mcp.go | 34 +++ tool/tsh/common/mcp_app.go | 218 +++++++++++++++++++ tool/tsh/common/mcp_app_test.go | 364 ++++++++++++++++++++++++++++++++ tool/tsh/common/tsh.go | 20 +- tool/tsh/common/tsh_test.go | 8 +- 16 files changed, 874 insertions(+), 34 deletions(-) create mode 100644 tool/tsh/common/mcp.go create mode 100644 tool/tsh/common/mcp_app.go create mode 100644 tool/tsh/common/mcp_app_test.go diff --git a/api/types/app.go b/api/types/app.go index 146f9de88b756..ae87481431006 100644 --- a/api/types/app.go +++ b/api/types/app.go @@ -18,6 +18,7 @@ package types import ( "fmt" + "iter" "net/url" "slices" "strconv" @@ -309,6 +310,9 @@ func (a *AppV3) GetProtocol() string { if a.IsTCP() { return "TCP" } + if a.IsMCP() { + return "MCP" + } return "HTTP" } @@ -565,18 +569,27 @@ func (a *AppV3) GetMCP() *MCP { // DeduplicateApps deduplicates apps by combination of app name and public address. // Apps can have the same name but also could have different addresses. -func DeduplicateApps(apps []Application) (result []Application) { +func DeduplicateApps(apps []Application) []Application { + return slices.Collect(DeduplicatedApps(slices.Values(apps))) +} + +// DeduplicatedApps iterates deduplicated apps by combination of app name and +// public address. This is the iter.Seq version of DeduplicateApps. +func DeduplicatedApps(apps iter.Seq[Application]) iter.Seq[Application] { type key struct{ name, addr string } seen := make(map[key]struct{}) - for _, app := range apps { - key := key{app.GetName(), app.GetPublicAddr()} - if _, ok := seen[key]; ok { - continue + return func(yield func(Application) bool) { + for app := range apps { + key := key{app.GetName(), app.GetPublicAddr()} + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + if !yield(app) { + return + } } - seen[key] = struct{}{} - result = append(result, app) } - return result } // Apps is a list of app resources. diff --git a/api/types/app_test.go b/api/types/app_test.go index c9d5e3da25d20..afb494f376082 100644 --- a/api/types/app_test.go +++ b/api/types/app_test.go @@ -18,6 +18,7 @@ package types import ( "fmt" + "slices" "strconv" "testing" @@ -736,3 +737,19 @@ func TestGetMCPServerTransportType(t *testing.T) { }) } } + +func TestDeduplicateApps(t *testing.T) { + var apps []Application + for _, name := range []string{"a", "b", "c", "b", "a", "d"} { + app_, err := NewAppV3(Metadata{ + Name: name, + }, AppSpecV3{ + URI: "localhost:3080", + }) + require.NoError(t, err) + apps = append(apps, app_) + } + + deduped := DeduplicateApps(apps) + require.Equal(t, []string{"a", "b", "c", "d"}, slices.Collect(ResourceNames(deduped))) +} diff --git a/api/types/appserver.go b/api/types/appserver.go index f2094e25391af..449b7ce98610a 100644 --- a/api/types/appserver.go +++ b/api/types/appserver.go @@ -18,6 +18,8 @@ package types import ( "fmt" + "iter" + "slices" "sort" "time" @@ -26,6 +28,7 @@ import ( "github.com/gravitational/teleport/api" "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/utils" + "github.com/gravitational/teleport/api/utils/iterutils" ) // AppServer represents a single proxied web app. @@ -409,3 +412,10 @@ func (s AppServers) GetFieldVals(field string) ([]string, error) { return vals, nil } + +// Applications iterates over the applications that the AppServers proxy. +func (s AppServers) Applications() iter.Seq[Application] { + return iterutils.Map(func(appServer AppServer) Application { + return appServer.GetApp() + }, slices.Values(s)) +} diff --git a/api/types/resource.go b/api/types/resource.go index 571572ef7235a..9690a80f8a8f3 100644 --- a/api/types/resource.go +++ b/api/types/resource.go @@ -92,6 +92,11 @@ func ResourceNames[R Resource, S ~[]R](s S) iter.Seq[string] { return iterutils.Map(GetName, slices.Values(s)) } +// CompareResourceByNames compares resources by their names. +func CompareResourceByNames[R Resource](a, b R) int { + return strings.Compare(a.GetName(), b.GetName()) +} + // ResourceDetails includes details about the resource type ResourceDetails struct { Hostname string diff --git a/lib/services/role.go b/lib/services/role.go index f5c709110abf2..a654584f293f1 100644 --- a/lib/services/role.go +++ b/lib/services/role.go @@ -53,6 +53,7 @@ import ( awsutils "github.com/gravitational/teleport/lib/utils/aws" logutils "github.com/gravitational/teleport/lib/utils/log" "github.com/gravitational/teleport/lib/utils/parse" + slicesutils "github.com/gravitational/teleport/lib/utils/slices" ) // DefaultImplicitRules provides access to the default set of implicit rules @@ -1029,6 +1030,19 @@ func (result *EnumerationResult) WildcardDenied() bool { return result.wildcardDenied } +// ToEntities converts result back to allowed and denied entity slices. Wildcard +// will be placed in the slice if any. +func (result *EnumerationResult) ToEntities() (allowed, denied []string) { + if result.wildcardDenied { + return nil, []string{types.Wildcard} + } + // Only include denied entities if the wildcard is allowed. + if result.wildcardAllowed { + return []string{types.Wildcard}, result.Denied() + } + return result.Allowed(), nil +} + // NewEnumerationResult returns new EnumerationResult. func NewEnumerationResult() EnumerationResult { return EnumerationResult{ @@ -1038,6 +1052,36 @@ func NewEnumerationResult() EnumerationResult { } } +// NewEnumerationResultFromEntities creates a new EnumerationResult and +// populates the result with provided allowed and denied entries. +func NewEnumerationResultFromEntities(allowed, denied []string) EnumerationResult { + switch { + // Wildcard deny, ignore all other entries. + case slices.Contains(denied, types.Wildcard): + return EnumerationResult{ + allowedDeniedMap: make(map[string]bool), + wildcardDenied: true, + } + + // Wildcard allow, ignore all allowed entries. + case slices.Contains(allowed, types.Wildcard): + return EnumerationResult{ + allowedDeniedMap: slicesutils.ToMapWithDefaultValue(denied, false), + wildcardAllowed: true, + } + + // No wildcard. + default: + allowedDeniedMap := slicesutils.ToMapWithDefaultValue(allowed, true) + for _, deny := range denied { + delete(allowedDeniedMap, deny) + } + return EnumerationResult{ + allowedDeniedMap: allowedDeniedMap, + } + } +} + // MatchNamespace returns true if given list of namespace matches // target namespace, wildcard matches everything. func MatchNamespace(selectors []string, namespace string) (bool, string) { diff --git a/lib/services/role_test.go b/lib/services/role_test.go index 028410a6ba09d..0a42640233be8 100644 --- a/lib/services/role_test.go +++ b/lib/services/role_test.go @@ -10066,3 +10066,67 @@ func TestMCPToolMatcher(t *testing.T) { }) } } + +func TestNewEnumerationResultFromEntities(t *testing.T) { + tests := []struct { + name string + inputAllowed []string + inputDenied []string + wantResult EnumerationResult + wantAllowed []string + wantDenied []string + }{ + { + name: "empty", + wantResult: EnumerationResult{ + allowedDeniedMap: make(map[string]bool), + }, + }, + { + name: "wildcard denied", + inputAllowed: []string{"allow_entry"}, + inputDenied: []string{"deny_entry", "*"}, + wantResult: EnumerationResult{ + allowedDeniedMap: make(map[string]bool), + wildcardDenied: true, + }, + wantDenied: []string{"*"}, + }, + { + name: "wildcard allowed", + inputAllowed: []string{"allow_entry", "*"}, + inputDenied: []string{"deny_entry"}, + wantResult: EnumerationResult{ + allowedDeniedMap: map[string]bool{ + "deny_entry": false, + }, + wildcardAllowed: true, + }, + wantAllowed: []string{"*"}, + wantDenied: []string{"deny_entry"}, + }, + { + name: "no wildcard", + inputAllowed: []string{"allow_entry_1", "deny_overwrite", "allow_entry_2"}, + inputDenied: []string{"deny_overwrite", "deny_entry"}, + wantResult: EnumerationResult{ + allowedDeniedMap: map[string]bool{ + "allow_entry_1": true, + "allow_entry_2": true, + }, + }, + wantAllowed: []string{"allow_entry_1", "allow_entry_2"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + actualResult := NewEnumerationResultFromEntities(test.inputAllowed, test.inputDenied) + require.Equal(t, test.wantResult, actualResult) + + actualAllowed, actualDenied := actualResult.ToEntities() + require.Equal(t, test.wantAllowed, actualAllowed) + require.Equal(t, test.wantDenied, actualDenied) + }) + } +} diff --git a/lib/utils/slices/slices.go b/lib/utils/slices/slices.go index 5e524b7f1ed9b..115905e4446c8 100644 --- a/lib/utils/slices/slices.go +++ b/lib/utils/slices/slices.go @@ -80,3 +80,21 @@ func DeduplicateKey[T any](s []T, key func(T) string) []T { } return out } + +// ToMap converts elements in a slice as the keys to a map, and use provided +// function to calculate their values in the map. +func ToMap[K comparable, V any, S ~[]K](s S, makeValue func(K) V) map[K]V { + m := make(map[K]V, len(s)) + for _, key := range s { + m[key] = makeValue(key) + } + return m +} + +// ToMapWithDefaultValue converts elements in a slice as the keys to a map, and +// all values in the map are set to provided default value. +func ToMapWithDefaultValue[K comparable, V any, S ~[]K](s S, defaultValue V) map[K]V { + return ToMap(s, func(K) V { + return defaultValue + }) +} diff --git a/lib/utils/slices/slices_test.go b/lib/utils/slices/slices_test.go index b0fe86a1bfd2e..6a0925480af79 100644 --- a/lib/utils/slices/slices_test.go +++ b/lib/utils/slices/slices_test.go @@ -20,6 +20,7 @@ package slices import ( "fmt" + "strconv" "strings" "testing" @@ -182,3 +183,24 @@ func TestDeduplicateKey(t *testing.T) { }) } } + +func TestToMap(t *testing.T) { + numbers := []int{1, 2, 3, 4} + m := ToMap(numbers, strconv.Itoa) + require.Equal(t, map[int]string{ + 1: "1", + 2: "2", + 3: "3", + 4: "4", + }, m) +} + +func TestToMapWithDefaultValue(t *testing.T) { + s := []string{"a", "b", "c"} + m := ToMapWithDefaultValue(s, true) + require.Equal(t, map[string]bool{ + "a": true, + "b": true, + "c": true, + }, m) +} diff --git a/tool/common/common.go b/tool/common/common.go index d0f9ef2ae9664..da3ec73dd83f2 100644 --- a/tool/common/common.go +++ b/tool/common/common.go @@ -27,6 +27,7 @@ import ( "sort" "strings" + "github.com/ghodss/yaml" "github.com/gravitational/trace" "github.com/gravitational/teleport" @@ -238,3 +239,35 @@ func FormatDefault[T comparable](val, defaultVal T) string { } return fmt.Sprintf("%v", val) } + +// FormatAllowedEntities returns a human-readable string describing the allowed +// entities, optionally including a list of denied entities as exceptions. +func FormatAllowedEntities(allowed []string, denied []string) string { + if len(allowed) == 0 { + return "(none)" + } + if len(denied) == 0 { + return fmt.Sprintf("%v", allowed) + } + return fmt.Sprintf("%v, except: %v", allowed, denied) +} + +// PrintJSONIndent prints provided value in JSON with default indentation. +func PrintJSONIndent(w io.Writer, v any) error { + out, err := utils.FastMarshalIndent(v, "", " ") + if err != nil { + return trace.Wrap(err) + } + _, err = fmt.Fprintln(w, string(out)) + return trace.Wrap(err) +} + +// PrintYAML prints provided value in YAML. +func PrintYAML(w io.Writer, v any) error { + out, err := yaml.Marshal(v) + if err != nil { + return trace.Wrap(err) + } + _, err = fmt.Fprintln(w, string(out)) + return trace.Wrap(err) +} diff --git a/tool/tsh/common/db.go b/tool/tsh/common/db.go index 6dbf2fbf5bf91..8cc96273ab3ba 100644 --- a/tool/tsh/common/db.go +++ b/tool/tsh/common/db.go @@ -1005,12 +1005,16 @@ func (d *databaseInfo) getChecker(ctx context.Context, tc *client.TeleportClient } defer clusterClient.Close() + return makeAccessChecker(ctx, tc, clusterClient.AuthClient) +} + +func makeAccessChecker(ctx context.Context, tc *client.TeleportClient, auth services.CurrentUserRoleGetter) (services.AccessChecker, error) { profile, err := tc.ProfileStatus() if err != nil { return nil, trace.Wrap(err) } - checker, err := services.NewAccessCheckerForRemoteCluster(ctx, profile.AccessInfo(), tc.SiteName, clusterClient.AuthClient) + checker, err := services.NewAccessCheckerForRemoteCluster(ctx, profile.AccessInfo(), tc.SiteName, auth) return checker, trace.Wrap(err) } diff --git a/tool/tsh/common/db_exec_test.go b/tool/tsh/common/db_exec_test.go index dc99644fd1b04..27b9850364fad 100644 --- a/tool/tsh/common/db_exec_test.go +++ b/tool/tsh/common/db_exec_test.go @@ -326,6 +326,14 @@ func (c *fakeDatabaseExecClient) issueCert(context.Context, *databaseInfo) (tls. return c.cert, nil } func (c *fakeDatabaseExecClient) listDatabasesWithFilter(ctx context.Context, req *proto.ListResourcesRequest) ([]types.Database, error) { + filtered, err := matchResources(req, c.allDatabaseServers) + if err != nil { + return nil, trace.Wrap(err) + } + return types.DatabaseServers(filtered).ToDatabases(), nil +} + +func matchResources[R types.ResourceWithLabels](req *proto.ListResourcesRequest, s []R) ([]R, error) { filter := services.MatchResourceFilter{ ResourceKind: req.ResourceType, Labels: req.Labels, @@ -339,13 +347,13 @@ func (c *fakeDatabaseExecClient) listDatabasesWithFilter(ctx context.Context, re filter.PredicateExpression = expression } - var filtered []types.Database - for _, dbServer := range c.allDatabaseServers { - match, err := services.MatchResourceByFilters(dbServer, filter, nil) + var filtered []R + for _, r := range s { + match, err := services.MatchResourceByFilters(r, filter, nil) if err != nil { return nil, trace.Wrap(err) } else if match { - filtered = append(filtered, dbServer.GetDatabase()) + filtered = append(filtered, r) } } return filtered, nil diff --git a/tool/tsh/common/mcp.go b/tool/tsh/common/mcp.go new file mode 100644 index 0000000000000..7e3b4b153c5de --- /dev/null +++ b/tool/tsh/common/mcp.go @@ -0,0 +1,34 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package common + +import ( + "github.com/alecthomas/kingpin/v2" +) + +type mcpCommands struct { + list *mcpListCommand +} + +func newMCPCommands(app *kingpin.Application, cf *CLIConf) *mcpCommands { + mcp := app.Command("mcp", "View and control proxied MCP servers.") + return &mcpCommands{ + list: newMCPListCommand(mcp, cf), + } +} diff --git a/tool/tsh/common/mcp_app.go b/tool/tsh/common/mcp_app.go new file mode 100644 index 0000000000000..f709e2b60e057 --- /dev/null +++ b/tool/tsh/common/mcp_app.go @@ -0,0 +1,218 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package common + +import ( + "cmp" + "context" + "fmt" + "io" + "iter" + "slices" + "strings" + + "github.com/alecthomas/kingpin/v2" + "github.com/gravitational/trace" + oteltrace "go.opentelemetry.io/otel/trace" + + "github.com/gravitational/teleport" + apiclient "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/iterutils" + "github.com/gravitational/teleport/lib/asciitable" + "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/tool/common" +) + +func newMCPListCommand(parent *kingpin.CmdClause, cf *CLIConf) *mcpListCommand { + cmd := &mcpListCommand{ + CmdClause: parent.Command("ls", "List available MCP server applications"), + cf: cf, + } + + cmd.Flag("verbose", "Show extra MCP server fields.").Short('v').BoolVar(&cf.Verbose) + cmd.Flag("search", searchHelp).StringVar(&cf.SearchKeywords) + cmd.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) + cmd.Arg("labels", labelHelp).StringVar(&cf.Labels) + cmd.Flag("format", defaults.FormatFlagDescription(defaults.DefaultFormats...)).Short('f').Default(teleport.Text).EnumVar(&cf.Format, defaults.DefaultFormats...) + return cmd +} + +// mcpListCommand implements `tsh mcp ls` command. +type mcpListCommand struct { + *kingpin.CmdClause + cf *CLIConf + accessChecker services.AccessChecker + mcpServers []types.Application +} + +func (c *mcpListCommand) run() error { + if err := c.fetch(); err != nil { + return trace.Wrap(err) + } + return trace.Wrap(c.print()) +} + +func (c *mcpListCommand) fetch() error { + ctx := c.cf.Context + tc, err := makeClient(c.cf) + if err != nil { + return trace.Wrap(err) + } + + var clusterClient *client.ClusterClient + err = client.RetryWithRelogin(ctx, tc, func() error { + clusterClient, err = tc.ConnectToCluster(ctx) + return trace.Wrap(err) + }) + if err != nil { + return trace.Wrap(err) + } + defer clusterClient.Close() + + c.accessChecker, err = makeAccessChecker(ctx, tc, clusterClient.AuthClient) + if err != nil { + return trace.Wrap(err) + } + + c.mcpServers, err = fetchMCPServers(ctx, tc, clusterClient.AuthClient) + if err != nil { + return trace.Wrap(err) + } + return nil +} + +func (c *mcpListCommand) print() error { + mcpServers := iterutils.Map(func(app types.Application) mcpServerWithDetails { + return newMCPServerWithDetails(app, c.accessChecker) + }, slices.Values(c.mcpServers)) + + switch c.cf.Format { + case "", teleport.Text: + if c.cf.Verbose { + return trace.Wrap(printMCPServersInVerboseText(c.cf.Stdout(), mcpServers)) + } + return trace.Wrap(printMCPServersInText(c.cf.Stdout(), mcpServers)) + + case teleport.JSON: + return trace.Wrap(common.PrintJSONIndent(c.cf.Stdout(), slices.Collect(mcpServers))) + case teleport.YAML: + return trace.Wrap(common.PrintYAML(c.cf.Stdout(), slices.Collect(mcpServers))) + + default: + return trace.BadParameter("unsupported format %q", c.cf.Format) + } +} + +func fetchMCPServers(ctx context.Context, tc *client.TeleportClient, auth apiclient.GetResourcesClient) ([]types.Application, error) { + ctx, span := tc.Tracer.Start( + ctx, + "fetchMCPServers", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + ) + defer span.End() + + filter := tc.ResourceFilter(types.KindAppServer) + filter.PredicateExpression = withMCPServerAppFilter(filter.PredicateExpression) + + appServers, err := apiclient.GetAllResources[types.AppServer](ctx, auth, filter) + if err != nil { + return nil, trace.Wrap(err) + } + + return slices.SortedFunc( + types.DeduplicatedApps(types.AppServers(appServers).Applications()), + types.CompareResourceByNames, + ), nil +} + +func withMCPServerAppFilter(predicateExpression string) string { + return makePredicateConjunction( + predicateExpression, + "hasPrefix(resource.spec.uri, \"mcp+\")", + ) +} + +// mcpServerWithDetails defines an MCP server application with permission +// details, for printing purpose. +type mcpServerWithDetails struct { + // Use a real type for inline. + *types.AppV3 + + Permissions struct { + MCP struct { + Tools struct { + Allowed []string `json:"allowed"` + Denied []string `json:"denied,omitempty"` + } `json:"tools"` + } `json:"mcp"` + } `json:"permissions"` +} + +func (a *mcpServerWithDetails) updateToolsPermissions(accessChecker services.AccessChecker) { + if accessChecker == nil { + return + } + + mcpTools := accessChecker.EnumerateMCPTools(a.AppV3) + a.Permissions.MCP.Tools.Allowed, a.Permissions.MCP.Tools.Denied = mcpTools.ToEntities() +} + +func newMCPServerWithDetails(app types.Application, accessChecker services.AccessChecker) mcpServerWithDetails { + a := mcpServerWithDetails{ + AppV3: app.Copy(), + } + a.updateToolsPermissions(accessChecker) + return a +} + +func printMCPServersInText(w io.Writer, mcpServers iter.Seq[mcpServerWithDetails]) error { + var rows [][]string + for mcpServer := range mcpServers { + rows = append(rows, []string{ + mcpServer.GetName(), + mcpServer.GetDescription(), + types.GetMCPServerTransportType(mcpServer.GetURI()), + common.FormatLabels(mcpServer.GetAllLabels(), false), + }) + } + t := asciitable.MakeTableWithTruncatedColumn([]string{"Name", "Description", "Type", "Labels"}, rows, "Labels") + _, err := fmt.Fprintln(w, t.AsBuffer().String()) + return trace.Wrap(err) +} + +func printMCPServersInVerboseText(w io.Writer, mcpServers iter.Seq[mcpServerWithDetails]) error { + t := asciitable.MakeTable([]string{"Name", "Description", "Type", "Labels", "Command", "Args", "Allowed Tools"}) + for mcpServer := range mcpServers { + mcpSpec := cmp.Or(mcpServer.GetMCP(), &types.MCP{}) + t.AddRow([]string{ + mcpServer.GetName(), + mcpServer.GetDescription(), + types.GetMCPServerTransportType(mcpServer.GetURI()), + common.FormatLabels(mcpServer.GetAllLabels(), true), + mcpSpec.Command, + strings.Join(mcpSpec.Args, " "), + common.FormatAllowedEntities(mcpServer.Permissions.MCP.Tools.Allowed, mcpServer.Permissions.MCP.Tools.Denied), + }) + } + _, err := fmt.Fprintln(w, t.AsBuffer().String()) + return trace.Wrap(err) +} diff --git a/tool/tsh/common/mcp_app_test.go b/tool/tsh/common/mcp_app_test.go new file mode 100644 index 0000000000000..dec2d83efe005 --- /dev/null +++ b/tool/tsh/common/mcp_app_test.go @@ -0,0 +1,364 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package common + +import ( + "bytes" + "context" + "slices" + "testing" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/observability/tracing" + "github.com/gravitational/teleport/lib/services" +) + +func Test_fetchMCPServers(t *testing.T) { + devLabels := map[string]string{"env": "dev"} + prodLabels := map[string]string{"env": "prod"} + + nonMCPAppServer := mustMakeNewAppServer(t, mustMakeNewAppV3(t, + types.Metadata{ + Name: "non-mcp-app", + Labels: devLabels, + }, + types.AppSpecV3{ + URI: "https://example.com", + }, + ), "host1") + + devMCPAppHost1 := mustMakeNewAppServer(t, mustMakeMCPAppWithNameAndLabels(t, "dev", devLabels), "host1") + devMCPAppHost2 := mustMakeNewAppServer(t, mustMakeMCPAppWithNameAndLabels(t, "dev", devLabels), "host2") + proMCPApp1 := mustMakeNewAppServer(t, mustMakeMCPAppWithNameAndLabels(t, "prod-1", prodLabels), "host1") + proMCPApp2 := mustMakeNewAppServer(t, mustMakeMCPAppWithNameAndLabels(t, "prod-2", prodLabels), "host1") + + fakeClient := &fakeResourcesClient{ + resources: []types.ResourceWithLabels{ + proMCPApp1, nonMCPAppServer, devMCPAppHost1, devMCPAppHost2, proMCPApp2, + }, + } + + tests := []struct { + name string + searchConfig client.Config + wantNames []string + }{ + { + name: "no match", + searchConfig: client.Config{ + Labels: map[string]string{"env": "not-found"}, + }, + wantNames: nil, + }, + { + name: "all", + searchConfig: client.Config{}, + wantNames: []string{"dev", "prod-1", "prod-2"}, + }, + { + name: "by label", + searchConfig: client.Config{ + Labels: map[string]string{"env": "prod"}, + }, + wantNames: []string{"prod-1", "prod-2"}, + }, + { + name: "by keywords", + searchConfig: client.Config{ + SearchKeywords: []string{"prod"}, + }, + wantNames: []string{"prod-1", "prod-2"}, + }, + { + name: "by predicate", + searchConfig: client.Config{ + PredicateExpression: "name==\"dev\"", + }, + wantNames: []string{"dev"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tc := &client.TeleportClient{ + Config: tt.searchConfig, + } + tc.Tracer = tracing.NoopTracer(teleport.ComponentTSH) + + mcpServers, err := fetchMCPServers(context.Background(), tc, fakeClient) + require.NoError(t, err) + require.Equal(t, tt.wantNames, slices.Collect(types.ResourceNames(mcpServers))) + }) + } +} + +// Test_mcpListCommand tests "tsh mcp ls". +// Note that mcpListCommand.fetch is not interesting to test and some of its +// logic is tested separately (see Test_fetchMCPServers above). Thus, this test +// mocks fetch results and tests mcpListCommand.print. +func Test_mcpListCommand(t *testing.T) { + devLabels := map[string]string{"env": "dev"} + mcpServers := []types.Application{ + mustMakeMCPAppWithNameAndLabels(t, "allow-read", devLabels), + mustMakeMCPAppWithNameAndLabels(t, "deny-write", devLabels), + } + accessChecker := fakeMCPServerAccessChecker{} + + tests := []struct { + name string + cf *CLIConf + wantOutput string + }{ + { + name: "text format", + cf: &CLIConf{}, + wantOutput: `Name Description Type Labels +---------- ----------- ----- ------- +allow-read description stdio env=dev +deny-write description stdio env=dev + +`, + }, + { + name: "text format in verbose", + cf: &CLIConf{ + Verbose: true, + }, + wantOutput: `Name Description Type Labels Command Args Allowed Tools +---------- ----------- ----- ------- ------- ---- ---------------------- +allow-read description stdio env=dev test arg [read_*] +deny-write description stdio env=dev test arg [*], except: [write_*] + +`, + }, + { + name: "JSON format", + cf: &CLIConf{ + Format: "json", + }, + wantOutput: `[ + { + "kind": "app", + "sub_kind": "mcp", + "version": "v3", + "metadata": { + "name": "allow-read", + "description": "description", + "labels": { + "env": "dev" + } + }, + "spec": { + "uri": "mcp+stdio://", + "insecure_skip_verify": false, + "mcp": { + "command": "test", + "args": [ + "arg" + ], + "run_as_host_user": "test" + } + }, + "permissions": { + "mcp": { + "tools": { + "allowed": [ + "read_*" + ] + } + } + } + }, + { + "kind": "app", + "sub_kind": "mcp", + "version": "v3", + "metadata": { + "name": "deny-write", + "description": "description", + "labels": { + "env": "dev" + } + }, + "spec": { + "uri": "mcp+stdio://", + "insecure_skip_verify": false, + "mcp": { + "command": "test", + "args": [ + "arg" + ], + "run_as_host_user": "test" + } + }, + "permissions": { + "mcp": { + "tools": { + "allowed": [ + "*" + ], + "denied": [ + "write_*" + ] + } + } + } + } +] +`, + }, + { + name: "YAML format", + cf: &CLIConf{ + Format: "yaml", + }, + wantOutput: `- kind: app + metadata: + description: description + labels: + env: dev + name: allow-read + permissions: + mcp: + tools: + allowed: + - read_* + spec: + insecure_skip_verify: false + mcp: + args: + - arg + command: test + run_as_host_user: test + uri: mcp+stdio:// + sub_kind: mcp + version: v3 +- kind: app + metadata: + description: description + labels: + env: dev + name: deny-write + permissions: + mcp: + tools: + allowed: + - '*' + denied: + - write_* + spec: + insecure_skip_verify: false + mcp: + args: + - arg + command: test + run_as_host_user: test + uri: mcp+stdio:// + sub_kind: mcp + version: v3 + +`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + cf := tt.cf + cf.OverrideStdout = &buf + cf.Context = context.Background() + + cmd := &mcpListCommand{ + cf: tt.cf, + mcpServers: mcpServers, + accessChecker: accessChecker, + } + + err := cmd.print() + require.NoError(t, err) + require.Equal(t, tt.wantOutput, buf.String()) + }) + } +} + +type fakeResourcesClient struct { + resources []types.ResourceWithLabels +} + +func (f *fakeResourcesClient) GetResources(ctx context.Context, req *proto.ListResourcesRequest) (*proto.ListResourcesResponse, error) { + filtered, err := matchResources(req, f.resources) + if err != nil { + return nil, trace.Wrap(err) + } + + paginatedResources, err := services.MakePaginatedResources(req.ResourceType, filtered, nil) + if err != nil { + return nil, trace.Wrap(err) + } + return &proto.ListResourcesResponse{ + Resources: paginatedResources, + TotalCount: int32(len(filtered)), + }, nil +} + +func mustMakeNewAppServer(t *testing.T, app *types.AppV3, host string) types.AppServer { + t.Helper() + appServer, err := types.NewAppServerV3FromApp(app, host, host) + require.NoError(t, err) + return appServer +} + +func mustMakeMCPAppWithNameAndLabels(t *testing.T, name string, labels map[string]string) *types.AppV3 { + t.Helper() + return mustMakeNewAppV3(t, + types.Metadata{ + Name: name, + Description: "description", + Labels: labels, + }, + types.AppSpecV3{ + MCP: &types.MCP{ + Command: "test", + Args: []string{"arg"}, + RunAsHostUser: "test", + }, + }, + ) +} + +type fakeMCPServerAccessChecker struct { + services.AccessChecker +} + +func (f fakeMCPServerAccessChecker) EnumerateMCPTools(app types.Application) services.EnumerationResult { + switch app.GetName() { + case "allow-read": + return services.NewEnumerationResultFromEntities([]string{"read_*"}, nil) + case "deny-write": + return services.NewEnumerationResultFromEntities([]string{"*"}, []string{"write_*"}) + default: + return services.NewEnumerationResult() + } +} diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index 550811255519b..f885487c5341f 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -1339,6 +1339,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { gitCmd := newGitCommands(app) pivCmd := newPIVCommands(app) + mcpCmd := newMCPCommands(app, &cf) if runtime.GOOS == constants.WindowsOS { bench.Hidden() @@ -1745,6 +1746,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { err = gitCmd.clone.run(&cf) case pivCmd.agent.FullCommand(): err = pivCmd.agent.run(&cf) + case mcpCmd.list.FullCommand(): + err = mcpCmd.list.run() default: // Handle commands that might not be available. switch { @@ -3202,14 +3205,7 @@ func getDBUsers(db types.Database, accessChecker services.AccessChecker) *dbUser ) return &dbUsers{} } - var denied []string - allowed := users.Allowed() - if users.WildcardAllowed() { - // start the list with *. - allowed = append([]string{types.Wildcard}, allowed...) - // only include denied users if the wildcard is allowed. - denied = append(denied, users.Denied()...) - } + allowed, denied := users.ToEntities() return &dbUsers{ Allowed: allowed, Denied: denied, @@ -3274,9 +3270,6 @@ func formatUsersForDB(database types.Database, accessChecker services.AccessChec } dbUsers := getDBUsers(database, accessChecker) - if len(dbUsers.Allowed) == 0 { - return "(none)" - } // Add a note for auto-provisioned user. if database.IsAutoUsersEnabled() { @@ -3293,10 +3286,7 @@ func formatUsersForDB(database types.Database, accessChecker services.AccessChec } } - if len(dbUsers.Denied) == 0 { - return fmt.Sprintf("%v", dbUsers.Allowed) - } - return fmt.Sprintf("%v, except: %v", dbUsers.Allowed, dbUsers.Denied) + return common.FormatAllowedEntities(dbUsers.Allowed, dbUsers.Denied) } // TODO(greedy52) more refactoring on db printing and move them to db_print.go. diff --git a/tool/tsh/common/tsh_test.go b/tool/tsh/common/tsh_test.go index 81b56f63fe03e..e30c0c8532e2b 100644 --- a/tool/tsh/common/tsh_test.go +++ b/tool/tsh/common/tsh_test.go @@ -4366,9 +4366,7 @@ func TestSerializeDatabases(t *testing.T) { dbUsersData: `, "users": { "allowed": [ - "*", - "bar", - "foo" + "*" ], "denied": [ "baz", @@ -4423,9 +4421,7 @@ func TestSerializeDatabases(t *testing.T) { dbUsersData: `, "users": { "allowed": [ - "*", - "bar", - "foo" + "*" ] }`, roles: services.RoleSet{ From 94e5218f7873bf316fcbc3e8f5766555a7f4aad0 Mon Sep 17 00:00:00 2001 From: "STeve (Xin) Huang" Date: Wed, 4 Jun 2025 14:31:07 -0400 Subject: [PATCH 2/2] address feedback --- lib/services/role.go | 54 ++++++++++++++++----------------- lib/services/role_test.go | 37 +++++----------------- lib/utils/slices/slices.go | 18 ----------- lib/utils/slices/slices_test.go | 22 -------------- tool/tsh/common/app.go | 4 +++ tool/tsh/common/mcp_app.go | 2 +- tool/tsh/common/proxy.go | 4 +++ tool/tsh/common/tsh_test.go | 8 +++-- 8 files changed, 50 insertions(+), 99 deletions(-) diff --git a/lib/services/role.go b/lib/services/role.go index a654584f293f1..151d81800ec42 100644 --- a/lib/services/role.go +++ b/lib/services/role.go @@ -53,7 +53,6 @@ import ( awsutils "github.com/gravitational/teleport/lib/utils/aws" logutils "github.com/gravitational/teleport/lib/utils/log" "github.com/gravitational/teleport/lib/utils/parse" - slicesutils "github.com/gravitational/teleport/lib/utils/slices" ) // DefaultImplicitRules provides access to the default set of implicit rules @@ -1030,15 +1029,18 @@ func (result *EnumerationResult) WildcardDenied() bool { return result.wildcardDenied } -// ToEntities converts result back to allowed and denied entity slices. Wildcard -// will be placed in the slice if any. +// ToEntities converts result back to allowed and denied entity slices. +// +// If wildcard is denied, only "*" is returned for the denied slice. +// If wildcard is allowed, allowed entities will be appended to the allowed +// slice after the "*" as a hint for users to select. +// Denied entities is only included if the wildcard is allowed. func (result *EnumerationResult) ToEntities() (allowed, denied []string) { if result.wildcardDenied { return nil, []string{types.Wildcard} } - // Only include denied entities if the wildcard is allowed. if result.wildcardAllowed { - return []string{types.Wildcard}, result.Denied() + return append([]string{types.Wildcard}, result.Allowed()...), result.Denied() } return result.Allowed(), nil } @@ -1055,30 +1057,28 @@ func NewEnumerationResult() EnumerationResult { // NewEnumerationResultFromEntities creates a new EnumerationResult and // populates the result with provided allowed and denied entries. func NewEnumerationResultFromEntities(allowed, denied []string) EnumerationResult { - switch { - // Wildcard deny, ignore all other entries. - case slices.Contains(denied, types.Wildcard): - return EnumerationResult{ - allowedDeniedMap: make(map[string]bool), - wildcardDenied: true, - } - - // Wildcard allow, ignore all allowed entries. - case slices.Contains(allowed, types.Wildcard): - return EnumerationResult{ - allowedDeniedMap: slicesutils.ToMapWithDefaultValue(denied, false), - wildcardAllowed: true, - } - - // No wildcard. - default: - allowedDeniedMap := slicesutils.ToMapWithDefaultValue(allowed, true) - for _, deny := range denied { - delete(allowedDeniedMap, deny) + var wildcardAllowed bool + var wildcardDenied bool + allowedDeniedMap := make(map[string]bool) + for _, allow := range allowed { + if allow == types.Wildcard { + wildcardAllowed = true + } else { + allowedDeniedMap[allow] = true } - return EnumerationResult{ - allowedDeniedMap: allowedDeniedMap, + } + for _, deny := range denied { + if deny == types.Wildcard { + wildcardDenied = true + wildcardAllowed = false + break } + allowedDeniedMap[deny] = false + } + return EnumerationResult{ + allowedDeniedMap: allowedDeniedMap, + wildcardAllowed: wildcardAllowed, + wildcardDenied: wildcardDenied, } } diff --git a/lib/services/role_test.go b/lib/services/role_test.go index 0a42640233be8..1f312befc001e 100644 --- a/lib/services/role_test.go +++ b/lib/services/role_test.go @@ -10072,59 +10072,38 @@ func TestNewEnumerationResultFromEntities(t *testing.T) { name string inputAllowed []string inputDenied []string - wantResult EnumerationResult wantAllowed []string wantDenied []string }{ { name: "empty", - wantResult: EnumerationResult{ - allowedDeniedMap: make(map[string]bool), - }, }, { name: "wildcard denied", inputAllowed: []string{"allow_entry"}, inputDenied: []string{"deny_entry", "*"}, - wantResult: EnumerationResult{ - allowedDeniedMap: make(map[string]bool), - wildcardDenied: true, - }, - wantDenied: []string{"*"}, + wantDenied: []string{"*"}, }, { name: "wildcard allowed", - inputAllowed: []string{"allow_entry", "*"}, - inputDenied: []string{"deny_entry"}, - wantResult: EnumerationResult{ - allowedDeniedMap: map[string]bool{ - "deny_entry": false, - }, - wildcardAllowed: true, - }, - wantAllowed: []string{"*"}, - wantDenied: []string{"deny_entry"}, + inputAllowed: []string{"allow_entry", "*", "deny_overwrite"}, + inputDenied: []string{"deny_overwrite"}, + wantAllowed: []string{"*", "allow_entry"}, + wantDenied: []string{"deny_overwrite"}, }, { name: "no wildcard", inputAllowed: []string{"allow_entry_1", "deny_overwrite", "allow_entry_2"}, inputDenied: []string{"deny_overwrite", "deny_entry"}, - wantResult: EnumerationResult{ - allowedDeniedMap: map[string]bool{ - "allow_entry_1": true, - "allow_entry_2": true, - }, - }, - wantAllowed: []string{"allow_entry_1", "allow_entry_2"}, + wantAllowed: []string{"allow_entry_1", "allow_entry_2"}, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - actualResult := NewEnumerationResultFromEntities(test.inputAllowed, test.inputDenied) - require.Equal(t, test.wantResult, actualResult) + result := NewEnumerationResultFromEntities(test.inputAllowed, test.inputDenied) - actualAllowed, actualDenied := actualResult.ToEntities() + actualAllowed, actualDenied := result.ToEntities() require.Equal(t, test.wantAllowed, actualAllowed) require.Equal(t, test.wantDenied, actualDenied) }) diff --git a/lib/utils/slices/slices.go b/lib/utils/slices/slices.go index 115905e4446c8..5e524b7f1ed9b 100644 --- a/lib/utils/slices/slices.go +++ b/lib/utils/slices/slices.go @@ -80,21 +80,3 @@ func DeduplicateKey[T any](s []T, key func(T) string) []T { } return out } - -// ToMap converts elements in a slice as the keys to a map, and use provided -// function to calculate their values in the map. -func ToMap[K comparable, V any, S ~[]K](s S, makeValue func(K) V) map[K]V { - m := make(map[K]V, len(s)) - for _, key := range s { - m[key] = makeValue(key) - } - return m -} - -// ToMapWithDefaultValue converts elements in a slice as the keys to a map, and -// all values in the map are set to provided default value. -func ToMapWithDefaultValue[K comparable, V any, S ~[]K](s S, defaultValue V) map[K]V { - return ToMap(s, func(K) V { - return defaultValue - }) -} diff --git a/lib/utils/slices/slices_test.go b/lib/utils/slices/slices_test.go index 6a0925480af79..b0fe86a1bfd2e 100644 --- a/lib/utils/slices/slices_test.go +++ b/lib/utils/slices/slices_test.go @@ -20,7 +20,6 @@ package slices import ( "fmt" - "strconv" "strings" "testing" @@ -183,24 +182,3 @@ func TestDeduplicateKey(t *testing.T) { }) } } - -func TestToMap(t *testing.T) { - numbers := []int{1, 2, 3, 4} - m := ToMap(numbers, strconv.Itoa) - require.Equal(t, map[int]string{ - 1: "1", - 2: "2", - 3: "3", - 4: "4", - }, m) -} - -func TestToMapWithDefaultValue(t *testing.T) { - s := []string{"a", "b", "c"} - m := ToMapWithDefaultValue(s, true) - require.Equal(t, map[string]bool{ - "a": true, - "b": true, - "c": true, - }, m) -} diff --git a/tool/tsh/common/app.go b/tool/tsh/common/app.go index 43d20a558898f..7de3d1c2d97ca 100644 --- a/tool/tsh/common/app.go +++ b/tool/tsh/common/app.go @@ -79,6 +79,10 @@ func onAppLogin(cf *CLIConf) error { } defer clusterClient.Close() + if app.IsMCP() { + return trace.BadParameter("MCP applications are not supported. Please see 'tsh mcp login --help' for more details.") + } + if err := validateTargetPort(app, int(cf.TargetPort)); err != nil { return trace.Wrap(err) } diff --git a/tool/tsh/common/mcp_app.go b/tool/tsh/common/mcp_app.go index f709e2b60e057..9a1958473f9cb 100644 --- a/tool/tsh/common/mcp_app.go +++ b/tool/tsh/common/mcp_app.go @@ -147,7 +147,7 @@ func fetchMCPServers(ctx context.Context, tc *client.TeleportClient, auth apicli func withMCPServerAppFilter(predicateExpression string) string { return makePredicateConjunction( predicateExpression, - "hasPrefix(resource.spec.uri, \"mcp+\")", + `resource.sub_kind == "mcp"`, ) } diff --git a/tool/tsh/common/proxy.go b/tool/tsh/common/proxy.go index a8cc0b832a472..484507550af34 100644 --- a/tool/tsh/common/proxy.go +++ b/tool/tsh/common/proxy.go @@ -501,6 +501,10 @@ func onProxyCommandApp(cf *CLIConf) error { return trace.Wrap(err) } + if app.IsMCP() { + return trace.BadParameter("MCP applications are not supported. Please see 'tsh mcp login --help' for more details.") + } + proxyApp, err := newLocalProxyAppWithPortMapping(cf.Context, tc, profile, appInfo.RouteToApp, app, portMapping, cf.InsecureSkipVerify) if err != nil { return trace.Wrap(err) diff --git a/tool/tsh/common/tsh_test.go b/tool/tsh/common/tsh_test.go index e30c0c8532e2b..81b56f63fe03e 100644 --- a/tool/tsh/common/tsh_test.go +++ b/tool/tsh/common/tsh_test.go @@ -4366,7 +4366,9 @@ func TestSerializeDatabases(t *testing.T) { dbUsersData: `, "users": { "allowed": [ - "*" + "*", + "bar", + "foo" ], "denied": [ "baz", @@ -4421,7 +4423,9 @@ func TestSerializeDatabases(t *testing.T) { dbUsersData: `, "users": { "allowed": [ - "*" + "*", + "bar", + "foo" ] }`, roles: services.RoleSet{