Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions api/types/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
"time"

"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/utils"
)

// WebSessionsGetter provides access to web sessions
Expand Down Expand Up @@ -115,6 +117,8 @@ type WebSession interface {
// requirement.
// See [TrustedDeviceRequirement].
GetTrustedDeviceRequirement() TrustedDeviceRequirement
// Copy returns a clone of the session resource.
Copy() WebSession
}

// NewWebSession returns new instance of the web session based on the V2 spec
Expand All @@ -138,6 +142,11 @@ func (ws *WebSessionV2) GetKind() string {
return ws.Kind
}

// Copy returns a clone of the session resource.
func (ws *WebSessionV2) Copy() WebSession {
return utils.CloneProtoMsg(ws)
}

// GetSubKind gets resource SubKind
func (ws *WebSessionV2) GetSubKind() string {
return ws.SubKind
Expand Down
171 changes: 0 additions & 171 deletions lib/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -513,13 +513,8 @@ type Cache struct {
restrictionsCache services.Restrictions
crownJewelsCache services.CrownJewels
databaseObjectsCache *local.DatabaseObjectService
appSessionCache services.AppSession
snowflakeSessionCache services.SnowflakeSession
samlIdPSessionCache services.SAMLIdPSession //nolint:revive // Because we want this to be IdP.
webSessionCache types.WebSessionInterface
webTokenCache types.WebTokenInterface
dynamicWindowsDesktopsCache services.DynamicWindowsDesktops
samlIdPServiceProvidersCache services.SAMLIdPServiceProviders //nolint:revive // Because we want this to be IdP.
userGroupsCache services.UserGroups
integrationsCache services.Integrations
userTasksCache services.UserTasks
Expand Down Expand Up @@ -922,13 +917,6 @@ func New(config Config) (*Cache, error) {
return nil, trace.Wrap(err)
}

//nolint:revive // Because we want this to be IdP.
samlIdPServiceProvidersCache, err := local.NewSAMLIdPServiceProviderService(config.Backend)
if err != nil {
cancel()
return nil, trace.Wrap(err)
}

userGroupsCache, err := local.NewUserGroupService(config.Backend)
if err != nil {
cancel()
Expand Down Expand Up @@ -1064,14 +1052,9 @@ func New(config Config) (*Cache, error) {
presenceCache: local.NewPresenceService(config.Backend),
restrictionsCache: local.NewRestrictionsService(config.Backend),
crownJewelsCache: crownJewelCache,
appSessionCache: identityService,
snowflakeSessionCache: identityService,
samlIdPSessionCache: identityService,
webSessionCache: identityService.WebSessions(),
webTokenCache: identityService.WebTokens(),
dynamicWindowsDesktopsCache: dynamicDesktopsService,
accessMontoringRuleCache: accessMonitoringRuleCache,
samlIdPServiceProvidersCache: samlIdPServiceProvidersCache,
userGroupsCache: userGroupsCache,
integrationsCache: integrationsCache,
userTasksCache: userTasksCache,
Expand Down Expand Up @@ -2047,105 +2030,6 @@ func (c *Cache) GetStaticHostUser(ctx context.Context, name string) (*userprovis
return rg.reader.GetStaticHostUser(ctx, name)
}

// GetAppSession gets an application web session.
func (c *Cache) GetAppSession(ctx context.Context, req types.GetAppSessionRequest) (types.WebSession, error) {
ctx, span := c.Tracer.Start(ctx, "cache/GetAppSession")
defer span.End()

rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.appSessions)
if err != nil {
return nil, trace.Wrap(err)
}
defer rg.Release()
sess, err := rg.reader.GetAppSession(ctx, req)
if trace.IsNotFound(err) && rg.IsCacheRead() {
// release read lock early
rg.Release()
// fallback is sane because method is never used
// in construction of derivative caches.
if sess, err := c.Config.AppSession.GetAppSession(ctx, req); err == nil {
c.Logger.DebugContext(ctx, "Cache was forced to load session from upstream",
"session_kind", sess.GetSubKind(),
"session", sess.GetName(),
)
return sess, nil
}
}

return sess, trace.Wrap(err)
}

// ListAppSessions returns a page of application web sessions.
func (c *Cache) ListAppSessions(ctx context.Context, pageSize int, pageToken, user string) ([]types.WebSession, string, error) {
ctx, span := c.Tracer.Start(ctx, "cache/ListAppSessions")
defer span.End()

rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.appSessions)
if err != nil {
return nil, "", trace.Wrap(err)
}
defer rg.Release()
return rg.reader.ListAppSessions(ctx, pageSize, pageToken, user)
}

// GetSnowflakeSession gets Snowflake web session.
func (c *Cache) GetSnowflakeSession(ctx context.Context, req types.GetSnowflakeSessionRequest) (types.WebSession, error) {
ctx, span := c.Tracer.Start(ctx, "cache/GetSnowflakeSession")
defer span.End()

rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.snowflakeSessions)
if err != nil {
return nil, trace.Wrap(err)
}
defer rg.Release()

sess, err := rg.reader.GetSnowflakeSession(ctx, req)
if trace.IsNotFound(err) && rg.IsCacheRead() {
// release read lock early
rg.Release()
// fallback is sane because method is never used
// in construction of derivative caches.
if sess, err := c.Config.SnowflakeSession.GetSnowflakeSession(ctx, req); err == nil {
c.Logger.DebugContext(ctx, "Cache was forced to load sessionfrom upstream",
"session_kind", sess.GetSubKind(),
"session", sess.GetName(),
)
return sess, nil
}
}

return sess, trace.Wrap(err)
}

// GetSAMLIdPSession gets a SAML IdP session.
func (c *Cache) GetSAMLIdPSession(ctx context.Context, req types.GetSAMLIdPSessionRequest) (types.WebSession, error) {
ctx, span := c.Tracer.Start(ctx, "cache/GetSAMLIdPSession")
defer span.End()

rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.samlIdPSessions)
if err != nil {
return nil, trace.Wrap(err)
}
defer rg.Release()

sess, err := rg.reader.GetSAMLIdPSession(ctx, req)
if trace.IsNotFound(err) && rg.IsCacheRead() {
// release read lock early
rg.Release()
// fallback is sane because method is never used
// in construction of derivative caches.
if sess, err := c.Config.SAMLIdPSession.GetSAMLIdPSession(ctx, req); err == nil {
c.Logger.DebugContext(ctx, "Cache was forced to load sessionfrom upstream",
"session_kind", sess.GetSubKind(),
"session", sess.GetName(),
)
return sess, nil
}
}

return sess, trace.Wrap(err)
}

