Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions client/internal/portforward/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,27 @@ import (
)

const (
envDisableNATMapper = "NB_DISABLE_NAT_MAPPER"
envDisableNATMapper = "NB_DISABLE_NAT_MAPPER"
envDisablePCPHealthCheck = "NB_DISABLE_PCP_HEALTH_CHECK"
)

func isDisabledByEnv() bool {
val := os.Getenv(envDisableNATMapper)
return parseBoolEnv(envDisableNATMapper)
}

func isHealthCheckDisabled() bool {
return parseBoolEnv(envDisablePCPHealthCheck)
}

func parseBoolEnv(key string) bool {
val := os.Getenv(key)
if val == "" {
return false
}

disabled, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envDisableNATMapper, err)
log.Warnf("failed to parse %s: %v", key, err)
return false
}
return disabled
Expand Down
82 changes: 72 additions & 10 deletions client/internal/portforward/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@ import (

"github.com/libp2p/go-nat"
log "github.com/sirupsen/logrus"

"github.com/netbirdio/netbird/client/internal/portforward/pcp"
)

const (
defaultMappingTTL = 2 * time.Hour
discoveryTimeout = 10 * time.Second
mappingDescription = "NetBird"
defaultMappingTTL = 2 * time.Hour
healthCheckInterval = 1 * time.Minute
discoveryTimeout = 10 * time.Second
mappingDescription = "NetBird"
)

// upnpErrPermanentLeaseOnly matches UPnP error 725 in SOAP fault XML,
Expand Down Expand Up @@ -154,7 +157,7 @@ func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) {
discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout)
defer discoverCancel()

gateway, err := nat.DiscoverGateway(discoverCtx)
gateway, err := discoverGateway(discoverCtx)
if err != nil {
return nil, nil, fmt.Errorf("discover gateway: %w", err)
}
Expand Down Expand Up @@ -189,7 +192,6 @@ func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping,
externalIP, err := gateway.GetExternalAddress()
if err != nil {
log.Debugf("failed to get external address: %v", err)
// todo return with err?
}

mapping := &Mapping{
Expand All @@ -208,25 +210,85 @@ func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping,

func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT, ttl time.Duration) {
if ttl == 0 {
// Permanent mappings don't expire, just wait for cancellation.
<-ctx.Done()
// Permanent mappings don't expire, just wait for cancellation
// but still run health checks for PCP gateways.
m.permanentLeaseLoop(ctx, gateway)
return
}

ticker := time.NewTicker(ttl / 2)
defer ticker.Stop()
renewTicker := time.NewTicker(ttl / 2)
healthTicker := time.NewTicker(healthCheckInterval)
defer renewTicker.Stop()
defer healthTicker.Stop()

for {
select {
case <-ctx.Done():
return
case <-ticker.C:
case <-renewTicker.C:
if err := m.renewMapping(ctx, gateway); err != nil {
log.Warnf("failed to renew port mapping: %v", err)
continue
}
case <-healthTicker.C:
if m.checkHealthAndRecreate(ctx, gateway) {
renewTicker.Reset(ttl / 2)
}
}
}
}

func (m *Manager) permanentLeaseLoop(ctx context.Context, gateway nat.NAT) {
healthTicker := time.NewTicker(healthCheckInterval)
defer healthTicker.Stop()

for {
select {
case <-ctx.Done():
return
case <-healthTicker.C:
m.checkHealthAndRecreate(ctx, gateway)
}
}
}

func (m *Manager) checkHealthAndRecreate(ctx context.Context, gateway nat.NAT) bool {
if isHealthCheckDisabled() {
return false
}

m.mappingLock.Lock()
hasMapping := m.mapping != nil
m.mappingLock.Unlock()

if !hasMapping {
return false
}

pcpNAT, ok := gateway.(*pcp.NAT)
if !ok {
return false
}

ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()

epoch, serverRestarted, err := pcpNAT.CheckServerHealth(ctx)
if err != nil {
log.Debugf("PCP health check failed: %v", err)
return false
}

if serverRestarted {
log.Warnf("PCP server restart detected (epoch=%d), recreating port mapping", epoch)
if err := m.renewMapping(ctx, gateway); err != nil {
log.Errorf("failed to recreate port mapping after server restart: %v", err)
return false
}
return true
}

return false
}

func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error {
Expand Down
Loading
Loading