Skip to content

Commit

Permalink
pg: change db into tx
Browse files Browse the repository at this point in the history
  • Loading branch information
alexisvisco committed May 29, 2024
1 parent 13aceee commit 785ad7c
Show file tree
Hide file tree
Showing 10 changed files with 48 additions and 43 deletions.
12 changes: 6 additions & 6 deletions pkg/schema/pg/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (p *Schema) AddColumn(tableName schema.TableName, columnName string, column

query := fmt.Sprintf("ALTER TABLE %s ADD %s", options.Table, p.column(options))

_, err := p.DB.ExecContext(p.Context.Context, query)
_, err := p.TX.ExecContext(p.Context.Context, query)
if err != nil {
p.Context.RaiseError(fmt.Errorf("error while adding column: %w", err))
return
Expand Down Expand Up @@ -237,7 +237,7 @@ func (p *Schema) AddColumnComment(tableName schema.TableName, columnName string,
},
}

_, err := p.DB.ExecContext(p.Context.Context, replacer.Replace(sql))
_, err := p.TX.ExecContext(p.Context.Context, replacer.Replace(sql))
if err != nil {
p.Context.RaiseError(fmt.Errorf("error while adding column comment: %w", err))
return
Expand Down Expand Up @@ -265,7 +265,7 @@ func (p *Schema) RenameColumn(tableName schema.TableName, oldColumnName, newColu

query := fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", tableName, oldColumnName, newColumnName)

_, err := p.DB.ExecContext(p.Context.Context, query)
_, err := p.TX.ExecContext(p.Context.Context, query)
if err != nil {
p.Context.RaiseError(fmt.Errorf("error while renaming column: %w", err))
return
Expand Down Expand Up @@ -330,7 +330,7 @@ func (p *Schema) DropColumn(tableName schema.TableName, columnName string, opt .
"if_exists": utils.StrFuncPredicate(options.IfExists, "IF EXISTS"),
}

_, err := p.DB.ExecContext(p.Context.Context, replacer.Replace(query))
_, err := p.TX.ExecContext(p.Context.Context, replacer.Replace(query))
if err != nil {
p.Context.RaiseError(fmt.Errorf("error while dropping column: %w", err))
return
Expand Down Expand Up @@ -388,7 +388,7 @@ func (p *Schema) ChangeColumnType(tableName schema.TableName, columnName string,
"using": utils.StrFuncPredicate(options.Using != "", fmt.Sprintf("USING %s", options.Using)),
}

_, err := p.DB.ExecContext(p.Context.Context, replacer.Replace(query))
_, err := p.TX.ExecContext(p.Context.Context, replacer.Replace(query))
if err != nil {
p.Context.RaiseError(fmt.Errorf("error while changing column: %w", err))
return
Expand Down Expand Up @@ -434,7 +434,7 @@ func (p *Schema) ChangeColumnDefault(tableName schema.TableName, columnName, def
"default": utils.StrFunc(options.Value),
}

_, err := p.DB.ExecContext(p.Context.Context, replacer.Replace(query))
_, err := p.TX.ExecContext(p.Context.Context, replacer.Replace(query))
if err != nil {
p.Context.RaiseError(fmt.Errorf("error while changing column default: %w", err))
return
Expand Down
8 changes: 4 additions & 4 deletions pkg/schema/pg/constraint.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (p *Schema) AddCheckConstraint(tableName schema.TableName, constraintName s
}

query := fmt.Sprintf("ALTER TABLE %s ADD %s", options.Table.String(), p.checkConstraint(options))
_, err := p.DB.ExecContext(p.Context.Context, query)
_, err := p.TX.ExecContext(p.Context.Context, query)
if err != nil {
p.Context.RaiseError(fmt.Errorf("error while adding check constraint: %w", err))
return
Expand Down Expand Up @@ -190,7 +190,7 @@ func (p *Schema) AddForeignKeyConstraint(fromTable, toTable schema.TableName, op
}
}

_, err := p.DB.ExecContext(p.Context.Context,
_, err := p.TX.ExecContext(p.Context.Context,
fmt.Sprintf("ALTER TABLE %s ADD %s", options.FromTable, p.foreignKeyConstraint(options)))
if err != nil {
p.Context.RaiseError(fmt.Errorf("error while adding foreign key: %w", err))
Expand Down Expand Up @@ -330,7 +330,7 @@ func (p *Schema) dropConstraint(table schema.TableName, constraintName string, i
"constraint_name": utils.StrFunc(constraintName),
}

_, err := p.DB.ExecContext(p.Context.Context, replacer.Replace(query))
_, err := p.TX.ExecContext(p.Context.Context, replacer.Replace(query))
if err != nil {
return fmt.Errorf("error while dropping foreign key: %w", err)
}
Expand Down Expand Up @@ -384,7 +384,7 @@ func (p *Schema) AddPrimaryKeyConstraint(tableName schema.TableName, columns []s
}

sql := fmt.Sprintf("ALTER TABLE %s ADD %s", options.Table, p.primaryKeyConstraint(options))
_, err := p.DB.ExecContext(p.Context.Context, sql)
_, err := p.TX.ExecContext(p.Context.Context, sql)
if err != nil {
p.Context.RaiseError(fmt.Errorf("error while adding primary key: %w", err))
return
Expand Down
14 changes: 7 additions & 7 deletions pkg/schema/pg/enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (s *Schema) CreateEnum(name string, values []string, opts ...schema.CreateE
},
}

_, err := s.DB.ExecContext(s.Context.Context, replacer.Replace(q))
_, err := s.TX.ExecContext(s.Context.Context, replacer.Replace(q))
if err != nil {
s.Context.RaiseError(fmt.Errorf("error while creating enum: %w", err))
return
Expand Down Expand Up @@ -101,7 +101,7 @@ func (s *Schema) AddEnumValue(name string, value string, opts ...schema.AddEnumV
fmt.Sprintf("AFTER %s", QuoteValue(options.AfterValue))),
}

_, err := s.NotInTx.ExecContext(s.Context.Context, replacer.Replace(q))
_, err := s.DB.ExecContext(s.Context.Context, replacer.Replace(q))
if err != nil {
s.Context.RaiseError(fmt.Errorf("error while adding enum value: %w", err))
return
Expand Down Expand Up @@ -146,7 +146,7 @@ func (s *Schema) DropEnum(name string, opts ...schema.DropEnumOptions) {
"if_exists": utils.StrFuncPredicate(options.IfExists, "IF EXISTS"),
}

_, err := s.DB.ExecContext(s.Context.Context, replacer.Replace(q))
_, err := s.TX.ExecContext(s.Context.Context, replacer.Replace(q))
if err != nil {
s.Context.RaiseError(fmt.Errorf("error while dropping enum: %w", err))
return
Expand Down Expand Up @@ -189,7 +189,7 @@ ORDER BY
values = append(values, *schemaName)
}

rows, err := s.DB.QueryContext(s.Context.Context, replacer.Replace(q), values...)
rows, err := s.TX.QueryContext(s.Context.Context, replacer.Replace(q), values...)
if err != nil {
s.Context.RaiseError(fmt.Errorf("error while finding enum usage: %w", err))
return nil
Expand Down Expand Up @@ -233,7 +233,7 @@ ORDER BY enumsortorder`
args = append(args, *schemaName)
}

rows, err := s.DB.QueryContext(s.Context.Context, replacer.Replace(q), args...)
rows, err := s.TX.QueryContext(s.Context.Context, replacer.Replace(q), args...)
if err != nil {
s.Context.RaiseError(fmt.Errorf("error while listing enum values: %w", err))
return nil
Expand Down Expand Up @@ -283,7 +283,7 @@ func (s *Schema) RenameEnum(oldName, newName string, opts ...schema.RenameEnumOp
"new_enum_name": utils.StrFunc(QuoteIdent(newName)),
}

_, err := s.DB.ExecContext(s.Context.Context, replacer.Replace(q))
_, err := s.TX.ExecContext(s.Context.Context, replacer.Replace(q))
if err != nil {
s.Context.RaiseError(fmt.Errorf("error while renaming enum: %w", err))
return
Expand Down Expand Up @@ -325,7 +325,7 @@ func (s *Schema) RenameEnumValue(name, oldName, newName string, opts ...schema.R
"new_value": utils.StrFunc(QuoteValue(newName)),
}

_, err := s.DB.ExecContext(s.Context.Context, replacer.Replace(q))
_, err := s.TX.ExecContext(s.Context.Context, replacer.Replace(q))
if err != nil {
s.Context.RaiseError(fmt.Errorf("error while renaming enum value: %w", err))
return
Expand Down
2 changes: 1 addition & 1 deletion pkg/schema/pg/enum_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ func TestPostgres_RenameValue(t *testing.T) {
})

// insert some data
_, err := p.DB.ExecContext(p.Context.Context,
_, err := p.TX.ExecContext(p.Context.Context,
"INSERT INTO tst_pg_rename_enum_value_0.articles (status) VALUES ('active');")
require.NoError(t, err)

Expand Down
10 changes: 5 additions & 5 deletions pkg/schema/pg/exists.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ func (p *Schema) ConstraintExist(tableName schema.TableName, constraintName stri
var result bool
query := "SELECT EXISTS(SELECT 1 FROM information_schema.table_constraints WHERE table_name = $1 AND constraint_name = $2 and constraint_schema = $3)"

row := p.DB.QueryRowContext(p.Context.Context, query, tableName.Name(), constraintName, tableName.Schema())
row := p.TX.QueryRowContext(p.Context.Context, query, tableName.Name(), constraintName, tableName.Schema())
if err := row.Scan(&result); err != nil {
p.Context.RaiseError(fmt.Errorf("error while checking if constraint exists: %w", err))
return false
Expand All @@ -25,7 +25,7 @@ func (p *Schema) ColumnExist(tableName schema.TableName, columnName string) bool
var result bool
query := "SELECT EXISTS(SELECT 1 FROM information_schema.columns WHERE table_name = $1 AND column_name = $2 and table_schema = $3)"

row, err := p.DB.QueryContext(p.Context.Context, query, tableName.Name(), columnName, tableName.Schema())
row, err := p.TX.QueryContext(p.Context.Context, query, tableName.Name(), columnName, tableName.Schema())
if err != nil {
p.Context.RaiseError(fmt.Errorf("error while checking if column exists: %w", err))
return false
Expand All @@ -44,7 +44,7 @@ func (p *Schema) TableExist(tableName schema.TableName) bool {
var result bool
query := "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name = $1 AND table_schema = $2)"

row := p.DB.QueryRowContext(p.Context.Context, query, tableName.Name(), tableName.Schema())
row := p.TX.QueryRowContext(p.Context.Context, query, tableName.Name(), tableName.Schema())
if err := row.Scan(&result); err != nil {
p.Context.RaiseError(fmt.Errorf("error while checking if table exists: %w", err))
return false
Expand All @@ -58,7 +58,7 @@ func (p *Schema) IndexExist(tableName schema.TableName, indexName string) bool {
var result bool
query := "SELECT EXISTS(SELECT 1 FROM pg_indexes WHERE tablename = $1 AND indexname = $2 and schemaname = $3)"

row := p.DB.QueryRowContext(p.Context.Context, query, tableName.Name(), indexName, tableName.Schema())
row := p.TX.QueryRowContext(p.Context.Context, query, tableName.Name(), indexName, tableName.Schema())
if err := row.Scan(&result); err != nil {
p.Context.RaiseError(fmt.Errorf("error while checking if index exists: %w", err))
return false
Expand All @@ -71,7 +71,7 @@ func (p *Schema) PrimaryKeyExist(tableName schema.TableName) bool {
var result bool
query := "SELECT EXISTS(SELECT 1 FROM information_schema.table_constraints WHERE table_name = $1 AND constraint_type = 'PRIMARY KEY')"

row := p.DB.QueryRowContext(context.Background(), query, tableName.Name())
row := p.TX.QueryRowContext(context.Background(), query, tableName.Name())
if err := row.Scan(&result); err != nil {
p.Context.RaiseError(fmt.Errorf("error while checking if primary key exists: %w", err))
return false
Expand Down
4 changes: 2 additions & 2 deletions pkg/schema/pg/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func (p *Schema) AddIndexConstraint(table schema.TableName, columns []string, op
},
}

_, err := p.DB.ExecContext(p.Context.Context, replacer.Replace(sql))
_, err := p.TX.ExecContext(p.Context.Context, replacer.Replace(sql))
if err != nil {
p.Context.RaiseError(fmt.Errorf("error while adding index: %w", err))
return
Expand Down Expand Up @@ -245,7 +245,7 @@ func (p *Schema) DropIndex(table schema.TableName, columns []string, opt ...sche
},
}

_, err := p.DB.ExecContext(p.Context.Context, replacer.Replace(sql))
_, err := p.TX.ExecContext(p.Context.Context, replacer.Replace(sql))
if err != nil {
p.Context.RaiseError(fmt.Errorf("error while dropping index: %w", err))
return
Expand Down
27 changes: 16 additions & 11 deletions pkg/schema/pg/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,22 @@ import (
)

type Schema struct {
DB schema.DB
NotInTx schema.DB
// TX is the transaction to execute the queries.
TX schema.DB

// DB is a database connection but not in a transaction.
DB schema.DB

Context *schema.MigratorContext

// ReversibleMigrationExecutor is a helper to execute reversible migrations in change method.
*schema.ReversibleMigrationExecutor
}

func NewPostgres(ctx *schema.MigratorContext, tx schema.DB, db schema.DB) *Schema {
return &Schema{
DB: tx,
NotInTx: db,
TX: tx,
DB: db,
Context: ctx,
ReversibleMigrationExecutor: schema.NewReversibleMigrationExecutor(ctx),
}
Expand All @@ -29,15 +34,15 @@ func (p *Schema) rollbackMode() *Schema {
ctx := *p.Context
ctx.MigrationDirection = types.MigrationDirectionNotReversible
return &Schema{
TX: p.TX,
DB: p.DB,
NotInTx: p.NotInTx,
Context: &ctx,
ReversibleMigrationExecutor: schema.NewReversibleMigrationExecutor(&ctx),
}
}

func (p *Schema) Exec(query string, args ...interface{}) {
_, err := p.DB.ExecContext(p.Context.Context, query, args...)
_, err := p.TX.ExecContext(p.Context.Context, query, args...)
if err != nil {
p.Context.RaiseError(fmt.Errorf("error while executing query: %w", err))
return
Expand Down Expand Up @@ -73,7 +78,7 @@ func (p *Schema) AddExtension(name string, option ...schema.ExtensionOptions) {
"schema": utils.StrFuncPredicate(options.Schema != "", fmt.Sprintf("SCHEMA %s", options.Schema)),
}

_, err := p.DB.ExecContext(p.Context.Context, replacer.Replace(sql))
_, err := p.TX.ExecContext(p.Context.Context, replacer.Replace(sql))
if err != nil {
p.Context.RaiseError(fmt.Errorf("error while adding extension: %w", err))
return
Expand Down Expand Up @@ -128,7 +133,7 @@ func (p *Schema) DropExtension(name string, opt ...schema.DropExtensionOptions)
"name": utils.StrFunc(p.toExtension(options.ExtensionName)),
}

_, err := p.DB.ExecContext(p.Context.Context, replacer.Replace(sql))
_, err := p.TX.ExecContext(p.Context.Context, replacer.Replace(sql))
if err != nil {
p.Context.RaiseError(fmt.Errorf("error while dropping extension: %w", err))
return
Expand All @@ -146,7 +151,7 @@ func (p *Schema) AddVersion(version string) {
"version_table": utils.StrFunc(p.Context.MigratorOptions.SchemaVersionTable.String()),
}

_, err := p.DB.ExecContext(p.Context.Context, replacer.Replace(sql), version)
_, err := p.TX.ExecContext(p.Context.Context, replacer.Replace(sql), version)
if err != nil {
p.Context.RaiseError(fmt.Errorf("error while adding version: %w", err))
return
Expand All @@ -164,7 +169,7 @@ func (p *Schema) RemoveVersion(version string) {
"version_table": utils.StrFunc(p.Context.MigratorOptions.SchemaVersionTable.String()),
}

_, err := p.DB.ExecContext(p.Context.Context, replacer.Replace(sql), version)
_, err := p.TX.ExecContext(p.Context.Context, replacer.Replace(sql), version)
if err != nil {
p.Context.RaiseError(fmt.Errorf("error while removing version: %w", err))
return
Expand All @@ -181,7 +186,7 @@ func (p *Schema) FindAppliedVersions() []string {
"version_table": utils.StrFunc(p.Context.MigratorOptions.SchemaVersionTable.String()),
}

rows, err := p.DB.QueryContext(p.Context.Context, replacer.Replace(sql))
rows, err := p.TX.QueryContext(p.Context.Context, replacer.Replace(sql))
if err != nil {
p.Context.RaiseError(fmt.Errorf("error while fetching applied versions: %w", err))
return nil
Expand Down
2 changes: 1 addition & 1 deletion pkg/schema/pg/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ where c.table_schema = $1
and c.table_name = $2
order by column_name;`

rows, err := p.DB.QueryContext(context.Background(), query, tableName.Schema(), tableName.Name())
rows, err := p.TX.QueryContext(context.Background(), query, tableName.Schema(), tableName.Name())
require.NoError(t, err)

defer rows.Close() // Ensure rows are closed
Expand Down
8 changes: 4 additions & 4 deletions pkg/schema/pg/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func (p *Schema) CreateTable(tableName schema.TableName, f func(*PostgresTableDe
"table_options": utils.StrFuncPredicate(options.Option != "", options.Option),
}

_, err := p.DB.ExecContext(p.Context.Context, replacer.Replace(q))
_, err := p.TX.ExecContext(p.Context.Context, replacer.Replace(q))
if err != nil {
p.Context.RaiseError(fmt.Errorf("error while creating table: %w", err))
return
Expand Down Expand Up @@ -435,7 +435,7 @@ func (p *Schema) DropTable(tableName schema.TableName, opts ...schema.DropTableO
"table_name": utils.StrFunc(tableName.String()),
}

_, err := p.DB.ExecContext(p.Context.Context, replacer.Replace(q))
_, err := p.TX.ExecContext(p.Context.Context, replacer.Replace(q))
if err != nil {
p.Context.RaiseError(fmt.Errorf("error while dropping table: %w", err))
return
Expand Down Expand Up @@ -466,7 +466,7 @@ func (p *Schema) RenameTable(oldTableName, newTableName schema.TableName) {
"new_table_name": utils.StrFunc(newTableName.String()),
}

_, err := p.DB.ExecContext(p.Context.Context, replacer.Replace(q))
_, err := p.TX.ExecContext(p.Context.Context, replacer.Replace(q))
if err != nil {
p.Context.RaiseError(fmt.Errorf("error while renaming table: %w", err))
return
Expand Down Expand Up @@ -518,7 +518,7 @@ func (p *Schema) AddTableComment(tableName schema.TableName, comment *string, op
},
}

_, err := p.DB.ExecContext(p.Context.Context, replacer.Replace(sql))
_, err := p.TX.ExecContext(p.Context.Context, replacer.Replace(sql))
if err != nil {
p.Context.RaiseError(fmt.Errorf("error while adding column comment: %w", err))
return
Expand Down
4 changes: 2 additions & 2 deletions pkg/schema/pg/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

func assertTableExist(t *testing.T, p *Schema, table schema.TableName) {
var exists bool
err := p.DB.QueryRowContext(context.Background(), `SELECT EXISTS (
err := p.TX.QueryRowContext(context.Background(), `SELECT EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_schema = $1
Expand All @@ -23,7 +23,7 @@ func assertTableExist(t *testing.T, p *Schema, table schema.TableName) {

func assertTableNotExist(t *testing.T, p *Schema, table schema.TableName) {
var exists bool
err := p.DB.QueryRowContext(context.Background(), `SELECT EXISTS (
err := p.TX.QueryRowContext(context.Background(), `SELECT EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_schema = $1
Expand Down

0 comments on commit 785ad7c

Please sign in to comment.