Skip to content

Commit

Permalink
Support canceling remote commands via context. (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
kgadams authored Jun 3, 2022
1 parent 69f69af commit 3206f10
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 72 deletions.
121 changes: 53 additions & 68 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package winrm

import (
"bytes"
"context"
"crypto/x509"
"errors"
"fmt"
"io"
"strings"
"sync"

"github.com/masterzen/winrm/soap"
Expand Down Expand Up @@ -105,82 +107,47 @@ func (c *Client) sendRequest(request *soap.SoapMessage) (string, error) {

// Run will run command on the the remote host, writing the process stdout and stderr to
// the given writers. Note with this method it isn't possible to inject stdin.
//
// Deprecated: use RunWithContext()
func (c *Client) Run(command string, stdout io.Writer, stderr io.Writer) (int, error) {
shell, err := c.CreateShell()
if err != nil {
return 1, err
}
defer shell.Close()
cmd, err := shell.Execute(command)
if err != nil {
return 1, err
}

var wg sync.WaitGroup
wg.Add(2)

go func() {
defer wg.Done()
io.Copy(stdout, cmd.Stdout)
}()

go func() {
defer wg.Done()
io.Copy(stderr, cmd.Stderr)
}()

cmd.Wait()
wg.Wait()
cmd.Close()
return c.RunWithContext(context.Background(), command, stdout, stderr)
}

return cmd.ExitCode(), cmd.err
// RunWithContext will run command on the the remote host, writing the process stdout and stderr to
// the given writers. Note with this method it isn't possible to inject stdin.
// If the context is canceled, the remote command is canceled.
func (c *Client) RunWithContext(ctx context.Context, command string, stdout io.Writer, stderr io.Writer) (int, error) {
return c.RunWithContextWithInput(ctx, command, stdout, stderr, nil)
}

// RunWithString will run command on the the remote host, returning the process stdout and stderr
// as strings, and using the input stdin string as the process input
//
// Deprecated: use RunWithContextWithString()
func (c *Client) RunWithString(command string, stdin string) (string, string, int, error) {
shell, err := c.CreateShell()
if err != nil {
return "", "", 1, err
}
defer shell.Close()

cmd, err := shell.Execute(command)
if err != nil {
return "", "", 1, err
}

if len(stdin) > 0 {
defer cmd.Stdin.Close()
_, err := cmd.Stdin.Write([]byte(stdin))
if err != nil {
return "", "", -1, err
}
}
return c.RunWithContextWithString(context.Background(), command, stdin)
}

// RunWithContextWithString will run command on the the remote host, returning the process stdout and stderr
// as strings, and using the input stdin string as the process input
// If the context is canceled, the remote command is canceled.
func (c *Client) RunWithContextWithString(ctx context.Context, command string, stdin string) (string, string, int, error) {
var outWriter, errWriter bytes.Buffer
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
io.Copy(&outWriter, cmd.Stdout)
}()

go func() {
defer wg.Done()
io.Copy(&errWriter, cmd.Stderr)
}()

cmd.Wait()
wg.Wait()
cmd.Close()

return outWriter.String(), errWriter.String(), cmd.ExitCode(), cmd.err
exitCode, err := c.RunWithContextWithInput(ctx, command, &outWriter, &errWriter, strings.NewReader(stdin))
return outWriter.String(), errWriter.String(), exitCode, err
}

//RunPSWithString will basically wrap your code to execute commands in powershell.exe. Default RunWithString
// RunPSWithString will basically wrap your code to execute commands in powershell.exe. Default RunWithString
// runs commands in cmd.exe
//
// Deprecated: use RunPSWithContextWithString()
func (c *Client) RunPSWithString(command string, stdin string) (string, string, int, error) {
return c.RunPSWithContextWithString(context.Background(), command, stdin)
}

// RunPSWithContextWithString will basically wrap your code to execute commands in powershell.exe. Default RunWithString
// runs commands in cmd.exe
func (c *Client) RunPSWithContextWithString(ctx context.Context, command string, stdin string) (string, string, int, error) {
command = Powershell(command)

// Let's check if we actually created a command
Expand All @@ -189,21 +156,35 @@ func (c *Client) RunPSWithString(command string, stdin string) (string, string,
}

// Specify powershell.exe to run encoded command
return c.RunWithString(command, stdin)
return c.RunWithContextWithString(ctx, command, stdin)
}

// RunWithInput will run command on the the remote host, writing the process stdout and stderr to
// the given writers, and injecting the process stdin with the stdin reader.
// Warning stdin (not stdout/stderr) are bufferized, which means reading only one byte in stdin will
// send a winrm http packet to the remote host. If stdin is a pipe, it might be better for
// performance reasons to buffer it.
func (c Client) RunWithInput(command string, stdout, stderr io.Writer, stdin io.Reader) (int, error) {
// If stdin is nil, this is equivalent to c.Run()
//
// Deprecated: use RunWithContextWithInput()
func (c *Client) RunWithInput(command string, stdout, stderr io.Writer, stdin io.Reader) (int, error) {
return c.RunWithContextWithInput(context.Background(), command, stdout, stderr, stdin)
}

// RunWithContextWithInput will run command on the the remote host, writing the process stdout and stderr to
// the given writers, and injecting the process stdin with the stdin reader.
// If the context is canceled, the command on the remote machine is canceled.
// Warning stdin (not stdout/stderr) are bufferized, which means reading only one byte in stdin will
// send a winrm http packet to the remote host. If stdin is a pipe, it might be better for
// performance reasons to buffer it.
// If stdin is nil, this is equivalent to c.RunWithContext()
func (c *Client) RunWithContextWithInput(ctx context.Context, command string, stdout, stderr io.Writer, stdin io.Reader) (int, error) {
shell, err := c.CreateShell()
if err != nil {
return 1, err
}
defer shell.Close()
cmd, err := shell.Execute(command)
cmd, err := shell.ExecuteWithContext(ctx, command)
if err != nil {
return 1, err
}
Expand All @@ -213,9 +194,14 @@ func (c Client) RunWithInput(command string, stdout, stderr io.Writer, stdin io.

go func() {
defer func() {
cmd.Stdin.Close()
wg.Done()
}()
if stdin == nil {
return
}
defer func() {
cmd.Stdin.Close()
}()
io.Copy(cmd.Stdin, stdin)
}()
go func() {
Expand All @@ -232,5 +218,4 @@ func (c Client) RunWithInput(command string, stdout, stderr io.Writer, stdin io.
cmd.Close()

return cmd.ExitCode(), cmd.err

}
16 changes: 13 additions & 3 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package winrm

import (
"bytes"
"context"
"errors"
"io"
"strings"
Expand Down Expand Up @@ -39,7 +40,7 @@ type Command struct {
cancel chan struct{}
}

func newCommand(shell *Shell, ids string) *Command {
func newCommand(ctx context.Context, shell *Shell, ids string) *Command {
command := &Command{
shell: shell,
client: shell.client,
Expand All @@ -57,7 +58,7 @@ func newCommand(shell *Shell, ids string) *Command {
}
command.Stderr = newCommandReader("stderr", command)

go fetchOutput(command)
go fetchOutput(ctx, command)

return command
}
Expand All @@ -72,12 +73,21 @@ func newCommandReader(stream string, command *Command) *commandReader {
}
}

func fetchOutput(command *Command) {
func fetchOutput(ctx context.Context, command *Command) {
for {
ctxDone := ctx.Done()
select {
case <-command.cancel:
_, _ = command.slurpAllOutput()
err := errors.New("canceled")
command.Stderr.write.CloseWithError(err)
command.Stdout.write.CloseWithError(err)
close(command.done)
return
case <-ctxDone:
command.err = ctx.Err()
ctxDone = nil
command.Close()
default:
finished, err := command.slurpAllOutput()
if finished {
Expand Down
11 changes: 10 additions & 1 deletion shell.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
package winrm

import "context"

// Shell is the local view of a WinRM Shell of a given Client
type Shell struct {
client *Client
id string
}

// Execute command on the given Shell, returning either an error or a Command
//
// Deprecated: user ExecuteWithContext
func (s *Shell) Execute(command string, arguments ...string) (*Command, error) {
return s.ExecuteWithContext(context.Background(), command, arguments...)
}

// ExecuteWithContext command on the given Shell, returning either an error or a Command
func (s *Shell) ExecuteWithContext(ctx context.Context, command string, arguments ...string) (*Command, error) {
request := NewExecuteCommandRequest(s.client.url, s.id, command, arguments, &s.client.Parameters)
defer request.Free()

Expand All @@ -21,7 +30,7 @@ func (s *Shell) Execute(command string, arguments ...string) (*Command, error) {
return nil, err
}

cmd := newCommand(s, commandID)
cmd := newCommand(ctx, s, commandID)

return cmd, nil
}
Expand Down

0 comments on commit 3206f10

Please sign in to comment.