diff --git a/api/types/accesslist/accesslist.go b/api/types/accesslist/accesslist.go index e379321088437..0d45014d9cf0c 100644 --- a/api/types/accesslist/accesslist.go +++ b/api/types/accesslist/accesslist.go @@ -406,8 +406,8 @@ func (a *AccessList) MatchSearch(values []string) bool { return types.MatchSearch(fieldVals, values, nil) } -// CloneResource returns a copy of the resource as types.ResourceWithLabels. -func (a *AccessList) CloneResource() types.ResourceWithLabels { +// Clone returns a copy of the list. +func (a *AccessList) Clone() *AccessList { var copy *AccessList utils.StrictObjectToStruct(a, ©) return copy diff --git a/api/types/accesslist/member.go b/api/types/accesslist/member.go index 28a71ac1bcf5e..fc3633e087ac1 100644 --- a/api/types/accesslist/member.go +++ b/api/types/accesslist/member.go @@ -135,3 +135,10 @@ func (a *AccessListMember) MatchSearch(values []string) bool { fieldVals := append(utils.MapToStrings(a.GetAllLabels()), a.GetName()) return types.MatchSearch(fieldVals, values, nil) } + +// Clone returns a copy of the member. +func (a *AccessListMember) Clone() *AccessListMember { + var copy *AccessListMember + utils.StrictObjectToStruct(a, ©) + return copy +} diff --git a/api/types/accesslist/review.go b/api/types/accesslist/review.go index 0cdfe76db1512..776d06bdc1638 100644 --- a/api/types/accesslist/review.go +++ b/api/types/accesslist/review.go @@ -25,6 +25,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/header" "github.com/gravitational/teleport/api/types/header/convert/legacy" + "github.com/gravitational/teleport/api/utils" ) // Review is an access list review resource. @@ -114,6 +115,13 @@ func (r *Review) GetMetadata() types.Metadata { return legacy.FromHeaderMetadata(r.Metadata) } +// Clone returns a copy of the review. +func (a *Review) Clone() *Review { + var copy *Review + utils.StrictObjectToStruct(a, ©) + return copy +} + func (r *ReviewSpec) UnmarshalJSON(data []byte) error { type Alias ReviewSpec review := struct { diff --git a/lib/cache/access_list.go b/lib/cache/access_list.go new file mode 100644 index 0000000000000..3a12183ab41e4 --- /dev/null +++ b/lib/cache/access_list.go @@ -0,0 +1,405 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "context" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/types/accesslist" + "github.com/gravitational/teleport/api/types/header" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils/sortcache" +) + +type accessListIndex string + +const accessListNameIndex accessListIndex = "name" + +func newAccessListCollection(upstream services.AccessLists, w types.WatchKind) (*collection[*accesslist.AccessList, accessListIndex], error) { + if upstream == nil { + return nil, trace.BadParameter("missing parameter AccessLists") + } + + return &collection[*accesslist.AccessList, accessListIndex]{ + store: newStore(map[accessListIndex]func(*accesslist.AccessList) string{ + accessListNameIndex: func(al *accesslist.AccessList) string { + return al.GetMetadata().Name + }, + }), + fetcher: func(ctx context.Context, loadSecrets bool) ([]*accesslist.AccessList, error) { + var resources []*accesslist.AccessList + var nextToken string + for { + var page []*accesslist.AccessList + var err error + page, nextToken, err = upstream.ListAccessLists(ctx, 0 /* page size */, nextToken) + if err != nil { + return nil, trace.Wrap(err) + } + + resources = append(resources, page...) + + if nextToken == "" { + break + } + } + return resources, nil + }, + headerTransform: func(hdr *types.ResourceHeader) *accesslist.AccessList { + return &accesslist.AccessList{ + ResourceHeader: header.ResourceHeader{ + Kind: hdr.Kind, + Version: hdr.Version, + Metadata: header.Metadata{ + Name: hdr.Metadata.Name, + }, + }, + } + }, + watch: w, + }, nil +} + +// GetAccessLists returns a list of all access lists. +func (c *Cache) GetAccessLists(ctx context.Context) ([]*accesslist.AccessList, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetAccessLists") + defer span.End() + + rg, err := acquireReadGuard(c, c.collections.accessLists) + if err != nil { + return nil, trace.Wrap(err) + } + defer rg.Release() + + if !rg.ReadCache() { + out, err := c.Config.AccessLists.GetAccessLists(ctx) + return out, trace.Wrap(err) + } + + out := make([]*accesslist.AccessList, 0, rg.store.len()) + for n := range rg.store.resources(accessListNameIndex, "", "") { + out = append(out, n.Clone()) + } + return out, nil +} + +// ListAccessLists returns a paginated list of access lists. +func (c *Cache) ListAccessLists(ctx context.Context, pageSize int, pageToken string) ([]*accesslist.AccessList, string, error) { + ctx, span := c.Tracer.Start(ctx, "cache/ListAccessLists") + defer span.End() + + lister := genericLister[*accesslist.AccessList, accessListIndex]{ + cache: c, + collection: c.collections.accessLists, + index: accessListNameIndex, + defaultPageSize: 100, + upstreamList: c.Config.AccessLists.ListAccessLists, + nextToken: func(t *accesslist.AccessList) string { + return t.GetMetadata().Name + }, + clone: func(al *accesslist.AccessList) *accesslist.AccessList { + return al.Clone() + }, + } + out, next, err := lister.list(ctx, pageSize, pageToken) + return out, next, trace.Wrap(err) +} + +// GetAccessList returns the specified access list resource. +func (c *Cache) GetAccessList(ctx context.Context, name string) (*accesslist.AccessList, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetAccessList") + defer span.End() + + var upstreamRead bool + getter := genericGetter[*accesslist.AccessList, accessListIndex]{ + cache: c, + collection: c.collections.accessLists, + index: accessListNameIndex, + upstreamGet: func(ctx context.Context, s string) (*accesslist.AccessList, error) { + upstreamRead = true + return c.Config.AccessLists.GetAccessList(ctx, s) + }, + clone: func(al *accesslist.AccessList) *accesslist.AccessList { + return al.Clone() + }, + } + out, err := getter.get(ctx, name) + if trace.IsNotFound(err) && !upstreamRead { + // fallback is sane because method is never used + // in construction of derivative caches. + if item, err := c.Config.AccessLists.GetAccessList(ctx, name); err == nil { + return item, nil + } + } + return out, trace.Wrap(err) +} + +type accessListMemberIndex string + +const ( + accessListMemberNameIndex accessListMemberIndex = "name" + accessListMemberKindIndex accessListMemberIndex = "kind" +) + +func newAccessListMemberCollection(upstream services.AccessLists, w types.WatchKind) (*collection[*accesslist.AccessListMember, accessListMemberIndex], error) { + if upstream == nil { + return nil, trace.BadParameter("missing parameter AccessLists") + } + + return &collection[*accesslist.AccessListMember, accessListMemberIndex]{ + store: newStore(map[accessListMemberIndex]func(*accesslist.AccessListMember) string{ + accessListMemberNameIndex: func(r *accesslist.AccessListMember) string { + return r.Spec.AccessList + "/" + r.GetName() + }, + accessListMemberKindIndex: func(r *accesslist.AccessListMember) string { + return r.Spec.AccessList + "/" + r.Spec.MembershipKind + "/" + r.GetName() + }, + }), + fetcher: func(ctx context.Context, loadSecrets bool) ([]*accesslist.AccessListMember, error) { + var resources []*accesslist.AccessListMember + var nextToken string + for { + var page []*accesslist.AccessListMember + var err error + page, nextToken, err = upstream.ListAllAccessListMembers(ctx, 0 /* page size */, nextToken) + if err != nil { + return nil, trace.Wrap(err) + } + + resources = append(resources, page...) + + if nextToken == "" { + break + } + } + return resources, nil + }, + headerTransform: func(hdr *types.ResourceHeader) *accesslist.AccessListMember { + return &accesslist.AccessListMember{ + ResourceHeader: header.ResourceHeader{ + Kind: hdr.Kind, + Version: hdr.Version, + Metadata: header.Metadata{ + Name: hdr.Metadata.Name, + }, + }, + Spec: accesslist.AccessListMemberSpec{ + AccessList: hdr.Metadata.Description, + }, + } + }, + watch: w, + }, nil +} + +// CountAccessListMembers will count all access list members. +func (c *Cache) CountAccessListMembers(ctx context.Context, accessListName string) (uint32, uint32, error) { + ctx, span := c.Tracer.Start(ctx, "cache/CountAccessListMembers") + defer span.End() + + rg, err := acquireReadGuard(c, c.collections.accessListMembers) + if err != nil { + return 0, 0, trace.Wrap(err) + } + defer rg.Release() + + if !rg.ReadCache() { + count, listCount, err := c.Config.AccessLists.CountAccessListMembers(ctx, accessListName) + return count, listCount, trace.Wrap(err) + } + + startKey := accessListName + "/" + accesslist.MembershipKindList + "/" + endKey := sortcache.NextKey(startKey) + listCount := uint32(rg.store.count(accessListMemberKindIndex, startKey, endKey)) + + return uint32(rg.store.len()) - listCount, listCount, trace.Wrap(err) +} + +// ListAccessListMembers returns a paginated list of all access list members. +// May return a DynamicAccessListError if the requested access list has an +// implicit member list and the underlying implementation does not have +// enough information to compute the dynamic member list. +func (c *Cache) ListAccessListMembers(ctx context.Context, accessListName string, pageSize int, pageToken string) (members []*accesslist.AccessListMember, nextToken string, err error) { + ctx, span := c.Tracer.Start(ctx, "cache/ListAccessListMembers") + defer span.End() + + rg, err := acquireReadGuard(c, c.collections.accessListMembers) + if err != nil { + return nil, "", trace.Wrap(err) + } + defer rg.Release() + + if !rg.ReadCache() { + out, next, err := c.Config.AccessLists.ListAccessListMembers(ctx, accessListName, pageSize, pageToken) + return out, next, trace.Wrap(err) + } + + start := accessListName + if pageToken != "" { + start += "/" + pageToken + } + + if pageSize <= 0 { + pageSize = defaults.DefaultChunkSize + } + + var out []*accesslist.AccessListMember + for member := range rg.store.resources(accessListMemberNameIndex, start, "") { + if len(out) == pageSize { + return out, accessListName + "/" + member.GetName(), nil + } + + out = append(out, member.Clone()) + } + + return out, "", trace.Wrap(err) +} + +// ListAllAccessListMembers returns a paginated list of all access list members for all access lists. +func (c *Cache) ListAllAccessListMembers(ctx context.Context, pageSize int, pageToken string) (members []*accesslist.AccessListMember, nextToken string, err error) { + ctx, span := c.Tracer.Start(ctx, "cache/ListAllAccessListMembers") + defer span.End() + + lister := genericLister[*accesslist.AccessListMember, accessListMemberIndex]{ + cache: c, + collection: c.collections.accessListMembers, + index: accessListMemberNameIndex, + defaultPageSize: 200, + upstreamList: c.Config.AccessLists.ListAllAccessListMembers, + nextToken: func(t *accesslist.AccessListMember) string { + return t.GetMetadata().Name + }, + clone: func(al *accesslist.AccessListMember) *accesslist.AccessListMember { + return al.Clone() + }, + } + out, next, err := lister.list(ctx, pageSize, nextToken) + return out, next, trace.Wrap(err) +} + +// GetAccessListMember returns the specified access list member resource. +// May return a DynamicAccessListError if the requested access list has an +// implicit member list and the underlying implementation does not have +// enough information to compute the dynamic member record. +func (c *Cache) GetAccessListMember(ctx context.Context, accessList string, memberName string) (*accesslist.AccessListMember, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetAccessListMember") + defer span.End() + + rg, err := acquireReadGuard(c, c.collections.accessListMembers) + if err != nil { + return nil, trace.Wrap(err) + } + defer rg.Release() + + if !rg.ReadCache() { + out, err := c.Config.AccessLists.GetAccessListMember(ctx, accessList, memberName) + return out, trace.Wrap(err) + } + + member, err := rg.store.get(accessListMemberNameIndex, accessList+"/"+memberName) + if err != nil { + return nil, trace.Wrap(err) + } + return member.Clone(), nil +} + +type accessListReviewIndex string + +const accessListReviewNameIndex = "name" + +func newAccessListReviewCollection(upstream services.AccessLists, w types.WatchKind) (*collection[*accesslist.Review, accessListReviewIndex], error) { + if upstream == nil { + return nil, trace.BadParameter("missing parameter AccessLists") + } + + return &collection[*accesslist.Review, accessListReviewIndex]{ + store: newStore(map[accessListReviewIndex]func(*accesslist.Review) string{ + accessListReviewNameIndex: func(r *accesslist.Review) string { + return r.Spec.AccessList + "/" + r.GetName() + }, + }), + fetcher: func(ctx context.Context, loadSecrets bool) ([]*accesslist.Review, error) { + var resources []*accesslist.Review + var nextToken string + for { + var page []*accesslist.Review + var err error + page, nextToken, err = upstream.ListAllAccessListReviews(ctx, 0 /* page size */, nextToken) + if err != nil { + return nil, trace.Wrap(err) + } + + resources = append(resources, page...) + + if nextToken == "" { + break + } + } + return resources, nil + }, + headerTransform: func(hdr *types.ResourceHeader) *accesslist.Review { + return &accesslist.Review{ + ResourceHeader: header.ResourceHeader{ + Kind: hdr.Kind, + Version: hdr.Version, + Metadata: header.Metadata{ + Name: hdr.Metadata.Name, + }, + }, + Spec: accesslist.ReviewSpec{ + AccessList: hdr.Metadata.Description, + }, + } + }, + watch: w, + }, nil +} + +// ListAccessListReviews will list access list reviews for a particular access list. +func (c *Cache) ListAccessListReviews(ctx context.Context, accessList string, pageSize int, pageToken string) ([]*accesslist.Review, string, error) { + ctx, span := c.Tracer.Start(ctx, "cache/ListAccessListReviews") + defer span.End() + + lister := genericLister[*accesslist.Review, accessListReviewIndex]{ + cache: c, + collection: c.collections.accessListReviews, + index: accessListReviewNameIndex, + defaultPageSize: 200, + upstreamList: func(ctx context.Context, pageSize int, pageToken string) ([]*accesslist.Review, string, error) { + reviews, next, err := c.AccessLists.ListAccessListReviews(ctx, accessList, pageSize, pageToken) + return reviews, next, trace.Wrap(err) + }, + nextToken: func(t *accesslist.Review) string { + return t.GetName() + }, + clone: func(r *accesslist.Review) *accesslist.Review { + return r.Clone() + }, + } + + start := accessList + if pageToken != "" { + start += "/" + pageToken + } + + out, next, err := lister.list(ctx, pageSize, start) + return out, next, trace.Wrap(err) +} diff --git a/lib/cache/access_list_test.go b/lib/cache/access_list_test.go new file mode 100644 index 0000000000000..d287da49da021 --- /dev/null +++ b/lib/cache/access_list_test.go @@ -0,0 +1,177 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types/accesslist" +) + +// TestAccessList tests that CRUD operations on access list resources are +// replicated from the backend to the cache. +func TestAccessList(t *testing.T) { + t.Parallel() + + p := newTestPack(t, ForAuth) + t.Cleanup(p.Close) + + clock := clockwork.NewFakeClock() + + testResources(t, p, testFuncs[*accesslist.AccessList]{ + newResource: func(name string) (*accesslist.AccessList, error) { + return newAccessList(t, name, clock), nil + }, + create: func(ctx context.Context, item *accesslist.AccessList) error { + _, err := p.accessLists.UpsertAccessList(ctx, item) + return trace.Wrap(err) + }, + list: func(ctx context.Context) ([]*accesslist.AccessList, error) { + items, _, err := p.accessLists.ListAccessLists(ctx, 0 /* page size */, "") + return items, trace.Wrap(err) + }, + cacheGet: p.cache.GetAccessList, + cacheList: func(ctx context.Context) ([]*accesslist.AccessList, error) { + items, _, err := p.cache.ListAccessLists(ctx, 0 /* page size */, "") + return items, trace.Wrap(err) + }, + update: func(ctx context.Context, item *accesslist.AccessList) error { + _, err := p.accessLists.UpsertAccessList(ctx, item) + return trace.Wrap(err) + }, + deleteAll: p.accessLists.DeleteAllAccessLists, + }) +} + +// TestAccessListMembers tests that CRUD operations on access list member resources are +// replicated from the backend to the cache. +func TestAccessListMembers(t *testing.T) { + t.Parallel() + + p := newTestPack(t, ForAuth) + t.Cleanup(p.Close) + + clock := clockwork.NewFakeClock() + + al, err := p.accessLists.UpsertAccessList(context.Background(), newAccessList(t, "access-list", clock)) + require.NoError(t, err) + + testResources(t, p, testFuncs[*accesslist.AccessListMember]{ + newResource: func(name string) (*accesslist.AccessListMember, error) { + return newAccessListMember(t, al.GetName(), name), nil + }, + create: func(ctx context.Context, item *accesslist.AccessListMember) error { + _, err := p.accessLists.UpsertAccessListMember(ctx, item) + return trace.Wrap(err) + }, + list: func(ctx context.Context) ([]*accesslist.AccessListMember, error) { + items, _, err := p.accessLists.ListAllAccessListMembers(ctx, 0 /* page size */, "") + return items, trace.Wrap(err) + }, + cacheGet: func(ctx context.Context, name string) (*accesslist.AccessListMember, error) { + return p.cache.GetAccessListMember(ctx, al.GetName(), name) + }, + cacheList: func(ctx context.Context) ([]*accesslist.AccessListMember, error) { + items, _, err := p.cache.ListAccessListMembers(ctx, al.GetName(), 0 /* page size */, "") + return items, trace.Wrap(err) + }, + update: func(ctx context.Context, item *accesslist.AccessListMember) error { + _, err := p.accessLists.UpsertAccessListMember(ctx, item) + return trace.Wrap(err) + }, + deleteAll: p.accessLists.DeleteAllAccessListMembers, + }) + + // Verify counting. + ctx := context.Background() + for i := 0; i < 40; i++ { + _, err = p.accessLists.UpsertAccessListMember(ctx, newAccessListMember(t, al.GetName(), strconv.Itoa(i))) + require.NoError(t, err) + } + + count, listCount, err := p.accessLists.CountAccessListMembers(ctx, al.GetName()) + require.NoError(t, err) + require.Equal(t, uint32(40), count) + require.Equal(t, uint32(0), listCount) + + // Eventually, this should be reflected in the cache. + require.Eventually(t, func() bool { + // Make sure the cache has a single resource in it. + count, listCount, err := p.cache.CountAccessListMembers(ctx, al.GetName()) + assert.NoError(t, err) + return count == uint32(40) && listCount == uint32(0) + }, time.Second*2, time.Millisecond*250) +} + +// TestAccessListReviews tests that CRUD operations on access list review resources are +// replicated from the backend to the cache. +func TestAccessListReviews(t *testing.T) { + t.Parallel() + + p := newTestPack(t, ForAuth) + t.Cleanup(p.Close) + + clock := clockwork.NewFakeClock() + + al, _, err := p.accessLists.UpsertAccessListWithMembers(context.Background(), newAccessList(t, "access-list", clock), + []*accesslist.AccessListMember{ + newAccessListMember(t, "access-list", "member1"), + newAccessListMember(t, "access-list", "member2"), + newAccessListMember(t, "access-list", "member3"), + newAccessListMember(t, "access-list", "member4"), + newAccessListMember(t, "access-list", "member5"), + }) + require.NoError(t, err) + + // Keep track of the reviews, as create can update them. We'll use this + // to make sure the values are up to date during the test. + reviews := map[string]*accesslist.Review{} + + testResources(t, p, testFuncs[*accesslist.Review]{ + newResource: func(name string) (*accesslist.Review, error) { + review := newAccessListReview(t, al.GetName(), name) + // Store the name in the description. + review.Metadata.Description = name + reviews[name] = review + return review, nil + }, + create: func(ctx context.Context, item *accesslist.Review) error { + review, _, err := p.accessLists.CreateAccessListReview(ctx, item) + // Use the old name from the description. + oldName := review.Metadata.Description + reviews[oldName].SetName(review.GetName()) + return trace.Wrap(err) + }, + list: func(ctx context.Context) ([]*accesslist.Review, error) { + items, _, err := p.accessLists.ListAllAccessListReviews(ctx, 0 /* page size */, "") + return items, trace.Wrap(err) + }, + cacheList: func(ctx context.Context) ([]*accesslist.Review, error) { + items, _, err := p.cache.ListAccessListReviews(ctx, al.GetName(), 0 /* page size */, "") + return items, trace.Wrap(err) + }, + deleteAll: p.accessLists.DeleteAllAccessListReviews, + }) +} diff --git a/lib/cache/cache.go b/lib/cache/cache.go index e0fbffc3c61d8..e747dee550698 100644 --- a/lib/cache/cache.go +++ b/lib/cache/cache.go @@ -50,7 +50,6 @@ import ( "github.com/gravitational/teleport/api/internalutils/stream" apitracing "github.com/gravitational/teleport/api/observability/tracing" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/api/types/accesslist" "github.com/gravitational/teleport/api/types/discoveryconfig" "github.com/gravitational/teleport/api/types/secreports" "github.com/gravitational/teleport/api/types/userloginstate" @@ -62,7 +61,6 @@ import ( "github.com/gravitational/teleport/lib/observability/tracing" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" - "github.com/gravitational/teleport/lib/services/simple" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/interval" ) @@ -522,7 +520,6 @@ type Cache struct { headlessAuthenticationsCache services.HeadlessAuthenticationService secReportsCache services.SecReports userLoginStateCache services.UserLoginStates - accessListCache *simple.AccessListService eventsFanout *services.FanoutV2 lowVolumeEventsFanout *utils.RoundRobin[*services.FanoutV2] kubeWaitingContsCache *local.KubeWaitingContainerService @@ -959,12 +956,6 @@ func New(config Config) (*Cache, error) { return nil, trace.Wrap(err) } - accessListCache, err := simple.NewAccessListService(config.Backend) - if err != nil { - cancel() - return nil, trace.Wrap(err) - } - databaseObjectsCache, err := local.NewDatabaseObjectService(config.Backend) if err != nil { cancel() @@ -1062,7 +1053,6 @@ func New(config Config) (*Cache, error) { headlessAuthenticationsCache: identityService, secReportsCache: secReportsCache, userLoginStateCache: userLoginStatesCache, - accessListCache: accessListCache, databaseObjectsCache: databaseObjectsCache, eventsFanout: fanout, lowVolumeEventsFanout: utils.NewRoundRobin(lowVolumeFanouts), @@ -2403,126 +2393,6 @@ func (c *Cache) GetUserLoginState(ctx context.Context, name string) (*userlogins return uls, trace.Wrap(err) } -// GetAccessLists returns a list of all access lists. -func (c *Cache) GetAccessLists(ctx context.Context) ([]*accesslist.AccessList, error) { - ctx, span := c.Tracer.Start(ctx, "cache/GetAccessLists") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.accessLists) - if err != nil { - return nil, trace.Wrap(err) - } - defer rg.Release() - return rg.reader.GetAccessLists(ctx) -} - -// ListAccessLists returns a paginated list of access lists. -func (c *Cache) ListAccessLists(ctx context.Context, pageSize int, nextToken string) ([]*accesslist.AccessList, string, error) { - ctx, span := c.Tracer.Start(ctx, "cache/ListAccessLists") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.accessLists) - if err != nil { - return nil, "", trace.Wrap(err) - } - defer rg.Release() - return rg.reader.ListAccessLists(ctx, pageSize, nextToken) -} - -// GetAccessList returns the specified access list resource. -func (c *Cache) GetAccessList(ctx context.Context, name string) (*accesslist.AccessList, error) { - ctx, span := c.Tracer.Start(ctx, "cache/GetAccessList") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.accessLists) - if err != nil { - return nil, trace.Wrap(err) - } - defer rg.Release() - item, err := rg.reader.GetAccessList(ctx, name) - if trace.IsNotFound(err) && rg.IsCacheRead() { - // release read lock early - rg.Release() - // fallback is sane because method is never used - // in construction of derivative caches. - if item, err := c.Config.AccessLists.GetAccessList(ctx, name); err == nil { - return item, nil - } - } - return item, trace.Wrap(err) -} - -// CountAccessListMembers will count all access list members. -func (c *Cache) CountAccessListMembers(ctx context.Context, accessListName string) (uint32, uint32, error) { - ctx, span := c.Tracer.Start(ctx, "cache/CountAccessListMembers") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.accessListMembers) - if err != nil { - return 0, 0, trace.Wrap(err) - } - defer rg.Release() - return rg.reader.CountAccessListMembers(ctx, accessListName) -} - -// ListAccessListMembers returns a paginated list of all access list members. -// May return a DynamicAccessListError if the requested access list has an -// implicit member list and the underlying implementation does not have -// enough information to compute the dynamic member list. -func (c *Cache) ListAccessListMembers(ctx context.Context, accessListName string, pageSize int, pageToken string) (members []*accesslist.AccessListMember, nextToken string, err error) { - ctx, span := c.Tracer.Start(ctx, "cache/ListAccessListMembers") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.accessListMembers) - if err != nil { - return nil, "", trace.Wrap(err) - } - defer rg.Release() - return rg.reader.ListAccessListMembers(ctx, accessListName, pageSize, pageToken) -} - -// ListAllAccessListMembers returns a paginated list of all access list members for all access lists. -func (c *Cache) ListAllAccessListMembers(ctx context.Context, pageSize int, pageToken string) (members []*accesslist.AccessListMember, nextToken string, err error) { - ctx, span := c.Tracer.Start(ctx, "cache/ListAllAccessListMembers") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.accessListMembers) - if err != nil { - return nil, "", trace.Wrap(err) - } - defer rg.Release() - return rg.reader.ListAllAccessListMembers(ctx, pageSize, pageToken) -} - -// GetAccessListMember returns the specified access list member resource. -// May return a DynamicAccessListError if the requested access list has an -// implicit member list and the underlying implementation does not have -// enough information to compute the dynamic member record. -func (c *Cache) GetAccessListMember(ctx context.Context, accessList string, memberName string) (*accesslist.AccessListMember, error) { - ctx, span := c.Tracer.Start(ctx, "cache/GetAccessListMember") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.accessListMembers) - if err != nil { - return nil, trace.Wrap(err) - } - defer rg.Release() - return rg.reader.GetAccessListMember(ctx, accessList, memberName) -} - -// ListAccessListReviews will list access list reviews for a particular access list. -func (c *Cache) ListAccessListReviews(ctx context.Context, accessList string, pageSize int, pageToken string) (reviews []*accesslist.Review, nextToken string, err error) { - ctx, span := c.Tracer.Start(ctx, "cache/ListAccessListReviews") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.accessListReviews) - if err != nil { - return nil, "", trace.Wrap(err) - } - defer rg.Release() - return rg.reader.ListAccessListReviews(ctx, accessList, pageSize, pageToken) -} - // ListAccessMonitoringRules returns a paginated list of access monitoring rules. func (c *Cache) ListAccessMonitoringRules(ctx context.Context, pageSize int, nextToken string) ([]*accessmonitoringrulesv1.AccessMonitoringRule, string, error) { ctx, span := c.Tracer.Start(ctx, "cache/ListAccessMonitoringRules") diff --git a/lib/cache/cache_test.go b/lib/cache/cache_test.go index b5c9a2a25a242..c4e4e72fa7c50 100644 --- a/lib/cache/cache_test.go +++ b/lib/cache/cache_test.go @@ -24,7 +24,6 @@ import ( "log/slog" "os" "slices" - "strconv" "sync" "testing" "time" @@ -1561,152 +1560,6 @@ func TestUserLoginStates(t *testing.T) { }) } -// TestAccessList tests that CRUD operations on access list resources are -// replicated from the backend to the cache. -func TestAccessList(t *testing.T) { - t.Parallel() - - p := newTestPack(t, ForAuth) - t.Cleanup(p.Close) - - clock := clockwork.NewFakeClockAt(time.Now()) - - testResources(t, p, testFuncs[*accesslist.AccessList]{ - newResource: func(name string) (*accesslist.AccessList, error) { - return newAccessList(t, name, clock), nil - }, - create: func(ctx context.Context, item *accesslist.AccessList) error { - _, err := p.accessLists.UpsertAccessList(ctx, item) - return trace.Wrap(err) - }, - list: func(ctx context.Context) ([]*accesslist.AccessList, error) { - items, _, err := p.accessLists.ListAccessLists(ctx, 0 /* page size */, "") - return items, trace.Wrap(err) - }, - cacheGet: p.cache.GetAccessList, - cacheList: func(ctx context.Context) ([]*accesslist.AccessList, error) { - items, _, err := p.cache.ListAccessLists(ctx, 0 /* page size */, "") - return items, trace.Wrap(err) - }, - update: func(ctx context.Context, item *accesslist.AccessList) error { - _, err := p.accessLists.UpsertAccessList(ctx, item) - return trace.Wrap(err) - }, - deleteAll: p.accessLists.DeleteAllAccessLists, - }) -} - -// TestAccessListMembers tests that CRUD operations on access list member resources are -// replicated from the backend to the cache. -func TestAccessListMembers(t *testing.T) { - t.Parallel() - - p := newTestPack(t, ForAuth) - t.Cleanup(p.Close) - - clock := clockwork.NewFakeClockAt(time.Now()) - - al, err := p.accessLists.UpsertAccessList(context.Background(), newAccessList(t, "access-list", clock)) - require.NoError(t, err) - - testResources(t, p, testFuncs[*accesslist.AccessListMember]{ - newResource: func(name string) (*accesslist.AccessListMember, error) { - return newAccessListMember(t, al.GetName(), name), nil - }, - create: func(ctx context.Context, item *accesslist.AccessListMember) error { - _, err := p.accessLists.UpsertAccessListMember(ctx, item) - return trace.Wrap(err) - }, - list: func(ctx context.Context) ([]*accesslist.AccessListMember, error) { - items, _, err := p.accessLists.ListAllAccessListMembers(ctx, 0 /* page size */, "") - return items, trace.Wrap(err) - }, - cacheGet: func(ctx context.Context, name string) (*accesslist.AccessListMember, error) { - return p.cache.GetAccessListMember(ctx, al.GetName(), name) - }, - cacheList: func(ctx context.Context) ([]*accesslist.AccessListMember, error) { - items, _, err := p.cache.ListAccessListMembers(ctx, al.GetName(), 0 /* page size */, "") - return items, trace.Wrap(err) - }, - update: func(ctx context.Context, item *accesslist.AccessListMember) error { - _, err := p.accessLists.UpsertAccessListMember(ctx, item) - return trace.Wrap(err) - }, - deleteAll: p.accessLists.DeleteAllAccessListMembers, - }) - - // Verify counting. - ctx := context.Background() - for i := 0; i < 40; i++ { - _, err = p.accessLists.UpsertAccessListMember(ctx, newAccessListMember(t, al.GetName(), strconv.Itoa(i))) - require.NoError(t, err) - } - - count, listCount, err := p.accessLists.CountAccessListMembers(ctx, al.GetName()) - require.NoError(t, err) - require.Equal(t, uint32(40), count) - require.Equal(t, uint32(0), listCount) - - // Eventually, this should be reflected in the cache. - require.Eventually(t, func() bool { - // Make sure the cache has a single resource in it. - count, listCount, err := p.cache.CountAccessListMembers(ctx, al.GetName()) - assert.NoError(t, err) - return count == uint32(40) && listCount == uint32(0) - }, time.Second*2, time.Millisecond*250) -} - -// TestAccessListReviews tests that CRUD operations on access list review resources are -// replicated from the backend to the cache. -func TestAccessListReviews(t *testing.T) { - t.Parallel() - - p := newTestPack(t, ForAuth) - t.Cleanup(p.Close) - - clock := clockwork.NewFakeClockAt(time.Now()) - - al, _, err := p.accessLists.UpsertAccessListWithMembers(context.Background(), newAccessList(t, "access-list", clock), - []*accesslist.AccessListMember{ - newAccessListMember(t, "access-list", "member1"), - newAccessListMember(t, "access-list", "member2"), - newAccessListMember(t, "access-list", "member3"), - newAccessListMember(t, "access-list", "member4"), - newAccessListMember(t, "access-list", "member5"), - }) - require.NoError(t, err) - - // Keep track of the reviews, as create can update them. We'll use this - // to make sure the values are up to date during the test. - reviews := map[string]*accesslist.Review{} - - testResources(t, p, testFuncs[*accesslist.Review]{ - newResource: func(name string) (*accesslist.Review, error) { - review := newAccessListReview(t, al.GetName(), name) - // Store the name in the description. - review.Metadata.Description = name - reviews[name] = review - return review, nil - }, - create: func(ctx context.Context, item *accesslist.Review) error { - review, _, err := p.accessLists.CreateAccessListReview(ctx, item) - // Use the old name from the description. - oldName := review.Metadata.Description - reviews[oldName].SetName(review.GetName()) - return trace.Wrap(err) - }, - list: func(ctx context.Context) ([]*accesslist.Review, error) { - items, _, err := p.accessLists.ListAllAccessListReviews(ctx, 0 /* page size */, "") - return items, trace.Wrap(err) - }, - cacheList: func(ctx context.Context) ([]*accesslist.Review, error) { - items, _, err := p.cache.ListAccessListReviews(ctx, al.GetName(), 0 /* page size */, "") - return items, trace.Wrap(err) - }, - deleteAll: p.accessLists.DeleteAllAccessListReviews, - }) -} - // TestCrownJewel tests that CRUD operations on user notification resources are // replicated from the backend to the cache. func TestCrownJewel(t *testing.T) { diff --git a/lib/cache/collections.go b/lib/cache/collections.go index 7c8c5ca8d3e79..4abcd0bcb94eb 100644 --- a/lib/cache/collections.go +++ b/lib/cache/collections.go @@ -28,6 +28,7 @@ import ( notificationsv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/notifications/v1" workloadidentityv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/types/accesslist" ) // collectionHandler is used by the [Cache] to seed the initial @@ -95,6 +96,9 @@ type collections struct { webSessions *collection[types.WebSession, webSessionIndex] appSessions *collection[types.WebSession, appSessionIndex] snowflakeSessions *collection[types.WebSession, snowflakeSessionIndex] + accessLists *collection[*accesslist.AccessList, accessListIndex] + accessListMembers *collection[*accesslist.AccessListMember, accessListMemberIndex] + accessListReviews *collection[*accesslist.Review, accessListReviewIndex] } // setupCollections ensures that the appropriate [collection] is @@ -440,6 +444,30 @@ func setupCollections(c Config) (*collections, error) { out.webSessions = collect out.byKind[resourceKind] = out.webSessions } + case types.KindAccessList: + collect, err := newAccessListCollection(c.AccessLists, watch) + if err != nil { + return nil, trace.Wrap(err) + } + + out.accessLists = collect + out.byKind[resourceKind] = out.accessLists + case types.KindAccessListMember: + collect, err := newAccessListMemberCollection(c.AccessLists, watch) + if err != nil { + return nil, trace.Wrap(err) + } + + out.accessListMembers = collect + out.byKind[resourceKind] = out.accessListMembers + case types.KindAccessListReview: + collect, err := newAccessListReviewCollection(c.AccessLists, watch) + if err != nil { + return nil, trace.Wrap(err) + } + + out.accessListReviews = collect + out.byKind[resourceKind] = out.accessListReviews } } diff --git a/lib/cache/legacy_collections.go b/lib/cache/legacy_collections.go index 32dd49c9278e2..4baa117d07cca 100644 --- a/lib/cache/legacy_collections.go +++ b/lib/cache/legacy_collections.go @@ -37,7 +37,6 @@ import ( userspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/users/v1" usertasksv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/usertasks/v1" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/api/types/accesslist" "github.com/gravitational/teleport/api/types/discoveryconfig" "github.com/gravitational/teleport/api/types/secreports" "github.com/gravitational/teleport/api/types/userloginstate" @@ -108,9 +107,6 @@ type legacyCollections struct { auditQueries collectionReader[services.SecurityAuditQueryGetter] secReports collectionReader[services.SecurityReportGetter] secReportsStates collectionReader[services.SecurityReportStateGetter] - accessLists collectionReader[accessListsGetter] - accessListMembers collectionReader[accessListMembersGetter] - accessListReviews collectionReader[accessListReviewsGetter] tunnelConnections collectionReader[tunnelConnectionGetter] databaseObjects collectionReader[services.DatabaseObjectsGetter] discoveryConfigs collectionReader[services.DiscoveryConfigsGetter] @@ -291,24 +287,6 @@ func setupLegacyCollections(c *Cache, watches []types.WatchKind) (*legacyCollect } collections.userLoginStates = &genericCollection[*userloginstate.UserLoginState, services.UserLoginStatesGetter, userLoginStateExecutor]{cache: c, watch: watch} collections.byKind[resourceKind] = collections.userLoginStates - case types.KindAccessList: - if c.AccessLists == nil { - return nil, trace.BadParameter("missing parameter AccessLists") - } - collections.accessLists = &genericCollection[*accesslist.AccessList, accessListsGetter, accessListExecutor]{cache: c, watch: watch} - collections.byKind[resourceKind] = collections.accessLists - case types.KindAccessListMember: - if c.AccessLists == nil { - return nil, trace.BadParameter("missing parameter AccessLists") - } - collections.accessListMembers = &genericCollection[*accesslist.AccessListMember, accessListMembersGetter, accessListMemberExecutor]{cache: c, watch: watch} - collections.byKind[resourceKind] = collections.accessListMembers - case types.KindAccessListReview: - if c.AccessLists == nil { - return nil, trace.BadParameter("missing parameter AccessLists") - } - collections.accessListReviews = &genericCollection[*accesslist.Review, accessListReviewsGetter, accessListReviewExecutor]{cache: c, watch: watch} - collections.byKind[resourceKind] = collections.accessListReviews case types.KindKubeWaitingContainer: if c.KubeWaitingContainers == nil { return nil, trace.BadParameter("missing parameter KubeWaitingContainers") @@ -1359,171 +1337,6 @@ func (userLoginStateExecutor) getReader(cache *Cache, cacheOK bool) services.Use var _ executor[*userloginstate.UserLoginState, services.UserLoginStatesGetter] = userLoginStateExecutor{} -type accessListExecutor struct{} - -func (accessListExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]*accesslist.AccessList, error) { - var resources []*accesslist.AccessList - var nextToken string - for { - var page []*accesslist.AccessList - var err error - page, nextToken, err = cache.AccessLists.ListAccessLists(ctx, 0 /* page size */, nextToken) - if err != nil { - return nil, trace.Wrap(err) - } - - resources = append(resources, page...) - - if nextToken == "" { - break - } - } - return resources, nil -} - -func (accessListExecutor) upsert(ctx context.Context, cache *Cache, resource *accesslist.AccessList) error { - _, err := cache.accessListCache.UnconditionalUpsertAccessList(ctx, resource) - return trace.Wrap(err) -} - -func (accessListExecutor) deleteAll(ctx context.Context, cache *Cache) error { - return cache.accessListCache.DeleteAllAccessLists(ctx) -} - -func (accessListExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error { - return cache.accessListCache.UnconditionalDeleteAccessList(ctx, resource.GetName()) -} - -func (accessListExecutor) isSingleton() bool { return false } - -func (accessListExecutor) getReader(cache *Cache, cacheOK bool) accessListsGetter { - if cacheOK { - return cache.accessListCache - } - return cache.Config.AccessLists -} - -type accessListsGetter interface { - GetAccessLists(ctx context.Context) ([]*accesslist.AccessList, error) - ListAccessLists(ctx context.Context, pageSize int, nextToken string) ([]*accesslist.AccessList, string, error) - GetAccessList(ctx context.Context, name string) (*accesslist.AccessList, error) -} - -type accessListMemberExecutor struct{} - -func (accessListMemberExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]*accesslist.AccessListMember, error) { - var resources []*accesslist.AccessListMember - var nextToken string - for { - var page []*accesslist.AccessListMember - var err error - page, nextToken, err = cache.AccessLists.ListAllAccessListMembers(ctx, 0 /* page size */, nextToken) - if err != nil { - return nil, trace.Wrap(err) - } - - resources = append(resources, page...) - - if nextToken == "" { - break - } - } - return resources, nil -} - -func (accessListMemberExecutor) upsert(ctx context.Context, cache *Cache, resource *accesslist.AccessListMember) error { - _, err := cache.accessListCache.UnconditionalUpsertAccessListMember(ctx, resource) - return trace.Wrap(err) -} - -func (accessListMemberExecutor) deleteAll(ctx context.Context, cache *Cache) error { - return cache.accessListCache.DeleteAllAccessListMembers(ctx) -} - -func (accessListMemberExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error { - return cache.accessListCache.UnconditionalDeleteAccessListMember(ctx, - resource.GetMetadata().Description, // Cache passes access ID via description field. - resource.GetName()) -} - -func (accessListMemberExecutor) isSingleton() bool { return false } - -func (accessListMemberExecutor) getReader(cache *Cache, cacheOK bool) accessListMembersGetter { - if cacheOK { - return cache.accessListCache - } - return cache.Config.AccessLists -} - -type accessListMembersGetter interface { - CountAccessListMembers(ctx context.Context, accessListName string) (uint32, uint32, error) - ListAccessListMembers(ctx context.Context, accessListName string, pageSize int, nextToken string) ([]*accesslist.AccessListMember, string, error) - GetAccessListMember(ctx context.Context, accessList string, memberName string) (*accesslist.AccessListMember, error) - ListAllAccessListMembers(ctx context.Context, pageSize int, pageToken string) ([]*accesslist.AccessListMember, string, error) -} - -type accessListReviewExecutor struct{} - -func (accessListReviewExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]*accesslist.Review, error) { - var resources []*accesslist.Review - var nextToken string - for { - var page []*accesslist.Review - var err error - page, nextToken, err = cache.AccessLists.ListAllAccessListReviews(ctx, 0 /* page size */, nextToken) - if err != nil { - return nil, trace.Wrap(err) - } - - resources = append(resources, page...) - - if nextToken == "" { - break - } - } - return resources, nil -} - -func (accessListReviewExecutor) upsert(ctx context.Context, cache *Cache, resource *accesslist.Review) error { - if _, _, err := cache.accessListCache.CreateAccessListReview(ctx, resource); err != nil { - if !trace.IsAlreadyExists(err) { - return trace.Wrap(err) - } - - if err := cache.accessListCache.DeleteAccessListReview(ctx, resource.Spec.AccessList, resource.GetName()); err != nil { - return trace.Wrap(err) - } - - if _, _, err := cache.accessListCache.CreateAccessListReview(ctx, resource); err != nil { - return trace.Wrap(err) - } - } - return nil -} - -func (accessListReviewExecutor) deleteAll(ctx context.Context, cache *Cache) error { - return cache.accessListCache.DeleteAllAccessListReviews(ctx) -} - -func (accessListReviewExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error { - return cache.accessListCache.DeleteAccessListReview(ctx, - resource.GetMetadata().Description, // Cache passes access ID via description field. - resource.GetName()) -} - -func (accessListReviewExecutor) isSingleton() bool { return false } - -func (accessListReviewExecutor) getReader(cache *Cache, cacheOK bool) accessListReviewsGetter { - if cacheOK { - return cache.accessListCache - } - return cache.Config.AccessLists -} - -type accessListReviewsGetter interface { - ListAccessListReviews(ctx context.Context, accessList string, pageSize int, pageToken string) (reviews []*accesslist.Review, nextToken string, err error) -} - type accessMonitoringRulesExecutor struct{} func (accessMonitoringRulesExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]*accessmonitoringrulesv1.AccessMonitoringRule, error) { diff --git a/lib/cache/store.go b/lib/cache/store.go index 470e0d45bf5a1..58a0e6bd37053 100644 --- a/lib/cache/store.go +++ b/lib/cache/store.go @@ -89,3 +89,13 @@ func (s *store[T, I]) get(index I, key string) (T, error) { func (s *store[T, I]) resources(index I, start, stop string) iter.Seq[T] { return s.cache.Ascend(index, start, stop) } + +// count returns the number of items that exist in the provided range. +func (s *store[T, I]) count(index I, start, stop string) int { + var n int + for range s.cache.Ascend(index, start, stop) { + n++ + } + + return n +}