diff --git a/api/types/app.go b/api/types/app.go index 389feb935c30b..fa550c2cafcc4 100644 --- a/api/types/app.go +++ b/api/types/app.go @@ -195,6 +195,17 @@ func (a *AppV3) SetDynamicLabels(dl map[string]CommandLabel) { a.Spec.DynamicLabels = LabelsToV2(dl) } +// GetLabel retrieves the label with the provided key. If not found +// value will be empty and ok will be false. +func (a *AppV3) GetLabel(key string) (value string, ok bool) { + if cmd, ok := a.Spec.DynamicLabels[key]; ok { + return cmd.Result, ok + } + + v, ok := a.Metadata.Labels[key] + return v, ok +} + // GetAllLabels returns the app combined static and dynamic labels. func (a *AppV3) GetAllLabels() map[string]string { return CombineLabels(a.Metadata.Labels, a.Spec.DynamicLabels) diff --git a/api/types/appserver.go b/api/types/appserver.go index 65bc0e79eb15d..ce817f3df8528 100644 --- a/api/types/appserver.go +++ b/api/types/appserver.go @@ -276,6 +276,19 @@ func (s *AppServerV3) SetProxyIDs(proxyIDs []string) { s.Spec.ProxyIDs = proxyIDs } +// GetLabel retrieves the label with the provided key. If not found +// value will be empty and ok will be false. +func (s *AppServerV3) GetLabel(key string) (value string, ok bool) { + if s.Spec.App != nil { + if v, ok := s.Spec.App.GetLabel(key); ok { + return v, ok + } + } + + v, ok := s.Metadata.Labels[key] + return v, ok +} + // GetAllLabels returns all resource's labels. Considering: // * Static labels from `Metadata.Labels` and `Spec.App`. // * Dynamic labels from `Spec.App.Spec`. diff --git a/api/types/connection_diagnostic.go b/api/types/connection_diagnostic.go index 4c6ad7e34e835..361daf6d34310 100644 --- a/api/types/connection_diagnostic.go +++ b/api/types/connection_diagnostic.go @@ -86,6 +86,13 @@ func (c *ConnectionDiagnosticV1) CheckAndSetDefaults() error { return nil } +// GetLabel retrieves the label with the provided key. If not found +// value will be empty and ok will be false. +func (c *ConnectionDiagnosticV1) GetLabel(key string) (value string, ok bool) { + v, ok := c.Metadata.Labels[key] + return v, ok +} + // GetAllLabels returns combined static and dynamic labels. func (c *ConnectionDiagnosticV1) GetAllLabels() map[string]string { return CombineLabels(c.Metadata.Labels, nil) diff --git a/api/types/database.go b/api/types/database.go index 5498610bab74c..35d637498b34a 100644 --- a/api/types/database.go +++ b/api/types/database.go @@ -221,6 +221,17 @@ func (d *DatabaseV3) SetDynamicLabels(dl map[string]CommandLabel) { d.Spec.DynamicLabels = LabelsToV2(dl) } +// GetLabel retrieves the label with the provided key. If not found +// value will be empty and ok will be false. +func (d *DatabaseV3) GetLabel(key string) (value string, ok bool) { + if cmd, ok := d.Spec.DynamicLabels[key]; ok { + return cmd.Result, ok + } + + v, ok := d.Metadata.Labels[key] + return v, ok +} + // GetAllLabels returns the database combined static and dynamic labels. func (d *DatabaseV3) GetAllLabels() map[string]string { return CombineLabels(d.Metadata.Labels, d.Spec.DynamicLabels) diff --git a/api/types/databaseserver.go b/api/types/databaseserver.go index 11c3c859ade93..f7e30c1e037d8 100644 --- a/api/types/databaseserver.go +++ b/api/types/databaseserver.go @@ -266,6 +266,19 @@ func (s *DatabaseServerV3) SetOrigin(origin string) { s.Metadata.SetOrigin(origin) } +// GetLabel retrieves the label with the provided key. If not found +// value will be empty and ok will be false. +func (s *DatabaseServerV3) GetLabel(key string) (value string, ok bool) { + if s.Spec.Database != nil { + if v, ok := s.Spec.Database.GetLabel(key); ok { + return v, ok + } + } + + v, ok := s.Metadata.Labels[key] + return v, ok +} + // GetAllLabels returns all resource's labels. Considering: // * Static labels from `Metadata.Labels` and `Spec.Database`. // * Dynamic labels from `Spec.DynamicLabels`. diff --git a/api/types/desktop.go b/api/types/desktop.go index 454491acba904..cb0a409dc32c4 100644 --- a/api/types/desktop.go +++ b/api/types/desktop.go @@ -116,6 +116,13 @@ func (s *WindowsDesktopServiceV3) SetProxyIDs(proxyIDs []string) { s.Spec.ProxyIDs = proxyIDs } +// GetLabel retrieves the label with the provided key. If not found +// value will be empty and ok will be false. +func (s *WindowsDesktopServiceV3) GetLabel(key string) (value string, ok bool) { + v, ok := s.Metadata.Labels[key] + return v, ok +} + // GetAllLabels returns the resources labels. func (s *WindowsDesktopServiceV3) GetAllLabels() map[string]string { return s.Metadata.Labels @@ -204,6 +211,13 @@ func (d *WindowsDesktopV3) GetHostID() string { return d.Spec.HostID } +// GetLabel retrieves the label with the provided key. If not found +// value will be empty and ok will be false. +func (d *WindowsDesktopV3) GetLabel(key string) (value string, ok bool) { + v, ok := d.Metadata.Labels[key] + return v, ok +} + // GetAllLabels returns combined static and dynamic labels. func (d *WindowsDesktopV3) GetAllLabels() map[string]string { // TODO(zmb3): add dynamic labels when running in agent mode diff --git a/api/types/kubernetes.go b/api/types/kubernetes.go index 4c48bdb805178..172f52f5aa5f7 100644 --- a/api/types/kubernetes.go +++ b/api/types/kubernetes.go @@ -142,6 +142,17 @@ func (k *KubernetesClusterV3) SetName(name string) { k.Metadata.Name = name } +// GetLabel retrieves the label with the provided key. If not found +// value will be empty and ok will be false. +func (k *KubernetesClusterV3) GetLabel(key string) (value string, ok bool) { + if cmd, ok := k.Spec.DynamicLabels[key]; ok { + return cmd.Result, ok + } + + v, ok := k.Metadata.Labels[key] + return v, ok +} + // GetStaticLabels returns the static labels. func (k *KubernetesClusterV3) GetStaticLabels() map[string]string { return k.Metadata.Labels diff --git a/api/types/resource.go b/api/types/resource.go index 2bfd87e28ed61..f3fbe4b62c780 100644 --- a/api/types/resource.go +++ b/api/types/resource.go @@ -80,6 +80,8 @@ type ResourceWithOrigin interface { type ResourceWithLabels interface { // ResourceWithOrigin is the base resource interface. ResourceWithOrigin + // GetLabel retrieves the label with the provided key. + GetLabel(key string) (value string, ok bool) // GetAllLabels returns all resource's labels. GetAllLabels() map[string]string // GetStaticLabels returns the resource's static labels. diff --git a/api/types/server.go b/api/types/server.go index 13f4c578647fb..476a412b0d71b 100644 --- a/api/types/server.go +++ b/api/types/server.go @@ -227,16 +227,29 @@ func (s *ServerV2) GetHostname() string { return s.Spec.Hostname } +// GetLabel retrieves the label with the provided key. If not found +// value will be empty and ok will be false. +func (s *ServerV2) GetLabel(key string) (value string, ok bool) { + if cmd, ok := s.Spec.CmdLabels[key]; ok { + return cmd.Result, ok + } + + v, ok := s.Metadata.Labels[key] + return v, ok +} + +// GetLabels returns server's static label key pairs. // GetLabels and GetStaticLabels are the same, and that is intentional. GetLabels // exists to preserve backwards compatibility, while GetStaticLabels exists to // implement ResourcesWithLabels. - -// GetLabels returns server's static label key pairs func (s *ServerV2) GetLabels() map[string]string { return s.Metadata.Labels } // GetStaticLabels returns the server static labels. +// GetLabels and GetStaticLabels are the same, and that is intentional. GetLabels +// exists to preserve backwards compatibility, while GetStaticLabels exists to +// implement ResourcesWithLabels. func (s *ServerV2) GetStaticLabels() map[string]string { return s.Metadata.Labels } diff --git a/go.mod b/go.mod index 81e89a2711c67..9886e711be555 100644 --- a/go.mod +++ b/go.mod @@ -62,7 +62,7 @@ require ( github.com/gravitational/trace v1.2.1 github.com/gravitational/ttlmap v0.0.0-20171116003245-91fd36b9004c github.com/grpc-ecosystem/go-grpc-middleware/providers/openmetrics/v2 v2.0.0-20220308023801-e4a6915ea237 - github.com/hashicorp/golang-lru v0.5.4 + github.com/hashicorp/golang-lru/v2 v2.0.2 github.com/jackc/pgconn v1.13.0 github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa github.com/jackc/pgproto3/v2 v2.3.1 diff --git a/go.sum b/go.sum index a52d0c21ddd54..865df9016985b 100644 --- a/go.sum +++ b/go.sum @@ -603,8 +603,8 @@ github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09 github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= -github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= +github.com/hashicorp/golang-lru/v2 v2.0.2 h1:Dwmkdr5Nc/oBiXgJS3CDHNhJtIHkuZ3DZF5twqnfBdU= +github.com/hashicorp/golang-lru/v2 v2.0.2/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 7153f572dcb53..956217cbc6886 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -1333,6 +1333,11 @@ func (a *ServerWithRoles) ListResources(ctx context.Context, req proto.ListResou req.SearchKeywords = nil req.PredicateExpression = "" + // Increase the limit to one more than was requested so + // that an additional page load is not needed to determine + // the next key. + req.Limit++ + resourceChecker, err := a.newResourceAccessChecker(req.ResourceType) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index 51b9a717d9021..7f078aa493ee4 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -21,6 +21,7 @@ import ( "crypto/tls" "crypto/x509/pkix" "fmt" + "io" "testing" "time" @@ -29,12 +30,14 @@ import ( "github.com/google/uuid" "github.com/gravitational/trace" "github.com/pquerna/otp/totp" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/defaults" + apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/types/installers" @@ -1502,6 +1505,105 @@ func TestSessionRecordingConfigRBAC(t *testing.T) { }) } +// time go test ./lib/auth -bench=. -run=^$ -v +// goos: darwin +// goarch: amd64 +// pkg: github.com/gravitational/teleport/lib/auth +// cpu: Intel(R) Core(TM) i9-9880H CPU @ 2.30GHz +// BenchmarkListNodes +// BenchmarkListNodes-16 1 1000469673 ns/op 518721960 B/op 8344858 allocs/op +// PASS +// ok github.com/gravitational/teleport/lib/auth 3.695s +// go test ./lib/auth -bench=. -run=^$ -v 19.02s user 3.87s system 244% cpu 9.376 total +func BenchmarkListNodes(b *testing.B) { + const nodeCount = 50_000 + const roleCount = 32 + + logger := logrus.StandardLogger() + logger.ReplaceHooks(make(logrus.LevelHooks)) + logger.SetLevel(logrus.DebugLevel) + logger.SetOutput(io.Discard) + + ctx := context.Background() + srv := newTestTLSServer(b) + + var values []string + for i := 0; i < roleCount; i++ { + values = append(values, uuid.New().String()) + } + + values[0] = "hidden" + + var hiddenNodes int + // Create test nodes. + for i := 0; i < nodeCount; i++ { + name := uuid.New().String() + val := values[i%len(values)] + if val == "hidden" { + hiddenNodes++ + } + node, err := types.NewServerWithLabels( + name, + types.KindNode, + types.ServerSpecV2{}, + map[string]string{"key": val}, + ) + require.NoError(b, err) + + _, err = srv.Auth().UpsertNode(ctx, node) + require.NoError(b, err) + } + + testNodes, err := srv.Auth().GetNodes(ctx, defaults.Namespace) + require.NoError(b, err) + require.Len(b, testNodes, nodeCount) + + var roles []types.Role + for _, val := range values { + role, err := types.NewRole(fmt.Sprintf("role-%s", val), types.RoleSpecV5{}) + require.NoError(b, err) + + if val == "hidden" { + role.SetNodeLabels(types.Deny, types.Labels{"key": {val}}) + } else { + role.SetNodeLabels(types.Allow, types.Labels{"key": {val}}) + } + roles = append(roles, role) + } + + // create user, role, and client + username := "user" + + user, err := CreateUser(srv.Auth(), username, roles...) + require.NoError(b, err) + identity := TestUser(user.GetName()) + clt, err := srv.NewClient(identity) + require.NoError(b, err) + + b.ReportAllocs() + b.ResetTimer() + + for n := 0; n < b.N; n++ { + var resources []types.ResourceWithLabels + req := proto.ListResourcesRequest{ + ResourceType: types.KindNode, + Namespace: apidefaults.Namespace, + Limit: 1_000, + } + for { + rsp, err := clt.ListResources(ctx, req) + require.NoError(b, err) + + resources = append(resources, rsp.Resources...) + req.StartKey = rsp.NextKey + if req.StartKey == "" { + break + } + } + require.Len(b, resources, nodeCount-hiddenNodes) + } +} + // TestGetAndList_Nodes users can retrieve nodes with various filters // and with the appropriate permissions. func TestGetAndList_Nodes(t *testing.T) { diff --git a/lib/auth/tls_test.go b/lib/auth/tls_test.go index e3c40c064a2b9..ecb648d346570 100644 --- a/lib/auth/tls_test.go +++ b/lib/auth/tls_test.go @@ -3449,7 +3449,7 @@ func verifyJWT(clock clockwork.Clock, clusterName string, pairs []*types.JWTKeyP return nil, trace.NewAggregate(errs...) } -func newTestTLSServer(t *testing.T) *TestTLSServer { +func newTestTLSServer(t testing.TB) *TestTLSServer { as, err := NewTestAuthServer(TestAuthServerConfig{ Dir: t.TempDir(), Clock: clockwork.NewFakeClock(), diff --git a/lib/backend/report.go b/lib/backend/report.go index 6c551309319bc..a2b2613c741e1 100644 --- a/lib/backend/report.go +++ b/lib/backend/report.go @@ -22,7 +22,7 @@ import ( "time" "github.com/gravitational/trace" - lru "github.com/hashicorp/golang-lru" + lru "github.com/hashicorp/golang-lru/v2" "github.com/jonboulle/clockwork" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" @@ -81,7 +81,7 @@ type Reporter struct { // // This will keep an upper limit on our memory usage while still always // reporting the most active keys. - topRequestsCache *lru.Cache + topRequestsCache *lru.Cache[topRequestsCacheKey, struct{}] } // NewReporter returns a new Reporter. @@ -95,12 +95,7 @@ func NewReporter(cfg ReporterConfig) (*Reporter, error) { return nil, trace.Wrap(err) } - cache, err := lru.NewWithEvict(cfg.TopRequestsCount, func(key interface{}, value interface{}) { - labels, ok := key.(topRequestsCacheKey) - if !ok { - log.Errorf("BUG: invalid cache key type: %T", key) - return - } + cache, err := lru.NewWithEvict(cfg.TopRequestsCount, func(labels topRequestsCacheKey, value struct{}) { // Evict the key from requests metric. requests.DeleteLabelValues(labels.component, labels.key, labels.isRange) }) diff --git a/lib/services/role.go b/lib/services/role.go index a38229c99667e..d4162c5c98634 100644 --- a/lib/services/role.go +++ b/lib/services/role.go @@ -1058,6 +1058,24 @@ func MatchDatabaseUser(selectors []string, user string) (bool, string) { // MatchLabels matches selector against target. Empty selector matches // nothing, wildcard matches everything. func MatchLabels(selector types.Labels, target map[string]string) (bool, string, error) { + return MatchLabelGetter(selector, mapLabelGetter(target)) +} + +// LabelGetter allows retrieving a particular label by name. +type LabelGetter interface { + GetLabel(key string) (value string, ok bool) +} + +type mapLabelGetter map[string]string + +func (m mapLabelGetter) GetLabel(key string) (value string, ok bool) { + v, ok := m[key] + return v, ok +} + +// MatchLabelGetter matches selector against labelGetter. Empty selector matches +// nothing, wildcard matches everything. +func MatchLabelGetter(selector types.Labels, labelGetter LabelGetter) (bool, string, error) { // Empty selector matches nothing. if len(selector) == 0 { return false, "no match, empty selector", nil @@ -1071,19 +1089,20 @@ func MatchLabels(selector types.Labels, target map[string]string) (bool, string, // Perform full match. for key, selectorValues := range selector { - targetVal, hasKey := target[key] - + targetVal, hasKey := labelGetter.GetLabel(key) if !hasKey { return false, fmt.Sprintf("no key match: '%v'", key), nil } - if !apiutils.SliceContainsStr(selectorValues, types.Wildcard) { - result, err := utils.SliceMatchesRegex(targetVal, selectorValues) - if err != nil { - return false, "", trace.Wrap(err) - } else if !result { - return false, fmt.Sprintf("no value match: got '%v' want: '%v'", targetVal, selectorValues), nil - } + if apiutils.SliceContainsStr(selectorValues, types.Wildcard) { + continue + } + + result, err := utils.SliceMatchesRegex(targetVal, selectorValues) + if err != nil { + return false, "", trace.Wrap(err) + } else if !result { + return false, fmt.Sprintf("no value match: got '%v' want: '%v'", targetVal, selectorValues), nil } } @@ -1955,16 +1974,16 @@ type AccessCheckable interface { GetKind() string GetName() string GetMetadata() types.Metadata - GetAllLabels() map[string]string + GetLabel(key string) (value string, ok bool) } // rbacDebugLogger creates a debug logger for Teleport's RBAC component. // It also returns a flag indicating whether debug logging is enabled, // allowing the RBAC system to generate more verbose errors in debug mode. func rbacDebugLogger() (debugEnabled bool, debugf func(format string, args ...interface{})) { - isDebugEnabled := log.IsLevelEnabled(log.DebugLevel) + isDebugEnabled := log.IsLevelEnabled(log.TraceLevel) log := log.WithField(trace.Component, teleport.ComponentRBAC) - return isDebugEnabled, log.Debugf + return isDebugEnabled, log.Tracef } // checkAccess checks if this role set has access to a particular resource, @@ -1981,7 +2000,6 @@ func (set RoleSet) checkAccess(r AccessCheckable, mfa AccessMFAParams, matchers } namespace := types.ProcessNamespace(r.GetMetadata().Namespace) - allLabels := r.GetAllLabels() // Additional message depending on kind of resource // so there's more context on why the user might not have access. @@ -2016,7 +2034,7 @@ func (set RoleSet) checkAccess(r AccessCheckable, mfa AccessMFAParams, matchers continue } - matchLabels, labelsMessage, err := MatchLabels(getRoleLabels(role, types.Deny), allLabels) + matchLabels, labelsMessage, err := MatchLabelGetter(getRoleLabels(role, types.Deny), r) if err != nil { return trace.Wrap(err) } @@ -2054,7 +2072,7 @@ func (set RoleSet) checkAccess(r AccessCheckable, mfa AccessMFAParams, matchers continue } - matchLabels, labelsMessage, err := MatchLabels(getRoleLabels(role, types.Allow), allLabels) + matchLabels, labelsMessage, err := MatchLabelGetter(getRoleLabels(role, types.Allow), r) if err != nil { return trace.Wrap(err) } diff --git a/lib/services/server.go b/lib/services/server.go index ce1799e25bc98..e482deea441ba 100644 --- a/lib/services/server.go +++ b/lib/services/server.go @@ -285,38 +285,28 @@ func UnmarshalServer(bytes []byte, kind string, opts ...MarshalOption) (types.Se return nil, trace.BadParameter("missing server data") } - var h types.ResourceHeader - if err = utils.FastUnmarshal(bytes, &h); err != nil { + var s types.ServerV2 + if err := utils.FastUnmarshal(bytes, &s); err != nil { + return nil, trace.BadParameter(err.Error()) + } + s.Kind = kind + if err := s.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } - - switch h.Version { - case types.V2: - var s types.ServerV2 - - if err := utils.FastUnmarshal(bytes, &s); err != nil { - return nil, trace.BadParameter(err.Error()) - } - s.Kind = kind - if err := s.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - if cfg.ID != 0 { - s.SetResourceID(cfg.ID) - } - if !cfg.Expires.IsZero() { - s.SetExpiry(cfg.Expires) - } - if s.Metadata.Expires != nil { - apiutils.UTC(s.Metadata.Expires) - } - // Force the timestamps to UTC for consistency. - // See https://github.com/gogo/protobuf/issues/519 for details on issues this causes for proto.Clone - apiutils.UTC(&s.Spec.Rotation.Started) - apiutils.UTC(&s.Spec.Rotation.LastRotated) - return &s, nil + if cfg.ID != 0 { + s.SetResourceID(cfg.ID) + } + if !cfg.Expires.IsZero() { + s.SetExpiry(cfg.Expires) + } + if s.Metadata.Expires != nil { + apiutils.UTC(s.Metadata.Expires) } - return nil, trace.BadParameter("server resource version %q is not supported", h.Version) + // Force the timestamps to UTC for consistency. + // See https://github.com/gogo/protobuf/issues/519 for details on issues this causes for proto.Clone + apiutils.UTC(&s.Spec.Rotation.Started) + apiutils.UTC(&s.Spec.Rotation.LastRotated) + return &s, nil } // MarshalServer marshals the Server resource to JSON. diff --git a/lib/utils/replace.go b/lib/utils/replace.go index 47d6ac004645d..f0a48df159291 100644 --- a/lib/utils/replace.go +++ b/lib/utils/replace.go @@ -19,6 +19,7 @@ import ( "strings" "github.com/gravitational/trace" + lru "github.com/hashicorp/golang-lru/v2" ) // ContainsExpansion returns true if value contains @@ -88,27 +89,54 @@ type RegexpConfig struct { // match is always evaluated as a regex either an exact match or regexp. func SliceMatchesRegex(input string, expressions []string) (bool, error) { for _, expression := range expressions { - if !strings.HasPrefix(expression, "^") || !strings.HasSuffix(expression, "$") { - // replace glob-style wildcards with regexp wildcards - // for plain strings, and quote all characters that could - // be interpreted in regular expression - expression = "^" + GlobToRegexp(expression) + "$" + result, err := matchString(input, expression) + if err != nil || result { + return result, trace.Wrap(err) } + } - expr, err := regexp.Compile(expression) - if err != nil { - return false, trace.BadParameter(err.Error()) - } + return false, nil +} - // Since the expression is always surrounded by ^ and $ this is an exact - // match for either a a plain string (for example ^hello$) or for a regexp - // (for example ^hel*o$). - if expr.MatchString(input) { - return true, nil - } +// mustCache initializes a new [lru.Cache] with the provided size. +// A panic will be triggered if the creation of the cache fails. +func mustCache[K comparable, V any](size int) *lru.Cache[K, V] { + cache, err := lru.New[K, V](size) + if err != nil { + panic(err) } - return false, nil + return cache +} + +// exprCache interns compiled regular expressions created in matchString +// to improve performance. +var exprCache = mustCache[string, *regexp.Regexp](1000) + +func matchString(input, expression string) (bool, error) { + if expr, ok := exprCache.Get(expression); ok { + return expr.MatchString(input), nil + } + + original := expression + if !strings.HasPrefix(expression, "^") || !strings.HasSuffix(expression, "$") { + // replace glob-style wildcards with regexp wildcards + // for plain strings, and quote all characters that could + // be interpreted in regular expression + expression = "^" + GlobToRegexp(expression) + "$" + } + + expr, err := regexp.Compile(expression) + if err != nil { + return false, trace.BadParameter(err.Error()) + } + + exprCache.Add(original, expr) + + // Since the expression is always surrounded by ^ and $ this is an exact + // match for either a plain string (for example ^hello$) or for a regexp + // (for example ^hel*o$). + return expr.MatchString(input), nil } var replaceWildcard = regexp.MustCompile(`(\\\*)`)