Skip to content

Commit

Permalink
Merge pull request #2127 from sschaap/feature/windows-openssh
Browse files Browse the repository at this point in the history
Support for Windows OpenSSH agent forwarding
  • Loading branch information
tonistiigi authored Jun 5, 2021
2 parents 12cfc87 + c9a5f88 commit 0c13337
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 12 deletions.
51 changes: 39 additions & 12 deletions session/sshforward/sshprovider/agentprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ func NewSSHAgentProvider(confs []AgentConfig) (session.Attachable, error) {
}

if conf.Paths[0] == "" {
return nil, errors.Errorf("invalid empty ssh agent socket, make sure SSH_AUTH_SOCK is set")
p, err := getFallbackAgentPath()
if err != nil {
return nil, errors.Wrap(err, "invalid empty ssh agent socket")
}
conf.Paths[0] = p
}

src, err := toAgentSource(conf.Paths)
Expand All @@ -56,7 +60,20 @@ func NewSSHAgentProvider(confs []AgentConfig) (session.Attachable, error) {

type source struct {
agent agent.Agent
socket string
socket *socketDialer
}

type socketDialer struct {
path string
dialer func(string) (net.Conn, error)
}

func (s socketDialer) Dial() (net.Conn, error) {
return s.dialer(s.path)
}

func (s socketDialer) String() string {
return s.path
}

type socketProvider struct {
Expand Down Expand Up @@ -94,8 +111,8 @@ func (sp *socketProvider) ForwardAgent(stream sshforward.SSH_ForwardAgentServer)

var a agent.Agent

if src.socket != "" {
conn, err := net.DialTimeout("unix", src.socket, time.Second)
if src.socket != nil {
conn, err := src.socket.Dial()
if err != nil {
return errors.Wrapf(err, "failed to connect to %s", src.socket)
}
Expand Down Expand Up @@ -124,21 +141,24 @@ func (sp *socketProvider) ForwardAgent(stream sshforward.SSH_ForwardAgentServer)

func toAgentSource(paths []string) (source, error) {
var keys bool
var socket string
var socket *socketDialer
a := agent.NewKeyring()
for _, p := range paths {
if socket != "" {
if socket != nil {
return source{}, errors.New("only single socket allowed")
}

if parsed := getWindowsPipeDialer(p); parsed != nil {
socket = parsed
continue
}

fi, err := os.Stat(p)
if err != nil {
return source{}, errors.WithStack(err)
}
if fi.Mode()&os.ModeSocket > 0 {
if keys {
return source{}, errors.Errorf("invalid combination of keys and sockets")
}
socket = p
socket = &socketDialer{path: p, dialer: unixSocketDialer}
continue
}

Expand All @@ -160,7 +180,7 @@ func toAgentSource(paths []string) (source, error) {
if keys {
return source{}, errors.Errorf("invalid combination of keys and sockets")
}
socket = p
socket = &socketDialer{path: p, dialer: unixSocketDialer}
continue
}

Expand All @@ -173,13 +193,20 @@ func toAgentSource(paths []string) (source, error) {
keys = true
}

if socket != "" {
if socket != nil {
if keys {
return source{}, errors.Errorf("invalid combination of keys and sockets")
}
return source{socket: socket}, nil
}

return source{agent: a}, nil
}

func unixSocketDialer(path string) (net.Conn, error) {
return net.DialTimeout("unix", path, 2*time.Second)
}

func sockPair() (io.ReadWriteCloser, io.ReadWriteCloser) {
pr1, pw1 := io.Pipe()
pr2, pw2 := io.Pipe()
Expand Down
15 changes: 15 additions & 0 deletions session/sshforward/sshprovider/agentprovider_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// +build !windows

package sshprovider

import (
"github.com/pkg/errors"
)

func getFallbackAgentPath() (string, error) {
return "", errors.Errorf("make sure SSH_AUTH_SOCK is set")
}

func getWindowsPipeDialer(path string) *socketDialer {
return nil
}
60 changes: 60 additions & 0 deletions session/sshforward/sshprovider/agentprovider_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// +build windows

package sshprovider

import (
"net"
"regexp"
"strings"

"github.com/Microsoft/go-winio"
"github.com/pkg/errors"
"golang.org/x/sys/windows"
)

// Returns the Windows OpenSSH agent named pipe path, but
// only if the agent is running. Returns an error otherwise.
func getFallbackAgentPath() (string, error) {
// Windows OpenSSH agent uses a named pipe rather
// than a UNIX socket. These pipes do not play nice
// with os.Stat (which tries to open its target), so
// use a FindFirstFile syscall to check for existence.
var fd windows.Win32finddata

path := `\\.\pipe\openssh-ssh-agent`
pathPtr, _ := windows.UTF16PtrFromString(path)
handle, err := windows.FindFirstFile(pathPtr, &fd)

if err != nil {
msg := "Windows OpenSSH agent not available at %s." +
" Enable the SSH agent service or set SSH_AUTH_SOCK."
return "", errors.Errorf(msg, path)
}

_ = windows.CloseHandle(handle)

return path, nil
}

// Returns true if the path references a named pipe.
func isWindowsPipePath(path string) bool {
// If path matches \\*\pipe\* then it references a named pipe
// and requires winio.DialPipe() rather than DialTimeout("unix").
// Slashes and backslashes may be used interchangeably in the path.
// Path separators may consist of multiple consecutive (back)slashes.
pipePattern := strings.ReplaceAll("^[/]{2}[^/]+[/]+pipe[/]+", "/", `\\/`)
ok, _ := regexp.MatchString(pipePattern, path)
return ok
}

func getWindowsPipeDialer(path string) *socketDialer {
if isWindowsPipePath(path) {
return &socketDialer{path: path, dialer: windowsPipeDialer}
}

return nil
}

func windowsPipeDialer(path string) (net.Conn, error) {
return winio.DialPipe(path, nil)
}

0 comments on commit 0c13337

Please sign in to comment.