diff --git a/api/types/session.go b/api/types/session.go
index 500a1e1cb23d0..1b7268e1a0f71 100644
--- a/api/types/session.go
+++ b/api/types/session.go
@@ -22,6 +22,8 @@ import (
"time"
"github.com/gravitational/trace"
+
+ "github.com/gravitational/teleport/api/utils"
)
// WebSessionsGetter provides access to web sessions
@@ -115,6 +117,8 @@ type WebSession interface {
// requirement.
// See [TrustedDeviceRequirement].
GetTrustedDeviceRequirement() TrustedDeviceRequirement
+ // Copy returns a clone of the session resource.
+ Copy() WebSession
}
// NewWebSession returns new instance of the web session based on the V2 spec
@@ -138,6 +142,11 @@ func (ws *WebSessionV2) GetKind() string {
return ws.Kind
}
+// Copy returns a clone of the session resource.
+func (ws *WebSessionV2) Copy() WebSession {
+ return utils.CloneProtoMsg(ws)
+}
+
// GetSubKind gets resource SubKind
func (ws *WebSessionV2) GetSubKind() string {
return ws.SubKind
diff --git a/lib/cache/cache.go b/lib/cache/cache.go
index bc7c258743f68..e0fbffc3c61d8 100644
--- a/lib/cache/cache.go
+++ b/lib/cache/cache.go
@@ -513,13 +513,8 @@ type Cache struct {
restrictionsCache services.Restrictions
crownJewelsCache services.CrownJewels
databaseObjectsCache *local.DatabaseObjectService
- appSessionCache services.AppSession
- snowflakeSessionCache services.SnowflakeSession
- samlIdPSessionCache services.SAMLIdPSession //nolint:revive // Because we want this to be IdP.
- webSessionCache types.WebSessionInterface
webTokenCache types.WebTokenInterface
dynamicWindowsDesktopsCache services.DynamicWindowsDesktops
- samlIdPServiceProvidersCache services.SAMLIdPServiceProviders //nolint:revive // Because we want this to be IdP.
userGroupsCache services.UserGroups
integrationsCache services.Integrations
userTasksCache services.UserTasks
@@ -922,13 +917,6 @@ func New(config Config) (*Cache, error) {
return nil, trace.Wrap(err)
}
- //nolint:revive // Because we want this to be IdP.
- samlIdPServiceProvidersCache, err := local.NewSAMLIdPServiceProviderService(config.Backend)
- if err != nil {
- cancel()
- return nil, trace.Wrap(err)
- }
-
userGroupsCache, err := local.NewUserGroupService(config.Backend)
if err != nil {
cancel()
@@ -1064,14 +1052,9 @@ func New(config Config) (*Cache, error) {
presenceCache: local.NewPresenceService(config.Backend),
restrictionsCache: local.NewRestrictionsService(config.Backend),
crownJewelsCache: crownJewelCache,
- appSessionCache: identityService,
- snowflakeSessionCache: identityService,
- samlIdPSessionCache: identityService,
- webSessionCache: identityService.WebSessions(),
webTokenCache: identityService.WebTokens(),
dynamicWindowsDesktopsCache: dynamicDesktopsService,
accessMontoringRuleCache: accessMonitoringRuleCache,
- samlIdPServiceProvidersCache: samlIdPServiceProvidersCache,
userGroupsCache: userGroupsCache,
integrationsCache: integrationsCache,
userTasksCache: userTasksCache,
@@ -2047,105 +2030,6 @@ func (c *Cache) GetStaticHostUser(ctx context.Context, name string) (*userprovis
return rg.reader.GetStaticHostUser(ctx, name)
}
-// GetAppSession gets an application web session.
-func (c *Cache) GetAppSession(ctx context.Context, req types.GetAppSessionRequest) (types.WebSession, error) {
- ctx, span := c.Tracer.Start(ctx, "cache/GetAppSession")
- defer span.End()
-
- rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.appSessions)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- defer rg.Release()
- sess, err := rg.reader.GetAppSession(ctx, req)
- if trace.IsNotFound(err) && rg.IsCacheRead() {
- // release read lock early
- rg.Release()
- // fallback is sane because method is never used
- // in construction of derivative caches.
- if sess, err := c.Config.AppSession.GetAppSession(ctx, req); err == nil {
- c.Logger.DebugContext(ctx, "Cache was forced to load session from upstream",
- "session_kind", sess.GetSubKind(),
- "session", sess.GetName(),
- )
- return sess, nil
- }
- }
-
- return sess, trace.Wrap(err)
-}
-
-// ListAppSessions returns a page of application web sessions.
-func (c *Cache) ListAppSessions(ctx context.Context, pageSize int, pageToken, user string) ([]types.WebSession, string, error) {
- ctx, span := c.Tracer.Start(ctx, "cache/ListAppSessions")
- defer span.End()
-
- rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.appSessions)
- if err != nil {
- return nil, "", trace.Wrap(err)
- }
- defer rg.Release()
- return rg.reader.ListAppSessions(ctx, pageSize, pageToken, user)
-}
-
-// GetSnowflakeSession gets Snowflake web session.
-func (c *Cache) GetSnowflakeSession(ctx context.Context, req types.GetSnowflakeSessionRequest) (types.WebSession, error) {
- ctx, span := c.Tracer.Start(ctx, "cache/GetSnowflakeSession")
- defer span.End()
-
- rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.snowflakeSessions)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- defer rg.Release()
-
- sess, err := rg.reader.GetSnowflakeSession(ctx, req)
- if trace.IsNotFound(err) && rg.IsCacheRead() {
- // release read lock early
- rg.Release()
- // fallback is sane because method is never used
- // in construction of derivative caches.
- if sess, err := c.Config.SnowflakeSession.GetSnowflakeSession(ctx, req); err == nil {
- c.Logger.DebugContext(ctx, "Cache was forced to load sessionfrom upstream",
- "session_kind", sess.GetSubKind(),
- "session", sess.GetName(),
- )
- return sess, nil
- }
- }
-
- return sess, trace.Wrap(err)
-}
-
-// GetSAMLIdPSession gets a SAML IdP session.
-func (c *Cache) GetSAMLIdPSession(ctx context.Context, req types.GetSAMLIdPSessionRequest) (types.WebSession, error) {
- ctx, span := c.Tracer.Start(ctx, "cache/GetSAMLIdPSession")
- defer span.End()
-
- rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.samlIdPSessions)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- defer rg.Release()
-
- sess, err := rg.reader.GetSAMLIdPSession(ctx, req)
- if trace.IsNotFound(err) && rg.IsCacheRead() {
- // release read lock early
- rg.Release()
- // fallback is sane because method is never used
- // in construction of derivative caches.
- if sess, err := c.Config.SAMLIdPSession.GetSAMLIdPSession(ctx, req); err == nil {
- c.Logger.DebugContext(ctx, "Cache was forced to load sessionfrom upstream",
- "session_kind", sess.GetSubKind(),
- "session", sess.GetName(),
- )
- return sess, nil
- }
- }
-
- return sess, trace.Wrap(err)
-}
-
func (c *Cache) GetDatabaseObject(ctx context.Context, name string) (*dbobjectv1.DatabaseObject, error) {
ctx, span := c.Tracer.Start(ctx, "cache/GetDatabaseObject")
defer span.End()
@@ -2170,35 +2054,6 @@ func (c *Cache) ListDatabaseObjects(ctx context.Context, size int, pageToken str
return rg.reader.ListDatabaseObjects(ctx, size, pageToken)
}
-// GetWebSession gets a regular web session.
-func (c *Cache) GetWebSession(ctx context.Context, req types.GetWebSessionRequest) (types.WebSession, error) {
- ctx, span := c.Tracer.Start(ctx, "cache/GetWebSession")
- defer span.End()
-
- rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.webSessions)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- defer rg.Release()
-
- sess, err := rg.reader.Get(ctx, req)
-
- if trace.IsNotFound(err) && rg.IsCacheRead() {
- // release read lock early
- rg.Release()
- // fallback is sane because method is never used
- // in construction of derivative caches.
- if sess, err := c.Config.WebSession.Get(ctx, req); err == nil {
- c.Logger.DebugContext(ctx, "Cache was forced to load sessionfrom upstream",
- "session_kind", sess.GetSubKind(),
- "session", sess.GetName(),
- )
- return sess, nil
- }
- }
- return sess, trace.Wrap(err)
-}
-
// GetWebToken gets a web token.
func (c *Cache) GetWebToken(ctx context.Context, req types.GetWebTokenRequest) (types.WebToken, error) {
ctx, span := c.Tracer.Start(ctx, "cache/GetWebToken")
@@ -2290,32 +2145,6 @@ func (c *Cache) ListDynamicWindowsDesktops(ctx context.Context, pageSize int, ne
return rg.reader.ListDynamicWindowsDesktops(ctx, pageSize, nextPage)
}
-// ListSAMLIdPServiceProviders returns a paginated list of SAML IdP service provider resources.
-func (c *Cache) ListSAMLIdPServiceProviders(ctx context.Context, pageSize int, nextKey string) ([]types.SAMLIdPServiceProvider, string, error) {
- ctx, span := c.Tracer.Start(ctx, "cache/ListSAMLIdPServiceProviders")
- defer span.End()
-
- rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.samlIdPServiceProviders)
- if err != nil {
- return nil, "", trace.Wrap(err)
- }
- defer rg.Release()
- return rg.reader.ListSAMLIdPServiceProviders(ctx, pageSize, nextKey)
-}
-
-// GetSAMLIdPServiceProvider returns the specified SAML IdP service provider resources.
-func (c *Cache) GetSAMLIdPServiceProvider(ctx context.Context, name string) (types.SAMLIdPServiceProvider, error) {
- ctx, span := c.Tracer.Start(ctx, "cache/GetSAMLIdPServiceProvider")
- defer span.End()
-
- rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.samlIdPServiceProviders)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- defer rg.Release()
- return rg.reader.GetSAMLIdPServiceProvider(ctx, name)
-}
-
// ListIntegrations returns a paginated list of all Integrations resources.
func (c *Cache) ListIntegrations(ctx context.Context, pageSize int, nextKey string) ([]types.Integration, string, error) {
ctx, span := c.Tracer.Start(ctx, "cache/ListIntegrations")
diff --git a/lib/cache/cache_test.go b/lib/cache/cache_test.go
index d5f2a1afaab00..b5c9a2a25a242 100644
--- a/lib/cache/cache_test.go
+++ b/lib/cache/cache_test.go
@@ -1288,40 +1288,6 @@ func mustCreateDatabase(t *testing.T, name, protocol, uri string) *types.Databas
return database
}
-// TestSAMLIdPServiceProviders tests that CRUD operations on SAML IdP service provider resources are
-// replicated from the backend to the cache.
-func TestSAMLIdPServiceProviders(t *testing.T) {
- t.Parallel()
-
- p := newTestPack(t, ForAuth)
- t.Cleanup(p.Close)
-
- testResources(t, p, testFuncs[types.SAMLIdPServiceProvider]{
- newResource: func(name string) (types.SAMLIdPServiceProvider, error) {
- return types.NewSAMLIdPServiceProvider(
- types.Metadata{
- Name: name,
- },
- types.SAMLIdPServiceProviderSpecV1{
- EntityDescriptor: testEntityDescriptor,
- EntityID: "IAMShowcase",
- })
- },
- create: p.samlIDPServiceProviders.CreateSAMLIdPServiceProvider,
- list: func(ctx context.Context) ([]types.SAMLIdPServiceProvider, error) {
- results, _, err := p.samlIDPServiceProviders.ListSAMLIdPServiceProviders(ctx, 0, "")
- return results, err
- },
- cacheGet: p.cache.GetSAMLIdPServiceProvider,
- cacheList: func(ctx context.Context) ([]types.SAMLIdPServiceProvider, error) {
- results, _, err := p.cache.ListSAMLIdPServiceProviders(ctx, 0, "")
- return results, err
- },
- update: p.samlIDPServiceProviders.UpdateSAMLIdPServiceProvider,
- deleteAll: p.samlIDPServiceProviders.DeleteAllSAMLIdPServiceProviders,
- })
-}
-
// TestLocks tests that CRUD operations on lock resources are
// replicated from the backend to the cache.
func TestLocks(t *testing.T) {
diff --git a/lib/cache/cert_authority_test.go b/lib/cache/cert_authority_test.go
index 7718f3054bdd2..d07523fc51ff4 100644
--- a/lib/cache/cert_authority_test.go
+++ b/lib/cache/cert_authority_test.go
@@ -94,8 +94,6 @@ func TestNodeCAFiltering(t *testing.T) {
DynamicAccess: p.cache.dynamicAccessCache,
Presence: p.cache.presenceCache,
Restrictions: p.cache.restrictionsCache,
- AppSession: p.cache.appSessionCache,
- WebSession: p.cache.webSessionCache,
WebToken: p.cache.webTokenCache,
DynamicWindowsDesktops: p.cache.dynamicWindowsDesktopsCache,
SAMLIdPServiceProviders: p.samlIDPServiceProviders,
diff --git a/lib/cache/collections.go b/lib/cache/collections.go
index 845217f157bb6..7c8c5ca8d3e79 100644
--- a/lib/cache/collections.go
+++ b/lib/cache/collections.go
@@ -90,6 +90,11 @@ type collections struct {
autoUpdateRollout *collection[*autoupdatev1.AutoUpdateAgentRollout, autoUpdateAgentRolloutIndex]
oktaImportRules *collection[types.OktaImportRule, oktaImportRuleIndex]
oktaAssignments *collection[types.OktaAssignment, oktaAssignmentIndex]
+ samlIdPServiceProviders *collection[types.SAMLIdPServiceProvider, samlIdPServiceProviderIndex]
+ samlIdPSessions *collection[types.WebSession, samlIdPSessionIndex]
+ webSessions *collection[types.WebSession, webSessionIndex]
+ appSessions *collection[types.WebSession, appSessionIndex]
+ snowflakeSessions *collection[types.WebSession, snowflakeSessionIndex]
}
// setupCollections ensures that the appropriate [collection] is
@@ -392,6 +397,49 @@ func setupCollections(c Config) (*collections, error) {
out.oktaAssignments = collect
out.byKind[resourceKind] = out.oktaAssignments
+ case types.KindSAMLIdPServiceProvider:
+ collect, err := newSAMLIdPServiceProviderCollection(c.SAMLIdPServiceProviders, watch)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ out.samlIdPServiceProviders = collect
+ out.byKind[resourceKind] = out.samlIdPServiceProviders
+ case types.KindWebSession:
+ switch watch.SubKind {
+ case types.KindAppSession:
+ collect, err := newAppSessionCollection(c.AppSession, watch)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ out.appSessions = collect
+ out.byKind[resourceKind] = out.appSessions
+ case types.KindSnowflakeSession:
+ collect, err := newSnowflakeSessionCollection(c.SnowflakeSession, watch)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ out.snowflakeSessions = collect
+ out.byKind[resourceKind] = out.snowflakeSessions
+ case types.KindSAMLIdPSession:
+ collect, err := newSAMLIdPSessionCollection(c.SAMLIdPSession, watch)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ out.samlIdPSessions = collect
+ out.byKind[resourceKind] = out.samlIdPSessions
+ case types.KindWebSession:
+ collect, err := newWebSessionCollection(c.WebSession, watch)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ out.webSessions = collect
+ out.byKind[resourceKind] = out.webSessions
+ }
}
}
diff --git a/lib/cache/generic_operations.go b/lib/cache/generic_operations.go
index 2b21e47f8e21f..5b3de6b532d36 100644
--- a/lib/cache/generic_operations.go
+++ b/lib/cache/generic_operations.go
@@ -57,7 +57,10 @@ func (g genericGetter[T, I]) get(ctx context.Context, identifier string) (T, err
}
out, err := rg.store.get(g.index, identifier)
- return g.clone(out), trace.Wrap(err)
+ if err != nil {
+ return t, trace.Wrap(err)
+ }
+ return g.clone(out), nil
}
// genericLister is a helper to retrieve a page of items from a cache collection.
diff --git a/lib/cache/legacy_collections.go b/lib/cache/legacy_collections.go
index e821e7d80ffc0..32dd49c9278e2 100644
--- a/lib/cache/legacy_collections.go
+++ b/lib/cache/legacy_collections.go
@@ -112,7 +112,6 @@ type legacyCollections struct {
accessListMembers collectionReader[accessListMembersGetter]
accessListReviews collectionReader[accessListReviewsGetter]
tunnelConnections collectionReader[tunnelConnectionGetter]
- appSessions collectionReader[appSessionGetter]
databaseObjects collectionReader[services.DatabaseObjectsGetter]
discoveryConfigs collectionReader[services.DiscoveryConfigsGetter]
installers collectionReader[installerGetter]
@@ -125,12 +124,8 @@ type legacyCollections struct {
networkRestrictions collectionReader[networkRestrictionGetter]
proxies collectionReader[services.ProxyGetter]
remoteClusters collectionReader[remoteClusterGetter]
- samlIdPServiceProviders collectionReader[samlIdPServiceProviderGetter]
- samlIdPSessions collectionReader[samlIdPSessionGetter]
- snowflakeSessions collectionReader[snowflakeSessionGetter]
uiConfigs collectionReader[uiConfigGetter]
userLoginStates collectionReader[services.UserLoginStatesGetter]
- webSessions collectionReader[webSessionGetter]
webTokens collectionReader[webTokenGetter]
dynamicWindowsDesktops collectionReader[dynamicWindowsDesktopsGetter]
accessGraphSettings collectionReader[accessGraphSettingsGetter]
@@ -190,45 +185,7 @@ func setupLegacyCollections(c *Cache, watches []types.WatchKind) (*legacyCollect
return nil, trace.BadParameter("missing parameter DynamicAccess")
}
collections.byKind[resourceKind] = &genericCollection[types.AccessRequest, noReader, accessRequestExecutor]{cache: c, watch: watch}
- case types.KindWebSession:
- switch watch.SubKind {
- case types.KindAppSession:
- if c.AppSession == nil {
- return nil, trace.BadParameter("missing parameter AppSession")
- }
- collections.appSessions = &genericCollection[types.WebSession, appSessionGetter, appSessionExecutor]{
- cache: c,
- watch: watch,
- }
- collections.byKind[resourceKind] = collections.appSessions
- case types.KindSnowflakeSession:
- if c.SnowflakeSession == nil {
- return nil, trace.BadParameter("missing parameter SnowflakeSession")
- }
- collections.snowflakeSessions = &genericCollection[types.WebSession, snowflakeSessionGetter, snowflakeSessionExecutor]{
- cache: c,
- watch: watch,
- }
- collections.byKind[resourceKind] = collections.snowflakeSessions
- case types.KindSAMLIdPSession:
- if c.SAMLIdPSession == nil {
- return nil, trace.BadParameter("missing parameter SAMLIdPSession")
- }
- collections.samlIdPSessions = &genericCollection[types.WebSession, samlIdPSessionGetter, samlIdPSessionExecutor]{
- cache: c,
- watch: watch,
- }
- collections.byKind[resourceKind] = collections.samlIdPSessions
- case types.KindWebSession:
- if c.WebSession == nil {
- return nil, trace.BadParameter("missing parameter WebSession")
- }
- collections.webSessions = &genericCollection[types.WebSession, webSessionGetter, webSessionExecutor]{
- cache: c,
- watch: watch,
- }
- collections.byKind[resourceKind] = collections.webSessions
- }
+
case types.KindWebToken:
if c.WebToken == nil {
return nil, trace.BadParameter("missing parameter WebToken")
@@ -283,15 +240,6 @@ func setupLegacyCollections(c *Cache, watches []types.WatchKind) (*legacyCollect
watch: watch,
}
collections.byKind[resourceKind] = collections.dynamicWindowsDesktops
- case types.KindSAMLIdPServiceProvider:
- if c.SAMLIdPServiceProviders == nil {
- return nil, trace.BadParameter("missing parameter SAMLIdPServiceProviders")
- }
- collections.samlIdPServiceProviders = &genericCollection[types.SAMLIdPServiceProvider, samlIdPServiceProviderGetter, samlIdPServiceProvidersExecutor]{
- cache: c,
- watch: watch,
- }
- collections.byKind[resourceKind] = collections.samlIdPServiceProviders
case types.KindIntegration:
if c.Integrations == nil {
return nil, trace.BadParameter("missing parameter Integrations")
@@ -679,216 +627,6 @@ func (databaseObjectExecutor) getReader(cache *Cache, cacheOK bool) services.Dat
var _ executor[*dbobjectv1.DatabaseObject, services.DatabaseObjectsGetter] = databaseObjectExecutor{}
-type appSessionExecutor struct{}
-
-func (appSessionExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]types.WebSession, error) {
- var (
- startKey string
- sessions []types.WebSession
- )
- for {
- webSessions, nextKey, err := cache.AppSession.ListAppSessions(ctx, 0, startKey, "")
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- if !loadSecrets {
- for i := 0; i < len(webSessions); i++ {
- webSessions[i] = webSessions[i].WithoutSecrets()
- }
- }
-
- sessions = append(sessions, webSessions...)
-
- if nextKey == "" {
- break
- }
- startKey = nextKey
- }
- return sessions, nil
-}
-
-func (appSessionExecutor) upsert(ctx context.Context, cache *Cache, resource types.WebSession) error {
- return cache.appSessionCache.UpsertAppSession(ctx, resource)
-}
-
-func (appSessionExecutor) deleteAll(ctx context.Context, cache *Cache) error {
- return cache.appSessionCache.DeleteAllAppSessions(ctx)
-}
-
-func (appSessionExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error {
- return cache.appSessionCache.DeleteAppSession(ctx, types.DeleteAppSessionRequest{
- SessionID: resource.GetName(),
- })
-}
-
-func (appSessionExecutor) isSingleton() bool { return false }
-
-func (appSessionExecutor) getReader(cache *Cache, cacheOK bool) appSessionGetter {
- if cacheOK {
- return cache.appSessionCache
- }
- return cache.Config.AppSession
-}
-
-type appSessionGetter interface {
- GetAppSession(ctx context.Context, req types.GetAppSessionRequest) (types.WebSession, error)
- ListAppSessions(ctx context.Context, pageSize int, pageToken, user string) ([]types.WebSession, string, error)
-}
-
-var _ executor[types.WebSession, appSessionGetter] = appSessionExecutor{}
-
-type snowflakeSessionExecutor struct{}
-
-func (snowflakeSessionExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]types.WebSession, error) {
- webSessions, err := cache.SnowflakeSession.GetSnowflakeSessions(ctx)
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- if !loadSecrets {
- for i := 0; i < len(webSessions); i++ {
- webSessions[i] = webSessions[i].WithoutSecrets()
- }
- }
-
- return webSessions, nil
-}
-
-func (snowflakeSessionExecutor) upsert(ctx context.Context, cache *Cache, resource types.WebSession) error {
- return cache.snowflakeSessionCache.UpsertSnowflakeSession(ctx, resource)
-}
-
-func (snowflakeSessionExecutor) deleteAll(ctx context.Context, cache *Cache) error {
- return cache.snowflakeSessionCache.DeleteAllSnowflakeSessions(ctx)
-}
-
-func (snowflakeSessionExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error {
- return cache.snowflakeSessionCache.DeleteSnowflakeSession(ctx, types.DeleteSnowflakeSessionRequest{
- SessionID: resource.GetName(),
- })
-}
-
-func (snowflakeSessionExecutor) isSingleton() bool { return false }
-
-func (snowflakeSessionExecutor) getReader(cache *Cache, cacheOK bool) snowflakeSessionGetter {
- if cacheOK {
- return cache.snowflakeSessionCache
- }
- return cache.Config.SnowflakeSession
-}
-
-type snowflakeSessionGetter interface {
- GetSnowflakeSession(context.Context, types.GetSnowflakeSessionRequest) (types.WebSession, error)
-}
-
-var _ executor[types.WebSession, snowflakeSessionGetter] = snowflakeSessionExecutor{}
-
-//nolint:revive // Because we want this to be IdP.
-type samlIdPSessionExecutor struct{}
-
-func (samlIdPSessionExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]types.WebSession, error) {
- var (
- startKey string
- sessions []types.WebSession
- )
- for {
- webSessions, nextKey, err := cache.SAMLIdPSession.ListSAMLIdPSessions(ctx, 0, startKey, "")
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- if !loadSecrets {
- for i := 0; i < len(webSessions); i++ {
- webSessions[i] = webSessions[i].WithoutSecrets()
- }
- }
-
- sessions = append(sessions, webSessions...)
-
- if nextKey == "" {
- break
- }
- startKey = nextKey
- }
- return sessions, nil
-}
-
-func (samlIdPSessionExecutor) upsert(ctx context.Context, cache *Cache, resource types.WebSession) error {
- return cache.samlIdPSessionCache.UpsertSAMLIdPSession(ctx, resource)
-}
-
-func (samlIdPSessionExecutor) deleteAll(ctx context.Context, cache *Cache) error {
- return cache.samlIdPSessionCache.DeleteAllSAMLIdPSessions(ctx)
-}
-
-func (samlIdPSessionExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error {
- return cache.samlIdPSessionCache.DeleteSAMLIdPSession(ctx, types.DeleteSAMLIdPSessionRequest{
- SessionID: resource.GetName(),
- })
-}
-
-func (samlIdPSessionExecutor) isSingleton() bool { return false }
-
-func (samlIdPSessionExecutor) getReader(cache *Cache, cacheOK bool) samlIdPSessionGetter {
- if cacheOK {
- return cache.samlIdPSessionCache
- }
- return cache.Config.SAMLIdPSession
-}
-
-type samlIdPSessionGetter interface {
- GetSAMLIdPSession(context.Context, types.GetSAMLIdPSessionRequest) (types.WebSession, error)
-}
-
-var _ executor[types.WebSession, samlIdPSessionGetter] = samlIdPSessionExecutor{}
-
-type webSessionExecutor struct{}
-
-func (webSessionExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]types.WebSession, error) {
- webSessions, err := cache.WebSession.List(ctx)
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- if !loadSecrets {
- for i := 0; i < len(webSessions); i++ {
- webSessions[i] = webSessions[i].WithoutSecrets()
- }
- }
-
- return webSessions, nil
-}
-
-func (webSessionExecutor) upsert(ctx context.Context, cache *Cache, resource types.WebSession) error {
- return cache.webSessionCache.Upsert(ctx, resource)
-}
-
-func (webSessionExecutor) deleteAll(ctx context.Context, cache *Cache) error {
- return cache.webSessionCache.DeleteAll(ctx)
-}
-
-func (webSessionExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error {
- return cache.webSessionCache.Delete(ctx, types.DeleteWebSessionRequest{
- SessionID: resource.GetName(),
- })
-}
-
-func (webSessionExecutor) isSingleton() bool { return false }
-
-func (webSessionExecutor) getReader(cache *Cache, cacheOK bool) webSessionGetter {
- if cacheOK {
- return cache.webSessionCache
- }
- return cache.Config.WebSession
-}
-
-type webSessionGetter interface {
- Get(ctx context.Context, req types.GetWebSessionRequest) (types.WebSession, error)
-}
-
-var _ executor[types.WebSession, webSessionGetter] = webSessionExecutor{}
-
type webTokenExecutor struct{}
func (webTokenExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]types.WebToken, error) {
@@ -1314,64 +1052,6 @@ func (userTasksExecutor) getReader(cache *Cache, cacheOK bool) userTasksGetter {
var _ executor[*usertasksv1.UserTask, userTasksGetter] = userTasksExecutor{}
-//nolint:revive // Because we want this to be IdP.
-type samlIdPServiceProvidersExecutor struct{}
-
-func (samlIdPServiceProvidersExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]types.SAMLIdPServiceProvider, error) {
- var (
- startKey string
- sps []types.SAMLIdPServiceProvider
- )
- for {
- var samlProviders []types.SAMLIdPServiceProvider
- var err error
- samlProviders, startKey, err = cache.SAMLIdPServiceProviders.ListSAMLIdPServiceProviders(ctx, 0, startKey)
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- sps = append(sps, samlProviders...)
-
- if startKey == "" {
- break
- }
- }
-
- return sps, nil
-}
-
-func (samlIdPServiceProvidersExecutor) upsert(ctx context.Context, cache *Cache, resource types.SAMLIdPServiceProvider) error {
- err := cache.samlIdPServiceProvidersCache.CreateSAMLIdPServiceProvider(ctx, resource)
- if trace.IsAlreadyExists(err) {
- err = cache.samlIdPServiceProvidersCache.UpdateSAMLIdPServiceProvider(ctx, resource)
- }
- return trace.Wrap(err)
-}
-
-func (samlIdPServiceProvidersExecutor) deleteAll(ctx context.Context, cache *Cache) error {
- return cache.samlIdPServiceProvidersCache.DeleteAllSAMLIdPServiceProviders(ctx)
-}
-
-func (samlIdPServiceProvidersExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error {
- return cache.samlIdPServiceProvidersCache.DeleteSAMLIdPServiceProvider(ctx, resource.GetName())
-}
-
-func (samlIdPServiceProvidersExecutor) isSingleton() bool { return false }
-
-func (samlIdPServiceProvidersExecutor) getReader(cache *Cache, cacheOK bool) samlIdPServiceProviderGetter {
- if cacheOK {
- return cache.samlIdPServiceProvidersCache
- }
- return cache.Config.SAMLIdPServiceProviders
-}
-
-type samlIdPServiceProviderGetter interface {
- ListSAMLIdPServiceProviders(context.Context, int, string) ([]types.SAMLIdPServiceProvider, string, error)
- GetSAMLIdPServiceProvider(ctx context.Context, name string) (types.SAMLIdPServiceProvider, error)
-}
-
-var _ executor[types.SAMLIdPServiceProvider, samlIdPServiceProviderGetter] = samlIdPServiceProvidersExecutor{}
-
// collectionReader extends the collection interface, adding routing capabilities.
type collectionReader[R any] interface {
legacyCollection
diff --git a/lib/cache/saml_idp.go b/lib/cache/saml_idp.go
new file mode 100644
index 0000000000000..bf67a3f6b62a4
--- /dev/null
+++ b/lib/cache/saml_idp.go
@@ -0,0 +1,194 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package cache
+
+import (
+ "context"
+
+ "github.com/gravitational/trace"
+
+ "github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/lib/services"
+)
+
+type samlIdPServiceProviderIndex string
+
+const samlIdPServiceProviderNameIndex samlIdPServiceProviderIndex = "name"
+
+func newSAMLIdPServiceProviderCollection(upstream services.SAMLIdPServiceProviders, w types.WatchKind) (*collection[types.SAMLIdPServiceProvider, samlIdPServiceProviderIndex], error) {
+ if upstream == nil {
+ return nil, trace.BadParameter("missing parameter SAMLIdPSession")
+ }
+
+ return &collection[types.SAMLIdPServiceProvider, samlIdPServiceProviderIndex]{
+ store: newStore(map[samlIdPServiceProviderIndex]func(types.SAMLIdPServiceProvider) string{
+ samlIdPServiceProviderNameIndex: func(r types.SAMLIdPServiceProvider) string {
+ return r.GetMetadata().Name
+ },
+ }),
+ fetcher: func(ctx context.Context, loadSecrets bool) ([]types.SAMLIdPServiceProvider, error) {
+ var startKey string
+ var sps []types.SAMLIdPServiceProvider
+ for {
+ var samlProviders []types.SAMLIdPServiceProvider
+ var err error
+ samlProviders, startKey, err = upstream.ListSAMLIdPServiceProviders(ctx, 0, startKey)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ sps = append(sps, samlProviders...)
+
+ if startKey == "" {
+ break
+ }
+ }
+
+ return sps, nil
+ },
+ headerTransform: func(hdr *types.ResourceHeader) types.SAMLIdPServiceProvider {
+ return &types.SAMLIdPServiceProviderV1{
+ ResourceHeader: types.ResourceHeader{
+ Kind: hdr.Kind,
+ Version: hdr.Version,
+ Metadata: types.Metadata{
+ Name: hdr.Metadata.Name,
+ },
+ },
+ }
+ },
+ watch: w,
+ }, nil
+}
+
+// ListSAMLIdPServiceProviders returns a paginated list of SAML IdP service provider resources.
+func (c *Cache) ListSAMLIdPServiceProviders(ctx context.Context, pageSize int, pageToken string) ([]types.SAMLIdPServiceProvider, string, error) {
+ ctx, span := c.Tracer.Start(ctx, "cache/ListSAMLIdPServiceProviders")
+ defer span.End()
+
+ lister := genericLister[types.SAMLIdPServiceProvider, samlIdPServiceProviderIndex]{
+ cache: c,
+ collection: c.collections.samlIdPServiceProviders,
+ index: samlIdPServiceProviderNameIndex,
+ defaultPageSize: 200,
+ upstreamList: c.Config.SAMLIdPServiceProviders.ListSAMLIdPServiceProviders,
+ nextToken: func(t types.SAMLIdPServiceProvider) string {
+ return t.GetMetadata().Name
+ },
+ clone: types.SAMLIdPServiceProvider.Copy,
+ }
+ out, next, err := lister.list(ctx, pageSize, pageToken)
+ return out, next, trace.Wrap(err)
+}
+
+// GetSAMLIdPServiceProvider returns the specified SAML IdP service provider resources.
+func (c *Cache) GetSAMLIdPServiceProvider(ctx context.Context, name string) (types.SAMLIdPServiceProvider, error) {
+ ctx, span := c.Tracer.Start(ctx, "cache/GetSAMLIdPServiceProvider")
+ defer span.End()
+
+ getter := genericGetter[types.SAMLIdPServiceProvider, samlIdPServiceProviderIndex]{
+ cache: c,
+ collection: c.collections.samlIdPServiceProviders,
+ index: samlIdPServiceProviderNameIndex,
+ upstreamGet: c.Config.SAMLIdPServiceProviders.GetSAMLIdPServiceProvider,
+ clone: types.SAMLIdPServiceProvider.Copy,
+ }
+ out, err := getter.get(ctx, name)
+ return out, trace.Wrap(err)
+}
+
+type samlIdPSessionIndex string
+
+const samlIdPSessionNameIndex samlIdPSessionIndex = "name"
+
+func newSAMLIdPSessionCollection(upstream services.SAMLIdPSession, w types.WatchKind) (*collection[types.WebSession, samlIdPSessionIndex], error) {
+ if upstream == nil {
+ return nil, trace.BadParameter("missing parameter SAMLIdPSession")
+ }
+
+ return &collection[types.WebSession, samlIdPSessionIndex]{
+ store: newStore(map[samlIdPSessionIndex]func(types.WebSession) string{
+ samlIdPSessionNameIndex: func(r types.WebSession) string {
+ return r.GetMetadata().Name
+ },
+ }),
+ fetcher: func(ctx context.Context, loadSecrets bool) ([]types.WebSession, error) {
+ var startKey string
+ var sessions []types.WebSession
+ for {
+ webSessions, nextKey, err := upstream.ListSAMLIdPSessions(ctx, 0, startKey, "")
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ if !loadSecrets {
+ for i := 0; i < len(webSessions); i++ {
+ webSessions[i] = webSessions[i].WithoutSecrets()
+ }
+ }
+
+ sessions = append(sessions, webSessions...)
+
+ if nextKey == "" {
+ break
+ }
+ startKey = nextKey
+ }
+ return sessions, nil
+ },
+ headerTransform: func(hdr *types.ResourceHeader) types.WebSession {
+ return &types.WebSessionV2{
+ Kind: hdr.Kind,
+ SubKind: hdr.SubKind,
+ Version: hdr.Version,
+ Metadata: types.Metadata{
+ Name: hdr.Metadata.Name,
+ },
+ }
+ },
+ watch: w,
+ }, nil
+}
+
+// GetSAMLIdPSession gets a SAML IdP session.
+func (c *Cache) GetSAMLIdPSession(ctx context.Context, req types.GetSAMLIdPSessionRequest) (types.WebSession, error) {
+ ctx, span := c.Tracer.Start(ctx, "cache/GetSAMLIdPSession")
+ defer span.End()
+
+ var upstreamRead bool
+ getter := genericGetter[types.WebSession, samlIdPSessionIndex]{
+ cache: c,
+ collection: c.collections.samlIdPSessions,
+ index: samlIdPSessionNameIndex,
+ upstreamGet: func(ctx context.Context, s string) (types.WebSession, error) {
+ upstreamRead = true
+
+ session, err := c.Config.SAMLIdPSession.GetSAMLIdPSession(ctx, types.GetSAMLIdPSessionRequest{SessionID: s})
+ return session, trace.Wrap(err)
+ },
+ clone: types.WebSession.Copy,
+ }
+ out, err := getter.get(ctx, req.SessionID)
+ if trace.IsNotFound(err) && !upstreamRead {
+ // fallback is sane because method is never used
+ // in construction of derivative caches.
+ if item, err := c.Config.SAMLIdPSession.GetSAMLIdPSession(ctx, req); err == nil {
+ return item, nil
+ }
+ }
+ return out, trace.Wrap(err)
+}
diff --git a/lib/cache/saml_idp_test.go b/lib/cache/saml_idp_test.go
new file mode 100644
index 0000000000000..10b6ab64482a9
--- /dev/null
+++ b/lib/cache/saml_idp_test.go
@@ -0,0 +1,104 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package cache
+
+import (
+ "context"
+ "strconv"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/gravitational/teleport/api/types"
+)
+
+// TestSAMLIdPServiceProviders tests that CRUD operations on SAML IdP service provider resources are
+// replicated from the backend to the cache.
+func TestSAMLIdPServiceProviders(t *testing.T) {
+ t.Parallel()
+
+ p := newTestPack(t, ForAuth)
+ t.Cleanup(p.Close)
+
+ testResources(t, p, testFuncs[types.SAMLIdPServiceProvider]{
+ newResource: func(name string) (types.SAMLIdPServiceProvider, error) {
+ return types.NewSAMLIdPServiceProvider(
+ types.Metadata{
+ Name: name,
+ },
+ types.SAMLIdPServiceProviderSpecV1{
+ EntityDescriptor: testEntityDescriptor,
+ EntityID: "IAMShowcase",
+ })
+ },
+ create: p.samlIDPServiceProviders.CreateSAMLIdPServiceProvider,
+ list: func(ctx context.Context) ([]types.SAMLIdPServiceProvider, error) {
+ results, _, err := p.samlIDPServiceProviders.ListSAMLIdPServiceProviders(ctx, 0, "")
+ return results, err
+ },
+ cacheGet: p.cache.GetSAMLIdPServiceProvider,
+ cacheList: func(ctx context.Context) ([]types.SAMLIdPServiceProvider, error) {
+ results, _, err := p.cache.ListSAMLIdPServiceProviders(ctx, 0, "")
+ return results, err
+ },
+ update: p.samlIDPServiceProviders.UpdateSAMLIdPServiceProvider,
+ deleteAll: p.samlIDPServiceProviders.DeleteAllSAMLIdPServiceProviders,
+ })
+}
+
+func TestSAMLIdPSessions(t *testing.T) {
+ t.Parallel()
+ ctx := t.Context()
+
+ p := newTestPack(t, ForAuth)
+ t.Cleanup(p.Close)
+
+ for i := 0; i < 31; i++ {
+ err := p.samlIdPSessionsS.UpsertSAMLIdPSession(t.Context(), &types.WebSessionV2{
+ Kind: types.KindWebSession,
+ SubKind: types.KindSAMLIdPSession,
+ Version: types.V2,
+ Metadata: types.Metadata{
+ Name: "saml-session" + strconv.Itoa(i+1),
+ },
+ Spec: types.WebSessionSpecV2{
+ User: "fish",
+ },
+ })
+ require.NoError(t, err)
+ }
+
+ require.EventuallyWithT(t, func(t *assert.CollectT) {
+ for i := 0; i < 31; i++ {
+ session, err := p.cache.GetSAMLIdPSession(ctx, types.GetSAMLIdPSessionRequest{SessionID: "saml-session" + strconv.Itoa(i+1)})
+ assert.NoError(t, err)
+ assert.NotNil(t, session)
+ }
+ }, 15*time.Second, 100*time.Millisecond)
+
+ require.NoError(t, p.samlIdPSessionsS.DeleteAllSAMLIdPSessions(ctx))
+
+ require.EventuallyWithT(t, func(t *assert.CollectT) {
+ for i := 0; i < 31; i++ {
+ session, err := p.cache.GetSAMLIdPSession(ctx, types.GetSAMLIdPSessionRequest{SessionID: "saml-session" + strconv.Itoa(i+1)})
+ assert.Error(t, err)
+ assert.Nil(t, session)
+ }
+ }, 15*time.Second, 100*time.Millisecond)
+}
diff --git a/lib/cache/web_session.go b/lib/cache/web_session.go
new file mode 100644
index 0000000000000..efdec6c24236c
--- /dev/null
+++ b/lib/cache/web_session.go
@@ -0,0 +1,320 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package cache
+
+import (
+ "context"
+ "iter"
+
+ "github.com/gravitational/trace"
+
+ "github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/lib/services"
+ "github.com/gravitational/teleport/lib/utils/sortcache"
+)
+
+type webSessionIndex string
+
+const webSessionNameIndex webSessionIndex = "name"
+
+func newWebSessionCollection(upstream types.WebSessionInterface, w types.WatchKind) (*collection[types.WebSession, webSessionIndex], error) {
+ if upstream == nil {
+ return nil, trace.BadParameter("missing parameter SAMLIdPSession")
+ }
+
+ return &collection[types.WebSession, webSessionIndex]{
+ store: newStore(map[webSessionIndex]func(types.WebSession) string{
+ webSessionNameIndex: func(r types.WebSession) string {
+ return r.GetMetadata().Name
+ },
+ }),
+ fetcher: func(ctx context.Context, loadSecrets bool) ([]types.WebSession, error) {
+ webSessions, err := upstream.List(ctx)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ if !loadSecrets {
+ for i := 0; i < len(webSessions); i++ {
+ webSessions[i] = webSessions[i].WithoutSecrets()
+ }
+ }
+
+ return webSessions, nil
+ },
+ headerTransform: func(hdr *types.ResourceHeader) types.WebSession {
+ return &types.WebSessionV2{
+ Kind: hdr.Kind,
+ SubKind: hdr.SubKind,
+ Version: hdr.Version,
+ Metadata: types.Metadata{
+ Name: hdr.Metadata.Name,
+ },
+ }
+ },
+ watch: w,
+ }, nil
+}
+
+// GetWebSession gets a regular web session.
+func (c *Cache) GetWebSession(ctx context.Context, req types.GetWebSessionRequest) (types.WebSession, error) {
+ ctx, span := c.Tracer.Start(ctx, "cache/GetWebSession")
+ defer span.End()
+
+ var upstreamRead bool
+ getter := genericGetter[types.WebSession, webSessionIndex]{
+ cache: c,
+ collection: c.collections.webSessions,
+ index: webSessionNameIndex,
+ upstreamGet: func(ctx context.Context, s string) (types.WebSession, error) {
+ upstreamRead = true
+
+ session, err := c.Config.WebSession.Get(ctx, types.GetWebSessionRequest{SessionID: s})
+ return session, trace.Wrap(err)
+ },
+ clone: types.WebSession.Copy,
+ }
+ out, err := getter.get(ctx, req.SessionID)
+ if trace.IsNotFound(err) && !upstreamRead {
+ // fallback is sane because method is never used
+ // in construction of derivative caches.
+ if sess, err := c.Config.WebSession.Get(ctx, req); err == nil {
+ c.Logger.DebugContext(ctx, "Cache was forced to load session from upstream",
+ "session_kind", sess.GetSubKind(),
+ "session", sess.GetName(),
+ )
+ return sess, nil
+ }
+ }
+ return out, trace.Wrap(err)
+}
+
+type appSessionIndex string
+
+const (
+ appSessionNameIndex appSessionIndex = "name"
+ appSessionUserIndex appSessionIndex = "user"
+)
+
+func newAppSessionCollection(upstream services.AppSession, w types.WatchKind) (*collection[types.WebSession, appSessionIndex], error) {
+ if upstream == nil {
+ return nil, trace.BadParameter("missing parameter AppSession")
+ }
+
+ return &collection[types.WebSession, appSessionIndex]{
+ store: newStore(map[appSessionIndex]func(types.WebSession) string{
+ appSessionNameIndex: func(r types.WebSession) string {
+ return r.GetMetadata().Name
+ },
+ appSessionUserIndex: func(r types.WebSession) string {
+ return r.GetUser() + "/" + r.GetMetadata().Name
+ },
+ }),
+ fetcher: func(ctx context.Context, loadSecrets bool) ([]types.WebSession, error) {
+ var startKey string
+ var sessions []types.WebSession
+
+ for {
+ webSessions, nextKey, err := upstream.ListAppSessions(ctx, 0, startKey, "")
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ if !loadSecrets {
+ for i := 0; i < len(webSessions); i++ {
+ webSessions[i] = webSessions[i].WithoutSecrets()
+ }
+ }
+
+ sessions = append(sessions, webSessions...)
+
+ if nextKey == "" {
+ break
+ }
+ startKey = nextKey
+ }
+ return sessions, nil
+ },
+ headerTransform: func(hdr *types.ResourceHeader) types.WebSession {
+ return &types.WebSessionV2{
+ Kind: hdr.Kind,
+ SubKind: hdr.SubKind,
+ Version: hdr.Version,
+ Metadata: types.Metadata{
+ Name: hdr.Metadata.Name,
+ },
+ }
+ },
+ watch: w,
+ }, nil
+}
+
+// GetAppSession gets an application web session.
+func (c *Cache) GetAppSession(ctx context.Context, req types.GetAppSessionRequest) (types.WebSession, error) {
+ ctx, span := c.Tracer.Start(ctx, "cache/GetAppSession")
+ defer span.End()
+
+ var upstreamRead bool
+ getter := genericGetter[types.WebSession, appSessionIndex]{
+ cache: c,
+ collection: c.collections.appSessions,
+ index: appSessionNameIndex,
+ upstreamGet: func(ctx context.Context, s string) (types.WebSession, error) {
+ upstreamRead = true
+
+ session, err := c.Config.AppSession.GetAppSession(ctx, types.GetAppSessionRequest{SessionID: s})
+ return session, trace.Wrap(err)
+ },
+ clone: types.WebSession.Copy,
+ }
+ out, err := getter.get(ctx, req.SessionID)
+ if trace.IsNotFound(err) && !upstreamRead {
+ // fallback is sane because method is never used
+ // in construction of derivative caches.
+ if sess, err := c.Config.AppSession.GetAppSession(ctx, req); err == nil {
+ c.Logger.DebugContext(ctx, "Cache was forced to load session from upstream",
+ "session_kind", sess.GetSubKind(),
+ "session", sess.GetName(),
+ )
+ return sess, nil
+ }
+ }
+ return out, trace.Wrap(err)
+}
+
+// ListAppSessions returns a page of application web sessions.
+func (c *Cache) ListAppSessions(ctx context.Context, pageSize int, pageToken, user string) ([]types.WebSession, string, error) {
+ ctx, span := c.Tracer.Start(ctx, "cache/ListAppSessions")
+ defer span.End()
+
+ rg, err := acquireReadGuard(c, c.collections.appSessions)
+ if err != nil {
+ return nil, "", trace.Wrap(err)
+ }
+ defer rg.Release()
+
+ if !rg.ReadCache() {
+ out, next, err := c.Config.AppSession.ListAppSessions(ctx, pageSize, pageToken, user)
+ return out, next, trace.Wrap(err)
+ }
+
+ // Adjust page size, so it can't be too large.
+ const maxSessionPageSize = 200
+ if pageSize <= 0 || pageSize > maxSessionPageSize {
+ pageSize = maxSessionPageSize
+ }
+
+ var sessions iter.Seq[types.WebSession]
+ if user == "" {
+ sessions = rg.store.resources(appSessionNameIndex, pageToken, "")
+ } else {
+ startKey := user + "/"
+ endKey := sortcache.NextKey(startKey)
+ if pageToken != "" {
+ startKey += pageToken
+ }
+
+ sessions = rg.store.resources(appSessionUserIndex, startKey, endKey)
+ }
+
+ var out []types.WebSession
+ for sess := range sessions {
+ if len(out) == pageSize {
+ return out, sess.GetName(), nil
+ }
+
+ out = append(out, sess.Copy())
+ }
+
+ return out, "", nil
+}
+
+type snowflakeSessionIndex string
+
+const snowflakeSessionNameIndex snowflakeSessionIndex = "name"
+
+func newSnowflakeSessionCollection(upstream services.SnowflakeSession, w types.WatchKind) (*collection[types.WebSession, snowflakeSessionIndex], error) {
+ if upstream == nil {
+ return nil, trace.BadParameter("missing parameter AppSession")
+ }
+
+ return &collection[types.WebSession, snowflakeSessionIndex]{
+ store: newStore(map[snowflakeSessionIndex]func(types.WebSession) string{
+ snowflakeSessionNameIndex: func(r types.WebSession) string {
+ return r.GetMetadata().Name
+ },
+ }),
+ fetcher: func(ctx context.Context, loadSecrets bool) ([]types.WebSession, error) {
+ webSessions, err := upstream.GetSnowflakeSessions(ctx)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ if !loadSecrets {
+ for i := 0; i < len(webSessions); i++ {
+ webSessions[i] = webSessions[i].WithoutSecrets()
+ }
+ }
+
+ return webSessions, nil
+ },
+ headerTransform: func(hdr *types.ResourceHeader) types.WebSession {
+ return &types.WebSessionV2{
+ Kind: hdr.Kind,
+ SubKind: hdr.SubKind,
+ Version: hdr.Version,
+ Metadata: types.Metadata{
+ Name: hdr.Metadata.Name,
+ },
+ }
+ },
+ watch: w,
+ }, nil
+}
+
+// GetSnowflakeSession gets Snowflake web session.
+func (c *Cache) GetSnowflakeSession(ctx context.Context, req types.GetSnowflakeSessionRequest) (types.WebSession, error) {
+ ctx, span := c.Tracer.Start(ctx, "cache/GetSnowflakeSession")
+ defer span.End()
+
+ var upstreamRead bool
+ getter := genericGetter[types.WebSession, snowflakeSessionIndex]{
+ cache: c,
+ collection: c.collections.snowflakeSessions,
+ index: snowflakeSessionNameIndex,
+ upstreamGet: func(ctx context.Context, s string) (types.WebSession, error) {
+ upstreamRead = true
+
+ session, err := c.Config.SnowflakeSession.GetSnowflakeSession(ctx, types.GetSnowflakeSessionRequest{SessionID: s})
+ return session, trace.Wrap(err)
+ },
+ clone: types.WebSession.Copy,
+ }
+ out, err := getter.get(ctx, req.SessionID)
+ if trace.IsNotFound(err) && !upstreamRead {
+ // fallback is sane because method is never used
+ // in construction of derivative caches.
+ if sess, err := c.Config.SnowflakeSession.GetSnowflakeSession(ctx, req); err == nil {
+ c.Logger.DebugContext(ctx, "Cache was forced to load session from upstream",
+ "session_kind", sess.GetSubKind(),
+ "session", sess.GetName(),
+ )
+ return sess, nil
+ }
+ }
+ return out, trace.Wrap(err)
+}
diff --git a/lib/cache/web_session_test.go b/lib/cache/web_session_test.go
new file mode 100644
index 0000000000000..0939fb810acdc
--- /dev/null
+++ b/lib/cache/web_session_test.go
@@ -0,0 +1,220 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package cache
+
+import (
+ "strconv"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/gravitational/teleport/api/types"
+)
+
+func TestAppSessions(t *testing.T) {
+ t.Parallel()
+ ctx := t.Context()
+
+ p := newTestPack(t, ForAuth)
+ t.Cleanup(p.Close)
+
+ for i := 0; i < 31; i++ {
+ err := p.appSessionS.UpsertAppSession(t.Context(), &types.WebSessionV2{
+ Kind: types.KindWebSession,
+ SubKind: types.KindAppSession,
+ Version: types.V2,
+ Metadata: types.Metadata{
+ Name: "app-session" + strconv.Itoa(i+1),
+ },
+ Spec: types.WebSessionSpecV2{
+ User: "fish",
+ },
+ })
+ require.NoError(t, err)
+ }
+
+ for i := 0; i < 3; i++ {
+ err := p.appSessionS.UpsertAppSession(t.Context(), &types.WebSessionV2{
+ Kind: types.KindWebSession,
+ SubKind: types.KindAppSession,
+ Version: types.V2,
+ Metadata: types.Metadata{
+ Name: "app-session" + strconv.Itoa(i+100),
+ },
+ Spec: types.WebSessionSpecV2{
+ User: "llama",
+ },
+ })
+ require.NoError(t, err)
+ }
+
+ require.EventuallyWithT(t, func(t *assert.CollectT) {
+ expected, next, err := p.appSessionS.ListAppSessions(ctx, 0, "", "")
+ assert.NoError(t, err)
+ assert.Empty(t, next)
+ assert.Len(t, expected, 34)
+
+ cached, next, err := p.cache.ListAppSessions(ctx, 0, "", "")
+ assert.NoError(t, err)
+ assert.Empty(t, next)
+ assert.Len(t, cached, 34)
+ }, 15*time.Second, 100*time.Millisecond)
+
+ session, err := p.cache.GetAppSession(ctx, types.GetAppSessionRequest{
+ SessionID: "app-session100",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, session)
+ require.Equal(t, "llama", session.GetUser())
+
+ session, err = p.cache.GetAppSession(ctx, types.GetAppSessionRequest{
+ SessionID: "app-session1",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, session)
+ require.Equal(t, "fish", session.GetUser())
+
+ var sessions []types.WebSession
+ for pageToken := ""; ; {
+ cached, next, err := p.cache.ListAppSessions(ctx, 1, pageToken, "llama")
+ if !assert.NoError(t, err) {
+ return
+ }
+ sessions = append(sessions, cached...)
+ pageToken = next
+ if next == "" {
+ break
+ }
+ }
+ assert.Len(t, sessions, 3)
+
+ sessions = nil
+ for pageToken := ""; ; {
+ cached, next, err := p.cache.ListAppSessions(ctx, 7, pageToken, "fish")
+ if !assert.NoError(t, err) {
+ return
+ }
+ sessions = append(sessions, cached...)
+ pageToken = next
+ if next == "" {
+ break
+ }
+ }
+ assert.Len(t, sessions, 31)
+
+ require.NoError(t, p.appSessionS.DeleteAllAppSessions(ctx))
+
+ require.EventuallyWithT(t, func(t *assert.CollectT) {
+ cached, next, err := p.cache.ListAppSessions(ctx, 0, "", "")
+ assert.NoError(t, err)
+ assert.Empty(t, next)
+ assert.Empty(t, cached)
+ }, 15*time.Second, 100*time.Millisecond)
+}
+
+func TestWebSessions(t *testing.T) {
+ t.Parallel()
+ ctx := t.Context()
+
+ p := newTestPack(t, ForAuth)
+ t.Cleanup(p.Close)
+
+ for i := 0; i < 31; i++ {
+ err := p.webSessionS.Upsert(t.Context(), &types.WebSessionV2{
+ Kind: types.KindWebSession,
+ SubKind: types.KindWebSession,
+ Version: types.V2,
+ Metadata: types.Metadata{
+ Name: "web-session" + strconv.Itoa(i+1),
+ },
+ Spec: types.WebSessionSpecV2{
+ User: "fish",
+ },
+ })
+ require.NoError(t, err)
+ }
+
+ require.EventuallyWithT(t, func(t *assert.CollectT) {
+ expected, err := p.webSessionS.List(ctx)
+ assert.NoError(t, err)
+ assert.Len(t, expected, 31)
+
+ for _, session := range expected {
+ cached, err := p.cache.GetWebSession(ctx, types.GetWebSessionRequest{SessionID: session.GetName()})
+ assert.NoError(t, err)
+ assert.Empty(t, cmp.Diff(session, cached))
+ }
+ }, 15*time.Second, 100*time.Millisecond)
+
+ require.NoError(t, p.webSessionS.DeleteAll(ctx))
+
+ require.EventuallyWithT(t, func(t *assert.CollectT) {
+ for i := 0; i < 31; i++ {
+ session, err := p.cache.GetWebSession(ctx, types.GetWebSessionRequest{SessionID: "web-session" + strconv.Itoa(i+1)})
+ assert.Error(t, err)
+ assert.Nil(t, session)
+ }
+ }, 15*time.Second, 100*time.Millisecond)
+}
+
+func TestSnowflakeSessions(t *testing.T) {
+ t.Parallel()
+ ctx := t.Context()
+
+ p := newTestPack(t, ForAuth)
+ t.Cleanup(p.Close)
+
+ for i := 0; i < 31; i++ {
+ err := p.snowflakeSessionS.UpsertSnowflakeSession(t.Context(), &types.WebSessionV2{
+ Kind: types.KindWebSession,
+ SubKind: types.KindSnowflakeSession,
+ Version: types.V2,
+ Metadata: types.Metadata{
+ Name: "snow-session" + strconv.Itoa(i+1),
+ },
+ Spec: types.WebSessionSpecV2{
+ User: "fish",
+ },
+ })
+ require.NoError(t, err)
+ }
+
+ require.EventuallyWithT(t, func(t *assert.CollectT) {
+ expected, err := p.snowflakeSessionS.GetSnowflakeSessions(ctx)
+ assert.NoError(t, err)
+ assert.Len(t, expected, 31)
+
+ for _, session := range expected {
+ cached, err := p.cache.GetSnowflakeSession(ctx, types.GetSnowflakeSessionRequest{SessionID: session.GetName()})
+ assert.NoError(t, err)
+ assert.Empty(t, cmp.Diff(session, cached))
+ }
+ }, 15*time.Second, 100*time.Millisecond)
+
+ require.NoError(t, p.snowflakeSessionS.DeleteAllSnowflakeSessions(ctx))
+
+ require.EventuallyWithT(t, func(t *assert.CollectT) {
+ for i := 0; i < 31; i++ {
+ session, err := p.cache.GetSnowflakeSession(ctx, types.GetSnowflakeSessionRequest{SessionID: "snow-session" + strconv.Itoa(i+1)})
+ assert.Error(t, err)
+ assert.Nil(t, session)
+ }
+ }, 15*time.Second, 100*time.Millisecond)
+}