Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 40 additions & 29 deletions lib/services/local/userpreferences.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"encoding/json"

"github.com/gravitational/trace"
"google.golang.org/protobuf/reflect/protoreflect"

userpreferencesv1 "github.com/gravitational/teleport/api/gen/proto/go/userpreferences/v1"
"github.com/gravitational/teleport/lib/backend"
Expand All @@ -36,12 +37,12 @@ type UserPreferencesService struct {
func DefaultUserPreferences() *userpreferencesv1.UserPreferences {
return &userpreferencesv1.UserPreferences{
Assist: &userpreferencesv1.AssistUserPreferences{
PreferredLogins: nil,
PreferredLogins: []string{},
ViewMode: userpreferencesv1.AssistViewMode_ASSIST_VIEW_MODE_DOCKED,
},
Theme: userpreferencesv1.Theme_THEME_LIGHT,
Onboard: &userpreferencesv1.OnboardUserPreferences{
PreferredResources: nil,
PreferredResources: []userpreferencesv1.Resource{},
},
}
}
Expand Down Expand Up @@ -84,7 +85,9 @@ func (u *UserPreferencesService) UpsertUserPreferences(ctx context.Context, req
preferences = DefaultUserPreferences()
}

mergePreferences(preferences, req.Preferences)
if err := overwriteValues(preferences, req.Preferences); err != nil {
return trace.Wrap(err)
}

item, err := createBackendItem(req.Username, preferences)
if err != nil {
Expand All @@ -108,7 +111,15 @@ func (u *UserPreferencesService) getUserPreferences(ctx context.Context, usernam
return nil, trace.Wrap(err)
}

return &p, nil
// Appy the default values to the existing preferences.
// This allows updating the preferences schema without returning empty values
// for new fields in the existing preferences.
df := DefaultUserPreferences()
if err := overwriteValues(df, &p); err != nil {
return nil, trace.Wrap(err)
}

return df, nil
}

// backendKey returns the backend key for the user preferences for the given username.
Expand Down Expand Up @@ -142,36 +153,36 @@ func createBackendItem(username string, preferences *userpreferencesv1.UserPrefe
return item, nil
}

// mergePreferences merges the values from src into dest.
func mergePreferences(dest, src *userpreferencesv1.UserPreferences) {
if src.Theme != userpreferencesv1.Theme_THEME_UNSPECIFIED {
dest.Theme = src.Theme
}
// overwriteValues overwrites the values in dst with the values in src.
// This function uses proto.Ranges internally to iterate over the fields in src.
// Because of this, only non-nil/empty fields in src will overwrite the values in dst.
func overwriteValues(dst, src protoreflect.ProtoMessage) error {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please add a comment mentioning that it only works because proto.range only visits non-nil/non-empty fields?

It would become clear to anyone using it how it works.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

d := dst.ProtoReflect()
s := src.ProtoReflect()

if src.Assist != nil {
mergeAssistUserPreferences(dest.Assist, src.Assist)
dName := d.Descriptor().FullName().Name()
sName := s.Descriptor().FullName().Name()
// If the names don't match, then the types don't match, so we can't overwrite.
if dName != sName {
return trace.BadParameter("dst and src must be the same type")
}

if src.Onboard != nil {
mergeOnboardUserPreferences(dest.Onboard, src.Onboard)
}
overwriteValuesRecursive(d, s)

return nil
}

// mergeAssistUserPreferences merges src preferences into the given dest assist user preferences.
func mergeAssistUserPreferences(dest, src *userpreferencesv1.AssistUserPreferences) {
if src.PreferredLogins != nil {
dest.PreferredLogins = src.PreferredLogins
}

if src.ViewMode != userpreferencesv1.AssistViewMode_ASSIST_VIEW_MODE_UNSPECIFIED {
dest.ViewMode = src.ViewMode
}
}
// overwriteValuesRecursive recursively overwrites the values in dst with the values in src.
// It's a helper function for overwriteValues.
func overwriteValuesRecursive(dst, src protoreflect.Message) {
src.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
switch {
case fd.Message() != nil:
overwriteValuesRecursive(dst.Mutable(fd).Message(), src.Get(fd).Message())
default:
dst.Set(fd, src.Get(fd))
}

// mergeOnboardUserPreferences merges src preferences into the given dest onboard user preferences.
func mergeOnboardUserPreferences(dest, src *userpreferencesv1.OnboardUserPreferences) {
if src.PreferredResources != nil {
dest.PreferredResources = src.PreferredResources
}
return true
})
}
54 changes: 52 additions & 2 deletions lib/services/local/userpreferences_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,17 @@ package local_test

import (
"context"
"encoding/json"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/testing/protocmp"

userpreferencesv1 "github.com/gravitational/teleport/api/gen/proto/go/userpreferences/v1"
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/backend/memory"
"github.com/gravitational/teleport/lib/services/local"
)
Expand All @@ -39,6 +44,8 @@ func newUserPreferencesService(t *testing.T) *local.UserPreferencesService {
}

func TestUserPreferencesCRUD2(t *testing.T) {
t.Parallel()

ctx := context.Background()
defaultPref := local.DefaultUserPreferences()
username := "something"
Expand Down Expand Up @@ -75,6 +82,9 @@ func TestUserPreferencesCRUD2(t *testing.T) {
Assist: &userpreferencesv1.AssistUserPreferences{
PreferredLogins: []string{"foo", "bar"},
},
Onboard: &userpreferencesv1.OnboardUserPreferences{
PreferredResources: []userpreferencesv1.Resource{},
},
},
},
expected: &userpreferencesv1.UserPreferences{
Expand Down Expand Up @@ -152,14 +162,18 @@ func TestUserPreferencesCRUD2(t *testing.T) {
}

for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()

identity := newUserPreferencesService(t)

res, err := identity.GetUserPreferences(ctx, &userpreferencesv1.GetUserPreferencesRequest{
Username: username,
})
require.NoError(t, err)
require.Equal(t, defaultPref, res.Preferences)
// Clone the proto as the accessing fields for some reason modifies the state.
require.Empty(t, cmp.Diff(defaultPref, proto.Clone(res.Preferences), protocmp.Transform()))

if test.req != nil {
err := identity.UpsertUserPreferences(ctx, test.req)
Expand All @@ -171,7 +185,43 @@ func TestUserPreferencesCRUD2(t *testing.T) {
})

require.NoError(t, err)
require.Equal(t, test.expected, res.Preferences)
require.Empty(t, cmp.Diff(test.expected, res.Preferences, protocmp.Transform()))
})
}
}

