diff --git a/cmd/sshd/channel_handlers.go b/cmd/sshd/channel_handlers.go index ded13a7..bb6259a 100644 --- a/cmd/sshd/channel_handlers.go +++ b/cmd/sshd/channel_handlers.go @@ -1,4 +1,4 @@ -// +build !windows +// +build !windows2012R2 package main diff --git a/cmd/sshd/channel_handlers_windows.go b/cmd/sshd/channel_handlers_windows2012R2.go similarity index 69% rename from cmd/sshd/channel_handlers_windows.go rename to cmd/sshd/channel_handlers_windows2012R2.go index 6f2e197..724eac7 100644 --- a/cmd/sshd/channel_handlers_windows.go +++ b/cmd/sshd/channel_handlers_windows2012R2.go @@ -1,8 +1,10 @@ -// +build windows +// +build windows2012R2 package main -import "code.cloudfoundry.org/diego-ssh/handlers" +import ( + "code.cloudfoundry.org/diego-ssh/handlers" +) func newChannelHandlers() map[string]handlers.NewChannelHandler { return map[string]handlers.NewChannelHandler{ diff --git a/cmd/sshd/helpers_internal_port_test.go b/cmd/sshd/helpers_test.go similarity index 96% rename from cmd/sshd/helpers_internal_port_test.go rename to cmd/sshd/helpers_test.go index c75e2d8..220bb19 100644 --- a/cmd/sshd/helpers_internal_port_test.go +++ b/cmd/sshd/helpers_test.go @@ -1,4 +1,4 @@ -// +build !external +// +build !windows2012R2 package main_test diff --git a/cmd/sshd/helpers_external_port_test.go b/cmd/sshd/helpers_windows2012R2_test.go similarity index 91% rename from cmd/sshd/helpers_external_port_test.go rename to cmd/sshd/helpers_windows2012R2_test.go index 4d44301..cf9d9bb 100644 --- a/cmd/sshd/helpers_external_port_test.go +++ b/cmd/sshd/helpers_windows2012R2_test.go @@ -1,4 +1,4 @@ -// +build external +// +build windows2012R2 package main_test @@ -13,7 +13,7 @@ import ( ) func buildSshd() string { - sshd, err := gexec.Build("code.cloudfoundry.org/diego-ssh/cmd/sshd", "-race", "-tags", "external") + sshd, err := gexec.Build("code.cloudfoundry.org/diego-ssh/cmd/sshd", "-race", "-tags", "windows2012R2") Expect(err).NotTo(HaveOccurred()) return sshd } diff --git a/cmd/sshd/main_internal_port.go b/cmd/sshd/main_port.go similarity index 91% rename from cmd/sshd/main_internal_port.go rename to cmd/sshd/main_port.go index e7c982e..2ffbbfa 100644 --- a/cmd/sshd/main_internal_port.go +++ b/cmd/sshd/main_port.go @@ -1,4 +1,4 @@ -// +build !external +// +build !windows2012R2 package main diff --git a/cmd/sshd/main_external_port.go b/cmd/sshd/main_port_windows2012R2.go similarity index 97% rename from cmd/sshd/main_external_port.go rename to cmd/sshd/main_port_windows2012R2.go index bafc559..624c73f 100644 --- a/cmd/sshd/main_external_port.go +++ b/cmd/sshd/main_port_windows2012R2.go @@ -1,4 +1,4 @@ -// +build external +// +build windows2012R2 package main diff --git a/cmd/sshd/main_suite_test.go b/cmd/sshd/main_suite_test.go index 6227c6f..69cad57 100644 --- a/cmd/sshd/main_suite_test.go +++ b/cmd/sshd/main_suite_test.go @@ -2,6 +2,8 @@ package main_test import ( "encoding/json" + "os" + "runtime" "code.cloudfoundry.org/diego-ssh/keys" . "github.com/onsi/ginkgo" @@ -26,6 +28,11 @@ func TestSSHDaemon(t *testing.T) { } var _ = SynchronizedBeforeSuite(func() []byte { + if runtime.GOOS == "windows" { + if os.Getenv("WINPTY_DLL_PATH") == "" { + Fail("Missing WINPTY_DLL_PATH environment variable") + } + } sshd := buildSshd() hostKey, err := keys.RSAKeyPairFactory.NewKeyPair(1024) diff --git a/cmd/sshd/main_test.go b/cmd/sshd/main_test.go index 61e03af..c823112 100644 --- a/cmd/sshd/main_test.go +++ b/cmd/sshd/main_test.go @@ -1,4 +1,4 @@ -// +build !windows +// +build !windows2012R2 package main_test @@ -11,6 +11,7 @@ import ( "os" "os/exec" "regexp" + "runtime" "strconv" "strings" "time" @@ -241,6 +242,10 @@ var _ = Describe("SSH daemon", func() { var ItDoesNotExposeSensitiveInformation = func() { It("does not expose the key on the command line", func() { + if runtime.GOOS == "windows" { + Skip("no fork/exec on windows") + } + pid := runner.(*ginkgomon.Runner).Command.Process.Pid command := exec.Command("ps", "-fp", strconv.Itoa(pid)) session, err := gexec.Start(command, GinkgoWriter, GinkgoWriter) @@ -489,10 +494,17 @@ var _ = Describe("SSH daemon", func() { session, err := client.NewSession() Expect(err).NotTo(HaveOccurred()) - result, err := session.Output("/bin/echo -n 'Hello there!'") + var cmd string + if runtime.GOOS == "windows" { + cmd = "echo Hello There!" + } else { + cmd = "/bin/echo -n 'Hello There!'" + } + + result, err := session.Output(cmd) Expect(err).NotTo(HaveOccurred()) - Expect(string(result)).To(Equal("Hello there!")) + Expect(strings.TrimSpace(string(result))).To(Equal(strings.TrimSpace("Hello There!"))) }) }) @@ -510,7 +522,7 @@ var _ = Describe("SSH daemon", func() { stdout := &bytes.Buffer{} - session.Stdin = strings.NewReader("/bin/echo -n $ENV_VAR") + session.Stdin = strings.NewReader(envVarCmd("ENV_VAR")) session.Stdout = stdout session.Setenv("ENV_VAR", "env_var_value") @@ -529,7 +541,7 @@ var _ = Describe("SSH daemon", func() { stdout := &bytes.Buffer{} - session.Stdin = strings.NewReader("/bin/echo -n $TEST") + session.Stdin = strings.NewReader(envVarCmd("TEST")) session.Stdout = stdout err = session.Shell() @@ -547,7 +559,7 @@ var _ = Describe("SSH daemon", func() { stdout := &bytes.Buffer{} - session.Stdin = strings.NewReader("/bin/echo -n $PATH") + session.Stdin = strings.NewReader(envVarCmd("PATH")) session.Stdout = stdout err = session.Shell() @@ -571,7 +583,7 @@ var _ = Describe("SSH daemon", func() { stdout := &bytes.Buffer{} - session.Stdin = strings.NewReader("/bin/echo -n $ENV_VAR") + session.Stdin = strings.NewReader(envVarCmd("ENV_VAR")) session.Stdout = stdout session.Setenv("ENV_VAR", "env_var_value") @@ -590,7 +602,7 @@ var _ = Describe("SSH daemon", func() { stdout := &bytes.Buffer{} - session.Stdin = strings.NewReader("/bin/echo -n $TEST") + session.Stdin = strings.NewReader(envVarCmd("TEST")) session.Stdout = stdout err = session.Shell() @@ -639,3 +651,11 @@ var _ = Describe("SSH daemon", func() { }) }) }) + +func envVarCmd(envVar string) string { + if runtime.GOOS == "windows" { + return "echo %" + envVar + "%\r\n" + } + + return fmt.Sprintf("/bin/echo -n $%s", envVar) +} diff --git a/cmd/sshd/main_windows_test.go b/cmd/sshd/main_windows2012R2_test.go similarity index 98% rename from cmd/sshd/main_windows_test.go rename to cmd/sshd/main_windows2012R2_test.go index 6c60efd..0171e9e 100644 --- a/cmd/sshd/main_windows_test.go +++ b/cmd/sshd/main_windows2012R2_test.go @@ -1,4 +1,4 @@ -// +build windows +// +build windows2012R2 package main_test diff --git a/handlers/handlers_suite_test.go b/handlers/handlers_suite_test.go index ab4dc3e..d7a76f4 100644 --- a/handlers/handlers_suite_test.go +++ b/handlers/handlers_suite_test.go @@ -1,6 +1,9 @@ package handlers_test import ( + "os" + "runtime" + "code.cloudfoundry.org/diego-ssh/keys" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -21,4 +24,10 @@ var _ = BeforeSuite(func() { Expect(err).NotTo(HaveOccurred()) TestHostKey = hostKey.PrivateKey() + + if runtime.GOOS == "windows" { + if os.Getenv("WINPTY_DLL_PATH") == "" { + Fail("Missing WINPTY_DLL_PATH environment variable") + } + } }) diff --git a/handlers/session_channel_handler.go b/handlers/session_channel_handler.go index 74b74a5..980a23d 100644 --- a/handlers/session_channel_handler.go +++ b/handlers/session_channel_handler.go @@ -1,4 +1,4 @@ -// +build !windows +// +build !windows,!windows2012R2 package handlers diff --git a/handlers/session_channel_handler_test.go b/handlers/session_channel_handler_test.go index 754a58a..0448545 100644 --- a/handlers/session_channel_handler_test.go +++ b/handlers/session_channel_handler_test.go @@ -1,4 +1,4 @@ -// +build !windows +// +build !windows,!windows2012R2 package handlers_test diff --git a/handlers/session_channel_handler_windows.go b/handlers/session_channel_handler_windows2012R2.go similarity index 90% rename from handlers/session_channel_handler_windows.go rename to handlers/session_channel_handler_windows2012R2.go index cbdd85a..1c70a6e 100644 --- a/handlers/session_channel_handler_windows.go +++ b/handlers/session_channel_handler_windows2012R2.go @@ -1,4 +1,4 @@ -// +build windows +// +build windows2012R2 package handlers @@ -15,7 +15,7 @@ func NewSessionChannelHandler() *SessionChannelHandler { } func (handler *SessionChannelHandler) HandleNewChannel(logger lager.Logger, newChannel ssh.NewChannel) { - err := newChannel.Reject(ssh.Prohibited, "SSH is not supported on windows cells") + err := newChannel.Reject(ssh.Prohibited, "SSH is not supported on windows2012R2 cells") if err != nil { logger.Error("handle-new-session-channel-failed", err) } diff --git a/handlers/session_channel_handler_windows_test.go b/handlers/session_channel_handler_windows2012R2_test.go similarity index 98% rename from handlers/session_channel_handler_windows_test.go rename to handlers/session_channel_handler_windows2012R2_test.go index 9b5e094..822e062 100644 --- a/handlers/session_channel_handler_windows_test.go +++ b/handlers/session_channel_handler_windows2012R2_test.go @@ -1,4 +1,4 @@ -// +build windows +// +build windows2012R2 package handlers_test @@ -68,7 +68,6 @@ var _ = Describe("SessionChannelHandler", func() { }) Context("when a session is opened", func() { - It("doesn't accept sessions", func() { _, sessionErr := client.NewSession() diff --git a/handlers/session_channel_handler_windows2016.go b/handlers/session_channel_handler_windows2016.go new file mode 100644 index 0000000..8db596e --- /dev/null +++ b/handlers/session_channel_handler_windows2016.go @@ -0,0 +1,639 @@ +// +build windows,!windows2012R2 + +package handlers + +import ( + "errors" + "fmt" + "os" + "os/exec" + "regexp" + "sync" + "syscall" + "time" + + "code.cloudfoundry.org/diego-ssh/helpers" + "code.cloudfoundry.org/diego-ssh/scp" + "code.cloudfoundry.org/diego-ssh/signals" + "code.cloudfoundry.org/diego-ssh/winpty" + "code.cloudfoundry.org/lager" + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" +) + +var scpRegex = regexp.MustCompile(`^\s*scp($|\s+)`) + +type SessionChannelHandler struct { + runner Runner + shellLocator ShellLocator + defaultEnv map[string]string + keepalive time.Duration + winPTYDllPath string +} + +func NewSessionChannelHandler( + runner Runner, + shellLocator ShellLocator, + defaultEnv map[string]string, + keepalive time.Duration, +) *SessionChannelHandler { + winPTYDllPath := os.Getenv("WINPTY_DLL_PATH") + return &SessionChannelHandler{ + runner: runner, + shellLocator: shellLocator, + defaultEnv: defaultEnv, + keepalive: keepalive, + winPTYDllPath: winPTYDllPath, + } +} + +func (handler *SessionChannelHandler) HandleNewChannel(logger lager.Logger, newChannel ssh.NewChannel) { + channel, requests, err := newChannel.Accept() + if err != nil { + logger.Error("handle-new-session-channel-failed", err) + return + } + + handler.newSession(logger, channel, handler.keepalive).serviceRequests(requests) +} + +type ptyRequestMsg struct { + Term string + Columns uint32 + Rows uint32 + Width uint32 + Height uint32 + Modelist string +} + +type session struct { + logger lager.Logger + complete bool + keepaliveDuration time.Duration + keepaliveStopCh chan struct{} + + shellPath string + runner Runner + channel ssh.Channel + + sync.Mutex + env map[string]string + command *exec.Cmd + + wg sync.WaitGroup + allocPty bool + ptyRequest ptyRequestMsg + + winpty *winpty.WinPTY + winPTYDllPath string +} + +func (handler *SessionChannelHandler) newSession(logger lager.Logger, channel ssh.Channel, keepalive time.Duration) *session { + return &session{ + logger: logger.Session("session-channel"), + keepaliveDuration: keepalive, + runner: handler.runner, + shellPath: handler.shellLocator.ShellPath(), + channel: channel, + env: handler.defaultEnv, + winPTYDllPath: handler.winPTYDllPath, + } +} + +func (sess *session) serviceRequests(requests <-chan *ssh.Request) { + logger := sess.logger + logger.Info("starting") + defer logger.Info("finished") + + defer sess.destroy() + + for req := range requests { + sess.logger.Info("received-request", lager.Data{"type": req.Type}) + switch req.Type { + case "env": + sess.handleEnvironmentRequest(req) + case "signal": + sess.handleSignalRequest(req) + case "pty-req": + sess.handlePtyRequest(req) + case "window-change": + sess.handleWindowChangeRequest(req) + case "exec": + sess.handleExecRequest(req) + case "shell": + sess.handleShellRequest(req) + case "subsystem": + sess.handleSubsystemRequest(req) + default: + if req.WantReply { + req.Reply(false, nil) + } + } + } +} + +func (sess *session) handleEnvironmentRequest(request *ssh.Request) { + logger := sess.logger.Session("handle-environment-request") + + type envMsg struct { + Name string + Value string + } + var envMessage envMsg + + err := ssh.Unmarshal(request.Payload, &envMessage) + if err != nil { + logger.Error("unmarshal-failed", err) + request.Reply(false, nil) + return + } + + sess.Lock() + sess.env[envMessage.Name] = envMessage.Value + sess.Unlock() + + if request.WantReply { + request.Reply(true, nil) + } +} + +func (sess *session) handleSignalRequest(request *ssh.Request) { + logger := sess.logger.Session("handle-signal-request") + + type signalMsg struct { + Signal string + } + var signalMessage signalMsg + + err := ssh.Unmarshal(request.Payload, &signalMessage) + if err != nil { + logger.Error("unmarshal-failed", err) + if request.WantReply { + request.Reply(false, nil) + } + return + } + + sess.Lock() + defer sess.Unlock() + + cmd := sess.command + + if cmd != nil { + var err error + signal := signals.SyscallSignals[ssh.Signal(signalMessage.Signal)] + if sess.winpty != nil { + err = sess.winpty.Signal(signal) + } else { + err = sess.runner.Signal(cmd, signal) + } + if err != nil { + logger.Error("process-signal-failed", err) + } + } + + if request.WantReply { + request.Reply(true, nil) + } +} + +func (sess *session) handlePtyRequest(request *ssh.Request) { + logger := sess.logger.Session("handle-pty-request") + + var ptyRequestMessage ptyRequestMsg + + err := ssh.Unmarshal(request.Payload, &ptyRequestMessage) + if err != nil { + logger.Error("unmarshal-failed", err) + if request.WantReply { + request.Reply(false, nil) + } + return + } + + sess.Lock() + defer sess.Unlock() + + sess.allocPty = true + sess.winpty, err = winpty.New(sess.winPTYDllPath) + if err != nil { + logger.Error("couldn't intialize winpty.dll", err) + if request.WantReply { + request.Reply(false, nil) + } + return + } + + sess.ptyRequest = ptyRequestMessage + sess.env["TERM"] = ptyRequestMessage.Term + + if request.WantReply { + request.Reply(true, nil) + } +} + +func (sess *session) handleWindowChangeRequest(request *ssh.Request) { + logger := sess.logger.Session("handle-window-change") + + type windowChangeMsg struct { + Columns uint32 + Rows uint32 + WidthPx uint32 + HeightPx uint32 + } + var windowChangeMessage windowChangeMsg + + err := ssh.Unmarshal(request.Payload, &windowChangeMessage) + if err != nil { + logger.Error("unmarshal-failed", err) + if request.WantReply { + request.Reply(false, nil) + } + return + } + + sess.Lock() + defer sess.Unlock() + + if sess.allocPty { + sess.ptyRequest.Columns = windowChangeMessage.Columns + sess.ptyRequest.Rows = windowChangeMessage.Rows + } + + if sess.winpty != nil { + err = setWindowSize(logger, sess.winpty, sess.ptyRequest.Columns, sess.ptyRequest.Rows) + if err != nil { + logger.Error("failed-to-set-window-size", err) + } + } + + if request.WantReply { + request.Reply(true, nil) + } +} + +func (sess *session) handleExecRequest(request *ssh.Request) { + logger := sess.logger.Session("handle-exec-request") + + type execMsg struct { + Command string + } + var execMessage execMsg + + err := ssh.Unmarshal(request.Payload, &execMessage) + if err != nil { + logger.Error("unmarshal-failed", err) + if request.WantReply { + request.Reply(false, nil) + } + return + } + + if scpRegex.MatchString(execMessage.Command) { + logger.Info("handling-scp-command", lager.Data{"Command": execMessage.Command}) + sess.executeSCP(execMessage.Command, request) + } else { + sess.executeShell(request, "/c", execMessage.Command) + } +} + +func (sess *session) handleShellRequest(request *ssh.Request) { + sess.executeShell(request) +} + +func (sess *session) handleSubsystemRequest(request *ssh.Request) { + logger := sess.logger.Session("handle-subsystem-request") + logger.Info("starting") + defer logger.Info("finished") + + type subsysMsg struct { + Subsystem string + } + var subsystemMessage subsysMsg + + err := ssh.Unmarshal(request.Payload, &subsystemMessage) + if err != nil { + logger.Error("unmarshal-failed", err) + if request.WantReply { + request.Reply(false, nil) + } + return + } + + if subsystemMessage.Subsystem != "sftp" { + logger.Info("unsupported-subsystem", lager.Data{"subsystem": subsystemMessage.Subsystem}) + if request.WantReply { + request.Reply(false, nil) + } + return + } + + lagerWriter := helpers.NewLagerWriter(logger.Session("sftp-server")) + sftpServer, err := sftp.NewServer(sess.channel, sess.channel, sftp.WithDebug(lagerWriter)) + if err != nil { + logger.Error("sftp-new-server-failed", err) + if request.WantReply { + request.Reply(false, nil) + } + return + } + + if request.WantReply { + request.Reply(true, nil) + } + + logger.Info("starting-server") + go func() { + defer sess.destroy() + err = sftpServer.Serve() + if err != nil { + logger.Error("sftp-serve-error", err) + } + }() +} + +func (sess *session) executeShell(request *ssh.Request, args ...string) { + logger := sess.logger.Session("execute-shell") + + sess.Lock() + cmd, err := sess.createCommand(args...) + if err != nil { + sess.Unlock() + logger.Error("failed-to-create-command", err) + if request.WantReply { + request.Reply(false, nil) + } + return + } + + if request.WantReply { + request.Reply(true, nil) + } + + if sess.allocPty { + err = sess.runWithPty(cmd) + } else { + err = sess.run(cmd) + } + + sess.Unlock() + + if err != nil { + sess.sendExitMessage(err) + sess.destroy() + return + } + + go func() { + err := sess.wait(cmd) + sess.sendExitMessage(err) + sess.destroy() + }() +} + +func (sess *session) createCommand(args ...string) (*exec.Cmd, error) { + if sess.command != nil { + return nil, errors.New("command already started") + } + + cmd := exec.Command(sess.shellPath, args...) + cmd.Env = sess.environment() + sess.command = cmd + + return cmd, nil +} + +func (sess *session) environment() []string { + env := []string{} + + env = append(env, `PATH=C:\Windows\system32;C:\Windows;C:\Windows\System32\Wbem;C:\Windows\System32\WindowsPowerShell\v1.0`) + env = append(env, "LANG=en_US.UTF8") + + for k, v := range sess.env { + if k != "HOME" && k != "USER" { + env = append(env, fmt.Sprintf("%s=%s", k, v)) + } + } + + env = append(env, fmt.Sprintf("HOME=%s", os.Getenv("HOME"))) + env = append(env, fmt.Sprintf("USER=%s", os.Getenv("USER"))) + + return env +} + +type exitStatusMsg struct { + Status uint32 +} + +type exitSignalMsg struct { + Signal string + CoreDumped bool + Error string + Lang string +} + +func (sess *session) sendExitMessage(err error) { + logger := sess.logger.Session("send-exit-message") + logger.Info("started") + defer logger.Info("finished") + + if err != nil { + logger.Error("building-exit-message-from-error", err) + } + + if err == nil { + _, sendErr := sess.channel.SendRequest("exit-status", false, ssh.Marshal(exitStatusMsg{})) + if sendErr != nil { + logger.Error("send-exit-status-failed", sendErr) + } + return + } + + var exitCode uint32 + winptyError, ok := err.(*winpty.ExitError) + if ok { + exitCode = winptyError.WaitStatus.ExitCode + } else { + exitError, ok := err.(*exec.ExitError) + if !ok { + exitMessage := exitStatusMsg{Status: 255} + _, sendErr := sess.channel.SendRequest("exit-status", false, ssh.Marshal(exitMessage)) + if sendErr != nil { + logger.Error("send-exit-status-failed", sendErr) + } + return + } + + waitStatus, ok := exitError.Sys().(syscall.WaitStatus) + if !ok { + exitMessage := exitStatusMsg{Status: 255} + _, sendErr := sess.channel.SendRequest("exit-status", false, ssh.Marshal(exitMessage)) + if sendErr != nil { + logger.Error("send-exit-status-failed", sendErr) + } + return + } + + if waitStatus.Signaled() { + exitMessage := exitSignalMsg{ + Signal: string(signals.SSHSignals[waitStatus.Signal()]), + CoreDumped: waitStatus.CoreDump(), + } + _, sendErr := sess.channel.SendRequest("exit-signal", false, ssh.Marshal(exitMessage)) + if sendErr != nil { + logger.Error("send-exit-status-failed", sendErr) + } + return + } + + exitCode = uint32(waitStatus.ExitStatus()) + } + + exitMessage := exitStatusMsg{Status: exitCode} + _, sendErr := sess.channel.SendRequest("exit-status", false, ssh.Marshal(exitMessage)) + if sendErr != nil { + logger.Error("send-exit-status-failed", sendErr) + } +} + +func setWindowSize(logger lager.Logger, pty *winpty.WinPTY, columns, rows uint32) error { + logger.Info("new-size", lager.Data{"columns": columns, "rows": rows}) + return pty.SetWinsize(columns, rows) +} + +func (sess *session) run(command *exec.Cmd) error { + logger := sess.logger.Session("run") + + command.Stdout = sess.channel + command.Stderr = sess.channel.Stderr() + + stdin, err := command.StdinPipe() + if err != nil { + return err + } + + go helpers.CopyAndClose(logger.Session("to-stdin"), nil, stdin, sess.channel, func() { stdin.Close() }) + + return sess.runner.Start(command) +} + +func (sess *session) runWithPty(command *exec.Cmd) error { + var err error + logger := sess.logger.Session("run") + + if err := sess.winpty.Open(); err != nil { + logger.Error("failed-to-open-pty", err) + return err + } + + setWindowSize(logger, sess.winpty, sess.ptyRequest.Columns, sess.ptyRequest.Rows) + + sess.wg.Add(1) + go helpers.Copy(logger.Session("to-pty"), nil, sess.winpty.StdIn, sess.channel) + go func() { + helpers.Copy(logger.Session("from-pty-out"), &sess.wg, sess.channel, sess.winpty.StdOut) + sess.channel.CloseWrite() + }() + + err = sess.winpty.Run(command) + if err == nil { + sess.keepaliveStopCh = make(chan struct{}) + go sess.keepalive(sess.keepaliveStopCh) + } + return err +} + +func (sess *session) keepalive(stopCh chan struct{}) { + logger := sess.logger.Session("keepalive") + + ticker := time.NewTicker(sess.keepaliveDuration) + defer ticker.Stop() + for { + select { + case <-ticker.C: + _, err := sess.channel.SendRequest("keepalive@cloudfoundry.org", true, nil) + logger.Info("keepalive", lager.Data{"success": err == nil}) + + if err != nil { + err = sess.winpty.Signal(syscall.SIGINT) + logger.Info("process-signaled", lager.Data{"error": err}) + return + } + case <-stopCh: + return + } + } +} + +func (sess *session) wait(command *exec.Cmd) error { + logger := sess.logger.Session("wait") + logger.Info("started") + defer logger.Info("done") + if sess.allocPty { + return sess.winpty.Wait() + } else { + return sess.runner.Wait(command) + } +} + +func (sess *session) destroy() { + logger := sess.logger.Session("destroy") + logger.Info("started") + defer logger.Info("done") + + sess.Lock() + defer sess.Unlock() + + if sess.complete { + return + } + + sess.complete = true + sess.wg.Wait() + + if sess.channel != nil { + sess.channel.Close() + } + + if sess.winpty != nil { + sess.winpty.Close() + sess.winpty = nil + } + + if sess.keepaliveStopCh != nil { + close(sess.keepaliveStopCh) + } +} + +func (sess *session) executeSCP(command string, request *ssh.Request) { + logger := sess.logger.Session("execute-scp") + + if request.WantReply { + request.Reply(true, nil) + } + + copier, err := scp.NewFromCommand(command, sess.channel, sess.channel, sess.channel.Stderr(), logger) + if err == nil { + err = copier.Copy() + } + + sess.sendSCPExitMessage(err) + sess.destroy() +} + +func (sess *session) sendSCPExitMessage(err error) { + logger := sess.logger.Session("send-scp-exit-message") + logger.Info("started") + defer logger.Info("finished") + + var exitMessage exitStatusMsg + if err != nil { + logger.Error("building-scp-exit-message-from-error", err) + exitMessage = exitStatusMsg{Status: 1} + } + + _, sendErr := sess.channel.SendRequest("exit-status", false, ssh.Marshal(exitMessage)) + if sendErr != nil { + logger.Error("send-exit-status-failed", sendErr) + } +} diff --git a/handlers/session_channel_handler_windows2016_test.go b/handlers/session_channel_handler_windows2016_test.go new file mode 100644 index 0000000..289aa7f --- /dev/null +++ b/handlers/session_channel_handler_windows2016_test.go @@ -0,0 +1,975 @@ +// +build windows,!windows2012R2 + +package handlers_test + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "strings" + "time" + + "code.cloudfoundry.org/diego-ssh/daemon" + "code.cloudfoundry.org/diego-ssh/handlers" + "code.cloudfoundry.org/diego-ssh/handlers/fakes" + "code.cloudfoundry.org/diego-ssh/test_helpers" + "code.cloudfoundry.org/lager/lagertest" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" +) + +var _ = Describe("SessionChannelHandler", func() { + var ( + sshd *daemon.Daemon + client *ssh.Client + + logger *lagertest.TestLogger + serverSSHConfig *ssh.ServerConfig + + runner *fakes.FakeRunner + shellLocator *fakes.FakeShellLocator + sessionChannelHandler *handlers.SessionChannelHandler + + newChannelHandlers map[string]handlers.NewChannelHandler + defaultEnv map[string]string + connectionFinished chan struct{} + ) + + BeforeEach(func() { + logger = lagertest.NewTestLogger("test") + serverSSHConfig = &ssh.ServerConfig{ + NoClientAuth: true, + } + serverSSHConfig.AddHostKey(TestHostKey) + + runner = &fakes.FakeRunner{} + realRunner := handlers.NewCommandRunner() + runner.StartStub = realRunner.Start + runner.WaitStub = realRunner.Wait + runner.SignalStub = realRunner.Signal + + shellLocator = &fakes.FakeShellLocator{} + shellLocator.ShellPathReturns("cmd.exe") + + defaultEnv = map[string]string{} + for _, env := range os.Environ() { + k := strings.Split(env, "=")[0] + v := strings.Split(env, "=")[1] + defaultEnv[k] = v + } + defaultEnv["TEST"] = "FOO" + + delete(defaultEnv, "Path") + delete(defaultEnv, "PATH") + + sessionChannelHandler = handlers.NewSessionChannelHandler(runner, shellLocator, defaultEnv, time.Second) + + newChannelHandlers = map[string]handlers.NewChannelHandler{ + "session": sessionChannelHandler, + } + + serverNetConn, clientNetConn := test_helpers.Pipe() + + sshd = daemon.New(logger, serverSSHConfig, nil, newChannelHandlers) + connectionFinished = make(chan struct{}) + go func() { + sshd.HandleConnection(serverNetConn) + close(connectionFinished) + }() + + client = test_helpers.NewClient(clientNetConn, nil) + }) + + AfterEach(func() { + if client != nil { + err := client.Close() + Expect(err).NotTo(HaveOccurred()) + } + Eventually(connectionFinished).Should(BeClosed()) + }) + + Context("when a session is opened", func() { + var session *ssh.Session + + BeforeEach(func() { + var sessionErr error + session, sessionErr = client.NewSession() + + Expect(sessionErr).NotTo(HaveOccurred()) + }) + + It("can use the session to execute a command with stdout and stderr", func() { + stdout, err := session.StdoutPipe() + Expect(err).NotTo(HaveOccurred()) + + stderr, err := session.StderrPipe() + Expect(err).NotTo(HaveOccurred()) + + err = session.Run("echo Hello && echo Goodbye 1>&2") + Expect(err).NotTo(HaveOccurred()) + + stdoutBytes, err := ioutil.ReadAll(stdout) + Expect(err).NotTo(HaveOccurred()) + Expect(string(stdoutBytes)).To(ContainSubstring("Hello")) + Expect(string(stdoutBytes)).NotTo(ContainSubstring("Goodbye")) + + stderrBytes, err := ioutil.ReadAll(stderr) + Expect(err).NotTo(HaveOccurred()) + Expect(string(stderrBytes)).To(ContainSubstring("Goodbye")) + Expect(string(stderrBytes)).NotTo(ContainSubstring("Hello")) + }) + + It("returns when the process exits", func() { + stdin, err := session.StdinPipe() + Expect(err).NotTo(HaveOccurred()) + + err = session.Run("dir") + Expect(err).NotTo(HaveOccurred()) + + stdin.Close() + }) + + Describe("scp", func() { + var ( + sourceDir, generatedTextFile, targetDir string + err error + stdin io.WriteCloser + stdout io.Reader + fileContents []byte + ) + + BeforeEach(func() { + stdin, err = session.StdinPipe() + Expect(err).NotTo(HaveOccurred()) + + stdout, err = session.StdoutPipe() + Expect(err).NotTo(HaveOccurred()) + + sourceDir, err = ioutil.TempDir("", "scp-source") + Expect(err).NotTo(HaveOccurred()) + + fileContents = []byte("---\nthis is a simple file\n\n") + generatedTextFile = filepath.Join(sourceDir, "textfile.txt") + + err = ioutil.WriteFile(generatedTextFile, fileContents, 0664) + Expect(err).NotTo(HaveOccurred()) + + targetDir, err = ioutil.TempDir("", "scp-target") + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + Expect(os.RemoveAll(sourceDir)).To(Succeed()) + Expect(os.RemoveAll(targetDir)).To(Succeed()) + }) + + It("properly copies using the secure copier", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := session.Run(fmt.Sprintf("scp -v -t %s", strings.Replace(targetDir, `\`, `\\`, -1))) + Expect(err).NotTo(HaveOccurred()) + close(done) + }() + + confirmation := make([]byte, 1) + _, err = stdout.Read(confirmation) + Expect(err).NotTo(HaveOccurred()) + Expect(confirmation).To(Equal([]byte{0})) + + expectedFileInfo, err := os.Stat(generatedTextFile) + Expect(err).NotTo(HaveOccurred()) + + _, err = stdin.Write([]byte(fmt.Sprintf("C0664 %d textfile.txt\n", expectedFileInfo.Size()))) + Expect(err).NotTo(HaveOccurred()) + + _, err = stdout.Read(confirmation) + Expect(err).NotTo(HaveOccurred()) + Expect(confirmation).To(Equal([]byte{0})) + + _, err = stdin.Write(fileContents) + Expect(err).NotTo(HaveOccurred()) + + _, err = stdin.Write([]byte{0}) + Expect(err).NotTo(HaveOccurred()) + + _, err = stdout.Read(confirmation) + Expect(err).NotTo(HaveOccurred()) + Expect(confirmation).To(Equal([]byte{0})) + + err = stdin.Close() + Expect(err).NotTo(HaveOccurred()) + + actualFilePath := filepath.Join(targetDir, filepath.Base(generatedTextFile)) + actualFileInfo, err := os.Stat(actualFilePath) + Expect(err).NotTo(HaveOccurred()) + + Expect(actualFileInfo.Mode()).To(Equal(expectedFileInfo.Mode())) + Expect(actualFileInfo.Size()).To(Equal(expectedFileInfo.Size())) + + actualContents, err := ioutil.ReadFile(actualFilePath) + Expect(err).NotTo(HaveOccurred()) + + expectedContents, err := ioutil.ReadFile(generatedTextFile) + Expect(err).NotTo(HaveOccurred()) + + Expect(actualContents).To(Equal(expectedContents)) + + Eventually(done).Should(BeClosed()) + }) + + It("properly fails when secure copying fails", func() { + errCh := make(chan error) + go func() { + defer GinkgoRecover() + errCh <- session.Run(fmt.Sprintf("scp -v -t %s", strings.Replace(targetDir, `\`, `\\`, -1))) + }() + + confirmation := make([]byte, 1) + _, err = stdout.Read(confirmation) + Expect(err).NotTo(HaveOccurred()) + Expect(confirmation).To(Equal([]byte{0})) + + _, err = stdin.Write([]byte("BOGUS PROTOCOL MESSAGE\n")) + Expect(err).NotTo(HaveOccurred()) + + _, err = stdout.Read(confirmation) + Expect(err).NotTo(HaveOccurred()) + Expect(confirmation).To(Equal([]byte{1})) + + err = <-errCh + exitErr, ok := err.(*ssh.ExitError) + Expect(ok).To(BeTrue()) + Expect(exitErr.ExitStatus()).To(Equal(1)) + }) + + It("properly fails when incorrect arguments are supplied", func() { + err := session.Run(fmt.Sprintf("scp -v -t /tmp/foo /tmp/bar")) + Expect(err).To(HaveOccurred()) + + exitErr, ok := err.(*ssh.ExitError) + Expect(ok).To(BeTrue()) + Expect(exitErr.ExitStatus()).To(Equal(1)) + }) + }) + + Describe("the shell locator", func() { + BeforeEach(func() { + err := session.Run("exit 0") + Expect(err).NotTo(HaveOccurred()) + }) + + It("uses the shell locator to find the default shell path", func() { + Expect(shellLocator.ShellPathCallCount()).To(Equal(1)) + + cmd := runner.StartArgsForCall(0) + Expect(cmd.Path).To(Equal("C:\\Windows\\system32\\cmd.exe")) + }) + }) + + Context("when stdin is provided by the client", func() { + BeforeEach(func() { + session.Stdin = strings.NewReader("Hello") + }) + + It("can use the session to execute a command that reads it", func() { + result, err := session.Output(`findstr x*`) + + Expect(err).NotTo(HaveOccurred()) + Expect(strings.TrimSpace(string(result))).To(Equal("Hello")) + }) + }) + + Context("when the command exits with a non-zero value", func() { + It("it preserve the exit code", func() { + err := session.Run("exit 3") + Expect(err).To(HaveOccurred()) + + exitErr, ok := err.(*ssh.ExitError) + Expect(ok).To(BeTrue()) + Expect(exitErr.ExitStatus()).To(Equal(3)) + }) + }) + + Context("when SIGKILL is sent across the session", func() { + Context("before a command has been run", func() { + BeforeEach(func() { + err := session.Signal(ssh.SIGKILL) + Expect(err).NotTo(HaveOccurred()) + }) + + It("does not prevent the command from running", func() { + result, err := session.Output("echo still kicking") + Expect(err).NotTo(HaveOccurred()) + Expect(strings.TrimSpace(string(result))).To(Equal(strings.TrimSpace("still kicking"))) + }) + }) + + Context("while a command is running", func() { + var stdin io.WriteCloser + var stdout io.Reader + + BeforeEach(func() { + var err error + stdin, err = session.StdinPipe() + Expect(err).NotTo(HaveOccurred()) + + stdout, err = session.StdoutPipe() + Expect(err).NotTo(HaveOccurred()) + + err = session.Shell() + Expect(err).NotTo(HaveOccurred()) + + reader := bufio.NewReader(stdout) + Eventually(reader.ReadLine).Should(ContainSubstring("Microsoft Windows")) + + Eventually(runner.StartCallCount).Should(Equal(1)) + }) + + It("is sent to the process", func() { + err := session.Signal(ssh.SIGKILL) + Expect(err).NotTo(HaveOccurred()) + + Eventually(runner.SignalCallCount).Should(Equal(1)) + + err = stdin.Close() + if err != nil { + Expect(err).To(Equal(io.EOF), "expected no error or ignorable EOF error") + } + + err = session.Wait() + Expect(err).To(HaveOccurred()) + Expect(err.(*ssh.ExitError).ExitStatus()).To(Equal(1)) + }) + }) + }) + + Context("when running a command without an explicit environemnt", func() { + It("does not inherit daemon's environment", func() { + os.Setenv("DAEMON_ENV", "daemon_env_value") + + result, err := session.Output("set") + Expect(err).NotTo(HaveOccurred()) + + Expect(result).NotTo(ContainSubstring("DAEMON_ENV=daemon_env_value")) + os.Unsetenv("DAEMON_ENV") + }) + + It("includes a default environment excluding PATH", func() { + result, err := session.Output("set") + Expect(err).NotTo(HaveOccurred()) + + Expect(string(result)).To(ContainSubstring(fmt.Sprintf(`PATH=C:\Windows\system32;C:\Windows;C:\Windows\System32\Wbem;C:\Windows\System32\WindowsPowerShell\v1.0`))) + Expect(result).To(ContainSubstring(fmt.Sprintf("LANG=en_US.UTF8"))) + Expect(result).To(ContainSubstring(fmt.Sprintf("TEST=FOO"))) + Expect(result).To(ContainSubstring(fmt.Sprintf("HOME=%s", os.Getenv("HOME")))) + Expect(result).To(ContainSubstring(fmt.Sprintf("USER=%s", os.Getenv("USER")))) + }) + }) + + Context("when environment variables are requested", func() { + Context("before starting the command", func() { + It("runs the command with the specified environment", func() { + err := session.Setenv("ENV1", "value1") + Expect(err).NotTo(HaveOccurred()) + + err = session.Setenv("ENV2", "value2") + Expect(err).NotTo(HaveOccurred()) + + result, err := session.Output("set") + Expect(err).NotTo(HaveOccurred()) + + Expect(result).To(ContainSubstring("ENV1=value1")) + Expect(result).To(ContainSubstring("ENV2=value2")) + }) + + It("uses the value last specified", func() { + err := session.Setenv("ENV1", "original") + Expect(err).NotTo(HaveOccurred()) + + err = session.Setenv("ENV1", "updated") + Expect(err).NotTo(HaveOccurred()) + + result, err := session.Output("set") + Expect(err).NotTo(HaveOccurred()) + + Expect(result).To(ContainSubstring("ENV1=updated")) + }) + + It("can override PATH and LANG", func() { + err := session.Setenv("PATH", "/bin:/usr/local/bin:/sbin") + Expect(err).NotTo(HaveOccurred()) + + err = session.Setenv("LANG", "en_UK.UTF8") + Expect(err).NotTo(HaveOccurred()) + + result, err := session.Output("set") + Expect(err).NotTo(HaveOccurred()) + + Expect(result).To(ContainSubstring("PATH=/bin:/usr/local/bin:/sbin")) + Expect(result).To(ContainSubstring("LANG=en_UK.UTF8")) + }) + + It("cannot override HOME and USER", func() { + err := session.Setenv("HOME", "/some/other/home") + Expect(err).NotTo(HaveOccurred()) + + err = session.Setenv("USER", "not-a-user") + Expect(err).NotTo(HaveOccurred()) + + result, err := session.Output("set") + Expect(err).NotTo(HaveOccurred()) + + Expect(result).To(ContainSubstring(fmt.Sprintf("HOME=%s", os.Getenv("HOME")))) + Expect(result).To(ContainSubstring(fmt.Sprintf("USER=%s", os.Getenv("USER")))) + }) + + It("can override default env variables", func() { + err := session.Setenv("TEST", "BAR") + Expect(err).NotTo(HaveOccurred()) + + result, err := session.Output("set") + Expect(err).NotTo(HaveOccurred()) + + Expect(result).To(ContainSubstring("TEST=BAR")) + }) + }) + + Context("after starting the command", func() { + var stdin io.WriteCloser + var stdout io.Reader + + BeforeEach(func() { + var err error + stdin, err = session.StdinPipe() + Expect(err).NotTo(HaveOccurred()) + + stdout, err = session.StdoutPipe() + Expect(err).NotTo(HaveOccurred()) + + err = session.Start(`findstr x* & set`) + Expect(err).NotTo(HaveOccurred()) + }) + + It("ignores the request", func() { + err := session.Setenv("ENV3", "value3") + Expect(err).NotTo(HaveOccurred()) + + stdin.Close() + + err = session.Wait() + Expect(err).NotTo(HaveOccurred()) + + stdoutBytes, err := ioutil.ReadAll(stdout) + Expect(err).NotTo(HaveOccurred()) + + Expect(string(stdoutBytes)).NotTo(ContainSubstring("ENV3")) + }) + }) + }) + + Context("when a pty request is received", func() { + var terminalModes ssh.TerminalModes + + BeforeEach(func() { + terminalModes = ssh.TerminalModes{} + }) + + JustBeforeEach(func() { + err := session.RequestPty("vt100", 43, 80, terminalModes) + Expect(err).NotTo(HaveOccurred()) + }) + + It("should allocate a console for the session", func() { + result, err := session.Output("timeout 1 2>nul >nul & if errorlevel 1 (echo redirect) else (echo console)") + Expect(err).NotTo(HaveOccurred()) + + Expect(string(result)).To(ContainSubstring("console")) + }) + + It("returns when the process exits", func() { + stdin, err := session.StdinPipe() + Expect(err).NotTo(HaveOccurred()) + + err = session.Run("dir") + Expect(err).NotTo(HaveOccurred()) + + stdin.Close() + }) + + It("terminates the shell when the stdin closes", func() { + err := session.Shell() + Expect(err).NotTo(HaveOccurred()) + time.Sleep(1 * time.Second) + + err = client.Conn.Close() + client = nil + Expect(err).NotTo(HaveOccurred()) + err = session.Wait() + Expect(err.Error()).To(Equal("wait: remote command exited without exit status or exit signal")) + }) + + It("should set the terminal type", func() { + result, err := session.Output(`echo %TERM%`) + Expect(err).NotTo(HaveOccurred()) + + Expect(string(result)).To(ContainSubstring("vt100")) + }) + + It("sets the correct window size for the terminal", func() { + result, err := session.Output("powershell.exe -command $w = $host.ui.rawui.WindowSize.Width; $h = $host.ui.rawui.WindowSize.Height; echo \"$h $w\"") + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(ContainSubstring("43 80")) + }) + + Context("when an interactive command is executed", func() { + var stdin io.WriteCloser + + JustBeforeEach(func() { + var err error + stdin, err = session.StdinPipe() + Expect(err).NotTo(HaveOccurred()) + }) + + It("terminates the session when the shell exits", func() { + err := session.Start("cmd.exe") + Expect(err).NotTo(HaveOccurred()) + + _, err = stdin.Write([]byte("exit\r\n")) + Expect(err).NotTo(HaveOccurred()) + + err = stdin.Close() + Expect(err).NotTo(HaveOccurred()) + + err = session.Wait() + Expect(err).NotTo(HaveOccurred()) + }) + }) + + Context("when a signal is sent across the session", func() { + Context("before a command has been run", func() { + BeforeEach(func() { + err := session.Signal(ssh.SIGKILL) + Expect(err).NotTo(HaveOccurred()) + }) + + It("does not prevent the command from running", func() { + result, err := session.Output("echo still kicking") + Expect(err).NotTo(HaveOccurred()) + Expect(string(result)).To(ContainSubstring("still kicking")) + }) + }) + + Context("SIGKILL is sent while a command is running", func() { + var stdin io.WriteCloser + var stdout io.Reader + + JustBeforeEach(func() { + var err error + stdin, err = session.StdinPipe() + Expect(err).NotTo(HaveOccurred()) + + stdout, err = session.StdoutPipe() + Expect(err).NotTo(HaveOccurred()) + + err = session.Shell() + Expect(err).NotTo(HaveOccurred()) + + reader := bufio.NewReader(stdout) + Eventually(reader.ReadLine).Should(ContainSubstring("Microsoft Windows")) + }) + + It("kills the process", func() { + err := session.Signal(ssh.SIGKILL) + Expect(err).NotTo(HaveOccurred()) + + err = stdin.Close() + if err != nil { + Expect(err).To(Equal(io.EOF), "expected no error or ignorable EOF error") + } + + err = session.Wait() + Expect(err).To(HaveOccurred()) + Expect(err.(*ssh.ExitError).ExitStatus()).To(Equal(1)) + }) + }) + + Context("SIGINT is sent while a command is running", func() { + var stdin io.WriteCloser + var stdout io.Reader + + JustBeforeEach(func() { + var err error + stdin, err = session.StdinPipe() + Expect(err).NotTo(HaveOccurred()) + + stdout, err = session.StdoutPipe() + Expect(err).NotTo(HaveOccurred()) + + err = session.Start("echo hello & findstr *x & echo goodbye") + Expect(err).NotTo(HaveOccurred()) + + reader := bufio.NewReader(stdout) + Eventually(reader.ReadLine).Should(ContainSubstring("hello")) + }) + + It("the process is interrupted", func() { + err := session.Signal(ssh.SIGINT) + Expect(err).NotTo(HaveOccurred()) + + err = stdin.Close() + if err != nil { + Expect(err).To(Equal(io.EOF), "expected no error or ignorable EOF error") + } + + err = session.Wait() + Expect(err).NotTo(HaveOccurred()) + reader := bufio.NewReader(stdout) + Eventually(reader.ReadLine).Should(ContainSubstring("goodbye")) + }) + }) + }) + }) + + Context("when a window change request is received", func() { + type winChangeMsg struct { + Columns uint32 + Rows uint32 + WidthPx uint32 + HeightPx uint32 + } + + var result []byte + + Context("before a pty is allocated", func() { + BeforeEach(func() { + _, err := session.SendRequest("window-change", false, ssh.Marshal(winChangeMsg{ + Rows: 50, + Columns: 132, + })) + Expect(err).NotTo(HaveOccurred()) + + err = session.RequestPty("vt100", 43, 80, ssh.TerminalModes{}) + Expect(err).NotTo(HaveOccurred()) + + result, err = session.Output("powershell.exe -command $w = $host.ui.rawui.WindowSize.Width; $h = $host.ui.rawui.WindowSize.Height; echo \"$h $w\"") + Expect(err).NotTo(HaveOccurred()) + }) + + It("ignores the request", func() { + Expect(result).To(ContainSubstring("43 80")) + }) + }) + + Context("after a pty is allocated", func() { + BeforeEach(func() { + err := session.RequestPty("vt100", 43, 80, ssh.TerminalModes{}) + Expect(err).NotTo(HaveOccurred()) + + _, err = session.SendRequest("window-change", false, ssh.Marshal(winChangeMsg{ + Rows: 50, + Columns: 132, + })) + Expect(err).NotTo(HaveOccurred()) + + result, err = session.Output("powershell.exe -command $w = $host.ui.rawui.WindowSize.Width; $h = $host.ui.rawui.WindowSize.Height; echo \"$h $w\"") + Expect(err).NotTo(HaveOccurred()) + }) + + It("changes the the size of the terminal", func() { + Expect(result).To(ContainSubstring("50 132")) + }) + }) + }) + + Context("after executing a command", func() { + BeforeEach(func() { + err := session.Run("exit") + Expect(err).NotTo(HaveOccurred()) + }) + + It("the session is no longer usable", func() { + _, err := session.SendRequest("exec", true, ssh.Marshal(struct{ Command string }{Command: "exit"})) + Expect(err).To(HaveOccurred()) + + _, err = session.SendRequest("bogus", true, nil) + Expect(err).To(HaveOccurred()) + + err = session.Setenv("foo", "bar") + Expect(err).To(HaveOccurred()) + }) + }) + + Context("when an interactive shell is requested", func() { + var stdin io.WriteCloser + + BeforeEach(func() { + var err error + stdin, err = session.StdinPipe() + Expect(err).NotTo(HaveOccurred()) + + err = session.Shell() + Expect(err).NotTo(HaveOccurred()) + }) + + It("starts the shell with the runner", func() { + Eventually(runner.StartCallCount).Should(Equal(1)) + + command := runner.StartArgsForCall(0) + Expect(command.Path).To(Equal("C:\\Windows\\system32\\cmd.exe")) + Expect(command.Args).To(ConsistOf("cmd.exe")) + }) + + It("terminates the session when the shell exits", func() { + _, err := stdin.Write([]byte("exit\n")) + Expect(err).NotTo(HaveOccurred()) + + err = stdin.Close() + if err != nil { + Expect(err).To(Equal(io.EOF), "expected no error or ignorable EOF error") + } + + err = session.Wait() + Expect(err).NotTo(HaveOccurred()) + }) + }) + + Context("and a command is provided", func() { + BeforeEach(func() { + err := session.Run("exit") + Expect(err).NotTo(HaveOccurred()) + }) + + It("uses the provided runner to start the command", func() { + Expect(runner.StartCallCount()).To(Equal(1)) + Expect(runner.WaitCallCount()).To(Equal(1)) + }) + + It("passes the correct command to the runner", func() { + command := runner.StartArgsForCall(0) + Expect(command.Path).To(Equal("C:\\Windows\\system32\\cmd.exe")) + Expect(command.Args).To(ConsistOf("cmd.exe", "/c", "exit")) + }) + + It("passes the same command to Start and Wait", func() { + command := runner.StartArgsForCall(0) + Expect(runner.WaitArgsForCall(0)).To(Equal(command)) + }) + }) + + Context("when executing an invalid command", func() { + It("returns an exit error with a non-zero exit status", func() { + err := session.Run("not-a-command") + Expect(err).To(HaveOccurred()) + + exitErr, ok := err.(*ssh.ExitError) + Expect(ok).To(BeTrue()) + Expect(exitErr.ExitStatus()).NotTo(Equal(0)) + }) + + Context("when starting the command fails", func() { + BeforeEach(func() { + runner.StartReturns(errors.New("oops")) + }) + + It("returns an exit status message with a non-zero status", func() { + err := session.Run("true") + Expect(err).To(HaveOccurred()) + + exitErr, ok := err.(*ssh.ExitError) + Expect(ok).To(BeTrue()) + Expect(exitErr.ExitStatus()).NotTo(Equal(0)) + }) + }) + + Context("when waiting on the command fails", func() { + BeforeEach(func() { + runner.WaitReturns(errors.New("oops")) + }) + + It("returns an exit status message with a non-zero status", func() { + err := session.Run("true") + Expect(err).To(HaveOccurred()) + + exitErr, ok := err.(*ssh.ExitError) + Expect(ok).To(BeTrue()) + Expect(exitErr.ExitStatus()).NotTo(Equal(0)) + }) + }) + }) + + Context("when an unknown request type is sent", func() { + var accepted bool + + BeforeEach(func() { + var err error + accepted, err = session.SendRequest("unknown-request-type", true, []byte("payload")) + Expect(err).NotTo(HaveOccurred()) + }) + + It("rejects the request", func() { + Expect(accepted).To(BeFalse()) + }) + + It("does not terminate the session", func() { + response, err := session.Output("echo Hello") + Expect(err).NotTo(HaveOccurred()) + Expect(strings.TrimSpace(string(response))).To(Equal("Hello")) + }) + }) + + Context("when an unknown subsystem is requested", func() { + var accepted bool + + BeforeEach(func() { + type subsysMsg struct{ Subsystem string } + + var err error + accepted, err = session.SendRequest("subsystem", true, ssh.Marshal(subsysMsg{Subsystem: "unknown"})) + Expect(err).NotTo(HaveOccurred()) + }) + + It("rejects the request", func() { + Expect(accepted).To(BeFalse()) + }) + + It("does not terminate the session", func() { + response, err := session.Output("echo Hello") + Expect(err).NotTo(HaveOccurred()) + Expect(strings.TrimSpace(string(response))).To(Equal("Hello")) + }) + }) + }) + + Context("when the sftp subystem is requested", func() { + It("accepts the request", func() { + type subsysMsg struct{ Subsystem string } + session, err := client.NewSession() + Expect(err).NotTo(HaveOccurred()) + defer session.Close() + + accepted, err := session.SendRequest("subsystem", true, ssh.Marshal(subsysMsg{Subsystem: "sftp"})) + Expect(err).NotTo(HaveOccurred()) + Expect(accepted).To(BeTrue()) + }) + + It("starts an sftp server in write mode", func() { + tempDir, err := ioutil.TempDir("", "sftp") + Expect(err).NotTo(HaveOccurred()) + defer os.RemoveAll(tempDir) + + sftp, err := sftp.NewClient(client) + Expect(err).NotTo(HaveOccurred()) + defer sftp.Close() + + By("creating the file") + target := filepath.Join(tempDir, "textfile.txt") + file, err := sftp.Create(target) + Expect(err).NotTo(HaveOccurred()) + + fileContents := []byte("---\nthis is a simple file\n\n") + _, err = file.Write(fileContents) + Expect(err).NotTo(HaveOccurred()) + + err = file.Close() + Expect(err).NotTo(HaveOccurred()) + + Expect(ioutil.ReadFile(target)).To(Equal(fileContents)) + + By("reading the file") + file, err = sftp.Open(target) + Expect(err).NotTo(HaveOccurred()) + + buffer := &bytes.Buffer{} + _, err = buffer.ReadFrom(file) + Expect(err).NotTo(HaveOccurred()) + + err = file.Close() + Expect(err).NotTo(HaveOccurred()) + + Expect(buffer.Bytes()).To(Equal(fileContents)) + + By("removing the file") + err = sftp.Remove(target) + Expect(err).NotTo(HaveOccurred()) + + _, err = os.Stat(target) + Expect(err).To(HaveOccurred()) + Expect(os.IsNotExist(err)).To(BeTrue()) + }) + }) + + Describe("invalid session channel requests", func() { + var channel ssh.Channel + var requests <-chan *ssh.Request + + BeforeEach(func() { + var err error + channel, requests, err = client.OpenChannel("session", nil) + Expect(err).NotTo(HaveOccurred()) + + go ssh.DiscardRequests(requests) + }) + + AfterEach(func() { + if channel != nil { + channel.Close() + } + }) + + Context("when an exec request fails to unmarshal", func() { + It("rejects the request", func() { + accepted, err := channel.SendRequest("exec", true, ssh.Marshal(struct{ Bogus uint32 }{Bogus: 1138})) + Expect(err).NotTo(HaveOccurred()) + Expect(accepted).To(BeFalse()) + }) + }) + + Context("when an env request fails to unmarshal", func() { + It("rejects the request", func() { + accepted, err := channel.SendRequest("env", true, ssh.Marshal(struct{ Bogus int }{Bogus: 1234})) + Expect(err).NotTo(HaveOccurred()) + Expect(accepted).To(BeFalse()) + }) + }) + + Context("when a signal request fails to unmarshal", func() { + It("rejects the request", func() { + accepted, err := channel.SendRequest("signal", true, ssh.Marshal(struct{ Bogus int }{Bogus: 1234})) + Expect(err).NotTo(HaveOccurred()) + Expect(accepted).To(BeFalse()) + }) + }) + + Context("when a pty request fails to unmarshal", func() { + It("rejects the request", func() { + accepted, err := channel.SendRequest("pty-req", true, ssh.Marshal(struct{ Bogus int }{Bogus: 1234})) + Expect(err).NotTo(HaveOccurred()) + Expect(accepted).To(BeFalse()) + }) + }) + + Context("when a window change request fails to unmarshal", func() { + It("rejects the request", func() { + accepted, err := channel.SendRequest("window-change", true, ssh.Marshal(struct{ Bogus int }{Bogus: 1234})) + Expect(err).NotTo(HaveOccurred()) + Expect(accepted).To(BeFalse()) + }) + }) + + Context("when a subsystem request fails to unmarshal", func() { + It("rejects the request", func() { + accepted, err := channel.SendRequest("subsystem", true, ssh.Marshal(struct{ Bogus int }{Bogus: 1234})) + Expect(err).NotTo(HaveOccurred()) + Expect(accepted).To(BeFalse()) + }) + }) + }) +}) diff --git a/handlers/shell_locator.go b/handlers/shell_locator.go index 055b447..8e6531a 100644 --- a/handlers/shell_locator.go +++ b/handlers/shell_locator.go @@ -1,3 +1,5 @@ +// +build !windows + package handlers import "os/exec" diff --git a/handlers/shell_locator_windows.go b/handlers/shell_locator_windows.go new file mode 100644 index 0000000..0eff466 --- /dev/null +++ b/handlers/shell_locator_windows.go @@ -0,0 +1,21 @@ +// +build windows + +package handlers + +import "os/exec" + +type shellLocator struct{} + +func NewShellLocator() ShellLocator { + return &shellLocator{} +} + +func (shellLocator) ShellPath() string { + for _, shell := range []string{"cmd.exe"} { + if path, err := exec.LookPath(shell); err == nil { + return path + } + } + + return "cmd.exe" +} diff --git a/winpty/winpty.go b/winpty/winpty.go new file mode 100644 index 0000000..b0ad890 --- /dev/null +++ b/winpty/winpty.go @@ -0,0 +1,310 @@ +// +build windows + +package winpty + +import ( + "errors" + "fmt" + "math" + "os" + "os/exec" + "path/filepath" + "regexp" + "strings" + "syscall" + "unicode/utf16" + "unsafe" + + "golang.org/x/sys/windows" +) + +var ( + winpty *windows.LazyDLL + winpty_config_new *windows.LazyProc + winpty_config_free *windows.LazyProc + winpty_error_free *windows.LazyProc + winpty_error_msg *windows.LazyProc + winpty_open *windows.LazyProc + winpty_free *windows.LazyProc + winpty_conin_name *windows.LazyProc + winpty_conout_name *windows.LazyProc + winpty_spawn_config_new *windows.LazyProc + winpty_spawn_config_free *windows.LazyProc + winpty_spawn *windows.LazyProc + winpty_set_size *windows.LazyProc +) + +var ( + kernel32 = windows.NewLazySystemDLL("kernel32.dll") + terminateProcess = kernel32.NewProc("TerminateProcess") +) + +type WinPTY struct { + StdIn *os.File + StdOut *os.File + + winPTYHandle uintptr + childHandle uintptr +} + +const ( + WINPTY_SPAWN_FLAG_AUTO_SHUTDOWN = uint64(1) +) + +func New(winPTYDllPath string) (*WinPTY, error) { + winpty = windows.NewLazyDLL(filepath.Join(winPTYDllPath, "winpty.dll")) + winpty_config_new = winpty.NewProc("winpty_config_new") + winpty_config_free = winpty.NewProc("winpty_config_free") + winpty_error_free = winpty.NewProc("winpty_error_free") + winpty_error_msg = winpty.NewProc("winpty_error_msg") + winpty_open = winpty.NewProc("winpty_open") + winpty_free = winpty.NewProc("winpty_free") + winpty_conin_name = winpty.NewProc("winpty_conin_name") + winpty_conout_name = winpty.NewProc("winpty_conout_name") + winpty_spawn_config_new = winpty.NewProc("winpty_spawn_config_new") + winpty_spawn_config_free = winpty.NewProc("winpty_spawn_config_free") + winpty_spawn = winpty.NewProc("winpty_spawn") + winpty_set_size = winpty.NewProc("winpty_set_size") + + var errorPtr uintptr + defer winpty_error_free.Call(errorPtr) + agentCfg, _, _ := winpty_config_new.Call(uintptr(0), uintptr(unsafe.Pointer(&errorPtr))) + if agentCfg == 0 { + return nil, fmt.Errorf("unable to create agent config: %s", winPTYErrorMessage(errorPtr)) + } + + winPTYHandle, _, _ := winpty_open.Call(agentCfg, uintptr(unsafe.Pointer(&errorPtr))) + if winPTYHandle == 0 { + return nil, fmt.Errorf("unable to launch WinPTY agent: %s", winPTYErrorMessage(errorPtr)) + } + winpty_config_free.Call(agentCfg) + + return &WinPTY{ + winPTYHandle: winPTYHandle, + }, nil +} + +func (w *WinPTY) Open() error { + if w.winPTYHandle == 0 { + return errors.New("winpty dll not initialized") + } + + stdinName, _, err := winpty_conin_name.Call(w.winPTYHandle) + if stdinName == 0 { + return fmt.Errorf("unable to get stdin pipe name: %s", err.Error()) + } + + stdoutName, _, err := winpty_conout_name.Call(w.winPTYHandle) + if stdoutName == 0 { + return fmt.Errorf("unable to get stdout pipe name: %s", err.Error()) + } + + stdinHandle, err := syscall.CreateFile((*uint16)(unsafe.Pointer(stdinName)), syscall.GENERIC_WRITE, 0, nil, syscall.OPEN_EXISTING, 0, 0) + if err != nil { + return fmt.Errorf("unable to open stdin pipe: %s", err.Error()) + } + + stdoutHandle, err := syscall.CreateFile((*uint16)(unsafe.Pointer(stdoutName)), syscall.GENERIC_READ, 0, nil, syscall.OPEN_EXISTING, 0, 0) + if err != nil { + return fmt.Errorf("unable to open stdout pipe: %s", err.Error()) + } + + w.StdIn = os.NewFile(uintptr(stdinHandle), "stdin") + w.StdOut = os.NewFile(uintptr(stdoutHandle), "stdout") + return nil +} + +func (w *WinPTY) Run(cmd *exec.Cmd) error { + escaped := makeCmdLine(append([]string{cmd.Path}, cmd.Args...)) + cmdLineStr, err := syscall.UTF16PtrFromString(escaped) + if err != nil { + w.StdOut.Close() + return fmt.Errorf("failed to convert cmd (%s) to pointer: %s", escaped, err.Error()) + } + + env := "" + for _, val := range cmd.Env { + env += (val + "\x00") + } + + var envPtr *uint16 + if env != "" { + envPtr = &utf16.Encode([]rune(env))[0] + } + + var errorPtr uintptr + defer winpty_error_free.Call(errorPtr) + spawnCfg, _, _ := winpty_spawn_config_new.Call( + uintptr(uint64(WINPTY_SPAWN_FLAG_AUTO_SHUTDOWN)), + uintptr(0), + uintptr(unsafe.Pointer(cmdLineStr)), + uintptr(0), + uintptr(unsafe.Pointer(envPtr)), + uintptr(unsafe.Pointer(&errorPtr))) + if spawnCfg == 0 { + w.StdOut.Close() + return fmt.Errorf("unable to create process config: %s", winPTYErrorMessage(errorPtr)) + } + + var createProcessErr uint32 + spawnRet, _, err := winpty_spawn.Call(w.winPTYHandle, + spawnCfg, + uintptr(unsafe.Pointer(&w.childHandle)), + uintptr(0), + uintptr(unsafe.Pointer(&createProcessErr)), + uintptr(unsafe.Pointer(&errorPtr))) + winpty_spawn_config_free.Call(spawnCfg) + if spawnRet == 0 { + w.StdOut.Close() + return fmt.Errorf("unable to spawn process: %s: %s", winPTYErrorMessage(errorPtr), windowsErrorMessage(createProcessErr)) + } + + return nil +} + +func (w *WinPTY) Wait() error { + _, err := syscall.WaitForSingleObject(syscall.Handle(w.childHandle), math.MaxUint32) + if err != nil { + return fmt.Errorf("unable to wait for child process: %s", err.Error()) + } + + var exitCode uint32 + err = syscall.GetExitCodeProcess(syscall.Handle(w.childHandle), &exitCode) + if err != nil { + return fmt.Errorf("couldn't get child exit code: %s", err.Error()) + } + + if exitCode != 0 { + return &ExitError{WaitStatus: syscall.WaitStatus{ExitCode: exitCode}} + } + + return nil +} + +type ExitError struct { + WaitStatus syscall.WaitStatus +} + +func (ee *ExitError) Error() string { + return fmt.Sprintf("exit code %d", ee.WaitStatus.ExitCode) +} + +func (w *WinPTY) Close() { + if w.winPTYHandle == 0 { + return + } + + winpty_free.Call(w.winPTYHandle) + + if w.StdIn != nil { + w.StdIn.Close() + } + + if w.StdOut != nil { + w.StdOut.Close() + } + + if w.childHandle != 0 { + syscall.CloseHandle(syscall.Handle(w.childHandle)) + } +} + +func (w *WinPTY) SetWinsize(columns, rows uint32) error { + if columns == 0 || rows == 0 { + return nil + } + ret, _, err := winpty_set_size.Call(w.winPTYHandle, uintptr(columns), uintptr(rows), uintptr(0)) + if ret == 0 { + return fmt.Errorf("failed to set window size: %s", err.Error()) + } + return nil +} + +func (w *WinPTY) Signal(sig syscall.Signal) error { + if sig == syscall.SIGINT { + return w.sendCtrlC() + } else if sig == syscall.SIGKILL { + return w.terminateChild() + } + + return syscall.Errno(syscall.EWINDOWS) +} + +func (w *WinPTY) sendCtrlC() error { + if w.childHandle == 0 { + return nil + } + + // 0x03 is Ctrl+C + // this tells the agent to generate Ctrl+C in the child process + // https://github.com/rprichard/winpty/blob/4978cf94b6ea48e38eea3146bd0d23210f87aa89/src/agent/ConsoleInput.cc#L387 + _, err := w.StdIn.Write([]byte{0x03}) + if err != nil { + return fmt.Errorf("couldn't send ctrl+c to child: %s", err.Error()) + } + return nil +} + +func (w *WinPTY) terminateChild() error { + if w.childHandle == 0 { + return nil + } + ret, _, err := terminateProcess.Call(w.childHandle, 1) + if ret == 0 { + return fmt.Errorf("failed to terminate child process: %s", err.Error()) + } + return nil +} + +func winPTYErrorMessage(ptr uintptr) string { + msgPtr, _, err := winpty_error_msg.Call(ptr) + if msgPtr == 0 { + return fmt.Sprintf("unknown error, couldn't convert: %s", err.Error()) + } + + out := make([]uint16, 0) + p := unsafe.Pointer(msgPtr) + + for { + val := *(*uint16)(p) + if val == 0 { + break + } + + out = append(out, val) + p = unsafe.Pointer(uintptr(p) + unsafe.Sizeof(uint16(0))) + } + return string(utf16.Decode(out)) +} + +func windowsErrorMessage(code uint32) string { + flags := uint32(windows.FORMAT_MESSAGE_FROM_SYSTEM | windows.FORMAT_MESSAGE_IGNORE_INSERTS) + langId := uint32(windows.SUBLANG_ENGLISH_US)<<10 | uint32(windows.LANG_ENGLISH) + buf := make([]uint16, 512) + + _, err := windows.FormatMessage(flags, uintptr(0), code, langId, buf, nil) + if err != nil { + return fmt.Sprintf("0x%x", code) + } + return strings.TrimSpace(syscall.UTF16ToString(buf)) +} + +func makeCmdLine(args []string) string { + if len(args) > 0 { + args[0] = filepath.Clean(args[0]) + base := filepath.Base(args[0]) + match, _ := regexp.MatchString(`\.[a-zA-Z]{3}$`, base) + if !match { + args[0] += ".exe" + } + } + var s string + for _, v := range args { + if s != "" { + s += " " + } + s += syscall.EscapeArg(v) + } + + return s +}