diff --git a/internal/pkg/agent/application/actions/handlers/handler_action_migrate.go b/internal/pkg/agent/application/actions/handlers/handler_action_migrate.go index fe9d69e1d95..4e1c8a18884 100644 --- a/internal/pkg/agent/application/actions/handlers/handler_action_migrate.go +++ b/internal/pkg/agent/application/actions/handlers/handler_action_migrate.go @@ -25,7 +25,9 @@ import ( const () type migrateCoordinator interface { - Migrate(_ context.Context, _ *fleetapi.ActionMigrate, _ func(done <-chan struct{}) backoff.Backoff) error + actionCoordinator + + Migrate(_ context.Context, _ *fleetapi.ActionMigrate, _ func(done <-chan struct{}) backoff.Backoff, _ func(context.Context, *fleetapi.ActionMigrate) error) error ReExec(callback reexec.ShutdownCallbackFn, argOverrides ...string) Protection() protection.Config } @@ -90,7 +92,7 @@ func (h *Migrate) Handle(ctx context.Context, a fleetapi.Action, ack acker.Acker action.Data.EnrollmentToken = enrollmentToken - if err := h.coord.Migrate(ctx, action, fleetgateway.RequestBackoff); err != nil { + if err := h.coord.Migrate(ctx, action, fleetgateway.RequestBackoff, h.notifyComponents); err != nil { // this should not happen, unmanaged agent should not receive the action // defensive coding to avoid misbehavior if errors.Is(err, coordinator.ErrNotManaged) { @@ -112,6 +114,22 @@ func (h *Migrate) Handle(ctx context.Context, a fleetapi.Action, ack acker.Acker return nil } +func (h *Migrate) notifyComponents(ctx context.Context, migrateAction *fleetapi.ActionMigrate) error { + state := h.coord.State() + ucs := findMatchingUnitsByActionType(state, fleetapi.ActionTypeMigrate) + if len(ucs) > 0 { + err := notifyUnitsOfProxiedAction(ctx, h.log, migrateAction, ucs, h.coord.PerformAction) + if err != nil { + return err + } + } else { + // Log and continue + h.log.Debugf("No components running for %v action type", fleetapi.ActionTypeMigrate) + } + + return nil +} + func (h *Migrate) ackFailure(ctx context.Context, err error, action *fleetapi.ActionMigrate, acker acker.Acker) { action.Err = err diff --git a/internal/pkg/agent/application/actions/handlers/handler_action_migrate_test.go b/internal/pkg/agent/application/actions/handlers/handler_action_migrate_test.go index e3d8a76136d..5560a34cc96 100644 --- a/internal/pkg/agent/application/actions/handlers/handler_action_migrate_test.go +++ b/internal/pkg/agent/application/actions/handlers/handler_action_migrate_test.go @@ -22,6 +22,7 @@ import ( "github.com/elastic/elastic-agent/internal/pkg/agent/protection" "github.com/elastic/elastic-agent/internal/pkg/core/backoff" "github.com/elastic/elastic-agent/internal/pkg/fleetapi" + "github.com/elastic/elastic-agent/pkg/component" "github.com/elastic/elastic-agent/pkg/core/logger/loggertest" mockinfo "github.com/elastic/elastic-agent/testing/mocks/internal_/pkg/agent/application/info" ) @@ -38,7 +39,7 @@ func TestActionMigratelHandler(t *testing.T) { ack.On("Commit", t.Context()).Return(nil) coord := &fakeMigrateCoordinator{} - coord.On("Migrate", mock.Anything, mock.Anything).Return(nil) + coord.On("Migrate", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) coord.On("ReExec", mock.Anything, mock.Anything) coord.On("Protection").Return(protection.Config{SignatureValidationKey: nil}) @@ -77,7 +78,8 @@ func TestActionMigratelHandler(t *testing.T) { ack.On("Commit", t.Context()).Return(nil) coord := &fakeMigrateCoordinator{} - coord.On("Migrate", mock.Anything, mock.Anything).Return(nil) + coord.On("State").Return(coordinator.State{}) + coord.On("Migrate", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) coord.On("ReExec", mock.Anything, mock.Anything) coord.On("Protection").Return(protection.Config{SignatureValidationKey: nil, Enabled: tc.protectionEnabled}) @@ -114,7 +116,8 @@ func TestActionMigratelHandler(t *testing.T) { ack.On("Commit", t.Context()).Return(nil) coord := &fakeMigrateCoordinator{} - coord.On("Migrate", mock.Anything, mock.Anything).Return(nil) + coord.On("State").Return(coordinator.State{}) + coord.On("Migrate", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) coord.On("ReExec", mock.Anything, mock.Anything) coord.On("Protection").Return(protection.Config{SignatureValidationKey: nil}) @@ -163,7 +166,8 @@ func TestActionMigratelHandler(t *testing.T) { ack.On("Commit", t.Context()).Return(nil) coord := &fakeMigrateCoordinator{} - coord.On("Migrate", mock.Anything, mock.Anything).Return(nil) + coord.On("State").Return(coordinator.State{}) + coord.On("Migrate", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) coord.On("ReExec", mock.Anything, mock.Anything) coord.On("Protection").Return(protection.Config{SignatureValidationKey: signatureValidationKey}) @@ -199,7 +203,7 @@ func TestActionMigratelHandler(t *testing.T) { ack.On("Commit", t.Context()).Return(nil) coord := &fakeMigrateCoordinator{} - coord.On("Migrate", mock.Anything, mock.Anything).Return(nil) + coord.On("Migrate", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) coord.On("ReExec", mock.Anything, mock.Anything) coord.On("Protection").Return(protection.Config{SignatureValidationKey: signatureValidationKey}) @@ -248,7 +252,8 @@ func TestActionMigratelHandler(t *testing.T) { ack.On("Commit", t.Context()).Return(nil) coord := &fakeMigrateCoordinator{} - coord.On("Migrate", mock.Anything, mock.Anything).Return(nil) + coord.On("State").Return(coordinator.State{}) + coord.On("Migrate", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) coord.On("ReExec", mock.Anything, mock.Anything) coord.On("Protection").Return(protection.Config{SignatureValidationKey: nil}) @@ -300,7 +305,7 @@ func TestActionMigratelHandler(t *testing.T) { ack.On("Commit", t.Context()).Return(nil) coord := &fakeMigrateCoordinator{} - coord.On("Migrate", mock.Anything, mock.Anything).Return(nil) + coord.On("Migrate", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) coord.On("ReExec", mock.Anything, mock.Anything) coord.On("Protection").Return(protection.Config{SignatureValidationKey: signatureValidationKey}) @@ -322,7 +327,8 @@ func TestActionMigratelHandler(t *testing.T) { ack.On("Commit", t.Context()).Return(nil) coord := &fakeMigrateCoordinator{} - coord.On("Migrate", mock.Anything, mock.Anything).Return(coordinator.ErrFleetServer) + coord.On("State").Return(coordinator.State{}) + coord.On("Migrate", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(coordinator.ErrFleetServer) coord.On("ReExec", mock.Anything, mock.Anything) coord.On("Protection").Return(protection.Config{SignatureValidationKey: nil}) @@ -343,11 +349,21 @@ type fakeMigrateCoordinator struct { mock.Mock } -func (f *fakeMigrateCoordinator) Migrate(ctx context.Context, a *fleetapi.ActionMigrate, _ func(done <-chan struct{}) backoff.Backoff) error { - args := f.Called(ctx, a) +func (f *fakeMigrateCoordinator) Migrate(ctx context.Context, a *fleetapi.ActionMigrate, b func(done <-chan struct{}) backoff.Backoff, n func(context.Context, *fleetapi.ActionMigrate) error) error { + args := f.Called(ctx, a, b, n) return args.Error(0) } +func (f *fakeMigrateCoordinator) State() coordinator.State { + args := f.Called() + return args.Get(0).(coordinator.State) +} + +func (f *fakeMigrateCoordinator) PerformAction(ctx context.Context, comp component.Component, unit component.Unit, name string, params map[string]interface{}) (map[string]interface{}, error) { + args := f.Called(ctx, comp, unit, name, params) + return args.Get(0).(map[string]interface{}), args.Error(1) +} + func (f *fakeMigrateCoordinator) ReExec(callback reexec.ShutdownCallbackFn, argOverrides ...string) { f.Called(callback, argOverrides) } diff --git a/internal/pkg/agent/application/coordinator/coordinator.go b/internal/pkg/agent/application/coordinator/coordinator.go index f20b5a3e849..260ee630667 100644 --- a/internal/pkg/agent/application/coordinator/coordinator.go +++ b/internal/pkg/agent/application/coordinator/coordinator.go @@ -10,6 +10,7 @@ import ( "fmt" "reflect" "strings" + "sync" "sync/atomic" "time" @@ -361,6 +362,11 @@ type Coordinator struct { // Abstraction for diagnostics AddSecretMarkers function for testability secretMarkerFunc func(*logger.Logger, *config.Config) error + + // migrationProgressWg is used to block processing of incoming policies after enroll is done + // incomming policies are blocked until we reboot so components receiving proxied MIGRATE action + // are not confused + migrationProgressWg sync.WaitGroup } // The channels Coordinator reads to receive updates from the various managers. @@ -596,7 +602,12 @@ func (c *Coordinator) ReExec(callback reexec.ShutdownCallbackFn, argOverrides .. // Migrate migrates agent to a new cluster and ACKs success to the old one. // In case of failure no ack is performed and error is returned. -func (c *Coordinator) Migrate(ctx context.Context, action *fleetapi.ActionMigrate, backoffFactory func(done <-chan struct{}) backoff.Backoff) error { +func (c *Coordinator) Migrate( + ctx context.Context, + action *fleetapi.ActionMigrate, + backoffFactory func(done <-chan struct{}) backoff.Backoff, + notifyFn func(context.Context, *fleetapi.ActionMigrate) error, +) error { if !c.isManaged { return ErrNotManaged } @@ -666,6 +677,24 @@ func (c *Coordinator) Migrate(ctx context.Context, action *fleetapi.ActionMigrat return errors.Join(fmt.Errorf("failed to enroll: %w", err), restoreErr) } + // lock processing of new config before notifying components + // hold lock until notification failure or reexec + c.migrationProgressWg.Add(1) + if notifyFn != nil { + // notify before completing migration + // components such endpoint are crucial to work even though it's on stale cluster + // error on component side is returned as part of Action response + if err := notifyFn(ctx, action); err != nil { + restoreErr := RestoreConfig() + + // in case of failure no need to lock processing + // safe to forward policy from source cluster + c.migrationProgressWg.Done() + + return errors.Join(fmt.Errorf("failed to notify components: %w", err), restoreErr) + } + } + // ACK success to source fleet server if err := c.ackMigration(ctx, action, c.fleetAcker); err != nil { c.logger.Warnf("failed to ACK success: %v", err) @@ -677,23 +706,30 @@ func (c *Coordinator) Migrate(ctx context.Context, action *fleetapi.ActionMigrat return fmt.Errorf("failed to clean backup config: %w", err) } + c.bestEffortUnenroll(ctx, originalOptions) + + return nil +} + +func (c *Coordinator) bestEffortUnenroll(ctx context.Context, originalOptions enroll.EnrollOptions) { originalRemoteConfig, err := originalOptions.RemoteConfig(false) if err != nil { - return fmt.Errorf("failed to construct original remote config: %w", err) + c.logger.Warnf("failed to construct original remote config: %v", err) + return } originalClient, err := fleetapiClient.NewAuthWithConfig( c.logger, originalOptions.EnrollAPIKey, originalRemoteConfig) if err != nil { - return fmt.Errorf("failed to create original fleet client: %w", err) + c.logger.Warnf("failed to create original fleet client: %v", err) + return } // Best effort: call unenroll on source cluster once done if err := c.unenroll(ctx, originalClient); err != nil { c.logger.Warnf("failed to unenroll from original cluster: %v", err) + return } - - return nil } type upgradeOpts struct { @@ -1532,6 +1568,8 @@ func (c *Coordinator) runLoopIteration(ctx context.Context) { // Always called on the main Coordinator goroutine. func (c *Coordinator) processConfig(ctx context.Context, cfg *config.Config) (err error) { + c.migrationProgressWg.Wait() + if c.otelMgr != nil { c.otelCfg = cfg.OTel } diff --git a/internal/pkg/agent/application/coordinator/coordinator_unit_test.go b/internal/pkg/agent/application/coordinator/coordinator_unit_test.go index 5c4146d098a..008a9505c81 100644 --- a/internal/pkg/agent/application/coordinator/coordinator_unit_test.go +++ b/internal/pkg/agent/application/coordinator/coordinator_unit_test.go @@ -1582,7 +1582,7 @@ func TestCoordinator_UnmanagedAgent_SkipsMigrate(t *testing.T) { return backoff.NewExpBackoff(done, 30*time.Millisecond, 2*time.Second) } - err := coord.Migrate(ctx, action, backoffFactory) + err := coord.Migrate(ctx, action, backoffFactory, nil) require.ErrorIs(t, err, ErrNotManaged) } @@ -1628,7 +1628,7 @@ func TestCoordinator_FleetServer_SkipsMigration(t *testing.T) { return backoff.NewExpBackoff(done, 30*time.Millisecond, 2*time.Second) } - err := coord.Migrate(ctx, action, backoffFactory) + err := coord.Migrate(ctx, action, backoffFactory, nil) require.ErrorIs(t, err, ErrFleetServer) } @@ -1785,7 +1785,7 @@ func TestCoordinator_InitiatesMigration(t *testing.T) { return backoff.NewExpBackoff(done, 30*time.Millisecond, 2*time.Second) } - err = coord.Migrate(ctx, action, backoffFactory) + err = coord.Migrate(ctx, action, backoffFactory, nil) require.NoError(t, err) acker.AssertCalled(t, "Ack", mock.Anything, action) @@ -1793,6 +1793,171 @@ func TestCoordinator_InitiatesMigration(t *testing.T) { require.True(t, unenrollCalled) } +func TestCoordinator_InvalidComponentRevertsMigration(t *testing.T) { + fipsutils.SkipIfFIPSOnly(t, "vault does not use NewGCMWithRandomNonce.") + cfgPath := paths.Config() + defer paths.SetConfig(cfgPath) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + tmpConfig := t.TempDir() + paths.SetConfig(tmpConfig) + agentConfigFile := paths.ConfigFile() + + var unenrollCalled bool + oldFleetServer := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + if strings.Contains(r.URL.Path, "unenroll") { + unenrollCalled = true + } + + _, err := w.Write(nil) + require.NoError(t, err) + + })) + defer oldFleetServer.Close() + + fleetConfig := configuration.DefaultFleetAgentConfig() + fleetConfig.Enabled = true + fleetConfig.AccessAPIKey = "access-api-key" + fleetConfig.Info.ID = "agent.id" + fleetConfig.Client.Host = oldFleetServer.URL + fleetConfig.Client.Hosts = []string{oldFleetServer.URL} + + agentConfig := &configuration.Configuration{ + Fleet: fleetConfig, + Settings: &configuration.SettingsConfig{ + ID: "agent.id", + }, + } + + rawAgentConfig := &configuration.Configuration{ + Fleet: &configuration.FleetAgentConfig{ + Enabled: true, + }, + Settings: &configuration.SettingsConfig{ + ID: "agent.id", + }, + } + + rawAgentConfigData, err := yaml.Marshal(rawAgentConfig) + require.NoError(t, err) + require.NoError(t, os.WriteFile(agentConfigFile, rawAgentConfigData, 0644)) + + // setup secret normally previously created by enroll + err = secret.CreateAgentSecret(ctx, + vault.WithUnprivileged(true), + vault.WithVaultPath(paths.AgentVaultPath()), + ) + require.NoError(t, err) + + store, err := storage.NewEncryptedDiskStore(ctx, paths.AgentConfigFile(), + storage.WithUnprivileged(true), + storage.WithVaultPath(paths.AgentVaultPath()), + ) + require.NoError(t, err) + + fleetAgentConfigData, err := yaml.Marshal(agentConfig) + require.NoError(t, err) + require.NoError(t, store.Save(bytes.NewReader(fleetAgentConfigData))) + + // overrideStateChan has buffer 2 so we can run on a single goroutine, + // since a successful upgrade sets the override state twice. + overrideStateChan := make(chan *coordinatorOverrideState, 2) + + // similarly, upgradeDetailsChan is a buffered channel as well. + upgradeDetailsChan := make(chan *details.Details, 2) + + // Create a manager that will allow upgrade attempts but return a failure + // from Upgrade itself (success requires testing ReExec and we aren't + // quite ready to do that yet). + upgradeMgr := &fakeUpgradeManager{ + upgradeable: true, + upgradeErr: errors.New("failed upgrade"), + } + + acker := &fakeActionAcker{} + + acker.On("Ack", mock.Anything, mock.Anything).Return(nil) + acker.On("Commit", mock.Anything).Return(nil) + + agentInfo, err := info.NewAgentInfo(ctx, false) + require.NoError(t, err) + coord := &Coordinator{ + stateBroadcaster: broadcaster.New(State{}, 0, 0), + overrideStateChan: overrideStateChan, + upgradeDetailsChan: upgradeDetailsChan, + upgradeMgr: upgradeMgr, + logger: logp.NewLogger("testing"), + // is managed so we proceed with migration + isManaged: true, + fleetAcker: acker, + agentInfo: agentInfo, + } + + coord.state.Components = append(coord.state.Components, runtime.ComponentComponentState{ + Component: component.Component{ + InputType: "not-a-fleet-server", + }, + }) + + newFleetServer := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + if strings.Contains(r.URL.Path, "status") { + _, err := w.Write(nil) + require.NoError(t, err) + return + } + + body := []byte(`{ + "action": "created", + "item": { + "id": "a4937110-e53e-11e9-934f-47a8e38a522c", + "active": true, + "policy_id": "default", + "type": "PERMANENT", + "enrolled_at": "2019-10-02T18:01:22.337Z", + "user_provided_metadata": {}, + "local_metadata": {}, + "actions": [], + "access_api_key": "API_KEY" + } + }`) + _, err := w.Write(body) + require.NoError(t, err) + + })) + defer newFleetServer.Close() + + action := &fleetapi.ActionMigrate{ + Data: fleetapi.ActionMigrateData{ + TargetURI: newFleetServer.URL, + EnrollmentToken: "token", + Settings: json.RawMessage(`{"insecure":true}`), + }, + ActionID: "migrate-id", + ActionType: "MIGRATE", + } + + backoffFactory := func(done <-chan struct{}) backoff.Backoff { + return backoff.NewExpBackoff(done, 30*time.Millisecond, 2*time.Second) + } + + failingComponentNotify := func(_ context.Context, _ *fleetapi.ActionMigrate) error { + return fmt.Errorf("failed to notify") + } + + err = coord.Migrate(ctx, action, backoffFactory, failingComponentNotify) + require.Error(t, err) + + acker.AssertNumberOfCalls(t, "Ack", 0) + acker.AssertNotCalled(t, "Commit", 0) + require.False(t, unenrollCalled) +} + // Returns an empty but non-nil set of transpiler variables for testing // (Coordinator will only regenerate its component model when it has non-nil // vars). diff --git a/internal/pkg/agent/application/enroll/options.go b/internal/pkg/agent/application/enroll/options.go index cd87c69e35c..80e90f5d2bd 100644 --- a/internal/pkg/agent/application/enroll/options.go +++ b/internal/pkg/agent/application/enroll/options.go @@ -139,9 +139,11 @@ func MergeOptionsWithMigrateAction(action *fleetapi.ActionMigrate, options Enrol return EnrollOptions{}, fmt.Errorf("failed to decode enroll options: %w", err) } - // do not preserve ID by default - delete(configMap, "id") - options.ID = "" + if len(options.ReplaceToken) == 0 { + // do not preserve ID by default + delete(configMap, "id") + options.ID = "" + } // overwriting what's needed if len(action.Data.Settings) > 0 { diff --git a/internal/pkg/fleetapi/action.go b/internal/pkg/fleetapi/action.go index 921a01dc4cc..469bdef8e59 100644 --- a/internal/pkg/fleetapi/action.go +++ b/internal/pkg/fleetapi/action.go @@ -467,6 +467,13 @@ func (a *ActionMigrate) String() string { return s.String() } +// MarshalMap marshals ActionMigrate into a corresponding map +func (a *ActionMigrate) MarshalMap() (map[string]interface{}, error) { + var res map[string]interface{} + err := mapstructure.Decode(a, &res) + return res, err +} + func (a *ActionMigrate) AckEvent() AckEvent { event := newAckEvent(a.ActionID, a.ActionType) if a.Err != nil { diff --git a/specs/endpoint-security.spec.yml b/specs/endpoint-security.spec.yml index ef2d2634123..eb0ec54cf45 100644 --- a/specs/endpoint-security.spec.yml +++ b/specs/endpoint-security.spec.yml @@ -14,6 +14,7 @@ inputs: proxied_actions: &proxied_actions - UNENROLL - UPGRADE + - MIGRATE runtime: preventions: - condition: ${runtime.native_arch} != '' and ${runtime.arch} != ${runtime.native_arch}