From e3b142a10ff87c28b9520ee49f83e4a9b687e96f Mon Sep 17 00:00:00 2001 From: Daniel Prueitt Date: Wed, 20 Mar 2024 17:04:51 +0900 Subject: [PATCH 1/2] feat: Add support for different schemas in Postgresql driver --- components/ResultsTable.go | 28 ++++++++++++++++-- components/Tree.go | 2 +- drivers/postgres.go | 60 +++++++++++++++++++++++--------------- 3 files changed, 62 insertions(+), 28 deletions(-) diff --git a/components/ResultsTable.go b/components/ResultsTable.go index 9742f2b..6abaa11 100644 --- a/components/ResultsTable.go +++ b/components/ResultsTable.go @@ -1103,6 +1103,7 @@ func (table *ResultsTable) AppendNewChange(changeType string, tableName string, func (table *ResultsTable) GetPrimaryKeyValue(rowIndex int) (string, string) { provider := table.DBDriver.GetProvider() columns := table.GetColumns() + constraints := table.GetConstraints() primaryKeyColumnName := "" primaryKeyValue := "" @@ -1131,18 +1132,39 @@ func (table *ResultsTable) GetPrimaryKeyValue(rowIndex int) (string, string) { case "postgres": keyColumnIndex := -1 + constraintTypeColumnIndex := -1 + constraintNameColumnIndex := -1 + pKeyName := "" primaryKeyColumnIndex := -1 + for i, constraint := range constraints[0] { + if constraint == "constraint_type" { + constraintTypeColumnIndex = i + } + if constraint == "column_name" { + constraintNameColumnIndex = i + } + } + + for _, col := range constraints { + if col[constraintTypeColumnIndex] == "PRIMARY KEY" { + pKeyName = col[constraintNameColumnIndex] + break + } + } + + primaryKeyColumnName = pKeyName for i, col := range columns[0] { - if col == "column_default" { + if col == "column_name" { keyColumnIndex = i + break } } for i, col := range columns { - if strings.Contains(col[keyColumnIndex], "nextval") { + if col[keyColumnIndex] == pKeyName { primaryKeyColumnIndex = i - 1 - primaryKeyColumnName = col[0] + break } } diff --git a/components/Tree.go b/components/Tree.go index c9b0360..1a02ff6 100644 --- a/components/Tree.go +++ b/components/Tree.go @@ -152,7 +152,7 @@ func (tree *Tree) updateNodes(children map[string][]string, node *tview.TreeNode for _, child := range values { childNode := tview.NewTreeNode(child) childNode.SetExpanded(defaultExpanded) - childNode.SetReference(node.GetText()) + childNode.SetReference(fmt.Sprintf("%s.%s", node.GetText(), key)) childNode.SetColor(tview.Styles.SecondaryTextColor) if rootNode != nil { rootNode.AddChild(childNode) diff --git a/drivers/postgres.go b/drivers/postgres.go index c966e6b..9fe953f 100644 --- a/drivers/postgres.go +++ b/drivers/postgres.go @@ -21,13 +21,11 @@ func (db *Postgres) Connect(urlstr string) (err error) { db.SetProvider("postgres") db.Connection, err = dburl.Open(urlstr) - if err != nil { return err } err = db.Connection.Ping() - if err != nil { return err } @@ -79,8 +77,9 @@ func (db *Postgres) GetTables(database string) (tables map[string][]string, err } func (db *Postgres) GetTableColumns(database, table string) (results [][]string, error error) { - tableName := strings.Split(table, ".")[1] - rows, err := db.Connection.Query(fmt.Sprintf("SELECT column_name, data_type, is_nullable, column_default FROM information_schema.columns WHERE table_catalog = '%s' AND table_name = '%s' ORDER by ordinal_position", database, tableName)) + tableSchema := strings.Split(table, ".")[1] + tableName := strings.Split(table, ".")[2] + rows, err := db.Connection.Query(fmt.Sprintf("SELECT column_name, data_type, is_nullable, column_default FROM information_schema.columns WHERE table_catalog = '%s' AND table_schema = '%s' AND table_name = '%s' ORDER by ordinal_position", database, tableSchema, tableName)) if err != nil { return results, err } @@ -114,12 +113,14 @@ func (db *Postgres) GetTableColumns(database, table string) (results [][]string, func (db *Postgres) GetConstraints(table string) (constraints [][]string, error error) { splitTableString := strings.Split(table, ".") - tableName := splitTableString[1] + tableSchema := splitTableString[1] + tableName := splitTableString[2] rows, err := db.Connection.Query(fmt.Sprintf(` SELECT tc.constraint_name, - kcu.column_name + kcu.column_name, + tc.constraint_type FROM information_schema.table_constraints AS tc JOIN information_schema.key_column_usage AS kcu ON tc.constraint_name = kcu.constraint_name @@ -128,8 +129,9 @@ func (db *Postgres) GetConstraints(table string) (constraints [][]string, error AND ccu.table_schema = tc.table_schema WHERE NOT tc.constraint_type = 'FOREIGN KEY' + AND tc.table_schema = '%s' AND tc.table_name = '%s' - `, tableName)) + `, tableSchema, tableName)) if err != nil { return constraints, err } @@ -164,7 +166,8 @@ func (db *Postgres) GetConstraints(table string) (constraints [][]string, error func (db *Postgres) GetForeignKeys(table string) (foreignKeys [][]string, error error) { splitTableString := strings.Split(table, ".") - tableName := splitTableString[1] + tableSchema := splitTableString[1] + tableName := splitTableString[2] rows, err := db.Connection.Query(fmt.Sprintf(` SELECT @@ -180,8 +183,9 @@ func (db *Postgres) GetForeignKeys(table string) (foreignKeys [][]string, error AND ccu.table_schema = tc.table_schema WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_schema = '%s' AND tc.table_name = '%s' - `, tableName)) + `, tableSchema, tableName)) if err != nil { return foreignKeys, err } @@ -215,12 +219,17 @@ func (db *Postgres) GetForeignKeys(table string) (foreignKeys [][]string, error } func (db *Postgres) GetIndexes(table string) (indexes [][]string, error error) { + splitTableString := strings.Split(table, ".") + tableSchema := splitTableString[1] + tableName := splitTableString[2] + rows, err := db.Connection.Query(fmt.Sprintf(` SELECT i.relname AS index_name, a.attname AS column_name, am.amname AS type FROM + pg_namespace n, pg_class t, pg_class i, pg_index ix, @@ -233,11 +242,13 @@ func (db *Postgres) GetIndexes(table string) (indexes [][]string, error error) { and a.attnum = ANY(ix.indkey) and t.relkind = 'r' and am.oid = i.relam + and n.oid = t.relnamespace + and n.nspname = '%s' and t.relname = '%s' ORDER BY t.relname, i.relname - `, table)) + `, tableSchema, tableName)) if err != nil { return indexes, err } @@ -270,7 +281,7 @@ func (db *Postgres) GetRecords(table, where, sort string, offset, limit int) (re defaultLimit := 300 splitTableString := strings.Split(table, ".") - tableName := splitTableString[1] + tableName := strings.Join(splitTableString[1:], ".") isPaginationEnabled := offset >= 0 && limit >= 0 @@ -399,7 +410,8 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange, inserts // Group changes by RowId and Table for _, change := range changes { if change.Type == "UPDATE" { - tableName := strings.Split(change.Table, ".")[1] + splitTableString := strings.Split(change.Table, ".") + tableName := strings.Join(splitTableString[1:], ".") key := fmt.Sprintf("%s|%s|%s", tableName, change.PrimaryKeyColumnName, change.PrimaryKeyValue) groupedUpdated[key] = append(groupedUpdated[key], change) } else if change.Type == "DELETE" { @@ -424,19 +436,20 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange, inserts // Merge all column updates updateClause := strings.Join(columns, ", ") - query := fmt.Sprintf("UPDATE %s SET %s WHERE '%s' = '%s';", table, updateClause, PrimaryKeyColumnName, primaryKeyValue) + query := fmt.Sprintf("UPDATE %s SET %s WHERE \"%s\" = '%s';", table, updateClause, PrimaryKeyColumnName, primaryKeyValue) queries = append(queries, query) } - for _, delete := range groupedDeletes { + for _, del := range groupedDeletes { statementType := "" query := "" statementType = "DELETE FROM" - tableName := strings.Split(delete.Table, ".")[1] + splitTableString := strings.Split(del.Table, ".") + tableName := strings.Join(splitTableString[1:], ".") - query = fmt.Sprintf("%s %s WHERE \"%s\" = '%s'", statementType, tableName, delete.PrimaryKeyColumnName, delete.PrimaryKeyValue) + query = fmt.Sprintf("%s %s WHERE \"%s\" = '%s'", statementType, tableName, del.PrimaryKeyColumnName, del.PrimaryKeyValue) if query != "" { queries = append(queries, query) @@ -447,29 +460,29 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange, inserts values := make([]string, 0, len(insert.Values)) for _, value := range insert.Values { - _, error := strconv.ParseFloat(value, 64) + _, err := strconv.ParseFloat(value, 64) - if strings.ToLower(value) != "default" && error != nil { + if strings.ToLower(value) != "default" && err != nil { values = append(values, fmt.Sprintf("'%s'", value)) } else { values = append(values, value) } } - tableName := strings.Split(insert.Table, ".")[1] + splitTableString := strings.Split(insert.Table, ".") + tableName := strings.Join(splitTableString[1:], ".") query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", tableName, strings.Join(insert.Columns, ", "), strings.Join(values, ", ")) queries = append(queries, query) } - tx, error := db.Connection.Begin() - if error != nil { - return error + tx, err := db.Connection.Begin() + if err != nil { + return err } for _, query := range queries { _, err = tx.Exec(query) - if err != nil { tx.Rollback() @@ -478,7 +491,6 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange, inserts } err = tx.Commit() - if err != nil { return err } From 1d34f13128055f7692ee71e26f9a2da354cadc2e Mon Sep 17 00:00:00 2001 From: Daniel Prueitt Date: Wed, 20 Mar 2024 17:04:51 +0900 Subject: [PATCH 2/2] feat: Add support for different schemas in Postgresql driver --- components/ResultsTable.go | 28 +++++++++++-- components/Tree.go | 2 +- drivers/postgres.go | 86 ++++++++++++++++++++------------------ 3 files changed, 72 insertions(+), 44 deletions(-) diff --git a/components/ResultsTable.go b/components/ResultsTable.go index 9742f2b..6abaa11 100644 --- a/components/ResultsTable.go +++ b/components/ResultsTable.go @@ -1103,6 +1103,7 @@ func (table *ResultsTable) AppendNewChange(changeType string, tableName string, func (table *ResultsTable) GetPrimaryKeyValue(rowIndex int) (string, string) { provider := table.DBDriver.GetProvider() columns := table.GetColumns() + constraints := table.GetConstraints() primaryKeyColumnName := "" primaryKeyValue := "" @@ -1131,18 +1132,39 @@ func (table *ResultsTable) GetPrimaryKeyValue(rowIndex int) (string, string) { case "postgres": keyColumnIndex := -1 + constraintTypeColumnIndex := -1 + constraintNameColumnIndex := -1 + pKeyName := "" primaryKeyColumnIndex := -1 + for i, constraint := range constraints[0] { + if constraint == "constraint_type" { + constraintTypeColumnIndex = i + } + if constraint == "column_name" { + constraintNameColumnIndex = i + } + } + + for _, col := range constraints { + if col[constraintTypeColumnIndex] == "PRIMARY KEY" { + pKeyName = col[constraintNameColumnIndex] + break + } + } + + primaryKeyColumnName = pKeyName for i, col := range columns[0] { - if col == "column_default" { + if col == "column_name" { keyColumnIndex = i + break } } for i, col := range columns { - if strings.Contains(col[keyColumnIndex], "nextval") { + if col[keyColumnIndex] == pKeyName { primaryKeyColumnIndex = i - 1 - primaryKeyColumnName = col[0] + break } } diff --git a/components/Tree.go b/components/Tree.go index c9b0360..63e9da0 100644 --- a/components/Tree.go +++ b/components/Tree.go @@ -152,7 +152,7 @@ func (tree *Tree) updateNodes(children map[string][]string, node *tview.TreeNode for _, child := range values { childNode := tview.NewTreeNode(child) childNode.SetExpanded(defaultExpanded) - childNode.SetReference(node.GetText()) + childNode.SetReference(key) childNode.SetColor(tview.Styles.SecondaryTextColor) if rootNode != nil { rootNode.AddChild(childNode) diff --git a/drivers/postgres.go b/drivers/postgres.go index c966e6b..84b67f6 100644 --- a/drivers/postgres.go +++ b/drivers/postgres.go @@ -21,13 +21,11 @@ func (db *Postgres) Connect(urlstr string) (err error) { db.SetProvider("postgres") db.Connection, err = dburl.Open(urlstr) - if err != nil { return err } err = db.Connection.Ping() - if err != nil { return err } @@ -79,8 +77,9 @@ func (db *Postgres) GetTables(database string) (tables map[string][]string, err } func (db *Postgres) GetTableColumns(database, table string) (results [][]string, error error) { + tableSchema := strings.Split(table, ".")[0] tableName := strings.Split(table, ".")[1] - rows, err := db.Connection.Query(fmt.Sprintf("SELECT column_name, data_type, is_nullable, column_default FROM information_schema.columns WHERE table_catalog = '%s' AND table_name = '%s' ORDER by ordinal_position", database, tableName)) + rows, err := db.Connection.Query(fmt.Sprintf("SELECT column_name, data_type, is_nullable, column_default FROM information_schema.columns WHERE table_catalog = '%s' AND table_schema = '%s' AND table_name = '%s' ORDER by ordinal_position", database, tableSchema, tableName)) if err != nil { return results, err } @@ -114,12 +113,14 @@ func (db *Postgres) GetTableColumns(database, table string) (results [][]string, func (db *Postgres) GetConstraints(table string) (constraints [][]string, error error) { splitTableString := strings.Split(table, ".") + tableSchema := splitTableString[0] tableName := splitTableString[1] rows, err := db.Connection.Query(fmt.Sprintf(` SELECT tc.constraint_name, - kcu.column_name + kcu.column_name, + tc.constraint_type FROM information_schema.table_constraints AS tc JOIN information_schema.key_column_usage AS kcu ON tc.constraint_name = kcu.constraint_name @@ -128,8 +129,9 @@ func (db *Postgres) GetConstraints(table string) (constraints [][]string, error AND ccu.table_schema = tc.table_schema WHERE NOT tc.constraint_type = 'FOREIGN KEY' + AND tc.table_schema = '%s' AND tc.table_name = '%s' - `, tableName)) + `, tableSchema, tableName)) if err != nil { return constraints, err } @@ -164,6 +166,7 @@ func (db *Postgres) GetConstraints(table string) (constraints [][]string, error func (db *Postgres) GetForeignKeys(table string) (foreignKeys [][]string, error error) { splitTableString := strings.Split(table, ".") + tableSchema := splitTableString[0] tableName := splitTableString[1] rows, err := db.Connection.Query(fmt.Sprintf(` @@ -180,8 +183,9 @@ func (db *Postgres) GetForeignKeys(table string) (foreignKeys [][]string, error AND ccu.table_schema = tc.table_schema WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_schema = '%s' AND tc.table_name = '%s' - `, tableName)) + `, tableSchema, tableName)) if err != nil { return foreignKeys, err } @@ -215,12 +219,17 @@ func (db *Postgres) GetForeignKeys(table string) (foreignKeys [][]string, error } func (db *Postgres) GetIndexes(table string) (indexes [][]string, error error) { + splitTableString := strings.Split(table, ".") + tableSchema := splitTableString[0] + tableName := splitTableString[1] + rows, err := db.Connection.Query(fmt.Sprintf(` SELECT i.relname AS index_name, a.attname AS column_name, am.amname AS type FROM + pg_namespace n, pg_class t, pg_class i, pg_index ix, @@ -233,11 +242,13 @@ func (db *Postgres) GetIndexes(table string) (indexes [][]string, error error) { and a.attnum = ANY(ix.indkey) and t.relkind = 'r' and am.oid = i.relam + and n.oid = t.relnamespace + and n.nspname = '%s' and t.relname = '%s' ORDER BY t.relname, i.relname - `, table)) + `, tableSchema, tableName)) if err != nil { return indexes, err } @@ -268,24 +279,20 @@ func (db *Postgres) GetIndexes(table string) (indexes [][]string, error error) { func (db *Postgres) GetRecords(table, where, sort string, offset, limit int) (records [][]string, totalRecords int, err error) { defaultLimit := 300 - - splitTableString := strings.Split(table, ".") - tableName := splitTableString[1] - isPaginationEnabled := offset >= 0 && limit >= 0 if limit != 0 { defaultLimit = limit } - query := fmt.Sprintf("SELECT * FROM %s s LIMIT %d OFFSET %d", tableName, defaultLimit, offset) + query := fmt.Sprintf("SELECT * FROM %s s LIMIT %d OFFSET %d", table, defaultLimit, offset) if where != "" { - query = fmt.Sprintf("SELECT * FROM %s %s LIMIT %d OFFSET %d", tableName, where, defaultLimit, offset) + query = fmt.Sprintf("SELECT * FROM %s %s LIMIT %d OFFSET %d", table, where, defaultLimit, offset) } if sort != "" { - query = fmt.Sprintf("SELECT * FROM %s %s ORDER BY %s LIMIT %d OFFSET %d", tableName, where, sort, defaultLimit, offset) + query = fmt.Sprintf("SELECT * FROM %s %s ORDER BY %s LIMIT %d OFFSET %d", table, where, sort, defaultLimit, offset) } paginatedRows, err := db.Connection.Query(query) @@ -294,7 +301,7 @@ func (db *Postgres) GetRecords(table, where, sort string, offset, limit int) (re } if isPaginationEnabled { - queryWithoutLimit := fmt.Sprintf("SELECT COUNT(*) FROM %s %s", tableName, where) + queryWithoutLimit := fmt.Sprintf("SELECT COUNT(*) FROM %s %s", table, where) rows := db.Connection.QueryRow(queryWithoutLimit) @@ -302,7 +309,10 @@ func (db *Postgres) GetRecords(table, where, sort string, offset, limit int) (re return records, totalRecords, err } - rows.Scan(&totalRecords) + err = rows.Scan(&totalRecords) + if err != nil { + return nil, 0, err + } defer paginatedRows.Close() } @@ -332,28 +342,28 @@ func (db *Postgres) GetRecords(table, where, sort string, offset, limit int) (re } func (db *Postgres) UpdateRecord(table, column, value, primaryKeyColumnName, primaryKeyValue string) (err error) { - query := fmt.Sprintf("UPDATE %s SET %s = '%s' WHERE '%s' = '%s'", table, column, value, primaryKeyColumnName, primaryKeyValue) + query := fmt.Sprintf("UPDATE %s SET %s = '%s' WHERE \"%s\" = '%s'", table, column, value, primaryKeyColumnName, primaryKeyValue) _, err = db.Connection.Exec(query) return err } func (db *Postgres) DeleteRecord(table, primaryKeyColumnName, primaryKeyValue string) (err error) { - query := fmt.Sprintf("DELETE FROM %s WHERE '%s' = '%s'", table, primaryKeyColumnName, primaryKeyValue) + query := fmt.Sprintf("DELETE FROM %s WHERE \"%s\" = '%s'", table, primaryKeyColumnName, primaryKeyValue) _, err = db.Connection.Exec(query) return err } func (db *Postgres) ExecuteDMLStatement(query string) (result string, err error) { - res, error := db.Connection.Exec(query) + res, err := db.Connection.Exec(query) - if error != nil { - return result, error + if err != nil { + return result, err } else { rowsAffected, _ := res.RowsAffected() - return fmt.Sprintf("%d rows affected", rowsAffected), error + return fmt.Sprintf("%d rows affected", rowsAffected), err } } @@ -375,7 +385,10 @@ func (db *Postgres) ExecuteQuery(query string) (results [][]string, err error) { rowValues[i] = new(sql.RawBytes) } - rows.Scan(rowValues...) + err = rows.Scan(rowValues...) + if err != nil { + return nil, err + } var row []string for _, col := range rowValues { @@ -399,8 +412,7 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange, inserts // Group changes by RowId and Table for _, change := range changes { if change.Type == "UPDATE" { - tableName := strings.Split(change.Table, ".")[1] - key := fmt.Sprintf("%s|%s|%s", tableName, change.PrimaryKeyColumnName, change.PrimaryKeyValue) + key := fmt.Sprintf("%s|%s|%s", change.Table, change.PrimaryKeyColumnName, change.PrimaryKeyValue) groupedUpdated[key] = append(groupedUpdated[key], change) } else if change.Type == "DELETE" { groupedDeletes = append(groupedDeletes, change) @@ -424,19 +436,17 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange, inserts // Merge all column updates updateClause := strings.Join(columns, ", ") - query := fmt.Sprintf("UPDATE %s SET %s WHERE '%s' = '%s';", table, updateClause, PrimaryKeyColumnName, primaryKeyValue) + query := fmt.Sprintf("UPDATE %s SET %s WHERE \"%s\" = '%s';", table, updateClause, PrimaryKeyColumnName, primaryKeyValue) queries = append(queries, query) } - for _, delete := range groupedDeletes { + for _, del := range groupedDeletes { statementType := "" query := "" statementType = "DELETE FROM" - tableName := strings.Split(delete.Table, ".")[1] - - query = fmt.Sprintf("%s %s WHERE \"%s\" = '%s'", statementType, tableName, delete.PrimaryKeyColumnName, delete.PrimaryKeyValue) + query = fmt.Sprintf("%s %s WHERE \"%s\" = '%s'", statementType, del.Table, del.PrimaryKeyColumnName, del.PrimaryKeyValue) if query != "" { queries = append(queries, query) @@ -447,29 +457,26 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange, inserts values := make([]string, 0, len(insert.Values)) for _, value := range insert.Values { - _, error := strconv.ParseFloat(value, 64) + _, err := strconv.ParseFloat(value, 64) - if strings.ToLower(value) != "default" && error != nil { + if strings.ToLower(value) != "default" && err != nil { values = append(values, fmt.Sprintf("'%s'", value)) } else { values = append(values, value) } } - tableName := strings.Split(insert.Table, ".")[1] - query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", tableName, strings.Join(insert.Columns, ", "), strings.Join(values, ", ")) - + query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", insert.Table, strings.Join(insert.Columns, ", "), strings.Join(values, ", ")) queries = append(queries, query) } - tx, error := db.Connection.Begin() - if error != nil { - return error + tx, err := db.Connection.Begin() + if err != nil { + return err } for _, query := range queries { _, err = tx.Exec(query) - if err != nil { tx.Rollback() @@ -478,7 +485,6 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange, inserts } err = tx.Commit() - if err != nil { return err }