Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions client/cmd/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,11 @@ func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward
return err
}

cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr)
if err := validateDestinationPort(remoteAddr); err != nil {
return fmt.Errorf("invalid remote address: %w", err)
}

log.Debugf("Local port forwarding: %s -> %s", localAddr, remoteAddr)

go func() {
if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) {
Expand All @@ -652,7 +656,11 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar
return err
}

cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr)
if err := validateDestinationPort(localAddr); err != nil {
return fmt.Errorf("invalid local address: %w", err)
}

log.Debugf("Remote port forwarding: %s -> %s", remoteAddr, localAddr)

go func() {
if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) {
Expand All @@ -663,6 +671,35 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar
return nil
}

// validateDestinationPort checks that the destination address has a valid port.
// Port 0 is only valid for bind addresses (where the OS picks an available port),
// not for destination addresses where we need to connect.
func validateDestinationPort(addr string) error {
if strings.HasPrefix(addr, "/") || strings.HasPrefix(addr, "./") {
return nil
}

_, portStr, err := net.SplitHostPort(addr)
if err != nil {
return fmt.Errorf("parse address %s: %w", addr, err)
}

port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("invalid port %s: %w", portStr, err)
}

if port == 0 {
return fmt.Errorf("port 0 is not valid for destination address")
}

if port < 0 || port > 65535 {
return fmt.Errorf("port %d out of range (1-65535)", port)
}

return nil
}

// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80".
// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket".
func parsePortForwardSpec(spec string) (string, string, error) {
Expand Down
13 changes: 11 additions & 2 deletions client/proto/daemon.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions client/proto/daemon.proto
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ message SSHSessionInfo {
string remoteAddress = 2;
string command = 3;
string jwtUsername = 4;
repeated string portForwards = 5;
}

// SSHServerState contains the latest state of the SSH server
Expand Down
1 change: 1 addition & 0 deletions client/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,7 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
RemoteAddress: session.RemoteAddress,
Command: session.Command,
JwtUsername: session.JWTUsername,
PortForwards: session.PortForwards,
})
}

Expand Down
28 changes: 6 additions & 22 deletions client/ssh/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"path/filepath"
Expand Down Expand Up @@ -557,8 +556,9 @@ func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {

channel, err := c.client.Dial("tcp", remoteAddr)
if err != nil {
if strings.Contains(err.Error(), "administratively prohibited") {
_, _ = fmt.Fprintf(os.Stderr, "channel open failed: administratively prohibited: port forwarding is disabled\n")
var openErr *ssh.OpenChannelError
if errors.As(err, &openErr) && openErr.Reason == ssh.Prohibited {
_, _ = fmt.Fprintf(os.Stderr, "channel open failed: port forwarding is disabled\n")
} else {
log.Debugf("local port forwarding to %s failed: %v", remoteAddr, err)
}
Expand All @@ -570,15 +570,7 @@ func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
}
}()

go func() {
if _, err := io.Copy(channel, localConn); err != nil {
log.Debugf("local forward copy error (local->remote): %v", err)
}
}()

if _, err := io.Copy(localConn, channel); err != nil {
log.Debugf("local forward copy error (remote->local): %v", err)
}
nbssh.BidirectionalCopy(localConn, channel)
}

// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr
Expand Down Expand Up @@ -633,7 +625,7 @@ func (c *Client) sendTCPIPForwardRequest(req tcpipForwardMsg) error {
return fmt.Errorf("send tcpip-forward request: %w", err)
}
if !ok {
return fmt.Errorf("remote port forwarding denied by server (check if --allow-ssh-remote-port-forwarding is enabled)")
return fmt.Errorf("remote port forwarding denied by server")
}
return nil
}
Expand Down Expand Up @@ -692,15 +684,7 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st
}
}()

go func() {
if _, err := io.Copy(localConn, channel); err != nil {
log.Debugf("remote forward copy error (remote->local): %v", err)
}
}()

if _, err := io.Copy(channel, localConn); err != nil {
log.Debugf("remote forward copy error (local->remote): %v", err)
}
nbssh.BidirectionalCopy(localConn, channel)
}

// tcpipForwardMsg represents the structure for tcpip-forward requests
Expand Down
49 changes: 49 additions & 0 deletions client/ssh/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,52 @@ func buildAddressList(hostname string, remote net.Addr) []string {
}
return addresses
}

// BidirectionalCopy copies data bidirectionally between two io.ReadWriter connections.
// It waits for both directions to complete before returning.
// The caller is responsible for closing the connections.
func BidirectionalCopy(rw1, rw2 io.ReadWriter) {
done := make(chan struct{}, 2)

go func() {
_, _ = io.Copy(rw2, rw1)
done <- struct{}{}
}()

go func() {
_, _ = io.Copy(rw1, rw2)
done <- struct{}{}
}()

<-done
<-done
}

// BidirectionalCopyWithContext copies data bidirectionally between two io.ReadWriteCloser connections.
// It waits for both directions to complete or for context cancellation before returning.
// Both connections are closed when the function returns.
func BidirectionalCopyWithContext(ctx context.Context, conn1, conn2 io.ReadWriteCloser) {
done := make(chan struct{}, 2)

go func() {
_, _ = io.Copy(conn2, conn1)
done <- struct{}{}
}()

go func() {
_, _ = io.Copy(conn1, conn2)
done <- struct{}{}
}()

select {
case <-ctx.Done():
case <-done:
select {
case <-ctx.Done():
case <-done:
}
}

_ = conn1.Close()
_ = conn2.Close()
}
Loading
Loading