diff --git a/lib/srv/db/cloud/users/helpers.go b/lib/srv/db/cloud/users/helpers.go index 7d074bdc66e1b..620b7019183b2 100644 --- a/lib/srv/db/cloud/users/helpers.go +++ b/lib/srv/db/cloud/users/helpers.go @@ -29,16 +29,22 @@ import ( "github.com/gravitational/teleport/lib/utils" ) -// lookupMap is a mapping of database objects to their managed users. +// lookupEntry is the entry value for lookupMap. +type lookupEntry struct { + database types.Database + users []User +} + +// lookupMap is a mapping of database names to their managed users. type lookupMap struct { - byDatabase map[types.Database][]User - mu sync.RWMutex + byName map[string]lookupEntry + mu sync.RWMutex } // newLookupMap creates a new lookup map. func newLookupMap() *lookupMap { return &lookupMap{ - byDatabase: make(map[types.Database][]User), + byName: make(map[string]lookupEntry), } } @@ -47,7 +53,7 @@ func (m *lookupMap) getDatabaseUser(database types.Database, username string) (U m.mu.RLock() defer m.mu.RUnlock() - for _, user := range m.byDatabase[database] { + for _, user := range m.byName[database.GetName()].users { if user.GetDatabaseUsername() == username { return user, true } @@ -61,9 +67,12 @@ func (m *lookupMap) setDatabaseUsers(database types.Database, users []User) { defer m.mu.Unlock() if len(users) > 0 { - m.byDatabase[database] = users + m.byName[database.GetName()] = lookupEntry{ + database: database, + users: users, + } } else { - delete(m.byDatabase, database) + delete(m.byName, database.GetName()) // Short circuit. if len(database.GetManagedUsers()) == 0 { @@ -79,27 +88,29 @@ func (m *lookupMap) setDatabaseUsers(database types.Database, users []User) { database.SetManagedUsers(usernames) } -// removeUnusedDatabases removes unused databases by comparing with provided -// active databases. -func (m *lookupMap) removeUnusedDatabases(activeDatabases types.Databases) { +func (m *lookupMap) removeIfURIChanged(database types.Database) { m.mu.Lock() defer m.mu.Unlock() - for database := range m.byDatabase { - if isActive := findDatabase(activeDatabases, database); !isActive { - delete(m.byDatabase, database) - } + current, ok := m.byName[database.GetName()] + if !ok || current.database.GetURI() == database.GetURI() { + return } + delete(m.byName, database.GetName()) } -// findDatabase finds the database object in provided list of databases. -func findDatabase(databases types.Databases, database types.Database) bool { - for i := range databases { - if databases[i] == database { - return true +// removeUnusedDatabases removes unused databases by comparing with provided +// active databases. +func (m *lookupMap) removeUnusedDatabases(activeDatabases types.Databases) { + m.mu.Lock() + defer m.mu.Unlock() + + activeDatabasesMap := activeDatabases.ToMap() + for databaseName := range m.byName { + if _, isActive := activeDatabasesMap[databaseName]; !isActive { + delete(m.byName, databaseName) } } - return false } // usersByID returns a map of users by their IDs. @@ -108,8 +119,8 @@ func (m *lookupMap) usersByID() map[string]User { defer m.mu.RUnlock() usersByID := make(map[string]User) - for _, users := range m.byDatabase { - for _, user := range users { + for _, entry := range m.byName { + for _, user := range entry.users { usersByID[user.GetID()] = user } } diff --git a/lib/srv/db/cloud/users/helpers_test.go b/lib/srv/db/cloud/users/helpers_test.go index e26774c5a312e..c09b67e778c06 100644 --- a/lib/srv/db/cloud/users/helpers_test.go +++ b/lib/srv/db/cloud/users/helpers_test.go @@ -77,6 +77,18 @@ func TestLookupMap(t *testing.T) { "userID3": user3, }, lookup.usersByID()) }) + + t.Run("removeIfURIChanged", func(t *testing.T) { + // URI does not change. No users should be removed. + lookup.removeIfURIChanged(db3) + require.Equal(t, map[string]User{ + "userID3": user3, + }, lookup.usersByID()) + + // Now replace with a RDS. + lookup.removeIfURIChanged(mustCreateRDSDatabase(t, "db3")) + require.Empty(t, lookup.usersByID()) + }) } func TestGenRandomPassword(t *testing.T) { diff --git a/lib/srv/db/cloud/users/users.go b/lib/srv/db/cloud/users/users.go index 0eac604c1d804..e03b60785ed55 100644 --- a/lib/srv/db/cloud/users/users.go +++ b/lib/srv/db/cloud/users/users.go @@ -202,6 +202,10 @@ func (u *Users) setupAllDatabasesAndRotatePassowrds(ctx context.Context, allData // rotate user passwords. func (u *Users) setupDatabasesAndRotatePasswords(ctx context.Context, databases types.Databases, updateMeta bool) { for _, database := range databases { + // Reset cache in case the same database name is now used for a + // different database server. + u.lookup.removeIfURIChanged(database) + fetcher, found := u.fetchersByType[database.GetType()] if !found { continue diff --git a/lib/srv/db/cloud/users/users_test.go b/lib/srv/db/cloud/users/users_test.go index 94d29a8a3686b..25e3501da7364 100644 --- a/lib/srv/db/cloud/users/users_test.go +++ b/lib/srv/db/cloud/users/users_test.go @@ -119,12 +119,23 @@ func TestUsers(t *testing.T) { // Validate db6 is same as before. requireDatabaseWithManagedUsers(t, users, db6, []string{"alice", "bob"}) }) + + t.Run("new database with same name", func(t *testing.T) { + newDB6 := mustCreateRDSDatabase(t, "db6") + users.setupDatabaseAndRotatePasswords(ctx, newDB6) + + // Make sure no users are cached for "db6". + _, err := users.GetPassword(context.Background(), db6, "alice") + require.Error(t, err) + }) } func requireDatabaseWithManagedUsers(t *testing.T, users *Users, db types.Database, managedUsers []string) { require.Equal(t, managedUsers, db.GetManagedUsers()) for _, username := range managedUsers { - password, err := users.GetPassword(context.TODO(), db, username) + // Usually a copy of the proxied database is passed to the engine + // instead of the same object. + password, err := users.GetPassword(context.Background(), db.Copy(), username) require.NoError(t, err) require.NotEmpty(t, password) }