diff --git a/api/types/appserver_or_saml_idp_sp.go b/api/types/appserver_or_saml_idp_sp.go index 14465f9869179..9ca138630065c 100644 --- a/api/types/appserver_or_saml_idp_sp.go +++ b/api/types/appserver_or_saml_idp_sp.go @@ -193,7 +193,9 @@ func (a *AppServerOrSAMLIdPServiceProviderV1) GetLabel(key string) (value string v, ok := appServer.Spec.App.Metadata.Labels[key] return v, ok } else { - return "", true + sp := a.GetSAMLIdPServiceProvider() + v, ok := sp.Metadata.Labels[key] + return v, ok } } diff --git a/api/types/resource.go b/api/types/resource.go index 57293f4729046..22f11ac2b2b70 100644 --- a/api/types/resource.go +++ b/api/types/resource.go @@ -495,13 +495,8 @@ func (m *Metadata) CheckAndSetDefaults() error { // MatchLabels takes a map of labels and returns `true` if the resource has ALL // of them. func MatchLabels(resource ResourceWithLabels, labels map[string]string) bool { - if len(labels) == 0 { - return true - } - - resourceLabels := resource.GetAllLabels() - for name, value := range labels { - if resourceLabels[name] != value { + for key, value := range labels { + if v, ok := resource.GetLabel(key); !ok || v != value { return false } } @@ -535,15 +530,11 @@ func IsValidLabelKey(s string) bool { // Returns true if all search vals were matched (or if nil search vals). // Returns false if no or partial match (or nil field values). func MatchSearch(fieldVals []string, searchVals []string, customMatch func(val string) bool) bool { - // Case fold all values to avoid repeated case folding while matching. - caseFoldedSearchVals := utils.ToLowerStrings(searchVals) - caseFoldedFieldVals := utils.ToLowerStrings(fieldVals) - Outer: - for _, searchV := range caseFoldedSearchVals { + for _, searchV := range searchVals { // Iterate through field values to look for a match. - for _, fieldV := range caseFoldedFieldVals { - if strings.Contains(fieldV, searchV) { + for _, fieldV := range fieldVals { + if containsFold(fieldV, searchV) { continue Outer } } @@ -559,6 +550,23 @@ Outer: return true } +// containsFold is a case-insensitive alternative to strings.Contains, used to help avoid excess allocations during searches. +func containsFold(s, substr string) bool { + if len(s) < len(substr) { + return false + } + + n := len(s) - len(substr) + + for i := 0; i <= n; i++ { + if strings.EqualFold(s[i:i+len(substr)], substr) { + return true + } + } + + return false +} + func stringCompare(a string, b string, isDesc bool) bool { if isDesc { return a > b diff --git a/api/types/server.go b/api/types/server.go index 62e9d2368e531..b550b027c3582 100644 --- a/api/types/server.go +++ b/api/types/server.go @@ -350,7 +350,11 @@ func (s *ServerV2) GetAllLabels() map[string]string { // CombineLabels combines the passed in static and dynamic labels. func CombineLabels(static map[string]string, dynamic map[string]CommandLabelV2) map[string]string { - lmap := make(map[string]string) + if len(dynamic) == 0 { + return static + } + + lmap := make(map[string]string, len(static)+len(dynamic)) for key, value := range static { lmap[key] = value } @@ -492,20 +496,27 @@ func (s *ServerV2) CheckAndSetDefaults() error { // MatchSearch goes through select field values and tries to // match against the list of search values. func (s *ServerV2) MatchSearch(values []string) bool { - var fieldVals []string + if s.GetKind() != KindNode { + return false + } + var custom func(val string) bool + if s.GetUseTunnel() { + custom = func(val string) bool { + return strings.EqualFold(val, "tunnel") + } + } - if s.GetKind() == KindNode { - fieldVals = append(utils.MapToStrings(s.GetAllLabels()), s.GetName(), s.GetHostname(), s.GetAddr()) - fieldVals = append(fieldVals, s.GetPublicAddrs()...) + fieldVals := make([]string, 0, (len(s.Metadata.Labels)*2)+(len(s.Spec.CmdLabels)*2)+len(s.Spec.PublicAddrs)+3) - if s.GetUseTunnel() { - custom = func(val string) bool { - return strings.EqualFold(val, "tunnel") - } - } + labels := CombineLabels(s.Metadata.Labels, s.Spec.CmdLabels) + for key, value := range labels { + fieldVals = append(fieldVals, key, value) } + fieldVals = append(fieldVals, s.Metadata.Name, s.Spec.Hostname, s.Spec.Addr) + fieldVals = append(fieldVals, s.Spec.PublicAddrs...) + return MatchSearch(fieldVals, values, custom) } diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 726d68d76d97f..edd98add2a032 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -1329,18 +1329,34 @@ func (a *ServerWithRoles) checkUnifiedAccess(resource types.ResourceWithLabels, return false, trace.Wrap(canAccessErr) } - if resourceKind != types.KindSAMLIdPServiceProvider { - if err := checker.CanAccess(resource); err != nil { + // Filter first and only check RBAC if there is a match to improve perf. + match, err := services.MatchResourceByFilters(resource, filter, nil) + if err != nil { + log.WithFields(logrus.Fields{ + "resource_name": resource.GetName(), + "resource_kind": resourceKind, + "error": err, + }). + Warn("Unable to determine access to resource, matching with filter failed") + return false, nil + } - if trace.IsAccessDenied(err) { - return false, nil - } - return false, trace.Wrap(err) + if !match { + return false, nil + } + + if resourceKind == types.KindSAMLIdPServiceProvider { + return true, nil + } + + if err := checker.CanAccess(resource); err != nil { + if trace.IsAccessDenied(err) { + return false, nil } + return false, trace.Wrap(err) } - match, err := services.MatchResourceByFilters(resource, filter, nil) - return match, trace.Wrap(err) + return true, nil } // ListUnifiedResources returns a paginated list of unified resources filtered by user access. @@ -1358,6 +1374,20 @@ func (a *ServerWithRoles) ListUnifiedResources(ctx context.Context, req *proto.L Kinds: req.Kinds, } + // If a predicate expression was provided, evaluate it with an empty + // server to determine if the expression is valid before attempting + // to do any listing. + if filter.PredicateExpression != "" { + parser, err := services.NewResourceParser(&types.ServerV2{}) + if err != nil { + return nil, trace.Wrap(err) + } + + if _, err := parser.EvalBoolPredicate(filter.PredicateExpression); err != nil { + return nil, trace.BadParameter("failed to parse predicate expression: %s", err.Error()) + } + } + // Populate resourceAccessMap with any access errors the user has for each possible // resource kind. This allows the access check to occur a single time per resource // kind instead of once per matching resource. diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index 92c6f3dfb7bea..5b9431e2e7c59 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -4721,6 +4721,130 @@ func TestListUnifiedResources_WithPredicate(t *testing.T) { require.Error(t, err) } +func BenchmarkListUnifiedResourcesFilter(b *testing.B) { + const nodeCount = 150_000 + const roleCount = 32 + + logger := logrus.StandardLogger() + logger.ReplaceHooks(make(logrus.LevelHooks)) + logrus.SetFormatter(logutils.NewTestJSONFormatter()) + logger.SetLevel(logrus.PanicLevel) + logger.SetOutput(io.Discard) + + ctx := context.Background() + srv := newTestTLSServer(b) + + var ids []string + for i := 0; i < roleCount; i++ { + ids = append(ids, uuid.New().String()) + } + + ids[0] = "hidden" + + var hiddenNodes int + // Create test nodes. + for i := 0; i < nodeCount; i++ { + name := uuid.New().String() + id := ids[i%len(ids)] + if id == "hidden" { + hiddenNodes++ + } + + labels := map[string]string{ + "kEy": id, + "grouP": "useRs", + } + + if i == 10 { + labels["ip"] = "10.20.30.40" + labels["ADDRESS"] = "10.20.30.41" + labels["food"] = "POTATO" + } + + node, err := types.NewServerWithLabels( + name, + types.KindNode, + types.ServerSpecV2{}, + labels, + ) + require.NoError(b, err) + + _, err = srv.Auth().UpsertNode(ctx, node) + require.NoError(b, err) + } + + b.Run("labels", func(b *testing.B) { + benchmarkListUnifiedResources( + b, ctx, + 1, + srv, + ids, + func(role types.Role, id string) { + role.SetNodeLabels(types.Allow, types.Labels{types.Wildcard: []string{types.Wildcard}}) + }, + func(req *proto.ListUnifiedResourcesRequest) { + req.Labels = map[string]string{"ip": "10.20.30.40"} + }, + ) + }) + b.Run("predicate path", func(b *testing.B) { + benchmarkListUnifiedResources( + b, ctx, + 1, + srv, + ids, + func(role types.Role, id string) { + role.SetNodeLabels(types.Allow, types.Labels{types.Wildcard: []string{types.Wildcard}}) + }, + func(req *proto.ListUnifiedResourcesRequest) { + req.PredicateExpression = `labels.ip == "10.20.30.40"` + }, + ) + }) + b.Run("predicate index", func(b *testing.B) { + benchmarkListUnifiedResources( + b, ctx, + 1, + srv, + ids, + func(role types.Role, id string) { + role.SetNodeLabels(types.Allow, types.Labels{types.Wildcard: []string{types.Wildcard}}) + }, + func(req *proto.ListUnifiedResourcesRequest) { + req.PredicateExpression = `labels["ip"] == "10.20.30.40"` + }, + ) + }) + b.Run("search lower", func(b *testing.B) { + benchmarkListUnifiedResources( + b, ctx, + 1, + srv, + ids, + func(role types.Role, id string) { + role.SetNodeLabels(types.Allow, types.Labels{types.Wildcard: []string{types.Wildcard}}) + }, + func(req *proto.ListUnifiedResourcesRequest) { + req.SearchKeywords = []string{"10.20.30.40"} + }, + ) + }) + b.Run("search upper", func(b *testing.B) { + benchmarkListUnifiedResources( + b, ctx, + 1, + srv, + ids, + func(role types.Role, id string) { + role.SetNodeLabels(types.Allow, types.Labels{types.Wildcard: []string{types.Wildcard}}) + }, + func(req *proto.ListUnifiedResourcesRequest) { + req.SearchKeywords = []string{"POTATO"} + }, + ) + }) +} + // go test ./lib/auth -bench=BenchmarkListUnifiedResources -run=^$ -v -benchtime 1x // goos: darwin // goarch: arm64 @@ -4731,7 +4855,7 @@ func TestListUnifiedResources_WithPredicate(t *testing.T) { // PASS // ok github.com/gravitational/teleport/lib/auth 2.878s func BenchmarkListUnifiedResources(b *testing.B) { - const nodeCount = 50_000 + const nodeCount = 150_000 const roleCount = 32 logger := logrus.StandardLogger() @@ -4791,10 +4915,11 @@ func BenchmarkListUnifiedResources(b *testing.B) { b.Run(tc.desc, func(b *testing.B) { benchmarkListUnifiedResources( b, ctx, - nodeCount, hiddenNodes, + nodeCount-hiddenNodes, srv, ids, tc.editRole, + func(req *proto.ListUnifiedResourcesRequest) {}, ) }) } @@ -4802,10 +4927,11 @@ func BenchmarkListUnifiedResources(b *testing.B) { func benchmarkListUnifiedResources( b *testing.B, ctx context.Context, - nodeCount, hiddenNodes int, + expectedCount int, srv *TestTLSServer, ids []string, editRole func(r types.Role, id string), + editReq func(req *proto.ListUnifiedResourcesRequest), ) { var roles []types.Role for _, id := range ids { @@ -4839,6 +4965,9 @@ func benchmarkListUnifiedResources( SortBy: types.SortBy{IsDesc: false, Field: types.ResourceMetadataName}, Limit: 1_000, } + + editReq(req) + for { rsp, err := clt.ListUnifiedResources(ctx, req) require.NoError(b, err) @@ -4849,7 +4978,7 @@ func benchmarkListUnifiedResources( break } } - require.Len(b, resources, nodeCount-hiddenNodes) + require.Len(b, resources, expectedCount) } } diff --git a/lib/services/matchers.go b/lib/services/matchers.go index 31698cf51968a..cb1a78fcec1cd 100644 --- a/lib/services/matchers.go +++ b/lib/services/matchers.go @@ -19,7 +19,6 @@ package services import ( - "fmt" "slices" "github.com/gravitational/trace" @@ -129,7 +128,7 @@ func MatchResourceLabels(matchers []ResourceMatcher, labels map[string]string) b // ResourceSeenKey is used as a key for a map that keeps track // of unique resource names and address. Currently "addr" // only applies to resource Application. -type ResourceSeenKey struct{ name, addr string } +type ResourceSeenKey struct{ name, kind, addr string } // MatchResourceByFilters returns true if all filter values given matched against the resource. // @@ -144,20 +143,21 @@ type ResourceSeenKey struct{ name, addr string } // is not provided but is provided for kind `KubernetesCluster`. func MatchResourceByFilters(resource types.ResourceWithLabels, filter MatchResourceFilter, seenMap map[ResourceSeenKey]struct{}) (bool, error) { var specResource types.ResourceWithLabels - resourceKind := resource.GetKind() + kind := resource.GetKind() // We assume when filtering for services like KubeService, AppServer, and DatabaseServer // the user is wanting to filter the contained resource ie. KubeClusters, Application, and Database. - resourceKey := ResourceSeenKey{} - switch resourceKind { + key := ResourceSeenKey{ + kind: kind, + name: resource.GetName(), + } + switch kind { case types.KindNode, types.KindDatabaseService, types.KindKubernetesCluster, types.KindWindowsDesktop, types.KindWindowsDesktopService, types.KindUserGroup: specResource = resource - resourceKey.name = fmt.Sprintf("%s/%s", specResource.GetName(), resourceKind) - case types.KindKubeServer: if seenMap != nil { return false, trace.BadParameter("checking for duplicate matches for resource kind %q is not supported", filter.ResourceKind) @@ -170,18 +170,17 @@ func MatchResourceByFilters(resource types.ResourceWithLabels, filter MatchResou return false, trace.BadParameter("expected types.DatabaseServer, got %T", resource) } specResource = server.GetDatabase() - resourceKey.name = fmt.Sprintf("%s/%s/", specResource.GetName(), resourceKind) - + key.name = specResource.GetName() case types.KindAppServer, types.KindSAMLIdPServiceProvider, types.KindAppOrSAMLIdPServiceProvider: switch appOrSP := resource.(type) { case types.AppServer: app := appOrSP.GetApp() specResource = app - resourceKey.name = fmt.Sprintf("%s/%s/", specResource.GetName(), resourceKind) - resourceKey.addr = app.GetPublicAddr() + key.addr = app.GetPublicAddr() + key.name = app.GetName() case types.SAMLIdPServiceProvider: specResource = appOrSP - resourceKey.name = fmt.Sprintf("%s/%s/", specResource.GetName(), resourceKind) + key.name = specResource.GetName() default: return false, trace.BadParameter("expected types.SAMLIdPServiceProvider or types.AppServer, got %T", resource) } @@ -190,10 +189,9 @@ func MatchResourceByFilters(resource types.ResourceWithLabels, filter MatchResou // of cases we need to handle. If the resource type didn't match any arm before // and it is not a Kubernetes resource kind, we return an error. if !slices.Contains(types.KubernetesResourcesKinds, filter.ResourceKind) { - return false, trace.NotImplemented("filtering for resource kind %q not supported", resourceKind) + return false, trace.NotImplemented("filtering for resource kind %q not supported", kind) } specResource = resource - resourceKey.name = fmt.Sprintf("%s/%s/", specResource.GetName(), resourceKind) } var match bool @@ -212,10 +210,10 @@ func MatchResourceByFilters(resource types.ResourceWithLabels, filter MatchResou // Deduplicate matches. if match && seenMap != nil { - if _, exists := seenMap[resourceKey]; exists { + if _, exists := seenMap[key]; exists { return false, nil } - seenMap[resourceKey] = struct{}{} + seenMap[key] = struct{}{} } return match, nil diff --git a/lib/services/parser.go b/lib/services/parser.go index 797719478ae77..141c469749a87 100644 --- a/lib/services/parser.go +++ b/lib/services/parser.go @@ -800,18 +800,17 @@ func NewResourceParser(resource types.ResourceWithLabels) (BoolPredicateParser, GetIdentifier: func(fields []string) (interface{}, error) { switch fields[0] { case ResourceLabelsIdentifier: - combinedLabels := resource.GetAllLabels() switch { // Field length of 1 means the user is using // an index expression ie: labels["env"], which the // parser will expect a map for lookup in `GetProperty`. case len(fields) == 1: - return labels(combinedLabels), nil + return resource, nil case len(fields) > 2: return nil, trace.BadParameter("only two fields are supported with identifier %q, got %d: %v", ResourceLabelsIdentifier, len(fields), fields) default: key := fields[1] - val, ok := combinedLabels[key] + val, ok := resource.GetLabel(key) if ok { return label{key: key, value: val}, nil } @@ -838,7 +837,7 @@ func NewResourceParser(resource types.ResourceWithLabels) (BoolPredicateParser, } }, GetProperty: func(mapVal, keyVal interface{}) (interface{}, error) { - m, ok := mapVal.(labels) + r, ok := mapVal.(types.ResourceWithLabels) if !ok { return GetStringMapValue(mapVal, keyVal) } @@ -848,7 +847,7 @@ func NewResourceParser(resource types.ResourceWithLabels) (BoolPredicateParser, return nil, trace.BadParameter("only string keys are supported") } - val, ok := m[key] + val, ok := r.GetLabel(key) if ok { return label{key: key, value: val}, nil } @@ -865,5 +864,3 @@ func NewResourceParser(resource types.ResourceWithLabels) (BoolPredicateParser, type label struct { key, value string } - -type labels map[string]string diff --git a/lib/services/role.go b/lib/services/role.go index c4a29263d1664..e176c3d0e0882 100644 --- a/lib/services/role.go +++ b/lib/services/role.go @@ -2422,9 +2422,14 @@ type AccessCheckable interface { // It also returns a flag indicating whether debug logging is enabled, // allowing the RBAC system to generate more verbose errors in debug mode. func rbacDebugLogger() (debugEnabled bool, debugf func(format string, args ...interface{})) { - isDebugEnabled := log.IsLevelEnabled(log.TraceLevel) - log := log.WithField(trace.Component, teleport.ComponentRBAC) - return isDebugEnabled, log.Tracef + debugEnabled = log.IsLevelEnabled(log.TraceLevel) + debugf = func(format string, args ...interface{}) {} + + if debugEnabled { + debugf = log.WithField(trace.Component, teleport.ComponentRBAC).Tracef + } + + return } func (set RoleSet) checkAccess(r AccessCheckable, traits wrappers.Traits, state AccessState, matchers ...RoleMatcher) error { diff --git a/lib/services/suite/presence_test.go b/lib/services/suite/presence_test.go index 66c6fc69b807c..7cc0558635858 100644 --- a/lib/services/suite/presence_test.go +++ b/lib/services/suite/presence_test.go @@ -32,7 +32,6 @@ func TestServerLabels(t *testing.T) { emptyLabels := make(map[string]string) // empty server := &types.ServerV2{} - require.Empty(t, cmp.Diff(server.GetAllLabels(), emptyLabels)) require.Empty(t, server.GetAllLabels()) require.True(t, types.MatchLabels(server, emptyLabels)) require.False(t, types.MatchLabels(server, map[string]string{"a": "b"}))