diff --git a/client/internal/dns/local/local_test.go b/client/internal/dns/local/local_test.go index 73f70035f7b..2c6b7dbc3a8 100644 --- a/client/internal/dns/local/local_test.go +++ b/client/internal/dns/local/local_test.go @@ -1263,9 +1263,9 @@ func TestLocalResolver_AuthoritativeFlag(t *testing.T) { }) } -// TestLocalResolver_Stop tests cleanup on Stop +// TestLocalResolver_Stop tests cleanup on GracefullyStop func TestLocalResolver_Stop(t *testing.T) { - t.Run("Stop clears all state", func(t *testing.T) { + t.Run("GracefullyStop clears all state", func(t *testing.T) { resolver := NewResolver() resolver.Update([]nbdns.CustomZone{{ Domain: "example.com.", @@ -1285,7 +1285,7 @@ func TestLocalResolver_Stop(t *testing.T) { assert.False(t, resolver.isInManagedZone("host.example.com.")) }) - t.Run("Stop is safe to call multiple times", func(t *testing.T) { + t.Run("GracefullyStop is safe to call multiple times", func(t *testing.T) { resolver := NewResolver() resolver.Update([]nbdns.CustomZone{{ Domain: "example.com.", @@ -1299,7 +1299,7 @@ func TestLocalResolver_Stop(t *testing.T) { resolver.Stop() }) - t.Run("Stop cancels in-flight external resolution", func(t *testing.T) { + t.Run("GracefullyStop cancels in-flight external resolution", func(t *testing.T) { resolver := NewResolver() lookupStarted := make(chan struct{}) diff --git a/client/internal/engine.go b/client/internal/engine.go index 46e4d7e0c71..68b1f327f4d 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -266,7 +266,7 @@ func NewEngine( networkSerial: 0, statusRecorder: statusRecorder, stateManager: stateManager, - portForwardManager: portforward.NewManager(stateManager), + portForwardManager: portforward.NewManager(), checks: checks, connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), @@ -513,7 +513,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) e.setupWGProxyNoTrack() // Start after interface is up since port may have been resolved from 0 or changed if occupied - e.portForwardManager.Start(e.ctx, uint16(e.config.WgPort)) + e.shutdownWg.Add(1) + go func() { + defer e.shutdownWg.Done() + e.portForwardManager.Start(e.ctx, uint16(e.config.WgPort)) + }() // Set the WireGuard interface for rosenpass after interface is up if e.rpManager != nil { @@ -1692,7 +1696,11 @@ func (e *Engine) close() { _ = e.rpManager.Close() } - e.portForwardManager.Stop() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := e.portForwardManager.GracefullyStop(ctx); err != nil { + log.Warnf("failed to gracefully stop port forwarding manager: %s", err) + } } func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) { diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 041be95dafb..d0b21117da1 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -385,7 +385,7 @@ func (w *WorkerICE) injectPortForwardedCandidate(srflxCandidate ice.Candidate) { w.muxAgent.Unlock() pfManager := w.conn.portForwardManager - if pfManager == nil || !pfManager.IsAvailable() { + if pfManager == nil { return } diff --git a/client/internal/portforward/env.go b/client/internal/portforward/env.go new file mode 100644 index 00000000000..444a6b47834 --- /dev/null +++ b/client/internal/portforward/env.go @@ -0,0 +1,26 @@ +package portforward + +import ( + "os" + "strconv" + + log "github.com/sirupsen/logrus" +) + +const ( + envDisableNATMapper = "NB_DISABLE_NAT_MAPPER" +) + +func isDisabledByEnv() bool { + val := os.Getenv(envDisableNATMapper) + if val == "" { + return false + } + + disabled, err := strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", envDisableNATMapper, err) + return false + } + return disabled +} diff --git a/client/internal/portforward/manager.go b/client/internal/portforward/manager.go index 6a0c60bbdb7..d0768009cd3 100644 --- a/client/internal/portforward/manager.go +++ b/client/internal/portforward/manager.go @@ -6,15 +6,11 @@ import ( "context" "fmt" "net" - "os" - "strconv" "sync" "time" "github.com/libp2p/go-nat" log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -22,8 +18,6 @@ const ( renewalInterval = defaultMappingTTL / 2 discoveryTimeout = 10 * time.Second mappingDescription = "NetBird" - - envDisableNATMapper = "NB_DISABLE_NAT_MAPPER" ) type Mapping struct { @@ -35,106 +29,138 @@ type Mapping struct { } type Manager struct { - ctx context.Context - cancel context.CancelFunc - stateManager *statemanager.Manager + cancel context.CancelFunc - mu sync.RWMutex - gateway nat.NAT - mapping *Mapping + mapping *Mapping + mappingLock sync.Mutex wgPort uint16 - wg sync.WaitGroup + + done chan struct{} + stopCtx chan context.Context + + // protect exported functions + mu sync.Mutex } -func NewManager(stateManager *statemanager.Manager) *Manager { +func NewManager() *Manager { return &Manager{ - stateManager: stateManager, + stopCtx: make(chan context.Context, 1), } } -// Start begins async discovery and mapping creation for the given WireGuard port. -// This does not block - use GetMapping() to check if mapping is ready. func (m *Manager) Start(ctx context.Context, wgPort uint16) { m.mu.Lock() - defer m.mu.Unlock() - if m.cancel != nil { + m.mu.Unlock() return } if isDisabledByEnv() { log.Infof("NAT port mapper disabled via %s", envDisableNATMapper) + m.mu.Unlock() return } if wgPort == 0 { log.Warnf("invalid WireGuard port 0; NAT mapping disabled") + m.mu.Unlock() return } - - m.ctx, m.cancel = context.WithCancel(ctx) m.wgPort = wgPort - m.stateManager.RegisterState(&State{}) + m.done = make(chan struct{}) + defer close(m.done) - m.wg.Add(1) - go m.run() + ctx, m.cancel = context.WithCancel(ctx) + m.mu.Unlock() + + gateway, mapping, err := m.setup(ctx) + if err != nil { + log.Errorf("failed to setup NAT port mapping: %v", err) + + return + } + + m.mappingLock.Lock() + m.mapping = mapping + m.mappingLock.Unlock() + + m.renewLoop(ctx, gateway) + + select { + case cleanupCtx := <-m.stopCtx: + // block the Start while cleaned up gracefully + m.cleanup(cleanupCtx, gateway) + default: + // return Start immediately and cleanup in background + cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 10*time.Second) + go func() { + defer cleanupCancel() + m.cleanup(cleanupCtx, gateway) + }() + } } -func (m *Manager) run() { - defer m.wg.Done() +// GetMapping returns the current mapping if ready, nil otherwise +func (m *Manager) GetMapping() *Mapping { + m.mappingLock.Lock() + defer m.mappingLock.Unlock() - if err := m.stateManager.LoadState(&State{}); err != nil { - log.Warnf("failed to load port forward state: %v", err) + if m.mapping == nil { + return nil } - var residualState *State - if existing := m.stateManager.GetState(&State{}); existing != nil { - if state, ok := existing.(*State); ok && state.InternalPort != 0 { - residualState = state - } + mapping := *m.mapping + return &mapping +} + +// GracefullyStop cancels the manager and attempts to delete the port mapping. +// After GracefullyStop returns, the manager cannot be restarted. +func (m *Manager) GracefullyStop(ctx context.Context) error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.cancel == nil { + return nil } - discoverCtx, discoverCancel := context.WithTimeout(m.ctx, discoveryTimeout) + // Send cleanup context before cancelling, so Start picks it up after renewLoop exits. + m.startTearDown(ctx) + + m.cancel() + m.cancel = nil + + select { + case <-ctx.Done(): + return ctx.Err() + case <-m.done: + return nil + } +} + +func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) { + discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout) defer discoverCancel() gateway, err := nat.DiscoverGateway(discoverCtx) if err != nil { log.Infof("NAT gateway discovery failed: %v (port forwarding disabled)", err) - return + return nil, nil, err } - m.mu.Lock() - m.gateway = gateway - m.mu.Unlock() - log.Infof("discovered NAT gateway: %s", gateway.Type()) - if residualState != nil { - if err := m.cleanupResidual(residualState); err != nil { - log.Warnf("failed to cleanup residual mapping: %v", err) - } - } - - if err := m.createMapping(); err != nil { + mapping, err := m.createMapping(ctx, gateway) + if err != nil { log.Warnf("failed to create port mapping: %v", err) - return + return nil, nil, err } - - m.renewLoop() + return gateway, mapping, nil } -func (m *Manager) cleanupResidual(state *State) error { - m.mu.RLock() - gateway := m.gateway - m.mu.RUnlock() - - if gateway == nil { - return nil - } - - ctx, cancel := context.WithTimeout(m.ctx, 10*time.Second) +func (m *Manager) cleanupResidual(ctx context.Context, gateway nat.NAT, state *State) error { + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() if err := gateway.DeletePortMapping(ctx, state.Protocol, int(state.InternalPort)); err != nil { @@ -142,193 +168,90 @@ func (m *Manager) cleanupResidual(state *State) error { } log.Infof("cleaned up residual port mapping for port %d", state.InternalPort) - - if err := m.stateManager.UpdateState(&State{}); err != nil { - return fmt.Errorf("clear state after cleanup: %w", err) - } - return nil } -func (m *Manager) createMapping() error { - m.mu.Lock() - gateway := m.gateway - wgPort := m.wgPort - m.mu.Unlock() - - if gateway == nil { - return nil - } - - ctx, cancel := context.WithTimeout(m.ctx, 30*time.Second) +func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - externalPort, err := gateway.AddPortMapping(ctx, "udp", int(wgPort), mappingDescription, defaultMappingTTL) + externalPort, err := gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, defaultMappingTTL) if err != nil { - return err + return nil, err } externalIP, err := gateway.GetExternalAddress() if err != nil { log.Debugf("failed to get external address: %v", err) + // todo return with err? } - m.mu.Lock() - defer m.mu.Unlock() - - if m.gateway != gateway { - log.Debugf("gateway changed during mapping creation, discarding result") - return nil - } - - m.mapping = &Mapping{ + mapping := &Mapping{ Protocol: "udp", - InternalPort: wgPort, + InternalPort: m.wgPort, ExternalPort: uint16(externalPort), ExternalIP: externalIP, NATType: gateway.Type(), } log.Infof("created port mapping: %d -> %d via %s (external IP: %s)", - wgPort, externalPort, gateway.Type(), externalIP) - - return m.persistStateLocked() -} - -// Stop cancels the manager and attempts to delete the port mapping. -// After Stop returns, the manager cannot be restarted. -func (m *Manager) Stop() { - m.mu.Lock() - cancel := m.cancel - gateway := m.gateway - mapping := m.mapping - m.cancel = nil - m.gateway = nil - m.mapping = nil - m.mu.Unlock() - - if cancel != nil { - cancel() - } - - m.wg.Wait() - - if gateway == nil || mapping == nil { - return - } - - ctx, ctxCancel := context.WithTimeout(context.Background(), 3*time.Second) - defer ctxCancel() - - if err := gateway.DeletePortMapping(ctx, mapping.Protocol, int(mapping.InternalPort)); err != nil { - log.Debugf("delete port mapping on stop: %v", err) - return - } - - log.Infof("deleted port mapping for port %d", mapping.InternalPort) - - if err := m.stateManager.UpdateState(&State{}); err != nil { - log.Debugf("clear state on stop: %v", err) - } + m.wgPort, externalPort, gateway.Type(), externalIP) + return mapping, nil } -// GetMapping returns the current mapping if ready, nil otherwise -func (m *Manager) GetMapping() *Mapping { - m.mu.RLock() - defer m.mu.RUnlock() - - if m.mapping == nil { - return nil - } - - mapping := *m.mapping - return &mapping -} - -// IsAvailable returns true if port forwarding is available and mapping exists -func (m *Manager) IsAvailable() bool { - m.mu.RLock() - defer m.mu.RUnlock() - - return m.mapping != nil -} - -func (m *Manager) renewLoop() { +func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT) { ticker := time.NewTicker(renewalInterval) defer ticker.Stop() for { select { - case <-m.ctx.Done(): + case <-ctx.Done(): return case <-ticker.C: - if err := m.renewMapping(); err != nil { + if err := m.renewMapping(ctx, gateway); err != nil { log.Warnf("failed to renew port mapping: %v", err) + continue } } } } -func (m *Manager) renewMapping() error { - m.mu.Lock() - gateway := m.gateway - mapping := m.mapping - m.mu.Unlock() - - if mapping == nil || gateway == nil { - return nil - } - - ctx, cancel := context.WithTimeout(m.ctx, 30*time.Second) +func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - externalPort, err := gateway.AddPortMapping(ctx, mapping.Protocol, int(mapping.InternalPort), mappingDescription, defaultMappingTTL) + externalPort, err := gateway.AddPortMapping(ctx, m.mapping.Protocol, int(m.mapping.InternalPort), mappingDescription, defaultMappingTTL) if err != nil { return fmt.Errorf("add port mapping: %w", err) } - m.mu.Lock() - defer m.mu.Unlock() - - if m.gateway != gateway || m.mapping == nil { - log.Debugf("state changed during mapping renewal, discarding result") - return nil - } - if uint16(externalPort) != m.mapping.ExternalPort { - log.Warnf("external port changed on renewal: %d -> %d (candidate may be stale)", - m.mapping.ExternalPort, externalPort) + log.Warnf("external port changed on renewal: %d -> %d (candidate may be stale)", m.mapping.ExternalPort, externalPort) + m.mappingLock.Lock() m.mapping.ExternalPort = uint16(externalPort) + m.mappingLock.Unlock() } log.Debugf("renewed port mapping: %d -> %d", m.mapping.InternalPort, m.mapping.ExternalPort) return nil } -func (m *Manager) persistStateLocked() error { - var state *State - if m.mapping != nil { - state = &State{ - InternalPort: m.mapping.InternalPort, - Protocol: m.mapping.Protocol, - } - } else { - state = &State{} +func (m *Manager) cleanup(ctx context.Context, gateway nat.NAT) { + if m.mapping == nil { + return } - return m.stateManager.UpdateState(state) -} - -func isDisabledByEnv() bool { - val := os.Getenv(envDisableNATMapper) - if val == "" { - return false + if err := gateway.DeletePortMapping(ctx, m.mapping.Protocol, int(m.mapping.InternalPort)); err != nil { + log.Warnf("delete port mapping on stop: %v", err) + return } - disabled, err := strconv.ParseBool(val) - if err != nil { - log.Warnf("failed to parse %s: %v", envDisableNATMapper, err) - return false + log.Infof("deleted port mapping for port %d", m.mapping.InternalPort) +} + +func (m *Manager) startTearDown(ctx context.Context) { + select { + case m.stopCtx <- ctx: + default: } - return disabled } diff --git a/client/internal/portforward/manager_js.go b/client/internal/portforward/manager_js.go index 2ebfa2dcda7..e7fd4a64ed2 100644 --- a/client/internal/portforward/manager_js.go +++ b/client/internal/portforward/manager_js.go @@ -3,8 +3,6 @@ package portforward import ( "context" "net" - - "github.com/netbirdio/netbird/client/internal/statemanager" ) // Mapping represents port mapping information. @@ -20,22 +18,17 @@ type Mapping struct { type Manager struct{} // NewManager returns a stub manager for js/wasm builds. -func NewManager(_ *statemanager.Manager) *Manager { +func NewManager() *Manager { return &Manager{} } // Start is a no-op on js/wasm. func (m *Manager) Start(context.Context, uint16) {} -// Stop is a no-op on js/wasm. -func (m *Manager) Stop() {} +// GracefullyStop is a no-op on js/wasm. +func (m *Manager) GracefullyStop(context.Context) error { return nil } // GetMapping always returns nil on js/wasm. func (m *Manager) GetMapping() *Mapping { return nil } - -// IsAvailable always returns false on js/wasm. -func (m *Manager) IsAvailable() bool { - return false -} diff --git a/client/internal/portforward/manager_test.go b/client/internal/portforward/manager_test.go index 732054975d5..5548fa4d590 100644 --- a/client/internal/portforward/manager_test.go +++ b/client/internal/portforward/manager_test.go @@ -96,7 +96,7 @@ func TestManager_CreateMapping(t *testing.T) { m, cancel := setupTestManager(t) defer cancel() - err := m.createMapping() + err := m.createMapping(nil) require.NoError(t, err) mapping := m.GetMapping()