diff --git a/database/cassandra/cassandra.go b/database/cassandra/cassandra.go index 15e276892..58713e72b 100644 --- a/database/cassandra/cassandra.go +++ b/database/cassandra/cassandra.go @@ -3,6 +3,7 @@ package cassandra import ( "errors" "fmt" + "go.uber.org/atomic" "io" "io/ioutil" nurl "net/url" @@ -45,7 +46,7 @@ type Config struct { type Cassandra struct { session *gocql.Session - isLocked bool + isLocked atomic.Bool // Open and WithInstance need to guarantee that config is never nil config *Config @@ -182,15 +183,16 @@ func (c *Cassandra) Close() error { } func (c *Cassandra) Lock() error { - if c.isLocked { + if !c.isLocked.CAS(false, true) { return database.ErrLocked } - c.isLocked = true return nil } func (c *Cassandra) Unlock() error { - c.isLocked = false + if !c.isLocked.CAS(true, false) { + return database.ErrNotLocked + } return nil } diff --git a/database/clickhouse/clickhouse.go b/database/clickhouse/clickhouse.go index b612aec59..625658ea8 100644 --- a/database/clickhouse/clickhouse.go +++ b/database/clickhouse/clickhouse.go @@ -285,7 +285,7 @@ func (ch *ClickHouse) Lock() error { } func (ch *ClickHouse) Unlock() error { if !ch.isLocked.CAS(true, false) { - return database.ErrLocked + return database.ErrNotLocked } return nil diff --git a/database/cockroachdb/cockroachdb.go b/database/cockroachdb/cockroachdb.go index 24cc6471f..935131ab3 100644 --- a/database/cockroachdb/cockroachdb.go +++ b/database/cockroachdb/cockroachdb.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "go.uber.org/atomic" "io" "io/ioutil" nurl "net/url" @@ -46,7 +47,7 @@ type Config struct { type CockroachDb struct { db *sql.DB - isLocked bool + isLocked atomic.Bool // Open and WithInstance need to guarantee that config is never nil config *Config @@ -152,71 +153,67 @@ func (c *CockroachDb) Close() error { // Locking is done manually with a separate lock table. Implementing advisory locks in CRDB is being discussed // See: https://github.com/cockroachdb/cockroach/issues/13546 func (c *CockroachDb) Lock() error { - err := crdb.ExecuteTx(context.Background(), c.db, nil, func(tx *sql.Tx) (err error) { - aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName) - if err != nil { - return err - } - - query := "SELECT * FROM " + c.config.LockTable + " WHERE lock_id = $1" - rows, err := tx.Query(query, aid) - if err != nil { - return database.Error{OrigErr: err, Err: "failed to fetch migration lock", Query: []byte(query)} - } - defer func() { - if errClose := rows.Close(); errClose != nil { - err = multierror.Append(err, errClose) + return database.CasRestoreOnErr(&c.isLocked, false, true, database.ErrLocked, func() (err error) { + return crdb.ExecuteTx(context.Background(), c.db, nil, func(tx *sql.Tx) (err error) { + aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName) + if err != nil { + return err } - }() - // If row exists at all, lock is present - locked := rows.Next() - if locked && !c.config.ForceLock { - return database.ErrLocked - } + query := "SELECT * FROM " + c.config.LockTable + " WHERE lock_id = $1" + rows, err := tx.Query(query, aid) + if err != nil { + return database.Error{OrigErr: err, Err: "failed to fetch migration lock", Query: []byte(query)} + } + defer func() { + if errClose := rows.Close(); errClose != nil { + err = multierror.Append(err, errClose) + } + }() + + // If row exists at all, lock is present + locked := rows.Next() + if locked && !c.config.ForceLock { + return database.ErrLocked + } - query = "INSERT INTO " + c.config.LockTable + " (lock_id) VALUES ($1)" - if _, err := tx.Exec(query, aid); err != nil { - return database.Error{OrigErr: err, Err: "failed to set migration lock", Query: []byte(query)} - } + query = "INSERT INTO " + c.config.LockTable + " (lock_id) VALUES ($1)" + if _, err := tx.Exec(query, aid); err != nil { + return database.Error{OrigErr: err, Err: "failed to set migration lock", Query: []byte(query)} + } - return nil + return nil + }) }) - - if err != nil { - return err - } else { - c.isLocked = true - return nil - } } // Locking is done manually with a separate lock table. Implementing advisory locks in CRDB is being discussed // See: https://github.com/cockroachdb/cockroach/issues/13546 func (c *CockroachDb) Unlock() error { - aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName) - if err != nil { - return err - } + return database.CasRestoreOnErr(&c.isLocked, true, false, database.ErrNotLocked, func() (err error) { + aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName) + if err != nil { + return err + } - // In the event of an implementation (non-migration) error, it is possible for the lock to not be released. Until - // a better locking mechanism is added, a manual purging of the lock table may be required in such circumstances - query := "DELETE FROM " + c.config.LockTable + " WHERE lock_id = $1" - if _, err := c.db.Exec(query, aid); err != nil { - if e, ok := err.(*pq.Error); ok { - // 42P01 is "UndefinedTableError" in CockroachDB - // https://github.com/cockroachdb/cockroach/blob/master/pkg/sql/pgwire/pgerror/codes.go - if e.Code == "42P01" { - // On drops, the lock table is fully removed; This is fine, and is a valid "unlocked" state for the schema - c.isLocked = false - return nil + // In the event of an implementation (non-migration) error, it is possible for the lock to not be released. Until + // a better locking mechanism is added, a manual purging of the lock table may be required in such circumstances + query := "DELETE FROM " + c.config.LockTable + " WHERE lock_id = $1" + if _, err := c.db.Exec(query, aid); err != nil { + if e, ok := err.(*pq.Error); ok { + // 42P01 is "UndefinedTableError" in CockroachDB + // https://github.com/cockroachdb/cockroach/blob/master/pkg/sql/pgwire/pgerror/codes.go + if e.Code == "42P01" { + // On drops, the lock table is fully removed; This is fine, and is a valid "unlocked" state for the schema + return nil + } } + + return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)} } - return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)} - } - c.isLocked = false - return nil + return nil + }) } func (c *CockroachDb) Run(migration io.Reader) error { diff --git a/database/firebird/firebird.go b/database/firebird/firebird.go index ca393e22b..41ccc33d3 100644 --- a/database/firebird/firebird.go +++ b/database/firebird/firebird.go @@ -10,6 +10,7 @@ import ( "github.com/golang-migrate/migrate/v4/database" "github.com/hashicorp/go-multierror" _ "github.com/nakagami/firebirdsql" + "go.uber.org/atomic" "io" "io/ioutil" nurl "net/url" @@ -36,7 +37,7 @@ type Firebird struct { // Locking and unlocking need to use the same connection conn *sql.Conn db *sql.DB - isLocked bool + isLocked atomic.Bool // Open and WithInstance need to guarantee that config is never nil config *Config @@ -106,15 +107,16 @@ func (f *Firebird) Close() error { } func (f *Firebird) Lock() error { - if f.isLocked { + if !f.isLocked.CAS(false, true) { return database.ErrLocked } - f.isLocked = true return nil } func (f *Firebird) Unlock() error { - f.isLocked = false + if !f.isLocked.CAS(true, false) { + return database.ErrNotLocked + } return nil } diff --git a/database/mongodb/mongodb.go b/database/mongodb/mongodb.go index 17ca804f2..3e18fd44b 100644 --- a/database/mongodb/mongodb.go +++ b/database/mongodb/mongodb.go @@ -10,6 +10,7 @@ import ( "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" + "go.uber.org/atomic" "io" "io/ioutil" "net/url" @@ -40,9 +41,10 @@ var ( ) type Mongo struct { - client *mongo.Client - db *mongo.Database - config *Config + client *mongo.Client + db *mongo.Database + config *Config + isLocked atomic.Bool } type Locking struct { @@ -327,55 +329,60 @@ func (m *Mongo) ensureVersionTable() (err error) { // Utilizes advisory locking on the config.LockingCollection collection // This uses a unique index on the `locking_key` field. func (m *Mongo) Lock() error { - if !m.config.Locking.Enabled { - return nil - } - pid := os.Getpid() - hostname, err := os.Hostname() - if err != nil { - hostname = fmt.Sprintf("Could not determine hostname. Error: %s", err.Error()) - } + return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error { + if !m.config.Locking.Enabled { + return nil + } - newLockObj := lockObj{ - Key: lockKeyUniqueValue, - Pid: pid, - Hostname: hostname, - CreatedAt: time.Now(), - } - operation := func() error { - timeout, cancelFunc := context.WithTimeout(context.Background(), contextWaitTimeout) - _, err := m.db.Collection(m.config.Locking.CollectionName).InsertOne(timeout, newLockObj) - defer cancelFunc() - return err - } - exponentialBackOff := backoff.NewExponentialBackOff() - duration := time.Duration(m.config.Locking.Timeout) * time.Second - exponentialBackOff.MaxElapsedTime = duration - exponentialBackOff.MaxInterval = time.Duration(m.config.Locking.Interval) * time.Second + pid := os.Getpid() + hostname, err := os.Hostname() + if err != nil { + hostname = fmt.Sprintf("Could not determine hostname. Error: %s", err.Error()) + } - err = backoff.Retry(operation, exponentialBackOff) - if err != nil { - return database.ErrLocked - } + newLockObj := lockObj{ + Key: lockKeyUniqueValue, + Pid: pid, + Hostname: hostname, + CreatedAt: time.Now(), + } + operation := func() error { + timeout, cancelFunc := context.WithTimeout(context.Background(), contextWaitTimeout) + _, err := m.db.Collection(m.config.Locking.CollectionName).InsertOne(timeout, newLockObj) + defer cancelFunc() + return err + } + exponentialBackOff := backoff.NewExponentialBackOff() + duration := time.Duration(m.config.Locking.Timeout) * time.Second + exponentialBackOff.MaxElapsedTime = duration + exponentialBackOff.MaxInterval = time.Duration(m.config.Locking.Interval) * time.Second - return nil + err = backoff.Retry(operation, exponentialBackOff) + if err != nil { + return database.ErrLocked + } + return nil + }) } + func (m *Mongo) Unlock() error { - if !m.config.Locking.Enabled { - return nil - } + return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error { + if !m.config.Locking.Enabled { + return nil + } - filter := findFilter{ - Key: lockKeyUniqueValue, - } + filter := findFilter{ + Key: lockKeyUniqueValue, + } - ctx, cancel := context.WithTimeout(context.Background(), contextWaitTimeout) - _, err := m.db.Collection(m.config.Locking.CollectionName).DeleteMany(ctx, filter) - defer cancel() + ctx, cancel := context.WithTimeout(context.Background(), contextWaitTimeout) + _, err := m.db.Collection(m.config.Locking.CollectionName).DeleteMany(ctx, filter) + defer cancel() - if err != nil { - return err - } - return nil + if err != nil { + return err + } + return nil + }) } diff --git a/database/mongodb/mongodb_test.go b/database/mongodb/mongodb_test.go index c73da46c4..f15f74113 100644 --- a/database/mongodb/mongodb_test.go +++ b/database/mongodb/mongodb_test.go @@ -221,18 +221,7 @@ func TestLockWorks(t *testing.T) { t.Fatal(err) } - // disable locking, validate wer can lock twice - mc.config.Locking.Enabled = false - err = mc.Lock() - if err != nil { - t.Fatal(err) - } - err = mc.Lock() - if err != nil { - t.Fatal(err) - } - - // re-enable locking, + // enable locking, //try to hit a lock conflict mc.config.Locking.Enabled = true mc.config.Locking.Timeout = 1 diff --git a/database/mysql/mysql.go b/database/mysql/mysql.go index 586df2494..29bb9a276 100644 --- a/database/mysql/mysql.go +++ b/database/mysql/mysql.go @@ -8,6 +8,7 @@ import ( "crypto/x509" "database/sql" "fmt" + "go.uber.org/atomic" "io" "io/ioutil" nurl "net/url" @@ -49,7 +50,7 @@ type Mysql struct { // just do everything over a single conn anyway. conn *sql.Conn db *sql.DB - isLocked bool + isLocked atomic.Bool config *Config } @@ -251,62 +252,53 @@ func (m *Mysql) Close() error { } func (m *Mysql) Lock() error { - if m.isLocked { - return database.ErrLocked - } - - if m.config.NoLock { - m.isLocked = true - return nil - } + return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error { + if m.config.NoLock { + return nil + } + aid, err := database.GenerateAdvisoryLockId( + fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) + if err != nil { + return err + } - aid, err := database.GenerateAdvisoryLockId( - fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) - if err != nil { - return err - } + query := "SELECT GET_LOCK(?, 10)" + var success bool + if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil { + return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} + } - query := "SELECT GET_LOCK(?, 10)" - var success bool - if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil { - return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} - } + if !success { + return database.ErrLocked + } - if success { - m.isLocked = true return nil - } - - return database.ErrLocked + }) } func (m *Mysql) Unlock() error { - if !m.isLocked { - return nil - } - - if m.config.NoLock { - m.isLocked = false - return nil - } + return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error { + if m.config.NoLock { + return nil + } - aid, err := database.GenerateAdvisoryLockId( - fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) - if err != nil { - return err - } + aid, err := database.GenerateAdvisoryLockId( + fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) + if err != nil { + return err + } - query := `SELECT RELEASE_LOCK(?)` - if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil { - return &database.Error{OrigErr: err, Query: []byte(query)} - } + query := `SELECT RELEASE_LOCK(?)` + if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil { + return &database.Error{OrigErr: err, Query: []byte(query)} + } - // NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed), - // in which case isLocked should be true until the timeout expires -- synchronizing - // these states is likely not worth trying to do; reconsider the necessity of isLocked. + // NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed), + // in which case isLocked should be true until the timeout expires -- synchronizing + // these states is likely not worth trying to do; reconsider the necessity of isLocked. - m.isLocked = false - return nil + return nil + }) } func (m *Mysql) Run(migration io.Reader) error { diff --git a/database/pgx/pgx.go b/database/pgx/pgx.go index 423ac1a08..fe709ef04 100644 --- a/database/pgx/pgx.go +++ b/database/pgx/pgx.go @@ -6,6 +6,7 @@ import ( "context" "database/sql" "fmt" + "go.uber.org/atomic" "io" "io/ioutil" nurl "net/url" @@ -58,7 +59,7 @@ type Postgres struct { // Locking and unlocking need to use the same connection conn *sql.Conn db *sql.DB - isLocked bool + isLocked atomic.Bool // Open and WithInstance need to guarantee that config is never nil config *Config @@ -220,41 +221,34 @@ func (p *Postgres) Close() error { // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS func (p *Postgres) Lock() error { - if p.isLocked { - return database.ErrLocked - } - - aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) - if err != nil { - return err - } - - // This will wait indefinitely until the lock can be acquired. - query := `SELECT pg_advisory_lock($1)` - if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { - return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} - } + return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error { + aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) + if err != nil { + return err + } - p.isLocked = true - return nil + // This will wait indefinitely until the lock can be acquired. + query := `SELECT pg_advisory_lock($1)` + if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { + return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} + } + return nil + }) } func (p *Postgres) Unlock() error { - if !p.isLocked { - return nil - } - - aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) - if err != nil { - return err - } + return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error { + aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) + if err != nil { + return err + } - query := `SELECT pg_advisory_unlock($1)` - if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { - return &database.Error{OrigErr: err, Query: []byte(query)} - } - p.isLocked = false - return nil + query := `SELECT pg_advisory_unlock($1)` + if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { + return &database.Error{OrigErr: err, Query: []byte(query)} + } + return nil + }) } func (p *Postgres) Run(migration io.Reader) error { diff --git a/database/postgres/postgres.go b/database/postgres/postgres.go index 0e384fe36..d59bfb524 100644 --- a/database/postgres/postgres.go +++ b/database/postgres/postgres.go @@ -6,6 +6,7 @@ import ( "context" "database/sql" "fmt" + "go.uber.org/atomic" "io" "io/ioutil" nurl "net/url" @@ -57,7 +58,7 @@ type Postgres struct { // Locking and unlocking need to use the same connection conn *sql.Conn db *sql.DB - isLocked bool + isLocked atomic.Bool // Open and WithInstance need to guarantee that config is never nil config *Config @@ -214,41 +215,35 @@ func (p *Postgres) Close() error { // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS func (p *Postgres) Lock() error { - if p.isLocked { - return database.ErrLocked - } - - aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) - if err != nil { - return err - } + return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error { + aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) + if err != nil { + return err + } - // This will wait indefinitely until the lock can be acquired. - query := `SELECT pg_advisory_lock($1)` - if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { - return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} - } + // This will wait indefinitely until the lock can be acquired. + query := `SELECT pg_advisory_lock($1)` + if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { + return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} + } - p.isLocked = true - return nil + return nil + }) } func (p *Postgres) Unlock() error { - if !p.isLocked { - return nil - } - - aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) - if err != nil { - return err - } + return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error { + aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) + if err != nil { + return err + } - query := `SELECT pg_advisory_unlock($1)` - if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { - return &database.Error{OrigErr: err, Query: []byte(query)} - } - p.isLocked = false - return nil + query := `SELECT pg_advisory_unlock($1)` + if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { + return &database.Error{OrigErr: err, Query: []byte(query)} + } + return nil + }) } func (p *Postgres) Run(migration io.Reader) error { diff --git a/database/ql/ql.go b/database/ql/ql.go index 5b2dbe355..1c4c49be6 100644 --- a/database/ql/ql.go +++ b/database/ql/ql.go @@ -4,6 +4,7 @@ import ( "database/sql" "fmt" "github.com/hashicorp/go-multierror" + "go.uber.org/atomic" "io" "io/ioutil" "strings" @@ -34,7 +35,7 @@ type Config struct { type Ql struct { db *sql.DB - isLocked bool + isLocked atomic.Bool config *Config } @@ -166,17 +167,15 @@ func (m *Ql) Drop() (err error) { return nil } func (m *Ql) Lock() error { - if m.isLocked { + if !m.isLocked.CAS(false, true) { return database.ErrLocked } - m.isLocked = true return nil } func (m *Ql) Unlock() error { - if !m.isLocked { - return nil + if !m.isLocked.CAS(true, false) { + return database.ErrNotLocked } - m.isLocked = false return nil } func (m *Ql) Run(migration io.Reader) error { diff --git a/database/redshift/redshift.go b/database/redshift/redshift.go index 1f10a29a4..d4539b8a2 100644 --- a/database/redshift/redshift.go +++ b/database/redshift/redshift.go @@ -6,6 +6,7 @@ import ( "context" "database/sql" "fmt" + "go.uber.org/atomic" "io" "io/ioutil" nurl "net/url" @@ -36,7 +37,7 @@ type Config struct { } type Redshift struct { - isLocked bool + isLocked atomic.Bool conn *sql.Conn db *sql.DB @@ -126,15 +127,16 @@ func (p *Redshift) Close() error { // Redshift does not support advisory lock functions: https://docs.aws.amazon.com/redshift/latest/dg/c_unsupported-postgresql-functions.html func (p *Redshift) Lock() error { - if p.isLocked { + if !p.isLocked.CAS(false, true) { return database.ErrLocked } - p.isLocked = true return nil } func (p *Redshift) Unlock() error { - p.isLocked = false + if !p.isLocked.CAS(true, false) { + return database.ErrNotLocked + } return nil } diff --git a/database/snowflake/snowflake.go b/database/snowflake/snowflake.go index 2ad794cec..53d7ca282 100644 --- a/database/snowflake/snowflake.go +++ b/database/snowflake/snowflake.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "go.uber.org/atomic" "io" "io/ioutil" nurl "net/url" @@ -37,7 +38,7 @@ type Config struct { } type Snowflake struct { - isLocked bool + isLocked atomic.Bool conn *sql.Conn db *sql.DB @@ -158,15 +159,16 @@ func (p *Snowflake) Close() error { } func (p *Snowflake) Lock() error { - if p.isLocked { + if !p.isLocked.CAS(false, true) { return database.ErrLocked } - p.isLocked = true return nil } func (p *Snowflake) Unlock() error { - p.isLocked = false + if !p.isLocked.CAS(true, false) { + return database.ErrNotLocked + } return nil } diff --git a/database/sqlcipher/sqlcipher.go b/database/sqlcipher/sqlcipher.go index 53e97446a..782eed24b 100644 --- a/database/sqlcipher/sqlcipher.go +++ b/database/sqlcipher/sqlcipher.go @@ -3,6 +3,7 @@ package sqlcipher import ( "database/sql" "fmt" + "go.uber.org/atomic" "io" "io/ioutil" nurl "net/url" @@ -34,7 +35,7 @@ type Config struct { type Sqlite struct { db *sql.DB - isLocked bool + isLocked atomic.Bool config *Config } @@ -177,18 +178,16 @@ func (m *Sqlite) Drop() (err error) { } func (m *Sqlite) Lock() error { - if m.isLocked { + if !m.isLocked.CAS(false, true) { return database.ErrLocked } - m.isLocked = true return nil } func (m *Sqlite) Unlock() error { - if !m.isLocked { - return nil + if !m.isLocked.CAS(true, false) { + return database.ErrNotLocked } - m.isLocked = false return nil } diff --git a/database/sqlite/sqlite.go b/database/sqlite/sqlite.go index 581b87d28..d33c60e46 100644 --- a/database/sqlite/sqlite.go +++ b/database/sqlite/sqlite.go @@ -3,6 +3,7 @@ package sqlite import ( "database/sql" "fmt" + "go.uber.org/atomic" "io" "io/ioutil" nurl "net/url" @@ -34,7 +35,7 @@ type Config struct { type Sqlite struct { db *sql.DB - isLocked bool + isLocked atomic.Bool config *Config } @@ -177,18 +178,16 @@ func (m *Sqlite) Drop() (err error) { } func (m *Sqlite) Lock() error { - if m.isLocked { + if !m.isLocked.CAS(false, true) { return database.ErrLocked } - m.isLocked = true return nil } func (m *Sqlite) Unlock() error { - if !m.isLocked { - return nil + if !m.isLocked.CAS(true, false) { + return database.ErrNotLocked } - m.isLocked = false return nil } diff --git a/database/sqlite3/sqlite3.go b/database/sqlite3/sqlite3.go index 4d40f3ecf..65aa6e74c 100644 --- a/database/sqlite3/sqlite3.go +++ b/database/sqlite3/sqlite3.go @@ -3,6 +3,7 @@ package sqlite3 import ( "database/sql" "fmt" + "go.uber.org/atomic" "io" "io/ioutil" nurl "net/url" @@ -34,7 +35,7 @@ type Config struct { type Sqlite struct { db *sql.DB - isLocked bool + isLocked atomic.Bool config *Config } @@ -177,18 +178,16 @@ func (m *Sqlite) Drop() (err error) { } func (m *Sqlite) Lock() error { - if m.isLocked { + if !m.isLocked.CAS(false, true) { return database.ErrLocked } - m.isLocked = true return nil } func (m *Sqlite) Unlock() error { - if !m.isLocked { - return nil + if !m.isLocked.CAS(true, false) { + return database.ErrNotLocked } - m.isLocked = false return nil } diff --git a/database/sqlserver/sqlserver.go b/database/sqlserver/sqlserver.go index b90619ff9..0f8252f3e 100644 --- a/database/sqlserver/sqlserver.go +++ b/database/sqlserver/sqlserver.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "go.uber.org/atomic" "io" "io/ioutil" nurl "net/url" @@ -47,7 +48,7 @@ type SQLServer struct { // Locking and unlocking need to use the same connection conn *sql.Conn db *sql.DB - isLocked bool + isLocked atomic.Bool // Open and WithInstance need to garantuee that config is never nil config *Config @@ -154,50 +155,44 @@ func (ss *SQLServer) Close() error { // Lock creates an advisory local on the database to prevent multiple migrations from running at the same time. func (ss *SQLServer) Lock() error { - if ss.isLocked { - return database.ErrLocked - } - - aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName) - if err != nil { - return err - } - - // This will either obtain the lock immediately and return true, - // or return false if the lock cannot be acquired immediately. - // MS Docs: sp_getapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-getapplock-transact-sql?view=sql-server-2017 - query := `EXEC sp_getapplock @Resource = @p1, @LockMode = 'Update', @LockOwner = 'Session', @LockTimeout = 0` + return database.CasRestoreOnErr(&ss.isLocked, false, true, database.ErrLocked, func() error { + aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName) + if err != nil { + return err + } - var status mssql.ReturnStatus - if _, err = ss.conn.ExecContext(context.Background(), query, aid, &status); err == nil && status > -1 { - ss.isLocked = true - return nil - } else if err != nil { - return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} - } else { - return &database.Error{Err: fmt.Sprintf("try lock failed with error %v: %v", status, lockErrorMap[status]), Query: []byte(query)} - } + // This will either obtain the lock immediately and return true, + // or return false if the lock cannot be acquired immediately. + // MS Docs: sp_getapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-getapplock-transact-sql?view=sql-server-2017 + query := `EXEC sp_getapplock @Resource = @p1, @LockMode = 'Update', @LockOwner = 'Session', @LockTimeout = 0` + + var status mssql.ReturnStatus + if _, err = ss.conn.ExecContext(context.Background(), query, aid, &status); err == nil && status > -1 { + return nil + } else if err != nil { + return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} + } else { + return &database.Error{Err: fmt.Sprintf("try lock failed with error %v: %v", status, lockErrorMap[status]), Query: []byte(query)} + } + }) } // Unlock froms the migration lock from the database func (ss *SQLServer) Unlock() error { - if !ss.isLocked { - return nil - } - - aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName) - if err != nil { - return err - } + return database.CasRestoreOnErr(&ss.isLocked, true, false, database.ErrNotLocked, func() error { + aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName) + if err != nil { + return err + } - // MS Docs: sp_releaseapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-releaseapplock-transact-sql?view=sql-server-2017 - query := `EXEC sp_releaseapplock @Resource = @p1, @LockOwner = 'Session'` - if _, err := ss.conn.ExecContext(context.Background(), query, aid); err != nil { - return &database.Error{OrigErr: err, Query: []byte(query)} - } - ss.isLocked = false + // MS Docs: sp_releaseapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-releaseapplock-transact-sql?view=sql-server-2017 + query := `EXEC sp_releaseapplock @Resource = @p1, @LockOwner = 'Session'` + if _, err := ss.conn.ExecContext(context.Background(), query, aid); err != nil { + return &database.Error{OrigErr: err, Query: []byte(query)} + } - return nil + return nil + }) } // Run the migrations for the database diff --git a/database/stub/stub.go b/database/stub/stub.go index f05e5443b..238ce8ba6 100644 --- a/database/stub/stub.go +++ b/database/stub/stub.go @@ -1,6 +1,7 @@ package stub import ( + "go.uber.org/atomic" "io" "io/ioutil" "reflect" @@ -19,7 +20,7 @@ type Stub struct { MigrationSequence []string LastRunMigration []byte // todo: make []string IsDirty bool - IsLocked bool + isLocked atomic.Bool Config *Config } @@ -49,15 +50,16 @@ func (s *Stub) Close() error { } func (s *Stub) Lock() error { - if s.IsLocked { + if !s.isLocked.CAS(false, true) { return database.ErrLocked } - s.IsLocked = true return nil } func (s *Stub) Unlock() error { - s.IsLocked = false + if !s.isLocked.CAS(true, false) { + return database.ErrNotLocked + } return nil } diff --git a/database/util.go b/database/util.go index 976ad3534..de66d5b80 100644 --- a/database/util.go +++ b/database/util.go @@ -2,6 +2,7 @@ package database import ( "fmt" + "go.uber.org/atomic" "hash/crc32" "strings" ) @@ -17,3 +18,16 @@ func GenerateAdvisoryLockId(databaseName string, additionalNames ...string) (str sum = sum * uint32(advisoryLockIDSalt) return fmt.Sprint(sum), nil } + +// CasRestoreOnErr CAS wrapper to automatically restore the lock state on error +func CasRestoreOnErr(lock *atomic.Bool, o, n bool, casErr error, f func() error) error { + if !lock.CAS(o, n) { + return casErr + } + if err := f(); err != nil { + // Automatically unlock/lock on error + lock.Store(o) + return err + } + return nil +} diff --git a/database/util_test.go b/database/util_test.go index 13cba46d8..3f1dc73ae 100644 --- a/database/util_test.go +++ b/database/util_test.go @@ -1,6 +1,8 @@ package database import ( + "errors" + "go.uber.org/atomic" "testing" ) @@ -45,3 +47,60 @@ func TestGenerateAdvisoryLockId(t *testing.T) { }) } } + +func TestCasRestoreOnErr(t *testing.T) { + casErr := errors.New("test lock CAS failure") + fErr := errors.New("test callback error") + + testcases := []struct { + name string + lock *atomic.Bool + from bool + to bool + expectLock bool + fErr error + expectError error + }{ + { + name: "Test positive CAS lock", + lock: atomic.NewBool(false), + from: false, + to: true, + expectLock: true, + fErr: nil, + expectError: nil, + }, + { + name: "Test negative CAS lock", + lock: atomic.NewBool(true), + from: false, + to: true, + expectLock: true, + fErr: nil, + expectError: casErr, + }, + { + name: "Test negative with callback lock", + lock: atomic.NewBool(false), + from: false, + to: true, + expectLock: false, + fErr: fErr, + expectError: fErr, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + if err := CasRestoreOnErr(tc.lock, tc.from, tc.to, casErr, func() error { + return tc.fErr + }); err != tc.expectError { + t.Error("Incorrect error value returned") + } + + if tc.lock.Load() != tc.expectLock { + t.Error("Incorrect state of lock") + } + }) + } +}