Skip to content

Commit

Permalink
Merge pull request #45 from danprueitt/support-tables-in-different-sc…
Browse files Browse the repository at this point in the history
…hemas

Support tables in different schemas
  • Loading branch information
jorgerojas26 authored Apr 6, 2024
2 parents 7bed807 + e515348 commit 14e0513
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 44 deletions.
28 changes: 25 additions & 3 deletions components/ResultsTable.go
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,7 @@ func (table *ResultsTable) AppendNewChange(changeType string, tableName string,
func (table *ResultsTable) GetPrimaryKeyValue(rowIndex int) (string, string) {
provider := table.DBDriver.GetProvider()
columns := table.GetColumns()
constraints := table.GetConstraints()

primaryKeyColumnName := ""
primaryKeyValue := ""
Expand Down Expand Up @@ -1131,18 +1132,39 @@ func (table *ResultsTable) GetPrimaryKeyValue(rowIndex int) (string, string) {

case "postgres":
keyColumnIndex := -1
constraintTypeColumnIndex := -1
constraintNameColumnIndex := -1
pKeyName := ""
primaryKeyColumnIndex := -1

for i, constraint := range constraints[0] {
if constraint == "constraint_type" {
constraintTypeColumnIndex = i
}
if constraint == "column_name" {
constraintNameColumnIndex = i
}
}

for _, col := range constraints {
if col[constraintTypeColumnIndex] == "PRIMARY KEY" {
pKeyName = col[constraintNameColumnIndex]
break
}
}

primaryKeyColumnName = pKeyName
for i, col := range columns[0] {
if col == "column_default" {
if col == "column_name" {
keyColumnIndex = i
break
}
}

for i, col := range columns {
if strings.Contains(col[keyColumnIndex], "nextval") {
if col[keyColumnIndex] == pKeyName {
primaryKeyColumnIndex = i - 1
primaryKeyColumnName = col[0]
break
}
}

Expand Down
2 changes: 1 addition & 1 deletion components/Tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func (tree *Tree) updateNodes(children map[string][]string, node *tview.TreeNode
for _, child := range values {
childNode := tview.NewTreeNode(child)
childNode.SetExpanded(defaultExpanded)
childNode.SetReference(node.GetText())
childNode.SetReference(key)
childNode.SetColor(tview.Styles.SecondaryTextColor)
if rootNode != nil {
rootNode.AddChild(childNode)
Expand Down
86 changes: 46 additions & 40 deletions drivers/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@ func (db *Postgres) Connect(urlstr string) (err error) {
db.SetProvider("postgres")

db.Connection, err = dburl.Open(urlstr)

if err != nil {
return err
}

err = db.Connection.Ping()

if err != nil {
return err
}
Expand Down Expand Up @@ -79,8 +77,9 @@ func (db *Postgres) GetTables(database string) (tables map[string][]string, err
}

func (db *Postgres) GetTableColumns(database, table string) (results [][]string, error error) {
tableSchema := strings.Split(table, ".")[0]
tableName := strings.Split(table, ".")[1]
rows, err := db.Connection.Query(fmt.Sprintf("SELECT column_name, data_type, is_nullable, column_default FROM information_schema.columns WHERE table_catalog = '%s' AND table_name = '%s' ORDER by ordinal_position", database, tableName))
rows, err := db.Connection.Query(fmt.Sprintf("SELECT column_name, data_type, is_nullable, column_default FROM information_schema.columns WHERE table_catalog = '%s' AND table_schema = '%s' AND table_name = '%s' ORDER by ordinal_position", database, tableSchema, tableName))
if err != nil {
return results, err
}
Expand Down Expand Up @@ -114,12 +113,14 @@ func (db *Postgres) GetTableColumns(database, table string) (results [][]string,

func (db *Postgres) GetConstraints(table string) (constraints [][]string, error error) {
splitTableString := strings.Split(table, ".")
tableSchema := splitTableString[0]
tableName := splitTableString[1]

rows, err := db.Connection.Query(fmt.Sprintf(`
SELECT
tc.constraint_name,
kcu.column_name
kcu.column_name,
tc.constraint_type
FROM
information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu ON tc.constraint_name = kcu.constraint_name
Expand All @@ -128,8 +129,9 @@ func (db *Postgres) GetConstraints(table string) (constraints [][]string, error
AND ccu.table_schema = tc.table_schema
WHERE
NOT tc.constraint_type = 'FOREIGN KEY'
AND tc.table_schema = '%s'
AND tc.table_name = '%s'
`, tableName))
`, tableSchema, tableName))
if err != nil {
return constraints, err
}
Expand Down Expand Up @@ -164,6 +166,7 @@ func (db *Postgres) GetConstraints(table string) (constraints [][]string, error

func (db *Postgres) GetForeignKeys(table string) (foreignKeys [][]string, error error) {
splitTableString := strings.Split(table, ".")
tableSchema := splitTableString[0]
tableName := splitTableString[1]

rows, err := db.Connection.Query(fmt.Sprintf(`
Expand All @@ -180,8 +183,9 @@ func (db *Postgres) GetForeignKeys(table string) (foreignKeys [][]string, error
AND ccu.table_schema = tc.table_schema
WHERE
tc.constraint_type = 'FOREIGN KEY'
AND tc.table_schema = '%s'
AND tc.table_name = '%s'
`, tableName))
`, tableSchema, tableName))
if err != nil {
return foreignKeys, err
}
Expand Down Expand Up @@ -215,12 +219,17 @@ func (db *Postgres) GetForeignKeys(table string) (foreignKeys [][]string, error
}

func (db *Postgres) GetIndexes(table string) (indexes [][]string, error error) {
splitTableString := strings.Split(table, ".")
tableSchema := splitTableString[0]
tableName := splitTableString[1]

rows, err := db.Connection.Query(fmt.Sprintf(`
SELECT
i.relname AS index_name,
a.attname AS column_name,
am.amname AS type
FROM
pg_namespace n,
pg_class t,
pg_class i,
pg_index ix,
Expand All @@ -233,11 +242,13 @@ func (db *Postgres) GetIndexes(table string) (indexes [][]string, error error) {
and a.attnum = ANY(ix.indkey)
and t.relkind = 'r'
and am.oid = i.relam
and n.oid = t.relnamespace
and n.nspname = '%s'
and t.relname = '%s'
ORDER BY
t.relname,
i.relname
`, table))
`, tableSchema, tableName))
if err != nil {
return indexes, err
}
Expand Down Expand Up @@ -268,24 +279,20 @@ func (db *Postgres) GetIndexes(table string) (indexes [][]string, error error) {

func (db *Postgres) GetRecords(table, where, sort string, offset, limit int) (records [][]string, totalRecords int, err error) {
defaultLimit := 300

splitTableString := strings.Split(table, ".")
tableName := splitTableString[1]

isPaginationEnabled := offset >= 0 && limit >= 0

if limit != 0 {
defaultLimit = limit
}

query := fmt.Sprintf("SELECT * FROM %s s LIMIT %d OFFSET %d", tableName, defaultLimit, offset)
query := fmt.Sprintf("SELECT * FROM %s s LIMIT %d OFFSET %d", table, defaultLimit, offset)

if where != "" {
query = fmt.Sprintf("SELECT * FROM %s %s LIMIT %d OFFSET %d", tableName, where, defaultLimit, offset)
query = fmt.Sprintf("SELECT * FROM %s %s LIMIT %d OFFSET %d", table, where, defaultLimit, offset)
}

if sort != "" {
query = fmt.Sprintf("SELECT * FROM %s %s ORDER BY %s LIMIT %d OFFSET %d", tableName, where, sort, defaultLimit, offset)
query = fmt.Sprintf("SELECT * FROM %s %s ORDER BY %s LIMIT %d OFFSET %d", table, where, sort, defaultLimit, offset)
}

paginatedRows, err := db.Connection.Query(query)
Expand All @@ -294,15 +301,18 @@ func (db *Postgres) GetRecords(table, where, sort string, offset, limit int) (re
}

if isPaginationEnabled {
queryWithoutLimit := fmt.Sprintf("SELECT COUNT(*) FROM %s %s", tableName, where)
queryWithoutLimit := fmt.Sprintf("SELECT COUNT(*) FROM %s %s", table, where)

rows := db.Connection.QueryRow(queryWithoutLimit)

if err != nil {
return records, totalRecords, err
}

rows.Scan(&totalRecords)
err = rows.Scan(&totalRecords)
if err != nil {
return nil, 0, err
}

defer paginatedRows.Close()
}
Expand Down Expand Up @@ -332,28 +342,28 @@ func (db *Postgres) GetRecords(table, where, sort string, offset, limit int) (re
}

func (db *Postgres) UpdateRecord(table, column, value, primaryKeyColumnName, primaryKeyValue string) (err error) {
query := fmt.Sprintf("UPDATE %s SET %s = '%s' WHERE '%s' = '%s'", table, column, value, primaryKeyColumnName, primaryKeyValue)
query := fmt.Sprintf("UPDATE %s SET %s = '%s' WHERE \"%s\" = '%s'", table, column, value, primaryKeyColumnName, primaryKeyValue)
_, err = db.Connection.Exec(query)

return err
}

func (db *Postgres) DeleteRecord(table, primaryKeyColumnName, primaryKeyValue string) (err error) {
query := fmt.Sprintf("DELETE FROM %s WHERE '%s' = '%s'", table, primaryKeyColumnName, primaryKeyValue)
query := fmt.Sprintf("DELETE FROM %s WHERE \"%s\" = '%s'", table, primaryKeyColumnName, primaryKeyValue)
_, err = db.Connection.Exec(query)

return err
}

func (db *Postgres) ExecuteDMLStatement(query string) (result string, err error) {
res, error := db.Connection.Exec(query)
res, err := db.Connection.Exec(query)

if error != nil {
return result, error
if err != nil {
return result, err
} else {
rowsAffected, _ := res.RowsAffected()

return fmt.Sprintf("%d rows affected", rowsAffected), error
return fmt.Sprintf("%d rows affected", rowsAffected), err
}
}

Expand All @@ -375,7 +385,10 @@ func (db *Postgres) ExecuteQuery(query string) (results [][]string, err error) {
rowValues[i] = new(sql.RawBytes)
}

rows.Scan(rowValues...)
err = rows.Scan(rowValues...)
if err != nil {
return nil, err
}

var row []string
for _, col := range rowValues {
Expand All @@ -399,8 +412,7 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange, inserts
// Group changes by RowId and Table
for _, change := range changes {
if change.Type == "UPDATE" {
tableName := strings.Split(change.Table, ".")[1]
key := fmt.Sprintf("%s|%s|%s", tableName, change.PrimaryKeyColumnName, change.PrimaryKeyValue)
key := fmt.Sprintf("%s|%s|%s", change.Table, change.PrimaryKeyColumnName, change.PrimaryKeyValue)
groupedUpdated[key] = append(groupedUpdated[key], change)
} else if change.Type == "DELETE" {
groupedDeletes = append(groupedDeletes, change)
Expand All @@ -424,19 +436,17 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange, inserts
// Merge all column updates
updateClause := strings.Join(columns, ", ")

query := fmt.Sprintf("UPDATE %s SET %s WHERE '%s' = '%s';", table, updateClause, PrimaryKeyColumnName, primaryKeyValue)
query := fmt.Sprintf("UPDATE %s SET %s WHERE \"%s\" = '%s';", table, updateClause, PrimaryKeyColumnName, primaryKeyValue)

queries = append(queries, query)
}

for _, delete := range groupedDeletes {
for _, del := range groupedDeletes {
statementType := ""
query := ""

statementType = "DELETE FROM"
tableName := strings.Split(delete.Table, ".")[1]

query = fmt.Sprintf("%s %s WHERE \"%s\" = '%s'", statementType, tableName, delete.PrimaryKeyColumnName, delete.PrimaryKeyValue)
query = fmt.Sprintf("%s %s WHERE \"%s\" = '%s'", statementType, del.Table, del.PrimaryKeyColumnName, del.PrimaryKeyValue)

if query != "" {
queries = append(queries, query)
Expand All @@ -447,29 +457,26 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange, inserts
values := make([]string, 0, len(insert.Values))

for _, value := range insert.Values {
_, error := strconv.ParseFloat(value, 64)
_, err := strconv.ParseFloat(value, 64)

if strings.ToLower(value) != "default" && error != nil {
if strings.ToLower(value) != "default" && err != nil {
values = append(values, fmt.Sprintf("'%s'", value))
} else {
values = append(values, value)
}
}

tableName := strings.Split(insert.Table, ".")[1]
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", tableName, strings.Join(insert.Columns, ", "), strings.Join(values, ", "))

query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", insert.Table, strings.Join(insert.Columns, ", "), strings.Join(values, ", "))
queries = append(queries, query)
}

tx, error := db.Connection.Begin()
if error != nil {
return error
tx, err := db.Connection.Begin()
if err != nil {
return err
}

for _, query := range queries {
_, err = tx.Exec(query)

if err != nil {
tx.Rollback()

Expand All @@ -478,7 +485,6 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange, inserts
}

err = tx.Commit()

if err != nil {
return err
}
Expand Down

0 comments on commit 14e0513

Please sign in to comment.