diff --git a/.flyctl/cmd/pg_unregister/main.go b/.flyctl/cmd/pg_unregister/main.go new file mode 100644 index 00000000..d5083c45 --- /dev/null +++ b/.flyctl/cmd/pg_unregister/main.go @@ -0,0 +1,36 @@ +package main + +import ( + "context" + "encoding/base64" + "fmt" + "os" + + "github.com/fly-apps/postgres-flex/pkg/flypg" + "github.com/fly-apps/postgres-flex/pkg/utils" +) + +func main() { + encodedArg := os.Args[1] + hostnameBytes, err := base64.StdEncoding.DecodeString(encodedArg) + if err != nil { + utils.WriteError(fmt.Errorf("failed to decode hostname: %v", err)) + os.Exit(1) + return + } + + node, err := flypg.NewNode() + if err != nil { + utils.WriteError(err) + os.Exit(1) + return + } + + if err := node.UnregisterMemberByHostname(context.Background(), string(hostnameBytes)); err != nil { + utils.WriteError(fmt.Errorf("failed to unregister member: %v", err)) + os.Exit(1) + return + } + + utils.WriteOutput("Member has been succesfully unregistered", "") +} diff --git a/Dockerfile b/Dockerfile index 95f99c85..ac254637 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,6 +10,8 @@ RUN CGO_ENABLED=0 GOOS=linux go build -v -o /fly/bin/event_handler ./cmd/event_h RUN CGO_ENABLED=0 GOOS=linux go build -v -o /fly/bin/failover_validation ./cmd/failover_validation RUN CGO_ENABLED=0 GOOS=linux go build -v -o /fly/bin/standby_cleaner ./cmd/standby_cleaner +RUN CGO_ENABLED=0 GOOS=linux go build -v -o /fly/bin/pg_unregister ./.flyctl/cmd/pg_unregister + RUN CGO_ENABLED=0 GOOS=linux go build -v -o /fly/bin/start ./cmd/start COPY ./bin/* /fly/bin/ diff --git a/cmd/event_handler/main.go b/cmd/event_handler/main.go index 3e541c49..5f3bdeed 100644 --- a/cmd/event_handler/main.go +++ b/cmd/event_handler/main.go @@ -35,7 +35,7 @@ func main() { fmt.Printf("failed initialize cluster state store. %v", err) } - member, err := cs.FindMember(int32(*nodeID)) + member, err := cs.FindMemberByID(int32(*nodeID)) if err != nil { fmt.Printf("failed to find member %v: %s", *nodeID, err) } @@ -64,7 +64,7 @@ func main() { fmt.Printf("failed to parse new member id: %s", err) } - member, err := cs.FindMember(int32(newMemberID)) + member, err := cs.FindMemberByID(int32(newMemberID)) if err != nil { fmt.Printf("failed to find member in consul: %s", err) } diff --git a/cmd/standby_cleaner/main.go b/cmd/standby_cleaner/main.go index c841f203..9625b896 100644 --- a/cmd/standby_cleaner/main.go +++ b/cmd/standby_cleaner/main.go @@ -7,10 +7,15 @@ import ( "time" "github.com/fly-apps/postgres-flex/pkg/flypg" - "github.com/fly-apps/postgres-flex/pkg/flypg/state" + "github.com/fly-apps/postgres-flex/pkg/flypg/admin" + "github.com/jackc/pgx/v4" ) -var Minute int64 = 60 +var ( + monitorFrequency = time.Minute * 5 + // TODO - Make this configurable and/or extend this to 12-24 hours. + deadMemberRemovalThreshold = time.Hour * 1 +) func main() { ctx := context.Background() @@ -20,55 +25,95 @@ func main() { os.Exit(1) } + // TODO - We should connect using the flypgadmin user so we can differentiate between + // internal admin connection usage and the actual repmgr process. conn, err := flypgNode.RepMgr.NewLocalConnection(ctx) if err != nil { fmt.Printf("failed to open local connection: %s\n", err) os.Exit(1) } - ticker := time.NewTicker(5 * time.Second) + seenAt := map[int]time.Time{} + + ticker := time.NewTicker(monitorFrequency) defer ticker.Stop() - seenAt := map[int]int64{} + for { + select { + case <-ticker.C: + role, err := flypgNode.RepMgr.CurrentRole(ctx, conn) + if err != nil { + fmt.Printf("Failed to check role: %s\n", err) + continue + } - for _ = range ticker.C { - role, err := flypgNode.RepMgr.CurrentRole(ctx, conn) - if err != nil { - fmt.Printf("Failed to check role: %s", err) - continue - } - if role != "primary" { - continue - } - standbys, err := flypgNode.RepMgr.Standbys(ctx, conn) - if err != nil { - fmt.Printf("Failed to get standbys: %s", err) - continue - } - for _, standby := range standbys { - newConn, err := flypgNode.RepMgr.NewRemoteConnection(ctx, standby.Ip) + if role != flypg.PrimaryRoleName { + continue + } + + standbys, err := flypgNode.RepMgr.Standbys(ctx, conn) if err != nil { - if time.Now().Unix()-seenAt[standby.Id] >= 10*Minute { - cs, err := state.NewClusterState() - if err != nil { - fmt.Printf("failed initialize cluster state store. %v", err) - } + fmt.Printf("Failed to query standbys: %s\n", err) + continue + } - err = flypgNode.RepMgr.UnregisterStandby(standby.Id) - if err != nil { - fmt.Printf("Failed to unregister %d: %s", standby.Id, err) - continue - } - delete(seenAt, standby.Id) + for _, standby := range standbys { + newConn, err := flypgNode.RepMgr.NewRemoteConnection(ctx, standby.Ip) + defer newConn.Close(ctx) + if err != nil { + // TODO - Verify the exception that's getting thrown. + if time.Now().Sub(seenAt[standby.Id]) >= deadMemberRemovalThreshold { + if err := flypgNode.UnregisterMemberByID(ctx, int32(standby.Id)); err != nil { + fmt.Printf("failed to unregister member %d: %v\n", standby.Id, err.Error()) + continue + } - // Remove from Consul - if err = cs.UnregisterMember(int32(standby.Id)); err != nil { - fmt.Printf("Failed to unregister %d from consul: %s", standby.Id, err) + delete(seenAt, standby.Id) } + + continue } - } else { - seenAt[standby.Id] = time.Now().Unix() - newConn.Close(ctx) + + seenAt[standby.Id] = time.Now() + } + + removeOrphanedReplicationSlots(ctx, conn, standbys) + } + } +} + +func removeOrphanedReplicationSlots(ctx context.Context, conn *pgx.Conn, standbys []flypg.Standby) { + var orphanedSlots []admin.ReplicationSlot + + slots, err := admin.ListReplicationSlots(ctx, conn) + if err != nil { + fmt.Printf("failed to list replication slots: %s", err) + } + + // An orphaned replication slot is defined as an inactive replication slot that is no longer tied to + // and existing repmgr member. + for _, slot := range slots { + matchFound := false + for _, standby := range standbys { + if slot.MemberID == int32(standby.Id) { + matchFound = true + } + } + + if !matchFound && !slot.Active { + orphanedSlots = append(orphanedSlots, slot) + } + } + + if len(orphanedSlots) > 0 { + fmt.Printf("%d orphaned replication slot(s) detected\n", len(orphanedSlots)) + + for _, slot := range orphanedSlots { + fmt.Printf("Dropping replication slot: %s\n", slot.Name) + + if err := admin.DropReplicationSlot(ctx, conn, slot.Name); err != nil { + fmt.Printf("failed to drop replication slot %s: %v\n", slot.Name, err) + continue } } } diff --git a/pkg/flypg/admin/admin.go b/pkg/flypg/admin/admin.go index 2218c347..e8bf217e 100644 --- a/pkg/flypg/admin/admin.go +++ b/pkg/flypg/admin/admin.go @@ -3,6 +3,8 @@ package admin import ( "context" "fmt" + "strconv" + "strings" "github.com/jackc/pgx/v4" ) @@ -75,6 +77,60 @@ func DeleteDatabase(ctx context.Context, pg *pgx.Conn, name string) error { return nil } +type ReplicationSlot struct { + MemberID int32 + Name string + Type string + Active bool + WalStatus string +} + +func ListReplicationSlots(ctx context.Context, pg *pgx.Conn) ([]ReplicationSlot, error) { + sql := fmt.Sprintf("SELECT slot_name, slot_type, active, wal_status from pg_replication_slots;") + rows, err := pg.Query(ctx, sql) + defer rows.Close() + if err != nil { + return nil, err + } + + var slots []ReplicationSlot + + for rows.Next() { + var slot ReplicationSlot + if err := rows.Scan(&slot.Name, &slot.Type, &slot.Active, &slot.WalStatus); err != nil { + return nil, err + } + + // Extract the repmgr member id from the slot name. + // Slot name has the following format: repmgr_slot_ + slotArr := strings.Split(slot.Name, "_") + if slotArr[0] == "repmgr" { + idStr := slotArr[2] + + num, err := strconv.ParseInt(idStr, 10, 32) + if err != nil { + return nil, err + } + + slot.MemberID = int32(num) + slots = append(slots, slot) + } + } + + return slots, nil +} + +func DropReplicationSlot(ctx context.Context, pg *pgx.Conn, name string) error { + sql := fmt.Sprintf("SELECT pg_drop_replication_slot('%s');", name) + + _, err := pg.Exec(ctx, sql) + if err != nil { + return err + } + + return nil +} + func EnableExtension(ctx context.Context, pg *pgx.Conn, extension string) error { sql := fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS %s;", extension) _, err := pg.Exec(context.Background(), sql) diff --git a/pkg/flypg/node.go b/pkg/flypg/node.go index 059c6fd8..b9105c24 100644 --- a/pkg/flypg/node.go +++ b/pkg/flypg/node.go @@ -305,7 +305,6 @@ func (n *Node) PostInit(ctx context.Context) error { return fmt.Errorf("failed to register member with consul: %s", err) } } - // Requery the primaryIP from consul in case the primary was assigned above. primary, err = cs.PrimaryMember() if err != nil { @@ -324,6 +323,34 @@ func (n *Node) NewLocalConnection(ctx context.Context, database string) (*pgx.Co return openConnection(ctx, host, database, n.OperatorCredentials) } +func (n *Node) UnregisterMemberByHostname(ctx context.Context, hostname string) error { + cs, err := state.NewClusterState() + if err != nil { + fmt.Printf("failed initialize cluster state store. %v", err) + } + + member, err := cs.FindMemberByHostname(hostname) + if err != nil { + return err + } + + return n.unregisterNode(ctx, cs, member) +} + +func (n *Node) UnregisterMemberByID(ctx context.Context, id int32) error { + cs, err := state.NewClusterState() + if err != nil { + fmt.Printf("failed initialize cluster state store. %v", err) + } + + member, err := cs.FindMemberByID(id) + if err != nil { + return err + } + + return n.unregisterNode(ctx, cs, member) +} + func (n *Node) isInitialized() bool { _, err := os.Stat(n.DataDir) if os.IsNotExist(err) { @@ -384,6 +411,25 @@ func (n *Node) createRequiredUsers(ctx context.Context, conn *pgx.Conn) error { return nil } +func (n *Node) unregisterNode(ctx context.Context, cs *state.ClusterState, member *state.Member) error { + if member == nil { + return state.ErrMemberNotFound + } + + // Unregister from repmgr + err := n.RepMgr.UnregisterStandby(int(member.ID)) + if err != nil { + return fmt.Errorf("failed to unregister member %d from repmgr: %s", member.ID, err) + } + + // Unregister from consul + if err := cs.UnregisterMember(member.ID); err != nil { + return fmt.Errorf("failed to unregister member %d from consul: %v", member.ID, err) + } + + return nil +} + type HBAEntry struct { Type string Database string diff --git a/pkg/flypg/state/cluster.go b/pkg/flypg/state/cluster.go index ab6c7119..0b82c664 100644 --- a/pkg/flypg/state/cluster.go +++ b/pkg/flypg/state/cluster.go @@ -80,12 +80,20 @@ func (c *ClusterState) UnregisterMember(id int32) error { return err } - // Rebuild the members slice and exclude the target member. + // Rebuild member slice without the target member + exists := false var members []*Member for _, member := range cluster.Members { - if member.ID != id { - members = append(members, member) + if member.ID == id { + exists = true + continue } + + members = append(members, member) + } + + if !exists { + return nil } cluster.Members = members @@ -148,7 +156,7 @@ func (c *ClusterState) PrimaryMember() (*Member, error) { return nil, nil } -func (c *ClusterState) FindMember(id int32) (*Member, error) { +func (c *ClusterState) FindMemberByID(id int32) (*Member, error) { cluster, _, err := c.clusterData() if err != nil { return nil, err @@ -163,6 +171,21 @@ func (c *ClusterState) FindMember(id int32) (*Member, error) { return nil, nil } +func (c *ClusterState) FindMemberByHostname(hostname string) (*Member, error) { + cluster, _, err := c.clusterData() + if err != nil { + return nil, err + } + + for _, member := range cluster.Members { + if member.Hostname == hostname { + return member, nil + } + } + + return nil, nil +} + func (c *ClusterState) clusterData() (*ClusterData, uint64, error) { var ( cluster ClusterData diff --git a/pkg/utils/response.go b/pkg/utils/response.go new file mode 100644 index 00000000..4f3063f8 --- /dev/null +++ b/pkg/utils/response.go @@ -0,0 +1,39 @@ +package utils + +import ( + "encoding/json" + "fmt" + "os" +) + +type Response struct { + Success bool `json:"success"` + Message string `json:"message"` + Data string `json:"data"` +} + +func WriteError(err error) { + resp := &Response{ + Success: false, + Message: err.Error(), + } + sendToStdout(resp) +} + +func WriteOutput(message, data string) { + resp := &Response{ + Success: true, + Message: message, + Data: data, + } + sendToStdout(resp) +} + +func sendToStdout(resp *Response) { + e, err := json.Marshal(resp) + if err != nil { + fmt.Println(err.Error()) + } + fmt.Println(string(e)) + os.Exit(0) +}