Skip to content

Commit

Permalink
Merge pull request #67 from jorgerojas26/mysql-numeric-table-name
Browse files Browse the repository at this point in the history
fix: format table name to include backticks
  • Loading branch information
jorgerojas26 authored Jul 10, 2024
2 parents 848dbf6 + 10a4b1a commit 78c1bd1
Showing 1 changed file with 28 additions and 6 deletions.
34 changes: 28 additions & 6 deletions drivers/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (db *MySQL) GetDatabases() ([]string, error) {
}

func (db *MySQL) GetTables(database string) (map[string][]string, error) {
rows, err := db.Connection.Query("SHOW TABLES FROM " + database)
rows, err := db.Connection.Query(fmt.Sprintf("SHOW TABLES FROM `%s`", database))

tables := make(map[string][]string)

Expand All @@ -79,6 +79,8 @@ func (db *MySQL) GetTables(database string) (map[string][]string, error) {
}

func (db *MySQL) GetTableColumns(database, table string) (results [][]string, err error) {
table = db.formatTableName(table)

rows, err := db.Connection.Query("DESCRIBE " + table)
if err != nil {
return results, err
Expand Down Expand Up @@ -112,6 +114,8 @@ func (db *MySQL) GetTableColumns(database, table string) (results [][]string, er
}

func (db *MySQL) GetConstraints(table string) (results [][]string, err error) {
table = db.formatTableName(table)

splitTableString := strings.Split(table, ".")
database := splitTableString[0]
tableName := splitTableString[1]
Expand Down Expand Up @@ -150,6 +154,7 @@ func (db *MySQL) GetConstraints(table string) (results [][]string, err error) {
}

func (db *MySQL) GetForeignKeys(table string) (results [][]string, err error) {
table = db.formatTableName(table)
splitTableString := strings.Split(table, ".")
database := splitTableString[0]
tableName := splitTableString[1]
Expand Down Expand Up @@ -188,6 +193,7 @@ func (db *MySQL) GetForeignKeys(table string) (results [][]string, err error) {
}

func (db *MySQL) GetIndexes(table string) (results [][]string, err error) {
table = db.formatTableName(table)
rows, err := db.Connection.Query("SHOW INDEX FROM " + table)
if err != nil {
return results, err
Expand Down Expand Up @@ -218,6 +224,7 @@ func (db *MySQL) GetIndexes(table string) (results [][]string, err error) {
}

func (db *MySQL) GetRecords(table, where, sort string, offset, limit int) (paginatedResults [][]string, totalRecords int, err error) {
table = db.formatTableName(table)
defaultLimit := 300

isPaginationEnabled := offset >= 0 && limit >= 0
Expand Down Expand Up @@ -313,6 +320,7 @@ func (db *MySQL) ExecuteQuery(query string) (results [][]string, err error) {

// TODO: Rewrites this logic to use the primary key instead of the id
func (db *MySQL) UpdateRecord(table, column, value, primaryKeyColumnName, primaryKeyValue string) error {
table = db.formatTableName(table)
query := fmt.Sprintf("UPDATE %s SET %s = \"%s\" WHERE %s = \"%s\"", table, column, value, primaryKeyColumnName, primaryKeyValue)
_, err := db.Connection.Exec(query)

Expand All @@ -321,6 +329,7 @@ func (db *MySQL) UpdateRecord(table, column, value, primaryKeyColumnName, primar

// TODO: Rewrites this logic to use the primary key instead of the id
func (db *MySQL) DeleteRecord(table, primaryKeyColumnName, primaryKeyValue string) error {
table = db.formatTableName(table)
query := fmt.Sprintf("DELETE FROM %s WHERE %s = \"%s\"", table, primaryKeyColumnName, primaryKeyValue)
_, err := db.Connection.Exec(query)

Expand Down Expand Up @@ -363,7 +372,7 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange, inserts []m

// Split key into table and rowId
splitted := strings.Split(key, "|")
table := splitted[0]
table := db.formatTableName(splitted[0])
primaryKeyColumnName := splitted[1]
primaryKeyValue := splitted[2]

Expand All @@ -385,7 +394,7 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange, inserts []m

statementType = "DELETE FROM"

query = fmt.Sprintf("%s %s WHERE %s = \"%s\"", statementType, delete.Table, delete.PrimaryKeyColumnName, delete.PrimaryKeyValue)
query = fmt.Sprintf("%s %s WHERE %s = \"%s\"", statementType, db.formatTableName(delete.Table), delete.PrimaryKeyColumnName, delete.PrimaryKeyValue)

if query != "" {
queries = append(queries, query)
Expand All @@ -405,7 +414,7 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange, inserts []m
}
}

query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", insert.Table, strings.Join(insert.Columns, ", "), strings.Join(values, ", "))
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", db.formatTableName(insert.Table), strings.Join(insert.Columns, ", "), strings.Join(values, ", "))

queries = append(queries, query)
}
Expand All @@ -418,7 +427,6 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange, inserts []m
for _, query := range queries {

_, err = tx.Exec(query)

if err != nil {
tx.Rollback()

Expand All @@ -427,7 +435,6 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange, inserts []m
}

err = tx.Commit()

if err != nil {
return err
}
Expand All @@ -442,3 +449,18 @@ func (db *MySQL) SetProvider(provider string) {
func (db *MySQL) GetProvider() string {
return db.Provider
}

func (db *MySQL) formatTableName(tableName string) string {
splittedTableName := strings.Split(tableName, ".")

if len(splittedTableName) == 1 {
return tableName
}

database := splittedTableName[0]
table := splittedTableName[1]

formattedTableName := fmt.Sprintf("`%s`.`%s`", database, table)

return formattedTableName
}

0 comments on commit 78c1bd1

Please sign in to comment.