From 8471afa7c8b8464fd3fe8e5632436bffca2c21d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Proch=C3=A1zka?= Date: Sat, 7 Sep 2024 21:50:52 +0200 Subject: [PATCH] return transaction rollback error --- drivers/mysql.go | 31 +++++-------------------------- drivers/postgres.go | 30 +++++------------------------- drivers/sqlite.go | 30 +++++------------------------- drivers/utils.go | 34 ++++++++++++++++++++++++++++++++++ 4 files changed, 49 insertions(+), 76 deletions(-) create mode 100644 drivers/utils.go diff --git a/drivers/mysql.go b/drivers/mysql.go index 4bb1045..4283057 100644 --- a/drivers/mysql.go +++ b/drivers/mysql.go @@ -8,7 +8,6 @@ import ( "github.com/xo/dburl" - "github.com/jorgerojas26/lazysql/helpers/logger" "github.com/jorgerojas26/lazysql/models" ) @@ -438,7 +437,7 @@ func (db *MySQL) ExecuteDMLStatement(query string) (result string, err error) { } func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange) (err error) { - var query []models.Query + var queries []models.Query for _, change := range changes { columnNames := []string{} @@ -475,7 +474,7 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange) (err error) Args: values, } - query = append(query, newQuery) + queries = append(queries, newQuery) case models.DmlUpdateType: queryStr := "UPDATE " queryStr += db.formatTableName(change.Database, change.Table) @@ -500,7 +499,7 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange) (err error) Args: args, } - query = append(query, newQuery) + queries = append(queries, newQuery) case models.DmlDeleteType: queryStr := "DELETE FROM " queryStr += db.formatTableName(change.Database, change.Table) @@ -511,30 +510,10 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange) (err error) Args: []interface{}{change.PrimaryKeyValue}, } - query = append(query, newQuery) + queries = append(queries, newQuery) } } - - trx, err := db.Connection.Begin() - if err != nil { - return err - } - defer trx.Rollback() - - for _, query := range query { - logger.Info(query.Query, map[string]any{"args": query.Args}) - _, err := trx.Exec(query.Query, query.Args...) - if err != nil { - return err - } - } - - err = trx.Commit() - if err != nil { - return err - } - - return nil + return queriesInTransaction(db.Connection, queries) } func (db *MySQL) SetProvider(provider string) { diff --git a/drivers/postgres.go b/drivers/postgres.go index 54e42a1..d404b39 100644 --- a/drivers/postgres.go +++ b/drivers/postgres.go @@ -723,7 +723,7 @@ func (db *Postgres) ExecuteQuery(query string) (results [][]string, err error) { } func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange) (err error) { - var query []models.Query + var queries []models.Query for _, change := range changes { columnNames := []string{} @@ -770,7 +770,7 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange) (err err Args: values, } - query = append(query, newQuery) + queries = append(queries, newQuery) case models.DmlUpdateType: queryStr := "UPDATE " + formattedTableName @@ -794,7 +794,7 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange) (err err Args: args, } - query = append(query, newQuery) + queries = append(queries, newQuery) case models.DmlDeleteType: queryStr := "DELETE FROM " + formattedTableName queryStr += fmt.Sprintf(" WHERE %s = $1", change.PrimaryKeyColumnName) @@ -804,30 +804,10 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange) (err err Args: []interface{}{change.PrimaryKeyValue}, } - query = append(query, newQuery) + queries = append(queries, newQuery) } } - - trx, err := db.Connection.Begin() - if err != nil { - return err - } - defer trx.Rollback() - - for _, query := range query { - logger.Info(query.Query, map[string]any{"args": query.Args}) - _, err := trx.Exec(query.Query, query.Args...) - if err != nil { - return err - } - } - - err = trx.Commit() - if err != nil { - return err - } - - return nil + return queriesInTransaction(db.Connection, queries) } func (db *Postgres) SetProvider(provider string) { diff --git a/drivers/sqlite.go b/drivers/sqlite.go index 84b891d..2c67d51 100644 --- a/drivers/sqlite.go +++ b/drivers/sqlite.go @@ -468,7 +468,7 @@ func (db *SQLite) ExecuteDMLStatement(query string) (result string, err error) { } func (db *SQLite) ExecutePendingChanges(changes []models.DbDmlChange) (err error) { - var query []models.Query + var queries []models.Query for _, change := range changes { columnNames := []string{} @@ -506,7 +506,7 @@ func (db *SQLite) ExecutePendingChanges(changes []models.DbDmlChange) (err error Args: values, } - query = append(query, newQuery) + queries = append(queries, newQuery) case models.DmlUpdateType: queryStr := "UPDATE " queryStr += db.formatTableName(change.Table) @@ -531,7 +531,7 @@ func (db *SQLite) ExecutePendingChanges(changes []models.DbDmlChange) (err error Args: args, } - query = append(query, newQuery) + queries = append(queries, newQuery) case models.DmlDeleteType: queryStr := "DELETE FROM " queryStr += db.formatTableName(change.Table) @@ -542,30 +542,10 @@ func (db *SQLite) ExecutePendingChanges(changes []models.DbDmlChange) (err error Args: []interface{}{change.PrimaryKeyValue}, } - query = append(query, newQuery) + queries = append(queries, newQuery) } } - - trx, err := db.Connection.Begin() - if err != nil { - return err - } - defer trx.Rollback() - - for _, query := range query { - logger.Info(query.Query, map[string]any{"args": query.Args}) - _, err := trx.Exec(query.Query, query.Args...) - if err != nil { - return err - } - } - - err = trx.Commit() - if err != nil { - return err - } - - return nil + return queriesInTransaction(db.Connection, queries) } func (db *SQLite) SetProvider(provider string) { diff --git a/drivers/utils.go b/drivers/utils.go new file mode 100644 index 0000000..d880917 --- /dev/null +++ b/drivers/utils.go @@ -0,0 +1,34 @@ +package drivers + +import ( + "database/sql" + "errors" + + "github.com/jorgerojas26/lazysql/helpers/logger" + "github.com/jorgerojas26/lazysql/models" +) + +func queriesInTransaction(db *sql.DB, queries []models.Query) (err error) { + trx, err := db.Begin() + if err != nil { + return err + } + defer func() { + rErr := trx.Rollback() + // sql.ErrTxDone is returned when trx.Commit was already called + if !errors.Is(rErr, sql.ErrTxDone) { + err = errors.Join(err, rErr) + } + }() + + for _, query := range queries { + logger.Info(query.Query, map[string]any{"args": query.Args}) + if _, err := trx.Exec(query.Query, query.Args...); err != nil { + return err + } + } + if err := trx.Commit(); err != nil { + return err + } + return nil +}