diff --git a/drivers/mysql.go b/drivers/mysql.go index d1523df..43b4b6c 100644 --- a/drivers/mysql.go +++ b/drivers/mysql.go @@ -60,7 +60,7 @@ func (db *MySQL) GetDatabases() ([]string, error) { } func (db *MySQL) GetTables(database string) (map[string][]string, error) { - rows, err := db.Connection.Query("SHOW TABLES FROM " + database) + rows, err := db.Connection.Query(fmt.Sprintf("SHOW TABLES FROM `%s`", database)) tables := make(map[string][]string) @@ -79,6 +79,8 @@ func (db *MySQL) GetTables(database string) (map[string][]string, error) { } func (db *MySQL) GetTableColumns(database, table string) (results [][]string, err error) { + table = db.formatTableName(table) + rows, err := db.Connection.Query("DESCRIBE " + table) if err != nil { return results, err @@ -112,6 +114,8 @@ func (db *MySQL) GetTableColumns(database, table string) (results [][]string, er } func (db *MySQL) GetConstraints(table string) (results [][]string, err error) { + table = db.formatTableName(table) + splitTableString := strings.Split(table, ".") database := splitTableString[0] tableName := splitTableString[1] @@ -150,6 +154,7 @@ func (db *MySQL) GetConstraints(table string) (results [][]string, err error) { } func (db *MySQL) GetForeignKeys(table string) (results [][]string, err error) { + table = db.formatTableName(table) splitTableString := strings.Split(table, ".") database := splitTableString[0] tableName := splitTableString[1] @@ -188,6 +193,7 @@ func (db *MySQL) GetForeignKeys(table string) (results [][]string, err error) { } func (db *MySQL) GetIndexes(table string) (results [][]string, err error) { + table = db.formatTableName(table) rows, err := db.Connection.Query("SHOW INDEX FROM " + table) if err != nil { return results, err @@ -218,6 +224,7 @@ func (db *MySQL) GetIndexes(table string) (results [][]string, err error) { } func (db *MySQL) GetRecords(table, where, sort string, offset, limit int) (paginatedResults [][]string, totalRecords int, err error) { + table = db.formatTableName(table) defaultLimit := 300 isPaginationEnabled := offset >= 0 && limit >= 0 @@ -313,6 +320,7 @@ func (db *MySQL) ExecuteQuery(query string) (results [][]string, err error) { // TODO: Rewrites this logic to use the primary key instead of the id func (db *MySQL) UpdateRecord(table, column, value, primaryKeyColumnName, primaryKeyValue string) error { + table = db.formatTableName(table) query := fmt.Sprintf("UPDATE %s SET %s = \"%s\" WHERE %s = \"%s\"", table, column, value, primaryKeyColumnName, primaryKeyValue) _, err := db.Connection.Exec(query) @@ -321,6 +329,7 @@ func (db *MySQL) UpdateRecord(table, column, value, primaryKeyColumnName, primar // TODO: Rewrites this logic to use the primary key instead of the id func (db *MySQL) DeleteRecord(table, primaryKeyColumnName, primaryKeyValue string) error { + table = db.formatTableName(table) query := fmt.Sprintf("DELETE FROM %s WHERE %s = \"%s\"", table, primaryKeyColumnName, primaryKeyValue) _, err := db.Connection.Exec(query) @@ -363,7 +372,7 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange, inserts []m // Split key into table and rowId splitted := strings.Split(key, "|") - table := splitted[0] + table := db.formatTableName(splitted[0]) primaryKeyColumnName := splitted[1] primaryKeyValue := splitted[2] @@ -385,7 +394,7 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange, inserts []m statementType = "DELETE FROM" - query = fmt.Sprintf("%s %s WHERE %s = \"%s\"", statementType, delete.Table, delete.PrimaryKeyColumnName, delete.PrimaryKeyValue) + query = fmt.Sprintf("%s %s WHERE %s = \"%s\"", statementType, db.formatTableName(delete.Table), delete.PrimaryKeyColumnName, delete.PrimaryKeyValue) if query != "" { queries = append(queries, query) @@ -405,7 +414,7 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange, inserts []m } } - query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", insert.Table, strings.Join(insert.Columns, ", "), strings.Join(values, ", ")) + query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", db.formatTableName(insert.Table), strings.Join(insert.Columns, ", "), strings.Join(values, ", ")) queries = append(queries, query) } @@ -418,7 +427,6 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange, inserts []m for _, query := range queries { _, err = tx.Exec(query) - if err != nil { tx.Rollback() @@ -427,7 +435,6 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange, inserts []m } err = tx.Commit() - if err != nil { return err } @@ -442,3 +449,18 @@ func (db *MySQL) SetProvider(provider string) { func (db *MySQL) GetProvider() string { return db.Provider } + +func (db *MySQL) formatTableName(tableName string) string { + splittedTableName := strings.Split(tableName, ".") + + if len(splittedTableName) == 1 { + return tableName + } + + database := splittedTableName[0] + table := splittedTableName[1] + + formattedTableName := fmt.Sprintf("`%s`.`%s`", database, table) + + return formattedTableName +}