Skip to content

Commit

Permalink
migrator: add functional options
Browse files Browse the repository at this point in the history
This allows customizing migrator behaviour by exposing it's configurable
parts through functional options. This change starts by adding a
`Migrations` and `TableName` options to fullfil current requirements.

Closes #13
Superseeds #14
  • Loading branch information
glerchundi authored and lopezator committed Aug 11, 2019
1 parent 32f7aaa commit 400b0c9
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 74 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ prepare: setup-env mod-download
.PHONY: mod-download
mod-download:
@echo "Running download..."
go mod download GOPROXY="$(GOPROXY)"
GOPROXY="$(GOPROXY)" go mod download

.PHONY: sanity-check
sanity-check: golangci-lint
Expand All @@ -34,4 +34,4 @@ golangci-lint:
.PHONY: test
test:
@echo "Running tests..."
2>&1 POSTGRES_URL="$(POSTGRES_URL)" MYSQL_URL="$(MYSQL_URL)" go test -v -tags="unit integration" -coverprofile=coverage.txt -covermode=atomic
2>&1 POSTGRES_URL="$(POSTGRES_URL)" MYSQL_URL="$(MYSQL_URL)" go test -v -tags="unit integration" -coverprofile=coverage.txt -covermode=atomic
20 changes: 11 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,17 @@ import (

func main() {
m := migrator.New(
&migrator.Migration{
Name: "Create table foo",
Func: func(tx *sql.Tx) error {
if _, err := tx.Exec("CREATE TABLE foo (id INT PRIMARY KEY)"); err != nil {
return err
}
return nil
Migrations(
&migrator.Migration{
Name: "Create table foo",
Func: func(tx *sql.Tx) error {
if _, err := tx.Exec("CREATE TABLE foo (id INT PRIMARY KEY)"); err != nil {
return err
}
return nil
},
},
},
),
)

// Migrate up
Expand All @@ -72,7 +74,7 @@ func main() {
```

Notes on examples above:
- Migrator creates/manages a table named `migrations` to keep track of the applied versions
- Migrator creates/manages a table named `migrations` to keep track of the applied versions. However, if want to customize the table name `TableName("my_migrations")` can be also used.

### Looking for more examples?

Expand Down
55 changes: 44 additions & 11 deletions migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,54 @@ import (
"fmt"
)

const tableName = "migrations"
const defaultTableName = "migrations"

// Migrator is the migrator implementation
type Migrator struct {
tableName string
migrations []interface{}
}

// Option sets options such migrations or table name.
type Option func(*Migrator)

// TableName creates an option to allow overriding the default table name
func TableName(tableName string) Option {
return func(m *Migrator) {
m.tableName = tableName
}
}

// Migrations creates an option with provided migrations
func Migrations(migrations ...interface{}) Option {
return func(m *Migrator) {
m.migrations = migrations
}
}

// New creates a new migrator instance
func New(migrations ...interface{}) (*Migrator, error) {
for _, m := range migrations {
func New(opts ...Option) (*Migrator, error) {
m := &Migrator{
tableName: defaultTableName,
}
for _, opt := range opts {
opt(m)
}

if len(m.migrations) == 0 {
return nil, errors.New("migrator: migrations must be provided")
}

for _, m := range m.migrations {
switch m.(type) {
case *Migration:
case *MigrationNoTx:
default:
return nil, errors.New("migrator: invalid migration type")
}
}
return &Migrator{migrations: migrations}, nil

return m, nil
}

// Migrate applies all available migrations
Expand All @@ -35,13 +65,13 @@ func (m *Migrator) Migrate(db *sql.DB) error {
version VARCHAR(255) NOT NULL,
PRIMARY KEY (id)
);
`, tableName))
`, m.tableName))
if err != nil {
return err
}

// count applied migrations
count, err := countApplied(db)
count, err := countApplied(db, m.tableName)
if err != nil {
return err
}
Expand All @@ -52,7 +82,7 @@ func (m *Migrator) Migrate(db *sql.DB) error {

// plan migrations
for idx, migration := range m.migrations[count:len(m.migrations)] {
insertVersion := fmt.Sprintf("INSERT INTO %s (id, version) VALUES (%d, '%s')", tableName, idx+count, migration.(fmt.Stringer).String())
insertVersion := fmt.Sprintf("INSERT INTO %s (id, version) VALUES (%d, '%s')", m.tableName, idx+count, migration.(fmt.Stringer).String())
switch m := migration.(type) {
case *Migration:
if err := migrate(db, insertVersion, m); err != nil {
Expand All @@ -69,12 +99,15 @@ func (m *Migrator) Migrate(db *sql.DB) error {
}

// Pending returns all pending (not yet applied) migrations
func (m *Migrator) Pending(db *sql.DB) []interface{} {
count, _ := countApplied(db)
return m.migrations[count:len(m.migrations)]
func (m *Migrator) Pending(db *sql.DB) ([]interface{}, error) {
count, err := countApplied(db, m.tableName)
if err != nil {
return nil, err
}
return m.migrations[count:len(m.migrations)], nil
}

func countApplied(db *sql.DB) (int, error) {
func countApplied(db *sql.DB, tableName string) (int, error) {
// count applied migrations
var count int
rows, err := db.Query(fmt.Sprintf("SELECT count(*) FROM %s", tableName))
Expand Down
104 changes: 52 additions & 52 deletions migrator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,43 +13,45 @@ import (
_ "github.com/lib/pq" // postgres driver
)

func migrateTest(driverName, url string) error {
migrator, err := New(
&Migration{
Name: "Using tx, encapsulate two queries",
Func: func(tx *sql.Tx) error {
if _, err := tx.Exec("CREATE TABLE foo (id INT PRIMARY KEY)"); err != nil {
return err
}
if _, err := tx.Exec("INSERT INTO foo (id) VALUES (1)"); err != nil {
return err
}
return nil
},
var migrations = []interface{}{
&Migration{
Name: "Using tx, encapsulate two queries",
Func: func(tx *sql.Tx) error {
if _, err := tx.Exec("CREATE TABLE foo (id INT PRIMARY KEY)"); err != nil {
return err
}
if _, err := tx.Exec("INSERT INTO foo (id) VALUES (1)"); err != nil {
return err
}
return nil
},
&MigrationNoTx{
Name: "Using db, execute one query",
Func: func(db *sql.DB) error {
if _, err := db.Exec("INSERT INTO foo (id) VALUES (2)"); err != nil {
return err
}
return nil
},
},
&MigrationNoTx{
Name: "Using db, execute one query",
Func: func(db *sql.DB) error {
if _, err := db.Exec("INSERT INTO foo (id) VALUES (2)"); err != nil {
return err
}
return nil
},
&Migration{
Name: "Using tx, one embedded query",
Func: func(tx *sql.Tx) error {
query, err := _escFSString(false, "/testdata/0_bar.sql")
if err != nil {
return err
}
if _, err := tx.Exec(query); err != nil {
return err
}
return nil
},
},
&Migration{
Name: "Using tx, one embedded query",
Func: func(tx *sql.Tx) error {
query, err := _escFSString(false, "/testdata/0_bar.sql")
if err != nil {
return err
}
if _, err := tx.Exec(query); err != nil {
return err
}
return nil
},
)
},
}

func migrateTest(driverName, url string) error {
migrator, err := New(Migrations(migrations...))
if err != nil {
return err
}
Expand Down Expand Up @@ -89,7 +91,7 @@ func TestMigrationNumber(t *testing.T) {
if err != nil {
t.Fatal(err)
}
count, err := countApplied(db)
count, err := countApplied(db, defaultTableName)
if err != nil {
t.Fatal(err)
}
Expand All @@ -99,7 +101,7 @@ func TestMigrationNumber(t *testing.T) {
}

func TestDatabaseNotFound(t *testing.T) {
migrator, err := New(&Migration{})
migrator, err := New(Migrations(&Migration{}))
if err != nil {
t.Fatal(err)
}
Expand All @@ -114,7 +116,7 @@ func TestBadMigrations(t *testing.T) {
if err != nil {
t.Fatal(err)
}
_, err = db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName))
_, err = db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", defaultTableName))
if err != nil {
t.Fatal(err)
}
Expand All @@ -126,27 +128,27 @@ func TestBadMigrations(t *testing.T) {
}{
{
name: "bad tx migration",
input: mustMigrator(New(&Migration{
input: mustMigrator(New(Migrations(&Migration{
Name: "bad tx migration",
Func: func(tx *sql.Tx) error {
if _, err := tx.Exec("FAIL FAST"); err != nil {
return err
}
return nil
},
})),
}))),
},
{
name: "bad db migration",
input: mustMigrator(New(&MigrationNoTx{
input: mustMigrator(New(Migrations(&MigrationNoTx{
Name: "bad db migration",
Func: func(db *sql.DB) error {
if _, err := db.Exec("FAIL FAST"); err != nil {
return err
}
return nil
},
})),
}))),
},
}

Expand Down Expand Up @@ -189,7 +191,7 @@ func TestBadMigrationNumber(t *testing.T) {
if err != nil {
t.Fatal(err)
}
migrator := mustMigrator(New(
migrator := mustMigrator(New(Migrations(
&Migration{
Name: "bad migration number",
Func: func(tx *sql.Tx) error {
Expand All @@ -199,7 +201,7 @@ func TestBadMigrationNumber(t *testing.T) {
return nil
},
},
))
)))
if err := migrator.Migrate(db); err == nil {
t.Fatalf("BAD MIGRATION NUMBER should fail: %v", err)
}
Expand All @@ -210,12 +212,7 @@ func TestPending(t *testing.T) {
if err != nil {
t.Fatal(err)
}
migrator, _ := New()
pending := migrator.Pending(db)
if len(pending) != 0 {
t.Fatalf("pending migrations should be 0")
}
migrator = mustMigrator(New(
migrator := mustMigrator(New(Migrations(
&Migration{
Name: "Using tx, create baz table",
Func: func(tx *sql.Tx) error {
Expand All @@ -225,9 +222,12 @@ func TestPending(t *testing.T) {
return nil
},
},
))
pending = migrator.Pending(db)
)))
pending, err := migrator.Pending(db)
if err != nil {
t.Fatal(err)
}
if len(pending) != 1 {
t.Fatalf("pending migrations should be 1")
t.Fatalf("pending migrations should be 1, got %d", len(pending))
}
}

0 comments on commit 400b0c9

Please sign in to comment.