Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions client/anonymize/anonymize.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,15 @@ func (a *Anonymizer) AnonymizeURI(uri string) string {
if u.Opaque != "" {
host, port, err := net.SplitHostPort(u.Opaque)
if err == nil {
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
anonymizedHost = net.JoinHostPort(a.AnonymizeDomain(host), port)
} else {
anonymizedHost = a.AnonymizeDomain(u.Opaque)
}
u.Opaque = anonymizedHost
} else if u.Host != "" {
host, port, err := net.SplitHostPort(u.Host)
if err == nil {
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
anonymizedHost = net.JoinHostPort(a.AnonymizeDomain(host), port)
} else {
anonymizedHost = a.AnonymizeDomain(u.Host)
}
Expand Down
10 changes: 10 additions & 0 deletions client/anonymize/anonymize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,16 @@ func TestAnonymizeString_IPAddresses(t *testing.T) {
input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43",
expect: "IPv4: 198.51.100.1 and IPv6: 2001:db8:ffff::1",
},
{
name: "STUN URI with IPv6",
input: "Connecting to stun:[2001:db8::ff00:42]:3478",
expect: "Connecting to stun:[2001:db8:ffff::]:3478",
},
{
name: "HTTPS URI with IPv6",
input: "Visit https://[2001:db8::ff00:42]:443/path",
expect: "Visit https://[2001:db8:ffff::]:443/path",
},
}

for _, tc := range tests {
Expand Down
2 changes: 1 addition & 1 deletion client/cmd/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ func parseHostnameAndCommand(args []string) error {
}

func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
target := fmt.Sprintf("%s:%d", addr, port)
target := net.JoinHostPort(strings.Trim(addr, "[]"), strconv.Itoa(port))
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
KnownHostsFile: knownHostsFile,
IdentityFile: identityFile,
Expand Down
48 changes: 48 additions & 0 deletions client/firewall/uspfilter/conntrack/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,54 @@ import (
var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()

func TestConnKey_String(t *testing.T) {
tests := []struct {
name string
key ConnKey
expect string
}{
{
name: "IPv4",
key: ConnKey{
SrcIP: netip.MustParseAddr("192.168.1.1"),
DstIP: netip.MustParseAddr("10.0.0.1"),
SrcPort: 12345,
DstPort: 80,
},
expect: "192.168.1.1:12345 → 10.0.0.1:80",
},
{
name: "IPv6",
key: ConnKey{
SrcIP: netip.MustParseAddr("2001:db8::1"),
DstIP: netip.MustParseAddr("2001:db8::2"),
SrcPort: 54321,
DstPort: 443,
},
expect: "[2001:db8::1]:54321 → [2001:db8::2]:443",
},
{
name: "IPv4-mapped IPv6 unmaps",
key: ConnKey{
SrcIP: netip.MustParseAddr("::ffff:10.0.0.1"),
DstIP: netip.MustParseAddr("::ffff:10.0.0.2"),
SrcPort: 1000,
DstPort: 2000,
},
expect: "10.0.0.1:1000 → 10.0.0.2:2000",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := tc.key.String()
if got != tc.expect {
t.Errorf("got %q, want %q", got, tc.expect)
}
})
}
}

// Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) {
b.Run("TCPHighLoad", func(b *testing.B) {
Expand Down
5 changes: 3 additions & 2 deletions client/firewall/uspfilter/conntrack/icmp.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"net"
"net/netip"
"strconv"
"sync"
"time"

Expand Down Expand Up @@ -137,12 +138,12 @@ func (info ICMPInfo) parseOriginalPacket() string {
case nftypes.TCP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
return "TCP " + net.JoinHostPort(srcIP.String(), strconv.Itoa(int(srcPort))) + " → " + net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort)))

case nftypes.UDP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
return "UDP " + net.JoinHostPort(srcIP.String(), strconv.Itoa(int(srcPort))) + " → " + net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort)))

case nftypes.ICMP:
icmpType := transportData[0]
Expand Down
36 changes: 36 additions & 0 deletions client/firewall/uspfilter/conntrack/icmp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,42 @@ import (
"testing"
)

func TestICMPConnKey_String(t *testing.T) {
tests := []struct {
name string
key ICMPConnKey
expect string
}{
{
name: "IPv4",
key: ICMPConnKey{
SrcIP: netip.MustParseAddr("192.168.1.1"),
DstIP: netip.MustParseAddr("10.0.0.1"),
ID: 1234,
},
expect: "192.168.1.1 → 10.0.0.1 (id 1234)",
},
{
name: "IPv6",
key: ICMPConnKey{
SrcIP: netip.MustParseAddr("2001:db8::1"),
DstIP: netip.MustParseAddr("2001:db8::2"),
ID: 5678,
},
expect: "2001:db8::1 → 2001:db8::2 (id 5678)",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := tc.key.String()
if got != tc.expect {
t.Errorf("got %q, want %q", got, tc.expect)
}
})
}
}

func BenchmarkICMPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
Expand Down
4 changes: 3 additions & 1 deletion client/firewall/uspfilter/tracer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package uspfilter

import (
"fmt"
"net"
"net/netip"
"strconv"
"time"

"github.com/google/gopacket"
Expand Down Expand Up @@ -443,7 +445,7 @@ func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP n
trace.AddResult(StageRouteACL, msg, allowed)

if allowed && m.forwarder.Load() != nil {
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
m.addForwardingResult(trace, "proxy-remote", net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort))), true)
}

trace.AddResult(StageCompleted, msgProcessingCompleted, allowed)
Expand Down
4 changes: 2 additions & 2 deletions client/internal/profilemanager/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/tls"
"encoding/json"
"fmt"
"net"
"net/url"
"os"
"os/user"
Expand Down Expand Up @@ -759,8 +760,7 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri
return config, nil
}

newURL, err := parseURL("Management URL", fmt.Sprintf("%s://%s:%d",
config.ManagementURL.Scheme, defaultManagementURL.Hostname(), 443))
newURL, err := parseURL("Management URL", fmt.Sprintf("%s://%s", config.ManagementURL.Scheme, net.JoinHostPort(defaultManagementURL.Hostname(), "443")))
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion client/internal/relay/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"net"
"strconv"
"sync"
"time"

Expand Down Expand Up @@ -257,7 +258,7 @@ func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr stri
}
}()

turnServerAddr := fmt.Sprintf("%s:%d", uri.Host, uri.Port)
turnServerAddr := net.JoinHostPort(uri.Host, strconv.Itoa(uri.Port))

var conn net.PacketConn
switch uri.Proto {
Expand Down
7 changes: 5 additions & 2 deletions client/internal/rosenpass/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,9 @@ func findRandomAvailableUDPPort() (int, error) {
}
defer conn.Close()

splitAddress := strings.Split(conn.LocalAddr().String(), ":")
return strconv.Atoi(splitAddress[len(splitAddress)-1])
_, portStr, err := net.SplitHostPort(conn.LocalAddr().String())
if err != nil {
return 0, fmt.Errorf("parse local address %s: %w", conn.LocalAddr(), err)
}
return strconv.Atoi(portStr)
}
14 changes: 14 additions & 0 deletions client/internal/rosenpass/manager_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package rosenpass

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestFindRandomAvailableUDPPort(t *testing.T) {
port, err := findRandomAvailableUDPPort()
require.NoError(t, err)
require.Greater(t, port, 0)
require.LessOrEqual(t, port, 65535)
}
2 changes: 1 addition & 1 deletion client/ssh/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, ne
return
}

dest := fmt.Sprintf("%s:%d", payload.DestAddr, payload.DestPort)
dest := net.JoinHostPort(payload.DestAddr, strconv.Itoa(int(payload.DestPort)))
log.Debugf("local port forwarding: %s", dest)

backendClient, err := p.getOrCreateBackendClient(sshCtx, sshCtx.User())
Expand Down
29 changes: 15 additions & 14 deletions client/ssh/server/port_forwarding.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ func (s *Server) configurePortForwarding(server *ssh.Server) {
server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool {
logger := s.getRequestLogger(ctx)
if !allowLocal {
logger.Warnf("local port forwarding denied for %s:%d: disabled", dstHost, dstPort)
logger.Warnf("local port forwarding denied for %s: disabled", net.JoinHostPort(dstHost, strconv.Itoa(int(dstPort))))
return false
}

if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil {
logger.Warnf("local port forwarding denied for %s:%d: %v", dstHost, dstPort, err)
logger.Warnf("local port forwarding denied for %s: %v", net.JoinHostPort(dstHost, strconv.Itoa(int(dstPort))), err)
return false
}

Expand All @@ -71,12 +71,12 @@ func (s *Server) configurePortForwarding(server *ssh.Server) {
server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
logger := s.getRequestLogger(ctx)
if !allowRemote {
logger.Warnf("remote port forwarding denied for %s:%d: disabled", bindHost, bindPort)
logger.Warnf("remote port forwarding denied for %s: disabled", net.JoinHostPort(bindHost, strconv.Itoa(int(bindPort))))
return false
}

if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil {
logger.Warnf("remote port forwarding denied for %s:%d: %v", bindHost, bindPort, err)
logger.Warnf("remote port forwarding denied for %s: %v", net.JoinHostPort(bindHost, strconv.Itoa(int(bindPort))), err)
return false
}

Expand Down Expand Up @@ -183,15 +183,16 @@ func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *
return false, nil
}

key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
hostPort := net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port)))
key := forwardKey(hostPort)
if s.removeRemoteForwardListener(key) {
forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, payload.Port)
forwardAddr := "-R " + hostPort
s.removeConnectionPortForward(ctx.RemoteAddr(), forwardAddr)
logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port)
logger.Infof("remote port forwarding cancelled: %s", hostPort)
return true, nil
}

