Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ import (
const ()

type migrateCoordinator interface {
Migrate(_ context.Context, _ *fleetapi.ActionMigrate, _ func(done <-chan struct{}) backoff.Backoff) error
actionCoordinator

Migrate(_ context.Context, _ *fleetapi.ActionMigrate, _ func(done <-chan struct{}) backoff.Backoff, _ func(context.Context, *fleetapi.ActionMigrate) error) error
ReExec(callback reexec.ShutdownCallbackFn, argOverrides ...string)
Protection() protection.Config
}
Expand Down Expand Up @@ -90,7 +92,7 @@ func (h *Migrate) Handle(ctx context.Context, a fleetapi.Action, ack acker.Acker

action.Data.EnrollmentToken = enrollmentToken

if err := h.coord.Migrate(ctx, action, fleetgateway.RequestBackoff); err != nil {
if err := h.coord.Migrate(ctx, action, fleetgateway.RequestBackoff, h.notifyComponents); err != nil {
// this should not happen, unmanaged agent should not receive the action
// defensive coding to avoid misbehavior
if errors.Is(err, coordinator.ErrNotManaged) {
Expand All @@ -112,6 +114,22 @@ func (h *Migrate) Handle(ctx context.Context, a fleetapi.Action, ack acker.Acker
return nil
}

func (h *Migrate) notifyComponents(ctx context.Context, migrateAction *fleetapi.ActionMigrate) error {
state := h.coord.State()
ucs := findMatchingUnitsByActionType(state, fleetapi.ActionTypeMigrate)
if len(ucs) > 0 {
err := notifyUnitsOfProxiedAction(ctx, h.log, migrateAction, ucs, h.coord.PerformAction)
if err != nil {
return err
}
} else {
// Log and continue
h.log.Debugf("No components running for %v action type", fleetapi.ActionTypeMigrate)
}

return nil
}

func (h *Migrate) ackFailure(ctx context.Context, err error, action *fleetapi.ActionMigrate, acker acker.Acker) {
action.Err = err

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/elastic/elastic-agent/internal/pkg/agent/protection"
"github.com/elastic/elastic-agent/internal/pkg/core/backoff"
"github.com/elastic/elastic-agent/internal/pkg/fleetapi"
"github.com/elastic/elastic-agent/pkg/component"
"github.com/elastic/elastic-agent/pkg/core/logger/loggertest"
mockinfo "github.com/elastic/elastic-agent/testing/mocks/internal_/pkg/agent/application/info"
)
Expand All @@ -38,7 +39,7 @@ func TestActionMigratelHandler(t *testing.T) {
ack.On("Commit", t.Context()).Return(nil)

coord := &fakeMigrateCoordinator{}
coord.On("Migrate", mock.Anything, mock.Anything).Return(nil)
coord.On("Migrate", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
coord.On("ReExec", mock.Anything, mock.Anything)
coord.On("Protection").Return(protection.Config{SignatureValidationKey: nil})

Expand Down Expand Up @@ -77,7 +78,8 @@ func TestActionMigratelHandler(t *testing.T) {
ack.On("Commit", t.Context()).Return(nil)

coord := &fakeMigrateCoordinator{}
coord.On("Migrate", mock.Anything, mock.Anything).Return(nil)
coord.On("State").Return(coordinator.State{})
coord.On("Migrate", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
coord.On("ReExec", mock.Anything, mock.Anything)
coord.On("Protection").Return(protection.Config{SignatureValidationKey: nil, Enabled: tc.protectionEnabled})

Expand Down Expand Up @@ -114,7 +116,8 @@ func TestActionMigratelHandler(t *testing.T) {
ack.On("Commit", t.Context()).Return(nil)

coord := &fakeMigrateCoordinator{}
coord.On("Migrate", mock.Anything, mock.Anything).Return(nil)
coord.On("State").Return(coordinator.State{})
coord.On("Migrate", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
coord.On("ReExec", mock.Anything, mock.Anything)
coord.On("Protection").Return(protection.Config{SignatureValidationKey: nil})

Expand Down Expand Up @@ -163,7 +166,8 @@ func TestActionMigratelHandler(t *testing.T) {
ack.On("Commit", t.Context()).Return(nil)

coord := &fakeMigrateCoordinator{}
coord.On("Migrate", mock.Anything, mock.Anything).Return(nil)
coord.On("State").Return(coordinator.State{})
coord.On("Migrate", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
coord.On("ReExec", mock.Anything, mock.Anything)
coord.On("Protection").Return(protection.Config{SignatureValidationKey: signatureValidationKey})

Expand Down Expand Up @@ -199,7 +203,7 @@ func TestActionMigratelHandler(t *testing.T) {
ack.On("Commit", t.Context()).Return(nil)

coord := &fakeMigrateCoordinator{}
coord.On("Migrate", mock.Anything, mock.Anything).Return(nil)
coord.On("Migrate", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
coord.On("ReExec", mock.Anything, mock.Anything)
coord.On("Protection").Return(protection.Config{SignatureValidationKey: signatureValidationKey})

Expand Down Expand Up @@ -248,7 +252,8 @@ func TestActionMigratelHandler(t *testing.T) {
ack.On("Commit", t.Context()).Return(nil)

coord := &fakeMigrateCoordinator{}
coord.On("Migrate", mock.Anything, mock.Anything).Return(nil)
coord.On("State").Return(coordinator.State{})
coord.On("Migrate", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
coord.On("ReExec", mock.Anything, mock.Anything)
coord.On("Protection").Return(protection.Config{SignatureValidationKey: nil})

Expand Down Expand Up @@ -300,7 +305,7 @@ func TestActionMigratelHandler(t *testing.T) {
ack.On("Commit", t.Context()).Return(nil)

coord := &fakeMigrateCoordinator{}
coord.On("Migrate", mock.Anything, mock.Anything).Return(nil)
coord.On("Migrate", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
coord.On("ReExec", mock.Anything, mock.Anything)
coord.On("Protection").Return(protection.Config{SignatureValidationKey: signatureValidationKey})

Expand All @@ -322,7 +327,8 @@ func TestActionMigratelHandler(t *testing.T) {
ack.On("Commit", t.Context()).Return(nil)

coord := &fakeMigrateCoordinator{}
coord.On("Migrate", mock.Anything, mock.Anything).Return(coordinator.ErrFleetServer)
coord.On("State").Return(coordinator.State{})
coord.On("Migrate", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(coordinator.ErrFleetServer)
coord.On("ReExec", mock.Anything, mock.Anything)
coord.On("Protection").Return(protection.Config{SignatureValidationKey: nil})

Expand All @@ -343,11 +349,21 @@ type fakeMigrateCoordinator struct {
mock.Mock
}

func (f *fakeMigrateCoordinator) Migrate(ctx context.Context, a *fleetapi.ActionMigrate, _ func(done <-chan struct{}) backoff.Backoff) error {
args := f.Called(ctx, a)
func (f *fakeMigrateCoordinator) Migrate(ctx context.Context, a *fleetapi.ActionMigrate, b func(done <-chan struct{}) backoff.Backoff, n func(context.Context, *fleetapi.ActionMigrate) error) error {
args := f.Called(ctx, a, b, n)
return args.Error(0)
}

func (f *fakeMigrateCoordinator) State() coordinator.State {
args := f.Called()
return args.Get(0).(coordinator.State)
}

func (f *fakeMigrateCoordinator) PerformAction(ctx context.Context, comp component.Component, unit component.Unit, name string, params map[string]interface{}) (map[string]interface{}, error) {
args := f.Called(ctx, comp, unit, name, params)
return args.Get(0).(map[string]interface{}), args.Error(1)
}

func (f *fakeMigrateCoordinator) ReExec(callback reexec.ShutdownCallbackFn, argOverrides ...string) {
f.Called(callback, argOverrides)
}
Expand Down
48 changes: 43 additions & 5 deletions internal/pkg/agent/application/coordinator/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"fmt"
"reflect"
"strings"
"sync"
"sync/atomic"
"time"

Expand Down Expand Up @@ -361,6 +362,11 @@ type Coordinator struct {

// Abstraction for diagnostics AddSecretMarkers function for testability
secretMarkerFunc func(*logger.Logger, *config.Config) error

// migrationProgressWg is used to block processing of incoming policies after enroll is done
// incomming policies are blocked until we reboot so components receiving proxied MIGRATE action
// are not confused
migrationProgressWg sync.WaitGroup
}

// The channels Coordinator reads to receive updates from the various managers.
Expand Down Expand Up @@ -596,7 +602,12 @@ func (c *Coordinator) ReExec(callback reexec.ShutdownCallbackFn, argOverrides ..

// Migrate migrates agent to a new cluster and ACKs success to the old one.
// In case of failure no ack is performed and error is returned.
func (c *Coordinator) Migrate(ctx context.Context, action *fleetapi.ActionMigrate, backoffFactory func(done <-chan struct{}) backoff.Backoff) error {
func (c *Coordinator) Migrate(
ctx context.Context,
action *fleetapi.ActionMigrate,
backoffFactory func(done <-chan struct{}) backoff.Backoff,
notifyFn func(context.Context, *fleetapi.ActionMigrate) error,
) error {
if !c.isManaged {
return ErrNotManaged
}
Expand Down Expand Up @@ -666,6 +677,24 @@ func (c *Coordinator) Migrate(ctx context.Context, action *fleetapi.ActionMigrat
return errors.Join(fmt.Errorf("failed to enroll: %w", err), restoreErr)
}

// lock processing of new config before notifying components
// hold lock until notification failure or reexec
c.migrationProgressWg.Add(1)
if notifyFn != nil {
// notify before completing migration
// components such endpoint are crucial to work even though it's on stale cluster
// error on component side is returned as part of Action response
if err := notifyFn(ctx, action); err != nil {
restoreErr := RestoreConfig()

// in case of failure no need to lock processing
// safe to forward policy from source cluster
c.migrationProgressWg.Done()

return errors.Join(fmt.Errorf("failed to notify components: %w", err), restoreErr)
}
}

// ACK success to source fleet server
if err := c.ackMigration(ctx, action, c.fleetAcker); err != nil {
c.logger.Warnf("failed to ACK success: %v", err)
Expand All @@ -677,23 +706,30 @@ func (c *Coordinator) Migrate(ctx context.Context, action *fleetapi.ActionMigrat
return fmt.Errorf("failed to clean backup config: %w", err)
}

c.bestEffortUnenroll(ctx, originalOptions)

return nil
}

func (c *Coordinator) bestEffortUnenroll(ctx context.Context, originalOptions enroll.EnrollOptions) {
originalRemoteConfig, err := originalOptions.RemoteConfig(false)
if err != nil {
return fmt.Errorf("failed to construct original remote config: %w", err)
c.logger.Warnf("failed to construct original remote config: %v", err)
return
}

originalClient, err := fleetapiClient.NewAuthWithConfig(
c.logger, originalOptions.EnrollAPIKey, originalRemoteConfig)
if err != nil {
return fmt.Errorf("failed to create original fleet client: %w", err)
c.logger.Warnf("failed to create original fleet client: %v", err)
return
}

// Best effort: call unenroll on source cluster once done
if err := c.unenroll(ctx, originalClient); err != nil {
c.logger.Warnf("failed to unenroll from original cluster: %v", err)
return
}

return nil
}

type upgradeOpts struct {
Expand Down Expand Up @@ -1532,6 +1568,8 @@ func (c *Coordinator) runLoopIteration(ctx context.Context) {

// Always called on the main Coordinator goroutine.
func (c *Coordinator) processConfig(ctx context.Context, cfg *config.Config) (err error) {
c.migrationProgressWg.Wait()

if c.otelMgr != nil {
c.otelCfg = cfg.OTel
}
Expand Down
Loading
Loading