diff --git a/api/types/userloginstate/user_login_state.go b/api/types/userloginstate/user_login_state.go index e566d8fb4314d..3dc82a21e7c09 100644 --- a/api/types/userloginstate/user_login_state.go +++ b/api/types/userloginstate/user_login_state.go @@ -23,6 +23,7 @@ import ( "github.com/gravitational/teleport/api/types/header" "github.com/gravitational/teleport/api/types/header/convert/legacy" "github.com/gravitational/teleport/api/types/trait" + "github.com/gravitational/teleport/api/utils" ) // UserLoginState is the ephemeral user login state. This will hold data to differentiate @@ -99,6 +100,13 @@ func (u *UserLoginState) CheckAndSetDefaults() error { return nil } +// Clone returns a copy of the member. +func (u *UserLoginState) Clone() *UserLoginState { + var copy *UserLoginState + utils.StrictObjectToStruct(u, ©) + return copy +} + // GetOriginalRoles returns the original roles that the user login state was derived from. func (u *UserLoginState) GetOriginalRoles() []string { return u.Spec.OriginalRoles diff --git a/lib/cache/cache.go b/lib/cache/cache.go index 49a5f30228641..6d79f1c31a3b1 100644 --- a/lib/cache/cache.go +++ b/lib/cache/cache.go @@ -42,13 +42,11 @@ import ( kubewaitingcontainerpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/kubewaitingcontainer/v1" provisioningv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/provisioning/v1" userprovisioningpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/userprovisioning/v2" - usertasksv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/usertasks/v1" "github.com/gravitational/teleport/api/internalutils/stream" apitracing "github.com/gravitational/teleport/api/observability/tracing" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/discoveryconfig" "github.com/gravitational/teleport/api/types/secreports" - "github.com/gravitational/teleport/api/types/userloginstate" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/backend" @@ -508,11 +506,9 @@ type Cache struct { databaseObjectsCache *local.DatabaseObjectService dynamicWindowsDesktopsCache services.DynamicWindowsDesktops userGroupsCache services.UserGroups - userTasksCache services.UserTasks discoveryConfigsCache services.DiscoveryConfigs headlessAuthenticationsCache services.HeadlessAuthenticationService secReportsCache services.SecReports - userLoginStateCache services.UserLoginStates eventsFanout *services.FanoutV2 lowVolumeEventsFanout *utils.RoundRobin[*services.FanoutV2] kubeWaitingContsCache *local.KubeWaitingContainerService @@ -918,12 +914,6 @@ func New(config Config) (*Cache, error) { return nil, trace.Wrap(err) } - userTasksCache, err := local.NewUserTasksService(config.Backend) - if err != nil { - cancel() - return nil, trace.Wrap(err) - } - discoveryConfigsCache, err := local.NewDiscoveryConfigService(config.Backend) if err != nil { cancel() @@ -936,12 +926,6 @@ func New(config Config) (*Cache, error) { return nil, trace.Wrap(err) } - userLoginStatesCache, err := local.NewUserLoginStateService(config.Backend) - if err != nil { - cancel() - return nil, trace.Wrap(err) - } - databaseObjectsCache, err := local.NewDatabaseObjectService(config.Backend) if err != nil { cancel() @@ -1018,11 +1002,9 @@ func New(config Config) (*Cache, error) { restrictionsCache: local.NewRestrictionsService(config.Backend), dynamicWindowsDesktopsCache: dynamicDesktopsService, userGroupsCache: userGroupsCache, - userTasksCache: userTasksCache, discoveryConfigsCache: discoveryConfigsCache, headlessAuthenticationsCache: identityService, secReportsCache: secReportsCache, - userLoginStateCache: userLoginStatesCache, databaseObjectsCache: databaseObjectsCache, eventsFanout: fanout, lowVolumeEventsFanout: utils.NewRoundRobin(lowVolumeFanouts), @@ -1904,32 +1886,6 @@ func (c *Cache) ListDynamicWindowsDesktops(ctx context.Context, pageSize int, ne return rg.reader.ListDynamicWindowsDesktops(ctx, pageSize, nextPage) } -// ListUserTasks returns a list of UserTask resources. -func (c *Cache) ListUserTasks(ctx context.Context, pageSize int64, nextKey string, filters *usertasksv1.ListUserTasksFilters) ([]*usertasksv1.UserTask, string, error) { - ctx, span := c.Tracer.Start(ctx, "cache/ListUserTasks") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.userTasks) - if err != nil { - return nil, "", trace.Wrap(err) - } - defer rg.Release() - return rg.reader.ListUserTasks(ctx, pageSize, nextKey, filters) -} - -// GetUserTask returns the specified UserTask resource. -func (c *Cache) GetUserTask(ctx context.Context, name string) (*usertasksv1.UserTask, error) { - ctx, span := c.Tracer.Start(ctx, "cache/GetUserTask") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.userTasks) - if err != nil { - return nil, trace.Wrap(err) - } - defer rg.Release() - return rg.reader.GetUserTask(ctx, name) -} - // ListDiscoveryConfigs returns a paginated list of all DiscoveryConfig resources. func (c *Cache) ListDiscoveryConfigs(ctx context.Context, pageSize int, nextKey string) ([]*discoveryconfig.DiscoveryConfig, string, error) { ctx, span := c.Tracer.Start(ctx, "cache/ListDiscoveryConfigs") @@ -2073,43 +2029,6 @@ func (c *Cache) ListSecurityReportsStates(ctx context.Context, pageSize int, nex return rg.reader.ListSecurityReportsStates(ctx, pageSize, nextKey) } -// GetUserLoginStates returns the all user login state resources. -func (c *Cache) GetUserLoginStates(ctx context.Context) ([]*userloginstate.UserLoginState, error) { - ctx, span := c.Tracer.Start(ctx, "cache/GetUserLoginStates") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.userLoginStates) - if err != nil { - return nil, trace.Wrap(err) - } - defer rg.Release() - return rg.reader.GetUserLoginStates(ctx) -} - -// GetUserLoginState returns the specified user login state resource. -func (c *Cache) GetUserLoginState(ctx context.Context, name string) (*userloginstate.UserLoginState, error) { - ctx, span := c.Tracer.Start(ctx, "cache/GetUserLoginState") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.userLoginStates) - if err != nil { - return nil, trace.Wrap(err) - } - - uls, err := rg.reader.GetUserLoginState(ctx, name) - if trace.IsNotFound(err) && rg.IsCacheRead() { - // release read lock early - rg.Release() - // fallback is sane because method is never used - // in construction of derivative caches. - if uls, err := c.Config.UserLoginStates.GetUserLoginState(ctx, name); err == nil { - return uls, nil - } - } - defer rg.Release() - return uls, trace.Wrap(err) -} - // ListResources is a part of auth.Cache implementation func (c *Cache) ListResources(ctx context.Context, req proto.ListResourcesRequest) (*types.ListResourcesResponse, error) { ctx, span := c.Tracer.Start(ctx, "cache/ListResources") diff --git a/lib/cache/cache_test.go b/lib/cache/cache_test.go index c2def644feada..d5ef68d705117 100644 --- a/lib/cache/cache_test.go +++ b/lib/cache/cache_test.go @@ -1223,34 +1223,6 @@ func mustCreateDatabase(t *testing.T, name, protocol, uri string) *types.Databas return database } -// TestUserTasks tests that CRUD operations on user notification resources are -// replicated from the backend to the cache. -func TestUserTasks(t *testing.T) { - t.Parallel() - - p := newTestPack(t, ForAuth) - t.Cleanup(p.Close) - - testResources153(t, p, testFuncs153[*usertasksv1.UserTask]{ - newResource: func(name string) (*usertasksv1.UserTask, error) { - return newUserTasks(t), nil - }, - create: func(ctx context.Context, item *usertasksv1.UserTask) error { - _, err := p.userTasks.CreateUserTask(ctx, item) - return trace.Wrap(err) - }, - list: func(ctx context.Context) ([]*usertasksv1.UserTask, error) { - items, _, err := p.userTasks.ListUserTasks(ctx, 0, "", &usertasksv1.ListUserTasksFilters{}) - return items, trace.Wrap(err) - }, - cacheList: func(ctx context.Context) ([]*usertasksv1.UserTask, error) { - items, _, err := p.userTasks.ListUserTasks(ctx, 0, "", &usertasksv1.ListUserTasksFilters{}) - return items, trace.Wrap(err) - }, - deleteAll: p.userTasks.DeleteAllUserTasks, - }) -} - func newUserTasks(t *testing.T) *usertasksv1.UserTask { t.Helper() @@ -1397,33 +1369,6 @@ func TestSecurityReportState(t *testing.T) { }) } -// TestUserLoginStates tests that CRUD operations on user login state resources are -// replicated from the backend to the cache. -func TestUserLoginStates(t *testing.T) { - t.Parallel() - - p := newTestPack(t, ForAuth) - t.Cleanup(p.Close) - - testResources(t, p, testFuncs[*userloginstate.UserLoginState]{ - newResource: func(name string) (*userloginstate.UserLoginState, error) { - return newUserLoginState(t, name), nil - }, - create: func(ctx context.Context, uls *userloginstate.UserLoginState) error { - _, err := p.userLoginStates.UpsertUserLoginState(ctx, uls) - return trace.Wrap(err) - }, - list: p.userLoginStates.GetUserLoginStates, - cacheGet: p.cache.GetUserLoginState, - cacheList: p.cache.GetUserLoginStates, - update: func(ctx context.Context, uls *userloginstate.UserLoginState) error { - _, err := p.userLoginStates.UpsertUserLoginState(ctx, uls) - return trace.Wrap(err) - }, - deleteAll: p.userLoginStates.DeleteAllUserLoginStates, - }) -} - func TestDatabaseObjects(t *testing.T) { t.Parallel() diff --git a/lib/cache/collections.go b/lib/cache/collections.go index 83f7783c63f31..8a99c4dc8b082 100644 --- a/lib/cache/collections.go +++ b/lib/cache/collections.go @@ -29,9 +29,11 @@ import ( identitycenterv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/identitycenter/v1" machineidv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1" notificationsv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/notifications/v1" + usertasksv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/usertasks/v1" workloadidentityv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/accesslist" + "github.com/gravitational/teleport/api/types/userloginstate" ) // collectionHandler is used by the [Cache] to seed the initial @@ -113,6 +115,8 @@ type collections struct { locks *collection[types.Lock, lockIndex] tunnelConnections *collection[types.TunnelConnection, tunnelConnectionIndex] remoteClusters *collection[types.RemoteCluster, remoteClusterIndex] + userTasks *collection[*usertasksv1.UserTask, userTaskIndex] + userLoginStates *collection[*userloginstate.UserLoginState, userLoginStateIndex] } // setupCollections ensures that the appropriate [collection] is @@ -449,6 +453,7 @@ func setupCollections(c Config) (*collections, error) { out.samlIdPSessions = collect out.byKind[resourceKind] = out.samlIdPSessions + case types.KindWebSession: collect, err := newWebSessionCollection(c.WebSession, watch) if err != nil { @@ -570,6 +575,22 @@ func setupCollections(c Config) (*collections, error) { out.remoteClusters = collect out.byKind[resourceKind] = out.remoteClusters + case types.KindUserTask: + collect, err := newUserTaskCollection(c.UserTasks, watch) + if err != nil { + return nil, trace.Wrap(err) + } + + out.userTasks = collect + out.byKind[resourceKind] = out.userTasks + case types.KindUserLoginState: + collect, err := newUserLoginStateCollection(c.UserLoginStates, watch) + if err != nil { + return nil, trace.Wrap(err) + } + + out.userLoginStates = collect + out.byKind[resourceKind] = out.userLoginStates } } diff --git a/lib/cache/legacy_collections.go b/lib/cache/legacy_collections.go index a0e7c4ca72c8a..08dc04c24efcd 100644 --- a/lib/cache/legacy_collections.go +++ b/lib/cache/legacy_collections.go @@ -36,7 +36,6 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/discoveryconfig" "github.com/gravitational/teleport/api/types/secreports" - "github.com/gravitational/teleport/api/types/userloginstate" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/services" ) @@ -101,11 +100,9 @@ type legacyCollections struct { secReportsStates collectionReader[services.SecurityReportStateGetter] databaseObjects collectionReader[services.DatabaseObjectsGetter] discoveryConfigs collectionReader[services.DiscoveryConfigsGetter] - userTasks collectionReader[userTasksGetter] kubeWaitingContainers collectionReader[kubernetesWaitingContainerGetter] staticHostUsers collectionReader[staticHostUserGetter] networkRestrictions collectionReader[networkRestrictionGetter] - userLoginStates collectionReader[services.UserLoginStatesGetter] dynamicWindowsDesktops collectionReader[dynamicWindowsDesktopsGetter] provisioningStates collectionReader[provisioningStateGetter] identityCenterPrincipalAssignments collectionReader[identityCenterPrincipalAssignmentGetter] @@ -152,15 +149,6 @@ func setupLegacyCollections(c *Cache, watches []types.WatchKind) (*legacyCollect watch: watch, } collections.byKind[resourceKind] = collections.dynamicWindowsDesktops - case types.KindUserTask: - if c.UserTasks == nil { - return nil, trace.BadParameter("missing parameter user tasks") - } - collections.userTasks = &genericCollection[*usertasksv1.UserTask, userTasksGetter, userTasksExecutor]{ - cache: c, - watch: watch, - } - collections.byKind[resourceKind] = collections.userTasks case types.KindDiscoveryConfig: if c.DiscoveryConfigs == nil { return nil, trace.BadParameter("missing parameter DiscoveryConfigs") @@ -188,12 +176,6 @@ func setupLegacyCollections(c *Cache, watches []types.WatchKind) (*legacyCollect } collections.secReportsStates = &genericCollection[*secreports.ReportState, services.SecurityReportStateGetter, secReportStateExecutor]{cache: c, watch: watch} collections.byKind[resourceKind] = collections.secReportsStates - case types.KindUserLoginState: - if c.UserLoginStates == nil { - return nil, trace.BadParameter("missing parameter UserLoginStates") - } - collections.userLoginStates = &genericCollection[*userloginstate.UserLoginState, services.UserLoginStatesGetter, userLoginStateExecutor]{cache: c, watch: watch} - collections.byKind[resourceKind] = collections.userLoginStates case types.KindKubeWaitingContainer: if c.KubeWaitingContainers == nil { return nil, trace.BadParameter("missing parameter KubeWaitingContainers") @@ -605,51 +587,6 @@ type staticHostUserGetter interface { GetStaticHostUser(ctx context.Context, name string) (*userprovisioningpb.StaticHostUser, error) } -type userTasksExecutor struct{} - -func (userTasksExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]*usertasksv1.UserTask, error) { - var resources []*usertasksv1.UserTask - var nextToken string - for { - var page []*usertasksv1.UserTask - var err error - page, nextToken, err = cache.UserTasks.ListUserTasks(ctx, 0 /* page size */, nextToken, &usertasksv1.ListUserTasksFilters{}) - if err != nil { - return nil, trace.Wrap(err) - } - resources = append(resources, page...) - - if nextToken == "" { - break - } - } - return resources, nil -} - -func (userTasksExecutor) upsert(ctx context.Context, cache *Cache, resource *usertasksv1.UserTask) error { - _, err := cache.userTasksCache.UpsertUserTask(ctx, resource) - return trace.Wrap(err) -} - -func (userTasksExecutor) deleteAll(ctx context.Context, cache *Cache) error { - return cache.userTasksCache.DeleteAllUserTasks(ctx) -} - -func (userTasksExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error { - return cache.userTasksCache.DeleteUserTask(ctx, resource.GetName()) -} - -func (userTasksExecutor) isSingleton() bool { return false } - -func (userTasksExecutor) getReader(cache *Cache, cacheOK bool) userTasksGetter { - if cacheOK { - return cache.userTasksCache - } - return cache.Config.UserTasks -} - -var _ executor[*usertasksv1.UserTask, userTasksGetter] = userTasksExecutor{} - // collectionReader extends the collection interface, adding routing capabilities. type collectionReader[R any] interface { legacyCollection @@ -873,34 +810,3 @@ func (noopExecutor) getReader(_ *Cache, _ bool) noReader { } var _ executor[*types.HeadlessAuthentication, noReader] = noopExecutor{} - -type userLoginStateExecutor struct{} - -func (userLoginStateExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]*userloginstate.UserLoginState, error) { - resources, err := cache.UserLoginStates.GetUserLoginStates(ctx) - return resources, trace.Wrap(err) -} - -func (userLoginStateExecutor) upsert(ctx context.Context, cache *Cache, resource *userloginstate.UserLoginState) error { - _, err := cache.userLoginStateCache.UpsertUserLoginState(ctx, resource) - return trace.Wrap(err) -} - -func (userLoginStateExecutor) deleteAll(ctx context.Context, cache *Cache) error { - return cache.userLoginStateCache.DeleteAllUserLoginStates(ctx) -} - -func (userLoginStateExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error { - return cache.userLoginStateCache.DeleteUserLoginState(ctx, resource.GetName()) -} - -func (userLoginStateExecutor) isSingleton() bool { return false } - -func (userLoginStateExecutor) getReader(cache *Cache, cacheOK bool) services.UserLoginStatesGetter { - if cacheOK { - return cache.userLoginStateCache - } - return cache.Config.UserLoginStates -} - -var _ executor[*userloginstate.UserLoginState, services.UserLoginStatesGetter] = userLoginStateExecutor{} diff --git a/lib/cache/user_login_state.go b/lib/cache/user_login_state.go new file mode 100644 index 0000000000000..70778d177aab5 --- /dev/null +++ b/lib/cache/user_login_state.go @@ -0,0 +1,116 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "context" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/types/header" + "github.com/gravitational/teleport/api/types/userloginstate" + "github.com/gravitational/teleport/lib/services" +) + +type userLoginStateIndex string + +const userLoginStateNameIndex userLoginStateIndex = "name" + +func newUserLoginStateCollection(upstream services.UserLoginStates, w types.WatchKind) (*collection[*userloginstate.UserLoginState, userLoginStateIndex], error) { + if upstream == nil { + return nil, trace.BadParameter("missing parameter UserTasks") + } + + return &collection[*userloginstate.UserLoginState, userLoginStateIndex]{ + store: newStore(map[userLoginStateIndex]func(*userloginstate.UserLoginState) string{ + userLoginStateNameIndex: func(r *userloginstate.UserLoginState) string { + return r.GetMetadata().Name + }, + }), + fetcher: func(ctx context.Context, loadSecrets bool) ([]*userloginstate.UserLoginState, error) { + out, err := upstream.GetUserLoginStates(ctx) + return out, trace.Wrap(err) + }, + headerTransform: func(hdr *types.ResourceHeader) *userloginstate.UserLoginState { + return &userloginstate.UserLoginState{ + ResourceHeader: header.ResourceHeader{ + Kind: hdr.Kind, + Version: hdr.Version, + Metadata: header.Metadata{ + Name: hdr.Metadata.Name, + }, + }, + } + }, + watch: w, + }, nil +} + +// GetUserLoginStates returns the all user login state resources. +func (c *Cache) GetUserLoginStates(ctx context.Context) ([]*userloginstate.UserLoginState, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetUserLoginStates") + defer span.End() + + rg, err := acquireReadGuard(c, c.collections.userLoginStates) + if err != nil { + return nil, trace.Wrap(err) + } + defer rg.Release() + + if !rg.ReadCache() { + states, err := c.Config.UserLoginStates.GetUserLoginStates(ctx) + return states, trace.Wrap(err) + } + + states := make([]*userloginstate.UserLoginState, 0, rg.store.len()) + for uls := range rg.store.resources(userLoginStateNameIndex, "", "") { + states = append(states, uls.Clone()) + } + + return states, nil +} + +// GetUserLoginState returns the specified user login state resource. +func (c *Cache) GetUserLoginState(ctx context.Context, name string) (*userloginstate.UserLoginState, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetUserLoginState") + defer span.End() + + var upstreamRead bool + getter := genericGetter[*userloginstate.UserLoginState, userLoginStateIndex]{ + cache: c, + collection: c.collections.userLoginStates, + index: userLoginStateNameIndex, + upstreamGet: func(ctx context.Context, name string) (*userloginstate.UserLoginState, error) { + upstreamRead = true + state, err := c.Config.UserLoginStates.GetUserLoginState(ctx, name) + return state, trace.Wrap(err) + }, + clone: func(uls *userloginstate.UserLoginState) *userloginstate.UserLoginState { + return uls.Clone() + }, + } + out, err := getter.get(ctx, name) + if trace.IsNotFound(err) && !upstreamRead { + // fallback is sane because method is never used + // in construction of derivative caches. + if uls, err := c.Config.UserLoginStates.GetUserLoginState(ctx, name); err == nil { + return uls, nil + } + } + return out, trace.Wrap(err) +} diff --git a/lib/cache/user_login_state_test.go b/lib/cache/user_login_state_test.go new file mode 100644 index 0000000000000..1dda017ff5da8 --- /dev/null +++ b/lib/cache/user_login_state_test.go @@ -0,0 +1,53 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "context" + "testing" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/types/userloginstate" +) + +// TestUserLoginStates tests that CRUD operations on user login state resources are +// replicated from the backend to the cache. +func TestUserLoginStates(t *testing.T) { + t.Parallel() + + p := newTestPack(t, ForAuth) + t.Cleanup(p.Close) + + testResources(t, p, testFuncs[*userloginstate.UserLoginState]{ + newResource: func(name string) (*userloginstate.UserLoginState, error) { + return newUserLoginState(t, name), nil + }, + create: func(ctx context.Context, uls *userloginstate.UserLoginState) error { + _, err := p.userLoginStates.UpsertUserLoginState(ctx, uls) + return trace.Wrap(err) + }, + list: p.userLoginStates.GetUserLoginStates, + cacheGet: p.cache.GetUserLoginState, + cacheList: p.cache.GetUserLoginStates, + update: func(ctx context.Context, uls *userloginstate.UserLoginState) error { + _, err := p.userLoginStates.UpsertUserLoginState(ctx, uls) + return trace.Wrap(err) + }, + deleteAll: p.userLoginStates.DeleteAllUserLoginStates, + }) +} diff --git a/lib/cache/user_task.go b/lib/cache/user_task.go new file mode 100644 index 0000000000000..47443c9869807 --- /dev/null +++ b/lib/cache/user_task.go @@ -0,0 +1,116 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "context" + + "github.com/gravitational/trace" + + headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" + usertasksv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/usertasks/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils" + "github.com/gravitational/teleport/lib/services" +) + +type userTaskIndex string + +const userTaskNameIndex userTaskIndex = "name" + +func newUserTaskCollection(upstream services.UserTasks, w types.WatchKind) (*collection[*usertasksv1.UserTask, userTaskIndex], error) { + if upstream == nil { + return nil, trace.BadParameter("missing parameter UserTasks") + } + + return &collection[*usertasksv1.UserTask, userTaskIndex]{ + store: newStore(map[userTaskIndex]func(*usertasksv1.UserTask) string{ + userTaskNameIndex: func(r *usertasksv1.UserTask) string { + return r.GetMetadata().GetName() + }, + }), + fetcher: func(ctx context.Context, loadSecrets bool) ([]*usertasksv1.UserTask, error) { + var resources []*usertasksv1.UserTask + var nextToken string + for { + var page []*usertasksv1.UserTask + var err error + page, nextToken, err = upstream.ListUserTasks(ctx, 0 /* page size */, nextToken, &usertasksv1.ListUserTasksFilters{}) + if err != nil { + return nil, trace.Wrap(err) + } + resources = append(resources, page...) + + if nextToken == "" { + break + } + } + return resources, nil + }, + headerTransform: func(hdr *types.ResourceHeader) *usertasksv1.UserTask { + return &usertasksv1.UserTask{ + Kind: hdr.Kind, + Version: hdr.Version, + Metadata: &headerv1.Metadata{ + Name: hdr.Metadata.Name, + }, + } + }, + watch: w, + }, nil +} + +// ListUserTasks returns a list of UserTask resources. +func (c *Cache) ListUserTasks(ctx context.Context, pageSize int64, pageToken string, filters *usertasksv1.ListUserTasksFilters) ([]*usertasksv1.UserTask, string, error) { + ctx, span := c.Tracer.Start(ctx, "cache/ListUserTasks") + defer span.End() + + lister := genericLister[*usertasksv1.UserTask, userTaskIndex]{ + cache: c, + collection: c.collections.userTasks, + index: userTaskNameIndex, + upstreamList: func(ctx context.Context, i int, s string) ([]*usertasksv1.UserTask, string, error) { + out, next, err := c.Config.UserTasks.ListUserTasks(ctx, pageSize, pageToken, filters) + return out, next, trace.Wrap(err) + }, + nextToken: func(t *usertasksv1.UserTask) string { + return t.GetMetadata().Name + }, + clone: utils.CloneProtoMsg[*usertasksv1.UserTask], + filter: func(ut *usertasksv1.UserTask) bool { + return services.MatchUserTask(ut, filters) + }, + } + out, next, err := lister.list(ctx, int(pageSize), pageToken) + return out, next, trace.Wrap(err) +} + +// GetUserTask returns the specified UserTask resource. +func (c *Cache) GetUserTask(ctx context.Context, name string) (*usertasksv1.UserTask, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetUserTask") + defer span.End() + + getter := genericGetter[*usertasksv1.UserTask, userTaskIndex]{ + cache: c, + collection: c.collections.userTasks, + index: userTaskNameIndex, + upstreamGet: c.Config.UserTasks.GetUserTask, + clone: utils.CloneProtoMsg[*usertasksv1.UserTask], + } + out, err := getter.get(ctx, name) + return out, trace.Wrap(err) +} diff --git a/lib/cache/user_task_test.go b/lib/cache/user_task_test.go new file mode 100644 index 0000000000000..4b2b6d33cbd08 --- /dev/null +++ b/lib/cache/user_task_test.go @@ -0,0 +1,54 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "context" + "testing" + + "github.com/gravitational/trace" + + usertasksv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/usertasks/v1" +) + +// TestUserTasks tests that CRUD operations on user notification resources are +// replicated from the backend to the cache. +func TestUserTasks(t *testing.T) { + t.Parallel() + + p := newTestPack(t, ForAuth) + t.Cleanup(p.Close) + + testResources153(t, p, testFuncs153[*usertasksv1.UserTask]{ + newResource: func(name string) (*usertasksv1.UserTask, error) { + return newUserTasks(t), nil + }, + create: func(ctx context.Context, item *usertasksv1.UserTask) error { + _, err := p.userTasks.CreateUserTask(ctx, item) + return trace.Wrap(err) + }, + list: func(ctx context.Context) ([]*usertasksv1.UserTask, error) { + items, _, err := p.userTasks.ListUserTasks(ctx, 0, "", &usertasksv1.ListUserTasksFilters{}) + return items, trace.Wrap(err) + }, + cacheList: func(ctx context.Context) ([]*usertasksv1.UserTask, error) { + items, _, err := p.cache.ListUserTasks(ctx, 0, "", &usertasksv1.ListUserTasksFilters{}) + return items, trace.Wrap(err) + }, + deleteAll: p.userTasks.DeleteAllUserTasks, + }) +} diff --git a/lib/services/local/user_task.go b/lib/services/local/user_task.go index 39447372fdc6a..3e587494536d8 100644 --- a/lib/services/local/user_task.go +++ b/lib/services/local/user_task.go @@ -55,17 +55,7 @@ func NewUserTasksService(b backend.Backend) (*UserTasksService, error) { func (s *UserTasksService) ListUserTasks(ctx context.Context, pagesize int64, lastKey string, filters *usertasksv1.ListUserTasksFilters) ([]*usertasksv1.UserTask, string, error) { r, nextToken, err := s.service.ListResourcesWithFilter(ctx, int(pagesize), lastKey, func(ut *usertasksv1.UserTask) bool { - integrationFilter := filters.GetIntegration() - if integrationFilter != "" && integrationFilter != ut.GetSpec().GetIntegration() { - return false - } - - stateFilter := filters.GetTaskState() - if stateFilter != "" && stateFilter != ut.GetSpec().GetState() { - return false - } - - return true + return services.MatchUserTask(ut, filters) }) return r, nextToken, trace.Wrap(err) } diff --git a/lib/services/user_task.go b/lib/services/user_task.go index 263b7fff6d655..702c548b782af 100644 --- a/lib/services/user_task.go +++ b/lib/services/user_task.go @@ -51,3 +51,17 @@ func MarshalUserTask(object *usertasksv1.UserTask, opts ...MarshalOption) ([]byt func UnmarshalUserTask(data []byte, opts ...MarshalOption) (*usertasksv1.UserTask, error) { return UnmarshalProtoResource[*usertasksv1.UserTask](data, opts...) } + +func MatchUserTask(ut *usertasksv1.UserTask, filters *usertasksv1.ListUserTasksFilters) bool { + integrationFilter := filters.GetIntegration() + if integrationFilter != "" && integrationFilter != ut.GetSpec().GetIntegration() { + return false + } + + stateFilter := filters.GetTaskState() + if stateFilter != "" && stateFilter != ut.GetSpec().GetState() { + return false + } + + return true +}