Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support canceling remote commands via context. #1

Merged
merged 1 commit into from
Jun 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 53 additions & 68 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package winrm

import (
"bytes"
"context"
"crypto/x509"
"errors"
"fmt"
"io"
"strings"
"sync"

"github.com/masterzen/winrm/soap"
Expand Down Expand Up @@ -105,82 +107,47 @@ func (c *Client) sendRequest(request *soap.SoapMessage) (string, error) {

// Run will run command on the the remote host, writing the process stdout and stderr to
// the given writers. Note with this method it isn't possible to inject stdin.
//
// Deprecated: use RunWithContext()
func (c *Client) Run(command string, stdout io.Writer, stderr io.Writer) (int, error) {
shell, err := c.CreateShell()
if err != nil {
return 1, err
}
defer shell.Close()
cmd, err := shell.Execute(command)
if err != nil {
return 1, err
}

var wg sync.WaitGroup
wg.Add(2)

go func() {
defer wg.Done()
io.Copy(stdout, cmd.Stdout)
}()

go func() {
defer wg.Done()
io.Copy(stderr, cmd.Stderr)
}()

cmd.Wait()
wg.Wait()
cmd.Close()
return c.RunWithContext(context.Background(), command, stdout, stderr)
}

return cmd.ExitCode(), cmd.err
// RunWithContext will run command on the the remote host, writing the process stdout and stderr to
// the given writers. Note with this method it isn't possible to inject stdin.
// If the context is canceled, the remote command is canceled.
func (c *Client) RunWithContext(ctx context.Context, command string, stdout io.Writer, stderr io.Writer) (int, error) {
return c.RunWithContextWithInput(ctx, command, stdout, stderr, nil)
}

// RunWithString will run command on the the remote host, returning the process stdout and stderr
// as strings, and using the input stdin string as the process input
//
// Deprecated: use RunWithContextWithString()
func (c *Client) RunWithString(command string, stdin string) (string, string, int, error) {
shell, err := c.CreateShell()
if err != nil {
return "", "", 1, err
}
defer shell.Close()

cmd, err := shell.Execute(command)
if err != nil {
return "", "", 1, err
}

if len(stdin) > 0 {
defer cmd.Stdin.Close()
_, err := cmd.Stdin.Write([]byte(stdin))
if err != nil {
return "", "", -1, err
}
}
return c.RunWithContextWithString(context.Background(), command, stdin)
}

// RunWithContextWithString will run command on the the remote host, returning the process stdout and stderr
// as strings, and using the input stdin string as the process input
// If the context is canceled, the remote command is canceled.
func (c *Client) RunWithContextWithString(ctx context.Context, command string, stdin string) (string, string, int, error) {
var outWriter, errWriter bytes.Buffer
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
io.Copy(&outWriter, cmd.Stdout)
}()

go func() {
defer wg.Done()
io.Copy(&errWriter, cmd.Stderr)
}()

cmd.Wait()
wg.Wait()
cmd.Close()

return outWriter.String(), errWriter.String(), cmd.ExitCode(), cmd.err
exitCode, err := c.RunWithContextWithInput(ctx, command, &outWriter, &errWriter, strings.NewReader(stdin))
return outWriter.String(), errWriter.String(), exitCode, err
}

//RunPSWithString will basically wrap your code to execute commands in powershell.exe. Default RunWithString
// RunPSWithString will basically wrap your code to execute commands in powershell.exe. Default RunWithString
// runs commands in cmd.exe
//
// Deprecated: use RunPSWithContextWithString()
func (c *Client) RunPSWithString(command string, stdin string) (string, string, int, error) {
return c.RunPSWithContextWithString(context.Background(), command, stdin)
}

// RunPSWithContextWithString will basically wrap your code to execute commands in powershell.exe. Default RunWithString
// runs commands in cmd.exe
func (c *Client) RunPSWithContextWithString(ctx context.Context, command string, stdin string) (string, string, int, error) {
command = Powershell(command)

// Let's check if we actually created a command
Expand All @@ -189,21 +156,35 @@ func (c *Client) RunPSWithString(command string, stdin string) (string, string,
}

// Specify powershell.exe to run encoded command
return c.RunWithString(command, stdin)
return c.RunWithContextWithString(ctx, command, stdin)
}

// RunWithInput will run command on the the remote host, writing the process stdout and stderr to
// the given writers, and injecting the process stdin with the stdin reader.
// Warning stdin (not stdout/stderr) are bufferized, which means reading only one byte in stdin will
// send a winrm http packet to the remote host. If stdin is a pipe, it might be better for
// performance reasons to buffer it.
func (c Client) RunWithInput(command string, stdout, stderr io.Writer, stdin io.Reader) (int, error) {
// If stdin is nil, this is equivalent to c.Run()
//
// Deprecated: use RunWithContextWithInput()
func (c *Client) RunWithInput(command string, stdout, stderr io.Writer, stdin io.Reader) (int, error) {
return c.RunWithContextWithInput(context.Background(), command, stdout, stderr, stdin)
}

// RunWithContextWithInput will run command on the the remote host, writing the process stdout and stderr to
// the given writers, and injecting the process stdin with the stdin reader.
// If the context is canceled, the command on the remote machine is canceled.
// Warning stdin (not stdout/stderr) are bufferized, which means reading only one byte in stdin will
// send a winrm http packet to the remote host. If stdin is a pipe, it might be better for
// performance reasons to buffer it.
// If stdin is nil, this is equivalent to c.RunWithContext()
func (c *Client) RunWithContextWithInput(ctx context.Context, command string, stdout, stderr io.Writer, stdin io.Reader) (int, error) {
shell, err := c.CreateShell()
if err != nil {
return 1, err
}
defer shell.Close()
cmd, err := shell.Execute(command)
cmd, err := shell.ExecuteWithContext(ctx, command)
if err != nil {
return 1, err
}
Expand All @@ -213,9 +194,14 @@ func (c Client) RunWithInput(command string, stdout, stderr io.Writer, stdin io.

go func() {
defer func() {
cmd.Stdin.Close()
wg.Done()
}()
if stdin == nil {
return
}
defer func() {
cmd.Stdin.Close()
}()
io.Copy(cmd.Stdin, stdin)
}()
go func() {
Expand All @@ -232,5 +218,4 @@ func (c Client) RunWithInput(command string, stdout, stderr io.Writer, stdin io.
cmd.Close()

return cmd.ExitCode(), cmd.err

}
16 changes: 13 additions & 3 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package winrm

import (
"bytes"
"context"
"errors"
"io"
"strings"
Expand Down Expand Up @@ -39,7 +40,7 @@ type Command struct {
cancel chan struct{}
}

func newCommand(shell *Shell, ids string) *Command {
func newCommand(ctx context.Context, shell *Shell, ids string) *Command {
command := &Command{
shell: shell,
client: shell.client,
Expand All @@ -57,7 +58,7 @@ func newCommand(shell *Shell, ids string) *Command {
}
command.Stderr = newCommandReader("stderr", command)

go fetchOutput(command)
go fetchOutput(ctx, command)

return command
}
Expand All @@ -72,12 +73,21 @@ func newCommandReader(stream string, command *Command) *commandReader {
}
}

func fetchOutput(command *Command) {
func fetchOutput(ctx context.Context, command *Command) {
for {
ctxDone := ctx.Done()
select {
case <-command.cancel:
_, _ = command.slurpAllOutput()
err := errors.New("canceled")
command.Stderr.write.CloseWithError(err)
command.Stdout.write.CloseWithError(err)
close(command.done)
return
case <-ctxDone:
command.err = ctx.Err()
ctxDone = nil
command.Close()
default:
finished, err := command.slurpAllOutput()
if finished {
Expand Down
11 changes: 10 additions & 1 deletion shell.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
package winrm

import "context"

// Shell is the local view of a WinRM Shell of a given Client
type Shell struct {
client *Client
id string
}

// Execute command on the given Shell, returning either an error or a Command
//
// Deprecated: user ExecuteWithContext
func (s *Shell) Execute(command string, arguments ...string) (*Command, error) {
return s.ExecuteWithContext(context.Background(), command, arguments...)
}

// ExecuteWithContext command on the given Shell, returning either an error or a Command
func (s *Shell) ExecuteWithContext(ctx context.Context, command string, arguments ...string) (*Command, error) {
request := NewExecuteCommandRequest(s.client.url, s.id, command, arguments, &s.client.Parameters)
defer request.Free()

Expand All @@ -21,7 +30,7 @@ func (s *Shell) Execute(command string, arguments ...string) (*Command, error) {
return nil, err
}

cmd := newCommand(s, commandID)
cmd := newCommand(ctx, s, commandID)

return cmd, nil
}
Expand Down