Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions framework/configstore/dlock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,13 @@ func setupLockTestStore(t *testing.T) *RDBConfigStore {
err = db.AutoMigrate(&tables.TableDistributedLock{})
require.NoError(t, err, "Failed to migrate test database")

return &RDBConfigStore{
db: db,
logger: newMockLogger(),
s := &RDBConfigStore{logger: newMockLogger()}
s.db.Store(db)
s.migrateOnFreshFn = func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error {
return fn(ctx, s.DB())
}
s.refreshPoolFn = func(ctx context.Context) error { return nil }
return s
}

// =============================================================================
Expand Down Expand Up @@ -241,7 +244,7 @@ func TestUpdateLockExpiry_ExpiredLock(t *testing.T) {
ExpiresAt: time.Now().UTC().Add(-1 * time.Second),
}
// Directly insert the expired lock
err := store.db.Create(lock).Error
err := store.DB().Create(lock).Error
require.NoError(t, err)

// Try to extend expired lock
Expand Down Expand Up @@ -327,11 +330,11 @@ func TestCleanupExpiredLocks_Success(t *testing.T) {
}

for _, l := range expiredLocks {
err := store.db.Create(&l).Error
err := store.DB().Create(&l).Error
require.NoError(t, err)
}
for _, l := range validLocks {
err := store.db.Create(&l).Error
err := store.DB().Create(&l).Error
require.NoError(t, err)
}

Expand Down Expand Up @@ -383,7 +386,7 @@ func TestCleanupExpiredLockByKey_Success(t *testing.T) {
HolderID: "holder-1",
ExpiresAt: time.Now().UTC().Add(-1 * time.Minute),
}
err := store.db.Create(&lock).Error
err := store.DB().Create(&lock).Error
require.NoError(t, err)

// Cleanup specific expired lock
Expand Down Expand Up @@ -505,7 +508,7 @@ func TestDistributedLockManager_CleanupExpiredLocks(t *testing.T) {
HolderID: "holder-1",
ExpiresAt: time.Now().UTC().Add(-1 * time.Minute),
}
err := store.db.Create(&lock).Error
err := store.DB().Create(&lock).Error
require.NoError(t, err)

count, err := manager.CleanupExpiredLocks(ctx)
Expand Down Expand Up @@ -565,7 +568,7 @@ func TestDistributedLock_TryLock_CleansUpExpired(t *testing.T) {
HolderID: "old-holder",
ExpiresAt: time.Now().UTC().Add(-1 * time.Minute),
}
err := store.db.Create(&expiredLock).Error
err := store.DB().Create(&expiredLock).Error
require.NoError(t, err)

// New lock should be able to acquire after cleanup
Expand Down Expand Up @@ -772,7 +775,7 @@ func TestDistributedLock_Extend_StolenLock(t *testing.T) {
require.NoError(t, err)

// Simulate lock being stolen by another process
err = store.db.Model(&tables.TableDistributedLock{}).
err = store.DB().Model(&tables.TableDistributedLock{}).
Where("lock_key = ?", "test-lock").
Update("holder_id", "another-holder").Error
require.NoError(t, err)
Expand Down Expand Up @@ -844,7 +847,7 @@ func TestDistributedLock_IsHeld_StolenByAnotherHolder(t *testing.T) {
require.NoError(t, err)

// Simulate lock being stolen by another process
err = store.db.Model(&tables.TableDistributedLock{}).
err = store.DB().Model(&tables.TableDistributedLock{}).
Where("lock_key = ?", "test-lock").
Update("holder_id", "another-holder").Error
require.NoError(t, err)
Expand All @@ -866,7 +869,7 @@ func TestDistributedLock_IsHeld_DeletedFromDB(t *testing.T) {
require.NoError(t, err)

// Delete lock directly from database
err = store.db.Where("lock_key = ?", "test-lock").Delete(&tables.TableDistributedLock{}).Error
err = store.DB().Where("lock_key = ?", "test-lock").Delete(&tables.TableDistributedLock{}).Error
require.NoError(t, err)

held, err := lock.IsHeld(ctx)
Expand Down
36 changes: 18 additions & 18 deletions framework/configstore/encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (s *RDBConfigStore) encryptPlaintextKeys(ctx context.Context) (int, error)
var count int
for {
var keys []tables.TableKey
if err := s.db.WithContext(ctx).
if err := s.DB().WithContext(ctx).
Where("encryption_status = ? OR encryption_status IS NULL OR encryption_status = ''", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&keys).Error; err != nil {
Expand All @@ -110,7 +110,7 @@ func (s *RDBConfigStore) encryptPlaintextKeys(ctx context.Context) (int, error)
if len(keys) == 0 {
break
}
if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range keys {
if err := tx.Save(&keys[i]).Error; err != nil {
return err
Expand All @@ -131,7 +131,7 @@ func (s *RDBConfigStore) encryptPlaintextVirtualKeys(ctx context.Context) (int,
var count int
for {
var vks []tables.TableVirtualKey
if err := s.db.WithContext(ctx).
if err := s.DB().WithContext(ctx).
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND value != ''", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&vks).Error; err != nil {
Expand All @@ -140,7 +140,7 @@ func (s *RDBConfigStore) encryptPlaintextVirtualKeys(ctx context.Context) (int,
if len(vks) == 0 {
break
}
if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range vks {
if err := tx.Save(&vks[i]).Error; err != nil {
return err
Expand All @@ -161,7 +161,7 @@ func (s *RDBConfigStore) encryptPlaintextSessions(ctx context.Context) (int, err
var count int
for {
var sessions []tables.SessionsTable
if err := s.db.WithContext(ctx).
if err := s.DB().WithContext(ctx).
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND token != ''", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&sessions).Error; err != nil {
Expand All @@ -170,7 +170,7 @@ func (s *RDBConfigStore) encryptPlaintextSessions(ctx context.Context) (int, err
if len(sessions) == 0 {
break
}
if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range sessions {
if err := tx.Save(&sessions[i]).Error; err != nil {
return err
Expand All @@ -191,7 +191,7 @@ func (s *RDBConfigStore) encryptPlaintextOAuthTokens(ctx context.Context) (int,
var count int
for {
var tokens []tables.TableOauthToken
if err := s.db.WithContext(ctx).
if err := s.DB().WithContext(ctx).
Where("encryption_status = ? OR encryption_status IS NULL OR encryption_status = ''", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&tokens).Error; err != nil {
Expand All @@ -200,7 +200,7 @@ func (s *RDBConfigStore) encryptPlaintextOAuthTokens(ctx context.Context) (int,
if len(tokens) == 0 {
break
}
if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range tokens {
if err := tx.Save(&tokens[i]).Error; err != nil {
return err
Expand All @@ -221,7 +221,7 @@ func (s *RDBConfigStore) encryptPlaintextOAuthConfigs(ctx context.Context) (int,
var count int
for {
var configs []tables.TableOauthConfig
if err := s.db.WithContext(ctx).
if err := s.DB().WithContext(ctx).
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND (client_secret != '' OR code_verifier != '')", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&configs).Error; err != nil {
Expand All @@ -230,7 +230,7 @@ func (s *RDBConfigStore) encryptPlaintextOAuthConfigs(ctx context.Context) (int,
if len(configs) == 0 {
break
}
if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range configs {
if err := tx.Save(&configs[i]).Error; err != nil {
return err
Expand All @@ -251,7 +251,7 @@ func (s *RDBConfigStore) encryptPlaintextMCPClients(ctx context.Context) (int, e
var count int
for {
var clients []tables.TableMCPClient
if err := s.db.WithContext(ctx).
if err := s.DB().WithContext(ctx).
Where("encryption_status = ? OR encryption_status IS NULL OR encryption_status = ''", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&clients).Error; err != nil {
Expand All @@ -260,7 +260,7 @@ func (s *RDBConfigStore) encryptPlaintextMCPClients(ctx context.Context) (int, e
if len(clients) == 0 {
break
}
if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range clients {
if err := tx.Save(&clients[i]).Error; err != nil {
return err
Expand All @@ -282,7 +282,7 @@ func (s *RDBConfigStore) encryptPlaintextProviderProxies(ctx context.Context) (i
var count int
for {
var providers []tables.TableProvider
if err := s.db.WithContext(ctx).
if err := s.DB().WithContext(ctx).
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND proxy_config_json != '' AND proxy_config_json IS NOT NULL", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&providers).Error; err != nil {
Expand All @@ -291,7 +291,7 @@ func (s *RDBConfigStore) encryptPlaintextProviderProxies(ctx context.Context) (i
if len(providers) == 0 {
break
}
if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range providers {
if err := tx.Save(&providers[i]).Error; err != nil {
return err
Expand All @@ -313,7 +313,7 @@ func (s *RDBConfigStore) encryptPlaintextVectorStoreConfigs(ctx context.Context)
var count int
for {
var configs []tables.TableVectorStoreConfig
if err := s.db.WithContext(ctx).
if err := s.DB().WithContext(ctx).
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND config IS NOT NULL AND config != ''", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&configs).Error; err != nil {
Expand All @@ -322,7 +322,7 @@ func (s *RDBConfigStore) encryptPlaintextVectorStoreConfigs(ctx context.Context)
if len(configs) == 0 {
break
}
if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range configs {
if err := tx.Save(&configs[i]).Error; err != nil {
return err
Expand All @@ -344,7 +344,7 @@ func (s *RDBConfigStore) encryptPlaintextPlugins(ctx context.Context) (int, erro
var count int
for {
var plugins []tables.TablePlugin
if err := s.db.WithContext(ctx).
if err := s.DB().WithContext(ctx).
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND config_json != '' AND config_json != '{}'", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&plugins).Error; err != nil {
Expand All @@ -353,7 +353,7 @@ func (s *RDBConfigStore) encryptPlaintextPlugins(ctx context.Context) (int, erro
if len(plugins) == 0 {
break
}
if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range plugins {
if err := tx.Save(&plugins[i]).Error; err != nil {
return err
Expand Down
8 changes: 5 additions & 3 deletions framework/configstore/encryption_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,12 @@ func setupEncryptionTestStore(t *testing.T) (*RDBConfigStore, *gorm.DB) {
)
require.NoError(t, err)

store := &RDBConfigStore{
db: db,
logger: bifrost.NewDefaultLogger(schemas.LogLevelInfo),
store := &RDBConfigStore{logger: bifrost.NewDefaultLogger(schemas.LogLevelInfo)}
store.db.Store(db)
store.migrateOnFreshFn = func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error {
return fn(ctx, store.DB())
}
store.refreshPoolFn = func(ctx context.Context) error { return nil }
return store, db
}

Expand Down
16 changes: 16 additions & 0 deletions framework/configstore/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,22 @@ func (l *migrationLock) release(ctx context.Context) {
l.conn.Close()
}

// RunSingleMigration applies a single gormigrate migration on the given
// *gorm.DB. Mirrors (*RDBConfigStore).RunMigration but takes the *gorm.DB
// directly, so downstream consumers (bifrost-enterprise, plugins) can run
// their migrations inside a MigrateOnFreshConnection callback without having
// to reach the throwaway pool through the ConfigStore abstraction.
func RunSingleMigration(ctx context.Context, db *gorm.DB, migration *migrator.Migration) error {
if db == nil {
return fmt.Errorf("db cannot be nil")
}
if migration == nil {
return fmt.Errorf("migration cannot be nil")
}
m := migrator.New(db.WithContext(ctx), migrator.DefaultOptions, []*migrator.Migration{migration})
return m.Migrate()
}
Comment thread
akshaydeo marked this conversation as resolved.

// Migrate performs the necessary database migrations.
func triggerMigrations(ctx context.Context, db *gorm.DB) error {
// Acquire advisory lock to serialize migrations across cluster nodes.
Expand Down
8 changes: 5 additions & 3 deletions framework/configstore/migrations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1122,10 +1122,12 @@ func setupFullMigrationDB(t *testing.T) (*RDBConfigStore, *gorm.DB) {
err = triggerMigrations(ctx, db)
require.NoError(t, err, "triggerMigrations should succeed on a fresh DB")

store := &RDBConfigStore{
db: db,
logger: bifrost.NewDefaultLogger(schemas.LogLevelInfo),
store := &RDBConfigStore{logger: bifrost.NewDefaultLogger(schemas.LogLevelInfo)}
store.db.Store(db)
store.migrateOnFreshFn = func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error {
return fn(ctx, store.DB())
}
store.refreshPoolFn = func(ctx context.Context) error { return nil }
return store, db
}

Expand Down
Loading
Loading