Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion client/internal/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbnet "github.com/netbirdio/netbird/client/net"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
Expand All @@ -34,7 +35,6 @@ import (
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/version"
)

Expand Down Expand Up @@ -272,6 +272,9 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan

c.engineMutex.Lock()
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil {
c.engine.handleAutoUpdateVersion(loginResp.PeerConfig.AutoUpdate)
}
c.engine.SetSyncResponsePersistence(c.persistSyncResponse)
c.engineMutex.Unlock()

Expand Down
17 changes: 10 additions & 7 deletions client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -721,16 +721,19 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg
return nil
}

func (e *Engine) handleAutoUpdateVersion(autoUpdateVersion string) {
if e.updateManager == nil && autoUpdateVersion != disableAutoUpdate {
e.updateManager = updatemanager.NewUpdateManager(e.statusRecorder)
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
if autoUpdateSettings == nil {
return
}
if e.updateManager == nil && autoUpdateSettings.Version != disableAutoUpdate && autoUpdateSettings.AlwaysUpdate {
e.updateManager = updatemanager.NewUpdateManager(e.statusRecorder, e.stateManager)
e.updateManager.Start(e.ctx)
} else if e.updateManager != nil && autoUpdateVersion == disableAutoUpdate {
} else if e.updateManager != nil && autoUpdateSettings.Version == disableAutoUpdate {
e.updateManager.Stop()
e.updateManager = nil
}
if e.updateManager != nil {
e.updateManager.SetVersion(autoUpdateVersion)
if e.updateManager != nil && autoUpdateSettings.AlwaysUpdate {
e.updateManager.SetVersion(autoUpdateSettings.Version)
}
}

Expand All @@ -739,7 +742,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
defer e.syncMsgMux.Unlock()

if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdateVersion)
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
}
if update.GetNetbirdConfig() != nil {
wCfg := update.GetNetbirdConfig()
Expand Down
123 changes: 115 additions & 8 deletions client/internal/updatemanager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
log "github.com/sirupsen/logrus"

"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/version"
)
Expand All @@ -32,6 +33,15 @@ type UpdateInterface interface {
StartFetcher()
}

type UpdateState struct {
PreUpdateVersion string
TargetVersion string
}

func (u UpdateState) Name() string {
return "autoUpdate"
}

