Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 64 additions & 20 deletions client/internal/peer/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ type Conn struct {

workerICE *WorkerICE
workerRelay *WorkerRelay
wgWatcherWg sync.WaitGroup

wgWatcher *WGWatcher
wgWatcherWg sync.WaitGroup
wgWatcherCancel context.CancelFunc

// used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
rosenpassRemoteKey []byte
Expand Down Expand Up @@ -126,6 +129,7 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {

connLog := log.WithField("peer", config.Key)

dumpState := newStateDump(config.Key, connLog, services.StatusRecorder)
var conn = &Conn{
Log: connLog,
config: config,
Expand All @@ -137,8 +141,9 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
semaphore: services.Semaphore,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
dumpState: dumpState,
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
}

return conn, nil
Expand All @@ -162,7 +167,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {

conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)

conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)

relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
Expand Down Expand Up @@ -231,7 +236,9 @@ func (conn *Conn) Close(signalToRemote bool) {
conn.Log.Infof("close peer connection")
conn.ctxCancel()

conn.workerRelay.DisableWgWatcher()
if conn.wgWatcherCancel != nil {
conn.wgWatcherCancel()
}
conn.workerRelay.CloseConn()
conn.workerICE.Close()

Expand Down Expand Up @@ -366,9 +373,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
ep = directEp
}

conn.workerRelay.DisableWgWatcher()
// todo consider to run conn.wgWatcherWg.Wait() here

if conn.wgProxyRelay != nil {
conn.wgProxyRelay.Pause()
}
Expand All @@ -390,6 +394,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
conn.wgProxyRelay.RedirectAs(ep)
}

conn.enableWgWatcherIfNeeded()

conn.currentConnPriority = priority
conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo)
Expand Down Expand Up @@ -423,11 +429,6 @@ func (conn *Conn) onICEStateDisconnected() {
conn.Log.Errorf("failed to switch to relay conn: %v", err)
}

conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx)
}()
conn.wgProxyRelay.Work()
conn.currentConnPriority = conntype.Relay
} else {
Expand All @@ -444,15 +445,15 @@ func (conn *Conn) onICEStateDisconnected() {
}
conn.statusICE.SetDisconnected()

conn.disableWgWatcherIfNeeded()

peerState := State{
PubKey: conn.config.Key,
ConnStatus: conn.evalStatus(),
Relayed: conn.isRelayed(),
ConnStatusUpdate: time.Now(),
}

err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState)
if err != nil {
if err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState); err != nil {
conn.Log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err)
}
}
Expand Down Expand Up @@ -500,11 +501,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
return
}

conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx)
}()
conn.enableWgWatcherIfNeeded()

wgConfigWorkaround()
conn.rosenpassRemoteKey = rci.rosenpassPubKey
Expand All @@ -519,7 +516,11 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
func (conn *Conn) onRelayDisconnected() {
conn.mu.Lock()
defer conn.mu.Unlock()
conn.handleRelayDisconnectedLocked()
}

// handleRelayDisconnectedLocked handles relay disconnection. Caller must hold conn.mu.
func (conn *Conn) handleRelayDisconnectedLocked() {
if conn.ctx.Err() != nil {
return
}
Expand All @@ -545,6 +546,8 @@ func (conn *Conn) onRelayDisconnected() {
}
conn.statusRelay.SetDisconnected()

conn.disableWgWatcherIfNeeded()

peerState := State{
PubKey: conn.config.Key,
ConnStatus: conn.evalStatus(),
Expand All @@ -563,6 +566,28 @@ func (conn *Conn) onGuardEvent() {
}
}

func (conn *Conn) onWGDisconnected() {
conn.mu.Lock()
defer conn.mu.Unlock()

if conn.ctx.Err() != nil {
return
}

conn.Log.Warnf("WireGuard handshake timeout detected, closing current connection")

// Close the active connection based on current priority
switch conn.currentConnPriority {
case conntype.Relay:
conn.workerRelay.CloseConn()
conn.handleRelayDisconnectedLocked()
case conntype.ICEP2P, conntype.ICETurn:
conn.workerICE.Close()
default:
conn.Log.Debugf("No active connection to close on WG timeout")
}
}

func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) {
peerState := State{
PubKey: conn.config.Key,
Expand Down Expand Up @@ -689,6 +714,25 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
return true
}

func (conn *Conn) enableWgWatcherIfNeeded() {
if !conn.wgWatcher.IsEnabled() {
wgWatcherCtx, wgWatcherCancel := context.WithCancel(conn.ctx)
conn.wgWatcherCancel = wgWatcherCancel
conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, conn.onWGDisconnected)
}()
}
}

func (conn *Conn) disableWgWatcherIfNeeded() {
if conn.currentConnPriority == conntype.None && conn.wgWatcherCancel != nil {
conn.wgWatcherCancel()
conn.wgWatcherCancel = nil
}
}

