Skip to content

Commit

Permalink
Merge pull request jorgerojas26#65 from jorgerojas26/postgres-connect…
Browse files Browse the repository at this point in the history
…ions-fix

Postgres connections fix
  • Loading branch information
jorgerojas26 authored Jul 10, 2024
2 parents 8357452 + a13f171 commit c89e6f1
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 15 deletions.
14 changes: 9 additions & 5 deletions components/Home.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,13 @@ func (home *Home) subscribeToTreeChanges() {
home.TabbedPane.AppendTab(tableName, table)
}

table.FetchRecords()
table.FetchRecords(func() {
home.focusLeftWrapper()
})

home.focusRightWrapper()
if table.state.error == "" {
home.focusRightWrapper()
}

app.App.ForceDraw()
}
Expand Down Expand Up @@ -202,7 +206,7 @@ func (home *Home) rightWrapperInputCapture(event *tcell.EventKey) *tcell.EventKe

if ((table.Menu != nil && table.Menu.GetSelectedOption() == 1) || table.Menu == nil) && !table.Pagination.GetIsFirstPage() && !table.GetIsLoading() {
table.Pagination.SetOffset(table.Pagination.GetOffset() - table.Pagination.GetLimit())
table.FetchRecords()
table.FetchRecords(nil)

}

Expand All @@ -216,7 +220,7 @@ func (home *Home) rightWrapperInputCapture(event *tcell.EventKey) *tcell.EventKe