logger.Warnf("cancel-tcpip-forward failed: no listener found for %s:%d", payload.Host, payload.Port)
logger.Warnf("cancel-tcpip-forward failed: no listener found for %s", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))))
return false, nil
}

Expand All @@ -201,7 +202,7 @@ func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, h

defer func() {
if err := ln.Close(); err != nil {
logger.Debugf("remote forward listener close error for %s:%d: %v", host, port, err)
logger.Debugf("remote forward listener close error for %s: %v", net.JoinHostPort(host, strconv.Itoa(int(port))), err)
}
}()

Expand Down Expand Up @@ -230,7 +231,7 @@ func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, h
}
go s.handleRemoteForwardConnection(ctx, result.conn, host, port)
case <-ctx.Done():
logger.Debugf("remote forward listener shutting down for %s:%d", host, port)
logger.Debugf("remote forward listener shutting down for %s", net.JoinHostPort(host, strconv.Itoa(int(port))))
return
}
}
Expand Down Expand Up @@ -311,17 +312,17 @@ func (s *Server) setupDirectForward(ctx ssh.Context, logger *log.Entry, sshConn
logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host)
}

key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
key := forwardKey(net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))))
s.storeRemoteForwardListener(key, ln)
Comment on lines +315 to 316
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use the allocated port for forwardKey to keep cancel lookups consistent.

When remote forwarding requests port 0, the server allocates actualPort. Storing the listener under payload.Port (0) can break cancel-tcpip-forward lookup when cancellation uses the allocated port.

🐛 Proposed fix
-	key := forwardKey(net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))))
+	key := forwardKey(net.JoinHostPort(payload.Host, strconv.Itoa(int(actualPort))))
 	s.storeRemoteForwardListener(key, ln)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
key := forwardKey(net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))))
s.storeRemoteForwardListener(key, ln)
key := forwardKey(net.JoinHostPort(payload.Host, strconv.Itoa(int(actualPort))))
s.storeRemoteForwardListener(key, ln)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@client/ssh/server/port_forwarding.go` around lines 315 - 316, The listener is
being stored under the requested port (payload.Port) which can be zero; change
storeRemoteForwardListener to use the actual allocated port instead so
cancel-tcpip-forward lookups match—construct the forward key with
forwardKey(net.JoinHostPort(payload.Host, strconv.Itoa(int(actualPort)))) (using
the actualPort returned/assigned by the listener) and call
s.storeRemoteForwardListener with that key and ln, ensuring any later cancel
handling that derives keys from the allocated port will find the stored
listener.


forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, actualPort)
forwardAddr := "-R " + net.JoinHostPort(payload.Host, strconv.Itoa(int(actualPort)))
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort)

response := make([]byte, 4)
binary.BigEndian.PutUint32(response, actualPort)

logger.Infof("remote port forwarding established: %s:%d", payload.Host, actualPort)
logger.Infof("remote port forwarding established: %s", net.JoinHostPort(payload.Host, strconv.Itoa(int(actualPort))))
return true, response
}

Expand Down Expand Up @@ -351,7 +352,7 @@ func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, h

channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr)
if err != nil {
logger.Debugf("open forward channel for %s:%d: %v", host, port, err)
logger.Debugf("open forward channel for %s: %v", net.JoinHostPort(host, strconv.Itoa(int(port))), err)
_ = conn.Close()
return
}
Expand Down
10 changes: 6 additions & 4 deletions client/ssh/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io"
"net"
"strconv"
"net/netip"
"slices"
"strings"
Expand Down Expand Up @@ -918,20 +919,21 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
s.mu.RUnlock()

if !allowLocal {
logger.Warnf("local port forwarding denied for %s:%d: disabled", payload.Host, payload.Port)
logger.Warnf("local port forwarding denied for %s: disabled", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))))
_ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled")
return
}

if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil {
logger.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err)
logger.Warnf("local port forwarding denied for %s: %v", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))), err)
_ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges")
return
}

forwardAddr := fmt.Sprintf("-L %s:%d", payload.Host, payload.Port)
hostPort := net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port)))
forwardAddr := "-L " + hostPort
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
logger.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
logger.Infof("local port forwarding: %s", hostPort)

ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
}
2 changes: 1 addition & 1 deletion combined/cmd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ func (c *CombinedConfig) autoConfigureClientSettings(exposedProto, exposedHost,
// Auto-configure local STUN servers for all ports
for _, port := range c.Server.StunPorts {
c.Management.Stuns = append(c.Management.Stuns, HostConfig{
URI: fmt.Sprintf("stun:%s:%d", exposedHost, port),
URI: "stun:" + net.JoinHostPort(strings.Trim(exposedHost, "[]"), fmt.Sprintf("%d", port)),
})
}
}
Expand Down
Loading
Loading