From 9049447e992ff1b676fecfaef62be0aac692751e Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Mon, 12 Jan 2026 16:48:08 +0800 Subject: [PATCH 01/10] Support non-PTY no-command interactive SSH sessions --- client/cmd/ssh_exec_unix.go | 2 +- client/cmd/ssh_sftp_unix.go | 2 +- client/ssh/proxy/proxy.go | 26 ++++-- client/ssh/server/command_execution.go | 25 +++--- client/ssh/server/command_execution_js.go | 7 +- client/ssh/server/command_execution_unix.go | 25 +++--- .../ssh/server/command_execution_windows.go | 79 ++++++----------- client/ssh/server/executor_unix.go | 34 +++++--- client/ssh/server/executor_unix_test.go | 10 +-- client/ssh/server/executor_windows.go | 84 ++++++++++--------- client/ssh/server/server.go | 2 +- client/ssh/server/server_config_test.go | 42 +++------- client/ssh/server/session_handlers.go | 42 ++-------- client/ssh/server/session_handlers_js.go | 4 +- client/ssh/server/sftp_windows.go | 2 +- client/ssh/server/userswitching_unix.go | 6 +- client/ssh/server/userswitching_windows.go | 23 +++-- client/ssh/server/winpty/conpty.go | 4 +- 18 files changed, 191 insertions(+), 228 deletions(-) diff --git a/client/cmd/ssh_exec_unix.go b/client/cmd/ssh_exec_unix.go index 2412f072c5e..0b085ebd916 100644 --- a/client/cmd/ssh_exec_unix.go +++ b/client/cmd/ssh_exec_unix.go @@ -52,7 +52,7 @@ func init() { // runSSHExec handles the SSH exec subcommand execution. func runSSHExec(cmd *cobra.Command, _ []string) error { - privilegeDropper := sshserver.NewPrivilegeDropper() + privilegeDropper := sshserver.NewPrivilegeDropper(nil) var groups []uint32 for _, groupInt := range sshExecGroups { diff --git a/client/cmd/ssh_sftp_unix.go b/client/cmd/ssh_sftp_unix.go index c06aab01713..ed6b8993dfb 100644 --- a/client/cmd/ssh_sftp_unix.go +++ b/client/cmd/ssh_sftp_unix.go @@ -36,7 +36,7 @@ func init() { } func sftpMain(cmd *cobra.Command, _ []string) error { - privilegeDropper := sshserver.NewPrivilegeDropper() + privilegeDropper := sshserver.NewPrivilegeDropper(nil) var groups []uint32 for _, groupInt := range sftpGroupsInt { diff --git a/client/ssh/proxy/proxy.go b/client/ssh/proxy/proxy.go index cb1c36e1313..8897b9c7e9f 100644 --- a/client/ssh/proxy/proxy.go +++ b/client/ssh/proxy/proxy.go @@ -207,8 +207,6 @@ func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) { } func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) { - // Create a backend session to mirror the client's session request. - // This keeps the connection alive on the server side while port forwarding channels operate. serverSession, err := sshClient.NewSession() if err != nil { _, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err) @@ -216,10 +214,28 @@ func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *c } defer func() { _ = serverSession.Close() }() - <-session.Context().Done() + serverSession.Stdin = session + serverSession.Stdout = session + serverSession.Stderr = session.Stderr() + + if err := serverSession.Shell(); err != nil { + log.Debugf("start shell: %v", err) + return + } - if err := session.Exit(0); err != nil { - log.Debugf("session exit: %v", err) + done := make(chan error, 1) + go func() { + done <- serverSession.Wait() + }() + + select { + case <-session.Context().Done(): + return + case err := <-done: + if err != nil { + log.Debugf("shell session: %v", err) + p.handleProxyExitCode(session, err) + } } } diff --git a/client/ssh/server/command_execution.go b/client/ssh/server/command_execution.go index 7a01ce4f665..b0a85fe4b52 100644 --- a/client/ssh/server/command_execution.go +++ b/client/ssh/server/command_execution.go @@ -12,8 +12,8 @@ import ( log "github.com/sirupsen/logrus" ) -// handleCommand executes an SSH command with privilege validation -func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, winCh <-chan ssh.Window) { +// handleExecution executes an SSH command or shell with privilege validation +func (s *Server) handleExecution(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) { hasPty := winCh != nil commandType := "command" @@ -23,7 +23,7 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command())) - execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty) + execCmd, cleanup, err := s.createCommand(logger, privilegeResult, session, hasPty) if err != nil { logger.Errorf("%s creation failed: %v", commandType, err) @@ -51,13 +51,12 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege defer cleanup() - ptyReq, _, _ := session.Pty() if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) { logger.Debugf("%s execution completed", commandType) } } -func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) { +func (s *Server) createCommand(logger *log.Entry, privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) { localUser := privilegeResult.User if localUser == nil { return nil, nil, errors.New("no user in privilege result") @@ -66,28 +65,28 @@ func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh // If PTY requested but su doesn't support --pty, skip su and use executor // This ensures PTY functionality is provided (executor runs within our allocated PTY) if hasPty && !s.suSupportsPty { - log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality") - cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty) + logger.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality") + cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty) if err != nil { return nil, nil, fmt.Errorf("create command with privileges: %w", err) } - cmd.Env = s.prepareCommandEnv(localUser, session) + cmd.Env = s.prepareCommandEnv(logger, localUser, session) return cmd, cleanup, nil } // Try su first for system integration (PAM/audit) when privileged - cmd, err := s.createSuCommand(session, localUser, hasPty) + cmd, err := s.createSuCommand(logger, session, localUser, hasPty) if err != nil || privilegeResult.UsedFallback { - log.Debugf("su command failed, falling back to executor: %v", err) - cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty) + logger.Debugf("su command failed, falling back to executor: %v", err) + cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty) if err != nil { return nil, nil, fmt.Errorf("create command with privileges: %w", err) } - cmd.Env = s.prepareCommandEnv(localUser, session) + cmd.Env = s.prepareCommandEnv(logger, localUser, session) return cmd, cleanup, nil } - cmd.Env = s.prepareCommandEnv(localUser, session) + cmd.Env = s.prepareCommandEnv(logger, localUser, session) return cmd, func() {}, nil } diff --git a/client/ssh/server/command_execution_js.go b/client/ssh/server/command_execution_js.go index 01759a3371a..9a723f5d01f 100644 --- a/client/ssh/server/command_execution_js.go +++ b/client/ssh/server/command_execution_js.go @@ -15,17 +15,17 @@ import ( var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform") // createSuCommand is not supported on JS/WASM -func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) { +func (s *Server) createSuCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) { return nil, errNotSupported } // createExecutorCommand is not supported on JS/WASM -func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) { +func (s *Server) createExecutorCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) { return nil, nil, errNotSupported } // prepareCommandEnv is not supported on JS/WASM -func (s *Server) prepareCommandEnv(_ *user.User, _ ssh.Session) []string { +func (s *Server) prepareCommandEnv(_ *log.Entry, _ *user.User, _ ssh.Session) []string { return nil } @@ -55,3 +55,4 @@ func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, e } return false } + diff --git a/client/ssh/server/command_execution_unix.go b/client/ssh/server/command_execution_unix.go index db1a9bcfe9d..f8dfa6c9ace 100644 --- a/client/ssh/server/command_execution_unix.go +++ b/client/ssh/server/command_execution_unix.go @@ -99,24 +99,29 @@ func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool { return isUtilLinux } -// createSuCommand creates a command using su -l -c for privilege switching -func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) { +// createSuCommand creates a command using su -l for privilege switching +func (s *Server) createSuCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) { + if err := validateUsername(localUser.Username); err != nil { + return nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err) + } + suPath, err := exec.LookPath("su") if err != nil { return nil, fmt.Errorf("su command not available: %w", err) } - command := session.RawCommand() - if command == "" { - return nil, fmt.Errorf("no command specified for su execution") - } - args := []string{"-l"} if hasPty && s.suSupportsPty { args = append(args, "--pty") } - args = append(args, localUser.Username, "-c", command) + args = append(args, localUser.Username) + + command := session.RawCommand() + if command != "" { + args = append(args, "-c", command) + } + logger.Debugf("creating su command: %s %v", suPath, args) cmd := exec.CommandContext(session.Context(), suPath, args...) cmd.Dir = localUser.HomeDir @@ -132,7 +137,7 @@ func (s *Server) getShellCommandArgs(shell, cmdString string) []string { } // prepareCommandEnv prepares environment variables for command execution on Unix -func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string { +func (s *Server) prepareCommandEnv(_ *log.Entry, localUser *user.User, session ssh.Session) []string { env := prepareUserEnv(localUser, getUserShell(localUser.Uid)) env = append(env, prepareSSHEnv(session)...) for _, v := range session.Environ() { @@ -154,7 +159,7 @@ func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, e return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh) } -func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { +func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session) if err != nil { logger.Errorf("Pty command creation failed: %v", err) diff --git a/client/ssh/server/command_execution_windows.go b/client/ssh/server/command_execution_windows.go index 9987968714a..2e3e1ec019d 100644 --- a/client/ssh/server/command_execution_windows.go +++ b/client/ssh/server/command_execution_windows.go @@ -20,32 +20,32 @@ import ( // getUserEnvironment retrieves the Windows environment for the target user. // Follows OpenSSH's resilient approach with graceful degradation on failures. -func (s *Server) getUserEnvironment(username, domain string) ([]string, error) { - userToken, err := s.getUserToken(username, domain) +func (s *Server) getUserEnvironment(logger *log.Entry, username, domain string) ([]string, error) { + userToken, err := s.getUserToken(logger, username, domain) if err != nil { return nil, fmt.Errorf("get user token: %w", err) } defer func() { if err := windows.CloseHandle(userToken); err != nil { - log.Debugf("close user token: %v", err) + logger.Debugf("close user token: %v", err) } }() - return s.getUserEnvironmentWithToken(userToken, username, domain) + return s.getUserEnvironmentWithToken(logger, userToken, username, domain) } // getUserEnvironmentWithToken retrieves the Windows environment using an existing token. -func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) { +func (s *Server) getUserEnvironmentWithToken(logger *log.Entry, userToken windows.Handle, username, domain string) ([]string, error) { userProfile, err := s.loadUserProfile(userToken, username, domain) if err != nil { - log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err) + logger.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err) userProfile = fmt.Sprintf("C:\\Users\\%s", username) } envMap := make(map[string]string) if err := s.loadSystemEnvironment(envMap); err != nil { - log.Debugf("failed to load system environment from registry: %v", err) + logger.Debugf("failed to load system environment from registry: %v", err) } s.setUserEnvironmentVariables(envMap, userProfile, username, domain) @@ -59,8 +59,8 @@ func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, } // getUserToken creates a user token for the specified user. -func (s *Server) getUserToken(username, domain string) (windows.Handle, error) { - privilegeDropper := NewPrivilegeDropper() +func (s *Server) getUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) { + privilegeDropper := NewPrivilegeDropper(logger) token, err := privilegeDropper.createToken(username, domain) if err != nil { return 0, fmt.Errorf("generate S4U user token: %w", err) @@ -242,9 +242,9 @@ func (s *Server) setUserEnvironmentVariables(envMap map[string]string, userProfi } // prepareCommandEnv prepares environment variables for command execution on Windows -func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string { +func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, session ssh.Session) []string { username, domain := s.parseUsername(localUser.Username) - userEnv, err := s.getUserEnvironment(username, domain) + userEnv, err := s.getUserEnvironment(logger, username, domain) if err != nil { log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err) env := prepareUserEnv(localUser, getUserShell(localUser.Uid)) @@ -267,22 +267,16 @@ func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) [] return env } -func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { +func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool { if privilegeResult.User == nil { logger.Errorf("no user in privilege result") return false } - cmd := session.Command() shell := getUserShell(privilegeResult.User.Uid) + logger.Infof("starting interactive shell: %s", shell) - if len(cmd) == 0 { - logger.Infof("starting interactive shell: %s", shell) - } else { - logger.Infof("executing command: %s", safeLogCommand(cmd)) - } - - s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd) + s.executeCommandWithPty(logger, session, nil, privilegeResult, ptyReq, nil) return true } @@ -294,11 +288,6 @@ func (s *Server) getShellCommandArgs(shell, cmdString string) []string { return []string{shell, "-Command", cmdString} } -func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) { - logger.Info("starting interactive shell") - s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand()) -} - type PtyExecutionRequest struct { Shell string Command string @@ -308,25 +297,25 @@ type PtyExecutionRequest struct { Domain string } -func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, req PtyExecutionRequest) error { - log.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d", +func executePtyCommandWithUserToken(logger *log.Entry, session ssh.Session, req PtyExecutionRequest) error { + logger.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d", req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height) - privilegeDropper := NewPrivilegeDropper() + privilegeDropper := NewPrivilegeDropper(logger) userToken, err := privilegeDropper.createToken(req.Username, req.Domain) if err != nil { return fmt.Errorf("create user token: %w", err) } defer func() { if err := windows.CloseHandle(userToken); err != nil { - log.Debugf("close user token: %v", err) + logger.Debugf("close user token: %v", err) } }() server := &Server{} - userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain) + userEnv, err := server.getUserEnvironmentWithToken(logger, userToken, req.Username, req.Domain) if err != nil { - log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err) + logger.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err) userEnv = os.Environ() } @@ -348,8 +337,8 @@ func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, re Environment: userEnv, } - log.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir) - return winpty.ExecutePtyWithUserToken(ctx, session, ptyConfig, userConfig) + logger.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir) + return winpty.ExecutePtyWithUserToken(session, ptyConfig, userConfig) } func getUserHomeFromEnv(env []string) string { @@ -371,10 +360,8 @@ func (s *Server) killProcessGroup(cmd *exec.Cmd) { return } - logger := log.WithField("pid", cmd.Process.Pid) - if err := cmd.Process.Kill(); err != nil { - logger.Debugf("kill process failed: %v", err) + log.Debugf("kill process %d failed: %v", cmd.Process.Pid, err) } } @@ -389,21 +376,7 @@ func (s *Server) detectUtilLinuxLogin(context.Context) bool { } // executeCommandWithPty executes a command with PTY allocation on Windows using ConPty -func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { - command := session.RawCommand() - if command == "" { - logger.Error("no command specified for PTY execution") - if err := session.Exit(1); err != nil { - logSessionExitError(logger, err) - } - return false - } - - return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command) -} - -// executeConPtyCommand executes a command using ConPty (common for interactive and command execution) -func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool { +func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, _ *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool { localUser := privilegeResult.User if localUser == nil { logger.Errorf("no user in privilege result") @@ -415,14 +388,14 @@ func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, pr req := PtyExecutionRequest{ Shell: shell, - Command: command, + Command: session.RawCommand(), Width: ptyReq.Window.Width, Height: ptyReq.Window.Height, Username: username, Domain: domain, } - if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil { + if err := executePtyCommandWithUserToken(logger, session, req); err != nil { logger.Errorf("ConPty execution failed: %v", err) if err := session.Exit(1); err != nil { logSessionExitError(logger, err) diff --git a/client/ssh/server/executor_unix.go b/client/ssh/server/executor_unix.go index 8adc824effe..ac848dd1e22 100644 --- a/client/ssh/server/executor_unix.go +++ b/client/ssh/server/executor_unix.go @@ -35,11 +35,21 @@ type ExecutorConfig struct { } // PrivilegeDropper handles secure privilege dropping in child processes -type PrivilegeDropper struct{} +type PrivilegeDropper struct { + logger *log.Entry +} // NewPrivilegeDropper creates a new privilege dropper -func NewPrivilegeDropper() *PrivilegeDropper { - return &PrivilegeDropper{} +func NewPrivilegeDropper(logger *log.Entry) *PrivilegeDropper { + return &PrivilegeDropper{logger: logger} +} + +// log returns the logger, falling back to standard logger if none set +func (pd *PrivilegeDropper) log() *log.Entry { + if pd.logger != nil { + return pd.logger + } + return log.NewEntry(log.StandardLogger()) } // CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping @@ -83,7 +93,7 @@ func (pd *PrivilegeDropper) CreateExecutorCommand(ctx context.Context, config Ex break } } - log.Tracef("creating executor command: %s %v", netbirdPath, safeArgs) + pd.log().Tracef("creating executor command: %s %v", netbirdPath, safeArgs) return exec.CommandContext(ctx, netbirdPath, args...), nil } @@ -206,17 +216,21 @@ func (pd *PrivilegeDropper) ExecuteWithPrivilegeDrop(ctx context.Context, config var execCmd *exec.Cmd if config.Command == "" { - os.Exit(ExitCodeSuccess) + execCmd = exec.CommandContext(ctx, config.Shell, "-l") + } else { + execCmd = exec.CommandContext(ctx, config.Shell, "-l", "-c", config.Command) } - - execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command) execCmd.Stdin = os.Stdin execCmd.Stdout = os.Stdout execCmd.Stderr = os.Stderr - cmdParts := strings.Fields(config.Command) - safeCmd := safeLogCommand(cmdParts) - log.Tracef("executing %s -c %s", execCmd.Path, safeCmd) + if config.Command == "" { + log.Tracef("executing login shell: %s -l", execCmd.Path) + } else { + cmdParts := strings.Fields(config.Command) + safeCmd := safeLogCommand(cmdParts) + log.Tracef("executing %s -l -c %s", execCmd.Path, safeCmd) + } if err := execCmd.Run(); err != nil { var exitError *exec.ExitError if errors.As(err, &exitError) { diff --git a/client/ssh/server/executor_unix_test.go b/client/ssh/server/executor_unix_test.go index 0c5108f57fa..f5dd46134dc 100644 --- a/client/ssh/server/executor_unix_test.go +++ b/client/ssh/server/executor_unix_test.go @@ -16,7 +16,7 @@ import ( ) func TestPrivilegeDropper_ValidatePrivileges(t *testing.T) { - pd := NewPrivilegeDropper() + pd := NewPrivilegeDropper(nil) currentUID := uint32(os.Geteuid()) currentGID := uint32(os.Getegid()) @@ -74,7 +74,7 @@ func TestPrivilegeDropper_ValidatePrivileges(t *testing.T) { } func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) { - pd := NewPrivilegeDropper() + pd := NewPrivilegeDropper(nil) config := ExecutorConfig{ UID: 1000, @@ -108,7 +108,7 @@ func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) { } func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) { - pd := NewPrivilegeDropper() + pd := NewPrivilegeDropper(nil) config := ExecutorConfig{ UID: 1000, @@ -157,7 +157,7 @@ func TestPrivilegeDropper_ActualPrivilegeDrop(t *testing.T) { // Test in a child process to avoid affecting the test runner if os.Getenv("TEST_PRIVILEGE_DROP") == "1" { - pd := NewPrivilegeDropper() + pd := NewPrivilegeDropper(nil) // This should succeed err := pd.DropPrivileges(targetUID, targetGID, []uint32{targetGID}) @@ -227,7 +227,7 @@ func findNonRootUser() (*user.User, error) { } func TestPrivilegeDropper_ExecuteWithPrivilegeDrop_Validation(t *testing.T) { - pd := NewPrivilegeDropper() + pd := NewPrivilegeDropper(nil) currentUID := uint32(os.Geteuid()) if currentUID == 0 { diff --git a/client/ssh/server/executor_windows.go b/client/ssh/server/executor_windows.go index d3504e05682..f34593f50e7 100644 --- a/client/ssh/server/executor_windows.go +++ b/client/ssh/server/executor_windows.go @@ -28,22 +28,31 @@ const ( ) type WindowsExecutorConfig struct { - Username string - Domain string - WorkingDir string - Shell string - Command string - Args []string - Interactive bool - Pty bool - PtyWidth int - PtyHeight int + Username string + Domain string + WorkingDir string + Shell string + Command string + Args []string + Pty bool + PtyWidth int + PtyHeight int } -type PrivilegeDropper struct{} +type PrivilegeDropper struct { + logger *log.Entry +} + +func NewPrivilegeDropper(logger *log.Entry) *PrivilegeDropper { + return &PrivilegeDropper{logger: logger} +} -func NewPrivilegeDropper() *PrivilegeDropper { - return &PrivilegeDropper{} +// log returns the logger, falling back to standard logger if none set +func (pd *PrivilegeDropper) log() *log.Entry { + if pd.logger != nil { + return pd.logger + } + return log.NewEntry(log.StandardLogger()) } var ( @@ -56,7 +65,6 @@ const ( // Common error messages commandFlag = "-Command" - closeTokenErrorMsg = "close token error: %v" // #nosec G101 -- This is an error message template, not credentials convertUsernameError = "convert username to UTF16: %w" convertDomainError = "convert domain to UTF16: %w" ) @@ -80,7 +88,7 @@ func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, co shellArgs = []string{shell} } - log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs) + pd.log().Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs) cmd, token, err := pd.CreateWindowsProcessAsUser( ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir) @@ -180,10 +188,10 @@ func newLsaString(s string) lsaString { // generateS4UUserToken creates a Windows token using S4U authentication // This is the exact approach OpenSSH for Windows uses for public key authentication -func generateS4UUserToken(username, domain string) (windows.Handle, error) { +func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) { userCpn := buildUserCpn(username, domain) - pd := NewPrivilegeDropper() + pd := NewPrivilegeDropper(logger) isDomainUser := !pd.isLocalUser(domain) lsaHandle, err := initializeLsaConnection() @@ -197,12 +205,12 @@ func generateS4UUserToken(username, domain string) (windows.Handle, error) { return 0, err } - logonInfo, logonInfoSize, err := prepareS4ULogonStructure(username, domain, isDomainUser) + logonInfo, logonInfoSize, err := prepareS4ULogonStructure(logger, username, domain, isDomainUser) if err != nil { return 0, err } - return performS4ULogon(lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser) + return performS4ULogon(logger, lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser) } // buildUserCpn constructs the user principal name @@ -310,21 +318,21 @@ func lookupPrincipalName(username, domain string) (string, error) { } // prepareS4ULogonStructure creates the appropriate S4U logon structure -func prepareS4ULogonStructure(username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) { +func prepareS4ULogonStructure(logger *log.Entry, username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) { if isDomainUser { - return prepareDomainS4ULogon(username, domain) + return prepareDomainS4ULogon(logger, username, domain) } - return prepareLocalS4ULogon(username) + return prepareLocalS4ULogon(logger, username) } // prepareDomainS4ULogon creates S4U logon structure for domain users -func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, error) { +func prepareDomainS4ULogon(logger *log.Entry, username, domain string) (unsafe.Pointer, uintptr, error) { upn, err := lookupPrincipalName(username, domain) if err != nil { return nil, 0, fmt.Errorf("lookup principal name: %w", err) } - log.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn) + logger.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn) upnUtf16, err := windows.UTF16FromString(upn) if err != nil { @@ -357,8 +365,8 @@ func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, er } // prepareLocalS4ULogon creates S4U logon structure for local users -func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) { - log.Debugf("using Msv1_0S4ULogon for local user: %s", username) +func prepareLocalS4ULogon(logger *log.Entry, username string) (unsafe.Pointer, uintptr, error) { + logger.Debugf("using Msv1_0S4ULogon for local user: %s", username) usernameUtf16, err := windows.UTF16FromString(username) if err != nil { @@ -406,11 +414,11 @@ func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) { } // performS4ULogon executes the S4U logon operation -func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) { +func performS4ULogon(logger *log.Entry, lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) { var tokenSource tokenSource copy(tokenSource.SourceName[:], "netbird") if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 { - log.Debugf("AllocateLocallyUniqueId failed") + logger.Debugf("AllocateLocallyUniqueId failed") } originName := newLsaString("netbird") @@ -441,7 +449,7 @@ func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo u if profile != 0 { if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess { - log.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret) + logger.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret) } } @@ -449,7 +457,7 @@ func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo u return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus) } - log.Debugf("created S4U %s token for user %s", + logger.Debugf("created S4U %s token for user %s", map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn) return token, nil } @@ -497,8 +505,8 @@ func (pd *PrivilegeDropper) isLocalUser(domain string) bool { // authenticateLocalUser handles authentication for local users func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) { - log.Debugf("using S4U authentication for local user %s", fullUsername) - token, err := generateS4UUserToken(username, ".") + pd.log().Debugf("using S4U authentication for local user %s", fullUsername) + token, err := generateS4UUserToken(pd.log(), username, ".") if err != nil { return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err) } @@ -507,12 +515,12 @@ func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) // authenticateDomainUser handles authentication for domain users func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) { - log.Debugf("using S4U authentication for domain user %s", fullUsername) - token, err := generateS4UUserToken(username, domain) + pd.log().Debugf("using S4U authentication for domain user %s", fullUsername) + token, err := generateS4UUserToken(pd.log(), username, domain) if err != nil { return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err) } - log.Debugf("Successfully created S4U token for domain user %s", fullUsername) + pd.log().Debugf("successfully created S4U token for domain user %s", fullUsername) return token, nil } @@ -526,7 +534,7 @@ func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, exec defer func() { if err := windows.CloseHandle(token); err != nil { - log.Debugf("close impersonation token: %v", err) + pd.log().Debugf("close impersonation token: %v", err) } }() @@ -564,7 +572,7 @@ func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceTo return cmd, primaryToken, nil } -// createSuCommand creates a command using su -l -c for privilege switching (Windows stub) -func (s *Server) createSuCommand(ssh.Session, *user.User, bool) (*exec.Cmd, error) { +// createSuCommand creates a command using su -l for privilege switching (Windows stub) +func (s *Server) createSuCommand(*log.Entry, ssh.Session, *user.User, bool) (*exec.Cmd, error) { return nil, fmt.Errorf("su command not available on Windows") } diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go index 3a8568979ab..e2162d498f9 100644 --- a/client/ssh/server/server.go +++ b/client/ssh/server/server.go @@ -335,7 +335,7 @@ func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) { sessions = append(sessions, info) } - // Add authenticated connections without sessions (e.g., -N/-T or port-forwarding only) + // Add authenticated connections without sessions (e.g., -N or port-forwarding only) for key, connState := range s.connections { remoteAddr := string(key) if reportedAddrs[remoteAddr] { diff --git a/client/ssh/server/server_config_test.go b/client/ssh/server/server_config_test.go index d85d85a51b8..f70e29963dd 100644 --- a/client/ssh/server/server_config_test.go +++ b/client/ssh/server/server_config_test.go @@ -483,12 +483,11 @@ func TestServer_IsPrivilegedUser(t *testing.T) { } } -func TestServer_PortForwardingOnlySession(t *testing.T) { - // Test that sessions without PTY and command are allowed when port forwarding is enabled +func TestServer_NonPtyShellSession(t *testing.T) { + // Test that non-PTY shell sessions (ssh -T) work regardless of port forwarding settings. currentUser, err := user.Current() require.NoError(t, err, "Should be able to get current user") - // Generate host key for server hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) require.NoError(t, err) @@ -496,36 +495,26 @@ func TestServer_PortForwardingOnlySession(t *testing.T) { name string allowLocalForwarding bool allowRemoteForwarding bool - expectAllowed bool - description string }{ { - name: "session_allowed_with_local_forwarding", + name: "shell_with_local_forwarding_enabled", allowLocalForwarding: true, allowRemoteForwarding: false, - expectAllowed: true, - description: "Port-forwarding-only session should be allowed when local forwarding is enabled", }, { - name: "session_allowed_with_remote_forwarding", + name: "shell_with_remote_forwarding_enabled", allowLocalForwarding: false, allowRemoteForwarding: true, - expectAllowed: true, - description: "Port-forwarding-only session should be allowed when remote forwarding is enabled", }, { - name: "session_allowed_with_both", + name: "shell_with_both_forwarding_enabled", allowLocalForwarding: true, allowRemoteForwarding: true, - expectAllowed: true, - description: "Port-forwarding-only session should be allowed when both forwarding types enabled", }, { - name: "session_denied_without_forwarding", + name: "shell_with_forwarding_disabled", allowLocalForwarding: false, allowRemoteForwarding: false, - expectAllowed: false, - description: "Port-forwarding-only session should be denied when all forwarding is disabled", }, } @@ -545,7 +534,6 @@ func TestServer_PortForwardingOnlySession(t *testing.T) { _ = server.Stop() }() - // Connect to the server without requesting PTY or command ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -557,20 +545,10 @@ func TestServer_PortForwardingOnlySession(t *testing.T) { _ = client.Close() }() - // Execute a command without PTY - this simulates ssh -T with no command - // The server should either allow it (port forwarding enabled) or reject it - output, err := client.ExecuteCommand(ctx, "") - if tt.expectAllowed { - // When allowed, the session stays open until cancelled - // ExecuteCommand with empty command should return without error - assert.NoError(t, err, "Session should be allowed when port forwarding is enabled") - assert.NotContains(t, output, "port forwarding is disabled", - "Output should not contain port forwarding disabled message") - } else if err != nil { - // When denied, we expect an error message about port forwarding being disabled - assert.Contains(t, err.Error(), "port forwarding is disabled", - "Should get port forwarding disabled message") - } + // Execute without PTY and no command - simulates ssh -T (shell without PTY) + // Should always succeed regardless of port forwarding settings + _, err = client.ExecuteCommand(ctx, "") + assert.NoError(t, err, "Non-PTY shell session should be allowed") }) } } diff --git a/client/ssh/server/session_handlers.go b/client/ssh/server/session_handlers.go index 3fd57806462..7582b632816 100644 --- a/client/ssh/server/session_handlers.go +++ b/client/ssh/server/session_handlers.go @@ -62,42 +62,12 @@ func (s *Server) sessionHandler(session ssh.Session) { ptyReq, winCh, isPty := session.Pty() hasCommand := len(session.Command()) > 0 - switch { - case isPty && hasCommand: - // ssh -t - Pty command execution - s.handleCommand(logger, session, privilegeResult, winCh) - case isPty: - // ssh - Pty interactive session (login) - s.handlePty(logger, session, privilegeResult, ptyReq, winCh) - case hasCommand: - // ssh - non-Pty command execution - s.handleCommand(logger, session, privilegeResult, nil) - default: - // ssh -T (or ssh -N) - no PTY, no command - s.handleNonInteractiveSession(logger, session) - } -} - -// handleNonInteractiveSession handles sessions that have no PTY and no command. -// These are typically used for port forwarding (ssh -L/-R) or tunneling (ssh -N). -func (s *Server) handleNonInteractiveSession(logger *log.Entry, session ssh.Session) { - s.updateSessionType(session, cmdNonInteractive) - - if !s.isPortForwardingEnabled() { - if _, err := io.WriteString(session, "port forwarding is disabled on this server\n"); err != nil { - logger.Debugf(errWriteSession, err) - } - if err := session.Exit(1); err != nil { - logSessionExitError(logger, err) - } - logger.Infof("rejected non-interactive session: port forwarding disabled") - return - } - - <-session.Context().Done() - - if err := session.Exit(0); err != nil { - logSessionExitError(logger, err) + if isPty && !hasCommand { + // ssh - PTY interactive session (login) + s.handlePtyLogin(logger, session, privilegeResult, ptyReq, winCh) + } else { + // ssh , ssh -t , ssh -T - command or shell execution + s.handleExecution(logger, session, privilegeResult, ptyReq, winCh) } } diff --git a/client/ssh/server/session_handlers_js.go b/client/ssh/server/session_handlers_js.go index c35e4da0b6e..4a6cf3d9221 100644 --- a/client/ssh/server/session_handlers_js.go +++ b/client/ssh/server/session_handlers_js.go @@ -9,8 +9,8 @@ import ( log "github.com/sirupsen/logrus" ) -// handlePty is not supported on JS/WASM -func (s *Server) handlePty(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool { +// handlePtyLogin is not supported on JS/WASM +func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool { errorMsg := "PTY sessions are not supported on WASM/JS platform\n" if _, err := fmt.Fprint(session.Stderr(), errorMsg); err != nil { logger.Debugf(errWriteSession, err) diff --git a/client/ssh/server/sftp_windows.go b/client/ssh/server/sftp_windows.go index dc532b9e766..f3e09543a14 100644 --- a/client/ssh/server/sftp_windows.go +++ b/client/ssh/server/sftp_windows.go @@ -31,7 +31,7 @@ func (s *Server) createSftpCommand(targetUser *user.User, sess ssh.Session) (*ex "--windows-domain", domain, } - pd := NewPrivilegeDropper() + pd := NewPrivilegeDropper(nil) token, err := pd.createToken(username, domain) if err != nil { return nil, 0, fmt.Errorf("create token: %w", err) diff --git a/client/ssh/server/userswitching_unix.go b/client/ssh/server/userswitching_unix.go index bc15574195e..c3ba9845e21 100644 --- a/client/ssh/server/userswitching_unix.go +++ b/client/ssh/server/userswitching_unix.go @@ -181,8 +181,8 @@ func (s *Server) getSupplementaryGroups(username string) ([]uint32, error) { // createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping. // Returns the command and a cleanup function (no-op on Unix). -func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) { - log.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty) +func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) { + logger.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty) if err := validateUsername(localUser.Username); err != nil { return nil, nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err) @@ -192,7 +192,7 @@ func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User if err != nil { return nil, nil, fmt.Errorf("parse user credentials: %w", err) } - privilegeDropper := NewPrivilegeDropper() + privilegeDropper := NewPrivilegeDropper(logger) config := ExecutorConfig{ UID: uid, GID: gid, diff --git a/client/ssh/server/userswitching_windows.go b/client/ssh/server/userswitching_windows.go index 5a5f75fa4c5..fb9c52233ad 100644 --- a/client/ssh/server/userswitching_windows.go +++ b/client/ssh/server/userswitching_windows.go @@ -88,20 +88,20 @@ func validateUsernameFormat(username string) error { // createExecutorCommand creates a command using Windows executor for privilege dropping. // Returns the command and a cleanup function that must be called after starting the process. -func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) { - log.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty) +func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) { + logger.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty) username, _ := s.parseUsername(localUser.Username) if err := validateUsername(username); err != nil { return nil, nil, fmt.Errorf("invalid username %q: %w", username, err) } - return s.createUserSwitchCommand(localUser, session, hasPty) + return s.createUserSwitchCommand(logger, session, localUser) } // createUserSwitchCommand creates a command with Windows user switching. // Returns the command and a cleanup function that must be called after starting the process. -func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Session, interactive bool) (*exec.Cmd, func(), error) { +func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session, localUser *user.User) (*exec.Cmd, func(), error) { username, domain := s.parseUsername(localUser.Username) shell := getUserShell(localUser.Uid) @@ -113,15 +113,14 @@ func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Sessi } config := WindowsExecutorConfig{ - Username: username, - Domain: domain, - WorkingDir: localUser.HomeDir, - Shell: shell, - Command: command, - Interactive: interactive || (rawCmd == ""), + Username: username, + Domain: domain, + WorkingDir: localUser.HomeDir, + Shell: shell, + Command: command, } - dropper := NewPrivilegeDropper() + dropper := NewPrivilegeDropper(logger) cmd, token, err := dropper.CreateWindowsExecutorCommand(session.Context(), config) if err != nil { return nil, nil, err @@ -130,7 +129,7 @@ func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Sessi cleanup := func() { if token != 0 { if err := windows.CloseHandle(windows.Handle(token)); err != nil { - log.Debugf("close primary token: %v", err) + logger.Debugf("close primary token: %v", err) } } } diff --git a/client/ssh/server/winpty/conpty.go b/client/ssh/server/winpty/conpty.go index 0f3659ffe86..c08ccfd059d 100644 --- a/client/ssh/server/winpty/conpty.go +++ b/client/ssh/server/winpty/conpty.go @@ -56,7 +56,7 @@ var ( ) // ExecutePtyWithUserToken executes a command with ConPty using user token. -func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error { +func ExecutePtyWithUserToken(session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error { args := buildShellArgs(ptyConfig.Shell, ptyConfig.Command) commandLine := buildCommandLine(args) @@ -64,7 +64,7 @@ func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig Pty: ptyConfig, User: userConfig, Session: session, - Context: ctx, + Context: session.Context(), } return executeConPtyWithConfig(commandLine, config) From c9c15aed966685fcdff88247f65264da38ef99d5 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Mon, 12 Jan 2026 17:44:07 +0800 Subject: [PATCH 02/10] Remove obsolete code --- client/ssh/server/port_forwarding.go | 7 ------- client/ssh/server/session_handlers.go | 12 ------------ 2 files changed, 19 deletions(-) diff --git a/client/ssh/server/port_forwarding.go b/client/ssh/server/port_forwarding.go index c60cf4f58a3..e16ff5d46f7 100644 --- a/client/ssh/server/port_forwarding.go +++ b/client/ssh/server/port_forwarding.go @@ -271,13 +271,6 @@ func (s *Server) isRemotePortForwardingAllowed() bool { return s.allowRemotePortForwarding } -// isPortForwardingEnabled checks if any port forwarding (local or remote) is enabled -func (s *Server) isPortForwardingEnabled() bool { - s.mu.RLock() - defer s.mu.RUnlock() - return s.allowLocalPortForwarding || s.allowRemotePortForwarding -} - // parseTcpipForwardRequest parses the SSH request payload func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) { var payload tcpipForwardMsg diff --git a/client/ssh/server/session_handlers.go b/client/ssh/server/session_handlers.go index 7582b632816..f12a75961e0 100644 --- a/client/ssh/server/session_handlers.go +++ b/client/ssh/server/session_handlers.go @@ -71,18 +71,6 @@ func (s *Server) sessionHandler(session ssh.Session) { } } -func (s *Server) updateSessionType(session ssh.Session, sessionType string) { - s.mu.Lock() - defer s.mu.Unlock() - - for _, state := range s.sessions { - if state.session == session { - state.sessionType = sessionType - return - } - } -} - func (s *Server) registerSession(session ssh.Session, sessionType string) sessionKey { sessionID := session.Context().Value(ssh.ContextKeySessionID) if sessionID == nil { From 113ff0187fc91da6d669ab411c9eaeffa2eaabbe Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Mon, 12 Jan 2026 18:17:36 +0800 Subject: [PATCH 03/10] WIP test --- client/ssh/server/command_execution_unix.go | 22 +- client/ssh/server/compatibility_test.go | 210 ++++++++++++++++++++ 2 files changed, 220 insertions(+), 12 deletions(-) diff --git a/client/ssh/server/command_execution_unix.go b/client/ssh/server/command_execution_unix.go index f8dfa6c9ace..55de3fd103d 100644 --- a/client/ssh/server/command_execution_unix.go +++ b/client/ssh/server/command_execution_unix.go @@ -249,11 +249,6 @@ func (s *Server) handlePtyIO(logger *log.Entry, session ssh.Session, ptyMgr *pty }() go func() { - defer func() { - if err := session.Close(); err != nil && !errors.Is(err, io.EOF) { - logger.Debugf("session close error: %v", err) - } - }() if _, err := io.Copy(session, ptmx); err != nil { if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) { logger.Warnf("Pty output copy error: %v", err) @@ -273,7 +268,7 @@ func (s *Server) waitForPtyCompletion(logger *log.Entry, session ssh.Session, ex case <-ctx.Done(): s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done) case err := <-done: - s.handlePtyCommandCompletion(logger, session, err) + s.handlePtyCommandCompletion(logger, session, ptyMgr, err) } } @@ -301,17 +296,20 @@ func (s *Server) handlePtySessionCancellation(logger *log.Entry, session ssh.Ses } } -func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, err error) { +func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager, err error) { if err != nil { logger.Debugf("Pty command execution failed: %v", err) s.handleSessionExit(session, err, logger) - return + } else { + logger.Debugf("Pty command completed successfully") + if err := session.Exit(0); err != nil { + logSessionExitError(logger, err) + } } - // Normal completion - logger.Debugf("Pty command completed successfully") - if err := session.Exit(0); err != nil { - logSessionExitError(logger, err) + // Close PTY to unblock io.Copy goroutines + if err := ptyMgr.Close(); err != nil { + logger.Debugf("Pty close after completion: %v", err) } } diff --git a/client/ssh/server/compatibility_test.go b/client/ssh/server/compatibility_test.go index 34ffccfd22a..759eea19d02 100644 --- a/client/ssh/server/compatibility_test.go +++ b/client/ssh/server/compatibility_test.go @@ -405,6 +405,216 @@ func createTempKeyFile(t *testing.T, privateKey []byte) (string, func()) { return createTempKeyFileFromBytes(t, privateKey) } +// TestSSHPtyModes tests different PTY allocation modes (-T, -t, -tt flags) +// This ensures our implementation matches OpenSSH behavior for: +// - ssh host command (no PTY - default when no TTY) +// - ssh -T host command (explicit no PTY) +// - ssh -t host command (force PTY) +// - ssh -T host (no PTY shell - our implementation) +func TestSSHPtyModes(t *testing.T) { + if testing.Short() { + t.Skip("Skipping SSH PTY mode tests in short mode") + } + + if !isSSHClientAvailable() { + t.Skip("SSH client not available on this system") + } + + if runtime.GOOS == "windows" && testutil.IsCI() { + t.Skip("Skipping Windows SSH PTY tests in CI due to S4U authentication issues") + } + + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t) + require.NoError(t, err) + + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + clientKeyFile, cleanupKey := createTempKeyFileFromBytes(t, clientPrivKeyOpenSSH) + defer cleanupKey() + + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + + username := testutil.GetTestUsername(t) + + baseArgs := []string{ + "-i", clientKeyFile, + "-p", portStr, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + "-o", "BatchMode=yes", + } + + t.Run("command_default_no_pty", func(t *testing.T) { + // ssh host command - no PTY allocation (tests don't have TTY) + args := append(baseArgs, fmt.Sprintf("%s@%s", username, host), "echo", "no_pty_default") + cmd := exec.Command("ssh", args...) + + output, err := cmd.CombinedOutput() + if err != nil { + t.Logf("Command (default no PTY) failed: %v, output: %s", err, output) + return + } + assert.Contains(t, string(output), "no_pty_default") + }) + + t.Run("command_explicit_no_pty", func(t *testing.T) { + // ssh -T host command - explicit no PTY + args := append(baseArgs, "-T", fmt.Sprintf("%s@%s", username, host), "echo", "explicit_no_pty") + cmd := exec.Command("ssh", args...) + + output, err := cmd.CombinedOutput() + if err != nil { + t.Logf("Command (-T explicit no PTY) failed: %v, output: %s", err, output) + return + } + assert.Contains(t, string(output), "explicit_no_pty") + }) + + t.Run("command_force_pty", func(t *testing.T) { + // ssh -t host command - force PTY allocation + // Use -tt to really force PTY even without TTY on our end + args := append(baseArgs, "-tt", fmt.Sprintf("%s@%s", username, host), "echo", "force_pty") + cmd := exec.Command("ssh", args...) + + output, err := cmd.CombinedOutput() + if err != nil { + t.Logf("Command (-tt force PTY) failed: %v, output: %s", err, output) + // PTY allocation might fail in some test environments, that's OK + return + } + // PTY output might have \r\n line endings + assert.Contains(t, string(output), "force_pty") + }) + + t.Run("shell_explicit_no_pty", func(t *testing.T) { + // ssh -T host - shell without PTY (our new behavior matching OpenSSH) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + args := append(baseArgs, "-T", fmt.Sprintf("%s@%s", username, host)) + cmd := exec.CommandContext(ctx, "ssh", args...) + + stdin, err := cmd.StdinPipe() + require.NoError(t, err) + + stdout, err := cmd.StdoutPipe() + require.NoError(t, err) + + err = cmd.Start() + if err != nil { + t.Logf("Shell (-T no PTY) start failed: %v", err) + return + } + + // Send commands through the non-PTY shell + go func() { + defer stdin.Close() + time.Sleep(100 * time.Millisecond) + stdin.Write([]byte("echo shell_no_pty_test\n")) + time.Sleep(100 * time.Millisecond) + stdin.Write([]byte("exit 0\n")) + }() + + output, _ := io.ReadAll(stdout) + err = cmd.Wait() + + if err != nil { + t.Logf("Shell (-T no PTY) failed: %v, output: %s", err, output) + return + } + assert.Contains(t, string(output), "shell_no_pty_test") + }) + + t.Run("exit_code_preserved_no_pty", func(t *testing.T) { + // Verify exit codes work with -T + args := append(baseArgs, "-T", fmt.Sprintf("%s@%s", username, host), "exit", "42") + cmd := exec.Command("ssh", args...) + + err := cmd.Run() + require.Error(t, err) + + if exitErr, ok := err.(*exec.ExitError); ok { + assert.Equal(t, 42, exitErr.ExitCode(), "Exit code should be preserved with -T") + } + }) + + t.Run("exit_code_preserved_with_pty", func(t *testing.T) { + // Verify exit codes work with -tt + // Use bash -c to ensure proper exit code handling + args := append(baseArgs, "-tt", fmt.Sprintf("%s@%s", username, host), "bash -c 'exit 43'") + cmd := exec.Command("ssh", args...) + + err := cmd.Run() + if err == nil { + t.Log("PTY command succeeded unexpectedly (exit 0)") + return + } + + exitErr, ok := err.(*exec.ExitError) + if !ok { + t.Logf("PTY exit code test: non-exit error: %v", err) + return + } + + assert.Equal(t, 43, exitErr.ExitCode(), "Exit code should be preserved with -tt") + }) + + t.Run("stderr_works_no_pty", func(t *testing.T) { + // Verify stderr is separate from stdout without PTY + // Pass the entire command as a single string for proper shell interpretation + args := append(baseArgs, "-T", fmt.Sprintf("%s@%s", username, host), + "sh -c 'echo stdout_msg; echo stderr_msg >&2'") + cmd := exec.Command("ssh", args...) + + var stdout, stderr strings.Builder + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + err := cmd.Run() + if err != nil { + t.Logf("stderr test failed: %v", err) + return + } + + assert.Contains(t, stdout.String(), "stdout_msg", "stdout should have stdout_msg") + assert.Contains(t, stderr.String(), "stderr_msg", "stderr should have stderr_msg") + assert.NotContains(t, stdout.String(), "stderr_msg", "stdout should NOT have stderr_msg") + }) + + t.Run("stderr_merged_with_pty", func(t *testing.T) { + // With PTY, stderr is merged into stdout + args := append(baseArgs, "-tt", fmt.Sprintf("%s@%s", username, host), + "sh -c 'echo stdout_msg; echo stderr_msg >&2'") + cmd := exec.Command("ssh", args...) + + output, err := cmd.CombinedOutput() + if err != nil { + t.Logf("PTY stderr test failed: %v, output: %s", err, output) + return + } + + // With PTY, both messages should appear in combined output + assert.Contains(t, string(output), "stdout_msg") + assert.Contains(t, string(output), "stderr_msg") + }) +} + // TestSSHServerFeatureCompatibility tests specific SSH features for compatibility func TestSSHServerFeatureCompatibility(t *testing.T) { if testing.Short() { From 9a9b7138bdb636c87c15ecd36dca1edb8b35568d Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 22 Jan 2026 16:46:44 +0800 Subject: [PATCH 04/10] Make options optional --- client/cmd/ssh_exec_unix.go | 2 +- client/cmd/ssh_sftp_unix.go | 2 +- client/ssh/server/executor_unix.go | 18 ++++++++++++++++-- client/ssh/server/executor_unix_test.go | 10 +++++----- client/ssh/server/executor_windows.go | 18 ++++++++++++++++-- client/ssh/server/userswitching_unix.go | 2 +- 6 files changed, 40 insertions(+), 12 deletions(-) diff --git a/client/cmd/ssh_exec_unix.go b/client/cmd/ssh_exec_unix.go index 0b085ebd916..2412f072c5e 100644 --- a/client/cmd/ssh_exec_unix.go +++ b/client/cmd/ssh_exec_unix.go @@ -52,7 +52,7 @@ func init() { // runSSHExec handles the SSH exec subcommand execution. func runSSHExec(cmd *cobra.Command, _ []string) error { - privilegeDropper := sshserver.NewPrivilegeDropper(nil) + privilegeDropper := sshserver.NewPrivilegeDropper() var groups []uint32 for _, groupInt := range sshExecGroups { diff --git a/client/cmd/ssh_sftp_unix.go b/client/cmd/ssh_sftp_unix.go index ed6b8993dfb..c06aab01713 100644 --- a/client/cmd/ssh_sftp_unix.go +++ b/client/cmd/ssh_sftp_unix.go @@ -36,7 +36,7 @@ func init() { } func sftpMain(cmd *cobra.Command, _ []string) error { - privilegeDropper := sshserver.NewPrivilegeDropper(nil) + privilegeDropper := sshserver.NewPrivilegeDropper() var groups []uint32 for _, groupInt := range sftpGroupsInt { diff --git a/client/ssh/server/executor_unix.go b/client/ssh/server/executor_unix.go index ac848dd1e22..0c856d98016 100644 --- a/client/ssh/server/executor_unix.go +++ b/client/ssh/server/executor_unix.go @@ -39,9 +39,23 @@ type PrivilegeDropper struct { logger *log.Entry } +// PrivilegeDropperOption is a functional option for configuring PrivilegeDropper +type PrivilegeDropperOption func(*PrivilegeDropper) + // NewPrivilegeDropper creates a new privilege dropper -func NewPrivilegeDropper(logger *log.Entry) *PrivilegeDropper { - return &PrivilegeDropper{logger: logger} +func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper { + pd := &PrivilegeDropper{} + for _, opt := range opts { + opt(pd) + } + return pd +} + +// WithLogger sets the logger for the PrivilegeDropper +func WithLogger(logger *log.Entry) PrivilegeDropperOption { + return func(pd *PrivilegeDropper) { + pd.logger = logger + } } // log returns the logger, falling back to standard logger if none set diff --git a/client/ssh/server/executor_unix_test.go b/client/ssh/server/executor_unix_test.go index f5dd46134dc..0c5108f57fa 100644 --- a/client/ssh/server/executor_unix_test.go +++ b/client/ssh/server/executor_unix_test.go @@ -16,7 +16,7 @@ import ( ) func TestPrivilegeDropper_ValidatePrivileges(t *testing.T) { - pd := NewPrivilegeDropper(nil) + pd := NewPrivilegeDropper() currentUID := uint32(os.Geteuid()) currentGID := uint32(os.Getegid()) @@ -74,7 +74,7 @@ func TestPrivilegeDropper_ValidatePrivileges(t *testing.T) { } func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) { - pd := NewPrivilegeDropper(nil) + pd := NewPrivilegeDropper() config := ExecutorConfig{ UID: 1000, @@ -108,7 +108,7 @@ func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) { } func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) { - pd := NewPrivilegeDropper(nil) + pd := NewPrivilegeDropper() config := ExecutorConfig{ UID: 1000, @@ -157,7 +157,7 @@ func TestPrivilegeDropper_ActualPrivilegeDrop(t *testing.T) { // Test in a child process to avoid affecting the test runner if os.Getenv("TEST_PRIVILEGE_DROP") == "1" { - pd := NewPrivilegeDropper(nil) + pd := NewPrivilegeDropper() // This should succeed err := pd.DropPrivileges(targetUID, targetGID, []uint32{targetGID}) @@ -227,7 +227,7 @@ func findNonRootUser() (*user.User, error) { } func TestPrivilegeDropper_ExecuteWithPrivilegeDrop_Validation(t *testing.T) { - pd := NewPrivilegeDropper(nil) + pd := NewPrivilegeDropper() currentUID := uint32(os.Geteuid()) if currentUID == 0 { diff --git a/client/ssh/server/executor_windows.go b/client/ssh/server/executor_windows.go index f34593f50e7..5d12ef5a76f 100644 --- a/client/ssh/server/executor_windows.go +++ b/client/ssh/server/executor_windows.go @@ -43,8 +43,22 @@ type PrivilegeDropper struct { logger *log.Entry } -func NewPrivilegeDropper(logger *log.Entry) *PrivilegeDropper { - return &PrivilegeDropper{logger: logger} +// PrivilegeDropperOption is a functional option for configuring PrivilegeDropper +type PrivilegeDropperOption func(*PrivilegeDropper) + +func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper { + pd := &PrivilegeDropper{} + for _, opt := range opts { + opt(pd) + } + return pd +} + +// WithLogger sets the logger for the PrivilegeDropper +func WithLogger(logger *log.Entry) PrivilegeDropperOption { + return func(pd *PrivilegeDropper) { + pd.logger = logger + } } // log returns the logger, falling back to standard logger if none set diff --git a/client/ssh/server/userswitching_unix.go b/client/ssh/server/userswitching_unix.go index c3ba9845e21..d9940065f76 100644 --- a/client/ssh/server/userswitching_unix.go +++ b/client/ssh/server/userswitching_unix.go @@ -192,7 +192,7 @@ func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, l if err != nil { return nil, nil, fmt.Errorf("parse user credentials: %w", err) } - privilegeDropper := NewPrivilegeDropper(logger) + privilegeDropper := NewPrivilegeDropper(WithLogger(logger)) config := ExecutorConfig{ UID: uid, GID: gid, From 94ade9dc951a9c56d5e80c6d2179275bb722e481 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 22 Jan 2026 16:57:39 +0800 Subject: [PATCH 05/10] Fix build --- client/ssh/server/command_execution_js.go | 1 - .../ssh/server/command_execution_windows.go | 4 ++-- client/ssh/server/compatibility_test.go | 20 ++++++++++++------- client/ssh/server/executor_windows.go | 2 +- client/ssh/server/userswitching_windows.go | 2 +- 5 files changed, 17 insertions(+), 12 deletions(-) diff --git a/client/ssh/server/command_execution_js.go b/client/ssh/server/command_execution_js.go index 9a723f5d01f..3aeaa135cda 100644 --- a/client/ssh/server/command_execution_js.go +++ b/client/ssh/server/command_execution_js.go @@ -55,4 +55,3 @@ func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, e } return false } - diff --git a/client/ssh/server/command_execution_windows.go b/client/ssh/server/command_execution_windows.go index 2e3e1ec019d..e1ba777f606 100644 --- a/client/ssh/server/command_execution_windows.go +++ b/client/ssh/server/command_execution_windows.go @@ -60,7 +60,7 @@ func (s *Server) getUserEnvironmentWithToken(logger *log.Entry, userToken window // getUserToken creates a user token for the specified user. func (s *Server) getUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) { - privilegeDropper := NewPrivilegeDropper(logger) + privilegeDropper := NewPrivilegeDropper(WithLogger(logger)) token, err := privilegeDropper.createToken(username, domain) if err != nil { return 0, fmt.Errorf("generate S4U user token: %w", err) @@ -301,7 +301,7 @@ func executePtyCommandWithUserToken(logger *log.Entry, session ssh.Session, req logger.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d", req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height) - privilegeDropper := NewPrivilegeDropper(logger) + privilegeDropper := NewPrivilegeDropper(WithLogger(logger)) userToken, err := privilegeDropper.createToken(req.Username, req.Domain) if err != nil { return fmt.Errorf("create user token: %w", err) diff --git a/client/ssh/server/compatibility_test.go b/client/ssh/server/compatibility_test.go index 759eea19d02..56084cbc985 100644 --- a/client/ssh/server/compatibility_test.go +++ b/client/ssh/server/compatibility_test.go @@ -10,6 +10,7 @@ import ( "os" "os/exec" "runtime" + "slices" "strings" "testing" "time" @@ -462,7 +463,7 @@ func TestSSHPtyModes(t *testing.T) { t.Run("command_default_no_pty", func(t *testing.T) { // ssh host command - no PTY allocation (tests don't have TTY) - args := append(baseArgs, fmt.Sprintf("%s@%s", username, host), "echo", "no_pty_default") + args := append(slices.Clone(baseArgs), fmt.Sprintf("%s@%s", username, host), "echo", "no_pty_default") cmd := exec.Command("ssh", args...) output, err := cmd.CombinedOutput() @@ -475,7 +476,7 @@ func TestSSHPtyModes(t *testing.T) { t.Run("command_explicit_no_pty", func(t *testing.T) { // ssh -T host command - explicit no PTY - args := append(baseArgs, "-T", fmt.Sprintf("%s@%s", username, host), "echo", "explicit_no_pty") + args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "echo", "explicit_no_pty") cmd := exec.Command("ssh", args...) output, err := cmd.CombinedOutput() @@ -489,7 +490,7 @@ func TestSSHPtyModes(t *testing.T) { t.Run("command_force_pty", func(t *testing.T) { // ssh -t host command - force PTY allocation // Use -tt to really force PTY even without TTY on our end - args := append(baseArgs, "-tt", fmt.Sprintf("%s@%s", username, host), "echo", "force_pty") + args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "echo", "force_pty") cmd := exec.Command("ssh", args...) output, err := cmd.CombinedOutput() @@ -507,7 +508,7 @@ func TestSSHPtyModes(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - args := append(baseArgs, "-T", fmt.Sprintf("%s@%s", username, host)) + args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host)) cmd := exec.CommandContext(ctx, "ssh", args...) stdin, err := cmd.StdinPipe() @@ -526,9 +527,14 @@ func TestSSHPtyModes(t *testing.T) { go func() { defer stdin.Close() time.Sleep(100 * time.Millisecond) - stdin.Write([]byte("echo shell_no_pty_test\n")) + if _, err := stdin.Write([]byte("echo shell_no_pty_test\n")); err != nil { + t.Errorf("write echo command: %v", err) + return + } time.Sleep(100 * time.Millisecond) - stdin.Write([]byte("exit 0\n")) + if _, err := stdin.Write([]byte("exit 0\n")); err != nil { + t.Errorf("write exit command: %v", err) + } }() output, _ := io.ReadAll(stdout) @@ -543,7 +549,7 @@ func TestSSHPtyModes(t *testing.T) { t.Run("exit_code_preserved_no_pty", func(t *testing.T) { // Verify exit codes work with -T - args := append(baseArgs, "-T", fmt.Sprintf("%s@%s", username, host), "exit", "42") + args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "exit", "42") cmd := exec.Command("ssh", args...) err := cmd.Run() diff --git a/client/ssh/server/executor_windows.go b/client/ssh/server/executor_windows.go index 5d12ef5a76f..92a521ef6a1 100644 --- a/client/ssh/server/executor_windows.go +++ b/client/ssh/server/executor_windows.go @@ -205,7 +205,7 @@ func newLsaString(s string) lsaString { func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) { userCpn := buildUserCpn(username, domain) - pd := NewPrivilegeDropper(logger) + pd := NewPrivilegeDropper(WithLogger(logger)) isDomainUser := !pd.isLocalUser(domain) lsaHandle, err := initializeLsaConnection() diff --git a/client/ssh/server/userswitching_windows.go b/client/ssh/server/userswitching_windows.go index fb9c52233ad..260e1301ed6 100644 --- a/client/ssh/server/userswitching_windows.go +++ b/client/ssh/server/userswitching_windows.go @@ -120,7 +120,7 @@ func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session, Command: command, } - dropper := NewPrivilegeDropper(logger) + dropper := NewPrivilegeDropper(WithLogger(logger)) cmd, token, err := dropper.CreateWindowsExecutorCommand(session.Context(), config) if err != nil { return nil, nil, err From 9f6e259700e1e5975d22a90746955183e8abf1c5 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 22 Jan 2026 17:08:18 +0800 Subject: [PATCH 06/10] Fix lint --- client/ssh/server/compatibility_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/client/ssh/server/compatibility_test.go b/client/ssh/server/compatibility_test.go index 56084cbc985..16e12f6a04f 100644 --- a/client/ssh/server/compatibility_test.go +++ b/client/ssh/server/compatibility_test.go @@ -563,7 +563,7 @@ func TestSSHPtyModes(t *testing.T) { t.Run("exit_code_preserved_with_pty", func(t *testing.T) { // Verify exit codes work with -tt // Use bash -c to ensure proper exit code handling - args := append(baseArgs, "-tt", fmt.Sprintf("%s@%s", username, host), "bash -c 'exit 43'") + args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "bash -c 'exit 43'") cmd := exec.Command("ssh", args...) err := cmd.Run() @@ -584,7 +584,7 @@ func TestSSHPtyModes(t *testing.T) { t.Run("stderr_works_no_pty", func(t *testing.T) { // Verify stderr is separate from stdout without PTY // Pass the entire command as a single string for proper shell interpretation - args := append(baseArgs, "-T", fmt.Sprintf("%s@%s", username, host), + args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "sh -c 'echo stdout_msg; echo stderr_msg >&2'") cmd := exec.Command("ssh", args...) @@ -605,7 +605,7 @@ func TestSSHPtyModes(t *testing.T) { t.Run("stderr_merged_with_pty", func(t *testing.T) { // With PTY, stderr is merged into stdout - args := append(baseArgs, "-tt", fmt.Sprintf("%s@%s", username, host), + args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "sh -c 'echo stdout_msg; echo stderr_msg >&2'") cmd := exec.Command("ssh", args...) From 7ada564eeedae1a3748668a7d9fe62616781e871 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 22 Jan 2026 20:18:36 +0800 Subject: [PATCH 07/10] Fix tests where su or su with --pty is unavailable --- client/ssh/server/compatibility_test.go | 70 +++++++++++++++++++------ 1 file changed, 55 insertions(+), 15 deletions(-) diff --git a/client/ssh/server/compatibility_test.go b/client/ssh/server/compatibility_test.go index 16e12f6a04f..9e4fb8c9ff0 100644 --- a/client/ssh/server/compatibility_test.go +++ b/client/ssh/server/compatibility_test.go @@ -24,25 +24,66 @@ import ( "github.com/netbirdio/netbird/client/ssh/testutil" ) -// TestMain handles package-level setup and cleanup func TestMain(m *testing.M) { - // Guard against infinite recursion when test binary is called as "netbird ssh exec" - // This happens when running tests as non-privileged user with fallback + // On platforms where su doesn't support --pty (macOS, FreeBSD, Windows), the SSH server + // spawns an executor subprocess via os.Executable(). During tests, this invokes the test + // binary with "ssh exec" args. We handle that here to properly execute commands and + // propagate exit codes. if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" { - // Just exit with error to break the recursion - fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n") - os.Exit(1) + runTestExecutor() + return } - // Run tests code := m.Run() - - // Cleanup any created test users testutil.CleanupTestUsers() - os.Exit(code) } +// runTestExecutor emulates the netbird executor for tests. +// Parses --shell and --cmd args, runs the command, and exits with the correct code. +func runTestExecutor() { + if os.Getenv("_NETBIRD_TEST_EXECUTOR") != "" { + fmt.Fprintf(os.Stderr, "executor recursion detected\n") + os.Exit(1) + } + os.Setenv("_NETBIRD_TEST_EXECUTOR", "1") + + shell := "/bin/sh" + var command string + for i := 3; i < len(os.Args); i++ { + switch os.Args[i] { + case "--shell": + if i+1 < len(os.Args) { + shell = os.Args[i+1] + i++ + } + case "--cmd": + if i+1 < len(os.Args) { + command = os.Args[i+1] + i++ + } + } + } + + var cmd *exec.Cmd + if command == "" { + cmd = exec.Command(shell, "-l") + } else { + cmd = exec.Command(shell, "-l", "-c", command) + } + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Run(); err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + os.Exit(exitErr.ExitCode()) + } + os.Exit(1) + } + os.Exit(0) +} + // TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client func TestSSHServerCompatibility(t *testing.T) { if testing.Short() { @@ -528,12 +569,12 @@ func TestSSHPtyModes(t *testing.T) { defer stdin.Close() time.Sleep(100 * time.Millisecond) if _, err := stdin.Write([]byte("echo shell_no_pty_test\n")); err != nil { - t.Errorf("write echo command: %v", err) + t.Logf("write echo command: %v", err) return } time.Sleep(100 * time.Millisecond) if _, err := stdin.Write([]byte("exit 0\n")); err != nil { - t.Errorf("write exit command: %v", err) + t.Logf("write exit command: %v", err) } }() @@ -561,9 +602,8 @@ func TestSSHPtyModes(t *testing.T) { }) t.Run("exit_code_preserved_with_pty", func(t *testing.T) { - // Verify exit codes work with -tt - // Use bash -c to ensure proper exit code handling - args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "bash -c 'exit 43'") + // Verify exit codes work with -tt (use sh for portability - bash may not be installed on FreeBSD) + args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "sh -c 'exit 43'") cmd := exec.Command("ssh", args...) err := cmd.Run() From 8865acab1e874a069054297c832438e25a9d8f88 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 22 Jan 2026 20:45:24 +0800 Subject: [PATCH 08/10] Fix windows build --- client/ssh/server/sftp_windows.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/ssh/server/sftp_windows.go b/client/ssh/server/sftp_windows.go index f3e09543a14..dc532b9e766 100644 --- a/client/ssh/server/sftp_windows.go +++ b/client/ssh/server/sftp_windows.go @@ -31,7 +31,7 @@ func (s *Server) createSftpCommand(targetUser *user.User, sess ssh.Session) (*ex "--windows-domain", domain, } - pd := NewPrivilegeDropper(nil) + pd := NewPrivilegeDropper() token, err := pd.createToken(username, domain) if err != nil { return nil, 0, fmt.Errorf("create token: %w", err) From 1e506cc430521448d3d5f2756c7004d3f53d34f9 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 22 Jan 2026 21:24:32 +0800 Subject: [PATCH 09/10] Address review --- client/ssh/server/compatibility_test.go | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/client/ssh/server/compatibility_test.go b/client/ssh/server/compatibility_test.go index 9e4fb8c9ff0..ac9123c8299 100644 --- a/client/ssh/server/compatibility_test.go +++ b/client/ssh/server/compatibility_test.go @@ -594,11 +594,18 @@ func TestSSHPtyModes(t *testing.T) { cmd := exec.Command("ssh", args...) err := cmd.Run() - require.Error(t, err) + if err == nil { + t.Log("Command succeeded unexpectedly (exit 0)") + return + } - if exitErr, ok := err.(*exec.ExitError); ok { - assert.Equal(t, 42, exitErr.ExitCode(), "Exit code should be preserved with -T") + exitErr, ok := err.(*exec.ExitError) + if !ok { + t.Logf("Non-exit error: %v", err) + return } + + assert.Equal(t, 42, exitErr.ExitCode(), "Exit code should be preserved with -T") }) t.Run("exit_code_preserved_with_pty", func(t *testing.T) { From 82dc0c665b077cf80337a5c6ea1dba88b3586b61 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 22 Jan 2026 23:48:05 +0800 Subject: [PATCH 10/10] Fix freebsd test --- client/ssh/server/command_execution_unix.go | 18 ++-- client/ssh/server/compatibility_test.go | 97 +++++---------------- client/ssh/server/executor_unix.go | 10 ++- client/ssh/server/executor_windows.go | 2 +- client/ssh/server/server_test.go | 10 ++- client/ssh/server/userswitching_unix.go | 2 +- 6 files changed, 49 insertions(+), 90 deletions(-) diff --git a/client/ssh/server/command_execution_unix.go b/client/ssh/server/command_execution_unix.go index 55de3fd103d..279b8934103 100644 --- a/client/ssh/server/command_execution_unix.go +++ b/client/ssh/server/command_execution_unix.go @@ -10,6 +10,7 @@ import ( "os" "os/exec" "os/user" + "path/filepath" "runtime" "strings" "sync" @@ -99,7 +100,7 @@ func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool { return isUtilLinux } -// createSuCommand creates a command using su -l for privilege switching +// createSuCommand creates a command using su - for privilege switching. func (s *Server) createSuCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) { if err := validateUsername(localUser.Username); err != nil { return nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err) @@ -110,7 +111,7 @@ func (s *Server) createSuCommand(logger *log.Entry, session ssh.Session, localUs return nil, fmt.Errorf("su command not available: %w", err) } - args := []string{"-l"} + args := []string{"-"} if hasPty && s.suSupportsPty { args = append(args, "--pty") } @@ -128,12 +129,19 @@ func (s *Server) createSuCommand(logger *log.Entry, session ssh.Session, localUs return cmd, nil } -// getShellCommandArgs returns the shell command and arguments for executing a command string +// getShellCommandArgs returns the shell command and arguments for executing a command string. func (s *Server) getShellCommandArgs(shell, cmdString string) []string { if cmdString == "" { - return []string{shell, "-l"} + return []string{shell} } - return []string{shell, "-l", "-c", cmdString} + return []string{shell, "-c", cmdString} +} + +// createShellCommand creates an exec.Cmd configured as a login shell by setting argv[0] to "-shellname". +func (s *Server) createShellCommand(ctx context.Context, shell string, args []string) *exec.Cmd { + cmd := exec.CommandContext(ctx, shell, args[1:]...) + cmd.Args[0] = "-" + filepath.Base(shell) + return cmd } // prepareCommandEnv prepares environment variables for command execution on Unix diff --git a/client/ssh/server/compatibility_test.go b/client/ssh/server/compatibility_test.go index ac9123c8299..7fe2d6c5e63 100644 --- a/client/ssh/server/compatibility_test.go +++ b/client/ssh/server/compatibility_test.go @@ -4,11 +4,13 @@ import ( "context" "crypto/ed25519" "crypto/rand" + "errors" "fmt" "io" "net" "os" "os/exec" + "path/filepath" "runtime" "slices" "strings" @@ -67,10 +69,11 @@ func runTestExecutor() { var cmd *exec.Cmd if command == "" { - cmd = exec.Command(shell, "-l") + cmd = exec.Command(shell) } else { - cmd = exec.Command(shell, "-l", "-c", command) + cmd = exec.Command(shell, "-c", command) } + cmd.Args[0] = "-" + filepath.Base(shell) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr @@ -503,49 +506,33 @@ func TestSSHPtyModes(t *testing.T) { } t.Run("command_default_no_pty", func(t *testing.T) { - // ssh host command - no PTY allocation (tests don't have TTY) args := append(slices.Clone(baseArgs), fmt.Sprintf("%s@%s", username, host), "echo", "no_pty_default") cmd := exec.Command("ssh", args...) output, err := cmd.CombinedOutput() - if err != nil { - t.Logf("Command (default no PTY) failed: %v, output: %s", err, output) - return - } + require.NoError(t, err, "Command (default no PTY) failed: %s", output) assert.Contains(t, string(output), "no_pty_default") }) t.Run("command_explicit_no_pty", func(t *testing.T) { - // ssh -T host command - explicit no PTY args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "echo", "explicit_no_pty") cmd := exec.Command("ssh", args...) output, err := cmd.CombinedOutput() - if err != nil { - t.Logf("Command (-T explicit no PTY) failed: %v, output: %s", err, output) - return - } + require.NoError(t, err, "Command (-T explicit no PTY) failed: %s", output) assert.Contains(t, string(output), "explicit_no_pty") }) t.Run("command_force_pty", func(t *testing.T) { - // ssh -t host command - force PTY allocation - // Use -tt to really force PTY even without TTY on our end args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "echo", "force_pty") cmd := exec.Command("ssh", args...) output, err := cmd.CombinedOutput() - if err != nil { - t.Logf("Command (-tt force PTY) failed: %v, output: %s", err, output) - // PTY allocation might fail in some test environments, that's OK - return - } - // PTY output might have \r\n line endings + require.NoError(t, err, "Command (-tt force PTY) failed: %s", output) assert.Contains(t, string(output), "force_pty") }) t.Run("shell_explicit_no_pty", func(t *testing.T) { - // ssh -T host - shell without PTY (our new behavior matching OpenSSH) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -558,79 +545,50 @@ func TestSSHPtyModes(t *testing.T) { stdout, err := cmd.StdoutPipe() require.NoError(t, err) - err = cmd.Start() - if err != nil { - t.Logf("Shell (-T no PTY) start failed: %v", err) - return - } + require.NoError(t, cmd.Start(), "Shell (-T no PTY) start failed") - // Send commands through the non-PTY shell go func() { defer stdin.Close() time.Sleep(100 * time.Millisecond) - if _, err := stdin.Write([]byte("echo shell_no_pty_test\n")); err != nil { - t.Logf("write echo command: %v", err) - return - } + _, err := stdin.Write([]byte("echo shell_no_pty_test\n")) + assert.NoError(t, err, "write echo command") time.Sleep(100 * time.Millisecond) - if _, err := stdin.Write([]byte("exit 0\n")); err != nil { - t.Logf("write exit command: %v", err) - } + _, err = stdin.Write([]byte("exit 0\n")) + assert.NoError(t, err, "write exit command") }() output, _ := io.ReadAll(stdout) err = cmd.Wait() - if err != nil { - t.Logf("Shell (-T no PTY) failed: %v, output: %s", err, output) - return - } + require.NoError(t, err, "Shell (-T no PTY) failed: %s", output) assert.Contains(t, string(output), "shell_no_pty_test") }) t.Run("exit_code_preserved_no_pty", func(t *testing.T) { - // Verify exit codes work with -T args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "exit", "42") cmd := exec.Command("ssh", args...) err := cmd.Run() - if err == nil { - t.Log("Command succeeded unexpectedly (exit 0)") - return - } - - exitErr, ok := err.(*exec.ExitError) - if !ok { - t.Logf("Non-exit error: %v", err) - return - } + require.Error(t, err, "Command should exit with non-zero") + var exitErr *exec.ExitError + require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err) assert.Equal(t, 42, exitErr.ExitCode(), "Exit code should be preserved with -T") }) t.Run("exit_code_preserved_with_pty", func(t *testing.T) { - // Verify exit codes work with -tt (use sh for portability - bash may not be installed on FreeBSD) args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "sh -c 'exit 43'") cmd := exec.Command("ssh", args...) err := cmd.Run() - if err == nil { - t.Log("PTY command succeeded unexpectedly (exit 0)") - return - } - - exitErr, ok := err.(*exec.ExitError) - if !ok { - t.Logf("PTY exit code test: non-exit error: %v", err) - return - } + require.Error(t, err, "PTY command should exit with non-zero") + var exitErr *exec.ExitError + require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err) assert.Equal(t, 43, exitErr.ExitCode(), "Exit code should be preserved with -tt") }) t.Run("stderr_works_no_pty", func(t *testing.T) { - // Verify stderr is separate from stdout without PTY - // Pass the entire command as a single string for proper shell interpretation args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "sh -c 'echo stdout_msg; echo stderr_msg >&2'") cmd := exec.Command("ssh", args...) @@ -639,30 +597,19 @@ func TestSSHPtyModes(t *testing.T) { cmd.Stdout = &stdout cmd.Stderr = &stderr - err := cmd.Run() - if err != nil { - t.Logf("stderr test failed: %v", err) - return - } - + require.NoError(t, cmd.Run(), "stderr test failed") assert.Contains(t, stdout.String(), "stdout_msg", "stdout should have stdout_msg") assert.Contains(t, stderr.String(), "stderr_msg", "stderr should have stderr_msg") assert.NotContains(t, stdout.String(), "stderr_msg", "stdout should NOT have stderr_msg") }) t.Run("stderr_merged_with_pty", func(t *testing.T) { - // With PTY, stderr is merged into stdout args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "sh -c 'echo stdout_msg; echo stderr_msg >&2'") cmd := exec.Command("ssh", args...) output, err := cmd.CombinedOutput() - if err != nil { - t.Logf("PTY stderr test failed: %v, output: %s", err, output) - return - } - - // With PTY, both messages should appear in combined output + require.NoError(t, err, "PTY stderr test failed: %s", output) assert.Contains(t, string(output), "stdout_msg") assert.Contains(t, string(output), "stderr_msg") }) diff --git a/client/ssh/server/executor_unix.go b/client/ssh/server/executor_unix.go index 0c856d98016..ee0b0ff781d 100644 --- a/client/ssh/server/executor_unix.go +++ b/client/ssh/server/executor_unix.go @@ -8,6 +8,7 @@ import ( "fmt" "os" "os/exec" + "path/filepath" "runtime" "strings" "syscall" @@ -230,20 +231,21 @@ func (pd *PrivilegeDropper) ExecuteWithPrivilegeDrop(ctx context.Context, config var execCmd *exec.Cmd if config.Command == "" { - execCmd = exec.CommandContext(ctx, config.Shell, "-l") + execCmd = exec.CommandContext(ctx, config.Shell) } else { - execCmd = exec.CommandContext(ctx, config.Shell, "-l", "-c", config.Command) + execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command) } + execCmd.Args[0] = "-" + filepath.Base(config.Shell) execCmd.Stdin = os.Stdin execCmd.Stdout = os.Stdout execCmd.Stderr = os.Stderr if config.Command == "" { - log.Tracef("executing login shell: %s -l", execCmd.Path) + log.Tracef("executing login shell: %s", execCmd.Path) } else { cmdParts := strings.Fields(config.Command) safeCmd := safeLogCommand(cmdParts) - log.Tracef("executing %s -l -c %s", execCmd.Path, safeCmd) + log.Tracef("executing %s -c %s", execCmd.Path, safeCmd) } if err := execCmd.Run(); err != nil { var exitError *exec.ExitError diff --git a/client/ssh/server/executor_windows.go b/client/ssh/server/executor_windows.go index 92a521ef6a1..51c995ec3cb 100644 --- a/client/ssh/server/executor_windows.go +++ b/client/ssh/server/executor_windows.go @@ -586,7 +586,7 @@ func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceTo return cmd, primaryToken, nil } -// createSuCommand creates a command using su -l for privilege switching (Windows stub) +// createSuCommand creates a command using su - for privilege switching (Windows stub). func (s *Server) createSuCommand(*log.Entry, ssh.Session, *user.User, bool) (*exec.Cmd, error) { return nil, fmt.Errorf("su command not available on Windows") } diff --git a/client/ssh/server/server_test.go b/client/ssh/server/server_test.go index 6610685393d..89fab717fe7 100644 --- a/client/ssh/server/server_test.go +++ b/client/ssh/server/server_test.go @@ -405,12 +405,14 @@ func TestSSHServer_WindowsShellHandling(t *testing.T) { assert.Equal(t, "-Command", args[1]) assert.Equal(t, "echo test", args[2]) } else { - // Test Unix shell behavior args := server.getShellCommandArgs("/bin/sh", "echo test") assert.Equal(t, "/bin/sh", args[0]) - assert.Equal(t, "-l", args[1]) - assert.Equal(t, "-c", args[2]) - assert.Equal(t, "echo test", args[3]) + assert.Equal(t, "-c", args[1]) + assert.Equal(t, "echo test", args[2]) + + args = server.getShellCommandArgs("/bin/sh", "") + assert.Equal(t, "/bin/sh", args[0]) + assert.Len(t, args, 1) } } diff --git a/client/ssh/server/userswitching_unix.go b/client/ssh/server/userswitching_unix.go index d9940065f76..d80b77042a3 100644 --- a/client/ssh/server/userswitching_unix.go +++ b/client/ssh/server/userswitching_unix.go @@ -233,7 +233,7 @@ func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.Use shell := getUserShell(localUser.Uid) args := s.getShellCommandArgs(shell, session.RawCommand()) - cmd := exec.CommandContext(session.Context(), args[0], args[1:]...) + cmd := s.createShellCommand(session.Context(), shell, args) cmd.Dir = localUser.HomeDir cmd.Env = s.preparePtyEnv(localUser, ptyReq, session)