Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions api/types/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package types

import (
"fmt"
"iter"
"net/url"
"slices"
"strconv"
Expand Down Expand Up @@ -309,6 +310,9 @@ func (a *AppV3) GetProtocol() string {
if a.IsTCP() {
return "TCP"
}
if a.IsMCP() {
return "MCP"
}
return "HTTP"
}

Expand Down Expand Up @@ -565,18 +569,27 @@ func (a *AppV3) GetMCP() *MCP {

// DeduplicateApps deduplicates apps by combination of app name and public address.
// Apps can have the same name but also could have different addresses.
func DeduplicateApps(apps []Application) (result []Application) {
func DeduplicateApps(apps []Application) []Application {
return slices.Collect(DeduplicatedApps(slices.Values(apps)))
}

// DeduplicatedApps iterates deduplicated apps by combination of app name and
// public address. This is the iter.Seq version of DeduplicateApps.
func DeduplicatedApps(apps iter.Seq[Application]) iter.Seq[Application] {
type key struct{ name, addr string }
seen := make(map[key]struct{})
for _, app := range apps {
key := key{app.GetName(), app.GetPublicAddr()}
if _, ok := seen[key]; ok {
continue
return func(yield func(Application) bool) {
for app := range apps {
key := key{app.GetName(), app.GetPublicAddr()}
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
if !yield(app) {
return
}
}
seen[key] = struct{}{}
result = append(result, app)
}
return result
}

// Apps is a list of app resources.
Expand Down
17 changes: 17 additions & 0 deletions api/types/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package types

import (
"fmt"
"slices"
"strconv"
"testing"

Expand Down Expand Up @@ -736,3 +737,19 @@ func TestGetMCPServerTransportType(t *testing.T) {
})
}
}

func TestDeduplicateApps(t *testing.T) {
var apps []Application
for _, name := range []string{"a", "b", "c", "b", "a", "d"} {
app_, err := NewAppV3(Metadata{
Name: name,
}, AppSpecV3{
URI: "localhost:3080",
})
require.NoError(t, err)
apps = append(apps, app_)
}

deduped := DeduplicateApps(apps)
require.Equal(t, []string{"a", "b", "c", "d"}, slices.Collect(ResourceNames(deduped)))
}
10 changes: 10 additions & 0 deletions api/types/appserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package types

import (
"fmt"
"iter"
"slices"
"sort"
"time"

Expand All @@ -26,6 +28,7 @@ import (
"github.com/gravitational/teleport/api"
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/api/utils/iterutils"
)

// AppServer represents a single proxied web app.
Expand Down Expand Up @@ -409,3 +412,10 @@ func (s AppServers) GetFieldVals(field string) ([]string, error) {

return vals, nil
}

// Applications iterates over the applications that the AppServers proxy.
func (s AppServers) Applications() iter.Seq[Application] {
return iterutils.Map(func(appServer AppServer) Application {
return appServer.GetApp()
}, slices.Values(s))
}
5 changes: 5 additions & 0 deletions api/types/resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ func ResourceNames[R Resource, S ~[]R](s S) iter.Seq[string] {
return iterutils.Map(GetName, slices.Values(s))
}

// CompareResourceByNames compares resources by their names.
func CompareResourceByNames[R Resource](a, b R) int {
return strings.Compare(a.GetName(), b.GetName())
}

// ResourceDetails includes details about the resource
type ResourceDetails struct {
Hostname string
Expand Down
44 changes: 44 additions & 0 deletions lib/services/role.go
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,22 @@ func (result *EnumerationResult) WildcardDenied() bool {
return result.wildcardDenied
}

// ToEntities converts result back to allowed and denied entity slices.
//
// If wildcard is denied, only "*" is returned for the denied slice.
// If wildcard is allowed, allowed entities will be appended to the allowed
// slice after the "*" as a hint for users to select.
// Denied entities is only included if the wildcard is allowed.
func (result *EnumerationResult) ToEntities() (allowed, denied []string) {
if result.wildcardDenied {
return nil, []string{types.Wildcard}
}
if result.wildcardAllowed {
return append([]string{types.Wildcard}, result.Allowed()...), result.Denied()
}
return result.Allowed(), nil
}

// NewEnumerationResult returns new EnumerationResult.
func NewEnumerationResult() EnumerationResult {
return EnumerationResult{
Expand All @@ -1038,6 +1054,34 @@ func NewEnumerationResult() EnumerationResult {
}
}

// NewEnumerationResultFromEntities creates a new EnumerationResult and
// populates the result with provided allowed and denied entries.
func NewEnumerationResultFromEntities(allowed, denied []string) EnumerationResult {
var wildcardAllowed bool
var wildcardDenied bool
allowedDeniedMap := make(map[string]bool)
for _, allow := range allowed {
if allow == types.Wildcard {
wildcardAllowed = true
} else {
allowedDeniedMap[allow] = true
}
}
for _, deny := range denied {
if deny == types.Wildcard {
wildcardDenied = true
wildcardAllowed = false
break
}
allowedDeniedMap[deny] = false
}
return EnumerationResult{
allowedDeniedMap: allowedDeniedMap,
wildcardAllowed: wildcardAllowed,
wildcardDenied: wildcardDenied,
}
}

// MatchNamespace returns true if given list of namespace matches
// target namespace, wildcard matches everything.
func MatchNamespace(selectors []string, namespace string) (bool, string) {
Expand Down
43 changes: 43 additions & 0 deletions lib/services/role_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10066,3 +10066,46 @@ func TestMCPToolMatcher(t *testing.T) {
})
}
}

func TestNewEnumerationResultFromEntities(t *testing.T) {
tests := []struct {
name string
inputAllowed []string
inputDenied []string
wantAllowed []string
wantDenied []string
}{
{
name: "empty",
},
{
name: "wildcard denied",
inputAllowed: []string{"allow_entry"},
inputDenied: []string{"deny_entry", "*"},
wantDenied: []string{"*"},
},
{
name: "wildcard allowed",
inputAllowed: []string{"allow_entry", "*", "deny_overwrite"},
inputDenied: []string{"deny_overwrite"},
wantAllowed: []string{"*", "allow_entry"},
wantDenied: []string{"deny_overwrite"},
},
{
name: "no wildcard",
inputAllowed: []string{"allow_entry_1", "deny_overwrite", "allow_entry_2"},
inputDenied: []string{"deny_overwrite", "deny_entry"},
wantAllowed: []string{"allow_entry_1", "allow_entry_2"},
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result := NewEnumerationResultFromEntities(test.inputAllowed, test.inputDenied)

actualAllowed, actualDenied := result.ToEntities()
require.Equal(t, test.wantAllowed, actualAllowed)
require.Equal(t, test.wantDenied, actualDenied)
})
}
}
33 changes: 33 additions & 0 deletions tool/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"sort"
"strings"

"github.com/ghodss/yaml"
"github.com/gravitational/trace"

"github.com/gravitational/teleport"
Expand Down Expand Up @@ -238,3 +239,35 @@ func FormatDefault[T comparable](val, defaultVal T) string {
}
return fmt.Sprintf("%v", val)
}

// FormatAllowedEntities returns a human-readable string describing the allowed
// entities, optionally including a list of denied entities as exceptions.
func FormatAllowedEntities(allowed []string, denied []string) string {
if len(allowed) == 0 {
return "(none)"
}
if len(denied) == 0 {
return fmt.Sprintf("%v", allowed)
}
return fmt.Sprintf("%v, except: %v", allowed, denied)
}

// PrintJSONIndent prints provided value in JSON with default indentation.
func PrintJSONIndent(w io.Writer, v any) error {
out, err := utils.FastMarshalIndent(v, "", " ")
if err != nil {
return trace.Wrap(err)
}
_, err = fmt.Fprintln(w, string(out))
return trace.Wrap(err)
}

// PrintYAML prints provided value in YAML.
func PrintYAML(w io.Writer, v any) error {
out, err := yaml.Marshal(v)
if err != nil {
return trace.Wrap(err)
}
_, err = fmt.Fprintln(w, string(out))
return trace.Wrap(err)
}
4 changes: 4 additions & 0 deletions tool/tsh/common/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ func onAppLogin(cf *CLIConf) error {
}
defer clusterClient.Close()

if app.IsMCP() {
return trace.BadParameter("MCP applications are not supported. Please see 'tsh mcp login --help' for more details.")
}

if err := validateTargetPort(app, int(cf.TargetPort)); err != nil {
return trace.Wrap(err)
}
Expand Down
6 changes: 5 additions & 1 deletion tool/tsh/common/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -1005,12 +1005,16 @@ func (d *databaseInfo) getChecker(ctx context.Context, tc *client.TeleportClient
}
defer clusterClient.Close()

return makeAccessChecker(ctx, tc, clusterClient.AuthClient)
}

func makeAccessChecker(ctx context.Context, tc *client.TeleportClient, auth services.CurrentUserRoleGetter) (services.AccessChecker, error) {
profile, err := tc.ProfileStatus()
if err != nil {
return nil, trace.Wrap(err)
}

checker, err := services.NewAccessCheckerForRemoteCluster(ctx, profile.AccessInfo(), tc.SiteName, clusterClient.AuthClient)
checker, err := services.NewAccessCheckerForRemoteCluster(ctx, profile.AccessInfo(), tc.SiteName, auth)
return checker, trace.Wrap(err)
}

Expand Down
16 changes: 12 additions & 4 deletions tool/tsh/common/db_exec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,14 @@ func (c *fakeDatabaseExecClient) issueCert(context.Context, *databaseInfo) (tls.
return c.cert, nil
}
func (c *fakeDatabaseExecClient) listDatabasesWithFilter(ctx context.Context, req *proto.ListResourcesRequest) ([]types.Database, error) {
filtered, err := matchResources(req, c.allDatabaseServers)
if err != nil {
return nil, trace.Wrap(err)
}
return types.DatabaseServers(filtered).ToDatabases(), nil
}

func matchResources[R types.ResourceWithLabels](req *proto.ListResourcesRequest, s []R) ([]R, error) {
filter := services.MatchResourceFilter{
ResourceKind: req.ResourceType,
Labels: req.Labels,
Expand All @@ -339,13 +347,13 @@ func (c *fakeDatabaseExecClient) listDatabasesWithFilter(ctx context.Context, re
filter.PredicateExpression = expression
}

var filtered []types.Database
for _, dbServer := range c.allDatabaseServers {
match, err := services.MatchResourceByFilters(dbServer, filter, nil)
var filtered []R
for _, r := range s {
match, err := services.MatchResourceByFilters(r, filter, nil)
if err != nil {
return nil, trace.Wrap(err)
} else if match {
filtered = append(filtered, dbServer.GetDatabase())
filtered = append(filtered, r)
}
}
return filtered, nil
Expand Down
8 changes: 6 additions & 2 deletions tool/tsh/common/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,20 @@

package common

import "github.com/alecthomas/kingpin/v2"
import (
"github.com/alecthomas/kingpin/v2"
)

type mcpCommands struct {
dbStart *mcpDBStartCommand
list *mcpListCommand
}

func newMCPCommands(app *kingpin.Application) *mcpCommands {
func newMCPCommands(app *kingpin.Application, cf *CLIConf) *mcpCommands {
mcp := app.Command("mcp", "View and control proxied MCP servers.")
db := mcp.Command("db", "Database access for MCP servers.")
return &mcpCommands{
dbStart: newMCPDBCommand(db),
list: newMCPListCommand(mcp, cf),
}
}
Loading
Loading