Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
return nil, trace.Wrap(err)
}

if cfg.VersionStorage == nil {
return nil, trace.BadParameter("version storage is not set")
}
if cfg.Trust == nil {
cfg.Trust = local.NewCAService(cfg.Backend)
}
Expand Down
14 changes: 10 additions & 4 deletions lib/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,11 @@ import (
)

type testPack struct {
bk backend.Backend
clusterName types.ClusterName
a *Server
mockEmitter *eventstest.MockRecorderEmitter
bk backend.Backend
versionStorage VersionStorage
clusterName types.ClusterName
a *Server
mockEmitter *eventstest.MockRecorderEmitter
}

func newTestPack(
Expand All @@ -109,9 +110,13 @@ func newTestPack(
return p, trace.Wrap(err)
}

p.versionStorage = NewFakeTeleportVersion()

p.mockEmitter = &eventstest.MockRecorderEmitter{}
authConfig := &InitConfig{
DataDir: dataDir,
Backend: p.bk,
VersionStorage: p.versionStorage,
ClusterName: p.clusterName,
Authority: testauthority.New(),
Emitter: p.mockEmitter,
Expand Down Expand Up @@ -1110,6 +1115,7 @@ func TestUpdateConfig(t *testing.T) {
authConfig := &InitConfig{
ClusterName: clusterName,
Backend: s.bk,
VersionStorage: s.versionStorage,
Authority: testauthority.New(),
SkipPeriodicOperations: true,
}
Expand Down
4 changes: 4 additions & 0 deletions lib/auth/github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,14 @@ func setupGithubContext(ctx context.Context, t *testing.T) *githubContext {
ClusterName: "me.localhost",
})
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, tt.b.Close())
})

authConfig := &InitConfig{
ClusterName: clusterName,
Backend: tt.b,
VersionStorage: NewFakeTeleportVersion(),
Authority: authority.New(),
SkipPeriodicOperations: true,
}
Expand Down
21 changes: 21 additions & 0 deletions lib/auth/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"testing"
"time"

"github.com/coreos/go-semver/semver"
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
Expand Down Expand Up @@ -281,7 +282,9 @@ func NewTestAuthServer(cfg TestAuthServerConfig) (*TestAuthServer, error) {
}