type UpdateManager struct {
lastTrigger time.Time
statusRecorder *peer.Status
Expand All @@ -40,6 +50,7 @@ type UpdateManager struct {
wg sync.WaitGroup
currentVersion string
updateFunc func(ctx context.Context, targetVersion string) error
stateManager *statemanager.Manager

cancel context.CancelFunc
update UpdateInterface
Expand All @@ -49,38 +60,83 @@ type UpdateManager struct {
expectedVersionMutex sync.Mutex
}

func NewUpdateManager(statusRecorder *peer.Status) *UpdateManager {
func NewUpdateManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) *UpdateManager {
manager := &UpdateManager{
statusRecorder: statusRecorder,
mgmUpdateChan: make(chan struct{}, 1),
updateChannel: make(chan struct{}, 1),
currentVersion: version.NetbirdVersion(),
updateFunc: triggerUpdate,
update: version.NewUpdate("nb/client"),
stateManager: stateManager,
}

return manager
}

func (u *UpdateManager) StartWithTimeout(ctx context.Context, timeout time.Duration) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will happen to the fetcher after a timeout? It seems like it will never be freed

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since timeout will cancel the context, onContextCancel will handle fetcher cleanup

if u.cancel != nil {
log.Errorf("UpdateManager already started")
return
}

u.startInit(ctx)
ctx, cancel := context.WithTimeout(ctx, timeout)
u.cancel = cancel

u.wg.Add(1)
go u.updateLoop(ctx)
}

func (u *UpdateManager) Start(ctx context.Context) {
if u.cancel != nil {
log.Errorf("UpdateManager already started")
return
}

go u.update.StartFetcher()
u.startInit(ctx)
ctx, cancel := context.WithCancel(ctx)
u.cancel = cancel

u.wg.Add(1)
go u.updateLoop(ctx)
}

func (u *UpdateManager) startInit(ctx context.Context) {
u.update.SetDaemonVersion(u.currentVersion)
u.update.SetOnUpdateListener(func() {
select {
case u.updateChannel <- struct{}{}:
default:
}
})
go u.update.StartFetcher()

ctx, cancel := context.WithCancel(ctx)
u.cancel = cancel

u.wg.Add(1)
go u.updateLoop(ctx)
u.stateManager.RegisterState(&UpdateState{})
if err := u.stateManager.LoadState(&UpdateState{}); err != nil {
log.Warnf("failed to load state: %v", err)
return
}
if u.stateManager.GetState(&UpdateState{}) == nil {
return
}
updateState := u.stateManager.GetState(&UpdateState{}).(*UpdateState)
log.Warnf("autoUpdate state loaded, %v", *updateState)
if updateState.TargetVersion == u.currentVersion {
log.Warnf("published notification event")
u.statusRecorder.PublishEvent(
cProto.SystemEvent_INFO,
cProto.SystemEvent_SYSTEM,
"Auto-update completed",
fmt.Sprintf("Your NetBird Client was auto-updated to version %s", u.currentVersion),
nil,
)
}
if err := u.stateManager.DeleteState(updateState); err != nil {
log.Warnf("failed to delete state: %v", err)
} else if err = u.stateManager.PersistState(ctx); err != nil {
log.Warnf("failed to persist state: %v", err)
}
}

func (u *UpdateManager) SetVersion(expectedVersion string) {
Expand Down Expand Up @@ -129,12 +185,26 @@ func (u *UpdateManager) Stop() {
u.wg.Wait()
}

func (u *UpdateManager) onContextCancel() {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You clean up updateManager with Stop and also with context cancellation. This means we have a race condition in the onContextCancel and Stop() functions.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the locking mechanism from the original PR modifications

if u.cancel == nil {
return
}

u.expectedVersionMutex.Lock()
defer u.expectedVersionMutex.Unlock()
if u.update != nil {
u.update.StopWatch()
u.update = nil
}
}

func (u *UpdateManager) updateLoop(ctx context.Context) {
defer u.wg.Done()

for {
select {
case <-ctx.Done():
u.onContextCancel()
return
case <-u.mgmUpdateChan:
case <-u.updateChannel:
Expand Down Expand Up @@ -189,9 +259,46 @@ func (u *UpdateManager) handleUpdate(ctx context.Context) {
nil,
)

err := u.updateFunc(ctx, updateVersion.String())
u.statusRecorder.PublishEvent(
cProto.SystemEvent_INFO,
cProto.SystemEvent_SYSTEM,
"",
"",
map[string]string{"progress_window": "show"},
)

updateState := UpdateState{
PreUpdateVersion: u.currentVersion,
TargetVersion: updateVersion.String(),
}
err := u.stateManager.UpdateState(updateState)
if err != nil {
log.Warnf("failed to update state: %v", err)
} else {
err = u.stateManager.PersistState(ctx)
if err != nil {
log.Warnf("failed to persist state: %v", err)
}
}

err = u.updateFunc(ctx, updateVersion.String())

if err != nil {
log.Errorf("Error triggering auto-update: %v", err)
u.statusRecorder.PublishEvent(
cProto.SystemEvent_ERROR,
cProto.SystemEvent_SYSTEM,
"Auto-update failed",
fmt.Sprintf("Auto-update failed: %v", err),
nil,
)
u.statusRecorder.PublishEvent(
cProto.SystemEvent_INFO,
cProto.SystemEvent_SYSTEM,
"",
"",
map[string]string{"progress_window": "hide"},
)
}
}

Expand Down
13 changes: 9 additions & 4 deletions client/internal/updatemanager/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package updatemanager

import (
"context"
"fmt"
v "github.com/hashicorp/go-version"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
"path"
"testing"
"time"
)
Expand Down Expand Up @@ -61,9 +64,10 @@ func Test_LatestVersion(t *testing.T) {
},
}

for _, c := range testMatrix {
for idx, c := range testMatrix {
mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion}
m := NewUpdateManager(peer.NewRecorder("")).WithCustomVersionUpdate(mockUpdate)
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
m := NewUpdateManager(peer.NewRecorder(""), statemanager.New(tmpFile)).WithCustomVersionUpdate(mockUpdate)

targetVersionChan := make(chan string, 1)

Expand Down Expand Up @@ -174,8 +178,9 @@ func Test_HandleUpdate(t *testing.T) {
shouldUpdate: false,
},
}
for _, c := range testMatrix {
m := NewUpdateManager(peer.NewRecorder("")).WithCustomVersionUpdate(&versionUpdateMock{latestVersion: c.latestVersion})
for idx, c := range testMatrix {
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
m := NewUpdateManager(peer.NewRecorder(""), statemanager.New(tmpFile)).WithCustomVersionUpdate(&versionUpdateMock{latestVersion: c.latestVersion})
targetVersionChan := make(chan string, 1)

m.updateFunc = func(ctx context.Context, targetVersion string) error {
Expand Down
3 changes: 1 addition & 2 deletions client/internal/updatemanager/update_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ package updatemanager
import (
"context"
"fmt"
"os/exec"

"golang.org/x/sys/windows/registry"
"os/exec"

log "github.com/sirupsen/logrus"
)
Expand Down
Loading
Loading