if ((table.Menu != nil && table.Menu.GetSelectedOption() == 1) || table.Menu == nil) && !table.Pagination.GetIsLastPage() && !table.GetIsLoading() {
table.Pagination.SetOffset(table.Pagination.GetOffset() + table.Pagination.GetLimit())
table.FetchRecords()
table.FetchRecords(nil)
}
}
}
Expand Down Expand Up @@ -287,7 +291,7 @@ func (home *Home) homeInputCapture(event *tcell.EventKey) *tcell.EventKey {
home.ListOfDbChanges = []models.DbDmlChange{}
home.ListOfDbInserts = []models.DbInsert{}

table.FetchRecords()
table.FetchRecords(nil)
home.Tree.ForceRemoveHighlight()

}
Expand Down
8 changes: 4 additions & 4 deletions components/ResultsTable.go
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ func (table *ResultsTable) subscribeToFilterChanges() {
switch stateChange.Key {
case "Filter":
if stateChange.Value != "" {
rows := table.FetchRecords()
rows := table.FetchRecords(nil)

if len(rows) > 1 {
table.Menu.SetSelectedOption(1)
Expand All @@ -591,7 +591,7 @@ func (table *ResultsTable) subscribeToFilterChanges() {
}

} else {
table.FetchRecords()
table.FetchRecords(nil)

table.SetInputCapture(table.tableInputCapture)
App.SetFocus(table)
Expand Down Expand Up @@ -863,7 +863,7 @@ func (table *ResultsTable) SetSortedBy(column string, direction string) {
}
}

func (table *ResultsTable) FetchRecords() [][]string {
func (table *ResultsTable) FetchRecords(onError func()) [][]string {
tableName := table.GetDBReference()

table.SetLoading(true)
Expand All @@ -877,7 +877,7 @@ func (table *ResultsTable) FetchRecords() [][]string {
records, totalRecords, err := table.DBDriver.GetRecords(tableName, where, sort, table.Pagination.GetOffset(), table.Pagination.GetLimit())

if err != nil {
table.SetError(err.Error(), nil)
table.SetError(err.Error(), onError)
table.SetLoading(false)
} else {
if table.GetIsFiltering() {
Expand Down
102 changes: 96 additions & 6 deletions drivers/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,17 @@ import (
)

type Postgres struct {
Connection *sql.DB
Provider string
Connection *sql.DB
Provider string
CurrentDatabase string
PreviousDatabase string
Urlstr string
}

const (
DEFAULT_PORT = "5432"
)

func (db *Postgres) Connect(urlstr string) (err error) {
db.SetProvider("postgres")

Expand All @@ -30,6 +37,22 @@ func (db *Postgres) Connect(urlstr string) (err error) {
return err
}

db.Urlstr = urlstr

// get current database

rows := db.Connection.QueryRow("SELECT current_database();")

database := ""

err = rows.Scan(&database)

db.CurrentDatabase = database
db.PreviousDatabase = database
if err != nil {
return err
}

return nil
}

Expand Down Expand Up @@ -58,8 +81,24 @@ func (db *Postgres) GetDatabases() (databases []string, err error) {
func (db *Postgres) GetTables(database string) (tables map[string][]string, err error) {
tables = make(map[string][]string)

switchDatabase := false

if database != db.CurrentDatabase {
err = db.SwitchDatabase(database)
if err != nil {
return tables, err
}
switchDatabase = true
}

rows, err := db.Connection.Query(fmt.Sprintf("SELECT table_name, table_schema FROM information_schema.tables WHERE table_catalog = '%s'", database))
if err != nil {
if switchDatabase {
err = db.SwitchDatabase(db.PreviousDatabase)
if err != nil {
return tables, err
}
}
return tables, err
}

Expand Down Expand Up @@ -278,6 +317,7 @@ func (db *Postgres) GetIndexes(table string) (indexes [][]string, error error) {
}

func (db *Postgres) GetRecords(table, where, sort string, offset, limit int) (records [][]string, totalRecords int, err error) {
table = db.formatTableName(table)
defaultLimit := 300
isPaginationEnabled := offset >= 0 && limit >= 0

Expand Down Expand Up @@ -342,13 +382,15 @@ func (db *Postgres) GetRecords(table, where, sort string, offset, limit int) (re
}

func (db *Postgres) UpdateRecord(table, column, value, primaryKeyColumnName, primaryKeyValue string) (err error) {
table = db.formatTableName(table)
query := fmt.Sprintf("UPDATE %s SET %s = '%s' WHERE \"%s\" = '%s'", table, column, value, primaryKeyColumnName, primaryKeyValue)
_, err = db.Connection.Exec(query)

return err
}

func (db *Postgres) DeleteRecord(table, primaryKeyColumnName, primaryKeyValue string) (err error) {
table = db.formatTableName(table)
query := fmt.Sprintf("DELETE FROM %s WHERE \"%s\" = '%s'", table, primaryKeyColumnName, primaryKeyValue)
_, err = db.Connection.Exec(query)

Expand Down Expand Up @@ -412,7 +454,7 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange, inserts
// Group changes by RowId and Table
for _, change := range changes {
if change.Type == "UPDATE" {
key := fmt.Sprintf("%s|%s|%s", change.Table, change.PrimaryKeyColumnName, change.PrimaryKeyValue)
key := fmt.Sprintf("%s|%s|%s", db.formatTableName(change.Table), change.PrimaryKeyColumnName, change.PrimaryKeyValue)
groupedUpdated[key] = append(groupedUpdated[key], change)
} else if change.Type == "DELETE" {
groupedDeletes = append(groupedDeletes, change)
Expand All @@ -425,7 +467,7 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange, inserts

// Split key into table and rowId
splitted := strings.Split(key, "|")
table := splitted[0]
table := db.formatTableName(splitted[0])
PrimaryKeyColumnName := splitted[1]
primaryKeyValue := splitted[2]

Expand All @@ -446,7 +488,7 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange, inserts
query := ""

statementType = "DELETE FROM"
query = fmt.Sprintf("%s %s WHERE \"%s\" = '%s'", statementType, del.Table, del.PrimaryKeyColumnName, del.PrimaryKeyValue)
query = fmt.Sprintf("%s %s WHERE \"%s\" = '%s'", statementType, db.formatTableName(del.Table), del.PrimaryKeyColumnName, del.PrimaryKeyValue)

if query != "" {
queries = append(queries, query)
Expand All @@ -466,7 +508,7 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange, inserts
}
}

query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", insert.Table, strings.Join(insert.Columns, ", "), strings.Join(values, ", "))
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", db.formatTableName(insert.Table), strings.Join(insert.Columns, ", "), strings.Join(values, ", "))
queries = append(queries, query)
}

Expand Down Expand Up @@ -499,3 +541,51 @@ func (db *Postgres) SetProvider(provider string) {
func (db *Postgres) GetProvider() string {
return db.Provider
}

func (db *Postgres) SwitchDatabase(database string) error {
parsedConn, err := dburl.Parse(db.Urlstr)
if err != nil {
return err
}

user := parsedConn.User.Username()
password, _ := parsedConn.User.Password()
host := parsedConn.Host
port := parsedConn.Port()
dbname := parsedConn.Path

if port == "" {
port = DEFAULT_PORT
}

if dbname == "" {
dbname = database
}

connection, err := sql.Open("postgres", fmt.Sprintf("host=%s port=%s user=%s password=%s dbname='%s' sslmode=disable", host, port, user, password, dbname))
if err != nil {
return err
}

db.Connection.Close()
db.Connection = connection
db.PreviousDatabase = db.CurrentDatabase
db.CurrentDatabase = database

return nil
}

func (db *Postgres) formatTableName(table string) string {
splittedTableName := strings.Split(table, ".")

if len(splittedTableName) == 1 {
return table
}

schema := splittedTableName[0]
tableName := splittedTableName[1]

formattedTableName := fmt.Sprintf("\"%s\".\"%s\"", schema, tableName)

return formattedTableName
}

0 comments on commit c89e6f1

Please sign in to comment.