Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions api/types/userloginstate/user_login_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/gravitational/teleport/api/types/header"
"github.com/gravitational/teleport/api/types/header/convert/legacy"
"github.com/gravitational/teleport/api/types/trait"
"github.com/gravitational/teleport/api/utils"
)

// UserLoginState is the ephemeral user login state. This will hold data to differentiate
Expand Down Expand Up @@ -99,6 +100,13 @@ func (u *UserLoginState) CheckAndSetDefaults() error {
return nil
}

// Clone returns a copy of the member.
func (u *UserLoginState) Clone() *UserLoginState {
var copy *UserLoginState
utils.StrictObjectToStruct(u, &copy)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does StrictObjectToStruct export a pointer to a pointer?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I shamelessly copied this from access lists. It looks like all existing variants of this resource style use the same pattern though. So if it's not correct we need to fix it everywhere.

func (a *AccessList) CloneResource() types.ResourceWithLabels {
var copy *AccessList
utils.StrictObjectToStruct(a, &copy)
return copy
}

func (a *DiscoveryConfig) CloneResource() types.ResourceWithLabels {
var copy *DiscoveryConfig
utils.StrictObjectToStruct(a, &copy)
return copy
}

func (a *ExternalAuditStorage) Clone() *ExternalAuditStorage {
var copy *ExternalAuditStorage
utils.StrictObjectToStruct(a, &copy)
return copy
}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we just pass it to (json.Decoder).Decode so we really don't need the double pointer, though it will work just fine.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we changed this to

	var copy UserLoginState
	utils.StrictObjectToStruct(u, &copy)
	return &copy

then ((*UserLoginState)(nil)).Clone() will be equivalent to &UserLoginState{}, right now it returns (*UserLoginState)(nil).

Obviously neither option is in any way sane and we should write or derive some clone function - and awalterschulze/goderive keeps distinguishing itself in making the entirely wrong choices once again, since it will unconditionally use unsafe and reflect to clone private fields of types which is absolutely wild - but in the meantime the way StrictObjectToStruct is being used is the least worst way.

return copy
}