srv.AuthServer, err = NewServer(&InitConfig{
DataDir: cfg.Dir,
Backend: srv.Backend,
VersionStorage: NewFakeTeleportVersion(),
Authority: authority.NewWithClock(cfg.Clock),
Access: access,
Identity: identity,
Expand Down Expand Up @@ -1121,6 +1124,24 @@ func (t *TestTLSServer) Stop() error {
return err
}

// FakeTeleportVersion fake version storage implementation always return current version.
type FakeTeleportVersion struct{}

// NewFakeTeleportVersion creates fake version storage.
func NewFakeTeleportVersion() *FakeTeleportVersion {
return &FakeTeleportVersion{}
}

// GetTeleportVersion returns current Teleport version.
func (s FakeTeleportVersion) GetTeleportVersion(_ context.Context) (*semver.Version, error) {
return teleport.SemVersion, nil
}

// WriteTeleportVersion stub function for writing.
func (s FakeTeleportVersion) WriteTeleportVersion(_ context.Context, _ *semver.Version) error {
return nil
}

// NewServerIdentity generates new server identity, used in tests
func NewServerIdentity(clt *Server, hostID string, role types.SystemRole) (*state.Identity, error) {
priv, pub, err := native.GenerateKeyPair()
Expand Down
15 changes: 15 additions & 0 deletions lib/auth/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"sync"
"time"

"github.com/coreos/go-semver/semver"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -71,11 +72,22 @@ var log = logrus.WithFields(logrus.Fields{
teleport.ComponentKey: teleport.ComponentAuth,
})

// VersionStorage local storage for saving the version.
type VersionStorage interface {
// GetTeleportVersion reads the last known Teleport version from storage.
GetTeleportVersion(ctx context.Context) (*semver.Version, error)
// WriteTeleportVersion writes the last known Teleport version to the storage.
WriteTeleportVersion(ctx context.Context, version *semver.Version) error
}

// InitConfig is auth server init config
type InitConfig struct {
// Backend is auth backend to use
Backend backend.Backend

// VersionStorage is a version storage for local process
VersionStorage VersionStorage

// Authority is key generator that we use
Authority sshca.Authority

Expand Down Expand Up @@ -336,6 +348,9 @@ func initCluster(ctx context.Context, cfg InitConfig, asrv *Server) error {
if err != nil {
return trace.Wrap(err)
}
if err := validateAndUpdateTeleportVersion(ctx, cfg.VersionStorage, teleport.SemVersion, firstStart); err != nil {
return trace.Wrap(err)
}

// if bootstrap resources are supplied, use them to bootstrap backend state
// on initial startup.
Expand Down
95 changes: 95 additions & 0 deletions lib/auth/init_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ import (
"context"
"fmt"
"math"
"path/filepath"
"slices"
"strings"
"sync"
"testing"
"time"

"github.com/coreos/go-semver/semver"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/uuid"
Expand All @@ -42,7 +44,9 @@ import (
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/types"
apisshutils "github.com/gravitational/teleport/api/utils/sshutils"
"github.com/gravitational/teleport/lib"
"github.com/gravitational/teleport/lib/auth/state"
"github.com/gravitational/teleport/lib/auth/storage"
"github.com/gravitational/teleport/lib/auth/testauthority"
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/backend/lite"
Expand Down Expand Up @@ -1008,16 +1012,28 @@ func setupConfig(t *testing.T) InitConfig {
bk, err := lite.New(context.TODO(), backend.Params{"path": tempDir})
require.NoError(t, err)

processStorage, err := storage.NewProcessStorage(
context.Background(),
filepath.Join(tempDir, teleport.ComponentProcess),
)
require.NoError(t, err)

clusterName, err := services.NewClusterNameWithRandomID(types.ClusterNameSpecV2{
ClusterName: "me.localhost",
})
require.NoError(t, err)

t.Cleanup(func() {
bk.Close()
processStorage.Close()
})

return InitConfig{
DataDir: tempDir,
HostUUID: "00000000-0000-0000-0000-000000000000",
NodeName: "foo",
Backend: bk,
VersionStorage: processStorage,
Authority: testauthority.New(),
ClusterAuditConfig: types.DefaultClusterAuditConfig(),
ClusterNetworkingConfig: types.DefaultClusterNetworkingConfig(),
Expand Down Expand Up @@ -1772,3 +1788,82 @@ func TestInitCreatesCertsIfMissing(t *testing.T) {
require.Len(t, cert, 1)
}
}

func TestTeleportProcessAuthVersionUpgradeCheck(t *testing.T) {
lib.SetInsecureDevMode(true)
defer lib.SetInsecureDevMode(false)

tests := []struct {
name string
initialVersion string
expectedVersion string
expectError bool
skipCheck bool
}{
{
name: "first-launch",
initialVersion: "",
expectedVersion: teleport.Version,
expectError: false,
},
{
name: "old-version-upgrade",
initialVersion: fmt.Sprintf("%d.0.0", teleport.SemVersion.Major-1),
expectedVersion: teleport.Version,
expectError: false,
},
{
name: "major-upgrade-fail",
initialVersion: fmt.Sprintf("%d.0.0", teleport.SemVersion.Major-2),
expectedVersion: fmt.Sprintf("%d.0.0", teleport.SemVersion.Major-2),
expectError: true,
},
{
name: "major-upgrade-with-dev-skip-check",
initialVersion: fmt.Sprintf("%d.0.0", teleport.SemVersion.Major-2),
expectedVersion: fmt.Sprintf("%d.0.0", teleport.SemVersion.Major-2),
expectError: false,
skipCheck: true,
},
{
name: "major-downgrade-fail",
initialVersion: fmt.Sprintf("%d.0.0", teleport.SemVersion.Major+2),
expectedVersion: fmt.Sprintf("%d.0.0", teleport.SemVersion.Major+2),
expectError: true,
},
{
name: "major-downgrade-with-dev-skip-check",
initialVersion: fmt.Sprintf("%d.0.0", teleport.SemVersion.Major+2),
expectedVersion: fmt.Sprintf("%d.0.0", teleport.SemVersion.Major+2),
expectError: false,
skipCheck: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

authCfg := setupConfig(t)

if test.initialVersion != "" {
err := authCfg.VersionStorage.WriteTeleportVersion(ctx, semver.New(test.initialVersion))
require.NoError(t, err)
}
if test.skipCheck {
t.Setenv(skipVersionUpgradeCheckEnv, "yes")
}

_, err := Init(ctx, authCfg)
if test.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
}

lastKnownVersion, err := authCfg.VersionStorage.GetTeleportVersion(ctx)
require.NoError(t, err)
require.Equal(t, test.expectedVersion, lastKnownVersion.String())
})
}
}
5 changes: 5 additions & 0 deletions lib/auth/password_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,14 @@ func setupPasswordSuite(t *testing.T) *passwordSuite {
ClusterName: "me.localhost",
})
require.NoError(t, err)
t.Cleanup(func() {
s.bk.Close()
})

authConfig := &InitConfig{
ClusterName: clusterName,
Backend: s.bk,
VersionStorage: NewFakeTeleportVersion(),
Authority: authority.New(),
SkipPeriodicOperations: true,
}
Expand Down
2 changes: 1 addition & 1 deletion lib/auth/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (s *StateV2) CheckAndSetDefaults() error {
return nil
}

// UnknownVersion is a sentinel value used to distinguish between InitialLocalVersion being missing from
// UnknownLocalVersion is a sentinel value used to distinguish between InitialLocalVersion being missing from
// state due to malformed input and InitialLocalVersion being missing due to the state having been created before
// teleport started recording InitialLocalVersion.
const UnknownLocalVersion = "unknown"
Expand Down
30 changes: 30 additions & 0 deletions lib/auth/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"encoding/json"
"strings"

"github.com/coreos/go-semver/semver"
"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/client/proto"
Expand All @@ -45,6 +46,10 @@ const (
statesPrefix = "states"
// idsPrefix is a key prefix for identities
idsPrefix = "ids"
// teleportPrefix is a key prefix to store internal data
teleportPrefix = "teleport"
// lastKnownVersion is a key for storing version of teleport
lastKnownVersion = "last-known-version"
)

// stateBackend implements abstraction over local or remote storage backend methods
Expand Down Expand Up @@ -203,6 +208,31 @@ func (p *ProcessStorage) WriteIdentity(name string, id state.Identity) error {
return trace.Wrap(err)
}

// GetTeleportVersion reads the last known Teleport version from storage.
func (p *ProcessStorage) GetTeleportVersion(ctx context.Context) (*semver.Version, error) {
item, err := p.stateStorage.Get(ctx, backend.Key(teleportPrefix, lastKnownVersion))
if err != nil {
return nil, trace.Wrap(err)
}
return semver.NewVersion(string(item.Value))
}

// WriteTeleportVersion writes the last known Teleport version to the storage.
func (p *ProcessStorage) WriteTeleportVersion(ctx context.Context, version *semver.Version) error {
if version == nil {
return trace.BadParameter("wrong version parameter")
}
item := backend.Item{
Key: backend.Key(teleportPrefix, lastKnownVersion),
Value: []byte(version.String()),
}
_, err := p.stateStorage.Put(ctx, item)
if err != nil {
return trace.Wrap(err)
}
return nil
}

// ReadLocalIdentity reads, parses and returns the given pub/pri key + cert from the
// key storage (dataDir).
func ReadLocalIdentity(dataDir string, id state.IdentityID) (*state.Identity, error) {
Expand Down
9 changes: 7 additions & 2 deletions lib/auth/trustedcluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,6 @@ func TestValidateTrustedCluster(t *testing.T) {
func newTestAuthServer(ctx context.Context, t *testing.T, name ...string) *Server {
bk, err := memory.New(memory.Config{})
require.NoError(t, err)
t.Cleanup(func() { bk.Close() })

clusterName := "me.localhost"
if len(name) != 0 {
Expand All @@ -414,12 +413,18 @@ func newTestAuthServer(ctx context.Context, t *testing.T, name ...string) *Serve
authConfig := &InitConfig{
ClusterName: clusterNameRes,
Backend: bk,
VersionStorage: NewFakeTeleportVersion(),
Authority: authority.New(),
SkipPeriodicOperations: true,
}
a, err := NewServer(authConfig)
require.NoError(t, err)
t.Cleanup(func() { a.Close() })

t.Cleanup(func() {
bk.Close()
a.Close()
})

require.NoError(t, a.SetClusterAuditConfig(ctx, types.DefaultClusterAuditConfig()))
_, err = a.UpsertClusterNetworkingConfig(ctx, types.DefaultClusterNetworkingConfig())
require.NoError(t, err)
Expand Down
Loading