Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lib/services/local/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,9 @@ func (l *globalSessionDataLimiter) Reset() {
l.lastReset = time.Time{}
l.mu.Unlock()
}

const (
WebPrefix = webPrefix
UsersPrefix = usersPrefix
ParamsPrefix = paramsPrefix
)
83 changes: 51 additions & 32 deletions lib/services/local/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,58 +441,77 @@ func (s *IdentityService) UpsertUser(ctx context.Context, user types.User) (type
return user, nil
}

// CompareAndSwapUser updates a user, but fails if the value (as exists in the
// backend) differs from the provided `existing` value. If the existing value
// CompareAndSwapUser updates a user, but fails if the user (as exists in the
// backend) differs from the provided `existing` user. If the existing user
// matches, returns no error, otherwise returns `trace.CompareFailed`.
func (s *IdentityService) CompareAndSwapUser(ctx context.Context, new, existing types.User) error {
if new.GetName() != existing.GetName() {
return trace.BadParameter("name mismatch between new and existing user")
}
if err := services.ValidateUser(new); err != nil {
return trace.Wrap(err)
}

newRaw, ok := new.WithoutSecrets().(types.User)
newWithoutSecrets, ok := new.WithoutSecrets().(types.User)
if !ok {
return trace.BadParameter("Invalid user type %T", new)
return trace.BadParameter("invalid new user type %T (this is a bug)", new)
}
rev := new.GetRevision()
newValue, err := services.MarshalUser(newRaw)
if err != nil {
return trace.Wrap(err)

existingWithoutSecrets, ok := existing.WithoutSecrets().(types.User)
if !ok {
return trace.BadParameter("invalid existing user type %T (this is a bug)", existing)
}
newItem := backend.Item{

item := backend.Item{
Key: backend.Key(webPrefix, usersPrefix, new.GetName(), paramsPrefix),
Value: newValue,
Value: nil, // avoid marshaling new until we pass one comparison
Expires: new.Expiry(),
ID: new.GetResourceID(),
Revision: rev,
Revision: "",
}

existingRaw, ok := existing.WithoutSecrets().(types.User)
if !ok {
return trace.BadParameter("Invalid user type %T", existing)
}
existingValue, err := services.MarshalUser(existingRaw)
if err != nil {
return trace.Wrap(err)
}
existingItem := backend.Item{
Key: backend.Key(webPrefix, usersPrefix, existing.GetName(), paramsPrefix),
Value: existingValue,
}
// one retry because ConditionalUpdate could occasionally spuriously fail,
// another retry because a single retry would be weird
const iterationLimit = 3
for i := 0; i < iterationLimit; i++ {
const withoutSecrets = false
currentWithoutSecrets, err := s.GetUser(ctx, new.GetName(), withoutSecrets)
if err != nil {
if trace.IsNotFound(err) {
return trace.CompareFailed("user %v did not match expected existing value", new.GetName())
}
return trace.Wrap(err)
}

_, err = s.CompareAndSwap(ctx, existingItem, newItem)
if err != nil {
if trace.IsCompareFailed(err) {
if !services.UsersEquals(existingWithoutSecrets, currentWithoutSecrets) {
return trace.CompareFailed("user %v did not match expected existing value", new.GetName())
}
return trace.Wrap(err)
}

if auth := new.GetLocalAuth(); auth != nil {
if err = s.upsertLocalAuthSecrets(ctx, new.GetName(), *auth); err != nil {
if item.Value == nil {
v, err := services.MarshalUser(newWithoutSecrets)
if err != nil {
return trace.Wrap(err)
}
item.Value = v
}

item.Revision = currentWithoutSecrets.GetRevision()

if _, err = s.Backend.ConditionalUpdate(ctx, item); err != nil {
if trace.IsCompareFailed(err) {
continue
}
return trace.Wrap(err)
}

if auth := new.GetLocalAuth(); auth != nil {
Comment thread
codingllama marked this conversation as resolved.
if err = s.upsertLocalAuthSecrets(ctx, new.GetName(), *auth); err != nil {
return trace.Wrap(err)
}
}
return nil
}
return nil

return trace.LimitExceeded("failed to update user within %v iterations", iterationLimit)
}

// GetUser returns a user by name
Expand Down
55 changes: 55 additions & 0 deletions lib/services/local/users_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"crypto/x509"
"encoding/base32"
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
"slices"
Expand Down Expand Up @@ -1139,3 +1140,57 @@ func TestIdentityService_ListUsers(t *testing.T) {
})
require.Empty(t, cmp.Diff(expectedUsers, retrieved, cmpopts.SortSlices(devicesSort)), "not all users returned from listing operation")
}

func TestCompareAndSwapUser(t *testing.T) {
t.Parallel()
require := require.New(t)
ctx := context.Background()

identity := newIdentityService(t, clockwork.NewFakeClock())

bob1, err := types.NewUser("bob")
require.NoError(err)
bob1.SetLogins([]string{"bob"})

bob2, err := types.NewUser("bob")
require.NoError(err)
bob2.SetLogins([]string{"bob", "alice"})

require.False(services.UsersEquals(bob1, bob2))

currentBob, err := identity.UpsertUser(ctx, bob1)
require.NoError(err)
require.True(services.UsersEquals(currentBob, bob1))

currentBob, err = identity.GetUser(ctx, "bob", false)
require.NoError(err)
require.True(services.UsersEquals(currentBob, bob1))

err = identity.CompareAndSwapUser(ctx, bob2, bob1)
require.NoError(err)

currentBob, err = identity.GetUser(ctx, "bob", false)
require.NoError(err)
require.True(services.UsersEquals(currentBob, bob2))

item, err := identity.Backend.Get(ctx, backend.Key(local.WebPrefix, local.UsersPrefix, "bob", local.ParamsPrefix))
require.NoError(err)
var m map[string]any
require.NoError(json.Unmarshal(item.Value, &m))
m["deprecated_field"] = 42
item.Value, err = json.Marshal(m)
require.NoError(err)
_, err = identity.Backend.Put(ctx, *item)
require.NoError(err)

currentBob, err = identity.GetUser(ctx, "bob", false)
require.NoError(err)
require.True(services.UsersEquals(currentBob, bob2))

err = identity.CompareAndSwapUser(ctx, bob1, bob2)
require.NoError(err)

currentBob, err = identity.GetUser(ctx, "bob", false)
require.NoError(err)
require.True(services.UsersEquals(currentBob, bob1))
}