func (c *Cache) GetDatabaseObject(ctx context.Context, name string) (*dbobjectv1.DatabaseObject, error) {
ctx, span := c.Tracer.Start(ctx, "cache/GetDatabaseObject")
defer span.End()
Expand All @@ -2170,35 +2054,6 @@ func (c *Cache) ListDatabaseObjects(ctx context.Context, size int, pageToken str
return rg.reader.ListDatabaseObjects(ctx, size, pageToken)
}

// GetWebSession gets a regular web session.
func (c *Cache) GetWebSession(ctx context.Context, req types.GetWebSessionRequest) (types.WebSession, error) {
ctx, span := c.Tracer.Start(ctx, "cache/GetWebSession")
defer span.End()

rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.webSessions)
if err != nil {
return nil, trace.Wrap(err)
}
defer rg.Release()

sess, err := rg.reader.Get(ctx, req)

if trace.IsNotFound(err) && rg.IsCacheRead() {
// release read lock early
rg.Release()
// fallback is sane because method is never used
// in construction of derivative caches.
if sess, err := c.Config.WebSession.Get(ctx, req); err == nil {
c.Logger.DebugContext(ctx, "Cache was forced to load sessionfrom upstream",
"session_kind", sess.GetSubKind(),
"session", sess.GetName(),
)
return sess, nil
}
}
return sess, trace.Wrap(err)
}

// GetWebToken gets a web token.
func (c *Cache) GetWebToken(ctx context.Context, req types.GetWebTokenRequest) (types.WebToken, error) {
ctx, span := c.Tracer.Start(ctx, "cache/GetWebToken")
Expand Down Expand Up @@ -2290,32 +2145,6 @@ func (c *Cache) ListDynamicWindowsDesktops(ctx context.Context, pageSize int, ne
return rg.reader.ListDynamicWindowsDesktops(ctx, pageSize, nextPage)
}

// ListSAMLIdPServiceProviders returns a paginated list of SAML IdP service provider resources.
func (c *Cache) ListSAMLIdPServiceProviders(ctx context.Context, pageSize int, nextKey string) ([]types.SAMLIdPServiceProvider, string, error) {
ctx, span := c.Tracer.Start(ctx, "cache/ListSAMLIdPServiceProviders")
defer span.End()

rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.samlIdPServiceProviders)
if err != nil {
return nil, "", trace.Wrap(err)
}
defer rg.Release()
return rg.reader.ListSAMLIdPServiceProviders(ctx, pageSize, nextKey)
}

// GetSAMLIdPServiceProvider returns the specified SAML IdP service provider resources.
func (c *Cache) GetSAMLIdPServiceProvider(ctx context.Context, name string) (types.SAMLIdPServiceProvider, error) {
ctx, span := c.Tracer.Start(ctx, "cache/GetSAMLIdPServiceProvider")
defer span.End()

rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.samlIdPServiceProviders)
if err != nil {
return nil, trace.Wrap(err)
}
defer rg.Release()
return rg.reader.GetSAMLIdPServiceProvider(ctx, name)
}

// ListIntegrations returns a paginated list of all Integrations resources.
func (c *Cache) ListIntegrations(ctx context.Context, pageSize int, nextKey string) ([]types.Integration, string, error) {
ctx, span := c.Tracer.Start(ctx, "cache/ListIntegrations")
Expand Down
34 changes: 0 additions & 34 deletions lib/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1288,40 +1288,6 @@ func mustCreateDatabase(t *testing.T, name, protocol, uri string) *types.Databas
return database
}

