From ed1e14d0d028d911708298a47be4fcd9c1b098bc Mon Sep 17 00:00:00 2001 From: Tim Ross Date: Thu, 1 May 2025 16:52:39 -0400 Subject: [PATCH] Convert web session and SAML service provider cache collections Moves SAMLIdPServiceProviders, AppSessions, WebSession, SAMLSessions, and SnowflakeSessions to the new cache collection scheme that was introduced in #52210. No additional functionality changes have been made here. This should be a purely mechanical translation to the new internal caching machinery. --- api/types/session.go | 9 + lib/cache/cache.go | 171 ---------------- lib/cache/cache_test.go | 34 ---- lib/cache/cert_authority_test.go | 2 - lib/cache/collections.go | 48 +++++ lib/cache/generic_operations.go | 5 +- lib/cache/legacy_collections.go | 322 +------------------------------ lib/cache/saml_idp.go | 194 +++++++++++++++++++ lib/cache/saml_idp_test.go | 104 ++++++++++ lib/cache/web_session.go | 320 ++++++++++++++++++++++++++++++ lib/cache/web_session_test.go | 220 +++++++++++++++++++++ 11 files changed, 900 insertions(+), 529 deletions(-) create mode 100644 lib/cache/saml_idp.go create mode 100644 lib/cache/saml_idp_test.go create mode 100644 lib/cache/web_session.go create mode 100644 lib/cache/web_session_test.go diff --git a/api/types/session.go b/api/types/session.go index 500a1e1cb23d0..1b7268e1a0f71 100644 --- a/api/types/session.go +++ b/api/types/session.go @@ -22,6 +22,8 @@ import ( "time" "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/utils" ) // WebSessionsGetter provides access to web sessions @@ -115,6 +117,8 @@ type WebSession interface { // requirement. // See [TrustedDeviceRequirement]. GetTrustedDeviceRequirement() TrustedDeviceRequirement + // Copy returns a clone of the session resource. + Copy() WebSession } // NewWebSession returns new instance of the web session based on the V2 spec @@ -138,6 +142,11 @@ func (ws *WebSessionV2) GetKind() string { return ws.Kind } +// Copy returns a clone of the session resource. +func (ws *WebSessionV2) Copy() WebSession { + return utils.CloneProtoMsg(ws) +} + // GetSubKind gets resource SubKind func (ws *WebSessionV2) GetSubKind() string { return ws.SubKind diff --git a/lib/cache/cache.go b/lib/cache/cache.go index bc7c258743f68..e0fbffc3c61d8 100644 --- a/lib/cache/cache.go +++ b/lib/cache/cache.go @@ -513,13 +513,8 @@ type Cache struct { restrictionsCache services.Restrictions crownJewelsCache services.CrownJewels databaseObjectsCache *local.DatabaseObjectService - appSessionCache services.AppSession - snowflakeSessionCache services.SnowflakeSession - samlIdPSessionCache services.SAMLIdPSession //nolint:revive // Because we want this to be IdP. - webSessionCache types.WebSessionInterface webTokenCache types.WebTokenInterface dynamicWindowsDesktopsCache services.DynamicWindowsDesktops - samlIdPServiceProvidersCache services.SAMLIdPServiceProviders //nolint:revive // Because we want this to be IdP. userGroupsCache services.UserGroups integrationsCache services.Integrations userTasksCache services.UserTasks @@ -922,13 +917,6 @@ func New(config Config) (*Cache, error) { return nil, trace.Wrap(err) } - //nolint:revive // Because we want this to be IdP. - samlIdPServiceProvidersCache, err := local.NewSAMLIdPServiceProviderService(config.Backend) - if err != nil { - cancel() - return nil, trace.Wrap(err) - } - userGroupsCache, err := local.NewUserGroupService(config.Backend) if err != nil { cancel() @@ -1064,14 +1052,9 @@ func New(config Config) (*Cache, error) { presenceCache: local.NewPresenceService(config.Backend), restrictionsCache: local.NewRestrictionsService(config.Backend), crownJewelsCache: crownJewelCache, - appSessionCache: identityService, - snowflakeSessionCache: identityService, - samlIdPSessionCache: identityService, - webSessionCache: identityService.WebSessions(), webTokenCache: identityService.WebTokens(), dynamicWindowsDesktopsCache: dynamicDesktopsService, accessMontoringRuleCache: accessMonitoringRuleCache, - samlIdPServiceProvidersCache: samlIdPServiceProvidersCache, userGroupsCache: userGroupsCache, integrationsCache: integrationsCache, userTasksCache: userTasksCache, @@ -2047,105 +2030,6 @@ func (c *Cache) GetStaticHostUser(ctx context.Context, name string) (*userprovis return rg.reader.GetStaticHostUser(ctx, name) } -// GetAppSession gets an application web session. -func (c *Cache) GetAppSession(ctx context.Context, req types.GetAppSessionRequest) (types.WebSession, error) { - ctx, span := c.Tracer.Start(ctx, "cache/GetAppSession") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.appSessions) - if err != nil { - return nil, trace.Wrap(err) - } - defer rg.Release() - sess, err := rg.reader.GetAppSession(ctx, req) - if trace.IsNotFound(err) && rg.IsCacheRead() { - // release read lock early - rg.Release() - // fallback is sane because method is never used - // in construction of derivative caches. - if sess, err := c.Config.AppSession.GetAppSession(ctx, req); err == nil { - c.Logger.DebugContext(ctx, "Cache was forced to load session from upstream", - "session_kind", sess.GetSubKind(), - "session", sess.GetName(), - ) - return sess, nil - } - } - - return sess, trace.Wrap(err) -} - -// ListAppSessions returns a page of application web sessions. -func (c *Cache) ListAppSessions(ctx context.Context, pageSize int, pageToken, user string) ([]types.WebSession, string, error) { - ctx, span := c.Tracer.Start(ctx, "cache/ListAppSessions") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.appSessions) - if err != nil { - return nil, "", trace.Wrap(err) - } - defer rg.Release() - return rg.reader.ListAppSessions(ctx, pageSize, pageToken, user) -} - -// GetSnowflakeSession gets Snowflake web session. -func (c *Cache) GetSnowflakeSession(ctx context.Context, req types.GetSnowflakeSessionRequest) (types.WebSession, error) { - ctx, span := c.Tracer.Start(ctx, "cache/GetSnowflakeSession") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.snowflakeSessions) - if err != nil { - return nil, trace.Wrap(err) - } - defer rg.Release() - - sess, err := rg.reader.GetSnowflakeSession(ctx, req) - if trace.IsNotFound(err) && rg.IsCacheRead() { - // release read lock early - rg.Release() - // fallback is sane because method is never used - // in construction of derivative caches. - if sess, err := c.Config.SnowflakeSession.GetSnowflakeSession(ctx, req); err == nil { - c.Logger.DebugContext(ctx, "Cache was forced to load sessionfrom upstream", - "session_kind", sess.GetSubKind(), - "session", sess.GetName(), - ) - return sess, nil - } - } - - return sess, trace.Wrap(err) -} - -// GetSAMLIdPSession gets a SAML IdP session. -func (c *Cache) GetSAMLIdPSession(ctx context.Context, req types.GetSAMLIdPSessionRequest) (types.WebSession, error) { - ctx, span := c.Tracer.Start(ctx, "cache/GetSAMLIdPSession") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.samlIdPSessions) - if err != nil { - return nil, trace.Wrap(err) - } - defer rg.Release() - - sess, err := rg.reader.GetSAMLIdPSession(ctx, req) - if trace.IsNotFound(err) && rg.IsCacheRead() { - // release read lock early - rg.Release() - // fallback is sane because method is never used - // in construction of derivative caches. - if sess, err := c.Config.SAMLIdPSession.GetSAMLIdPSession(ctx, req); err == nil { - c.Logger.DebugContext(ctx, "Cache was forced to load sessionfrom upstream", - "session_kind", sess.GetSubKind(), - "session", sess.GetName(), - ) - return sess, nil - } - } - - return sess, trace.Wrap(err) -} - func (c *Cache) GetDatabaseObject(ctx context.Context, name string) (*dbobjectv1.DatabaseObject, error) { ctx, span := c.Tracer.Start(ctx, "cache/GetDatabaseObject") defer span.End() @@ -2170,35 +2054,6 @@ func (c *Cache) ListDatabaseObjects(ctx context.Context, size int, pageToken str return rg.reader.ListDatabaseObjects(ctx, size, pageToken) } -// GetWebSession gets a regular web session. -func (c *Cache) GetWebSession(ctx context.Context, req types.GetWebSessionRequest) (types.WebSession, error) { - ctx, span := c.Tracer.Start(ctx, "cache/GetWebSession") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.webSessions) - if err != nil { - return nil, trace.Wrap(err) - } - defer rg.Release() - - sess, err := rg.reader.Get(ctx, req) - - if trace.IsNotFound(err) && rg.IsCacheRead() { - // release read lock early - rg.Release() - // fallback is sane because method is never used - // in construction of derivative caches. - if sess, err := c.Config.WebSession.Get(ctx, req); err == nil { - c.Logger.DebugContext(ctx, "Cache was forced to load sessionfrom upstream", - "session_kind", sess.GetSubKind(), - "session", sess.GetName(), - ) - return sess, nil - } - } - return sess, trace.Wrap(err) -} - // GetWebToken gets a web token. func (c *Cache) GetWebToken(ctx context.Context, req types.GetWebTokenRequest) (types.WebToken, error) { ctx, span := c.Tracer.Start(ctx, "cache/GetWebToken") @@ -2290,32 +2145,6 @@ func (c *Cache) ListDynamicWindowsDesktops(ctx context.Context, pageSize int, ne return rg.reader.ListDynamicWindowsDesktops(ctx, pageSize, nextPage) } -// ListSAMLIdPServiceProviders returns a paginated list of SAML IdP service provider resources. -func (c *Cache) ListSAMLIdPServiceProviders(ctx context.Context, pageSize int, nextKey string) ([]types.SAMLIdPServiceProvider, string, error) { - ctx, span := c.Tracer.Start(ctx, "cache/ListSAMLIdPServiceProviders") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.samlIdPServiceProviders) - if err != nil { - return nil, "", trace.Wrap(err) - } - defer rg.Release() - return rg.reader.ListSAMLIdPServiceProviders(ctx, pageSize, nextKey) -} - -// GetSAMLIdPServiceProvider returns the specified SAML IdP service provider resources. -func (c *Cache) GetSAMLIdPServiceProvider(ctx context.Context, name string) (types.SAMLIdPServiceProvider, error) { - ctx, span := c.Tracer.Start(ctx, "cache/GetSAMLIdPServiceProvider") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.samlIdPServiceProviders) - if err != nil { - return nil, trace.Wrap(err) - } - defer rg.Release() - return rg.reader.GetSAMLIdPServiceProvider(ctx, name) -} - // ListIntegrations returns a paginated list of all Integrations resources. func (c *Cache) ListIntegrations(ctx context.Context, pageSize int, nextKey string) ([]types.Integration, string, error) { ctx, span := c.Tracer.Start(ctx, "cache/ListIntegrations") diff --git a/lib/cache/cache_test.go b/lib/cache/cache_test.go index d5f2a1afaab00..b5c9a2a25a242 100644 --- a/lib/cache/cache_test.go +++ b/lib/cache/cache_test.go @@ -1288,40 +1288,6 @@ func mustCreateDatabase(t *testing.T, name, protocol, uri string) *types.Databas return database } -// TestSAMLIdPServiceProviders tests that CRUD operations on SAML IdP service provider resources are -// replicated from the backend to the cache. -func TestSAMLIdPServiceProviders(t *testing.T) { - t.Parallel() - - p := newTestPack(t, ForAuth) - t.Cleanup(p.Close) - - testResources(t, p, testFuncs[types.SAMLIdPServiceProvider]{ - newResource: func(name string) (types.SAMLIdPServiceProvider, error) { - return types.NewSAMLIdPServiceProvider( - types.Metadata{ - Name: name, - }, - types.SAMLIdPServiceProviderSpecV1{ - EntityDescriptor: testEntityDescriptor, - EntityID: "IAMShowcase", - }) - }, - create: p.samlIDPServiceProviders.CreateSAMLIdPServiceProvider, - list: func(ctx context.Context) ([]types.SAMLIdPServiceProvider, error) { - results, _, err := p.samlIDPServiceProviders.ListSAMLIdPServiceProviders(ctx, 0, "") - return results, err - }, - cacheGet: p.cache.GetSAMLIdPServiceProvider, - cacheList: func(ctx context.Context) ([]types.SAMLIdPServiceProvider, error) { - results, _, err := p.cache.ListSAMLIdPServiceProviders(ctx, 0, "") - return results, err - }, - update: p.samlIDPServiceProviders.UpdateSAMLIdPServiceProvider, - deleteAll: p.samlIDPServiceProviders.DeleteAllSAMLIdPServiceProviders, - }) -} - // TestLocks tests that CRUD operations on lock resources are // replicated from the backend to the cache. func TestLocks(t *testing.T) { diff --git a/lib/cache/cert_authority_test.go b/lib/cache/cert_authority_test.go index 7718f3054bdd2..d07523fc51ff4 100644 --- a/lib/cache/cert_authority_test.go +++ b/lib/cache/cert_authority_test.go @@ -94,8 +94,6 @@ func TestNodeCAFiltering(t *testing.T) { DynamicAccess: p.cache.dynamicAccessCache, Presence: p.cache.presenceCache, Restrictions: p.cache.restrictionsCache, - AppSession: p.cache.appSessionCache, - WebSession: p.cache.webSessionCache, WebToken: p.cache.webTokenCache, DynamicWindowsDesktops: p.cache.dynamicWindowsDesktopsCache, SAMLIdPServiceProviders: p.samlIDPServiceProviders, diff --git a/lib/cache/collections.go b/lib/cache/collections.go index 845217f157bb6..7c8c5ca8d3e79 100644 --- a/lib/cache/collections.go +++ b/lib/cache/collections.go @@ -90,6 +90,11 @@ type collections struct { autoUpdateRollout *collection[*autoupdatev1.AutoUpdateAgentRollout, autoUpdateAgentRolloutIndex] oktaImportRules *collection[types.OktaImportRule, oktaImportRuleIndex] oktaAssignments *collection[types.OktaAssignment, oktaAssignmentIndex] + samlIdPServiceProviders *collection[types.SAMLIdPServiceProvider, samlIdPServiceProviderIndex] + samlIdPSessions *collection[types.WebSession, samlIdPSessionIndex] + webSessions *collection[types.WebSession, webSessionIndex] + appSessions *collection[types.WebSession, appSessionIndex] + snowflakeSessions *collection[types.WebSession, snowflakeSessionIndex] } // setupCollections ensures that the appropriate [collection] is @@ -392,6 +397,49 @@ func setupCollections(c Config) (*collections, error) { out.oktaAssignments = collect out.byKind[resourceKind] = out.oktaAssignments + case types.KindSAMLIdPServiceProvider: + collect, err := newSAMLIdPServiceProviderCollection(c.SAMLIdPServiceProviders, watch) + if err != nil { + return nil, trace.Wrap(err) + } + + out.samlIdPServiceProviders = collect + out.byKind[resourceKind] = out.samlIdPServiceProviders + case types.KindWebSession: + switch watch.SubKind { + case types.KindAppSession: + collect, err := newAppSessionCollection(c.AppSession, watch) + if err != nil { + return nil, trace.Wrap(err) + } + + out.appSessions = collect + out.byKind[resourceKind] = out.appSessions + case types.KindSnowflakeSession: + collect, err := newSnowflakeSessionCollection(c.SnowflakeSession, watch) + if err != nil { + return nil, trace.Wrap(err) + } + + out.snowflakeSessions = collect + out.byKind[resourceKind] = out.snowflakeSessions + case types.KindSAMLIdPSession: + collect, err := newSAMLIdPSessionCollection(c.SAMLIdPSession, watch) + if err != nil { + return nil, trace.Wrap(err) + } + + out.samlIdPSessions = collect + out.byKind[resourceKind] = out.samlIdPSessions + case types.KindWebSession: + collect, err := newWebSessionCollection(c.WebSession, watch) + if err != nil { + return nil, trace.Wrap(err) + } + + out.webSessions = collect + out.byKind[resourceKind] = out.webSessions + } } } diff --git a/lib/cache/generic_operations.go b/lib/cache/generic_operations.go index 2b21e47f8e21f..5b3de6b532d36 100644 --- a/lib/cache/generic_operations.go +++ b/lib/cache/generic_operations.go @@ -57,7 +57,10 @@ func (g genericGetter[T, I]) get(ctx context.Context, identifier string) (T, err } out, err := rg.store.get(g.index, identifier) - return g.clone(out), trace.Wrap(err) + if err != nil { + return t, trace.Wrap(err) + } + return g.clone(out), nil } // genericLister is a helper to retrieve a page of items from a cache collection. diff --git a/lib/cache/legacy_collections.go b/lib/cache/legacy_collections.go index e821e7d80ffc0..32dd49c9278e2 100644 --- a/lib/cache/legacy_collections.go +++ b/lib/cache/legacy_collections.go @@ -112,7 +112,6 @@ type legacyCollections struct { accessListMembers collectionReader[accessListMembersGetter] accessListReviews collectionReader[accessListReviewsGetter] tunnelConnections collectionReader[tunnelConnectionGetter] - appSessions collectionReader[appSessionGetter] databaseObjects collectionReader[services.DatabaseObjectsGetter] discoveryConfigs collectionReader[services.DiscoveryConfigsGetter] installers collectionReader[installerGetter] @@ -125,12 +124,8 @@ type legacyCollections struct { networkRestrictions collectionReader[networkRestrictionGetter] proxies collectionReader[services.ProxyGetter] remoteClusters collectionReader[remoteClusterGetter] - samlIdPServiceProviders collectionReader[samlIdPServiceProviderGetter] - samlIdPSessions collectionReader[samlIdPSessionGetter] - snowflakeSessions collectionReader[snowflakeSessionGetter] uiConfigs collectionReader[uiConfigGetter] userLoginStates collectionReader[services.UserLoginStatesGetter] - webSessions collectionReader[webSessionGetter] webTokens collectionReader[webTokenGetter] dynamicWindowsDesktops collectionReader[dynamicWindowsDesktopsGetter] accessGraphSettings collectionReader[accessGraphSettingsGetter] @@ -190,45 +185,7 @@ func setupLegacyCollections(c *Cache, watches []types.WatchKind) (*legacyCollect return nil, trace.BadParameter("missing parameter DynamicAccess") } collections.byKind[resourceKind] = &genericCollection[types.AccessRequest, noReader, accessRequestExecutor]{cache: c, watch: watch} - case types.KindWebSession: - switch watch.SubKind { - case types.KindAppSession: - if c.AppSession == nil { - return nil, trace.BadParameter("missing parameter AppSession") - } - collections.appSessions = &genericCollection[types.WebSession, appSessionGetter, appSessionExecutor]{ - cache: c, - watch: watch, - } - collections.byKind[resourceKind] = collections.appSessions - case types.KindSnowflakeSession: - if c.SnowflakeSession == nil { - return nil, trace.BadParameter("missing parameter SnowflakeSession") - } - collections.snowflakeSessions = &genericCollection[types.WebSession, snowflakeSessionGetter, snowflakeSessionExecutor]{ - cache: c, - watch: watch, - } - collections.byKind[resourceKind] = collections.snowflakeSessions - case types.KindSAMLIdPSession: - if c.SAMLIdPSession == nil { - return nil, trace.BadParameter("missing parameter SAMLIdPSession") - } - collections.samlIdPSessions = &genericCollection[types.WebSession, samlIdPSessionGetter, samlIdPSessionExecutor]{ - cache: c, - watch: watch, - } - collections.byKind[resourceKind] = collections.samlIdPSessions - case types.KindWebSession: - if c.WebSession == nil { - return nil, trace.BadParameter("missing parameter WebSession") - } - collections.webSessions = &genericCollection[types.WebSession, webSessionGetter, webSessionExecutor]{ - cache: c, - watch: watch, - } - collections.byKind[resourceKind] = collections.webSessions - } + case types.KindWebToken: if c.WebToken == nil { return nil, trace.BadParameter("missing parameter WebToken") @@ -283,15 +240,6 @@ func setupLegacyCollections(c *Cache, watches []types.WatchKind) (*legacyCollect watch: watch, } collections.byKind[resourceKind] = collections.dynamicWindowsDesktops - case types.KindSAMLIdPServiceProvider: - if c.SAMLIdPServiceProviders == nil { - return nil, trace.BadParameter("missing parameter SAMLIdPServiceProviders") - } - collections.samlIdPServiceProviders = &genericCollection[types.SAMLIdPServiceProvider, samlIdPServiceProviderGetter, samlIdPServiceProvidersExecutor]{ - cache: c, - watch: watch, - } - collections.byKind[resourceKind] = collections.samlIdPServiceProviders case types.KindIntegration: if c.Integrations == nil { return nil, trace.BadParameter("missing parameter Integrations") @@ -679,216 +627,6 @@ func (databaseObjectExecutor) getReader(cache *Cache, cacheOK bool) services.Dat var _ executor[*dbobjectv1.DatabaseObject, services.DatabaseObjectsGetter] = databaseObjectExecutor{} -type appSessionExecutor struct{} - -func (appSessionExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]types.WebSession, error) { - var ( - startKey string - sessions []types.WebSession - ) - for { - webSessions, nextKey, err := cache.AppSession.ListAppSessions(ctx, 0, startKey, "") - if err != nil { - return nil, trace.Wrap(err) - } - - if !loadSecrets { - for i := 0; i < len(webSessions); i++ { - webSessions[i] = webSessions[i].WithoutSecrets() - } - } - - sessions = append(sessions, webSessions...) - - if nextKey == "" { - break - } - startKey = nextKey - } - return sessions, nil -} - -func (appSessionExecutor) upsert(ctx context.Context, cache *Cache, resource types.WebSession) error { - return cache.appSessionCache.UpsertAppSession(ctx, resource) -} - -func (appSessionExecutor) deleteAll(ctx context.Context, cache *Cache) error { - return cache.appSessionCache.DeleteAllAppSessions(ctx) -} - -func (appSessionExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error { - return cache.appSessionCache.DeleteAppSession(ctx, types.DeleteAppSessionRequest{ - SessionID: resource.GetName(), - }) -} - -func (appSessionExecutor) isSingleton() bool { return false } - -func (appSessionExecutor) getReader(cache *Cache, cacheOK bool) appSessionGetter { - if cacheOK { - return cache.appSessionCache - } - return cache.Config.AppSession -} - -type appSessionGetter interface { - GetAppSession(ctx context.Context, req types.GetAppSessionRequest) (types.WebSession, error) - ListAppSessions(ctx context.Context, pageSize int, pageToken, user string) ([]types.WebSession, string, error) -} - -var _ executor[types.WebSession, appSessionGetter] = appSessionExecutor{} - -type snowflakeSessionExecutor struct{} - -func (snowflakeSessionExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]types.WebSession, error) { - webSessions, err := cache.SnowflakeSession.GetSnowflakeSessions(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - - if !loadSecrets { - for i := 0; i < len(webSessions); i++ { - webSessions[i] = webSessions[i].WithoutSecrets() - } - } - - return webSessions, nil -} - -func (snowflakeSessionExecutor) upsert(ctx context.Context, cache *Cache, resource types.WebSession) error { - return cache.snowflakeSessionCache.UpsertSnowflakeSession(ctx, resource) -} - -func (snowflakeSessionExecutor) deleteAll(ctx context.Context, cache *Cache) error { - return cache.snowflakeSessionCache.DeleteAllSnowflakeSessions(ctx) -} - -func (snowflakeSessionExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error { - return cache.snowflakeSessionCache.DeleteSnowflakeSession(ctx, types.DeleteSnowflakeSessionRequest{ - SessionID: resource.GetName(), - }) -} - -func (snowflakeSessionExecutor) isSingleton() bool { return false } - -func (snowflakeSessionExecutor) getReader(cache *Cache, cacheOK bool) snowflakeSessionGetter { - if cacheOK { - return cache.snowflakeSessionCache - } - return cache.Config.SnowflakeSession -} - -type snowflakeSessionGetter interface { - GetSnowflakeSession(context.Context, types.GetSnowflakeSessionRequest) (types.WebSession, error) -} - -var _ executor[types.WebSession, snowflakeSessionGetter] = snowflakeSessionExecutor{} - -//nolint:revive // Because we want this to be IdP. -type samlIdPSessionExecutor struct{} - -func (samlIdPSessionExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]types.WebSession, error) { - var ( - startKey string - sessions []types.WebSession - ) - for { - webSessions, nextKey, err := cache.SAMLIdPSession.ListSAMLIdPSessions(ctx, 0, startKey, "") - if err != nil { - return nil, trace.Wrap(err) - } - - if !loadSecrets { - for i := 0; i < len(webSessions); i++ { - webSessions[i] = webSessions[i].WithoutSecrets() - } - } - - sessions = append(sessions, webSessions...) - - if nextKey == "" { - break - } - startKey = nextKey - } - return sessions, nil -} - -func (samlIdPSessionExecutor) upsert(ctx context.Context, cache *Cache, resource types.WebSession) error { - return cache.samlIdPSessionCache.UpsertSAMLIdPSession(ctx, resource) -} - -func (samlIdPSessionExecutor) deleteAll(ctx context.Context, cache *Cache) error { - return cache.samlIdPSessionCache.DeleteAllSAMLIdPSessions(ctx) -} - -func (samlIdPSessionExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error { - return cache.samlIdPSessionCache.DeleteSAMLIdPSession(ctx, types.DeleteSAMLIdPSessionRequest{ - SessionID: resource.GetName(), - }) -} - -func (samlIdPSessionExecutor) isSingleton() bool { return false } - -func (samlIdPSessionExecutor) getReader(cache *Cache, cacheOK bool) samlIdPSessionGetter { - if cacheOK { - return cache.samlIdPSessionCache - } - return cache.Config.SAMLIdPSession -} - -type samlIdPSessionGetter interface { - GetSAMLIdPSession(context.Context, types.GetSAMLIdPSessionRequest) (types.WebSession, error) -} - -var _ executor[types.WebSession, samlIdPSessionGetter] = samlIdPSessionExecutor{} - -type webSessionExecutor struct{} - -func (webSessionExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]types.WebSession, error) { - webSessions, err := cache.WebSession.List(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - - if !loadSecrets { - for i := 0; i < len(webSessions); i++ { - webSessions[i] = webSessions[i].WithoutSecrets() - } - } - - return webSessions, nil -} - -func (webSessionExecutor) upsert(ctx context.Context, cache *Cache, resource types.WebSession) error { - return cache.webSessionCache.Upsert(ctx, resource) -} - -func (webSessionExecutor) deleteAll(ctx context.Context, cache *Cache) error { - return cache.webSessionCache.DeleteAll(ctx) -} - -func (webSessionExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error { - return cache.webSessionCache.Delete(ctx, types.DeleteWebSessionRequest{ - SessionID: resource.GetName(), - }) -} - -func (webSessionExecutor) isSingleton() bool { return false } - -func (webSessionExecutor) getReader(cache *Cache, cacheOK bool) webSessionGetter { - if cacheOK { - return cache.webSessionCache - } - return cache.Config.WebSession -} - -type webSessionGetter interface { - Get(ctx context.Context, req types.GetWebSessionRequest) (types.WebSession, error) -} - -var _ executor[types.WebSession, webSessionGetter] = webSessionExecutor{} - type webTokenExecutor struct{} func (webTokenExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]types.WebToken, error) { @@ -1314,64 +1052,6 @@ func (userTasksExecutor) getReader(cache *Cache, cacheOK bool) userTasksGetter { var _ executor[*usertasksv1.UserTask, userTasksGetter] = userTasksExecutor{} -//nolint:revive // Because we want this to be IdP. -type samlIdPServiceProvidersExecutor struct{} - -func (samlIdPServiceProvidersExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]types.SAMLIdPServiceProvider, error) { - var ( - startKey string - sps []types.SAMLIdPServiceProvider - ) - for { - var samlProviders []types.SAMLIdPServiceProvider - var err error - samlProviders, startKey, err = cache.SAMLIdPServiceProviders.ListSAMLIdPServiceProviders(ctx, 0, startKey) - if err != nil { - return nil, trace.Wrap(err) - } - - sps = append(sps, samlProviders...) - - if startKey == "" { - break - } - } - - return sps, nil -} - -func (samlIdPServiceProvidersExecutor) upsert(ctx context.Context, cache *Cache, resource types.SAMLIdPServiceProvider) error { - err := cache.samlIdPServiceProvidersCache.CreateSAMLIdPServiceProvider(ctx, resource) - if trace.IsAlreadyExists(err) { - err = cache.samlIdPServiceProvidersCache.UpdateSAMLIdPServiceProvider(ctx, resource) - } - return trace.Wrap(err) -} - -func (samlIdPServiceProvidersExecutor) deleteAll(ctx context.Context, cache *Cache) error { - return cache.samlIdPServiceProvidersCache.DeleteAllSAMLIdPServiceProviders(ctx) -} - -func (samlIdPServiceProvidersExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error { - return cache.samlIdPServiceProvidersCache.DeleteSAMLIdPServiceProvider(ctx, resource.GetName()) -} - -func (samlIdPServiceProvidersExecutor) isSingleton() bool { return false } - -func (samlIdPServiceProvidersExecutor) getReader(cache *Cache, cacheOK bool) samlIdPServiceProviderGetter { - if cacheOK { - return cache.samlIdPServiceProvidersCache - } - return cache.Config.SAMLIdPServiceProviders -} - -type samlIdPServiceProviderGetter interface { - ListSAMLIdPServiceProviders(context.Context, int, string) ([]types.SAMLIdPServiceProvider, string, error) - GetSAMLIdPServiceProvider(ctx context.Context, name string) (types.SAMLIdPServiceProvider, error) -} - -var _ executor[types.SAMLIdPServiceProvider, samlIdPServiceProviderGetter] = samlIdPServiceProvidersExecutor{} - // collectionReader extends the collection interface, adding routing capabilities. type collectionReader[R any] interface { legacyCollection diff --git a/lib/cache/saml_idp.go b/lib/cache/saml_idp.go new file mode 100644 index 0000000000000..bf67a3f6b62a4 --- /dev/null +++ b/lib/cache/saml_idp.go @@ -0,0 +1,194 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "context" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/services" +) + +type samlIdPServiceProviderIndex string + +const samlIdPServiceProviderNameIndex samlIdPServiceProviderIndex = "name" + +func newSAMLIdPServiceProviderCollection(upstream services.SAMLIdPServiceProviders, w types.WatchKind) (*collection[types.SAMLIdPServiceProvider, samlIdPServiceProviderIndex], error) { + if upstream == nil { + return nil, trace.BadParameter("missing parameter SAMLIdPSession") + } + + return &collection[types.SAMLIdPServiceProvider, samlIdPServiceProviderIndex]{ + store: newStore(map[samlIdPServiceProviderIndex]func(types.SAMLIdPServiceProvider) string{ + samlIdPServiceProviderNameIndex: func(r types.SAMLIdPServiceProvider) string { + return r.GetMetadata().Name + }, + }), + fetcher: func(ctx context.Context, loadSecrets bool) ([]types.SAMLIdPServiceProvider, error) { + var startKey string + var sps []types.SAMLIdPServiceProvider + for { + var samlProviders []types.SAMLIdPServiceProvider + var err error + samlProviders, startKey, err = upstream.ListSAMLIdPServiceProviders(ctx, 0, startKey) + if err != nil { + return nil, trace.Wrap(err) + } + + sps = append(sps, samlProviders...) + + if startKey == "" { + break + } + } + + return sps, nil + }, + headerTransform: func(hdr *types.ResourceHeader) types.SAMLIdPServiceProvider { + return &types.SAMLIdPServiceProviderV1{ + ResourceHeader: types.ResourceHeader{ + Kind: hdr.Kind, + Version: hdr.Version, + Metadata: types.Metadata{ + Name: hdr.Metadata.Name, + }, + }, + } + }, + watch: w, + }, nil +} + +// ListSAMLIdPServiceProviders returns a paginated list of SAML IdP service provider resources. +func (c *Cache) ListSAMLIdPServiceProviders(ctx context.Context, pageSize int, pageToken string) ([]types.SAMLIdPServiceProvider, string, error) { + ctx, span := c.Tracer.Start(ctx, "cache/ListSAMLIdPServiceProviders") + defer span.End() + + lister := genericLister[types.SAMLIdPServiceProvider, samlIdPServiceProviderIndex]{ + cache: c, + collection: c.collections.samlIdPServiceProviders, + index: samlIdPServiceProviderNameIndex, + defaultPageSize: 200, + upstreamList: c.Config.SAMLIdPServiceProviders.ListSAMLIdPServiceProviders, + nextToken: func(t types.SAMLIdPServiceProvider) string { + return t.GetMetadata().Name + }, + clone: types.SAMLIdPServiceProvider.Copy, + } + out, next, err := lister.list(ctx, pageSize, pageToken) + return out, next, trace.Wrap(err) +} + +// GetSAMLIdPServiceProvider returns the specified SAML IdP service provider resources. +func (c *Cache) GetSAMLIdPServiceProvider(ctx context.Context, name string) (types.SAMLIdPServiceProvider, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetSAMLIdPServiceProvider") + defer span.End() + + getter := genericGetter[types.SAMLIdPServiceProvider, samlIdPServiceProviderIndex]{ + cache: c, + collection: c.collections.samlIdPServiceProviders, + index: samlIdPServiceProviderNameIndex, + upstreamGet: c.Config.SAMLIdPServiceProviders.GetSAMLIdPServiceProvider, + clone: types.SAMLIdPServiceProvider.Copy, + } + out, err := getter.get(ctx, name) + return out, trace.Wrap(err) +} + +type samlIdPSessionIndex string + +const samlIdPSessionNameIndex samlIdPSessionIndex = "name" + +func newSAMLIdPSessionCollection(upstream services.SAMLIdPSession, w types.WatchKind) (*collection[types.WebSession, samlIdPSessionIndex], error) { + if upstream == nil { + return nil, trace.BadParameter("missing parameter SAMLIdPSession") + } + + return &collection[types.WebSession, samlIdPSessionIndex]{ + store: newStore(map[samlIdPSessionIndex]func(types.WebSession) string{ + samlIdPSessionNameIndex: func(r types.WebSession) string { + return r.GetMetadata().Name + }, + }), + fetcher: func(ctx context.Context, loadSecrets bool) ([]types.WebSession, error) { + var startKey string + var sessions []types.WebSession + for { + webSessions, nextKey, err := upstream.ListSAMLIdPSessions(ctx, 0, startKey, "") + if err != nil { + return nil, trace.Wrap(err) + } + + if !loadSecrets { + for i := 0; i < len(webSessions); i++ { + webSessions[i] = webSessions[i].WithoutSecrets() + } + } + + sessions = append(sessions, webSessions...) + + if nextKey == "" { + break + } + startKey = nextKey + } + return sessions, nil + }, + headerTransform: func(hdr *types.ResourceHeader) types.WebSession { + return &types.WebSessionV2{ + Kind: hdr.Kind, + SubKind: hdr.SubKind, + Version: hdr.Version, + Metadata: types.Metadata{ + Name: hdr.Metadata.Name, + }, + } + }, + watch: w, + }, nil +} + +// GetSAMLIdPSession gets a SAML IdP session. +func (c *Cache) GetSAMLIdPSession(ctx context.Context, req types.GetSAMLIdPSessionRequest) (types.WebSession, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetSAMLIdPSession") + defer span.End() + + var upstreamRead bool + getter := genericGetter[types.WebSession, samlIdPSessionIndex]{ + cache: c, + collection: c.collections.samlIdPSessions, + index: samlIdPSessionNameIndex, + upstreamGet: func(ctx context.Context, s string) (types.WebSession, error) { + upstreamRead = true + + session, err := c.Config.SAMLIdPSession.GetSAMLIdPSession(ctx, types.GetSAMLIdPSessionRequest{SessionID: s}) + return session, trace.Wrap(err) + }, + clone: types.WebSession.Copy, + } + out, err := getter.get(ctx, req.SessionID) + if trace.IsNotFound(err) && !upstreamRead { + // fallback is sane because method is never used + // in construction of derivative caches. + if item, err := c.Config.SAMLIdPSession.GetSAMLIdPSession(ctx, req); err == nil { + return item, nil + } + } + return out, trace.Wrap(err) +} diff --git a/lib/cache/saml_idp_test.go b/lib/cache/saml_idp_test.go new file mode 100644 index 0000000000000..10b6ab64482a9 --- /dev/null +++ b/lib/cache/saml_idp_test.go @@ -0,0 +1,104 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" +) + +// TestSAMLIdPServiceProviders tests that CRUD operations on SAML IdP service provider resources are +// replicated from the backend to the cache. +func TestSAMLIdPServiceProviders(t *testing.T) { + t.Parallel() + + p := newTestPack(t, ForAuth) + t.Cleanup(p.Close) + + testResources(t, p, testFuncs[types.SAMLIdPServiceProvider]{ + newResource: func(name string) (types.SAMLIdPServiceProvider, error) { + return types.NewSAMLIdPServiceProvider( + types.Metadata{ + Name: name, + }, + types.SAMLIdPServiceProviderSpecV1{ + EntityDescriptor: testEntityDescriptor, + EntityID: "IAMShowcase", + }) + }, + create: p.samlIDPServiceProviders.CreateSAMLIdPServiceProvider, + list: func(ctx context.Context) ([]types.SAMLIdPServiceProvider, error) { + results, _, err := p.samlIDPServiceProviders.ListSAMLIdPServiceProviders(ctx, 0, "") + return results, err + }, + cacheGet: p.cache.GetSAMLIdPServiceProvider, + cacheList: func(ctx context.Context) ([]types.SAMLIdPServiceProvider, error) { + results, _, err := p.cache.ListSAMLIdPServiceProviders(ctx, 0, "") + return results, err + }, + update: p.samlIDPServiceProviders.UpdateSAMLIdPServiceProvider, + deleteAll: p.samlIDPServiceProviders.DeleteAllSAMLIdPServiceProviders, + }) +} + +func TestSAMLIdPSessions(t *testing.T) { + t.Parallel() + ctx := t.Context() + + p := newTestPack(t, ForAuth) + t.Cleanup(p.Close) + + for i := 0; i < 31; i++ { + err := p.samlIdPSessionsS.UpsertSAMLIdPSession(t.Context(), &types.WebSessionV2{ + Kind: types.KindWebSession, + SubKind: types.KindSAMLIdPSession, + Version: types.V2, + Metadata: types.Metadata{ + Name: "saml-session" + strconv.Itoa(i+1), + }, + Spec: types.WebSessionSpecV2{ + User: "fish", + }, + }) + require.NoError(t, err) + } + + require.EventuallyWithT(t, func(t *assert.CollectT) { + for i := 0; i < 31; i++ { + session, err := p.cache.GetSAMLIdPSession(ctx, types.GetSAMLIdPSessionRequest{SessionID: "saml-session" + strconv.Itoa(i+1)}) + assert.NoError(t, err) + assert.NotNil(t, session) + } + }, 15*time.Second, 100*time.Millisecond) + + require.NoError(t, p.samlIdPSessionsS.DeleteAllSAMLIdPSessions(ctx)) + + require.EventuallyWithT(t, func(t *assert.CollectT) { + for i := 0; i < 31; i++ { + session, err := p.cache.GetSAMLIdPSession(ctx, types.GetSAMLIdPSessionRequest{SessionID: "saml-session" + strconv.Itoa(i+1)}) + assert.Error(t, err) + assert.Nil(t, session) + } + }, 15*time.Second, 100*time.Millisecond) +} diff --git a/lib/cache/web_session.go b/lib/cache/web_session.go new file mode 100644 index 0000000000000..efdec6c24236c --- /dev/null +++ b/lib/cache/web_session.go @@ -0,0 +1,320 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "context" + "iter" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils/sortcache" +) + +type webSessionIndex string + +const webSessionNameIndex webSessionIndex = "name" + +func newWebSessionCollection(upstream types.WebSessionInterface, w types.WatchKind) (*collection[types.WebSession, webSessionIndex], error) { + if upstream == nil { + return nil, trace.BadParameter("missing parameter SAMLIdPSession") + } + + return &collection[types.WebSession, webSessionIndex]{ + store: newStore(map[webSessionIndex]func(types.WebSession) string{ + webSessionNameIndex: func(r types.WebSession) string { + return r.GetMetadata().Name + }, + }), + fetcher: func(ctx context.Context, loadSecrets bool) ([]types.WebSession, error) { + webSessions, err := upstream.List(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + if !loadSecrets { + for i := 0; i < len(webSessions); i++ { + webSessions[i] = webSessions[i].WithoutSecrets() + } + } + + return webSessions, nil + }, + headerTransform: func(hdr *types.ResourceHeader) types.WebSession { + return &types.WebSessionV2{ + Kind: hdr.Kind, + SubKind: hdr.SubKind, + Version: hdr.Version, + Metadata: types.Metadata{ + Name: hdr.Metadata.Name, + }, + } + }, + watch: w, + }, nil +} + +// GetWebSession gets a regular web session. +func (c *Cache) GetWebSession(ctx context.Context, req types.GetWebSessionRequest) (types.WebSession, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetWebSession") + defer span.End() + + var upstreamRead bool + getter := genericGetter[types.WebSession, webSessionIndex]{ + cache: c, + collection: c.collections.webSessions, + index: webSessionNameIndex, + upstreamGet: func(ctx context.Context, s string) (types.WebSession, error) { + upstreamRead = true + + session, err := c.Config.WebSession.Get(ctx, types.GetWebSessionRequest{SessionID: s}) + return session, trace.Wrap(err) + }, + clone: types.WebSession.Copy, + } + out, err := getter.get(ctx, req.SessionID) + if trace.IsNotFound(err) && !upstreamRead { + // fallback is sane because method is never used + // in construction of derivative caches. + if sess, err := c.Config.WebSession.Get(ctx, req); err == nil { + c.Logger.DebugContext(ctx, "Cache was forced to load session from upstream", + "session_kind", sess.GetSubKind(), + "session", sess.GetName(), + ) + return sess, nil + } + } + return out, trace.Wrap(err) +} + +type appSessionIndex string + +const ( + appSessionNameIndex appSessionIndex = "name" + appSessionUserIndex appSessionIndex = "user" +) + +func newAppSessionCollection(upstream services.AppSession, w types.WatchKind) (*collection[types.WebSession, appSessionIndex], error) { + if upstream == nil { + return nil, trace.BadParameter("missing parameter AppSession") + } + + return &collection[types.WebSession, appSessionIndex]{ + store: newStore(map[appSessionIndex]func(types.WebSession) string{ + appSessionNameIndex: func(r types.WebSession) string { + return r.GetMetadata().Name + }, + appSessionUserIndex: func(r types.WebSession) string { + return r.GetUser() + "/" + r.GetMetadata().Name + }, + }), + fetcher: func(ctx context.Context, loadSecrets bool) ([]types.WebSession, error) { + var startKey string + var sessions []types.WebSession + + for { + webSessions, nextKey, err := upstream.ListAppSessions(ctx, 0, startKey, "") + if err != nil { + return nil, trace.Wrap(err) + } + + if !loadSecrets { + for i := 0; i < len(webSessions); i++ { + webSessions[i] = webSessions[i].WithoutSecrets() + } + } + + sessions = append(sessions, webSessions...) + + if nextKey == "" { + break + } + startKey = nextKey + } + return sessions, nil + }, + headerTransform: func(hdr *types.ResourceHeader) types.WebSession { + return &types.WebSessionV2{ + Kind: hdr.Kind, + SubKind: hdr.SubKind, + Version: hdr.Version, + Metadata: types.Metadata{ + Name: hdr.Metadata.Name, + }, + } + }, + watch: w, + }, nil +} + +// GetAppSession gets an application web session. +func (c *Cache) GetAppSession(ctx context.Context, req types.GetAppSessionRequest) (types.WebSession, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetAppSession") + defer span.End() + + var upstreamRead bool + getter := genericGetter[types.WebSession, appSessionIndex]{ + cache: c, + collection: c.collections.appSessions, + index: appSessionNameIndex, + upstreamGet: func(ctx context.Context, s string) (types.WebSession, error) { + upstreamRead = true + + session, err := c.Config.AppSession.GetAppSession(ctx, types.GetAppSessionRequest{SessionID: s}) + return session, trace.Wrap(err) + }, + clone: types.WebSession.Copy, + } + out, err := getter.get(ctx, req.SessionID) + if trace.IsNotFound(err) && !upstreamRead { + // fallback is sane because method is never used + // in construction of derivative caches. + if sess, err := c.Config.AppSession.GetAppSession(ctx, req); err == nil { + c.Logger.DebugContext(ctx, "Cache was forced to load session from upstream", + "session_kind", sess.GetSubKind(), + "session", sess.GetName(), + ) + return sess, nil + } + } + return out, trace.Wrap(err) +} + +// ListAppSessions returns a page of application web sessions. +func (c *Cache) ListAppSessions(ctx context.Context, pageSize int, pageToken, user string) ([]types.WebSession, string, error) { + ctx, span := c.Tracer.Start(ctx, "cache/ListAppSessions") + defer span.End() + + rg, err := acquireReadGuard(c, c.collections.appSessions) + if err != nil { + return nil, "", trace.Wrap(err) + } + defer rg.Release() + + if !rg.ReadCache() { + out, next, err := c.Config.AppSession.ListAppSessions(ctx, pageSize, pageToken, user) + return out, next, trace.Wrap(err) + } + + // Adjust page size, so it can't be too large. + const maxSessionPageSize = 200 + if pageSize <= 0 || pageSize > maxSessionPageSize { + pageSize = maxSessionPageSize + } + + var sessions iter.Seq[types.WebSession] + if user == "" { + sessions = rg.store.resources(appSessionNameIndex, pageToken, "") + } else { + startKey := user + "/" + endKey := sortcache.NextKey(startKey) + if pageToken != "" { + startKey += pageToken + } + + sessions = rg.store.resources(appSessionUserIndex, startKey, endKey) + } + + var out []types.WebSession + for sess := range sessions { + if len(out) == pageSize { + return out, sess.GetName(), nil + } + + out = append(out, sess.Copy()) + } + + return out, "", nil +} + +type snowflakeSessionIndex string + +const snowflakeSessionNameIndex snowflakeSessionIndex = "name" + +func newSnowflakeSessionCollection(upstream services.SnowflakeSession, w types.WatchKind) (*collection[types.WebSession, snowflakeSessionIndex], error) { + if upstream == nil { + return nil, trace.BadParameter("missing parameter AppSession") + } + + return &collection[types.WebSession, snowflakeSessionIndex]{ + store: newStore(map[snowflakeSessionIndex]func(types.WebSession) string{ + snowflakeSessionNameIndex: func(r types.WebSession) string { + return r.GetMetadata().Name + }, + }), + fetcher: func(ctx context.Context, loadSecrets bool) ([]types.WebSession, error) { + webSessions, err := upstream.GetSnowflakeSessions(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + if !loadSecrets { + for i := 0; i < len(webSessions); i++ { + webSessions[i] = webSessions[i].WithoutSecrets() + } + } + + return webSessions, nil + }, + headerTransform: func(hdr *types.ResourceHeader) types.WebSession { + return &types.WebSessionV2{ + Kind: hdr.Kind, + SubKind: hdr.SubKind, + Version: hdr.Version, + Metadata: types.Metadata{ + Name: hdr.Metadata.Name, + }, + } + }, + watch: w, + }, nil +} + +// GetSnowflakeSession gets Snowflake web session. +func (c *Cache) GetSnowflakeSession(ctx context.Context, req types.GetSnowflakeSessionRequest) (types.WebSession, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetSnowflakeSession") + defer span.End() + + var upstreamRead bool + getter := genericGetter[types.WebSession, snowflakeSessionIndex]{ + cache: c, + collection: c.collections.snowflakeSessions, + index: snowflakeSessionNameIndex, + upstreamGet: func(ctx context.Context, s string) (types.WebSession, error) { + upstreamRead = true + + session, err := c.Config.SnowflakeSession.GetSnowflakeSession(ctx, types.GetSnowflakeSessionRequest{SessionID: s}) + return session, trace.Wrap(err) + }, + clone: types.WebSession.Copy, + } + out, err := getter.get(ctx, req.SessionID) + if trace.IsNotFound(err) && !upstreamRead { + // fallback is sane because method is never used + // in construction of derivative caches. + if sess, err := c.Config.SnowflakeSession.GetSnowflakeSession(ctx, req); err == nil { + c.Logger.DebugContext(ctx, "Cache was forced to load session from upstream", + "session_kind", sess.GetSubKind(), + "session", sess.GetName(), + ) + return sess, nil + } + } + return out, trace.Wrap(err) +} diff --git a/lib/cache/web_session_test.go b/lib/cache/web_session_test.go new file mode 100644 index 0000000000000..0939fb810acdc --- /dev/null +++ b/lib/cache/web_session_test.go @@ -0,0 +1,220 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "strconv" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" +) + +func TestAppSessions(t *testing.T) { + t.Parallel() + ctx := t.Context() + + p := newTestPack(t, ForAuth) + t.Cleanup(p.Close) + + for i := 0; i < 31; i++ { + err := p.appSessionS.UpsertAppSession(t.Context(), &types.WebSessionV2{ + Kind: types.KindWebSession, + SubKind: types.KindAppSession, + Version: types.V2, + Metadata: types.Metadata{ + Name: "app-session" + strconv.Itoa(i+1), + }, + Spec: types.WebSessionSpecV2{ + User: "fish", + }, + }) + require.NoError(t, err) + } + + for i := 0; i < 3; i++ { + err := p.appSessionS.UpsertAppSession(t.Context(), &types.WebSessionV2{ + Kind: types.KindWebSession, + SubKind: types.KindAppSession, + Version: types.V2, + Metadata: types.Metadata{ + Name: "app-session" + strconv.Itoa(i+100), + }, + Spec: types.WebSessionSpecV2{ + User: "llama", + }, + }) + require.NoError(t, err) + } + + require.EventuallyWithT(t, func(t *assert.CollectT) { + expected, next, err := p.appSessionS.ListAppSessions(ctx, 0, "", "") + assert.NoError(t, err) + assert.Empty(t, next) + assert.Len(t, expected, 34) + + cached, next, err := p.cache.ListAppSessions(ctx, 0, "", "") + assert.NoError(t, err) + assert.Empty(t, next) + assert.Len(t, cached, 34) + }, 15*time.Second, 100*time.Millisecond) + + session, err := p.cache.GetAppSession(ctx, types.GetAppSessionRequest{ + SessionID: "app-session100", + }) + require.NoError(t, err) + require.NotNil(t, session) + require.Equal(t, "llama", session.GetUser()) + + session, err = p.cache.GetAppSession(ctx, types.GetAppSessionRequest{ + SessionID: "app-session1", + }) + require.NoError(t, err) + require.NotNil(t, session) + require.Equal(t, "fish", session.GetUser()) + + var sessions []types.WebSession + for pageToken := ""; ; { + cached, next, err := p.cache.ListAppSessions(ctx, 1, pageToken, "llama") + if !assert.NoError(t, err) { + return + } + sessions = append(sessions, cached...) + pageToken = next + if next == "" { + break + } + } + assert.Len(t, sessions, 3) + + sessions = nil + for pageToken := ""; ; { + cached, next, err := p.cache.ListAppSessions(ctx, 7, pageToken, "fish") + if !assert.NoError(t, err) { + return + } + sessions = append(sessions, cached...) + pageToken = next + if next == "" { + break + } + } + assert.Len(t, sessions, 31) + + require.NoError(t, p.appSessionS.DeleteAllAppSessions(ctx)) + + require.EventuallyWithT(t, func(t *assert.CollectT) { + cached, next, err := p.cache.ListAppSessions(ctx, 0, "", "") + assert.NoError(t, err) + assert.Empty(t, next) + assert.Empty(t, cached) + }, 15*time.Second, 100*time.Millisecond) +} + +func TestWebSessions(t *testing.T) { + t.Parallel() + ctx := t.Context() + + p := newTestPack(t, ForAuth) + t.Cleanup(p.Close) + + for i := 0; i < 31; i++ { + err := p.webSessionS.Upsert(t.Context(), &types.WebSessionV2{ + Kind: types.KindWebSession, + SubKind: types.KindWebSession, + Version: types.V2, + Metadata: types.Metadata{ + Name: "web-session" + strconv.Itoa(i+1), + }, + Spec: types.WebSessionSpecV2{ + User: "fish", + }, + }) + require.NoError(t, err) + } + + require.EventuallyWithT(t, func(t *assert.CollectT) { + expected, err := p.webSessionS.List(ctx) + assert.NoError(t, err) + assert.Len(t, expected, 31) + + for _, session := range expected { + cached, err := p.cache.GetWebSession(ctx, types.GetWebSessionRequest{SessionID: session.GetName()}) + assert.NoError(t, err) + assert.Empty(t, cmp.Diff(session, cached)) + } + }, 15*time.Second, 100*time.Millisecond) + + require.NoError(t, p.webSessionS.DeleteAll(ctx)) + + require.EventuallyWithT(t, func(t *assert.CollectT) { + for i := 0; i < 31; i++ { + session, err := p.cache.GetWebSession(ctx, types.GetWebSessionRequest{SessionID: "web-session" + strconv.Itoa(i+1)}) + assert.Error(t, err) + assert.Nil(t, session) + } + }, 15*time.Second, 100*time.Millisecond) +} + +func TestSnowflakeSessions(t *testing.T) { + t.Parallel() + ctx := t.Context() + + p := newTestPack(t, ForAuth) + t.Cleanup(p.Close) + + for i := 0; i < 31; i++ { + err := p.snowflakeSessionS.UpsertSnowflakeSession(t.Context(), &types.WebSessionV2{ + Kind: types.KindWebSession, + SubKind: types.KindSnowflakeSession, + Version: types.V2, + Metadata: types.Metadata{ + Name: "snow-session" + strconv.Itoa(i+1), + }, + Spec: types.WebSessionSpecV2{ + User: "fish", + }, + }) + require.NoError(t, err) + } + + require.EventuallyWithT(t, func(t *assert.CollectT) { + expected, err := p.snowflakeSessionS.GetSnowflakeSessions(ctx) + assert.NoError(t, err) + assert.Len(t, expected, 31) + + for _, session := range expected { + cached, err := p.cache.GetSnowflakeSession(ctx, types.GetSnowflakeSessionRequest{SessionID: session.GetName()}) + assert.NoError(t, err) + assert.Empty(t, cmp.Diff(session, cached)) + } + }, 15*time.Second, 100*time.Millisecond) + + require.NoError(t, p.snowflakeSessionS.DeleteAllSnowflakeSessions(ctx)) + + require.EventuallyWithT(t, func(t *assert.CollectT) { + for i := 0; i < 31; i++ { + session, err := p.cache.GetSnowflakeSession(ctx, types.GetSnowflakeSessionRequest{SessionID: "snow-session" + strconv.Itoa(i+1)}) + assert.Error(t, err) + assert.Nil(t, session) + } + }, 15*time.Second, 100*time.Millisecond) +}