diff --git a/lib/autoupdate/agent/config.go b/lib/autoupdate/agent/config.go index f1347d17c20e5..c2f51f3240cb4 100644 --- a/lib/autoupdate/agent/config.go +++ b/lib/autoupdate/agent/config.go @@ -222,16 +222,8 @@ func validateConfigSpec(spec *UpdateSpec, override OverrideConfig) error { if override.Path != "" { spec.Path = override.Path } - if override.Group != "" { - spec.Group = override.Group - } - switch override.BaseURL { - case "": - case "default": - spec.BaseURL = "" - default: - spec.BaseURL = override.BaseURL - } + spec.Group = overrideOptional(spec.Group, override.Group) + spec.BaseURL = overrideOptional(spec.BaseURL, override.BaseURL) if spec.BaseURL != "" && !strings.HasPrefix(strings.ToLower(spec.BaseURL), "https://") { return trace.Errorf("Teleport download base URL %s must use TLS (https://)", spec.BaseURL) @@ -245,6 +237,17 @@ func validateConfigSpec(spec *UpdateSpec, override OverrideConfig) error { return nil } +func overrideOptional(orig, override string) string { + switch override { + case "": + return orig + case "default": + return "" + default: + return override + } +} + // Status of the agent auto-updates system. type Status struct { UpdateSpec `yaml:",inline"` diff --git a/lib/autoupdate/agent/config_test.go b/lib/autoupdate/agent/config_test.go index 39d318cd6ee4c..d301a16a0b1fa 100644 --- a/lib/autoupdate/agent/config_test.go +++ b/lib/autoupdate/agent/config_test.go @@ -125,3 +125,111 @@ func TestNewRevisionFromDir(t *testing.T) { }) } } + +func TestValidateConfigSpec(t *testing.T) { + t.Parallel() + + for _, tt := range []struct { + name string + config UpdateSpec + override UpdateSpec + result UpdateSpec + errMatch string + }{ + { + name: "overrides", + config: UpdateSpec{ + Proxy: "proxy", + Path: "/path", + Group: "group", + BaseURL: "https://example.com", + }, + override: UpdateSpec{ + Enabled: true, + Pinned: true, + Proxy: "overrideProxy", + Path: "/overridePath", + Group: "group2", + BaseURL: "https://example.com", + }, + result: UpdateSpec{ + Enabled: true, + Pinned: true, + Proxy: "overrideProxy", + Path: "/overridePath", + Group: "group2", + BaseURL: "https://example.com", + }, + }, + { + name: "default overrides", + config: UpdateSpec{ + Proxy: "proxy", + Path: "/path", + Group: "group", + BaseURL: "https://example.com", + }, + override: UpdateSpec{ + Proxy: "default", + Path: "default", + Group: "default", + BaseURL: "default", + }, + result: UpdateSpec{ + Proxy: "default", + Path: "default", + }, + }, + { + name: "only overrides", + override: UpdateSpec{ + Enabled: true, + Pinned: true, + Proxy: "overrideProxy", + Path: "/overridePath", + Group: "group2", + BaseURL: "https://example.com", + }, + result: UpdateSpec{ + Enabled: true, + Pinned: true, + Proxy: "overrideProxy", + Path: "/overridePath", + Group: "group2", + BaseURL: "https://example.com", + }, + }, + { + name: "no overrides", + config: UpdateSpec{ + Proxy: "proxy", + Path: "/path", + Group: "group", + BaseURL: "https://example.com", + }, + result: UpdateSpec{ + Proxy: "proxy", + Path: "/path", + Group: "group", + BaseURL: "https://example.com", + }, + }, + { + name: "BaseURL validation fails", + override: UpdateSpec{ + BaseURL: "http://example.com", + }, + errMatch: "must use TLS", + }, + } { + t.Run(tt.name, func(t *testing.T) { + err := validateConfigSpec(&tt.config, OverrideConfig{UpdateSpec: tt.override}) + if tt.errMatch != "" { + require.ErrorContains(t, err, tt.errMatch) + return + } + require.NoError(t, err) + require.Equal(t, tt.result, tt.config) + }) + } +} diff --git a/lib/autoupdate/agent/updater.go b/lib/autoupdate/agent/updater.go index cf83d4a527b74..aa852256e0f98 100644 --- a/lib/autoupdate/agent/updater.go +++ b/lib/autoupdate/agent/updater.go @@ -884,12 +884,16 @@ func (u *Updater) find(ctx context.Context, cfg *UpdateConfig, id string) (FindR if err != nil { return FindResp{}, trace.Wrap(err, "failed to parse proxy server address") } + group := cfg.Spec.Group + if group == "" { + group = "default" + } resp, err := webclient.Find(&webclient.Config{ Context: ctx, ProxyAddr: addr.Addr, Insecure: u.InsecureSkipVerify, Timeout: 30 * time.Second, - UpdateGroup: cfg.Spec.Group, + UpdateGroup: group, UpdateID: id, Pool: u.Pool, }) diff --git a/lib/autoupdate/agent/updater_test.go b/lib/autoupdate/agent/updater_test.go index e57a28ed5ed49..05e78f5464ed7 100644 --- a/lib/autoupdate/agent/updater_test.go +++ b/lib/autoupdate/agent/updater_test.go @@ -460,6 +460,7 @@ func TestUpdater_Update(t *testing.T) { installedRevision: NewRevision("16.3.0", 0), installedBaseURL: autoupdate.DefaultBaseURL, + requestGroup: "default", errMatch: "install error", }, { @@ -476,7 +477,8 @@ func TestUpdater_Update(t *testing.T) { Active: NewRevision("16.3.0", 0), }, }, - inWindow: true, + inWindow: true, + requestGroup: "default", }, { name: "version already installed outside of window", @@ -492,6 +494,7 @@ func TestUpdater_Update(t *testing.T) { Active: NewRevision("16.3.0", 0), }, }, + requestGroup: "default", }, { name: "version detects as linked", @@ -510,6 +513,7 @@ func TestUpdater_Update(t *testing.T) { linkedRevisions: []Revision{NewRevision("16.3.0", 0)}, inWindow: true, + requestGroup: "default", installedRevision: NewRevision("16.3.0", 0), installedBaseURL: "https://example.com", linkedRevision: NewRevision("16.3.0", 0), @@ -536,6 +540,7 @@ func TestUpdater_Update(t *testing.T) { }, inWindow: true, + requestGroup: "default", installedRevision: NewRevision("16.3.0", 0), installedBaseURL: "https://example.com", linkedRevision: NewRevision("16.3.0", 0), @@ -564,6 +569,7 @@ func TestUpdater_Update(t *testing.T) { inWindow: true, linkedRevisions: []Revision{NewRevision("backup-version", 0)}, + requestGroup: "default", installedRevision: NewRevision("16.3.0", 0), installedBaseURL: "https://example.com", linkedRevision: NewRevision("16.3.0", 0), @@ -588,7 +594,8 @@ func TestUpdater_Update(t *testing.T) { Backup: toPtr(NewRevision("backup-version", 0)), }, }, - inWindow: true, + inWindow: true, + requestGroup: "default", }, { name: "config does not exist", @@ -611,6 +618,7 @@ func TestUpdater_Update(t *testing.T) { inWindow: true, flags: autoupdate.FlagEnterprise | autoupdate.FlagFIPS, + requestGroup: "default", installedRevision: NewRevision("16.3.0", autoupdate.FlagEnterprise|autoupdate.FlagFIPS), installedBaseURL: "https://example.com", linkedRevision: NewRevision("16.3.0", autoupdate.FlagEnterprise|autoupdate.FlagFIPS), @@ -644,6 +652,7 @@ func TestUpdater_Update(t *testing.T) { inWindow: true, setupErr: errors.New("setup error"), + requestGroup: "default", installedRevision: NewRevision("16.3.0", 0), installedBaseURL: "https://example.com", linkedRevision: NewRevision("16.3.0", 0), @@ -673,10 +682,11 @@ func TestUpdater_Update(t *testing.T) { inWindow: true, agpl: true, - reloadCalls: 0, - revertCalls: 0, - setupCalls: 0, - errMatch: "AGPL", + requestGroup: "default", + reloadCalls: 0, + revertCalls: 0, + setupCalls: 0, + errMatch: "AGPL", }, { name: "skip version", @@ -694,7 +704,8 @@ func TestUpdater_Update(t *testing.T) { Skip: toPtr(NewRevision("16.3.0", 0)), }, }, - inWindow: true, + inWindow: true, + requestGroup: "default", }, { name: "pinned version", @@ -712,7 +723,8 @@ func TestUpdater_Update(t *testing.T) { Backup: toPtr(NewRevision("backup-version", 0)), }, }, - inWindow: true, + inWindow: true, + requestGroup: "default", }, } @@ -1486,6 +1498,7 @@ func TestUpdater_Install(t *testing.T) { installedRevision: NewRevision("16.3.0", 0), installedBaseURL: autoupdate.DefaultBaseURL, linkedRevision: NewRevision("16.3.0", 0), + requestGroup: "default", setupCalls: 1, restarted: true, }, @@ -1503,6 +1516,7 @@ func TestUpdater_Install(t *testing.T) { installedRevision: NewRevision("16.3.0", 0), installedBaseURL: autoupdate.DefaultBaseURL, linkedRevision: NewRevision("16.3.0", 0), + requestGroup: "default", setupCalls: 1, restarted: true, }, @@ -1528,6 +1542,7 @@ func TestUpdater_Install(t *testing.T) { installedRevision: NewRevision("16.3.0", 0), installedBaseURL: autoupdate.DefaultBaseURL, + requestGroup: "default", errMatch: "install error", }, { @@ -1536,8 +1551,9 @@ func TestUpdater_Install(t *testing.T) { Version: updateConfigVersion, Kind: updateConfigKind, }, - agpl: true, - errMatch: "AGPL", + agpl: true, + requestGroup: "default", + errMatch: "AGPL", }, { name: "version already installed", @@ -1552,6 +1568,7 @@ func TestUpdater_Install(t *testing.T) { installedRevision: NewRevision("16.3.0", 0), installedBaseURL: autoupdate.DefaultBaseURL, linkedRevision: NewRevision("16.3.0", 0), + requestGroup: "default", setupCalls: 1, restarted: false, }, @@ -1570,6 +1587,7 @@ func TestUpdater_Install(t *testing.T) { installedBaseURL: autoupdate.DefaultBaseURL, linkedRevision: NewRevision("16.3.0", 0), removedRevision: NewRevision("backup-version", 0), + requestGroup: "default", setupCalls: 1, restarted: true, }, @@ -1587,6 +1605,7 @@ func TestUpdater_Install(t *testing.T) { installedRevision: NewRevision("16.3.0", 0), installedBaseURL: autoupdate.DefaultBaseURL, linkedRevision: NewRevision("16.3.0", 0), + requestGroup: "default", setupCalls: 1, }, { @@ -1595,6 +1614,7 @@ func TestUpdater_Install(t *testing.T) { installedRevision: NewRevision("16.3.0", 0), installedBaseURL: autoupdate.DefaultBaseURL, linkedRevision: NewRevision("16.3.0", 0), + requestGroup: "default", setupCalls: 1, restarted: true, }, @@ -1604,6 +1624,7 @@ func TestUpdater_Install(t *testing.T) { installedRevision: NewRevision("16.3.0", autoupdate.FlagEnterprise|autoupdate.FlagFIPS), installedBaseURL: autoupdate.DefaultBaseURL, linkedRevision: NewRevision("16.3.0", autoupdate.FlagEnterprise|autoupdate.FlagFIPS), + requestGroup: "default", setupCalls: 1, restarted: true, }, @@ -1619,6 +1640,7 @@ func TestUpdater_Install(t *testing.T) { installedRevision: NewRevision("16.3.0", 0), installedBaseURL: autoupdate.DefaultBaseURL, linkedRevision: NewRevision("16.3.0", 0), + requestGroup: "default", revertCalls: 1, setupCalls: 1, reloadCalls: 1, @@ -1639,6 +1661,7 @@ func TestUpdater_Install(t *testing.T) { installedRevision: NewRevision("16.3.0", 0), installedBaseURL: autoupdate.DefaultBaseURL, linkedRevision: NewRevision("16.3.0", 0), + requestGroup: "default", revertCalls: 1, setupCalls: 1, errMatch: "setup error", @@ -1650,6 +1673,7 @@ func TestUpdater_Install(t *testing.T) { installedRevision: NewRevision("16.3.0", 0), installedBaseURL: autoupdate.DefaultBaseURL, linkedRevision: NewRevision("16.3.0", 0), + requestGroup: "default", setupCalls: 1, restarted: true, }, @@ -1661,6 +1685,7 @@ func TestUpdater_Install(t *testing.T) { installedRevision: NewRevision("16.3.0", 0), installedBaseURL: autoupdate.DefaultBaseURL, linkedRevision: NewRevision("16.3.0", 0), + requestGroup: "default", setupCalls: 1, restarted: true, },