// TestSAMLIdPServiceProviders tests that CRUD operations on SAML IdP service provider resources are
// replicated from the backend to the cache.
func TestSAMLIdPServiceProviders(t *testing.T) {
t.Parallel()

p := newTestPack(t, ForAuth)
t.Cleanup(p.Close)

testResources(t, p, testFuncs[types.SAMLIdPServiceProvider]{
newResource: func(name string) (types.SAMLIdPServiceProvider, error) {
return types.NewSAMLIdPServiceProvider(
types.Metadata{
Name: name,
},
types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: testEntityDescriptor,
EntityID: "IAMShowcase",
})
},
create: p.samlIDPServiceProviders.CreateSAMLIdPServiceProvider,
list: func(ctx context.Context) ([]types.SAMLIdPServiceProvider, error) {
results, _, err := p.samlIDPServiceProviders.ListSAMLIdPServiceProviders(ctx, 0, "")
return results, err
},
cacheGet: p.cache.GetSAMLIdPServiceProvider,
cacheList: func(ctx context.Context) ([]types.SAMLIdPServiceProvider, error) {
results, _, err := p.cache.ListSAMLIdPServiceProviders(ctx, 0, "")
return results, err
},
update: p.samlIDPServiceProviders.UpdateSAMLIdPServiceProvider,
deleteAll: p.samlIDPServiceProviders.DeleteAllSAMLIdPServiceProviders,
})
}

// TestLocks tests that CRUD operations on lock resources are
// replicated from the backend to the cache.
func TestLocks(t *testing.T) {
Expand Down
2 changes: 0 additions & 2 deletions lib/cache/cert_authority_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@ func TestNodeCAFiltering(t *testing.T) {
DynamicAccess: p.cache.dynamicAccessCache,
Presence: p.cache.presenceCache,
Restrictions: p.cache.restrictionsCache,
AppSession: p.cache.appSessionCache,
WebSession: p.cache.webSessionCache,
WebToken: p.cache.webTokenCache,
DynamicWindowsDesktops: p.cache.dynamicWindowsDesktopsCache,
SAMLIdPServiceProviders: p.samlIDPServiceProviders,
Expand Down
48 changes: 48 additions & 0 deletions lib/cache/collections.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ type collections struct {
autoUpdateRollout *collection[*autoupdatev1.AutoUpdateAgentRollout, autoUpdateAgentRolloutIndex]
oktaImportRules *collection[types.OktaImportRule, oktaImportRuleIndex]
oktaAssignments *collection[types.OktaAssignment, oktaAssignmentIndex]
samlIdPServiceProviders *collection[types.SAMLIdPServiceProvider, samlIdPServiceProviderIndex]
samlIdPSessions *collection[types.WebSession, samlIdPSessionIndex]
webSessions *collection[types.WebSession, webSessionIndex]
appSessions *collection[types.WebSession, appSessionIndex]
snowflakeSessions *collection[types.WebSession, snowflakeSessionIndex]
}

// setupCollections ensures that the appropriate [collection] is
Expand Down Expand Up @@ -392,6 +397,49 @@ func setupCollections(c Config) (*collections, error) {

out.oktaAssignments = collect
out.byKind[resourceKind] = out.oktaAssignments
case types.KindSAMLIdPServiceProvider:
collect, err := newSAMLIdPServiceProviderCollection(c.SAMLIdPServiceProviders, watch)
if err != nil {
return nil, trace.Wrap(err)
}

out.samlIdPServiceProviders = collect
out.byKind[resourceKind] = out.samlIdPServiceProviders
case types.KindWebSession:
switch watch.SubKind {
case types.KindAppSession:
collect, err := newAppSessionCollection(c.AppSession, watch)
if err != nil {
return nil, trace.Wrap(err)
}

out.appSessions = collect
out.byKind[resourceKind] = out.appSessions
case types.KindSnowflakeSession:
collect, err := newSnowflakeSessionCollection(c.SnowflakeSession, watch)
if err != nil {
return nil, trace.Wrap(err)
}

out.snowflakeSessions = collect
out.byKind[resourceKind] = out.snowflakeSessions
case types.KindSAMLIdPSession:
collect, err := newSAMLIdPSessionCollection(c.SAMLIdPSession, watch)
if err != nil {
return nil, trace.Wrap(err)
}

out.samlIdPSessions = collect
out.byKind[resourceKind] = out.samlIdPSessions
case types.KindWebSession:
collect, err := newWebSessionCollection(c.WebSession, watch)
if err != nil {
return nil, trace.Wrap(err)
}

out.webSessions = collect
out.byKind[resourceKind] = out.webSessions
}
}
}

Expand Down
5 changes: 4 additions & 1 deletion lib/cache/generic_operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ func (g genericGetter[T, I]) get(ctx context.Context, identifier string) (T, err
}

out, err := rg.store.get(g.index, identifier)
return g.clone(out), trace.Wrap(err)
if err != nil {
return t, trace.Wrap(err)
}
return g.clone(out), nil
}

// genericLister is a helper to retrieve a page of items from a cache collection.
Expand Down
Loading
Loading