func TestLayoutUpdate(t *testing.T) {
t.Parallel()

ctx := context.Background()
identity := newUserPreferencesService(t)

outdatedPrefs := &userpreferencesv1.UserPreferences{
Assist: &userpreferencesv1.AssistUserPreferences{
PreferredLogins: []string{"foo", "bar"},
},
}
val, err := json.Marshal(outdatedPrefs)
require.NoError(t, err)

// Insert the outdated preferences directly into the backend
// to simulate a previous version of the preferences.
_, err = identity.Put(ctx, backend.Item{
Key: backend.Key("user_preferences", "test"),
Value: val,
})
require.NoError(t, err)

// Get the preferences and ensure that the layout is updated.
prefs, err := identity.GetUserPreferences(ctx, &userpreferencesv1.GetUserPreferencesRequest{
Username: "test",
})
require.NoError(t, err)
// The layout should be updated to the latest version (values should not be nil).
require.NotNil(t, prefs.Preferences.Onboard)
// Non-existing values should be set to the default value.
require.Equal(t, userpreferencesv1.AssistViewMode_ASSIST_VIEW_MODE_DOCKED, prefs.Preferences.Assist.ViewMode)
require.Equal(t, userpreferencesv1.Theme_THEME_LIGHT, prefs.Preferences.Theme)
// Existing values should be preserved.
require.Equal(t, []string{"foo", "bar"}, prefs.Preferences.Assist.PreferredLogins)
}