diff --git a/.deepsource.toml b/.deepsource.toml new file mode 100644 index 00000000..2b003949 --- /dev/null +++ b/.deepsource.toml @@ -0,0 +1,10 @@ +version = 1 + +[[analyzers]] +name = "shell" + +[[analyzers]] +name = "go" + + [analyzers.meta] + import_root = "github.com/fly-apps/postgres-flex" diff --git a/cmd/admin_server/main.go b/cmd/admin_server/main.go index 4d571581..2ea08bc0 100644 --- a/cmd/admin_server/main.go +++ b/cmd/admin_server/main.go @@ -2,14 +2,10 @@ package main import ( "github.com/fly-apps/postgres-flex/internal/api" - "github.com/fly-apps/postgres-flex/internal/flypg" ) func main() { - node, err := flypg.NewNode() - if err != nil { + if err := api.StartHttpServer(); err != nil { panic(err) } - - api.StartHttpServer(node) } diff --git a/cmd/event_handler/main.go b/cmd/event_handler/main.go index c717cda9..d0e23af6 100644 --- a/cmd/event_handler/main.go +++ b/cmd/event_handler/main.go @@ -15,17 +15,24 @@ import ( const eventLogFile = "/data/event.log" func main() { + ctx := context.Background() + + if err := processEvent(ctx); err != nil { + log.Println(err) + os.Exit(1) + } +} + +func processEvent(ctx context.Context) error { event := flag.String("event", "", "event type") nodeID := flag.Int("node-id", 0, "the node id") success := flag.String("success", "", "success (1) failure (0)") details := flag.String("details", "", "details") flag.Parse() - ctx := context.Background() - - logFile, err := os.OpenFile(eventLogFile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0644) + logFile, err := os.OpenFile(eventLogFile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600) if err != nil { - fmt.Printf("failed to open event log: %s", err) + return fmt.Errorf("failed to open event log: %s", err) } defer logFile.Close() @@ -34,46 +41,40 @@ func main() { node, err := flypg.NewNode() if err != nil { - log.Printf("failed to initialize node: %s", err) - os.Exit(1) + return fmt.Errorf("failed to initialize node: %s", err) } switch *event { case "child_node_disconnect", "child_node_reconnect", "child_node_new_connect": conn, err := node.RepMgr.NewLocalConnection(ctx) if err != nil { - log.Printf("failed to open local connection: %s", err) - os.Exit(1) + return fmt.Errorf("failed to open local connection: %s", err) } defer conn.Close(ctx) member, err := node.RepMgr.Member(ctx, conn) if err != nil { - log.Printf("failed to resolve member: %s", err) - os.Exit(1) + return fmt.Errorf("failed to resolve member: %s", err) } if member.Role != flypg.PrimaryRoleName { // We should never get here. log.Println("skipping since we are not the primary") - os.Exit(0) + return nil } if err := evaluateClusterState(ctx, conn, node); err != nil { - log.Printf("failed to evaluate cluster state: %s", err) - os.Exit(0) + return fmt.Errorf("failed to evaluate cluster state: %s", err) } - - os.Exit(0) - default: - // noop } + + return nil } func evaluateClusterState(ctx context.Context, conn *pgx.Conn, node *flypg.Node) error { primary, err := flypg.PerformScreening(ctx, conn, node) if errors.Is(err, flypg.ErrZombieDiagnosisUndecided) || errors.Is(err, flypg.ErrZombieDiscovered) { - if err := flypg.Quarantine(ctx, conn, node, primary); err != nil { + if err := flypg.Quarantine(ctx, node, primary); err != nil { return fmt.Errorf("failed to quarantine failed primary: %s", err) } return fmt.Errorf("primary has been quarantined: %s", err) diff --git a/cmd/monitor/monitor_cluster_state.go b/cmd/monitor/monitor_cluster_state.go index df2e1431..ef6e7da9 100644 --- a/cmd/monitor/monitor_cluster_state.go +++ b/cmd/monitor/monitor_cluster_state.go @@ -39,7 +39,7 @@ func clusterStateMonitorTick(ctx context.Context, node *flypg.Node) error { primary, err := flypg.PerformScreening(ctx, conn, node) if errors.Is(err, flypg.ErrZombieDiagnosisUndecided) || errors.Is(err, flypg.ErrZombieDiscovered) { - if err := flypg.Quarantine(ctx, conn, node, primary); err != nil { + if err := flypg.Quarantine(ctx, node, primary); err != nil { return fmt.Errorf("failed to quarantine failed primary: %s", err) } return fmt.Errorf("primary has been quarantined: %s", err) diff --git a/cmd/monitor/monitor_dead_members.go b/cmd/monitor/monitor_dead_members.go index 5deff5a5..98d28702 100644 --- a/cmd/monitor/monitor_dead_members.go +++ b/cmd/monitor/monitor_dead_members.go @@ -78,7 +78,7 @@ func deadMemberMonitorTick(ctx context.Context, node *flypg.Node, seenAt map[int // TODO - Verify the exception that's getting thrown. if time.Since(seenAt[standby.ID]) >= deadMemberRemovalThreshold { log.Printf("Removing dead member: %s\n", standby.Hostname) - if err := node.RepMgr.UnregisterMember(ctx, standby); err != nil { + if err := node.RepMgr.UnregisterMember(standby); err != nil { log.Printf("failed to unregister member %s: %v", standby.Hostname, err) continue } diff --git a/cmd/pg_unregister/main.go b/cmd/pg_unregister/main.go index 063ea01a..1dfdf7e9 100644 --- a/cmd/pg_unregister/main.go +++ b/cmd/pg_unregister/main.go @@ -11,43 +11,42 @@ import ( ) func main() { + ctx := context.Background() + + if err := processUnregistration(ctx); err != nil { + utils.WriteError(err) + os.Exit(1) + } + + utils.WriteOutput("Member has been succesfully unregistered", "") +} + +func processUnregistration(ctx context.Context) error { encodedArg := os.Args[1] hostnameBytes, err := base64.StdEncoding.DecodeString(encodedArg) if err != nil { - utils.WriteError(fmt.Errorf("failed to decode hostname: %v", err)) - os.Exit(1) - return + return fmt.Errorf("failed to decode hostname: %v", err) } - ctx := context.Background() - node, err := flypg.NewNode() if err != nil { - utils.WriteError(err) - os.Exit(1) - return + return fmt.Errorf("faied to initialize node: %s", err) } conn, err := node.RepMgr.NewLocalConnection(ctx) if err != nil { - utils.WriteError(fmt.Errorf("failed to connect to local db: %s", err)) - os.Exit(1) - return + return fmt.Errorf("failed to connect to local db: %s", err) } defer conn.Close(ctx) member, err := node.RepMgr.MemberByHostname(ctx, conn, string(hostnameBytes)) if err != nil { - utils.WriteError(fmt.Errorf("failed to resolve member: %s", err)) - os.Exit(1) - return + return fmt.Errorf("failed to resolve member: %s", err) } - if err := node.RepMgr.UnregisterMember(ctx, *member); err != nil { - utils.WriteError(fmt.Errorf("failed to unregister member: %v", err)) - os.Exit(1) - return + if err := node.RepMgr.UnregisterMember(*member); err != nil { + return fmt.Errorf("failed to unregister member: %v", err) } - utils.WriteOutput("Member has been succesfully unregistered", "") + return nil } diff --git a/internal/api/handle_admin.go b/internal/api/handle_admin.go index 3221c16d..0472dbfd 100644 --- a/internal/api/handle_admin.go +++ b/internal/api/handle_admin.go @@ -12,7 +12,7 @@ import ( "golang.org/x/exp/slices" ) -func handleReadonlyState(w http.ResponseWriter, r *http.Request) { +func handleReadonlyState(w http.ResponseWriter, _ *http.Request) { res := &Response{ Result: false, } @@ -24,7 +24,7 @@ func handleReadonlyState(w http.ResponseWriter, r *http.Request) { renderJSON(w, res, http.StatusOK) } -func handleHaproxyRestart(w http.ResponseWriter, r *http.Request) { +func handleHaproxyRestart(w http.ResponseWriter, _ *http.Request) { if err := flypg.RestartHaproxy(); err != nil { renderErr(w, err) return @@ -78,18 +78,18 @@ func handleDisableReadonly(w http.ResponseWriter, r *http.Request) { } func handleRole(w http.ResponseWriter, r *http.Request) { - conn, close, err := localConnection(r.Context(), "postgres") + node, err := flypg.NewNode() if err != nil { renderErr(w, err) return } - defer close() - node, err := flypg.NewNode() + conn, err := localConnection(r.Context(), "postgres") if err != nil { renderErr(w, err) return } + defer conn.Close(r.Context()) member, err := node.RepMgr.Member(r.Context(), conn) if err != nil { @@ -118,13 +118,19 @@ type SettingsUpdate struct { RestartRequired bool `json:"restart_required"` } -func (s *Server) handleUpdatePostgresSettings(w http.ResponseWriter, r *http.Request) { - conn, close, err := localConnection(r.Context(), "postgres") +func handleUpdatePostgresSettings(w http.ResponseWriter, r *http.Request) { + node, err := flypg.NewNode() if err != nil { renderErr(w, err) return } - defer close() + + conn, err := localConnection(r.Context(), "postgres") + if err != nil { + renderErr(w, err) + return + } + defer conn.Close(r.Context()) consul, err := state.NewStore() if err != nil { @@ -132,7 +138,7 @@ func (s *Server) handleUpdatePostgresSettings(w http.ResponseWriter, r *http.Req return } - user, err := flypg.ReadFromFile(s.node.PGConfig.UserConfigFile()) + user, err := flypg.ReadFromFile(node.PGConfig.UserConfigFile()) if err != nil { renderErr(w, err) return @@ -158,7 +164,7 @@ func (s *Server) handleUpdatePostgresSettings(w http.ResponseWriter, r *http.Req user[k] = v } - s.node.PGConfig.SetUserConfig(user) + node.PGConfig.SetUserConfig(user) var requiresRestart []string @@ -185,7 +191,7 @@ func (s *Server) handleUpdatePostgresSettings(w http.ResponseWriter, r *http.Req }} } - err = flypg.PushUserConfig(s.node.PGConfig, consul) + err = flypg.PushUserConfig(node.PGConfig, consul) if err != nil { renderErr(w, err) return @@ -194,13 +200,19 @@ func (s *Server) handleUpdatePostgresSettings(w http.ResponseWriter, r *http.Req renderJSON(w, res, http.StatusOK) } -func (s *Server) handleApplyConfig(w http.ResponseWriter, r *http.Request) { - conn, close, err := localConnection(r.Context(), "postgres") +func handleApplyConfig(w http.ResponseWriter, r *http.Request) { + node, err := flypg.NewNode() if err != nil { renderErr(w, err) return } - defer close() + + conn, err := localConnection(r.Context(), "postgres") + if err != nil { + renderErr(w, err) + return + } + defer conn.Close(r.Context()) consul, err := state.NewStore() if err != nil { @@ -208,7 +220,7 @@ func (s *Server) handleApplyConfig(w http.ResponseWriter, r *http.Request) { return } - err = flypg.SyncUserConfig(s.node.PGConfig, consul) + err = flypg.SyncUserConfig(node.PGConfig, consul) if err != nil { renderErr(w, err) return @@ -225,16 +237,21 @@ type PGSettingsResponse struct { Settings []admin.PGSetting `json:"settings"` } -func (s *Server) handleViewPostgresSettings(w http.ResponseWriter, r *http.Request) { - conn, close, err := localConnection(r.Context(), "postgres") +func handleViewPostgresSettings(w http.ResponseWriter, r *http.Request) { + node, err := flypg.NewNode() if err != nil { renderErr(w, err) return } - defer close() + conn, err := localConnection(r.Context(), "postgres") + if err != nil { + renderErr(w, err) + return + } + defer conn.Close(r.Context()) - all, err := s.node.PGConfig.CurrentConfig() + all, err := node.PGConfig.CurrentConfig() if err != nil { renderErr(w, err) return @@ -264,8 +281,14 @@ func (s *Server) handleViewPostgresSettings(w http.ResponseWriter, r *http.Reque renderJSON(w, resp, http.StatusOK) } -func (s *Server) handleViewRepmgrSettings(w http.ResponseWriter, r *http.Request) { - all, err := s.node.RepMgr.CurrentConfig() +func handleViewRepmgrSettings(w http.ResponseWriter, r *http.Request) { + node, err := flypg.NewNode() + if err != nil { + renderErr(w, err) + return + } + + all, err := node.RepMgr.CurrentConfig() if err != nil { renderErr(w, err) return diff --git a/internal/api/handle_databases.go b/internal/api/handle_databases.go index f502d2a0..3229b495 100644 --- a/internal/api/handle_databases.go +++ b/internal/api/handle_databases.go @@ -11,12 +11,12 @@ import ( func handleListDatabases(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - conn, close, err := localConnection(ctx, "postgres") + conn, err := localConnection(ctx, "postgres") if err != nil { renderErr(w, err) return } - defer close() + defer conn.Close(r.Context()) dbs, err := admin.ListDatabases(ctx, conn) if err != nil { @@ -36,12 +36,12 @@ func handleGetDatabase(w http.ResponseWriter, r *http.Request) { name = chi.URLParam(r, "name") ) - conn, close, err := localConnection(ctx, "postgres") + conn, err := localConnection(ctx, "postgres") if err != nil { renderErr(w, err) return } - defer close() + defer conn.Close(r.Context()) db, err := admin.FindDatabase(ctx, conn, name) if err != nil { @@ -58,12 +58,12 @@ func handleGetDatabase(w http.ResponseWriter, r *http.Request) { func handleCreateDatabase(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - conn, close, err := localConnection(ctx, "postgres") + conn, err := localConnection(ctx, "postgres") if err != nil { renderErr(w, err) return } - defer close() + defer conn.Close(r.Context()) var input createDatabaseRequest if err := json.NewDecoder(r.Body).Decode(&input); err != nil { @@ -77,12 +77,12 @@ func handleCreateDatabase(w http.ResponseWriter, r *http.Request) { return } - dbConn, close, err := localConnection(ctx, input.Name) + dbConn, err := localConnection(ctx, input.Name) if err != nil { renderErr(w, err) return } - defer close() + defer conn.Close(r.Context()) if err := admin.GrantCreateOnPublic(ctx, dbConn); err != nil { renderErr(w, err) @@ -98,12 +98,12 @@ func handleDeleteDatabase(w http.ResponseWriter, r *http.Request) { ctx = r.Context() name = chi.URLParam(r, "name") ) - conn, close, err := localConnection(ctx, "postgres") + conn, err := localConnection(ctx, "postgres") if err != nil { renderErr(w, err) return } - defer close() + defer conn.Close(r.Context()) err = admin.DeleteDatabase(ctx, conn, name) if err != nil { diff --git a/internal/api/handle_users.go b/internal/api/handle_users.go index aaa0e380..aea31b82 100644 --- a/internal/api/handle_users.go +++ b/internal/api/handle_users.go @@ -12,12 +12,12 @@ import ( func handleListUsers(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - conn, close, err := localConnection(ctx, "postgres") + conn, err := localConnection(ctx, "postgres") if err != nil { renderErr(w, err) return } - defer close() + defer conn.Close(r.Context()) users, err := admin.ListUsers(ctx, conn) if err != nil { @@ -38,12 +38,12 @@ func handleGetUser(w http.ResponseWriter, r *http.Request) { name = chi.URLParam(r, "name") ) - conn, close, err := localConnection(ctx, "postgres") + conn, err := localConnection(ctx, "postgres") if err != nil { renderErr(w, err) return } - defer close() + defer conn.Close(r.Context()) user, err := admin.FindUser(ctx, conn, name) if err != nil { @@ -59,12 +59,12 @@ func handleGetUser(w http.ResponseWriter, r *http.Request) { func handleCreateUser(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - conn, close, err := localConnection(ctx, "postgres") + conn, err := localConnection(ctx, "postgres") if err != nil { renderErr(w, err) return } - defer close() + defer conn.Close(r.Context()) var input createUserRequest err = json.NewDecoder(r.Body).Decode(&input) @@ -107,12 +107,12 @@ func handleDeleteUser(w http.ResponseWriter, r *http.Request) { name = chi.URLParam(r, "name") ) - conn, close, err := localConnection(r.Context(), "postgres") + conn, err := localConnection(r.Context(), "postgres") if err != nil { renderErr(w, err) return } - defer close() + defer conn.Close(r.Context()) databases, err := admin.ListDatabases(ctx, conn) if err != nil { @@ -121,12 +121,12 @@ func handleDeleteUser(w http.ResponseWriter, r *http.Request) { } for _, database := range databases { - dbConn, close, err := localConnection(r.Context(), database.Name) + dbConn, err := localConnection(r.Context(), database.Name) if err != nil { renderErr(w, err) return } - defer close() + defer dbConn.Close(r.Context()) if err := admin.ReassignOwnership(ctx, dbConn, name, "postgres"); err != nil { renderErr(w, fmt.Errorf("failed to reassign ownership: %s", err)) diff --git a/internal/api/handler.go b/internal/api/handler.go index d14d9f77..f46e6622 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "time" "github.com/fly-apps/postgres-flex/internal/flycheck" "github.com/fly-apps/postgres-flex/internal/flypg" @@ -13,21 +14,21 @@ import ( const Port = 5500 -type Server struct { - node *flypg.Node -} - -func StartHttpServer(node *flypg.Node) { - server := &Server{node: node} +func StartHttpServer() error { r := chi.NewMux() - r.Mount("/flycheck", flycheck.Handler()) - r.Mount("/commands", server.Handler()) + r.Mount("/commands", Handler()) - http.ListenAndServe(fmt.Sprintf(":%d", Port), r) + server := &http.Server{ + Handler: r, + Addr: fmt.Sprintf(":%v", Port), + ReadHeaderTimeout: 3 * time.Second, + } + + return server.ListenAndServe() } -func (s *Server) Handler() http.Handler { +func Handler() http.Handler { r := chi.NewRouter() r.Route("/users", func(r chi.Router) { @@ -51,28 +52,25 @@ func (s *Server) Handler() http.Handler { r.Get("/haproxy/restart", handleHaproxyRestart) r.Get("/role", handleRole) - r.Get("/settings/view/postgres", s.handleViewPostgresSettings) - r.Get("/settings/view/repmgr", s.handleViewRepmgrSettings) - r.Post("/settings/update/postgres", s.handleUpdatePostgresSettings) - r.Post("/settings/apply", s.handleApplyConfig) + r.Get("/settings/view/postgres", handleViewPostgresSettings) + r.Get("/settings/view/repmgr", handleViewRepmgrSettings) + r.Post("/settings/update/postgres", handleUpdatePostgresSettings) + r.Post("/settings/apply", handleApplyConfig) }) return r } -func localConnection(ctx context.Context, database string) (*pgx.Conn, func() error, error) { +func localConnection(ctx context.Context, database string) (*pgx.Conn, error) { node, err := flypg.NewNode() if err != nil { - return nil, nil, err + return nil, err } pg, err := node.NewLocalConnection(ctx, database) if err != nil { - return nil, nil, err - } - close := func() error { - return pg.Close(ctx) + return nil, err } - return pg, close, nil + return pg, nil } diff --git a/internal/flycheck/vm.go b/internal/flycheck/vm.go index ba256029..94372a07 100644 --- a/internal/flycheck/vm.go +++ b/internal/flycheck/vm.go @@ -3,8 +3,8 @@ package flycheck import ( "errors" "fmt" - "io/ioutil" "math" + "os" "runtime" "strconv" "syscall" @@ -37,8 +37,7 @@ func CheckVM(checks *check.CheckSuite) *check.CheckSuite { func checkPressure(name string) (string, error) { var avg10, avg60, avg300, counter float64 - //var rest string - raw, err := ioutil.ReadFile("/proc/pressure/" + name) + raw, err := os.ReadFile("/proc/pressure/" + name) if err != nil { return "", err } @@ -82,7 +81,7 @@ func checkPressure(name string) (string, error) { func checkLoad() (string, error) { var loadAverage1, loadAverage5, loadAverage10 float64 var runningProcesses, totalProcesses, lastProcessID int - raw, err := ioutil.ReadFile("/proc/loadavg") + raw, err := os.ReadFile("/proc/loadavg") if err != nil { return "", err diff --git a/internal/flypg/admin/admin.go b/internal/flypg/admin/admin.go index 778651f7..07aa2201 100644 --- a/internal/flypg/admin/admin.go +++ b/internal/flypg/admin/admin.go @@ -140,7 +140,7 @@ func DropReplicationSlot(ctx context.Context, pg *pgx.Conn, name string) error { func EnableExtension(ctx context.Context, pg *pgx.Conn, extension string) error { sql := fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS %s;", extension) - _, err := pg.Exec(context.Background(), sql) + _, err := pg.Exec(ctx, sql) return err } diff --git a/internal/flypg/config.go b/internal/flypg/config.go index c0aa31e2..a0cb75b4 100644 --- a/internal/flypg/config.go +++ b/internal/flypg/config.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/fly-apps/postgres-flex/internal/flypg/state" + "github.com/fly-apps/postgres-flex/internal/utils" ) type ConfigMap map[string]interface{} @@ -23,25 +24,30 @@ type Config interface { } func WriteUserConfig(c Config, consul *state.Store) error { - if c.UserConfig() != nil { - if err := pushToConsul(c, consul); err != nil { - return fmt.Errorf("failed to write to consul: %s", err) - } + if c.UserConfig() == nil { + return nil + } + + if err := pushToConsul(c, consul); err != nil { + return fmt.Errorf("failed to write to consul: %s", err) + } - if err := WriteConfigFiles(c); err != nil { - return fmt.Errorf("failed to write to pg config file: %s", err) - } + if err := WriteConfigFiles(c); err != nil { + return fmt.Errorf("failed to write to pg config file: %s", err) } return nil } func PushUserConfig(c Config, consul *state.Store) error { - if c.UserConfig() != nil { - if err := pushToConsul(c, consul); err != nil { - return fmt.Errorf("failed to write to consul: %s", err) - } + if c.UserConfig() == nil { + return nil + } + + if err := pushToConsul(c, consul); err != nil { + return fmt.Errorf("failed to write to consul: %s", err) } + return nil } @@ -69,11 +75,11 @@ func pushToConsul(c Config, consul *state.Store) error { configBytes, err := json.Marshal(c.UserConfig()) if err != nil { - return err + return fmt.Errorf("failed to marshal user config: %s", err) } if err := consul.PushUserConfig(c.ConsulKey(), configBytes); err != nil { - return err + return fmt.Errorf("failed to push user config to consul: %s", err) } return nil @@ -97,29 +103,12 @@ func pullFromConsul(c Config, consul *state.Store) (ConfigMap, error) { } func WriteConfigFiles(c Config) error { - internalFile, err := os.OpenFile(c.InternalConfigFile(), os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0644) - if err != nil { - return err + if err := writeUserConfigFile(c); err != nil { + return fmt.Errorf("failed to write user config: %s", err) } - defer internalFile.Close() - userFile, err := os.OpenFile(c.UserConfigFile(), os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0644) - if err != nil { - return err - } - defer userFile.Close() - - internal := c.InternalConfig() - - for key, value := range c.UserConfig() { - entry := fmt.Sprintf("%s = %v\n", key, value) - delete(internal, key) - userFile.Write([]byte(entry)) - } - - for key, value := range internal { - entry := fmt.Sprintf("%s = %v\n", key, value) - internalFile.Write([]byte(entry)) + if err := writeInternalConfigFile(c); err != nil { + return fmt.Errorf("failed to write internal config: %s", err) } return nil @@ -141,5 +130,48 @@ func ReadFromFile(path string) (ConfigMap, error) { conf[key] = value } - return conf, nil + return conf, file.Sync() +} + +func writeInternalConfigFile(c Config) error { + file, err := os.Create(c.InternalConfigFile()) + if err != nil { + return err + } + defer file.Close() + + internal := c.InternalConfig() + + for key, value := range internal { + entry := fmt.Sprintf("%s = %v\n", key, value) + file.Write([]byte(entry)) + } + + if err := utils.SetFileOwnership(c.InternalConfigFile(), "postgres"); err != nil { + return fmt.Errorf("failed to set file ownership on %s: %s", c.InternalConfigFile(), err) + } + + return file.Sync() +} + +func writeUserConfigFile(c Config) error { + file, err := os.Create(c.UserConfigFile()) + if err != nil { + return err + } + defer file.Close() + + internal := c.InternalConfig() + + for key, value := range c.UserConfig() { + entry := fmt.Sprintf("%s = %v\n", key, value) + delete(internal, key) + file.Write([]byte(entry)) + } + + if err := utils.SetFileOwnership(c.UserConfigFile(), "postgres"); err != nil { + return fmt.Errorf("failed to set file ownership on %s: %s", c.UserConfigFile(), err) + } + + return file.Sync() } diff --git a/internal/flypg/flypg.go b/internal/flypg/flypg.go index dcfa8023..3f06111d 100644 --- a/internal/flypg/flypg.go +++ b/internal/flypg/flypg.go @@ -39,7 +39,7 @@ func (c *FlyPGConfig) UserConfig() ConfigMap { return c.userConfig } -func (c *FlyPGConfig) ConsulKey() string { +func (*FlyPGConfig) ConsulKey() string { return "FlyPGConfig" } @@ -84,13 +84,21 @@ func (c *FlyPGConfig) initialize() error { if err != nil { return err } - defer internal.Close() + defer func() { + if err := internal.Close(); err != nil { + fmt.Printf("failed to close file: %s\n", err) + } + }() user, err := os.Create(c.userConfigFilePath) if err != nil { return err } - defer user.Close() + defer func() { + if err := user.Close(); err != nil { + fmt.Printf("failed to close file: %s\n", err) + } + }() return nil } diff --git a/internal/flypg/haproxy.go b/internal/flypg/haproxy.go index c41f8fe6..49775947 100644 --- a/internal/flypg/haproxy.go +++ b/internal/flypg/haproxy.go @@ -7,7 +7,7 @@ import ( ) func RestartHaproxy() error { - if err := utils.RunCommand("restart-haproxy", "root"); err != nil { + if _, err := utils.RunCommand("restart-haproxy", "root"); err != nil { return fmt.Errorf("failed to restart haproxy: %s", err) } diff --git a/internal/flypg/node.go b/internal/flypg/node.go index 317e054c..23022fd5 100644 --- a/internal/flypg/node.go +++ b/internal/flypg/node.go @@ -114,7 +114,7 @@ func NewNode() (*Node, error) { func (n *Node) Init(ctx context.Context) error { // Ensure directory and files have proper permissions if err := setDirOwnership(); err != nil { - return err + return fmt.Errorf("failed to set directory ownership: %s", err) } // Check to see if we were just restored @@ -122,7 +122,7 @@ func (n *Node) Init(ctx context.Context) error { // Check to see if there's an active restore. active, err := isRestoreActive() if err != nil { - return err + return fmt.Errorf("failed to verify active restore: %s", err) } if active { @@ -135,13 +135,13 @@ func (n *Node) Init(ctx context.Context) error { // Verify whether we are a booting zombie. if ZombieLockExists() { if err := handleZombieLock(ctx, n); err != nil { - return err + return fmt.Errorf("failed to handle zombie lock: %s", err) } } err := writeSSHKey() if err != nil { - return fmt.Errorf("failed initialize ssh. %v", err) + return fmt.Errorf("failed write ssh keys: %s", err) } store, err := state.NewStore() @@ -166,7 +166,6 @@ func (n *Node) Init(ctx context.Context) error { if !clusterInitialized { fmt.Println("Provisioning primary") - // Initialize ourselves as the primary. if err := n.initializePG(); err != nil { return fmt.Errorf("failed to initialize postgres %s", err) @@ -175,16 +174,19 @@ func (n *Node) Init(ctx context.Context) error { if err := n.setDefaultHBA(); err != nil { return fmt.Errorf("failed updating pg_hba.conf: %s", err) } - } else { fmt.Println("Provisioning standby") - // Initialize ourselves as a standby cloneTarget, err := n.RepMgr.ResolveMemberOverDNS(ctx) if err != nil { - return err + return fmt.Errorf("failed to resolve member over dns: %s", err) } if err := n.RepMgr.clonePrimary(cloneTarget.Hostname); err != nil { + // Clean-up the directory so it can be retried. + if rErr := os.Remove(n.DataDir); rErr != nil { + fmt.Printf("failed to cleanup postgresql dir after clone error: %s\n", rErr) + } + return fmt.Errorf("failed to clone primary: %s", err) } } @@ -195,7 +197,7 @@ func (n *Node) Init(ctx context.Context) error { } if err := setDirOwnership(); err != nil { - return err + return fmt.Errorf("failed to set directory ownership: %s", err) } return nil @@ -282,13 +284,13 @@ func (n *Node) PostInit(ctx context.Context) error { primary, err := PerformScreening(ctx, conn, n) if errors.Is(err, ErrZombieDiagnosisUndecided) { fmt.Println("Unable to confirm that we are the true primary!") - if err := Quarantine(ctx, conn, n, primary); err != nil { + if err := Quarantine(ctx, n, primary); err != nil { return fmt.Errorf("failed to quarantine failed primary: %s", err) } } else if errors.Is(err, ErrZombieDiscovered) { fmt.Printf("The majority of registered members agree that '%s' is the real primary.\n", primary) - if err := Quarantine(ctx, conn, n, primary); err != nil { + if err := Quarantine(ctx, n, primary); err != nil { return fmt.Errorf("failed to quarantine failed primary: %s", err) } // Issue panic to force a process restart so we can attempt to rejoin @@ -336,12 +338,25 @@ func (n *Node) initializePG() error { return nil } - if err := os.WriteFile("/data/.default_password", []byte(n.OperatorCredentials.Password), 0644); err != nil { - return err + if err := writePasswordFile(n.OperatorCredentials.Password); err != nil { + return fmt.Errorf("failed to write pg password file: %s", err) } - cmd := exec.Command("gosu", "postgres", "initdb", "--pgdata", n.DataDir, "--pwfile=/data/.default_password") - if _, err := cmd.CombinedOutput(); err != nil { - return err + + cmdStr := fmt.Sprintf("initdb --pgdata=%s --pwfile=/data/.default_password", n.DataDir) + if _, err := utils.RunCommand(cmdStr, "postgres"); err != nil { + return fmt.Errorf("failed to run postgres initdb: %s", err) + } + + return nil +} + +func writePasswordFile(pwd string) error { + if err := os.WriteFile("/data/.default_password", []byte(pwd), 0600); err != nil { + return fmt.Errorf("failed to write default password: %s", err) + } + + if err := utils.SetFileOwnership("/data/.default_password", "postgres"); err != nil { + return fmt.Errorf("failed to set file ownership: %s", err) } return nil @@ -349,10 +364,7 @@ func (n *Node) initializePG() error { func (n *Node) isPGInitialized() bool { _, err := os.Stat(n.DataDir) - if os.IsNotExist(err) { - return false - } - return true + return !os.IsNotExist(err) } func (n *Node) configureInternal(store *state.Store) error { @@ -361,11 +373,11 @@ func (n *Node) configureInternal(store *state.Store) error { } if err := SyncUserConfig(&n.InternalConfig, store); err != nil { - return fmt.Errorf("failed to sync user config from consul for internal config: %s", err) + return fmt.Errorf("failed to sync internal config from consul: %s", err) } if err := WriteConfigFiles(&n.InternalConfig); err != nil { - return fmt.Errorf("failed to write config files for internal config: %s", err) + return fmt.Errorf("failed to write internal config files: %s", err) } return nil @@ -397,7 +409,7 @@ func (n *Node) configurePostgres(store *state.Store) error { } if err := WriteConfigFiles(n.PGConfig); err != nil { - return err + return fmt.Errorf("failed to write pg config files: %s", err) } return nil @@ -506,7 +518,7 @@ func (n *Node) setDefaultHBA() error { } path := fmt.Sprintf("%s/pg_hba.conf", n.DataDir) - file, err := os.OpenFile(path, os.O_RDWR|os.O_TRUNC, 0644) + file, err := os.OpenFile(path, os.O_RDWR|os.O_TRUNC, 0600) if err != nil { return err } @@ -520,7 +532,7 @@ func (n *Node) setDefaultHBA() error { } } - return nil + return file.Sync() } func openConnection(parentCtx context.Context, host string, database string, creds Credentials) (*pgx.Conn, error) { diff --git a/internal/flypg/pg.go b/internal/flypg/pg.go index b2bae49e..8de38ed2 100644 --- a/internal/flypg/pg.go +++ b/internal/flypg/pg.go @@ -13,7 +13,6 @@ import ( "github.com/fly-apps/postgres-flex/internal/flypg/admin" "github.com/fly-apps/postgres-flex/internal/utils" "github.com/jackc/pgx/v5" - "github.com/pkg/errors" ) type PGConfig struct { @@ -39,7 +38,7 @@ func (c *PGConfig) UserConfig() ConfigMap { return c.userConfig } -func (c *PGConfig) ConsulKey() string { +func (*PGConfig) ConsulKey() string { return "PGConfig" } @@ -119,65 +118,6 @@ func (c *PGConfig) Print(w io.Writer) error { return e.Encode(cfg) } -// Setup will ensure the required configuration files are stubbed and the parent -// postgresql.conf file includes them. -func (c *PGConfig) initialize() error { - if _, err := os.Stat(c.internalConfigFilePath); err != nil { - if os.IsNotExist(err) { - if err := utils.RunCommand(fmt.Sprintf("touch %s", c.internalConfigFilePath), "postgres"); err != nil { - return err - } - } else { - return err - } - } - - if _, err := os.Stat(c.userConfigFilePath); err != nil { - if os.IsNotExist(err) { - if err := utils.RunCommand(fmt.Sprintf("touch %s", c.userConfigFilePath), "postgres"); err != nil { - return err - } - } else { - return err - } - } - - b, err := os.ReadFile(c.configFilePath) - if err != nil { - return err - } - - var entries []string - if !strings.Contains(string(b), "postgresql.internal.conf") { - entries = append(entries, "include 'postgresql.internal.conf'\n") - } - - if !strings.Contains(string(b), "postgresql.user.conf") { - entries = append(entries, "include 'postgresql.user.conf'\n") - } - - if len(entries) > 0 { - f, err := os.OpenFile(c.configFilePath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600) - if err != nil { - return nil - } - defer f.Close() - - for _, entry := range entries { - if _, err := f.WriteString(entry); err != nil { - return fmt.Errorf("failed append configuration entry: %s", err) - } - } - } - - err = c.SetDefaults() - if err != nil { - return errors.New("Failed to set PG defaults") - } - - return nil -} - // SetDefaults WriteDefaults will resolve the default configuration settings and write them to the // internal config file. func (c *PGConfig) SetDefaults() error { @@ -258,9 +198,74 @@ func (c *PGConfig) RuntimeApply(ctx context.Context, conn *pgx.Conn) error { return nil } +// initialize will ensure the required configuration files are stubbed and the parent +// postgresql.conf file includes them. +func (c *PGConfig) initialize() error { + if _, err := os.Stat(c.internalConfigFilePath); err != nil { + if os.IsNotExist(err) { + if _, err := utils.RunCommand(fmt.Sprintf("touch %s", c.internalConfigFilePath), "postgres"); err != nil { + return err + } + } else { + return err + } + } + + if _, err := os.Stat(c.userConfigFilePath); err != nil { + if os.IsNotExist(err) { + if _, err := utils.RunCommand(fmt.Sprintf("touch %s", c.userConfigFilePath), "postgres"); err != nil { + return err + } + } else { + return err + } + } + + b, err := os.ReadFile(c.configFilePath) + if err != nil { + return err + } + + var entries []string + if !strings.Contains(string(b), "postgresql.internal.conf") { + entries = append(entries, "include 'postgresql.internal.conf'\n") + } + + if !strings.Contains(string(b), "postgresql.user.conf") { + entries = append(entries, "include 'postgresql.user.conf'\n") + } + + if len(entries) > 0 { + if err := c.writePGConfigEntries(entries); err != nil { + return fmt.Errorf("failed to write pg entries: %s", err) + } + } + + if err := c.SetDefaults(); err != nil { + return fmt.Errorf("failed to set pg defaults: %s", err) + } + + return nil +} + +func (c *PGConfig) writePGConfigEntries(entries []string) error { + f, err := os.OpenFile(c.configFilePath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600) + if err != nil { + return err + } + defer f.Close() + + for _, entry := range entries { + if _, err := f.WriteString(entry); err != nil { + return fmt.Errorf("failed append configuration entry: %s", err) + } + } + + return f.Sync() +} + func memTotalInBytes() (int64, error) { memoryStr := os.Getenv("FLY_VM_MEMORY_MB") - if memoryStr == "" { return 0, fmt.Errorf("FLY_VM_MEMORY_MB envvar has not been set") } diff --git a/internal/flypg/readonly.go b/internal/flypg/readonly.go index 693186fd..b245322b 100644 --- a/internal/flypg/readonly.go +++ b/internal/flypg/readonly.go @@ -105,11 +105,7 @@ func BroadcastReadonlyChange(ctx context.Context, n *Node, enabled bool) error { func ReadOnlyLockExists() bool { _, err := os.Stat(readOnlyLockFile) - if os.IsNotExist(err) { - return false - } - - return true + return !os.IsNotExist(err) } func writeReadOnlyLock() error { @@ -117,17 +113,12 @@ func writeReadOnlyLock() error { return nil } - if err := os.WriteFile(readOnlyLockFile, []byte(time.Now().String()), 0644); err != nil { - return err - } - - pgUID, pgGID, err := utils.SystemUserIDs("postgres") - if err != nil { + if err := os.WriteFile(readOnlyLockFile, []byte(time.Now().String()), 0600); err != nil { return err } - if err := os.Chown(readOnlyLockFile, pgUID, pgGID); err != nil { - return fmt.Errorf("failed to set readonly.lock owner: %s", err) + if err := utils.SetFileOwnership(readOnlyLockFile, "postgres"); err != nil { + return fmt.Errorf("failed to set file ownership: %s", err) } return nil @@ -138,11 +129,7 @@ func removeReadOnlyLock() error { return nil } - if err := os.Remove(readOnlyLockFile); err != nil { - return err - } - - return nil + return os.Remove(readOnlyLockFile) } func changeReadOnlyState(ctx context.Context, n *Node, enable bool) error { @@ -169,7 +156,6 @@ func changeReadOnlyState(ctx context.Context, n *Node, enable bool) error { return fmt.Errorf("failed to list database: %s", err) } - var dbNames []string for _, db := range databases { // exclude administrative dbs if db.Name == "repmgr" || db.Name == "postgres" { @@ -180,8 +166,6 @@ func changeReadOnlyState(ctx context.Context, n *Node, enable bool) error { if _, err = conn.Exec(ctx, sql); err != nil { return fmt.Errorf("failed to alter readonly state on db %s: %s", db.Name, err) } - - dbNames = append(dbNames, db.Name) } } diff --git a/internal/flypg/repmgr.go b/internal/flypg/repmgr.go index 5ea67d08..92cdb536 100644 --- a/internal/flypg/repmgr.go +++ b/internal/flypg/repmgr.go @@ -80,7 +80,7 @@ func (r *RepMgr) CurrentConfig() (ConfigMap, error) { return all, nil } -func (r *RepMgr) ConsulKey() string { +func (*RepMgr) ConsulKey() string { return "repmgr" } @@ -97,7 +97,7 @@ func (r *RepMgr) NewRemoteConnection(ctx context.Context, hostname string) (*pgx func (r *RepMgr) initialize() error { r.setDefaults() - file, err := os.OpenFile(r.ConfigPath, os.O_APPEND|os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + file, err := os.Create(r.ConfigPath) if err != nil { return nil } @@ -115,11 +115,11 @@ func (r *RepMgr) initialize() error { return fmt.Errorf("failed creating pgpass file: %s", err) } - if err := setDirOwnership(); err != nil { - return fmt.Errorf("failed to set dir ownership: %s", err) + if err := utils.SetFileOwnership(r.ConfigPath, "postgres"); err != nil { + return fmt.Errorf("failed to set repmgr.conf ownership: %s", err) } - return nil + return file.Sync() } func (r *RepMgr) setup(ctx context.Context, conn *pgx.Conn) error { @@ -162,20 +162,16 @@ func (r *RepMgr) setDefaults() { func (r *RepMgr) registerPrimary() error { cmdStr := fmt.Sprintf("repmgr -f %s primary register -F -v", r.ConfigPath) - if err := utils.RunCommand(cmdStr, "postgres"); err != nil { - return err - } + _, err := utils.RunCommand(cmdStr, "postgres") - return nil + return err } func (r *RepMgr) unregisterPrimary(id int) error { cmdStr := fmt.Sprintf("repmgr primary unregister -f %s --node-id=%d", r.ConfigPath, id) - if err := utils.RunCommand(cmdStr, "postgres"); err != nil { - return err - } + _, err := utils.RunCommand(cmdStr, "postgres") - return nil + return err } func (r *RepMgr) rejoinCluster(hostname string) error { @@ -188,18 +184,15 @@ func (r *RepMgr) rejoinCluster(hostname string) error { ) fmt.Println(cmdStr) + _, err := utils.RunCommand(cmdStr, "postgres") - if err := utils.RunCommand(cmdStr, "postgres"); err != nil { - return err - } - - return nil + return err } func (r *RepMgr) registerStandby() error { // Force re-registry to ensure the standby picks up any new configuration changes. cmdStr := fmt.Sprintf("repmgr -f %s standby register -F", r.ConfigPath) - if err := utils.RunCommand(cmdStr, "postgres"); err != nil { + if _, err := utils.RunCommand(cmdStr, "postgres"); err != nil { fmt.Printf("failed to register standby: %s", err) } @@ -208,7 +201,7 @@ func (r *RepMgr) registerStandby() error { func (r *RepMgr) unregisterStandby(id int) error { cmdStr := fmt.Sprintf("repmgr standby unregister -f %s --node-id=%d", r.ConfigPath, id) - if err := utils.RunCommand(cmdStr, "postgres"); err != nil { + if _, err := utils.RunCommand(cmdStr, "postgres"); err != nil { fmt.Printf("failed to unregister standby: %s", err) } @@ -217,8 +210,8 @@ func (r *RepMgr) unregisterStandby(id int) error { func (r *RepMgr) clonePrimary(ipStr string) error { cmdStr := fmt.Sprintf("mkdir -p %s", r.DataDir) - if err := utils.RunCommand(cmdStr, "postgres"); err != nil { - return err + if _, err := utils.RunCommand(cmdStr, "postgres"); err != nil { + return fmt.Errorf("failed to create pg directory: %s", err) } cmdStr = fmt.Sprintf("repmgr -h %s -p %d -d %s -U %s -f %s standby clone -F", @@ -229,17 +222,25 @@ func (r *RepMgr) clonePrimary(ipStr string) error { r.ConfigPath) fmt.Println(cmdStr) - return utils.RunCommand(cmdStr, "postgres") + if _, err := utils.RunCommand(cmdStr, "postgres"); err != nil { + return fmt.Errorf("failed to clone primary: %s", err) + } + + return nil } func (r *RepMgr) writePasswdConf() error { path := "/data/.pgpass" file, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0600) if err != nil { - return err + return fmt.Errorf("failed to open repmgr password file: %s", err) } defer file.Close() + if err := utils.SetFileOwnership(path, "postgres"); err != nil { + return fmt.Errorf("failed to set file ownership: %s", err) + } + entries := []string{ fmt.Sprintf("*:*:*:%s:%s", r.Credentials.Username, r.Credentials.Password), } @@ -252,7 +253,7 @@ func (r *RepMgr) writePasswdConf() error { } } - return nil + return file.Sync() } type Member struct { @@ -263,7 +264,7 @@ type Member struct { Role string } -func (r *RepMgr) Members(ctx context.Context, pg *pgx.Conn) ([]Member, error) { +func (*RepMgr) Members(ctx context.Context, pg *pgx.Conn) ([]Member, error) { sql := "select node_id, node_name, location, active, type from repmgr.nodes;" rows, err := pg.Query(ctx, sql) if err != nil { @@ -272,7 +273,6 @@ func (r *RepMgr) Members(ctx context.Context, pg *pgx.Conn) ([]Member, error) { defer rows.Close() var members []Member - for rows.Next() { var member Member if err := rows.Scan(&member.ID, &member.Hostname, &member.Region, &member.Active, &member.Role); err != nil { @@ -282,7 +282,7 @@ func (r *RepMgr) Members(ctx context.Context, pg *pgx.Conn) ([]Member, error) { members = append(members, member) } - return members, err + return members, nil } func (r *RepMgr) Member(ctx context.Context, conn *pgx.Conn) (*Member, error) { @@ -300,7 +300,7 @@ func (r *RepMgr) Member(ctx context.Context, conn *pgx.Conn) (*Member, error) { return nil, pgx.ErrNoRows } -func (r *RepMgr) PrimaryMember(ctx context.Context, pg *pgx.Conn) (*Member, error) { +func (*RepMgr) PrimaryMember(ctx context.Context, pg *pgx.Conn) (*Member, error) { var member Member sql := "select node_id, node_name, location, active, type from repmgr.nodes where type = 'primary' and active = true;" err := pg.QueryRow(ctx, sql).Scan(&member.ID, &member.Hostname, &member.Region, &member.Active, &member.Role) @@ -327,7 +327,7 @@ func (r *RepMgr) StandbyMembers(ctx context.Context, conn *pgx.Conn) ([]Member, return standbys, nil } -func (r *RepMgr) MemberByID(ctx context.Context, pg *pgx.Conn, id int) (*Member, error) { +func (*RepMgr) MemberByID(ctx context.Context, pg *pgx.Conn, id int) (*Member, error) { var member Member sql := fmt.Sprintf("select node_id, node_name, location, active, type from repmgr.nodes where node_id = %d;", id) @@ -339,7 +339,7 @@ func (r *RepMgr) MemberByID(ctx context.Context, pg *pgx.Conn, id int) (*Member, return &member, nil } -func (r *RepMgr) MemberByHostname(ctx context.Context, pg *pgx.Conn, hostname string) (*Member, error) { +func (*RepMgr) MemberByHostname(ctx context.Context, pg *pgx.Conn, hostname string) (*Member, error) { var member Member sql := fmt.Sprintf("select node_id, node_name, location, active, type from repmgr.nodes where node_name = '%s';", hostname) @@ -410,7 +410,7 @@ func (r *RepMgr) HostInRegion(ctx context.Context, hostname string) (bool, error return false, nil } -func (r *RepMgr) UnregisterMember(ctx context.Context, member Member) error { +func (r *RepMgr) UnregisterMember(member Member) error { if member.Role == PrimaryRoleName { if err := r.unregisterPrimary(member.ID); err != nil { return fmt.Errorf("failed to unregister member %d: %s", member.ID, err) diff --git a/internal/flypg/restore.go b/internal/flypg/restore.go index be051879..6e76ef06 100644 --- a/internal/flypg/restore.go +++ b/internal/flypg/restore.go @@ -112,15 +112,11 @@ func backupHBAFile() error { return err } - if err = os.WriteFile(pathToHBABackup, val, 0644); err != nil { - return err - } - - return nil + return os.WriteFile(pathToHBABackup, val, 0600) } func grantLocalAccess() error { - file, err := os.OpenFile(pathToHBAFile, os.O_RDWR|os.O_TRUNC, 0644) + file, err := os.OpenFile(pathToHBAFile, os.O_RDWR|os.O_TRUNC, 0600) if err != nil { return err } @@ -132,7 +128,7 @@ func grantLocalAccess() error { return err } - return nil + return file.Sync() } func restoreHBAFile() error { @@ -143,7 +139,7 @@ func restoreHBAFile() error { } // open the main pg_hba - file, err := os.OpenFile(pathToHBAFile, os.O_RDWR|os.O_TRUNC, 0644) + file, err := os.OpenFile(pathToHBAFile, os.O_RDWR|os.O_TRUNC, 0600) if err != nil { return err } @@ -160,11 +156,11 @@ func restoreHBAFile() error { return err } - return nil + return file.Sync() } func setRestoreLock() error { - file, err := os.OpenFile(restoreLockFile, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0644) + file, err := os.OpenFile(restoreLockFile, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { return err } @@ -175,7 +171,7 @@ func setRestoreLock() error { return err } - return nil + return file.Sync() } func openConn(ctx context.Context, n *Node) (*pgx.Conn, error) { diff --git a/internal/flypg/ssh.go b/internal/flypg/ssh.go index e802f889..a452c4c6 100644 --- a/internal/flypg/ssh.go +++ b/internal/flypg/ssh.go @@ -6,18 +6,38 @@ import ( "os/exec" ) -func writeSSHKey() error { - var ( - key = os.Getenv("SSH_KEY") - cert = os.Getenv("SSH_CERT") - ) +const ( + privateKeyFile = "/data/.ssh/id_rsa" + publicKeyFile = "/data/.ssh/id_rsa-cert.pub" +) +func writeSSHKey() error { err := os.Mkdir("/data/.ssh", 0700) if err != nil && !os.IsExist(err) { return err } - keyFile, err := os.Create("/data/.ssh/id_rsa") + if err := writePrivateKey(); err != nil { + return fmt.Errorf("failed to write private key: %s", err) + } + + if err := writePublicKey(); err != nil { + return fmt.Errorf("failed to write cert: %s", err) + } + + cmdStr := fmt.Sprintf("chmod 600 %s %s", privateKeyFile, publicKeyFile) + cmd := exec.Command("sh", "-c", cmdStr) + if _, err := cmd.Output(); err != nil { + return err + } + + return nil +} + +func writePrivateKey() error { + key := os.Getenv("SSH_KEY") + + keyFile, err := os.Create(privateKeyFile) if err != nil { return err } @@ -27,7 +47,13 @@ func writeSSHKey() error { return err } - certFile, err := os.Create("/data/.ssh/id_rsa-cert.pub") + return keyFile.Sync() +} + +func writePublicKey() error { + cert := os.Getenv("SSH_CERT") + + certFile, err := os.Create(publicKeyFile) if err != nil { return err } @@ -38,11 +64,5 @@ func writeSSHKey() error { return err } - cmdStr := fmt.Sprintf("chmod 600 %s %s", "/data/.ssh/id_rsa", "/data/.ssh/id_rsa-cert.pub") - cmd := exec.Command("sh", "-c", cmdStr) - if _, err := cmd.Output(); err != nil { - return err - } - - return nil + return certFile.Sync() } diff --git a/internal/flypg/zombie.go b/internal/flypg/zombie.go index 6d58c486..4c981f45 100644 --- a/internal/flypg/zombie.go +++ b/internal/flypg/zombie.go @@ -27,35 +27,23 @@ const zombieLockFile = "/data/zombie.lock" func ZombieLockExists() bool { _, err := os.Stat(zombieLockFile) - if os.IsNotExist(err) { - return false - } - return true + return !os.IsNotExist(err) } func writeZombieLock(hostname string) error { - if err := os.WriteFile(zombieLockFile, []byte(hostname), 0644); err != nil { - return err - } - - pgUID, pgGID, err := utils.SystemUserIDs("postgres") - if err != nil { + if err := os.WriteFile(zombieLockFile, []byte(hostname), 0600); err != nil { return err } - if err := os.Chown(zombieLockFile, pgUID, pgGID); err != nil { - return fmt.Errorf("failed to set zombie.lock owner: %s", err) + if err := utils.SetFileOwnership(zombieLockFile, "postgres"); err != nil { + return fmt.Errorf("failed to set file ownership: %s", err) } return nil } func RemoveZombieLock() error { - if err := os.Remove(zombieLockFile); err != nil { - return err - } - - return nil + return os.Remove(zombieLockFile) } func ReadZombieLock() (string, error) { @@ -175,7 +163,7 @@ func ZombieDiagnosis(s *DNASample) (string, error) { return "", ErrZombieDiagnosisUndecided } -func Quarantine(ctx context.Context, conn *pgx.Conn, n *Node, primary string) error { +func Quarantine(ctx context.Context, n *Node, primary string) error { if err := writeZombieLock(primary); err != nil { return fmt.Errorf("failed to set zombie lock: %s", err) } diff --git a/internal/utils/response.go b/internal/utils/response.go index 4f3063f8..ced1910b 100644 --- a/internal/utils/response.go +++ b/internal/utils/response.go @@ -3,7 +3,6 @@ package utils import ( "encoding/json" "fmt" - "os" ) type Response struct { @@ -35,5 +34,4 @@ func sendToStdout(resp *Response) { fmt.Println(err.Error()) } fmt.Println(string(e)) - os.Exit(0) } diff --git a/internal/utils/shell.go b/internal/utils/shell.go index e4b7a47d..efabb47b 100644 --- a/internal/utils/shell.go +++ b/internal/utils/shell.go @@ -1,23 +1,37 @@ package utils import ( + "fmt" + "os" "os/exec" "os/user" "strconv" "syscall" ) -func RunCommand(cmdStr, user string) error { +func RunCommand(cmdStr, user string) ([]byte, error) { pgUID, pgGID, err := SystemUserIDs(user) if err != nil { - return err + return nil, err } cmd := exec.Command("sh", "-c", cmdStr) cmd.SysProcAttr = &syscall.SysProcAttr{} cmd.SysProcAttr.Credential = &syscall.Credential{Uid: uint32(pgUID), Gid: uint32(pgGID)} - _, err = cmd.Output() - return err + return cmd.Output() +} + +func SetFileOwnership(pathToFile, owner string) error { + uid, gid, err := SystemUserIDs(owner) + if err != nil { + return fmt.Errorf("failed to resolve system user ids: %s", err) + } + + if err := os.Chown(pathToFile, uid, gid); err != nil { + return fmt.Errorf("failed to set ownership on file %s: %s", pathToFile, err) + } + + return nil } func SystemUserIDs(usr string) (int, int, error) {