Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(options): allow customizing the names of the migration tables #38

Merged
merged 5 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ postgres:
-p 5432:5432 \
postgres:11

.PHONY: psql
psql:
@echo "---> Running psql"
psql -h localhost -p 5432 -U $(TEST_DATABASE_USER) -d $(TEST_DATABASE_NAME)

.PHONY: release
release:
@echo "---> Creating new release"
Expand Down
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ files to be saved in (which will be the same directory of the main package, e.g.
`example`), an instance of `*pg.DB`, and `os.Args`; and log any potential errors
that could be returned.

Once this has been set up, then you can use the `create`, `migrate`, `status`, `rollback`,
`help` commands like so:
You can also call `migrations.RunWithOptions` to configure the way that the
migrations run (e.g. customize the name of the migration tables).

Once this has been set up, then you can use the `create`, `migrate`, `status`,
`rollback`, `help` commands like so:

```
$ go run example/*.go create create_users_table
Expand Down
2 changes: 1 addition & 1 deletion create.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func init() {
}
`

func create(directory, name string) error {
func (m *migrator) create(directory, name string) error {
version := time.Now().UTC().Format(timeFormat)
fullname := fmt.Sprintf("%s_%s", version, name)
filename := path.Join(directory, fullname+".go")
Expand Down
3 changes: 2 additions & 1 deletion create_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ func TestCreate(t *testing.T) {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
tmp := os.TempDir()
name := fmt.Sprintf("create_test_migration_%d", r.Int())
m := newMigrator(nil, RunOptions{})

err := create(tmp, name)
err := m.create(tmp, name)
assert.Nil(t, err)

files, err := os.ReadDir(tmp)
Expand Down
109 changes: 58 additions & 51 deletions migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ func Register(name string, up, down func(orm.DB) error, opts MigrationOptions) {
})
}

func migrate(db *pg.DB) (err error) {
func (m *migrator) migrate() (err error) {
// sort the registered migrations by name (which will sort by the
// timestamp in their names)
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].Name < migrations[j].Name
})

// look at the migrations table to see the already run migrations
completed, err := getCompletedMigrations(db)
completed, err := m.getCompletedMigrations()
if err != nil {
return err
}
Expand All @@ -46,118 +46,125 @@ func migrate(db *pg.DB) (err error) {
}

// acquire the migration lock from the migrations_lock table
err = acquireLock(db)
err = m.acquireLock()
if err != nil {
return err
}
defer func() {
e := releaseLock(db)
e := m.releaseLock()
if e != nil && err == nil {
err = e
}
}()

// find the last batch number
batch, err := getLastBatchNumber(db)
batch, err := m.getLastBatchNumber()
if err != nil {
return err
}
batch++

fmt.Printf("Running batch %d with %d migration(s)...\n", batch, len(uncompleted))

