Skip to content

Commit

Permalink
return transaction rollback error
Browse files Browse the repository at this point in the history
  • Loading branch information
prochac committed Sep 7, 2024
1 parent 4fb4a36 commit 8471afa
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 76 deletions.
31 changes: 5 additions & 26 deletions drivers/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"github.com/xo/dburl"

"github.com/jorgerojas26/lazysql/helpers/logger"
"github.com/jorgerojas26/lazysql/models"
)

Expand Down Expand Up @@ -438,7 +437,7 @@ func (db *MySQL) ExecuteDMLStatement(query string) (result string, err error) {
}

func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange) (err error) {
var query []models.Query
var queries []models.Query

for _, change := range changes {
columnNames := []string{}
Expand Down Expand Up @@ -475,7 +474,7 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange) (err error)
Args: values,
}

query = append(query, newQuery)
queries = append(queries, newQuery)
case models.DmlUpdateType:
queryStr := "UPDATE "
queryStr += db.formatTableName(change.Database, change.Table)
Expand All @@ -500,7 +499,7 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange) (err error)
Args: args,
}

query = append(query, newQuery)
queries = append(queries, newQuery)
case models.DmlDeleteType:
queryStr := "DELETE FROM "
queryStr += db.formatTableName(change.Database, change.Table)
Expand All @@ -511,30 +510,10 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange) (err error)
Args: []interface{}{change.PrimaryKeyValue},
}

query = append(query, newQuery)
queries = append(queries, newQuery)
}
}

trx, err := db.Connection.Begin()
if err != nil {
return err
}
defer trx.Rollback()

for _, query := range query {
logger.Info(query.Query, map[string]any{"args": query.Args})
_, err := trx.Exec(query.Query, query.Args...)
if err != nil {
return err
}
}

err = trx.Commit()
if err != nil {
return err
}

return nil
return queriesInTransaction(db.Connection, queries)
}

func (db *MySQL) SetProvider(provider string) {
Expand Down
30 changes: 5 additions & 25 deletions drivers/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ func (db *Postgres) ExecuteQuery(query string) (results [][]string, err error) {
}

func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange) (err error) {
var query []models.Query
var queries []models.Query

for _, change := range changes {
columnNames := []string{}
Expand Down Expand Up @@ -770,7 +770,7 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange) (err err
Args: values,
}

query = append(query, newQuery)
queries = append(queries, newQuery)
case models.DmlUpdateType:
queryStr := "UPDATE " + formattedTableName

Expand All @@ -794,7 +794,7 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange) (err err
Args: args,
}

query = append(query, newQuery)
queries = append(queries, newQuery)
case models.DmlDeleteType:
queryStr := "DELETE FROM " + formattedTableName
queryStr += fmt.Sprintf(" WHERE %s = $1", change.PrimaryKeyColumnName)
Expand All @@ -804,30 +804,10 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange) (err err
Args: []interface{}{change.PrimaryKeyValue},
}

query = append(query, newQuery)
queries = append(queries, newQuery)
}
}

trx, err := db.Connection.Begin()
if err != nil {
return err
}
defer trx.Rollback()

for _, query := range query {
logger.Info(query.Query, map[string]any{"args": query.Args})
_, err := trx.Exec(query.Query, query.Args...)
if err != nil {
return err
}
}

err = trx.Commit()
if err != nil {
return err
}

return nil
return queriesInTransaction(db.Connection, queries)
}

func (db *Postgres) SetProvider(provider string) {
Expand Down
30 changes: 5 additions & 25 deletions drivers/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ func (db *SQLite) ExecuteDMLStatement(query string) (result string, err error) {
}

func (db *SQLite) ExecutePendingChanges(changes []models.DbDmlChange) (err error) {
var query []models.Query
var queries []models.Query

for _, change := range changes {
columnNames := []string{}
Expand Down Expand Up @@ -506,7 +506,7 @@ func (db *SQLite) ExecutePendingChanges(changes []models.DbDmlChange) (err error
Args: values,
}

query = append(query, newQuery)
queries = append(queries, newQuery)
case models.DmlUpdateType:
queryStr := "UPDATE "
queryStr += db.formatTableName(change.Table)
Expand All @@ -531,7 +531,7 @@ func (db *SQLite) ExecutePendingChanges(changes []models.DbDmlChange) (err error
Args: args,
}

query = append(query, newQuery)
queries = append(queries, newQuery)
case models.DmlDeleteType:
queryStr := "DELETE FROM "
queryStr += db.formatTableName(change.Table)
Expand All @@ -542,30 +542,10 @@ func (db *SQLite) ExecutePendingChanges(changes []models.DbDmlChange) (err error
Args: []interface{}{change.PrimaryKeyValue},
}

query = append(query, newQuery)
queries = append(queries, newQuery)
}
}

trx, err := db.Connection.Begin()
if err != nil {
return err
}
defer trx.Rollback()

for _, query := range query {
logger.Info(query.Query, map[string]any{"args": query.Args})
_, err := trx.Exec(query.Query, query.Args...)
if err != nil {
return err
}
}

err = trx.Commit()
if err != nil {
return err
}

return nil
return queriesInTransaction(db.Connection, queries)
}

func (db *SQLite) SetProvider(provider string) {
Expand Down
34 changes: 34 additions & 0 deletions drivers/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package drivers

import (
"database/sql"
"errors"

"github.com/jorgerojas26/lazysql/helpers/logger"
"github.com/jorgerojas26/lazysql/models"
)

func queriesInTransaction(db *sql.DB, queries []models.Query) (err error) {
trx, err := db.Begin()
if err != nil {
return err
}
defer func() {
rErr := trx.Rollback()
// sql.ErrTxDone is returned when trx.Commit was already called
if !errors.Is(rErr, sql.ErrTxDone) {
err = errors.Join(err, rErr)
}
}()

for _, query := range queries {
logger.Info(query.Query, map[string]any{"args": query.Args})
if _, err := trx.Exec(query.Query, query.Args...); err != nil {
return err
}
}
if err := trx.Commit(); err != nil {
return err
}
return nil
}

0 comments on commit 8471afa

Please sign in to comment.