// GetOriginalRoles returns the original roles that the user login state was derived from.
func (u *UserLoginState) GetOriginalRoles() []string {
return u.Spec.OriginalRoles
Expand Down
81 changes: 0 additions & 81 deletions lib/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,11 @@ import (
kubewaitingcontainerpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/kubewaitingcontainer/v1"
provisioningv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/provisioning/v1"
userprovisioningpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/userprovisioning/v2"
usertasksv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/usertasks/v1"
"github.com/gravitational/teleport/api/internalutils/stream"
apitracing "github.com/gravitational/teleport/api/observability/tracing"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/discoveryconfig"
"github.com/gravitational/teleport/api/types/secreports"
"github.com/gravitational/teleport/api/types/userloginstate"
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/lib/backend"
Expand Down Expand Up @@ -508,11 +506,9 @@ type Cache struct {
databaseObjectsCache *local.DatabaseObjectService
dynamicWindowsDesktopsCache services.DynamicWindowsDesktops
userGroupsCache services.UserGroups
userTasksCache services.UserTasks
discoveryConfigsCache services.DiscoveryConfigs
headlessAuthenticationsCache services.HeadlessAuthenticationService
secReportsCache services.SecReports
userLoginStateCache services.UserLoginStates
eventsFanout *services.FanoutV2
lowVolumeEventsFanout *utils.RoundRobin[*services.FanoutV2]
kubeWaitingContsCache *local.KubeWaitingContainerService
Expand Down Expand Up @@ -918,12 +914,6 @@ func New(config Config) (*Cache, error) {
return nil, trace.Wrap(err)
}

userTasksCache, err := local.NewUserTasksService(config.Backend)
if err != nil {
cancel()
return nil, trace.Wrap(err)
}

discoveryConfigsCache, err := local.NewDiscoveryConfigService(config.Backend)
if err != nil {
cancel()
Expand All @@ -936,12 +926,6 @@ func New(config Config) (*Cache, error) {
return nil, trace.Wrap(err)
}

userLoginStatesCache, err := local.NewUserLoginStateService(config.Backend)
if err != nil {
cancel()
return nil, trace.Wrap(err)
}

databaseObjectsCache, err := local.NewDatabaseObjectService(config.Backend)
if err != nil {
cancel()
Expand Down Expand Up @@ -1018,11 +1002,9 @@ func New(config Config) (*Cache, error) {
restrictionsCache: local.NewRestrictionsService(config.Backend),
dynamicWindowsDesktopsCache: dynamicDesktopsService,
userGroupsCache: userGroupsCache,
userTasksCache: userTasksCache,
discoveryConfigsCache: discoveryConfigsCache,
headlessAuthenticationsCache: identityService,
secReportsCache: secReportsCache,
userLoginStateCache: userLoginStatesCache,
databaseObjectsCache: databaseObjectsCache,
eventsFanout: fanout,
lowVolumeEventsFanout: utils.NewRoundRobin(lowVolumeFanouts),
Expand Down Expand Up @@ -1904,32 +1886,6 @@ func (c *Cache) ListDynamicWindowsDesktops(ctx context.Context, pageSize int, ne
return rg.reader.ListDynamicWindowsDesktops(ctx, pageSize, nextPage)
}

// ListUserTasks returns a list of UserTask resources.
func (c *Cache) ListUserTasks(ctx context.Context, pageSize int64, nextKey string, filters *usertasksv1.ListUserTasksFilters) ([]*usertasksv1.UserTask, string, error) {
ctx, span := c.Tracer.Start(ctx, "cache/ListUserTasks")
defer span.End()

rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.userTasks)
if err != nil {
return nil, "", trace.Wrap(err)
}
defer rg.Release()
return rg.reader.ListUserTasks(ctx, pageSize, nextKey, filters)
}

// GetUserTask returns the specified UserTask resource.
func (c *Cache) GetUserTask(ctx context.Context, name string) (*usertasksv1.UserTask, error) {
ctx, span := c.Tracer.Start(ctx, "cache/GetUserTask")
defer span.End()

rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.userTasks)
if err != nil {
return nil, trace.Wrap(err)
}
defer rg.Release()
return rg.reader.GetUserTask(ctx, name)
}

// ListDiscoveryConfigs returns a paginated list of all DiscoveryConfig resources.
func (c *Cache) ListDiscoveryConfigs(ctx context.Context, pageSize int, nextKey string) ([]*discoveryconfig.DiscoveryConfig, string, error) {
ctx, span := c.Tracer.Start(ctx, "cache/ListDiscoveryConfigs")
Expand Down Expand Up @@ -2073,43 +2029,6 @@ func (c *Cache) ListSecurityReportsStates(ctx context.Context, pageSize int, nex
return rg.reader.ListSecurityReportsStates(ctx, pageSize, nextKey)
}

// GetUserLoginStates returns the all user login state resources.
func (c *Cache) GetUserLoginStates(ctx context.Context) ([]*userloginstate.UserLoginState, error) {
ctx, span := c.Tracer.Start(ctx, "cache/GetUserLoginStates")
defer span.End()

rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.userLoginStates)
if err != nil {
return nil, trace.Wrap(err)
}
defer rg.Release()
return rg.reader.GetUserLoginStates(ctx)
}

// GetUserLoginState returns the specified user login state resource.
func (c *Cache) GetUserLoginState(ctx context.Context, name string) (*userloginstate.UserLoginState, error) {
ctx, span := c.Tracer.Start(ctx, "cache/GetUserLoginState")
defer span.End()

rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.userLoginStates)
if err != nil {
return nil, trace.Wrap(err)
}

uls, err := rg.reader.GetUserLoginState(ctx, name)
if trace.IsNotFound(err) && rg.IsCacheRead() {
// release read lock early
rg.Release()
// fallback is sane because method is never used
// in construction of derivative caches.
if uls, err := c.Config.UserLoginStates.GetUserLoginState(ctx, name); err == nil {
return uls, nil
}
}
defer rg.Release()
return uls, trace.Wrap(err)
}

// ListResources is a part of auth.Cache implementation
func (c *Cache) ListResources(ctx context.Context, req proto.ListResourcesRequest) (*types.ListResourcesResponse, error) {
ctx, span := c.Tracer.Start(ctx, "cache/ListResources")
Expand Down
55 changes: 0 additions & 55 deletions lib/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1223,34 +1223,6 @@ func mustCreateDatabase(t *testing.T, name, protocol, uri string) *types.Databas
return database
}

// TestUserTasks tests that CRUD operations on user notification resources are
// replicated from the backend to the cache.
func TestUserTasks(t *testing.T) {
t.Parallel()

p := newTestPack(t, ForAuth)
t.Cleanup(p.Close)

testResources153(t, p, testFuncs153[*usertasksv1.UserTask]{
newResource: func(name string) (*usertasksv1.UserTask, error) {
return newUserTasks(t), nil
},
create: func(ctx context.Context, item *usertasksv1.UserTask) error {
_, err := p.userTasks.CreateUserTask(ctx, item)
return trace.Wrap(err)
},
list: func(ctx context.Context) ([]*usertasksv1.UserTask, error) {
items, _, err := p.userTasks.ListUserTasks(ctx, 0, "", &usertasksv1.ListUserTasksFilters{})
return items, trace.Wrap(err)
},
cacheList: func(ctx context.Context) ([]*usertasksv1.UserTask, error) {
items, _, err := p.userTasks.ListUserTasks(ctx, 0, "", &usertasksv1.ListUserTasksFilters{})
return items, trace.Wrap(err)
},
deleteAll: p.userTasks.DeleteAllUserTasks,
})
}

func newUserTasks(t *testing.T) *usertasksv1.UserTask {
t.Helper()

Expand Down Expand Up @@ -1397,33 +1369,6 @@ func TestSecurityReportState(t *testing.T) {
})
}

// TestUserLoginStates tests that CRUD operations on user login state resources are
// replicated from the backend to the cache.
func TestUserLoginStates(t *testing.T) {
t.Parallel()

p := newTestPack(t, ForAuth)
t.Cleanup(p.Close)

testResources(t, p, testFuncs[*userloginstate.UserLoginState]{
newResource: func(name string) (*userloginstate.UserLoginState, error) {
return newUserLoginState(t, name), nil
},
create: func(ctx context.Context, uls *userloginstate.UserLoginState) error {
_, err := p.userLoginStates.UpsertUserLoginState(ctx, uls)
return trace.Wrap(err)
},
list: p.userLoginStates.GetUserLoginStates,
cacheGet: p.cache.GetUserLoginState,
cacheList: p.cache.GetUserLoginStates,
update: func(ctx context.Context, uls *userloginstate.UserLoginState) error {
_, err := p.userLoginStates.UpsertUserLoginState(ctx, uls)
return trace.Wrap(err)
},
deleteAll: p.userLoginStates.DeleteAllUserLoginStates,
})
}

func TestDatabaseObjects(t *testing.T) {
t.Parallel()

Expand Down
21 changes: 21 additions & 0 deletions lib/cache/collections.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ import (
identitycenterv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/identitycenter/v1"
machineidv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1"
notificationsv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/notifications/v1"
usertasksv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/usertasks/v1"
workloadidentityv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/accesslist"
"github.com/gravitational/teleport/api/types/userloginstate"
)

// collectionHandler is used by the [Cache] to seed the initial
Expand Down Expand Up @@ -113,6 +115,8 @@ type collections struct {
locks *collection[types.Lock, lockIndex]
tunnelConnections *collection[types.TunnelConnection, tunnelConnectionIndex]
remoteClusters *collection[types.RemoteCluster, remoteClusterIndex]
userTasks *collection[*usertasksv1.UserTask, userTaskIndex]
userLoginStates *collection[*userloginstate.UserLoginState, userLoginStateIndex]
}

// setupCollections ensures that the appropriate [collection] is
Expand Down Expand Up @@ -449,6 +453,7 @@ func setupCollections(c Config) (*collections, error) {

out.samlIdPSessions = collect
out.byKind[resourceKind] = out.samlIdPSessions

case types.KindWebSession:
collect, err := newWebSessionCollection(c.WebSession, watch)
if err != nil {
Expand Down Expand Up @@ -570,6 +575,22 @@ func setupCollections(c Config) (*collections, error) {

out.remoteClusters = collect
out.byKind[resourceKind] = out.remoteClusters
case types.KindUserTask:
collect, err := newUserTaskCollection(c.UserTasks, watch)
if err != nil {
return nil, trace.Wrap(err)
}

out.userTasks = collect
out.byKind[resourceKind] = out.userTasks
case types.KindUserLoginState:
collect, err := newUserLoginStateCollection(c.UserLoginStates, watch)
if err != nil {
return nil, trace.Wrap(err)
}

out.userLoginStates = collect
out.byKind[resourceKind] = out.userLoginStates
}
}

Expand Down
Loading
Loading