diff --git a/internal/flypg/admin/admin.go b/internal/flypg/admin/admin.go index a851cf85..ca754c86 100644 --- a/internal/flypg/admin/admin.go +++ b/internal/flypg/admin/admin.go @@ -74,13 +74,14 @@ func ChangePassword(ctx context.Context, pg *pgx.Conn, username, password string func CreateDatabaseWithOwner(ctx context.Context, pg *pgx.Conn, name, owner string) error { dbInfo, err := FindDatabase(ctx, pg, name) - if err != nil && err != pgx.ErrNoRows { + if err != nil { return err } - // Database already exists. + if dbInfo != nil { return nil } + sql := fmt.Sprintf("CREATE DATABASE %s OWNER %s;", name, owner) _, err = pg.Exec(ctx, sql) @@ -89,7 +90,7 @@ func CreateDatabaseWithOwner(ctx context.Context, pg *pgx.Conn, name, owner stri func CreateDatabase(ctx context.Context, pg *pgx.Conn, name string) error { dbInfo, err := FindDatabase(ctx, pg, name) - if err != nil && err != pgx.ErrNoRows { + if err != nil { return err } // Database already exists. @@ -197,7 +198,7 @@ func ListDatabases(ctx context.Context, pg *pgx.Conn) ([]DbInfo, error) { } defer rows.Close() - values := []DbInfo{} + var values []DbInfo for rows.Next() { di := DbInfo{} @@ -210,23 +211,19 @@ func ListDatabases(ctx context.Context, pg *pgx.Conn) ([]DbInfo, error) { return values, nil } -func FindDatabase(ctx context.Context, pg *pgx.Conn, name string) (*DbInfo, error) { - sql := ` - SELECT - datname, - (SELECT array_agg(u.usename::text order by u.usename) FROM pg_user u WHERE has_database_privilege(u.usename, d.datname, 'CONNECT')) as allowed_users - FROM pg_database d WHERE d.datname='%s'; - ` - - sql = fmt.Sprintf(sql, name) - row := pg.QueryRow(ctx, sql) - - db := new(DbInfo) - if err := row.Scan(&db.Name, &db.Users); err != nil { +func FindDatabase(ctx context.Context, conn *pgx.Conn, name string) (*DbInfo, error) { + dbs, err := ListDatabases(ctx, conn) + if err != nil { return nil, err } - return db, nil + for _, db := range dbs { + if db.Name == name { + return &db, nil + } + } + + return nil, nil } type UserInfo struct { @@ -260,7 +257,7 @@ func ListUsers(ctx context.Context, pg *pgx.Conn) ([]UserInfo, error) { } defer rows.Close() - values := []UserInfo{} + var values []UserInfo for rows.Next() { ui := UserInfo{} diff --git a/internal/flypg/node.go b/internal/flypg/node.go index 64c91f0a..0453e373 100644 --- a/internal/flypg/node.go +++ b/internal/flypg/node.go @@ -329,7 +329,7 @@ func (n *Node) PostInit(ctx context.Context) error { // Setup repmgr database and extension if err := n.RepMgr.enable(ctx, conn); err != nil { - fmt.Printf("failed to setup repmgr: %s\n", err) + return fmt.Errorf("failed to enable repmgr: %s", err) } // Register ourself as the primary