diff --git a/examples/chart/teleport-kube-agent/templates/updater/deployment.yaml b/examples/chart/teleport-kube-agent/templates/updater/deployment.yaml index b9d547638df9a..e7359243d1601 100644 --- a/examples/chart/teleport-kube-agent/templates/updater/deployment.yaml +++ b/examples/chart/teleport-kube-agent/templates/updater/deployment.yaml @@ -65,7 +65,7 @@ spec: - "--agent-name={{ .Release.Name }}" - "--agent-namespace={{ .Release.Namespace }}" - "--base-image={{ include "teleport-kube-agent.baseImage" . }}" - - "--version-server={{ $updater.versionServer }}" + - "--version-server={{ tpl $updater.versionServer . }}" - "--version-channel={{ $updater.releaseChannel }}" {{- if .Values.updater.extraArgs }} {{- toYaml .Values.updater.extraArgs | nindent 10 }} diff --git a/examples/chart/teleport-kube-agent/tests/updater_deployment_test.yaml b/examples/chart/teleport-kube-agent/tests/updater_deployment_test.yaml index 032c8348f88d5..b6b28f40bee22 100644 --- a/examples/chart/teleport-kube-agent/tests/updater_deployment_test.yaml +++ b/examples/chart/teleport-kube-agent/tests/updater_deployment_test.yaml @@ -57,6 +57,16 @@ tests: - contains: path: spec.template.spec.containers[0].args content: "--agent-namespace=my-namespace" + - it: defaults the updater version server to the proxy address + set: + proxyAddr: proxy.teleport.example.com:443 + roles: "custom" + updater: + enabled: true + asserts: + - contains: + path: spec.template.spec.containers[0].args + content: "--version-server=https://proxy.teleport.example.com:443/v1/webapi/automaticupgrades/channel" - it: sets the updater version server values: - ../.lint/updater.yaml diff --git a/examples/chart/teleport-kube-agent/values.yaml b/examples/chart/teleport-kube-agent/values.yaml index 567973ecbbd34..0c8ba7411616e 100644 --- a/examples/chart/teleport-kube-agent/values.yaml +++ b/examples/chart/teleport-kube-agent/values.yaml @@ -146,7 +146,8 @@ updater: # `updater.versionServer` is the URL of the version server the agent fetches # the target version from. The complete version endpoint is built by # concatenating `versionServer` and `releaseChannel`. - versionServer: "https://updates.releases.teleport.dev/v1/" + # This field supports gotemplate. + versionServer: "https://{{ .Values.proxyAddr }}/v1/webapi/automaticupgrades/channel" # Release channel the agent subscribes to. releaseChannel: "stable/cloud" image: public.ecr.aws/gravitational/teleport-kube-agent-updater diff --git a/integration/proxy/automaticupgrades_test.go b/integration/proxy/automaticupgrades_test.go new file mode 100644 index 0000000000000..28fa4d3cd49e6 --- /dev/null +++ b/integration/proxy/automaticupgrades_test.go @@ -0,0 +1,189 @@ +/* + * Teleport + * Copyright (C) 2023 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package proxy + +import ( + "context" + "crypto/tls" + "fmt" + "io" + "net/http" + "net/url" + "path/filepath" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/integration/helpers" + "github.com/gravitational/teleport/lib/automaticupgrades" + "github.com/gravitational/teleport/lib/automaticupgrades/basichttp" + "github.com/gravitational/teleport/lib/automaticupgrades/constants" + "github.com/gravitational/teleport/lib/service/servicecfg" + "github.com/gravitational/teleport/lib/utils" +) + +func createProxyWithChannels(t *testing.T, channels automaticupgrades.Channels) string { + features := proto.Features{} + require.NoError(t, channels.CheckAndSetDefaults(features)) + testDir := t.TempDir() + + cfg := helpers.InstanceConfig{ + ClusterName: "root.example.com", + HostID: uuid.New().String(), + NodeName: helpers.Loopback, + Log: utils.NewLoggerForTests(), + } + cfg.Listeners = helpers.SingleProxyPortSetup(t, &cfg.Fds) + rc := helpers.NewInstance(t, cfg) + + var err error + rcConf := servicecfg.MakeDefaultConfig() + rcConf.DataDir = filepath.Join(testDir, "data") + rcConf.Auth.Enabled = true + rcConf.Proxy.Enabled = true + rcConf.SSH.Enabled = false + rcConf.Proxy.DisableWebInterface = true + rcConf.Version = "v3" + rcConf.Proxy.AutomaticUpgradesChannels = channels + + err = rc.CreateEx(t, nil, rcConf) + require.NoError(t, err) + err = rc.Start() + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, rc.StopAll()) + }) + + return cfg.Listeners.Web +} + +func TestVersionServer(t *testing.T) { + // Test setup + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testVersion := "v12.2.6" + testVersionMajorTooHigh := "v99.1.3" + + staticChannel := "static/ok" + staticHighChannel := "static/high" + staticNoVersionChannel := "static/none" + forwardChannel := "forward/ok" + forwardHighChannel := "forward/high" + forwardNoVersionChannel := "forward/none" + forwardPath := "/version-server/" + + upstreamServer := basichttp.NewServerMock(forwardPath + constants.VersionPath) + upstreamServer.SetResponse(t, http.StatusOK, testVersion) + upstreamHighServer := basichttp.NewServerMock(forwardPath + constants.VersionPath) + upstreamHighServer.SetResponse(t, http.StatusOK, testVersionMajorTooHigh) + upstreamNoVersionServer := basichttp.NewServerMock(forwardPath + constants.VersionPath) + upstreamNoVersionServer.SetResponse(t, http.StatusOK, constants.NoVersion) + + channels := automaticupgrades.Channels{ + staticChannel: { + StaticVersion: testVersion, + }, + staticHighChannel: { + StaticVersion: testVersionMajorTooHigh, + }, + staticNoVersionChannel: { + StaticVersion: constants.NoVersion, + }, + forwardChannel: { + ForwardURL: upstreamServer.Srv.URL + forwardPath, + }, + forwardHighChannel: { + ForwardURL: upstreamHighServer.Srv.URL + forwardPath, + }, + forwardNoVersionChannel: { + ForwardURL: upstreamNoVersionServer.Srv.URL + forwardPath, + }, + } + + proxyAddr := createProxyWithChannels(t, channels) + + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + httpClient := http.Client{Transport: tr} + + // Test execution + tests := []struct { + name string + channel string + expectedResponse string + }{ + { + name: "static version OK", + channel: staticChannel, + expectedResponse: testVersion, + }, + { + name: "static version too high", + channel: staticHighChannel, + expectedResponse: teleport.Version, + }, + { + name: "static version none", + channel: staticNoVersionChannel, + expectedResponse: constants.NoVersion, + }, + { + name: "forward version OK", + channel: forwardChannel, + expectedResponse: testVersion, + }, + { + name: "forward version too high", + channel: forwardHighChannel, + expectedResponse: teleport.Version, + }, + { + name: "forward version none", + channel: forwardNoVersionChannel, + expectedResponse: constants.NoVersion, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + channelUrl, err := url.Parse( + fmt.Sprintf("https://%s/v1/webapi/automaticupgrades/channel/%s/version", proxyAddr, tt.channel), + ) + require.NoError(t, err) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, channelUrl.String(), nil) + require.NoError(t, err) + res, err := httpClient.Do(req) + require.NoError(t, err) + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + + require.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, tt.expectedResponse, string(body)) + }) + } +} diff --git a/integrations/kube-agent-updater/cmd/teleport-kube-agent-updater/main.go b/integrations/kube-agent-updater/cmd/teleport-kube-agent-updater/main.go index 30e2f485cffc0..1a98dfa316d0c 100644 --- a/integrations/kube-agent-updater/cmd/teleport-kube-agent-updater/main.go +++ b/integrations/kube-agent-updater/cmd/teleport-kube-agent-updater/main.go @@ -188,6 +188,8 @@ func main() { os.Exit(1) } + ctrl.Log.Info("starting the updater", "url", versionServerURL.String()) + if err := mgr.Start(ctx); err != nil { ctrl.Log.Error(err, "failed to start manager, exiting") os.Exit(1) diff --git a/integrations/kube-agent-updater/pkg/constants/constants.go b/integrations/kube-agent-updater/pkg/constants/constants.go index 65c9e9c9d4e5b..2ca709962c89c 100644 --- a/integrations/kube-agent-updater/pkg/constants/constants.go +++ b/integrations/kube-agent-updater/pkg/constants/constants.go @@ -29,4 +29,8 @@ const ( VersionPath = "version" HTTPTimeout = 10 * time.Second CacheDuration = time.Minute + + // NoVersion is returned by the version endpoint when there is no valid target version. + // This can be caused by the target version being incompatible with the cluster version. + NoVersion = "none" ) diff --git a/integrations/kube-agent-updater/pkg/controller/deployment.go b/integrations/kube-agent-updater/pkg/controller/deployment.go index 1d208f2f43b41..c377bea0dbef9 100644 --- a/integrations/kube-agent-updater/pkg/controller/deployment.go +++ b/integrations/kube-agent-updater/pkg/controller/deployment.go @@ -27,6 +27,8 @@ import ( ctrl "sigs.k8s.io/controller-runtime" kclient "sigs.k8s.io/controller-runtime/pkg/client" ctrllog "sigs.k8s.io/controller-runtime/pkg/log" + + "github.com/gravitational/teleport/integrations/kube-agent-updater/pkg/version" ) // DeploymentVersionUpdater Reconciles a podSpec by changing its image @@ -74,7 +76,7 @@ func (r *DeploymentVersionUpdater) Reconcile(ctx context.Context, req ctrl.Reque image, err := r.GetVersion(ctx, &obj, currentVersion) var ( - noNewVersionErr *NoNewVersionError + noNewVersionErr *version.NoNewVersionError maintenanceErr *MaintenanceNotTriggeredError ) switch { diff --git a/integrations/kube-agent-updater/pkg/controller/errors.go b/integrations/kube-agent-updater/pkg/controller/errors.go index 157f7de32a816..fccbbad75756c 100644 --- a/integrations/kube-agent-updater/pkg/controller/errors.go +++ b/integrations/kube-agent-updater/pkg/controller/errors.go @@ -16,10 +16,6 @@ limitations under the License. package controller -import ( - "fmt" -) - // MaintenanceNotTriggeredError indicates that no trigger returned true and the controller did not reconcile. type MaintenanceNotTriggeredError struct { Message string `json:"message"` @@ -32,18 +28,3 @@ func (e *MaintenanceNotTriggeredError) Error() string { } return "maintenance not triggered" } - -// NoNewVersionError indicates that no new version was found and the controller did not reconcile. -type NoNewVersionError struct { - Message string `json:"message"` - CurrentVersion string `json:"currentVersion"` - NextVersion string `json:"nextVersion"` -} - -// Error returns log friendly description of an error -func (e *NoNewVersionError) Error() string { - if e.Message != "" { - return e.Message - } - return fmt.Sprintf("no new version (current: %q, next: %q)", e.CurrentVersion, e.NextVersion) -} diff --git a/integrations/kube-agent-updater/pkg/controller/statefulset.go b/integrations/kube-agent-updater/pkg/controller/statefulset.go index 5b78227197a2f..4c014c80f3fbc 100644 --- a/integrations/kube-agent-updater/pkg/controller/statefulset.go +++ b/integrations/kube-agent-updater/pkg/controller/statefulset.go @@ -32,6 +32,7 @@ import ( ctrllog "sigs.k8s.io/controller-runtime/pkg/log" "github.com/gravitational/teleport/integrations/kube-agent-updater/pkg/podutils" + "github.com/gravitational/teleport/integrations/kube-agent-updater/pkg/version" ) type StatefulSetVersionUpdater struct { @@ -99,7 +100,7 @@ func (r *StatefulSetVersionUpdater) Reconcile(ctx context.Context, req ctrl.Requ image, err := r.GetVersion(ctx, &obj, currentVersion) var ( - noNewVersionErr *NoNewVersionError + noNewVersionErr *version.NoNewVersionError maintenanceErr *MaintenanceNotTriggeredError ) switch { diff --git a/integrations/kube-agent-updater/pkg/controller/updater.go b/integrations/kube-agent-updater/pkg/controller/updater.go index 1ecc87647f2c9..764e5f7f948e3 100644 --- a/integrations/kube-agent-updater/pkg/controller/updater.go +++ b/integrations/kube-agent-updater/pkg/controller/updater.go @@ -61,7 +61,7 @@ func (r *VersionUpdater) GetVersion(ctx context.Context, obj client.Object, curr log.Info("New version candidate", "nextVersion", nextVersion) if !version.ValidVersionChange(ctx, currentVersion, nextVersion) { - return nil, &NoNewVersionError{CurrentVersion: currentVersion, NextVersion: nextVersion} + return nil, &version.NoNewVersionError{CurrentVersion: currentVersion, NextVersion: nextVersion} } log.Info("Version change is valid, building img candidate") diff --git a/integrations/kube-agent-updater/pkg/controller/updater_test.go b/integrations/kube-agent-updater/pkg/controller/updater_test.go index d33678577b1c5..51ae579aaf837 100644 --- a/integrations/kube-agent-updater/pkg/controller/updater_test.go +++ b/integrations/kube-agent-updater/pkg/controller/updater_test.go @@ -42,8 +42,8 @@ const ( ) var ( - alwaysTrigger = maintenance.NewMaintenanceTriggerMock("always trigger", true) - neverTrigger = maintenance.NewMaintenanceTriggerMock("never trigger", false) + alwaysTrigger = maintenance.NewMaintenanceStaticTrigger("always trigger", true) + neverTrigger = maintenance.NewMaintenanceStaticTrigger("never trigger", false) alwaysValid = img.NewImageValidatorMock( "always", true, @@ -80,7 +80,7 @@ func Test_VersionUpdater_GetVersion(t *testing.T) { releaseRegistry: defaultTestRegistry, releasePath: defaultTestPath, currentVersion: versionMid, - versionGetter: version.NewGetterMock(versionHigh, nil), + versionGetter: version.NewStaticGetter(versionHigh, nil), maintenanceTriggers: []maintenance.Trigger{alwaysTrigger}, imageCheckers: []img.Validator{alwaysValid}, assertErr: require.NoError, @@ -91,7 +91,7 @@ func Test_VersionUpdater_GetVersion(t *testing.T) { releaseRegistry: defaultTestRegistry, releasePath: defaultTestPath, currentVersion: "", - versionGetter: version.NewGetterMock(versionHigh, nil), + versionGetter: version.NewStaticGetter(versionHigh, nil), maintenanceTriggers: []maintenance.Trigger{alwaysTrigger}, imageCheckers: []img.Validator{alwaysValid}, assertErr: require.NoError, @@ -102,10 +102,21 @@ func Test_VersionUpdater_GetVersion(t *testing.T) { releaseRegistry: defaultTestRegistry, releasePath: defaultTestPath, currentVersion: versionMid, - versionGetter: version.NewGetterMock(versionMid, nil), + versionGetter: version.NewStaticGetter(versionMid, nil), maintenanceTriggers: []maintenance.Trigger{alwaysTrigger}, imageCheckers: []img.Validator{alwaysValid}, - assertErr: errorIsType(&NoNewVersionError{}), + assertErr: errorIsType(&version.NoNewVersionError{}), + expectedImage: "", + }, + { + name: "no version", + releaseRegistry: defaultTestRegistry, + releasePath: defaultTestPath, + currentVersion: versionMid, + versionGetter: version.NewStaticGetter("", &version.NoNewVersionError{Message: "version server did not advertise a version"}), + maintenanceTriggers: []maintenance.Trigger{alwaysTrigger}, + imageCheckers: []img.Validator{alwaysValid}, + assertErr: errorIsType(&version.NoNewVersionError{}), expectedImage: "", }, { @@ -113,7 +124,7 @@ func Test_VersionUpdater_GetVersion(t *testing.T) { releaseRegistry: defaultTestRegistry, releasePath: defaultTestPath, currentVersion: versionMid, - versionGetter: version.NewGetterMock(versionHigh, nil), + versionGetter: version.NewStaticGetter(versionHigh, nil), maintenanceTriggers: []maintenance.Trigger{neverTrigger}, imageCheckers: []img.Validator{alwaysValid}, assertErr: errorIsType(&MaintenanceNotTriggeredError{}), @@ -124,7 +135,7 @@ func Test_VersionUpdater_GetVersion(t *testing.T) { releaseRegistry: defaultTestRegistry, releasePath: defaultTestPath, currentVersion: versionMid, - versionGetter: version.NewGetterMock(versionHigh, nil), + versionGetter: version.NewStaticGetter(versionHigh, nil), maintenanceTriggers: []maintenance.Trigger{alwaysTrigger}, imageCheckers: []img.Validator{neverValid}, assertErr: errorIsType(&trace.TrustError{}), @@ -135,7 +146,7 @@ func Test_VersionUpdater_GetVersion(t *testing.T) { releaseRegistry: defaultTestRegistry, releasePath: defaultTestPath, currentVersion: versionMid, - versionGetter: version.NewGetterMock("", &trace.ConnectionProblemError{}), + versionGetter: version.NewStaticGetter("", &trace.ConnectionProblemError{}), maintenanceTriggers: []maintenance.Trigger{alwaysTrigger}, imageCheckers: []img.Validator{neverValid}, assertErr: errorIsType(&trace.ConnectionProblemError{}), @@ -159,7 +170,7 @@ func Test_VersionUpdater_GetVersion(t *testing.T) { baseImage: baseImage, } - // We need a dummy Kubernetes object, it is not used by the TriggerMock + // We need a dummy Kubernetes object, it is not used by the StaticTrigger obj := &core.Pod{} // Doing the test diff --git a/integrations/kube-agent-updater/pkg/maintenance/mock.go b/integrations/kube-agent-updater/pkg/maintenance/mock.go index 0c0c6b5738d21..a777bfca0764c 100644 --- a/integrations/kube-agent-updater/pkg/maintenance/mock.go +++ b/integrations/kube-agent-updater/pkg/maintenance/mock.go @@ -22,33 +22,33 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" ) -// TriggerMock is a fake Trigger that return a static answer. This is used +// StaticTrigger is a fake Trigger that return a static answer. This is used // for testing purposes and is inherently disruptive. -type TriggerMock struct { +type StaticTrigger struct { name string canStart bool } -// Name returns the TriggerMock name. -func (m TriggerMock) Name() string { +// Name returns the StaticTrigger name. +func (m StaticTrigger) Name() string { return m.name } // CanStart returns the statically defined maintenance approval result. -func (m TriggerMock) CanStart(_ context.Context, _ client.Object) (bool, error) { +func (m StaticTrigger) CanStart(_ context.Context, _ client.Object) (bool, error) { return m.canStart, nil } // Default returns the default behavior if the trigger fails. This cannot -// happen for a TriggerMock and is here solely to implement the Trigger +// happen for a StaticTrigger and is here solely to implement the Trigger // interface. -func (m TriggerMock) Default() bool { +func (m StaticTrigger) Default() bool { return m.canStart } -// NewMaintenanceTriggerMock creates a TriggerMock -func NewMaintenanceTriggerMock(name string, canStart bool) Trigger { - return TriggerMock{ +// NewMaintenanceStaticTrigger creates a StaticTrigger +func NewMaintenanceStaticTrigger(name string, canStart bool) Trigger { + return StaticTrigger{ name: name, canStart: canStart, } diff --git a/integrations/kube-agent-updater/pkg/version/basichttp.go b/integrations/kube-agent-updater/pkg/version/basichttp.go index 40b69cae91da8..cb0418e703dd5 100644 --- a/integrations/kube-agent-updater/pkg/version/basichttp.go +++ b/integrations/kube-agent-updater/pkg/version/basichttp.go @@ -47,8 +47,12 @@ func (b *basicHTTPVersionClient) Get(ctx context.Context) (string, error) { if err != nil { return "", trace.Wrap(err) } + response := string(body) + if response == constants.NoVersion { + return "", &NoNewVersionError{Message: "version server did not advertise a version"} + } // We trim spaces because the value might end with one or many newlines - version, err := EnsureSemver(strings.TrimSpace(string(body))) + version, err := EnsureSemver(strings.TrimSpace(response)) return version, trace.Wrap(err) } diff --git a/integrations/kube-agent-updater/pkg/version/errors.go b/integrations/kube-agent-updater/pkg/version/errors.go new file mode 100644 index 0000000000000..96344c60858de --- /dev/null +++ b/integrations/kube-agent-updater/pkg/version/errors.go @@ -0,0 +1,36 @@ +/* + * Teleport + * Copyright (C) 2023 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package version + +import "fmt" + +// NoNewVersionError indicates that no new version was found and the controller did not reconcile. +type NoNewVersionError struct { + Message string `json:"message"` + CurrentVersion string `json:"currentVersion"` + NextVersion string `json:"nextVersion"` +} + +// Error returns log friendly description of an error +func (e *NoNewVersionError) Error() string { + if e.Message != "" { + return e.Message + } + return fmt.Sprintf("no new version (current: %q, next: %q)", e.CurrentVersion, e.NextVersion) +} diff --git a/integrations/kube-agent-updater/pkg/version/mock.go b/integrations/kube-agent-updater/pkg/version/mock.go index 6746ebcd62a67..725d8758794f5 100644 --- a/integrations/kube-agent-updater/pkg/version/mock.go +++ b/integrations/kube-agent-updater/pkg/version/mock.go @@ -18,28 +18,38 @@ package version import ( "context" + "fmt" "strings" + + "github.com/gravitational/teleport/integrations/kube-agent-updater/pkg/constants" ) -// GetterMock is a fake version.Getter that return a static answer. This is used +// StaticGetter is a fake version.Getter that return a static answer. This is used // for testing purposes. -type GetterMock struct { +type StaticGetter struct { version string err error } // GetVersion returns the statically defined version. -func (v GetterMock) GetVersion(_ context.Context) (string, error) { +func (v StaticGetter) GetVersion(_ context.Context) (string, error) { return v.version, v.err } -// NewGetterMock creates a GetterMock -func NewGetterMock(version string, err error) Getter { +// NewStaticGetter creates a StaticGetter +func NewStaticGetter(version string, err error) Getter { + if version == constants.NoVersion { + return StaticGetter{ + version: "", + err: &NoNewVersionError{Message: fmt.Sprintf("target version set to '%s'", constants.NoVersion)}, + } + } + semVersion := version if semVersion != "" && !strings.HasPrefix(semVersion, "v") { semVersion = "v" + version } - return GetterMock{ + return StaticGetter{ version: semVersion, err: err, } diff --git a/integrations/kube-agent-updater/pkg/version/versionget.go b/integrations/kube-agent-updater/pkg/version/versionget.go index ace71623b1ae0..00ffdc31fda3f 100644 --- a/integrations/kube-agent-updater/pkg/version/versionget.go +++ b/integrations/kube-agent-updater/pkg/version/versionget.go @@ -28,6 +28,8 @@ import ( // Getter gets the target image version for an external source. It should cache // the result to reduce io and avoid potential rate-limits and is safe to call // multiple times over a short period. +// If the version source intentionally returns no version, a NoNewVersionError is +// returned. type Getter interface { GetVersion(context.Context) (string, error) } @@ -35,7 +37,6 @@ type Getter interface { // ValidVersionChange receives the current version and the candidate next version // and evaluates if the version transition is valid. func ValidVersionChange(ctx context.Context, current, next string) bool { - // TODO: clarify rollback constraints regarding previous version and add a "previous" parameter log := ctrllog.FromContext(ctx).V(1) // Cannot upgrade to a non-valid version if !semver.IsValid(next) { diff --git a/lib/automaticupgrades/basichttp/client.go b/lib/automaticupgrades/basichttp/client.go new file mode 100644 index 0000000000000..b1684a2e42087 --- /dev/null +++ b/lib/automaticupgrades/basichttp/client.go @@ -0,0 +1,59 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package basichttp + +import ( + "context" + "io" + "net/http" + "net/url" + + "github.com/gravitational/trace" +) + +// Client extends the regular http.Client by adding a GetContent method that does +// a GET query to a given URL and returns an error if the status is non-200. +// This is typically used to retrieve small files stored in a S3 bucket like the +// maintenance.BasicHTTPMaintenanceTrigger or the version.BasicHTTPVersionGetter +// are doing. +type Client struct { + *http.Client +} + +// GetContent sends a GET HTTP request and fails if the response is not 200. +func (c *Client) GetContent(ctx context.Context, targetURL url.URL) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL.String(), nil) + if err != nil { + return []byte{}, trace.Wrap(err) + } + res, err := c.Do(req) + if err != nil { + return []byte{}, trace.Wrap(err) + } + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + if err != nil { + return []byte{}, trace.Wrap(err) + } + + if res.StatusCode != http.StatusOK { + return []byte{}, trace.Errorf("non-200 status code received: '%d'", res.StatusCode) + } + + return body, nil +} diff --git a/lib/automaticupgrades/basichttp/servermock.go b/lib/automaticupgrades/basichttp/servermock.go new file mode 100644 index 0000000000000..fa245248d4f0d --- /dev/null +++ b/lib/automaticupgrades/basichttp/servermock.go @@ -0,0 +1,59 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package basichttp + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +// ServerMock is a HTTP server whose response can be controlled from the tests. +// This is used to mock external dependencies like s3 buckets or a remote HTTP server. +type ServerMock struct { + Srv *httptest.Server + + t *testing.T + code int + response string + path string +} + +// SetResponse sets the ServerMock's response. +func (m *ServerMock) SetResponse(t *testing.T, code int, response string) { + m.t = t + m.code = code + m.response = response +} + +// ServeHTTP implements the http.Handler interface. +func (m *ServerMock) ServeHTTP(w http.ResponseWriter, r *http.Request) { + require.Equal(m.t, m.path, r.URL.Path) + w.WriteHeader(m.code) + _, err := io.WriteString(w, m.response) + require.NoError(m.t, err) +} + +// NewServerMock builds and returns a +func NewServerMock(path string) *ServerMock { + mock := ServerMock{path: path} + mock.Srv = httptest.NewServer(http.HandlerFunc(mock.ServeHTTP)) + return &mock +} diff --git a/lib/automaticupgrades/cache/cache.go b/lib/automaticupgrades/cache/cache.go new file mode 100644 index 0000000000000..a03555b3cdd51 --- /dev/null +++ b/lib/automaticupgrades/cache/cache.go @@ -0,0 +1,66 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cache + +import ( + "context" + "sync" + "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" +) + +// TimedMemoize wraps a function returning (value, error) and caches the +// value AND the error for a specific time. This cache is mainly used to ensure +// external calls rates are reasonable. The cache is thread-safe. +type TimedMemoize[T any] struct { + clock clockwork.Clock + mutex sync.Mutex + cachedValue T + cachedError error + validUntil time.Time + cacheDuration time.Duration + getter func(ctx context.Context) (T, error) +} + +// Get does a cache lookup and updates the cache in case of cache miss. +func (c *TimedMemoize[T]) Get(ctx context.Context) (T, error) { + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.validUntil.After(c.clock.Now()) { + // TimedMemoize hit, we return cached result + return c.cachedValue, trace.Wrap(c.cachedError) + } + + // Cache miss, we do a query and update the cache + value, err := c.getter(ctx) + c.validUntil = c.clock.Now().Add(c.cacheDuration) + c.cachedValue = value + c.cachedError = newCachedError(err, c.validUntil) + return value, trace.Wrap(err) +} + +// NewTimedMemoize builds and returns a TimedMemoize +func NewTimedMemoize[T any](getter func(ctx context.Context) (T, error), duration time.Duration) *TimedMemoize[T] { + return &TimedMemoize[T]{ + clock: clockwork.NewRealClock(), + getter: getter, + cacheDuration: duration, + } +} diff --git a/lib/automaticupgrades/cache/cache_test.go b/lib/automaticupgrades/cache/cache_test.go new file mode 100644 index 0000000000000..7e200a79be882 --- /dev/null +++ b/lib/automaticupgrades/cache/cache_test.go @@ -0,0 +1,125 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cache + +import ( + "context" + "testing" + "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" +) + +func failIfCalled(t *testing.T, value string, err error) func(context.Context) (string, error) { + return func(_ context.Context) (string, error) { + require.Fail(t, "should not be called") + return value, err + } +} + +func failIfCalledTwice(t *testing.T, value string, err error) func(context.Context) (string, error) { + var fuse bool + return func(_ context.Context) (string, error) { + if !fuse { + fuse = true + return value, err + } + require.Fail(t, "should not be called twice") + return value, err + } +} + +func TestTimedMemoize_Get(t *testing.T) { + ctx := context.Background() + + now := time.Now() + longBefore := now.Add(-2 * time.Hour) + longAfter := now.Add(2 * time.Hour) + + upstreamValue := "upstream" + cachedValue := "cached" + upstreamError := trace.LimitExceeded("rate-limited") + oldUpstreamError := trace.CompareFailed("comparison failed") + + assertUncachedError := func(t2 require.TestingT, err error, _ ...interface{}) { + require.Equal(t2, err, upstreamError) + } + assertCachedError := func(t2 require.TestingT, err error, _ ...interface{}) { + _, ok := trace.Unwrap(err).(cachedError) + require.True(t2, ok) + } + + tests := []struct { + name string + cachedValue string + cachedError error + validUntil time.Time + getter func(*testing.T, string, error) func(context.Context) (string, error) + expectedValue string + assertErr require.ErrorAssertionFunc + }{ + { + name: "fresh cache", + cachedValue: "", + cachedError: nil, + validUntil: time.Time{}, + getter: failIfCalledTwice, + expectedValue: upstreamValue, + assertErr: assertUncachedError, + }, + { + name: "valid cache", + cachedValue: cachedValue, + cachedError: newCachedError(oldUpstreamError, longAfter), + validUntil: longAfter, + getter: failIfCalled, + expectedValue: cachedValue, + assertErr: assertCachedError, + }, + { + name: "expired cache", + cachedValue: cachedValue, + cachedError: newCachedError(oldUpstreamError, longBefore), + validUntil: longBefore, + getter: failIfCalledTwice, + expectedValue: upstreamValue, + assertErr: assertUncachedError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cache := TimedMemoize[string]{ + cachedValue: tt.cachedValue, + cachedError: tt.cachedError, + validUntil: tt.validUntil, + clock: clockwork.NewFakeClockAt(now), + cacheDuration: 30 * time.Minute, + getter: tt.getter(t, upstreamValue, upstreamError), + } + result, err := cache.Get(ctx) + require.Equal(t, tt.expectedValue, result) + // The first error might or might not be cached + tt.assertErr(t, err) + result, err = cache.Get(ctx) + require.Equal(t, tt.expectedValue, result) + // The second error must be cached + assertCachedError(t, err) + }) + } +} diff --git a/lib/automaticupgrades/cache/error.go b/lib/automaticupgrades/cache/error.go new file mode 100644 index 0000000000000..952ba0c362dad --- /dev/null +++ b/lib/automaticupgrades/cache/error.go @@ -0,0 +1,54 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cache + +import ( + "fmt" + "time" +) + +// cachedError wraps an error before storing it into cache. This adds more +// context into the original error by clearly indicating the error have been +// cached and for how long. +type cachedError struct { + err error + until time.Time +} + +func (e cachedError) Error() string { + return fmt.Sprintf("error cached until '%s': %s", e.until, e.err) +} + +// OrigError returns the original error. This implements trace.ErrorWrapper +// and allows to be unwrapped by trace.Unwrap(). +func (e cachedError) OrigError() error { + return e.err +} + +// Unwrap returns the original error. +func (e cachedError) Unwrap() error { + return e.err +} + +// newCachedError takes an error and wraps it into a cachedError. If there is no +// error, it returns nothing. +func newCachedError(err error, until time.Time) error { + if err == nil { + return nil + } + return cachedError{err, until} +} diff --git a/lib/automaticupgrades/channel.go b/lib/automaticupgrades/channel.go new file mode 100644 index 0000000000000..de2ba83acb796 --- /dev/null +++ b/lib/automaticupgrades/channel.go @@ -0,0 +1,212 @@ +/* + * Teleport + * Copyright (C) 2023 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package automaticupgrades + +import ( + "context" + "net/url" + "strconv" + "strings" + + "github.com/gravitational/trace" + "golang.org/x/mod/semver" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/lib/automaticupgrades/maintenance" + "github.com/gravitational/teleport/lib/automaticupgrades/version" +) + +const ( + DefaultChannelName = "default" + DefaultCloudChannelName = "stable/cloud" + stableCloudVersionBaseURL = "https://updates.releases.teleport.dev/v1/stable/cloud" +) + +// Channels is a map of Channel objects. +type Channels map[string]*Channel + +// CheckAndSetDefaults checks that every Channel is valid and initializes them. +// It also creates default channels if they are not already present. +// Cloud must have the `default` and `stable/cloud` channels. +// Self-hosted with automatic upgrades must have the `default` channel. +func (c Channels) CheckAndSetDefaults(features proto.Features) error { + defaultChannel, err := NewDefaultChannel() + if err != nil { + return trace.Wrap(err) + } + + // If we're on cloud, we need at least "cloud/stable" and "default" + if features.GetCloud() { + if _, ok := c[DefaultCloudChannelName]; !ok { + c[DefaultCloudChannelName] = defaultChannel + } + if _, ok := c[DefaultChannelName]; !ok { + c[DefaultChannelName] = c[DefaultCloudChannelName] + } + } + + // If we're on self-hosted with automatic upgrades, we need a "default" channel + // We don't want to break existing setups so we'll automatically point to the + // `cloud/stable` channel. + // TODO: in v15 make this a hard requirement and error if `default` is not set + // and automatic upgrades are enabled + if features.GetAutomaticUpgrades() { + if _, ok := c[DefaultChannelName]; !ok { + c[DefaultChannelName] = defaultChannel + } + } + + var errs []error + for name, channel := range c { + // Wrapping is not mandatory here, but it adds the channel name in the + // error, which makes troubleshooting easier. + err = trace.Wrap(channel.CheckAndSetDefaults(), "failed to create channel %s", name) + if err != nil { + errs = append(errs, err) + } + } + return trace.NewAggregate(errs...) +} + +// DefaultVersion returns the version served by the default upgrade channel. +func (c Channels) DefaultVersion(ctx context.Context) (string, error) { + channel, ok := c[DefaultChannelName] + if !ok { + return "", trace.NotFound("default version channel not found") + } + targetVersion, err := channel.GetVersion(ctx) + return targetVersion, trace.Wrap(err) +} + +// Channel describes an automatic update channel configuration. +// It can be configured to serve a static version, or forward version requests +// to an upstream version server. Forwarded results are cached for 1 minute. +type Channel struct { + // ForwardURL is the URL of the upstream version server providing the channel version/criticality. + ForwardURL string `yaml:"forward_url,omitempty"` + // StaticVersion is a static version the channel must serve. With or without the leading "v". + StaticVersion string `yaml:"static_version,omitempty"` + // Critical is whether the static version channel should be marked as a critical update. + Critical bool `yaml:"critical"` + + // versionGetter gets the version of the channel. It is populated by CheckAndSetDefaults. + versionGetter version.Getter + // criticalTrigger gets the criticality of the channel. It is populated by CheckAndSetDefaults. + criticalTrigger maintenance.Trigger + // teleportMajor stores the current teleport major for comparison. + // This field is initialized during CheckAndSetDefaults. + teleportMajor int +} + +// CheckAndSetDefaults checks that the Channel configuration is valid and inits +// the version getter and maintenance trigger of the Channel based on its +// configuration. This function must be called before using the channel. +func (c *Channel) CheckAndSetDefaults() error { + switch { + case c.ForwardURL != "" && (c.StaticVersion != "" || c.Critical): + return trace.BadParameter("cannot set both ForwardURL and (StaticVersion or Critical)") + case c.ForwardURL != "": + baseURL, err := url.Parse(c.ForwardURL) + if err != nil { + return trace.Wrap(err) + } + c.versionGetter = version.NewBasicHTTPVersionGetter(baseURL) + c.criticalTrigger = maintenance.NewBasicHTTPMaintenanceTrigger("remote", baseURL) + case c.StaticVersion != "": + c.versionGetter = version.NewStaticGetter(c.StaticVersion, nil) + c.criticalTrigger = maintenance.NewMaintenanceStaticTrigger("remote", c.Critical) + default: + return trace.BadParameter("either ForwardURL or StaticVersion must be set") + } + + var err error + c.teleportMajor, err = parseMajorFromVersionString(teleport.Version) + if err != nil { + return trace.Wrap(err, "failed to process teleport version") + } + + return nil +} + +// GetVersion returns the current version of the channel. If io is involved, +// this function implements cache and is safe to call frequently. +// If the target version major is higher than the Teleport version (the one +// in the Teleport binary, this is usually the proxy version), this function +// returns the Teleport version instead. +// If the version source intentionally did not specify a version, a +// NoNewVersionError is returned. +func (c *Channel) GetVersion(ctx context.Context) (string, error) { + targetVersion, err := c.versionGetter.GetVersion(ctx) + if err != nil { + return "", trace.Wrap(err) + } + + targetMajor, err := parseMajorFromVersionString(targetVersion) + if err != nil { + return "", trace.Wrap(err, "failed to process target version") + } + + // The target version is officially incompatible with our version, + // we prefer returning our version rather than having a broken client + if targetMajor > c.teleportMajor { + return teleport.Version, nil + } + + return targetVersion, nil +} + +// GetCritical returns the current criticality of the channel. If io is involved, +// this function implements cache and is safe to call frequently. +func (c *Channel) GetCritical(ctx context.Context) (bool, error) { + return c.criticalTrigger.CanStart(ctx, nil) +} + +// NewDefaultChannel creates a default automatic upgrade channel +// It looks up the environment variable, and if not found uses the default +// base URL. This default channel can be used in the proxy (to back its own version server) +// or in other Teleport process such as integration services deploying and +// updating teleport agents. +func NewDefaultChannel() (*Channel, error) { + forwardURL := GetChannel() + if forwardURL == "" { + forwardURL = stableCloudVersionBaseURL + } + defaultChannel := &Channel{ + ForwardURL: forwardURL, + } + if err := defaultChannel.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + return defaultChannel, nil +} + +func parseMajorFromVersionString(v string) (int, error) { + v, err := version.EnsureSemver(v) + if err != nil { + return 0, trace.Wrap(err, "invalid semver: %s", v) + } + majorStr := semver.Major(v) + if majorStr == "" { + return 0, trace.BadParameter("cannot detect version major") + } + + major, err := strconv.Atoi(strings.TrimPrefix(majorStr, "v")) + return major, trace.Wrap(err, "cannot convert version major to int") +} diff --git a/lib/automaticupgrades/channel_test.go b/lib/automaticupgrades/channel_test.go new file mode 100644 index 0000000000000..e69e1a6ec3b68 --- /dev/null +++ b/lib/automaticupgrades/channel_test.go @@ -0,0 +1,211 @@ +/* + * Teleport + * Copyright (C) 2023 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package automaticupgrades + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/lib/automaticupgrades/constants" + "github.com/gravitational/teleport/lib/automaticupgrades/maintenance" + "github.com/gravitational/teleport/lib/automaticupgrades/version" +) + +const testVersion = "v1.2.3" + +func Test_Channels_CheckAndSetDefaults(t *testing.T) { + noFeatures := proto.Features{} + cloudFeatures := proto.Features{Cloud: true, AutomaticUpgrades: true} + customChannelURL := "https://foo.example.com/bar" + t.Run("no channels", func(t *testing.T) { + c := Channels{} + require.NoError(t, c.CheckAndSetDefaults(noFeatures)) + }) + t.Run("single channel", func(t *testing.T) { + channel := &Channel{StaticVersion: testVersion} + c := Channels{"foo": channel} + require.NoError(t, c.CheckAndSetDefaults(noFeatures)) + require.NotNil(t, channel.versionGetter) + require.NotNil(t, channel.criticalTrigger) + }) + t.Run("many channels", func(t *testing.T) { + channel1 := &Channel{StaticVersion: testVersion} + channel2 := &Channel{StaticVersion: testVersion} + channel3 := &Channel{StaticVersion: testVersion} + c := Channels{"foo": channel1, "bar": channel2, "baz": channel3} + require.NoError(t, c.CheckAndSetDefaults(noFeatures)) + require.NotNil(t, channel1.versionGetter) + require.NotNil(t, channel1.criticalTrigger) + require.NotNil(t, channel2.versionGetter) + require.NotNil(t, channel2.criticalTrigger) + require.NotNil(t, channel3.versionGetter) + require.NotNil(t, channel3.criticalTrigger) + }) + t.Run("default channels for cloud", func(t *testing.T) { + // Cloud must have `default` and `stable/cloud` channels by default + c := Channels{} + require.NoError(t, c.CheckAndSetDefaults(cloudFeatures)) + require.Len(t, c, 2) + defaultChannel, ok := c[DefaultChannelName] + require.True(t, ok) + require.Equal(t, stableCloudVersionBaseURL, defaultChannel.ForwardURL) + stableCloudChannel, ok := c[DefaultCloudChannelName] + require.True(t, ok) + require.Equal(t, stableCloudVersionBaseURL, stableCloudChannel.ForwardURL) + }) + t.Run("cloud override stable/cloud", func(t *testing.T) { + // When "stable/cloud" channel is configured, CheckAndSetDefaults + // must honor it AND also use it as the "default" channel. + c := Channels{DefaultCloudChannelName: &Channel{ForwardURL: customChannelURL}} + require.NoError(t, c.CheckAndSetDefaults(cloudFeatures)) + require.Len(t, c, 2) + stableCloudChannel, ok := c[DefaultCloudChannelName] + require.True(t, ok) + require.Equal(t, customChannelURL, stableCloudChannel.ForwardURL) + defaultChannel, ok := c[DefaultChannelName] + require.True(t, ok) + require.Equal(t, customChannelURL, defaultChannel.ForwardURL) + }) + t.Run("cloud override default", func(t *testing.T) { + // When the "default" channel is manually configured, CheckAndSetDefaults + // must honor it. + // In this test, only the "default" channel must be custom, the + // "stable/cloud" channel must point to the standard cloud URL. + c := Channels{DefaultChannelName: &Channel{ForwardURL: customChannelURL}} + require.NoError(t, c.CheckAndSetDefaults(cloudFeatures)) + require.Len(t, c, 2) + defaultChannel, ok := c[DefaultChannelName] + require.True(t, ok) + require.Equal(t, customChannelURL, defaultChannel.ForwardURL) + stableCloudChannel, ok := c[DefaultCloudChannelName] + require.True(t, ok) + require.Equal(t, stableCloudVersionBaseURL, stableCloudChannel.ForwardURL) + }) + t.Run("self-hosted no channel", func(t *testing.T) { + // In self-hosted automatic-upgrades setups, we need a default channel. + // For backward compatibility we should add it instead of requiring it. + c := Channels{} + require.NoError(t, c.CheckAndSetDefaults(proto.Features{AutomaticUpgrades: true})) + require.Len(t, c, 1) + defaultChannel, ok := c[DefaultChannelName] + require.True(t, ok) + require.Equal(t, stableCloudVersionBaseURL, defaultChannel.ForwardURL) + _, ok = c[DefaultCloudChannelName] + require.False(t, ok) + }) + +} + +func Test_Channel_CheckAndSetDefaults(t *testing.T) { + + tests := []struct { + name string + channel Channel + assertError require.ErrorAssertionFunc + expectedVersionGetterType interface{} + expectedCriticalTriggerType interface{} + }{ + { + name: "empty (invalid)", + channel: Channel{}, + assertError: require.Error, + }, + { + name: "forward URL (valid)", + channel: Channel{ + ForwardURL: stableCloudVersionBaseURL, + }, + assertError: require.NoError, + expectedVersionGetterType: version.BasicHTTPVersionGetter{}, + expectedCriticalTriggerType: maintenance.BasicHTTPMaintenanceTrigger{}, + }, + { + name: "static version (valid)", + channel: Channel{ + StaticVersion: testVersion, + }, + assertError: require.NoError, + expectedVersionGetterType: version.StaticGetter{}, + expectedCriticalTriggerType: maintenance.StaticTrigger{}, + }, + { + name: "all set (invalid)", + channel: Channel{ + ForwardURL: stableCloudVersionBaseURL, + StaticVersion: testVersion, + }, + assertError: require.Error, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.assertError(t, tt.channel.CheckAndSetDefaults()) + require.IsType(t, tt.expectedVersionGetterType, tt.channel.versionGetter) + require.IsType(t, tt.expectedCriticalTriggerType, tt.channel.criticalTrigger) + }) + } +} + +func Test_Channel_GetVersion(t *testing.T) { + ctx := context.Background() + tests := []struct { + name string + targetVersion string + expectedVersion string + assertErr require.ErrorAssertionFunc + }{ + { + name: "normal version", + targetVersion: "v1.2.3", + expectedVersion: "v1.2.3", + assertErr: require.NoError, + }, + { + name: "no version", + targetVersion: constants.NoVersion, + expectedVersion: "", + assertErr: require.Error, + }, + { + name: "version too high", + targetVersion: "v99.1.1", + expectedVersion: teleport.Version, + assertErr: require.NoError, + }, + { + name: "version invalid", + targetVersion: "foobar", + assertErr: require.Error, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ch := Channel{StaticVersion: tt.targetVersion} + require.NoError(t, ch.CheckAndSetDefaults()) + + result, err := ch.GetVersion(ctx) + tt.assertErr(t, err) + require.Equal(t, tt.expectedVersion, result) + }) + } +} diff --git a/lib/automaticupgrades/constants/constants.go b/lib/automaticupgrades/constants/constants.go new file mode 100644 index 0000000000000..ec40a052f8240 --- /dev/null +++ b/lib/automaticupgrades/constants/constants.go @@ -0,0 +1,34 @@ +/* +Copyright 2023 Gravitational, Inc. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package constants + +import "time" + +const ( + // MaintenancePath is the version discovery endpoint representing if the + // target version represents a critical update as defined in + // https://github.com/gravitational/teleport/blob/master/rfd/0109-cloud-agent-upgrades.md#version-discovery-endpoint + MaintenancePath = "critical" + // VersionPath is the version discovery endpoint returning the current + // target version as defined in + // https://github.com/gravitational/teleport/blob/master/rfd/0109-cloud-agent-upgrades.md#version-discovery-endpoint + VersionPath = "version" + HTTPTimeout = 10 * time.Second + CacheDuration = time.Minute + + // NoVersion is returned by the version endpoint when there is no valid target version. + // This can be caused by the target version being incompatible with the cluster version. + NoVersion = "none" +) diff --git a/lib/automaticupgrades/maintenance/basichttp.go b/lib/automaticupgrades/maintenance/basichttp.go new file mode 100644 index 0000000000000..5df16fbb4a2c0 --- /dev/null +++ b/lib/automaticupgrades/maintenance/basichttp.go @@ -0,0 +1,110 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package maintenance + +import ( + "context" + "net/http" + "net/url" + "strings" + + "github.com/gravitational/trace" + "sigs.k8s.io/controller-runtime/pkg/client" + + "github.com/gravitational/teleport/lib/automaticupgrades/basichttp" + "github.com/gravitational/teleport/lib/automaticupgrades/cache" + "github.com/gravitational/teleport/lib/automaticupgrades/constants" +) + +// basicHTTPMaintenanceClient retrieves whether the target version represents a +// critical update from an HTTP endpoint. It should not be invoked immediately +// and must be wrapped in a cache layer in order to avoid spamming the version +// server in case of reconciliation errors. +// use BasicHTTPMaintenanceTrigger if you need to check if an update is critical. +type basicHTTPMaintenanceClient struct { + baseURL *url.URL + client *basichttp.Client +} + +// Get sends an HTTP GET request and returns whether the current target version +// represents a critical update. +func (b *basicHTTPMaintenanceClient) Get(ctx context.Context) (bool, error) { + versionURL := b.baseURL.JoinPath(constants.MaintenancePath) + body, err := b.client.GetContent(ctx, *versionURL) + if err != nil { + return false, trace.Wrap(err) + } + // Validating early that the payload can be converted to a boolean allows to + // gracefully catch connectivity error caused by mitm infrastructure such as + // corporate proxies. + result, err := stringToBool(strings.TrimSpace(string(body))) + return result, trace.Wrap(err) +} + +// BasicHTTPMaintenanceTrigger gets the critical status from an HTTP response +// containing only a truthy or falsy string. +// This is used typically to trigger emergency maintenances from a file hosted +// in a s3 bucket or raw file served over HTTP. +// BasicHTTPMaintenanceTrigger uses basicHTTPMaintenanceClient and wraps it in a cache +// in order to mitigate the impact of too frequent reconciliations. +// The structure implements the maintenance.Trigger interface. +type BasicHTTPMaintenanceTrigger struct { + name string + cachedGetter func(context.Context) (bool, error) +} + +// Name implements maintenance.Triggernd returns the trigger name for logging +// and debugging pursposes. +func (g BasicHTTPMaintenanceTrigger) Name() string { + return g.name +} + +// Default returns what to do if the trigger can't be evaluated. +// BasicHTTPMaintenanceTrigger should fail open, so the function returns true. +func (g BasicHTTPMaintenanceTrigger) Default() bool { + return false +} + +// CanStart implements maintenance.Trigger +func (g BasicHTTPMaintenanceTrigger) CanStart(ctx context.Context, _ client.Object) (bool, error) { + result, err := g.cachedGetter(ctx) + return result, trace.Wrap(err) +} + +// NewBasicHTTPMaintenanceTrigger builds and return a Trigger checking a public HTTP endpoint. +func NewBasicHTTPMaintenanceTrigger(name string, baseURL *url.URL) Trigger { + client := &http.Client{ + Timeout: constants.HTTPTimeout, + } + httpMaintenanceClient := &basicHTTPMaintenanceClient{ + baseURL: baseURL, + client: &basichttp.Client{Client: client}, + } + + return BasicHTTPMaintenanceTrigger{name, cache.NewTimedMemoize[bool](httpMaintenanceClient.Get, constants.CacheDuration).Get} +} + +func stringToBool(input string) (bool, error) { + switch { + case strings.EqualFold("true", input), strings.EqualFold("yes", input): + return true, nil + case strings.EqualFold("false", input), strings.EqualFold("no", input): + return false, nil + default: + return false, trace.BadParameter("cannot convert input to boolean: %s", input) + } +} diff --git a/lib/automaticupgrades/maintenance/basichttp_test.go b/lib/automaticupgrades/maintenance/basichttp_test.go new file mode 100644 index 0000000000000..4304a78ee16e2 --- /dev/null +++ b/lib/automaticupgrades/maintenance/basichttp_test.go @@ -0,0 +1,106 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package maintenance + +import ( + "context" + "net/http" + "net/url" + "testing" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + + basichttp2 "github.com/gravitational/teleport/lib/automaticupgrades/basichttp" + "github.com/gravitational/teleport/lib/automaticupgrades/constants" +) + +const basicHTTPTestPath = "/v1/cloud-stable" + +func Test_basicHTTPMaintenanceClient_Get(t *testing.T) { + mock := basichttp2.NewServerMock(basicHTTPTestPath + "/" + constants.MaintenancePath) + t.Cleanup(mock.Srv.Close) + serverURL, err := url.Parse(mock.Srv.URL) + serverURL.Path = basicHTTPTestPath + require.NoError(t, err) + ctx := context.Background() + + tests := []struct { + name string + statusCode int + response string + expected bool + assertErr require.ErrorAssertionFunc + }{ + { + name: "all good - no maintenance", + statusCode: http.StatusOK, + response: "no", + expected: false, + assertErr: require.NoError, + }, + { + name: "all good - maintenance", + statusCode: http.StatusOK, + response: "yes", + expected: true, + assertErr: require.NoError, + }, + { + name: "all good with newline", + statusCode: http.StatusOK, + response: "yes\n", + expected: true, + assertErr: require.NoError, + }, + { + name: "invalid response", + statusCode: http.StatusOK, + response: "hello", + expected: false, + assertErr: require.Error, + }, + { + name: "empty", + statusCode: http.StatusOK, + response: "", + expected: false, + assertErr: func(t2 require.TestingT, err2 error, _ ...interface{}) { + require.IsType(t2, &trace.BadParameterError{}, trace.Unwrap(err2)) + }, + }, + { + name: "non-200 response", + statusCode: http.StatusInternalServerError, + response: "ERROR - SOMETHING WENT WRONG", + expected: false, + assertErr: require.Error, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b := &basicHTTPMaintenanceClient{ + baseURL: serverURL, + client: &basichttp2.Client{Client: mock.Srv.Client()}, + } + mock.SetResponse(t, tt.statusCode, tt.response) + result, err := b.Get(ctx) + tt.assertErr(t, err) + require.Equal(t, tt.expected, result) + }) + } +} diff --git a/lib/automaticupgrades/maintenance/mock.go b/lib/automaticupgrades/maintenance/mock.go new file mode 100644 index 0000000000000..a777bfca0764c --- /dev/null +++ b/lib/automaticupgrades/maintenance/mock.go @@ -0,0 +1,55 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package maintenance + +import ( + "context" + + "sigs.k8s.io/controller-runtime/pkg/client" +) + +// StaticTrigger is a fake Trigger that return a static answer. This is used +// for testing purposes and is inherently disruptive. +type StaticTrigger struct { + name string + canStart bool +} + +// Name returns the StaticTrigger name. +func (m StaticTrigger) Name() string { + return m.name +} + +// CanStart returns the statically defined maintenance approval result. +func (m StaticTrigger) CanStart(_ context.Context, _ client.Object) (bool, error) { + return m.canStart, nil +} + +// Default returns the default behavior if the trigger fails. This cannot +// happen for a StaticTrigger and is here solely to implement the Trigger +// interface. +func (m StaticTrigger) Default() bool { + return m.canStart +} + +// NewMaintenanceStaticTrigger creates a StaticTrigger +func NewMaintenanceStaticTrigger(name string, canStart bool) Trigger { + return StaticTrigger{ + name: name, + canStart: canStart, + } +} diff --git a/lib/automaticupgrades/maintenance/trigger.go b/lib/automaticupgrades/maintenance/trigger.go new file mode 100644 index 0000000000000..457078f34fa49 --- /dev/null +++ b/lib/automaticupgrades/maintenance/trigger.go @@ -0,0 +1,62 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package maintenance + +import ( + "context" + + "sigs.k8s.io/controller-runtime/pkg/client" + ctrllog "sigs.k8s.io/controller-runtime/pkg/log" +) + +// Trigger is evaluated to decide whether a maintenance can happen or not. +// Maintenances can happen because of multiple reasons like: +// - attempt to recover from a broken state +// - we are in a maintenance window +// - emergency security patch +// Each Trigger has a Name() for logging purposes and a Default() method +// returning whether the trigger should allow the maintenance or not in case\ +// of error. +type Trigger interface { + Name() string + CanStart(ctx context.Context, object client.Object) (bool, error) + Default() bool +} + +// Triggers is a list of Trigger. Triggers are OR-ed: any trigger firing in the +// list will cause the maintenance to be triggered. +type Triggers []Trigger + +// CanStart checks if the maintenance can be started. It will return true if at +// least a Trigger approves the maintenance. +func (t Triggers) CanStart(ctx context.Context, object client.Object) bool { + log := ctrllog.FromContext(ctx).V(1) + for _, trigger := range t { + start, err := trigger.CanStart(ctx, object) + if err != nil { + start = trigger.Default() + log.Error(err, "trigger failed to evaluate, using its default value", "trigger", trigger.Name(), "defaultValue", start) + } else { + log.Info("trigger evaluated", "trigger", trigger.Name(), "result", start) + } + if start { + log.Info("maintenance triggered", "trigger", trigger.Name()) + return true + } + } + return false +} diff --git a/lib/automaticupgrades/version.go b/lib/automaticupgrades/version.go deleted file mode 100644 index d1f6c393f1605..0000000000000 --- a/lib/automaticupgrades/version.go +++ /dev/null @@ -1,129 +0,0 @@ -/* -Copyright 2023 Gravitational, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package automaticupgrades - -import ( - "context" - "net/http" - "net/url" - "strings" - - "github.com/gravitational/trace" - - "github.com/gravitational/teleport" - "github.com/gravitational/teleport/lib/utils" -) - -const ( - // stableCloudVersionBaseURL is the base URL for the server that returns the current stable/cloud version. - stableCloudVersionBaseURL = "https://updates.releases.teleport.dev" - - // stableCloudVersionPath is the URL path that returns the current stable/cloud version. - stableCloudVersionPath = "/v1/stable/cloud/version" - - // stableCloudCriticalPath is the URL path that returns the stable/cloud critical flag. - stableCloudCriticalPath = "/v1/stable/cloud/critical" -) - -// Version returns the version that should be used for installing Teleport Services -// This is used when installing agents using scripts. -// Even when Teleport Auth/Proxy is using vX, the agents must always respect this version. -func Version(ctx context.Context, versionURL string) (string, error) { - versionURL, err := getVersionURL(versionURL) - if err != nil { - return "", trace.Wrap(err) - } - - resp, err := sendRequest(ctx, versionURL) - if err != nil { - return "", trace.Wrap(err) - } - - return resp, nil -} - -// Critical returns true if a critical upgrade is available. -func Critical(ctx context.Context, criticalURL string) (bool, error) { - criticalURL, err := getCriticalURL(criticalURL) - if err != nil { - return false, trace.Wrap(err) - } - - critical, err := sendRequest(ctx, criticalURL) - if err != nil { - return false, trace.Wrap(err) - } - - // Expects critical endpoint to return either the string "yes" or "no" - switch critical { - case "yes": - return true, nil - case "no": - return false, nil - default: - return false, trace.BadParameter("critical endpoint returned an unexpected value: %v", critical) - } -} - -// sendRequest sends a GET request to the reqURL and returns the response value -func sendRequest(ctx context.Context, reqURL string) (string, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) - if err != nil { - return "", trace.Wrap(err) - } - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return "", trace.Wrap(err) - } - defer resp.Body.Close() - - body, err := utils.ReadAtMost(resp.Body, teleport.MaxHTTPResponseSize) - if err != nil { - return "", trace.Wrap(err) - } - - if resp.StatusCode != http.StatusOK { - return "", trace.BadParameter("invalid status code %d, body: %s", resp.StatusCode, string(body)) - } - - return strings.TrimSpace(string(body)), trace.Wrap(err) -} - -// getVersionURL returns the versionURL or the default stable/cloud version url. -func getVersionURL(versionURL string) (string, error) { - if versionURL != "" { - return versionURL, nil - } - cloudStableVersionURL, err := url.JoinPath(stableCloudVersionBaseURL, stableCloudVersionPath) - if err != nil { - return "", trace.Wrap(err) - } - return cloudStableVersionURL, nil -} - -// getCriticalURL returns the criticalURL or the default stable/cloud critical url. -func getCriticalURL(criticalURL string) (string, error) { - if criticalURL != "" { - return criticalURL, nil - } - cloudStableCriticalURL, err := url.JoinPath(stableCloudVersionBaseURL, stableCloudCriticalPath) - if err != nil { - return "", trace.Wrap(err) - } - return cloudStableCriticalURL, nil -} diff --git a/lib/automaticupgrades/version/basichttp.go b/lib/automaticupgrades/version/basichttp.go new file mode 100644 index 0000000000000..8887bd265faee --- /dev/null +++ b/lib/automaticupgrades/version/basichttp.go @@ -0,0 +1,83 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package version + +import ( + "context" + "net/http" + "net/url" + "strings" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/lib/automaticupgrades/basichttp" + "github.com/gravitational/teleport/lib/automaticupgrades/cache" + "github.com/gravitational/teleport/lib/automaticupgrades/constants" +) + +// basicHTTPVersionClient retrieves the version from an HTTP endpoint +// it should not be invoked immediately and must be wrapped in a cache layer in +// order to avoid spamming the version server in case of reconciliation errors. +// use BasicHTTPVersionGetter if you need to get a version. +type basicHTTPVersionClient struct { + baseURL *url.URL + client *basichttp.Client +} + +// Get sends an HTTP GET request and returns the version prefixed by "v". +// It expects the endpoint to be unauthenticated, return 200 and the response +// body to contain a valid semver tag without the "v". +func (b *basicHTTPVersionClient) Get(ctx context.Context) (string, error) { + versionURL := b.baseURL.JoinPath(constants.VersionPath) + body, err := b.client.GetContent(ctx, *versionURL) + if err != nil { + return "", trace.Wrap(err) + } + response := string(body) + if response == constants.NoVersion { + return "", &NoNewVersionError{Message: "version server did not advertise a version"} + } + // We trim spaces because the value might end with one or many newlines + version, err := EnsureSemver(strings.TrimSpace(response)) + return version, trace.Wrap(err) +} + +// BasicHTTPVersionGetter gets the version from an HTTP response containing +// only the version. This is used typically to get version from a file hosted +// in a s3 bucket or raw file served over HTTP. +// BasicHTTPVersionGetter uses basicHTTPVersionClient and wraps it in a cache +// in order to mitigate the impact of too frequent reconciliations. +// The structure implements the version.Getter interface. +type BasicHTTPVersionGetter struct { + versionGetter func(context.Context) (string, error) +} + +func (g BasicHTTPVersionGetter) GetVersion(ctx context.Context) (string, error) { + return g.versionGetter(ctx) +} + +func NewBasicHTTPVersionGetter(baseURL *url.URL) Getter { + client := &http.Client{ + Timeout: constants.HTTPTimeout, + } + httpVersionClient := &basicHTTPVersionClient{ + baseURL: baseURL, + client: &basichttp.Client{Client: client}, + } + + return BasicHTTPVersionGetter{cache.NewTimedMemoize[string](httpVersionClient.Get, constants.CacheDuration).Get} +} diff --git a/lib/automaticupgrades/version/basichttp_test.go b/lib/automaticupgrades/version/basichttp_test.go new file mode 100644 index 0000000000000..05db877fdfbbf --- /dev/null +++ b/lib/automaticupgrades/version/basichttp_test.go @@ -0,0 +1,101 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package version + +import ( + "context" + "net/http" + "net/url" + "testing" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + + basichttp2 "github.com/gravitational/teleport/lib/automaticupgrades/basichttp" + "github.com/gravitational/teleport/lib/automaticupgrades/constants" +) + +const basicHTTPTestPath = "/v1/cloud-stable" + +func Test_basicHTTPVersionClient_Get(t *testing.T) { + mock := basichttp2.NewServerMock(basicHTTPTestPath + "/" + constants.VersionPath) + t.Cleanup(mock.Srv.Close) + serverURL, err := url.Parse(mock.Srv.URL) + serverURL.Path = basicHTTPTestPath + require.NoError(t, err) + ctx := context.Background() + + tests := []struct { + name string + statusCode int + response string + expected string + assertErr require.ErrorAssertionFunc + }{ + { + name: "all good", + statusCode: http.StatusOK, + response: "12.0.3", + expected: "v12.0.3", + assertErr: require.NoError, + }, + { + name: "all good with newline", + statusCode: http.StatusOK, + response: "12.0.3\n", + expected: "v12.0.3", + assertErr: require.NoError, + }, + { + name: "non-semver", + statusCode: http.StatusOK, + response: "hello", + expected: "", + assertErr: func(t2 require.TestingT, err2 error, _ ...interface{}) { + require.IsType(t2, &trace.BadParameterError{}, trace.Unwrap(err2)) + }, + }, + { + name: "empty", + statusCode: http.StatusOK, + response: "", + expected: "", + assertErr: func(t2 require.TestingT, err2 error, _ ...interface{}) { + require.IsType(t2, &trace.BadParameterError{}, trace.Unwrap(err2)) + }, + }, + { + name: "non-200 response", + statusCode: http.StatusInternalServerError, + response: "ERROR - SOMETHING WENT WRONG", + expected: "", + assertErr: require.Error, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b := &basicHTTPVersionClient{ + baseURL: serverURL, + client: &basichttp2.Client{Client: mock.Srv.Client()}, + } + mock.SetResponse(t, tt.statusCode, tt.response) + result, err := b.Get(ctx) + tt.assertErr(t, err) + require.Equal(t, tt.expected, result) + }) + } +} diff --git a/lib/automaticupgrades/version/errors.go b/lib/automaticupgrades/version/errors.go new file mode 100644 index 0000000000000..96344c60858de --- /dev/null +++ b/lib/automaticupgrades/version/errors.go @@ -0,0 +1,36 @@ +/* + * Teleport + * Copyright (C) 2023 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package version + +import "fmt" + +// NoNewVersionError indicates that no new version was found and the controller did not reconcile. +type NoNewVersionError struct { + Message string `json:"message"` + CurrentVersion string `json:"currentVersion"` + NextVersion string `json:"nextVersion"` +} + +// Error returns log friendly description of an error +func (e *NoNewVersionError) Error() string { + if e.Message != "" { + return e.Message + } + return fmt.Sprintf("no new version (current: %q, next: %q)", e.CurrentVersion, e.NextVersion) +} diff --git a/lib/automaticupgrades/version/static.go b/lib/automaticupgrades/version/static.go new file mode 100644 index 0000000000000..ec3e629c1d2b7 --- /dev/null +++ b/lib/automaticupgrades/version/static.go @@ -0,0 +1,56 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package version + +import ( + "context" + "fmt" + "strings" + + "github.com/gravitational/teleport/lib/automaticupgrades/constants" +) + +// StaticGetter is a fake version.Getter that return a static answer. This is used +// for testing purposes. +type StaticGetter struct { + version string + err error +} + +// GetVersion returns the statically defined version. +func (v StaticGetter) GetVersion(_ context.Context) (string, error) { + return v.version, v.err +} + +// NewStaticGetter creates a StaticGetter +func NewStaticGetter(version string, err error) Getter { + if version == constants.NoVersion { + return StaticGetter{ + version: "", + err: &NoNewVersionError{Message: fmt.Sprintf("target version set to '%s'", constants.NoVersion)}, + } + } + + semVersion := version + if semVersion != "" && !strings.HasPrefix(semVersion, "v") { + semVersion = "v" + version + } + return StaticGetter{ + version: semVersion, + err: err, + } +} diff --git a/lib/automaticupgrades/version/versionget.go b/lib/automaticupgrades/version/versionget.go new file mode 100644 index 0000000000000..00ffdc31fda3f --- /dev/null +++ b/lib/automaticupgrades/version/versionget.go @@ -0,0 +1,65 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package version + +import ( + "context" + "strings" + + "github.com/gravitational/trace" + "golang.org/x/mod/semver" + ctrllog "sigs.k8s.io/controller-runtime/pkg/log" +) + +// Getter gets the target image version for an external source. It should cache +// the result to reduce io and avoid potential rate-limits and is safe to call +// multiple times over a short period. +// If the version source intentionally returns no version, a NoNewVersionError is +// returned. +type Getter interface { + GetVersion(context.Context) (string, error) +} + +// ValidVersionChange receives the current version and the candidate next version +// and evaluates if the version transition is valid. +func ValidVersionChange(ctx context.Context, current, next string) bool { + log := ctrllog.FromContext(ctx).V(1) + // Cannot upgrade to a non-valid version + if !semver.IsValid(next) { + log.Error(trace.BadParameter("next version is not following semver"), "version change is invalid", "nextVersion", next) + return false + } + switch semver.Compare(next, current) { + // No need to upgrade if version is the same + case 0: + return false + default: + return true + } +} + +// EnsureSemver adds the 'v' prefix if needed and ensures the provided version +// is semver-compliant. +func EnsureSemver(current string) (string, error) { + if !strings.HasPrefix(current, "v") { + current = "v" + current + } + if !semver.IsValid(current) { + return "", trace.BadParameter("tag %s is not following semver", current) + } + return current, nil +} diff --git a/lib/automaticupgrades/version/versionget_test.go b/lib/automaticupgrades/version/versionget_test.go new file mode 100644 index 0000000000000..f1d1b46fe56bf --- /dev/null +++ b/lib/automaticupgrades/version/versionget_test.go @@ -0,0 +1,71 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package version + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +const ( + semverLow = "v11.3.2" + semverMid = "v11.5.4" + semverHigh = "v12.2.1" + invalidSemverHigh = "12.2.1" +) + +func TestValidVersionChange(t *testing.T) { + ctx := context.Background() + tests := []struct { + name string + current string + next string + want bool + }{ + { + name: "upgrade", + current: semverMid, + next: semverHigh, + want: true, + }, + { + name: "same version", + current: semverMid, + next: semverMid, + want: false, + }, + { + name: "unknown current version", + current: "", + next: semverMid, + want: true, + }, + { + name: "non-semver current version", + current: semverMid, + next: invalidSemverHigh, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, ValidVersionChange(ctx, tt.current, tt.next)) + }) + } +} diff --git a/lib/automaticupgrades/version_test.go b/lib/automaticupgrades/version_test.go deleted file mode 100644 index 4ba8d299201e3..0000000000000 --- a/lib/automaticupgrades/version_test.go +++ /dev/null @@ -1,201 +0,0 @@ -/* -Copyright 2023 Gravitational, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package automaticupgrades - -import ( - "context" - "net/http" - "net/http/httptest" - "net/url" - "testing" - - "github.com/gravitational/trace" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestVersion(t *testing.T) { - ctx := context.Background() - - isBadParameterErr := func(tt require.TestingT, err error, i ...any) { - require.True(tt, trace.IsBadParameter(err), "expected bad parameter, got %v", err) - } - - for _, tt := range []struct { - name string - mockStatusCode int - mockResponseString string - errCheck require.ErrorAssertionFunc - expectedVersion string - }{ - { - name: "real response", - mockStatusCode: http.StatusOK, - mockResponseString: "v13.1.1\n", - errCheck: require.NoError, - expectedVersion: "v13.1.1", - }, - { - name: "invalid status code (500)", - mockStatusCode: http.StatusInternalServerError, - errCheck: isBadParameterErr, - }, - { - name: "invalid status code (403)", - mockStatusCode: http.StatusForbidden, - errCheck: isBadParameterErr, - }, - { - name: "valid but has spaces", - mockStatusCode: http.StatusOK, - mockResponseString: " v13.1.1 \n \r\n", - errCheck: require.NoError, - expectedVersion: "v13.1.1", - }, - } { - t.Run(tt.name, func(t *testing.T) { - httpTestServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "/v1/stable/cloud/version", r.URL.Path) - w.WriteHeader(tt.mockStatusCode) - w.Write([]byte(tt.mockResponseString)) - })) - defer httpTestServer.Close() - - versionURL, err := url.JoinPath(httpTestServer.URL, "/v1/stable/cloud/version") - require.NoError(t, err) - - v, err := Version(ctx, versionURL) - tt.errCheck(t, err) - if err != nil { - return - } - - require.Equal(t, tt.expectedVersion, v) - }) - } -} - -func TestCritical(t *testing.T) { - ctx := context.Background() - - isBadParameterErr := func(tt require.TestingT, err error, i ...any) { - require.True(tt, trace.IsBadParameter(err), "expected bad parameter, got %v", err) - } - - for _, tt := range []struct { - name string - mockStatusCode int - mockResponseString string - errCheck require.ErrorAssertionFunc - expectedCritical bool - }{ - { - name: "critical available", - mockStatusCode: http.StatusOK, - mockResponseString: "yes\n", - errCheck: require.NoError, - expectedCritical: true, - }, - { - name: "critical is not available", - mockStatusCode: http.StatusOK, - mockResponseString: "no\n", - errCheck: require.NoError, - expectedCritical: false, - }, - { - name: "invalid status code (500)", - mockStatusCode: http.StatusInternalServerError, - errCheck: isBadParameterErr, - }, - { - name: "invalid status code (403)", - mockStatusCode: http.StatusForbidden, - errCheck: isBadParameterErr, - }, - } { - t.Run(tt.name, func(t *testing.T) { - httpTestServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "/v1/stable/cloud/critical", r.URL.Path) - w.WriteHeader(tt.mockStatusCode) - w.Write([]byte(tt.mockResponseString)) - })) - defer httpTestServer.Close() - - criticalURL, err := url.JoinPath(httpTestServer.URL, "/v1/stable/cloud/critical") - require.NoError(t, err) - - v, err := Critical(ctx, criticalURL) - tt.errCheck(t, err) - if err != nil { - return - } - - require.Equal(t, tt.expectedCritical, v) - }) - } -} - -func TestGetVersionURL(t *testing.T) { - for _, tt := range []struct { - name string - versionURL string - expectedURL string - }{ - { - name: "default stable/cloud version url", - versionURL: "", - expectedURL: "https://updates.releases.teleport.dev/v1/stable/cloud/version", - }, - { - name: "custom version url", - versionURL: "https://custom.dev/version", - expectedURL: "https://custom.dev/version", - }, - } { - t.Run(tt.name, func(t *testing.T) { - v, err := getVersionURL(tt.versionURL) - require.NoError(t, err) - require.Equal(t, tt.expectedURL, v) - }) - } -} - -func TestGetCriticalURL(t *testing.T) { - for _, tt := range []struct { - name string - criticalURL string - expectedURL string - }{ - { - name: "default stable/cloud critical url", - criticalURL: "", - expectedURL: "https://updates.releases.teleport.dev/v1/stable/cloud/critical", - }, - { - name: "custom critical url", - criticalURL: "https://custom.dev/critical", - expectedURL: "https://custom.dev/critical", - }, - } { - t.Run(tt.name, func(t *testing.T) { - v, err := getCriticalURL(tt.criticalURL) - require.NoError(t, err) - require.Equal(t, tt.expectedURL, v) - }) - } -} diff --git a/lib/config/configuration.go b/lib/config/configuration.go index 08b51f59fb332..75e0675df5a7e 100644 --- a/lib/config/configuration.go +++ b/lib/config/configuration.go @@ -1003,6 +1003,10 @@ func applyProxyConfig(fc *FileConfig, cfg *servicecfg.Config) error { } } + if fc.Proxy.AutomaticUpgradesChannels != nil { + cfg.Proxy.AutomaticUpgradesChannels = fc.Proxy.AutomaticUpgradesChannels + } + if fc.Proxy.MySQLServerVersion != "" { cfg.Proxy.MySQLServerVersion = fc.Proxy.MySQLServerVersion } diff --git a/lib/config/fileconf.go b/lib/config/fileconf.go index c6173f32535c6..1b880520e8e57 100644 --- a/lib/config/fileconf.go +++ b/lib/config/fileconf.go @@ -45,6 +45,7 @@ import ( apiutils "github.com/gravitational/teleport/api/utils" awsapiutils "github.com/gravitational/teleport/api/utils/aws" "github.com/gravitational/teleport/api/utils/tlsutils" + "github.com/gravitational/teleport/lib/automaticupgrades" "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/defaults" @@ -2181,6 +2182,11 @@ type Proxy struct { // the "X-Forwarded-For" headers for web APIs received from layer 7 load // balancers or reverse proxies. TrustXForwardedFor types.Bool `yaml:"trust_x_forwarded_for,omitempty"` + + // AutomaticUpgradesChannels is a map of all version channels used by the + // proxy built-in version server to retrieve target versions. This is part + // of the automatic upgrades. + AutomaticUpgradesChannels automaticupgrades.Channels `yaml:"automatic_upgrades_channels,omitempty"` } // UIConfig provides config options for the web UI served by the proxy service. diff --git a/lib/service/awsoidc.go b/lib/service/awsoidc.go index 07b5bc468b765..9ae8892b4d55f 100644 --- a/lib/service/awsoidc.go +++ b/lib/service/awsoidc.go @@ -19,7 +19,6 @@ package service import ( "context" "fmt" - "net/url" "strings" "time" @@ -67,18 +66,13 @@ func (process *TeleportProcess) initAWSOIDCDeployServiceUpdater() error { return nil } - // If criticalEndpoint or versionEndpoint are empty, the default stable/cloud endpoint will be used - var criticalEndpoint string - var versionEndpoint string - if automaticupgrades.GetChannel() != "" { - criticalEndpoint, err = url.JoinPath(automaticupgrades.GetChannel(), "critical") - if err != nil { - return trace.Wrap(err) - } - versionEndpoint, err = url.JoinPath(automaticupgrades.GetChannel(), "version") - if err != nil { - return trace.Wrap(err) - } + // TODO: use the proxy channel if available? + // This would require to pass the proxy configuration there, but would avoid + // future inconsistencies: if the proxy is manually configured to serve a + // static version, it will not be picked up by the AWS OIDC deploy updater. + upgradeChannel, err := automaticupgrades.NewDefaultChannel() + if err != nil { + return trace.Wrap(err) } issuer, err := awsoidc.IssuerFromPublicAddress(process.proxyPublicAddr().Addr) @@ -98,8 +92,7 @@ func (process *TeleportProcess) initAWSOIDCDeployServiceUpdater() error { TeleportClusterName: clusterNameConfig.GetClusterName(), TeleportClusterVersion: resp.GetServerVersion(), AWSOIDCProviderAddr: issuer, - CriticalEndpoint: criticalEndpoint, - VersionEndpoint: versionEndpoint, + UpgradeChannel: upgradeChannel, }) if err != nil { return trace.Wrap(err) @@ -123,10 +116,8 @@ type AWSOIDCDeployServiceUpdaterConfig struct { TeleportClusterVersion string // AWSOIDCProvderAddr specifies the AWS OIDC provider address used to generate AWS OIDC tokens AWSOIDCProviderAddr string - // CriticalEndpoint specifies the endpoint to check for critical updates - CriticalEndpoint string - // VersionEndpoint specifies the endpoint to check for current teleport version - VersionEndpoint string + // UpgradeChannel is the channel that serves the version used by the updater. + UpgradeChannel *automaticupgrades.Channel } // CheckAndSetDefaults checks and sets default config values. @@ -201,7 +192,7 @@ func (updater *AWSOIDCDeployServiceUpdater) updateAWSOIDCDeployServices(ctx cont return trace.Wrap(err) } - critical, err := automaticupgrades.Critical(ctx, updater.CriticalEndpoint) + critical, err := updater.UpgradeChannel.GetCritical(ctx) if err != nil { return trace.Wrap(err) } @@ -212,7 +203,7 @@ func (updater *AWSOIDCDeployServiceUpdater) updateAWSOIDCDeployServices(ctx cont return nil } - stableVersion, err := automaticupgrades.Version(ctx, updater.VersionEndpoint) + stableVersion, err := updater.UpgradeChannel.GetVersion(ctx) if err != nil { return trace.Wrap(err) } diff --git a/lib/service/service.go b/lib/service/service.go index c47ac8b1fad38..193fdef2d11b9 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -3947,33 +3947,34 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { } webConfig := web.Config{ - Proxy: tsrv, - AuthServers: cfg.AuthServerAddresses()[0], - DomainName: cfg.Hostname, - ProxyClient: conn.Client, - ProxySSHAddr: proxySSHAddr, - ProxyWebAddr: cfg.Proxy.WebAddr, - ProxyPublicAddrs: cfg.Proxy.PublicAddrs, - CipherSuites: cfg.CipherSuites, - FIPS: cfg.FIPS, - AccessPoint: accessPoint, - Emitter: streamEmitter, - PluginRegistry: process.PluginRegistry, - HostUUID: process.Config.HostUUID, - Context: process.ExitContext(), - StaticFS: fs, - ClusterFeatures: process.getClusterFeatures(), - UI: cfg.Proxy.UI, - ProxySettings: proxySettings, - PublicProxyAddr: process.proxyPublicAddr().Addr, - ALPNHandler: alpnHandlerForWeb.HandleConnection, - ProxyKubeAddr: proxyKubeAddr, - TraceClient: traceClt, - Router: proxyRouter, - SessionControl: sessionController, - PROXYSigner: proxySigner, - OpenAIConfig: cfg.Testing.OpenAIConfig, - NodeWatcher: nodeWatcher, + Proxy: tsrv, + AuthServers: cfg.AuthServerAddresses()[0], + DomainName: cfg.Hostname, + ProxyClient: conn.Client, + ProxySSHAddr: proxySSHAddr, + ProxyWebAddr: cfg.Proxy.WebAddr, + ProxyPublicAddrs: cfg.Proxy.PublicAddrs, + CipherSuites: cfg.CipherSuites, + FIPS: cfg.FIPS, + AccessPoint: accessPoint, + Emitter: streamEmitter, + PluginRegistry: process.PluginRegistry, + HostUUID: process.Config.HostUUID, + Context: process.ExitContext(), + StaticFS: fs, + ClusterFeatures: process.getClusterFeatures(), + UI: cfg.Proxy.UI, + ProxySettings: proxySettings, + PublicProxyAddr: process.proxyPublicAddr().Addr, + ALPNHandler: alpnHandlerForWeb.HandleConnection, + ProxyKubeAddr: proxyKubeAddr, + TraceClient: traceClt, + Router: proxyRouter, + SessionControl: sessionController, + PROXYSigner: proxySigner, + OpenAIConfig: cfg.Testing.OpenAIConfig, + NodeWatcher: nodeWatcher, + AutomaticUpgradesChannels: cfg.Proxy.AutomaticUpgradesChannels, } webHandler, err := web.NewHandler(webConfig) if err != nil { diff --git a/lib/service/servicecfg/proxy.go b/lib/service/servicecfg/proxy.go index af2ca611fc2ee..27891645b94e7 100644 --- a/lib/service/servicecfg/proxy.go +++ b/lib/service/servicecfg/proxy.go @@ -24,6 +24,7 @@ import ( "github.com/gravitational/trace" "github.com/gravitational/teleport/api/client/webclient" + "github.com/gravitational/teleport/lib/automaticupgrades" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/utils" @@ -151,6 +152,11 @@ type ProxyConfig struct { // as a label and used by reverse tunnel agents in proxy peering mode. Zero // is a valid generation. ProxyGroupGeneration uint64 + + // AutomaticUpgradesChannels is a map of all version channels used by the + // proxy built-in version server to retrieve target versions. This is part + // of the automatic upgrades. + AutomaticUpgradesChannels automaticupgrades.Channels } // WebPublicAddr returns the address for the web endpoint on this proxy that diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index b4e757fbb4db0..7b6ab279d3961 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -269,11 +269,10 @@ type Config struct { // the proxy's cache and get nodes in real time. NodeWatcher *services.NodeWatcher - // AutomaticUpgradesVersionURL is the URL which returns the target agent version. - // This URL must returns a valid version string. - // Eg, v13.4.3 - // Optional: uses cloud/stable channel when omitted. - AutomaticUpgradesVersionURL string + // AutomaticUpgradesChannels is a map of all version channels used by the + // proxy built-in version server to retrieve target versions. This is part + // of the automatic upgrades. + AutomaticUpgradesChannels automaticupgrades.Channels } // SetDefaults ensures proper default values are set if @@ -350,6 +349,17 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) { h.assistantLimiter = rate.NewLimiter(rate.Inf, 0) } + if automaticUpgrades(cfg.ClusterFeatures) && h.cfg.AutomaticUpgradesChannels == nil { + h.cfg.AutomaticUpgradesChannels = automaticupgrades.Channels{} + } + + if h.cfg.AutomaticUpgradesChannels != nil { + err := h.cfg.AutomaticUpgradesChannels.CheckAndSetDefaults(cfg.ClusterFeatures) + if err != nil { + return nil, trace.Wrap(err) + } + } + // for properly handling url-encoded parameter values. h.UseRawPath = true @@ -828,6 +838,10 @@ func (h *Handler) bindDefaultEndpoints() { // Updates the user's preferences h.PUT("/webapi/user/preferences", h.WithAuth(h.updateUserPreferences)) + + // Implements the agent version server. + // Channel can contain "/", hence the use of a catch-all parameter + h.GET("/webapi/automaticupgrades/channel/*request", h.WithLimiter(h.automaticUpgrades)) } // GetProxyClient returns authenticated auth server client @@ -1477,7 +1491,7 @@ func (h *Handler) getWebConfig(w http.ResponseWriter, r *http.Request, p httprou automaticUpgradesEnabled := clusterFeatures.GetAutomaticUpgrades() var automaticUpgradesTargetVersion string if automaticUpgradesEnabled { - automaticUpgradesTargetVersion, err = automaticupgrades.Version(r.Context(), h.cfg.AutomaticUpgradesVersionURL) + automaticUpgradesTargetVersion, err = h.cfg.AutomaticUpgradesChannels.DefaultVersion(r.Context()) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 0c34b1621ddc5..49575adb4dc4f 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -96,6 +96,7 @@ import ( "github.com/gravitational/teleport/lib/auth/testauthority" wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes" "github.com/gravitational/teleport/lib/authz" + "github.com/gravitational/teleport/lib/automaticupgrades" "github.com/gravitational/teleport/lib/bpf" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/client/conntest" @@ -4584,20 +4585,21 @@ func TestGetWebConfig(t *testing.T) { } env.proxies[0].handler.handler.cfg.ProxySettings = mockProxySetting - httpTestServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, r.URL.Path, "/v1/stable/cloud/version") - w.WriteHeader(http.StatusOK) - w.Write([]byte("v99.0.1")) - })) - defer httpTestServer.Close() - versionURL, err := url.JoinPath(httpTestServer.URL, "/v1/stable/cloud/version") require.NoError(t, err) - env.proxies[0].handler.handler.cfg.AutomaticUpgradesVersionURL = versionURL + // This version is too high and MUST NOT be used + testVersion := "v99.0.1" + channels := automaticupgrades.Channels{ + automaticupgrades.DefaultCloudChannelName: { + StaticVersion: testVersion, + }, + } + require.NoError(t, channels.CheckAndSetDefaults(authproto.Features{AutomaticUpgrades: true, Cloud: true})) + env.proxies[0].handler.handler.cfg.AutomaticUpgradesChannels = channels expectedCfg.IsCloud = true expectedCfg.IsUsageBasedBilling = true expectedCfg.AutomaticUpgrades = true - expectedCfg.AutomaticUpgradesTargetVersion = "v99.0.1" + expectedCfg.AutomaticUpgradesTargetVersion = teleport.Version expectedCfg.AssistEnabled = false // request and verify enabled features are enabled. diff --git a/lib/web/automaticupgrades.go b/lib/web/automaticupgrades.go new file mode 100644 index 0000000000000..6b7833dc629e2 --- /dev/null +++ b/lib/web/automaticupgrades.go @@ -0,0 +1,119 @@ +/* + * Teleport + * Copyright (C) 2023 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package web + +import ( + "context" + "errors" + "net/http" + "strings" + "time" + + "github.com/gravitational/trace" + "github.com/julienschmidt/httprouter" + + "github.com/gravitational/teleport/lib/automaticupgrades" + "github.com/gravitational/teleport/lib/automaticupgrades/constants" + "github.com/gravitational/teleport/lib/automaticupgrades/version" +) + +const defaultChannelTimeout = 5 * time.Second + +// automaticUpgrades implements a version server in the Teleport Proxy. +// It is configured through the Teleport Proxy configuration and tells agent updaters +// which version they should install. +func (h *Handler) automaticUpgrades(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) { + if h.cfg.AutomaticUpgradesChannels == nil { + return nil, trace.AccessDenied("This proxy is not configured to serve automatic upgrades channels.") + } + + // The request format is "/{version,critical}" + // As might contain "/" we have to split, pop the last part + // and re-construct the channel name. + channelAndType := p.ByName("request") + + reqParts := strings.Split(strings.Trim(channelAndType, "/"), "/") + if len(reqParts) < 2 { + return nil, trace.BadParameter("path format should be /webapi/automaticupgrades/channel//{version,critical}") + } + requestType := reqParts[len(reqParts)-1] + channelName := strings.Join(reqParts[:len(reqParts)-1], "/") + + if channelName == "" { + return nil, trace.BadParameter("a channel name is required") + } + + // We check if the channel is configured + channel, ok := h.cfg.AutomaticUpgradesChannels[channelName] + if !ok { + return nil, trace.NotFound("channel %s not found", channelName) + } + + // Finally, we treat the request based on its type + switch requestType { + case "version": + h.log.Debugf("Agent requesting version for channel %s", channelName) + return h.automaticUpgradesVersion(w, r, channel) + case "critical": + h.log.Debugf("Agent requesting criticality for channel %s", channelName) + return h.automaticUpgradesCritical(w, r, channel) + default: + return nil, trace.BadParameter("requestType path must end with 'version' or 'critical'") + } +} + +// automaticUpgradesVersion handles version requests from upgraders +func (h *Handler) automaticUpgradesVersion(w http.ResponseWriter, r *http.Request, channel *automaticupgrades.Channel) (interface{}, error) { + ctx, cancel := context.WithTimeout(r.Context(), defaultChannelTimeout) + defer cancel() + + targetVersion, err := channel.GetVersion(ctx) + if err != nil { + // If the error is that the upstream channel has no version + // We gracefully handle by serving "none" + var NoNewVersionErr *version.NoNewVersionError + if errors.As(trace.Unwrap(err), &NoNewVersionErr) { + _, err = w.Write([]byte(constants.NoVersion)) + return nil, trace.Wrap(err) + } + // Else we propagate the error + return nil, trace.Wrap(err) + } + + _, err = w.Write([]byte(targetVersion)) + return nil, trace.Wrap(err) +} + +// automaticUpgradesCritical handles criticality requests from upgraders +func (h *Handler) automaticUpgradesCritical(w http.ResponseWriter, r *http.Request, channel *automaticupgrades.Channel) (interface{}, error) { + ctx, cancel := context.WithTimeout(r.Context(), defaultChannelTimeout) + defer cancel() + + critical, err := channel.GetCritical(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + response := "no" + if critical { + response = "yes" + } + _, err = w.Write([]byte(response)) + return nil, trace.Wrap(err) +} diff --git a/lib/web/integrations_awsoidc.go b/lib/web/integrations_awsoidc.go index 8a6b393fa3eb7..1eb3b228e53e8 100644 --- a/lib/web/integrations_awsoidc.go +++ b/lib/web/integrations_awsoidc.go @@ -26,7 +26,6 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/aws" - "github.com/gravitational/teleport/lib/automaticupgrades" "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/integrations/awsoidc" "github.com/gravitational/teleport/lib/reversetunnelclient" @@ -149,7 +148,7 @@ func (h *Handler) awsOIDCDeployService(w http.ResponseWriter, r *http.Request, p teleportVersionTag := teleport.Version if automaticUpgrades(h.ClusterFeatures) { - cloudStableVersion, err := automaticupgrades.Version(ctx, "" /* use default version server */) + cloudStableVersion, err := h.cfg.AutomaticUpgradesChannels.DefaultVersion(ctx) if err != nil { return "", trace.Wrap(err) } diff --git a/lib/web/join_tokens.go b/lib/web/join_tokens.go index 9423b24b662da..1045bd8cd95c6 100644 --- a/lib/web/join_tokens.go +++ b/lib/web/join_tokens.go @@ -40,7 +40,6 @@ import ( "github.com/gravitational/teleport/api/types" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/auth" - "github.com/gravitational/teleport/lib/automaticupgrades" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/modules" @@ -77,9 +76,10 @@ type scriptSettings struct { databaseInstallMode bool installUpdater bool - // automaticUpgradesVersionURL is the URL for getting the version when using the cloud/stable channel. - // Optional. - automaticUpgradesVersionURL string + // automaticUpgradesVersion is the target automatic upgrades version. + // The version must be valid semver, with the leading 'v'. e.g. v15.0.0-dev + // Required when installUpdater is true. + automaticUpgradesVersion string } // automaticUpgrades returns whether automaticUpgrades should be enabled. @@ -206,14 +206,38 @@ func (h *Handler) createTokenHandle(w http.ResponseWriter, r *http.Request, para }, nil } +// getAutoUpgrades checks if automaticUpgrades are enabled and returns the +// version that should be used according to auto upgrades default channel. +func (h *Handler) getAutoUpgrades(ctx context.Context) (bool, string, error) { + var autoUpgradesVersion string + var err error + autoUpgrades := automaticUpgrades(h.ClusterFeatures) + if autoUpgrades { + autoUpgradesVersion, err = h.cfg.AutomaticUpgradesChannels.DefaultVersion(ctx) + if err != nil { + log.WithError(err).Info("Failed to get auto upgrades version.") + return false, "", trace.Wrap(err) + } + } + return autoUpgrades, autoUpgradesVersion, nil + +} + func (h *Handler) getNodeJoinScriptHandle(w http.ResponseWriter, r *http.Request, params httprouter.Params) (interface{}, error) { httplib.SetScriptHeaders(w.Header()) + autoUpgrades, autoUpgradesVersion, err := h.getAutoUpgrades(r.Context()) + if err != nil { + w.Write(scripts.ErrorBashScript) + return nil, nil + } + settings := scriptSettings{ - token: params.ByName("token"), - appInstallMode: false, - joinMethod: r.URL.Query().Get("method"), - installUpdater: automaticUpgrades(h.ClusterFeatures), + token: params.ByName("token"), + appInstallMode: false, + joinMethod: r.URL.Query().Get("method"), + installUpdater: autoUpgrades, + automaticUpgradesVersion: autoUpgradesVersion, } script, err := getJoinScript(r.Context(), settings, h.GetProxyClient()) @@ -250,12 +274,19 @@ func (h *Handler) getAppJoinScriptHandle(w http.ResponseWriter, r *http.Request, return nil, nil } + autoUpgrades, autoUpgradesVersion, err := h.getAutoUpgrades(r.Context()) + if err != nil { + w.Write(scripts.ErrorBashScript) + return nil, nil + } + settings := scriptSettings{ - token: params.ByName("token"), - appInstallMode: true, - appName: name, - appURI: uri, - installUpdater: automaticUpgrades(h.ClusterFeatures), + token: params.ByName("token"), + appInstallMode: true, + appName: name, + appURI: uri, + installUpdater: autoUpgrades, + automaticUpgradesVersion: autoUpgradesVersion, } script, err := getJoinScript(r.Context(), settings, h.GetProxyClient()) @@ -277,10 +308,17 @@ func (h *Handler) getAppJoinScriptHandle(w http.ResponseWriter, r *http.Request, func (h *Handler) getDatabaseJoinScriptHandle(w http.ResponseWriter, r *http.Request, params httprouter.Params) (interface{}, error) { httplib.SetScriptHeaders(w.Header()) + autoUpgrades, autoUpgradesVersion, err := h.getAutoUpgrades(r.Context()) + if err != nil { + w.Write(scripts.ErrorBashScript) + return nil, nil + } + settings := scriptSettings{ - token: params.ByName("token"), - databaseInstallMode: true, - installUpdater: automaticUpgrades(h.ClusterFeatures), + token: params.ByName("token"), + databaseInstallMode: true, + installUpdater: autoUpgrades, + automaticUpgradesVersion: autoUpgradesVersion, } script, err := getJoinScript(r.Context(), settings, h.GetProxyClient()) @@ -391,17 +429,17 @@ func getJoinScript(ctx context.Context, settings scriptSettings, m nodeAPIGetter // The install script will install the updater (teleport-ent-updater) for Cloud customers enrolled in Automatic Upgrades. // The repo channel used must be `stable/cloud` which has the available packages for the Cloud Customer's agents. - // It pins the teleport version to the one specified by https://updates.releases.teleport.dev/v1/stable/cloud/version + // It pins the teleport version to the one specified by the default version channel // This ensures the initial installed version is the same as the `teleport-ent-updater` would install. if settings.installUpdater { - repoChannel = stableCloudChannelRepo - cloudStableVersion, err := automaticupgrades.Version(ctx, settings.automaticUpgradesVersionURL) - if err != nil { - return "", trace.Wrap(err) + if settings.automaticUpgradesVersion == "" { + return "", trace.Wrap(err, "automatic upgrades version must be set when installUpdater is true") } - // cloudStableVersion has vX.Y.Z format, however the script expects the version to not include the `v` - version = strings.TrimPrefix(cloudStableVersion, "v") + repoChannel = stableCloudChannelRepo + // automaticUpgradesVersion has vX.Y.Z format, however the script + // expects the version to not include the `v` so we strip it + version = strings.TrimPrefix(settings.automaticUpgradesVersion, "v") } // This section relies on Go's default zero values to make sure that the settings diff --git a/lib/web/join_tokens_test.go b/lib/web/join_tokens_test.go index eb2c88e403e98..a7f008dd3b8c8 100644 --- a/lib/web/join_tokens_test.go +++ b/lib/web/join_tokens_test.go @@ -20,14 +20,10 @@ import ( "context" "encoding/hex" "fmt" - "net/http" - "net/http/httptest" - "net/url" "regexp" "testing" "github.com/gravitational/trace" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/gravitational/teleport" @@ -935,18 +931,7 @@ func TestJoinScript(t *testing.T) { t.Run("using repo", func(t *testing.T) { t.Run("installUpdater is true", func(t *testing.T) { currentStableCloudVersion := "v99.1.1" - - httpTestServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, r.URL.Path, "/v1/stable/cloud/version") - w.WriteHeader(http.StatusOK) - w.Write([]byte(currentStableCloudVersion)) - })) - defer httpTestServer.Close() - - versionURL, err := url.JoinPath(httpTestServer.URL, "/v1/stable/cloud/version") - require.NoError(t, err) - - script, err := getJoinScript(context.Background(), scriptSettings{token: validToken, installUpdater: true, automaticUpgradesVersionURL: versionURL}, m) + script, err := getJoinScript(context.Background(), scriptSettings{token: validToken, installUpdater: true, automaticUpgradesVersion: currentStableCloudVersion}, m) require.NoError(t, err) // list of packages must include the updater