Skip to content

Commit

Permalink
logger: implement logger option
Browse files Browse the repository at this point in the history
Add a Logger interface, along with a LoggerFunc and a new WithLogger option.

This would enable:

- Customizing migrator output
- Integrating a 3rd party logging tool

By default, migrator will work as before, just outputting to stdout.

Closes #6
  • Loading branch information
lopezator committed Apr 10, 2020
1 parent beba8dc commit bee4efd
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 12 deletions.
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,23 @@ func main() {
```

Notes on examples above:
- Migrator creates/manages a table named `migrations` to keep track of the applied versions. However, if want to customize the table name `migrator.TableName("my_migrations")` can be passed to `migrator.New` function as an additional option.

- Migrator creates/manages a table named `migrations` to keep track of the applied versions. However, if you want to customize the table name `migrator.TableName("my_migrations")` can be passed to `migrator.New` function as an additional option.

### Logging

By default, migrator prints applying/applied migration info to stdout.
If that's enough for you, you can skip this section.

If you need some special formatting or want to use a 3rd party logging library, this could be done by using `WithLogger` option as follows:

```go
logger := migrator.WithLogger(migrator.LoggerFunc(func(msg string, args ...interface{}) {
// Your code here
})))
```

Then you will only need to pass the logger as an option to `migrator.New`.

### Looking for more examples?

Expand Down
42 changes: 33 additions & 9 deletions migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@ import (
"database/sql"
"errors"
"fmt"
"log"
"os"
)

const defaultTableName = "migrations"

// Migrator is the migrator implementation
type Migrator struct {
tableName string
logger Logger
migrations []interface{}
}

Expand All @@ -24,6 +27,26 @@ func TableName(tableName string) Option {
}
}

// Logger interface
type Logger interface {
Printf(string, ...interface{})
}

// LoggerFunc is a bridge between Logger and any third party logger
type LoggerFunc func(string, ...interface{})

// Printf implements Logger interface
func (f LoggerFunc) Printf(msg string, args ...interface{}) {
f(msg, args...)
}

// WithLogger creates an option to allow overriding the stdout logging
func WithLogger(logger Logger) Option {
return func(m *Migrator) {
m.logger = logger
}
}

// Migrations creates an option with provided migrations
func Migrations(migrations ...interface{}) Option {
return func(m *Migrator) {
Expand All @@ -34,6 +57,7 @@ func Migrations(migrations ...interface{}) Option {
// New creates a new migrator instance
func New(opts ...Option) (*Migrator, error) {
m := &Migrator{
logger: log.New(os.Stdout, "migrator: ", 0),
tableName: defaultTableName,
}
for _, opt := range opts {
Expand Down Expand Up @@ -83,13 +107,13 @@ func (m *Migrator) Migrate(db *sql.DB) error {
// plan migrations
for idx, migration := range m.migrations[count:len(m.migrations)] {
insertVersion := fmt.Sprintf("INSERT INTO %s (id, version) VALUES (%d, '%s')", m.tableName, idx+count, migration.(fmt.Stringer).String())
switch m := migration.(type) {
switch mig := migration.(type) {
case *Migration:
if err := migrate(db, insertVersion, m); err != nil {
if err := migrate(db, m.logger, insertVersion, mig); err != nil {
return fmt.Errorf("migrator: error while running migrations: %v", err)
}
case *MigrationNoTx:
if err := migrateNoTx(db, insertVersion, m); err != nil {
if err := migrateNoTx(db, m.logger, insertVersion, mig); err != nil {
return fmt.Errorf("migrator: error while running migrations: %v", err)
}
}
Expand Down Expand Up @@ -149,7 +173,7 @@ func (m *MigrationNoTx) String() string {
return m.Name
}

func migrate(db *sql.DB, insertVersion string, migration *Migration) error {
func migrate(db *sql.DB, logger Logger, insertVersion string, migration *Migration) error {
tx, err := db.Begin()
if err != nil {
return err
Expand All @@ -163,27 +187,27 @@ func migrate(db *sql.DB, insertVersion string, migration *Migration) error {
}
err = tx.Commit()
}()
fmt.Println(fmt.Sprintf("migrator: applying migration named '%s'...", migration.Name))
logger.Printf("applying migration named '%s'...", migration.Name)
if err = migration.Func(tx); err != nil {
return fmt.Errorf("error executing golang migration: %s", err)
}
if _, err = tx.Exec(insertVersion); err != nil {
return fmt.Errorf("error updating migration versions: %s", err)
}
fmt.Println(fmt.Sprintf("migrator: applied migration named '%s'", migration.Name))
logger.Printf("applied migration named '%s'", migration.Name)

return err
}

func migrateNoTx(db *sql.DB, insertVersion string, migration *MigrationNoTx) error {
fmt.Println(fmt.Sprintf("migrator: applying no tx migration named '%s'...", migration.Name))
func migrateNoTx(db *sql.DB, logger Logger, insertVersion string, migration *MigrationNoTx) error {
logger.Printf("applying no tx migration named '%s'...", migration.Name)
if err := migration.Func(db); err != nil {
return fmt.Errorf("error executing golang migration: %s", err)
}
if _, err := db.Exec(insertVersion); err != nil {
return fmt.Errorf("error updating migration versions: %s", err)
}
fmt.Println(fmt.Sprintf("migrator: applied no tx migration named '%s'...", migration.Name))
logger.Printf("applied no tx migration named '%s'", migration.Name)

return nil
}
5 changes: 3 additions & 2 deletions migrator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package migrator
import (
"database/sql"
"fmt"
"log"
"os"
"strings"
"testing"
Expand Down Expand Up @@ -167,7 +168,7 @@ func TestBadMigrate(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if err := migrate(db, "BAD INSERT VERSION", &Migration{Name: "bad insert version", Func: func(tx *sql.Tx) error {
if err := migrate(db, log.New(os.Stdout, "migrator: ", 0), "BAD INSERT VERSION", &Migration{Name: "bad insert version", Func: func(tx *sql.Tx) error {
return nil
}}); err == nil {
t.Fatal("BAD INSERT VERSION should fail!")
Expand All @@ -179,7 +180,7 @@ func TestBadMigrateNoTx(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if err := migrateNoTx(db, "BAD INSERT VERSION", &MigrationNoTx{Name: "bad migrate no tx", Func: func(db *sql.DB) error {
if err := migrateNoTx(db, log.New(os.Stdout, "migrator: ", 0), "BAD INSERT VERSION", &MigrationNoTx{Name: "bad migrate no tx", Func: func(db *sql.DB) error {
return nil
}}); err == nil {
t.Fatal("BAD INSERT VERSION should fail!")
Expand Down

0 comments on commit bee4efd

Please sign in to comment.