From 1b0b1db575894ebc5a613bdfe23a7ab010a4898b Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Mon, 26 Jan 2026 21:43:48 +0800 Subject: [PATCH 01/10] Add NAT-PMP/UPnP support --- client/internal/engine.go | 65 +++-- client/internal/peer/conn.go | 50 ++-- client/internal/peer/worker_ice.go | 83 +++++- client/internal/portforward/manager.go | 289 ++++++++++++++++++++ client/internal/portforward/manager_test.go | 156 +++++++++++ client/internal/portforward/state.go | 44 +++ client/server/state_generic.go | 5 + client/server/state_linux.go | 5 + go.mod | 4 + go.sum | 8 + 10 files changed, 655 insertions(+), 54 deletions(-) create mode 100644 client/internal/portforward/manager.go create mode 100644 client/internal/portforward/manager_test.go create mode 100644 client/internal/portforward/state.go diff --git a/client/internal/engine.go b/client/internal/engine.go index f0693e82cc2..b9c7953e0a8 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -43,6 +43,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer/guard" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/peerstore" + "github.com/netbirdio/netbird/client/internal/portforward" "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/rosenpass" @@ -199,9 +200,10 @@ type Engine struct { // checks are the client-applied posture checks that need to be evaluated on the client checks []*mgmProto.Checks - relayManager *relayClient.Manager - stateManager *statemanager.Manager - srWatcher *guard.SRWatcher + relayManager *relayClient.Manager + stateManager *statemanager.Manager + portForwardManager *portforward.Manager + srWatcher *guard.SRWatcher // Sync response persistence (protected by syncRespMux) syncRespMux sync.RWMutex @@ -249,25 +251,26 @@ func NewEngine( stateManager *statemanager.Manager, ) *Engine { engine := &Engine{ - clientCtx: clientCtx, - clientCancel: clientCancel, - signal: signalClient, - signaler: peer.NewSignaler(signalClient, config.WgPrivateKey), - mgmClient: mgmClient, - relayManager: relayManager, - peerStore: peerstore.NewConnStore(), - syncMsgMux: &sync.Mutex{}, - config: config, - mobileDep: mobileDep, - STUNs: []*stun.URI{}, - TURNs: []*stun.URI{}, - networkSerial: 0, - statusRecorder: statusRecorder, - stateManager: stateManager, - checks: checks, - connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), - probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), - jobExecutor: jobexec.NewExecutor(), + clientCtx: clientCtx, + clientCancel: clientCancel, + signal: signalClient, + signaler: peer.NewSignaler(signalClient, config.WgPrivateKey), + mgmClient: mgmClient, + relayManager: relayManager, + peerStore: peerstore.NewConnStore(), + syncMsgMux: &sync.Mutex{}, + config: config, + mobileDep: mobileDep, + STUNs: []*stun.URI{}, + TURNs: []*stun.URI{}, + networkSerial: 0, + statusRecorder: statusRecorder, + stateManager: stateManager, + portForwardManager: portforward.NewManager(stateManager), + checks: checks, + connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), + probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), + jobExecutor: jobexec.NewExecutor(), } log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String()) @@ -509,6 +512,9 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) // conntrack entries from being created before the rules are in place e.setupWGProxyNoTrack() + // Start after interface is up since port may have been resolved from 0 or changed if occupied + e.portForwardManager.Start(e.ctx, e.config.WgPort) + // Set the WireGuard interface for rosenpass after interface is up if e.rpManager != nil { e.rpManager.SetInterface(e.wgInterface) @@ -1526,12 +1532,13 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV } serviceDependencies := peer.ServiceDependencies{ - StatusRecorder: e.statusRecorder, - Signaler: e.signaler, - IFaceDiscover: e.mobileDep.IFaceDiscover, - RelayManager: e.relayManager, - SrWatcher: e.srWatcher, - Semaphore: e.connSemaphore, + StatusRecorder: e.statusRecorder, + Signaler: e.signaler, + IFaceDiscover: e.mobileDep.IFaceDiscover, + RelayManager: e.relayManager, + SrWatcher: e.srWatcher, + Semaphore: e.connSemaphore, + PortForwardManager: e.portForwardManager, } peerConn, err := peer.NewConn(config, serviceDependencies) if err != nil { @@ -1684,6 +1691,8 @@ func (e *Engine) close() { if e.rpManager != nil { _ = e.rpManager.Close() } + + e.portForwardManager.Stop() } func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) { diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 39133a6d3cd..4a9e2855ad4 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -22,6 +22,7 @@ import ( icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/peer/id" "github.com/netbirdio/netbird/client/internal/peer/worker" + "github.com/netbirdio/netbird/client/internal/portforward" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/route" relayClient "github.com/netbirdio/netbird/shared/relay/client" @@ -36,6 +37,7 @@ type ServiceDependencies struct { SrWatcher *guard.SRWatcher Semaphore *semaphoregroup.SemaphoreGroup PeerConnDispatcher *dispatcher.ConnectionDispatcher + PortForwardManager *portforward.Manager } type WgConfig struct { @@ -77,16 +79,17 @@ type ConnConfig struct { } type Conn struct { - Log *log.Entry - mu sync.Mutex - ctx context.Context - ctxCancel context.CancelFunc - config ConnConfig - statusRecorder *Status - signaler *Signaler - iFaceDiscover stdnet.ExternalIFaceDiscover - relayManager *relayClient.Manager - srWatcher *guard.SRWatcher + Log *log.Entry + mu sync.Mutex + ctx context.Context + ctxCancel context.CancelFunc + config ConnConfig + statusRecorder *Status + signaler *Signaler + iFaceDiscover stdnet.ExternalIFaceDiscover + relayManager *relayClient.Manager + srWatcher *guard.SRWatcher + portForwardManager *portforward.Manager onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) onDisconnected func(remotePeer string) @@ -132,19 +135,20 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) { dumpState := newStateDump(config.Key, connLog, services.StatusRecorder) var conn = &Conn{ - Log: connLog, - config: config, - statusRecorder: services.StatusRecorder, - signaler: services.Signaler, - iFaceDiscover: services.IFaceDiscover, - relayManager: services.RelayManager, - srWatcher: services.SrWatcher, - semaphore: services.Semaphore, - statusRelay: worker.NewAtomicStatus(), - statusICE: worker.NewAtomicStatus(), - dumpState: dumpState, - endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)), - wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState), + Log: connLog, + config: config, + statusRecorder: services.StatusRecorder, + signaler: services.Signaler, + iFaceDiscover: services.IFaceDiscover, + relayManager: services.RelayManager, + srWatcher: services.SrWatcher, + semaphore: services.Semaphore, + portForwardManager: services.PortForwardManager, + statusRelay: worker.NewAtomicStatus(), + statusICE: worker.NewAtomicStatus(), + dumpState: dumpState, + endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)), + wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState), } return conn, nil diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index b6b9d2cf447..15d5df0dd58 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/internal/peer/conntype" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" + "github.com/netbirdio/netbird/client/internal/portforward" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/route" ) @@ -60,6 +61,9 @@ type WorkerICE struct { // we record the last known state of the ICE agent to avoid duplicate on disconnected events lastKnownState ice.ConnectionState + + // portForwardAttempted tracks if we've already tried port forwarding this session + portForwardAttempted bool } func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) { @@ -210,6 +214,8 @@ func (w *WorkerICE) Close() { } func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) { + w.portForwardAttempted = false + agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd) if err != nil { return nil, fmt.Errorf("create agent: %w", err) @@ -362,6 +368,77 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) { w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err) } }() + + if candidate.Type() == ice.CandidateTypeServerReflexive { + w.injectPortForwardedCandidate(candidate) + } +} + +// injectPortForwardedCandidate signals an additional candidate using the pre-created port mapping. +func (w *WorkerICE) injectPortForwardedCandidate(srflxCandidate ice.Candidate) { + w.muxAgent.Lock() + if w.portForwardAttempted { + w.muxAgent.Unlock() + return + } + w.portForwardAttempted = true + w.muxAgent.Unlock() + + pfManager := w.conn.portForwardManager + if pfManager == nil || !pfManager.IsAvailable() { + return + } + + mapping := pfManager.GetMapping() + if mapping == nil { + return + } + + forwardedCandidate, err := w.createForwardedCandidate(srflxCandidate, mapping) + if err != nil { + w.log.Warnf("create forwarded candidate: %v", err) + return + } + + w.log.Debugf("injecting port-forwarded candidate: %s (mapping: %d -> %d via %s, priority: %d)", + forwardedCandidate.String(), mapping.InternalPort, mapping.ExternalPort, mapping.NATType, forwardedCandidate.Priority()) + + go func() { + if err := w.signaler.SignalICECandidate(forwardedCandidate, w.config.Key); err != nil { + w.log.Errorf("signal port-forwarded candidate: %v", err) + } + }() +} + +// createForwardedCandidate creates a new server reflexive candidate with the forwarded port. +// It uses the NAT gateway's external IP with the forwarded port. +func (w *WorkerICE) createForwardedCandidate(srflxCandidate ice.Candidate, mapping *portforward.Mapping) (ice.Candidate, error) { + var externalIP string + if mapping.ExternalIP != nil && !mapping.ExternalIP.IsUnspecified() { + externalIP = mapping.ExternalIP.String() + } else { + // Fallback to STUN-discovered address if NAT didn't provide external IP + externalIP = srflxCandidate.Address() + } + + // Per RFC 8445, the related address for srflx is the base (host candidate address). + // If the original srflx has unspecified related address, use its own address as base. + relAddr := srflxCandidate.RelatedAddress().Address + if relAddr == "" || relAddr == "0.0.0.0" || relAddr == "::" { + relAddr = srflxCandidate.Address() + } + + priority := srflxCandidate.Priority() + 1000 + + return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ + Network: srflxCandidate.NetworkType().String(), + Address: externalIP, + Port: int(mapping.ExternalPort), + Component: srflxCandidate.Component(), + Priority: priority, + RelAddr: relAddr, + RelPort: int(mapping.InternalPort), + }) } func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) { @@ -403,10 +480,10 @@ func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) { if !lok || !rok { continue } - w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms", + w.log.Debugf("successful ICE path %s: [%s %s %s:%d] <-> [%s %s %s:%d] rtt=%.3fms", sessionID, - local.NetworkType(), local.Type(), local.Address(), - remote.NetworkType(), remote.Type(), remote.Address(), + local.NetworkType(), local.Type(), local.Address(), local.Port(), + remote.NetworkType(), remote.Type(), remote.Address(), remote.Port(), stat.CurrentRoundTripTime*1000) } } diff --git a/client/internal/portforward/manager.go b/client/internal/portforward/manager.go new file mode 100644 index 00000000000..2a0b6873222 --- /dev/null +++ b/client/internal/portforward/manager.go @@ -0,0 +1,289 @@ +package portforward + +import ( + "context" + "fmt" + "net" + "os" + "strconv" + "sync" + "time" + + "github.com/libp2p/go-nat" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +const ( + defaultMappingTTL = 2 * time.Hour + renewalInterval = defaultMappingTTL / 2 + discoveryTimeout = 10 * time.Second + mappingDescription = "NetBird" + + envDisableNATMapper = "NB_DISABLE_NAT_MAPPER" +) + +type Mapping struct { + Protocol string + InternalPort uint16 + ExternalPort uint16 + ExternalIP net.IP + NATType string +} + +type Manager struct { + ctx context.Context + cancel context.CancelFunc + stateManager *statemanager.Manager + + mu sync.RWMutex + gateway nat.NAT + mapping *Mapping + + wgPort uint16 + wg sync.WaitGroup +} + +func NewManager(stateManager *statemanager.Manager) *Manager { + return &Manager{ + stateManager: stateManager, + } +} + +// Start begins async discovery and mapping creation for the given WireGuard port. +// This does not block - use GetMapping() to check if mapping is ready. +func (m *Manager) Start(ctx context.Context, wgPort int) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.cancel != nil { + return + } + + if isDisabledByEnv() { + log.Infof("NAT port mapper disabled via %s", envDisableNATMapper) + return + } + + m.ctx, m.cancel = context.WithCancel(ctx) + m.wgPort = uint16(wgPort) + + m.stateManager.RegisterState(&State{}) + + m.wg.Add(1) + go m.run() +} + +func isDisabledByEnv() bool { + val := os.Getenv(envDisableNATMapper) + if val == "" { + return false + } + + disabled, err := strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", envDisableNATMapper, err) + return false + } + return disabled +} + +func (m *Manager) run() { + defer m.wg.Done() + + if err := m.stateManager.LoadState(&State{}); err != nil { + log.Warnf("failed to load port forward state: %v", err) + } + + var residualState *State + if existing := m.stateManager.GetState(&State{}); existing != nil { + if state, ok := existing.(*State); ok && state.InternalPort != 0 { + residualState = state + } + } + + discoverCtx, discoverCancel := context.WithTimeout(m.ctx, discoveryTimeout) + defer discoverCancel() + + gateway, err := nat.DiscoverGateway(discoverCtx) + if err != nil { + log.Infof("NAT gateway discovery failed: %v (port forwarding disabled)", err) + return + } + + m.mu.Lock() + m.gateway = gateway + m.mu.Unlock() + + log.Infof("discovered NAT gateway: %s", gateway.Type()) + + if residualState != nil { + if err := m.cleanupResidual(residualState); err != nil { + log.Warnf("failed to cleanup residual mapping: %v", err) + } + } + + if err := m.createMapping(); err != nil { + log.Warnf("failed to create port mapping: %v", err) + return + } + + m.renewLoop() +} + +func (m *Manager) cleanupResidual(state *State) error { + m.mu.RLock() + gateway := m.gateway + m.mu.RUnlock() + + if gateway == nil { + return nil + } + + ctx, cancel := context.WithTimeout(m.ctx, 10*time.Second) + defer cancel() + + if err := gateway.DeletePortMapping(ctx, state.Protocol, int(state.InternalPort)); err != nil { + return fmt.Errorf("delete residual mapping: %w", err) + } + + log.Infof("cleaned up residual port mapping for port %d", state.InternalPort) + + if err := m.stateManager.UpdateState(&State{}); err != nil { + return fmt.Errorf("clear state after cleanup: %w", err) + } + + return nil +} + +func (m *Manager) createMapping() error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.gateway == nil { + return nil + } + + ctx, cancel := context.WithTimeout(m.ctx, 30*time.Second) + defer cancel() + + externalPort, err := m.gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, defaultMappingTTL) + if err != nil { + return err + } + + externalIP, err := m.gateway.GetExternalAddress() + if err != nil { + log.Debugf("failed to get external address: %v", err) + } + + m.mapping = &Mapping{ + Protocol: "udp", + InternalPort: m.wgPort, + ExternalPort: uint16(externalPort), + ExternalIP: externalIP, + NATType: m.gateway.Type(), + } + + log.Infof("created port mapping: %d -> %d via %s (external IP: %s)", + m.wgPort, externalPort, m.gateway.Type(), externalIP) + + return m.persistStateLocked() +} + +func (m *Manager) Stop() { + m.mu.Lock() + cancel := m.cancel + m.cancel = nil + m.mu.Unlock() + + if cancel != nil { + cancel() + } + + m.wg.Wait() + + m.mu.Lock() + if err := m.persistStateLocked(); err != nil { + log.Debugf("persist state on stop: %v", err) + } + m.mu.Unlock() +} + +// GetMapping returns the current mapping if ready, nil otherwise +func (m *Manager) GetMapping() *Mapping { + m.mu.RLock() + defer m.mu.RUnlock() + + if m.mapping == nil { + return nil + } + + mapping := *m.mapping + return &mapping +} + +// IsAvailable returns true if port forwarding is available and mapping exists +func (m *Manager) IsAvailable() bool { + m.mu.RLock() + defer m.mu.RUnlock() + + return m.mapping != nil +} + +func (m *Manager) renewLoop() { + ticker := time.NewTicker(renewalInterval) + defer ticker.Stop() + + for { + select { + case <-m.ctx.Done(): + return + case <-ticker.C: + if err := m.renewMapping(); err != nil { + log.Warnf("failed to renew port mapping: %v", err) + } + } + } +} + +func (m *Manager) renewMapping() error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.mapping == nil || m.gateway == nil { + return nil + } + + ctx, cancel := context.WithTimeout(m.ctx, 30*time.Second) + defer cancel() + + externalPort, err := m.gateway.AddPortMapping(ctx, m.mapping.Protocol, int(m.mapping.InternalPort), mappingDescription, defaultMappingTTL) + if err != nil { + return fmt.Errorf("add port mapping: %w", err) + } + + if uint16(externalPort) != m.mapping.ExternalPort { + log.Warnf("external port changed on renewal: %d -> %d (candidate may be stale)", + m.mapping.ExternalPort, externalPort) + m.mapping.ExternalPort = uint16(externalPort) + } + + log.Debugf("renewed port mapping: %d -> %d", m.mapping.InternalPort, m.mapping.ExternalPort) + return nil +} + +func (m *Manager) persistStateLocked() error { + var state *State + if m.mapping != nil { + state = &State{ + InternalPort: m.mapping.InternalPort, + Protocol: m.mapping.Protocol, + } + } else { + state = &State{} + } + + return m.stateManager.UpdateState(state) +} diff --git a/client/internal/portforward/manager_test.go b/client/internal/portforward/manager_test.go new file mode 100644 index 00000000000..3c1422fdffb --- /dev/null +++ b/client/internal/portforward/manager_test.go @@ -0,0 +1,156 @@ +package portforward + +import ( + "context" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +type mockNAT struct { + natType string + deviceAddr net.IP + externalAddr net.IP + internalAddr net.IP + mappings map[int]int + addMappingErr error + deleteMappingErr error +} + +func newMockNAT() *mockNAT { + return &mockNAT{ + natType: "Mock-NAT", + deviceAddr: net.ParseIP("192.168.1.1"), + externalAddr: net.ParseIP("203.0.113.50"), + internalAddr: net.ParseIP("192.168.1.100"), + mappings: make(map[int]int), + } +} + +func (m *mockNAT) Type() string { + return m.natType +} + +func (m *mockNAT) GetDeviceAddress() (net.IP, error) { + return m.deviceAddr, nil +} + +func (m *mockNAT) GetExternalAddress() (net.IP, error) { + return m.externalAddr, nil +} + +func (m *mockNAT) GetInternalAddress() (net.IP, error) { + return m.internalAddr, nil +} + +func (m *mockNAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, description string, timeout time.Duration) (int, error) { + if m.addMappingErr != nil { + return 0, m.addMappingErr + } + externalPort := internalPort + m.mappings[internalPort] = externalPort + return externalPort, nil +} + +func (m *mockNAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error { + if m.deleteMappingErr != nil { + return m.deleteMappingErr + } + delete(m.mappings, internalPort) + return nil +} + +func setupTestManager(t *testing.T) (*Manager, context.CancelFunc) { + tmpDir := t.TempDir() + statePath := tmpDir + "/state.json" + sm := statemanager.New(statePath) + sm.Start() + t.Cleanup(func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + sm.Stop(ctx) + }) + + m := NewManager(sm) + m.gateway = newMockNAT() + + ctx, cancel := context.WithCancel(context.Background()) + m.ctx = ctx + m.cancel = cancel + m.wgPort = 51820 + + sm.RegisterState(&State{}) + + return m, cancel +} + +func TestManager_CreateMapping(t *testing.T) { + m, cancel := setupTestManager(t) + defer cancel() + + err := m.createMapping() + require.NoError(t, err) + + mapping := m.GetMapping() + require.NotNil(t, mapping) + + assert.Equal(t, "udp", mapping.Protocol) + assert.Equal(t, uint16(51820), mapping.InternalPort) + assert.Equal(t, uint16(51820), mapping.ExternalPort) + assert.Equal(t, "Mock-NAT", mapping.NATType) + assert.Equal(t, net.ParseIP("203.0.113.50").To4(), mapping.ExternalIP.To4()) +} + +func TestManager_GetMapping_ReturnsNilWhenNotReady(t *testing.T) { + tmpDir := t.TempDir() + statePath := tmpDir + "/state.json" + sm := statemanager.New(statePath) + + m := NewManager(sm) + + assert.Nil(t, m.GetMapping()) +} + +func TestManager_IsAvailable(t *testing.T) { + tmpDir := t.TempDir() + statePath := tmpDir + "/state.json" + sm := statemanager.New(statePath) + + m := NewManager(sm) + + // Initially not available (no mapping) + assert.False(t, m.IsAvailable()) + + // Set gateway but no mapping - still not available + m.gateway = newMockNAT() + assert.False(t, m.IsAvailable()) + + // Add mapping - now available + m.mapping = &Mapping{InternalPort: 51820} + assert.True(t, m.IsAvailable()) + + // Clear mapping - not available again + m.mapping = nil + assert.False(t, m.IsAvailable()) +} + +func TestState_Cleanup(t *testing.T) { + state := &State{ + Protocol: "udp", + InternalPort: 51820, + } + + // Cleanup should not error even if NAT discovery fails + err := state.Cleanup() + assert.NoError(t, err) +} + +func TestState_Name(t *testing.T) { + state := &State{} + assert.Equal(t, "port_forward_state", state.Name()) +} diff --git a/client/internal/portforward/state.go b/client/internal/portforward/state.go new file mode 100644 index 00000000000..b88ee5dc6e1 --- /dev/null +++ b/client/internal/portforward/state.go @@ -0,0 +1,44 @@ +package portforward + +import ( + "context" + "fmt" + + "github.com/libp2p/go-nat" + log "github.com/sirupsen/logrus" +) + +// State is persisted only for crash recovery cleanup +type State struct { + InternalPort uint16 `json:"internal_port,omitempty"` + Protocol string `json:"protocol,omitempty"` +} + +func (s *State) Name() string { + return "port_forward_state" +} + +// Cleanup implements statemanager.CleanableState for crash recovery +func (s *State) Cleanup() error { + if s.InternalPort == 0 { + return nil + } + + log.Infof("cleaning up stale port mapping for port %d", s.InternalPort) + + ctx, cancel := context.WithTimeout(context.Background(), discoveryTimeout) + defer cancel() + + gateway, err := nat.DiscoverGateway(ctx) + if err != nil { + // Discovery failure is not an error - gateway may not exist + log.Debugf("cleanup: no gateway found: %v", err) + return nil + } + + if err := gateway.DeletePortMapping(ctx, s.Protocol, int(s.InternalPort)); err != nil { + return fmt.Errorf("delete port mapping: %w", err) + } + + return nil +} diff --git a/client/server/state_generic.go b/client/server/state_generic.go index 980ba0cda42..3f794b611d1 100644 --- a/client/server/state_generic.go +++ b/client/server/state_generic.go @@ -9,6 +9,11 @@ import ( "github.com/netbirdio/netbird/client/ssh/config" ) +// registerStates registers all states that need crash recovery cleanup. +// Note: portforward.State is intentionally NOT registered here to avoid blocking startup +// for up to 10 seconds during NAT gateway discovery when no gateway is present. +// The gateway reference cannot be persisted across restarts, so cleanup requires re-discovery. +// Port forward cleanup is handled by the Manager during normal operation instead. func registerStates(mgr *statemanager.Manager) { mgr.RegisterState(&dns.ShutdownState{}) mgr.RegisterState(&systemops.ShutdownState{}) diff --git a/client/server/state_linux.go b/client/server/state_linux.go index 019477d8eae..655edfc5337 100644 --- a/client/server/state_linux.go +++ b/client/server/state_linux.go @@ -11,6 +11,11 @@ import ( "github.com/netbirdio/netbird/client/ssh/config" ) +// registerStates registers all states that need crash recovery cleanup. +// Note: portforward.State is intentionally NOT registered here to avoid blocking startup +// for up to 10 seconds during NAT gateway discovery when no gateway is present. +// The gateway reference cannot be persisted across restarts, so cleanup requires re-discovery. +// Port forward cleanup is handled by the Manager during normal operation instead. func registerStates(mgr *statemanager.Manager) { mgr.RegisterState(&dns.ShutdownState{}) mgr.RegisterState(&systemops.ShutdownState{}) diff --git a/go.mod b/go.mod index 2a6c311ce02..a284a6b4ddd 100644 --- a/go.mod +++ b/go.mod @@ -63,6 +63,7 @@ require ( github.com/hashicorp/go-version v1.6.0 github.com/jackc/pgx/v5 v5.5.5 github.com/libdns/route53 v1.5.0 + github.com/libp2p/go-nat v0.2.0 github.com/libp2p/go-netroute v0.2.1 github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 github.com/mdlayher/socket v0.5.1 @@ -203,10 +204,12 @@ require ( github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-uuid v1.0.3 // indirect github.com/huandu/xstrings v1.5.0 // indirect + github.com/huin/goupnp v1.2.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/jackpal/go-nat-pmp v1.0.2 // indirect github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect @@ -216,6 +219,7 @@ require ( github.com/kelseyhightower/envconfig v1.4.0 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect + github.com/koron/go-ssdp v0.0.4 // indirect github.com/kr/fs v0.1.0 // indirect github.com/lib/pq v1.10.9 // indirect github.com/libdns/libdns v0.2.2 // indirect diff --git a/go.sum b/go.sum index 17e5c8ffaa4..f1e45108cdf 100644 --- a/go.sum +++ b/go.sum @@ -287,6 +287,8 @@ github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09 github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= +github.com/huin/goupnp v1.2.0 h1:uOKW26NG1hsSSbXIZ1IR7XP9Gjd1U8pnLaCMgntmkmY= +github.com/huin/goupnp v1.2.0/go.mod h1:gnGPsThkYa7bFi/KWmEysQRf48l2dvR5bxr2OFckNX8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= @@ -297,6 +299,8 @@ github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus= +github.com/jackpal/go-nat-pmp v1.0.2/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc= github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= @@ -334,6 +338,8 @@ github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYW github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/koron/go-ssdp v0.0.4 h1:1IDwrghSKYM7yLf7XCzbByg2sJ/JcNOZRXS2jczTwz0= +github.com/koron/go-ssdp v0.0.4/go.mod h1:oDXq+E5IL5q0U8uSBcoAXzTzInwy5lEgC91HoKtbmZk= github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -352,6 +358,8 @@ github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s= github.com/libdns/libdns v0.2.2/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ= github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA= github.com/libdns/route53 v1.5.0/go.mod h1:joT4hKmaTNKHEwb7GmZ65eoDz1whTu7KKYPS8ZqIh6Q= +github.com/libp2p/go-nat v0.2.0 h1:Tyz+bUFAYqGyJ/ppPPymMGbIgNRH+WqC5QrT5fKrrGk= +github.com/libp2p/go-nat v0.2.0/go.mod h1:3MJr+GRpRkyT65EpVPBstXLvOlAPzUVlG6Pwg9ohLJk= github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 h1:J56rFEfUTFT9j9CiRXhi1r8lUJ4W5idG3CiaBZGojNU= github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81/go.mod h1:RD8ML/YdXctQ7qbcizZkw5mZ6l8Ogrl1dodBzVJduwI= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= From 28f2c1b3b50a94c28ce5ab97d2392664642fb981 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 28 Jan 2026 23:41:31 +0800 Subject: [PATCH 02/10] Exclude JS --- client/internal/portforward/manager.go | 2 + client/internal/portforward/manager_js.go | 41 +++++++++++++++++++++ client/internal/portforward/manager_test.go | 4 +- client/internal/portforward/state.go | 2 + 4 files changed, 48 insertions(+), 1 deletion(-) create mode 100644 client/internal/portforward/manager_js.go diff --git a/client/internal/portforward/manager.go b/client/internal/portforward/manager.go index 2a0b6873222..b0706e2b7bb 100644 --- a/client/internal/portforward/manager.go +++ b/client/internal/portforward/manager.go @@ -1,3 +1,5 @@ +//go:build !js + package portforward import ( diff --git a/client/internal/portforward/manager_js.go b/client/internal/portforward/manager_js.go new file mode 100644 index 00000000000..27dd40c06af --- /dev/null +++ b/client/internal/portforward/manager_js.go @@ -0,0 +1,41 @@ +package portforward + +import ( + "context" + "net" + + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +// Mapping represents port mapping information. +type Mapping struct { + Protocol string + InternalPort uint16 + ExternalPort uint16 + ExternalIP net.IP + NATType string +} + +// Manager is a stub for js/wasm builds where NAT-PMP/UPnP is not supported. +type Manager struct{} + +// NewManager returns a stub manager for js/wasm builds. +func NewManager(_ *statemanager.Manager) *Manager { + return &Manager{} +} + +// Start is a no-op on js/wasm. +func (m *Manager) Start(context.Context, int) {} + +// Stop is a no-op on js/wasm. +func (m *Manager) Stop() {} + +// GetMapping always returns nil on js/wasm. +func (m *Manager) GetMapping() *Mapping { + return nil +} + +// IsAvailable always returns false on js/wasm. +func (m *Manager) IsAvailable() bool { + return false +} diff --git a/client/internal/portforward/manager_test.go b/client/internal/portforward/manager_test.go index 3c1422fdffb..6cb4de6b0d5 100644 --- a/client/internal/portforward/manager_test.go +++ b/client/internal/portforward/manager_test.go @@ -1,3 +1,5 @@ +//go:build !js + package portforward import ( @@ -73,7 +75,7 @@ func setupTestManager(t *testing.T) (*Manager, context.CancelFunc) { t.Cleanup(func() { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - sm.Stop(ctx) + require.NoError(t, sm.Stop(ctx)) }) m := NewManager(sm) diff --git a/client/internal/portforward/state.go b/client/internal/portforward/state.go index b88ee5dc6e1..d443e26711c 100644 --- a/client/internal/portforward/state.go +++ b/client/internal/portforward/state.go @@ -1,3 +1,5 @@ +//go:build !js + package portforward import ( From 05ba63a5bd9a8e4cec30765eaea9b9fea614358f Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 28 Jan 2026 23:50:25 +0800 Subject: [PATCH 03/10] Copy extensions --- client/internal/peer/worker_ice.go | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 15d5df0dd58..d4cd4f7b7b2 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -430,7 +430,7 @@ func (w *WorkerICE) createForwardedCandidate(srflxCandidate ice.Candidate, mappi priority := srflxCandidate.Priority() + 1000 - return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ + candidate, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ Network: srflxCandidate.NetworkType().String(), Address: externalIP, Port: int(mapping.ExternalPort), @@ -439,6 +439,20 @@ func (w *WorkerICE) createForwardedCandidate(srflxCandidate ice.Candidate, mappi RelAddr: relAddr, RelPort: int(mapping.InternalPort), }) + if err != nil { + return nil, fmt.Errorf("create candidate: %w", err) + } + + for _, e := range srflxCandidate.Extensions() { + if e.Key == ice.ExtensionKeyCandidateID { + e.Value = srflxCandidate.ID() + } + if err := candidate.AddExtension(e); err != nil { + return nil, fmt.Errorf("add extension: %w", err) + } + } + + return candidate, nil } func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) { From 124aa1a87531c6d82146346538c98d69cfba80e9 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 28 Jan 2026 23:58:45 +0800 Subject: [PATCH 04/10] Address review --- client/internal/engine.go | 2 +- client/internal/portforward/manager.go | 24 +++++++++++++++++++---- client/internal/portforward/manager_js.go | 2 +- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index b9c7953e0a8..46e4d7e0c71 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -513,7 +513,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) e.setupWGProxyNoTrack() // Start after interface is up since port may have been resolved from 0 or changed if occupied - e.portForwardManager.Start(e.ctx, e.config.WgPort) + e.portForwardManager.Start(e.ctx, uint16(e.config.WgPort)) // Set the WireGuard interface for rosenpass after interface is up if e.rpManager != nil { diff --git a/client/internal/portforward/manager.go b/client/internal/portforward/manager.go index b0706e2b7bb..d1774a1ef07 100644 --- a/client/internal/portforward/manager.go +++ b/client/internal/portforward/manager.go @@ -55,7 +55,7 @@ func NewManager(stateManager *statemanager.Manager) *Manager { // Start begins async discovery and mapping creation for the given WireGuard port. // This does not block - use GetMapping() to check if mapping is ready. -func (m *Manager) Start(ctx context.Context, wgPort int) { +func (m *Manager) Start(ctx context.Context, wgPort uint16) { m.mu.Lock() defer m.mu.Unlock() @@ -69,7 +69,7 @@ func (m *Manager) Start(ctx context.Context, wgPort int) { } m.ctx, m.cancel = context.WithCancel(ctx) - m.wgPort = uint16(wgPort) + m.wgPort = wgPort m.stateManager.RegisterState(&State{}) @@ -207,10 +207,26 @@ func (m *Manager) Stop() { m.wg.Wait() m.mu.Lock() + defer m.mu.Unlock() + + if m.gateway == nil || m.mapping == nil { + return + } + + ctx, ctxCancel := context.WithTimeout(context.Background(), 3*time.Second) + defer ctxCancel() + + if err := m.gateway.DeletePortMapping(ctx, m.mapping.Protocol, int(m.mapping.InternalPort)); err != nil { + log.Debugf("delete port mapping on stop: %v", err) + return + } + + log.Infof("deleted port mapping for port %d", m.mapping.InternalPort) + m.mapping = nil + if err := m.persistStateLocked(); err != nil { - log.Debugf("persist state on stop: %v", err) + log.Debugf("clear state on stop: %v", err) } - m.mu.Unlock() } // GetMapping returns the current mapping if ready, nil otherwise diff --git a/client/internal/portforward/manager_js.go b/client/internal/portforward/manager_js.go index 27dd40c06af..2ebfa2dcda7 100644 --- a/client/internal/portforward/manager_js.go +++ b/client/internal/portforward/manager_js.go @@ -25,7 +25,7 @@ func NewManager(_ *statemanager.Manager) *Manager { } // Start is a no-op on js/wasm. -func (m *Manager) Start(context.Context, int) {} +func (m *Manager) Start(context.Context, uint16) {} // Stop is a no-op on js/wasm. func (m *Manager) Stop() {} From c5c2e02580eeb294d870cbe86526fa7a8229eb08 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 29 Jan 2026 00:04:18 +0800 Subject: [PATCH 05/10] Mock out test dep --- client/internal/portforward/manager_test.go | 14 +++++++++++++- client/internal/portforward/state.go | 6 +++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/client/internal/portforward/manager_test.go b/client/internal/portforward/manager_test.go index 6cb4de6b0d5..732054975d5 100644 --- a/client/internal/portforward/manager_test.go +++ b/client/internal/portforward/manager_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/libp2p/go-nat" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -142,14 +143,25 @@ func TestManager_IsAvailable(t *testing.T) { } func TestState_Cleanup(t *testing.T) { + origDiscover := discoverGateway + defer func() { discoverGateway = origDiscover }() + + mockGateway := newMockNAT() + discoverGateway = func(ctx context.Context) (nat.NAT, error) { + return mockGateway, nil + } + state := &State{ Protocol: "udp", InternalPort: 51820, } - // Cleanup should not error even if NAT discovery fails err := state.Cleanup() assert.NoError(t, err) + + // Verify the mapping was deleted + _, exists := mockGateway.mappings[51820] + assert.False(t, exists, "mapping should be deleted after cleanup") } func TestState_Name(t *testing.T) { diff --git a/client/internal/portforward/state.go b/client/internal/portforward/state.go index d443e26711c..3f939751a73 100644 --- a/client/internal/portforward/state.go +++ b/client/internal/portforward/state.go @@ -10,6 +10,10 @@ import ( log "github.com/sirupsen/logrus" ) +// discoverGateway is the function used for NAT gateway discovery. +// It can be replaced in tests to avoid real network operations. +var discoverGateway = nat.DiscoverGateway + // State is persisted only for crash recovery cleanup type State struct { InternalPort uint16 `json:"internal_port,omitempty"` @@ -31,7 +35,7 @@ func (s *State) Cleanup() error { ctx, cancel := context.WithTimeout(context.Background(), discoveryTimeout) defer cancel() - gateway, err := nat.DiscoverGateway(ctx) + gateway, err := discoverGateway(ctx) if err != nil { // Discovery failure is not an error - gateway may not exist log.Debugf("cleanup: no gateway found: %v", err) From 2efc7134a3aef7a015581bf5f994f6a07a79651a Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 29 Jan 2026 00:06:44 +0800 Subject: [PATCH 06/10] Add prio docs --- client/internal/peer/worker_ice.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index d4cd4f7b7b2..041be95dafb 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -428,6 +428,8 @@ func (w *WorkerICE) createForwardedCandidate(srflxCandidate ice.Candidate, mappi relAddr = srflxCandidate.Address() } + // Arbitrary +1000 boost on top of RFC 8445 priority to favor port-forwarded candidates + // over regular srflx during ICE connectivity checks. priority := srflxCandidate.Priority() + 1000 candidate, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ From 7fa7f1c7d74f91e668dfcc263f026c326f4aa728 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 30 Jan 2026 22:45:57 +0800 Subject: [PATCH 07/10] Address review --- client/internal/portforward/manager.go | 51 +++++++++++++++----------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/client/internal/portforward/manager.go b/client/internal/portforward/manager.go index d1774a1ef07..d496891348f 100644 --- a/client/internal/portforward/manager.go +++ b/client/internal/portforward/manager.go @@ -68,6 +68,11 @@ func (m *Manager) Start(ctx context.Context, wgPort uint16) { return } + if wgPort == 0 { + log.Warnf("invalid WireGuard port 0; NAT mapping disabled") + return + } + m.ctx, m.cancel = context.WithCancel(ctx) m.wgPort = wgPort @@ -77,20 +82,6 @@ func (m *Manager) Start(ctx context.Context, wgPort uint16) { go m.run() } -func isDisabledByEnv() bool { - val := os.Getenv(envDisableNATMapper) - if val == "" { - return false - } - - disabled, err := strconv.ParseBool(val) - if err != nil { - log.Warnf("failed to parse %s: %v", envDisableNATMapper, err) - return false - } - return disabled -} - func (m *Manager) run() { defer m.wg.Done() @@ -194,10 +185,16 @@ func (m *Manager) createMapping() error { return m.persistStateLocked() } +// Stop cancels the manager and attempts to delete the port mapping. +// After Stop returns, the manager cannot be restarted. func (m *Manager) Stop() { m.mu.Lock() cancel := m.cancel + gateway := m.gateway + mapping := m.mapping m.cancel = nil + m.gateway = nil + m.mapping = nil m.mu.Unlock() if cancel != nil { @@ -206,25 +203,21 @@ func (m *Manager) Stop() { m.wg.Wait() - m.mu.Lock() - defer m.mu.Unlock() - - if m.gateway == nil || m.mapping == nil { + if gateway == nil || mapping == nil { return } ctx, ctxCancel := context.WithTimeout(context.Background(), 3*time.Second) defer ctxCancel() - if err := m.gateway.DeletePortMapping(ctx, m.mapping.Protocol, int(m.mapping.InternalPort)); err != nil { + if err := gateway.DeletePortMapping(ctx, mapping.Protocol, int(mapping.InternalPort)); err != nil { log.Debugf("delete port mapping on stop: %v", err) return } - log.Infof("deleted port mapping for port %d", m.mapping.InternalPort) - m.mapping = nil + log.Infof("deleted port mapping for port %d", mapping.InternalPort) - if err := m.persistStateLocked(); err != nil { + if err := m.stateManager.UpdateState(&State{}); err != nil { log.Debugf("clear state on stop: %v", err) } } @@ -305,3 +298,17 @@ func (m *Manager) persistStateLocked() error { return m.stateManager.UpdateState(state) } + +func isDisabledByEnv() bool { + val := os.Getenv(envDisableNATMapper) + if val == "" { + return false + } + + disabled, err := strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", envDisableNATMapper, err) + return false + } + return disabled +} From 3f93b5dba22df7ec4f1d1945b2487ef00319049f Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 30 Jan 2026 23:12:58 +0800 Subject: [PATCH 08/10] Unlock Stop --- client/internal/portforward/manager.go | 40 +++++++++++++++++++------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/client/internal/portforward/manager.go b/client/internal/portforward/manager.go index d496891348f..6a0c60bbdb7 100644 --- a/client/internal/portforward/manager.go +++ b/client/internal/portforward/manager.go @@ -152,35 +152,45 @@ func (m *Manager) cleanupResidual(state *State) error { func (m *Manager) createMapping() error { m.mu.Lock() - defer m.mu.Unlock() + gateway := m.gateway + wgPort := m.wgPort + m.mu.Unlock() - if m.gateway == nil { + if gateway == nil { return nil } ctx, cancel := context.WithTimeout(m.ctx, 30*time.Second) defer cancel() - externalPort, err := m.gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, defaultMappingTTL) + externalPort, err := gateway.AddPortMapping(ctx, "udp", int(wgPort), mappingDescription, defaultMappingTTL) if err != nil { return err } - externalIP, err := m.gateway.GetExternalAddress() + externalIP, err := gateway.GetExternalAddress() if err != nil { log.Debugf("failed to get external address: %v", err) } + m.mu.Lock() + defer m.mu.Unlock() + + if m.gateway != gateway { + log.Debugf("gateway changed during mapping creation, discarding result") + return nil + } + m.mapping = &Mapping{ Protocol: "udp", - InternalPort: m.wgPort, + InternalPort: wgPort, ExternalPort: uint16(externalPort), ExternalIP: externalIP, - NATType: m.gateway.Type(), + NATType: gateway.Type(), } log.Infof("created port mapping: %d -> %d via %s (external IP: %s)", - m.wgPort, externalPort, m.gateway.Type(), externalIP) + wgPort, externalPort, gateway.Type(), externalIP) return m.persistStateLocked() } @@ -261,20 +271,30 @@ func (m *Manager) renewLoop() { func (m *Manager) renewMapping() error { m.mu.Lock() - defer m.mu.Unlock() + gateway := m.gateway + mapping := m.mapping + m.mu.Unlock() - if m.mapping == nil || m.gateway == nil { + if mapping == nil || gateway == nil { return nil } ctx, cancel := context.WithTimeout(m.ctx, 30*time.Second) defer cancel() - externalPort, err := m.gateway.AddPortMapping(ctx, m.mapping.Protocol, int(m.mapping.InternalPort), mappingDescription, defaultMappingTTL) + externalPort, err := gateway.AddPortMapping(ctx, mapping.Protocol, int(mapping.InternalPort), mappingDescription, defaultMappingTTL) if err != nil { return fmt.Errorf("add port mapping: %w", err) } + m.mu.Lock() + defer m.mu.Unlock() + + if m.gateway != gateway || m.mapping == nil { + log.Debugf("state changed during mapping renewal, discarding result") + return nil + } + if uint16(externalPort) != m.mapping.ExternalPort { log.Warnf("external port changed on renewal: %d -> %d (candidate may be stale)", m.mapping.ExternalPort, externalPort) From 4d6b810f8b3ea04074bd1b84a190b0e35d7ac207 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 4 Mar 2026 15:04:28 +0100 Subject: [PATCH 09/10] Refactor port forward manager cleanup lifecycle (#5342) Replace async goroutine-based cleanup with a synchronous flow where Start runs cleanup inline after renewLoop exits. Use a stopCtx channel so GracefullyStop can pass its deadline-bounded context to Start's cleanup path. When no graceful stop occurs, Start fires cleanup in a background goroutine with a 10s timeout. Also fix GetMapping double Lock, renewMapping referencing undefined m.mu, cleanup referencing undefined variables, remove statemanager dependency, and align manager_js.go stub signatures. --- client/internal/dns/local/local_test.go | 8 +- client/internal/engine.go | 14 +- client/internal/peer/worker_ice.go | 2 +- client/internal/portforward/env.go | 26 ++ client/internal/portforward/manager.go | 311 ++++++++------------ client/internal/portforward/manager_js.go | 13 +- client/internal/portforward/manager_test.go | 2 +- 7 files changed, 163 insertions(+), 213 deletions(-) create mode 100644 client/internal/portforward/env.go diff --git a/client/internal/dns/local/local_test.go b/client/internal/dns/local/local_test.go index 73f70035f7b..2c6b7dbc3a8 100644 --- a/client/internal/dns/local/local_test.go +++ b/client/internal/dns/local/local_test.go @@ -1263,9 +1263,9 @@ func TestLocalResolver_AuthoritativeFlag(t *testing.T) { }) } -// TestLocalResolver_Stop tests cleanup on Stop +// TestLocalResolver_Stop tests cleanup on GracefullyStop func TestLocalResolver_Stop(t *testing.T) { - t.Run("Stop clears all state", func(t *testing.T) { + t.Run("GracefullyStop clears all state", func(t *testing.T) { resolver := NewResolver() resolver.Update([]nbdns.CustomZone{{ Domain: "example.com.", @@ -1285,7 +1285,7 @@ func TestLocalResolver_Stop(t *testing.T) { assert.False(t, resolver.isInManagedZone("host.example.com.")) }) - t.Run("Stop is safe to call multiple times", func(t *testing.T) { + t.Run("GracefullyStop is safe to call multiple times", func(t *testing.T) { resolver := NewResolver() resolver.Update([]nbdns.CustomZone{{ Domain: "example.com.", @@ -1299,7 +1299,7 @@ func TestLocalResolver_Stop(t *testing.T) { resolver.Stop() }) - t.Run("Stop cancels in-flight external resolution", func(t *testing.T) { + t.Run("GracefullyStop cancels in-flight external resolution", func(t *testing.T) { resolver := NewResolver() lookupStarted := make(chan struct{}) diff --git a/client/internal/engine.go b/client/internal/engine.go index 46e4d7e0c71..68b1f327f4d 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -266,7 +266,7 @@ func NewEngine( networkSerial: 0, statusRecorder: statusRecorder, stateManager: stateManager, - portForwardManager: portforward.NewManager(stateManager), + portForwardManager: portforward.NewManager(), checks: checks, connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), @@ -513,7 +513,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) e.setupWGProxyNoTrack() // Start after interface is up since port may have been resolved from 0 or changed if occupied - e.portForwardManager.Start(e.ctx, uint16(e.config.WgPort)) + e.shutdownWg.Add(1) + go func() { + defer e.shutdownWg.Done() + e.portForwardManager.Start(e.ctx, uint16(e.config.WgPort)) + }() // Set the WireGuard interface for rosenpass after interface is up if e.rpManager != nil { @@ -1692,7 +1696,11 @@ func (e *Engine) close() { _ = e.rpManager.Close() } - e.portForwardManager.Stop() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := e.portForwardManager.GracefullyStop(ctx); err != nil { + log.Warnf("failed to gracefully stop port forwarding manager: %s", err) + } } func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) { diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 041be95dafb..d0b21117da1 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -385,7 +385,7 @@ func (w *WorkerICE) injectPortForwardedCandidate(srflxCandidate ice.Candidate) { w.muxAgent.Unlock() pfManager := w.conn.portForwardManager - if pfManager == nil || !pfManager.IsAvailable() { + if pfManager == nil { return } diff --git a/client/internal/portforward/env.go b/client/internal/portforward/env.go new file mode 100644 index 00000000000..444a6b47834 --- /dev/null +++ b/client/internal/portforward/env.go @@ -0,0 +1,26 @@ +package portforward + +import ( + "os" + "strconv" + + log "github.com/sirupsen/logrus" +) + +const ( + envDisableNATMapper = "NB_DISABLE_NAT_MAPPER" +) + +func isDisabledByEnv() bool { + val := os.Getenv(envDisableNATMapper) + if val == "" { + return false + } + + disabled, err := strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", envDisableNATMapper, err) + return false + } + return disabled +} diff --git a/client/internal/portforward/manager.go b/client/internal/portforward/manager.go index 6a0c60bbdb7..d0768009cd3 100644 --- a/client/internal/portforward/manager.go +++ b/client/internal/portforward/manager.go @@ -6,15 +6,11 @@ import ( "context" "fmt" "net" - "os" - "strconv" "sync" "time" "github.com/libp2p/go-nat" log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -22,8 +18,6 @@ const ( renewalInterval = defaultMappingTTL / 2 discoveryTimeout = 10 * time.Second mappingDescription = "NetBird" - - envDisableNATMapper = "NB_DISABLE_NAT_MAPPER" ) type Mapping struct { @@ -35,106 +29,138 @@ type Mapping struct { } type Manager struct { - ctx context.Context - cancel context.CancelFunc - stateManager *statemanager.Manager + cancel context.CancelFunc - mu sync.RWMutex - gateway nat.NAT - mapping *Mapping + mapping *Mapping + mappingLock sync.Mutex wgPort uint16 - wg sync.WaitGroup + + done chan struct{} + stopCtx chan context.Context + + // protect exported functions + mu sync.Mutex } -func NewManager(stateManager *statemanager.Manager) *Manager { +func NewManager() *Manager { return &Manager{ - stateManager: stateManager, + stopCtx: make(chan context.Context, 1), } } -// Start begins async discovery and mapping creation for the given WireGuard port. -// This does not block - use GetMapping() to check if mapping is ready. func (m *Manager) Start(ctx context.Context, wgPort uint16) { m.mu.Lock() - defer m.mu.Unlock() - if m.cancel != nil { + m.mu.Unlock() return } if isDisabledByEnv() { log.Infof("NAT port mapper disabled via %s", envDisableNATMapper) + m.mu.Unlock() return } if wgPort == 0 { log.Warnf("invalid WireGuard port 0; NAT mapping disabled") + m.mu.Unlock() return } - - m.ctx, m.cancel = context.WithCancel(ctx) m.wgPort = wgPort - m.stateManager.RegisterState(&State{}) + m.done = make(chan struct{}) + defer close(m.done) - m.wg.Add(1) - go m.run() + ctx, m.cancel = context.WithCancel(ctx) + m.mu.Unlock() + + gateway, mapping, err := m.setup(ctx) + if err != nil { + log.Errorf("failed to setup NAT port mapping: %v", err) + + return + } + + m.mappingLock.Lock() + m.mapping = mapping + m.mappingLock.Unlock() + + m.renewLoop(ctx, gateway) + + select { + case cleanupCtx := <-m.stopCtx: + // block the Start while cleaned up gracefully + m.cleanup(cleanupCtx, gateway) + default: + // return Start immediately and cleanup in background + cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 10*time.Second) + go func() { + defer cleanupCancel() + m.cleanup(cleanupCtx, gateway) + }() + } } -func (m *Manager) run() { - defer m.wg.Done() +// GetMapping returns the current mapping if ready, nil otherwise +func (m *Manager) GetMapping() *Mapping { + m.mappingLock.Lock() + defer m.mappingLock.Unlock() - if err := m.stateManager.LoadState(&State{}); err != nil { - log.Warnf("failed to load port forward state: %v", err) + if m.mapping == nil { + return nil } - var residualState *State - if existing := m.stateManager.GetState(&State{}); existing != nil { - if state, ok := existing.(*State); ok && state.InternalPort != 0 { - residualState = state - } + mapping := *m.mapping + return &mapping +} + +// GracefullyStop cancels the manager and attempts to delete the port mapping. +// After GracefullyStop returns, the manager cannot be restarted. +func (m *Manager) GracefullyStop(ctx context.Context) error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.cancel == nil { + return nil } - discoverCtx, discoverCancel := context.WithTimeout(m.ctx, discoveryTimeout) + // Send cleanup context before cancelling, so Start picks it up after renewLoop exits. + m.startTearDown(ctx) + + m.cancel() + m.cancel = nil + + select { + case <-ctx.Done(): + return ctx.Err() + case <-m.done: + return nil + } +} + +func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) { + discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout) defer discoverCancel() gateway, err := nat.DiscoverGateway(discoverCtx) if err != nil { log.Infof("NAT gateway discovery failed: %v (port forwarding disabled)", err) - return + return nil, nil, err } - m.mu.Lock() - m.gateway = gateway - m.mu.Unlock() - log.Infof("discovered NAT gateway: %s", gateway.Type()) - if residualState != nil { - if err := m.cleanupResidual(residualState); err != nil { - log.Warnf("failed to cleanup residual mapping: %v", err) - } - } - - if err := m.createMapping(); err != nil { + mapping, err := m.createMapping(ctx, gateway) + if err != nil { log.Warnf("failed to create port mapping: %v", err) - return + return nil, nil, err } - - m.renewLoop() + return gateway, mapping, nil } -func (m *Manager) cleanupResidual(state *State) error { - m.mu.RLock() - gateway := m.gateway - m.mu.RUnlock() - - if gateway == nil { - return nil - } - - ctx, cancel := context.WithTimeout(m.ctx, 10*time.Second) +func (m *Manager) cleanupResidual(ctx context.Context, gateway nat.NAT, state *State) error { + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() if err := gateway.DeletePortMapping(ctx, state.Protocol, int(state.InternalPort)); err != nil { @@ -142,193 +168,90 @@ func (m *Manager) cleanupResidual(state *State) error { } log.Infof("cleaned up residual port mapping for port %d", state.InternalPort) - - if err := m.stateManager.UpdateState(&State{}); err != nil { - return fmt.Errorf("clear state after cleanup: %w", err) - } - return nil } -func (m *Manager) createMapping() error { - m.mu.Lock() - gateway := m.gateway - wgPort := m.wgPort - m.mu.Unlock() - - if gateway == nil { - return nil - } - - ctx, cancel := context.WithTimeout(m.ctx, 30*time.Second) +func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - externalPort, err := gateway.AddPortMapping(ctx, "udp", int(wgPort), mappingDescription, defaultMappingTTL) + externalPort, err := gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, defaultMappingTTL) if err != nil { - return err + return nil, err } externalIP, err := gateway.GetExternalAddress() if err != nil { log.Debugf("failed to get external address: %v", err) + // todo return with err? } - m.mu.Lock() - defer m.mu.Unlock() - - if m.gateway != gateway { - log.Debugf("gateway changed during mapping creation, discarding result") - return nil - } - - m.mapping = &Mapping{ + mapping := &Mapping{ Protocol: "udp", - InternalPort: wgPort, + InternalPort: m.wgPort, ExternalPort: uint16(externalPort), ExternalIP: externalIP, NATType: gateway.Type(), } log.Infof("created port mapping: %d -> %d via %s (external IP: %s)", - wgPort, externalPort, gateway.Type(), externalIP) - - return m.persistStateLocked() -} - -// Stop cancels the manager and attempts to delete the port mapping. -// After Stop returns, the manager cannot be restarted. -func (m *Manager) Stop() { - m.mu.Lock() - cancel := m.cancel - gateway := m.gateway - mapping := m.mapping - m.cancel = nil - m.gateway = nil - m.mapping = nil - m.mu.Unlock() - - if cancel != nil { - cancel() - } - - m.wg.Wait() - - if gateway == nil || mapping == nil { - return - } - - ctx, ctxCancel := context.WithTimeout(context.Background(), 3*time.Second) - defer ctxCancel() - - if err := gateway.DeletePortMapping(ctx, mapping.Protocol, int(mapping.InternalPort)); err != nil { - log.Debugf("delete port mapping on stop: %v", err) - return - } - - log.Infof("deleted port mapping for port %d", mapping.InternalPort) - - if err := m.stateManager.UpdateState(&State{}); err != nil { - log.Debugf("clear state on stop: %v", err) - } + m.wgPort, externalPort, gateway.Type(), externalIP) + return mapping, nil } -// GetMapping returns the current mapping if ready, nil otherwise -func (m *Manager) GetMapping() *Mapping { - m.mu.RLock() - defer m.mu.RUnlock() - - if m.mapping == nil { - return nil - } - - mapping := *m.mapping - return &mapping -} - -// IsAvailable returns true if port forwarding is available and mapping exists -func (m *Manager) IsAvailable() bool { - m.mu.RLock() - defer m.mu.RUnlock() - - return m.mapping != nil -} - -func (m *Manager) renewLoop() { +func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT) { ticker := time.NewTicker(renewalInterval) defer ticker.Stop() for { select { - case <-m.ctx.Done(): + case <-ctx.Done(): return case <-ticker.C: - if err := m.renewMapping(); err != nil { + if err := m.renewMapping(ctx, gateway); err != nil { log.Warnf("failed to renew port mapping: %v", err) + continue } } } } -func (m *Manager) renewMapping() error { - m.mu.Lock() - gateway := m.gateway - mapping := m.mapping - m.mu.Unlock() - - if mapping == nil || gateway == nil { - return nil - } - - ctx, cancel := context.WithTimeout(m.ctx, 30*time.Second) +func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - externalPort, err := gateway.AddPortMapping(ctx, mapping.Protocol, int(mapping.InternalPort), mappingDescription, defaultMappingTTL) + externalPort, err := gateway.AddPortMapping(ctx, m.mapping.Protocol, int(m.mapping.InternalPort), mappingDescription, defaultMappingTTL) if err != nil { return fmt.Errorf("add port mapping: %w", err) } - m.mu.Lock() - defer m.mu.Unlock() - - if m.gateway != gateway || m.mapping == nil { - log.Debugf("state changed during mapping renewal, discarding result") - return nil - } - if uint16(externalPort) != m.mapping.ExternalPort { - log.Warnf("external port changed on renewal: %d -> %d (candidate may be stale)", - m.mapping.ExternalPort, externalPort) + log.Warnf("external port changed on renewal: %d -> %d (candidate may be stale)", m.mapping.ExternalPort, externalPort) + m.mappingLock.Lock() m.mapping.ExternalPort = uint16(externalPort) + m.mappingLock.Unlock() } log.Debugf("renewed port mapping: %d -> %d", m.mapping.InternalPort, m.mapping.ExternalPort) return nil } -func (m *Manager) persistStateLocked() error { - var state *State - if m.mapping != nil { - state = &State{ - InternalPort: m.mapping.InternalPort, - Protocol: m.mapping.Protocol, - } - } else { - state = &State{} +func (m *Manager) cleanup(ctx context.Context, gateway nat.NAT) { + if m.mapping == nil { + return } - return m.stateManager.UpdateState(state) -} - -func isDisabledByEnv() bool { - val := os.Getenv(envDisableNATMapper) - if val == "" { - return false + if err := gateway.DeletePortMapping(ctx, m.mapping.Protocol, int(m.mapping.InternalPort)); err != nil { + log.Warnf("delete port mapping on stop: %v", err) + return } - disabled, err := strconv.ParseBool(val) - if err != nil { - log.Warnf("failed to parse %s: %v", envDisableNATMapper, err) - return false + log.Infof("deleted port mapping for port %d", m.mapping.InternalPort) +} + +func (m *Manager) startTearDown(ctx context.Context) { + select { + case m.stopCtx <- ctx: + default: } - return disabled } diff --git a/client/internal/portforward/manager_js.go b/client/internal/portforward/manager_js.go index 2ebfa2dcda7..e7fd4a64ed2 100644 --- a/client/internal/portforward/manager_js.go +++ b/client/internal/portforward/manager_js.go @@ -3,8 +3,6 @@ package portforward import ( "context" "net" - - "github.com/netbirdio/netbird/client/internal/statemanager" ) // Mapping represents port mapping information. @@ -20,22 +18,17 @@ type Mapping struct { type Manager struct{} // NewManager returns a stub manager for js/wasm builds. -func NewManager(_ *statemanager.Manager) *Manager { +func NewManager() *Manager { return &Manager{} } // Start is a no-op on js/wasm. func (m *Manager) Start(context.Context, uint16) {} -// Stop is a no-op on js/wasm. -func (m *Manager) Stop() {} +// GracefullyStop is a no-op on js/wasm. +func (m *Manager) GracefullyStop(context.Context) error { return nil } // GetMapping always returns nil on js/wasm. func (m *Manager) GetMapping() *Mapping { return nil } - -// IsAvailable always returns false on js/wasm. -func (m *Manager) IsAvailable() bool { - return false -} diff --git a/client/internal/portforward/manager_test.go b/client/internal/portforward/manager_test.go index 732054975d5..5548fa4d590 100644 --- a/client/internal/portforward/manager_test.go +++ b/client/internal/portforward/manager_test.go @@ -96,7 +96,7 @@ func TestManager_CreateMapping(t *testing.T) { m, cancel := setupTestManager(t) defer cancel() - err := m.createMapping() + err := m.createMapping(nil) require.NoError(t, err) mapping := m.GetMapping() From ec50347d10d60f81ec3c2053ac860f62dff3426b Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 17 Mar 2026 06:46:12 +0100 Subject: [PATCH 10/10] Address PR review feedback for NAT-PMP/UPnP support --- client/internal/peer/worker_ice.go | 16 ++-- client/internal/portforward/manager.go | 23 ++--- client/internal/portforward/manager_js.go | 6 +- client/internal/portforward/manager_test.go | 95 +++++++++------------ 4 files changed, 62 insertions(+), 78 deletions(-) diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 4d07ade9512..29bf5aaaa74 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -384,14 +384,6 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) { // injectPortForwardedCandidate signals an additional candidate using the pre-created port mapping. func (w *WorkerICE) injectPortForwardedCandidate(srflxCandidate ice.Candidate) { - w.muxAgent.Lock() - if w.portForwardAttempted { - w.muxAgent.Unlock() - return - } - w.portForwardAttempted = true - w.muxAgent.Unlock() - pfManager := w.conn.portForwardManager if pfManager == nil { return @@ -402,6 +394,14 @@ func (w *WorkerICE) injectPortForwardedCandidate(srflxCandidate ice.Candidate) { return } + w.muxAgent.Lock() + if w.portForwardAttempted { + w.muxAgent.Unlock() + return + } + w.portForwardAttempted = true + w.muxAgent.Unlock() + forwardedCandidate, err := w.createForwardedCandidate(srflxCandidate, mapping) if err != nil { w.log.Warnf("create forwarded candidate: %v", err) diff --git a/client/internal/portforward/manager.go b/client/internal/portforward/manager.go index d0768009cd3..019c2ad8670 100644 --- a/client/internal/portforward/manager.go +++ b/client/internal/portforward/manager.go @@ -159,18 +159,6 @@ func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) { return gateway, mapping, nil } -func (m *Manager) cleanupResidual(ctx context.Context, gateway nat.NAT, state *State) error { - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - - if err := gateway.DeletePortMapping(ctx, state.Protocol, int(state.InternalPort)); err != nil { - return fmt.Errorf("delete residual mapping: %w", err) - } - - log.Infof("cleaned up residual port mapping for port %d", state.InternalPort) - return nil -} - func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping, error) { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() @@ -237,16 +225,21 @@ func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error { } func (m *Manager) cleanup(ctx context.Context, gateway nat.NAT) { - if m.mapping == nil { + m.mappingLock.Lock() + mapping := m.mapping + m.mapping = nil + m.mappingLock.Unlock() + + if mapping == nil { return } - if err := gateway.DeletePortMapping(ctx, m.mapping.Protocol, int(m.mapping.InternalPort)); err != nil { + if err := gateway.DeletePortMapping(ctx, mapping.Protocol, int(mapping.InternalPort)); err != nil { log.Warnf("delete port mapping on stop: %v", err) return } - log.Infof("deleted port mapping for port %d", m.mapping.InternalPort) + log.Infof("deleted port mapping for port %d", mapping.InternalPort) } func (m *Manager) startTearDown(ctx context.Context) { diff --git a/client/internal/portforward/manager_js.go b/client/internal/portforward/manager_js.go index e7fd4a64ed2..d5db147f240 100644 --- a/client/internal/portforward/manager_js.go +++ b/client/internal/portforward/manager_js.go @@ -22,8 +22,10 @@ func NewManager() *Manager { return &Manager{} } -// Start is a no-op on js/wasm. -func (m *Manager) Start(context.Context, uint16) {} +// Start is a no-op on js/wasm: NAT-PMP/UPnP is not available in browser environments. +func (m *Manager) Start(context.Context, uint16) { + // no NAT traversal in wasm +} // GracefullyStop is a no-op on js/wasm. func (m *Manager) GracefullyStop(context.Context) error { return nil } diff --git a/client/internal/portforward/manager_test.go b/client/internal/portforward/manager_test.go index 5548fa4d590..1029e87f5cc 100644 --- a/client/internal/portforward/manager_test.go +++ b/client/internal/portforward/manager_test.go @@ -11,8 +11,6 @@ import ( "github.com/libp2p/go-nat" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/netbirdio/netbird/client/internal/statemanager" ) type mockNAT struct { @@ -68,38 +66,13 @@ func (m *mockNAT) DeletePortMapping(ctx context.Context, protocol string, intern return nil } -func setupTestManager(t *testing.T) (*Manager, context.CancelFunc) { - tmpDir := t.TempDir() - statePath := tmpDir + "/state.json" - sm := statemanager.New(statePath) - sm.Start() - t.Cleanup(func() { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - require.NoError(t, sm.Stop(ctx)) - }) - - m := NewManager(sm) - m.gateway = newMockNAT() - - ctx, cancel := context.WithCancel(context.Background()) - m.ctx = ctx - m.cancel = cancel - m.wgPort = 51820 - - sm.RegisterState(&State{}) - - return m, cancel -} - func TestManager_CreateMapping(t *testing.T) { - m, cancel := setupTestManager(t) - defer cancel() + m := NewManager() + m.wgPort = 51820 - err := m.createMapping(nil) + gateway := newMockNAT() + mapping, err := m.createMapping(context.Background(), gateway) require.NoError(t, err) - - mapping := m.GetMapping() require.NotNil(t, mapping) assert.Equal(t, "udp", mapping.Protocol) @@ -110,36 +83,52 @@ func TestManager_CreateMapping(t *testing.T) { } func TestManager_GetMapping_ReturnsNilWhenNotReady(t *testing.T) { - tmpDir := t.TempDir() - statePath := tmpDir + "/state.json" - sm := statemanager.New(statePath) + m := NewManager() + assert.Nil(t, m.GetMapping()) +} - m := NewManager(sm) +func TestManager_GetMapping_ReturnsCopy(t *testing.T) { + m := NewManager() + m.mapping = &Mapping{ + Protocol: "udp", + InternalPort: 51820, + ExternalPort: 51820, + } - assert.Nil(t, m.GetMapping()) + mapping := m.GetMapping() + require.NotNil(t, mapping) + assert.Equal(t, uint16(51820), mapping.InternalPort) + + // Mutating the returned copy should not affect the manager's mapping. + mapping.ExternalPort = 9999 + assert.Equal(t, uint16(51820), m.GetMapping().ExternalPort) } -func TestManager_IsAvailable(t *testing.T) { - tmpDir := t.TempDir() - statePath := tmpDir + "/state.json" - sm := statemanager.New(statePath) +func TestManager_Cleanup_DeletesMapping(t *testing.T) { + m := NewManager() + m.mapping = &Mapping{ + Protocol: "udp", + InternalPort: 51820, + ExternalPort: 51820, + } - m := NewManager(sm) + gateway := newMockNAT() + // Seed the mock so we can verify deletion. + gateway.mappings[51820] = 51820 - // Initially not available (no mapping) - assert.False(t, m.IsAvailable()) + m.cleanup(context.Background(), gateway) - // Set gateway but no mapping - still not available - m.gateway = newMockNAT() - assert.False(t, m.IsAvailable()) + _, exists := gateway.mappings[51820] + assert.False(t, exists, "mapping should be deleted from gateway") + assert.Nil(t, m.GetMapping(), "in-memory mapping should be cleared") +} - // Add mapping - now available - m.mapping = &Mapping{InternalPort: 51820} - assert.True(t, m.IsAvailable()) +func TestManager_Cleanup_NilMapping(t *testing.T) { + m := NewManager() + gateway := newMockNAT() - // Clear mapping - not available again - m.mapping = nil - assert.False(t, m.IsAvailable()) + // Should not panic or call gateway. + m.cleanup(context.Background(), gateway) } func TestState_Cleanup(t *testing.T) { @@ -147,6 +136,7 @@ func TestState_Cleanup(t *testing.T) { defer func() { discoverGateway = origDiscover }() mockGateway := newMockNAT() + mockGateway.mappings[51820] = 51820 discoverGateway = func(ctx context.Context) (nat.NAT, error) { return mockGateway, nil } @@ -159,7 +149,6 @@ func TestState_Cleanup(t *testing.T) { err := state.Cleanup() assert.NoError(t, err) - // Verify the mapping was deleted _, exists := mockGateway.mappings[51820] assert.False(t, exists, "mapping should be deleted after cleanup") }