diff --git a/go/vt/hook/hook.go b/go/vt/hook/hook.go index 0215a143829..6cee35e4241 100644 --- a/go/vt/hook/hook.go +++ b/go/vt/hook/hook.go @@ -18,6 +18,8 @@ package hook import ( "bytes" + "context" + "errors" "fmt" "io" "os" @@ -25,6 +27,7 @@ import ( "path" "strings" "syscall" + "time" vtenv "vitess.io/vitess/go/vt/env" "vitess.io/vitess/go/vt/log" @@ -69,6 +72,10 @@ const ( // HOOK_GENERIC_ERROR is returned for unknown errors. HOOK_GENERIC_ERROR = -6 + + // HOOK_TIMEOUT_ERROR is returned when a CommandContext has its context + // become done before the command terminates. + HOOK_TIMEOUT_ERROR = -7 ) // WaitFunc is a return type for the Pipe methods. @@ -90,8 +97,8 @@ func NewHookWithEnv(name string, params []string, env map[string]string) *Hook { return &Hook{Name: name, Parameters: params, ExtraEnv: env} } -// findHook trie to locate the hook, and returns the exec.Cmd for it. -func (hook *Hook) findHook() (*exec.Cmd, int, error) { +// findHook tries to locate the hook, and returns the exec.Cmd for it. +func (hook *Hook) findHook(ctx context.Context) (*exec.Cmd, int, error) { // Check the hook path. if strings.Contains(hook.Name, "/") { return nil, HOOK_INVALID_NAME, fmt.Errorf("hook cannot contain '/'") @@ -116,7 +123,7 @@ func (hook *Hook) findHook() (*exec.Cmd, int, error) { // Configure the command. log.Infof("hook: executing hook: %v %v", vthook, strings.Join(hook.Parameters, " ")) - cmd := exec.Command(vthook, hook.Parameters...) + cmd := exec.CommandContext(ctx, vthook, hook.Parameters...) if len(hook.ExtraEnv) > 0 { cmd.Env = os.Environ() for key, value := range hook.ExtraEnv { @@ -127,12 +134,12 @@ func (hook *Hook) findHook() (*exec.Cmd, int, error) { return cmd, HOOK_SUCCESS, nil } -// Execute tries to execute the Hook and returns a HookResult. -func (hook *Hook) Execute() (result *HookResult) { +// ExecuteContext tries to execute the Hook with the given context and returns a HookResult. +func (hook *Hook) ExecuteContext(ctx context.Context) (result *HookResult) { result = &HookResult{} // Find the hook. - cmd, status, err := hook.findHook() + cmd, status, err := hook.findHook(ctx) if err != nil { result.ExitStatus = status result.Stderr = err.Error() + "\n" @@ -143,25 +150,54 @@ func (hook *Hook) Execute() (result *HookResult) { var stdout, stderr bytes.Buffer cmd.Stdout = &stdout cmd.Stderr = &stderr + + start := time.Now() err = cmd.Run() + duration := time.Since(start) + result.Stdout = stdout.String() result.Stderr = stderr.String() + + defer func() { + log.Infof("hook: result is %v", result.String()) + }() + if err == nil { result.ExitStatus = HOOK_SUCCESS - } else { - if cmd.ProcessState != nil && cmd.ProcessState.Sys() != nil { - result.ExitStatus = cmd.ProcessState.Sys().(syscall.WaitStatus).ExitStatus() - } else { - result.ExitStatus = HOOK_CANNOT_GET_EXIT_STATUS - } - result.Stderr += "ERROR: " + err.Error() + "\n" + return result } - log.Infof("hook: result is %v", result.String()) + if ctx.Err() != nil && errors.Is(ctx.Err(), context.DeadlineExceeded) { + // When (exec.Cmd).Run hits a context cancelled, the process is killed via SIGTERM. + // This means: + // 1. cmd.ProcessState.Exited() is false. + // 2. cmd.ProcessState.ExitCode() is -1. + // [ref]: https://golang.org/pkg/os/#ProcessState.ExitCode + // + // Therefore, we need to catch this error specifically, and set result.ExitStatus to + // HOOK_TIMEOUT_ERROR, because just using ExitStatus will result in HOOK_DOES_NOT_EXIST, + // which would be wrong. Since we're already doing some custom handling, we'll also include + // the amount of time the command was running in the error string, in case that is helpful. + result.ExitStatus = HOOK_TIMEOUT_ERROR + result.Stderr += fmt.Sprintf("ERROR: (after %s) %s\n", duration, err) + return result + } + + if cmd.ProcessState != nil && cmd.ProcessState.Sys() != nil { + result.ExitStatus = cmd.ProcessState.Sys().(syscall.WaitStatus).ExitStatus() + } else { + result.ExitStatus = HOOK_CANNOT_GET_EXIT_STATUS + } + result.Stderr += "ERROR: " + err.Error() + "\n" return result } +// Execute tries to execute the Hook and returns a HookResult. +func (hook *Hook) Execute() (result *HookResult) { + return hook.ExecuteContext(context.Background()) +} + // ExecuteOptional executes an optional hook, logs if it doesn't // exist, and returns a printable error. func (hook *Hook) ExecuteOptional() error { @@ -187,7 +223,7 @@ func (hook *Hook) ExecuteOptional() error { // - an error code and an error if anything fails. func (hook *Hook) ExecuteAsWritePipe(out io.Writer) (io.WriteCloser, WaitFunc, int, error) { // Find the hook. - cmd, status, err := hook.findHook() + cmd, status, err := hook.findHook(context.Background()) if err != nil { return nil, nil, status, err } @@ -226,7 +262,7 @@ func (hook *Hook) ExecuteAsWritePipe(out io.Writer) (io.WriteCloser, WaitFunc, i // - an error code and an error if anything fails. func (hook *Hook) ExecuteAsReadPipe(in io.Reader) (io.Reader, WaitFunc, int, error) { // Find the hook. - cmd, status, err := hook.findHook() + cmd, status, err := hook.findHook(context.Background()) if err != nil { return nil, nil, status, err } diff --git a/go/vt/hook/hook_test.go b/go/vt/hook/hook_test.go new file mode 100644 index 00000000000..fd9235f14b3 --- /dev/null +++ b/go/vt/hook/hook_test.go @@ -0,0 +1,39 @@ +package hook + +import ( + "context" + "os" + "os/exec" + "path" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + vtenv "vitess.io/vitess/go/vt/env" +) + +func TestExecuteContext(t *testing.T) { + vtroot, err := vtenv.VtRoot() + require.NoError(t, err) + + sleep, err := exec.LookPath("sleep") + require.NoError(t, err) + + sleepHookPath := path.Join(vtroot, "vthook", "sleep") + require.NoError(t, os.Symlink(sleep, sleepHookPath)) + defer func() { + require.NoError(t, os.Remove(sleepHookPath)) + }() + + h := NewHook("sleep", []string{"5"}) + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) + defer cancel() + + hr := h.ExecuteContext(ctx) + assert.Equal(t, HOOK_TIMEOUT_ERROR, hr.ExitStatus) + + h.Parameters = []string{"0.1"} + hr = h.Execute() + assert.Equal(t, HOOK_SUCCESS, hr.ExitStatus) +} diff --git a/go/vt/mysqlctl/builtinbackupengine.go b/go/vt/mysqlctl/builtinbackupengine.go index cc077bb6610..a1db9af5107 100644 --- a/go/vt/mysqlctl/builtinbackupengine.go +++ b/go/vt/mysqlctl/builtinbackupengine.go @@ -20,6 +20,7 @@ import ( "bufio" "context" "encoding/json" + "flag" "fmt" "io" "os" @@ -47,6 +48,13 @@ const ( dataDictionaryFile = "mysql.ibd" ) +var ( + // BuiltinBackupMysqldTimeout is how long ExecuteBackup should wait for response from mysqld.Shutdown. + // It can later be extended for other calls to mysqld during backup functions. + // Exported for testing. + BuiltinBackupMysqldTimeout = flag.Duration("builtinbackup_mysqld_timeout", 10*time.Minute, "how long to wait for mysqld to shutdown at the start of the backup") +) + // BuiltinBackupEngine encapsulates the logic of the builtin engine // it implements the BackupEngine interface and contains all the logic // required to implement a backup/restore by copying files from and to @@ -182,7 +190,9 @@ func (be *BuiltinBackupEngine) ExecuteBackup(ctx context.Context, params BackupP params.Logger.Infof("using replication position: %v", replicationPosition) // shutdown mysqld - err = params.Mysqld.Shutdown(ctx, params.Cnf, true) + shutdownCtx, cancel := context.WithTimeout(ctx, *BuiltinBackupMysqldTimeout) + err = params.Mysqld.Shutdown(shutdownCtx, params.Cnf, true) + defer cancel() if err != nil { return false, vterrors.Wrap(err, "can't shutdown mysqld") } diff --git a/go/vt/mysqlctl/builtinbackupengine_test.go b/go/vt/mysqlctl/builtinbackupengine_test.go new file mode 100644 index 00000000000..3e9b739839b --- /dev/null +++ b/go/vt/mysqlctl/builtinbackupengine_test.go @@ -0,0 +1,135 @@ +// Package mysqlctl_test is the blackbox tests for package mysqlctl. +// Tests that need to use fakemysqldaemon must be written as blackbox tests; +// since fakemysqldaemon imports mysqlctl, importing fakemysqldaemon in +// a `package mysqlctl` test would cause a circular import. +package mysqlctl_test + +import ( + "context" + "os" + "path" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "vitess.io/vitess/go/mysql/fakesqldb" + "vitess.io/vitess/go/vt/logutil" + "vitess.io/vitess/go/vt/mysqlctl" + "vitess.io/vitess/go/vt/mysqlctl/fakemysqldaemon" + "vitess.io/vitess/go/vt/mysqlctl/filebackupstorage" + "vitess.io/vitess/go/vt/proto/topodata" + "vitess.io/vitess/go/vt/proto/vttime" + "vitess.io/vitess/go/vt/topo" + "vitess.io/vitess/go/vt/topo/memorytopo" + "vitess.io/vitess/go/vt/vttablet/faketmclient" + "vitess.io/vitess/go/vt/vttablet/tmclient" +) + +func setBuiltinBackupMysqldDeadline(t time.Duration) time.Duration { + old := *mysqlctl.BuiltinBackupMysqldTimeout + mysqlctl.BuiltinBackupMysqldTimeout = &t + + return old +} + +func createBackupDir(root string, dirs ...string) error { + for _, dir := range dirs { + if err := os.MkdirAll(path.Join(root, dir), 0755); err != nil { + return err + } + } + + return nil +} + +func TestExecuteBackup(t *testing.T) { + // Set up local backup directory + backupRoot := "testdata/builtinbackup_test" + *filebackupstorage.FileBackupStorageRoot = backupRoot + require.NoError(t, createBackupDir(backupRoot, "innodb", "log", "datadir")) + defer os.RemoveAll(backupRoot) + + ctx := context.Background() + + // Set up topo + keyspace, shard := "mykeyspace", "-80" + ts := memorytopo.NewServer("cell1") + defer ts.Close() + + require.NoError(t, ts.CreateKeyspace(ctx, keyspace, &topodata.Keyspace{})) + require.NoError(t, ts.CreateShard(ctx, keyspace, shard)) + + tablet := topo.NewTablet(100, "cell1", "mykeyspace-00-80-0100") + tablet.Keyspace = keyspace + tablet.Shard = shard + + require.NoError(t, ts.CreateTablet(ctx, tablet)) + + _, err := ts.UpdateShardFields(ctx, keyspace, shard, func(si *topo.ShardInfo) error { + si.MasterAlias = &topodata.TabletAlias{Uid: 100, Cell: "cell1"} + + now := time.Now() + si.MasterTermStartTime = &vttime.Time{Seconds: int64(now.Second()), Nanoseconds: int32(now.Nanosecond())} + + return nil + }) + require.NoError(t, err) + + // Set up tm client + // Note that using faketmclient.NewFakeTabletManagerClient will cause infinite recursion :shrug: + tmclient.RegisterTabletManagerClientFactory("grpc", + func() tmclient.TabletManagerClient { return &faketmclient.FakeTabletManagerClient{} }, + ) + + be := &mysqlctl.BuiltinBackupEngine{} + + // Configure a tight deadline to force a timeout + oldDeadline := setBuiltinBackupMysqldDeadline(time.Second) + defer setBuiltinBackupMysqldDeadline(oldDeadline) + + bh := filebackupstorage.FileBackupHandle{} + + // Spin up a fake daemon to be used in backups. It needs to be allowed to receive: + // "STOP SLAVE", "START SLAVE", in that order. + mysqld := fakemysqldaemon.NewFakeMysqlDaemon(fakesqldb.New(t)) + mysqld.ExpectedExecuteSuperQueryList = []string{"STOP SLAVE", "START SLAVE"} + // mysqld.ShutdownTime = time.Minute + + ok, err := be.ExecuteBackup(ctx, mysqlctl.BackupParams{ + Logger: logutil.NewConsoleLogger(), + Mysqld: mysqld, + Cnf: &mysqlctl.Mycnf{ + InnodbDataHomeDir: path.Join(backupRoot, "innodb"), + InnodbLogGroupHomeDir: path.Join(backupRoot, "log"), + DataDir: path.Join(backupRoot, "datadir"), + }, + HookExtraEnv: map[string]string{}, + TopoServer: ts, + Keyspace: keyspace, + Shard: shard, + }, &bh) + + require.NoError(t, err) + assert.True(t, ok) + + mysqld.ExpectedExecuteSuperQueryCurrent = 0 // resest the index of what queries we've run + mysqld.ShutdownTime = time.Minute // reminder that shutdownDeadline is 1s + + ok, err = be.ExecuteBackup(ctx, mysqlctl.BackupParams{ + Logger: logutil.NewConsoleLogger(), + Mysqld: mysqld, + Cnf: &mysqlctl.Mycnf{ + InnodbDataHomeDir: path.Join(backupRoot, "innodb"), + InnodbLogGroupHomeDir: path.Join(backupRoot, "log"), + DataDir: path.Join(backupRoot, "datadir"), + }, + HookExtraEnv: map[string]string{}, + TopoServer: ts, + Keyspace: keyspace, + Shard: shard, + }, &bh) + + assert.Error(t, err) + assert.False(t, ok) +} diff --git a/go/vt/mysqlctl/fakemysqldaemon/fakemysqldaemon.go b/go/vt/mysqlctl/fakemysqldaemon/fakemysqldaemon.go index f4142ffee20..84e933dfa6f 100644 --- a/go/vt/mysqlctl/fakemysqldaemon/fakemysqldaemon.go +++ b/go/vt/mysqlctl/fakemysqldaemon/fakemysqldaemon.go @@ -48,6 +48,14 @@ type FakeMysqlDaemon struct { // Running is used by Start / Shutdown Running bool + // StartupTime is used to simulate mysqlds that take some time to respond + // to a "start" command. It is used by Start. + StartupTime time.Duration + + // ShutdownTime is used to simulate mysqlds that take some time to respond + // to a "stop" request (i.e. a wedged systemd unit). It is used by Shutdown. + ShutdownTime time.Duration + // MysqlPort will be returned by GetMysqlPort(). Set to -1 to // return an error. MysqlPort sync2.AtomicInt32 @@ -181,6 +189,15 @@ func (fmd *FakeMysqlDaemon) Start(ctx context.Context, cnf *mysqlctl.Mycnf, mysq if fmd.Running { return fmt.Errorf("fake mysql daemon already running") } + + if fmd.StartupTime > 0 { + select { + case <-time.After(fmd.StartupTime): + case <-ctx.Done(): + return ctx.Err() + } + } + fmd.Running = true return nil } @@ -190,6 +207,15 @@ func (fmd *FakeMysqlDaemon) Shutdown(ctx context.Context, cnf *mysqlctl.Mycnf, w if !fmd.Running { return fmt.Errorf("fake mysql daemon not running") } + + if fmd.ShutdownTime > 0 { + select { + case <-time.After(fmd.ShutdownTime): + case <-ctx.Done(): + return ctx.Err() + } + } + fmd.Running = false return nil } diff --git a/go/vt/mysqlctl/mysqld.go b/go/vt/mysqlctl/mysqld.go index 7abd1da7a78..fc6fce373be 100644 --- a/go/vt/mysqlctl/mysqld.go +++ b/go/vt/mysqlctl/mysqld.go @@ -512,7 +512,7 @@ func (mysqld *Mysqld) Shutdown(ctx context.Context, cnf *Mycnf, waitForMysqld bo // try the mysqld shutdown hook, if any h := hook.NewSimpleHook("mysqld_shutdown") - hr := h.Execute() + hr := h.ExecuteContext(ctx) switch hr.ExitStatus { case hook.HOOK_SUCCESS: // hook exists and worked, we can keep going