diff --git a/lib/services/local/userpreferences.go b/lib/services/local/userpreferences.go index 98badee39a828..f237c35787790 100644 --- a/lib/services/local/userpreferences.go +++ b/lib/services/local/userpreferences.go @@ -23,6 +23,7 @@ import ( "encoding/json" "github.com/gravitational/trace" + "google.golang.org/protobuf/reflect/protoreflect" userpreferencesv1 "github.com/gravitational/teleport/api/gen/proto/go/userpreferences/v1" "github.com/gravitational/teleport/lib/backend" @@ -36,12 +37,12 @@ type UserPreferencesService struct { func DefaultUserPreferences() *userpreferencesv1.UserPreferences { return &userpreferencesv1.UserPreferences{ Assist: &userpreferencesv1.AssistUserPreferences{ - PreferredLogins: nil, + PreferredLogins: []string{}, ViewMode: userpreferencesv1.AssistViewMode_ASSIST_VIEW_MODE_DOCKED, }, Theme: userpreferencesv1.Theme_THEME_LIGHT, Onboard: &userpreferencesv1.OnboardUserPreferences{ - PreferredResources: nil, + PreferredResources: []userpreferencesv1.Resource{}, }, } } @@ -84,7 +85,9 @@ func (u *UserPreferencesService) UpsertUserPreferences(ctx context.Context, req preferences = DefaultUserPreferences() } - mergePreferences(preferences, req.Preferences) + if err := overwriteValues(preferences, req.Preferences); err != nil { + return trace.Wrap(err) + } item, err := createBackendItem(req.Username, preferences) if err != nil { @@ -108,7 +111,15 @@ func (u *UserPreferencesService) getUserPreferences(ctx context.Context, usernam return nil, trace.Wrap(err) } - return &p, nil + // Appy the default values to the existing preferences. + // This allows updating the preferences schema without returning empty values + // for new fields in the existing preferences. + df := DefaultUserPreferences() + if err := overwriteValues(df, &p); err != nil { + return nil, trace.Wrap(err) + } + + return df, nil } // backendKey returns the backend key for the user preferences for the given username. @@ -142,36 +153,36 @@ func createBackendItem(username string, preferences *userpreferencesv1.UserPrefe return item, nil } -// mergePreferences merges the values from src into dest. -func mergePreferences(dest, src *userpreferencesv1.UserPreferences) { - if src.Theme != userpreferencesv1.Theme_THEME_UNSPECIFIED { - dest.Theme = src.Theme - } +// overwriteValues overwrites the values in dst with the values in src. +// This function uses proto.Ranges internally to iterate over the fields in src. +// Because of this, only non-nil/empty fields in src will overwrite the values in dst. +func overwriteValues(dst, src protoreflect.ProtoMessage) error { + d := dst.ProtoReflect() + s := src.ProtoReflect() - if src.Assist != nil { - mergeAssistUserPreferences(dest.Assist, src.Assist) + dName := d.Descriptor().FullName().Name() + sName := s.Descriptor().FullName().Name() + // If the names don't match, then the types don't match, so we can't overwrite. + if dName != sName { + return trace.BadParameter("dst and src must be the same type") } - if src.Onboard != nil { - mergeOnboardUserPreferences(dest.Onboard, src.Onboard) - } + overwriteValuesRecursive(d, s) + return nil } -// mergeAssistUserPreferences merges src preferences into the given dest assist user preferences. -func mergeAssistUserPreferences(dest, src *userpreferencesv1.AssistUserPreferences) { - if src.PreferredLogins != nil { - dest.PreferredLogins = src.PreferredLogins - } - - if src.ViewMode != userpreferencesv1.AssistViewMode_ASSIST_VIEW_MODE_UNSPECIFIED { - dest.ViewMode = src.ViewMode - } -} +// overwriteValuesRecursive recursively overwrites the values in dst with the values in src. +// It's a helper function for overwriteValues. +func overwriteValuesRecursive(dst, src protoreflect.Message) { + src.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { + switch { + case fd.Message() != nil: + overwriteValuesRecursive(dst.Mutable(fd).Message(), src.Get(fd).Message()) + default: + dst.Set(fd, src.Get(fd)) + } -// mergeOnboardUserPreferences merges src preferences into the given dest onboard user preferences. -func mergeOnboardUserPreferences(dest, src *userpreferencesv1.OnboardUserPreferences) { - if src.PreferredResources != nil { - dest.PreferredResources = src.PreferredResources - } + return true + }) } diff --git a/lib/services/local/userpreferences_test.go b/lib/services/local/userpreferences_test.go index 29057fd9dbf5e..0bcf015c773a4 100644 --- a/lib/services/local/userpreferences_test.go +++ b/lib/services/local/userpreferences_test.go @@ -18,12 +18,17 @@ package local_test import ( "context" + "encoding/json" "testing" + "github.com/google/go-cmp/cmp" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/testing/protocmp" userpreferencesv1 "github.com/gravitational/teleport/api/gen/proto/go/userpreferences/v1" + "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/backend/memory" "github.com/gravitational/teleport/lib/services/local" ) @@ -39,6 +44,8 @@ func newUserPreferencesService(t *testing.T) *local.UserPreferencesService { } func TestUserPreferencesCRUD2(t *testing.T) { + t.Parallel() + ctx := context.Background() defaultPref := local.DefaultUserPreferences() username := "something" @@ -75,6 +82,9 @@ func TestUserPreferencesCRUD2(t *testing.T) { Assist: &userpreferencesv1.AssistUserPreferences{ PreferredLogins: []string{"foo", "bar"}, }, + Onboard: &userpreferencesv1.OnboardUserPreferences{ + PreferredResources: []userpreferencesv1.Resource{}, + }, }, }, expected: &userpreferencesv1.UserPreferences{ @@ -152,14 +162,18 @@ func TestUserPreferencesCRUD2(t *testing.T) { } for _, test := range tests { + test := test t.Run(test.name, func(t *testing.T) { + t.Parallel() + identity := newUserPreferencesService(t) res, err := identity.GetUserPreferences(ctx, &userpreferencesv1.GetUserPreferencesRequest{ Username: username, }) require.NoError(t, err) - require.Equal(t, defaultPref, res.Preferences) + // Clone the proto as the accessing fields for some reason modifies the state. + require.Empty(t, cmp.Diff(defaultPref, proto.Clone(res.Preferences), protocmp.Transform())) if test.req != nil { err := identity.UpsertUserPreferences(ctx, test.req) @@ -171,7 +185,43 @@ func TestUserPreferencesCRUD2(t *testing.T) { }) require.NoError(t, err) - require.Equal(t, test.expected, res.Preferences) + require.Empty(t, cmp.Diff(test.expected, res.Preferences, protocmp.Transform())) }) } } + +func TestLayoutUpdate(t *testing.T) { + t.Parallel() + + ctx := context.Background() + identity := newUserPreferencesService(t) + + outdatedPrefs := &userpreferencesv1.UserPreferences{ + Assist: &userpreferencesv1.AssistUserPreferences{ + PreferredLogins: []string{"foo", "bar"}, + }, + } + val, err := json.Marshal(outdatedPrefs) + require.NoError(t, err) + + // Insert the outdated preferences directly into the backend + // to simulate a previous version of the preferences. + _, err = identity.Put(ctx, backend.Item{ + Key: backend.Key("user_preferences", "test"), + Value: val, + }) + require.NoError(t, err) + + // Get the preferences and ensure that the layout is updated. + prefs, err := identity.GetUserPreferences(ctx, &userpreferencesv1.GetUserPreferencesRequest{ + Username: "test", + }) + require.NoError(t, err) + // The layout should be updated to the latest version (values should not be nil). + require.NotNil(t, prefs.Preferences.Onboard) + // Non-existing values should be set to the default value. + require.Equal(t, userpreferencesv1.AssistViewMode_ASSIST_VIEW_MODE_DOCKED, prefs.Preferences.Assist.ViewMode) + require.Equal(t, userpreferencesv1.Theme_THEME_LIGHT, prefs.Preferences.Theme) + // Existing values should be preserved. + require.Equal(t, []string{"foo", "bar"}, prefs.Preferences.Assist.PreferredLogins) +}