Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -1402,9 +1402,6 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
return fmt.Errorf("error saving groups: %w", err)
}

addNewGroups = util.Difference(updatedAutoGroups, user.AutoGroups)
removeOldGroups = util.Difference(user.AutoGroups, updatedAutoGroups)

user.AutoGroups = updatedAutoGroups
if err = transaction.SaveUser(ctx, user); err != nil {
return fmt.Errorf("error saving user: %w", err)
Expand Down
62 changes: 62 additions & 0 deletions management/server/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,7 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
}

func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
b.Setenv("NETBIRD_STORE_ENGINE", "postgres")
claims := auth.UserAuth{
Domain: "example.com",
UserId: "pvt-domain-user",
Expand Down Expand Up @@ -945,6 +946,18 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
b.Fatal(err)
}

a, err := am.Store.GetAccount(context.Background(), id)
if err != nil {
b.Fatal(err)
}

a.Groups = genGroups()

err = am.Store.SaveAccount(context.Background(), a)
if err != nil {
b.Fatal(err)
}

users := genUsers("priv", 100)

acc, err := am.Store.GetAccount(context.Background(), id)
Expand Down Expand Up @@ -1005,6 +1018,41 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {

}

func genGroups() map[string]*types.Group {
return map[string]*types.Group{
"one": {
Name: "one",
},
"two": {
Name: "two",
},
"three": {
Name: "three",
},
"four": {
Name: "four",
},
"five": {
Name: "five",
},
"six": {
Name: "six",
},
"seven": {
Name: "seven",
},
"eight": {
Name: "eight",
},
"nine": {
Name: "nine",
},
"ten": {
Name: "ten",
},
}
}
Comment thread
pascal-fischer marked this conversation as resolved.

func genUsers(p string, n int) map[string]*types.User {
users := map[string]*types.User{}
now := time.Now()
Expand Down Expand Up @@ -1723,6 +1771,13 @@ func TestAccount_Copy(t *testing.T) {
Id: "user1",
Role: types.UserRoleAdmin,
AutoGroups: []string{"group1"},
Groups: []*types.GroupUser{
{
AccountID: "account1",
UserID: "user1",
GroupID: "group1",
},
},
PATs: map[string]*types.PersonalAccessToken{
"pat1": {
ID: "pat1",
Expand All @@ -1742,6 +1797,13 @@ func TestAccount_Copy(t *testing.T) {
Peers: []string{"peer1"},
Resources: []types.Resource{},
GroupPeers: []types.GroupPeer{},
GroupUsers: []types.GroupUser{
{
AccountID: "account1",
UserID: "user1",
GroupID: "group1",
},
},
},
},
Policies: []*types.Policy{
Expand Down
24 changes: 17 additions & 7 deletions management/server/group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -380,13 +380,6 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
AutoGroups: []string{groupForUsers.ID},
}
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, "", "", false)
account.Routes[routeResource.ID] = routeResource
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
account.Policies = append(account.Policies, policy)
account.SetupKeys[setupKey.Id] = setupKey
account.Users[user.Id] = user

err := am.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, nil, err
Expand All @@ -400,6 +393,23 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)

account, err = am.Store.GetAccount(context.Background(), accountID)
if err != nil {
return nil, nil, err
}

account.Routes[routeResource.ID] = routeResource
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
account.Policies = append(account.Policies, policy)
account.SetupKeys[setupKey.Id] = setupKey
account.Users[user.Id] = user

err = am.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, nil, err
}

acc, err := am.Store.GetAccount(context.Background(), account.Id)
if err != nil {
return nil, nil, err
Expand Down
100 changes: 100 additions & 0 deletions management/server/migration/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -539,3 +539,103 @@ func RemoveDuplicatePeerKeys(ctx context.Context, db *gorm.DB) error {

return nil
}

// CleanupOrphanedIDs removes non-existent IDs from the JSON array column.
// T is the type of the model that contains the list.
// This migration cleans up the lists field by removing IDs that no longer exist in the target table.
func CleanupOrphanedIDs[T, S any](ctx context.Context, db *gorm.DB, columnName string) error {
var sourceModel T
var fkModel S

if !db.Migrator().HasTable(&sourceModel) {
log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", sourceModel)
return nil
}

if !db.Migrator().HasTable(&fkModel) {
log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", fkModel)
return nil
}

stmt := &gorm.Statement{DB: db}
err := stmt.Parse(&sourceModel)
if err != nil {
return fmt.Errorf("parse model: %w", err)
}
tableName := stmt.Schema.Table

if !db.Migrator().HasColumn(&sourceModel, columnName) {
log.WithContext(ctx).Debugf("Column %s does not exist in table %s, no migration needed", columnName, tableName)
return nil
}

if err := db.Transaction(func(tx *gorm.DB) error {
var rows []map[string]any
if err := tx.Table(tableName).Select("id", columnName).Find(&rows).Error; err != nil {
return fmt.Errorf("find rows: %w", err)
}

// Get all valid IDs from the fk table
var validIDs []string
if err := tx.Model(fkModel).Select("id").Pluck("id", &validIDs).Error; err != nil {
return fmt.Errorf("fetch valid group IDs: %w", err)
}

validIDMap := make(map[string]bool, len(validIDs))
for _, id := range validIDs {
validIDMap[id] = true
}

updatedCount := 0
for _, row := range rows {
jsonValue, ok := row[columnName].(string)
if !ok || jsonValue == "" || jsonValue == "null" {
continue
}

var list []string
if err := json.Unmarshal([]byte(jsonValue), &list); err != nil {
log.WithContext(ctx).Warnf("Failed to unmarshal %s for id %v: %v", columnName, row["id"], err)
continue
}

if len(list) == 0 {
continue
}

// Filter out non-existent IDs
cleanedList := make([]string, 0, len(list))
for _, groupID := range list {
if validIDMap[groupID] {
cleanedList = append(cleanedList, groupID)
}
}

// Only update if there were orphaned ids removed
if len(cleanedList) != len(list) {
cleanedJSON, err := json.Marshal(cleanedList)
if err != nil {
return fmt.Errorf("marshal cleaned %s: %w", columnName, err)
}

if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(columnName, cleanedJSON).Error; err != nil {
return fmt.Errorf("update row with id %v: %w", row["id"], err)
}
updatedCount++
}
}

if updatedCount > 0 {
log.WithContext(ctx).Infof("Cleaned up orphaned %s in %d rows from table %s", columnName, updatedCount, tableName)
} else {
log.WithContext(ctx).Debugf("No orphaned %s found in table %s", columnName, tableName)
}

return nil
}); err != nil {
return err
}

log.WithContext(ctx).Infof("Cleanup of orphaned %s from table %s completed", columnName, tableName)
return nil
}
Loading
Loading