diff --git a/framework/configstore/dlock_test.go b/framework/configstore/dlock_test.go index 019abc5624..16a4e892a9 100644 --- a/framework/configstore/dlock_test.go +++ b/framework/configstore/dlock_test.go @@ -90,10 +90,13 @@ func setupLockTestStore(t *testing.T) *RDBConfigStore { err = db.AutoMigrate(&tables.TableDistributedLock{}) require.NoError(t, err, "Failed to migrate test database") - return &RDBConfigStore{ - db: db, - logger: newMockLogger(), + s := &RDBConfigStore{logger: newMockLogger()} + s.db.Store(db) + s.migrateOnFreshFn = func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error { + return fn(ctx, s.DB()) } + s.refreshPoolFn = func(ctx context.Context) error { return nil } + return s } // ============================================================================= @@ -241,7 +244,7 @@ func TestUpdateLockExpiry_ExpiredLock(t *testing.T) { ExpiresAt: time.Now().UTC().Add(-1 * time.Second), } // Directly insert the expired lock - err := store.db.Create(lock).Error + err := store.DB().Create(lock).Error require.NoError(t, err) // Try to extend expired lock @@ -327,11 +330,11 @@ func TestCleanupExpiredLocks_Success(t *testing.T) { } for _, l := range expiredLocks { - err := store.db.Create(&l).Error + err := store.DB().Create(&l).Error require.NoError(t, err) } for _, l := range validLocks { - err := store.db.Create(&l).Error + err := store.DB().Create(&l).Error require.NoError(t, err) } @@ -383,7 +386,7 @@ func TestCleanupExpiredLockByKey_Success(t *testing.T) { HolderID: "holder-1", ExpiresAt: time.Now().UTC().Add(-1 * time.Minute), } - err := store.db.Create(&lock).Error + err := store.DB().Create(&lock).Error require.NoError(t, err) // Cleanup specific expired lock @@ -505,7 +508,7 @@ func TestDistributedLockManager_CleanupExpiredLocks(t *testing.T) { HolderID: "holder-1", ExpiresAt: time.Now().UTC().Add(-1 * time.Minute), } - err := store.db.Create(&lock).Error + err := store.DB().Create(&lock).Error require.NoError(t, err) count, err := manager.CleanupExpiredLocks(ctx) @@ -565,7 +568,7 @@ func TestDistributedLock_TryLock_CleansUpExpired(t *testing.T) { HolderID: "old-holder", ExpiresAt: time.Now().UTC().Add(-1 * time.Minute), } - err := store.db.Create(&expiredLock).Error + err := store.DB().Create(&expiredLock).Error require.NoError(t, err) // New lock should be able to acquire after cleanup @@ -772,7 +775,7 @@ func TestDistributedLock_Extend_StolenLock(t *testing.T) { require.NoError(t, err) // Simulate lock being stolen by another process - err = store.db.Model(&tables.TableDistributedLock{}). + err = store.DB().Model(&tables.TableDistributedLock{}). Where("lock_key = ?", "test-lock"). Update("holder_id", "another-holder").Error require.NoError(t, err) @@ -844,7 +847,7 @@ func TestDistributedLock_IsHeld_StolenByAnotherHolder(t *testing.T) { require.NoError(t, err) // Simulate lock being stolen by another process - err = store.db.Model(&tables.TableDistributedLock{}). + err = store.DB().Model(&tables.TableDistributedLock{}). Where("lock_key = ?", "test-lock"). Update("holder_id", "another-holder").Error require.NoError(t, err) @@ -866,7 +869,7 @@ func TestDistributedLock_IsHeld_DeletedFromDB(t *testing.T) { require.NoError(t, err) // Delete lock directly from database - err = store.db.Where("lock_key = ?", "test-lock").Delete(&tables.TableDistributedLock{}).Error + err = store.DB().Where("lock_key = ?", "test-lock").Delete(&tables.TableDistributedLock{}).Error require.NoError(t, err) held, err := lock.IsHeld(ctx) diff --git a/framework/configstore/encryption.go b/framework/configstore/encryption.go index b8818cb9ba..b2de668abe 100644 --- a/framework/configstore/encryption.go +++ b/framework/configstore/encryption.go @@ -101,7 +101,7 @@ func (s *RDBConfigStore) encryptPlaintextKeys(ctx context.Context) (int, error) var count int for { var keys []tables.TableKey - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Where("encryption_status = ? OR encryption_status IS NULL OR encryption_status = ''", encryptionStatusPlainText). Limit(encryptionBatchSize). Find(&keys).Error; err != nil { @@ -110,7 +110,7 @@ func (s *RDBConfigStore) encryptPlaintextKeys(ctx context.Context) (int, error) if len(keys) == 0 { break } - if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { for i := range keys { if err := tx.Save(&keys[i]).Error; err != nil { return err @@ -131,7 +131,7 @@ func (s *RDBConfigStore) encryptPlaintextVirtualKeys(ctx context.Context) (int, var count int for { var vks []tables.TableVirtualKey - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND value != ''", encryptionStatusPlainText). Limit(encryptionBatchSize). Find(&vks).Error; err != nil { @@ -140,7 +140,7 @@ func (s *RDBConfigStore) encryptPlaintextVirtualKeys(ctx context.Context) (int, if len(vks) == 0 { break } - if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { for i := range vks { if err := tx.Save(&vks[i]).Error; err != nil { return err @@ -161,7 +161,7 @@ func (s *RDBConfigStore) encryptPlaintextSessions(ctx context.Context) (int, err var count int for { var sessions []tables.SessionsTable - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND token != ''", encryptionStatusPlainText). Limit(encryptionBatchSize). Find(&sessions).Error; err != nil { @@ -170,7 +170,7 @@ func (s *RDBConfigStore) encryptPlaintextSessions(ctx context.Context) (int, err if len(sessions) == 0 { break } - if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { for i := range sessions { if err := tx.Save(&sessions[i]).Error; err != nil { return err @@ -191,7 +191,7 @@ func (s *RDBConfigStore) encryptPlaintextOAuthTokens(ctx context.Context) (int, var count int for { var tokens []tables.TableOauthToken - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Where("encryption_status = ? OR encryption_status IS NULL OR encryption_status = ''", encryptionStatusPlainText). Limit(encryptionBatchSize). Find(&tokens).Error; err != nil { @@ -200,7 +200,7 @@ func (s *RDBConfigStore) encryptPlaintextOAuthTokens(ctx context.Context) (int, if len(tokens) == 0 { break } - if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { for i := range tokens { if err := tx.Save(&tokens[i]).Error; err != nil { return err @@ -221,7 +221,7 @@ func (s *RDBConfigStore) encryptPlaintextOAuthConfigs(ctx context.Context) (int, var count int for { var configs []tables.TableOauthConfig - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND (client_secret != '' OR code_verifier != '')", encryptionStatusPlainText). Limit(encryptionBatchSize). Find(&configs).Error; err != nil { @@ -230,7 +230,7 @@ func (s *RDBConfigStore) encryptPlaintextOAuthConfigs(ctx context.Context) (int, if len(configs) == 0 { break } - if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { for i := range configs { if err := tx.Save(&configs[i]).Error; err != nil { return err @@ -251,7 +251,7 @@ func (s *RDBConfigStore) encryptPlaintextMCPClients(ctx context.Context) (int, e var count int for { var clients []tables.TableMCPClient - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Where("encryption_status = ? OR encryption_status IS NULL OR encryption_status = ''", encryptionStatusPlainText). Limit(encryptionBatchSize). Find(&clients).Error; err != nil { @@ -260,7 +260,7 @@ func (s *RDBConfigStore) encryptPlaintextMCPClients(ctx context.Context) (int, e if len(clients) == 0 { break } - if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { for i := range clients { if err := tx.Save(&clients[i]).Error; err != nil { return err @@ -282,7 +282,7 @@ func (s *RDBConfigStore) encryptPlaintextProviderProxies(ctx context.Context) (i var count int for { var providers []tables.TableProvider - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND proxy_config_json != '' AND proxy_config_json IS NOT NULL", encryptionStatusPlainText). Limit(encryptionBatchSize). Find(&providers).Error; err != nil { @@ -291,7 +291,7 @@ func (s *RDBConfigStore) encryptPlaintextProviderProxies(ctx context.Context) (i if len(providers) == 0 { break } - if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { for i := range providers { if err := tx.Save(&providers[i]).Error; err != nil { return err @@ -313,7 +313,7 @@ func (s *RDBConfigStore) encryptPlaintextVectorStoreConfigs(ctx context.Context) var count int for { var configs []tables.TableVectorStoreConfig - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND config IS NOT NULL AND config != ''", encryptionStatusPlainText). Limit(encryptionBatchSize). Find(&configs).Error; err != nil { @@ -322,7 +322,7 @@ func (s *RDBConfigStore) encryptPlaintextVectorStoreConfigs(ctx context.Context) if len(configs) == 0 { break } - if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { for i := range configs { if err := tx.Save(&configs[i]).Error; err != nil { return err @@ -344,7 +344,7 @@ func (s *RDBConfigStore) encryptPlaintextPlugins(ctx context.Context) (int, erro var count int for { var plugins []tables.TablePlugin - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND config_json != '' AND config_json != '{}'", encryptionStatusPlainText). Limit(encryptionBatchSize). Find(&plugins).Error; err != nil { @@ -353,7 +353,7 @@ func (s *RDBConfigStore) encryptPlaintextPlugins(ctx context.Context) (int, erro if len(plugins) == 0 { break } - if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { for i := range plugins { if err := tx.Save(&plugins[i]).Error; err != nil { return err diff --git a/framework/configstore/encryption_test.go b/framework/configstore/encryption_test.go index 9ac36baede..e7ad6272b1 100644 --- a/framework/configstore/encryption_test.go +++ b/framework/configstore/encryption_test.go @@ -54,10 +54,12 @@ func setupEncryptionTestStore(t *testing.T) (*RDBConfigStore, *gorm.DB) { ) require.NoError(t, err) - store := &RDBConfigStore{ - db: db, - logger: bifrost.NewDefaultLogger(schemas.LogLevelInfo), + store := &RDBConfigStore{logger: bifrost.NewDefaultLogger(schemas.LogLevelInfo)} + store.db.Store(db) + store.migrateOnFreshFn = func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error { + return fn(ctx, store.DB()) } + store.refreshPoolFn = func(ctx context.Context) error { return nil } return store, db } diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go index a0352382f8..39637d2dfe 100644 --- a/framework/configstore/migrations.go +++ b/framework/configstore/migrations.go @@ -72,6 +72,22 @@ func (l *migrationLock) release(ctx context.Context) { l.conn.Close() } +// RunSingleMigration applies a single gormigrate migration on the given +// *gorm.DB. Mirrors (*RDBConfigStore).RunMigration but takes the *gorm.DB +// directly, so downstream consumers (bifrost-enterprise, plugins) can run +// their migrations inside a MigrateOnFreshConnection callback without having +// to reach the throwaway pool through the ConfigStore abstraction. +func RunSingleMigration(ctx context.Context, db *gorm.DB, migration *migrator.Migration) error { + if db == nil { + return fmt.Errorf("db cannot be nil") + } + if migration == nil { + return fmt.Errorf("migration cannot be nil") + } + m := migrator.New(db.WithContext(ctx), migrator.DefaultOptions, []*migrator.Migration{migration}) + return m.Migrate() +} + // Migrate performs the necessary database migrations. func triggerMigrations(ctx context.Context, db *gorm.DB) error { // Acquire advisory lock to serialize migrations across cluster nodes. diff --git a/framework/configstore/migrations_test.go b/framework/configstore/migrations_test.go index b03afaa7ff..31abb58798 100644 --- a/framework/configstore/migrations_test.go +++ b/framework/configstore/migrations_test.go @@ -1122,10 +1122,12 @@ func setupFullMigrationDB(t *testing.T) (*RDBConfigStore, *gorm.DB) { err = triggerMigrations(ctx, db) require.NoError(t, err, "triggerMigrations should succeed on a fresh DB") - store := &RDBConfigStore{ - db: db, - logger: bifrost.NewDefaultLogger(schemas.LogLevelInfo), + store := &RDBConfigStore{logger: bifrost.NewDefaultLogger(schemas.LogLevelInfo)} + store.db.Store(db) + store.migrateOnFreshFn = func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error { + return fn(ctx, store.DB()) } + store.refreshPoolFn = func(ctx context.Context) error { return nil } return store, db } diff --git a/framework/configstore/postgres.go b/framework/configstore/postgres.go index ecc016b68d..b88edf143b 100644 --- a/framework/configstore/postgres.go +++ b/framework/configstore/postgres.go @@ -21,12 +21,67 @@ type PostgresConfig struct { MaxOpenConns int `json:"max_open_conns"` } +// buildPostgresDSN assembles a libpq-style DSN from the validated config. +func buildPostgresDSN(config *PostgresConfig) string { + return fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s", + config.Host.GetValue(), config.Port.GetValue(), config.User.GetValue(), + config.Password.GetValue(), config.DBName.GetValue(), config.SSLMode.GetValue()) +} + +// openPostresConnection opens a *gorm.DB against the configured Postgres instance +// using the shared bifrost logger. Used for both the throwaway migration pool +// and the runtime pool. +func openPostresConnection(dsn string, logger schemas.Logger) (*gorm.DB, error) { + return gorm.Open(postgres.New(postgres.Config{DSN: dsn}), &gorm.Config{ + Logger: newGormLogger(logger), + }) +} + +// closeDbConn closes the *sql.DB backing a *gorm.DB, logging any error. +// Used in error paths and for the throwaway migration pool. +func closeDbConn(db *gorm.DB, logger schemas.Logger) { + sqlDB, err := db.DB() + if err != nil { + logger.Error("failed to resolve *sql.DB for close: %v", err) + return + } + if err := sqlDB.Close(); err != nil { + logger.Error("failed to close DB connection: %v", err) + } +} + +// applyPostgresPoolTuning applies MaxIdleConns / MaxOpenConns from config to +// the supplied *gorm.DB, falling back to defaults when the config leaves the +// field at zero. +func applyPostgresPoolTuning(db *gorm.DB, config *PostgresConfig) error { + sqlDB, err := db.DB() + if err != nil { + return err + } + maxIdleConns := config.MaxIdleConns + if maxIdleConns == 0 { + maxIdleConns = 5 + } + sqlDB.SetMaxIdleConns(maxIdleConns) + maxOpenConns := config.MaxOpenConns + if maxOpenConns == 0 { + maxOpenConns = 50 + } + sqlDB.SetMaxOpenConns(maxOpenConns) + return nil +} + // newPostgresConfigStore creates a new Postgres config store. +// +// Uses a two-pool lifecycle to avoid SQLSTATE 0A000 ("cached plan must not +// change result type"): a throwaway migration pool runs DDL and is closed +// immediately, then a fresh runtime pool is opened. The runtime pool's +// connections never see pre-migration schema, so their cached prepared-plans +// stay valid for the life of the process. func newPostgresConfigStore(ctx context.Context, config *PostgresConfig, logger schemas.Logger) (ConfigStore, error) { if config == nil { return nil, fmt.Errorf("config is required") } - // Validate required config if config.Host == nil || config.Host.GetValue() == "" { return nil, fmt.Errorf("postgres host is required") } @@ -45,53 +100,69 @@ func newPostgresConfigStore(ctx context.Context, config *PostgresConfig, logger if config.SSLMode == nil || config.SSLMode.GetValue() == "" { return nil, fmt.Errorf("postgres ssl mode is required") } - dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s", config.Host.GetValue(), config.Port.GetValue(), config.User.GetValue(), config.Password.GetValue(), config.DBName.GetValue(), config.SSLMode.GetValue()) - db, err := gorm.Open(postgres.New(postgres.Config{ - DSN: dsn, - }), &gorm.Config{ - Logger: newGormLogger(logger), - }) + dsn := buildPostgresDSN(config) + + // Throwaway pool for schema migrations. Closing it before the runtime pool + // opens guarantees no cached prepared-plan survives the DDL. + mDb, err := openPostresConnection(dsn, logger) if err != nil { return nil, err } + if err := triggerMigrations(ctx, mDb); err != nil { + closeDbConn(mDb, logger) + return nil, err + } + closeDbConn(mDb, logger) - // Configure connection pool - sqlDB, err := db.DB() + // Runtime pool. Opens against post-migration schema. + db, err := openPostresConnection(dsn, logger) if err != nil { return nil, err } - // Set MaxIdleConns (default: 5) - maxIdleConns := config.MaxIdleConns - if maxIdleConns == 0 { - maxIdleConns = 5 + if err := applyPostgresPoolTuning(db, config); err != nil { + closeDbConn(db, logger) + return nil, err } - sqlDB.SetMaxIdleConns(maxIdleConns) - // Set MaxOpenConns (default: 50) - maxOpenConns := config.MaxOpenConns - if maxOpenConns == 0 { - maxOpenConns = 50 + d := &RDBConfigStore{logger: logger} + d.db.Store(db) + + // migrateOnFreshFn: downstream consumers (e.g. bifrost-enterprise) run + // their migrations via this hook on a throwaway pool that closes after fn. + d.migrateOnFreshFn = func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error { + tempDB, err := openPostresConnection(dsn, logger) + if err != nil { + return err + } + defer closeDbConn(tempDB, logger) + return fn(ctx, tempDB) } - sqlDB.SetMaxOpenConns(maxOpenConns) - d := &RDBConfigStore{db: db, logger: logger} - // Run migrations - if err := triggerMigrations(ctx, db); err != nil { - // Closing the DB connection - if sqlDB, dbErr := db.DB(); dbErr == nil { - if closeErr := sqlDB.Close(); closeErr != nil { - logger.Error("failed to close DB connection: %v", closeErr) - } + // refreshPoolFn: open fresh runtime pool first (so a failure leaves the + // existing pool in place), swap atomically, then close the old pool. + // sql.DB.Close blocks until in-flight queries finish, so callers already + // using the old pool complete safely. + d.refreshPoolFn = func(ctx context.Context) error { + newDB, err := openPostresConnection(dsn, logger) + if err != nil { + return fmt.Errorf("failed to open fresh runtime pool: %w", err) } - return nil, err + if err := applyPostgresPoolTuning(newDB, config); err != nil { + closeDbConn(newDB, logger) + return fmt.Errorf("failed to tune fresh runtime pool: %w", err) + } + oldDB := d.db.Swap(newDB) + if oldDB != nil { + closeDbConn(oldDB, logger) + } + return nil } - // Encrypt any plaintext rows if encryption is enabled + + // Encrypt any plaintext rows if encryption is enabled. Runs on the + // runtime pool — pure DML (SELECT + UPDATE), no DDL, so cached plans it + // installs remain valid until the next external migration batch. if err := d.EncryptPlaintextRows(ctx); err != nil { - if sqlDB, dbErr := db.DB(); dbErr == nil { - if closeErr := sqlDB.Close(); closeErr != nil { - logger.Error("failed to close DB connection: %v", closeErr) - } - } + closeDbConn(db, logger) return nil, fmt.Errorf("failed to encrypt plaintext rows: %w", err) } return d, nil diff --git a/framework/configstore/prompts.go b/framework/configstore/prompts.go index e760351b95..c30dacd75a 100644 --- a/framework/configstore/prompts.go +++ b/framework/configstore/prompts.go @@ -27,7 +27,7 @@ func isUniqueConstraintError(err error) bool { // GetFolders gets all folders func (s *RDBConfigStore) GetFolders(ctx context.Context) ([]tables.TableFolder, error) { var folders []tables.TableFolder - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Order("created_at DESC"). Find(&folders).Error; err != nil { return nil, err @@ -36,7 +36,7 @@ func (s *RDBConfigStore) GetFolders(ctx context.Context) ([]tables.TableFolder, // Get prompts count for each folder for i := range folders { var count int64 - if err := s.db.WithContext(ctx).Model(&tables.TablePrompt{}).Where("folder_id = ?", folders[i].ID).Count(&count).Error; err != nil { + if err := s.DB().WithContext(ctx).Model(&tables.TablePrompt{}).Where("folder_id = ?", folders[i].ID).Count(&count).Error; err != nil { return nil, err } folders[i].PromptsCount = int(count) @@ -48,7 +48,7 @@ func (s *RDBConfigStore) GetFolders(ctx context.Context) ([]tables.TableFolder, // GetFolderByID gets a folder by ID func (s *RDBConfigStore) GetFolderByID(ctx context.Context, id string) (*tables.TableFolder, error) { var folder tables.TableFolder - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). First(&folder, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound @@ -60,12 +60,12 @@ func (s *RDBConfigStore) GetFolderByID(ctx context.Context, id string) (*tables. // CreateFolder creates a new folder func (s *RDBConfigStore) CreateFolder(ctx context.Context, folder *tables.TableFolder) error { - return s.db.WithContext(ctx).Create(folder).Error + return s.DB().WithContext(ctx).Create(folder).Error } // UpdateFolder updates a folder func (s *RDBConfigStore) UpdateFolder(ctx context.Context, folder *tables.TableFolder) error { - res := s.db.WithContext(ctx).Where("id = ?", folder.ID).Save(folder) + res := s.DB().WithContext(ctx).Where("id = ?", folder.ID).Save(folder) if res.Error != nil { return res.Error } @@ -79,7 +79,7 @@ func (s *RDBConfigStore) UpdateFolder(ctx context.Context, folder *tables.TableF // PostgreSQL uses native ON DELETE CASCADE; SQLite requires manual cascade because it cannot // alter foreign key constraints after table creation. func (s *RDBConfigStore) DeleteFolder(ctx context.Context, id string) error { - return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { // Check folder exists var folder tables.TableFolder if err := tx.First(&folder, "id = ?", id).Error; err != nil { @@ -90,7 +90,7 @@ func (s *RDBConfigStore) DeleteFolder(ctx context.Context, id string) error { } // PostgreSQL: ON DELETE CASCADE handles all child deletions - if s.db.Dialector.Name() == "postgres" { + if s.DB().Dialector.Name() == "postgres" { return tx.Delete(&folder).Error } @@ -135,7 +135,7 @@ func (s *RDBConfigStore) DeleteFolder(ctx context.Context, id string) error { // GetPrompts gets all prompts, optionally filtered by folder ID func (s *RDBConfigStore) GetPrompts(ctx context.Context, folderID *string) ([]tables.TablePrompt, error) { var prompts []tables.TablePrompt - query := s.db.WithContext(ctx). + query := s.DB().WithContext(ctx). Preload("Folder"). Order("created_at DESC") @@ -150,7 +150,7 @@ func (s *RDBConfigStore) GetPrompts(ctx context.Context, folderID *string) ([]ta // Get latest version for each prompt for i := range prompts { var latestVersion tables.TablePromptVersion - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }). Where("prompt_id = ? AND is_latest = ?", prompts[i].ID, true). First(&latestVersion).Error; err != nil { @@ -168,7 +168,7 @@ func (s *RDBConfigStore) GetPrompts(ctx context.Context, folderID *string) ([]ta // GetPromptByID gets a prompt by ID with latest version func (s *RDBConfigStore) GetPromptByID(ctx context.Context, id string) (*tables.TablePrompt, error) { var prompt tables.TablePrompt - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Preload("Folder"). First(&prompt, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -179,7 +179,7 @@ func (s *RDBConfigStore) GetPromptByID(ctx context.Context, id string) (*tables. // Get latest version var latestVersion tables.TablePromptVersion - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }). Where("prompt_id = ? AND is_latest = ?", prompt.ID, true). First(&latestVersion).Error; err != nil { @@ -195,13 +195,13 @@ func (s *RDBConfigStore) GetPromptByID(ctx context.Context, id string) (*tables. // CreatePrompt creates a new prompt func (s *RDBConfigStore) CreatePrompt(ctx context.Context, prompt *tables.TablePrompt) error { - return s.db.WithContext(ctx).Create(prompt).Error + return s.DB().WithContext(ctx).Create(prompt).Error } // UpdatePrompt updates a prompt func (s *RDBConfigStore) UpdatePrompt(ctx context.Context, prompt *tables.TablePrompt) error { // Use Select to explicitly include FolderID so GORM writes NULL when it's nil - res := s.db.WithContext(ctx). + res := s.DB().WithContext(ctx). Model(prompt). Where("id = ?", prompt.ID). Select("Name", "FolderID", "UpdatedAt"). @@ -219,7 +219,7 @@ func (s *RDBConfigStore) UpdatePrompt(ctx context.Context, prompt *tables.TableP // PostgreSQL uses native ON DELETE CASCADE; SQLite requires manual cascade because it cannot // alter foreign key constraints after table creation. func (s *RDBConfigStore) DeletePrompt(ctx context.Context, id string) error { - return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { // Check prompt exists var prompt tables.TablePrompt if err := tx.First(&prompt, "id = ?", id).Error; err != nil { @@ -230,7 +230,7 @@ func (s *RDBConfigStore) DeletePrompt(ctx context.Context, id string) error { } // PostgreSQL: ON DELETE CASCADE handles all child deletions - if s.db.Dialector.Name() == "postgres" { + if s.DB().Dialector.Name() == "postgres" { return tx.Delete(&prompt).Error } @@ -258,7 +258,7 @@ func (s *RDBConfigStore) DeletePrompt(ctx context.Context, id string) error { // GetAllPromptVersions returns every version across all prompts in a single query. func (s *RDBConfigStore) GetAllPromptVersions(ctx context.Context) ([]tables.TablePromptVersion, error) { var versions []tables.TablePromptVersion - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }). Order("prompt_id ASC, version_number DESC"). Find(&versions).Error; err != nil { @@ -270,7 +270,7 @@ func (s *RDBConfigStore) GetAllPromptVersions(ctx context.Context) ([]tables.Tab // GetPromptVersions gets all versions for a prompt func (s *RDBConfigStore) GetPromptVersions(ctx context.Context, promptID string) ([]tables.TablePromptVersion, error) { var versions []tables.TablePromptVersion - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }). Where("prompt_id = ?", promptID). Order("version_number DESC"). @@ -283,7 +283,7 @@ func (s *RDBConfigStore) GetPromptVersions(ctx context.Context, promptID string) // GetPromptVersionByID gets a version by ID func (s *RDBConfigStore) GetPromptVersionByID(ctx context.Context, id uint) (*tables.TablePromptVersion, error) { var version tables.TablePromptVersion - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }). Preload("Prompt"). First(&version, "id = ?", id).Error; err != nil { @@ -298,7 +298,7 @@ func (s *RDBConfigStore) GetPromptVersionByID(ctx context.Context, id uint) (*ta // GetLatestPromptVersion gets the latest version for a prompt func (s *RDBConfigStore) GetLatestPromptVersion(ctx context.Context, promptID string) (*tables.TablePromptVersion, error) { var version tables.TablePromptVersion - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }). Where("prompt_id = ? AND is_latest = ?", promptID, true). First(&version).Error; err != nil { @@ -315,7 +315,7 @@ func (s *RDBConfigStore) GetLatestPromptVersion(ctx context.Context, promptID st func (s *RDBConfigStore) CreatePromptVersion(ctx context.Context, version *tables.TablePromptVersion) error { const maxRetries = 3 for attempt := 0; attempt < maxRetries; attempt++ { - err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { // Get the next version number var maxVersionNumber int if err := tx.Model(&tables.TablePromptVersion{}). @@ -364,7 +364,7 @@ func (s *RDBConfigStore) CreatePromptVersion(ctx context.Context, version *table // DeletePromptVersion deletes a version and promotes the previous version to latest if needed. // PostgreSQL uses native ON DELETE CASCADE for messages; SQLite requires manual cascade. func (s *RDBConfigStore) DeletePromptVersion(ctx context.Context, id uint) error { - return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { // Get the version to check if it's latest var version tables.TablePromptVersion if err := tx.First(&version, "id = ?", id).Error; err != nil { @@ -375,7 +375,7 @@ func (s *RDBConfigStore) DeletePromptVersion(ctx context.Context, id uint) error } // SQLite: manually delete version messages (PostgreSQL CASCADE handles this) - if s.db.Dialector.Name() != "postgres" { + if s.DB().Dialector.Name() != "postgres" { if err := tx.Where("version_id = ?", id).Delete(&tables.TablePromptVersionMessage{}).Error; err != nil { return err } @@ -413,7 +413,7 @@ func (s *RDBConfigStore) DeletePromptVersion(ctx context.Context, id uint) error // GetPromptSessions gets all sessions for a prompt func (s *RDBConfigStore) GetPromptSessions(ctx context.Context, promptID string) ([]tables.TablePromptSession, error) { var sessions []tables.TablePromptSession - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }). Preload("Version"). Where("prompt_id = ?", promptID). @@ -427,7 +427,7 @@ func (s *RDBConfigStore) GetPromptSessions(ctx context.Context, promptID string) // GetPromptSessionByID gets a session by ID func (s *RDBConfigStore) GetPromptSessionByID(ctx context.Context, id uint) (*tables.TablePromptSession, error) { var session tables.TablePromptSession - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }). Preload("Prompt"). Preload("Version"). @@ -442,7 +442,7 @@ func (s *RDBConfigStore) GetPromptSessionByID(ctx context.Context, id uint) (*ta // CreatePromptSession creates a new session func (s *RDBConfigStore) CreatePromptSession(ctx context.Context, session *tables.TablePromptSession) error { - return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { // Verify version belongs to the same prompt if set if session.VersionID != nil { var version tables.TablePromptVersion @@ -484,7 +484,7 @@ func (s *RDBConfigStore) CreatePromptSession(ctx context.Context, session *table // UpdatePromptSession updates a session and its messages func (s *RDBConfigStore) UpdatePromptSession(ctx context.Context, session *tables.TablePromptSession) error { - return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { // Verify version belongs to the same prompt if set if session.VersionID != nil { var version tables.TablePromptVersion @@ -530,7 +530,7 @@ func (s *RDBConfigStore) UpdatePromptSession(ctx context.Context, session *table // RenamePromptSession updates only the name of a session func (s *RDBConfigStore) RenamePromptSession(ctx context.Context, id uint, name string) error { - result := s.db.WithContext(ctx).Model(&tables.TablePromptSession{}).Where("id = ?", id).Update("name", name) + result := s.DB().WithContext(ctx).Model(&tables.TablePromptSession{}).Where("id = ?", id).Update("name", name) if result.Error != nil { return result.Error } @@ -543,7 +543,7 @@ func (s *RDBConfigStore) RenamePromptSession(ctx context.Context, id uint, name // DeletePromptSession deletes a session and its messages. // PostgreSQL uses native ON DELETE CASCADE for messages; SQLite requires manual cascade. func (s *RDBConfigStore) DeletePromptSession(ctx context.Context, id uint) error { - return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { var session tables.TablePromptSession if err := tx.First(&session, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -553,7 +553,7 @@ func (s *RDBConfigStore) DeletePromptSession(ctx context.Context, id uint) error } // PostgreSQL: ON DELETE CASCADE handles message deletion - if s.db.Dialector.Name() == "postgres" { + if s.DB().Dialector.Name() == "postgres" { return tx.Delete(&session).Error } diff --git a/framework/configstore/rdb.go b/framework/configstore/rdb.go index c5f1c26ed5..b19a6b3d86 100644 --- a/framework/configstore/rdb.go +++ b/framework/configstore/rdb.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "strings" + "sync/atomic" "time" "github.com/bytedance/sonic" @@ -14,16 +15,21 @@ import ( "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/maximhq/bifrost/framework/encrypt" "github.com/maximhq/bifrost/framework/logstore" - "github.com/maximhq/bifrost/framework/migrator" "github.com/maximhq/bifrost/framework/vectorstore" "gorm.io/gorm" "gorm.io/gorm/clause" ) // RDBConfigStore represents a configuration store that uses a relational database. +// +// The runtime *gorm.DB is held behind an atomic.Pointer so RefreshConnectionPool +// can swap it out without tearing callers down. migrateOnFreshFn and refreshPoolFn +// are backend-specific hooks installed by the constructor (postgres vs sqlite). type RDBConfigStore struct { - db *gorm.DB - logger schemas.Logger + db atomic.Pointer[gorm.DB] + logger schemas.Logger + migrateOnFreshFn func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error + refreshPoolFn func(ctx context.Context) error } // getWeight safely dereferences a *float64 weight pointer, returning 1.0 as default if nil. @@ -156,7 +162,7 @@ func (s *RDBConfigStore) UpdateClientConfig(ctx context.Context, config *ClientC ConfigHash: config.ConfigHash, } // Delete existing client config and create new one in a transaction - return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { if err := tx.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.TableClientConfig{}).Error; err != nil { return err } @@ -166,12 +172,51 @@ func (s *RDBConfigStore) UpdateClientConfig(ctx context.Context, config *ClientC // Ping checks if the database is reachable. func (s *RDBConfigStore) Ping(ctx context.Context) error { - return s.db.WithContext(ctx).Exec("SELECT 1").Error + return s.DB().WithContext(ctx).Exec("SELECT 1").Error } -// DB returns the underlying database connection. +// DB returns the current runtime database connection. The returned pointer is +// only valid for the duration of the caller's operation — after a +// RefreshConnectionPool call, future DB() calls return a fresh *gorm.DB backed +// by a different *sql.DB pool. Callers that issue multiple operations should +// call DB() per operation rather than caching the pointer. func (s *RDBConfigStore) DB() *gorm.DB { - return s.db + return s.db.Load() +} + +// RunMigration opens a throwaway connection against the same +// backing database, invokes fn with it, and closes the connection. Use this +// for DDL that must not leave cached prepared-statement plans on the runtime +// pool. After fn returns, callers should invoke RefreshConnectionPool if the +// migration altered tables the runtime pool has already queried. +// +// For SQLite, the throwaway concept doesn't apply (no server-side plan cache, +// single-writer file lock), so this runs fn against the existing *gorm.DB. +// +// Returns an error if the store was constructed without a migration hook +// wired — e.g. a direct `&RDBConfigStore{}` literal that skipped the +// newPostgresConfigStore / newSqliteConfigStore constructor. An explicit +// error is safer than a silent fallback to the runtime pool: running DDL +// on the runtime pool would reintroduce SQLSTATE 0A000. +func (s *RDBConfigStore) RunMigration(ctx context.Context, fn func(context.Context, *gorm.DB) error) error { + if s.migrateOnFreshFn == nil { + return fmt.Errorf("configstore: migration hook is not configured; construct the store via newPostgresConfigStore or newSqliteConfigStore") + } + return s.migrateOnFreshFn(ctx, fn) +} + +// RefreshConnectionPool closes the runtime pool and opens a fresh one against +// the same configuration. In-flight queries on the old pool complete before +// it closes; subsequent DB() calls return the new pool, whose connections +// carry no cached plans. SQLite is a no-op. +// +// Returns an error if the store was constructed without a refresh hook wired +// (same rationale as RunMigration). +func (s *RDBConfigStore) RefreshConnectionPool(ctx context.Context) error { + if s.refreshPoolFn == nil { + return fmt.Errorf("configstore: refresh hook is not configured; construct the store via newPostgresConfigStore or newSqliteConfigStore") + } + return s.refreshPoolFn(ctx) } // parseGormError parses GORM errors to provide user-friendly error messages. @@ -273,7 +318,7 @@ func (s *RDBConfigStore) UpdateFrameworkConfig(ctx context.Context, config *tabl // GetFrameworkConfig retrieves the framework configuration from the database. func (s *RDBConfigStore) GetFrameworkConfig(ctx context.Context) (*tables.TableFrameworkConfig, error) { var dbConfig tables.TableFrameworkConfig - if err := s.db.WithContext(ctx).First(&dbConfig).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&dbConfig).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } @@ -285,7 +330,7 @@ func (s *RDBConfigStore) GetFrameworkConfig(ctx context.Context) (*tables.TableF // GetClientConfig retrieves the client configuration from the database. func (s *RDBConfigStore) GetClientConfig(ctx context.Context) (*ClientConfig, error) { var dbConfig tables.TableClientConfig - if err := s.db.WithContext(ctx).First(&dbConfig).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&dbConfig).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } @@ -334,7 +379,7 @@ func (s *RDBConfigStore) UpdateProvidersConfig(ctx context.Context, providers ma if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } for providerName, providerConfig := range providers { dbProvider := tables.TableProvider{ @@ -497,7 +542,7 @@ func (s *RDBConfigStore) UpdateProvider(ctx context.Context, provider schemas.Mo if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } // Find the existing provider var dbProvider tables.TableProvider @@ -648,7 +693,7 @@ func (s *RDBConfigStore) AddProvider(ctx context.Context, provider schemas.Model if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } // Create a deep copy of the config to avoid modifying the original configCopy, err := deepCopy(config) @@ -748,7 +793,7 @@ func (s *RDBConfigStore) DeleteProvider(ctx context.Context, provider schemas.Mo if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } // Find the existing provider var dbProvider tables.TableProvider @@ -790,7 +835,7 @@ func (s *RDBConfigStore) DeleteProvider(ctx context.Context, provider schemas.Mo // GetProvidersConfig retrieves the provider configuration from the database. func (s *RDBConfigStore) GetProvidersConfig(ctx context.Context) (map[schemas.ModelProvider]ProviderConfig, error) { var dbProviders []tables.TableProvider - if err := s.db.WithContext(ctx).Preload("Keys").Find(&dbProviders).Error; err != nil { + if err := s.DB().WithContext(ctx).Preload("Keys").Find(&dbProviders).Error; err != nil { return nil, err } if len(dbProviders) == 0 { @@ -827,7 +872,7 @@ func (s *RDBConfigStore) GetProvidersConfig(ctx context.Context) (map[schemas.Mo // GetProviderConfig retrieves the provider configuration from the database. func (s *RDBConfigStore) GetProviderConfig(ctx context.Context, provider schemas.ModelProvider) (*ProviderConfig, error) { var dbProvider tables.TableProvider - if err := s.db.WithContext(ctx).Preload("Keys").Where("name = ?", string(provider)).First(&dbProvider).Error; err != nil { + if err := s.DB().WithContext(ctx).Preload("Keys").Where("name = ?", string(provider)).First(&dbProvider).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } @@ -857,7 +902,7 @@ func (s *RDBConfigStore) GetProviderConfig(ctx context.Context, provider schemas // GetProviderKeys retrieves all keys for a provider ordered by creation time. func (s *RDBConfigStore) GetProviderKeys(ctx context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { var dbKeys []tables.TableKey - result := s.db.WithContext(ctx). + result := s.DB().WithContext(ctx). Table("config_providers"). Select("config_keys.*"). Joins("LEFT JOIN config_keys ON config_keys.provider_id = config_providers.id"). @@ -906,7 +951,7 @@ func (s *RDBConfigStore) getProviderKeyByName(ctx context.Context, txDB *gorm.DB // GetProviderKey retrieves a single key for a provider. func (s *RDBConfigStore) GetProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string) (*schemas.Key, error) { - dbKey, err := s.getProviderKeyByName(ctx, s.db, provider, keyID) + dbKey, err := s.getProviderKeyByName(ctx, s.DB(), provider, keyID) if err != nil { return nil, err } @@ -921,7 +966,7 @@ func (s *RDBConfigStore) CreateProviderKey(ctx context.Context, provider schemas if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } var dbProvider tables.TableProvider if err := txDB.WithContext(ctx).Where("name = ?", string(provider)).First(&dbProvider).Error; err != nil { @@ -946,7 +991,7 @@ func (s *RDBConfigStore) UpdateProviderKey(ctx context.Context, provider schemas if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } existingKey, err := s.getProviderKeyByName(ctx, txDB, provider, keyID) @@ -982,7 +1027,7 @@ func (s *RDBConfigStore) DeleteProviderKey(ctx context.Context, provider schemas if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } providerIDSubquery := txDB.Model(&tables.TableProvider{}). @@ -1005,7 +1050,7 @@ func (s *RDBConfigStore) DeleteProviderKey(ctx context.Context, provider schemas // GetProviders retrieves all providers from the database with their governance relationships. func (s *RDBConfigStore) GetProviders(ctx context.Context) ([]tables.TableProvider, error) { var providers []tables.TableProvider - if err := s.db.WithContext(ctx).Preload("Budget").Preload("RateLimit").Find(&providers).Error; err != nil { + if err := s.DB().WithContext(ctx).Preload("Budget").Preload("RateLimit").Find(&providers).Error; err != nil { return nil, err } return providers, nil @@ -1014,7 +1059,7 @@ func (s *RDBConfigStore) GetProviders(ctx context.Context) ([]tables.TableProvid // GetProvider retrieves a provider by name from the database with governance relationships. func (s *RDBConfigStore) GetProvider(ctx context.Context, provider schemas.ModelProvider) (*tables.TableProvider, error) { var providerInfo tables.TableProvider - if err := s.db.WithContext(ctx).Preload("Budget").Preload("RateLimit").Where("name = ?", string(provider)).First(&providerInfo).Error; err != nil { + if err := s.DB().WithContext(ctx).Preload("Budget").Preload("RateLimit").Where("name = ?", string(provider)).First(&providerInfo).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } @@ -1026,7 +1071,7 @@ func (s *RDBConfigStore) GetProvider(ctx context.Context, provider schemas.Model // GetProviderByName retrieves a provider by name from the database with governance relationships. func (s *RDBConfigStore) GetProviderByName(ctx context.Context, name string) (*tables.TableProvider, error) { var provider tables.TableProvider - if err := s.db.WithContext(ctx).Preload("Budget").Preload("RateLimit").Where("name = ?", name).First(&provider).Error; err != nil { + if err := s.DB().WithContext(ctx).Preload("Budget").Preload("RateLimit").Where("name = ?", name).First(&provider).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } @@ -1041,7 +1086,7 @@ func (s *RDBConfigStore) GetProviderByName(ctx context.Context, name string) (*t func (s *RDBConfigStore) UpdateStatus(ctx context.Context, provider schemas.ModelProvider, keyID string, status, description string) error { // Update key-level status (for keyed providers) if keyID != "" { - result := s.db.WithContext(ctx). + result := s.DB().WithContext(ctx). Model(&tables.TableKey{}). Where("key_id = ?", keyID). Updates(map[string]interface{}{ @@ -1059,7 +1104,7 @@ func (s *RDBConfigStore) UpdateStatus(ctx context.Context, provider schemas.Mode // Update provider-level status (for keyless providers) if provider != "" { - result := s.db.WithContext(ctx). + result := s.DB().WithContext(ctx). Model(&tables.TableProvider{}). Where("name = ?", string(provider)). Updates(map[string]interface{}{ @@ -1082,14 +1127,14 @@ func (s *RDBConfigStore) UpdateStatus(ctx context.Context, provider schemas.Mode func (s *RDBConfigStore) GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, error) { var dbMCPClients []tables.TableMCPClient // Get all MCP clients - if err := s.db.WithContext(ctx).Find(&dbMCPClients).Error; err != nil { + if err := s.DB().WithContext(ctx).Find(&dbMCPClients).Error; err != nil { return nil, err } if len(dbMCPClients) == 0 { return nil, nil } var clientConfig tables.TableClientConfig - if err := s.db.WithContext(ctx).First(&clientConfig).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&clientConfig).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { // Return MCP config with default ToolManagerConfig if no client config exists // This will never happen, but just in case. @@ -1163,7 +1208,7 @@ func (s *RDBConfigStore) GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, // GetMCPClientsPaginated retrieves MCP clients with pagination and optional search. func (s *RDBConfigStore) GetMCPClientsPaginated(ctx context.Context, params MCPClientsQueryParams) ([]tables.TableMCPClient, int64, error) { - baseQuery := s.db.WithContext(ctx).Model(&tables.TableMCPClient{}) + baseQuery := s.DB().WithContext(ctx).Model(&tables.TableMCPClient{}) if params.Search != "" { search := "%" + strings.ToLower(params.Search) + "%" @@ -1202,7 +1247,7 @@ func (s *RDBConfigStore) GetMCPClientsPaginated(ctx context.Context, params MCPC // GetMCPClientByID retrieves an MCP client by ID from the database. func (s *RDBConfigStore) GetMCPClientByID(ctx context.Context, id string) (*tables.TableMCPClient, error) { var mcpClient tables.TableMCPClient - if err := s.db.WithContext(ctx).Where("client_id = ?", id).First(&mcpClient).Error; err != nil { + if err := s.DB().WithContext(ctx).Where("client_id = ?", id).First(&mcpClient).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } @@ -1214,7 +1259,7 @@ func (s *RDBConfigStore) GetMCPClientByID(ctx context.Context, id string) (*tabl // GetMCPClientByName retrieves an MCP client by name from the database. func (s *RDBConfigStore) GetMCPClientByName(ctx context.Context, name string) (*tables.TableMCPClient, error) { var mcpClient tables.TableMCPClient - if err := s.db.WithContext(ctx).Where("name = ?", name).First(&mcpClient).Error; err != nil { + if err := s.DB().WithContext(ctx).Where("name = ?", name).First(&mcpClient).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } @@ -1225,7 +1270,7 @@ func (s *RDBConfigStore) GetMCPClientByName(ctx context.Context, name string) (* // CreateMCPClientConfig creates a new MCP client configuration in the database. func (s *RDBConfigStore) CreateMCPClientConfig(ctx context.Context, clientConfig *schemas.MCPClientConfig) error { - return s.db.Transaction(func(tx *gorm.DB) error { + return s.DB().Transaction(func(tx *gorm.DB) error { // Check if a client with the same name already exists if _, err := s.GetMCPClientByName(ctx, clientConfig.Name); err == nil { return fmt.Errorf("MCP client with name '%s' already exists", clientConfig.Name) @@ -1262,7 +1307,7 @@ func (s *RDBConfigStore) CreateMCPClientConfig(ctx context.Context, clientConfig // UpdateMCPClientConfig updates an existing MCP client configuration in the database. func (s *RDBConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, clientConfig *tables.TableMCPClient) error { - return s.db.Transaction(func(tx *gorm.DB) error { + return s.DB().Transaction(func(tx *gorm.DB) error { // Find existing client var existingClient tables.TableMCPClient if err := tx.WithContext(ctx).Where("client_id = ?", id).First(&existingClient).Error; err != nil { @@ -1376,7 +1421,7 @@ func (s *RDBConfigStore) UpdateMCPClientDiscoveredTools(ctx context.Context, cli if err != nil { return fmt.Errorf("failed to marshal tool name mapping: %w", err) } - return s.db.WithContext(ctx). + return s.DB().WithContext(ctx). Model(&tables.TableMCPClient{}). Where("client_id = ?", clientID). Updates(map[string]interface{}{ @@ -1388,7 +1433,7 @@ func (s *RDBConfigStore) UpdateMCPClientDiscoveredTools(ctx context.Context, cli // DeleteMCPClientConfig deletes an MCP client configuration from the database. func (s *RDBConfigStore) DeleteMCPClientConfig(ctx context.Context, id string) error { - return s.db.Transaction(func(tx *gorm.DB) error { + return s.DB().Transaction(func(tx *gorm.DB) error { // Find existing client var existingClient tables.TableMCPClient if err := tx.WithContext(ctx).Where("client_id = ?", id).First(&existingClient).Error; err != nil { @@ -1411,7 +1456,7 @@ func (s *RDBConfigStore) DeleteMCPClientConfig(ctx context.Context, id string) e // GetVectorStoreConfig retrieves the vector store configuration from the database. func (s *RDBConfigStore) GetVectorStoreConfig(ctx context.Context) (*vectorstore.Config, error) { var vectorStoreTableConfig tables.TableVectorStoreConfig - if err := s.db.WithContext(ctx).First(&vectorStoreTableConfig).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&vectorStoreTableConfig).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { // Return default cache configuration return nil, nil @@ -1427,7 +1472,7 @@ func (s *RDBConfigStore) GetVectorStoreConfig(ctx context.Context) (*vectorstore // UpdateVectorStoreConfig updates the vector store configuration in the database. func (s *RDBConfigStore) UpdateVectorStoreConfig(ctx context.Context, config *vectorstore.Config) error { - return s.db.Transaction(func(tx *gorm.DB) error { + return s.DB().Transaction(func(tx *gorm.DB) error { // Delete existing cache config if err := tx.WithContext(ctx).Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.TableVectorStoreConfig{}).Error; err != nil { return err @@ -1449,7 +1494,7 @@ func (s *RDBConfigStore) UpdateVectorStoreConfig(ctx context.Context, config *ve // GetLogsStoreConfig retrieves the logs store configuration from the database. func (s *RDBConfigStore) GetLogsStoreConfig(ctx context.Context) (*logstore.Config, error) { var dbConfig tables.TableLogStoreConfig - if err := s.db.WithContext(ctx).First(&dbConfig).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&dbConfig).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } @@ -1467,7 +1512,7 @@ func (s *RDBConfigStore) GetLogsStoreConfig(ctx context.Context) (*logstore.Conf // UpdateLogsStoreConfig updates the logs store configuration in the database. func (s *RDBConfigStore) UpdateLogsStoreConfig(ctx context.Context, config *logstore.Config) error { - return s.db.Transaction(func(tx *gorm.DB) error { + return s.DB().Transaction(func(tx *gorm.DB) error { if err := tx.WithContext(ctx).Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.TableLogStoreConfig{}).Error; err != nil { return err } @@ -1487,7 +1532,7 @@ func (s *RDBConfigStore) UpdateLogsStoreConfig(ctx context.Context, config *logs // GetConfig retrieves a specific config from the database. func (s *RDBConfigStore) GetConfig(ctx context.Context, key string) (*tables.TableGovernanceConfig, error) { var config tables.TableGovernanceConfig - if err := s.db.WithContext(ctx).First(&config, "key = ?", key).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&config, "key = ?", key).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } @@ -1502,7 +1547,7 @@ func (s *RDBConfigStore) UpdateConfig(ctx context.Context, config *tables.TableG if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } return txDB.WithContext(ctx).Save(config).Error } @@ -1510,7 +1555,7 @@ func (s *RDBConfigStore) UpdateConfig(ctx context.Context, config *tables.TableG // GetModelPrices retrieves all model pricing records from the database. func (s *RDBConfigStore) GetModelPrices(ctx context.Context) ([]tables.TableModelPricing, error) { var modelPrices []tables.TableModelPricing - if err := s.db.WithContext(ctx).Find(&modelPrices).Error; err != nil { + if err := s.DB().WithContext(ctx).Find(&modelPrices).Error; err != nil { return nil, err } return modelPrices, nil @@ -1524,7 +1569,7 @@ func (s *RDBConfigStore) UpsertModelPrices(ctx context.Context, pricing *tables. if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } db := txDB.WithContext(ctx) @@ -1543,14 +1588,14 @@ func (s *RDBConfigStore) DeleteModelPrices(ctx context.Context, tx ...*gorm.DB) if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } return txDB.WithContext(ctx).Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.TableModelPricing{}).Error } func (s *RDBConfigStore) GetPricingOverrides(ctx context.Context, filters PricingOverrideFilters) ([]tables.TablePricingOverride, error) { var overrides []tables.TablePricingOverride - q := s.db.WithContext(ctx).Model(&tables.TablePricingOverride{}) + q := s.DB().WithContext(ctx).Model(&tables.TablePricingOverride{}) if filters.ScopeKind != nil { q = q.Where("scope_kind = ?", *filters.ScopeKind) } @@ -1570,7 +1615,7 @@ func (s *RDBConfigStore) GetPricingOverrides(ctx context.Context, filters Pricin } func (s *RDBConfigStore) GetPricingOverridesPaginated(ctx context.Context, params PricingOverridesQueryParams) ([]tables.TablePricingOverride, int64, error) { - baseQuery := s.db.WithContext(ctx).Model(&tables.TablePricingOverride{}) + baseQuery := s.DB().WithContext(ctx).Model(&tables.TablePricingOverride{}) if params.Search != "" { search := "%" + strings.ToLower(params.Search) + "%" @@ -1620,7 +1665,7 @@ func (s *RDBConfigStore) GetPricingOverridesPaginated(ctx context.Context, param func (s *RDBConfigStore) GetPricingOverrideByID(ctx context.Context, id string) (*tables.TablePricingOverride, error) { var override tables.TablePricingOverride - if err := s.db.WithContext(ctx).First(&override, "id = ?", id).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&override, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } @@ -1634,7 +1679,7 @@ func (s *RDBConfigStore) CreatePricingOverride(ctx context.Context, override *ta if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Create(override).Error; err != nil { return s.parseGormError(err) @@ -1647,7 +1692,7 @@ func (s *RDBConfigStore) UpdatePricingOverride(ctx context.Context, override *ta if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Save(override).Error; err != nil { return s.parseGormError(err) @@ -1660,7 +1705,7 @@ func (s *RDBConfigStore) DeletePricingOverride(ctx context.Context, id string, t if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } res := txDB.WithContext(ctx).Delete(&tables.TablePricingOverride{}, "id = ?", id) if res.Error != nil { @@ -1677,7 +1722,7 @@ func (s *RDBConfigStore) DeletePricingOverride(ctx context.Context, id string, t // GetModelParameters returns all stored model parameter rows. func (s *RDBConfigStore) GetModelParameters(ctx context.Context) ([]tables.TableModelParameters, error) { var rows []tables.TableModelParameters - if err := s.db.WithContext(ctx).Find(&rows).Error; err != nil { + if err := s.DB().WithContext(ctx).Find(&rows).Error; err != nil { return nil, err } return rows, nil @@ -1686,7 +1731,7 @@ func (s *RDBConfigStore) GetModelParameters(ctx context.Context) ([]tables.Table // GetModelParametersByModel retrieves model parameters for a specific model. func (s *RDBConfigStore) GetModelParametersByModel(ctx context.Context, model string) (*tables.TableModelParameters, error) { var params tables.TableModelParameters - if err := s.db.WithContext(ctx).Where("model = ?", model).First(¶ms).Error; err != nil { + if err := s.DB().WithContext(ctx).Where("model = ?", model).First(¶ms).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } @@ -1703,7 +1748,7 @@ func (s *RDBConfigStore) UpsertModelParameters(ctx context.Context, params *tabl if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } db := txDB.WithContext(ctx) @@ -1720,7 +1765,7 @@ func (s *RDBConfigStore) UpsertModelParameters(ctx context.Context, params *tabl func (s *RDBConfigStore) GetPlugins(ctx context.Context) ([]*tables.TablePlugin, error) { var plugins []*tables.TablePlugin - if err := s.db.WithContext(ctx).Find(&plugins).Error; err != nil { + if err := s.DB().WithContext(ctx).Find(&plugins).Error; err != nil { return nil, err } return plugins, nil @@ -1728,7 +1773,7 @@ func (s *RDBConfigStore) GetPlugins(ctx context.Context) ([]*tables.TablePlugin, func (s *RDBConfigStore) GetPlugin(ctx context.Context, name string) (*tables.TablePlugin, error) { var plugin tables.TablePlugin - if err := s.db.WithContext(ctx).First(&plugin, "name = ?", name).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&plugin, "name = ?", name).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } @@ -1743,7 +1788,7 @@ func (s *RDBConfigStore) CreatePlugin(ctx context.Context, plugin *tables.TableP if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } // Mark plugin as custom if path is not empty if plugin.Path != nil && strings.TrimSpace(*plugin.Path) != "" { @@ -1763,7 +1808,7 @@ func (s *RDBConfigStore) UpsertPlugin(ctx context.Context, plugin *tables.TableP if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } // Mark plugin as custom if path is not empty if plugin.Path != nil && strings.TrimSpace(*plugin.Path) != "" { @@ -1802,7 +1847,7 @@ func (s *RDBConfigStore) UpdatePlugin(ctx context.Context, plugin *tables.TableP txDB = tx[0] localTx = false } else { - txDB = s.db.Begin() + txDB = s.DB().Begin() localTx = true } // Mark plugin as custom if path is not empty @@ -1835,7 +1880,7 @@ func (s *RDBConfigStore) DeletePlugin(ctx context.Context, name string, tx ...*g if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } return txDB.WithContext(ctx).Delete(&tables.TablePlugin{}, "name = ?", name).Error } @@ -1847,12 +1892,12 @@ func (s *RDBConfigStore) GetRedactedVirtualKeys(ctx context.Context, ids []strin var virtualKeys []tables.TableVirtualKey if len(ids) > 0 { - err := s.db.WithContext(ctx).Select("id, name, description, is_active").Where("id IN ?", ids).Find(&virtualKeys).Error + err := s.DB().WithContext(ctx).Select("id, name, description, is_active").Where("id IN ?", ids).Find(&virtualKeys).Error if err != nil { return nil, err } } else { - err := s.db.WithContext(ctx).Select("id, name, description, is_active").Find(&virtualKeys).Error + err := s.DB().WithContext(ctx).Select("id, name, description, is_active").Find(&virtualKeys).Error if err != nil { return nil, err } @@ -1903,7 +1948,7 @@ func (s *RDBConfigStore) GetVirtualKeys(ctx context.Context) ([]tables.TableVirt var virtualKeys []tables.TableVirtualKey // Preload all relationships for complete information - if err := preloadVirtualKeyBaseRelations(s.db.WithContext(ctx)). + if err := preloadVirtualKeyBaseRelations(s.DB().WithContext(ctx)). Order("created_at ASC"). Find(&virtualKeys).Error; err != nil { return nil, err @@ -1914,7 +1959,7 @@ func (s *RDBConfigStore) GetVirtualKeys(ctx context.Context) ([]tables.TableVirt // GetVirtualKeysPaginated retrieves virtual keys with pagination, filtering, and search support. func (s *RDBConfigStore) GetVirtualKeysPaginated(ctx context.Context, params VirtualKeyQueryParams) ([]tables.TableVirtualKey, int64, error) { // Build base query with filters - baseQuery := s.db.WithContext(ctx).Model(&tables.TableVirtualKey{}) + baseQuery := s.DB().WithContext(ctx).Model(&tables.TableVirtualKey{}) // Virtual keys are either customer-scoped or team-scoped, never both. // When both filters are provided, use OR to match keys belonging to either. @@ -1998,7 +2043,7 @@ func (s *RDBConfigStore) GetVirtualKeysPaginated(ctx context.Context, params Vir // GetVirtualKey retrieves a virtual key from the database. func (s *RDBConfigStore) GetVirtualKey(ctx context.Context, id string) (*tables.TableVirtualKey, error) { var virtualKey tables.TableVirtualKey - if err := preloadVirtualKeyDetailRelations(s.db.WithContext(ctx)). + if err := preloadVirtualKeyDetailRelations(s.DB().WithContext(ctx)). First(&virtualKey, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound @@ -2012,7 +2057,7 @@ func (s *RDBConfigStore) GetVirtualKey(ctx context.Context, id string) (*tables. func (s *RDBConfigStore) GetVirtualKeyByValue(ctx context.Context, value string) (*tables.TableVirtualKey, error) { valueHash := encrypt.HashSHA256(value) var virtualKey tables.TableVirtualKey - query := preloadVirtualKeyBaseRelations(s.db.WithContext(ctx)) + query := preloadVirtualKeyBaseRelations(s.DB().WithContext(ctx)) // Use hash-based lookup if hash column is populated, fall back to plaintext for backward compat if err := query.Where("value_hash = ?", valueHash).First(&virtualKey).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -2035,7 +2080,7 @@ func (s *RDBConfigStore) GetVirtualKeyByValue(ctx context.Context, value string) func (s *RDBConfigStore) GetVirtualKeyQuotaByValue(ctx context.Context, value string) (*tables.TableVirtualKey, error) { valueHash := encrypt.HashSHA256(value) var virtualKey tables.TableVirtualKey - baseQuery := s.db.WithContext(ctx).Preload("Budgets").Preload("RateLimit") + baseQuery := s.DB().WithContext(ctx).Preload("Budgets").Preload("RateLimit") if err := baseQuery.Session(&gorm.Session{}).Where("value_hash = ?", valueHash).First(&virtualKey).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { // Fallback: try plaintext lookup for rows not yet migrated @@ -2058,7 +2103,7 @@ func (s *RDBConfigStore) CreateVirtualKey(ctx context.Context, virtualKey *table if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Create(virtualKey).Error; err != nil { return s.parseGormError(err) @@ -2072,7 +2117,7 @@ func (s *RDBConfigStore) UpdateVirtualKey(ctx context.Context, virtualKey *table if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } // Check if record exists by ID or Name @@ -2106,7 +2151,7 @@ func (s *RDBConfigStore) GetKeysByIDs(ctx context.Context, ids []string) ([]tabl return []tables.TableKey{}, nil } var keys []tables.TableKey - if err := s.db.WithContext(ctx).Where("key_id IN ?", ids).Find(&keys).Error; err != nil { + if err := s.DB().WithContext(ctx).Where("key_id IN ?", ids).Find(&keys).Error; err != nil { return nil, err } return keys, nil @@ -2115,7 +2160,7 @@ func (s *RDBConfigStore) GetKeysByIDs(ctx context.Context, ids []string) ([]tabl // GetKeysByProvider retrieves all keys for a specific provider func (s *RDBConfigStore) GetKeysByProvider(ctx context.Context, provider string) ([]tables.TableKey, error) { var keys []tables.TableKey - if err := s.db.WithContext(ctx).Where("provider = ?", provider).Find(&keys).Error; err != nil { + if err := s.DB().WithContext(ctx).Where("provider = ?", provider).Find(&keys).Error; err != nil { return nil, err } return keys, nil @@ -2125,12 +2170,12 @@ func (s *RDBConfigStore) GetKeysByProvider(ctx context.Context, provider string) func (s *RDBConfigStore) GetAllRedactedKeys(ctx context.Context, ids []string) ([]schemas.Key, error) { var keys []tables.TableKey if len(ids) > 0 { - err := s.db.WithContext(ctx).Select("id, key_id, name, models_json, blacklisted_models_json, weight").Where("key_id IN ?", ids).Find(&keys).Error + err := s.DB().WithContext(ctx).Select("id, key_id, name, models_json, blacklisted_models_json, weight").Where("key_id IN ?", ids).Find(&keys).Error if err != nil { return nil, err } } else { - err := s.db.WithContext(ctx).Select("id, key_id, name, models_json, blacklisted_models_json, weight").Find(&keys).Error + err := s.DB().WithContext(ctx).Select("id, key_id, name, models_json, blacklisted_models_json, weight").Find(&keys).Error if err != nil { return nil, err } @@ -2158,7 +2203,7 @@ func (s *RDBConfigStore) GetAllRedactedKeys(ctx context.Context, ids []string) ( // DeleteVirtualKey deletes a virtual key from the database. func (s *RDBConfigStore) DeleteVirtualKey(ctx context.Context, id string) error { - if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { var virtualKey tables.TableVirtualKey if err := tx.WithContext(ctx).Preload("ProviderConfigs").First(&virtualKey, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -2243,7 +2288,7 @@ func (s *RDBConfigStore) DeleteVirtualKey(ctx context.Context, id string) error // GetVirtualKeyProviderConfigs retrieves all virtual key provider configs from the database. func (s *RDBConfigStore) GetVirtualKeyProviderConfigs(ctx context.Context, virtualKeyID string) ([]tables.TableVirtualKeyProviderConfig, error) { var virtualKey tables.TableVirtualKey - if err := s.db.WithContext(ctx).First(&virtualKey, "id = ?", virtualKeyID).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&virtualKey, "id = ?", virtualKeyID).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return []tables.TableVirtualKeyProviderConfig{}, nil } @@ -2253,7 +2298,7 @@ func (s *RDBConfigStore) GetVirtualKeyProviderConfigs(ctx context.Context, virtu return nil, nil } var providerConfigs []tables.TableVirtualKeyProviderConfig - if err := s.db.WithContext(ctx).Where("virtual_key_id = ?", virtualKey.ID).Find(&providerConfigs).Error; err != nil { + if err := s.DB().WithContext(ctx).Where("virtual_key_id = ?", virtualKey.ID).Find(&providerConfigs).Error; err != nil { return nil, err } return providerConfigs, nil @@ -2265,7 +2310,7 @@ func (s *RDBConfigStore) CreateVirtualKeyProviderConfig(ctx context.Context, vir if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } // Store keys before create keysToAssociate := virtualKeyProviderConfig.Keys @@ -2336,7 +2381,7 @@ func (s *RDBConfigStore) UpdateVirtualKeyProviderConfig(ctx context.Context, vir if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } // Store keys before save @@ -2411,7 +2456,7 @@ func (s *RDBConfigStore) DeleteVirtualKeyProviderConfig(ctx context.Context, id if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } // First fetch the provider config to get budget and rate limit IDs var providerConfig tables.TableVirtualKeyProviderConfig @@ -2443,7 +2488,7 @@ func (s *RDBConfigStore) DeleteVirtualKeyProviderConfig(ctx context.Context, id // GetVirtualKeyMCPConfigs retrieves all virtual key MCP configs from the database. func (s *RDBConfigStore) GetVirtualKeyMCPConfigs(ctx context.Context, virtualKeyID string) ([]tables.TableVirtualKeyMCPConfig, error) { var virtualKey tables.TableVirtualKey - if err := s.db.WithContext(ctx).First(&virtualKey, "id = ?", virtualKeyID).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&virtualKey, "id = ?", virtualKeyID).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return []tables.TableVirtualKeyMCPConfig{}, nil } @@ -2453,7 +2498,7 @@ func (s *RDBConfigStore) GetVirtualKeyMCPConfigs(ctx context.Context, virtualKey return nil, nil } var mcpConfigs []tables.TableVirtualKeyMCPConfig - if err := s.db.WithContext(ctx).Preload("MCPClient").Where("virtual_key_id = ?", virtualKey.ID).Find(&mcpConfigs).Error; err != nil { + if err := s.DB().WithContext(ctx).Preload("MCPClient").Where("virtual_key_id = ?", virtualKey.ID).Find(&mcpConfigs).Error; err != nil { return nil, err } return mcpConfigs, nil @@ -2462,7 +2507,7 @@ func (s *RDBConfigStore) GetVirtualKeyMCPConfigs(ctx context.Context, virtualKey // GetVirtualKeyMCPConfigsByMCPClientID retrieves all VK MCP configs for a given MCP client. func (s *RDBConfigStore) GetVirtualKeyMCPConfigsByMCPClientID(ctx context.Context, mcpClientID uint) ([]tables.TableVirtualKeyMCPConfig, error) { var configs []tables.TableVirtualKeyMCPConfig - if err := s.db.WithContext(ctx).Where("mcp_client_id = ?", mcpClientID).Find(&configs).Error; err != nil { + if err := s.DB().WithContext(ctx).Where("mcp_client_id = ?", mcpClientID).Find(&configs).Error; err != nil { return nil, err } return configs, nil @@ -2474,7 +2519,7 @@ func (s *RDBConfigStore) GetVirtualKeyMCPConfigsByMCPClientIDs(ctx context.Conte return nil, nil } var configs []tables.TableVirtualKeyMCPConfig - if err := s.db.WithContext(ctx).Where("mcp_client_id IN ?", mcpClientIDs).Find(&configs).Error; err != nil { + if err := s.DB().WithContext(ctx).Where("mcp_client_id IN ?", mcpClientIDs).Find(&configs).Error; err != nil { return nil, err } return configs, nil @@ -2487,7 +2532,7 @@ func (s *RDBConfigStore) GetVirtualKeyMCPConfigsByMCPClientStringIDs(ctx context return nil, nil } var configs []tables.TableVirtualKeyMCPConfig - err := s.db.WithContext(ctx). + err := s.DB().WithContext(ctx). Preload("MCPClient"). Joins("JOIN config_mcp_clients ON config_mcp_clients.id = governance_virtual_key_mcp_configs.mcp_client_id"). Where("config_mcp_clients.client_id IN ?", clientIDs). @@ -2504,7 +2549,7 @@ func (s *RDBConfigStore) CreateVirtualKeyMCPConfig(ctx context.Context, virtualK if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Create(virtualKeyMCPConfig).Error; err != nil { return s.parseGormError(err) @@ -2518,7 +2563,7 @@ func (s *RDBConfigStore) UpdateVirtualKeyMCPConfig(ctx context.Context, virtualK if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Save(virtualKeyMCPConfig).Error; err != nil { return s.parseGormError(err) @@ -2532,7 +2577,7 @@ func (s *RDBConfigStore) DeleteVirtualKeyMCPConfig(ctx context.Context, id uint, if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } return txDB.WithContext(ctx).Delete(&tables.TableVirtualKeyMCPConfig{}, "id = ?", id).Error } @@ -2542,7 +2587,7 @@ const teamSelectWithVKCount = "governance_teams.*, (SELECT COUNT(*) FROM governa // GetTeams retrieves all teams from the database. func (s *RDBConfigStore) GetTeams(ctx context.Context, customerID string) ([]tables.TableTeam, error) { // Preload relationships for complete information - query := s.db.WithContext(ctx). + query := s.DB().WithContext(ctx). Select(teamSelectWithVKCount). Preload("Customer").Preload("Budget").Preload("RateLimit") // Optional filtering by customer @@ -2558,7 +2603,7 @@ func (s *RDBConfigStore) GetTeams(ctx context.Context, customerID string) ([]tab // GetTeamsPaginated retrieves teams with pagination, filtering, and search support. func (s *RDBConfigStore) GetTeamsPaginated(ctx context.Context, params TeamsQueryParams) ([]tables.TableTeam, int64, error) { - baseQuery := s.db.WithContext(ctx).Model(&tables.TableTeam{}) + baseQuery := s.DB().WithContext(ctx).Model(&tables.TableTeam{}) if params.CustomerID != "" { baseQuery = baseQuery.Where("customer_id = ?", params.CustomerID) @@ -2600,7 +2645,7 @@ func (s *RDBConfigStore) GetTeamsPaginated(ctx context.Context, params TeamsQuer // GetTeam retrieves a specific team from the database. func (s *RDBConfigStore) GetTeam(ctx context.Context, id string) (*tables.TableTeam, error) { var team tables.TableTeam - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Select(teamSelectWithVKCount). Preload("Customer").Preload("Budget").Preload("RateLimit"). First(&team, "id = ?", id).Error; err != nil { @@ -2618,7 +2663,7 @@ func (s *RDBConfigStore) CreateTeam(ctx context.Context, team *tables.TableTeam, if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Create(team).Error; err != nil { return s.parseGormError(err) @@ -2632,7 +2677,7 @@ func (s *RDBConfigStore) UpdateTeam(ctx context.Context, team *tables.TableTeam, if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Save(team).Error; err != nil { return s.parseGormError(err) @@ -2642,7 +2687,7 @@ func (s *RDBConfigStore) UpdateTeam(ctx context.Context, team *tables.TableTeam, // DeleteTeam deletes a team from the database. func (s *RDBConfigStore) DeleteTeam(ctx context.Context, id string) error { - if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { var team tables.TableTeam if err := tx.WithContext(ctx).Preload("Budget").Preload("RateLimit").First(&team, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -2689,7 +2734,7 @@ func (s *RDBConfigStore) DeleteTeam(ctx context.Context, id string) error { // GetCustomers retrieves all customers from the database. func (s *RDBConfigStore) GetCustomers(ctx context.Context) ([]tables.TableCustomer, error) { var customers []tables.TableCustomer - if err := preloadCustomerRelations(s.db.WithContext(ctx), ""). + if err := preloadCustomerRelations(s.DB().WithContext(ctx), ""). Order("created_at ASC"). Find(&customers).Error; err != nil { return nil, err @@ -2699,7 +2744,7 @@ func (s *RDBConfigStore) GetCustomers(ctx context.Context) ([]tables.TableCustom // GetCustomersPaginated retrieves customers with pagination and optional search filtering. func (s *RDBConfigStore) GetCustomersPaginated(ctx context.Context, params CustomersQueryParams) ([]tables.TableCustomer, int64, error) { - baseQuery := s.db.WithContext(ctx).Model(&tables.TableCustomer{}) + baseQuery := s.DB().WithContext(ctx).Model(&tables.TableCustomer{}) if params.Search != "" { search := "%" + strings.ToLower(params.Search) + "%" baseQuery = baseQuery.Where("LOWER(name) LIKE ?", search) @@ -2731,7 +2776,7 @@ func (s *RDBConfigStore) GetCustomersPaginated(ctx context.Context, params Custo // GetCustomer retrieves a specific customer from the database. func (s *RDBConfigStore) GetCustomer(ctx context.Context, id string) (*tables.TableCustomer, error) { var customer tables.TableCustomer - if err := preloadCustomerRelations(s.db.WithContext(ctx), ""). + if err := preloadCustomerRelations(s.DB().WithContext(ctx), ""). First(&customer, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound @@ -2747,7 +2792,7 @@ func (s *RDBConfigStore) CreateCustomer(ctx context.Context, customer *tables.Ta if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Create(customer).Error; err != nil { return s.parseGormError(err) @@ -2761,7 +2806,7 @@ func (s *RDBConfigStore) UpdateCustomer(ctx context.Context, customer *tables.Ta if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Save(customer).Error; err != nil { return s.parseGormError(err) @@ -2771,7 +2816,7 @@ func (s *RDBConfigStore) UpdateCustomer(ctx context.Context, customer *tables.Ta // DeleteCustomer deletes a customer from the database. func (s *RDBConfigStore) DeleteCustomer(ctx context.Context, id string) error { - if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { var customer tables.TableCustomer if err := tx.WithContext(ctx).Preload("Budget").Preload("RateLimit").First(&customer, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -2822,7 +2867,7 @@ func (s *RDBConfigStore) DeleteCustomer(ctx context.Context, id string) error { // GetRateLimits retrieves all rate limits from the database. func (s *RDBConfigStore) GetRateLimits(ctx context.Context) ([]tables.TableRateLimit, error) { var rateLimits []tables.TableRateLimit - if err := s.db.WithContext(ctx).Order("created_at ASC").Find(&rateLimits).Error; err != nil { + if err := s.DB().WithContext(ctx).Order("created_at ASC").Find(&rateLimits).Error; err != nil { return nil, err } return rateLimits, nil @@ -2834,7 +2879,7 @@ func (s *RDBConfigStore) GetRateLimit(ctx context.Context, id string, tx ...*gor if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } var rateLimit tables.TableRateLimit if err := txDB.WithContext(ctx).First(&rateLimit, "id = ?", id).Error; err != nil { @@ -2852,7 +2897,7 @@ func (s *RDBConfigStore) CreateRateLimit(ctx context.Context, rateLimit *tables. if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Create(rateLimit).Error; err != nil { return s.parseGormError(err) @@ -2866,7 +2911,7 @@ func (s *RDBConfigStore) UpdateRateLimit(ctx context.Context, rateLimit *tables. if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Save(rateLimit).Error; err != nil { return s.parseGormError(err) @@ -2880,7 +2925,7 @@ func (s *RDBConfigStore) UpdateRateLimits(ctx context.Context, rateLimits []*tab if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } for _, rl := range rateLimits { if err := txDB.WithContext(ctx).Save(rl).Error; err != nil { @@ -2896,7 +2941,7 @@ func (s *RDBConfigStore) DeleteRateLimit(ctx context.Context, id string, tx ...* if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", id).Error; err != nil { return s.parseGormError(err) @@ -2907,7 +2952,7 @@ func (s *RDBConfigStore) DeleteRateLimit(ctx context.Context, id string, tx ...* // GetBudgets retrieves all budgets from the database. func (s *RDBConfigStore) GetBudgets(ctx context.Context) ([]tables.TableBudget, error) { var budgets []tables.TableBudget - if err := s.db.WithContext(ctx).Order("created_at ASC").Find(&budgets).Error; err != nil { + if err := s.DB().WithContext(ctx).Order("created_at ASC").Find(&budgets).Error; err != nil { return nil, err } return budgets, nil @@ -2919,7 +2964,7 @@ func (s *RDBConfigStore) GetBudget(ctx context.Context, id string, tx ...*gorm.D if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } var budget tables.TableBudget if err := txDB.WithContext(ctx).First(&budget, "id = ?", id).Error; err != nil { @@ -2937,7 +2982,7 @@ func (s *RDBConfigStore) CreateBudget(ctx context.Context, budget *tables.TableB if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Create(budget).Error; err != nil { return s.parseGormError(err) @@ -2951,7 +2996,7 @@ func (s *RDBConfigStore) UpdateBudgets(ctx context.Context, budgets []*tables.Ta if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } for _, b := range budgets { if err := txDB.WithContext(ctx).Save(b).Error; err != nil { @@ -2967,7 +3012,7 @@ func (s *RDBConfigStore) UpdateBudget(ctx context.Context, budget *tables.TableB if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Save(budget).Error; err != nil { return s.parseGormError(err) @@ -2981,7 +3026,7 @@ func (s *RDBConfigStore) DeleteBudget(ctx context.Context, id string, tx ...*gor if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", id).Error; err != nil { return s.parseGormError(err) @@ -2992,7 +3037,7 @@ func (s *RDBConfigStore) DeleteBudget(ctx context.Context, id string, tx ...*gor // UpdateBudgetUsage updates only the current_usage field of a budget. // Uses SkipHooks to avoid triggering BeforeSave validation since we're only updating usage. func (s *RDBConfigStore) UpdateBudgetUsage(ctx context.Context, id string, currentUsage float64) error { - result := s.db.WithContext(ctx). + result := s.DB().WithContext(ctx). Session(&gorm.Session{SkipHooks: true}). Model(&tables.TableBudget{}). Where("id = ?", id). @@ -3009,7 +3054,7 @@ func (s *RDBConfigStore) UpdateBudgetUsage(ctx context.Context, id string, curre // UpdateRateLimitUsage updates only the usage fields of a rate limit. // Uses SkipHooks to avoid triggering BeforeSave validation since we're only updating usage. func (s *RDBConfigStore) UpdateRateLimitUsage(ctx context.Context, id string, tokenCurrentUsage int64, requestCurrentUsage int64) error { - result := s.db.WithContext(ctx). + result := s.DB().WithContext(ctx). Session(&gorm.Session{SkipHooks: true}). Model(&tables.TableRateLimit{}). Where("id = ?", id). @@ -3029,7 +3074,7 @@ func (s *RDBConfigStore) UpdateRateLimitUsage(ctx context.Context, id string, to // loadRoutingRulesOrdered loads routing rules with Targets preloaded, using consistent ordering: // rules by priority ASC, created_at DESC, id ASC; targets by weight DESC for deterministic ordering. func (s *RDBConfigStore) loadRoutingRulesOrdered(ctx context.Context, dest *[]tables.TableRoutingRule, scopes ...func(*gorm.DB) *gorm.DB) error { - q := s.db.WithContext(ctx). + q := s.DB().WithContext(ctx). Preload("Targets", func(db *gorm.DB) *gorm.DB { return db.Order("weight DESC"). Order("COALESCE(provider, '') ASC"). @@ -3054,7 +3099,7 @@ func (s *RDBConfigStore) GetRoutingRules(ctx context.Context) ([]tables.TableRou // GetRoutingRulesPaginated retrieves routing rules with pagination and optional search filtering. func (s *RDBConfigStore) GetRoutingRulesPaginated(ctx context.Context, params RoutingRulesQueryParams) ([]tables.TableRoutingRule, int64, error) { - baseQuery := s.db.WithContext(ctx).Model(&tables.TableRoutingRule{}) + baseQuery := s.DB().WithContext(ctx).Model(&tables.TableRoutingRule{}) if params.Search != "" { search := "%" + strings.ToLower(params.Search) + "%" @@ -3135,12 +3180,12 @@ func (s *RDBConfigStore) GetRedactedRoutingRules(ctx context.Context, ids []stri var routingRules []tables.TableRoutingRule if len(ids) > 0 { - err := s.db.WithContext(ctx).Select("id, name, description, enabled").Where("id IN ?", ids).Find(&routingRules).Error + err := s.DB().WithContext(ctx).Select("id, name, description, enabled").Where("id IN ?", ids).Find(&routingRules).Error if err != nil { return nil, err } } else { - err := s.db.WithContext(ctx).Select("id, name, description, enabled").Find(&routingRules).Error + err := s.DB().WithContext(ctx).Select("id, name, description, enabled").Find(&routingRules).Error if err != nil { return nil, err } @@ -3150,7 +3195,7 @@ func (s *RDBConfigStore) GetRedactedRoutingRules(ctx context.Context, ids []stri // CreateRoutingRule creates a new routing rule in the database. func (s *RDBConfigStore) CreateRoutingRule(ctx context.Context, rule *tables.TableRoutingRule, tx ...*gorm.DB) error { - database := s.db + database := s.DB() if len(tx) > 0 && tx[0] != nil { database = tx[0] } @@ -3199,7 +3244,7 @@ func (s *RDBConfigStore) CreateRoutingRule(ctx context.Context, rule *tables.Tab // UpdateRoutingRule updates an existing routing rule in the database. // It enforces the same unique-priority-per-scope invariant as CreateRoutingRule. func (s *RDBConfigStore) UpdateRoutingRule(ctx context.Context, rule *tables.TableRoutingRule, tx ...*gorm.DB) error { - database := s.db + database := s.DB() if len(tx) > 0 && tx[0] != nil { database = tx[0] } @@ -3250,7 +3295,7 @@ func (s *RDBConfigStore) UpdateRoutingRule(ctx context.Context, rule *tables.Tab // DeleteRoutingRule deletes a routing rule and its targets from the database. func (s *RDBConfigStore) DeleteRoutingRule(ctx context.Context, id string, tx ...*gorm.DB) error { - database := s.db + database := s.DB() if len(tx) > 0 && tx[0] != nil { database = tx[0] } @@ -3273,7 +3318,7 @@ func (s *RDBConfigStore) DeleteRoutingRule(ctx context.Context, id string, tx .. // GetModelConfigs retrieves all model configs from the database. func (s *RDBConfigStore) GetModelConfigs(ctx context.Context) ([]tables.TableModelConfig, error) { var modelConfigs []tables.TableModelConfig - if err := s.db.WithContext(ctx).Preload("Budget").Preload("RateLimit").Find(&modelConfigs).Error; err != nil { + if err := s.DB().WithContext(ctx).Preload("Budget").Preload("RateLimit").Find(&modelConfigs).Error; err != nil { return nil, err } return modelConfigs, nil @@ -3281,7 +3326,7 @@ func (s *RDBConfigStore) GetModelConfigs(ctx context.Context) ([]tables.TableMod // GetModelConfigsPaginated retrieves model configs with pagination, filtering, and search support. func (s *RDBConfigStore) GetModelConfigsPaginated(ctx context.Context, params ModelConfigsQueryParams) ([]tables.TableModelConfig, int64, error) { - baseQuery := s.db.WithContext(ctx).Model(&tables.TableModelConfig{}) + baseQuery := s.DB().WithContext(ctx).Model(&tables.TableModelConfig{}) if params.Search != "" { search := "%" + strings.ToLower(params.Search) + "%" @@ -3322,7 +3367,7 @@ func (s *RDBConfigStore) GetModelConfigsPaginated(ctx context.Context, params Mo // GetModelConfig retrieves a specific model config from the database by model name and optional provider. func (s *RDBConfigStore) GetModelConfig(ctx context.Context, modelName string, provider *string) (*tables.TableModelConfig, error) { var modelConfig tables.TableModelConfig - query := s.db.WithContext(ctx).Where("model_name = ?", modelName) + query := s.DB().WithContext(ctx).Where("model_name = ?", modelName) if provider != nil { query = query.Where("provider = ?", *provider) } else { @@ -3340,7 +3385,7 @@ func (s *RDBConfigStore) GetModelConfig(ctx context.Context, modelName string, p // GetModelConfigByID retrieves a specific model config from the database by ID. func (s *RDBConfigStore) GetModelConfigByID(ctx context.Context, id string) (*tables.TableModelConfig, error) { var modelConfig tables.TableModelConfig - if err := s.db.WithContext(ctx).Preload("Budget").Preload("RateLimit").First(&modelConfig, "id = ?", id).Error; err != nil { + if err := s.DB().WithContext(ctx).Preload("Budget").Preload("RateLimit").First(&modelConfig, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } @@ -3355,7 +3400,7 @@ func (s *RDBConfigStore) CreateModelConfig(ctx context.Context, modelConfig *tab if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Create(modelConfig).Error; err != nil { return s.parseGormError(err) @@ -3369,7 +3414,7 @@ func (s *RDBConfigStore) UpdateModelConfig(ctx context.Context, modelConfig *tab if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } if err := txDB.WithContext(ctx).Save(modelConfig).Error; err != nil { return s.parseGormError(err) @@ -3383,7 +3428,7 @@ func (s *RDBConfigStore) UpdateModelConfigs(ctx context.Context, modelConfigs [] if len(tx) > 0 { txDB = tx[0] } else { - txDB = s.db + txDB = s.DB() } for _, mc := range modelConfigs { if err := txDB.WithContext(ctx).Save(mc).Error; err != nil { @@ -3395,7 +3440,7 @@ func (s *RDBConfigStore) UpdateModelConfigs(ctx context.Context, modelConfigs [] // DeleteModelConfig deletes a model config from the database. func (s *RDBConfigStore) DeleteModelConfig(ctx context.Context, id string) error { - return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { // First fetch the model config to get budget and rate limit IDs var modelConfig tables.TableModelConfig if err := tx.First(&modelConfig, "id = ?", id).Error; err != nil { @@ -3443,7 +3488,7 @@ func (s *RDBConfigStore) GetGovernanceConfig(ctx context.Context) (*GovernanceCo var pricingOverrides []tables.TablePricingOverride var governanceConfigs []tables.TableGovernanceConfig - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Preload("ProviderConfigs"). Preload("ProviderConfigs.Keys", func(db *gorm.DB) *gorm.DB { return db.Select("id, name, key_id, models_json, provider") @@ -3451,34 +3496,34 @@ func (s *RDBConfigStore) GetGovernanceConfig(ctx context.Context) (*GovernanceCo Find(&virtualKeys).Error; err != nil { return nil, err } - if err := s.db.WithContext(ctx). + if err := s.DB().WithContext(ctx). Select(teamSelectWithVKCount). Find(&teams).Error; err != nil { return nil, err } - if err := s.db.WithContext(ctx).Find(&customers).Error; err != nil { + if err := s.DB().WithContext(ctx).Find(&customers).Error; err != nil { return nil, err } - if err := s.db.WithContext(ctx).Find(&budgets).Error; err != nil { + if err := s.DB().WithContext(ctx).Find(&budgets).Error; err != nil { return nil, err } - if err := s.db.WithContext(ctx).Find(&rateLimits).Error; err != nil { + if err := s.DB().WithContext(ctx).Find(&rateLimits).Error; err != nil { return nil, err } - if err := s.db.WithContext(ctx).Find(&modelConfigs).Error; err != nil { + if err := s.DB().WithContext(ctx).Find(&modelConfigs).Error; err != nil { return nil, err } - if err := s.db.WithContext(ctx).Find(&providers).Error; err != nil { + if err := s.DB().WithContext(ctx).Find(&providers).Error; err != nil { return nil, err } if err := s.loadRoutingRulesOrdered(ctx, &routingRules); err != nil { return nil, err } - if err := s.db.WithContext(ctx).Find(&pricingOverrides).Error; err != nil { + if err := s.DB().WithContext(ctx).Find(&pricingOverrides).Error; err != nil { return nil, err } // Fetching governance config for username and password - if err := s.db.WithContext(ctx).Find(&governanceConfigs).Error; err != nil { + if err := s.DB().WithContext(ctx).Find(&governanceConfigs).Error; err != nil { return nil, err } // Check if any config is present @@ -3533,22 +3578,22 @@ func (s *RDBConfigStore) GetAuthConfig(ctx context.Context) (*AuthConfig, error) var password *string var isEnabled bool var disableAuthOnInference bool - if err := s.db.WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigAdminUsernameKey).Select("value").Scan(&username).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigAdminUsernameKey).Select("value").Scan(&username).Error; err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err } } - if err := s.db.WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigAdminPasswordKey).Select("value").Scan(&password).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigAdminPasswordKey).Select("value").Scan(&password).Error; err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err } } - if err := s.db.WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigIsAuthEnabledKey).Select("value").Scan(&isEnabled).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigIsAuthEnabledKey).Select("value").Scan(&isEnabled).Error; err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err } } - if err := s.db.WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigDisableAuthOnInferenceKey).Select("value").Scan(&disableAuthOnInference).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigDisableAuthOnInferenceKey).Select("value").Scan(&disableAuthOnInference).Error; err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err } @@ -3566,7 +3611,7 @@ func (s *RDBConfigStore) GetAuthConfig(ctx context.Context) (*AuthConfig, error) // UpdateAuthConfig updates the auth configuration in the database. func (s *RDBConfigStore) UpdateAuthConfig(ctx context.Context, config *AuthConfig) error { - return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { if err := tx.Save(&tables.TableGovernanceConfig{ Key: tables.ConfigAdminUsernameKey, Value: config.AdminUserName.GetValue(), @@ -3598,7 +3643,7 @@ func (s *RDBConfigStore) UpdateAuthConfig(ctx context.Context, config *AuthConfi // GetProxyConfig retrieves the proxy configuration from the database. func (s *RDBConfigStore) GetProxyConfig(ctx context.Context) (*tables.GlobalProxyConfig, error) { var configEntry tables.TableGovernanceConfig - if err := s.db.WithContext(ctx).First(&configEntry, "key = ?", tables.ConfigProxyKey).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&configEntry, "key = ?", tables.ConfigProxyKey).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } @@ -3645,7 +3690,7 @@ func (s *RDBConfigStore) UpdateProxyConfig(ctx context.Context, config *tables.G if err != nil { return fmt.Errorf("failed to marshal proxy config: %w", err) } - return s.db.WithContext(ctx).Save(&tables.TableGovernanceConfig{ + return s.DB().WithContext(ctx).Save(&tables.TableGovernanceConfig{ Key: tables.ConfigProxyKey, Value: string(configJSON), }).Error @@ -3654,7 +3699,7 @@ func (s *RDBConfigStore) UpdateProxyConfig(ctx context.Context, config *tables.G // GetRestartRequiredConfig retrieves the restart required configuration from the database. func (s *RDBConfigStore) GetRestartRequiredConfig(ctx context.Context) (*tables.RestartRequiredConfig, error) { var configEntry tables.TableGovernanceConfig - if err := s.db.WithContext(ctx).First(&configEntry, "key = ?", tables.ConfigRestartRequiredKey).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&configEntry, "key = ?", tables.ConfigRestartRequiredKey).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } @@ -3676,7 +3721,7 @@ func (s *RDBConfigStore) SetRestartRequiredConfig(ctx context.Context, config *t if err != nil { return fmt.Errorf("failed to marshal restart required config: %w", err) } - return s.db.WithContext(ctx).Save(&tables.TableGovernanceConfig{ + return s.DB().WithContext(ctx).Save(&tables.TableGovernanceConfig{ Key: tables.ConfigRestartRequiredKey, Value: string(configJSON), }).Error @@ -3684,7 +3729,7 @@ func (s *RDBConfigStore) SetRestartRequiredConfig(ctx context.Context, config *t // ClearRestartRequiredConfig clears the restart required configuration in the database. func (s *RDBConfigStore) ClearRestartRequiredConfig(ctx context.Context) error { - return s.db.WithContext(ctx).Save(&tables.TableGovernanceConfig{ + return s.DB().WithContext(ctx).Save(&tables.TableGovernanceConfig{ Key: tables.ConfigRestartRequiredKey, Value: `{"required":false,"reason":""}`, }).Error @@ -3694,11 +3739,11 @@ func (s *RDBConfigStore) ClearRestartRequiredConfig(ctx context.Context) error { func (s *RDBConfigStore) GetSession(ctx context.Context, token string) (*tables.SessionsTable, error) { var session tables.SessionsTable tokenHash := encrypt.HashSHA256(token) - err := s.db.WithContext(ctx).First(&session, "token_hash = ?", tokenHash).Error + err := s.DB().WithContext(ctx).First(&session, "token_hash = ?", tokenHash).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { // Fall back to plaintext lookup for backward compatibility - if err := s.db.WithContext(ctx).First(&session, "token = ?", token).Error; err != nil { + if err := s.DB().WithContext(ctx).First(&session, "token = ?", token).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } @@ -3713,31 +3758,31 @@ func (s *RDBConfigStore) GetSession(ctx context.Context, token string) (*tables. // CreateSession creates a new session in the database. func (s *RDBConfigStore) CreateSession(ctx context.Context, session *tables.SessionsTable) error { - return s.db.WithContext(ctx).Create(session).Error + return s.DB().WithContext(ctx).Create(session).Error } // DeleteSession deletes a session from the database. func (s *RDBConfigStore) DeleteSession(ctx context.Context, token string) error { tokenHash := encrypt.HashSHA256(token) - result := s.db.WithContext(ctx).Delete(&tables.SessionsTable{}, "token_hash = ?", tokenHash) + result := s.DB().WithContext(ctx).Delete(&tables.SessionsTable{}, "token_hash = ?", tokenHash) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { // Fall back to plaintext lookup for backward compatibility - return s.db.WithContext(ctx).Delete(&tables.SessionsTable{}, "token = ?", token).Error + return s.DB().WithContext(ctx).Delete(&tables.SessionsTable{}, "token = ?", token).Error } return nil } // FlushSessions flushes all sessions from the database. func (s *RDBConfigStore) FlushSessions(ctx context.Context) error { - return s.db.WithContext(ctx).Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.SessionsTable{}).Error + return s.DB().WithContext(ctx).Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.SessionsTable{}).Error } // ExecuteTransaction executes a transaction. func (s *RDBConfigStore) ExecuteTransaction(ctx context.Context, fn func(tx *gorm.DB) error) error { - return s.db.WithContext(ctx).Transaction(fn) + return s.DB().WithContext(ctx).Transaction(fn) } // RetryOnNotFound retries a function up to 3 times with 1-second delays if it returns ErrNotFound @@ -3769,12 +3814,12 @@ func (s *RDBConfigStore) RetryOnNotFound(ctx context.Context, fn func(ctx contex // doesTableExist checks if a table exists in the database. func (s *RDBConfigStore) doesTableExist(ctx context.Context, tableName string) bool { - return s.db.WithContext(ctx).Migrator().HasTable(tableName) + return s.DB().WithContext(ctx).Migrator().HasTable(tableName) } // removeNullKeys removes null keys from the database. func (s *RDBConfigStore) removeNullKeys(ctx context.Context) error { - return s.db.WithContext(ctx).Exec("DELETE FROM config_keys WHERE key_id IS NULL OR value IS NULL").Error + return s.DB().WithContext(ctx).Exec("DELETE FROM config_keys WHERE key_id IS NULL OR value IS NULL").Error } // removeDuplicateKeysAndNullKeys removes duplicate keys based on key_id and value combination @@ -3793,7 +3838,7 @@ func (s *RDBConfigStore) removeDuplicateKeysAndNullKeys(ctx context.Context) err s.logger.Debug("deleting duplicate keys from the database") // Find and delete duplicate keys, keeping only the one with the smallest ID // This query deletes all records except the one with the minimum ID for each (key_id, value) pair - result := s.db.WithContext(ctx).Exec(` + result := s.DB().WithContext(ctx).Exec(` DELETE FROM config_keys WHERE id NOT IN ( SELECT MIN(id) @@ -3809,18 +3854,9 @@ func (s *RDBConfigStore) removeDuplicateKeysAndNullKeys(ctx context.Context) err return nil } -// RunMigration runs a migration. -func (s *RDBConfigStore) RunMigration(ctx context.Context, migration *migrator.Migration) error { - if migration == nil { - return fmt.Errorf("migration cannot be nil") - } - m := migrator.New(s.db, migrator.DefaultOptions, []*migrator.Migration{migration}) - return m.Migrate() -} - // Close closes the SQLite config store. func (s *RDBConfigStore) Close(ctx context.Context) error { - sqlDB, err := s.db.DB() + sqlDB, err := s.DB().DB() if err != nil { return err } @@ -3836,7 +3872,7 @@ func (s *RDBConfigStore) TryAcquireLock(ctx context.Context, lock *tables.TableD } // Use GORM clause-based insert for dialect-appropriate SQL - result := s.db.WithContext(ctx).Clauses( + result := s.DB().WithContext(ctx).Clauses( clause.OnConflict{ Columns: []clause.Column{{Name: "lock_key"}}, DoNothing: true, @@ -3854,7 +3890,7 @@ func (s *RDBConfigStore) TryAcquireLock(ctx context.Context, lock *tables.TableD // GetLock retrieves a lock by its key. Returns nil if the lock doesn't exist. func (s *RDBConfigStore) GetLock(ctx context.Context, lockKey string) (*tables.TableDistributedLock, error) { var lock tables.TableDistributedLock - result := s.db.WithContext(ctx).Where("lock_key = ?", lockKey).First(&lock) + result := s.DB().WithContext(ctx).Where("lock_key = ?", lockKey).First(&lock) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -3869,7 +3905,7 @@ func (s *RDBConfigStore) GetLock(ctx context.Context, lockKey string) (*tables.T // UpdateLockExpiry updates the expiration time for an existing lock. // Only succeeds if the holder ID matches the current lock holder. func (s *RDBConfigStore) UpdateLockExpiry(ctx context.Context, lockKey, holderID string, expiresAt time.Time) error { - result := s.db.WithContext(ctx).Model(&tables.TableDistributedLock{}). + result := s.DB().WithContext(ctx).Model(&tables.TableDistributedLock{}). Where("lock_key = ? AND holder_id = ? AND expires_at > ?", lockKey, holderID, time.Now().UTC()). Update("expires_at", expiresAt) @@ -3887,7 +3923,7 @@ func (s *RDBConfigStore) UpdateLockExpiry(ctx context.Context, lockKey, holderID // ReleaseLock deletes a lock if the holder ID matches. // Returns true if the lock was released, false if it wasn't held by the given holder. func (s *RDBConfigStore) ReleaseLock(ctx context.Context, lockKey, holderID string) (bool, error) { - result := s.db.WithContext(ctx). + result := s.DB().WithContext(ctx). Where("lock_key = ? AND holder_id = ?", lockKey, holderID). Delete(&tables.TableDistributedLock{}) @@ -3901,7 +3937,7 @@ func (s *RDBConfigStore) ReleaseLock(ctx context.Context, lockKey, holderID stri // CleanupExpiredLocks removes all locks that have expired. // Returns the number of locks cleaned up. func (s *RDBConfigStore) CleanupExpiredLocks(ctx context.Context) (int64, error) { - result := s.db.WithContext(ctx). + result := s.DB().WithContext(ctx). Where("expires_at < ?", time.Now().UTC()). Delete(&tables.TableDistributedLock{}) @@ -3915,7 +3951,7 @@ func (s *RDBConfigStore) CleanupExpiredLocks(ctx context.Context) (int64, error) // CleanupExpiredLockByKey atomically deletes a specific lock only if it has expired. // Returns true if an expired lock was deleted, false if the lock doesn't exist or hasn't expired. func (s *RDBConfigStore) CleanupExpiredLockByKey(ctx context.Context, lockKey string) (bool, error) { - result := s.db.WithContext(ctx). + result := s.DB().WithContext(ctx). Where("lock_key = ? AND expires_at < ?", lockKey, time.Now().UTC()). Delete(&tables.TableDistributedLock{}) @@ -3931,7 +3967,7 @@ func (s *RDBConfigStore) CleanupExpiredLockByKey(ctx context.Context, lockKey st // GetOauthConfigByID retrieves an OAuth config by its ID func (s *RDBConfigStore) GetOauthConfigByID(ctx context.Context, id string) (*tables.TableOauthConfig, error) { var config tables.TableOauthConfig - result := s.db.WithContext(ctx).Where("id = ?", id).First(&config) + result := s.DB().WithContext(ctx).Where("id = ?", id).First(&config) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -3945,7 +3981,7 @@ func (s *RDBConfigStore) GetOauthConfigByID(ctx context.Context, id string) (*ta // State is unique per OAuth flow (used for CSRF protection on callback) func (s *RDBConfigStore) GetOauthConfigByState(ctx context.Context, state string) (*tables.TableOauthConfig, error) { var config tables.TableOauthConfig - result := s.db.WithContext(ctx).Where("state = ?", state).First(&config) + result := s.DB().WithContext(ctx).Where("state = ?", state).First(&config) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -3958,7 +3994,7 @@ func (s *RDBConfigStore) GetOauthConfigByState(ctx context.Context, state string // GetOauthTokenByID retrieves an OAuth token by its ID func (s *RDBConfigStore) GetOauthTokenByID(ctx context.Context, id string) (*tables.TableOauthToken, error) { var token tables.TableOauthToken - result := s.db.WithContext(ctx).Where("id = ?", id).First(&token) + result := s.DB().WithContext(ctx).Where("id = ?", id).First(&token) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -3970,7 +4006,7 @@ func (s *RDBConfigStore) GetOauthTokenByID(ctx context.Context, id string) (*tab // CreateOauthConfig creates a new OAuth config func (s *RDBConfigStore) CreateOauthConfig(ctx context.Context, config *tables.TableOauthConfig) error { - result := s.db.WithContext(ctx).Create(config) + result := s.DB().WithContext(ctx).Create(config) if result.Error != nil { return fmt.Errorf("failed to create oauth config: %w", result.Error) } @@ -3979,7 +4015,7 @@ func (s *RDBConfigStore) CreateOauthConfig(ctx context.Context, config *tables.T // CreateOauthToken creates a new OAuth token func (s *RDBConfigStore) CreateOauthToken(ctx context.Context, token *tables.TableOauthToken) error { - result := s.db.WithContext(ctx).Create(token) + result := s.DB().WithContext(ctx).Create(token) if result.Error != nil { return fmt.Errorf("failed to create oauth token: %w", result.Error) } @@ -3988,7 +4024,7 @@ func (s *RDBConfigStore) CreateOauthToken(ctx context.Context, token *tables.Tab // UpdateOauthConfig updates an existing OAuth config func (s *RDBConfigStore) UpdateOauthConfig(ctx context.Context, config *tables.TableOauthConfig) error { - result := s.db.WithContext(ctx).Save(config) + result := s.DB().WithContext(ctx).Save(config) if result.Error != nil { return fmt.Errorf("failed to update oauth config: %w", result.Error) } @@ -3997,7 +4033,7 @@ func (s *RDBConfigStore) UpdateOauthConfig(ctx context.Context, config *tables.T // UpdateOauthToken updates an existing OAuth token func (s *RDBConfigStore) UpdateOauthToken(ctx context.Context, token *tables.TableOauthToken) error { - result := s.db.WithContext(ctx).Save(token) + result := s.DB().WithContext(ctx).Save(token) if result.Error != nil { return fmt.Errorf("failed to update oauth token: %w", result.Error) } @@ -4006,7 +4042,7 @@ func (s *RDBConfigStore) UpdateOauthToken(ctx context.Context, token *tables.Tab // DeleteOauthToken deletes an OAuth token by its ID func (s *RDBConfigStore) DeleteOauthToken(ctx context.Context, id string) error { - result := s.db.WithContext(ctx).Where("id = ?", id).Delete(&tables.TableOauthToken{}) + result := s.DB().WithContext(ctx).Where("id = ?", id).Delete(&tables.TableOauthToken{}) if result.Error != nil { return fmt.Errorf("failed to delete oauth token: %w", result.Error) } @@ -4016,7 +4052,7 @@ func (s *RDBConfigStore) DeleteOauthToken(ctx context.Context, id string) error // GetExpiringOauthTokens retrieves tokens that are expiring before the given time func (s *RDBConfigStore) GetExpiringOauthTokens(ctx context.Context, before time.Time) ([]*tables.TableOauthToken, error) { var tokens []*tables.TableOauthToken - result := s.db.WithContext(ctx). + result := s.DB().WithContext(ctx). Where("expires_at < ?", before). Find(&tokens) if result.Error != nil { @@ -4028,7 +4064,7 @@ func (s *RDBConfigStore) GetExpiringOauthTokens(ctx context.Context, before time // GetOauthConfigByTokenID retrieves an OAuth config that references a specific token func (s *RDBConfigStore) GetOauthConfigByTokenID(ctx context.Context, tokenID string) (*tables.TableOauthConfig, error) { var config tables.TableOauthConfig - result := s.db.WithContext(ctx).Where("token_id = ?", tokenID).First(&config) + result := s.DB().WithContext(ctx).Where("token_id = ?", tokenID).First(&config) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -4043,7 +4079,7 @@ func (s *RDBConfigStore) GetOauthConfigByTokenID(ctx context.Context, tokenID st // GetOauthUserSessionByID retrieves a per-user OAuth session by its ID func (s *RDBConfigStore) GetOauthUserSessionByID(ctx context.Context, id string) (*tables.TableOauthUserSession, error) { var session tables.TableOauthUserSession - result := s.db.WithContext(ctx).Where("id = ?", id).First(&session) + result := s.DB().WithContext(ctx).Where("id = ?", id).First(&session) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -4056,7 +4092,7 @@ func (s *RDBConfigStore) GetOauthUserSessionByID(ctx context.Context, id string) // GetOauthUserSessionByState retrieves a per-user OAuth session by its state token func (s *RDBConfigStore) GetOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error) { var session tables.TableOauthUserSession - result := s.db.WithContext(ctx).Where("state = ?", state).First(&session) + result := s.DB().WithContext(ctx).Where("state = ?", state).First(&session) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -4070,7 +4106,7 @@ func (s *RDBConfigStore) GetOauthUserSessionByState(ctx context.Context, state s // Returns nil if the session doesn't exist or has already been claimed by another request. func (s *RDBConfigStore) ClaimOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error) { var session tables.TableOauthUserSession - result := s.db.WithContext(ctx).Where("state = ? AND status = ?", state, "pending").First(&session) + result := s.DB().WithContext(ctx).Where("state = ? AND status = ?", state, "pending").First(&session) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -4078,7 +4114,7 @@ func (s *RDBConfigStore) ClaimOauthUserSessionByState(ctx context.Context, state return nil, fmt.Errorf("failed to claim oauth user session by state: %w", result.Error) } // Atomically transition from "pending" to "claiming" to prevent concurrent claims - updateResult := s.db.WithContext(ctx).Model(&tables.TableOauthUserSession{}). + updateResult := s.DB().WithContext(ctx).Model(&tables.TableOauthUserSession{}). Where("id = ? AND status = ?", session.ID, "pending"). Update("status", "claiming") if updateResult.Error != nil { @@ -4095,7 +4131,7 @@ func (s *RDBConfigStore) ClaimOauthUserSessionByState(ctx context.Context, state func (s *RDBConfigStore) GetOauthUserSessionBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserSession, error) { var session tables.TableOauthUserSession tokenHash := encrypt.HashSHA256(sessionToken) - result := s.db.WithContext(ctx).Where("session_token_hash = ?", tokenHash).First(&session) + result := s.DB().WithContext(ctx).Where("session_token_hash = ?", tokenHash).First(&session) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -4107,7 +4143,7 @@ func (s *RDBConfigStore) GetOauthUserSessionBySessionToken(ctx context.Context, // CreateOauthUserSession creates a new per-user OAuth session func (s *RDBConfigStore) CreateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error { - result := s.db.WithContext(ctx).Create(session) + result := s.DB().WithContext(ctx).Create(session) if result.Error != nil { return fmt.Errorf("failed to create oauth user session: %w", result.Error) } @@ -4116,7 +4152,7 @@ func (s *RDBConfigStore) CreateOauthUserSession(ctx context.Context, session *ta // UpdateOauthUserSession updates an existing per-user OAuth session func (s *RDBConfigStore) UpdateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error { - result := s.db.WithContext(ctx).Save(session) + result := s.DB().WithContext(ctx).Save(session) if result.Error != nil { return fmt.Errorf("failed to update oauth user session: %w", result.Error) } @@ -4133,11 +4169,11 @@ func (s *RDBConfigStore) GetOauthUserTokenByIdentity(ctx context.Context, virtua var result *gorm.DB if userID != "" { - result = s.db.WithContext(ctx).Where("user_id = ? AND mcp_client_id = ?", userID, mcpClientID).First(&token) + result = s.DB().WithContext(ctx).Where("user_id = ? AND mcp_client_id = ?", userID, mcpClientID).First(&token) } else if virtualKeyID != "" { - result = s.db.WithContext(ctx).Where("virtual_key_id = ? AND mcp_client_id = ?", virtualKeyID, mcpClientID).First(&token) + result = s.DB().WithContext(ctx).Where("virtual_key_id = ? AND mcp_client_id = ?", virtualKeyID, mcpClientID).First(&token) } else if sessionToken != "" { - result = s.db.WithContext(ctx).Where("session_token = ? AND mcp_client_id = ?", sessionToken, mcpClientID).First(&token) + result = s.DB().WithContext(ctx).Where("session_token = ? AND mcp_client_id = ?", sessionToken, mcpClientID).First(&token) } else { return nil, nil } @@ -4154,7 +4190,7 @@ func (s *RDBConfigStore) GetOauthUserTokenByIdentity(ctx context.Context, virtua func (s *RDBConfigStore) GetOauthUserTokenBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserToken, error) { var token tables.TableOauthUserToken tokenHash := encrypt.HashSHA256(sessionToken) - result := s.db.WithContext(ctx).Where("session_token_hash = ?", tokenHash).First(&token) + result := s.DB().WithContext(ctx).Where("session_token_hash = ?", tokenHash).First(&token) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -4170,7 +4206,7 @@ func (s *RDBConfigStore) GetOauthUserTokenBySessionToken(ctx context.Context, se func (s *RDBConfigStore) CreateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error { // Wrap in a transaction so the SELECT + CREATE/UPDATE is atomic, preventing // duplicate tokens when concurrent requests race on the same identity+client pair. - return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { if token.UserID != nil && *token.UserID != "" { var existing tables.TableOauthUserToken err := tx.Where("user_id = ? AND mcp_client_id = ?", *token.UserID, token.MCPClientID).First(&existing).Error @@ -4202,7 +4238,7 @@ func (s *RDBConfigStore) CreateOauthUserToken(ctx context.Context, token *tables // UpdateOauthUserToken updates an existing per-user OAuth token func (s *RDBConfigStore) UpdateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error { - result := s.db.WithContext(ctx).Save(token) + result := s.DB().WithContext(ctx).Save(token) if result.Error != nil { return fmt.Errorf("failed to update oauth user token: %w", result.Error) } @@ -4211,7 +4247,7 @@ func (s *RDBConfigStore) UpdateOauthUserToken(ctx context.Context, token *tables // DeleteOauthUserToken deletes a per-user OAuth token by its ID func (s *RDBConfigStore) DeleteOauthUserToken(ctx context.Context, id string) error { - result := s.db.WithContext(ctx).Where("id = ?", id).Delete(&tables.TableOauthUserToken{}) + result := s.DB().WithContext(ctx).Where("id = ?", id).Delete(&tables.TableOauthUserToken{}) if result.Error != nil { return fmt.Errorf("failed to delete oauth user token: %w", result.Error) } @@ -4220,7 +4256,7 @@ func (s *RDBConfigStore) DeleteOauthUserToken(ctx context.Context, id string) er // DeleteOauthUserTokensByMCPClient deletes all per-user OAuth tokens for a specific MCP client func (s *RDBConfigStore) DeleteOauthUserTokensByMCPClient(ctx context.Context, mcpClientID string) error { - result := s.db.WithContext(ctx).Where("mcp_client_id = ?", mcpClientID).Delete(&tables.TableOauthUserToken{}) + result := s.DB().WithContext(ctx).Where("mcp_client_id = ?", mcpClientID).Delete(&tables.TableOauthUserToken{}) if result.Error != nil { return fmt.Errorf("failed to delete oauth user tokens for mcp client: %w", result.Error) } @@ -4232,7 +4268,7 @@ func (s *RDBConfigStore) DeleteOauthUserTokensByMCPClient(ctx context.Context, m // GetPerUserOAuthClientByClientID retrieves a dynamically registered OAuth client by its client_id. func (s *RDBConfigStore) GetPerUserOAuthClientByClientID(ctx context.Context, clientID string) (*tables.TablePerUserOAuthClient, error) { var client tables.TablePerUserOAuthClient - result := s.db.WithContext(ctx).Where("client_id = ?", clientID).First(&client) + result := s.DB().WithContext(ctx).Where("client_id = ?", clientID).First(&client) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -4244,7 +4280,7 @@ func (s *RDBConfigStore) GetPerUserOAuthClientByClientID(ctx context.Context, cl // CreatePerUserOAuthClient creates a new dynamically registered OAuth client. func (s *RDBConfigStore) CreatePerUserOAuthClient(ctx context.Context, client *tables.TablePerUserOAuthClient) error { - result := s.db.WithContext(ctx).Create(client) + result := s.DB().WithContext(ctx).Create(client) if result.Error != nil { return fmt.Errorf("failed to create per-user oauth client: %w", result.Error) } @@ -4255,7 +4291,7 @@ func (s *RDBConfigStore) CreatePerUserOAuthClient(ctx context.Context, client *t func (s *RDBConfigStore) GetPerUserOAuthSessionByAccessToken(ctx context.Context, accessToken string) (*tables.TablePerUserOAuthSession, error) { var session tables.TablePerUserOAuthSession tokenHash := encrypt.HashSHA256(accessToken) - result := s.db.WithContext(ctx).Where("access_token_hash = ?", tokenHash).Preload("VirtualKey", func(db *gorm.DB) *gorm.DB { + result := s.DB().WithContext(ctx).Where("access_token_hash = ?", tokenHash).Preload("VirtualKey", func(db *gorm.DB) *gorm.DB { return db.Select("id, name, value, encryption_status") }).First(&session) if result.Error != nil { @@ -4270,7 +4306,7 @@ func (s *RDBConfigStore) GetPerUserOAuthSessionByAccessToken(ctx context.Context // GetPerUserOAuthSessionByID retrieves a Bifrost-issued session by its ID. func (s *RDBConfigStore) GetPerUserOAuthSessionByID(ctx context.Context, id string) (*tables.TablePerUserOAuthSession, error) { var session tables.TablePerUserOAuthSession - result := s.db.WithContext(ctx).Where("id = ?", id).First(&session) + result := s.DB().WithContext(ctx).Where("id = ?", id).First(&session) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -4282,7 +4318,7 @@ func (s *RDBConfigStore) GetPerUserOAuthSessionByID(ctx context.Context, id stri // CreatePerUserOAuthSession creates a new Bifrost-issued OAuth session. func (s *RDBConfigStore) CreatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error { - result := s.db.WithContext(ctx).Create(session) + result := s.DB().WithContext(ctx).Create(session) if result.Error != nil { return fmt.Errorf("failed to create per-user oauth session: %w", result.Error) } @@ -4291,7 +4327,7 @@ func (s *RDBConfigStore) CreatePerUserOAuthSession(ctx context.Context, session // UpdatePerUserOAuthSession updates a Bifrost-issued OAuth session (e.g., to attach user identity). func (s *RDBConfigStore) UpdatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error { - result := s.db.WithContext(ctx).Save(session) + result := s.DB().WithContext(ctx).Save(session) if result.Error != nil { return fmt.Errorf("failed to update per-user oauth session: %w", result.Error) } @@ -4300,7 +4336,7 @@ func (s *RDBConfigStore) UpdatePerUserOAuthSession(ctx context.Context, session // DeletePerUserOAuthSession deletes a Bifrost-issued OAuth session by ID. func (s *RDBConfigStore) DeletePerUserOAuthSession(ctx context.Context, id string) error { - result := s.db.WithContext(ctx).Where("id = ?", id).Delete(&tables.TablePerUserOAuthSession{}) + result := s.DB().WithContext(ctx).Where("id = ?", id).Delete(&tables.TablePerUserOAuthSession{}) if result.Error != nil { return fmt.Errorf("failed to delete per-user oauth session: %w", result.Error) } @@ -4311,7 +4347,7 @@ func (s *RDBConfigStore) DeletePerUserOAuthSession(ctx context.Context, id strin func (s *RDBConfigStore) GetPerUserOAuthCodeByCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error) { var codeRecord tables.TablePerUserOAuthCode codeHash := encrypt.HashSHA256(code) - result := s.db.WithContext(ctx).Where("code_hash = ?", codeHash).First(&codeRecord) + result := s.DB().WithContext(ctx).Where("code_hash = ?", codeHash).First(&codeRecord) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -4323,7 +4359,7 @@ func (s *RDBConfigStore) GetPerUserOAuthCodeByCode(ctx context.Context, code str // CreatePerUserOAuthCode creates a new authorization code record. func (s *RDBConfigStore) CreatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error { - result := s.db.WithContext(ctx).Create(code) + result := s.DB().WithContext(ctx).Create(code) if result.Error != nil { return fmt.Errorf("failed to create per-user oauth code: %w", result.Error) } @@ -4335,7 +4371,7 @@ func (s *RDBConfigStore) CreatePerUserOAuthCode(ctx context.Context, code *table func (s *RDBConfigStore) ClaimPerUserOAuthCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error) { codeHash := encrypt.HashSHA256(code) var codeRecord tables.TablePerUserOAuthCode - result := s.db.WithContext(ctx).Where("code_hash = ? AND used = ?", codeHash, false).First(&codeRecord) + result := s.DB().WithContext(ctx).Where("code_hash = ? AND used = ?", codeHash, false).First(&codeRecord) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -4343,7 +4379,7 @@ func (s *RDBConfigStore) ClaimPerUserOAuthCode(ctx context.Context, code string) return nil, fmt.Errorf("failed to find per-user oauth code: %w", result.Error) } // Atomically mark as used - updateResult := s.db.WithContext(ctx).Model(&tables.TablePerUserOAuthCode{}). + updateResult := s.DB().WithContext(ctx).Model(&tables.TablePerUserOAuthCode{}). Where("id = ? AND used = ?", codeRecord.ID, false). Update("used", true) if updateResult.Error != nil { @@ -4358,7 +4394,7 @@ func (s *RDBConfigStore) ClaimPerUserOAuthCode(ctx context.Context, code string) // UpdatePerUserOAuthCode updates an authorization code record (e.g., marking as used). func (s *RDBConfigStore) UpdatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error { - result := s.db.WithContext(ctx).Save(code) + result := s.DB().WithContext(ctx).Save(code) if result.Error != nil { return fmt.Errorf("failed to update per-user oauth code: %w", result.Error) } @@ -4370,7 +4406,7 @@ func (s *RDBConfigStore) UpdatePerUserOAuthCode(ctx context.Context, code *table // GetPerUserOAuthPendingFlow retrieves a pending consent flow by its ID. func (s *RDBConfigStore) GetPerUserOAuthPendingFlow(ctx context.Context, id string) (*tables.TablePerUserOAuthPendingFlow, error) { var flow tables.TablePerUserOAuthPendingFlow - result := s.db.WithContext(ctx).Where("id = ?", id).First(&flow) + result := s.DB().WithContext(ctx).Where("id = ?", id).First(&flow) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -4382,7 +4418,7 @@ func (s *RDBConfigStore) GetPerUserOAuthPendingFlow(ctx context.Context, id stri // CreatePerUserOAuthPendingFlow persists a new pending consent flow. func (s *RDBConfigStore) CreatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error { - result := s.db.WithContext(ctx).Create(flow) + result := s.DB().WithContext(ctx).Create(flow) if result.Error != nil { return fmt.Errorf("failed to create per-user oauth pending flow: %w", result.Error) } @@ -4391,7 +4427,7 @@ func (s *RDBConfigStore) CreatePerUserOAuthPendingFlow(ctx context.Context, flow // UpdatePerUserOAuthPendingFlow updates an existing pending consent flow (e.g., after VK step). func (s *RDBConfigStore) UpdatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error { - result := s.db.WithContext(ctx).Save(flow) + result := s.DB().WithContext(ctx).Save(flow) if result.Error != nil { return fmt.Errorf("failed to update per-user oauth pending flow: %w", result.Error) } @@ -4400,7 +4436,7 @@ func (s *RDBConfigStore) UpdatePerUserOAuthPendingFlow(ctx context.Context, flow // DeletePerUserOAuthPendingFlow deletes a pending consent flow after it has been submitted. func (s *RDBConfigStore) DeletePerUserOAuthPendingFlow(ctx context.Context, id string) error { - result := s.db.WithContext(ctx).Where("id = ?", id).Delete(&tables.TablePerUserOAuthPendingFlow{}) + result := s.DB().WithContext(ctx).Where("id = ?", id).Delete(&tables.TablePerUserOAuthPendingFlow{}) if result.Error != nil { return fmt.Errorf("failed to delete per-user oauth pending flow: %w", result.Error) } @@ -4409,14 +4445,14 @@ func (s *RDBConfigStore) DeletePerUserOAuthPendingFlow(ctx context.Context, id s func (s *RDBConfigStore) ConsumePerUserOAuthPendingFlow(ctx context.Context, id string) (int64, error) { now := time.Now().UTC() - result := s.db.WithContext(ctx).Where("id = ? AND expires_at > ?", id, now).Delete(&tables.TablePerUserOAuthPendingFlow{}) + result := s.DB().WithContext(ctx).Where("id = ? AND expires_at > ?", id, now).Delete(&tables.TablePerUserOAuthPendingFlow{}) if result.Error != nil { return 0, fmt.Errorf("failed to consume per-user oauth pending flow: %w", result.Error) } if result.RowsAffected == 0 { // Distinguish between already-consumed (record gone) and expired (record exists but TTL elapsed). var count int64 - if err := s.db.WithContext(ctx).Model(&tables.TablePerUserOAuthPendingFlow{}).Where("id = ?", id).Count(&count).Error; err != nil { + if err := s.DB().WithContext(ctx).Model(&tables.TablePerUserOAuthPendingFlow{}).Where("id = ?", id).Count(&count).Error; err != nil { return 0, fmt.Errorf("failed to inspect per-user oauth pending flow: %w", err) } if count > 0 { @@ -4430,7 +4466,7 @@ func (s *RDBConfigStore) ConsumePerUserOAuthPendingFlow(ctx context.Context, id // and creates the authorization code in a single transaction. func (s *RDBConfigStore) FinalizePerUserOAuthConsent(ctx context.Context, flowID string, session *tables.TablePerUserOAuthSession, code *tables.TablePerUserOAuthCode) (int64, error) { var rowsAffected int64 - err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { // 1. Consume the pending flow (atomic idempotency guard). // Also enforce the TTL so an expired flow cannot be finalized even if callers miss the check. now := time.Now().UTC() @@ -4479,8 +4515,8 @@ func (s *RDBConfigStore) GetOauthUserTokensByGatewaySessionID(ctx context.Contex // linked to this gateway session ID. This supports per-service proxy tokens // (e.g. "flow::") where each MCP service gets its own hash. var tokens []tables.TableOauthUserToken - subquery := s.db.Model(&tables.TableOauthUserSession{}).Select("session_token_hash").Where("gateway_session_id = ?", gatewaySessionID) - result := s.db.WithContext(ctx).Where("session_token_hash IN (?)", subquery).Find(&tokens) + subquery := s.DB().Model(&tables.TableOauthUserSession{}).Select("session_token_hash").Where("gateway_session_id = ?", gatewaySessionID) + result := s.DB().WithContext(ctx).Where("session_token_hash IN (?)", subquery).Find(&tokens) if result.Error != nil { return nil, fmt.Errorf("failed to get oauth user tokens by gateway session id: %w", result.Error) } @@ -4510,8 +4546,8 @@ func (s *RDBConfigStore) TransferOauthUserTokensFromGatewaySession(ctx context.C // Update all tokens whose session_token_hash matches any upstream session // linked to this gateway session ID. - subquery := s.db.Model(&tables.TableOauthUserSession{}).Select("session_token_hash").Where("gateway_session_id = ?", gatewaySessionID) - result := s.db.WithContext(ctx).Model(&tables.TableOauthUserToken{}). + subquery := s.DB().Model(&tables.TableOauthUserSession{}).Select("session_token_hash").Where("gateway_session_id = ?", gatewaySessionID) + result := s.DB().WithContext(ctx).Model(&tables.TableOauthUserToken{}). Where("session_token_hash IN (?)", subquery). Updates(updates) if result.Error != nil { diff --git a/framework/configstore/rdb_test.go b/framework/configstore/rdb_test.go index 4877dd02fc..48325f82f2 100644 --- a/framework/configstore/rdb_test.go +++ b/framework/configstore/rdb_test.go @@ -53,10 +53,13 @@ func setupRDBTestStore(t *testing.T) *RDBConfigStore { err = db.SetupJoinTable(&tables.TableVirtualKeyProviderConfig{}, "Keys", &tables.TableVirtualKeyProviderConfigKey{}) require.NoError(t, err, "Failed to setup join table") - return &RDBConfigStore{ - db: db, - logger: nil, + s := &RDBConfigStore{logger: nil} + s.db.Store(db) + s.migrateOnFreshFn = func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error { + return fn(ctx, s.DB()) } + s.refreshPoolFn = func(ctx context.Context) error { return nil } + return s } // ============================================================================= @@ -718,7 +721,7 @@ func TestCreateVirtualKeyProviderConfig_WithKeys(t *testing.T) { // Load with keys var configWithKeys tables.TableVirtualKeyProviderConfig - err = store.db.Preload("Keys").First(&configWithKeys, "id = ?", configs[0].ID).Error + err = store.DB().Preload("Keys").First(&configWithKeys, "id = ?", configs[0].ID).Error require.NoError(t, err) assert.Len(t, configWithKeys.Keys, 1) } @@ -1203,7 +1206,7 @@ func createTestPromptTree(t *testing.T, store *RDBConfigStore, ctx context.Conte func countRows(t *testing.T, store *RDBConfigStore, model interface{}) int64 { t.Helper() var count int64 - require.NoError(t, store.db.Model(model).Count(&count).Error) + require.NoError(t, store.DB().Model(model).Count(&count).Error) return count } @@ -1389,7 +1392,7 @@ func TestDeletePromptSession(t *testing.T) { // Session messages for that session should be gone var msgCount int64 - require.NoError(t, store.db.Model(&tables.TablePromptSessionMessage{}).Where("session_id = ?", sessionID).Count(&msgCount).Error) + require.NoError(t, store.DB().Model(&tables.TablePromptSessionMessage{}).Where("session_id = ?", sessionID).Count(&msgCount).Error) assert.Equal(t, int64(0), msgCount) }) diff --git a/framework/configstore/sqlite.go b/framework/configstore/sqlite.go index 4c4cbe8594..9482801d08 100644 --- a/framework/configstore/sqlite.go +++ b/framework/configstore/sqlite.go @@ -35,7 +35,16 @@ func newSqliteConfigStore(ctx context.Context, config *SQLiteConfig, logger sche return nil, err } logger.Debug("db opened for configstore") - s := &RDBConfigStore{db: db, logger: logger} + s := &RDBConfigStore{logger: logger} + s.db.Store(db) + // SQLite has no server-side prepared-plan cache, and opening a second + // handle on the same file would contend for the single-writer lock — + // so both hooks operate on the existing *gorm.DB. + s.migrateOnFreshFn = func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error { + return fn(ctx, s.DB()) + } + s.refreshPoolFn = func(ctx context.Context) error { return nil } + logger.Debug("running migration to remove duplicate keys") // Run migration to remove duplicate keys before AutoMigrate if err := s.removeDuplicateKeysAndNullKeys(ctx); err != nil { diff --git a/framework/configstore/store.go b/framework/configstore/store.go index 3fbb678159..16cedc6b6a 100644 --- a/framework/configstore/store.go +++ b/framework/configstore/store.go @@ -9,7 +9,6 @@ import ( "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/maximhq/bifrost/framework/logstore" - "github.com/maximhq/bifrost/framework/migrator" "github.com/maximhq/bifrost/framework/vectorstore" "gorm.io/gorm" ) @@ -393,8 +392,25 @@ type ConfigStore interface { // DB returns the underlying database connection. DB() *gorm.DB - // Migration manager - RunMigration(ctx context.Context, migration *migrator.Migration) error + // RunMigration opens a throwaway *gorm.DB against the same + // backing database, invokes fn with it, and closes the connection. Use + // this for DDL (typically downstream-consumer migrations) that must not + // leave cached prepared-statement plans on the runtime pool. + // + // After fn returns successfully, callers should invoke + // RefreshConnectionPool if the migration altered tables the runtime pool + // has already queried — otherwise SQLSTATE 0A000 can surface on reads + // whose cached plans predate the DDL. + // + // For SQLite backends, this is a pass-through that runs fn on the + // existing connection (no server-side plan cache, single-writer lock). + RunMigration(ctx context.Context, fn func(context.Context, *gorm.DB) error) error + + // RefreshConnectionPool tears down the runtime pool and opens a fresh + // one against the same configuration. In-flight queries on the old + // pool complete before it closes; subsequent DB() calls return the new + // pool, whose connections carry no cached plans. SQLite is a no-op. + RefreshConnectionPool(ctx context.Context) error // Cleanup Close(ctx context.Context) error diff --git a/framework/logstore/asyncjob_test.go b/framework/logstore/asyncjob_test.go index df71d7befe..c569fe0f31 100644 --- a/framework/logstore/asyncjob_test.go +++ b/framework/logstore/asyncjob_test.go @@ -86,14 +86,10 @@ func waitForJobStatus(t *testing.T, store LogStore, jobID string) *AsyncJob { func TestSubmitJob_PropagatesContextValues(t *testing.T) { executor := newTestAsyncExecutor(t) - // Simulate original request context values - contextValues := map[any]any{ - schemas.BifrostContextKeyVirtualKey: "sk-bf-test", - schemas.BifrostContextKey("x-bf-prom-env"): "production", - schemas.BifrostContextKey("x-bf-eh-custom"): "custom-value", - } - - var capturedCtx *schemas.BifrostContext + capturedCtx := schemas.NewBifrostContext(context.Background(), <-time.After(1*time.Minute)) + capturedCtx.SetValue(schemas.BifrostContextKeyVirtualKey, "sk-bf-test") + capturedCtx.SetValue(schemas.BifrostContextKey("x-bf-eh-custom"), "custom-value") + capturedCtx.SetValue(schemas.BifrostContextKey("x-bf-prom-env"), "production") var done atomic.Bool operation := func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) { diff --git a/framework/logstore/postgres.go b/framework/logstore/postgres.go index df78b1735d..183d466554 100644 --- a/framework/logstore/postgres.go +++ b/framework/logstore/postgres.go @@ -24,6 +24,13 @@ type PostgresConfig struct { } // newPostgresLogStore creates a new Postgres log store. +// +// Uses a two-pool lifecycle to avoid SQLSTATE 0A000 ("cached plan must not +// change result type"): a throwaway pool runs the version check and schema +// migrations and is closed immediately, then a fresh runtime pool is opened +// for query traffic and the async index / matview builders. The runtime +// pool's connections never see pre-migration schema, so their cached +// prepared-plans stay valid for the life of the process. func newPostgresLogStore(ctx context.Context, config *PostgresConfig, logger schemas.Logger) (LogStore, error) { if config == nil { return nil, fmt.Errorf("config is required") @@ -48,11 +55,56 @@ func newPostgresLogStore(ctx context.Context, config *PostgresConfig, logger sch return nil, fmt.Errorf("postgres ssl mode is required") } dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s", config.Host.GetValue(), config.Port.GetValue(), config.User.GetValue(), config.Password.GetValue(), config.DBName.GetValue(), config.SSLMode.GetValue()) - db, err := gorm.Open(postgres.New(postgres.Config{ - DSN: dsn, - }), &gorm.Config{ - Logger: newGormLogger(logger), - }) + + openPool := func() (*gorm.DB, error) { + return gorm.Open(postgres.New(postgres.Config{DSN: dsn}), &gorm.Config{ + Logger: newGormLogger(logger), + }) + } + + // closePoolStrict returns the close error so callers can abort startup + // when the throwaway migration pool doesn't tear down cleanly — a half- + // closed pool weakens the guarantee that no cached plans survive DDL. + closePool := func(db *gorm.DB) error { + if db == nil { + return nil + } + sqlDB, err := db.DB() + if err != nil { + return err + } + return sqlDB.Close() + } + + // Throwaway pool for the version gate and schema migrations. Closing it + // before the runtime pool opens guarantees no cached plan survives DDL. + mDb, err := openPool() + if err != nil { + return nil, err + } + + // Postgres version gate: refuse to start below 16 (matviews, partitioning, + // and some JSON operators we rely on depend on 16+). + var pgVersionNum int + if err := mDb.Raw("SELECT current_setting('server_version_num')::int").Scan(&pgVersionNum).Error; err != nil { + _ = closePool(mDb) + return nil, err + } + if pgVersionNum < 160000 { + _ = closePool(mDb) + return nil, fmt.Errorf("postgres version is lower than 16, please upgrade to 16 or higher") + } + + if err := triggerMigrations(ctx, mDb); err != nil { + _ = closePool(mDb) + return nil, err + } + if err := closePool(mDb); err != nil { + return nil, fmt.Errorf("close migration db connection: %w", err) + } + + // Runtime pool. Opens against post-migration schema. + db, err := openPool() if err != nil { return nil, err } @@ -60,6 +112,7 @@ func newPostgresLogStore(ctx context.Context, config *PostgresConfig, logger sch // Configure connection pool sqlDB, err := db.DB() if err != nil { + closePool(db) return nil, err } // Set MaxIdleConns (default: 5) @@ -77,25 +130,6 @@ func newPostgresLogStore(ctx context.Context, config *PostgresConfig, logger sch sqlDB.SetMaxOpenConns(maxOpenConns) d := &RDBLogStore{db: db, logger: logger} - // Check version of postgres, if is lower than 16, throw fatal error - var pgVersionNum int - if err := db.Raw("SELECT current_setting('server_version_num')::int").Scan(&pgVersionNum).Error; err != nil { - sqlDB.Close() - return nil, err - } - if pgVersionNum < 160000 { - sqlDB.Close() - return nil, fmt.Errorf("postgres version is lower than 16, please upgrade to 16 or higher") - } - - // Run migrations - if err := triggerMigrations(ctx, db); err != nil { - if sqlDB, sqlErr := db.DB(); sqlErr == nil { - sqlDB.Close() - } - return nil, err - } - // Run all index builds sequentially in a single goroutine to prevent // deadlocks from concurrent CREATE INDEX CONCURRENTLY on the same table. // Each function is idempotent and acquires its own advisory lock for