diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 204686c268b09..dd13e284b2f32 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -168,6 +168,9 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) { return nil, trace.Wrap(err) } + if cfg.VersionStorage == nil { + return nil, trace.BadParameter("version storage is not set") + } if cfg.Trust == nil { cfg.Trust = local.NewCAService(cfg.Backend) } diff --git a/lib/auth/auth_test.go b/lib/auth/auth_test.go index 3ff1627c61757..092debe7a2e88 100644 --- a/lib/auth/auth_test.go +++ b/lib/auth/auth_test.go @@ -85,10 +85,11 @@ import ( ) type testPack struct { - bk backend.Backend - clusterName types.ClusterName - a *Server - mockEmitter *eventstest.MockRecorderEmitter + bk backend.Backend + versionStorage VersionStorage + clusterName types.ClusterName + a *Server + mockEmitter *eventstest.MockRecorderEmitter } func newTestPack( @@ -109,9 +110,13 @@ func newTestPack( return p, trace.Wrap(err) } + p.versionStorage = NewFakeTeleportVersion() + p.mockEmitter = &eventstest.MockRecorderEmitter{} authConfig := &InitConfig{ + DataDir: dataDir, Backend: p.bk, + VersionStorage: p.versionStorage, ClusterName: p.clusterName, Authority: testauthority.New(), Emitter: p.mockEmitter, @@ -1110,6 +1115,7 @@ func TestUpdateConfig(t *testing.T) { authConfig := &InitConfig{ ClusterName: clusterName, Backend: s.bk, + VersionStorage: s.versionStorage, Authority: testauthority.New(), SkipPeriodicOperations: true, } diff --git a/lib/auth/github_test.go b/lib/auth/github_test.go index 56c64e2ec2f65..c06d771d76f37 100644 --- a/lib/auth/github_test.go +++ b/lib/auth/github_test.go @@ -75,10 +75,14 @@ func setupGithubContext(ctx context.Context, t *testing.T) *githubContext { ClusterName: "me.localhost", }) require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, tt.b.Close()) + }) authConfig := &InitConfig{ ClusterName: clusterName, Backend: tt.b, + VersionStorage: NewFakeTeleportVersion(), Authority: authority.New(), SkipPeriodicOperations: true, } diff --git a/lib/auth/helpers.go b/lib/auth/helpers.go index 4e0d8a43fbb70..f4410a22eddd1 100644 --- a/lib/auth/helpers.go +++ b/lib/auth/helpers.go @@ -27,6 +27,7 @@ import ( "testing" "time" + "github.com/coreos/go-semver/semver" "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" @@ -281,7 +282,9 @@ func NewTestAuthServer(cfg TestAuthServerConfig) (*TestAuthServer, error) { } srv.AuthServer, err = NewServer(&InitConfig{ + DataDir: cfg.Dir, Backend: srv.Backend, + VersionStorage: NewFakeTeleportVersion(), Authority: authority.NewWithClock(cfg.Clock), Access: access, Identity: identity, @@ -1121,6 +1124,24 @@ func (t *TestTLSServer) Stop() error { return err } +// FakeTeleportVersion fake version storage implementation always return current version. +type FakeTeleportVersion struct{} + +// NewFakeTeleportVersion creates fake version storage. +func NewFakeTeleportVersion() *FakeTeleportVersion { + return &FakeTeleportVersion{} +} + +// GetTeleportVersion returns current Teleport version. +func (s FakeTeleportVersion) GetTeleportVersion(_ context.Context) (*semver.Version, error) { + return teleport.SemVersion, nil +} + +// WriteTeleportVersion stub function for writing. +func (s FakeTeleportVersion) WriteTeleportVersion(_ context.Context, _ *semver.Version) error { + return nil +} + // NewServerIdentity generates new server identity, used in tests func NewServerIdentity(clt *Server, hostID string, role types.SystemRole) (*state.Identity, error) { priv, pub, err := native.GenerateKeyPair() diff --git a/lib/auth/init.go b/lib/auth/init.go index 779b17bcb08f8..64cf79285a71b 100644 --- a/lib/auth/init.go +++ b/lib/auth/init.go @@ -31,6 +31,7 @@ import ( "sync" "time" + "github.com/coreos/go-semver/semver" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/sirupsen/logrus" @@ -71,11 +72,22 @@ var log = logrus.WithFields(logrus.Fields{ teleport.ComponentKey: teleport.ComponentAuth, }) +// VersionStorage local storage for saving the version. +type VersionStorage interface { + // GetTeleportVersion reads the last known Teleport version from storage. + GetTeleportVersion(ctx context.Context) (*semver.Version, error) + // WriteTeleportVersion writes the last known Teleport version to the storage. + WriteTeleportVersion(ctx context.Context, version *semver.Version) error +} + // InitConfig is auth server init config type InitConfig struct { // Backend is auth backend to use Backend backend.Backend + // VersionStorage is a version storage for local process + VersionStorage VersionStorage + // Authority is key generator that we use Authority sshca.Authority @@ -336,6 +348,9 @@ func initCluster(ctx context.Context, cfg InitConfig, asrv *Server) error { if err != nil { return trace.Wrap(err) } + if err := validateAndUpdateTeleportVersion(ctx, cfg.VersionStorage, teleport.SemVersion, firstStart); err != nil { + return trace.Wrap(err) + } // if bootstrap resources are supplied, use them to bootstrap backend state // on initial startup. diff --git a/lib/auth/init_test.go b/lib/auth/init_test.go index 10e7338d01629..a988f6577e30d 100644 --- a/lib/auth/init_test.go +++ b/lib/auth/init_test.go @@ -22,12 +22,14 @@ import ( "context" "fmt" "math" + "path/filepath" "slices" "strings" "sync" "testing" "time" + "github.com/coreos/go-semver/semver" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/uuid" @@ -42,7 +44,9 @@ import ( "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/types" apisshutils "github.com/gravitational/teleport/api/utils/sshutils" + "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/auth/state" + "github.com/gravitational/teleport/lib/auth/storage" "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/backend/lite" @@ -1008,16 +1012,28 @@ func setupConfig(t *testing.T) InitConfig { bk, err := lite.New(context.TODO(), backend.Params{"path": tempDir}) require.NoError(t, err) + processStorage, err := storage.NewProcessStorage( + context.Background(), + filepath.Join(tempDir, teleport.ComponentProcess), + ) + require.NoError(t, err) + clusterName, err := services.NewClusterNameWithRandomID(types.ClusterNameSpecV2{ ClusterName: "me.localhost", }) require.NoError(t, err) + t.Cleanup(func() { + bk.Close() + processStorage.Close() + }) + return InitConfig{ DataDir: tempDir, HostUUID: "00000000-0000-0000-0000-000000000000", NodeName: "foo", Backend: bk, + VersionStorage: processStorage, Authority: testauthority.New(), ClusterAuditConfig: types.DefaultClusterAuditConfig(), ClusterNetworkingConfig: types.DefaultClusterNetworkingConfig(), @@ -1772,3 +1788,82 @@ func TestInitCreatesCertsIfMissing(t *testing.T) { require.Len(t, cert, 1) } } + +func TestTeleportProcessAuthVersionUpgradeCheck(t *testing.T) { + lib.SetInsecureDevMode(true) + defer lib.SetInsecureDevMode(false) + + tests := []struct { + name string + initialVersion string + expectedVersion string + expectError bool + skipCheck bool + }{ + { + name: "first-launch", + initialVersion: "", + expectedVersion: teleport.Version, + expectError: false, + }, + { + name: "old-version-upgrade", + initialVersion: fmt.Sprintf("%d.0.0", teleport.SemVersion.Major-1), + expectedVersion: teleport.Version, + expectError: false, + }, + { + name: "major-upgrade-fail", + initialVersion: fmt.Sprintf("%d.0.0", teleport.SemVersion.Major-2), + expectedVersion: fmt.Sprintf("%d.0.0", teleport.SemVersion.Major-2), + expectError: true, + }, + { + name: "major-upgrade-with-dev-skip-check", + initialVersion: fmt.Sprintf("%d.0.0", teleport.SemVersion.Major-2), + expectedVersion: fmt.Sprintf("%d.0.0", teleport.SemVersion.Major-2), + expectError: false, + skipCheck: true, + }, + { + name: "major-downgrade-fail", + initialVersion: fmt.Sprintf("%d.0.0", teleport.SemVersion.Major+2), + expectedVersion: fmt.Sprintf("%d.0.0", teleport.SemVersion.Major+2), + expectError: true, + }, + { + name: "major-downgrade-with-dev-skip-check", + initialVersion: fmt.Sprintf("%d.0.0", teleport.SemVersion.Major+2), + expectedVersion: fmt.Sprintf("%d.0.0", teleport.SemVersion.Major+2), + expectError: false, + skipCheck: true, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + authCfg := setupConfig(t) + + if test.initialVersion != "" { + err := authCfg.VersionStorage.WriteTeleportVersion(ctx, semver.New(test.initialVersion)) + require.NoError(t, err) + } + if test.skipCheck { + t.Setenv(skipVersionUpgradeCheckEnv, "yes") + } + + _, err := Init(ctx, authCfg) + if test.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + + lastKnownVersion, err := authCfg.VersionStorage.GetTeleportVersion(ctx) + require.NoError(t, err) + require.Equal(t, test.expectedVersion, lastKnownVersion.String()) + }) + } +} diff --git a/lib/auth/password_test.go b/lib/auth/password_test.go index 90fa0e9540fc9..af5a8b559ac4c 100644 --- a/lib/auth/password_test.go +++ b/lib/auth/password_test.go @@ -77,9 +77,14 @@ func setupPasswordSuite(t *testing.T) *passwordSuite { ClusterName: "me.localhost", }) require.NoError(t, err) + t.Cleanup(func() { + s.bk.Close() + }) + authConfig := &InitConfig{ ClusterName: clusterName, Backend: s.bk, + VersionStorage: NewFakeTeleportVersion(), Authority: authority.New(), SkipPeriodicOperations: true, } diff --git a/lib/auth/state/state.go b/lib/auth/state/state.go index aa2f7348d0b52..325115a825d3e 100644 --- a/lib/auth/state/state.go +++ b/lib/auth/state/state.go @@ -75,7 +75,7 @@ func (s *StateV2) CheckAndSetDefaults() error { return nil } -// UnknownVersion is a sentinel value used to distinguish between InitialLocalVersion being missing from +// UnknownLocalVersion is a sentinel value used to distinguish between InitialLocalVersion being missing from // state due to malformed input and InitialLocalVersion being missing due to the state having been created before // teleport started recording InitialLocalVersion. const UnknownLocalVersion = "unknown" diff --git a/lib/auth/storage/storage.go b/lib/auth/storage/storage.go index 2930760b66b49..8caf0ca93684c 100644 --- a/lib/auth/storage/storage.go +++ b/lib/auth/storage/storage.go @@ -29,6 +29,7 @@ import ( "encoding/json" "strings" + "github.com/coreos/go-semver/semver" "github.com/gravitational/trace" "github.com/gravitational/teleport/api/client/proto" @@ -45,6 +46,10 @@ const ( statesPrefix = "states" // idsPrefix is a key prefix for identities idsPrefix = "ids" + // teleportPrefix is a key prefix to store internal data + teleportPrefix = "teleport" + // lastKnownVersion is a key for storing version of teleport + lastKnownVersion = "last-known-version" ) // stateBackend implements abstraction over local or remote storage backend methods @@ -203,6 +208,31 @@ func (p *ProcessStorage) WriteIdentity(name string, id state.Identity) error { return trace.Wrap(err) } +// GetTeleportVersion reads the last known Teleport version from storage. +func (p *ProcessStorage) GetTeleportVersion(ctx context.Context) (*semver.Version, error) { + item, err := p.stateStorage.Get(ctx, backend.Key(teleportPrefix, lastKnownVersion)) + if err != nil { + return nil, trace.Wrap(err) + } + return semver.NewVersion(string(item.Value)) +} + +// WriteTeleportVersion writes the last known Teleport version to the storage. +func (p *ProcessStorage) WriteTeleportVersion(ctx context.Context, version *semver.Version) error { + if version == nil { + return trace.BadParameter("wrong version parameter") + } + item := backend.Item{ + Key: backend.Key(teleportPrefix, lastKnownVersion), + Value: []byte(version.String()), + } + _, err := p.stateStorage.Put(ctx, item) + if err != nil { + return trace.Wrap(err) + } + return nil +} + // ReadLocalIdentity reads, parses and returns the given pub/pri key + cert from the // key storage (dataDir). func ReadLocalIdentity(dataDir string, id state.IdentityID) (*state.Identity, error) { diff --git a/lib/auth/trustedcluster_test.go b/lib/auth/trustedcluster_test.go index ed80a476b08d6..ee7f52c7a89ae 100644 --- a/lib/auth/trustedcluster_test.go +++ b/lib/auth/trustedcluster_test.go @@ -400,7 +400,6 @@ func TestValidateTrustedCluster(t *testing.T) { func newTestAuthServer(ctx context.Context, t *testing.T, name ...string) *Server { bk, err := memory.New(memory.Config{}) require.NoError(t, err) - t.Cleanup(func() { bk.Close() }) clusterName := "me.localhost" if len(name) != 0 { @@ -414,12 +413,18 @@ func newTestAuthServer(ctx context.Context, t *testing.T, name ...string) *Serve authConfig := &InitConfig{ ClusterName: clusterNameRes, Backend: bk, + VersionStorage: NewFakeTeleportVersion(), Authority: authority.New(), SkipPeriodicOperations: true, } a, err := NewServer(authConfig) require.NoError(t, err) - t.Cleanup(func() { a.Close() }) + + t.Cleanup(func() { + bk.Close() + a.Close() + }) + require.NoError(t, a.SetClusterAuditConfig(ctx, types.DefaultClusterAuditConfig())) _, err = a.UpsertClusterNetworkingConfig(ctx, types.DefaultClusterNetworkingConfig()) require.NoError(t, err) diff --git a/lib/auth/version.go b/lib/auth/version.go new file mode 100644 index 0000000000000..b601898060030 --- /dev/null +++ b/lib/auth/version.go @@ -0,0 +1,91 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package auth + +import ( + "context" + "os" + + "github.com/coreos/go-semver/semver" + "github.com/gravitational/trace" +) + +const ( + // majorVersionConstraint is the major version constraint when previous major version must be + // present in the storage, if not - we must refuse to start. + // TODO(vapopov): DELETE IN 18.0.0 + majorVersionConstraint = 18 + + // skipVersionUpgradeCheckEnv is environment variable key for disabling the check + // major version upgrade check. + skipVersionUpgradeCheckEnv = "TELEPORT_UNSTABLE_SKIP_VERSION_UPGRADE_CHECK" +) + +// validateAndUpdateTeleportVersion validates that the major version persistent in the backend +// meets our upgrade compatibility guide. +func validateAndUpdateTeleportVersion( + ctx context.Context, + storage VersionStorage, + currentVersion *semver.Version, + firstTimeStart bool, +) error { + if skip := os.Getenv(skipVersionUpgradeCheckEnv); skip != "" { + return nil + } + + lastKnownVersion, err := storage.GetTeleportVersion(ctx) + if trace.IsNotFound(err) { + // When this is not the first start, we have to ensure that previous versions, + // introduced before this check, were also verified. Therefore, not having a version + // in the database means the last known version is = majorVersionConstraint && !firstTimeStart { + return trace.BadParameter("Unsupported upgrade path detected: to %v. "+ + "Teleport supports direct upgrades to the next major version only.\n "+ + "For instance, if you have version 15.x.x, you must upgrade to version 16.x.x first. "+ + "See compatibility guarantees for details: "+ + "https://goteleport.com/docs/upgrading/overview/#component-compatibility.", + currentVersion.String()) + } + if err := storage.WriteTeleportVersion(ctx, currentVersion); err != nil { + return trace.Wrap(err) + } + return nil + } else if err != nil { + return trace.Wrap(err) + } + + if currentVersion.Major-lastKnownVersion.Major > 1 { + return trace.BadParameter("Unsupported upgrade path detected: from %v to %v. "+ + "Teleport supports direct upgrades to the next major version only.\n Please upgrade "+ + "your cluster to version %d.x.x first. See compatibility guarantees for details: "+ + "https://goteleport.com/docs/upgrading/overview/#component-compatibility.", + lastKnownVersion, currentVersion.String(), lastKnownVersion.Major+1) + } + if lastKnownVersion.Major-currentVersion.Major > 1 { + return trace.BadParameter("Unsupported downgrade path detected: from %v to %v. "+ + "Teleport doesn't support major version downgrade.\n Please downgrade "+ + "your cluster to version %d.x.x first. See compatibility guarantees for details: "+ + "https://goteleport.com/docs/upgrading/overview/#component-compatibility.", + lastKnownVersion, currentVersion.String(), lastKnownVersion.Major-1) + } + if err := storage.WriteTeleportVersion(ctx, currentVersion); err != nil { + return trace.Wrap(err) + } + return nil +} diff --git a/lib/service/service.go b/lib/service/service.go index 7b128d75d3a44..6c4f6537e4d3d 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -1860,6 +1860,7 @@ func (process *TeleportProcess) initAuthService() error { emitter = localLog } } + clusterName := cfg.Auth.ClusterName.GetClusterName() ident, err := process.storage.ReadIdentity(state.IdentityCurrent, types.RoleAdmin) if err != nil && !trace.IsNotFound(err) { @@ -1911,6 +1912,7 @@ func (process *TeleportProcess) initAuthService() error { process.ExitContext(), auth.InitConfig{ Backend: b, + VersionStorage: process.storage, Authority: cfg.Keygen, ClusterConfiguration: cfg.ClusterConfiguration, ClusterAuditConfig: cfg.Auth.AuditConfig, diff --git a/lib/srv/mock.go b/lib/srv/mock.go index 0a443872f8f0e..5fbc95ba0e9b5 100644 --- a/lib/srv/mock.go +++ b/lib/srv/mock.go @@ -127,12 +127,16 @@ func newMockServer(t *testing.T) *mockServer { StaticTokens: []types.ProvisionTokenV1{}, }) require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, bk.Close()) + }) authCfg := &auth.InitConfig{ - Backend: bk, - Authority: testauthority.New(), - ClusterName: clusterName, - StaticTokens: staticTokens, + Backend: bk, + VersionStorage: auth.NewFakeTeleportVersion(), + Authority: testauthority.New(), + ClusterName: clusterName, + StaticTokens: staticTokens, } authServer, err := auth.NewServer(authCfg, auth.WithClock(clock))