Skip to content

Support multi-statement execution for PostgreSQL #495

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions database/postgres/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
|------------|---------------------|-------------|
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds |
| `x-multi-statement` | `MultiStatementEnabled` | Enable multi-statement execution (default: false) |
| `x-multi-statement-max-size` | `MultiStatementMaxSize` | Maximum size of single statement in bytes (default: 10MB) |
| `dbname` | `DatabaseName` | The name of the database to connect to |
| `search_path` | | This variable specifies the order in which schemas are searched when an object is referenced by a simple name with no schema specified. |
| `user` | | The user to sign in as |
Expand All @@ -27,3 +29,10 @@
2. Wrap your existing migrations in transactions ([BEGIN/COMMIT](https://www.postgresql.org/docs/current/static/transaction-iso.html)) if you use multiple statements within one migration.
3. Download and install the latest migrate version.
4. Force the current migration version with `migrate force <current_version>`.

## Multi-statement mode

In PostgreSQL running multiple SQL statements in one `Exec` executes them inside a transaction. Sometimes this
behavior is not desirable because some statements can be only run outside of transaction (e.g.
`CREATE INDEX CONCURRENTLY`). If you want to use `CREATE INDEX CONCURRENTLY` without activating multi-statement mode
you have to put such statements in a separate migration files.
64 changes: 51 additions & 13 deletions database/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
"github.com/golang-migrate/migrate/v4/database/multistmt"
multierror "github.com/hashicorp/go-multierror"
"github.com/lib/pq"
)
Expand All @@ -25,7 +26,12 @@ func init() {
database.Register("postgresql", &db)
}

var DefaultMigrationsTable = "schema_migrations"
var (
multiStmtDelimiter = []byte(";")

DefaultMigrationsTable = "schema_migrations"
DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
)

var (
ErrNilConfig = fmt.Errorf("no config")
Expand All @@ -35,10 +41,12 @@ var (
)

type Config struct {
MigrationsTable string
DatabaseName string
SchemaName string
StatementTimeout time.Duration
MigrationsTable string
DatabaseName string
SchemaName string
StatementTimeout time.Duration
MultiStatementEnabled bool
MultiStatementMaxSize int
}

type Postgres struct {
Expand Down Expand Up @@ -132,10 +140,23 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
}
}

multiStatementMaxSize := DefaultMultiStatementMaxSize
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The init code, and constants duplicate each other across the drivers, so probably not a bad idea to extract this into a common place (or even introduce new MultiStatementDriver interface).

Is this what you meant by init code?
What were you thinking in terms of a MultiStatementDriver interface?

if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
multiStatementMaxSize, err = strconv.Atoi(s)
if err != nil {
return nil, err
}
if multiStatementMaxSize <= 0 {
multiStatementMaxSize = DefaultMultiStatementMaxSize
}
}

px, err := WithInstance(db, &Config{
DatabaseName: purl.Path,
MigrationsTable: migrationsTable,
StatementTimeout: time.Duration(statementTimeout) * time.Millisecond,
DatabaseName: purl.Path,
MigrationsTable: migrationsTable,
StatementTimeout: time.Duration(statementTimeout) * time.Millisecond,
MultiStatementEnabled: purl.Query().Get("x-multi-statement") == "true",
MultiStatementMaxSize: multiStatementMaxSize,
})

if err != nil {
Expand Down Expand Up @@ -194,18 +215,36 @@ func (p *Postgres) Unlock() error {
}

func (p *Postgres) Run(migration io.Reader) error {
if p.config.MultiStatementEnabled {
var err error
if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool {
if err = p.runStatement(m); err != nil {
return false
}
return true
}); e != nil {
return e
}
return err
}
migr, err := ioutil.ReadAll(migration)
if err != nil {
return err
}
return p.runStatement(migr)
}

func (p *Postgres) runStatement(statement []byte) error {
ctx := context.Background()
if p.config.StatementTimeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout)
defer cancel()
}
// run migration
query := string(migr[:])
query := string(statement)
if strings.TrimSpace(query) == "" {
return nil
}
if _, err := p.conn.ExecContext(ctx, query); err != nil {
if pgErr, ok := err.(*pq.Error); ok {
var line uint
Expand All @@ -223,11 +262,10 @@ func (p *Postgres) Run(migration io.Reader) error {
if pgErr.Detail != "" {
message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
}
return database.Error{OrigErr: err, Err: message, Query: migr, Line: line}
return database.Error{OrigErr: err, Err: message, Query: statement, Line: line}
}
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
return database.Error{OrigErr: err, Err: "migration failed", Query: statement}
}

return nil
}

Expand Down
40 changes: 37 additions & 3 deletions database/postgres/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ var (
}
)

func pgConnectionString(host, port string) string {
return fmt.Sprintf("postgres://postgres:%s@%s:%s/postgres?sslmode=disable", pgPassword, host, port)
func pgConnectionString(host, port string, options ...string) string {
options = append(options, "sslmode=disable")
return fmt.Sprintf("postgres://postgres:%s@%s:%s/postgres?%s", pgPassword, host, port, strings.Join(options, "&"))
}

func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
Expand Down Expand Up @@ -122,7 +123,7 @@ func TestMigrate(t *testing.T) {
})
}

func TestMultiStatement(t *testing.T) {
func TestMultipleStatements(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
Expand Down Expand Up @@ -155,6 +156,39 @@ func TestMultiStatement(t *testing.T) {
})
}

func TestMultipleStatementsInMultiStatementMode(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}

addr := pgConnectionString(ip, port, "x-multi-statement=true")
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil {
t.Fatalf("expected err to be nil, got %v", err)
}

// make sure created index exists
var exists bool
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM pg_indexes WHERE schemaname = (SELECT current_schema()) AND indexname = 'idx_foo')").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table bar to exist")
}
})
}

func TestErrorParsing(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
Expand Down