Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions client/ssh/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func (p *SSHProxy) runProxySSHServer(jwtToken string) error {

func (p *SSHProxy) handleSSHSession(session ssh.Session) {
ptyReq, winCh, isPty := session.Pty()
hasCommand := len(session.Command()) > 0
hasCommand := session.RawCommand() != ""

sshClient, err := p.getOrCreateBackendClient(session.Context(), session.User())
if err != nil {
Expand Down Expand Up @@ -180,7 +180,7 @@ func (p *SSHProxy) handleSSHSession(session ssh.Session) {
}

if hasCommand {
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
if err := serverSession.Run(session.RawCommand()); err != nil {
log.Debugf("run command: %v", err)
p.handleProxyExitCode(session, err)
}
Expand Down
186 changes: 186 additions & 0 deletions client/ssh/proxy/proxy_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package proxy

import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
Expand Down Expand Up @@ -245,6 +246,191 @@ func TestSSHProxy_Connect(t *testing.T) {
cancel()
}

// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
// when forwarding commands to the backend. This is critical for tools like
// Ansible that send commands such as:
//
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
//
// The single quotes must be preserved so the backend shell receives the
// subshell expression as a single argument to -c.
func TestSSHProxy_CommandQuoting(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
Comment thread
lixmal marked this conversation as resolved.

sshClient, cleanup := setupProxySSHClient(t)
defer cleanup()

// These commands simulate what the SSH protocol delivers as exec payloads.
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
// the local shell strips the outer single quotes, and the SSH exec request
// contains the raw string: /bin/sh -c "( echo hello )"
//
// The proxy must forward this string verbatim. Using session.Command()
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
// the command on the backend.
tests := []struct {
name string
command string
expect string
}{
{
name: "subshell_in_double_quotes",
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
expect: "from-subshell\nouter\n",
},
{
name: "printf_with_special_chars",
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
expect: "hello world\n",
},
{
name: "nested_command_substitution",
command: `/bin/sh -c "echo $(echo nested)"`,
expect: "nested\n",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
session, err := sshClient.NewSession()
require.NoError(t, err)
defer func() { _ = session.Close() }()

var stderrBuf bytes.Buffer
session.Stderr = &stderrBuf

outputCh := make(chan []byte, 1)
errCh := make(chan error, 1)
go func() {
output, err := session.Output(tc.command)
outputCh <- output
errCh <- err
}()

select {
case output := <-outputCh:
err := <-errCh
if stderrBuf.Len() > 0 {
t.Logf("stderr: %s", stderrBuf.String())
}
require.NoError(t, err, "command should succeed: %s", tc.command)
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
case <-time.After(5 * time.Second):
t.Fatalf("command timed out: %s", tc.command)
}
})
}
}

// setupProxySSHClient creates a full proxy test environment and returns
// an SSH client connected through the proxy to a backend NetBird SSH server.
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
t.Helper()

const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)

jwksServer, privateKey, jwksURL := setupJWKSServer(t)

hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
require.NoError(t, err)

serverConfig := &server.Config{
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audiences: []string{audience},
KeysLocation: jwksURL,
},
}
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)

testUsername := testutil.GetTestUsername(t)
testJWTUser := "test-username"
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
require.NoError(t, err)

authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
testUsername: {0},
},
}
sshServer.UpdateSSHAuth(authConfig)

sshServerAddr := server.StartTestServer(t, sshServer)

mockDaemon := startMockDaemon(t)

host, portStr, err := net.SplitHostPort(sshServerAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)

mockDaemon.setHostKey(host, hostPubKey)

validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
mockDaemon.setJWTToken(validToken)

proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
require.NoError(t, err)

origStdin := os.Stdin
origStdout := os.Stdout

stdinReader, stdinWriter, err := os.Pipe()
require.NoError(t, err)
stdoutReader, stdoutWriter, err := os.Pipe()
require.NoError(t, err)

os.Stdin = stdinReader
os.Stdout = stdoutWriter

clientConn, proxyConn := net.Pipe()

go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)

go func() {
_ = proxyInstance.Connect(ctx)
}()

sshConfig := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second,
}

sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
require.NoError(t, err)

client := cryptossh.NewClient(sshClientConn, chans, reqs)

cleanupFn := func() {
_ = client.Close()
_ = clientConn.Close()
cancel()
os.Stdin = origStdin
os.Stdout = origStdout
_ = sshServer.Stop()
mockDaemon.stop()
jwksServer.Close()
}

return client, cleanupFn
}

type mockDaemonServer struct {
proto.UnimplementedDaemonServiceServer
hostKeys map[string][]byte
Expand Down
2 changes: 1 addition & 1 deletion client/ssh/server/session_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (s *Server) sessionHandler(session ssh.Session) {
}

ptyReq, winCh, isPty := session.Pty()
hasCommand := len(session.Command()) > 0
hasCommand := session.RawCommand() != ""

if isPty && !hasCommand {
// ssh <host> - PTY interactive session (login)
Expand Down
Loading