diff --git a/internal/api/handle_admin.go b/internal/api/handle_admin.go index 76943229..8d1972e7 100644 --- a/internal/api/handle_admin.go +++ b/internal/api/handle_admin.go @@ -221,25 +221,11 @@ func (s *Server) handleViewPostgresSettings(w http.ResponseWriter, r *http.Reque defer close() - internal, err := flypg.ReadFromFile(s.node.PGConfig.InternalConfigFile()) + all, err := s.node.PGConfig.CurrentConfig() if err != nil { renderErr(w, err) return } - user, err := flypg.ReadFromFile(s.node.PGConfig.UserConfigFile()) - if err != nil { - renderErr(w, err) - return - } - - all := map[string]interface{}{} - - for k, v := range internal { - all[k] = v - } - for k, v := range user { - all[k] = v - } var in []string @@ -266,26 +252,12 @@ func (s *Server) handleViewPostgresSettings(w http.ResponseWriter, r *http.Reque } func (s *Server) handleViewBouncerSettings(w http.ResponseWriter, r *http.Request) { - internal, err := flypg.ReadFromFile(s.node.PGBouncer.InternalConfigFile()) - if err != nil { - renderErr(w, err) - return - } - user, err := flypg.ReadFromFile(s.node.PGBouncer.UserConfigFile()) + all, err := s.node.PGBouncer.CurrentConfig() if err != nil { renderErr(w, err) return } - all := map[string]interface{}{} - - for k, v := range internal { - all[k] = v - } - for k, v := range user { - all[k] = v - } - var in []string if err := json.NewDecoder(r.Body).Decode(&in); err != nil { @@ -307,26 +279,12 @@ func (s *Server) handleViewBouncerSettings(w http.ResponseWriter, r *http.Reques } func (s *Server) handleViewRepmgrSettings(w http.ResponseWriter, r *http.Request) { - internal, err := flypg.ReadFromFile(s.node.RepMgr.InternalConfigFile()) - if err != nil { - renderErr(w, err) - return - } - user, err := flypg.ReadFromFile(s.node.RepMgr.UserConfigFile()) + all, err := s.node.RepMgr.CurrentConfig() if err != nil { renderErr(w, err) return } - all := map[string]interface{}{} - - for k, v := range internal { - all[k] = v - } - for k, v := range user { - all[k] = v - } - var in []string if err := json.NewDecoder(r.Body).Decode(&in); err != nil { diff --git a/internal/flypg/config.go b/internal/flypg/config.go index 87c48ba7..6702f528 100644 --- a/internal/flypg/config.go +++ b/internal/flypg/config.go @@ -19,6 +19,7 @@ type Config interface { UserConfig() ConfigMap SetUserConfig(configMap ConfigMap) ConsulKey() string + CurrentConfig() (ConfigMap, error) } func WriteUserConfig(c Config, consul *state.Store) error { diff --git a/internal/flypg/flypg.go b/internal/flypg/flypg.go index 515eb3e3..ef362526 100644 --- a/internal/flypg/flypg.go +++ b/internal/flypg/flypg.go @@ -55,6 +55,28 @@ func (c *FlyPGConfig) UserConfigFile() string { return c.userConfigFilePath } +func (c *FlyPGConfig) CurrentConfig() (ConfigMap, error) { + internal, err := ReadFromFile(c.InternalConfigFile()) + if err != nil { + return nil, err + } + user, err := ReadFromFile(c.UserConfigFile()) + if err != nil { + return nil, err + } + + all := ConfigMap{} + + for k, v := range internal { + all[k] = v + } + for k, v := range user { + all[k] = v + } + + return all, nil +} + func (c *FlyPGConfig) initialize() error { c.SetDefaults() diff --git a/internal/flypg/pg.go b/internal/flypg/pg.go index 050c4cb0..eba4229a 100644 --- a/internal/flypg/pg.go +++ b/internal/flypg/pg.go @@ -55,6 +55,28 @@ func (c *PGConfig) UserConfigFile() string { return c.userConfigFilePath } +func (c *PGConfig) CurrentConfig() (ConfigMap, error) { + internal, err := ReadFromFile(c.InternalConfigFile()) + if err != nil { + return nil, err + } + user, err := ReadFromFile(c.UserConfigFile()) + if err != nil { + return nil, err + } + + all := ConfigMap{} + + for k, v := range internal { + all[k] = v + } + for k, v := range user { + all[k] = v + } + + return all, nil +} + func NewConfig(dataDir string, port int) *PGConfig { return &PGConfig{ dataDir: dataDir, diff --git a/internal/flypg/pgbouncer.go b/internal/flypg/pgbouncer.go index 238e04fb..9383c04b 100644 --- a/internal/flypg/pgbouncer.go +++ b/internal/flypg/pgbouncer.go @@ -12,6 +12,12 @@ import ( "github.com/jackc/pgx/v5" ) +const ( + transactionPooler = "transaction" + sessionPooler = "session" + statementPooler = "statement" +) + type PGBouncer struct { PrivateIP string Credentials Credentials @@ -68,6 +74,37 @@ func (p *PGBouncer) ConfigurePrimary(ctx context.Context, primary string, reload return nil } +func (p *PGBouncer) CurrentConfig() (ConfigMap, error) { + internal, err := ReadFromFile(p.InternalConfigFile()) + if err != nil { + return nil, err + } + user, err := ReadFromFile(p.UserConfigFile()) + if err != nil { + return nil, err + } + + all := ConfigMap{} + + for k, v := range internal { + all[k] = v + } + for k, v := range user { + all[k] = v + } + + return all, nil +} + +func (p *PGBouncer) poolMode() (string, error) { + conf, err := p.CurrentConfig() + if err != nil { + return "", err + } + + return conf["pool_mode"].(string), nil +} + func (p *PGBouncer) initialize() error { cmdStr := fmt.Sprintf("mkdir -p %s", p.ConfigPath) if err := utils.RunCommand(cmdStr); err != nil { @@ -157,6 +194,40 @@ func (p *PGBouncer) forceReconnect(ctx context.Context, databases []string) erro return nil } +func (p *PGBouncer) killConnections(ctx context.Context, databases []string) error { + conn, err := p.NewConnection(ctx) + if err != nil { + return err + } + defer conn.Close(ctx) + + for _, db := range databases { + _, err = conn.Exec(ctx, fmt.Sprintf("KILL %s;", db)) + if err != nil { + return err + } + } + + return nil +} + +func (p *PGBouncer) resumeConnections(ctx context.Context, databases []string) error { + conn, err := p.NewConnection(ctx) + if err != nil { + return err + } + defer conn.Close(ctx) + + for _, db := range databases { + _, err = conn.Exec(ctx, fmt.Sprintf("RESUME %s;", db)) + if err != nil { + return err + } + } + + return nil +} + func (p *PGBouncer) NewConnection(ctx context.Context) (*pgx.Conn, error) { host := net.JoinHostPort(p.PrivateIP, strconv.Itoa(p.Port)) return openConnection(ctx, host, "pgbouncer", p.Credentials) diff --git a/internal/flypg/readonly.go b/internal/flypg/readonly.go index a43ad711..107b2208 100644 --- a/internal/flypg/readonly.go +++ b/internal/flypg/readonly.go @@ -148,8 +148,26 @@ func changeReadOnlyState(ctx context.Context, n *Node, enable bool) error { } defer bConn.Close(ctx) - if err := n.PGBouncer.forceReconnect(ctx, dbNames); err != nil { - return fmt.Errorf("failed to force connection reset: %s", err) + poolMode, err := n.PGBouncer.poolMode() + if err != nil { + return fmt.Errorf("failed to resolve active pool mode: %s", err) + } + + switch poolMode { + case transactionPooler, statementPooler: + if err := n.PGBouncer.forceReconnect(ctx, dbNames); err != nil { + return fmt.Errorf("failed to force connection reset: %s", err) + } + case sessionPooler: + if err := n.PGBouncer.killConnections(ctx, dbNames); err != nil { + return fmt.Errorf("failed to kill connections: %s", err) + } + + if err := n.PGBouncer.resumeConnections(ctx, dbNames); err != nil { + return fmt.Errorf("failed to resume connections: %s", err) + } + default: + return fmt.Errorf("failed to resolve valid pooler. found: %s", poolMode) } return nil diff --git a/internal/flypg/repmgr.go b/internal/flypg/repmgr.go index 73526d04..bd555b08 100644 --- a/internal/flypg/repmgr.go +++ b/internal/flypg/repmgr.go @@ -58,6 +58,28 @@ func (r *RepMgr) SetUserConfig(configMap ConfigMap) { r.userConfig = configMap } +func (r *RepMgr) CurrentConfig() (ConfigMap, error) { + internal, err := ReadFromFile(r.InternalConfigFile()) + if err != nil { + return nil, err + } + user, err := ReadFromFile(r.UserConfigFile()) + if err != nil { + return nil, err + } + + all := ConfigMap{} + + for k, v := range internal { + all[k] = v + } + for k, v := range user { + all[k] = v + } + + return all, nil +} + func (r *RepMgr) ConsulKey() string { return "repmgr" }