diff --git a/examples/chart/teleport-kube-agent/templates/updater/deployment.yaml b/examples/chart/teleport-kube-agent/templates/updater/deployment.yaml
index b9d547638df9a..e7359243d1601 100644
--- a/examples/chart/teleport-kube-agent/templates/updater/deployment.yaml
+++ b/examples/chart/teleport-kube-agent/templates/updater/deployment.yaml
@@ -65,7 +65,7 @@ spec:
- "--agent-name={{ .Release.Name }}"
- "--agent-namespace={{ .Release.Namespace }}"
- "--base-image={{ include "teleport-kube-agent.baseImage" . }}"
- - "--version-server={{ $updater.versionServer }}"
+ - "--version-server={{ tpl $updater.versionServer . }}"
- "--version-channel={{ $updater.releaseChannel }}"
{{- if .Values.updater.extraArgs }}
{{- toYaml .Values.updater.extraArgs | nindent 10 }}
diff --git a/examples/chart/teleport-kube-agent/tests/updater_deployment_test.yaml b/examples/chart/teleport-kube-agent/tests/updater_deployment_test.yaml
index 032c8348f88d5..b6b28f40bee22 100644
--- a/examples/chart/teleport-kube-agent/tests/updater_deployment_test.yaml
+++ b/examples/chart/teleport-kube-agent/tests/updater_deployment_test.yaml
@@ -57,6 +57,16 @@ tests:
- contains:
path: spec.template.spec.containers[0].args
content: "--agent-namespace=my-namespace"
+ - it: defaults the updater version server to the proxy address
+ set:
+ proxyAddr: proxy.teleport.example.com:443
+ roles: "custom"
+ updater:
+ enabled: true
+ asserts:
+ - contains:
+ path: spec.template.spec.containers[0].args
+ content: "--version-server=https://proxy.teleport.example.com:443/v1/webapi/automaticupgrades/channel"
- it: sets the updater version server
values:
- ../.lint/updater.yaml
diff --git a/examples/chart/teleport-kube-agent/values.yaml b/examples/chart/teleport-kube-agent/values.yaml
index 567973ecbbd34..0c8ba7411616e 100644
--- a/examples/chart/teleport-kube-agent/values.yaml
+++ b/examples/chart/teleport-kube-agent/values.yaml
@@ -146,7 +146,8 @@ updater:
# `updater.versionServer` is the URL of the version server the agent fetches
# the target version from. The complete version endpoint is built by
# concatenating `versionServer` and `releaseChannel`.
- versionServer: "https://updates.releases.teleport.dev/v1/"
+ # This field supports gotemplate.
+ versionServer: "https://{{ .Values.proxyAddr }}/v1/webapi/automaticupgrades/channel"
# Release channel the agent subscribes to.
releaseChannel: "stable/cloud"
image: public.ecr.aws/gravitational/teleport-kube-agent-updater
diff --git a/integration/proxy/automaticupgrades_test.go b/integration/proxy/automaticupgrades_test.go
new file mode 100644
index 0000000000000..28fa4d3cd49e6
--- /dev/null
+++ b/integration/proxy/automaticupgrades_test.go
@@ -0,0 +1,189 @@
+/*
+ * Teleport
+ * Copyright (C) 2023 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package proxy
+
+import (
+ "context"
+ "crypto/tls"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "path/filepath"
+ "testing"
+
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/gravitational/teleport"
+ "github.com/gravitational/teleport/api/client/proto"
+ "github.com/gravitational/teleport/integration/helpers"
+ "github.com/gravitational/teleport/lib/automaticupgrades"
+ "github.com/gravitational/teleport/lib/automaticupgrades/basichttp"
+ "github.com/gravitational/teleport/lib/automaticupgrades/constants"
+ "github.com/gravitational/teleport/lib/service/servicecfg"
+ "github.com/gravitational/teleport/lib/utils"
+)
+
+func createProxyWithChannels(t *testing.T, channels automaticupgrades.Channels) string {
+ features := proto.Features{}
+ require.NoError(t, channels.CheckAndSetDefaults(features))
+ testDir := t.TempDir()
+
+ cfg := helpers.InstanceConfig{
+ ClusterName: "root.example.com",
+ HostID: uuid.New().String(),
+ NodeName: helpers.Loopback,
+ Log: utils.NewLoggerForTests(),
+ }
+ cfg.Listeners = helpers.SingleProxyPortSetup(t, &cfg.Fds)
+ rc := helpers.NewInstance(t, cfg)
+
+ var err error
+ rcConf := servicecfg.MakeDefaultConfig()
+ rcConf.DataDir = filepath.Join(testDir, "data")
+ rcConf.Auth.Enabled = true
+ rcConf.Proxy.Enabled = true
+ rcConf.SSH.Enabled = false
+ rcConf.Proxy.DisableWebInterface = true
+ rcConf.Version = "v3"
+ rcConf.Proxy.AutomaticUpgradesChannels = channels
+
+ err = rc.CreateEx(t, nil, rcConf)
+ require.NoError(t, err)
+ err = rc.Start()
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ assert.NoError(t, rc.StopAll())
+ })
+
+ return cfg.Listeners.Web
+}
+
+func TestVersionServer(t *testing.T) {
+ // Test setup
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ testVersion := "v12.2.6"
+ testVersionMajorTooHigh := "v99.1.3"
+
+ staticChannel := "static/ok"
+ staticHighChannel := "static/high"
+ staticNoVersionChannel := "static/none"
+ forwardChannel := "forward/ok"
+ forwardHighChannel := "forward/high"
+ forwardNoVersionChannel := "forward/none"
+ forwardPath := "/version-server/"
+
+ upstreamServer := basichttp.NewServerMock(forwardPath + constants.VersionPath)
+ upstreamServer.SetResponse(t, http.StatusOK, testVersion)
+ upstreamHighServer := basichttp.NewServerMock(forwardPath + constants.VersionPath)
+ upstreamHighServer.SetResponse(t, http.StatusOK, testVersionMajorTooHigh)
+ upstreamNoVersionServer := basichttp.NewServerMock(forwardPath + constants.VersionPath)
+ upstreamNoVersionServer.SetResponse(t, http.StatusOK, constants.NoVersion)
+
+ channels := automaticupgrades.Channels{
+ staticChannel: {
+ StaticVersion: testVersion,
+ },
+ staticHighChannel: {
+ StaticVersion: testVersionMajorTooHigh,
+ },
+ staticNoVersionChannel: {
+ StaticVersion: constants.NoVersion,
+ },
+ forwardChannel: {
+ ForwardURL: upstreamServer.Srv.URL + forwardPath,
+ },
+ forwardHighChannel: {
+ ForwardURL: upstreamHighServer.Srv.URL + forwardPath,
+ },
+ forwardNoVersionChannel: {
+ ForwardURL: upstreamNoVersionServer.Srv.URL + forwardPath,
+ },
+ }
+
+ proxyAddr := createProxyWithChannels(t, channels)
+
+ tr := &http.Transport{
+ TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
+ }
+ httpClient := http.Client{Transport: tr}
+
+ // Test execution
+ tests := []struct {
+ name string
+ channel string
+ expectedResponse string
+ }{
+ {
+ name: "static version OK",
+ channel: staticChannel,
+ expectedResponse: testVersion,
+ },
+ {
+ name: "static version too high",
+ channel: staticHighChannel,
+ expectedResponse: teleport.Version,
+ },
+ {
+ name: "static version none",
+ channel: staticNoVersionChannel,
+ expectedResponse: constants.NoVersion,
+ },
+ {
+ name: "forward version OK",
+ channel: forwardChannel,
+ expectedResponse: testVersion,
+ },
+ {
+ name: "forward version too high",
+ channel: forwardHighChannel,
+ expectedResponse: teleport.Version,
+ },
+ {
+ name: "forward version none",
+ channel: forwardNoVersionChannel,
+ expectedResponse: constants.NoVersion,
+ },
+ }
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ channelUrl, err := url.Parse(
+ fmt.Sprintf("https://%s/v1/webapi/automaticupgrades/channel/%s/version", proxyAddr, tt.channel),
+ )
+ require.NoError(t, err)
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, channelUrl.String(), nil)
+ require.NoError(t, err)
+ res, err := httpClient.Do(req)
+ require.NoError(t, err)
+ defer res.Body.Close()
+
+ body, err := io.ReadAll(res.Body)
+ require.NoError(t, err)
+
+ require.Equal(t, http.StatusOK, res.StatusCode)
+ require.Equal(t, tt.expectedResponse, string(body))
+ })
+ }
+}
diff --git a/integrations/kube-agent-updater/cmd/teleport-kube-agent-updater/main.go b/integrations/kube-agent-updater/cmd/teleport-kube-agent-updater/main.go
index 30e2f485cffc0..1a98dfa316d0c 100644
--- a/integrations/kube-agent-updater/cmd/teleport-kube-agent-updater/main.go
+++ b/integrations/kube-agent-updater/cmd/teleport-kube-agent-updater/main.go
@@ -188,6 +188,8 @@ func main() {
os.Exit(1)
}
+ ctrl.Log.Info("starting the updater", "url", versionServerURL.String())
+
if err := mgr.Start(ctx); err != nil {
ctrl.Log.Error(err, "failed to start manager, exiting")
os.Exit(1)
diff --git a/integrations/kube-agent-updater/pkg/constants/constants.go b/integrations/kube-agent-updater/pkg/constants/constants.go
index 65c9e9c9d4e5b..2ca709962c89c 100644
--- a/integrations/kube-agent-updater/pkg/constants/constants.go
+++ b/integrations/kube-agent-updater/pkg/constants/constants.go
@@ -29,4 +29,8 @@ const (
VersionPath = "version"
HTTPTimeout = 10 * time.Second
CacheDuration = time.Minute
+
+ // NoVersion is returned by the version endpoint when there is no valid target version.
+ // This can be caused by the target version being incompatible with the cluster version.
+ NoVersion = "none"
)
diff --git a/integrations/kube-agent-updater/pkg/controller/deployment.go b/integrations/kube-agent-updater/pkg/controller/deployment.go
index 1d208f2f43b41..c377bea0dbef9 100644
--- a/integrations/kube-agent-updater/pkg/controller/deployment.go
+++ b/integrations/kube-agent-updater/pkg/controller/deployment.go
@@ -27,6 +27,8 @@ import (
ctrl "sigs.k8s.io/controller-runtime"
kclient "sigs.k8s.io/controller-runtime/pkg/client"
ctrllog "sigs.k8s.io/controller-runtime/pkg/log"
+
+ "github.com/gravitational/teleport/integrations/kube-agent-updater/pkg/version"
)
// DeploymentVersionUpdater Reconciles a podSpec by changing its image
@@ -74,7 +76,7 @@ func (r *DeploymentVersionUpdater) Reconcile(ctx context.Context, req ctrl.Reque
image, err := r.GetVersion(ctx, &obj, currentVersion)
var (
- noNewVersionErr *NoNewVersionError
+ noNewVersionErr *version.NoNewVersionError
maintenanceErr *MaintenanceNotTriggeredError
)
switch {
diff --git a/integrations/kube-agent-updater/pkg/controller/errors.go b/integrations/kube-agent-updater/pkg/controller/errors.go
index 157f7de32a816..fccbbad75756c 100644
--- a/integrations/kube-agent-updater/pkg/controller/errors.go
+++ b/integrations/kube-agent-updater/pkg/controller/errors.go
@@ -16,10 +16,6 @@ limitations under the License.
package controller
-import (
- "fmt"
-)
-
// MaintenanceNotTriggeredError indicates that no trigger returned true and the controller did not reconcile.
type MaintenanceNotTriggeredError struct {
Message string `json:"message"`
@@ -32,18 +28,3 @@ func (e *MaintenanceNotTriggeredError) Error() string {
}
return "maintenance not triggered"
}
-
-// NoNewVersionError indicates that no new version was found and the controller did not reconcile.
-type NoNewVersionError struct {
- Message string `json:"message"`
- CurrentVersion string `json:"currentVersion"`
- NextVersion string `json:"nextVersion"`
-}
-
-// Error returns log friendly description of an error
-func (e *NoNewVersionError) Error() string {
- if e.Message != "" {
- return e.Message
- }
- return fmt.Sprintf("no new version (current: %q, next: %q)", e.CurrentVersion, e.NextVersion)
-}
diff --git a/integrations/kube-agent-updater/pkg/controller/statefulset.go b/integrations/kube-agent-updater/pkg/controller/statefulset.go
index 5b78227197a2f..4c014c80f3fbc 100644
--- a/integrations/kube-agent-updater/pkg/controller/statefulset.go
+++ b/integrations/kube-agent-updater/pkg/controller/statefulset.go
@@ -32,6 +32,7 @@ import (
ctrllog "sigs.k8s.io/controller-runtime/pkg/log"
"github.com/gravitational/teleport/integrations/kube-agent-updater/pkg/podutils"
+ "github.com/gravitational/teleport/integrations/kube-agent-updater/pkg/version"
)
type StatefulSetVersionUpdater struct {
@@ -99,7 +100,7 @@ func (r *StatefulSetVersionUpdater) Reconcile(ctx context.Context, req ctrl.Requ
image, err := r.GetVersion(ctx, &obj, currentVersion)
var (
- noNewVersionErr *NoNewVersionError
+ noNewVersionErr *version.NoNewVersionError
maintenanceErr *MaintenanceNotTriggeredError
)
switch {
diff --git a/integrations/kube-agent-updater/pkg/controller/updater.go b/integrations/kube-agent-updater/pkg/controller/updater.go
index 1ecc87647f2c9..764e5f7f948e3 100644
--- a/integrations/kube-agent-updater/pkg/controller/updater.go
+++ b/integrations/kube-agent-updater/pkg/controller/updater.go
@@ -61,7 +61,7 @@ func (r *VersionUpdater) GetVersion(ctx context.Context, obj client.Object, curr
log.Info("New version candidate", "nextVersion", nextVersion)
if !version.ValidVersionChange(ctx, currentVersion, nextVersion) {
- return nil, &NoNewVersionError{CurrentVersion: currentVersion, NextVersion: nextVersion}
+ return nil, &version.NoNewVersionError{CurrentVersion: currentVersion, NextVersion: nextVersion}
}
log.Info("Version change is valid, building img candidate")
diff --git a/integrations/kube-agent-updater/pkg/controller/updater_test.go b/integrations/kube-agent-updater/pkg/controller/updater_test.go
index d33678577b1c5..51ae579aaf837 100644
--- a/integrations/kube-agent-updater/pkg/controller/updater_test.go
+++ b/integrations/kube-agent-updater/pkg/controller/updater_test.go
@@ -42,8 +42,8 @@ const (
)
var (
- alwaysTrigger = maintenance.NewMaintenanceTriggerMock("always trigger", true)
- neverTrigger = maintenance.NewMaintenanceTriggerMock("never trigger", false)
+ alwaysTrigger = maintenance.NewMaintenanceStaticTrigger("always trigger", true)
+ neverTrigger = maintenance.NewMaintenanceStaticTrigger("never trigger", false)
alwaysValid = img.NewImageValidatorMock(
"always",
true,
@@ -80,7 +80,7 @@ func Test_VersionUpdater_GetVersion(t *testing.T) {
releaseRegistry: defaultTestRegistry,
releasePath: defaultTestPath,
currentVersion: versionMid,
- versionGetter: version.NewGetterMock(versionHigh, nil),
+ versionGetter: version.NewStaticGetter(versionHigh, nil),
maintenanceTriggers: []maintenance.Trigger{alwaysTrigger},
imageCheckers: []img.Validator{alwaysValid},
assertErr: require.NoError,
@@ -91,7 +91,7 @@ func Test_VersionUpdater_GetVersion(t *testing.T) {
releaseRegistry: defaultTestRegistry,
releasePath: defaultTestPath,
currentVersion: "",
- versionGetter: version.NewGetterMock(versionHigh, nil),
+ versionGetter: version.NewStaticGetter(versionHigh, nil),
maintenanceTriggers: []maintenance.Trigger{alwaysTrigger},
imageCheckers: []img.Validator{alwaysValid},
assertErr: require.NoError,
@@ -102,10 +102,21 @@ func Test_VersionUpdater_GetVersion(t *testing.T) {
releaseRegistry: defaultTestRegistry,
releasePath: defaultTestPath,
currentVersion: versionMid,
- versionGetter: version.NewGetterMock(versionMid, nil),
+ versionGetter: version.NewStaticGetter(versionMid, nil),
maintenanceTriggers: []maintenance.Trigger{alwaysTrigger},
imageCheckers: []img.Validator{alwaysValid},
- assertErr: errorIsType(&NoNewVersionError{}),
+ assertErr: errorIsType(&version.NoNewVersionError{}),
+ expectedImage: "",
+ },
+ {
+ name: "no version",
+ releaseRegistry: defaultTestRegistry,
+ releasePath: defaultTestPath,
+ currentVersion: versionMid,
+ versionGetter: version.NewStaticGetter("", &version.NoNewVersionError{Message: "version server did not advertise a version"}),
+ maintenanceTriggers: []maintenance.Trigger{alwaysTrigger},
+ imageCheckers: []img.Validator{alwaysValid},
+ assertErr: errorIsType(&version.NoNewVersionError{}),
expectedImage: "",
},
{
@@ -113,7 +124,7 @@ func Test_VersionUpdater_GetVersion(t *testing.T) {
releaseRegistry: defaultTestRegistry,
releasePath: defaultTestPath,
currentVersion: versionMid,
- versionGetter: version.NewGetterMock(versionHigh, nil),
+ versionGetter: version.NewStaticGetter(versionHigh, nil),
maintenanceTriggers: []maintenance.Trigger{neverTrigger},
imageCheckers: []img.Validator{alwaysValid},
assertErr: errorIsType(&MaintenanceNotTriggeredError{}),
@@ -124,7 +135,7 @@ func Test_VersionUpdater_GetVersion(t *testing.T) {
releaseRegistry: defaultTestRegistry,
releasePath: defaultTestPath,
currentVersion: versionMid,
- versionGetter: version.NewGetterMock(versionHigh, nil),
+ versionGetter: version.NewStaticGetter(versionHigh, nil),
maintenanceTriggers: []maintenance.Trigger{alwaysTrigger},
imageCheckers: []img.Validator{neverValid},
assertErr: errorIsType(&trace.TrustError{}),
@@ -135,7 +146,7 @@ func Test_VersionUpdater_GetVersion(t *testing.T) {
releaseRegistry: defaultTestRegistry,
releasePath: defaultTestPath,
currentVersion: versionMid,
- versionGetter: version.NewGetterMock("", &trace.ConnectionProblemError{}),
+ versionGetter: version.NewStaticGetter("", &trace.ConnectionProblemError{}),
maintenanceTriggers: []maintenance.Trigger{alwaysTrigger},
imageCheckers: []img.Validator{neverValid},
assertErr: errorIsType(&trace.ConnectionProblemError{}),
@@ -159,7 +170,7 @@ func Test_VersionUpdater_GetVersion(t *testing.T) {
baseImage: baseImage,
}
- // We need a dummy Kubernetes object, it is not used by the TriggerMock
+ // We need a dummy Kubernetes object, it is not used by the StaticTrigger
obj := &core.Pod{}
// Doing the test
diff --git a/integrations/kube-agent-updater/pkg/maintenance/mock.go b/integrations/kube-agent-updater/pkg/maintenance/mock.go
index 0c0c6b5738d21..a777bfca0764c 100644
--- a/integrations/kube-agent-updater/pkg/maintenance/mock.go
+++ b/integrations/kube-agent-updater/pkg/maintenance/mock.go
@@ -22,33 +22,33 @@ import (
"sigs.k8s.io/controller-runtime/pkg/client"
)
-// TriggerMock is a fake Trigger that return a static answer. This is used
+// StaticTrigger is a fake Trigger that return a static answer. This is used
// for testing purposes and is inherently disruptive.
-type TriggerMock struct {
+type StaticTrigger struct {
name string
canStart bool
}
-// Name returns the TriggerMock name.
-func (m TriggerMock) Name() string {
+// Name returns the StaticTrigger name.
+func (m StaticTrigger) Name() string {
return m.name
}
// CanStart returns the statically defined maintenance approval result.
-func (m TriggerMock) CanStart(_ context.Context, _ client.Object) (bool, error) {
+func (m StaticTrigger) CanStart(_ context.Context, _ client.Object) (bool, error) {
return m.canStart, nil
}
// Default returns the default behavior if the trigger fails. This cannot
-// happen for a TriggerMock and is here solely to implement the Trigger
+// happen for a StaticTrigger and is here solely to implement the Trigger
// interface.
-func (m TriggerMock) Default() bool {
+func (m StaticTrigger) Default() bool {
return m.canStart
}
-// NewMaintenanceTriggerMock creates a TriggerMock
-func NewMaintenanceTriggerMock(name string, canStart bool) Trigger {
- return TriggerMock{
+// NewMaintenanceStaticTrigger creates a StaticTrigger
+func NewMaintenanceStaticTrigger(name string, canStart bool) Trigger {
+ return StaticTrigger{
name: name,
canStart: canStart,
}
diff --git a/integrations/kube-agent-updater/pkg/version/basichttp.go b/integrations/kube-agent-updater/pkg/version/basichttp.go
index 40b69cae91da8..cb0418e703dd5 100644
--- a/integrations/kube-agent-updater/pkg/version/basichttp.go
+++ b/integrations/kube-agent-updater/pkg/version/basichttp.go
@@ -47,8 +47,12 @@ func (b *basicHTTPVersionClient) Get(ctx context.Context) (string, error) {
if err != nil {
return "", trace.Wrap(err)
}
+ response := string(body)
+ if response == constants.NoVersion {
+ return "", &NoNewVersionError{Message: "version server did not advertise a version"}
+ }
// We trim spaces because the value might end with one or many newlines
- version, err := EnsureSemver(strings.TrimSpace(string(body)))
+ version, err := EnsureSemver(strings.TrimSpace(response))
return version, trace.Wrap(err)
}
diff --git a/integrations/kube-agent-updater/pkg/version/errors.go b/integrations/kube-agent-updater/pkg/version/errors.go
new file mode 100644
index 0000000000000..96344c60858de
--- /dev/null
+++ b/integrations/kube-agent-updater/pkg/version/errors.go
@@ -0,0 +1,36 @@
+/*
+ * Teleport
+ * Copyright (C) 2023 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package version
+
+import "fmt"
+
+// NoNewVersionError indicates that no new version was found and the controller did not reconcile.
+type NoNewVersionError struct {
+ Message string `json:"message"`
+ CurrentVersion string `json:"currentVersion"`
+ NextVersion string `json:"nextVersion"`
+}
+
+// Error returns log friendly description of an error
+func (e *NoNewVersionError) Error() string {
+ if e.Message != "" {
+ return e.Message
+ }
+ return fmt.Sprintf("no new version (current: %q, next: %q)", e.CurrentVersion, e.NextVersion)
+}
diff --git a/integrations/kube-agent-updater/pkg/version/mock.go b/integrations/kube-agent-updater/pkg/version/mock.go
index 6746ebcd62a67..725d8758794f5 100644
--- a/integrations/kube-agent-updater/pkg/version/mock.go
+++ b/integrations/kube-agent-updater/pkg/version/mock.go
@@ -18,28 +18,38 @@ package version
import (
"context"
+ "fmt"
"strings"
+
+ "github.com/gravitational/teleport/integrations/kube-agent-updater/pkg/constants"
)
-// GetterMock is a fake version.Getter that return a static answer. This is used
+// StaticGetter is a fake version.Getter that return a static answer. This is used
// for testing purposes.
-type GetterMock struct {
+type StaticGetter struct {
version string
err error
}
// GetVersion returns the statically defined version.
-func (v GetterMock) GetVersion(_ context.Context) (string, error) {
+func (v StaticGetter) GetVersion(_ context.Context) (string, error) {
return v.version, v.err
}
-// NewGetterMock creates a GetterMock
-func NewGetterMock(version string, err error) Getter {
+// NewStaticGetter creates a StaticGetter
+func NewStaticGetter(version string, err error) Getter {
+ if version == constants.NoVersion {
+ return StaticGetter{
+ version: "",
+ err: &NoNewVersionError{Message: fmt.Sprintf("target version set to '%s'", constants.NoVersion)},
+ }
+ }
+
semVersion := version
if semVersion != "" && !strings.HasPrefix(semVersion, "v") {
semVersion = "v" + version
}
- return GetterMock{
+ return StaticGetter{
version: semVersion,
err: err,
}
diff --git a/integrations/kube-agent-updater/pkg/version/versionget.go b/integrations/kube-agent-updater/pkg/version/versionget.go
index ace71623b1ae0..00ffdc31fda3f 100644
--- a/integrations/kube-agent-updater/pkg/version/versionget.go
+++ b/integrations/kube-agent-updater/pkg/version/versionget.go
@@ -28,6 +28,8 @@ import (
// Getter gets the target image version for an external source. It should cache
// the result to reduce io and avoid potential rate-limits and is safe to call
// multiple times over a short period.
+// If the version source intentionally returns no version, a NoNewVersionError is
+// returned.
type Getter interface {
GetVersion(context.Context) (string, error)
}
@@ -35,7 +37,6 @@ type Getter interface {
// ValidVersionChange receives the current version and the candidate next version
// and evaluates if the version transition is valid.
func ValidVersionChange(ctx context.Context, current, next string) bool {
- // TODO: clarify rollback constraints regarding previous version and add a "previous" parameter
log := ctrllog.FromContext(ctx).V(1)
// Cannot upgrade to a non-valid version
if !semver.IsValid(next) {
diff --git a/lib/automaticupgrades/basichttp/client.go b/lib/automaticupgrades/basichttp/client.go
new file mode 100644
index 0000000000000..b1684a2e42087
--- /dev/null
+++ b/lib/automaticupgrades/basichttp/client.go
@@ -0,0 +1,59 @@
+/*
+Copyright 2023 Gravitational, Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package basichttp
+
+import (
+ "context"
+ "io"
+ "net/http"
+ "net/url"
+
+ "github.com/gravitational/trace"
+)
+
+// Client extends the regular http.Client by adding a GetContent method that does
+// a GET query to a given URL and returns an error if the status is non-200.
+// This is typically used to retrieve small files stored in a S3 bucket like the
+// maintenance.BasicHTTPMaintenanceTrigger or the version.BasicHTTPVersionGetter
+// are doing.
+type Client struct {
+ *http.Client
+}
+
+// GetContent sends a GET HTTP request and fails if the response is not 200.
+func (c *Client) GetContent(ctx context.Context, targetURL url.URL) ([]byte, error) {
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL.String(), nil)
+ if err != nil {
+ return []byte{}, trace.Wrap(err)
+ }
+ res, err := c.Do(req)
+ if err != nil {
+ return []byte{}, trace.Wrap(err)
+ }
+ defer res.Body.Close()
+
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ return []byte{}, trace.Wrap(err)
+ }
+
+ if res.StatusCode != http.StatusOK {
+ return []byte{}, trace.Errorf("non-200 status code received: '%d'", res.StatusCode)
+ }
+
+ return body, nil
+}
diff --git a/lib/automaticupgrades/basichttp/servermock.go b/lib/automaticupgrades/basichttp/servermock.go
new file mode 100644
index 0000000000000..fa245248d4f0d
--- /dev/null
+++ b/lib/automaticupgrades/basichttp/servermock.go
@@ -0,0 +1,59 @@
+/*
+Copyright 2023 Gravitational, Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package basichttp
+
+import (
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+// ServerMock is a HTTP server whose response can be controlled from the tests.
+// This is used to mock external dependencies like s3 buckets or a remote HTTP server.
+type ServerMock struct {
+ Srv *httptest.Server
+
+ t *testing.T
+ code int
+ response string
+ path string
+}
+
+// SetResponse sets the ServerMock's response.
+func (m *ServerMock) SetResponse(t *testing.T, code int, response string) {
+ m.t = t
+ m.code = code
+ m.response = response
+}
+
+// ServeHTTP implements the http.Handler interface.
+func (m *ServerMock) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ require.Equal(m.t, m.path, r.URL.Path)
+ w.WriteHeader(m.code)
+ _, err := io.WriteString(w, m.response)
+ require.NoError(m.t, err)
+}
+
+// NewServerMock builds and returns a
+func NewServerMock(path string) *ServerMock {
+ mock := ServerMock{path: path}
+ mock.Srv = httptest.NewServer(http.HandlerFunc(mock.ServeHTTP))
+ return &mock
+}
diff --git a/lib/automaticupgrades/cache/cache.go b/lib/automaticupgrades/cache/cache.go
new file mode 100644
index 0000000000000..a03555b3cdd51
--- /dev/null
+++ b/lib/automaticupgrades/cache/cache.go
@@ -0,0 +1,66 @@
+/*
+Copyright 2023 Gravitational, Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package cache
+
+import (
+ "context"
+ "sync"
+ "time"
+
+ "github.com/gravitational/trace"
+ "github.com/jonboulle/clockwork"
+)
+
+// TimedMemoize wraps a function returning (value, error) and caches the
+// value AND the error for a specific time. This cache is mainly used to ensure
+// external calls rates are reasonable. The cache is thread-safe.
+type TimedMemoize[T any] struct {
+ clock clockwork.Clock
+ mutex sync.Mutex
+ cachedValue T
+ cachedError error
+ validUntil time.Time
+ cacheDuration time.Duration
+ getter func(ctx context.Context) (T, error)
+}
+
+// Get does a cache lookup and updates the cache in case of cache miss.
+func (c *TimedMemoize[T]) Get(ctx context.Context) (T, error) {
+ c.mutex.Lock()
+ defer c.mutex.Unlock()
+
+ if c.validUntil.After(c.clock.Now()) {
+ // TimedMemoize hit, we return cached result
+ return c.cachedValue, trace.Wrap(c.cachedError)
+ }
+
+ // Cache miss, we do a query and update the cache
+ value, err := c.getter(ctx)
+ c.validUntil = c.clock.Now().Add(c.cacheDuration)
+ c.cachedValue = value
+ c.cachedError = newCachedError(err, c.validUntil)
+ return value, trace.Wrap(err)
+}
+
+// NewTimedMemoize builds and returns a TimedMemoize
+func NewTimedMemoize[T any](getter func(ctx context.Context) (T, error), duration time.Duration) *TimedMemoize[T] {
+ return &TimedMemoize[T]{
+ clock: clockwork.NewRealClock(),
+ getter: getter,
+ cacheDuration: duration,
+ }
+}
diff --git a/lib/automaticupgrades/cache/cache_test.go b/lib/automaticupgrades/cache/cache_test.go
new file mode 100644
index 0000000000000..7e200a79be882
--- /dev/null
+++ b/lib/automaticupgrades/cache/cache_test.go
@@ -0,0 +1,125 @@
+/*
+Copyright 2023 Gravitational, Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package cache
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/gravitational/trace"
+ "github.com/jonboulle/clockwork"
+ "github.com/stretchr/testify/require"
+)
+
+func failIfCalled(t *testing.T, value string, err error) func(context.Context) (string, error) {
+ return func(_ context.Context) (string, error) {
+ require.Fail(t, "should not be called")
+ return value, err
+ }
+}
+
+func failIfCalledTwice(t *testing.T, value string, err error) func(context.Context) (string, error) {
+ var fuse bool
+ return func(_ context.Context) (string, error) {
+ if !fuse {
+ fuse = true
+ return value, err
+ }
+ require.Fail(t, "should not be called twice")
+ return value, err
+ }
+}
+
+func TestTimedMemoize_Get(t *testing.T) {
+ ctx := context.Background()
+
+ now := time.Now()
+ longBefore := now.Add(-2 * time.Hour)
+ longAfter := now.Add(2 * time.Hour)
+
+ upstreamValue := "upstream"
+ cachedValue := "cached"
+ upstreamError := trace.LimitExceeded("rate-limited")
+ oldUpstreamError := trace.CompareFailed("comparison failed")
+
+ assertUncachedError := func(t2 require.TestingT, err error, _ ...interface{}) {
+ require.Equal(t2, err, upstreamError)
+ }
+ assertCachedError := func(t2 require.TestingT, err error, _ ...interface{}) {
+ _, ok := trace.Unwrap(err).(cachedError)
+ require.True(t2, ok)
+ }
+
+ tests := []struct {
+ name string
+ cachedValue string
+ cachedError error
+ validUntil time.Time
+ getter func(*testing.T, string, error) func(context.Context) (string, error)
+ expectedValue string
+ assertErr require.ErrorAssertionFunc
+ }{
+ {
+ name: "fresh cache",
+ cachedValue: "",
+ cachedError: nil,
+ validUntil: time.Time{},
+ getter: failIfCalledTwice,
+ expectedValue: upstreamValue,
+ assertErr: assertUncachedError,
+ },
+ {
+ name: "valid cache",
+ cachedValue: cachedValue,
+ cachedError: newCachedError(oldUpstreamError, longAfter),
+ validUntil: longAfter,
+ getter: failIfCalled,
+ expectedValue: cachedValue,
+ assertErr: assertCachedError,
+ },
+ {
+ name: "expired cache",
+ cachedValue: cachedValue,
+ cachedError: newCachedError(oldUpstreamError, longBefore),
+ validUntil: longBefore,
+ getter: failIfCalledTwice,
+ expectedValue: upstreamValue,
+ assertErr: assertUncachedError,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cache := TimedMemoize[string]{
+ cachedValue: tt.cachedValue,
+ cachedError: tt.cachedError,
+ validUntil: tt.validUntil,
+ clock: clockwork.NewFakeClockAt(now),
+ cacheDuration: 30 * time.Minute,
+ getter: tt.getter(t, upstreamValue, upstreamError),
+ }
+ result, err := cache.Get(ctx)
+ require.Equal(t, tt.expectedValue, result)
+ // The first error might or might not be cached
+ tt.assertErr(t, err)
+ result, err = cache.Get(ctx)
+ require.Equal(t, tt.expectedValue, result)
+ // The second error must be cached
+ assertCachedError(t, err)
+ })
+ }
+}
diff --git a/lib/automaticupgrades/cache/error.go b/lib/automaticupgrades/cache/error.go
new file mode 100644
index 0000000000000..952ba0c362dad
--- /dev/null
+++ b/lib/automaticupgrades/cache/error.go
@@ -0,0 +1,54 @@
+/*
+Copyright 2023 Gravitational, Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package cache
+
+import (
+ "fmt"
+ "time"
+)
+
+// cachedError wraps an error before storing it into cache. This adds more
+// context into the original error by clearly indicating the error have been
+// cached and for how long.
+type cachedError struct {
+ err error
+ until time.Time
+}
+
+func (e cachedError) Error() string {
+ return fmt.Sprintf("error cached until '%s': %s", e.until, e.err)
+}
+
+// OrigError returns the original error. This implements trace.ErrorWrapper
+// and allows to be unwrapped by trace.Unwrap().
+func (e cachedError) OrigError() error {
+ return e.err
+}
+
+// Unwrap returns the original error.
+func (e cachedError) Unwrap() error {
+ return e.err
+}
+
+// newCachedError takes an error and wraps it into a cachedError. If there is no
+// error, it returns nothing.
+func newCachedError(err error, until time.Time) error {
+ if err == nil {
+ return nil
+ }
+ return cachedError{err, until}
+}
diff --git a/lib/automaticupgrades/channel.go b/lib/automaticupgrades/channel.go
new file mode 100644
index 0000000000000..de2ba83acb796
--- /dev/null
+++ b/lib/automaticupgrades/channel.go
@@ -0,0 +1,212 @@
+/*
+ * Teleport
+ * Copyright (C) 2023 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package automaticupgrades
+
+import (
+ "context"
+ "net/url"
+ "strconv"
+ "strings"
+
+ "github.com/gravitational/trace"
+ "golang.org/x/mod/semver"
+
+ "github.com/gravitational/teleport"
+ "github.com/gravitational/teleport/api/client/proto"
+ "github.com/gravitational/teleport/lib/automaticupgrades/maintenance"
+ "github.com/gravitational/teleport/lib/automaticupgrades/version"
+)
+
+const (
+ DefaultChannelName = "default"
+ DefaultCloudChannelName = "stable/cloud"
+ stableCloudVersionBaseURL = "https://updates.releases.teleport.dev/v1/stable/cloud"
+)
+
+// Channels is a map of Channel objects.
+type Channels map[string]*Channel
+
+// CheckAndSetDefaults checks that every Channel is valid and initializes them.
+// It also creates default channels if they are not already present.
+// Cloud must have the `default` and `stable/cloud` channels.
+// Self-hosted with automatic upgrades must have the `default` channel.
+func (c Channels) CheckAndSetDefaults(features proto.Features) error {
+ defaultChannel, err := NewDefaultChannel()
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ // If we're on cloud, we need at least "cloud/stable" and "default"
+ if features.GetCloud() {
+ if _, ok := c[DefaultCloudChannelName]; !ok {
+ c[DefaultCloudChannelName] = defaultChannel
+ }
+ if _, ok := c[DefaultChannelName]; !ok {
+ c[DefaultChannelName] = c[DefaultCloudChannelName]
+ }
+ }
+
+ // If we're on self-hosted with automatic upgrades, we need a "default" channel
+ // We don't want to break existing setups so we'll automatically point to the
+ // `cloud/stable` channel.
+ // TODO: in v15 make this a hard requirement and error if `default` is not set
+ // and automatic upgrades are enabled
+ if features.GetAutomaticUpgrades() {
+ if _, ok := c[DefaultChannelName]; !ok {
+ c[DefaultChannelName] = defaultChannel
+ }
+ }
+
+ var errs []error
+ for name, channel := range c {
+ // Wrapping is not mandatory here, but it adds the channel name in the
+ // error, which makes troubleshooting easier.
+ err = trace.Wrap(channel.CheckAndSetDefaults(), "failed to create channel %s", name)
+ if err != nil {
+ errs = append(errs, err)
+ }
+ }
+ return trace.NewAggregate(errs...)
+}
+
+// DefaultVersion returns the version served by the default upgrade channel.
+func (c Channels) DefaultVersion(ctx context.Context) (string, error) {
+ channel, ok := c[DefaultChannelName]
+ if !ok {
+ return "", trace.NotFound("default version channel not found")
+ }
+ targetVersion, err := channel.GetVersion(ctx)
+ return targetVersion, trace.Wrap(err)
+}
+
+// Channel describes an automatic update channel configuration.
+// It can be configured to serve a static version, or forward version requests
+// to an upstream version server. Forwarded results are cached for 1 minute.
+type Channel struct {
+ // ForwardURL is the URL of the upstream version server providing the channel version/criticality.
+ ForwardURL string `yaml:"forward_url,omitempty"`
+ // StaticVersion is a static version the channel must serve. With or without the leading "v".
+ StaticVersion string `yaml:"static_version,omitempty"`
+ // Critical is whether the static version channel should be marked as a critical update.
+ Critical bool `yaml:"critical"`
+
+ // versionGetter gets the version of the channel. It is populated by CheckAndSetDefaults.
+ versionGetter version.Getter
+ // criticalTrigger gets the criticality of the channel. It is populated by CheckAndSetDefaults.
+ criticalTrigger maintenance.Trigger
+ // teleportMajor stores the current teleport major for comparison.
+ // This field is initialized during CheckAndSetDefaults.
+ teleportMajor int
+}
+
+// CheckAndSetDefaults checks that the Channel configuration is valid and inits
+// the version getter and maintenance trigger of the Channel based on its
+// configuration. This function must be called before using the channel.
+func (c *Channel) CheckAndSetDefaults() error {
+ switch {
+ case c.ForwardURL != "" && (c.StaticVersion != "" || c.Critical):
+ return trace.BadParameter("cannot set both ForwardURL and (StaticVersion or Critical)")
+ case c.ForwardURL != "":
+ baseURL, err := url.Parse(c.ForwardURL)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ c.versionGetter = version.NewBasicHTTPVersionGetter(baseURL)
+ c.criticalTrigger = maintenance.NewBasicHTTPMaintenanceTrigger("remote", baseURL)
+ case c.StaticVersion != "":
+ c.versionGetter = version.NewStaticGetter(c.StaticVersion, nil)
+ c.criticalTrigger = maintenance.NewMaintenanceStaticTrigger("remote", c.Critical)
+ default:
+ return trace.BadParameter("either ForwardURL or StaticVersion must be set")
+ }
+
+ var err error
+ c.teleportMajor, err = parseMajorFromVersionString(teleport.Version)
+ if err != nil {
+ return trace.Wrap(err, "failed to process teleport version")
+ }
+
+ return nil
+}
+
+// GetVersion returns the current version of the channel. If io is involved,
+// this function implements cache and is safe to call frequently.
+// If the target version major is higher than the Teleport version (the one
+// in the Teleport binary, this is usually the proxy version), this function
+// returns the Teleport version instead.
+// If the version source intentionally did not specify a version, a
+// NoNewVersionError is returned.
+func (c *Channel) GetVersion(ctx context.Context) (string, error) {
+ targetVersion, err := c.versionGetter.GetVersion(ctx)
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+
+ targetMajor, err := parseMajorFromVersionString(targetVersion)
+ if err != nil {
+ return "", trace.Wrap(err, "failed to process target version")
+ }
+
+ // The target version is officially incompatible with our version,
+ // we prefer returning our version rather than having a broken client
+ if targetMajor > c.teleportMajor {
+ return teleport.Version, nil
+ }
+
+ return targetVersion, nil
+}
+
+// GetCritical returns the current criticality of the channel. If io is involved,
+// this function implements cache and is safe to call frequently.
+func (c *Channel) GetCritical(ctx context.Context) (bool, error) {
+ return c.criticalTrigger.CanStart(ctx, nil)
+}
+
+// NewDefaultChannel creates a default automatic upgrade channel
+// It looks up the environment variable, and if not found uses the default
+// base URL. This default channel can be used in the proxy (to back its own version server)
+// or in other Teleport process such as integration services deploying and
+// updating teleport agents.
+func NewDefaultChannel() (*Channel, error) {
+ forwardURL := GetChannel()
+ if forwardURL == "" {
+ forwardURL = stableCloudVersionBaseURL
+ }
+ defaultChannel := &Channel{
+ ForwardURL: forwardURL,
+ }
+ if err := defaultChannel.CheckAndSetDefaults(); err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return defaultChannel, nil
+}
+
+func parseMajorFromVersionString(v string) (int, error) {
+ v, err := version.EnsureSemver(v)
+ if err != nil {
+ return 0, trace.Wrap(err, "invalid semver: %s", v)
+ }
+ majorStr := semver.Major(v)
+ if majorStr == "" {
+ return 0, trace.BadParameter("cannot detect version major")
+ }
+
+ major, err := strconv.Atoi(strings.TrimPrefix(majorStr, "v"))
+ return major, trace.Wrap(err, "cannot convert version major to int")
+}
diff --git a/lib/automaticupgrades/channel_test.go b/lib/automaticupgrades/channel_test.go
new file mode 100644
index 0000000000000..e69e1a6ec3b68
--- /dev/null
+++ b/lib/automaticupgrades/channel_test.go
@@ -0,0 +1,211 @@
+/*
+ * Teleport
+ * Copyright (C) 2023 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package automaticupgrades
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "github.com/gravitational/teleport"
+ "github.com/gravitational/teleport/api/client/proto"
+ "github.com/gravitational/teleport/lib/automaticupgrades/constants"
+ "github.com/gravitational/teleport/lib/automaticupgrades/maintenance"
+ "github.com/gravitational/teleport/lib/automaticupgrades/version"
+)
+
+const testVersion = "v1.2.3"
+
+func Test_Channels_CheckAndSetDefaults(t *testing.T) {
+ noFeatures := proto.Features{}
+ cloudFeatures := proto.Features{Cloud: true, AutomaticUpgrades: true}
+ customChannelURL := "https://foo.example.com/bar"
+ t.Run("no channels", func(t *testing.T) {
+ c := Channels{}
+ require.NoError(t, c.CheckAndSetDefaults(noFeatures))
+ })
+ t.Run("single channel", func(t *testing.T) {
+ channel := &Channel{StaticVersion: testVersion}
+ c := Channels{"foo": channel}
+ require.NoError(t, c.CheckAndSetDefaults(noFeatures))
+ require.NotNil(t, channel.versionGetter)
+ require.NotNil(t, channel.criticalTrigger)
+ })
+ t.Run("many channels", func(t *testing.T) {
+ channel1 := &Channel{StaticVersion: testVersion}
+ channel2 := &Channel{StaticVersion: testVersion}
+ channel3 := &Channel{StaticVersion: testVersion}
+ c := Channels{"foo": channel1, "bar": channel2, "baz": channel3}
+ require.NoError(t, c.CheckAndSetDefaults(noFeatures))
+ require.NotNil(t, channel1.versionGetter)
+ require.NotNil(t, channel1.criticalTrigger)
+ require.NotNil(t, channel2.versionGetter)
+ require.NotNil(t, channel2.criticalTrigger)
+ require.NotNil(t, channel3.versionGetter)
+ require.NotNil(t, channel3.criticalTrigger)
+ })
+ t.Run("default channels for cloud", func(t *testing.T) {
+ // Cloud must have `default` and `stable/cloud` channels by default
+ c := Channels{}
+ require.NoError(t, c.CheckAndSetDefaults(cloudFeatures))
+ require.Len(t, c, 2)
+ defaultChannel, ok := c[DefaultChannelName]
+ require.True(t, ok)
+ require.Equal(t, stableCloudVersionBaseURL, defaultChannel.ForwardURL)
+ stableCloudChannel, ok := c[DefaultCloudChannelName]
+ require.True(t, ok)
+ require.Equal(t, stableCloudVersionBaseURL, stableCloudChannel.ForwardURL)
+ })
+ t.Run("cloud override stable/cloud", func(t *testing.T) {
+ // When "stable/cloud" channel is configured, CheckAndSetDefaults
+ // must honor it AND also use it as the "default" channel.
+ c := Channels{DefaultCloudChannelName: &Channel{ForwardURL: customChannelURL}}
+ require.NoError(t, c.CheckAndSetDefaults(cloudFeatures))
+ require.Len(t, c, 2)
+ stableCloudChannel, ok := c[DefaultCloudChannelName]
+ require.True(t, ok)
+ require.Equal(t, customChannelURL, stableCloudChannel.ForwardURL)
+ defaultChannel, ok := c[DefaultChannelName]
+ require.True(t, ok)
+ require.Equal(t, customChannelURL, defaultChannel.ForwardURL)
+ })
+ t.Run("cloud override default", func(t *testing.T) {
+ // When the "default" channel is manually configured, CheckAndSetDefaults
+ // must honor it.
+ // In this test, only the "default" channel must be custom, the
+ // "stable/cloud" channel must point to the standard cloud URL.
+ c := Channels{DefaultChannelName: &Channel{ForwardURL: customChannelURL}}
+ require.NoError(t, c.CheckAndSetDefaults(cloudFeatures))
+ require.Len(t, c, 2)
+ defaultChannel, ok := c[DefaultChannelName]
+ require.True(t, ok)
+ require.Equal(t, customChannelURL, defaultChannel.ForwardURL)
+ stableCloudChannel, ok := c[DefaultCloudChannelName]
+ require.True(t, ok)
+ require.Equal(t, stableCloudVersionBaseURL, stableCloudChannel.ForwardURL)
+ })
+ t.Run("self-hosted no channel", func(t *testing.T) {
+ // In self-hosted automatic-upgrades setups, we need a default channel.
+ // For backward compatibility we should add it instead of requiring it.
+ c := Channels{}
+ require.NoError(t, c.CheckAndSetDefaults(proto.Features{AutomaticUpgrades: true}))
+ require.Len(t, c, 1)
+ defaultChannel, ok := c[DefaultChannelName]
+ require.True(t, ok)
+ require.Equal(t, stableCloudVersionBaseURL, defaultChannel.ForwardURL)
+ _, ok = c[DefaultCloudChannelName]
+ require.False(t, ok)
+ })
+
+}
+
+func Test_Channel_CheckAndSetDefaults(t *testing.T) {
+
+ tests := []struct {
+ name string
+ channel Channel
+ assertError require.ErrorAssertionFunc
+ expectedVersionGetterType interface{}
+ expectedCriticalTriggerType interface{}
+ }{
+ {
+ name: "empty (invalid)",
+ channel: Channel{},
+ assertError: require.Error,
+ },
+ {
+ name: "forward URL (valid)",
+ channel: Channel{
+ ForwardURL: stableCloudVersionBaseURL,
+ },
+ assertError: require.NoError,
+ expectedVersionGetterType: version.BasicHTTPVersionGetter{},
+ expectedCriticalTriggerType: maintenance.BasicHTTPMaintenanceTrigger{},
+ },
+ {
+ name: "static version (valid)",
+ channel: Channel{
+ StaticVersion: testVersion,
+ },
+ assertError: require.NoError,
+ expectedVersionGetterType: version.StaticGetter{},
+ expectedCriticalTriggerType: maintenance.StaticTrigger{},
+ },
+ {
+ name: "all set (invalid)",
+ channel: Channel{
+ ForwardURL: stableCloudVersionBaseURL,
+ StaticVersion: testVersion,
+ },
+ assertError: require.Error,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tt.assertError(t, tt.channel.CheckAndSetDefaults())
+ require.IsType(t, tt.expectedVersionGetterType, tt.channel.versionGetter)
+ require.IsType(t, tt.expectedCriticalTriggerType, tt.channel.criticalTrigger)
+ })
+ }
+}
+
+func Test_Channel_GetVersion(t *testing.T) {
+ ctx := context.Background()
+ tests := []struct {
+ name string
+ targetVersion string
+ expectedVersion string
+ assertErr require.ErrorAssertionFunc
+ }{
+ {
+ name: "normal version",
+ targetVersion: "v1.2.3",
+ expectedVersion: "v1.2.3",
+ assertErr: require.NoError,
+ },
+ {
+ name: "no version",
+ targetVersion: constants.NoVersion,
+ expectedVersion: "",
+ assertErr: require.Error,
+ },
+ {
+ name: "version too high",
+ targetVersion: "v99.1.1",
+ expectedVersion: teleport.Version,
+ assertErr: require.NoError,
+ },
+ {
+ name: "version invalid",
+ targetVersion: "foobar",
+ assertErr: require.Error,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ch := Channel{StaticVersion: tt.targetVersion}
+ require.NoError(t, ch.CheckAndSetDefaults())
+
+ result, err := ch.GetVersion(ctx)
+ tt.assertErr(t, err)
+ require.Equal(t, tt.expectedVersion, result)
+ })
+ }
+}
diff --git a/lib/automaticupgrades/constants/constants.go b/lib/automaticupgrades/constants/constants.go
new file mode 100644
index 0000000000000..ec40a052f8240
--- /dev/null
+++ b/lib/automaticupgrades/constants/constants.go
@@ -0,0 +1,34 @@
+/*
+Copyright 2023 Gravitational, Inc.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+package constants
+
+import "time"
+
+const (
+ // MaintenancePath is the version discovery endpoint representing if the
+ // target version represents a critical update as defined in
+ // https://github.com/gravitational/teleport/blob/master/rfd/0109-cloud-agent-upgrades.md#version-discovery-endpoint
+ MaintenancePath = "critical"
+ // VersionPath is the version discovery endpoint returning the current
+ // target version as defined in
+ // https://github.com/gravitational/teleport/blob/master/rfd/0109-cloud-agent-upgrades.md#version-discovery-endpoint
+ VersionPath = "version"
+ HTTPTimeout = 10 * time.Second
+ CacheDuration = time.Minute
+
+ // NoVersion is returned by the version endpoint when there is no valid target version.
+ // This can be caused by the target version being incompatible with the cluster version.
+ NoVersion = "none"
+)
diff --git a/lib/automaticupgrades/maintenance/basichttp.go b/lib/automaticupgrades/maintenance/basichttp.go
new file mode 100644
index 0000000000000..5df16fbb4a2c0
--- /dev/null
+++ b/lib/automaticupgrades/maintenance/basichttp.go
@@ -0,0 +1,110 @@
+/*
+Copyright 2023 Gravitational, Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package maintenance
+
+import (
+ "context"
+ "net/http"
+ "net/url"
+ "strings"
+
+ "github.com/gravitational/trace"
+ "sigs.k8s.io/controller-runtime/pkg/client"
+
+ "github.com/gravitational/teleport/lib/automaticupgrades/basichttp"
+ "github.com/gravitational/teleport/lib/automaticupgrades/cache"
+ "github.com/gravitational/teleport/lib/automaticupgrades/constants"
+)
+
+// basicHTTPMaintenanceClient retrieves whether the target version represents a
+// critical update from an HTTP endpoint. It should not be invoked immediately
+// and must be wrapped in a cache layer in order to avoid spamming the version
+// server in case of reconciliation errors.
+// use BasicHTTPMaintenanceTrigger if you need to check if an update is critical.
+type basicHTTPMaintenanceClient struct {
+ baseURL *url.URL
+ client *basichttp.Client
+}
+
+// Get sends an HTTP GET request and returns whether the current target version
+// represents a critical update.
+func (b *basicHTTPMaintenanceClient) Get(ctx context.Context) (bool, error) {
+ versionURL := b.baseURL.JoinPath(constants.MaintenancePath)
+ body, err := b.client.GetContent(ctx, *versionURL)
+ if err != nil {
+ return false, trace.Wrap(err)
+ }
+ // Validating early that the payload can be converted to a boolean allows to
+ // gracefully catch connectivity error caused by mitm infrastructure such as
+ // corporate proxies.
+ result, err := stringToBool(strings.TrimSpace(string(body)))
+ return result, trace.Wrap(err)
+}
+
+// BasicHTTPMaintenanceTrigger gets the critical status from an HTTP response
+// containing only a truthy or falsy string.
+// This is used typically to trigger emergency maintenances from a file hosted
+// in a s3 bucket or raw file served over HTTP.
+// BasicHTTPMaintenanceTrigger uses basicHTTPMaintenanceClient and wraps it in a cache
+// in order to mitigate the impact of too frequent reconciliations.
+// The structure implements the maintenance.Trigger interface.
+type BasicHTTPMaintenanceTrigger struct {
+ name string
+ cachedGetter func(context.Context) (bool, error)
+}
+
+// Name implements maintenance.Triggernd returns the trigger name for logging
+// and debugging pursposes.
+func (g BasicHTTPMaintenanceTrigger) Name() string {
+ return g.name
+}
+
+// Default returns what to do if the trigger can't be evaluated.
+// BasicHTTPMaintenanceTrigger should fail open, so the function returns true.
+func (g BasicHTTPMaintenanceTrigger) Default() bool {
+ return false
+}
+
+// CanStart implements maintenance.Trigger
+func (g BasicHTTPMaintenanceTrigger) CanStart(ctx context.Context, _ client.Object) (bool, error) {
+ result, err := g.cachedGetter(ctx)
+ return result, trace.Wrap(err)
+}
+
+// NewBasicHTTPMaintenanceTrigger builds and return a Trigger checking a public HTTP endpoint.
+func NewBasicHTTPMaintenanceTrigger(name string, baseURL *url.URL) Trigger {
+ client := &http.Client{
+ Timeout: constants.HTTPTimeout,
+ }
+ httpMaintenanceClient := &basicHTTPMaintenanceClient{
+ baseURL: baseURL,
+ client: &basichttp.Client{Client: client},
+ }
+
+ return BasicHTTPMaintenanceTrigger{name, cache.NewTimedMemoize[bool](httpMaintenanceClient.Get, constants.CacheDuration).Get}
+}
+
+func stringToBool(input string) (bool, error) {
+ switch {
+ case strings.EqualFold("true", input), strings.EqualFold("yes", input):
+ return true, nil
+ case strings.EqualFold("false", input), strings.EqualFold("no", input):
+ return false, nil
+ default:
+ return false, trace.BadParameter("cannot convert input to boolean: %s", input)
+ }
+}
diff --git a/lib/automaticupgrades/maintenance/basichttp_test.go b/lib/automaticupgrades/maintenance/basichttp_test.go
new file mode 100644
index 0000000000000..4304a78ee16e2
--- /dev/null
+++ b/lib/automaticupgrades/maintenance/basichttp_test.go
@@ -0,0 +1,106 @@
+/*
+Copyright 2023 Gravitational, Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package maintenance
+
+import (
+ "context"
+ "net/http"
+ "net/url"
+ "testing"
+
+ "github.com/gravitational/trace"
+ "github.com/stretchr/testify/require"
+
+ basichttp2 "github.com/gravitational/teleport/lib/automaticupgrades/basichttp"
+ "github.com/gravitational/teleport/lib/automaticupgrades/constants"
+)
+
+const basicHTTPTestPath = "/v1/cloud-stable"
+
+func Test_basicHTTPMaintenanceClient_Get(t *testing.T) {
+ mock := basichttp2.NewServerMock(basicHTTPTestPath + "/" + constants.MaintenancePath)
+ t.Cleanup(mock.Srv.Close)
+ serverURL, err := url.Parse(mock.Srv.URL)
+ serverURL.Path = basicHTTPTestPath
+ require.NoError(t, err)
+ ctx := context.Background()
+
+ tests := []struct {
+ name string
+ statusCode int
+ response string
+ expected bool
+ assertErr require.ErrorAssertionFunc
+ }{
+ {
+ name: "all good - no maintenance",
+ statusCode: http.StatusOK,
+ response: "no",
+ expected: false,
+ assertErr: require.NoError,
+ },
+ {
+ name: "all good - maintenance",
+ statusCode: http.StatusOK,
+ response: "yes",
+ expected: true,
+ assertErr: require.NoError,
+ },
+ {
+ name: "all good with newline",
+ statusCode: http.StatusOK,
+ response: "yes\n",
+ expected: true,
+ assertErr: require.NoError,
+ },
+ {
+ name: "invalid response",
+ statusCode: http.StatusOK,
+ response: "hello",
+ expected: false,
+ assertErr: require.Error,
+ },
+ {
+ name: "empty",
+ statusCode: http.StatusOK,
+ response: "",
+ expected: false,
+ assertErr: func(t2 require.TestingT, err2 error, _ ...interface{}) {
+ require.IsType(t2, &trace.BadParameterError{}, trace.Unwrap(err2))
+ },
+ },
+ {
+ name: "non-200 response",
+ statusCode: http.StatusInternalServerError,
+ response: "ERROR - SOMETHING WENT WRONG",
+ expected: false,
+ assertErr: require.Error,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ b := &basicHTTPMaintenanceClient{
+ baseURL: serverURL,
+ client: &basichttp2.Client{Client: mock.Srv.Client()},
+ }
+ mock.SetResponse(t, tt.statusCode, tt.response)
+ result, err := b.Get(ctx)
+ tt.assertErr(t, err)
+ require.Equal(t, tt.expected, result)
+ })
+ }
+}
diff --git a/lib/automaticupgrades/maintenance/mock.go b/lib/automaticupgrades/maintenance/mock.go
new file mode 100644
index 0000000000000..a777bfca0764c
--- /dev/null
+++ b/lib/automaticupgrades/maintenance/mock.go
@@ -0,0 +1,55 @@
+/*
+Copyright 2023 Gravitational, Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package maintenance
+
+import (
+ "context"
+
+ "sigs.k8s.io/controller-runtime/pkg/client"
+)
+
+// StaticTrigger is a fake Trigger that return a static answer. This is used
+// for testing purposes and is inherently disruptive.
+type StaticTrigger struct {
+ name string
+ canStart bool
+}
+
+// Name returns the StaticTrigger name.
+func (m StaticTrigger) Name() string {
+ return m.name
+}
+
+// CanStart returns the statically defined maintenance approval result.
+func (m StaticTrigger) CanStart(_ context.Context, _ client.Object) (bool, error) {
+ return m.canStart, nil
+}
+
+// Default returns the default behavior if the trigger fails. This cannot
+// happen for a StaticTrigger and is here solely to implement the Trigger
+// interface.
+func (m StaticTrigger) Default() bool {
+ return m.canStart
+}
+
+// NewMaintenanceStaticTrigger creates a StaticTrigger
+func NewMaintenanceStaticTrigger(name string, canStart bool) Trigger {
+ return StaticTrigger{
+ name: name,
+ canStart: canStart,
+ }
+}
diff --git a/lib/automaticupgrades/maintenance/trigger.go b/lib/automaticupgrades/maintenance/trigger.go
new file mode 100644
index 0000000000000..457078f34fa49
--- /dev/null
+++ b/lib/automaticupgrades/maintenance/trigger.go
@@ -0,0 +1,62 @@
+/*
+Copyright 2023 Gravitational, Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package maintenance
+
+import (
+ "context"
+
+ "sigs.k8s.io/controller-runtime/pkg/client"
+ ctrllog "sigs.k8s.io/controller-runtime/pkg/log"
+)
+
+// Trigger is evaluated to decide whether a maintenance can happen or not.
+// Maintenances can happen because of multiple reasons like:
+// - attempt to recover from a broken state
+// - we are in a maintenance window
+// - emergency security patch
+// Each Trigger has a Name() for logging purposes and a Default() method
+// returning whether the trigger should allow the maintenance or not in case\
+// of error.
+type Trigger interface {
+ Name() string
+ CanStart(ctx context.Context, object client.Object) (bool, error)
+ Default() bool
+}
+
+// Triggers is a list of Trigger. Triggers are OR-ed: any trigger firing in the
+// list will cause the maintenance to be triggered.
+type Triggers []Trigger
+
+// CanStart checks if the maintenance can be started. It will return true if at
+// least a Trigger approves the maintenance.
+func (t Triggers) CanStart(ctx context.Context, object client.Object) bool {
+ log := ctrllog.FromContext(ctx).V(1)
+ for _, trigger := range t {
+ start, err := trigger.CanStart(ctx, object)
+ if err != nil {
+ start = trigger.Default()
+ log.Error(err, "trigger failed to evaluate, using its default value", "trigger", trigger.Name(), "defaultValue", start)
+ } else {
+ log.Info("trigger evaluated", "trigger", trigger.Name(), "result", start)
+ }
+ if start {
+ log.Info("maintenance triggered", "trigger", trigger.Name())
+ return true
+ }
+ }
+ return false
+}
diff --git a/lib/automaticupgrades/version.go b/lib/automaticupgrades/version.go
deleted file mode 100644
index d1f6c393f1605..0000000000000
--- a/lib/automaticupgrades/version.go
+++ /dev/null
@@ -1,129 +0,0 @@
-/*
-Copyright 2023 Gravitational, Inc.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-*/
-
-package automaticupgrades
-
-import (
- "context"
- "net/http"
- "net/url"
- "strings"
-
- "github.com/gravitational/trace"
-
- "github.com/gravitational/teleport"
- "github.com/gravitational/teleport/lib/utils"
-)
-
-const (
- // stableCloudVersionBaseURL is the base URL for the server that returns the current stable/cloud version.
- stableCloudVersionBaseURL = "https://updates.releases.teleport.dev"
-
- // stableCloudVersionPath is the URL path that returns the current stable/cloud version.
- stableCloudVersionPath = "/v1/stable/cloud/version"
-
- // stableCloudCriticalPath is the URL path that returns the stable/cloud critical flag.
- stableCloudCriticalPath = "/v1/stable/cloud/critical"
-)
-
-// Version returns the version that should be used for installing Teleport Services
-// This is used when installing agents using scripts.
-// Even when Teleport Auth/Proxy is using vX, the agents must always respect this version.
-func Version(ctx context.Context, versionURL string) (string, error) {
- versionURL, err := getVersionURL(versionURL)
- if err != nil {
- return "", trace.Wrap(err)
- }
-
- resp, err := sendRequest(ctx, versionURL)
- if err != nil {
- return "", trace.Wrap(err)
- }
-
- return resp, nil
-}
-
-// Critical returns true if a critical upgrade is available.
-func Critical(ctx context.Context, criticalURL string) (bool, error) {
- criticalURL, err := getCriticalURL(criticalURL)
- if err != nil {
- return false, trace.Wrap(err)
- }
-
- critical, err := sendRequest(ctx, criticalURL)
- if err != nil {
- return false, trace.Wrap(err)
- }
-
- // Expects critical endpoint to return either the string "yes" or "no"
- switch critical {
- case "yes":
- return true, nil
- case "no":
- return false, nil
- default:
- return false, trace.BadParameter("critical endpoint returned an unexpected value: %v", critical)
- }
-}
-
-// sendRequest sends a GET request to the reqURL and returns the response value
-func sendRequest(ctx context.Context, reqURL string) (string, error) {
- req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil)
- if err != nil {
- return "", trace.Wrap(err)
- }
-
- resp, err := http.DefaultClient.Do(req)
- if err != nil {
- return "", trace.Wrap(err)
- }
- defer resp.Body.Close()
-
- body, err := utils.ReadAtMost(resp.Body, teleport.MaxHTTPResponseSize)
- if err != nil {
- return "", trace.Wrap(err)
- }
-
- if resp.StatusCode != http.StatusOK {
- return "", trace.BadParameter("invalid status code %d, body: %s", resp.StatusCode, string(body))
- }
-
- return strings.TrimSpace(string(body)), trace.Wrap(err)
-}
-
-// getVersionURL returns the versionURL or the default stable/cloud version url.
-func getVersionURL(versionURL string) (string, error) {
- if versionURL != "" {
- return versionURL, nil
- }
- cloudStableVersionURL, err := url.JoinPath(stableCloudVersionBaseURL, stableCloudVersionPath)
- if err != nil {
- return "", trace.Wrap(err)
- }
- return cloudStableVersionURL, nil
-}
-
-// getCriticalURL returns the criticalURL or the default stable/cloud critical url.
-func getCriticalURL(criticalURL string) (string, error) {
- if criticalURL != "" {
- return criticalURL, nil
- }
- cloudStableCriticalURL, err := url.JoinPath(stableCloudVersionBaseURL, stableCloudCriticalPath)
- if err != nil {
- return "", trace.Wrap(err)
- }
- return cloudStableCriticalURL, nil
-}
diff --git a/lib/automaticupgrades/version/basichttp.go b/lib/automaticupgrades/version/basichttp.go
new file mode 100644
index 0000000000000..8887bd265faee
--- /dev/null
+++ b/lib/automaticupgrades/version/basichttp.go
@@ -0,0 +1,83 @@
+/*
+Copyright 2023 Gravitational, Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package version
+
+import (
+ "context"
+ "net/http"
+ "net/url"
+ "strings"
+
+ "github.com/gravitational/trace"
+
+ "github.com/gravitational/teleport/lib/automaticupgrades/basichttp"
+ "github.com/gravitational/teleport/lib/automaticupgrades/cache"
+ "github.com/gravitational/teleport/lib/automaticupgrades/constants"
+)
+
+// basicHTTPVersionClient retrieves the version from an HTTP endpoint
+// it should not be invoked immediately and must be wrapped in a cache layer in
+// order to avoid spamming the version server in case of reconciliation errors.
+// use BasicHTTPVersionGetter if you need to get a version.
+type basicHTTPVersionClient struct {
+ baseURL *url.URL
+ client *basichttp.Client
+}
+
+// Get sends an HTTP GET request and returns the version prefixed by "v".
+// It expects the endpoint to be unauthenticated, return 200 and the response
+// body to contain a valid semver tag without the "v".
+func (b *basicHTTPVersionClient) Get(ctx context.Context) (string, error) {
+ versionURL := b.baseURL.JoinPath(constants.VersionPath)
+ body, err := b.client.GetContent(ctx, *versionURL)
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+ response := string(body)
+ if response == constants.NoVersion {
+ return "", &NoNewVersionError{Message: "version server did not advertise a version"}
+ }
+ // We trim spaces because the value might end with one or many newlines
+ version, err := EnsureSemver(strings.TrimSpace(response))
+ return version, trace.Wrap(err)
+}
+
+// BasicHTTPVersionGetter gets the version from an HTTP response containing
+// only the version. This is used typically to get version from a file hosted
+// in a s3 bucket or raw file served over HTTP.
+// BasicHTTPVersionGetter uses basicHTTPVersionClient and wraps it in a cache
+// in order to mitigate the impact of too frequent reconciliations.
+// The structure implements the version.Getter interface.
+type BasicHTTPVersionGetter struct {
+ versionGetter func(context.Context) (string, error)
+}
+
+func (g BasicHTTPVersionGetter) GetVersion(ctx context.Context) (string, error) {
+ return g.versionGetter(ctx)
+}
+
+func NewBasicHTTPVersionGetter(baseURL *url.URL) Getter {
+ client := &http.Client{
+ Timeout: constants.HTTPTimeout,
+ }
+ httpVersionClient := &basicHTTPVersionClient{
+ baseURL: baseURL,
+ client: &basichttp.Client{Client: client},
+ }
+
+ return BasicHTTPVersionGetter{cache.NewTimedMemoize[string](httpVersionClient.Get, constants.CacheDuration).Get}
+}
diff --git a/lib/automaticupgrades/version/basichttp_test.go b/lib/automaticupgrades/version/basichttp_test.go
new file mode 100644
index 0000000000000..05db877fdfbbf
--- /dev/null
+++ b/lib/automaticupgrades/version/basichttp_test.go
@@ -0,0 +1,101 @@
+/*
+Copyright 2023 Gravitational, Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package version
+
+import (
+ "context"
+ "net/http"
+ "net/url"
+ "testing"
+
+ "github.com/gravitational/trace"
+ "github.com/stretchr/testify/require"
+
+ basichttp2 "github.com/gravitational/teleport/lib/automaticupgrades/basichttp"
+ "github.com/gravitational/teleport/lib/automaticupgrades/constants"
+)
+
+const basicHTTPTestPath = "/v1/cloud-stable"
+
+func Test_basicHTTPVersionClient_Get(t *testing.T) {
+ mock := basichttp2.NewServerMock(basicHTTPTestPath + "/" + constants.VersionPath)
+ t.Cleanup(mock.Srv.Close)
+ serverURL, err := url.Parse(mock.Srv.URL)
+ serverURL.Path = basicHTTPTestPath
+ require.NoError(t, err)
+ ctx := context.Background()
+
+ tests := []struct {
+ name string
+ statusCode int
+ response string
+ expected string
+ assertErr require.ErrorAssertionFunc
+ }{
+ {
+ name: "all good",
+ statusCode: http.StatusOK,
+ response: "12.0.3",
+ expected: "v12.0.3",
+ assertErr: require.NoError,
+ },
+ {
+ name: "all good with newline",
+ statusCode: http.StatusOK,
+ response: "12.0.3\n",
+ expected: "v12.0.3",
+ assertErr: require.NoError,
+ },
+ {
+ name: "non-semver",
+ statusCode: http.StatusOK,
+ response: "hello",
+ expected: "",
+ assertErr: func(t2 require.TestingT, err2 error, _ ...interface{}) {
+ require.IsType(t2, &trace.BadParameterError{}, trace.Unwrap(err2))
+ },
+ },
+ {
+ name: "empty",
+ statusCode: http.StatusOK,
+ response: "",
+ expected: "",
+ assertErr: func(t2 require.TestingT, err2 error, _ ...interface{}) {
+ require.IsType(t2, &trace.BadParameterError{}, trace.Unwrap(err2))
+ },
+ },
+ {
+ name: "non-200 response",
+ statusCode: http.StatusInternalServerError,
+ response: "ERROR - SOMETHING WENT WRONG",
+ expected: "",
+ assertErr: require.Error,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ b := &basicHTTPVersionClient{
+ baseURL: serverURL,
+ client: &basichttp2.Client{Client: mock.Srv.Client()},
+ }
+ mock.SetResponse(t, tt.statusCode, tt.response)
+ result, err := b.Get(ctx)
+ tt.assertErr(t, err)
+ require.Equal(t, tt.expected, result)
+ })
+ }
+}
diff --git a/lib/automaticupgrades/version/errors.go b/lib/automaticupgrades/version/errors.go
new file mode 100644
index 0000000000000..96344c60858de
--- /dev/null
+++ b/lib/automaticupgrades/version/errors.go
@@ -0,0 +1,36 @@
+/*
+ * Teleport
+ * Copyright (C) 2023 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package version
+
+import "fmt"
+
+// NoNewVersionError indicates that no new version was found and the controller did not reconcile.
+type NoNewVersionError struct {
+ Message string `json:"message"`
+ CurrentVersion string `json:"currentVersion"`
+ NextVersion string `json:"nextVersion"`
+}
+
+// Error returns log friendly description of an error
+func (e *NoNewVersionError) Error() string {
+ if e.Message != "" {
+ return e.Message
+ }
+ return fmt.Sprintf("no new version (current: %q, next: %q)", e.CurrentVersion, e.NextVersion)
+}
diff --git a/lib/automaticupgrades/version/static.go b/lib/automaticupgrades/version/static.go
new file mode 100644
index 0000000000000..ec3e629c1d2b7
--- /dev/null
+++ b/lib/automaticupgrades/version/static.go
@@ -0,0 +1,56 @@
+/*
+Copyright 2023 Gravitational, Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package version
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ "github.com/gravitational/teleport/lib/automaticupgrades/constants"
+)
+
+// StaticGetter is a fake version.Getter that return a static answer. This is used
+// for testing purposes.
+type StaticGetter struct {
+ version string
+ err error
+}
+
+// GetVersion returns the statically defined version.
+func (v StaticGetter) GetVersion(_ context.Context) (string, error) {
+ return v.version, v.err
+}
+
+// NewStaticGetter creates a StaticGetter
+func NewStaticGetter(version string, err error) Getter {
+ if version == constants.NoVersion {
+ return StaticGetter{
+ version: "",
+ err: &NoNewVersionError{Message: fmt.Sprintf("target version set to '%s'", constants.NoVersion)},
+ }
+ }
+
+ semVersion := version
+ if semVersion != "" && !strings.HasPrefix(semVersion, "v") {
+ semVersion = "v" + version
+ }
+ return StaticGetter{
+ version: semVersion,
+ err: err,
+ }
+}
diff --git a/lib/automaticupgrades/version/versionget.go b/lib/automaticupgrades/version/versionget.go
new file mode 100644
index 0000000000000..00ffdc31fda3f
--- /dev/null
+++ b/lib/automaticupgrades/version/versionget.go
@@ -0,0 +1,65 @@
+/*
+Copyright 2023 Gravitational, Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package version
+
+import (
+ "context"
+ "strings"
+
+ "github.com/gravitational/trace"
+ "golang.org/x/mod/semver"
+ ctrllog "sigs.k8s.io/controller-runtime/pkg/log"
+)
+
+// Getter gets the target image version for an external source. It should cache
+// the result to reduce io and avoid potential rate-limits and is safe to call
+// multiple times over a short period.
+// If the version source intentionally returns no version, a NoNewVersionError is
+// returned.
+type Getter interface {
+ GetVersion(context.Context) (string, error)
+}
+
+// ValidVersionChange receives the current version and the candidate next version
+// and evaluates if the version transition is valid.
+func ValidVersionChange(ctx context.Context, current, next string) bool {
+ log := ctrllog.FromContext(ctx).V(1)
+ // Cannot upgrade to a non-valid version
+ if !semver.IsValid(next) {
+ log.Error(trace.BadParameter("next version is not following semver"), "version change is invalid", "nextVersion", next)
+ return false
+ }
+ switch semver.Compare(next, current) {
+ // No need to upgrade if version is the same
+ case 0:
+ return false
+ default:
+ return true
+ }
+}
+
+// EnsureSemver adds the 'v' prefix if needed and ensures the provided version
+// is semver-compliant.
+func EnsureSemver(current string) (string, error) {
+ if !strings.HasPrefix(current, "v") {
+ current = "v" + current
+ }
+ if !semver.IsValid(current) {
+ return "", trace.BadParameter("tag %s is not following semver", current)
+ }
+ return current, nil
+}
diff --git a/lib/automaticupgrades/version/versionget_test.go b/lib/automaticupgrades/version/versionget_test.go
new file mode 100644
index 0000000000000..f1d1b46fe56bf
--- /dev/null
+++ b/lib/automaticupgrades/version/versionget_test.go
@@ -0,0 +1,71 @@
+/*
+Copyright 2023 Gravitational, Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package version
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+const (
+ semverLow = "v11.3.2"
+ semverMid = "v11.5.4"
+ semverHigh = "v12.2.1"
+ invalidSemverHigh = "12.2.1"
+)
+
+func TestValidVersionChange(t *testing.T) {
+ ctx := context.Background()
+ tests := []struct {
+ name string
+ current string
+ next string
+ want bool
+ }{
+ {
+ name: "upgrade",
+ current: semverMid,
+ next: semverHigh,
+ want: true,
+ },
+ {
+ name: "same version",
+ current: semverMid,
+ next: semverMid,
+ want: false,
+ },
+ {
+ name: "unknown current version",
+ current: "",
+ next: semverMid,
+ want: true,
+ },
+ {
+ name: "non-semver current version",
+ current: semverMid,
+ next: invalidSemverHigh,
+ want: false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ require.Equal(t, tt.want, ValidVersionChange(ctx, tt.current, tt.next))
+ })
+ }
+}
diff --git a/lib/automaticupgrades/version_test.go b/lib/automaticupgrades/version_test.go
deleted file mode 100644
index 4ba8d299201e3..0000000000000
--- a/lib/automaticupgrades/version_test.go
+++ /dev/null
@@ -1,201 +0,0 @@
-/*
-Copyright 2023 Gravitational, Inc.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-*/
-
-package automaticupgrades
-
-import (
- "context"
- "net/http"
- "net/http/httptest"
- "net/url"
- "testing"
-
- "github.com/gravitational/trace"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-func TestVersion(t *testing.T) {
- ctx := context.Background()
-
- isBadParameterErr := func(tt require.TestingT, err error, i ...any) {
- require.True(tt, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
- }
-
- for _, tt := range []struct {
- name string
- mockStatusCode int
- mockResponseString string
- errCheck require.ErrorAssertionFunc
- expectedVersion string
- }{
- {
- name: "real response",
- mockStatusCode: http.StatusOK,
- mockResponseString: "v13.1.1\n",
- errCheck: require.NoError,
- expectedVersion: "v13.1.1",
- },
- {
- name: "invalid status code (500)",
- mockStatusCode: http.StatusInternalServerError,
- errCheck: isBadParameterErr,
- },
- {
- name: "invalid status code (403)",
- mockStatusCode: http.StatusForbidden,
- errCheck: isBadParameterErr,
- },
- {
- name: "valid but has spaces",
- mockStatusCode: http.StatusOK,
- mockResponseString: " v13.1.1 \n \r\n",
- errCheck: require.NoError,
- expectedVersion: "v13.1.1",
- },
- } {
- t.Run(tt.name, func(t *testing.T) {
- httpTestServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- assert.Equal(t, "/v1/stable/cloud/version", r.URL.Path)
- w.WriteHeader(tt.mockStatusCode)
- w.Write([]byte(tt.mockResponseString))
- }))
- defer httpTestServer.Close()
-
- versionURL, err := url.JoinPath(httpTestServer.URL, "/v1/stable/cloud/version")
- require.NoError(t, err)
-
- v, err := Version(ctx, versionURL)
- tt.errCheck(t, err)
- if err != nil {
- return
- }
-
- require.Equal(t, tt.expectedVersion, v)
- })
- }
-}
-
-func TestCritical(t *testing.T) {
- ctx := context.Background()
-
- isBadParameterErr := func(tt require.TestingT, err error, i ...any) {
- require.True(tt, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
- }
-
- for _, tt := range []struct {
- name string
- mockStatusCode int
- mockResponseString string
- errCheck require.ErrorAssertionFunc
- expectedCritical bool
- }{
- {
- name: "critical available",
- mockStatusCode: http.StatusOK,
- mockResponseString: "yes\n",
- errCheck: require.NoError,
- expectedCritical: true,
- },
- {
- name: "critical is not available",
- mockStatusCode: http.StatusOK,
- mockResponseString: "no\n",
- errCheck: require.NoError,
- expectedCritical: false,
- },
- {
- name: "invalid status code (500)",
- mockStatusCode: http.StatusInternalServerError,
- errCheck: isBadParameterErr,
- },
- {
- name: "invalid status code (403)",
- mockStatusCode: http.StatusForbidden,
- errCheck: isBadParameterErr,
- },
- } {
- t.Run(tt.name, func(t *testing.T) {
- httpTestServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- assert.Equal(t, "/v1/stable/cloud/critical", r.URL.Path)
- w.WriteHeader(tt.mockStatusCode)
- w.Write([]byte(tt.mockResponseString))
- }))
- defer httpTestServer.Close()
-
- criticalURL, err := url.JoinPath(httpTestServer.URL, "/v1/stable/cloud/critical")
- require.NoError(t, err)
-
- v, err := Critical(ctx, criticalURL)
- tt.errCheck(t, err)
- if err != nil {
- return
- }
-
- require.Equal(t, tt.expectedCritical, v)
- })
- }
-}
-
-func TestGetVersionURL(t *testing.T) {
- for _, tt := range []struct {
- name string
- versionURL string
- expectedURL string
- }{
- {
- name: "default stable/cloud version url",
- versionURL: "",
- expectedURL: "https://updates.releases.teleport.dev/v1/stable/cloud/version",
- },
- {
- name: "custom version url",
- versionURL: "https://custom.dev/version",
- expectedURL: "https://custom.dev/version",
- },
- } {
- t.Run(tt.name, func(t *testing.T) {
- v, err := getVersionURL(tt.versionURL)
- require.NoError(t, err)
- require.Equal(t, tt.expectedURL, v)
- })
- }
-}
-
-func TestGetCriticalURL(t *testing.T) {
- for _, tt := range []struct {
- name string
- criticalURL string
- expectedURL string
- }{
- {
- name: "default stable/cloud critical url",
- criticalURL: "",
- expectedURL: "https://updates.releases.teleport.dev/v1/stable/cloud/critical",
- },
- {
- name: "custom critical url",
- criticalURL: "https://custom.dev/critical",
- expectedURL: "https://custom.dev/critical",
- },
- } {
- t.Run(tt.name, func(t *testing.T) {
- v, err := getCriticalURL(tt.criticalURL)
- require.NoError(t, err)
- require.Equal(t, tt.expectedURL, v)
- })
- }
-}
diff --git a/lib/config/configuration.go b/lib/config/configuration.go
index 08b51f59fb332..75e0675df5a7e 100644
--- a/lib/config/configuration.go
+++ b/lib/config/configuration.go
@@ -1003,6 +1003,10 @@ func applyProxyConfig(fc *FileConfig, cfg *servicecfg.Config) error {
}
}
+ if fc.Proxy.AutomaticUpgradesChannels != nil {
+ cfg.Proxy.AutomaticUpgradesChannels = fc.Proxy.AutomaticUpgradesChannels
+ }
+
if fc.Proxy.MySQLServerVersion != "" {
cfg.Proxy.MySQLServerVersion = fc.Proxy.MySQLServerVersion
}
diff --git a/lib/config/fileconf.go b/lib/config/fileconf.go
index c6173f32535c6..1b880520e8e57 100644
--- a/lib/config/fileconf.go
+++ b/lib/config/fileconf.go
@@ -45,6 +45,7 @@ import (
apiutils "github.com/gravitational/teleport/api/utils"
awsapiutils "github.com/gravitational/teleport/api/utils/aws"
"github.com/gravitational/teleport/api/utils/tlsutils"
+ "github.com/gravitational/teleport/lib/automaticupgrades"
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/defaults"
@@ -2181,6 +2182,11 @@ type Proxy struct {
// the "X-Forwarded-For" headers for web APIs received from layer 7 load
// balancers or reverse proxies.
TrustXForwardedFor types.Bool `yaml:"trust_x_forwarded_for,omitempty"`
+
+ // AutomaticUpgradesChannels is a map of all version channels used by the
+ // proxy built-in version server to retrieve target versions. This is part
+ // of the automatic upgrades.
+ AutomaticUpgradesChannels automaticupgrades.Channels `yaml:"automatic_upgrades_channels,omitempty"`
}
// UIConfig provides config options for the web UI served by the proxy service.
diff --git a/lib/service/awsoidc.go b/lib/service/awsoidc.go
index 07b5bc468b765..9ae8892b4d55f 100644
--- a/lib/service/awsoidc.go
+++ b/lib/service/awsoidc.go
@@ -19,7 +19,6 @@ package service
import (
"context"
"fmt"
- "net/url"
"strings"
"time"
@@ -67,18 +66,13 @@ func (process *TeleportProcess) initAWSOIDCDeployServiceUpdater() error {
return nil
}
- // If criticalEndpoint or versionEndpoint are empty, the default stable/cloud endpoint will be used
- var criticalEndpoint string
- var versionEndpoint string
- if automaticupgrades.GetChannel() != "" {
- criticalEndpoint, err = url.JoinPath(automaticupgrades.GetChannel(), "critical")
- if err != nil {
- return trace.Wrap(err)
- }
- versionEndpoint, err = url.JoinPath(automaticupgrades.GetChannel(), "version")
- if err != nil {
- return trace.Wrap(err)
- }
+ // TODO: use the proxy channel if available?
+ // This would require to pass the proxy configuration there, but would avoid
+ // future inconsistencies: if the proxy is manually configured to serve a
+ // static version, it will not be picked up by the AWS OIDC deploy updater.
+ upgradeChannel, err := automaticupgrades.NewDefaultChannel()
+ if err != nil {
+ return trace.Wrap(err)
}
issuer, err := awsoidc.IssuerFromPublicAddress(process.proxyPublicAddr().Addr)
@@ -98,8 +92,7 @@ func (process *TeleportProcess) initAWSOIDCDeployServiceUpdater() error {
TeleportClusterName: clusterNameConfig.GetClusterName(),
TeleportClusterVersion: resp.GetServerVersion(),
AWSOIDCProviderAddr: issuer,
- CriticalEndpoint: criticalEndpoint,
- VersionEndpoint: versionEndpoint,
+ UpgradeChannel: upgradeChannel,
})
if err != nil {
return trace.Wrap(err)
@@ -123,10 +116,8 @@ type AWSOIDCDeployServiceUpdaterConfig struct {
TeleportClusterVersion string
// AWSOIDCProvderAddr specifies the AWS OIDC provider address used to generate AWS OIDC tokens
AWSOIDCProviderAddr string
- // CriticalEndpoint specifies the endpoint to check for critical updates
- CriticalEndpoint string
- // VersionEndpoint specifies the endpoint to check for current teleport version
- VersionEndpoint string
+ // UpgradeChannel is the channel that serves the version used by the updater.
+ UpgradeChannel *automaticupgrades.Channel
}
// CheckAndSetDefaults checks and sets default config values.
@@ -201,7 +192,7 @@ func (updater *AWSOIDCDeployServiceUpdater) updateAWSOIDCDeployServices(ctx cont
return trace.Wrap(err)
}
- critical, err := automaticupgrades.Critical(ctx, updater.CriticalEndpoint)
+ critical, err := updater.UpgradeChannel.GetCritical(ctx)
if err != nil {
return trace.Wrap(err)
}
@@ -212,7 +203,7 @@ func (updater *AWSOIDCDeployServiceUpdater) updateAWSOIDCDeployServices(ctx cont
return nil
}
- stableVersion, err := automaticupgrades.Version(ctx, updater.VersionEndpoint)
+ stableVersion, err := updater.UpgradeChannel.GetVersion(ctx)
if err != nil {
return trace.Wrap(err)
}
diff --git a/lib/service/service.go b/lib/service/service.go
index c47ac8b1fad38..193fdef2d11b9 100644
--- a/lib/service/service.go
+++ b/lib/service/service.go
@@ -3947,33 +3947,34 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
}
webConfig := web.Config{
- Proxy: tsrv,
- AuthServers: cfg.AuthServerAddresses()[0],
- DomainName: cfg.Hostname,
- ProxyClient: conn.Client,
- ProxySSHAddr: proxySSHAddr,
- ProxyWebAddr: cfg.Proxy.WebAddr,
- ProxyPublicAddrs: cfg.Proxy.PublicAddrs,
- CipherSuites: cfg.CipherSuites,
- FIPS: cfg.FIPS,
- AccessPoint: accessPoint,
- Emitter: streamEmitter,
- PluginRegistry: process.PluginRegistry,
- HostUUID: process.Config.HostUUID,
- Context: process.ExitContext(),
- StaticFS: fs,
- ClusterFeatures: process.getClusterFeatures(),
- UI: cfg.Proxy.UI,
- ProxySettings: proxySettings,
- PublicProxyAddr: process.proxyPublicAddr().Addr,
- ALPNHandler: alpnHandlerForWeb.HandleConnection,
- ProxyKubeAddr: proxyKubeAddr,
- TraceClient: traceClt,
- Router: proxyRouter,
- SessionControl: sessionController,
- PROXYSigner: proxySigner,
- OpenAIConfig: cfg.Testing.OpenAIConfig,
- NodeWatcher: nodeWatcher,
+ Proxy: tsrv,
+ AuthServers: cfg.AuthServerAddresses()[0],
+ DomainName: cfg.Hostname,
+ ProxyClient: conn.Client,
+ ProxySSHAddr: proxySSHAddr,
+ ProxyWebAddr: cfg.Proxy.WebAddr,
+ ProxyPublicAddrs: cfg.Proxy.PublicAddrs,
+ CipherSuites: cfg.CipherSuites,
+ FIPS: cfg.FIPS,
+ AccessPoint: accessPoint,
+ Emitter: streamEmitter,
+ PluginRegistry: process.PluginRegistry,
+ HostUUID: process.Config.HostUUID,
+ Context: process.ExitContext(),
+ StaticFS: fs,
+ ClusterFeatures: process.getClusterFeatures(),
+ UI: cfg.Proxy.UI,
+ ProxySettings: proxySettings,
+ PublicProxyAddr: process.proxyPublicAddr().Addr,
+ ALPNHandler: alpnHandlerForWeb.HandleConnection,
+ ProxyKubeAddr: proxyKubeAddr,
+ TraceClient: traceClt,
+ Router: proxyRouter,
+ SessionControl: sessionController,
+ PROXYSigner: proxySigner,
+ OpenAIConfig: cfg.Testing.OpenAIConfig,
+ NodeWatcher: nodeWatcher,
+ AutomaticUpgradesChannels: cfg.Proxy.AutomaticUpgradesChannels,
}
webHandler, err := web.NewHandler(webConfig)
if err != nil {
diff --git a/lib/service/servicecfg/proxy.go b/lib/service/servicecfg/proxy.go
index af2ca611fc2ee..27891645b94e7 100644
--- a/lib/service/servicecfg/proxy.go
+++ b/lib/service/servicecfg/proxy.go
@@ -24,6 +24,7 @@ import (
"github.com/gravitational/trace"
"github.com/gravitational/teleport/api/client/webclient"
+ "github.com/gravitational/teleport/lib/automaticupgrades"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/limiter"
"github.com/gravitational/teleport/lib/utils"
@@ -151,6 +152,11 @@ type ProxyConfig struct {
// as a label and used by reverse tunnel agents in proxy peering mode. Zero
// is a valid generation.
ProxyGroupGeneration uint64
+
+ // AutomaticUpgradesChannels is a map of all version channels used by the
+ // proxy built-in version server to retrieve target versions. This is part
+ // of the automatic upgrades.
+ AutomaticUpgradesChannels automaticupgrades.Channels
}
// WebPublicAddr returns the address for the web endpoint on this proxy that
diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go
index b4e757fbb4db0..7b6ab279d3961 100644
--- a/lib/web/apiserver.go
+++ b/lib/web/apiserver.go
@@ -269,11 +269,10 @@ type Config struct {
// the proxy's cache and get nodes in real time.
NodeWatcher *services.NodeWatcher
- // AutomaticUpgradesVersionURL is the URL which returns the target agent version.
- // This URL must returns a valid version string.
- // Eg, v13.4.3
- // Optional: uses cloud/stable channel when omitted.
- AutomaticUpgradesVersionURL string
+ // AutomaticUpgradesChannels is a map of all version channels used by the
+ // proxy built-in version server to retrieve target versions. This is part
+ // of the automatic upgrades.
+ AutomaticUpgradesChannels automaticupgrades.Channels
}
// SetDefaults ensures proper default values are set if
@@ -350,6 +349,17 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) {
h.assistantLimiter = rate.NewLimiter(rate.Inf, 0)
}
+ if automaticUpgrades(cfg.ClusterFeatures) && h.cfg.AutomaticUpgradesChannels == nil {
+ h.cfg.AutomaticUpgradesChannels = automaticupgrades.Channels{}
+ }
+
+ if h.cfg.AutomaticUpgradesChannels != nil {
+ err := h.cfg.AutomaticUpgradesChannels.CheckAndSetDefaults(cfg.ClusterFeatures)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ }
+
// for properly handling url-encoded parameter values.
h.UseRawPath = true
@@ -828,6 +838,10 @@ func (h *Handler) bindDefaultEndpoints() {
// Updates the user's preferences
h.PUT("/webapi/user/preferences", h.WithAuth(h.updateUserPreferences))
+
+ // Implements the agent version server.
+ // Channel can contain "/", hence the use of a catch-all parameter
+ h.GET("/webapi/automaticupgrades/channel/*request", h.WithLimiter(h.automaticUpgrades))
}
// GetProxyClient returns authenticated auth server client
@@ -1477,7 +1491,7 @@ func (h *Handler) getWebConfig(w http.ResponseWriter, r *http.Request, p httprou
automaticUpgradesEnabled := clusterFeatures.GetAutomaticUpgrades()
var automaticUpgradesTargetVersion string
if automaticUpgradesEnabled {
- automaticUpgradesTargetVersion, err = automaticupgrades.Version(r.Context(), h.cfg.AutomaticUpgradesVersionURL)
+ automaticUpgradesTargetVersion, err = h.cfg.AutomaticUpgradesChannels.DefaultVersion(r.Context())
if err != nil {
return nil, trace.Wrap(err)
}
diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go
index 0c34b1621ddc5..49575adb4dc4f 100644
--- a/lib/web/apiserver_test.go
+++ b/lib/web/apiserver_test.go
@@ -96,6 +96,7 @@ import (
"github.com/gravitational/teleport/lib/auth/testauthority"
wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes"
"github.com/gravitational/teleport/lib/authz"
+ "github.com/gravitational/teleport/lib/automaticupgrades"
"github.com/gravitational/teleport/lib/bpf"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/client/conntest"
@@ -4584,20 +4585,21 @@ func TestGetWebConfig(t *testing.T) {
}
env.proxies[0].handler.handler.cfg.ProxySettings = mockProxySetting
- httpTestServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- assert.Equal(t, r.URL.Path, "/v1/stable/cloud/version")
- w.WriteHeader(http.StatusOK)
- w.Write([]byte("v99.0.1"))
- }))
- defer httpTestServer.Close()
- versionURL, err := url.JoinPath(httpTestServer.URL, "/v1/stable/cloud/version")
require.NoError(t, err)
- env.proxies[0].handler.handler.cfg.AutomaticUpgradesVersionURL = versionURL
+ // This version is too high and MUST NOT be used
+ testVersion := "v99.0.1"
+ channels := automaticupgrades.Channels{
+ automaticupgrades.DefaultCloudChannelName: {
+ StaticVersion: testVersion,
+ },
+ }
+ require.NoError(t, channels.CheckAndSetDefaults(authproto.Features{AutomaticUpgrades: true, Cloud: true}))
+ env.proxies[0].handler.handler.cfg.AutomaticUpgradesChannels = channels
expectedCfg.IsCloud = true
expectedCfg.IsUsageBasedBilling = true
expectedCfg.AutomaticUpgrades = true
- expectedCfg.AutomaticUpgradesTargetVersion = "v99.0.1"
+ expectedCfg.AutomaticUpgradesTargetVersion = teleport.Version
expectedCfg.AssistEnabled = false
// request and verify enabled features are enabled.
diff --git a/lib/web/automaticupgrades.go b/lib/web/automaticupgrades.go
new file mode 100644
index 0000000000000..6b7833dc629e2
--- /dev/null
+++ b/lib/web/automaticupgrades.go
@@ -0,0 +1,119 @@
+/*
+ * Teleport
+ * Copyright (C) 2023 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package web
+
+import (
+ "context"
+ "errors"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/gravitational/trace"
+ "github.com/julienschmidt/httprouter"
+
+ "github.com/gravitational/teleport/lib/automaticupgrades"
+ "github.com/gravitational/teleport/lib/automaticupgrades/constants"
+ "github.com/gravitational/teleport/lib/automaticupgrades/version"
+)
+
+const defaultChannelTimeout = 5 * time.Second
+
+// automaticUpgrades implements a version server in the Teleport Proxy.
+// It is configured through the Teleport Proxy configuration and tells agent updaters
+// which version they should install.
+func (h *Handler) automaticUpgrades(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
+ if h.cfg.AutomaticUpgradesChannels == nil {
+ return nil, trace.AccessDenied("This proxy is not configured to serve automatic upgrades channels.")
+ }
+
+ // The request format is "/{version,critical}"
+ // As might contain "/" we have to split, pop the last part
+ // and re-construct the channel name.
+ channelAndType := p.ByName("request")
+
+ reqParts := strings.Split(strings.Trim(channelAndType, "/"), "/")
+ if len(reqParts) < 2 {
+ return nil, trace.BadParameter("path format should be /webapi/automaticupgrades/channel//{version,critical}")
+ }
+ requestType := reqParts[len(reqParts)-1]
+ channelName := strings.Join(reqParts[:len(reqParts)-1], "/")
+
+ if channelName == "" {
+ return nil, trace.BadParameter("a channel name is required")
+ }
+
+ // We check if the channel is configured
+ channel, ok := h.cfg.AutomaticUpgradesChannels[channelName]
+ if !ok {
+ return nil, trace.NotFound("channel %s not found", channelName)
+ }
+
+ // Finally, we treat the request based on its type
+ switch requestType {
+ case "version":
+ h.log.Debugf("Agent requesting version for channel %s", channelName)
+ return h.automaticUpgradesVersion(w, r, channel)
+ case "critical":
+ h.log.Debugf("Agent requesting criticality for channel %s", channelName)
+ return h.automaticUpgradesCritical(w, r, channel)
+ default:
+ return nil, trace.BadParameter("requestType path must end with 'version' or 'critical'")
+ }
+}
+
+// automaticUpgradesVersion handles version requests from upgraders
+func (h *Handler) automaticUpgradesVersion(w http.ResponseWriter, r *http.Request, channel *automaticupgrades.Channel) (interface{}, error) {
+ ctx, cancel := context.WithTimeout(r.Context(), defaultChannelTimeout)
+ defer cancel()
+
+ targetVersion, err := channel.GetVersion(ctx)
+ if err != nil {
+ // If the error is that the upstream channel has no version
+ // We gracefully handle by serving "none"
+ var NoNewVersionErr *version.NoNewVersionError
+ if errors.As(trace.Unwrap(err), &NoNewVersionErr) {
+ _, err = w.Write([]byte(constants.NoVersion))
+ return nil, trace.Wrap(err)
+ }
+ // Else we propagate the error
+ return nil, trace.Wrap(err)
+ }
+
+ _, err = w.Write([]byte(targetVersion))
+ return nil, trace.Wrap(err)
+}
+
+// automaticUpgradesCritical handles criticality requests from upgraders
+func (h *Handler) automaticUpgradesCritical(w http.ResponseWriter, r *http.Request, channel *automaticupgrades.Channel) (interface{}, error) {
+ ctx, cancel := context.WithTimeout(r.Context(), defaultChannelTimeout)
+ defer cancel()
+
+ critical, err := channel.GetCritical(ctx)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ response := "no"
+ if critical {
+ response = "yes"
+ }
+ _, err = w.Write([]byte(response))
+ return nil, trace.Wrap(err)
+}
diff --git a/lib/web/integrations_awsoidc.go b/lib/web/integrations_awsoidc.go
index 8a6b393fa3eb7..1eb3b228e53e8 100644
--- a/lib/web/integrations_awsoidc.go
+++ b/lib/web/integrations_awsoidc.go
@@ -26,7 +26,6 @@ import (
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/api/utils/aws"
- "github.com/gravitational/teleport/lib/automaticupgrades"
"github.com/gravitational/teleport/lib/httplib"
"github.com/gravitational/teleport/lib/integrations/awsoidc"
"github.com/gravitational/teleport/lib/reversetunnelclient"
@@ -149,7 +148,7 @@ func (h *Handler) awsOIDCDeployService(w http.ResponseWriter, r *http.Request, p
teleportVersionTag := teleport.Version
if automaticUpgrades(h.ClusterFeatures) {
- cloudStableVersion, err := automaticupgrades.Version(ctx, "" /* use default version server */)
+ cloudStableVersion, err := h.cfg.AutomaticUpgradesChannels.DefaultVersion(ctx)
if err != nil {
return "", trace.Wrap(err)
}
diff --git a/lib/web/join_tokens.go b/lib/web/join_tokens.go
index 9423b24b662da..1045bd8cd95c6 100644
--- a/lib/web/join_tokens.go
+++ b/lib/web/join_tokens.go
@@ -40,7 +40,6 @@ import (
"github.com/gravitational/teleport/api/types"
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/lib/auth"
- "github.com/gravitational/teleport/lib/automaticupgrades"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/httplib"
"github.com/gravitational/teleport/lib/modules"
@@ -77,9 +76,10 @@ type scriptSettings struct {
databaseInstallMode bool
installUpdater bool
- // automaticUpgradesVersionURL is the URL for getting the version when using the cloud/stable channel.
- // Optional.
- automaticUpgradesVersionURL string
+ // automaticUpgradesVersion is the target automatic upgrades version.
+ // The version must be valid semver, with the leading 'v'. e.g. v15.0.0-dev
+ // Required when installUpdater is true.
+ automaticUpgradesVersion string
}
// automaticUpgrades returns whether automaticUpgrades should be enabled.
@@ -206,14 +206,38 @@ func (h *Handler) createTokenHandle(w http.ResponseWriter, r *http.Request, para
}, nil
}
+// getAutoUpgrades checks if automaticUpgrades are enabled and returns the
+// version that should be used according to auto upgrades default channel.
+func (h *Handler) getAutoUpgrades(ctx context.Context) (bool, string, error) {
+ var autoUpgradesVersion string
+ var err error
+ autoUpgrades := automaticUpgrades(h.ClusterFeatures)
+ if autoUpgrades {
+ autoUpgradesVersion, err = h.cfg.AutomaticUpgradesChannels.DefaultVersion(ctx)
+ if err != nil {
+ log.WithError(err).Info("Failed to get auto upgrades version.")
+ return false, "", trace.Wrap(err)
+ }
+ }
+ return autoUpgrades, autoUpgradesVersion, nil
+
+}
+
func (h *Handler) getNodeJoinScriptHandle(w http.ResponseWriter, r *http.Request, params httprouter.Params) (interface{}, error) {
httplib.SetScriptHeaders(w.Header())
+ autoUpgrades, autoUpgradesVersion, err := h.getAutoUpgrades(r.Context())
+ if err != nil {
+ w.Write(scripts.ErrorBashScript)
+ return nil, nil
+ }
+
settings := scriptSettings{
- token: params.ByName("token"),
- appInstallMode: false,
- joinMethod: r.URL.Query().Get("method"),
- installUpdater: automaticUpgrades(h.ClusterFeatures),
+ token: params.ByName("token"),
+ appInstallMode: false,
+ joinMethod: r.URL.Query().Get("method"),
+ installUpdater: autoUpgrades,
+ automaticUpgradesVersion: autoUpgradesVersion,
}
script, err := getJoinScript(r.Context(), settings, h.GetProxyClient())
@@ -250,12 +274,19 @@ func (h *Handler) getAppJoinScriptHandle(w http.ResponseWriter, r *http.Request,
return nil, nil
}
+ autoUpgrades, autoUpgradesVersion, err := h.getAutoUpgrades(r.Context())
+ if err != nil {
+ w.Write(scripts.ErrorBashScript)
+ return nil, nil
+ }
+
settings := scriptSettings{
- token: params.ByName("token"),
- appInstallMode: true,
- appName: name,
- appURI: uri,
- installUpdater: automaticUpgrades(h.ClusterFeatures),
+ token: params.ByName("token"),
+ appInstallMode: true,
+ appName: name,
+ appURI: uri,
+ installUpdater: autoUpgrades,
+ automaticUpgradesVersion: autoUpgradesVersion,
}
script, err := getJoinScript(r.Context(), settings, h.GetProxyClient())
@@ -277,10 +308,17 @@ func (h *Handler) getAppJoinScriptHandle(w http.ResponseWriter, r *http.Request,
func (h *Handler) getDatabaseJoinScriptHandle(w http.ResponseWriter, r *http.Request, params httprouter.Params) (interface{}, error) {
httplib.SetScriptHeaders(w.Header())
+ autoUpgrades, autoUpgradesVersion, err := h.getAutoUpgrades(r.Context())
+ if err != nil {
+ w.Write(scripts.ErrorBashScript)
+ return nil, nil
+ }
+
settings := scriptSettings{
- token: params.ByName("token"),
- databaseInstallMode: true,
- installUpdater: automaticUpgrades(h.ClusterFeatures),
+ token: params.ByName("token"),
+ databaseInstallMode: true,
+ installUpdater: autoUpgrades,
+ automaticUpgradesVersion: autoUpgradesVersion,
}
script, err := getJoinScript(r.Context(), settings, h.GetProxyClient())
@@ -391,17 +429,17 @@ func getJoinScript(ctx context.Context, settings scriptSettings, m nodeAPIGetter
// The install script will install the updater (teleport-ent-updater) for Cloud customers enrolled in Automatic Upgrades.
// The repo channel used must be `stable/cloud` which has the available packages for the Cloud Customer's agents.
- // It pins the teleport version to the one specified by https://updates.releases.teleport.dev/v1/stable/cloud/version
+ // It pins the teleport version to the one specified by the default version channel
// This ensures the initial installed version is the same as the `teleport-ent-updater` would install.
if settings.installUpdater {
- repoChannel = stableCloudChannelRepo
- cloudStableVersion, err := automaticupgrades.Version(ctx, settings.automaticUpgradesVersionURL)
- if err != nil {
- return "", trace.Wrap(err)
+ if settings.automaticUpgradesVersion == "" {
+ return "", trace.Wrap(err, "automatic upgrades version must be set when installUpdater is true")
}
- // cloudStableVersion has vX.Y.Z format, however the script expects the version to not include the `v`
- version = strings.TrimPrefix(cloudStableVersion, "v")
+ repoChannel = stableCloudChannelRepo
+ // automaticUpgradesVersion has vX.Y.Z format, however the script
+ // expects the version to not include the `v` so we strip it
+ version = strings.TrimPrefix(settings.automaticUpgradesVersion, "v")
}
// This section relies on Go's default zero values to make sure that the settings
diff --git a/lib/web/join_tokens_test.go b/lib/web/join_tokens_test.go
index eb2c88e403e98..a7f008dd3b8c8 100644
--- a/lib/web/join_tokens_test.go
+++ b/lib/web/join_tokens_test.go
@@ -20,14 +20,10 @@ import (
"context"
"encoding/hex"
"fmt"
- "net/http"
- "net/http/httptest"
- "net/url"
"regexp"
"testing"
"github.com/gravitational/trace"
- "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/gravitational/teleport"
@@ -935,18 +931,7 @@ func TestJoinScript(t *testing.T) {
t.Run("using repo", func(t *testing.T) {
t.Run("installUpdater is true", func(t *testing.T) {
currentStableCloudVersion := "v99.1.1"
-
- httpTestServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- assert.Equal(t, r.URL.Path, "/v1/stable/cloud/version")
- w.WriteHeader(http.StatusOK)
- w.Write([]byte(currentStableCloudVersion))
- }))
- defer httpTestServer.Close()
-
- versionURL, err := url.JoinPath(httpTestServer.URL, "/v1/stable/cloud/version")
- require.NoError(t, err)
-
- script, err := getJoinScript(context.Background(), scriptSettings{token: validToken, installUpdater: true, automaticUpgradesVersionURL: versionURL}, m)
+ script, err := getJoinScript(context.Background(), scriptSettings{token: validToken, installUpdater: true, automaticUpgradesVersion: currentStableCloudVersion}, m)
require.NoError(t, err)
// list of packages must include the updater