From cb2a774ac838a025f96a3f794a42894950053f8d Mon Sep 17 00:00:00 2001 From: Tim Ross Date: Tue, 6 May 2025 16:58:54 -0400 Subject: [PATCH] Convert identity center provisioning to new cache mechanism Moves identity center assignments and provisioning states to the new cache collection scheme that was introduced in #52210. In addition to migrating the collections, the cache was also updated to implement the following missing methods: - GetProvisioningState - ListProvisioningStatesForAllDownstreams - ListPrincipalAssignments --- lib/cache/cache.go | 31 ------- lib/cache/collections.go | 148 ++++++++++++++++++------------- lib/cache/identitycenter.go | 119 ++++++++++++++++--------- lib/cache/identitycenter_test.go | 58 +++++++----- lib/cache/legacy_collections.go | 32 +------ lib/cache/provisioning.go | 148 +++++++++++++++++-------------- lib/cache/provisioning_test.go | 56 +++++++----- 7 files changed, 313 insertions(+), 279 deletions(-) diff --git a/lib/cache/cache.go b/lib/cache/cache.go index 945460994a687..c80cb505096c3 100644 --- a/lib/cache/cache.go +++ b/lib/cache/cache.go @@ -38,7 +38,6 @@ import ( "github.com/gravitational/teleport/api/client/proto" apidefaults "github.com/gravitational/teleport/api/defaults" identitycenterv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/identitycenter/v1" - provisioningv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/provisioning/v1" "github.com/gravitational/teleport/api/internalutils/stream" apitracing "github.com/gravitational/teleport/api/observability/tracing" "github.com/gravitational/teleport/api/types" @@ -505,8 +504,6 @@ type Cache struct { secReportsCache services.SecReports eventsFanout *services.FanoutV2 lowVolumeEventsFanout *utils.RoundRobin[*services.FanoutV2] - provisioningStatesCache *local.ProvisioningStateService - identityCenterCache *local.IdentityCenterService pluginStaticCredentialsCache *local.PluginStaticCredentialsService gitServersCache *local.GitServerService @@ -900,12 +897,6 @@ func New(config Config) (*Cache, error) { return nil, trace.Wrap(err) } - provisioningStatesCache, err := local.NewProvisioningStateService(config.Backend) - if err != nil { - cancel() - return nil, trace.Wrap(err) - } - secReportsCache, err := local.NewSecReportsService(config.Backend, config.Clock) if err != nil { cancel() @@ -924,13 +915,6 @@ func New(config Config) (*Cache, error) { return nil, trace.Wrap(err) } - identityCenterCache, err := local.NewIdentityCenterService(local.IdentityCenterServiceConfig{ - Backend: config.Backend}) - if err != nil { - cancel() - return nil, trace.Wrap(err) - } - pluginStaticCredentialsCache, err := local.NewPluginStaticCredentialsService(config.Backend) if err != nil { cancel() @@ -966,8 +950,6 @@ func New(config Config) (*Cache, error) { secReportsCache: secReportsCache, eventsFanout: fanout, lowVolumeEventsFanout: utils.NewRoundRobin(lowVolumeFanouts), - provisioningStatesCache: provisioningStatesCache, - identityCenterCache: identityCenterCache, pluginStaticCredentialsCache: pluginStaticCredentialsCache, gitServersCache: gitServersCache, collections: collections, @@ -2054,16 +2036,3 @@ func buildListResourcesResponse[T types.ResourceWithLabels](resources iter.Seq[T return &resp, nil } - -func (c *Cache) GetProvisioningState(ctx context.Context, downstream services.DownstreamID, id services.ProvisioningStateID) (*provisioningv1.PrincipalState, error) { - ctx, span := c.Tracer.Start(ctx, "cache/GetProvisioningState") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.provisioningStates) - if err != nil { - return nil, trace.Wrap(err) - } - defer rg.Release() - - return rg.reader.GetProvisioningState(ctx, downstream, id) -} diff --git a/lib/cache/collections.go b/lib/cache/collections.go index 3b4046be9b038..4996069aa51b9 100644 --- a/lib/cache/collections.go +++ b/lib/cache/collections.go @@ -31,6 +31,7 @@ import ( kubewaitingcontainerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/kubewaitingcontainer/v1" machineidv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1" notificationsv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/notifications/v1" + provisioningv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/provisioning/v1" userprovisioningv2 "github.com/gravitational/teleport/api/gen/proto/go/teleport/userprovisioning/v2" usertasksv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/usertasks/v1" workloadidentityv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" @@ -64,70 +65,72 @@ type collectionHandler interface { type collections struct { byKind map[resourceKind]collectionHandler - provisionTokens *collection[types.ProvisionToken, provisionTokenIndex] - staticTokens *collection[types.StaticTokens, staticTokensIndex] - certAuthorities *collection[types.CertAuthority, certAuthorityIndex] - users *collection[types.User, userIndex] - roles *collection[types.Role, roleIndex] - authServers *collection[types.Server, authServerIndex] - proxyServers *collection[types.Server, proxyServerIndex] - nodes *collection[types.Server, nodeIndex] - apps *collection[types.Application, appIndex] - appServers *collection[types.AppServer, appServerIndex] - dbs *collection[types.Database, databaseIndex] - dbServers *collection[types.DatabaseServer, databaseServerIndex] - dbServices *collection[types.DatabaseService, databaseServiceIndex] - kubeServers *collection[types.KubeServer, kubeServerIndex] - kubeClusters *collection[types.KubeCluster, kubeClusterIndex] - kubeWaitingContainers *collection[*kubewaitingcontainerv1.KubernetesWaitingContainer, kubeWaitingContainerIndex] - windowsDesktops *collection[types.WindowsDesktop, windowsDesktopIndex] - windowsDesktopServices *collection[types.WindowsDesktopService, windowsDesktopServiceIndex] - dynamicWindowsDesktops *collection[types.DynamicWindowsDesktop, dynamicWindowsDesktopIndex] - userGroups *collection[types.UserGroup, userGroupIndex] - identityCenterAccounts *collection[*identitycenterv1.Account, identityCenterAccountIndex] - identityCenterAccountAssignments *collection[*identitycenterv1.AccountAssignment, identityCenterAccountAssignmentIndex] - healthCheckConfig *collection[*healthcheckconfigv1.HealthCheckConfig, healthCheckConfigIndex] - reverseTunnels *collection[types.ReverseTunnel, reverseTunnelIndex] - spiffeFederations *collection[*machineidv1.SPIFFEFederation, spiffeFederationIndex] - workloadIdentity *collection[*workloadidentityv1.WorkloadIdentity, workloadIdentityIndex] - userNotifications *collection[*notificationsv1.Notification, userNotificationIndex] - globalNotifications *collection[*notificationsv1.GlobalNotification, globalNotificationIndex] - clusterName *collection[types.ClusterName, clusterNameIndex] - auditConfig *collection[types.ClusterAuditConfig, clusterAuditConfigIndex] - networkingConfig *collection[types.ClusterNetworkingConfig, clusterNetworkingConfigIndex] - authPreference *collection[types.AuthPreference, authPreferenceIndex] - sessionRecordingConfig *collection[types.SessionRecordingConfig, sessionRecordingConfigIndex] - autoUpdateConfig *collection[*autoupdatev1.AutoUpdateConfig, autoUpdateConfigIndex] - autoUpdateVerion *collection[*autoupdatev1.AutoUpdateVersion, autoUpdateVersionIndex] - autoUpdateRollout *collection[*autoupdatev1.AutoUpdateAgentRollout, autoUpdateAgentRolloutIndex] - oktaImportRules *collection[types.OktaImportRule, oktaImportRuleIndex] - oktaAssignments *collection[types.OktaAssignment, oktaAssignmentIndex] - samlIdPServiceProviders *collection[types.SAMLIdPServiceProvider, samlIdPServiceProviderIndex] - samlIdPSessions *collection[types.WebSession, samlIdPSessionIndex] - webSessions *collection[types.WebSession, webSessionIndex] - appSessions *collection[types.WebSession, appSessionIndex] - snowflakeSessions *collection[types.WebSession, snowflakeSessionIndex] - accessLists *collection[*accesslist.AccessList, accessListIndex] - accessListMembers *collection[*accesslist.AccessListMember, accessListMemberIndex] - accessListReviews *collection[*accesslist.Review, accessListReviewIndex] - crownJewels *collection[*crownjewelv1.CrownJewel, crownJewelIndex] - accessGraphSettings *collection[*clusterconfigv1.AccessGraphSettings, accessGraphSettingsIndex] - integrations *collection[types.Integration, integrationIndex] - pluginStaticCredentials *collection[types.PluginStaticCredentials, pluginStaticCredentialsIndex] - accessMonitoringRules *collection[*accessmonitoringrulesv1.AccessMonitoringRule, accessMonitoringRuleIndex] - webTokens *collection[types.WebToken, webTokenIndex] - uiConfigs *collection[types.UIConfig, webUIConfigIndex] - installers *collection[types.Installer, installerIndex] - locks *collection[types.Lock, lockIndex] - tunnelConnections *collection[types.TunnelConnection, tunnelConnectionIndex] - remoteClusters *collection[types.RemoteCluster, remoteClusterIndex] - userTasks *collection[*usertasksv1.UserTask, userTaskIndex] - userLoginStates *collection[*userloginstate.UserLoginState, userLoginStateIndex] - gitServers *collection[types.Server, gitServerIndex] - databaseObjects *collection[*dbobjectv1.DatabaseObject, databaseObjectIndex] - staticHostUsers *collection[*userprovisioningv2.StaticHostUser, staticHostUserIndex] - networkRestrictions *collection[types.NetworkRestrictions, networkingRestrictionIndex] - discoveryConfigs *collection[*discoveryconfig.DiscoveryConfig, discoveryConfigIndex] + provisionTokens *collection[types.ProvisionToken, provisionTokenIndex] + staticTokens *collection[types.StaticTokens, staticTokensIndex] + certAuthorities *collection[types.CertAuthority, certAuthorityIndex] + users *collection[types.User, userIndex] + roles *collection[types.Role, roleIndex] + authServers *collection[types.Server, authServerIndex] + proxyServers *collection[types.Server, proxyServerIndex] + nodes *collection[types.Server, nodeIndex] + apps *collection[types.Application, appIndex] + appServers *collection[types.AppServer, appServerIndex] + dbs *collection[types.Database, databaseIndex] + dbServers *collection[types.DatabaseServer, databaseServerIndex] + dbServices *collection[types.DatabaseService, databaseServiceIndex] + kubeServers *collection[types.KubeServer, kubeServerIndex] + kubeClusters *collection[types.KubeCluster, kubeClusterIndex] + kubeWaitingContainers *collection[*kubewaitingcontainerv1.KubernetesWaitingContainer, kubeWaitingContainerIndex] + windowsDesktops *collection[types.WindowsDesktop, windowsDesktopIndex] + windowsDesktopServices *collection[types.WindowsDesktopService, windowsDesktopServiceIndex] + dynamicWindowsDesktops *collection[types.DynamicWindowsDesktop, dynamicWindowsDesktopIndex] + userGroups *collection[types.UserGroup, userGroupIndex] + identityCenterAccounts *collection[*identitycenterv1.Account, identityCenterAccountIndex] + identityCenterAccountAssignments *collection[*identitycenterv1.AccountAssignment, identityCenterAccountAssignmentIndex] + healthCheckConfig *collection[*healthcheckconfigv1.HealthCheckConfig, healthCheckConfigIndex] + reverseTunnels *collection[types.ReverseTunnel, reverseTunnelIndex] + spiffeFederations *collection[*machineidv1.SPIFFEFederation, spiffeFederationIndex] + workloadIdentity *collection[*workloadidentityv1.WorkloadIdentity, workloadIdentityIndex] + userNotifications *collection[*notificationsv1.Notification, userNotificationIndex] + globalNotifications *collection[*notificationsv1.GlobalNotification, globalNotificationIndex] + clusterName *collection[types.ClusterName, clusterNameIndex] + auditConfig *collection[types.ClusterAuditConfig, clusterAuditConfigIndex] + networkingConfig *collection[types.ClusterNetworkingConfig, clusterNetworkingConfigIndex] + authPreference *collection[types.AuthPreference, authPreferenceIndex] + sessionRecordingConfig *collection[types.SessionRecordingConfig, sessionRecordingConfigIndex] + autoUpdateConfig *collection[*autoupdatev1.AutoUpdateConfig, autoUpdateConfigIndex] + autoUpdateVerion *collection[*autoupdatev1.AutoUpdateVersion, autoUpdateVersionIndex] + autoUpdateRollout *collection[*autoupdatev1.AutoUpdateAgentRollout, autoUpdateAgentRolloutIndex] + oktaImportRules *collection[types.OktaImportRule, oktaImportRuleIndex] + oktaAssignments *collection[types.OktaAssignment, oktaAssignmentIndex] + samlIdPServiceProviders *collection[types.SAMLIdPServiceProvider, samlIdPServiceProviderIndex] + samlIdPSessions *collection[types.WebSession, samlIdPSessionIndex] + webSessions *collection[types.WebSession, webSessionIndex] + appSessions *collection[types.WebSession, appSessionIndex] + snowflakeSessions *collection[types.WebSession, snowflakeSessionIndex] + accessLists *collection[*accesslist.AccessList, accessListIndex] + accessListMembers *collection[*accesslist.AccessListMember, accessListMemberIndex] + accessListReviews *collection[*accesslist.Review, accessListReviewIndex] + crownJewels *collection[*crownjewelv1.CrownJewel, crownJewelIndex] + accessGraphSettings *collection[*clusterconfigv1.AccessGraphSettings, accessGraphSettingsIndex] + integrations *collection[types.Integration, integrationIndex] + pluginStaticCredentials *collection[types.PluginStaticCredentials, pluginStaticCredentialsIndex] + accessMonitoringRules *collection[*accessmonitoringrulesv1.AccessMonitoringRule, accessMonitoringRuleIndex] + webTokens *collection[types.WebToken, webTokenIndex] + uiConfigs *collection[types.UIConfig, webUIConfigIndex] + installers *collection[types.Installer, installerIndex] + locks *collection[types.Lock, lockIndex] + tunnelConnections *collection[types.TunnelConnection, tunnelConnectionIndex] + remoteClusters *collection[types.RemoteCluster, remoteClusterIndex] + userTasks *collection[*usertasksv1.UserTask, userTaskIndex] + userLoginStates *collection[*userloginstate.UserLoginState, userLoginStateIndex] + gitServers *collection[types.Server, gitServerIndex] + databaseObjects *collection[*dbobjectv1.DatabaseObject, databaseObjectIndex] + staticHostUsers *collection[*userprovisioningv2.StaticHostUser, staticHostUserIndex] + networkRestrictions *collection[types.NetworkRestrictions, networkingRestrictionIndex] + discoveryConfigs *collection[*discoveryconfig.DiscoveryConfig, discoveryConfigIndex] + provisioningStates *collection[*provisioningv1.PrincipalState, principalStateIndex] + identityCenterPrincipalAssignments *collection[*identitycenterv1.PrincipalAssignment, identityCenterPrincipalAssignmentIndex] } // setupCollections ensures that the appropriate [collection] is @@ -658,6 +661,23 @@ func setupCollections(c Config) (*collections, error) { out.discoveryConfigs = collect out.byKind[resourceKind] = out.discoveryConfigs + case types.KindProvisioningPrincipalState: + + collect, err := newPrincipalStateCollection(c.ProvisioningStates, watch) + if err != nil { + return nil, trace.Wrap(err) + } + + out.provisioningStates = collect + out.byKind[resourceKind] = out.provisioningStates + case types.KindIdentityCenterPrincipalAssignment: + collect, err := newIdentityCenterPrincipalAssignmentCollection(c.IdentityCenter, watch) + if err != nil { + return nil, trace.Wrap(err) + } + + out.identityCenterPrincipalAssignments = collect + out.byKind[resourceKind] = out.identityCenterPrincipalAssignments } } diff --git a/lib/cache/identitycenter.go b/lib/cache/identitycenter.go index c42f1cf21a494..b2f94a45eb273 100644 --- a/lib/cache/identitycenter.go +++ b/lib/cache/identitycenter.go @@ -244,62 +244,95 @@ func (c *Cache) ListAccountAssignments(ctx context.Context, pageSize int, pageTo } return assignments, "", nil - } -type identityCenterPrincipalAssignmentGetter interface { - GetPrincipalAssignment(context.Context, services.PrincipalAssignmentID) (*identitycenterv1.PrincipalAssignment, error) - ListPrincipalAssignments(context.Context, int, *pagination.PageRequestToken) ([]*identitycenterv1.PrincipalAssignment, pagination.NextPageToken, error) -} +type identityCenterPrincipalAssignmentIndex string -type identityCenterPrincipalAssignmentExecutor struct{} +const identityCenterPrincipalAssignmentNameIndex identityCenterPrincipalAssignmentIndex = "name" -var _ executor[ - *identitycenterv1.PrincipalAssignment, - identityCenterPrincipalAssignmentGetter, -] = identityCenterPrincipalAssignmentExecutor{} +func newIdentityCenterPrincipalAssignmentCollection(upstream services.IdentityCenter, w types.WatchKind) (*collection[*identitycenterv1.PrincipalAssignment, identityCenterPrincipalAssignmentIndex], error) { + if upstream == nil { + return nil, trace.BadParameter("missing parameter IdentityCenter") + } -func (identityCenterPrincipalAssignmentExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]*identitycenterv1.PrincipalAssignment, error) { - var pageToken pagination.PageRequestToken - var resources []*identitycenterv1.PrincipalAssignment - for { - resourcesPage, nextPage, err := cache.IdentityCenter.ListPrincipalAssignments(ctx, 0, &pageToken) - if err != nil { - return nil, trace.Wrap(err) - } + return &collection[*identitycenterv1.PrincipalAssignment, identityCenterPrincipalAssignmentIndex]{ + store: newStore(map[identityCenterPrincipalAssignmentIndex]func(*identitycenterv1.PrincipalAssignment) string{ + identityCenterPrincipalAssignmentNameIndex: func(r *identitycenterv1.PrincipalAssignment) string { + return r.GetMetadata().GetName() + }, + }), + fetcher: func(ctx context.Context, loadSecrets bool) ([]*identitycenterv1.PrincipalAssignment, error) { + var pageToken pagination.PageRequestToken + var resources []*identitycenterv1.PrincipalAssignment + for { + resourcesPage, nextPage, err := upstream.ListPrincipalAssignments(ctx, 0, &pageToken) + if err != nil { + return nil, trace.Wrap(err) + } - resources = append(resources, resourcesPage...) + resources = append(resources, resourcesPage...) - if nextPage == pagination.EndOfList { - break - } - pageToken.Update(nextPage) - } - return resources, nil + if nextPage == "" { + break + } + pageToken.Update(nextPage) + } + return resources, nil + }, + headerTransform: func(hdr *types.ResourceHeader) *identitycenterv1.PrincipalAssignment { + return &identitycenterv1.PrincipalAssignment{ + Kind: hdr.Kind, + Version: hdr.Version, + Metadata: &headerv1.Metadata{ + Name: hdr.Metadata.Name, + }, + } + }, + watch: w, + }, nil } -func (identityCenterPrincipalAssignmentExecutor) upsert(ctx context.Context, cache *Cache, resource *identitycenterv1.PrincipalAssignment) error { - _, err := cache.identityCenterCache.UpsertPrincipalAssignment(ctx, resource) - return trace.Wrap(err) -} +func (c *Cache) GetPrincipalAssignment(ctx context.Context, id services.PrincipalAssignmentID) (*identitycenterv1.PrincipalAssignment, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetPrincipalAssignment") + defer span.End() -func (identityCenterPrincipalAssignmentExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error { - return trace.Wrap(cache.identityCenterCache.DeletePrincipalAssignment(ctx, - services.PrincipalAssignmentID(resource.GetName()))) + getter := genericGetter[*identitycenterv1.PrincipalAssignment, identityCenterPrincipalAssignmentIndex]{ + cache: c, + collection: c.collections.identityCenterPrincipalAssignments, + index: identityCenterPrincipalAssignmentNameIndex, + upstreamGet: func(ctx context.Context, s string) (*identitycenterv1.PrincipalAssignment, error) { + out, err := c.Config.IdentityCenter.GetPrincipalAssignment(ctx, services.PrincipalAssignmentID(s)) + return out, trace.Wrap(err) + }, + clone: utils.CloneProtoMsg[*identitycenterv1.PrincipalAssignment], + } + out, err := getter.get(ctx, string(id)) + return out, trace.Wrap(err) } -func (identityCenterPrincipalAssignmentExecutor) deleteAll(ctx context.Context, cache *Cache) error { - _, err := cache.identityCenterCache.DeleteAllPrincipalAssignments(ctx, &identitycenterv1.DeleteAllPrincipalAssignmentsRequest{}) - return trace.Wrap(err) -} +func (c *Cache) ListPrincipalAssignments(ctx context.Context, pageSize int, req *pagination.PageRequestToken) ([]*identitycenterv1.PrincipalAssignment, pagination.NextPageToken, error) { + ctx, span := c.Tracer.Start(ctx, "cache/ListPrincipalAssignments") + defer span.End() + + lister := genericLister[*identitycenterv1.PrincipalAssignment, identityCenterPrincipalAssignmentIndex]{ + cache: c, + collection: c.collections.identityCenterPrincipalAssignments, + index: identityCenterPrincipalAssignmentNameIndex, + upstreamList: func(ctx context.Context, pageSize int, s string) ([]*identitycenterv1.PrincipalAssignment, string, error) { + out, next, err := c.Config.IdentityCenter.ListPrincipalAssignments(ctx, pageSize, req) + return out, string(next), trace.Wrap(err) + }, + nextToken: func(t *identitycenterv1.PrincipalAssignment) string { + return t.GetMetadata().GetName() + }, + clone: utils.CloneProtoMsg[*identitycenterv1.PrincipalAssignment], + } -func (identityCenterPrincipalAssignmentExecutor) getReader(cache *Cache, cacheOK bool) identityCenterPrincipalAssignmentGetter { - if cacheOK { - return cache.identityCenterCache + nextToken, err := req.Consume() + if err != nil { + return nil, "", trace.Wrap(err) } - return cache.Config.IdentityCenter -} -func (identityCenterPrincipalAssignmentExecutor) isSingleton() bool { - return false + out, next, err := lister.list(ctx, pageSize, nextToken) + return out, pagination.NextPageToken(next), trace.Wrap(err) } diff --git a/lib/cache/identitycenter_test.go b/lib/cache/identitycenter_test.go index 707407d3c1881..26c2544372541 100644 --- a/lib/cache/identitycenter_test.go +++ b/lib/cache/identitycenter_test.go @@ -141,30 +141,11 @@ func newIdentityCenterPrincipalAssignment(id string) *identitycenterv1.Principal } } -// TestIdentityCenterPrincpialAssignment asserts that an Identity Center PrincipalAssignment can be cached +// TestIdentityCenterPrincipalAssignment asserts that an Identity Center PrincipalAssignment can be cached func TestIdentityCenterPrincipalAssignment(t *testing.T) { fixturePack := newTestPack(t, ForAuth) t.Cleanup(fixturePack.Close) - collect := func(ctx context.Context, src identityCenterPrincipalAssignmentGetter) ([]*identitycenterv1.PrincipalAssignment, error) { - var result []*identitycenterv1.PrincipalAssignment - var pageToken pagination.PageRequestToken - for { - page, nextPage, err := src.ListPrincipalAssignments(ctx, 0, &pageToken) - if err != nil { - return nil, trace.Wrap(err) - } - result = append(result, page...) - - if nextPage == pagination.EndOfList { - break - } - - pageToken.Update(nextPage) - } - return result, nil - } - testResources153(t, fixturePack, testFuncs153[*identitycenterv1.PrincipalAssignment]{ newResource: func(s string) (*identitycenterv1.PrincipalAssignment, error) { return newIdentityCenterPrincipalAssignment(s), nil @@ -178,7 +159,22 @@ func TestIdentityCenterPrincipalAssignment(t *testing.T) { return trace.Wrap(err) }, list: func(ctx context.Context) ([]*identitycenterv1.PrincipalAssignment, error) { - return collect(ctx, fixturePack.identityCenter) + var result []*identitycenterv1.PrincipalAssignment + var pageToken pagination.PageRequestToken + for { + page, nextPage, err := fixturePack.identityCenter.ListPrincipalAssignments(ctx, 0, &pageToken) + if err != nil { + return nil, trace.Wrap(err) + } + result = append(result, page...) + + if nextPage == pagination.EndOfList { + break + } + + pageToken.Update(nextPage) + } + return result, nil }, delete: func(ctx context.Context, id string) error { return trace.Wrap(fixturePack.identityCenter.DeletePrincipalAssignment(ctx, services.PrincipalAssignmentID(id))) @@ -188,11 +184,25 @@ func TestIdentityCenterPrincipalAssignment(t *testing.T) { return trace.Wrap(err) }, cacheList: func(ctx context.Context) ([]*identitycenterv1.PrincipalAssignment, error) { - return collect(ctx, fixturePack.cache.identityCenterCache) + var result []*identitycenterv1.PrincipalAssignment + var pageToken pagination.PageRequestToken + for { + page, nextPage, err := fixturePack.cache.ListPrincipalAssignments(ctx, 0, &pageToken) + if err != nil { + return nil, trace.Wrap(err) + } + result = append(result, page...) + + if nextPage == pagination.EndOfList { + break + } + + pageToken.Update(nextPage) + } + return result, nil }, cacheGet: func(ctx context.Context, id string) (*identitycenterv1.PrincipalAssignment, error) { - r, err := fixturePack.cache.identityCenterCache.GetPrincipalAssignment( - ctx, services.PrincipalAssignmentID(id)) + r, err := fixturePack.cache.GetPrincipalAssignment(ctx, services.PrincipalAssignmentID(id)) return r, trace.Wrap(err) }, }) diff --git a/lib/cache/legacy_collections.go b/lib/cache/legacy_collections.go index 0c681edcc47fc..556a733ae3468 100644 --- a/lib/cache/legacy_collections.go +++ b/lib/cache/legacy_collections.go @@ -26,8 +26,6 @@ import ( "github.com/gravitational/trace" "github.com/gravitational/teleport/api/client/proto" - identitycenterv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/identitycenter/v1" - provisioningv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/provisioning/v1" userspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/users/v1" usertasksv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/usertasks/v1" "github.com/gravitational/teleport/api/types" @@ -90,11 +88,9 @@ type legacyCollections struct { // byKind is a map of registered collections by resource Kind/SubKind byKind map[resourceKind]legacyCollection - auditQueries collectionReader[services.SecurityAuditQueryGetter] - secReports collectionReader[services.SecurityReportGetter] - secReportsStates collectionReader[services.SecurityReportStateGetter] - provisioningStates collectionReader[provisioningStateGetter] - identityCenterPrincipalAssignments collectionReader[identityCenterPrincipalAssignmentGetter] + auditQueries collectionReader[services.SecurityAuditQueryGetter] + secReports collectionReader[services.SecurityReportGetter] + secReportsStates collectionReader[services.SecurityReportStateGetter] } // setupLegacyCollections returns a registry of legacyCollections. @@ -131,28 +127,6 @@ func setupLegacyCollections(c *Cache, watches []types.WatchKind) (*legacyCollect } collections.secReportsStates = &genericCollection[*secreports.ReportState, services.SecurityReportStateGetter, secReportStateExecutor]{cache: c, watch: watch} collections.byKind[resourceKind] = collections.secReportsStates - case types.KindProvisioningPrincipalState: - if c.ProvisioningStates == nil { - return nil, trace.BadParameter("missing parameter KindProvisioningState") - } - collections.provisioningStates = &genericCollection[*provisioningv1.PrincipalState, provisioningStateGetter, provisioningStateExecutor]{ - cache: c, - watch: watch, - } - collections.byKind[resourceKind] = collections.provisioningStates - case types.KindIdentityCenterPrincipalAssignment: - if c.IdentityCenter == nil { - return nil, trace.BadParameter("missing parameter IdentityCenter") - } - collections.identityCenterPrincipalAssignments = &genericCollection[ - *identitycenterv1.PrincipalAssignment, - identityCenterPrincipalAssignmentGetter, - identityCenterPrincipalAssignmentExecutor, - ]{ - cache: c, - watch: watch, - } - collections.byKind[resourceKind] = collections.identityCenterPrincipalAssignments default: if _, ok := c.collections.byKind[resourceKind]; !ok { return nil, trace.BadParameter("resource %q is not supported", watch.Kind) diff --git a/lib/cache/provisioning.go b/lib/cache/provisioning.go index b078379da2d9e..80288f5cd733f 100644 --- a/lib/cache/provisioning.go +++ b/lib/cache/provisioning.go @@ -21,87 +21,105 @@ import ( "github.com/gravitational/trace" + headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" provisioningv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/provisioning/v1" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils/pagination" ) -type provisioningStateGetter interface { - GetProvisioningState(context.Context, services.DownstreamID, services.ProvisioningStateID) (*provisioningv1.PrincipalState, error) - ListProvisioningStatesForAllDownstreams(context.Context, int, *pagination.PageRequestToken) ([]*provisioningv1.PrincipalState, pagination.NextPageToken, error) -} - -type provisioningStateExecutor struct{} - -func (provisioningStateExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]*provisioningv1.PrincipalState, error) { - if cache == nil { - return nil, trace.BadParameter("cache is nil") - } - - if cache.ProvisioningStates == nil { - return nil, trace.BadParameter("cache provisioning state source is not set") - } - - var page pagination.PageRequestToken - var resources []*provisioningv1.PrincipalState - for { - var resourcesPage []*provisioningv1.PrincipalState - var err error - - resourcesPage, nextPage, err := cache.ProvisioningStates.ListProvisioningStatesForAllDownstreams(ctx, 0, &page) - if err != nil { - return nil, trace.Wrap(err) - } +type principalStateIndex string - resources = append(resources, resourcesPage...) +const principalStateNameIndex principalStateIndex = "name" - if nextPage == pagination.EndOfList { - break - } - page.Update(nextPage) +func newPrincipalStateCollection(upstream services.ProvisioningStates, w types.WatchKind) (*collection[*provisioningv1.PrincipalState, principalStateIndex], error) { + if upstream == nil { + return nil, trace.BadParameter("missing parameter ProvisioningStates") } - return resources, nil -} -func (provisioningStateExecutor) upsert(ctx context.Context, cache *Cache, resource *provisioningv1.PrincipalState) error { - _, err := cache.provisioningStatesCache.UpsertProvisioningState(ctx, resource) - return trace.Wrap(err) + return &collection[*provisioningv1.PrincipalState, principalStateIndex]{ + store: newStore(map[principalStateIndex]func(*provisioningv1.PrincipalState) string{ + principalStateNameIndex: func(r *provisioningv1.PrincipalState) string { + return r.GetMetadata().GetName() + }, + }), + fetcher: func(ctx context.Context, loadSecrets bool) ([]*provisioningv1.PrincipalState, error) { + var page pagination.PageRequestToken + var resources []*provisioningv1.PrincipalState + for { + var resourcesPage []*provisioningv1.PrincipalState + var err error + + resourcesPage, nextPage, err := upstream.ListProvisioningStatesForAllDownstreams(ctx, 0, &page) + if err != nil { + return nil, trace.Wrap(err) + } + + resources = append(resources, resourcesPage...) + + if nextPage == "" { + break + } + page.Update(nextPage) + } + + return resources, nil + }, + headerTransform: func(hdr *types.ResourceHeader) *provisioningv1.PrincipalState { + return &provisioningv1.PrincipalState{ + Kind: hdr.Kind, + Version: hdr.Version, + Metadata: &headerv1.Metadata{ + Name: hdr.Metadata.Name, + }, + } + }, + watch: w, + }, nil } -func (provisioningStateExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error { - unwrapper, ok := resource.(types.Resource153UnwrapperT[*provisioningv1.PrincipalState]) - if !ok { - return trace.BadParameter("resource must implement Resource153Unwrapper: %T", resource) - } - - principalState := unwrapper.UnwrapT() - principalStateID := principalState.GetMetadata().GetName() - downstreamID := principalState.GetSpec().GetDownstreamId() - if principalStateID == "" || downstreamID == "" { - return trace.BadParameter("malformed PrincipalState") +func (c *Cache) GetProvisioningState(ctx context.Context, downstream services.DownstreamID, id services.ProvisioningStateID) (*provisioningv1.PrincipalState, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetProvisioningState") + defer span.End() + + getter := genericGetter[*provisioningv1.PrincipalState, principalStateIndex]{ + cache: c, + collection: c.collections.provisioningStates, + index: principalStateNameIndex, + upstreamGet: func(ctx context.Context, s string) (*provisioningv1.PrincipalState, error) { + out, err := c.Config.ProvisioningStates.GetProvisioningState(ctx, downstream, id) + return out, trace.Wrap(err) + }, + clone: utils.CloneProtoMsg[*provisioningv1.PrincipalState], } - - err := cache.provisioningStatesCache.DeleteProvisioningState( - ctx, - services.DownstreamID(downstreamID), - services.ProvisioningStateID(principalStateID)) - return trace.Wrap(err) + out, err := getter.get(ctx, string(id)) + return out, trace.Wrap(err) } -func (provisioningStateExecutor) deleteAll(ctx context.Context, cache *Cache) error { - return trace.Wrap(cache.provisioningStatesCache.DeleteAllProvisioningStates(ctx)) -} +func (c *Cache) ListProvisioningStatesForAllDownstreams(ctx context.Context, pageSize int, req *pagination.PageRequestToken) ([]*provisioningv1.PrincipalState, pagination.NextPageToken, error) { + ctx, span := c.Tracer.Start(ctx, "cache/ListProvisioningStatesForAllDownstreams") + defer span.End() + + lister := genericLister[*provisioningv1.PrincipalState, principalStateIndex]{ + cache: c, + collection: c.collections.provisioningStates, + index: principalStateNameIndex, + upstreamList: func(ctx context.Context, pageSize int, s string) ([]*provisioningv1.PrincipalState, string, error) { + out, next, err := c.Config.ProvisioningStates.ListProvisioningStatesForAllDownstreams(ctx, pageSize, req) + return out, string(next), trace.Wrap(err) + }, + nextToken: func(t *provisioningv1.PrincipalState) string { + return t.GetMetadata().GetName() + }, + clone: utils.CloneProtoMsg[*provisioningv1.PrincipalState], + } -func (provisioningStateExecutor) getReader(cache *Cache, cacheOK bool) provisioningStateGetter { - if cacheOK { - return cache.provisioningStatesCache + nextToken, err := req.Consume() + if err != nil { + return nil, "", trace.Wrap(err) } - return cache.Config.ProvisioningStates -} -func (provisioningStateExecutor) isSingleton() bool { - return false + out, next, err := lister.list(ctx, pageSize, nextToken) + return out, pagination.NextPageToken(next), trace.Wrap(err) } - -var _ executor[*provisioningv1.PrincipalState, provisioningStateGetter] = provisioningStateExecutor{} diff --git a/lib/cache/provisioning_test.go b/lib/cache/provisioning_test.go index e13c363f7391b..365140069bb44 100644 --- a/lib/cache/provisioning_test.go +++ b/lib/cache/provisioning_test.go @@ -50,33 +50,13 @@ func newProvisioningPrincipalState(id string) *provisioningv1.PrincipalState { } } -// TestProvisioningState asserts that a ProvisioningPrincipalState can be cached +// TestProvisioningPrincipalState asserts that a ProvisioningPrincipalState can be cached func TestProvisioningPrincipalState(t *testing.T) { - t.Parallel() fixturePack := newTestPack(t, ForAuth) t.Cleanup(fixturePack.Close) - collect := func(ctx context.Context, src provisioningStateGetter) ([]*provisioningv1.PrincipalState, error) { - var result []*provisioningv1.PrincipalState - var pageToken pagination.PageRequestToken - for { - page, nextPage, err := src.ListProvisioningStatesForAllDownstreams(ctx, 0, &pageToken) - if err != nil { - return nil, trace.Wrap(err) - } - result = append(result, page...) - - if nextPage == pagination.EndOfList { - break - } - - pageToken.Update(nextPage) - } - return result, nil - } - testResources153(t, fixturePack, testFuncs153[*provisioningv1.PrincipalState]{ newResource: func(s string) (*provisioningv1.PrincipalState, error) { return newProvisioningPrincipalState(s), nil @@ -90,7 +70,22 @@ func TestProvisioningPrincipalState(t *testing.T) { return trace.Wrap(err) }, list: func(ctx context.Context) ([]*provisioningv1.PrincipalState, error) { - return collect(ctx, fixturePack.provisioningStates) + var result []*provisioningv1.PrincipalState + var pageToken pagination.PageRequestToken + for { + page, nextPage, err := fixturePack.provisioningStates.ListProvisioningStatesForAllDownstreams(ctx, 0, &pageToken) + if err != nil { + return nil, trace.Wrap(err) + } + result = append(result, page...) + + if nextPage == pagination.EndOfList { + break + } + + pageToken.Update(nextPage) + } + return result, nil }, delete: func(ctx context.Context, id string) error { return trace.Wrap(fixturePack.provisioningStates.DeleteProvisioningState( @@ -100,7 +95,22 @@ func TestProvisioningPrincipalState(t *testing.T) { return trace.Wrap(fixturePack.provisioningStates.DeleteAllProvisioningStates(ctx)) }, cacheList: func(ctx context.Context) ([]*provisioningv1.PrincipalState, error) { - return collect(ctx, fixturePack.cache.provisioningStatesCache) + var result []*provisioningv1.PrincipalState + var pageToken pagination.PageRequestToken + for { + page, nextPage, err := fixturePack.cache.ListProvisioningStatesForAllDownstreams(ctx, 0, &pageToken) + if err != nil { + return nil, trace.Wrap(err) + } + result = append(result, page...) + + if nextPage == pagination.EndOfList { + break + } + + pageToken.Update(nextPage) + } + return result, nil }, cacheGet: func(ctx context.Context, id string) (*provisioningv1.PrincipalState, error) { r, err := fixturePack.provisioningStates.GetProvisioningState(