Skip to content
Merged
8 changes: 4 additions & 4 deletions client/internal/dns/local/local_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1263,9 +1263,9 @@ func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
})
}

// TestLocalResolver_Stop tests cleanup on Stop
// TestLocalResolver_Stop tests cleanup on GracefullyStop
func TestLocalResolver_Stop(t *testing.T) {
t.Run("Stop clears all state", func(t *testing.T) {
t.Run("GracefullyStop clears all state", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
Expand All @@ -1285,7 +1285,7 @@ func TestLocalResolver_Stop(t *testing.T) {
assert.False(t, resolver.isInManagedZone("host.example.com."))
})

t.Run("Stop is safe to call multiple times", func(t *testing.T) {
t.Run("GracefullyStop is safe to call multiple times", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
Expand All @@ -1299,7 +1299,7 @@ func TestLocalResolver_Stop(t *testing.T) {
resolver.Stop()
})

t.Run("Stop cancels in-flight external resolution", func(t *testing.T) {
t.Run("GracefullyStop cancels in-flight external resolution", func(t *testing.T) {
resolver := NewResolver()

lookupStarted := make(chan struct{})
Expand Down
75 changes: 46 additions & 29 deletions client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
Expand Down Expand Up @@ -210,9 +211,10 @@ type Engine struct {
// checks are the client-applied posture checks that need to be evaluated on the client
checks []*mgmProto.Checks

relayManager *relayClient.Manager
stateManager *statemanager.Manager
srWatcher *guard.SRWatcher
relayManager *relayClient.Manager
stateManager *statemanager.Manager
portForwardManager *portforward.Manager
srWatcher *guard.SRWatcher

// Sync response persistence (protected by syncRespMux)
syncRespMux sync.RWMutex
Expand Down Expand Up @@ -259,26 +261,27 @@ func NewEngine(
mobileDep MobileDependency,
) *Engine {
engine := &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
signal: services.SignalClient,
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
mgmClient: services.MgmClient,
relayManager: services.RelayManager,
peerStore: peerstore.NewConnStore(),
syncMsgMux: &sync.Mutex{},
config: config,
mobileDep: mobileDep,
STUNs: []*stun.URI{},
TURNs: []*stun.URI{},
networkSerial: 0,
statusRecorder: services.StatusRecorder,
stateManager: services.StateManager,
checks: services.Checks,
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
jobExecutor: jobexec.NewExecutor(),
clientMetrics: services.ClientMetrics,
updateManager: services.UpdateManager,
clientCtx: clientCtx,
clientCancel: clientCancel,
signal: services.SignalClient,
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
mgmClient: services.MgmClient,
relayManager: services.RelayManager,
peerStore: peerstore.NewConnStore(),
syncMsgMux: &sync.Mutex{},
config: config,
mobileDep: mobileDep,
STUNs: []*stun.URI{},
TURNs: []*stun.URI{},
networkSerial: 0,
statusRecorder: services.StatusRecorder,
stateManager: services.StateManager,
portForwardManager: portforward.NewManager(),
checks: services.Checks,
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
jobExecutor: jobexec.NewExecutor(),
clientMetrics: services.ClientMetrics,
updateManager: services.UpdateManager,
}

log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
Expand Down Expand Up @@ -537,6 +540,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
// conntrack entries from being created before the rules are in place
e.setupWGProxyNoTrack()

// Start after interface is up since port may have been resolved from 0 or changed if occupied
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
e.portForwardManager.Start(e.ctx, uint16(e.config.WgPort))
}()

// Set the WireGuard interface for rosenpass after interface is up
if e.rpManager != nil {
e.rpManager.SetInterface(e.wgInterface)
Expand Down Expand Up @@ -1540,12 +1550,13 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
}

serviceDependencies := peer.ServiceDependencies{
StatusRecorder: e.statusRecorder,
Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
MetricsRecorder: e.clientMetrics,
StatusRecorder: e.statusRecorder,
Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
PortForwardManager: e.portForwardManager,
MetricsRecorder: e.clientMetrics,
}
peerConn, err := peer.NewConn(config, serviceDependencies)
if err != nil {
Expand Down Expand Up @@ -1702,6 +1713,12 @@ func (e *Engine) close() {
if e.rpManager != nil {
_ = e.rpManager.Close()
}

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := e.portForwardManager.GracefullyStop(ctx); err != nil {
log.Warnf("failed to gracefully stop port forwarding manager: %s", err)
}
}

func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
Expand Down
50 changes: 27 additions & 23 deletions client/internal/peer/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peer/worker"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
Expand All @@ -45,6 +46,7 @@ type ServiceDependencies struct {
RelayManager *relayClient.Manager
SrWatcher *guard.SRWatcher
PeerConnDispatcher *dispatcher.ConnectionDispatcher
PortForwardManager *portforward.Manager
MetricsRecorder MetricsRecorder
}

Expand Down Expand Up @@ -87,16 +89,17 @@ type ConnConfig struct {
}

type Conn struct {
Log *log.Entry
mu sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
config ConnConfig
statusRecorder *Status
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
relayManager *relayClient.Manager
srWatcher *guard.SRWatcher
Log *log.Entry
mu sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
config ConnConfig
statusRecorder *Status
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
relayManager *relayClient.Manager
srWatcher *guard.SRWatcher
portForwardManager *portforward.Manager

onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string)
Expand Down Expand Up @@ -145,19 +148,20 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {

dumpState := newStateDump(config.Key, connLog, services.StatusRecorder)
var conn = &Conn{
Log: connLog,
config: config,
statusRecorder: services.StatusRecorder,
signaler: services.Signaler,
iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager,
srWatcher: services.SrWatcher,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: dumpState,
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
metricsRecorder: services.MetricsRecorder,
Log: connLog,
config: config,
statusRecorder: services.StatusRecorder,
signaler: services.Signaler,
iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager,
srWatcher: services.SrWatcher,
portForwardManager: services.PortForwardManager,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: dumpState,
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
metricsRecorder: services.MetricsRecorder,
}

return conn, nil
Expand Down
99 changes: 96 additions & 3 deletions client/internal/peer/worker_ice.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/peer/conntype"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route"
)
Expand Down Expand Up @@ -61,6 +62,9 @@ type WorkerICE struct {

// we record the last known state of the ICE agent to avoid duplicate on disconnected events
lastKnownState ice.ConnectionState

// portForwardAttempted tracks if we've already tried port forwarding this session
portForwardAttempted bool
}

func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) {
Expand Down Expand Up @@ -214,6 +218,8 @@ func (w *WorkerICE) Close() {
}

func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
w.portForwardAttempted = false

agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
if err != nil {
return nil, fmt.Errorf("create agent: %w", err)
Expand Down Expand Up @@ -370,6 +376,93 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
}
}()

if candidate.Type() == ice.CandidateTypeServerReflexive {
w.injectPortForwardedCandidate(candidate)
}
}

// injectPortForwardedCandidate signals an additional candidate using the pre-created port mapping.
func (w *WorkerICE) injectPortForwardedCandidate(srflxCandidate ice.Candidate) {
pfManager := w.conn.portForwardManager
if pfManager == nil {
return
}

mapping := pfManager.GetMapping()
if mapping == nil {
return
}

w.muxAgent.Lock()
if w.portForwardAttempted {
w.muxAgent.Unlock()
return
}
w.portForwardAttempted = true
w.muxAgent.Unlock()

forwardedCandidate, err := w.createForwardedCandidate(srflxCandidate, mapping)
if err != nil {
w.log.Warnf("create forwarded candidate: %v", err)
return
}

w.log.Debugf("injecting port-forwarded candidate: %s (mapping: %d -> %d via %s, priority: %d)",
forwardedCandidate.String(), mapping.InternalPort, mapping.ExternalPort, mapping.NATType, forwardedCandidate.Priority())

go func() {
if err := w.signaler.SignalICECandidate(forwardedCandidate, w.config.Key); err != nil {
w.log.Errorf("signal port-forwarded candidate: %v", err)
}
}()
}

// createForwardedCandidate creates a new server reflexive candidate with the forwarded port.
// It uses the NAT gateway's external IP with the forwarded port.
func (w *WorkerICE) createForwardedCandidate(srflxCandidate ice.Candidate, mapping *portforward.Mapping) (ice.Candidate, error) {
var externalIP string
if mapping.ExternalIP != nil && !mapping.ExternalIP.IsUnspecified() {
externalIP = mapping.ExternalIP.String()
} else {
// Fallback to STUN-discovered address if NAT didn't provide external IP
externalIP = srflxCandidate.Address()
}

// Per RFC 8445, the related address for srflx is the base (host candidate address).
// If the original srflx has unspecified related address, use its own address as base.
relAddr := srflxCandidate.RelatedAddress().Address
if relAddr == "" || relAddr == "0.0.0.0" || relAddr == "::" {
relAddr = srflxCandidate.Address()
}

// Arbitrary +1000 boost on top of RFC 8445 priority to favor port-forwarded candidates
// over regular srflx during ICE connectivity checks.
priority := srflxCandidate.Priority() + 1000
Comment thread
coderabbitai[bot] marked this conversation as resolved.

candidate, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
Network: srflxCandidate.NetworkType().String(),
Address: externalIP,
Port: int(mapping.ExternalPort),
Component: srflxCandidate.Component(),
Priority: priority,
RelAddr: relAddr,
RelPort: int(mapping.InternalPort),
})
if err != nil {
return nil, fmt.Errorf("create candidate: %w", err)
}

for _, e := range srflxCandidate.Extensions() {
if e.Key == ice.ExtensionKeyCandidateID {
e.Value = srflxCandidate.ID()
}
if err := candidate.AddExtension(e); err != nil {
return nil, fmt.Errorf("add extension: %w", err)
}
}

return candidate, nil
}

func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) {
Expand Down Expand Up @@ -411,10 +504,10 @@ func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) {
if !lok || !rok {
continue
}
w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms",
w.log.Debugf("successful ICE path %s: [%s %s %s:%d] <-> [%s %s %s:%d] rtt=%.3fms",
sessionID,
local.NetworkType(), local.Type(), local.Address(),
remote.NetworkType(), remote.Type(), remote.Address(),
local.NetworkType(), local.Type(), local.Address(), local.Port(),
remote.NetworkType(), remote.Type(), remote.Address(), remote.Port(),
stat.CurrentRoundTripTime*1000)
}
}
Expand Down
26 changes: 26 additions & 0 deletions client/internal/portforward/env.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package portforward

import (
"os"
"strconv"

log "github.com/sirupsen/logrus"
)

const (
envDisableNATMapper = "NB_DISABLE_NAT_MAPPER"
)

func isDisabledByEnv() bool {
val := os.Getenv(envDisableNATMapper)
if val == "" {
return false
}

disabled, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envDisableNATMapper, err)
return false
}
return disabled
}
Loading
Loading