Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion api/types/appserver_or_saml_idp_sp.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ func (a *AppServerOrSAMLIdPServiceProviderV1) GetLabel(key string) (value string
v, ok := appServer.Spec.App.Metadata.Labels[key]
return v, ok
} else {
return "", true
sp := a.GetSAMLIdPServiceProvider()
v, ok := sp.Metadata.Labels[key]
return v, ok
}
}

Expand Down
36 changes: 22 additions & 14 deletions api/types/resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -495,13 +495,8 @@ func (m *Metadata) CheckAndSetDefaults() error {
// MatchLabels takes a map of labels and returns `true` if the resource has ALL
// of them.
func MatchLabels(resource ResourceWithLabels, labels map[string]string) bool {
if len(labels) == 0 {
return true
}

resourceLabels := resource.GetAllLabels()
for name, value := range labels {
if resourceLabels[name] != value {
for key, value := range labels {
if v, ok := resource.GetLabel(key); !ok || v != value {
return false
}
}
Expand Down Expand Up @@ -535,15 +530,11 @@ func IsValidLabelKey(s string) bool {
// Returns true if all search vals were matched (or if nil search vals).
// Returns false if no or partial match (or nil field values).
func MatchSearch(fieldVals []string, searchVals []string, customMatch func(val string) bool) bool {
// Case fold all values to avoid repeated case folding while matching.
caseFoldedSearchVals := utils.ToLowerStrings(searchVals)
caseFoldedFieldVals := utils.ToLowerStrings(fieldVals)

Outer:
for _, searchV := range caseFoldedSearchVals {
for _, searchV := range searchVals {
// Iterate through field values to look for a match.
for _, fieldV := range caseFoldedFieldVals {
if strings.Contains(fieldV, searchV) {
for _, fieldV := range fieldVals {
if containsFold(fieldV, searchV) {
continue Outer
}
}
Expand All @@ -559,6 +550,23 @@ Outer:
return true
}

// containsFold is a case-insensitive alternative to strings.Contains, used to help avoid excess allocations during searches.
func containsFold(s, substr string) bool {
if len(s) < len(substr) {
return false
}

n := len(s) - len(substr)

for i := 0; i <= n; i++ {
if strings.EqualFold(s[i:i+len(substr)], substr) {
return true
}
}

return false
}

func stringCompare(a string, b string, isDesc bool) bool {
if isDesc {
return a > b
Expand Down
31 changes: 21 additions & 10 deletions api/types/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,11 @@ func (s *ServerV2) GetAllLabels() map[string]string {

// CombineLabels combines the passed in static and dynamic labels.
func CombineLabels(static map[string]string, dynamic map[string]CommandLabelV2) map[string]string {
lmap := make(map[string]string)
if len(dynamic) == 0 {
return static
}

lmap := make(map[string]string, len(static)+len(dynamic))
for key, value := range static {
lmap[key] = value
}
Expand Down Expand Up @@ -492,20 +496,27 @@ func (s *ServerV2) CheckAndSetDefaults() error {
// MatchSearch goes through select field values and tries to
// match against the list of search values.
func (s *ServerV2) MatchSearch(values []string) bool {
var fieldVals []string
if s.GetKind() != KindNode {
return false
}

var custom func(val string) bool
if s.GetUseTunnel() {
custom = func(val string) bool {
return strings.EqualFold(val, "tunnel")
}
}

if s.GetKind() == KindNode {
fieldVals = append(utils.MapToStrings(s.GetAllLabels()), s.GetName(), s.GetHostname(), s.GetAddr())
fieldVals = append(fieldVals, s.GetPublicAddrs()...)
fieldVals := make([]string, 0, (len(s.Metadata.Labels)*2)+(len(s.Spec.CmdLabels)*2)+len(s.Spec.PublicAddrs)+3)

if s.GetUseTunnel() {
custom = func(val string) bool {
return strings.EqualFold(val, "tunnel")
}
}
labels := CombineLabels(s.Metadata.Labels, s.Spec.CmdLabels)
for key, value := range labels {
fieldVals = append(fieldVals, key, value)
}

fieldVals = append(fieldVals, s.Metadata.Name, s.Spec.Hostname, s.Spec.Addr)
fieldVals = append(fieldVals, s.Spec.PublicAddrs...)

return MatchSearch(fieldVals, values, custom)
}

Expand Down
46 changes: 38 additions & 8 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -1329,18 +1329,34 @@ func (a *ServerWithRoles) checkUnifiedAccess(resource types.ResourceWithLabels,
return false, trace.Wrap(canAccessErr)
}

if resourceKind != types.KindSAMLIdPServiceProvider {
if err := checker.CanAccess(resource); err != nil {
// Filter first and only check RBAC if there is a match to improve perf.
match, err := services.MatchResourceByFilters(resource, filter, nil)
if err != nil {
log.WithFields(logrus.Fields{
"resource_name": resource.GetName(),
"resource_kind": resourceKind,
"error": err,
}).
Warn("Unable to determine access to resource, matching with filter failed")
return false, nil
}

if trace.IsAccessDenied(err) {
return false, nil
}
return false, trace.Wrap(err)
if !match {
return false, nil
}

if resourceKind == types.KindSAMLIdPServiceProvider {
return true, nil
}

if err := checker.CanAccess(resource); err != nil {
if trace.IsAccessDenied(err) {
return false, nil
}
return false, trace.Wrap(err)
}

match, err := services.MatchResourceByFilters(resource, filter, nil)
return match, trace.Wrap(err)
return true, nil
}

// ListUnifiedResources returns a paginated list of unified resources filtered by user access.
Expand All @@ -1358,6 +1374,20 @@ func (a *ServerWithRoles) ListUnifiedResources(ctx context.Context, req *proto.L
Kinds: req.Kinds,
}

// If a predicate expression was provided, evaluate it with an empty
// server to determine if the expression is valid before attempting
// to do any listing.
if filter.PredicateExpression != "" {
parser, err := services.NewResourceParser(&types.ServerV2{})
if err != nil {
return nil, trace.Wrap(err)
}

if _, err := parser.EvalBoolPredicate(filter.PredicateExpression); err != nil {
return nil, trace.BadParameter("failed to parse predicate expression: %s", err.Error())
}
}

// Populate resourceAccessMap with any access errors the user has for each possible
// resource kind. This allows the access check to occur a single time per resource
// kind instead of once per matching resource.
Expand Down
137 changes: 133 additions & 4 deletions lib/auth/auth_with_roles_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4721,6 +4721,130 @@ func TestListUnifiedResources_WithPredicate(t *testing.T) {
require.Error(t, err)
}

func BenchmarkListUnifiedResourcesFilter(b *testing.B) {
const nodeCount = 150_000
const roleCount = 32

logger := logrus.StandardLogger()
logger.ReplaceHooks(make(logrus.LevelHooks))
logrus.SetFormatter(logutils.NewTestJSONFormatter())
logger.SetLevel(logrus.PanicLevel)
logger.SetOutput(io.Discard)

ctx := context.Background()
srv := newTestTLSServer(b)

var ids []string
for i := 0; i < roleCount; i++ {
ids = append(ids, uuid.New().String())
}

ids[0] = "hidden"

var hiddenNodes int
// Create test nodes.
for i := 0; i < nodeCount; i++ {
name := uuid.New().String()
id := ids[i%len(ids)]
if id == "hidden" {
hiddenNodes++
}

labels := map[string]string{
"kEy": id,
"grouP": "useRs",
}

if i == 10 {
labels["ip"] = "10.20.30.40"
labels["ADDRESS"] = "10.20.30.41"
labels["food"] = "POTATO"
}

node, err := types.NewServerWithLabels(
name,
types.KindNode,
types.ServerSpecV2{},
labels,
)
require.NoError(b, err)

_, err = srv.Auth().UpsertNode(ctx, node)
require.NoError(b, err)
}

b.Run("labels", func(b *testing.B) {
benchmarkListUnifiedResources(
b, ctx,
1,
srv,
ids,
func(role types.Role, id string) {
role.SetNodeLabels(types.Allow, types.Labels{types.Wildcard: []string{types.Wildcard}})
},
func(req *proto.ListUnifiedResourcesRequest) {
req.Labels = map[string]string{"ip": "10.20.30.40"}
},
)
})
b.Run("predicate path", func(b *testing.B) {
benchmarkListUnifiedResources(
b, ctx,
1,
srv,
ids,
func(role types.Role, id string) {
role.SetNodeLabels(types.Allow, types.Labels{types.Wildcard: []string{types.Wildcard}})
},
func(req *proto.ListUnifiedResourcesRequest) {
req.PredicateExpression = `labels.ip == "10.20.30.40"`
},
)
})
b.Run("predicate index", func(b *testing.B) {
benchmarkListUnifiedResources(
b, ctx,
1,
srv,
ids,
func(role types.Role, id string) {
role.SetNodeLabels(types.Allow, types.Labels{types.Wildcard: []string{types.Wildcard}})
},
func(req *proto.ListUnifiedResourcesRequest) {
req.PredicateExpression = `labels["ip"] == "10.20.30.40"`
},
)
})
b.Run("search lower", func(b *testing.B) {
benchmarkListUnifiedResources(
b, ctx,
1,
srv,
ids,
func(role types.Role, id string) {
role.SetNodeLabels(types.Allow, types.Labels{types.Wildcard: []string{types.Wildcard}})
},
func(req *proto.ListUnifiedResourcesRequest) {
req.SearchKeywords = []string{"10.20.30.40"}
},
)
})
b.Run("search upper", func(b *testing.B) {
benchmarkListUnifiedResources(
b, ctx,
1,
srv,
ids,
func(role types.Role, id string) {
role.SetNodeLabels(types.Allow, types.Labels{types.Wildcard: []string{types.Wildcard}})
},
func(req *proto.ListUnifiedResourcesRequest) {
req.SearchKeywords = []string{"POTATO"}
},
)
})
}

// go test ./lib/auth -bench=BenchmarkListUnifiedResources -run=^$ -v -benchtime 1x
// goos: darwin
// goarch: arm64
Expand All @@ -4731,7 +4855,7 @@ func TestListUnifiedResources_WithPredicate(t *testing.T) {
// PASS
// ok github.com/gravitational/teleport/lib/auth 2.878s
func BenchmarkListUnifiedResources(b *testing.B) {
const nodeCount = 50_000
const nodeCount = 150_000
const roleCount = 32

logger := logrus.StandardLogger()
Expand Down Expand Up @@ -4791,21 +4915,23 @@ func BenchmarkListUnifiedResources(b *testing.B) {
b.Run(tc.desc, func(b *testing.B) {
benchmarkListUnifiedResources(
b, ctx,
nodeCount, hiddenNodes,
nodeCount-hiddenNodes,
srv,
ids,
tc.editRole,
func(req *proto.ListUnifiedResourcesRequest) {},
)
})
}
}

func benchmarkListUnifiedResources(
b *testing.B, ctx context.Context,
nodeCount, hiddenNodes int,
expectedCount int,
srv *TestTLSServer,
ids []string,
editRole func(r types.Role, id string),
editReq func(req *proto.ListUnifiedResourcesRequest),
) {
var roles []types.Role
for _, id := range ids {
Expand Down Expand Up @@ -4839,6 +4965,9 @@ func benchmarkListUnifiedResources(
SortBy: types.SortBy{IsDesc: false, Field: types.ResourceMetadataName},
Limit: 1_000,
}

editReq(req)

for {
rsp, err := clt.ListUnifiedResources(ctx, req)
require.NoError(b, err)
Expand All @@ -4849,7 +4978,7 @@ func benchmarkListUnifiedResources(
break
}
}
require.Len(b, resources, nodeCount-hiddenNodes)
require.Len(b, resources, expectedCount)
}
}

Expand Down
Loading