func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
conn.Log.Debugf("setup proxied WireGuard connection")
udpAddr := &net.UDPAddr{
Expand Down
72 changes: 36 additions & 36 deletions client/internal/peer/wg_watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,8 @@ type WGWatcher struct {
peerKey string
stateDump *stateDump

ctx context.Context
ctxCancel context.CancelFunc
ctxLock sync.Mutex
enabledTime time.Time
enabled bool
muEnabled sync.RWMutex
}

func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
Expand All @@ -46,52 +44,44 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
}

// EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing.
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
w.log.Debugf("enable WireGuard watcher")
w.ctxLock.Lock()
w.enabledTime = time.Now()

if w.ctx != nil && w.ctx.Err() == nil {
w.log.Errorf("WireGuard watcher already enabled")
w.ctxLock.Unlock()
// The watcher runs until ctx is cancelled. Caller is responsible for context lifecycle management.
func (w *WGWatcher) EnableWgWatcher(ctx context.Context, onDisconnectedFn func()) {
w.muEnabled.Lock()
if w.enabled {
w.muEnabled.Unlock()
return
}

ctx, ctxCancel := context.WithCancel(parentCtx)
w.ctx = ctx
w.ctxCancel = ctxCancel
w.ctxLock.Unlock()
w.log.Debugf("enable WireGuard watcher")
enabledTime := time.Now()
w.enabled = true
w.muEnabled.Unlock()

initialHandshake, err := w.wgState()
if err != nil {
w.log.Warnf("failed to read initial wg stats: %v", err)
}

w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake)
}

// DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit
func (w *WGWatcher) DisableWgWatcher() {
w.ctxLock.Lock()
defer w.ctxLock.Unlock()

if w.ctxCancel == nil {
return
}
w.periodicHandshakeCheck(ctx, onDisconnectedFn, enabledTime, initialHandshake)

w.log.Debugf("disable WireGuard watcher")
w.muEnabled.Lock()
w.enabled = false
w.muEnabled.Unlock()
}

w.ctxCancel()
w.ctxCancel = nil
// IsEnabled returns true if the WireGuard watcher is currently enabled
func (w *WGWatcher) IsEnabled() bool {
w.muEnabled.RLock()
defer w.muEnabled.RUnlock()
return w.enabled
}

// wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) {
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
w.log.Infof("WireGuard watcher started")

timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop()
defer ctxCancel()

lastHandshake := initialHandshake

Expand All @@ -104,7 +94,7 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex
return
}
if lastHandshake.IsZero() {
elapsed := handshake.Sub(w.enabledTime).Seconds()
elapsed := calcElapsed(enabledTime, *handshake)
w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake)
}

Expand Down Expand Up @@ -134,19 +124,19 @@ func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) {

// the current know handshake did not change
if handshake.Equal(lastHandshake) {
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
w.log.Warnf("WireGuard handshake timed out: %v", handshake)
return nil, false
}

// in case if the machine is suspended, the handshake time will be in the past
if handshake.Add(checkPeriod).Before(time.Now()) {
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
w.log.Warnf("WireGuard handshake timed out: %v", handshake)
return nil, false
}

// error handling for handshake time in the future
if handshake.After(time.Now()) {
w.log.Warnf("WireGuard handshake is in the future, closing relay connection: %v", handshake)
w.log.Warnf("WireGuard handshake is in the future: %v", handshake)
return nil, false
}

Expand All @@ -164,3 +154,13 @@ func (w *WGWatcher) wgState() (time.Time, error) {
}
return wgState.LastHandshake, nil
}

// calcElapsed calculates elapsed time since watcher was enabled.
// The watcher started after the wg configuration happens, because of this need to normalise the negative value
func calcElapsed(enabledTime, handshake time.Time) float64 {
elapsed := handshake.Sub(enabledTime).Seconds()
if elapsed < 0 {
elapsed = 0
}
return elapsed
}
20 changes: 13 additions & 7 deletions client/internal/peer/wg_watcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package peer

import (
"context"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -48,7 +49,6 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) {
case <-time.After(10 * time.Second):
t.Errorf("timeout")
}
watcher.DisableWgWatcher()
}

func TestWGWatcher_ReEnable(t *testing.T) {
Expand All @@ -60,14 +60,21 @@ func TestWGWatcher_ReEnable(t *testing.T) {
watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump("peer", mlog, &Status{}))

ctx, cancel := context.WithCancel(context.Background())
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
watcher.EnableWgWatcher(ctx, func() {})
}()
cancel()

wg.Wait()

// Re-enable with a new context
ctx, cancel = context.WithCancel(context.Background())
defer cancel()

onDisconnected := make(chan struct{}, 1)

go watcher.EnableWgWatcher(ctx, func() {})
time.Sleep(1 * time.Second)
watcher.DisableWgWatcher()

go watcher.EnableWgWatcher(ctx, func() {
onDisconnected <- struct{}{}
})
Expand All @@ -80,5 +87,4 @@ func TestWGWatcher_ReEnable(t *testing.T) {
case <-time.After(10 * time.Second):
t.Errorf("timeout")
}
watcher.DisableWgWatcher()
}
Loading
Loading