for _, m := range uncompleted {
m.Batch = batch
for _, mig := range uncompleted {
var err error
if m.DisableTransaction {
err = m.Up(db)
if mig.DisableTransaction {
err = mig.Up(m.db)
} else {
err = db.RunInTransaction(db.Context(), func(tx *pg.Tx) error {
return m.Up(tx)
err = m.db.RunInTransaction(m.db.Context(), func(tx *pg.Tx) error {
return mig.Up(tx)
})
}
if err != nil {
return fmt.Errorf("%s: %s", m.Name, err)
return fmt.Errorf("%s: %s", mig.Name, err)
}

m.CompletedAt = time.Now()
_, err = db.Model(m).Insert()
migrationMap := map[string]interface{}{
"name": mig.Name,
"batch": batch,
"completed_at": time.Now(),
}
_, err = m.db.
Model(&migrationMap).
Table(m.opts.MigrationsTableName).
Insert()
if err != nil {
return fmt.Errorf("%s: %s", m.Name, err)
return fmt.Errorf("%s: %s", mig.Name, err)
}
fmt.Printf("Finished running %q\n", m.Name)
fmt.Printf("Finished running %q\n", mig.Name)
}

return nil
}

func getCompletedMigrations(db orm.DB) ([]*migration, error) {
func (m *migrator) getCompletedMigrations() ([]*migration, error) {
var completed []*migration

err := db.
Model(&completed).
err := orm.NewQuery(m.db).
Table(m.opts.MigrationsTableName).
Order("id").
Select()
Select(&completed)
if err != nil {
return nil, err
}

return completed, nil
}

func filterMigrations(all, subset []*migration, wantCompleted bool) []*migration {
subsetMap := map[string]bool{}

for _, c := range subset {
subsetMap[c.Name] = true
}

var d []*migration

for _, a := range all {
if subsetMap[a.Name] == wantCompleted {
d = append(d, a)
}
}

return d
}

func acquireLock(db *pg.DB) error {
l := lock{ID: lockID, IsLocked: true}

result, err := db.Model(&l).
func (m *migrator) acquireLock() error {
l := map[string]interface{}{"is_locked": true}
result, err := m.db.
Model(&l).
Table(m.opts.MigrationLockTableName).
Column("is_locked").
WherePK().
Where("id = ?", lockID).
Where("is_locked = ?", false).
Update()

if err != nil {
return err
}

if result.RowsAffected() == 0 {
return ErrAlreadyLocked
}

return nil
}

func releaseLock(db orm.DB) error {
l := lock{ID: lockID, IsLocked: false}
_, err := db.Model(&l).
WherePK().
func (m *migrator) releaseLock() error {
l := map[string]interface{}{"is_locked": false}
_, err := m.db.
Model(&l).
Table(m.opts.MigrationLockTableName).
Column("is_locked").
Where("id = ?", lockID).
Update()
return err
}

func getLastBatchNumber(db orm.DB) (int32, error) {
func (m *migrator) getLastBatchNumber() (int32, error) {
var res struct{ Batch int32 }
err := db.Model(&migration{}).
err := orm.NewQuery(m.db).
Table(m.opts.MigrationsTableName).
ColumnExpr("COALESCE(MAX(batch), 0) AS batch").
Select(&res)
if err != nil {
return 0, err
}
return res.Batch, nil
}

func filterMigrations(all, subset []*migration, wantCompleted bool) []*migration {
subsetMap := map[string]bool{}
for _, c := range subset {
subsetMap[c.Name] = true
}

var d []*migration
for _, a := range all {
if subsetMap[a.Name] == wantCompleted {
d = append(d, a)
}
}

return d
}
26 changes: 14 additions & 12 deletions migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ func TestMigrate(t *testing.T) {
User: os.Getenv("TEST_DATABASE_USER"),
Database: os.Getenv("TEST_DATABASE_NAME"),
})

db.AddQueryHook(logQueryHook{})
m := newMigrator(db, RunOptions{})

err := ensureMigrationTables(db)
err := m.ensureMigrationTables()
require.Nil(t, err)

defer clearMigrations(t, db)
Expand All @@ -65,7 +65,7 @@ func TestMigrate(t *testing.T) {
{Name: "123", Up: noopMigration, Down: noopMigration},
}

err := migrate(db)
err := m.migrate()
assert.Nil(tt, err)

assert.Equal(tt, "123", migrations[0].Name)
Expand All @@ -83,7 +83,7 @@ func TestMigrate(t *testing.T) {
_, err := db.Model(migrations[0]).Insert()
assert.Nil(tt, err)

err = migrate(db)
err = m.migrate()
assert.Nil(tt, err)

var m []*migration
Expand All @@ -105,7 +105,7 @@ func TestMigrate(t *testing.T) {
_, err := db.Model(&migrations).Insert()
assert.Nil(tt, err)

err = migrate(db)
err = m.migrate()
assert.Nil(tt, err)

count, err := db.Model(&migration{}).Where("batch = 2").Count()
Expand All @@ -121,11 +121,11 @@ func TestMigrate(t *testing.T) {
{Name: "456", Up: noopMigration, Down: noopMigration},
}

err := acquireLock(db)
err := m.acquireLock()
assert.Nil(tt, err)
defer releaseLock(db)
defer m.releaseLock()

err = migrate(db)
err = m.migrate()
assert.Equal(tt, ErrAlreadyLocked, err)
})

Expand All @@ -141,10 +141,10 @@ func TestMigrate(t *testing.T) {
_, err := db.Model(migrations[0]).Insert()
assert.Nil(tt, err)

err = migrate(db)
err = m.migrate()
assert.Nil(tt, err)

batch, err := getLastBatchNumber(db)
batch, err := m.getLastBatchNumber()
assert.Nil(tt, err)
assert.Equal(tt, batch, int32(6))

Expand All @@ -160,7 +160,7 @@ func TestMigrate(t *testing.T) {
{Name: "123", Up: erringMigration, Down: noopMigration, DisableTransaction: false},
}

err := migrate(db)
err := m.migrate()
assert.EqualError(tt, err, "123: error")

assertTable(tt, db, "test_table", false)
Expand All @@ -173,7 +173,7 @@ func TestMigrate(t *testing.T) {
{Name: "123", Up: erringMigration, Down: noopMigration, DisableTransaction: true},
}

err := migrate(db)
err := m.migrate()
assert.EqualError(tt, err, "123: error")

assertTable(tt, db, "test_table", true)
Expand Down Expand Up @@ -207,6 +207,8 @@ func clearMigrations(t *testing.T, db *pg.DB) {

_, err := db.Exec("DELETE FROM migrations")
assert.Nil(t, err)
_, err = db.Exec("UPDATE migration_lock SET is_locked = FALSE")
assert.Nil(t, err)
_, err = db.Exec("DROP TABLE IF EXISTS test_table")
assert.Nil(t, err)
}
Expand Down
Loading
Loading