Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 24 additions & 45 deletions cli/connhelper/commandconn/commandconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,43 +80,48 @@ type commandConn struct {
remoteAddr net.Addr
}

// kill returns nil if the command terminated, regardless to the exit status.
func (c *commandConn) kill() error {
// kill terminates the process. On Windows it kills the process directly,
// whereas on other platforms, a SIGTERM is sent, before forcefully terminating
// the process after 3 seconds.
func (c *commandConn) kill() {
if c.cmdExited.Load() {
return nil
return
}
c.cmdMutex.Lock()
var werr error
if runtime.GOOS != "windows" {
werrCh := make(chan error)
go func() { werrCh <- c.cmd.Wait() }()
c.cmd.Process.Signal(syscall.SIGTERM)
_ = c.cmd.Process.Signal(syscall.SIGTERM)
select {
case werr = <-werrCh:
case <-time.After(3 * time.Second):
c.cmd.Process.Kill()
_ = c.cmd.Process.Kill()
werr = <-werrCh
}
} else {
c.cmd.Process.Kill()
_ = c.cmd.Process.Kill()
werr = c.cmd.Wait()
}
c.cmdWaitErr = werr
c.cmdMutex.Unlock()
c.cmdExited.Store(true)
return nil
}

// onEOF gets called if we receive an io.EOF while reading
// or writing from the undelying command pipes.
// handleEOF handles io.EOF errors while reading or writing from the underlying
// command pipes.
//
// When we've received an EOF we expect that the command will
// be terminated soon. As such, we call Wait() on the command
// and return EOF or the error depending on whether the command
// exited with an error.
//
// If Wait() does not return within 10s, an error is returned
func (c *commandConn) onEOF(eof error) error {
func (c *commandConn) handleEOF(err error) error {
if err != io.EOF {
return err
}

c.cmdMutex.Lock()
defer c.cmdMutex.Unlock()

Expand All @@ -134,12 +139,12 @@ func (c *commandConn) onEOF(eof error) error {
c.stderrMu.Lock()
stderr := c.stderr.String()
c.stderrMu.Unlock()
return errors.Errorf("command %v did not exit after %v: stderr=%q", c.cmd.Args, eof, stderr)
return errors.Errorf("command %v did not exit after %v: stderr=%q", c.cmd.Args, err, stderr)
}
}

if werr == nil {
return eof
return err
}
c.stderrMu.Lock()
stderr := c.stderr.String()
Expand All @@ -148,16 +153,7 @@ func (c *commandConn) onEOF(eof error) error {
}

func ignorableCloseError(err error) bool {
errS := err.Error()
ss := []string{
os.ErrClosed.Error(),
}
for _, s := range ss {
if strings.Contains(errS, s) {
return true
}
}
return false
return strings.Contains(err.Error(), os.ErrClosed.Error())
}

func (c *commandConn) Read(p []byte) (int, error) {
Expand All @@ -172,10 +168,7 @@ func (c *commandConn) Read(p []byte) (int, error) {
return 0, io.EOF
}

if err == io.EOF {
err = c.onEOF(err)
}
return n, err
return n, c.handleEOF(err)
}

func (c *commandConn) Write(p []byte) (int, error) {
Expand All @@ -190,26 +183,19 @@ func (c *commandConn) Write(p []byte) (int, error) {
return 0, io.EOF
}

if err == io.EOF {
err = c.onEOF(err)
}
return n, err
return n, c.handleEOF(err)
}

// CloseRead allows commandConn to implement halfCloser
func (c *commandConn) CloseRead() error {
// NOTE: maybe already closed here
if err := c.stdout.Close(); err != nil && !ignorableCloseError(err) {
logrus.Warnf("commandConn.CloseRead: %v", err)
return err
}
c.stdoutClosed.Store(true)

if c.stdinClosed.Load() {
if err := c.kill(); err != nil {
logrus.Warnf("commandConn.CloseRead: %v", err)
return err
}
c.kill()
}

return nil
Expand All @@ -219,18 +205,13 @@ func (c *commandConn) CloseRead() error {
func (c *commandConn) CloseWrite() error {
// NOTE: maybe already closed here
if err := c.stdin.Close(); err != nil && !ignorableCloseError(err) {
logrus.Warnf("commandConn.CloseWrite: %v", err)
return err
}
c.stdinClosed.Store(true)

if c.stdoutClosed.Load() {
if err := c.kill(); err != nil {
logrus.Warnf("commandConn.CloseWrite: %v", err)
return err
}
c.kill()
}

return nil
}

Expand All @@ -244,13 +225,11 @@ func (c *commandConn) Close() error {
c.closing.Store(true)
defer c.closing.Store(false)

err := c.CloseRead()
if err != nil {
if err := c.CloseRead(); err != nil {
logrus.Warnf("commandConn.Close: CloseRead: %v", err)
return err
}
err = c.CloseWrite()
if err != nil {
if err := c.CloseWrite(); err != nil {
logrus.Warnf("commandConn.Close: CloseWrite: %v", err)
return err
}
Expand Down