diff --git a/drivers/mysql.go b/drivers/mysql.go index 1c47412..4bb1045 100644 --- a/drivers/mysql.go +++ b/drivers/mysql.go @@ -44,12 +44,6 @@ func (db *MySQL) GetDatabases() ([]string, error) { if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() for rows.Next() { @@ -62,6 +56,9 @@ func (db *MySQL) GetDatabases() ([]string, error) { databases = append(databases, database) } } + if err := rows.Err(); err != nil { + return nil, err + } return databases, nil } @@ -72,20 +69,12 @@ func (db *MySQL) GetTables(database string) (map[string][]string, error) { } rows, err := db.Connection.Query(fmt.Sprintf("SHOW TABLES FROM `%s`", database)) - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - - defer rows.Close() - - tables := make(map[string][]string) - if err != nil { return nil, err } + defer rows.Close() + tables := make(map[string][]string) for rows.Next() { var table string err = rows.Scan(&table) @@ -95,6 +84,9 @@ func (db *MySQL) GetTables(database string) (map[string][]string, error) { tables[database] = append(tables[database], table) } + if err := rows.Err(); err != nil { + return nil, err + } return tables, nil } @@ -115,12 +107,6 @@ func (db *MySQL) GetTableColumns(database, table string) (results [][]string, er if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() columns, err := rows.Columns() @@ -149,6 +135,9 @@ func (db *MySQL) GetTableColumns(database, table string) (results [][]string, er results = append(results, row) } + if err := rows.Err(); err != nil { + return nil, err + } return } @@ -168,12 +157,6 @@ func (db *MySQL) GetConstraints(database, table string) (results [][]string, err if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() columns, err := rows.Columns() @@ -201,6 +184,9 @@ func (db *MySQL) GetConstraints(database, table string) (results [][]string, err results = append(results, row) } + if err := rows.Err(); err != nil { + return nil, err + } return } @@ -220,12 +206,6 @@ func (db *MySQL) GetForeignKeys(database, table string) (results [][]string, err if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() columns, err := rows.Columns() @@ -253,6 +233,9 @@ func (db *MySQL) GetForeignKeys(database, table string) (results [][]string, err results = append(results, row) } + if err := rows.Err(); err != nil { + return nil, err + } return } @@ -273,12 +256,6 @@ func (db *MySQL) GetIndexes(database, table string) (results [][]string, err err if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() columns, err := rows.Columns() @@ -306,6 +283,9 @@ func (db *MySQL) GetIndexes(database, table string) (results [][]string, err err results = append(results, row) } + if err := rows.Err(); err != nil { + return nil, err + } return } @@ -340,30 +320,8 @@ func (db *MySQL) GetRecords(database, table, where, sort string, offset, limit i if err != nil { return nil, 0, err } - - rowsErr := paginatedRows.Err() - - if rowsErr != nil { - return nil, 0, rowsErr - } - defer paginatedRows.Close() - countQuery := "SELECT COUNT(*) FROM " - countQuery += fmt.Sprintf("`%s`.", database) - countQuery += fmt.Sprintf("`%s`", table) - - rows := db.Connection.QueryRow(countQuery) - - if err != nil { - return nil, 0, err - } - - err = rows.Scan(&totalRecords) - if err != nil { - return nil, 0, err - } - columns, err := paginatedRows.Columns() if err != nil { return nil, 0, err @@ -388,7 +346,20 @@ func (db *MySQL) GetRecords(database, table, where, sort string, offset, limit i } paginatedResults = append(paginatedResults, row) - + } + if err := paginatedRows.Err(); err != nil { + return nil, 0, err + } + // close to release the connection + if err := paginatedRows.Close(); err != nil { + return nil, 0, err + } + countQuery := "SELECT COUNT(*) FROM " + countQuery += fmt.Sprintf("`%s`.", database) + countQuery += fmt.Sprintf("`%s`", table) + row := db.Connection.QueryRow(countQuery) + if err := row.Scan(&totalRecords); err != nil { + return nil, 0, err } return @@ -399,12 +370,6 @@ func (db *MySQL) ExecuteQuery(query string) (results [][]string, err error) { if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() columns, err := rows.Columns() @@ -431,7 +396,9 @@ func (db *MySQL) ExecuteQuery(query string) (results [][]string, err error) { } results = append(results, row) - + } + if err := rows.Err(); err != nil { + return nil, err } return @@ -552,6 +519,7 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange) (err error) if err != nil { return err } + defer trx.Rollback() for _, query := range query { logger.Info(query.Query, map[string]any{"args": query.Args}) diff --git a/drivers/postgres.go b/drivers/postgres.go index a3fede9..54e42a1 100644 --- a/drivers/postgres.go +++ b/drivers/postgres.go @@ -67,16 +67,8 @@ func (db *Postgres) GetDatabases() (databases []string, err error) { if err != nil { return nil, err } - defer rows.Close() - rowsErr := rows.Err() - - if rowsErr != nil { - err = rowsErr - return nil, err - } - for rows.Next() { var database string err := rows.Scan(&database) @@ -85,6 +77,9 @@ func (db *Postgres) GetDatabases() (databases []string, err error) { } databases = append(databases, database) } + if err := rows.Err(); err != nil { + return nil, err + } return databases, nil } @@ -113,29 +108,23 @@ func (db *Postgres) GetTables(database string) (tables map[string][]string, err query := "SELECT table_name, table_schema FROM information_schema.tables WHERE table_catalog = $1" rows, err := db.Connection.Query(query, database) + if err != nil { + return nil, err + } + defer rows.Close() - if rows != nil { - rowsErr := rows.Err() - - if rowsErr != nil { - err = rowsErr - } - - defer rows.Close() - - for rows.Next() { - var tableName string - var tableSchema string - - err = rows.Scan(&tableName, &tableSchema) - - tables[tableSchema] = append(tables[tableSchema], tableName) - + for rows.Next() { + var ( + tableName string + tableSchema string + ) + if err := rows.Scan(&tableName, &tableSchema); err != nil { + return nil, err } + tables[tableSchema] = append(tables[tableSchema], tableName) } - - if err != nil { + if err := rows.Err(); err != nil { return nil, err } @@ -176,46 +165,35 @@ func (db *Postgres) GetTableColumns(database, table string) (results [][]string, query := "SELECT column_name, data_type, is_nullable, column_default FROM information_schema.columns WHERE table_catalog = $1 AND table_schema = $2 AND table_name = $3 ORDER by ordinal_position" rows, err := db.Connection.Query(query, database, tableSchema, tableName) + if err != nil { + return nil, err + } + defer rows.Close() - if rows != nil { - - rowsErr := rows.Err() - - if rowsErr != nil { - err = rowsErr - } + columns, columnsError := rows.Columns() + if columnsError != nil { + err = columnsError + } - defer rows.Close() + results = append(results, columns) - columns, columnsError := rows.Columns() + for rows.Next() { + rowValues := make([]interface{}, len(columns)) - if columnsError != nil { - err = columnsError + for i := range columns { + rowValues[i] = new(sql.RawBytes) } - results = append(results, columns) - - for rows.Next() { - rowValues := make([]interface{}, len(columns)) - - for i := range columns { - rowValues[i] = new(sql.RawBytes) - } - - err = rows.Scan(rowValues...) - - var row []string - for _, col := range rowValues { - row = append(row, string(*col.(*sql.RawBytes))) - } - - results = append(results, row) + if err := rows.Scan(rowValues...); err != nil { + return nil, err } - } + var row []string + for _, col := range rowValues { + row = append(row, string(*col.(*sql.RawBytes))) + } - if err != nil { - return nil, err + results = append(results, row) } return @@ -268,43 +246,36 @@ func (db *Postgres) GetConstraints(database, table string) (constraints [][]stri AND tc.table_schema = '%s' AND tc.table_name = '%s' `, tableSchema, tableName)) + if err != nil { + return nil, err + } + defer rows.Close() - if rows != nil { + columns, columnsError := rows.Columns() + if columnsError != nil { + err = columnsError + } - rowsErr := rows.Err() + constraints = append(constraints, columns) - if rowsErr != nil { - err = rowsErr + for rows.Next() { + rowValues := make([]interface{}, len(columns)) + for i := range columns { + rowValues[i] = new(sql.RawBytes) } - defer rows.Close() - - columns, columnsError := rows.Columns() - - if columnsError != nil { - err = columnsError + if err := rows.Scan(rowValues...); err != nil { + return nil, err } - constraints = append(constraints, columns) - - for rows.Next() { - rowValues := make([]interface{}, len(columns)) - for i := range columns { - rowValues[i] = new(sql.RawBytes) - } - - err = rows.Scan(rowValues...) - - var row []string - for _, col := range rowValues { - row = append(row, string(*col.(*sql.RawBytes))) - } - - constraints = append(constraints, row) + var row []string + for _, col := range rowValues { + row = append(row, string(*col.(*sql.RawBytes))) } - } - if err != nil { + constraints = append(constraints, row) + } + if err := rows.Err(); err != nil { return nil, err } @@ -359,43 +330,36 @@ func (db *Postgres) GetForeignKeys(database, table string) (foreignKeys [][]stri AND tc.table_schema = '%s' AND tc.table_name = '%s' `, tableSchema, tableName)) + if err != nil { + return nil, err + } + defer rows.Close() - if rows != nil { + columns, columnsError := rows.Columns() + if columnsError != nil { + err = columnsError + } - rowsErr := rows.Err() + foreignKeys = append(foreignKeys, columns) - if rowsErr != nil { - err = rowsErr + for rows.Next() { + rowValues := make([]interface{}, len(columns)) + for i := range columns { + rowValues[i] = new(sql.RawBytes) } - defer rows.Close() - - columns, columnsError := rows.Columns() - - if columnsError != nil { - err = columnsError + if err := rows.Scan(rowValues...); err != nil { + return nil, err } - foreignKeys = append(foreignKeys, columns) - - for rows.Next() { - rowValues := make([]interface{}, len(columns)) - for i := range columns { - rowValues[i] = new(sql.RawBytes) - } - - err = rows.Scan(rowValues...) - - var row []string - for _, col := range rowValues { - row = append(row, string(*col.(*sql.RawBytes))) - } - - foreignKeys = append(foreignKeys, row) + var row []string + for _, col := range rowValues { + row = append(row, string(*col.(*sql.RawBytes))) } - } - if err != nil { + foreignKeys = append(foreignKeys, row) + } + if err := rows.Err(); err != nil { return nil, err } @@ -459,43 +423,36 @@ func (db *Postgres) GetIndexes(database, table string) (indexes [][]string, err t.relname, i.relname `, tableSchema, tableName)) + if err != nil { + return nil, err + } + defer rows.Close() - if rows != nil { + columns, columnsError := rows.Columns() + if columnsError != nil { + err = columnsError + } - rowsErr := rows.Err() + indexes = append(indexes, columns) - if rowsErr != nil { - err = rowsErr + for rows.Next() { + rowValues := make([]interface{}, len(columns)) + for i := range columns { + rowValues[i] = new(sql.RawBytes) } - defer rows.Close() - - columns, columnsError := rows.Columns() - - if columnsError != nil { - err = columnsError + if err := rows.Scan(rowValues...); err != nil { + return nil, err } - indexes = append(indexes, columns) - - for rows.Next() { - rowValues := make([]interface{}, len(columns)) - for i := range columns { - rowValues[i] = new(sql.RawBytes) - } - - err = rows.Scan(rowValues...) - - var row []string - for _, col := range rowValues { - row = append(row, string(*col.(*sql.RawBytes))) - } - - indexes = append(indexes, row) + var row []string + for _, col := range rowValues { + row = append(row, string(*col.(*sql.RawBytes))) } - } - if err != nil { + indexes = append(indexes, row) + } + if err := rows.Err(); err != nil { return nil, err } @@ -555,57 +512,48 @@ func (db *Postgres) GetRecords(database, table, where, sort string, offset, limi query += " LIMIT $1 OFFSET $2" paginatedRows, err := db.Connection.Query(query, limit, offset) + if err != nil { + return nil, 0, err + } + defer paginatedRows.Close() - if paginatedRows != nil { - - rowsErr := paginatedRows.Err() + columns, columnsError := paginatedRows.Columns() + if columnsError != nil { + return nil, 0, columnsError + } - defer paginatedRows.Close() + records = append(records, columns) - if rowsErr != nil { - err = rowsErr + for paginatedRows.Next() { + rowValues := make([]interface{}, len(columns)) + for i := range columns { + rowValues[i] = new(sql.RawBytes) } - countQuery := "SELECT COUNT(*) FROM " - countQuery += formattedTableName - - rows := db.Connection.QueryRow(countQuery) - - rowsErr = rows.Err() - - if rowsErr != nil { - err = rowsErr + if err := paginatedRows.Scan(rowValues...); err != nil { + return nil, 0, err } - err = rows.Scan(&totalRecords) - - columns, columnsError := paginatedRows.Columns() - - if columnsError != nil { - err = columnsError + var row []string + for _, col := range rowValues { + row = append(row, string(*col.(*sql.RawBytes))) } - records = append(records, columns) + records = append(records, row) - for paginatedRows.Next() { - rowValues := make([]interface{}, len(columns)) - for i := range columns { - rowValues[i] = new(sql.RawBytes) - } - - err = paginatedRows.Scan(rowValues...) - - var row []string - for _, col := range rowValues { - row = append(row, string(*col.(*sql.RawBytes))) - } - - records = append(records, row) - - } + } + if err := paginatedRows.Err(); err != nil { + return nil, 0, err + } + // close to release the connection + if err := paginatedRows.Close(); err != nil { + return nil, 0, err } - if err != nil { + countQuery := "SELECT COUNT(*) FROM " + countQuery += formattedTableName + row := db.Connection.QueryRow(countQuery) + if err := row.Scan(&totalRecords); err != nil { return nil, 0, err } @@ -740,15 +688,8 @@ func (db *Postgres) ExecuteQuery(query string) (results [][]string, err error) { if err != nil { return nil, err } - defer rows.Close() - rowsErr := rows.Err() - - if rowsErr != nil { - err = rowsErr - } - columns, err := rows.Columns() if err != nil { return nil, err @@ -773,7 +714,9 @@ func (db *Postgres) ExecuteQuery(query string) (results [][]string, err error) { } results = append(results, row) - + } + if err := rows.Err(); err != nil { + return nil, err } return @@ -869,6 +812,7 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange) (err err if err != nil { return err } + defer trx.Rollback() for _, query := range query { logger.Info(query.Query, map[string]any{"args": query.Args}) diff --git a/drivers/sqlite.go b/drivers/sqlite.go index 24fc476..84b891d 100644 --- a/drivers/sqlite.go +++ b/drivers/sqlite.go @@ -45,12 +45,6 @@ func (db *SQLite) GetDatabases() ([]string, error) { if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() for rows.Next() { @@ -65,6 +59,9 @@ func (db *SQLite) GetDatabases() ([]string, error) { databases = append(databases, dbName) } + if err := rows.Err(); err != nil { + return nil, err + } return databases, nil } @@ -78,12 +75,6 @@ func (db *SQLite) GetTables(database string) (map[string][]string, error) { if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() tables := make(map[string][]string) @@ -97,6 +88,9 @@ func (db *SQLite) GetTables(database string) (map[string][]string, error) { tables[database] = append(tables[database], table) } + if err := rows.Err(); err != nil { + return nil, err + } return tables, nil } @@ -110,12 +104,6 @@ func (db *SQLite) GetTableColumns(_, table string) (results [][]string, err erro if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() columns, err := rows.Columns() @@ -148,6 +136,9 @@ func (db *SQLite) GetTableColumns(_, table string) (results [][]string, err erro results = append(results, row[1:]) } + if err := rows.Err(); err != nil { + return nil, err + } return } @@ -164,12 +155,6 @@ func (db *SQLite) GetConstraints(_, table string) (results [][]string, err error if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() columns, err := rows.Columns() @@ -201,6 +186,9 @@ func (db *SQLite) GetConstraints(_, table string) (results [][]string, err error results = append(results, row) } + if err := rows.Err(); err != nil { + return nil, err + } return } @@ -214,12 +202,6 @@ func (db *SQLite) GetForeignKeys(_, table string) (results [][]string, err error if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() columns, err := rows.Columns() @@ -251,6 +233,9 @@ func (db *SQLite) GetForeignKeys(_, table string) (results [][]string, err error results = append(results, row) } + if err := rows.Err(); err != nil { + return nil, err + } return } @@ -264,12 +249,6 @@ func (db *SQLite) GetIndexes(_, table string) (results [][]string, err error) { if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() columns, err := rows.Columns() @@ -301,6 +280,9 @@ func (db *SQLite) GetIndexes(_, table string) (results [][]string, err error) { results = append(results, row) } + if err := rows.Err(); err != nil { + return nil, err + } return } @@ -331,29 +313,8 @@ func (db *SQLite) GetRecords(_, table, where, sort string, offset, limit int) (p if err != nil { return nil, 0, err } - - rowsErr := paginatedRows.Err() - - if rowsErr != nil { - return nil, 0, rowsErr - } - defer paginatedRows.Close() - countQuery := "SELECT COUNT(*) FROM " - countQuery += db.formatTableName(table) - - rows := db.Connection.QueryRow(countQuery) - - if err != nil { - return nil, 0, err - } - - err = rows.Scan(&totalRecords) - if err != nil { - return nil, 0, err - } - columns, err := paginatedRows.Columns() if err != nil { return nil, 0, err @@ -378,7 +339,20 @@ func (db *SQLite) GetRecords(_, table, where, sort string, offset, limit int) (p } paginatedResults = append(paginatedResults, row) + } + if err := paginatedRows.Err(); err != nil { + return nil, 0, err + } + // close to release the connection + if err := paginatedRows.Close(); err != nil { + return nil, 0, err + } + countQuery := "SELECT COUNT(*) FROM " + countQuery += db.formatTableName(table) + row := db.Connection.QueryRow(countQuery) + if err := row.Scan(&totalRecords); err != nil { + return nil, 0, err } return @@ -389,12 +363,6 @@ func (db *SQLite) ExecuteQuery(query string) (results [][]string, err error) { if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() columns, err := rows.Columns() @@ -425,7 +393,9 @@ func (db *SQLite) ExecuteQuery(query string) (results [][]string, err error) { } results = append(results, row) - + } + if err := rows.Err(); err != nil { + return nil, err } return @@ -580,6 +550,7 @@ func (db *SQLite) ExecutePendingChanges(changes []models.DbDmlChange) (err error if err != nil { return err } + defer trx.Rollback() for _, query := range query { logger.Info(query.Query, map[string]any{"args": query.Args})