Skip to content

Commit

Permalink
fix: error handing for row cursors
Browse files Browse the repository at this point in the history
`(*sql.Rows).Next()` can return `false` due to an error.
The error must be handled by checking `(*sql.Rows).Err()`.

The transaction must be ended by commit or rollback to close the db connection. Calling rollback after commit is noop.

Don't call a query in the middle of cursor, as it may lead to a deadlock. The inner query may wait for a free connection, and the connection from the cursor can't be freed because it's stuck by the inner query waiting.
  • Loading branch information
prochac committed Sep 4, 2024
1 parent 7cd8dd2 commit 4fb4a36
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 323 deletions.
108 changes: 38 additions & 70 deletions drivers/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,6 @@ func (db *MySQL) GetDatabases() ([]string, error) {
if err != nil {
return nil, err
}

rowsErr := rows.Err()
if rowsErr != nil {
return nil, rowsErr
}

defer rows.Close()

for rows.Next() {
Expand All @@ -62,6 +56,9 @@ func (db *MySQL) GetDatabases() ([]string, error) {
databases = append(databases, database)
}
}
if err := rows.Err(); err != nil {
return nil, err
}

return databases, nil
}
Expand All @@ -72,20 +69,12 @@ func (db *MySQL) GetTables(database string) (map[string][]string, error) {
}

rows, err := db.Connection.Query(fmt.Sprintf("SHOW TABLES FROM `%s`", database))

rowsErr := rows.Err()
if rowsErr != nil {
return nil, rowsErr
}

defer rows.Close()

tables := make(map[string][]string)

if err != nil {
return nil, err
}
defer rows.Close()

tables := make(map[string][]string)
for rows.Next() {
var table string
err = rows.Scan(&table)
Expand All @@ -95,6 +84,9 @@ func (db *MySQL) GetTables(database string) (map[string][]string, error) {

tables[database] = append(tables[database], table)
}
if err := rows.Err(); err != nil {
return nil, err
}

return tables, nil
}
Expand All @@ -115,12 +107,6 @@ func (db *MySQL) GetTableColumns(database, table string) (results [][]string, er
if err != nil {
return nil, err
}

rowsErr := rows.Err()
if rowsErr != nil {
return nil, rowsErr
}

defer rows.Close()

columns, err := rows.Columns()
Expand Down Expand Up @@ -149,6 +135,9 @@ func (db *MySQL) GetTableColumns(database, table string) (results [][]string, er

results = append(results, row)
}
if err := rows.Err(); err != nil {
return nil, err
}

return
}
Expand All @@ -168,12 +157,6 @@ func (db *MySQL) GetConstraints(database, table string) (results [][]string, err
if err != nil {
return nil, err
}

rowsErr := rows.Err()
if rowsErr != nil {
return nil, rowsErr
}

defer rows.Close()

columns, err := rows.Columns()
Expand Down Expand Up @@ -201,6 +184,9 @@ func (db *MySQL) GetConstraints(database, table string) (results [][]string, err

results = append(results, row)
}
if err := rows.Err(); err != nil {
return nil, err
}

return
}
Expand All @@ -220,12 +206,6 @@ func (db *MySQL) GetForeignKeys(database, table string) (results [][]string, err
if err != nil {
return nil, err
}

rowsErr := rows.Err()
if rowsErr != nil {
return nil, rowsErr
}

defer rows.Close()

columns, err := rows.Columns()
Expand Down Expand Up @@ -253,6 +233,9 @@ func (db *MySQL) GetForeignKeys(database, table string) (results [][]string, err

results = append(results, row)
}
if err := rows.Err(); err != nil {
return nil, err
}

return
}
Expand All @@ -273,12 +256,6 @@ func (db *MySQL) GetIndexes(database, table string) (results [][]string, err err
if err != nil {
return nil, err
}

rowsErr := rows.Err()
if rowsErr != nil {
return nil, rowsErr
}

defer rows.Close()

columns, err := rows.Columns()
Expand Down Expand Up @@ -306,6 +283,9 @@ func (db *MySQL) GetIndexes(database, table string) (results [][]string, err err

results = append(results, row)
}
if err := rows.Err(); err != nil {
return nil, err
}

return
}
Expand Down Expand Up @@ -340,30 +320,8 @@ func (db *MySQL) GetRecords(database, table, where, sort string, offset, limit i
if err != nil {
return nil, 0, err
}

rowsErr := paginatedRows.Err()

if rowsErr != nil {
return nil, 0, rowsErr
}

defer paginatedRows.Close()

countQuery := "SELECT COUNT(*) FROM "
countQuery += fmt.Sprintf("`%s`.", database)
countQuery += fmt.Sprintf("`%s`", table)

rows := db.Connection.QueryRow(countQuery)

if err != nil {
return nil, 0, err
}

err = rows.Scan(&totalRecords)
if err != nil {
return nil, 0, err
}

columns, err := paginatedRows.Columns()
if err != nil {
return nil, 0, err
Expand All @@ -388,7 +346,20 @@ func (db *MySQL) GetRecords(database, table, where, sort string, offset, limit i
}

paginatedResults = append(paginatedResults, row)

}
if err := paginatedRows.Err(); err != nil {
return nil, 0, err
}
// close to release the connection
if err := paginatedRows.Close(); err != nil {
return nil, 0, err
}
countQuery := "SELECT COUNT(*) FROM "
countQuery += fmt.Sprintf("`%s`.", database)
countQuery += fmt.Sprintf("`%s`", table)
row := db.Connection.QueryRow(countQuery)
if err := row.Scan(&totalRecords); err != nil {
return nil, 0, err
}

return
Expand All @@ -399,12 +370,6 @@ func (db *MySQL) ExecuteQuery(query string) (results [][]string, err error) {
if err != nil {
return nil, err
}

rowsErr := rows.Err()
if rowsErr != nil {
return nil, rowsErr
}

defer rows.Close()

columns, err := rows.Columns()
Expand All @@ -431,7 +396,9 @@ func (db *MySQL) ExecuteQuery(query string) (results [][]string, err error) {
}

results = append(results, row)

}
if err := rows.Err(); err != nil {
return nil, err
}

return
Expand Down Expand Up @@ -552,6 +519,7 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange) (err error)
if err != nil {
return err
}
defer trx.Rollback()

for _, query := range query {
logger.Info(query.Query, map[string]any{"args": query.Args})
Expand Down
Loading

0 comments on commit 4fb4a36

Please sign in to comment.