From eec189f7b8b80268b85a565ecfdf47f6418e659d Mon Sep 17 00:00:00 2001 From: Russell Jones Date: Wed, 6 Jul 2022 13:38:10 -0700 Subject: [PATCH 1/2] Refactor tests under services package. Refactored "lib/services/suite" and "lib/services/local" packages to use testify instead of gocheck. Swapped backend from "lite" to "memory" to cut time to run tests in half. --- lib/auth/github_test.go | 116 +- lib/auth/tls_test.go | 1709 ++++++++++++---------- lib/services/local/access_test.go | 11 +- lib/services/local/apps_test.go | 8 +- lib/services/local/configuration_test.go | 120 +- lib/services/local/databases_test.go | 8 +- lib/services/local/resource_test.go | 9 +- lib/services/local/services_test.go | 157 +- lib/services/local/unstable_test.go | 12 +- lib/services/local/users_test.go | 9 +- lib/services/suite/presence_test.go | 37 +- lib/services/suite/suite.go | 837 +++++------ 12 files changed, 1586 insertions(+), 1447 deletions(-) diff --git a/lib/auth/github_test.go b/lib/auth/github_test.go index 722dbee2953c9..bb92d364d9e0c 100644 --- a/lib/auth/github_test.go +++ b/lib/auth/github_test.go @@ -33,96 +33,107 @@ import ( "github.com/gravitational/trace" + "github.com/google/go-cmp/cmp" "github.com/google/uuid" "github.com/jonboulle/clockwork" - "gopkg.in/check.v1" ) -func TestAPI(t *testing.T) { check.TestingT(t) } - -type GithubSuite struct { +type githubContext struct { a *Server mockEmitter *eventstest.MockEmitter b backend.Backend c clockwork.FakeClock } -var _ = check.Suite(&GithubSuite{}) +func setupGithubContext(ctx context.Context, t *testing.T) *githubContext { + var tt githubContext + t.Cleanup(func() { tt.Close() }) -func (s *GithubSuite) SetUpSuite(c *check.C) { - s.c = clockwork.NewFakeClockAt(time.Now()) + tt.c = clockwork.NewFakeClockAt(time.Now()) var err error - s.b, err = lite.NewWithConfig(context.Background(), lite.Config{ - Path: c.MkDir(), + tt.b, err = lite.NewWithConfig(context.Background(), lite.Config{ + Path: t.TempDir(), PollStreamPeriod: 200 * time.Millisecond, - Clock: s.c, + Clock: tt.c, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) clusterName, err := services.NewClusterNameWithRandomID(types.ClusterNameSpecV2{ ClusterName: "me.localhost", }) - c.Assert(err, check.IsNil) + require.NoError(t, err) authConfig := &InitConfig{ ClusterName: clusterName, - Backend: s.b, + Backend: tt.b, Authority: authority.New(), SkipPeriodicOperations: true, } - s.a, err = NewServer(authConfig) - c.Assert(err, check.IsNil) + tt.a, err = NewServer(authConfig) + require.NoError(t, err) - s.mockEmitter = &eventstest.MockEmitter{} - s.a.emitter = s.mockEmitter + tt.mockEmitter = &eventstest.MockEmitter{} + tt.a.emitter = tt.mockEmitter + + return &tt +} + +func (tt *githubContext) Close() error { + return trace.NewAggregate( + tt.a.Close(), + tt.b.Close()) } -func (s *GithubSuite) TestPopulateClaims(c *check.C) { +func TestPopulateClaims(t *testing.T) { claims, err := populateGithubClaims(&testGithubAPIClient{}) - c.Assert(err, check.IsNil) - c.Assert(claims, check.DeepEquals, &types.GithubClaims{ + require.NoError(t, err) + require.Empty(t, cmp.Diff(claims, &types.GithubClaims{ Username: "octocat", OrganizationToTeams: map[string][]string{ "org1": {"team1", "team2"}, "org2": {"team1"}, }, Teams: []string{"team1", "team2", "team1"}, - }) + })) + } -func (s *GithubSuite) TestCreateGithubUser(c *check.C) { +func TestCreateGithubUser(t *testing.T) { + ctx := context.Background() + tt := setupGithubContext(ctx, t) + // Dry-run creation of Github user. - user, err := s.a.createGithubUser(context.Background(), &createUserParams{ + user, err := tt.a.createGithubUser(context.Background(), &createUserParams{ connectorName: "github", username: "foo@example.com", roles: []string{"admin"}, sessionTTL: 1 * time.Minute, }, true) - c.Assert(err, check.IsNil) - c.Assert(user.GetName(), check.Equals, "foo@example.com") + require.NoError(t, err) + require.Equal(t, user.GetName(), "foo@example.com") // Dry-run must not create a user. - _, err = s.a.GetUser("foo@example.com", false) - c.Assert(err, check.NotNil) + _, err = tt.a.GetUser("foo@example.com", false) + require.Error(t, err) // Create GitHub user with 1 minute expiry. - _, err = s.a.createGithubUser(context.Background(), &createUserParams{ + _, err = tt.a.createGithubUser(context.Background(), &createUserParams{ connectorName: "github", username: "foo", roles: []string{"admin"}, sessionTTL: 1 * time.Minute, }, false) - c.Assert(err, check.IsNil) + require.NoError(t, err) // Within that 1 minute period the user should still exist. - _, err = s.a.GetUser("foo", false) - c.Assert(err, check.IsNil) + _, err = tt.a.GetUser("foo", false) + require.NoError(t, err) // Advance time 2 minutes, the user should be gone. - s.c.Advance(2 * time.Minute) - _, err = s.a.GetUser("foo", false) - c.Assert(err, check.NotNil) + tt.c.Advance(2 * time.Minute) + _, err = tt.a.GetUser("foo", false) + require.Error(t, err) } type testGithubAPIClient struct{} @@ -151,7 +162,10 @@ func (c *testGithubAPIClient) getTeams() ([]teamResponse, error) { }, nil } -func (s *GithubSuite) TestValidateGithubAuthCallbackEventsEmitted(c *check.C) { +func TestValidateGithubAuthCallbackEventsEmitted(t *testing.T) { + ctx := context.Background() + tt := setupGithubContext(ctx, t) + auth := &GithubAuthResponse{ Username: "test-name", } @@ -178,20 +192,20 @@ func (s *GithubSuite) TestValidateGithubAuthCallbackEventsEmitted(c *check.C) { diagCtx.info.GithubClaims = claims return auth, nil } - _, _ = validateGithubAuthCallbackHelper(context.Background(), m, nil, s.a.emitter) - c.Assert(s.mockEmitter.LastEvent().GetType(), check.Equals, events.UserLoginEvent) - c.Assert(s.mockEmitter.LastEvent().GetCode(), check.Equals, events.UserSSOLoginCode) - c.Assert(ssoDiagInfoCalls, check.Equals, 0) - s.mockEmitter.Reset() + _, _ = validateGithubAuthCallbackHelper(context.Background(), m, nil, tt.a.emitter) + require.Equal(t, tt.mockEmitter.LastEvent().GetType(), events.UserLoginEvent) + require.Equal(t, tt.mockEmitter.LastEvent().GetCode(), events.UserSSOLoginCode) + require.Equal(t, ssoDiagInfoCalls, 0) + tt.mockEmitter.Reset() // Test failure event. m.mockValidateGithubAuthCallback = func(ctx context.Context, diagCtx *ssoDiagContext, q url.Values) (*GithubAuthResponse, error) { diagCtx.info.GithubClaims = claims return auth, trace.BadParameter("") } - _, _ = validateGithubAuthCallbackHelper(context.Background(), m, nil, s.a.emitter) - c.Assert(s.mockEmitter.LastEvent().GetCode(), check.Equals, events.UserSSOLoginFailureCode) - c.Assert(ssoDiagInfoCalls, check.Equals, 0) + _, _ = validateGithubAuthCallbackHelper(context.Background(), m, nil, tt.a.emitter) + require.Equal(t, tt.mockEmitter.LastEvent().GetCode(), events.UserSSOLoginFailureCode) + require.Equal(t, ssoDiagInfoCalls, 0) // TestFlow: true m.testFlow = true @@ -201,20 +215,20 @@ func (s *GithubSuite) TestValidateGithubAuthCallbackEventsEmitted(c *check.C) { diagCtx.info.GithubClaims = claims return auth, nil } - _, _ = validateGithubAuthCallbackHelper(context.Background(), m, nil, s.a.emitter) - c.Assert(s.mockEmitter.LastEvent().GetType(), check.Equals, events.UserLoginEvent) - c.Assert(s.mockEmitter.LastEvent().GetCode(), check.Equals, events.UserSSOTestFlowLoginCode) - c.Assert(ssoDiagInfoCalls, check.Equals, 1) - s.mockEmitter.Reset() + _, _ = validateGithubAuthCallbackHelper(context.Background(), m, nil, tt.a.emitter) + require.Equal(t, tt.mockEmitter.LastEvent().GetType(), events.UserLoginEvent) + require.Equal(t, tt.mockEmitter.LastEvent().GetCode(), events.UserSSOTestFlowLoginCode) + require.Equal(t, ssoDiagInfoCalls, 1) + tt.mockEmitter.Reset() // Test failure event. m.mockValidateGithubAuthCallback = func(ctx context.Context, diagCtx *ssoDiagContext, q url.Values) (*GithubAuthResponse, error) { diagCtx.info.GithubClaims = claims return auth, trace.BadParameter("") } - _, _ = validateGithubAuthCallbackHelper(context.Background(), m, nil, s.a.emitter) - c.Assert(s.mockEmitter.LastEvent().GetCode(), check.Equals, events.UserSSOTestFlowLoginFailureCode) - c.Assert(ssoDiagInfoCalls, check.Equals, 2) + _, _ = validateGithubAuthCallbackHelper(context.Background(), m, nil, tt.a.emitter) + require.Equal(t, tt.mockEmitter.LastEvent().GetCode(), events.UserSSOTestFlowLoginFailureCode) + require.Equal(t, ssoDiagInfoCalls, 2) } type mockedGithubManager struct { diff --git a/lib/auth/tls_test.go b/lib/auth/tls_test.go index 79a7a6dccd9d9..5f356c065d438 100644 --- a/lib/auth/tls_test.go +++ b/lib/auth/tls_test.go @@ -35,7 +35,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" - "gopkg.in/check.v1" "github.com/gravitational/trace" @@ -59,10 +58,6 @@ import ( "github.com/gravitational/teleport/lib/utils" ) -// https://github.com/gravitational/teleport/commit/f7144fe8600a47d0cd40d4d3f4e3929cf0fdc43c -// replaced TLSSuite and check with authContext and testify, but is not being -// backported. I'm adding both here to backport TestWebSessionMultiAccessRequests - type authContext struct { dataDir string server *TestTLSServer @@ -92,177 +87,162 @@ func (a *authContext) Close() error { return a.server.Close() } -type TLSSuite struct { - dataDir string - server *TestTLSServer - clock clockwork.FakeClock -} - -var _ = check.Suite(&TLSSuite{}) - -func (s *TLSSuite) SetUpTest(c *check.C) { - s.dataDir = c.MkDir() - s.clock = clockwork.NewFakeClock() - - testAuthServer, err := NewTestAuthServer(TestAuthServerConfig{ - Dir: s.dataDir, - Clock: s.clock, - }) - c.Assert(err, check.IsNil) - s.server, err = testAuthServer.NewTestTLSServer() - c.Assert(err, check.IsNil) -} - -func (s *TLSSuite) TearDownTest(c *check.C) { - if s.server != nil { - s.server.Close() - } -} - // TestRemoteBuiltinRole tests remote builtin role // that gets mapped to remote proxy readonly role -func (s *TLSSuite) TestRemoteBuiltinRole(c *check.C) { +func TestRemoteBuiltinRole(t *testing.T) { + t.Parallel() + ctx := context.Background() + tt := setupAuthContext(ctx, t) + remoteServer, err := NewTestAuthServer(TestAuthServerConfig{ - Dir: c.MkDir(), + Dir: t.TempDir(), ClusterName: "remote", - Clock: s.clock, + Clock: tt.clock, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) - certPool, err := s.server.CertPool() - c.Assert(err, check.IsNil) + certPool, err := tt.server.CertPool() + require.NoError(t, err) // without trust, proxy server will get rejected // remote auth server will get rejected because it is not supported remoteProxy, err := remoteServer.NewRemoteClient( - TestBuiltin(types.RoleProxy), s.server.Addr(), certPool) - c.Assert(err, check.IsNil) + TestBuiltin(types.RoleProxy), tt.server.Addr(), certPool) + require.NoError(t, err) // certificate authority is not recognized, because // the trust has not been established yet _, err = remoteProxy.GetNodes(ctx, apidefaults.Namespace) - fixtures.ExpectConnectionProblem(c, err) + require.True(t, trace.IsConnectionProblem(err)) // after trust is established, things are good - err = s.server.AuthServer.Trust(ctx, remoteServer, nil) - c.Assert(err, check.IsNil) + err = tt.server.AuthServer.Trust(ctx, remoteServer, nil) + require.NoError(t, err) // re initialize client with trust established. remoteProxy, err = remoteServer.NewRemoteClient( - TestBuiltin(types.RoleProxy), s.server.Addr(), certPool) - c.Assert(err, check.IsNil) + TestBuiltin(types.RoleProxy), tt.server.Addr(), certPool) + require.NoError(t, err) _, err = remoteProxy.GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.IsNil) + require.NoError(t, err) // remote auth server will get rejected even with established trust remoteAuth, err := remoteServer.NewRemoteClient( - TestBuiltin(types.RoleAuth), s.server.Addr(), certPool) - c.Assert(err, check.IsNil) + TestBuiltin(types.RoleAuth), tt.server.Addr(), certPool) + require.NoError(t, err) _, err = remoteAuth.GetDomainName(ctx) - fixtures.ExpectAccessDenied(c, err) + require.True(t, trace.IsAccessDenied(err)) } // TestAcceptedUsage tests scenario when server is set up // to accept certificates with certain usage metadata restrictions // encoded -func (s *TLSSuite) TestAcceptedUsage(c *check.C) { +func TestAcceptedUsage(t *testing.T) { + t.Parallel() + ctx := context.Background() + tt := setupAuthContext(ctx, t) + server, err := NewTestAuthServer(TestAuthServerConfig{ - Dir: c.MkDir(), + Dir: t.TempDir(), ClusterName: "remote", AcceptedUsage: []string{"usage:k8s"}, - Clock: s.clock, + Clock: tt.clock, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) user, _, err := CreateUserAndRole(server.AuthServer, "user", []string{"role"}) - c.Assert(err, check.IsNil) + require.NoError(t, err) tlsServer, err := server.NewTestTLSServer() - c.Assert(err, check.IsNil) + require.NoError(t, err) defer tlsServer.Close() // Unrestricted clients can use restricted servers client, err := tlsServer.NewClient(TestUser(user.GetName())) - c.Assert(err, check.IsNil) + require.NoError(t, err) // certificate authority is not recognized, because // the trust has not been established yet _, err = client.GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.IsNil) + require.NoError(t, err) // restricted clients can use restricted servers if restrictions // exactly match identity := TestUser(user.GetName()) identity.AcceptedUsage = []string{"usage:k8s"} client, err = tlsServer.NewClient(identity) - c.Assert(err, check.IsNil) + require.NoError(t, err) _, err = client.GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.IsNil) + require.NoError(t, err) // restricted clients can will be rejected if usage does not match identity = TestUser(user.GetName()) identity.AcceptedUsage = []string{"usage:extra"} client, err = tlsServer.NewClient(identity) - c.Assert(err, check.IsNil) + require.NoError(t, err) _, err = client.GetNodes(ctx, apidefaults.Namespace) - fixtures.ExpectAccessDenied(c, err) + require.True(t, trace.IsAccessDenied(err)) // restricted clients can will be rejected, for now if there is any mismatch, // including extra usage. identity = TestUser(user.GetName()) identity.AcceptedUsage = []string{"usage:k8s", "usage:unknown"} client, err = tlsServer.NewClient(identity) - c.Assert(err, check.IsNil) + require.NoError(t, err) _, err = client.GetNodes(ctx, apidefaults.Namespace) - fixtures.ExpectAccessDenied(c, err) + require.True(t, trace.IsAccessDenied(err)) } // TestRemoteRotation tests remote builtin role // that attempts certificate authority rotation -func (s *TLSSuite) TestRemoteRotation(c *check.C) { - ctx := context.TODO() +func TestRemoteRotation(t *testing.T) { + t.Parallel() + + ctx := context.Background() + tt := setupAuthContext(ctx, t) + var ok bool remoteServer, err := NewTestAuthServer(TestAuthServerConfig{ - Dir: c.MkDir(), + Dir: t.TempDir(), ClusterName: "remote", - Clock: s.clock, + Clock: tt.clock, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) - certPool, err := s.server.CertPool() - c.Assert(err, check.IsNil) + certPool, err := tt.server.CertPool() + require.NoError(t, err) // after trust is established, things are good - err = s.server.AuthServer.Trust(ctx, remoteServer, nil) - c.Assert(err, check.IsNil) + err = tt.server.AuthServer.Trust(ctx, remoteServer, nil) + require.NoError(t, err) remoteProxy, err := remoteServer.NewRemoteClient( - TestBuiltin(types.RoleProxy), s.server.Addr(), certPool) - c.Assert(err, check.IsNil) + TestBuiltin(types.RoleProxy), tt.server.Addr(), certPool) + require.NoError(t, err) remoteAuth, err := remoteServer.NewRemoteClient( - TestBuiltin(types.RoleAuth), s.server.Addr(), certPool) - c.Assert(err, check.IsNil) + TestBuiltin(types.RoleAuth), tt.server.Addr(), certPool) + require.NoError(t, err) // remote cluster starts rotation gracePeriod := time.Hour remoteServer.AuthServer.privateKey, ok = fixtures.PEMBytes["rsa"] - c.Assert(ok, check.Equals, true) + require.Equal(t, ok, true) err = remoteServer.AuthServer.RotateCertAuthority(ctx, RotateRequest{ Type: types.HostCA, GracePeriod: &gracePeriod, TargetPhase: types.RotationPhaseInit, Mode: types.RotationModeManual, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // moves to update clients err = remoteServer.AuthServer.RotateCertAuthority(ctx, RotateRequest{ @@ -271,454 +251,478 @@ func (s *TLSSuite) TestRemoteRotation(c *check.C) { TargetPhase: types.RotationPhaseUpdateClients, Mode: types.RotationModeManual, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) remoteCA, err := remoteServer.AuthServer.GetCertAuthority(ctx, types.CertAuthID{ DomainName: remoteServer.ClusterName, Type: types.HostCA, }, false) - c.Assert(err, check.IsNil) + require.NoError(t, err) // remote proxy should be rejected when trying to rotate ca // that is not associated with the remote cluster clone := remoteCA.Clone() - clone.SetName(s.server.ClusterName()) + clone.SetName(tt.server.ClusterName()) err = remoteProxy.RotateExternalCertAuthority(ctx, clone) - fixtures.ExpectAccessDenied(c, err) + require.True(t, trace.IsAccessDenied(err)) // remote proxy can't upsert the certificate authority, // only to rotate it (in remote rotation only certain fields are updated) err = remoteProxy.UpsertCertAuthority(remoteCA) - fixtures.ExpectAccessDenied(c, err) + require.True(t, trace.IsAccessDenied(err)) // remote proxy can't read local cert authority with secrets _, err = remoteProxy.GetCertAuthority(ctx, types.CertAuthID{ - DomainName: s.server.ClusterName(), + DomainName: tt.server.ClusterName(), Type: types.HostCA, }, true) - fixtures.ExpectAccessDenied(c, err) + require.True(t, trace.IsAccessDenied(err)) // no secrets read is allowed _, err = remoteProxy.GetCertAuthority(ctx, types.CertAuthID{ - DomainName: s.server.ClusterName(), + DomainName: tt.server.ClusterName(), Type: types.HostCA, }, false) - c.Assert(err, check.IsNil) + require.NoError(t, err) // remote auth server will get rejected err = remoteAuth.RotateExternalCertAuthority(ctx, remoteCA) - fixtures.ExpectAccessDenied(c, err) + require.True(t, trace.IsAccessDenied(err)) // remote proxy should be able to perform remote cert authority // rotation err = remoteProxy.RotateExternalCertAuthority(ctx, remoteCA) - c.Assert(err, check.IsNil) + require.NoError(t, err) // newRemoteProxy should be trusted by the auth server newRemoteProxy, err := remoteServer.NewRemoteClient( - TestBuiltin(types.RoleProxy), s.server.Addr(), certPool) - c.Assert(err, check.IsNil) + TestBuiltin(types.RoleProxy), tt.server.Addr(), certPool) + require.NoError(t, err) _, err = newRemoteProxy.GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.IsNil) + require.NoError(t, err) // old proxy client is still trusted - _, err = s.server.CloneClient(remoteProxy).GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.IsNil) + _, err = tt.server.CloneClient(remoteProxy).GetNodes(ctx, apidefaults.Namespace) + require.NoError(t, err) } // TestLocalProxyPermissions tests new local proxy permissions // as it's now allowed to update host cert authorities of remote clusters -func (s *TLSSuite) TestLocalProxyPermissions(c *check.C) { +func TestLocalProxyPermissions(t *testing.T) { + t.Parallel() + ctx := context.Background() + tt := setupAuthContext(ctx, t) + remoteServer, err := NewTestAuthServer(TestAuthServerConfig{ - Dir: c.MkDir(), + Dir: t.TempDir(), ClusterName: "remote", - Clock: s.clock, + Clock: tt.clock, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // after trust is established, things are good - err = s.server.AuthServer.Trust(ctx, remoteServer, nil) - c.Assert(err, check.IsNil) + err = tt.server.AuthServer.Trust(ctx, remoteServer, nil) + require.NoError(t, err) - ca, err := s.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ - DomainName: s.server.ClusterName(), + ca, err := tt.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ + DomainName: tt.server.ClusterName(), Type: types.HostCA, }, false) - c.Assert(err, check.IsNil) + require.NoError(t, err) - proxy, err := s.server.NewClient(TestBuiltin(types.RoleProxy)) - c.Assert(err, check.IsNil) + proxy, err := tt.server.NewClient(TestBuiltin(types.RoleProxy)) + require.NoError(t, err) // local proxy can't update local cert authorities err = proxy.UpsertCertAuthority(ca) - fixtures.ExpectAccessDenied(c, err) + require.True(t, trace.IsAccessDenied(err)) // local proxy is allowed to update host CA of remote cert authorities - remoteCA, err := s.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ + remoteCA, err := tt.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ DomainName: remoteServer.ClusterName, Type: types.HostCA, }, false) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = proxy.UpsertCertAuthority(remoteCA) - c.Assert(err, check.IsNil) + require.NoError(t, err) } // TestAutoRotation tests local automatic rotation -func (s *TLSSuite) TestAutoRotation(c *check.C) { +func TestAutoRotation(t *testing.T) { + t.Parallel() + ctx := context.Background() + tt := setupAuthContext(ctx, t) + var ok bool // create proxy client - proxy, err := s.server.NewClient(TestBuiltin(types.RoleProxy)) - c.Assert(err, check.IsNil) + proxy, err := tt.server.NewClient(TestBuiltin(types.RoleProxy)) + require.NoError(t, err) // client works before rotation is initiated _, err = proxy.GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.IsNil) + require.NoError(t, err) // starts rotation - s.server.Auth().privateKey, ok = fixtures.PEMBytes["rsa"] - c.Assert(ok, check.Equals, true) + tt.server.Auth().privateKey, ok = fixtures.PEMBytes["rsa"] + require.Equal(t, ok, true) gracePeriod := time.Hour - err = s.server.Auth().RotateCertAuthority(ctx, RotateRequest{ + err = tt.server.Auth().RotateCertAuthority(ctx, RotateRequest{ Type: types.HostCA, GracePeriod: &gracePeriod, Mode: types.RotationModeAuto, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // advance rotation by clock - s.clock.Advance(gracePeriod/3 + time.Minute) - err = s.server.Auth().autoRotateCertAuthorities(ctx) - c.Assert(err, check.IsNil) + tt.clock.Advance(gracePeriod/3 + time.Minute) + err = tt.server.Auth().autoRotateCertAuthorities(ctx) + require.NoError(t, err) - ca, err := s.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ - DomainName: s.server.ClusterName(), + ca, err := tt.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ + DomainName: tt.server.ClusterName(), Type: types.HostCA, }, false) - c.Assert(err, check.IsNil) - c.Assert(ca.GetRotation().Phase, check.Equals, types.RotationPhaseUpdateClients) + require.NoError(t, err) + require.Equal(t, ca.GetRotation().Phase, types.RotationPhaseUpdateClients) // old clients should work - _, err = s.server.CloneClient(proxy).GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.IsNil) + _, err = tt.server.CloneClient(proxy).GetNodes(ctx, apidefaults.Namespace) + require.NoError(t, err) // new clients work as well - _, err = s.server.NewClient(TestBuiltin(types.RoleProxy)) - c.Assert(err, check.IsNil) + _, err = tt.server.NewClient(TestBuiltin(types.RoleProxy)) + require.NoError(t, err) // advance rotation by clock - s.clock.Advance((gracePeriod*2)/3 + time.Minute) - err = s.server.Auth().autoRotateCertAuthorities(ctx) - c.Assert(err, check.IsNil) + tt.clock.Advance((gracePeriod*2)/3 + time.Minute) + err = tt.server.Auth().autoRotateCertAuthorities(ctx) + require.NoError(t, err) - ca, err = s.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ - DomainName: s.server.ClusterName(), + ca, err = tt.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ + DomainName: tt.server.ClusterName(), Type: types.HostCA, }, false) - c.Assert(err, check.IsNil) - c.Assert(ca.GetRotation().Phase, check.Equals, types.RotationPhaseUpdateServers) + require.NoError(t, err) + require.Equal(t, ca.GetRotation().Phase, types.RotationPhaseUpdateServers) // old clients should work - _, err = s.server.CloneClient(proxy).GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.IsNil) + _, err = tt.server.CloneClient(proxy).GetNodes(ctx, apidefaults.Namespace) + require.NoError(t, err) // new clients work as well - newProxy, err := s.server.NewClient(TestBuiltin(types.RoleProxy)) - c.Assert(err, check.IsNil) + newProxy, err := tt.server.NewClient(TestBuiltin(types.RoleProxy)) + require.NoError(t, err) _, err = newProxy.GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.IsNil) + require.NoError(t, err) // complete rotation - advance rotation by clock - s.clock.Advance(gracePeriod/3 + time.Minute) - err = s.server.Auth().autoRotateCertAuthorities(ctx) - c.Assert(err, check.IsNil) - ca, err = s.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ - DomainName: s.server.ClusterName(), + tt.clock.Advance(gracePeriod/3 + time.Minute) + err = tt.server.Auth().autoRotateCertAuthorities(ctx) + require.NoError(t, err) + ca, err = tt.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ + DomainName: tt.server.ClusterName(), Type: types.HostCA, }, false) - c.Assert(err, check.IsNil) - c.Assert(ca.GetRotation().Phase, check.Equals, types.RotationPhaseStandby) - c.Assert(err, check.IsNil) + require.NoError(t, err) + require.Equal(t, ca.GetRotation().Phase, types.RotationPhaseStandby) + require.NoError(t, err) // old clients should no longer work // new client has to be created here to force re-create the new // connection instead of re-using the one from pool // this is not going to be a problem in real teleport // as it reloads the full server after reload - _, err = s.server.CloneClient(proxy).GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.ErrorMatches, ".*bad certificate.*") + _, err = tt.server.CloneClient(proxy).GetNodes(ctx, apidefaults.Namespace) + require.ErrorContains(t, err, "bad certificate") // new clients work - _, err = s.server.CloneClient(newProxy).GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.IsNil) + _, err = tt.server.CloneClient(newProxy).GetNodes(ctx, apidefaults.Namespace) + require.NoError(t, err) } // TestAutoFallback tests local automatic rotation fallback, // when user intervenes with rollback and rotation gets switched // to manual mode -func (s *TLSSuite) TestAutoFallback(c *check.C) { +func TestAutoFallback(t *testing.T) { + t.Parallel() + ctx := context.Background() + tt := setupAuthContext(ctx, t) + var ok bool // create proxy client just for test purposes - proxy, err := s.server.NewClient(TestBuiltin(types.RoleProxy)) - c.Assert(err, check.IsNil) + proxy, err := tt.server.NewClient(TestBuiltin(types.RoleProxy)) + require.NoError(t, err) // client works before rotation is initiated _, err = proxy.GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.IsNil) + require.NoError(t, err) // starts rotation - s.server.Auth().privateKey, ok = fixtures.PEMBytes["rsa"] - c.Assert(ok, check.Equals, true) + tt.server.Auth().privateKey, ok = fixtures.PEMBytes["rsa"] + require.Equal(t, ok, true) gracePeriod := time.Hour - err = s.server.Auth().RotateCertAuthority(ctx, RotateRequest{ + err = tt.server.Auth().RotateCertAuthority(ctx, RotateRequest{ Type: types.HostCA, GracePeriod: &gracePeriod, Mode: types.RotationModeAuto, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // advance rotation by clock - s.clock.Advance(gracePeriod/3 + time.Minute) - err = s.server.Auth().autoRotateCertAuthorities(ctx) - c.Assert(err, check.IsNil) + tt.clock.Advance(gracePeriod/3 + time.Minute) + err = tt.server.Auth().autoRotateCertAuthorities(ctx) + require.NoError(t, err) - ca, err := s.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ - DomainName: s.server.ClusterName(), + ca, err := tt.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ + DomainName: tt.server.ClusterName(), Type: types.HostCA, }, false) - c.Assert(err, check.IsNil) - c.Assert(ca.GetRotation().Phase, check.Equals, types.RotationPhaseUpdateClients) - c.Assert(ca.GetRotation().Mode, check.Equals, types.RotationModeAuto) + require.NoError(t, err) + require.Equal(t, ca.GetRotation().Phase, types.RotationPhaseUpdateClients) + require.Equal(t, ca.GetRotation().Mode, types.RotationModeAuto) // rollback rotation - err = s.server.Auth().RotateCertAuthority(ctx, RotateRequest{ + err = tt.server.Auth().RotateCertAuthority(ctx, RotateRequest{ Type: types.HostCA, GracePeriod: &gracePeriod, TargetPhase: types.RotationPhaseRollback, Mode: types.RotationModeManual, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) - ca, err = s.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ - DomainName: s.server.ClusterName(), + ca, err = tt.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ + DomainName: tt.server.ClusterName(), Type: types.HostCA, }, false) - c.Assert(err, check.IsNil) - c.Assert(ca.GetRotation().Phase, check.Equals, types.RotationPhaseRollback) - c.Assert(ca.GetRotation().Mode, check.Equals, types.RotationModeManual) + require.NoError(t, err) + require.Equal(t, ca.GetRotation().Phase, types.RotationPhaseRollback) + require.Equal(t, ca.GetRotation().Mode, types.RotationModeManual) } // TestManualRotation tests local manual rotation // that performs full-cycle certificate authority rotation -func (s *TLSSuite) TestManualRotation(c *check.C) { +func TestManualRotation(t *testing.T) { + t.Parallel() + ctx := context.Background() + tt := setupAuthContext(ctx, t) + var ok bool // create proxy client just for test purposes - proxy, err := s.server.NewClient(TestBuiltin(types.RoleProxy)) - c.Assert(err, check.IsNil) + proxy, err := tt.server.NewClient(TestBuiltin(types.RoleProxy)) + require.NoError(t, err) // client works before rotation is initiated _, err = proxy.GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.IsNil) + require.NoError(t, err) // can't jump to mid-phase gracePeriod := time.Hour - s.server.Auth().privateKey, ok = fixtures.PEMBytes["rsa"] - c.Assert(ok, check.Equals, true) - err = s.server.Auth().RotateCertAuthority(ctx, RotateRequest{ + tt.server.Auth().privateKey, ok = fixtures.PEMBytes["rsa"] + require.Equal(t, ok, true) + err = tt.server.Auth().RotateCertAuthority(ctx, RotateRequest{ Type: types.HostCA, GracePeriod: &gracePeriod, TargetPhase: types.RotationPhaseUpdateServers, Mode: types.RotationModeManual, }) - fixtures.ExpectBadParameter(c, err) + require.True(t, trace.IsBadParameter(err)) // starts rotation - err = s.server.Auth().RotateCertAuthority(ctx, RotateRequest{ + err = tt.server.Auth().RotateCertAuthority(ctx, RotateRequest{ Type: types.HostCA, GracePeriod: &gracePeriod, TargetPhase: types.RotationPhaseInit, Mode: types.RotationModeManual, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // old clients should work - _, err = s.server.CloneClient(proxy).GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.IsNil) + _, err = tt.server.CloneClient(proxy).GetNodes(ctx, apidefaults.Namespace) + require.NoError(t, err) // clients reconnect - err = s.server.Auth().RotateCertAuthority(ctx, RotateRequest{ + err = tt.server.Auth().RotateCertAuthority(ctx, RotateRequest{ Type: types.HostCA, GracePeriod: &gracePeriod, TargetPhase: types.RotationPhaseUpdateClients, Mode: types.RotationModeManual, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // old clients should work - _, err = s.server.CloneClient(proxy).GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.IsNil) + _, err = tt.server.CloneClient(proxy).GetNodes(ctx, apidefaults.Namespace) + require.NoError(t, err) // new clients work as well - newProxy, err := s.server.NewClient(TestBuiltin(types.RoleProxy)) - c.Assert(err, check.IsNil) + newProxy, err := tt.server.NewClient(TestBuiltin(types.RoleProxy)) + require.NoError(t, err) _, err = newProxy.GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.IsNil) + require.NoError(t, err) // can't jump to standy - err = s.server.Auth().RotateCertAuthority(ctx, RotateRequest{ + err = tt.server.Auth().RotateCertAuthority(ctx, RotateRequest{ Type: types.HostCA, GracePeriod: &gracePeriod, TargetPhase: types.RotationPhaseStandby, Mode: types.RotationModeManual, }) - fixtures.ExpectBadParameter(c, err) + require.True(t, trace.IsBadParameter(err)) // advance rotation: - err = s.server.Auth().RotateCertAuthority(ctx, RotateRequest{ + err = tt.server.Auth().RotateCertAuthority(ctx, RotateRequest{ Type: types.HostCA, GracePeriod: &gracePeriod, TargetPhase: types.RotationPhaseUpdateServers, Mode: types.RotationModeManual, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // old clients should work - _, err = s.server.CloneClient(proxy).GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.IsNil) + _, err = tt.server.CloneClient(proxy).GetNodes(ctx, apidefaults.Namespace) + require.NoError(t, err) // new clients work as well - _, err = s.server.CloneClient(newProxy).GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.IsNil) + _, err = tt.server.CloneClient(newProxy).GetNodes(ctx, apidefaults.Namespace) + require.NoError(t, err) // complete rotation - err = s.server.Auth().RotateCertAuthority(ctx, RotateRequest{ + err = tt.server.Auth().RotateCertAuthority(ctx, RotateRequest{ Type: types.HostCA, GracePeriod: &gracePeriod, TargetPhase: types.RotationPhaseStandby, Mode: types.RotationModeManual, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // old clients should no longer work // new client has to be created here to force re-create the new // connection instead of re-using the one from pool // this is not going to be a problem in real teleport // as it reloads the full server after reload - _, err = s.server.CloneClient(proxy).GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.ErrorMatches, ".*bad certificate.*") + _, err = tt.server.CloneClient(proxy).GetNodes(ctx, apidefaults.Namespace) + require.ErrorContains(t, err, "bad certificate") // new clients work - _, err = s.server.CloneClient(newProxy).GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.IsNil) + _, err = tt.server.CloneClient(newProxy).GetNodes(ctx, apidefaults.Namespace) + require.NoError(t, err) } // TestRollback tests local manual rotation rollback -func (s *TLSSuite) TestRollback(c *check.C) { +func TestRollback(t *testing.T) { + t.Parallel() + ctx := context.Background() + tt := setupAuthContext(ctx, t) + var ok bool // create proxy client just for test purposes - proxy, err := s.server.NewClient(TestBuiltin(types.RoleProxy)) - c.Assert(err, check.IsNil) + proxy, err := tt.server.NewClient(TestBuiltin(types.RoleProxy)) + require.NoError(t, err) // client works before rotation is initiated _, err = proxy.GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.IsNil) + require.NoError(t, err) // starts rotation gracePeriod := time.Hour - s.server.Auth().privateKey, ok = fixtures.PEMBytes["rsa"] - c.Assert(ok, check.Equals, true) - err = s.server.Auth().RotateCertAuthority(ctx, RotateRequest{ + tt.server.Auth().privateKey, ok = fixtures.PEMBytes["rsa"] + require.Equal(t, ok, true) + err = tt.server.Auth().RotateCertAuthority(ctx, RotateRequest{ Type: types.HostCA, GracePeriod: &gracePeriod, TargetPhase: types.RotationPhaseInit, Mode: types.RotationModeManual, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // move to update clients phase - err = s.server.Auth().RotateCertAuthority(ctx, RotateRequest{ + err = tt.server.Auth().RotateCertAuthority(ctx, RotateRequest{ Type: types.HostCA, GracePeriod: &gracePeriod, TargetPhase: types.RotationPhaseUpdateClients, Mode: types.RotationModeManual, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // new clients work - newProxy, err := s.server.NewClient(TestBuiltin(types.RoleProxy)) - c.Assert(err, check.IsNil) + newProxy, err := tt.server.NewClient(TestBuiltin(types.RoleProxy)) + require.NoError(t, err) _, err = newProxy.GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.IsNil) + require.NoError(t, err) // advance rotation: - err = s.server.Auth().RotateCertAuthority(ctx, RotateRequest{ + err = tt.server.Auth().RotateCertAuthority(ctx, RotateRequest{ Type: types.HostCA, GracePeriod: &gracePeriod, TargetPhase: types.RotationPhaseUpdateServers, Mode: types.RotationModeManual, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // rollback rotation - err = s.server.Auth().RotateCertAuthority(ctx, RotateRequest{ + err = tt.server.Auth().RotateCertAuthority(ctx, RotateRequest{ Type: types.HostCA, GracePeriod: &gracePeriod, TargetPhase: types.RotationPhaseRollback, Mode: types.RotationModeManual, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // new clients work, server still accepts the creds // because new clients should re-register and receive new certs - _, err = s.server.CloneClient(newProxy).GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.IsNil) + _, err = tt.server.CloneClient(newProxy).GetNodes(ctx, apidefaults.Namespace) + require.NoError(t, err) // can't jump to other phases - err = s.server.Auth().RotateCertAuthority(ctx, RotateRequest{ + err = tt.server.Auth().RotateCertAuthority(ctx, RotateRequest{ Type: types.HostCA, GracePeriod: &gracePeriod, TargetPhase: types.RotationPhaseUpdateClients, Mode: types.RotationModeManual, }) - fixtures.ExpectBadParameter(c, err) + require.True(t, trace.IsBadParameter(err)) // complete rollback - err = s.server.Auth().RotateCertAuthority(ctx, RotateRequest{ + err = tt.server.Auth().RotateCertAuthority(ctx, RotateRequest{ Type: types.HostCA, GracePeriod: &gracePeriod, TargetPhase: types.RotationPhaseStandby, Mode: types.RotationModeManual, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // clients with new creds will no longer work - _, err = s.server.CloneClient(newProxy).GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.ErrorMatches, ".*bad certificate.*") + _, err = tt.server.CloneClient(newProxy).GetNodes(ctx, apidefaults.Namespace) + require.ErrorContains(t, err, "bad certificate") // clients with old creds will still work - _, err = s.server.CloneClient(proxy).GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.IsNil) + _, err = tt.server.CloneClient(proxy).GetNodes(ctx, apidefaults.Namespace) + require.NoError(t, err) } // TestAppTokenRotation checks that JWT tokens can be rotated and tokens can or // can not be validated at the appropriate phase. -func (s *TLSSuite) TestAppTokenRotation(c *check.C) { +func TestAppTokenRotation(t *testing.T) { + t.Parallel() + ctx := context.Background() - client, err := s.server.NewClient(TestBuiltin(types.RoleApp)) - c.Assert(err, check.IsNil) + tt := setupAuthContext(ctx, t) + + client, err := tt.server.NewClient(TestBuiltin(types.RoleApp)) + require.NoError(t, err) // Create a JWT using the current CA, this will become the "old" CA during // rotation. @@ -727,55 +731,55 @@ func (s *TLSSuite) TestAppTokenRotation(c *check.C) { Username: "foo", Roles: []string{"bar", "baz"}, URI: "http://localhost:8080", - Expires: s.clock.Now().Add(1 * time.Minute), + Expires: tt.clock.Now().Add(1 * time.Minute), }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // Check that the "old" CA can be used to verify tokens. - oldCA, err := s.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ - DomainName: s.server.ClusterName(), + oldCA, err := tt.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ + DomainName: tt.server.ClusterName(), Type: types.JWTSigner, }, true) - c.Assert(err, check.IsNil) - c.Assert(oldCA.GetTrustedJWTKeyPairs(), check.HasLen, 1) + require.NoError(t, err) + require.Len(t, oldCA.GetTrustedJWTKeyPairs(), 1) // Verify that the JWT token validates with the JWT authority. - _, err = s.verifyJWT(s.clock, s.server.ClusterName(), oldCA.GetTrustedJWTKeyPairs(), oldJWT) - c.Assert(err, check.IsNil) + _, err = verifyJWT(tt.clock, tt.server.ClusterName(), oldCA.GetTrustedJWTKeyPairs(), oldJWT) + require.NoError(t, err) // Start rotation and move to initial phase. A new CA will be added (for // verification), but requests will continue to be signed by the old CA. gracePeriod := time.Hour - err = s.server.Auth().RotateCertAuthority(ctx, RotateRequest{ + err = tt.server.Auth().RotateCertAuthority(ctx, RotateRequest{ Type: types.JWTSigner, GracePeriod: &gracePeriod, TargetPhase: types.RotationPhaseInit, Mode: types.RotationModeManual, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // At this point in rotation, two JWT key pairs should exist. - oldCA, err = s.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ - DomainName: s.server.ClusterName(), + oldCA, err = tt.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ + DomainName: tt.server.ClusterName(), Type: types.JWTSigner, }, true) - c.Assert(err, check.IsNil) - c.Assert(oldCA.GetRotation().Phase, check.Equals, types.RotationPhaseInit) - c.Assert(oldCA.GetTrustedJWTKeyPairs(), check.HasLen, 2) + require.NoError(t, err) + require.Equal(t, oldCA.GetRotation().Phase, types.RotationPhaseInit) + require.Len(t, oldCA.GetTrustedJWTKeyPairs(), 2) // Verify that the JWT token validates with the JWT authority. - _, err = s.verifyJWT(s.clock, s.server.ClusterName(), oldCA.GetTrustedJWTKeyPairs(), oldJWT) - c.Assert(err, check.IsNil) + _, err = verifyJWT(tt.clock, tt.server.ClusterName(), oldCA.GetTrustedJWTKeyPairs(), oldJWT) + require.NoError(t, err) // Move rotation into the update client phase. In this phase, requests will // be signed by the new CA, but the old CA will be around to verify requests. - err = s.server.Auth().RotateCertAuthority(ctx, RotateRequest{ + err = tt.server.Auth().RotateCertAuthority(ctx, RotateRequest{ Type: types.JWTSigner, GracePeriod: &gracePeriod, TargetPhase: types.RotationPhaseUpdateClients, Mode: types.RotationModeManual, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // New tokens should now fail to validate with the old key. newJWT, err := client.GenerateAppToken(ctx, @@ -783,176 +787,189 @@ func (s *TLSSuite) TestAppTokenRotation(c *check.C) { Username: "foo", Roles: []string{"bar", "baz"}, URI: "http://localhost:8080", - Expires: s.clock.Now().Add(1 * time.Minute), + Expires: tt.clock.Now().Add(1 * time.Minute), }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // New tokens will validate with the new key. - newCA, err := s.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ - DomainName: s.server.ClusterName(), + newCA, err := tt.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ + DomainName: tt.server.ClusterName(), Type: types.JWTSigner, }, true) - c.Assert(err, check.IsNil) - c.Assert(newCA.GetRotation().Phase, check.Equals, types.RotationPhaseUpdateClients) - c.Assert(newCA.GetTrustedJWTKeyPairs(), check.HasLen, 2) + require.NoError(t, err) + require.Equal(t, newCA.GetRotation().Phase, types.RotationPhaseUpdateClients) + require.Len(t, newCA.GetTrustedJWTKeyPairs(), 2) // Both JWT should now validate. - _, err = s.verifyJWT(s.clock, s.server.ClusterName(), newCA.GetTrustedJWTKeyPairs(), oldJWT) - c.Assert(err, check.IsNil) - _, err = s.verifyJWT(s.clock, s.server.ClusterName(), newCA.GetTrustedJWTKeyPairs(), newJWT) - c.Assert(err, check.IsNil) + _, err = verifyJWT(tt.clock, tt.server.ClusterName(), newCA.GetTrustedJWTKeyPairs(), oldJWT) + require.NoError(t, err) + _, err = verifyJWT(tt.clock, tt.server.ClusterName(), newCA.GetTrustedJWTKeyPairs(), newJWT) + require.NoError(t, err) // Move rotation into update servers phase. - err = s.server.Auth().RotateCertAuthority(ctx, RotateRequest{ + err = tt.server.Auth().RotateCertAuthority(ctx, RotateRequest{ Type: types.JWTSigner, GracePeriod: &gracePeriod, TargetPhase: types.RotationPhaseUpdateServers, Mode: types.RotationModeManual, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // At this point only the phase on the CA should have changed. - newCA, err = s.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ - DomainName: s.server.ClusterName(), + newCA, err = tt.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ + DomainName: tt.server.ClusterName(), Type: types.JWTSigner, }, true) - c.Assert(err, check.IsNil) - c.Assert(newCA.GetRotation().Phase, check.Equals, types.RotationPhaseUpdateServers) - c.Assert(newCA.GetTrustedJWTKeyPairs(), check.HasLen, 2) + require.NoError(t, err) + require.Equal(t, newCA.GetRotation().Phase, types.RotationPhaseUpdateServers) + require.Len(t, newCA.GetTrustedJWTKeyPairs(), 2) // Both JWT should continue to validate. - _, err = s.verifyJWT(s.clock, s.server.ClusterName(), newCA.GetTrustedJWTKeyPairs(), oldJWT) - c.Assert(err, check.IsNil) - _, err = s.verifyJWT(s.clock, s.server.ClusterName(), newCA.GetTrustedJWTKeyPairs(), newJWT) - c.Assert(err, check.IsNil) + _, err = verifyJWT(tt.clock, tt.server.ClusterName(), newCA.GetTrustedJWTKeyPairs(), oldJWT) + require.NoError(t, err) + _, err = verifyJWT(tt.clock, tt.server.ClusterName(), newCA.GetTrustedJWTKeyPairs(), newJWT) + require.NoError(t, err) // Complete rotation. The old CA will be removed. - err = s.server.Auth().RotateCertAuthority(ctx, RotateRequest{ + err = tt.server.Auth().RotateCertAuthority(ctx, RotateRequest{ Type: types.JWTSigner, GracePeriod: &gracePeriod, TargetPhase: types.RotationPhaseStandby, Mode: types.RotationModeManual, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // The new CA should now only have a single key. - newCA, err = s.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ - DomainName: s.server.ClusterName(), + newCA, err = tt.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ + DomainName: tt.server.ClusterName(), Type: types.JWTSigner, }, true) - c.Assert(err, check.IsNil) - c.Assert(newCA.GetRotation().Phase, check.Equals, types.RotationPhaseStandby) - c.Assert(newCA.GetTrustedJWTKeyPairs(), check.HasLen, 1) + require.NoError(t, err) + require.Equal(t, newCA.GetRotation().Phase, types.RotationPhaseStandby) + require.Len(t, newCA.GetTrustedJWTKeyPairs(), 1) // Old token should no longer validate. - _, err = s.verifyJWT(s.clock, s.server.ClusterName(), newCA.GetTrustedJWTKeyPairs(), oldJWT) - c.Assert(err, check.NotNil) - _, err = s.verifyJWT(s.clock, s.server.ClusterName(), newCA.GetTrustedJWTKeyPairs(), newJWT) - c.Assert(err, check.IsNil) + _, err = verifyJWT(tt.clock, tt.server.ClusterName(), newCA.GetTrustedJWTKeyPairs(), oldJWT) + require.Error(t, err) + _, err = verifyJWT(tt.clock, tt.server.ClusterName(), newCA.GetTrustedJWTKeyPairs(), newJWT) + require.NoError(t, err) } // TestRemoteUser tests scenario when remote user connects to the local // auth server and some edge cases. -func (s *TLSSuite) TestRemoteUser(c *check.C) { +func TestRemoteUser(t *testing.T) { + t.Parallel() + ctx := context.Background() + tt := setupAuthContext(ctx, t) + remoteServer, err := NewTestAuthServer(TestAuthServerConfig{ - Dir: c.MkDir(), + Dir: t.TempDir(), ClusterName: "remote", - Clock: s.clock, + Clock: tt.clock, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) remoteUser, remoteRole, err := CreateUserAndRole(remoteServer.AuthServer, "remote-user", []string{"remote-role"}) - c.Assert(err, check.IsNil) + require.NoError(t, err) - certPool, err := s.server.CertPool() - c.Assert(err, check.IsNil) + certPool, err := tt.server.CertPool() + require.NoError(t, err) remoteClient, err := remoteServer.NewRemoteClient( - TestUser(remoteUser.GetName()), s.server.Addr(), certPool) - c.Assert(err, check.IsNil) + TestUser(remoteUser.GetName()), tt.server.Addr(), certPool) + require.NoError(t, err) // User is not authorized to perform any actions // as local cluster does not trust the remote cluster yet _, err = remoteClient.GetDomainName(ctx) - fixtures.ExpectConnectionProblem(c, err) + require.True(t, trace.IsConnectionProblem(err)) // Establish trust, the request will still fail, there is // no role mapping set up - err = s.server.AuthServer.Trust(ctx, remoteServer, nil) - c.Assert(err, check.IsNil) + err = tt.server.AuthServer.Trust(ctx, remoteServer, nil) + require.NoError(t, err) // Create fresh client now trust is established remoteClient, err = remoteServer.NewRemoteClient( - TestUser(remoteUser.GetName()), s.server.Addr(), certPool) - c.Assert(err, check.IsNil) + TestUser(remoteUser.GetName()), tt.server.Addr(), certPool) + require.NoError(t, err) _, err = remoteClient.GetDomainName(ctx) - fixtures.ExpectAccessDenied(c, err) + require.True(t, trace.IsAccessDenied(err)) // Establish trust and map remote role to local admin role - _, localRole, err := CreateUserAndRole(s.server.Auth(), "local-user", []string{"local-role"}) - c.Assert(err, check.IsNil) + _, localRole, err := CreateUserAndRole(tt.server.Auth(), "local-user", []string{"local-role"}) + require.NoError(t, err) - err = s.server.AuthServer.Trust(ctx, remoteServer, types.RoleMap{{Remote: remoteRole.GetName(), Local: []string{localRole.GetName()}}}) - c.Assert(err, check.IsNil) + err = tt.server.AuthServer.Trust(ctx, remoteServer, types.RoleMap{{Remote: remoteRole.GetName(), Local: []string{localRole.GetName()}}}) + require.NoError(t, err) _, err = remoteClient.GetDomainName(ctx) - c.Assert(err, check.IsNil) + require.NoError(t, err) } // TestNopUser tests user with no permissions except // the ones that require other authentication methods ("nop" user) -func (s *TLSSuite) TestNopUser(c *check.C) { +func TestNopUser(t *testing.T) { + t.Parallel() + ctx := context.Background() - client, err := s.server.NewClient(TestNop()) - c.Assert(err, check.IsNil) + tt := setupAuthContext(ctx, t) + + client, err := tt.server.NewClient(TestNop()) + require.NoError(t, err) // Nop User can get cluster name _, err = client.GetDomainName(ctx) - c.Assert(err, check.IsNil) + require.NoError(t, err) // But can not get users or nodes _, err = client.GetUsers(false) - fixtures.ExpectAccessDenied(c, err) + require.True(t, trace.IsAccessDenied(err)) _, err = client.GetNodes(ctx, apidefaults.Namespace) - fixtures.ExpectAccessDenied(c, err) + require.True(t, trace.IsAccessDenied(err)) // Endpoints that allow current user access should return access denied to // the Nop user. err = client.CheckPassword("foo", nil, "") - fixtures.ExpectAccessDenied(c, err) + require.True(t, trace.IsAccessDenied(err)) } // TestOwnRole tests that user can read roles assigned to them (used by web UI) -func (s *TLSSuite) TestReadOwnRole(c *check.C) { +func TestReadOwnRole(t *testing.T) { + t.Parallel() + ctx := context.Background() + tt := setupAuthContext(ctx, t) - clt, err := s.server.NewClient(TestAdmin()) - c.Assert(err, check.IsNil) + clt, err := tt.server.NewClient(TestAdmin()) + require.NoError(t, err) user1, userRole, err := CreateUserAndRoleWithoutRoles(clt, "user1", []string{"user1"}) - c.Assert(err, check.IsNil) + require.NoError(t, err) user2, _, err := CreateUserAndRoleWithoutRoles(clt, "user2", []string{"user2"}) - c.Assert(err, check.IsNil) + require.NoError(t, err) // user should be able to read their own roles - userClient, err := s.server.NewClient(TestUser(user1.GetName())) - c.Assert(err, check.IsNil) + userClient, err := tt.server.NewClient(TestUser(user1.GetName())) + require.NoError(t, err) _, err = userClient.GetRole(ctx, userRole.GetName()) - c.Assert(err, check.IsNil) + require.NoError(t, err) // user2 can't read user1 role - userClient2, err := s.server.NewClient(TestIdentity{I: LocalUser{Username: user2.GetName()}}) - c.Assert(err, check.IsNil) + userClient2, err := tt.server.NewClient(TestIdentity{I: LocalUser{Username: user2.GetName()}}) + require.NoError(t, err) _, err = userClient2.GetRole(ctx, userRole.GetName()) - fixtures.ExpectAccessDenied(c, err) + require.True(t, trace.IsAccessDenied(err)) } func TestGetCurrentUser(t *testing.T) { + t.Parallel() + ctx := context.Background() srv := newTestTLSServer(t) @@ -997,90 +1014,130 @@ func TestGetCurrentUserRoles(t *testing.T) { require.Empty(t, cmp.Diff(roles, []types.Role{user1Role}, cmpopts.IgnoreFields(types.Metadata{}, "ID"))) } -func (s *TLSSuite) TestAuthPreference(c *check.C) { - clt, err := s.server.NewClient(TestAdmin()) - c.Assert(err, check.IsNil) +func TestAuthPreferenceSettings(t *testing.T) { + t.Parallel() + + ctx := context.Background() + tt := setupAuthContext(ctx, t) + + clt, err := tt.server.NewClient(TestAdmin()) + require.NoError(t, err) suite := &suite.ServicesTestSuite{ ConfigS: clt, } - suite.AuthPreference(c) + suite.AuthPreference(t) } -func (s *TLSSuite) TestTunnelConnectionsCRUD(c *check.C) { - clt, err := s.server.NewClient(TestAdmin()) - c.Assert(err, check.IsNil) +func TestTunnelConnectionsCRUD(t *testing.T) { + t.Parallel() + + ctx := context.Background() + tt := setupAuthContext(ctx, t) + + clt, err := tt.server.NewClient(TestAdmin()) + require.NoError(t, err) suite := &suite.ServicesTestSuite{ PresenceS: clt, Clock: clockwork.NewFakeClock(), } - suite.TunnelConnectionsCRUD(c) + suite.TunnelConnectionsCRUD(t) } -func (s *TLSSuite) TestRemoteClustersCRUD(c *check.C) { - clt, err := s.server.NewClient(TestAdmin()) - c.Assert(err, check.IsNil) +func TestRemoteClustersCRUD(t *testing.T) { + t.Parallel() + + ctx := context.Background() + tt := setupAuthContext(ctx, t) + + clt, err := tt.server.NewClient(TestAdmin()) + require.NoError(t, err) suite := &suite.ServicesTestSuite{ PresenceS: clt, } - suite.RemoteClustersCRUD(c) + suite.RemoteClustersCRUD(t) } -func (s *TLSSuite) TestServersCRUD(c *check.C) { - clt, err := s.server.NewClient(TestAdmin()) - c.Assert(err, check.IsNil) +func TestServersCRUD(t *testing.T) { + t.Parallel() + + ctx := context.Background() + tt := setupAuthContext(ctx, t) + + clt, err := tt.server.NewClient(TestAdmin()) + require.NoError(t, err) suite := &suite.ServicesTestSuite{ PresenceS: clt, } - suite.ServerCRUD(c) + suite.ServerCRUD(t) } // TestAppServerCRUD tests CRUD functionality for services.App using an auth client. -func (s *TLSSuite) TestAppServerCRUD(c *check.C) { - clt, err := s.server.NewClient(TestBuiltin(types.RoleApp)) - c.Assert(err, check.IsNil) +func TestAppServerCRUD(t *testing.T) { + t.Parallel() + + ctx := context.Background() + tt := setupAuthContext(ctx, t) + + clt, err := tt.server.NewClient(TestBuiltin(types.RoleApp)) + require.NoError(t, err) suite := &suite.ServicesTestSuite{ PresenceS: clt, } - suite.AppServerCRUD(c) + suite.AppServerCRUD(t) } -func (s *TLSSuite) TestReverseTunnelsCRUD(c *check.C) { - clt, err := s.server.NewClient(TestAdmin()) - c.Assert(err, check.IsNil) +func TestReverseTunnelsCRUD(t *testing.T) { + t.Parallel() + + ctx := context.Background() + tt := setupAuthContext(ctx, t) + + clt, err := tt.server.NewClient(TestAdmin()) + require.NoError(t, err) suite := &suite.ServicesTestSuite{ PresenceS: clt, } - suite.ReverseTunnelsCRUD(c) + suite.ReverseTunnelsCRUD(t) } -func (s *TLSSuite) TestUsersCRUD(c *check.C) { - clt, err := s.server.NewClient(TestAdmin()) - c.Assert(err, check.IsNil) +func TestUsersCRUD(t *testing.T) { + t.Parallel() + + ctx := context.Background() + tt := setupAuthContext(ctx, t) + + clt, err := tt.server.NewClient(TestAdmin()) + require.NoError(t, err) err = clt.UpsertPassword("user1", []byte("some pass")) - c.Assert(err, check.IsNil) + require.NoError(t, err) users, err := clt.GetUsers(false) - c.Assert(err, check.IsNil) - c.Assert(len(users), check.Equals, 1) - c.Assert(users[0].GetName(), check.Equals, "user1") + require.NoError(t, err) + require.Equal(t, len(users), 1) + require.Equal(t, users[0].GetName(), "user1") - c.Assert(clt.DeleteUser(context.TODO(), "user1"), check.IsNil) + require.NoError(t, clt.DeleteUser(context.TODO(), "user1")) users, err = clt.GetUsers(false) - c.Assert(err, check.IsNil) - c.Assert(len(users), check.Equals, 0) + require.NoError(t, err) + require.Equal(t, len(users), 0) } -func (s *TLSSuite) TestPasswordGarbage(c *check.C) { - clt, err := s.server.NewClient(TestAdmin()) - c.Assert(err, check.IsNil) +func TestPasswordGarbage(t *testing.T) { + t.Parallel() + + ctx := context.Background() + tt := setupAuthContext(ctx, t) + + clt, err := tt.server.NewClient(TestAdmin()) + require.NoError(t, err) garbage := [][]byte{ nil, make([]byte, defaults.MaxPasswordLength+1), @@ -1088,50 +1145,64 @@ func (s *TLSSuite) TestPasswordGarbage(c *check.C) { } for _, g := range garbage { err := clt.CheckPassword("user1", g, "123456") - fixtures.ExpectBadParameter(c, err) + require.True(t, trace.IsBadParameter(err)) } } -func (s *TLSSuite) TestPasswordCRUD(c *check.C) { - clt, err := s.server.NewClient(TestAdmin()) - c.Assert(err, check.IsNil) +func TestPasswordCRUD(t *testing.T) { + t.Parallel() + + ctx := context.Background() + tt := setupAuthContext(ctx, t) + + clt, err := tt.server.NewClient(TestAdmin()) + require.NoError(t, err) pass := []byte("abc123") rawSecret := "def456" otpSecret := base32.StdEncoding.EncodeToString([]byte(rawSecret)) err = clt.CheckPassword("user1", pass, "123456") - c.Assert(err, check.NotNil) + require.Error(t, err) err = clt.UpsertPassword("user1", pass) - c.Assert(err, check.IsNil) + require.NoError(t, err) - dev, err := services.NewTOTPDevice("otp", otpSecret, s.clock.Now()) - c.Assert(err, check.IsNil) - ctx := context.Background() - err = s.server.Auth().UpsertMFADevice(ctx, "user1", dev) - c.Assert(err, check.IsNil) + dev, err := services.NewTOTPDevice("otp", otpSecret, tt.clock.Now()) + require.NoError(t, err) + + err = tt.server.Auth().UpsertMFADevice(ctx, "user1", dev) + require.NoError(t, err) - validToken, err := totp.GenerateCode(otpSecret, s.server.Clock().Now()) - c.Assert(err, check.IsNil) + validToken, err := totp.GenerateCode(otpSecret, tt.server.Clock().Now()) + require.NoError(t, err) err = clt.CheckPassword("user1", pass, validToken) - c.Assert(err, check.IsNil) + require.NoError(t, err) } -func (s *TLSSuite) TestTokens(c *check.C) { +func TestTokens(t *testing.T) { + t.Parallel() + ctx := context.Background() - clt, err := s.server.NewClient(TestAdmin()) - c.Assert(err, check.IsNil) + tt := setupAuthContext(ctx, t) + + clt, err := tt.server.NewClient(TestAdmin()) + require.NoError(t, err) out, err := clt.GenerateToken(ctx, &proto.GenerateTokenRequest{Roles: types.SystemRoles{types.RoleNode}}) - c.Assert(err, check.IsNil) - c.Assert(len(out), check.Not(check.Equals), 0) + require.NoError(t, err) + require.NotEqual(t, out, 0) } -func (s *TLSSuite) TestOTPCRUD(c *check.C) { - clt, err := s.server.NewClient(TestAdmin()) - c.Assert(err, check.IsNil) +func TestOTPCRUD(t *testing.T) { + t.Parallel() + + ctx := context.Background() + tt := setupAuthContext(ctx, t) + + clt, err := tt.server.NewClient(TestAdmin()) + require.NoError(t, err) user := "user1" pass := []byte("abc123") @@ -1140,16 +1211,16 @@ func (s *TLSSuite) TestOTPCRUD(c *check.C) { // upsert a password and totp secret err = clt.UpsertPassword("user1", pass) - c.Assert(err, check.IsNil) - dev, err := services.NewTOTPDevice("otp", otpSecret, s.clock.Now()) - c.Assert(err, check.IsNil) - ctx := context.Background() - err = s.server.Auth().UpsertMFADevice(ctx, user, dev) - c.Assert(err, check.IsNil) + require.NoError(t, err) + dev, err := services.NewTOTPDevice("otp", otpSecret, tt.clock.Now()) + require.NoError(t, err) + + err = tt.server.Auth().UpsertMFADevice(ctx, user, dev) + require.NoError(t, err) // a completely invalid token should return access denied err = clt.CheckPassword("user1", pass, "123456") - c.Assert(err, check.NotNil) + require.Error(t, err) // an invalid token should return access denied // @@ -1158,39 +1229,43 @@ func (s *TLSSuite) TestOTPCRUD(c *check.C) { // valid for 30 seconds + 30 second skew before and after for a usability // reasons. so a token made between seconds 31 and 60 is still valid, and // invalidity starts at 61 seconds in the future. - invalidToken, err := totp.GenerateCode(otpSecret, s.server.Clock().Now().Add(61*time.Second)) - c.Assert(err, check.IsNil) + invalidToken, err := totp.GenerateCode(otpSecret, tt.server.Clock().Now().Add(61*time.Second)) + require.NoError(t, err) err = clt.CheckPassword("user1", pass, invalidToken) - c.Assert(err, check.NotNil) + require.Error(t, err) // a valid token (created right now and from a valid key) should return success - validToken, err := totp.GenerateCode(otpSecret, s.server.Clock().Now()) - c.Assert(err, check.IsNil) + validToken, err := totp.GenerateCode(otpSecret, tt.server.Clock().Now()) + require.NoError(t, err) err = clt.CheckPassword("user1", pass, validToken) - c.Assert(err, check.IsNil) + require.NoError(t, err) // try the same valid token now it should fail because we don't allow re-use of tokens err = clt.CheckPassword("user1", pass, validToken) - c.Assert(err, check.NotNil) + require.Error(t, err) } // TestWebSessions tests web sessions flow for web user, // that logs in, extends web session and tries to perform administratvie action // but fails -func (s *TLSSuite) TestWebSessionWithoutAccessRequest(c *check.C) { - clt, err := s.server.NewClient(TestAdmin()) - c.Assert(err, check.IsNil) +func TestWebSessionWithoutAccessRequest(t *testing.T) { + t.Parallel() + ctx := context.Background() + tt := setupAuthContext(ctx, t) + + clt, err := tt.server.NewClient(TestAdmin()) + require.NoError(t, err) user := "user1" pass := []byte("abc123") _, _, err = CreateUserAndRole(clt, user, []string{user}) - c.Assert(err, check.IsNil) + require.NoError(t, err) - proxy, err := s.server.NewClient(TestBuiltin(types.RoleProxy)) - c.Assert(err, check.IsNil) + proxy, err := tt.server.NewClient(TestBuiltin(types.RoleProxy)) + require.NoError(t, err) req := AuthenticateUserRequest{ Username: user, @@ -1200,44 +1275,44 @@ func (s *TLSSuite) TestWebSessionWithoutAccessRequest(c *check.C) { } // authentication attempt fails with no password set up _, err = proxy.AuthenticateWebUser(ctx, req) - fixtures.ExpectAccessDenied(c, err) + require.True(t, trace.IsAccessDenied(err)) err = clt.UpsertPassword(user, pass) - c.Assert(err, check.IsNil) + require.NoError(t, err) // success with password set up ws, err := proxy.AuthenticateWebUser(ctx, req) - c.Assert(err, check.IsNil) - c.Assert(ws, check.Not(check.Equals), "") + require.NoError(t, err) + require.NotEqual(t, ws, "") - web, err := s.server.NewClientFromWebSession(ws) - c.Assert(err, check.IsNil) + web, err := tt.server.NewClientFromWebSession(ws) + require.NoError(t, err) _, err = web.GetWebSessionInfo(ctx, user, ws.GetName()) - c.Assert(err, check.IsNil) + require.NoError(t, err) - new, err := web.ExtendWebSession(ctx, WebSessionReq{ + ns, err := web.ExtendWebSession(ctx, WebSessionReq{ User: user, PrevSessionID: ws.GetName(), }) - c.Assert(err, check.IsNil) - c.Assert(new, check.NotNil) + require.NoError(t, err) + require.NotNil(t, ns) // Requesting forbidden action for user fails err = web.DeleteUser(ctx, user) - fixtures.ExpectAccessDenied(c, err) + require.True(t, trace.IsAccessDenied(err)) err = clt.DeleteWebSession(ctx, user, ws.GetName()) - c.Assert(err, check.IsNil) + require.NoError(t, err) _, err = web.GetWebSessionInfo(ctx, user, ws.GetName()) - c.Assert(err, check.NotNil) + require.Error(t, err) _, err = web.ExtendWebSession(ctx, WebSessionReq{ User: user, PrevSessionID: ws.GetName(), }) - c.Assert(err, check.NotNil) + require.Error(t, err) } func TestWebSessionMultiAccessRequests(t *testing.T) { @@ -1454,21 +1529,25 @@ func TestWebSessionMultiAccessRequests(t *testing.T) { } } -func (s *TLSSuite) TestWebSessionWithApprovedAccessRequestAndSwitchback(c *check.C) { - clt, err := s.server.NewClient(TestAdmin()) - c.Assert(err, check.IsNil) +func TestWebSessionWithApprovedAccessRequestAndSwitchback(t *testing.T) { + t.Parallel() + ctx := context.Background() + tt := setupAuthContext(ctx, t) + + clt, err := tt.server.NewClient(TestAdmin()) + require.NoError(t, err) user := "user2" pass := []byte("abc123") newUser, err := CreateUserRoleAndRequestable(clt, user, "test-request-role") - c.Assert(err, check.IsNil) - c.Assert(newUser.GetRoles(), check.HasLen, 1) - c.Assert(newUser.GetRoles(), check.DeepEquals, []string{"user:user2"}) + require.NoError(t, err) + require.Len(t, newUser.GetRoles(), 1) + require.Empty(t, cmp.Diff(newUser.GetRoles(), []string{"user:user2"})) - proxy, err := s.server.NewClient(TestBuiltin(types.RoleProxy)) - c.Assert(err, check.IsNil) + proxy, err := tt.server.NewClient(TestBuiltin(types.RoleProxy)) + require.NoError(t, err) // Create a user to create a web session for. req := AuthenticateUserRequest{ @@ -1479,45 +1558,45 @@ func (s *TLSSuite) TestWebSessionWithApprovedAccessRequestAndSwitchback(c *check } err = clt.UpsertPassword(user, pass) - c.Assert(err, check.IsNil) + require.NoError(t, err) ws, err := proxy.AuthenticateWebUser(ctx, req) - c.Assert(err, check.IsNil) + require.NoError(t, err) - web, err := s.server.NewClientFromWebSession(ws) - c.Assert(err, check.IsNil) + web, err := tt.server.NewClientFromWebSession(ws) + require.NoError(t, err) initialRole := newUser.GetRoles()[0] initialSession, err := web.GetWebSessionInfo(ctx, user, ws.GetName()) - c.Assert(err, check.IsNil) + require.NoError(t, err) // Create a approved access request. accessReq, err := services.NewAccessRequest(user, []string{"test-request-role"}...) - c.Assert(err, check.IsNil) + require.NoError(t, err) // Set a lesser expiry date, to test switching back to default expiration later. - accessReq.SetAccessExpiry(s.clock.Now().Add(time.Minute * 10)) + accessReq.SetAccessExpiry(tt.clock.Now().Add(time.Minute * 10)) accessReq.SetState(types.RequestState_APPROVED) err = clt.CreateAccessRequest(ctx, accessReq) - c.Assert(err, check.IsNil) + require.NoError(t, err) sess1, err := web.ExtendWebSession(ctx, WebSessionReq{ User: user, PrevSessionID: ws.GetName(), AccessRequestID: accessReq.GetMetadata().Name, }) - c.Assert(err, check.IsNil) - c.Assert(sess1.Expiry(), check.Equals, s.clock.Now().Add(time.Minute*10)) - c.Assert(sess1.GetLoginTime(), check.Equals, initialSession.GetLoginTime()) + require.NoError(t, err) + require.Equal(t, sess1.Expiry(), tt.clock.Now().Add(time.Minute*10)) + require.Equal(t, sess1.GetLoginTime(), initialSession.GetLoginTime()) sshcert, err := sshutils.ParseCertificate(sess1.GetPub()) - c.Assert(err, check.IsNil) + require.NoError(t, err) // Roles extracted from cert should contain the initial role and the role assigned with access request. roles, err := services.ExtractRolesFromCert(sshcert) - c.Assert(err, check.IsNil) - c.Assert(roles, check.HasLen, 2) + require.NoError(t, err) + require.Len(t, roles, 2) mappedRole := map[string]string{ roles[0]: "", @@ -1525,141 +1604,151 @@ func (s *TLSSuite) TestWebSessionWithApprovedAccessRequestAndSwitchback(c *check } _, hasRole := mappedRole[initialRole] - c.Assert(hasRole, check.Equals, true) + require.Equal(t, hasRole, true) _, hasRole = mappedRole["test-request-role"] - c.Assert(hasRole, check.Equals, true) + require.Equal(t, hasRole, true) // certRequests extracts the active requests from a PEM encoded TLS cert. certRequests := func(tlsCert []byte) []string { cert, err := tlsca.ParseCertificatePEM(tlsCert) - c.Assert(err, check.IsNil) + require.NoError(t, err) identity, err := tlsca.FromSubject(cert.Subject, cert.NotAfter) - c.Assert(err, check.IsNil) + require.NoError(t, err) return identity.ActiveRequests } - c.Assert(certRequests(sess1.GetTLSCert()), check.DeepEquals, []string{accessReq.GetName()}) + require.Empty(t, cmp.Diff(certRequests(sess1.GetTLSCert()), []string{accessReq.GetName()})) // Test switch back to default role and expiry. - sess2, err := web.ExtendWebSession(ctx, WebSessionReq{ + sess2, err := web.ExtendWebSession(context.TODO(), WebSessionReq{ User: user, PrevSessionID: ws.GetName(), Switchback: true, }) - c.Assert(err, check.IsNil) - c.Assert(sess2.GetExpiryTime(), check.Equals, initialSession.GetExpiryTime()) - c.Assert(sess2.GetLoginTime(), check.Equals, initialSession.GetLoginTime()) + require.NoError(t, err) + require.Equal(t, sess2.GetExpiryTime(), initialSession.GetExpiryTime()) + require.Equal(t, sess2.GetLoginTime(), initialSession.GetLoginTime()) sshcert, err = sshutils.ParseCertificate(sess2.GetPub()) - c.Assert(err, check.IsNil) + require.NoError(t, err) roles, err = services.ExtractRolesFromCert(sshcert) - c.Assert(err, check.IsNil) - c.Assert(roles, check.DeepEquals, []string{initialRole}) + require.NoError(t, err) + require.Empty(t, cmp.Diff(roles, []string{initialRole})) - c.Assert(certRequests(sess2.GetTLSCert()), check.HasLen, 0) + require.Len(t, certRequests(sess2.GetTLSCert()), 0) } // TestGetCertAuthority tests certificate authority permissions -func (s *TLSSuite) TestGetCertAuthority(c *check.C) { +func TestGetCertAuthority(t *testing.T) { + t.Parallel() + ctx := context.Background() + tt := setupAuthContext(ctx, t) + // generate server keys for node - nodeClt, err := s.server.NewClient(TestIdentity{I: BuiltinRole{Username: "00000000-0000-0000-0000-000000000000", Role: types.RoleNode}}) - c.Assert(err, check.IsNil) + nodeClt, err := tt.server.NewClient(TestIdentity{I: BuiltinRole{Username: "00000000-0000-0000-0000-000000000000", Role: types.RoleNode}}) + require.NoError(t, err) defer nodeClt.Close() // node is authorized to fetch CA without secrets ca, err := nodeClt.GetCertAuthority(ctx, types.CertAuthID{ - DomainName: s.server.ClusterName(), + DomainName: tt.server.ClusterName(), Type: types.HostCA, }, false) - c.Assert(err, check.IsNil) + require.NoError(t, err) for _, keyPair := range ca.GetActiveKeys().TLS { - c.Assert(keyPair.Key, check.IsNil) + fmt.Printf("--> keyPair.Key: %v.\n", keyPair) + require.Nil(t, keyPair.Key) } for _, keyPair := range ca.GetActiveKeys().SSH { - c.Assert(keyPair.PrivateKey, check.IsNil) + require.Nil(t, keyPair.PrivateKey) } // node is not authorized to fetch CA with secrets _, err = nodeClt.GetCertAuthority(ctx, types.CertAuthID{ - DomainName: s.server.ClusterName(), + DomainName: tt.server.ClusterName(), Type: types.HostCA, }, true) - fixtures.ExpectAccessDenied(c, err) + require.True(t, trace.IsAccessDenied(err)) // non-admin users are not allowed to get access to private key material user, err := types.NewUser("bob") - c.Assert(err, check.IsNil) + require.NoError(t, err) role := services.RoleForUser(user) role.SetLogins(types.Allow, []string{user.GetName()}) - err = s.server.Auth().UpsertRole(ctx, role) - c.Assert(err, check.IsNil) + err = tt.server.Auth().UpsertRole(ctx, role) + require.NoError(t, err) user.AddRole(role.GetName()) - err = s.server.Auth().UpsertUser(user) - c.Assert(err, check.IsNil) + err = tt.server.Auth().UpsertUser(user) + require.NoError(t, err) - userClt, err := s.server.NewClient(TestUser(user.GetName())) - c.Assert(err, check.IsNil) + userClt, err := tt.server.NewClient(TestUser(user.GetName())) + require.NoError(t, err) defer userClt.Close() // user is authorized to fetch CA without secrets _, err = userClt.GetCertAuthority(ctx, types.CertAuthID{ - DomainName: s.server.ClusterName(), + DomainName: tt.server.ClusterName(), Type: types.HostCA, }, false) - c.Assert(err, check.IsNil) + require.NoError(t, err) // user is not authorized to fetch CA with secrets _, err = userClt.GetCertAuthority(ctx, types.CertAuthID{ - DomainName: s.server.ClusterName(), + DomainName: tt.server.ClusterName(), Type: types.HostCA, }, true) - fixtures.ExpectAccessDenied(c, err) + require.True(t, trace.IsAccessDenied(err)) } -func (s *TLSSuite) TestPluginData(c *check.C) { +func TestPluginData(t *testing.T) { + t.Parallel() + + ctx := context.Background() + tt := setupAuthContext(ctx, t) + priv, pub, err := native.GenerateKeyPair() - c.Assert(err, check.IsNil) + require.NoError(t, err) // make sure we can parse the private and public key privateKey, err := ssh.ParseRawPrivateKey(priv) - c.Assert(err, check.IsNil) + require.NoError(t, err) _, err = tlsca.MarshalPublicKeyFromPrivateKeyPEM(privateKey) - c.Assert(err, check.IsNil) + require.NoError(t, err) _, _, _, _, err = ssh.ParseAuthorizedKey(pub) - c.Assert(err, check.IsNil) + require.NoError(t, err) user := "user1" role := "some-role" - _, err = CreateUserRoleAndRequestable(s.server.Auth(), user, role) - c.Assert(err, check.IsNil) + _, err = CreateUserRoleAndRequestable(tt.server.Auth(), user, role) + require.NoError(t, err) testUser := TestUser(user) testUser.TTL = time.Hour - userClient, err := s.server.NewClient(testUser) - c.Assert(err, check.IsNil) + userClient, err := tt.server.NewClient(testUser) + require.NoError(t, err) plugin := "my-plugin" - _, err = CreateAccessPluginUser(context.TODO(), s.server.Auth(), plugin) - c.Assert(err, check.IsNil) + _, err = CreateAccessPluginUser(context.TODO(), tt.server.Auth(), plugin) + require.NoError(t, err) pluginUser := TestUser(plugin) pluginUser.TTL = time.Hour - pluginClient, err := s.server.NewClient(pluginUser) - c.Assert(err, check.IsNil) + pluginClient, err := tt.server.NewClient(pluginUser) + require.NoError(t, err) req, err := services.NewAccessRequest(user, role) - c.Assert(err, check.IsNil) + require.NoError(t, err) - c.Assert(userClient.CreateAccessRequest(context.TODO(), req), check.IsNil) + require.NoError(t, userClient.CreateAccessRequest(context.TODO(), req)) err = pluginClient.UpdatePluginData(context.TODO(), types.PluginDataUpdateParams{ Kind: types.KindAccessRequest, @@ -1669,18 +1758,18 @@ func (s *TLSSuite) TestPluginData(c *check.C) { "foo": "bar", }, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) data, err := pluginClient.GetPluginData(context.TODO(), types.PluginDataFilter{ Kind: types.KindAccessRequest, Resource: req.GetName(), }) - c.Assert(err, check.IsNil) - c.Assert(len(data), check.Equals, 1) + require.NoError(t, err) + require.Equal(t, len(data), 1) entry, ok := data[0].Entries()[plugin] - c.Assert(ok, check.Equals, true) - c.Assert(entry.Data, check.DeepEquals, map[string]string{"foo": "bar"}) + require.Equal(t, ok, true) + require.Empty(t, cmp.Diff(entry.Data, map[string]string{"foo": "bar"})) err = pluginClient.UpdatePluginData(context.TODO(), types.PluginDataUpdateParams{ Kind: types.KindAccessRequest, @@ -1694,24 +1783,27 @@ func (s *TLSSuite) TestPluginData(c *check.C) { "foo": "bar", }, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) data, err = pluginClient.GetPluginData(context.TODO(), types.PluginDataFilter{ Kind: types.KindAccessRequest, Resource: req.GetName(), }) - c.Assert(err, check.IsNil) - c.Assert(len(data), check.Equals, 1) + require.NoError(t, err) + require.Equal(t, len(data), 1) entry, ok = data[0].Entries()[plugin] - c.Assert(ok, check.Equals, true) - c.Assert(entry.Data, check.DeepEquals, map[string]string{"spam": "eggs"}) + require.Equal(t, ok, true) + require.Empty(t, cmp.Diff(entry.Data, map[string]string{"spam": "eggs"})) } // TestGenerateCerts tests edge cases around authorization of // certificate generation for servers and users func TestGenerateCerts(t *testing.T) { + t.Parallel() + ctx := context.Background() + srv := newTestTLSServer(t) priv, pub, err := native.GenerateKeyPair() require.NoError(t, err) @@ -2098,45 +2190,48 @@ func TestGenerateCerts(t *testing.T) { // TestGenerateAppToken checks the identity of the caller and makes sure only // certain roles can request JWT tokens. -func (s *TLSSuite) TestGenerateAppToken(c *check.C) { - authClient, err := s.server.NewClient(TestBuiltin(types.RoleAdmin)) - c.Assert(err, check.IsNil) +func TestGenerateAppToken(t *testing.T) { + ctx := context.Background() + tt := setupAuthContext(ctx, t) + + authClient, err := tt.server.NewClient(TestBuiltin(types.RoleAdmin)) + require.NoError(t, err) ca, err := authClient.GetCertAuthority(context.Background(), types.CertAuthID{ Type: types.JWTSigner, - DomainName: s.server.ClusterName(), + DomainName: tt.server.ClusterName(), }, true) - c.Assert(err, check.IsNil) + require.NoError(t, err) - signer, err := s.server.AuthServer.AuthServer.GetKeyStore().GetJWTSigner(ca) - c.Assert(err, check.IsNil) - key, err := services.GetJWTSigner(signer, ca.GetClusterName(), s.clock) - c.Assert(err, check.IsNil) + signer, err := tt.server.AuthServer.AuthServer.GetKeyStore().GetJWTSigner(ca) + require.NoError(t, err) + key, err := services.GetJWTSigner(signer, ca.GetClusterName(), tt.clock) + require.NoError(t, err) tests := []struct { inMachineRole types.SystemRole - inComment check.CommentInterface + inComment string outError bool }{ { inMachineRole: types.RoleNode, - inComment: check.Commentf("nodes should not have the ability to generate tokens"), + inComment: "nodes should not have the ability to generate tokens", outError: true, }, { inMachineRole: types.RoleProxy, - inComment: check.Commentf("proxies should not have the ability to generate tokens"), + inComment: "proxies should not have the ability to generate tokens", outError: true, }, { inMachineRole: types.RoleApp, - inComment: check.Commentf("only apps should have the ability to generate tokens"), + inComment: "only apps should have the ability to generate tokens", outError: false, }, } - for _, tt := range tests { - client, err := s.server.NewClient(TestBuiltin(tt.inMachineRole)) - c.Assert(err, check.IsNil, tt.inComment) + for _, ts := range tests { + client, err := tt.server.NewClient(TestBuiltin(ts.inMachineRole)) + require.NoError(t, err, ts.inComment) token, err := client.GenerateAppToken( context.Background(), @@ -2144,42 +2239,44 @@ func (s *TLSSuite) TestGenerateAppToken(c *check.C) { Username: "foo@example.com", Roles: []string{"bar", "baz"}, URI: "http://localhost:8080", - Expires: s.clock.Now().Add(1 * time.Minute), + Expires: tt.clock.Now().Add(1 * time.Minute), }) - c.Assert(err != nil, check.Equals, tt.outError, tt.inComment) - if !tt.outError { + require.Equal(t, err != nil, ts.outError, ts.inComment) + if !ts.outError { claims, err := key.Verify(jwt.VerifyParams{ Username: "foo@example.com", RawToken: token, URI: "http://localhost:8080", }) - c.Assert(err, check.IsNil, tt.inComment) - c.Assert(claims.Username, check.Equals, "foo@example.com", tt.inComment) - c.Assert(claims.Roles, check.DeepEquals, []string{"bar", "baz"}, tt.inComment) + require.NoError(t, err, ts.inComment) + require.Equal(t, claims.Username, "foo@example.com", ts.inComment) + require.Empty(t, cmp.Diff(claims.Roles, []string{"bar", "baz"}), ts.inComment) } } } // TestCertificateFormat makes sure that certificates are generated with the // correct format. -func (s *TLSSuite) TestCertificateFormat(c *check.C) { +func TestCertificateFormat(t *testing.T) { ctx := context.Background() + tt := setupAuthContext(ctx, t) + priv, pub, err := native.GenerateKeyPair() - c.Assert(err, check.IsNil) + require.NoError(t, err) // make sure we can parse the private and public key _, err = ssh.ParsePrivateKey(priv) - c.Assert(err, check.IsNil) + require.NoError(t, err) _, _, _, _, err = ssh.ParseAuthorizedKey(pub) - c.Assert(err, check.IsNil) + require.NoError(t, err) // use admin client to create user and role - user, userRole, err := CreateUserAndRole(s.server.Auth(), "user", []string{"user"}) - c.Assert(err, check.IsNil) + user, userRole, err := CreateUserAndRole(tt.server.Auth(), "user", []string{"user"}) + require.NoError(t, err) pass := []byte("very secure password") - err = s.server.Auth().UpsertPassword(user.GetName(), pass) - c.Assert(err, check.IsNil) + err = tt.server.Auth().UpsertPassword(user.GetName(), pass) + require.NoError(t, err) tests := []struct { inRoleCertificateFormat string @@ -2200,15 +2297,15 @@ func (s *TLSSuite) TestCertificateFormat(c *check.C) { }, } - for _, tt := range tests { + for _, ts := range tests { roleOptions := userRole.GetOptions() - roleOptions.CertificateFormat = tt.inRoleCertificateFormat + roleOptions.CertificateFormat = ts.inRoleCertificateFormat userRole.SetOptions(roleOptions) - err := s.server.Auth().UpsertRole(ctx, userRole) - c.Assert(err, check.IsNil) + err := tt.server.Auth().UpsertRole(ctx, userRole) + require.NoError(t, err) - proxyClient, err := s.server.NewClient(TestBuiltin(types.RoleProxy)) - c.Assert(err, check.IsNil) + proxyClient, err := tt.server.NewClient(TestBuiltin(types.RoleProxy)) + require.NoError(t, err) // authentication attempt fails with password auth only re, err := proxyClient.AuthenticateSSHUser(ctx, AuthenticateSSHRequest{ @@ -2218,59 +2315,66 @@ func (s *TLSSuite) TestCertificateFormat(c *check.C) { Password: pass, }, }, - CompatibilityMode: tt.inClientCertificateFormat, + CompatibilityMode: ts.inClientCertificateFormat, TTL: apidefaults.CertDuration, PublicKey: pub, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) parsedCert, err := sshutils.ParseCertificate(re.Cert) - c.Assert(err, check.IsNil) + require.NoError(t, err) _, ok := parsedCert.Extensions[teleport.CertExtensionTeleportRoles] - c.Assert(ok, check.Equals, tt.outCertContainsRole) + require.Equal(t, ok, ts.outCertContainsRole) } } // TestClusterConfigContext checks that the cluster configuration gets passed // along in the context and permissions get updated accordingly. -func (s *TLSSuite) TestClusterConfigContext(c *check.C) { +func TestClusterConfigContext(t *testing.T) { + t.Parallel() + ctx := context.Background() + tt := setupAuthContext(ctx, t) - proxy, err := s.server.NewClient(TestBuiltin(types.RoleProxy)) - c.Assert(err, check.IsNil) + proxy, err := tt.server.NewClient(TestBuiltin(types.RoleProxy)) + require.NoError(t, err) _, pub, err := native.GenerateKeyPair() - c.Assert(err, check.IsNil) + require.NoError(t, err) // try and generate a host cert, this should fail because we are recording // at the nodes not at the proxy _, err = proxy.GenerateHostCert(pub, "a", "b", nil, "localhost", types.RoleProxy, 0) - fixtures.ExpectAccessDenied(c, err) + require.True(t, trace.IsAccessDenied(err)) // update cluster config to record at the proxy recConfig, err := types.NewSessionRecordingConfigFromConfigFile(types.SessionRecordingConfigSpecV2{ Mode: types.RecordAtProxy, }) - c.Assert(err, check.IsNil) - err = s.server.Auth().SetSessionRecordingConfig(ctx, recConfig) - c.Assert(err, check.IsNil) + require.NoError(t, err) + err = tt.server.Auth().SetSessionRecordingConfig(ctx, recConfig) + require.NoError(t, err) // try and generate a host cert, now the proxy should be able to generate a // host cert because it's in recording mode. _, err = proxy.GenerateHostCert(pub, "a", "b", nil, "localhost", types.RoleProxy, 0) - c.Assert(err, check.IsNil) + require.NoError(t, err) } // TestAuthenticateWebUserOTP tests web authentication flow for password + OTP -func (s *TLSSuite) TestAuthenticateWebUserOTP(c *check.C) { +func TestAuthenticateWebUserOTP(t *testing.T) { + t.Parallel() + ctx := context.Background() - clt, err := s.server.NewClient(TestAdmin()) - c.Assert(err, check.IsNil) + tt := setupAuthContext(ctx, t) + + clt, err := tt.server.NewClient(TestAdmin()) + require.NoError(t, err) user := "ws-test" pass := []byte("ws-abc123") @@ -2278,44 +2382,44 @@ func (s *TLSSuite) TestAuthenticateWebUserOTP(c *check.C) { otpSecret := base32.StdEncoding.EncodeToString([]byte(rawSecret)) _, _, err = CreateUserAndRole(clt, user, []string{user}) - c.Assert(err, check.IsNil) + require.NoError(t, err) - err = s.server.Auth().UpsertPassword(user, pass) - c.Assert(err, check.IsNil) + err = tt.server.Auth().UpsertPassword(user, pass) + require.NoError(t, err) - dev, err := services.NewTOTPDevice("otp", otpSecret, s.clock.Now()) - c.Assert(err, check.IsNil) - err = s.server.Auth().UpsertMFADevice(ctx, user, dev) - c.Assert(err, check.IsNil) + dev, err := services.NewTOTPDevice("otp", otpSecret, tt.clock.Now()) + require.NoError(t, err) + err = tt.server.Auth().UpsertMFADevice(ctx, user, dev) + require.NoError(t, err) // create a valid otp token - validToken, err := totp.GenerateCode(otpSecret, s.clock.Now()) - c.Assert(err, check.IsNil) + validToken, err := totp.GenerateCode(otpSecret, tt.clock.Now()) + require.NoError(t, err) - proxy, err := s.server.NewClient(TestBuiltin(types.RoleProxy)) - c.Assert(err, check.IsNil) + proxy, err := tt.server.NewClient(TestBuiltin(types.RoleProxy)) + require.NoError(t, err) authPreference, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorOTP, }) - c.Assert(err, check.IsNil) - err = s.server.Auth().SetAuthPreference(ctx, authPreference) - c.Assert(err, check.IsNil) + require.NoError(t, err) + err = tt.server.Auth().SetAuthPreference(ctx, authPreference) + require.NoError(t, err) // authentication attempt fails with wrong password _, err = proxy.AuthenticateWebUser(ctx, AuthenticateUserRequest{ Username: user, OTP: &OTPCreds{Password: []byte("wrong123"), Token: validToken}, }) - fixtures.ExpectAccessDenied(c, err) + require.True(t, trace.IsAccessDenied(err)) // authentication attempt fails with wrong otp _, err = proxy.AuthenticateWebUser(ctx, AuthenticateUserRequest{ Username: user, OTP: &OTPCreds{Password: pass, Token: "wrong123"}, }) - fixtures.ExpectAccessDenied(c, err) + require.True(t, trace.IsAccessDenied(err)) // authentication attempt fails with password auth only _, err = proxy.AuthenticateWebUser(ctx, AuthenticateUserRequest{ @@ -2324,46 +2428,50 @@ func (s *TLSSuite) TestAuthenticateWebUserOTP(c *check.C) { Password: pass, }, }) - fixtures.ExpectAccessDenied(c, err) + require.True(t, trace.IsAccessDenied(err)) // authentication succeeds ws, err := proxy.AuthenticateWebUser(ctx, AuthenticateUserRequest{ Username: user, OTP: &OTPCreds{Password: pass, Token: validToken}, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) - userClient, err := s.server.NewClientFromWebSession(ws) - c.Assert(err, check.IsNil) + userClient, err := tt.server.NewClientFromWebSession(ws) + require.NoError(t, err) _, err = userClient.GetWebSessionInfo(ctx, user, ws.GetName()) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = clt.DeleteWebSession(ctx, user, ws.GetName()) - c.Assert(err, check.IsNil) + require.NoError(t, err) _, err = userClient.GetWebSessionInfo(ctx, user, ws.GetName()) - c.Assert(err, check.NotNil) + require.Error(t, err) } // TestLoginAttempts makes sure the login attempt counter is incremented and // reset correctly. -func (s *TLSSuite) TestLoginAttempts(c *check.C) { - clt, err := s.server.NewClient(TestAdmin()) - c.Assert(err, check.IsNil) +func TestLoginAttempts(t *testing.T) { + t.Parallel() + ctx := context.Background() + tt := setupAuthContext(ctx, t) + + clt, err := tt.server.NewClient(TestAdmin()) + require.NoError(t, err) user := "user1" pass := []byte("abc123") _, _, err = CreateUserAndRole(clt, user, []string{user}) - c.Assert(err, check.IsNil) + require.NoError(t, err) - proxy, err := s.server.NewClient(TestBuiltin(types.RoleProxy)) - c.Assert(err, check.IsNil) + proxy, err := tt.server.NewClient(TestBuiltin(types.RoleProxy)) + require.NoError(t, err) err = clt.UpsertPassword(user, pass) - c.Assert(err, check.IsNil) + require.NoError(t, err) req := AuthenticateUserRequest{ Username: user, @@ -2373,112 +2481,120 @@ func (s *TLSSuite) TestLoginAttempts(c *check.C) { } // authentication attempt fails with bad password _, err = proxy.AuthenticateWebUser(ctx, req) - fixtures.ExpectAccessDenied(c, err) + require.True(t, trace.IsAccessDenied(err)) // creates first failed login attempt - loginAttempts, err := s.server.Auth().GetUserLoginAttempts(user) - c.Assert(err, check.IsNil) - c.Assert(loginAttempts, check.HasLen, 1) + loginAttempts, err := tt.server.Auth().GetUserLoginAttempts(user) + require.NoError(t, err) + require.Len(t, loginAttempts, 1) // try second time with wrong pass req.Pass.Password = pass _, err = proxy.AuthenticateWebUser(ctx, req) - c.Assert(err, check.IsNil) + require.NoError(t, err) // clears all failed attempts after success - loginAttempts, err = s.server.Auth().GetUserLoginAttempts(user) - c.Assert(err, check.IsNil) - c.Assert(loginAttempts, check.HasLen, 0) + loginAttempts, err = tt.server.Auth().GetUserLoginAttempts(user) + require.NoError(t, err) + require.Len(t, loginAttempts, 0) } -func (s *TLSSuite) TestChangeUserAuthentication(c *check.C) { +func TestChangeUserAuthenticationSettings(t *testing.T) { + t.Parallel() + ctx := context.Background() + tt := setupAuthContext(ctx, t) + authPref, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ AllowLocalAuth: types.NewBoolOption(true), }) - c.Assert(err, check.IsNil) + require.NoError(t, err) - err = s.server.Auth().SetAuthPreference(ctx, authPref) - c.Assert(err, check.IsNil) + err = tt.server.Auth().SetAuthPreference(ctx, authPref) + require.NoError(t, err) authPreference, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorOTP, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) - err = s.server.Auth().SetAuthPreference(ctx, authPreference) - c.Assert(err, check.IsNil) + err = tt.server.Auth().SetAuthPreference(ctx, authPreference) + require.NoError(t, err) username := "user1" // Create a local user. - clt, err := s.server.NewClient(TestAdmin()) - c.Assert(err, check.IsNil) + clt, err := tt.server.NewClient(TestAdmin()) + require.NoError(t, err) _, _, err = CreateUserAndRole(clt, username, []string{"role1"}) - c.Assert(err, check.IsNil) + require.NoError(t, err) - token, err := s.server.Auth().CreateResetPasswordToken(ctx, CreateUserTokenRequest{ + token, err := tt.server.Auth().CreateResetPasswordToken(ctx, CreateUserTokenRequest{ Name: username, TTL: time.Hour, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) - res, err := s.server.Auth().CreateRegisterChallenge(ctx, &proto.CreateRegisterChallengeRequest{ + res, err := tt.server.Auth().CreateRegisterChallenge(ctx, &proto.CreateRegisterChallengeRequest{ TokenID: token.GetName(), DeviceType: proto.DeviceType_DEVICE_TYPE_TOTP, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) - otpToken, err := totp.GenerateCode(res.GetTOTP().GetSecret(), s.server.Clock().Now()) - c.Assert(err, check.IsNil) + otpToken, err := totp.GenerateCode(res.GetTOTP().GetSecret(), tt.server.Clock().Now()) + require.NoError(t, err) - _, err = s.server.Auth().ChangeUserAuthentication(ctx, &proto.ChangeUserAuthenticationRequest{ + _, err = tt.server.Auth().ChangeUserAuthentication(ctx, &proto.ChangeUserAuthenticationRequest{ TokenID: token.GetName(), NewPassword: []byte("qweqweqwe"), NewMFARegisterResponse: &proto.MFARegisterResponse{Response: &proto.MFARegisterResponse_TOTP{ TOTP: &proto.TOTPRegisterResponse{Code: otpToken}, }}, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) } // TestLoginNoLocalAuth makes sure that logins for local accounts can not be // performed when local auth is disabled. -func (s *TLSSuite) TestLoginNoLocalAuth(c *check.C) { +func TestLoginNoLocalAuth(t *testing.T) { + t.Parallel() + ctx := context.Background() + tt := setupAuthContext(ctx, t) + user := "foo" pass := []byte("barbaz") // Create a local user. - clt, err := s.server.NewClient(TestAdmin()) - c.Assert(err, check.IsNil) + clt, err := tt.server.NewClient(TestAdmin()) + require.NoError(t, err) _, _, err = CreateUserAndRole(clt, user, []string{user}) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = clt.UpsertPassword(user, pass) - c.Assert(err, check.IsNil) + require.NoError(t, err) // Set auth preference to disallow local auth. authPref, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ AllowLocalAuth: types.NewBoolOption(false), }) - c.Assert(err, check.IsNil) - err = s.server.Auth().SetAuthPreference(ctx, authPref) - c.Assert(err, check.IsNil) + require.NoError(t, err) + err = tt.server.Auth().SetAuthPreference(ctx, authPref) + require.NoError(t, err) // Make sure access is denied for web login. - _, err = s.server.Auth().AuthenticateWebUser(ctx, AuthenticateUserRequest{ + _, err = tt.server.Auth().AuthenticateWebUser(ctx, AuthenticateUserRequest{ Username: user, Pass: &PassCreds{ Password: pass, }, }) - fixtures.ExpectAccessDenied(c, err) + require.True(t, trace.IsAccessDenied(err)) // Make sure access is denied for SSH login. _, pub, err := native.GenerateKeyPair() - c.Assert(err, check.IsNil) - _, err = s.server.Auth().AuthenticateSSHUser(ctx, AuthenticateSSHRequest{ + require.NoError(t, err) + _, err = tt.server.Auth().AuthenticateSSHUser(ctx, AuthenticateSSHRequest{ AuthenticateUserRequest: AuthenticateUserRequest{ Username: user, Pass: &PassCreds{ @@ -2487,19 +2603,22 @@ func (s *TLSSuite) TestLoginNoLocalAuth(c *check.C) { }, PublicKey: pub, }) - fixtures.ExpectAccessDenied(c, err) + require.True(t, trace.IsAccessDenied(err)) } // TestCipherSuites makes sure that clients with invalid cipher suites can // not connect. -func (s *TLSSuite) TestCipherSuites(c *check.C) { - otherServer, err := s.server.AuthServer.NewTestTLSServer() - c.Assert(err, check.IsNil) +func TestCipherSuites(t *testing.T) { + ctx := context.Background() + tt := setupAuthContext(ctx, t) + + otherServer, err := tt.server.AuthServer.NewTestTLSServer() + require.NoError(t, err) defer otherServer.Close() // Create a client with ciphersuites that the server does not support. - tlsConfig, err := s.server.ClientTLSConfig(TestNop()) - c.Assert(err, check.IsNil) + tlsConfig, err := tt.server.ClientTLSConfig(TestNop()) + require.NoError(t, err) tlsConfig.CipherSuites = []uint16{ tls.TLS_RSA_WITH_AES_128_CBC_SHA, tls.TLS_RSA_WITH_AES_256_CBC_SHA, @@ -2507,7 +2626,7 @@ func (s *TLSSuite) TestCipherSuites(c *check.C) { addrs := []string{ otherServer.Addr().String(), - s.server.Addr().String(), + tt.server.Addr().String(), } client, err := NewClient(client.Config{ Addrs: addrs, @@ -2516,26 +2635,30 @@ func (s *TLSSuite) TestCipherSuites(c *check.C) { }, CircuitBreakerConfig: breaker.NoopBreakerConfig(), }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // Requests should fail. _, err = client.GetClusterName() - c.Assert(err, check.NotNil) + require.Error(t, err) } // TestTLSFailover tests HTTP client failover between two tls servers -func (s *TLSSuite) TestTLSFailover(c *check.C) { - otherServer, err := s.server.AuthServer.NewTestTLSServer() - c.Assert(err, check.IsNil) - defer otherServer.Close() +func TestTLSFailover(t *testing.T) { + t.Parallel() + ctx := context.Background() + tt := setupAuthContext(ctx, t) - tlsConfig, err := s.server.ClientTLSConfig(TestNop()) - c.Assert(err, check.IsNil) + otherServer, err := tt.server.AuthServer.NewTestTLSServer() + require.NoError(t, err) + defer otherServer.Close() + + tlsConfig, err := tt.server.ClientTLSConfig(TestNop()) + require.NoError(t, err) addrs := []string{ otherServer.Addr().String(), - s.server.Addr().String(), + tt.server.Addr().String(), } client, err := NewClient(client.Config{ Addrs: addrs, @@ -2544,57 +2667,61 @@ func (s *TLSSuite) TestTLSFailover(c *check.C) { }, CircuitBreakerConfig: breaker.NoopBreakerConfig(), }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // couple of runs to get enough connections for i := 0; i < 4; i++ { _, err = client.Get(ctx, client.Endpoint("not", "exist"), url.Values{}) - fixtures.ExpectNotFound(c, err) + require.True(t, trace.IsNotFound(err)) } // stop the server to get response err = otherServer.Stop() - c.Assert(err, check.IsNil) + require.NoError(t, err) // client detects closed sockets and reconnect to the backup server for i := 0; i < 4; i++ { _, err = client.Get(ctx, client.Endpoint("not", "exist"), url.Values{}) - fixtures.ExpectNotFound(c, err) + require.True(t, trace.IsNotFound(err)) } } // TestRegisterCAPin makes sure that registration only works with a valid // CA pin. -func (s *TLSSuite) TestRegisterCAPin(c *check.C) { +func TestRegisterCAPin(t *testing.T) { + t.Parallel() + ctx := context.Background() + tt := setupAuthContext(ctx, t) + // Generate a token to use. - token, err := s.server.AuthServer.AuthServer.GenerateToken(ctx, &proto.GenerateTokenRequest{ + token, err := tt.server.AuthServer.AuthServer.GenerateToken(ctx, &proto.GenerateTokenRequest{ Roles: types.SystemRoles{ types.RoleProxy, }, TTL: proto.Duration(time.Hour), }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // Generate public and private keys for node. priv, pub, err := native.GenerateKeyPair() - c.Assert(err, check.IsNil) + require.NoError(t, err) privateKey, err := ssh.ParseRawPrivateKey(priv) - c.Assert(err, check.IsNil) + require.NoError(t, err) pubTLS, err := tlsca.MarshalPublicKeyFromPrivateKeyPEM(privateKey) - c.Assert(err, check.IsNil) + require.NoError(t, err) // Calculate what CA pin should be. - localCAResponse, err := s.server.AuthServer.AuthServer.GetClusterCACert(ctx) - c.Assert(err, check.IsNil) + localCAResponse, err := tt.server.AuthServer.AuthServer.GetClusterCACert(ctx) + require.NoError(t, err) caPins, err := tlsca.CalculatePins(localCAResponse.TLSCA) - c.Assert(err, check.IsNil) - c.Assert(caPins, check.HasLen, 1) + require.NoError(t, err) + require.Len(t, caPins, 1) caPin := caPins[0] // Attempt to register with valid CA pin, should work. _, err = Register(RegisterParams{ - Servers: []utils.NetAddr{utils.FromAddr(s.server.Addr())}, + Servers: []utils.NetAddr{utils.FromAddr(tt.server.Addr())}, Token: token, ID: IdentityID{ HostUUID: "once", @@ -2605,14 +2732,14 @@ func (s *TLSSuite) TestRegisterCAPin(c *check.C) { PublicSSHKey: pub, PublicTLSKey: pubTLS, CAPins: []string{caPin}, - Clock: s.clock, + Clock: tt.clock, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // Attempt to register with multiple CA pins where the auth server only // matches one, should work. _, err = Register(RegisterParams{ - Servers: []utils.NetAddr{utils.FromAddr(s.server.Addr())}, + Servers: []utils.NetAddr{utils.FromAddr(tt.server.Addr())}, Token: token, ID: IdentityID{ HostUUID: "once", @@ -2623,13 +2750,13 @@ func (s *TLSSuite) TestRegisterCAPin(c *check.C) { PublicSSHKey: pub, PublicTLSKey: pubTLS, CAPins: []string{"sha256:123", caPin}, - Clock: s.clock, + Clock: tt.clock, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // Attempt to register with invalid CA pin, should fail. _, err = Register(RegisterParams{ - Servers: []utils.NetAddr{utils.FromAddr(s.server.Addr())}, + Servers: []utils.NetAddr{utils.FromAddr(tt.server.Addr())}, Token: token, ID: IdentityID{ HostUUID: "once", @@ -2640,13 +2767,13 @@ func (s *TLSSuite) TestRegisterCAPin(c *check.C) { PublicSSHKey: pub, PublicTLSKey: pubTLS, CAPins: []string{"sha256:123"}, - Clock: s.clock, + Clock: tt.clock, }) - c.Assert(err, check.NotNil) + require.Error(t, err) // Attempt to register with multiple invalid CA pins, should fail. _, err = Register(RegisterParams{ - Servers: []utils.NetAddr{utils.FromAddr(s.server.Addr())}, + Servers: []utils.NetAddr{utils.FromAddr(tt.server.Addr())}, Token: token, ID: IdentityID{ HostUUID: "once", @@ -2657,32 +2784,32 @@ func (s *TLSSuite) TestRegisterCAPin(c *check.C) { PublicSSHKey: pub, PublicTLSKey: pubTLS, CAPins: []string{"sha256:123", "sha256:456"}, - Clock: s.clock, + Clock: tt.clock, }) - c.Assert(err, check.NotNil) + require.Error(t, err) // Add another cert to the CA (dupe the current one for simplicity) - hostCA, err := s.server.AuthServer.AuthServer.GetCertAuthority(ctx, types.CertAuthID{ - DomainName: s.server.AuthServer.ClusterName, + hostCA, err := tt.server.AuthServer.AuthServer.GetCertAuthority(ctx, types.CertAuthID{ + DomainName: tt.server.AuthServer.ClusterName, Type: types.HostCA, }, true) - c.Assert(err, check.IsNil) + require.NoError(t, err) activeKeys := hostCA.GetActiveKeys() activeKeys.TLS = append(activeKeys.TLS, activeKeys.TLS...) hostCA.SetActiveKeys(activeKeys) - err = s.server.AuthServer.AuthServer.UpsertCertAuthority(hostCA) - c.Assert(err, check.IsNil) + err = tt.server.AuthServer.AuthServer.UpsertCertAuthority(hostCA) + require.NoError(t, err) // Calculate what CA pins should be. - localCAResponse, err = s.server.AuthServer.AuthServer.GetClusterCACert(ctx) - c.Assert(err, check.IsNil) + localCAResponse, err = tt.server.AuthServer.AuthServer.GetClusterCACert(ctx) + require.NoError(t, err) caPins, err = tlsca.CalculatePins(localCAResponse.TLSCA) - c.Assert(err, check.IsNil) - c.Assert(caPins, check.HasLen, 2) + require.NoError(t, err) + require.Len(t, caPins, 2) // Attempt to register with multiple CA pins, should work _, err = Register(RegisterParams{ - Servers: []utils.NetAddr{utils.FromAddr(s.server.Addr())}, + Servers: []utils.NetAddr{utils.FromAddr(tt.server.Addr())}, Token: token, ID: IdentityID{ HostUUID: "once", @@ -2693,35 +2820,39 @@ func (s *TLSSuite) TestRegisterCAPin(c *check.C) { PublicSSHKey: pub, PublicTLSKey: pubTLS, CAPins: caPins, - Clock: s.clock, + Clock: tt.clock, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) } // TestRegisterCAPath makes sure registration only works with a valid CA // file on disk. -func (s *TLSSuite) TestRegisterCAPath(c *check.C) { +func TestRegisterCAPath(t *testing.T) { + t.Parallel() + ctx := context.Background() + tt := setupAuthContext(ctx, t) + // Generate a token to use. - token, err := s.server.AuthServer.AuthServer.GenerateToken(ctx, &proto.GenerateTokenRequest{ + token, err := tt.server.AuthServer.AuthServer.GenerateToken(ctx, &proto.GenerateTokenRequest{ Roles: types.SystemRoles{ types.RoleProxy, }, TTL: proto.Duration(time.Hour), }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // Generate public and private keys for node. priv, pub, err := native.GenerateKeyPair() - c.Assert(err, check.IsNil) + require.NoError(t, err) privateKey, err := ssh.ParseRawPrivateKey(priv) - c.Assert(err, check.IsNil) + require.NoError(t, err) pubTLS, err := tlsca.MarshalPublicKeyFromPrivateKeyPEM(privateKey) - c.Assert(err, check.IsNil) + require.NoError(t, err) // Attempt to register with nothing at the CA path, should work. _, err = Register(RegisterParams{ - Servers: []utils.NetAddr{utils.FromAddr(s.server.Addr())}, + Servers: []utils.NetAddr{utils.FromAddr(tt.server.Addr())}, Token: token, ID: IdentityID{ HostUUID: "once", @@ -2731,26 +2862,26 @@ func (s *TLSSuite) TestRegisterCAPath(c *check.C) { AdditionalPrincipals: []string{"example.com"}, PublicSSHKey: pub, PublicTLSKey: pubTLS, - Clock: s.clock, + Clock: tt.clock, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // Extract the root CA public key and write it out to the data dir. - hostCA, err := s.server.AuthServer.AuthServer.GetCertAuthority(ctx, types.CertAuthID{ - DomainName: s.server.AuthServer.ClusterName, + hostCA, err := tt.server.AuthServer.AuthServer.GetCertAuthority(ctx, types.CertAuthID{ + DomainName: tt.server.AuthServer.ClusterName, Type: types.HostCA, }, false) - c.Assert(err, check.IsNil) + require.NoError(t, err) certs := services.GetTLSCerts(hostCA) - c.Assert(certs, check.HasLen, 1) + require.Len(t, certs, 1) certPem := certs[0] - caPath := filepath.Join(s.dataDir, defaults.CACertFile) + caPath := filepath.Join(tt.dataDir, defaults.CACertFile) err = os.WriteFile(caPath, certPem, teleport.FileMaskOwnerOnly) - c.Assert(err, check.IsNil) + require.NoError(t, err) // Attempt to register with valid CA path, should work. _, err = Register(RegisterParams{ - Servers: []utils.NetAddr{utils.FromAddr(s.server.Addr())}, + Servers: []utils.NetAddr{utils.FromAddr(tt.server.Addr())}, Token: token, ID: IdentityID{ HostUUID: "once", @@ -2761,15 +2892,19 @@ func (s *TLSSuite) TestRegisterCAPath(c *check.C) { PublicSSHKey: pub, PublicTLSKey: pubTLS, CAPath: caPath, - Clock: s.clock, + Clock: tt.clock, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) } // TestEventsNodePresence tests streaming node presence API - // announcing node and keeping node alive -func (s *TLSSuite) TestEventsNodePresence(c *check.C) { +func TestEventsNodePresence(t *testing.T) { + t.Parallel() + ctx := context.Background() + tt := setupAuthContext(ctx, t) + node := &types.ServerV2{ Kind: types.KindNode, Version: types.V2, @@ -2782,21 +2917,21 @@ func (s *TLSSuite) TestEventsNodePresence(c *check.C) { }, } node.SetExpiry(time.Now().Add(2 * time.Second)) - clt, err := s.server.NewClient(TestIdentity{ + clt, err := tt.server.NewClient(TestIdentity{ I: BuiltinRole{ Role: types.RoleNode, - Username: fmt.Sprintf("%v.%v", node.Metadata.Name, s.server.ClusterName()), + Username: fmt.Sprintf("%v.%v", node.Metadata.Name, tt.server.ClusterName()), }, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) defer clt.Close() keepAlive, err := clt.UpsertNode(ctx, node) - c.Assert(err, check.IsNil) - c.Assert(keepAlive, check.NotNil) + require.NoError(t, err) + require.NotNil(t, keepAlive) keepAliver, err := clt.NewKeepAliver(ctx) - c.Assert(err, check.IsNil) + require.NoError(t, err) defer keepAliver.Close() keepAlive.Expires = time.Now().Add(2 * time.Second) @@ -2804,21 +2939,21 @@ func (s *TLSSuite) TestEventsNodePresence(c *check.C) { case keepAliver.KeepAlives() <- *keepAlive: // ok case <-time.After(time.Second): - c.Fatalf("time out sending keep ailve") + t.Fatalf("time out sending keep ailve") case <-keepAliver.Done(): - c.Fatalf("unknown problem sending keep ailve") + t.Fatalf("unknown problem sending keep ailve") } // upsert node and keep alives will fail for users with no privileges - nopClt, err := s.server.NewClient(TestBuiltin(types.RoleNop)) - c.Assert(err, check.IsNil) + nopClt, err := tt.server.NewClient(TestBuiltin(types.RoleNop)) + require.NoError(t, err) defer nopClt.Close() _, err = nopClt.UpsertNode(ctx, node) - fixtures.ExpectAccessDenied(c, err) + require.True(t, trace.IsAccessDenied(err)) k2, err := nopClt.NewKeepAliver(ctx) - c.Assert(err, check.IsNil) + require.NoError(t, err) keepAlive.Expires = time.Now().Add(2 * time.Second) go func() { @@ -2830,49 +2965,53 @@ func (s *TLSSuite) TestEventsNodePresence(c *check.C) { select { case <-time.After(time.Second): - c.Fatalf("time out expecting error") + t.Fatalf("time out expecting error") case <-k2.Done(): } - fixtures.ExpectAccessDenied(c, k2.Error()) + require.True(t, trace.IsAccessDenied(k2.Error())) } // TestEventsPermissions tests events with regards // to certificate authority rotation -func (s *TLSSuite) TestEventsPermissions(c *check.C) { - clt, err := s.server.NewClient(TestBuiltin(types.RoleNode)) - c.Assert(err, check.IsNil) +func TestEventsPermissions(t *testing.T) { + t.Parallel() + + ctx := context.Background() + tt := setupAuthContext(ctx, t) + + clt, err := tt.server.NewClient(TestBuiltin(types.RoleNode)) + require.NoError(t, err) defer clt.Close() - ctx := context.TODO() w, err := clt.NewWatcher(ctx, types.Watch{Kinds: []types.WatchKind{{Kind: types.KindCertAuthority}}}) - c.Assert(err, check.IsNil) + require.NoError(t, err) defer w.Close() select { case <-time.After(2 * time.Second): - c.Fatalf("Timeout waiting for init event") + t.Fatalf("Timeout waiting for init event") case event := <-w.Events(): - c.Assert(event.Type, check.Equals, types.OpInit) + require.Equal(t, event.Type, types.OpInit) } // start rotation gracePeriod := time.Hour - err = s.server.Auth().RotateCertAuthority(ctx, RotateRequest{ + err = tt.server.Auth().RotateCertAuthority(ctx, RotateRequest{ Type: types.HostCA, GracePeriod: &gracePeriod, TargetPhase: types.RotationPhaseInit, Mode: types.RotationModeManual, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) - ca, err := s.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ - DomainName: s.server.ClusterName(), + ca, err := tt.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ + DomainName: tt.server.ClusterName(), Type: types.HostCA, }, false) - c.Assert(err, check.IsNil) + require.NoError(t, err) - suite.ExpectResource(c, w, 3*time.Second, ca) + suite.ExpectResource(t, w, 3*time.Second, ca) type testCase struct { name string @@ -2921,14 +3060,14 @@ func (s *TLSSuite) TestEventsPermissions(c *check.C) { } tryWatch := func(tc testCase) { - client, err := s.server.NewClient(tc.identity) - c.Assert(err, check.IsNil) + client, err := tt.server.NewClient(tc.identity) + require.NoError(t, err) defer client.Close() watcher, err := client.NewWatcher(ctx, types.Watch{ Kinds: tc.watches, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) defer watcher.Close() go func() { @@ -2940,11 +3079,11 @@ func (s *TLSSuite) TestEventsPermissions(c *check.C) { select { case <-time.After(time.Second): - c.Fatalf("time out expecting error in test %q", tc.name) + t.Fatalf("time out expecting error in test %q", tc.name) case <-watcher.Done(): } - fixtures.ExpectAccessDenied(c, watcher.Error()) + require.True(t, trace.IsAccessDenied(watcher.Error())) } for _, tc := range testCases { @@ -2953,9 +3092,14 @@ func (s *TLSSuite) TestEventsPermissions(c *check.C) { } // TestEvents tests events suite -func (s *TLSSuite) TestEvents(c *check.C) { - clt, err := s.server.NewClient(TestAdmin()) - c.Assert(err, check.IsNil) +func TestEvents(t *testing.T) { + t.Parallel() + + ctx := context.Background() + tt := setupAuthContext(ctx, t) + + clt, err := tt.server.NewClient(TestAdmin()) + require.NoError(t, err) suite := &suite.ServicesTestSuite{ ConfigS: clt, @@ -2966,16 +3110,20 @@ func (s *TLSSuite) TestEvents(c *check.C) { Access: clt, UsersS: clt, } - suite.Events(c) + suite.Events(t) } // TestEventsClusterConfig test cluster configuration -func (s *TLSSuite) TestEventsClusterConfig(c *check.C) { - clt, err := s.server.NewClient(TestBuiltin(types.RoleAdmin)) - c.Assert(err, check.IsNil) +func TestEventsClusterConfig(t *testing.T) { + t.Parallel() + + ctx := context.Background() + tt := setupAuthContext(ctx, t) + + clt, err := tt.server.NewClient(TestBuiltin(types.RoleAdmin)) + require.NoError(t, err) defer clt.Close() - ctx := context.TODO() w, err := clt.NewWatcher(ctx, types.Watch{Kinds: []types.WatchKind{ {Kind: types.KindCertAuthority, LoadSecrets: true}, {Kind: types.KindStaticTokens}, @@ -2983,33 +3131,33 @@ func (s *TLSSuite) TestEventsClusterConfig(c *check.C) { {Kind: types.KindClusterAuditConfig}, {Kind: types.KindClusterName}, }}) - c.Assert(err, check.IsNil) + require.NoError(t, err) defer w.Close() select { case <-time.After(2 * time.Second): - c.Fatalf("Timeout waiting for init event") + t.Fatalf("Timeout waiting for init event") case event := <-w.Events(): - c.Assert(event.Type, check.Equals, types.OpInit) + require.Equal(t, event.Type, types.OpInit) } // start rotation gracePeriod := time.Hour - err = s.server.Auth().RotateCertAuthority(ctx, RotateRequest{ + err = tt.server.Auth().RotateCertAuthority(ctx, RotateRequest{ Type: types.HostCA, GracePeriod: &gracePeriod, TargetPhase: types.RotationPhaseInit, Mode: types.RotationModeManual, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) - ca, err := s.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ - DomainName: s.server.ClusterName(), + ca, err := tt.server.Auth().GetCertAuthority(ctx, types.CertAuthID{ + DomainName: tt.server.ClusterName(), Type: types.HostCA, }, true) - c.Assert(err, check.IsNil) + require.NoError(t, err) - suite.ExpectResource(c, w, 3*time.Second, ca) + suite.ExpectResource(t, w, 3*time.Second, ca) // set static tokens staticTokens, err := types.NewStaticTokens(types.StaticTokensSpecV2{ @@ -3021,32 +3169,32 @@ func (s *TLSSuite) TestEventsClusterConfig(c *check.C) { }, }, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) - err = s.server.Auth().SetStaticTokens(staticTokens) - c.Assert(err, check.IsNil) + err = tt.server.Auth().SetStaticTokens(staticTokens) + require.NoError(t, err) - staticTokens, err = s.server.Auth().GetStaticTokens() - c.Assert(err, check.IsNil) - suite.ExpectResource(c, w, 3*time.Second, staticTokens) + staticTokens, err = tt.server.Auth().GetStaticTokens() + require.NoError(t, err) + suite.ExpectResource(t, w, 3*time.Second, staticTokens) // create provision token and expect the update event token, err := types.NewProvisionToken( "tok2", types.SystemRoles{types.RoleProxy}, time.Now().UTC().Add(3*time.Hour)) - c.Assert(err, check.IsNil) + require.NoError(t, err) - err = s.server.Auth().UpsertToken(ctx, token) - c.Assert(err, check.IsNil) + err = tt.server.Auth().UpsertToken(ctx, token) + require.NoError(t, err) - token, err = s.server.Auth().GetToken(ctx, token.GetName()) - c.Assert(err, check.IsNil) + token, err = tt.server.Auth().GetToken(ctx, token.GetName()) + require.NoError(t, err) - suite.ExpectResource(c, w, 3*time.Second, token) + suite.ExpectResource(t, w, 3*time.Second, token) // delete token and expect delete event - err = s.server.Auth().DeleteToken(ctx, token.GetName()) - c.Assert(err, check.IsNil) - suite.ExpectDeleteResource(c, w, 3*time.Second, &types.ResourceHeader{ + err = tt.server.Auth().DeleteToken(ctx, token.GetName()) + require.NoError(t, err) + suite.ExpectDeleteResource(t, w, 3*time.Second, &types.ResourceHeader{ Kind: types.KindToken, Version: types.V2, Metadata: types.Metadata{ @@ -3059,17 +3207,17 @@ func (s *TLSSuite) TestEventsClusterConfig(c *check.C) { auditConfig, err := types.NewClusterAuditConfig(types.ClusterAuditConfigSpecV2{ AuditEventsURI: []string{"dynamodb://audit_table_name", "file:///home/log"}, }) - c.Assert(err, check.IsNil) - err = s.server.Auth().SetClusterAuditConfig(ctx, auditConfig) - c.Assert(err, check.IsNil) + require.NoError(t, err) + err = tt.server.Auth().SetClusterAuditConfig(ctx, auditConfig) + require.NoError(t, err) - auditConfigResource, err := s.server.Auth().GetClusterAuditConfig(ctx) - c.Assert(err, check.IsNil) - suite.ExpectResource(c, w, 3*time.Second, auditConfigResource) + auditConfigResource, err := tt.server.Auth().GetClusterAuditConfig(ctx) + require.NoError(t, err) + suite.ExpectResource(t, w, 3*time.Second, auditConfigResource) // update cluster name resource metadata - clusterNameResource, err := s.server.Auth().GetClusterName() - c.Assert(err, check.IsNil) + clusterNameResource, err := tt.server.Auth().GetClusterName() + require.NoError(t, err) // update the resource with different labels to test the change clusterName := &types.ClusterNameV2{ @@ -3085,28 +3233,33 @@ func (s *TLSSuite) TestEventsClusterConfig(c *check.C) { Spec: clusterNameResource.(*types.ClusterNameV2).Spec, } - err = s.server.Auth().DeleteClusterName() - c.Assert(err, check.IsNil) - err = s.server.Auth().SetClusterName(clusterName) - c.Assert(err, check.IsNil) + err = tt.server.Auth().DeleteClusterName() + require.NoError(t, err) + err = tt.server.Auth().SetClusterName(clusterName) + require.NoError(t, err) - clusterNameResource, err = s.server.Auth().ClusterConfiguration.GetClusterName() - c.Assert(err, check.IsNil) - suite.ExpectResource(c, w, 3*time.Second, clusterNameResource) + clusterNameResource, err = tt.server.Auth().ClusterConfiguration.GetClusterName() + require.NoError(t, err) + suite.ExpectResource(t, w, 3*time.Second, clusterNameResource) } -func (s *TLSSuite) TestNetworkRestrictions(c *check.C) { - clt, err := s.server.NewClient(TestAdmin()) - c.Assert(err, check.IsNil) +func TestNetworkRestrictions(t *testing.T) { + t.Parallel() + + ctx := context.Background() + tt := setupAuthContext(ctx, t) + + clt, err := tt.server.NewClient(TestAdmin()) + require.NoError(t, err) suite := &suite.ServicesTestSuite{ RestrictionsS: clt, } - suite.NetworkRestrictions(c) + suite.NetworkRestrictions(t) } // verifyJWT verifies that the token was signed by one the passed in key pair. -func (s *TLSSuite) verifyJWT(clock clockwork.Clock, clusterName string, pairs []*types.JWTKeyPair, token string) (*jwt.Claims, error) { +func verifyJWT(clock clockwork.Clock, clusterName string, pairs []*types.JWTKeyPair, token string) (*jwt.Claims, error) { errs := []error{} for _, pair := range pairs { publicKey, err := utils.ParsePublicKey(pair.PublicKey) diff --git a/lib/services/local/access_test.go b/lib/services/local/access_test.go index 284ca5e91a908..5d3d61770a5cb 100644 --- a/lib/services/local/access_test.go +++ b/lib/services/local/access_test.go @@ -24,18 +24,23 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/backend/lite" + "github.com/gravitational/teleport/lib/backend/memory" ) func TestLockCRUD(t *testing.T) { ctx := context.Background() - lite, err := lite.NewWithConfig(ctx, lite.Config{Path: t.TempDir()}) + + backend, err := memory.New(memory.Config{ + Context: ctx, + Clock: clockwork.NewFakeClock(), + }) require.NoError(t, err) - access := NewAccessService(lite) + access := NewAccessService(backend) lock1, err := types.NewLock("lock1", types.LockSpecV2{ Target: types.LockTarget{ diff --git a/lib/services/local/apps_test.go b/lib/services/local/apps_test.go index 5a9931c786e7f..3894c85ee8acd 100644 --- a/lib/services/local/apps_test.go +++ b/lib/services/local/apps_test.go @@ -21,7 +21,7 @@ import ( "testing" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/backend/lite" + "github.com/gravitational/teleport/lib/backend/memory" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -34,9 +34,9 @@ import ( func TestAppsCRUD(t *testing.T) { ctx := context.Background() - backend, err := lite.NewWithConfig(ctx, lite.Config{ - Path: t.TempDir(), - Clock: clockwork.NewFakeClock(), + backend, err := memory.New(memory.Config{ + Context: ctx, + Clock: clockwork.NewFakeClock(), }) require.NoError(t, err) diff --git a/lib/services/local/configuration_test.go b/lib/services/local/configuration_test.go index 26c0d664e371c..5b71cde7a34ad 100644 --- a/lib/services/local/configuration_test.go +++ b/lib/services/local/configuration_test.go @@ -18,102 +18,122 @@ package local import ( "context" + "testing" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/backend" - "github.com/gravitational/teleport/lib/backend/lite" - "github.com/gravitational/teleport/lib/fixtures" + "github.com/gravitational/teleport/lib/backend/memory" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/suite" - "gopkg.in/check.v1" + "github.com/google/go-cmp/cmp" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" "gopkg.in/yaml.v2" ) -type ClusterConfigurationSuite struct { +type configContext struct { bk backend.Backend } -var _ = check.Suite(&ClusterConfigurationSuite{}) +func setupConfigContext(ctx context.Context, t *testing.T) *configContext { + var tt configContext + t.Cleanup(func() { tt.Close() }) + + clock := clockwork.NewFakeClock() -func (s *ClusterConfigurationSuite) SetUpTest(c *check.C) { var err error - s.bk, err = lite.New(context.TODO(), backend.Params{"path": c.MkDir()}) - c.Assert(err, check.IsNil) + tt.bk, err = memory.New(memory.Config{ + Context: context.Background(), + Clock: clock, + }) + require.NoError(t, err) + + return &tt } -func (s *ClusterConfigurationSuite) TearDownTest(c *check.C) { - c.Assert(s.bk.Close(), check.IsNil) +func (tt *configContext) Close() error { + return tt.bk.Close() } -func (s *ClusterConfigurationSuite) TestAuthPreference(c *check.C) { - clusterConfig, err := NewClusterConfigurationService(s.bk) - c.Assert(err, check.IsNil) +func TestAuthPreference(t *testing.T) { + tt := setupConfigContext(context.Background(), t) + + clusterConfig, err := NewClusterConfigurationService(tt.bk) + require.NoError(t, err) suite := &suite.ServicesTestSuite{ ConfigS: clusterConfig, } - suite.AuthPreference(c) + suite.AuthPreference(t) } -func (s *ClusterConfigurationSuite) TestClusterName(c *check.C) { - clusterConfig, err := NewClusterConfigurationService(s.bk) - c.Assert(err, check.IsNil) +func TestClusterName(t *testing.T) { + tt := setupConfigContext(context.Background(), t) + + clusterConfig, err := NewClusterConfigurationService(tt.bk) + require.NoError(t, err) suite := &suite.ServicesTestSuite{ ConfigS: clusterConfig, } - suite.ClusterName(c) + suite.ClusterName(t) } -func (s *ClusterConfigurationSuite) TestClusterNetworkingConfig(c *check.C) { - clusterConfig, err := NewClusterConfigurationService(s.bk) - c.Assert(err, check.IsNil) +func TestClusterNetworkingConfig(t *testing.T) { + tt := setupConfigContext(context.Background(), t) + + clusterConfig, err := NewClusterConfigurationService(tt.bk) + require.NoError(t, err) suite := &suite.ServicesTestSuite{ ConfigS: clusterConfig, } - suite.ClusterNetworkingConfig(c) + suite.ClusterNetworkingConfig(t) } -func (s *ClusterConfigurationSuite) TestSessionRecordingConfig(c *check.C) { - clusterConfig, err := NewClusterConfigurationService(s.bk) - c.Assert(err, check.IsNil) +func TestSessionRecordingConfig(t *testing.T) { + tt := setupConfigContext(context.Background(), t) + + clusterConfig, err := NewClusterConfigurationService(tt.bk) + require.NoError(t, err) suite := &suite.ServicesTestSuite{ ConfigS: clusterConfig, } - suite.SessionRecordingConfig(c) + suite.SessionRecordingConfig(t) } -func (s *ClusterConfigurationSuite) TestStaticTokens(c *check.C) { - clusterConfig, err := NewClusterConfigurationService(s.bk) - c.Assert(err, check.IsNil) +func TestStaticTokens(t *testing.T) { + tt := setupConfigContext(context.Background(), t) + + clusterConfig, err := NewClusterConfigurationService(tt.bk) + require.NoError(t, err) suite := &suite.ServicesTestSuite{ ConfigS: clusterConfig, } - suite.StaticTokens(c) + suite.StaticTokens(t) } -func (s *ClusterConfigurationSuite) TestSessionRecording(c *check.C) { +func TestSessionRecording(t *testing.T) { // don't allow invalid session recording values _, err := types.NewSessionRecordingConfigFromConfigFile(types.SessionRecordingConfigSpecV2{ Mode: "foo", }) - c.Assert(err, check.NotNil) + require.Error(t, err) // default is to record at the node recConfig, err := types.NewSessionRecordingConfigFromConfigFile(types.SessionRecordingConfigSpecV2{}) - c.Assert(err, check.IsNil) - c.Assert(recConfig.GetMode(), check.Equals, types.RecordAtNode) + require.NoError(t, err) + require.Equal(t, recConfig.GetMode(), types.RecordAtNode) // update sessions to be recorded at the proxy and check again recConfig.SetMode(types.RecordAtProxy) - c.Assert(recConfig.GetMode(), check.Equals, types.RecordAtProxy) + require.Equal(t, recConfig.GetMode(), types.RecordAtProxy) } -func (s *ClusterConfigurationSuite) TestAuditConfig(c *check.C) { +func TestAuditConfig(t *testing.T) { testCases := []struct { spec types.ClusterAuditConfigSpecV2 config string @@ -150,22 +170,22 @@ audit_events_uri: 'dynamodb://audit_table_name' for _, tc := range testCases { in, err := types.NewClusterAuditConfig(tc.spec) - c.Assert(err, check.IsNil) + require.NoError(t, err) var data map[string]interface{} err = yaml.Unmarshal([]byte(tc.config), &data) - c.Assert(err, check.IsNil) + require.NoError(t, err) configSpec, err := services.ClusterAuditConfigSpecFromObject(data) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err := types.NewClusterAuditConfig(*configSpec) - c.Assert(err, check.IsNil) - fixtures.DeepCompare(c, out, in) + require.NoError(t, err) + require.Empty(t, cmp.Diff(out, in)) } } -func (s *ClusterConfigurationSuite) TestAuditConfigMarshal(c *check.C) { +func TestAuditConfigMarshal(t *testing.T) { // single audit_events uri value auditConfig, err := types.NewClusterAuditConfig(types.ClusterAuditConfigSpecV2{ Region: "us-west-1", @@ -173,14 +193,14 @@ func (s *ClusterConfigurationSuite) TestAuditConfigMarshal(c *check.C) { AuditSessionsURI: "file:///home/log", AuditEventsURI: []string{"dynamodb://audit_table_name"}, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) data, err := services.MarshalClusterAuditConfig(auditConfig) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err := services.UnmarshalClusterAuditConfig(data) - c.Assert(err, check.IsNil) - fixtures.DeepCompare(c, auditConfig, out) + require.NoError(t, err) + require.Empty(t, cmp.Diff(auditConfig, out)) // multiple events uri values auditConfig, err = types.NewClusterAuditConfig(types.ClusterAuditConfigSpecV2{ @@ -189,12 +209,12 @@ func (s *ClusterConfigurationSuite) TestAuditConfigMarshal(c *check.C) { AuditSessionsURI: "file:///home/log", AuditEventsURI: []string{"dynamodb://audit_table_name", "file:///home/test/log"}, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) data, err = services.MarshalClusterAuditConfig(auditConfig) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err = services.UnmarshalClusterAuditConfig(data) - c.Assert(err, check.IsNil) - fixtures.DeepCompare(c, auditConfig, out) + require.NoError(t, err) + require.Empty(t, cmp.Diff(auditConfig, out)) } diff --git a/lib/services/local/databases_test.go b/lib/services/local/databases_test.go index 35d46220df0c0..b08172f6a007e 100644 --- a/lib/services/local/databases_test.go +++ b/lib/services/local/databases_test.go @@ -21,7 +21,7 @@ import ( "testing" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/backend/lite" + "github.com/gravitational/teleport/lib/backend/memory" "github.com/gravitational/teleport/lib/defaults" "github.com/google/go-cmp/cmp" @@ -35,9 +35,9 @@ import ( func TestDatabasesCRUD(t *testing.T) { ctx := context.Background() - backend, err := lite.NewWithConfig(ctx, lite.Config{ - Path: t.TempDir(), - Clock: clockwork.NewFakeClock(), + backend, err := memory.New(memory.Config{ + Context: ctx, + Clock: clockwork.NewFakeClock(), }) require.NoError(t, err) diff --git a/lib/services/local/resource_test.go b/lib/services/local/resource_test.go index e9fff8cc28241..d2630d2b427a9 100644 --- a/lib/services/local/resource_test.go +++ b/lib/services/local/resource_test.go @@ -27,7 +27,7 @@ import ( apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/backend" - "github.com/gravitational/teleport/lib/backend/lite" + "github.com/gravitational/teleport/lib/backend/memory" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/suite" @@ -48,10 +48,9 @@ func (r *ResourceSuite) SetUpTest(c *check.C) { clock := clockwork.NewFakeClockAt(time.Now()) - r.bk, err = lite.NewWithConfig(context.TODO(), lite.Config{ - Path: c.MkDir(), - PollStreamPeriod: 200 * time.Millisecond, - Clock: clock, + r.bk, err = memory.New(memory.Config{ + Context: context.Background(), + Clock: clock, }) c.Assert(err, check.IsNil) } diff --git a/lib/services/local/services_test.go b/lib/services/local/services_test.go index 97d01330c847b..de9deea2005e8 100644 --- a/lib/services/local/services_test.go +++ b/lib/services/local/services_test.go @@ -20,14 +20,13 @@ import ( "context" "os" "testing" - "time" "github.com/gravitational/teleport/lib/backend" - "github.com/gravitational/teleport/lib/backend/lite" + "github.com/gravitational/teleport/lib/backend/memory" "github.com/gravitational/teleport/lib/services/suite" "github.com/gravitational/teleport/lib/utils" "github.com/jonboulle/clockwork" - "gopkg.in/check.v1" + "github.com/stretchr/testify/require" ) func TestMain(m *testing.M) { @@ -35,135 +34,77 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -type ServicesSuite struct { +type servicesContext struct { bk backend.Backend suite *suite.ServicesTestSuite } -var _ = check.Suite(&ServicesSuite{}) - -func (s *ServicesSuite) SetUpTest(c *check.C) { - var err error - ctx := context.Background() +func setupServicesContext(ctx context.Context, t *testing.T) *servicesContext { + var tt servicesContext + t.Cleanup(func() { tt.Close() }) clock := clockwork.NewFakeClock() - s.bk, err = lite.NewWithConfig(ctx, lite.Config{ - Path: c.MkDir(), - PollStreamPeriod: 200 * time.Millisecond, - Clock: clock, + var err error + tt.bk, err = memory.New(memory.Config{ + Clock: clock, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) - configService, err := NewClusterConfigurationService(s.bk) - c.Assert(err, check.IsNil) + configService, err := NewClusterConfigurationService(tt.bk) + require.NoError(t, err) - eventsService := NewEventsService(s.bk) - presenceService := NewPresenceService(s.bk) + eventsService := NewEventsService(tt.bk) + presenceService := NewPresenceService(tt.bk) - s.suite = &suite.ServicesTestSuite{ - CAS: NewCAService(s.bk), + tt.suite = &suite.ServicesTestSuite{ + CAS: NewCAService(tt.bk), PresenceS: presenceService, - ProvisioningS: NewProvisioningService(s.bk), - WebS: NewIdentityService(s.bk), - Access: NewAccessService(s.bk), + ProvisioningS: NewProvisioningService(tt.bk), + WebS: NewIdentityService(tt.bk), + Access: NewAccessService(tt.bk), EventsS: eventsService, ChangesC: make(chan interface{}), ConfigS: configService, - RestrictionsS: NewRestrictionsService(s.bk), + RestrictionsS: NewRestrictionsService(tt.bk), Clock: clock, } -} - -func (s *ServicesSuite) TearDownTest(c *check.C) { - c.Assert(s.bk.Close(), check.IsNil) -} - -func (s *ServicesSuite) TestUserCACRUD(c *check.C) { - s.suite.CertAuthCRUD(c) -} - -func (s *ServicesSuite) TestServerCRUD(c *check.C) { - s.suite.ServerCRUD(c) -} - -// TestAppServerCRUD tests CRUD functionality for services.App. -func (s *ServicesSuite) TestAppServerCRUD(c *check.C) { - s.suite.AppServerCRUD(c) -} - -func (s *ServicesSuite) TestReverseTunnelsCRUD(c *check.C) { - s.suite.ReverseTunnelsCRUD(c) -} - -func (s *ServicesSuite) TestUsersCRUD(c *check.C) { - s.suite.UsersCRUD(c) -} - -func (s *ServicesSuite) TestUsersExpiry(c *check.C) { - s.suite.UsersExpiry(c) -} - -func (s *ServicesSuite) TestLoginAttempts(c *check.C) { - s.suite.LoginAttempts(c) -} -func (s *ServicesSuite) TestPasswordHashCRUD(c *check.C) { - s.suite.PasswordHashCRUD(c) + return &tt } -func (s *ServicesSuite) TestWebSessionCRUD(c *check.C) { - s.suite.WebSessionCRUD(c) +func (tt *servicesContext) Close() error { + return tt.bk.Close() } -func (s *ServicesSuite) TestToken(c *check.C) { - s.suite.TokenCRUD(c) -} - -func (s *ServicesSuite) TestRoles(c *check.C) { - s.suite.RolesCRUD(c) -} - -func (s *ServicesSuite) TestSAMLCRUD(c *check.C) { - s.suite.SAMLCRUD(c) -} - -func (s *ServicesSuite) TestTunnelConnectionsCRUD(c *check.C) { - s.suite.TunnelConnectionsCRUD(c) -} - -func (s *ServicesSuite) TestGithubConnectorCRUD(c *check.C) { - s.suite.GithubConnectorCRUD(c) -} - -func (s *ServicesSuite) TestRemoteClustersCRUD(c *check.C) { - s.suite.RemoteClustersCRUD(c) -} - -func (s *ServicesSuite) TestEvents(c *check.C) { - s.suite.Events(c) -} - -func (s *ServicesSuite) TestEventsClusterConfig(c *check.C) { - s.suite.EventsClusterConfig(c) -} +func TestCRUD(t *testing.T) { + tt := setupServicesContext(context.Background(), t) -func (s *ServicesSuite) TestSemaphoreLock(c *check.C) { - s.suite.SemaphoreLock(c) + t.Run("TestUserCACRUD", tt.suite.CertAuthCRUD) + t.Run("TestServerCRUD", tt.suite.ServerCRUD) + t.Run("TestAppServerCRUD", tt.suite.AppServerCRUD) + t.Run("TestReverseTunnelsCRUD", tt.suite.ReverseTunnelsCRUD) + t.Run("TestUsersCRUD", tt.suite.UsersCRUD) + t.Run("TestUsersExpiry", tt.suite.UsersExpiry) + t.Run("TestLoginAttempts", tt.suite.LoginAttempts) + t.Run("TestPasswordHashCRUD", tt.suite.PasswordHashCRUD) + t.Run("TestWebSessionCRUD", tt.suite.WebSessionCRUD) + t.Run("TestToken", tt.suite.TokenCRUD) + t.Run("TestRoles", tt.suite.RolesCRUD) + t.Run("TestSAMLCRUD", tt.suite.SAMLCRUD) + t.Run("TestTunnelConnectionsCRUD", tt.suite.TunnelConnectionsCRUD) + t.Run("TestGithubConnectorCRUD", tt.suite.GithubConnectorCRUD) + t.Run("TestRemoteClustersCRUD", tt.suite.RemoteClustersCRUD) + t.Run("TestEvents", tt.suite.Events) + t.Run("TestEventsClusterConfig", tt.suite.EventsClusterConfig) + t.Run("TestNetworkRestrictions", func(t *testing.T) { tt.suite.NetworkRestrictions(t) }) } -func (s *ServicesSuite) TestSemaphoreConcurrency(c *check.C) { - s.suite.SemaphoreConcurrency(c) -} - -func (s *ServicesSuite) TestSemaphoreContention(c *check.C) { - s.suite.SemaphoreContention(c) -} - -func (s *ServicesSuite) TestSemaphoreFlakiness(c *check.C) { - s.suite.SemaphoreFlakiness(c) -} +func TestSemaphore(t *testing.T) { + tt := setupServicesContext(context.Background(), t) -func (s *ServicesSuite) TestNetworkRestrictions(c *check.C) { - s.suite.NetworkRestrictions(c) + t.Run("TestSemaphoreLock", tt.suite.SemaphoreLock) + t.Run("TestSemaphoreConcurrency", tt.suite.SemaphoreConcurrency) + t.Run("TestSemaphoreContention", tt.suite.SemaphoreContention) + t.Run("TestSemaphoreFlakiness", tt.suite.SemaphoreFlakiness) } diff --git a/lib/services/local/unstable_test.go b/lib/services/local/unstable_test.go index e917cc0787d5f..e6abec5e0b641 100644 --- a/lib/services/local/unstable_test.go +++ b/lib/services/local/unstable_test.go @@ -22,7 +22,8 @@ import ( "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/backend/lite" + "github.com/gravitational/teleport/lib/backend/memory" + "github.com/jonboulle/clockwork" "github.com/gravitational/trace" "github.com/stretchr/testify/require" @@ -35,12 +36,15 @@ func TestSystemRoleAssertions(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - lite, err := lite.NewWithConfig(ctx, lite.Config{Path: t.TempDir()}) + backend, err := memory.New(memory.Config{ + Context: ctx, + Clock: clockwork.NewFakeClock(), + }) require.NoError(t, err) - defer lite.Close() + defer backend.Close() - unstable := NewUnstableService(lite) + unstable := NewUnstableService(backend) _, err = unstable.GetSystemRoleAssertions(ctx, serverID, assertionID) require.True(t, trace.IsNotFound(err)) diff --git a/lib/services/local/users_test.go b/lib/services/local/users_test.go index d850592aa2898..9ac230d1d2ae5 100644 --- a/lib/services/local/users_test.go +++ b/lib/services/local/users_test.go @@ -27,7 +27,7 @@ import ( "github.com/google/uuid" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/backend" - "github.com/gravitational/teleport/lib/backend/lite" + "github.com/gravitational/teleport/lib/backend/memory" "github.com/gravitational/teleport/lib/services/local" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" @@ -38,10 +38,9 @@ import ( func newIdentityService(t *testing.T, clock clockwork.Clock) *local.IdentityService { t.Helper() - backend, err := lite.NewWithConfig(context.Background(), lite.Config{ - Path: t.TempDir(), - PollStreamPeriod: 200 * time.Millisecond, - Clock: clock, + backend, err := memory.New(memory.Config{ + Context: context.Background(), + Clock: clockwork.NewFakeClock(), }) require.NoError(t, err) return local.NewIdentityService(backend) diff --git a/lib/services/suite/presence_test.go b/lib/services/suite/presence_test.go index d6aa6c96839ad..c7cffe11e46d2 100644 --- a/lib/services/suite/presence_test.go +++ b/lib/services/suite/presence_test.go @@ -22,24 +22,18 @@ import ( "github.com/gravitational/teleport/api/types" - "gopkg.in/check.v1" + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/require" ) -func Test(t *testing.T) { check.TestingT(t) } - -type PresenceSuite struct { -} - -var _ = check.Suite(&PresenceSuite{}) - -func (s *PresenceSuite) TestServerLabels(c *check.C) { +func TestServerLabels(t *testing.T) { emptyLabels := make(map[string]string) // empty server := &types.ServerV2{} - c.Assert(server.GetAllLabels(), check.DeepEquals, emptyLabels) - c.Assert(server.LabelsString(), check.Equals, "") - c.Assert(server.MatchAgainst(emptyLabels), check.Equals, true) - c.Assert(server.MatchAgainst(map[string]string{"a": "b"}), check.Equals, false) + require.Empty(t, cmp.Diff(server.GetAllLabels(), emptyLabels)) + require.Equal(t, server.LabelsString(), "") + require.Equal(t, server.MatchAgainst(emptyLabels), true) + require.Equal(t, server.MatchAgainst(map[string]string{"a": "b"}), false) // more complex server = &types.ServerV2{ @@ -59,14 +53,15 @@ func (s *PresenceSuite) TestServerLabels(c *check.C) { }, } - c.Assert(server.GetAllLabels(), check.DeepEquals, map[string]string{ + require.Empty(t, cmp.Diff(server.GetAllLabels(), map[string]string{ "role": "database", "time": "now", - }) - c.Assert(server.LabelsString(), check.Equals, "role=database,time=now") - c.Assert(server.MatchAgainst(emptyLabels), check.Equals, true) - c.Assert(server.MatchAgainst(map[string]string{"a": "b"}), check.Equals, false) - c.Assert(server.MatchAgainst(map[string]string{"role": "database"}), check.Equals, true) - c.Assert(server.MatchAgainst(map[string]string{"time": "now"}), check.Equals, true) - c.Assert(server.MatchAgainst(map[string]string{"time": "now", "role": "database"}), check.Equals, true) + })) + + require.Equal(t, server.LabelsString(), "role=database,time=now") + require.Equal(t, server.MatchAgainst(emptyLabels), true) + require.Equal(t, server.MatchAgainst(map[string]string{"a": "b"}), false) + require.Equal(t, server.MatchAgainst(map[string]string{"role": "database"}), true) + require.Equal(t, server.MatchAgainst(map[string]string{"time": "now"}), true) + require.Equal(t, server.MatchAgainst(map[string]string{"time": "now", "role": "database"}), true) } diff --git a/lib/services/suite/suite.go b/lib/services/suite/suite.go index f4f270a66583e..50f5534a5cca1 100644 --- a/lib/services/suite/suite.go +++ b/lib/services/suite/suite.go @@ -20,11 +20,14 @@ import ( "context" "crypto/rsa" "crypto/x509/pkix" + "fmt" "sort" "sync" "sync/atomic" + "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/google/uuid" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/constants" @@ -38,8 +41,8 @@ import ( "github.com/gravitational/trace" "github.com/jonboulle/clockwork" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" - "gopkg.in/check.v1" ) // NewTestCA returns new test authority with a test key as a public and @@ -152,19 +155,19 @@ func (s *ServicesTestSuite) Users() services.UsersService { return s.UsersS } -func userSlicesEqual(c *check.C, a []types.User, b []types.User) { - comment := check.Commentf("a: %#v b: %#v", a, b) - c.Assert(len(a), check.Equals, len(b), comment) +func userSlicesEqual(t *testing.T, a []types.User, b []types.User) { + require.EqualValuesf(t, len(a), len(b), "a: %#v b: %#v", a, b) + sort.Sort(services.Users(a)) sort.Sort(services.Users(b)) + for i := range a { - usersEqual(c, a[i], b[i]) + usersEqual(t, a[i], b[i]) } } -func usersEqual(c *check.C, a types.User, b types.User) { - comment := check.Commentf("a: %#v b: %#v", a, b) - c.Assert(services.UsersEquals(a, b), check.Equals, true, comment) +func usersEqual(t *testing.T, a types.User, b types.User) { + require.True(t, services.UsersEquals(a, b), cmp.Diff(a, b)) } func newUser(name string, roles []string) types.User { @@ -181,49 +184,50 @@ func newUser(name string, roles []string) types.User { } } -func (s *ServicesTestSuite) UsersCRUD(c *check.C) { +func (s *ServicesTestSuite) UsersCRUD(t *testing.T) { ctx := context.Background() + u, err := s.WebS.GetUsers(false) - c.Assert(err, check.IsNil) - c.Assert(len(u), check.Equals, 0) + require.NoError(t, err) + require.Equal(t, len(u), 0) - c.Assert(s.WebS.UpsertPasswordHash("user1", []byte("hash")), check.IsNil) - c.Assert(s.WebS.UpsertPasswordHash("user2", []byte("hash2")), check.IsNil) + require.NoError(t, s.WebS.UpsertPasswordHash("user1", []byte("hash"))) + require.NoError(t, s.WebS.UpsertPasswordHash("user2", []byte("hash2"))) u, err = s.WebS.GetUsers(false) - c.Assert(err, check.IsNil) - userSlicesEqual(c, u, []types.User{newUser("user1", nil), newUser("user2", nil)}) + require.NoError(t, err) + userSlicesEqual(t, u, []types.User{newUser("user1", nil), newUser("user2", nil)}) out, err := s.WebS.GetUser("user1", false) - c.Assert(err, check.IsNil) - usersEqual(c, out, u[0]) + require.NoError(t, err) + usersEqual(t, out, u[0]) user := newUser("user1", []string{"admin", "user"}) - c.Assert(s.WebS.UpsertUser(user), check.IsNil) + require.NoError(t, s.WebS.UpsertUser(user)) out, err = s.WebS.GetUser("user1", false) - c.Assert(err, check.IsNil) - usersEqual(c, out, user) + require.NoError(t, err) + usersEqual(t, out, user) out, err = s.WebS.GetUser("user1", false) - c.Assert(err, check.IsNil) - usersEqual(c, out, user) + require.NoError(t, err) + usersEqual(t, out, user) - c.Assert(s.WebS.DeleteUser(ctx, "user1"), check.IsNil) + require.NoError(t, s.WebS.DeleteUser(ctx, "user1")) u, err = s.WebS.GetUsers(false) - c.Assert(err, check.IsNil) - userSlicesEqual(c, u, []types.User{newUser("user2", nil)}) + require.NoError(t, err) + userSlicesEqual(t, u, []types.User{newUser("user2", nil)}) err = s.WebS.DeleteUser(ctx, "user1") - fixtures.ExpectNotFound(c, err) + require.True(t, trace.IsNotFound(err)) // bad username err = s.WebS.UpsertUser(newUser("", nil)) - fixtures.ExpectBadParameter(c, err) + require.True(t, trace.IsBadParameter(err)) } -func (s *ServicesTestSuite) UsersExpiry(c *check.C) { +func (s *ServicesTestSuite) UsersExpiry(t *testing.T) { expiresAt := s.Clock.Now().Add(1 * time.Minute) err := s.WebS.UpsertUser(&types.UserV2{ @@ -236,56 +240,58 @@ func (s *ServicesTestSuite) UsersExpiry(c *check.C) { }, Spec: types.UserSpecV2{}, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) // Make sure the user exists. u, err := s.WebS.GetUser("foo", false) - c.Assert(err, check.IsNil) - c.Assert(u.GetName(), check.Equals, "foo") + require.NoError(t, err) + require.Equal(t, u.GetName(), "foo") s.Clock.Advance(2 * time.Minute) // Make sure the user is now gone. _, err = s.WebS.GetUser("foo", false) - c.Assert(err, check.NotNil) + require.Error(t, err) } -func (s *ServicesTestSuite) LoginAttempts(c *check.C) { - user := newUser("user1", []string{"admin", "user"}) - c.Assert(s.WebS.UpsertUser(user), check.IsNil) +func (s *ServicesTestSuite) LoginAttempts(t *testing.T) { + user1 := uuid.NewString() + + user := newUser(user1, []string{"admin", "user"}) + require.NoError(t, s.WebS.UpsertUser(user)) attempts, err := s.WebS.GetUserLoginAttempts(user.GetName()) - c.Assert(err, check.IsNil) - c.Assert(len(attempts), check.Equals, 0) + require.NoError(t, err) + require.Equal(t, len(attempts), 0) clock := clockwork.NewFakeClock() attempt1 := services.LoginAttempt{Time: clock.Now().UTC(), Success: false} err = s.WebS.AddUserLoginAttempt(user.GetName(), attempt1, defaults.AttemptTTL) - c.Assert(err, check.IsNil) + require.NoError(t, err) attempt2 := services.LoginAttempt{Time: clock.Now().UTC(), Success: false} err = s.WebS.AddUserLoginAttempt(user.GetName(), attempt2, defaults.AttemptTTL) - c.Assert(err, check.IsNil) + require.NoError(t, err) attempts, err = s.WebS.GetUserLoginAttempts(user.GetName()) - c.Assert(err, check.IsNil) - c.Assert(attempts, check.DeepEquals, []services.LoginAttempt{attempt1, attempt2}) - c.Assert(services.LastFailed(3, attempts), check.Equals, false) - c.Assert(services.LastFailed(2, attempts), check.Equals, true) + require.NoError(t, err) + require.Empty(t, cmp.Diff(attempts, []services.LoginAttempt{attempt1, attempt2})) + require.Equal(t, services.LastFailed(3, attempts), false) + require.Equal(t, services.LastFailed(2, attempts), true) } -func (s *ServicesTestSuite) CertAuthCRUD(c *check.C) { +func (s *ServicesTestSuite) CertAuthCRUD(t *testing.T) { ctx := context.Background() ca := NewTestCA(types.UserCA, "example.com") - c.Assert(s.CAS.UpsertCertAuthority(ca), check.IsNil) + require.NoError(t, s.CAS.UpsertCertAuthority(ca)) out, err := s.CAS.GetCertAuthority(ctx, ca.GetID(), true) - c.Assert(err, check.IsNil) + require.NoError(t, err) ca.SetResourceID(out.GetResourceID()) - fixtures.DeepCompare(c, out, ca) + require.Equal(t, out, ca) cas, err := s.CAS.GetCertAuthorities(ctx, types.UserCA, false) - c.Assert(err, check.IsNil) + require.NoError(t, err) ca2 := ca.Clone().(*types.CertAuthorityV2) ca2.Spec.ActiveKeys.SSH[0].PrivateKey = nil ca2.Spec.SigningKeys = nil @@ -293,22 +299,22 @@ func (s *ServicesTestSuite) CertAuthCRUD(c *check.C) { ca2.Spec.TLSKeyPairs[0].Key = nil ca2.Spec.ActiveKeys.JWT[0].PrivateKey = nil ca2.Spec.JWTKeyPairs[0].PrivateKey = nil - fixtures.DeepCompare(c, cas[0], ca2) + require.Equal(t, cas[0], ca2) cas, err = s.CAS.GetCertAuthorities(ctx, types.UserCA, true) - c.Assert(err, check.IsNil) - fixtures.DeepCompare(c, cas[0], ca) + require.NoError(t, err) + require.Equal(t, cas[0], ca) cas, err = s.CAS.GetCertAuthorities(ctx, types.UserCA, true) - c.Assert(err, check.IsNil) - fixtures.DeepCompare(c, cas[0], ca) + require.NoError(t, err) + require.Equal(t, cas[0], ca) err = s.CAS.DeleteCertAuthority(*ca.ID()) - c.Assert(err, check.IsNil) + require.NoError(t, err) // test compare and swap ca = NewTestCA(types.UserCA, "example.com") - c.Assert(s.CAS.CreateCertAuthority(ca), check.IsNil) + require.NoError(t, s.CAS.CreateCertAuthority(ca)) clock := clockwork.NewFakeClock() newCA := *ca @@ -321,12 +327,12 @@ func (s *ServicesTestSuite) CertAuthCRUD(c *check.C) { newCA.SetRotation(rotation) err = s.CAS.CompareAndSwapCertAuthority(&newCA, ca) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err = s.CAS.GetCertAuthority(ctx, ca.GetID(), true) - c.Assert(err, check.IsNil) + require.NoError(t, err) newCA.SetResourceID(out.GetResourceID()) - fixtures.DeepCompare(c, &newCA, out) + require.Equal(t, &newCA, out) } // NewServer creates a new server resource @@ -345,99 +351,100 @@ func NewServer(kind, name, addr, namespace string) *types.ServerV2 { } } -func (s *ServicesTestSuite) ServerCRUD(c *check.C) { +func (s *ServicesTestSuite) ServerCRUD(t *testing.T) { ctx := context.Background() + // SSH service. out, err := s.PresenceS.GetNodes(ctx, apidefaults.Namespace) - c.Assert(err, check.IsNil) - c.Assert(len(out), check.Equals, 0) + require.NoError(t, err) + require.Equal(t, len(out), 0) srv := NewServer(types.KindNode, "srv1", "127.0.0.1:2022", apidefaults.Namespace) _, err = s.PresenceS.UpsertNode(ctx, srv) - c.Assert(err, check.IsNil) + require.NoError(t, err) node, err := s.PresenceS.GetNode(ctx, srv.Metadata.Namespace, srv.GetName()) - c.Assert(err, check.IsNil) + require.NoError(t, err) srv.SetResourceID(node.GetResourceID()) - fixtures.DeepCompare(c, node, srv) + require.Empty(t, cmp.Diff(node, srv)) out, err = s.PresenceS.GetNodes(ctx, srv.Metadata.Namespace) - c.Assert(err, check.IsNil) - c.Assert(out, check.HasLen, 1) + require.NoError(t, err) + require.Len(t, out, 1) srv.SetResourceID(out[0].GetResourceID()) - fixtures.DeepCompare(c, out, []types.Server{srv}) + require.Empty(t, cmp.Diff(out, []types.Server{srv})) err = s.PresenceS.DeleteNode(ctx, srv.Metadata.Namespace, srv.GetName()) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err = s.PresenceS.GetNodes(ctx, srv.Metadata.Namespace) - c.Assert(err, check.IsNil) - c.Assert(out, check.HasLen, 0) + require.NoError(t, err) + require.Len(t, out, 0) // Proxy service. out, err = s.PresenceS.GetProxies() - c.Assert(err, check.IsNil) - c.Assert(len(out), check.Equals, 0) + require.NoError(t, err) + require.Equal(t, len(out), 0) proxy := NewServer(types.KindProxy, "proxy1", "127.0.0.1:2023", apidefaults.Namespace) - c.Assert(s.PresenceS.UpsertProxy(proxy), check.IsNil) + require.NoError(t, s.PresenceS.UpsertProxy(proxy)) out, err = s.PresenceS.GetProxies() - c.Assert(err, check.IsNil) - c.Assert(out, check.HasLen, 1) + require.NoError(t, err) + require.Len(t, out, 1) proxy.SetResourceID(out[0].GetResourceID()) - c.Assert(out, check.DeepEquals, []types.Server{proxy}) + require.Empty(t, cmp.Diff(out, []types.Server{proxy})) err = s.PresenceS.DeleteProxy(proxy.GetName()) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err = s.PresenceS.GetProxies() - c.Assert(err, check.IsNil) - c.Assert(out, check.HasLen, 0) + require.NoError(t, err) + require.Len(t, out, 0) // Auth service. out, err = s.PresenceS.GetAuthServers() - c.Assert(err, check.IsNil) - c.Assert(len(out), check.Equals, 0) + require.NoError(t, err) + require.Equal(t, len(out), 0) auth := NewServer(types.KindAuthServer, "auth1", "127.0.0.1:2025", apidefaults.Namespace) - c.Assert(s.PresenceS.UpsertAuthServer(auth), check.IsNil) + require.NoError(t, s.PresenceS.UpsertAuthServer(auth)) out, err = s.PresenceS.GetAuthServers() - c.Assert(err, check.IsNil) - c.Assert(out, check.HasLen, 1) + require.NoError(t, err) + require.Len(t, out, 1) auth.SetResourceID(out[0].GetResourceID()) - c.Assert(out, check.DeepEquals, []types.Server{auth}) + require.Empty(t, cmp.Diff(out, []types.Server{auth})) // Kubernetes service. out, err = s.PresenceS.GetKubeServices(ctx) - c.Assert(err, check.IsNil) - c.Assert(len(out), check.Equals, 0) + require.NoError(t, err) + require.Equal(t, len(out), 0) kube1 := NewServer(types.KindKubeService, "kube1", "10.0.0.1:3026", apidefaults.Namespace) _, err = s.PresenceS.UpsertKubeServiceV2(ctx, kube1) - c.Assert(err, check.IsNil) + require.NoError(t, err) kube2 := NewServer(types.KindKubeService, "kube2", "10.0.0.2:3026", apidefaults.Namespace) _, err = s.PresenceS.UpsertKubeServiceV2(ctx, kube2) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err = s.PresenceS.GetKubeServices(ctx) - c.Assert(err, check.IsNil) - c.Assert(out, check.HasLen, 2) + require.NoError(t, err) + require.Len(t, out, 2) kube1.SetResourceID(out[0].GetResourceID()) kube2.SetResourceID(out[1].GetResourceID()) - c.Assert(out, check.DeepEquals, []types.Server{kube1, kube2}) + require.Empty(t, cmp.Diff(out, []types.Server{kube1, kube2})) - c.Assert(s.PresenceS.DeleteKubeService(ctx, kube1.GetName()), check.IsNil) + require.NoError(t, s.PresenceS.DeleteKubeService(ctx, kube1.GetName())) out, err = s.PresenceS.GetKubeServices(ctx) - c.Assert(err, check.IsNil) - c.Assert(out, check.HasLen, 1) - c.Assert(out, check.DeepEquals, []types.Server{kube2}) + require.NoError(t, err) + require.Len(t, out, 1) + require.Empty(t, cmp.Diff(out, []types.Server{kube2})) - c.Assert(s.PresenceS.DeleteAllKubeServices(ctx), check.IsNil) + require.NoError(t, s.PresenceS.DeleteAllKubeServices(ctx)) out, err = s.PresenceS.GetKubeServices(ctx) - c.Assert(err, check.IsNil) - c.Assert(out, check.HasLen, 0) + require.NoError(t, err) + require.Len(t, out, 0) } // NewAppServer creates a new application server resource. @@ -462,7 +469,7 @@ func NewAppServer(name string, internalAddr string, publicAddr string) *types.Se } // AppServerCRUD tests CRUD functionality for services.Server. -func (s *ServicesTestSuite) AppServerCRUD(c *check.C) { +func (s *ServicesTestSuite) AppServerCRUD(t *testing.T) { ctx := context.Background() // Create application. @@ -470,28 +477,28 @@ func (s *ServicesTestSuite) AppServerCRUD(c *check.C) { // Expect not to be returned any applications and trace.NotFound. out, err := s.PresenceS.GetAppServers(ctx, apidefaults.Namespace) - c.Assert(err, check.IsNil) - c.Assert(len(out), check.Equals, 0) + require.NoError(t, err) + require.Equal(t, len(out), 0) // Upsert application. _, err = s.PresenceS.UpsertAppServer(ctx, server) - c.Assert(err, check.IsNil) + require.NoError(t, err) // Check again, expect a single application to be found. out, err = s.PresenceS.GetAppServers(ctx, server.GetNamespace()) - c.Assert(err, check.IsNil) - c.Assert(out, check.HasLen, 1) + require.NoError(t, err) + require.Len(t, out, 1) server.SetResourceID(out[0].GetResourceID()) - fixtures.DeepCompare(c, []types.Server{server}, out) + require.Empty(t, cmp.Diff([]types.Server{server}, out)) // Remove the application. err = s.PresenceS.DeleteAppServer(ctx, server.Metadata.Namespace, server.GetName()) - c.Assert(err, check.IsNil) + require.NoError(t, err) // Now expect no applications to be returned. out, err = s.PresenceS.GetAppServers(ctx, server.Metadata.Namespace) - c.Assert(err, check.IsNil) - c.Assert(out, check.HasLen, 0) + require.NoError(t, err) + require.Len(t, out, 0) } func newReverseTunnel(clusterName string, dialAddrs []string) *types.ReverseTunnelV2 { @@ -509,61 +516,61 @@ func newReverseTunnel(clusterName string, dialAddrs []string) *types.ReverseTunn } } -func (s *ServicesTestSuite) ReverseTunnelsCRUD(c *check.C) { +func (s *ServicesTestSuite) ReverseTunnelsCRUD(t *testing.T) { out, err := s.PresenceS.GetReverseTunnels(context.Background()) - c.Assert(err, check.IsNil) - c.Assert(len(out), check.Equals, 0) + require.NoError(t, err) + require.Equal(t, len(out), 0) tunnel := newReverseTunnel("example.com", []string{"example.com:2023"}) - c.Assert(s.PresenceS.UpsertReverseTunnel(tunnel), check.IsNil) + require.NoError(t, s.PresenceS.UpsertReverseTunnel(tunnel)) out, err = s.PresenceS.GetReverseTunnels(context.Background()) - c.Assert(err, check.IsNil) - c.Assert(out, check.HasLen, 1) + require.NoError(t, err) + require.Len(t, out, 1) tunnel.SetResourceID(out[0].GetResourceID()) - fixtures.DeepCompare(c, out, []types.ReverseTunnel{tunnel}) + require.Empty(t, cmp.Diff(out, []types.ReverseTunnel{tunnel})) err = s.PresenceS.DeleteReverseTunnel(tunnel.Spec.ClusterName) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err = s.PresenceS.GetReverseTunnels(context.Background()) - c.Assert(err, check.IsNil) - c.Assert(len(out), check.Equals, 0) + require.NoError(t, err) + require.Equal(t, len(out), 0) err = s.PresenceS.UpsertReverseTunnel(newReverseTunnel("", []string{"127.0.0.1:1234"})) - fixtures.ExpectBadParameter(c, err) + require.True(t, trace.IsBadParameter(err)) err = s.PresenceS.UpsertReverseTunnel(newReverseTunnel("example.com", []string{""})) - fixtures.ExpectBadParameter(c, err) + require.True(t, trace.IsBadParameter(err)) err = s.PresenceS.UpsertReverseTunnel(newReverseTunnel("example.com", []string{})) - fixtures.ExpectBadParameter(c, err) + require.True(t, trace.IsBadParameter(err)) } -func (s *ServicesTestSuite) PasswordHashCRUD(c *check.C) { +func (s *ServicesTestSuite) PasswordHashCRUD(t *testing.T) { _, err := s.WebS.GetPasswordHash("user1") - c.Assert(trace.IsNotFound(err), check.Equals, true, check.Commentf("%#v", err)) + require.Equal(t, trace.IsNotFound(err), true, fmt.Sprintf("%#v", err)) err = s.WebS.UpsertPasswordHash("user1", []byte("hello123")) - c.Assert(err, check.IsNil) + require.NoError(t, err) hash, err := s.WebS.GetPasswordHash("user1") - c.Assert(err, check.IsNil) - c.Assert(hash, check.DeepEquals, []byte("hello123")) + require.NoError(t, err) + require.Empty(t, cmp.Diff(hash, []byte("hello123"))) err = s.WebS.UpsertPasswordHash("user1", []byte("hello321")) - c.Assert(err, check.IsNil) + require.NoError(t, err) hash, err = s.WebS.GetPasswordHash("user1") - c.Assert(err, check.IsNil) - c.Assert(hash, check.DeepEquals, []byte("hello321")) + require.NoError(t, err) + require.Empty(t, cmp.Diff(hash, []byte("hello321"))) } -func (s *ServicesTestSuite) WebSessionCRUD(c *check.C) { +func (s *ServicesTestSuite) WebSessionCRUD(t *testing.T) { ctx := context.Background() req := types.GetWebSessionRequest{User: "user1", SessionID: "sid1"} _, err := s.WebS.WebSessions().Get(ctx, req) - c.Assert(trace.IsNotFound(err), check.Equals, true, check.Commentf("%#v", err)) + require.Equal(t, trace.IsNotFound(err), true, fmt.Sprintf("%#v", err)) dt := s.Clock.Now().Add(1 * time.Minute) ws, err := types.NewWebSession("sid1", types.KindWebSession, @@ -573,14 +580,14 @@ func (s *ServicesTestSuite) WebSessionCRUD(c *check.C) { Priv: []byte("priv123"), Expires: dt, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.WebS.WebSessions().Upsert(ctx, ws) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err := s.WebS.WebSessions().Get(ctx, req) - c.Assert(err, check.IsNil) - c.Assert(out, check.DeepEquals, ws) + require.NoError(t, err) + require.Empty(t, cmp.Diff(out, ws)) ws1, err := types.NewWebSession("sid1", types.KindWebSession, types.WebSessionSpecV2{ @@ -589,48 +596,48 @@ func (s *ServicesTestSuite) WebSessionCRUD(c *check.C) { Priv: []byte("priv321"), Expires: dt, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.WebS.WebSessions().Upsert(ctx, ws1) - c.Assert(err, check.IsNil) + require.NoError(t, err) out2, err := s.WebS.WebSessions().Get(ctx, req) - c.Assert(err, check.IsNil) - c.Assert(out2, check.DeepEquals, ws1) + require.NoError(t, err) + require.Empty(t, cmp.Diff(out2, ws1)) - c.Assert(s.WebS.WebSessions().Delete(ctx, types.DeleteWebSessionRequest{ + require.NoError(t, s.WebS.WebSessions().Delete(ctx, types.DeleteWebSessionRequest{ User: req.User, SessionID: req.SessionID, - }), check.IsNil) + })) _, err = s.WebS.WebSessions().Get(ctx, req) - fixtures.ExpectNotFound(c, err) + require.True(t, trace.IsNotFound(err)) } -func (s *ServicesTestSuite) TokenCRUD(c *check.C) { +func (s *ServicesTestSuite) TokenCRUD(t *testing.T) { ctx := context.Background() _, err := s.ProvisioningS.GetToken(ctx, "token") - fixtures.ExpectNotFound(c, err) + require.True(t, trace.IsNotFound(err)) - t, err := types.NewProvisionToken("token", types.SystemRoles{types.RoleAuth, types.RoleNode}, time.Time{}) - c.Assert(err, check.IsNil) + tok, err := types.NewProvisionToken("token", types.SystemRoles{types.RoleAuth, types.RoleNode}, time.Time{}) + require.NoError(t, err) - c.Assert(s.ProvisioningS.UpsertToken(ctx, t), check.IsNil) + require.NoError(t, s.ProvisioningS.UpsertToken(ctx, tok)) token, err := s.ProvisioningS.GetToken(ctx, "token") - c.Assert(err, check.IsNil) - c.Assert(token.GetRoles().Include(types.RoleAuth), check.Equals, true) - c.Assert(token.GetRoles().Include(types.RoleNode), check.Equals, true) - c.Assert(token.GetRoles().Include(types.RoleProxy), check.Equals, false) + require.NoError(t, err) + require.Equal(t, token.GetRoles().Include(types.RoleAuth), true) + require.Equal(t, token.GetRoles().Include(types.RoleNode), true) + require.Equal(t, token.GetRoles().Include(types.RoleProxy), false) diff := s.Clock.Now().UTC().Add(defaults.ProvisioningTokenTTL).Second() - token.Expiry().Second() if diff > 1 { - c.Fatalf("expected diff to be within one second, got %v instead", diff) + t.Fatalf("expected diff to be within one second, got %v instead", diff) } - c.Assert(s.ProvisioningS.DeleteToken(ctx, "token"), check.IsNil) + require.NoError(t, s.ProvisioningS.DeleteToken(ctx, "token")) _, err = s.ProvisioningS.GetToken(ctx, "token") - fixtures.ExpectNotFound(c, err) + require.True(t, trace.IsNotFound(err)) // check tokens backwards compatibility and marshal/unmarshal expiry := time.Now().UTC().Add(time.Hour) @@ -640,48 +647,48 @@ func (s *ServicesTestSuite) TokenCRUD(c *check.C) { Expires: expiry, } v2, err := types.NewProvisionToken(v1.Token, v1.Roles, expiry) - c.Assert(err, check.IsNil) + require.NoError(t, err) // Tokens in different version formats are backwards and forwards // compatible - fixtures.DeepCompare(c, v1.V2(), v2) - fixtures.DeepCompare(c, v2.V1(), v1) + require.Empty(t, cmp.Diff(v1.V2(), v2)) + require.Empty(t, cmp.Diff(v2.V1(), v1)) // Marshal V1, unmarshal V2 data, err := services.MarshalProvisionToken(v2, services.WithVersion(types.V1)) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err := services.UnmarshalProvisionToken(data) - c.Assert(err, check.IsNil) - fixtures.DeepCompare(c, out, v2) + require.NoError(t, err) + require.Empty(t, cmp.Diff(out, v2)) // Test delete all tokens - t, err = types.NewProvisionToken("token1", types.SystemRoles{types.RoleAuth, types.RoleNode}, time.Time{}) - c.Assert(err, check.IsNil) - c.Assert(s.ProvisioningS.UpsertToken(ctx, t), check.IsNil) + tok, err = types.NewProvisionToken("token1", types.SystemRoles{types.RoleAuth, types.RoleNode}, time.Time{}) + require.NoError(t, err) + require.NoError(t, s.ProvisioningS.UpsertToken(ctx, tok)) - t, err = types.NewProvisionToken("token2", types.SystemRoles{types.RoleAuth, types.RoleNode}, time.Time{}) - c.Assert(err, check.IsNil) - c.Assert(s.ProvisioningS.UpsertToken(ctx, t), check.IsNil) + tok, err = types.NewProvisionToken("token2", types.SystemRoles{types.RoleAuth, types.RoleNode}, time.Time{}) + require.NoError(t, err) + require.NoError(t, s.ProvisioningS.UpsertToken(ctx, tok)) tokens, err := s.ProvisioningS.GetTokens(ctx) - c.Assert(err, check.IsNil) - c.Assert(tokens, check.HasLen, 2) + require.NoError(t, err) + require.Len(t, tokens, 2) err = s.ProvisioningS.DeleteAllTokens() - c.Assert(err, check.IsNil) + require.NoError(t, err) tokens, err = s.ProvisioningS.GetTokens(ctx) - c.Assert(err, check.IsNil) - c.Assert(tokens, check.HasLen, 0) + require.NoError(t, err) + require.Len(t, tokens, 0) } -func (s *ServicesTestSuite) RolesCRUD(c *check.C) { +func (s *ServicesTestSuite) RolesCRUD(t *testing.T) { ctx := context.Background() out, err := s.Access.GetRoles(ctx) - c.Assert(err, check.IsNil) - c.Assert(len(out), check.Equals, 0) + require.NoError(t, err) + require.Equal(t, len(out), 0) role := types.RoleV5{ Kind: types.KindRole, @@ -715,31 +722,31 @@ func (s *ServicesTestSuite) RolesCRUD(c *check.C) { } err = s.Access.UpsertRole(ctx, &role) - c.Assert(err, check.IsNil) + require.NoError(t, err) rout, err := s.Access.GetRole(ctx, role.Metadata.Name) - c.Assert(err, check.IsNil) + require.NoError(t, err) role.SetResourceID(rout.GetResourceID()) - fixtures.DeepCompare(c, rout, &role) + require.Empty(t, cmp.Diff(rout, &role)) role.Spec.Allow.Logins = []string{"bob"} err = s.Access.UpsertRole(ctx, &role) - c.Assert(err, check.IsNil) + require.NoError(t, err) rout, err = s.Access.GetRole(ctx, role.Metadata.Name) - c.Assert(err, check.IsNil) + require.NoError(t, err) role.SetResourceID(rout.GetResourceID()) - c.Assert(rout, check.DeepEquals, &role) + require.Empty(t, cmp.Diff(rout, &role)) err = s.Access.DeleteRole(ctx, role.Metadata.Name) - c.Assert(err, check.IsNil) + require.NoError(t, err) _, err = s.Access.GetRole(ctx, role.Metadata.Name) - fixtures.ExpectNotFound(c, err) + require.True(t, trace.IsNotFound(err)) } -func (s *ServicesTestSuite) NamespacesCRUD(c *check.C) { +func (s *ServicesTestSuite) NamespacesCRUD(t *testing.T) { out, err := s.PresenceS.GetNamespaces() - c.Assert(err, check.IsNil) - c.Assert(len(out), check.Equals, 0) + require.NoError(t, err) + require.Equal(t, len(out), 0) ns := types.Namespace{ Kind: types.KindNamespace, @@ -750,19 +757,19 @@ func (s *ServicesTestSuite) NamespacesCRUD(c *check.C) { }, } err = s.PresenceS.UpsertNamespace(ns) - c.Assert(err, check.IsNil) + require.NoError(t, err) nsout, err := s.PresenceS.GetNamespace(ns.Metadata.Name) - c.Assert(err, check.IsNil) - c.Assert(nsout, check.DeepEquals, &ns) + require.NoError(t, err) + require.Empty(t, cmp.Diff(nsout, &ns)) err = s.PresenceS.DeleteNamespace(ns.Metadata.Name) - c.Assert(err, check.IsNil) + require.NoError(t, err) _, err = s.PresenceS.GetNamespace(ns.Metadata.Name) - fixtures.ExpectNotFound(c, err) + require.True(t, trace.IsNotFound(err)) } -func (s *ServicesTestSuite) SAMLCRUD(c *check.C) { +func (s *ServicesTestSuite) SAMLCRUD(t *testing.T) { ctx := context.Background() connector := &types.SAMLConnectorV2{ Kind: types.KindSAML, @@ -788,42 +795,42 @@ func (s *ServicesTestSuite) SAMLCRUD(c *check.C) { }, } err := services.ValidateSAMLConnector(connector) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.WebS.UpsertSAMLConnector(ctx, connector) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err := s.WebS.GetSAMLConnector(ctx, connector.GetName(), true) - c.Assert(err, check.IsNil) - fixtures.DeepCompare(c, out, connector) + require.NoError(t, err) + require.Empty(t, cmp.Diff(out, connector)) connectors, err := s.WebS.GetSAMLConnectors(ctx, true) - c.Assert(err, check.IsNil) - fixtures.DeepCompare(c, []types.SAMLConnector{connector}, connectors) + require.NoError(t, err) + require.Empty(t, cmp.Diff([]types.SAMLConnector{connector}, connectors)) out2, err := s.WebS.GetSAMLConnector(ctx, connector.GetName(), false) - c.Assert(err, check.IsNil) + require.NoError(t, err) connectorNoSecrets := *connector connectorNoSecrets.Spec.SigningKeyPair.PrivateKey = "" - fixtures.DeepCompare(c, out2, &connectorNoSecrets) + require.Empty(t, cmp.Diff(out2, &connectorNoSecrets)) connectorsNoSecrets, err := s.WebS.GetSAMLConnectors(ctx, false) - c.Assert(err, check.IsNil) - fixtures.DeepCompare(c, []types.SAMLConnector{&connectorNoSecrets}, connectorsNoSecrets) + require.NoError(t, err) + require.Empty(t, cmp.Diff([]types.SAMLConnector{&connectorNoSecrets}, connectorsNoSecrets)) err = s.WebS.DeleteSAMLConnector(ctx, connector.GetName()) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.WebS.DeleteSAMLConnector(ctx, connector.GetName()) - c.Assert(trace.IsNotFound(err), check.Equals, true, check.Commentf("expected not found, got %T", err)) + require.Equal(t, trace.IsNotFound(err), true, fmt.Sprintf("expected not found, got %T", err)) _, err = s.WebS.GetSAMLConnector(ctx, connector.GetName(), true) - c.Assert(trace.IsNotFound(err), check.Equals, true, check.Commentf("expected not found, got %T", err)) + require.Equal(t, trace.IsNotFound(err), true, fmt.Sprintf("expected not found, got %T", err)) } -func (s *ServicesTestSuite) TunnelConnectionsCRUD(c *check.C) { +func (s *ServicesTestSuite) TunnelConnectionsCRUD(t *testing.T) { clusterName := "example.com" out, err := s.PresenceS.GetTunnelConnections(clusterName) - c.Assert(err, check.IsNil) - c.Assert(len(out), check.Equals, 0) + require.NoError(t, err) + require.Equal(t, len(out), 0) dt := s.Clock.Now() conn, err := types.NewTunnelConnection("conn1", types.TunnelConnectionSpecV2{ @@ -831,63 +838,63 @@ func (s *ServicesTestSuite) TunnelConnectionsCRUD(c *check.C) { ProxyName: "p1", LastHeartbeat: dt, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.PresenceS.UpsertTunnelConnection(conn) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err = s.PresenceS.GetTunnelConnections(clusterName) - c.Assert(err, check.IsNil) - c.Assert(len(out), check.Equals, 1) + require.NoError(t, err) + require.Equal(t, len(out), 1) conn.SetResourceID(out[0].GetResourceID()) - fixtures.DeepCompare(c, out[0], conn) + require.Empty(t, cmp.Diff(out[0], conn)) out, err = s.PresenceS.GetAllTunnelConnections() - c.Assert(err, check.IsNil) - c.Assert(len(out), check.Equals, 1) - fixtures.DeepCompare(c, out[0], conn) + require.NoError(t, err) + require.Equal(t, len(out), 1) + require.Empty(t, cmp.Diff(out[0], conn)) dt = dt.Add(time.Hour) conn.SetLastHeartbeat(dt) err = s.PresenceS.UpsertTunnelConnection(conn) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err = s.PresenceS.GetTunnelConnections(clusterName) - c.Assert(err, check.IsNil) - c.Assert(len(out), check.Equals, 1) + require.NoError(t, err) + require.Equal(t, len(out), 1) conn.SetResourceID(out[0].GetResourceID()) - fixtures.DeepCompare(c, out[0], conn) + require.Empty(t, cmp.Diff(out[0], conn)) err = s.PresenceS.DeleteAllTunnelConnections() - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err = s.PresenceS.GetTunnelConnections(clusterName) - c.Assert(err, check.IsNil) - c.Assert(len(out), check.Equals, 0) + require.NoError(t, err) + require.Equal(t, len(out), 0) err = s.PresenceS.DeleteAllTunnelConnections() - c.Assert(err, check.IsNil) + require.NoError(t, err) // test delete individual connection err = s.PresenceS.UpsertTunnelConnection(conn) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err = s.PresenceS.GetTunnelConnections(clusterName) - c.Assert(err, check.IsNil) - c.Assert(len(out), check.Equals, 1) + require.NoError(t, err) + require.Equal(t, len(out), 1) conn.SetResourceID(out[0].GetResourceID()) - fixtures.DeepCompare(c, out[0], conn) + require.Empty(t, cmp.Diff(out[0], conn)) err = s.PresenceS.DeleteTunnelConnection(clusterName, conn.GetName()) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err = s.PresenceS.GetTunnelConnections(clusterName) - c.Assert(err, check.IsNil) - c.Assert(len(out), check.Equals, 0) + require.NoError(t, err) + require.Equal(t, len(out), 0) } -func (s *ServicesTestSuite) GithubConnectorCRUD(c *check.C) { +func (s *ServicesTestSuite) GithubConnectorCRUD(t *testing.T) { ctx := context.Background() connector := &types.GithubConnectorV3{ Kind: types.KindGithubConnector, @@ -912,122 +919,122 @@ func (s *ServicesTestSuite) GithubConnectorCRUD(c *check.C) { }, } err := connector.CheckAndSetDefaults() - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.WebS.UpsertGithubConnector(ctx, connector) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err := s.WebS.GetGithubConnector(ctx, connector.GetName(), true) - c.Assert(err, check.IsNil) - fixtures.DeepCompare(c, out, connector) + require.NoError(t, err) + require.Empty(t, cmp.Diff(out, connector)) connectors, err := s.WebS.GetGithubConnectors(ctx, true) - c.Assert(err, check.IsNil) - fixtures.DeepCompare(c, []types.GithubConnector{connector}, connectors) + require.NoError(t, err) + require.Empty(t, cmp.Diff([]types.GithubConnector{connector}, connectors)) out2, err := s.WebS.GetGithubConnector(ctx, connector.GetName(), false) - c.Assert(err, check.IsNil) + require.NoError(t, err) connectorNoSecrets := *connector connectorNoSecrets.Spec.ClientSecret = "" - fixtures.DeepCompare(c, out2, &connectorNoSecrets) + require.Empty(t, cmp.Diff(out2, &connectorNoSecrets)) connectorsNoSecrets, err := s.WebS.GetGithubConnectors(ctx, false) - c.Assert(err, check.IsNil) - fixtures.DeepCompare(c, []types.GithubConnector{&connectorNoSecrets}, connectorsNoSecrets) + require.NoError(t, err) + require.Empty(t, cmp.Diff([]types.GithubConnector{&connectorNoSecrets}, connectorsNoSecrets)) err = s.WebS.DeleteGithubConnector(ctx, connector.GetName()) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.WebS.DeleteGithubConnector(ctx, connector.GetName()) - c.Assert(trace.IsNotFound(err), check.Equals, true, check.Commentf("expected not found, got %T", err)) + require.Equal(t, trace.IsNotFound(err), true, fmt.Sprintf("expected not found, got %T", err)) _, err = s.WebS.GetGithubConnector(ctx, connector.GetName(), true) - c.Assert(trace.IsNotFound(err), check.Equals, true, check.Commentf("expected not found, got %T", err)) + require.Equal(t, trace.IsNotFound(err), true, fmt.Sprintf("expected not found, got %T", err)) } -func (s *ServicesTestSuite) RemoteClustersCRUD(c *check.C) { +func (s *ServicesTestSuite) RemoteClustersCRUD(t *testing.T) { clusterName := "example.com" out, err := s.PresenceS.GetRemoteClusters() - c.Assert(err, check.IsNil) - c.Assert(len(out), check.Equals, 0) + require.NoError(t, err) + require.Equal(t, len(out), 0) rc, err := types.NewRemoteCluster(clusterName) - c.Assert(err, check.IsNil) + require.NoError(t, err) rc.SetConnectionStatus(teleport.RemoteClusterStatusOffline) err = s.PresenceS.CreateRemoteCluster(rc) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.PresenceS.CreateRemoteCluster(rc) - fixtures.ExpectAlreadyExists(c, err) + require.True(t, trace.IsAlreadyExists(err)) out, err = s.PresenceS.GetRemoteClusters() - c.Assert(err, check.IsNil) - c.Assert(len(out), check.Equals, 1) + require.NoError(t, err) + require.Equal(t, len(out), 1) rc.SetResourceID(out[0].GetResourceID()) - fixtures.DeepCompare(c, out[0], rc) + require.Empty(t, cmp.Diff(out[0], rc)) err = s.PresenceS.DeleteAllRemoteClusters() - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err = s.PresenceS.GetRemoteClusters() - c.Assert(err, check.IsNil) - c.Assert(len(out), check.Equals, 0) + require.NoError(t, err) + require.Equal(t, len(out), 0) // test delete individual connection err = s.PresenceS.CreateRemoteCluster(rc) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err = s.PresenceS.GetRemoteClusters() - c.Assert(err, check.IsNil) - c.Assert(len(out), check.Equals, 1) - fixtures.DeepCompare(c, out[0], rc) + require.NoError(t, err) + require.Equal(t, len(out), 1) + require.Empty(t, cmp.Diff(out[0], rc)) err = s.PresenceS.DeleteRemoteCluster(clusterName) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.PresenceS.DeleteRemoteCluster(clusterName) - fixtures.ExpectNotFound(c, err) + require.True(t, trace.IsNotFound(err)) } // AuthPreference tests authentication preference service -func (s *ServicesTestSuite) AuthPreference(c *check.C) { +func (s *ServicesTestSuite) AuthPreference(t *testing.T) { ctx := context.Background() ap, err := types.NewAuthPreferenceFromConfigFile(types.AuthPreferenceSpecV2{ Type: "local", SecondFactor: "otp", DisconnectExpiredCert: types.NewBoolOption(true), }) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.ConfigS.SetAuthPreference(ctx, ap) - c.Assert(err, check.IsNil) + require.NoError(t, err) gotAP, err := s.ConfigS.GetAuthPreference(ctx) - c.Assert(err, check.IsNil) + require.NoError(t, err) - c.Assert(gotAP.GetType(), check.Equals, "local") - c.Assert(gotAP.GetSecondFactor(), check.Equals, constants.SecondFactorOTP) - c.Assert(gotAP.GetDisconnectExpiredCert(), check.Equals, true) + require.Equal(t, gotAP.GetType(), "local") + require.Equal(t, gotAP.GetSecondFactor(), constants.SecondFactorOTP) + require.Equal(t, gotAP.GetDisconnectExpiredCert(), true) } // SessionRecordingConfig tests session recording configuration. -func (s *ServicesTestSuite) SessionRecordingConfig(c *check.C) { +func (s *ServicesTestSuite) SessionRecordingConfig(t *testing.T) { ctx := context.Background() recConfig, err := types.NewSessionRecordingConfigFromConfigFile(types.SessionRecordingConfigSpecV2{ Mode: types.RecordAtProxy, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.ConfigS.SetSessionRecordingConfig(ctx, recConfig) - c.Assert(err, check.IsNil) + require.NoError(t, err) gotrecConfig, err := s.ConfigS.GetSessionRecordingConfig(ctx) - c.Assert(err, check.IsNil) + require.NoError(t, err) - c.Assert(gotrecConfig.GetMode(), check.Equals, types.RecordAtProxy) + require.Equal(t, gotrecConfig.GetMode(), types.RecordAtProxy) } -func (s *ServicesTestSuite) StaticTokens(c *check.C) { +func (s *ServicesTestSuite) StaticTokens(t *testing.T) { // set static tokens staticTokens, err := types.NewStaticTokens(types.StaticTokensSpecV2{ StaticTokens: []types.ProvisionTokenV1{ @@ -1038,21 +1045,21 @@ func (s *ServicesTestSuite) StaticTokens(c *check.C) { }, }, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.ConfigS.SetStaticTokens(staticTokens) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err := s.ConfigS.GetStaticTokens() - c.Assert(err, check.IsNil) + require.NoError(t, err) staticTokens.SetResourceID(out.GetResourceID()) - fixtures.DeepCompare(c, staticTokens, out) + require.Empty(t, cmp.Diff(staticTokens, out)) err = s.ConfigS.DeleteStaticTokens() - c.Assert(err, check.IsNil) + require.NoError(t, err) _, err = s.ConfigS.GetStaticTokens() - fixtures.ExpectNotFound(c, err) + require.True(t, trace.IsNotFound(err)) } // Options provides functional arguments @@ -1082,51 +1089,51 @@ func CollectOptions(opts ...Option) Options { } // ClusterName tests cluster name. -func (s *ServicesTestSuite) ClusterName(c *check.C, opts ...Option) { +func (s *ServicesTestSuite) ClusterName(t *testing.T, opts ...Option) { clusterName, err := services.NewClusterNameWithRandomID(types.ClusterNameSpecV2{ ClusterName: "example.com", }) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.ConfigS.SetClusterName(clusterName) - c.Assert(err, check.IsNil) + require.NoError(t, err) gotName, err := s.ConfigS.GetClusterName() - c.Assert(err, check.IsNil) + require.NoError(t, err) clusterName.SetResourceID(gotName.GetResourceID()) - fixtures.DeepCompare(c, clusterName, gotName) + require.Empty(t, cmp.Diff(clusterName, gotName)) err = s.ConfigS.DeleteClusterName() - c.Assert(err, check.IsNil) + require.NoError(t, err) _, err = s.ConfigS.GetClusterName() - fixtures.ExpectNotFound(c, err) + require.True(t, trace.IsNotFound(err)) err = s.ConfigS.UpsertClusterName(clusterName) - c.Assert(err, check.IsNil) + require.NoError(t, err) gotName, err = s.ConfigS.GetClusterName() - c.Assert(err, check.IsNil) + require.NoError(t, err) clusterName.SetResourceID(gotName.GetResourceID()) - fixtures.DeepCompare(c, clusterName, gotName) + require.Empty(t, cmp.Diff(clusterName, gotName)) } // ClusterNetworkingConfig tests cluster networking configuration. -func (s *ServicesTestSuite) ClusterNetworkingConfig(c *check.C) { +func (s *ServicesTestSuite) ClusterNetworkingConfig(t *testing.T) { ctx := context.Background() netConfig, err := types.NewClusterNetworkingConfigFromConfigFile(types.ClusterNetworkingConfigSpecV2{ ClientIdleTimeout: types.NewDuration(17 * time.Second), KeepAliveCountMax: 3000, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.ConfigS.SetClusterNetworkingConfig(ctx, netConfig) - c.Assert(err, check.IsNil) + require.NoError(t, err) gotNetConfig, err := s.ConfigS.GetClusterNetworkingConfig(ctx) - c.Assert(err, check.IsNil) + require.NoError(t, err) - c.Assert(gotNetConfig.GetClientIdleTimeout(), check.Equals, 17*time.Second) - c.Assert(gotNetConfig.GetKeepAliveCountMax(), check.Equals, int64(3000)) + require.Equal(t, gotNetConfig.GetClientIdleTimeout(), 17*time.Second) + require.Equal(t, gotNetConfig.GetKeepAliveCountMax(), int64(3000)) } // sem wrapper is a helper for overriding the keepalive @@ -1140,7 +1147,7 @@ func (w *semWrapper) KeepAliveSemaphoreLease(ctx context.Context, lease types.Se return w.keepAlive(ctx, lease) } -func (s *ServicesTestSuite) SemaphoreFlakiness(c *check.C) { +func (s *ServicesTestSuite) SemaphoreFlakiness(t *testing.T) { ctx := context.Background() const renewals = 3 // wrap our services.Semaphores instance to cause two out of three lease @@ -1174,16 +1181,16 @@ func (s *ServicesTestSuite) SemaphoreFlakiness(c *check.C) { defer cancel() lock, err := services.AcquireSemaphoreLock(cancelCtx, cfg) - c.Assert(err, check.IsNil) + require.NoError(t, err) for i := 0; i < renewals; i++ { select { case <-lock.Renewed(): continue case <-lock.Done(): - c.Fatalf("Lost semaphore lock: %v", lock.Wait()) + t.Fatalf("Lost semaphore lock: %v", lock.Wait()) case <-time.After(time.Second): - c.Fatalf("Timeout waiting for renewals") + t.Fatalf("Timeout waiting for renewals") } } } @@ -1195,7 +1202,7 @@ func (s *ServicesTestSuite) SemaphoreFlakiness(c *check.C) { // fairly small. Semaphores aren't cheap and the auth server is expected // to start returning "too much contention" errors at around 100 concurrent // attempts. -func (s *ServicesTestSuite) SemaphoreContention(c *check.C) { +func (s *ServicesTestSuite) SemaphoreContention(t *testing.T) { ctx := context.Background() const locks int64 = 50 const iters = 5 @@ -1219,21 +1226,22 @@ func (s *ServicesTestSuite) SemaphoreContention(c *check.C) { go func() { defer wg.Done() _, err := services.AcquireSemaphoreLock(cancelCtx, cfg) - c.Assert(err, check.IsNil) + require.NoError(t, err) }() } wg.Wait() cancel() - c.Assert(s.PresenceS.DeleteSemaphore(ctx, types.SemaphoreFilter{ + require.NoError(t, s.PresenceS.DeleteSemaphore(ctx, types.SemaphoreFilter{ SemaphoreKind: cfg.Params.SemaphoreKind, SemaphoreName: cfg.Params.SemaphoreName, - }), check.IsNil) + })) + } } // SemaphoreConcurrency verifies that a large number of concurrent // acquisitions result in the correct number of successful acquisitions. -func (s *ServicesTestSuite) SemaphoreConcurrency(c *check.C) { +func (s *ServicesTestSuite) SemaphoreConcurrency(t *testing.T) { ctx := context.Background() const maxLeases int64 = 20 const attempts int64 = 200 @@ -1267,14 +1275,15 @@ func (s *ServicesTestSuite) SemaphoreConcurrency(c *check.C) { }() } wg.Wait() - c.Assert(atomic.LoadInt64(&success), check.Equals, maxLeases) - c.Assert(atomic.LoadInt64(&failure), check.Equals, attempts-maxLeases) + require.Equal(t, atomic.LoadInt64(&success), maxLeases) + require.Equal(t, atomic.LoadInt64(&failure), attempts-maxLeases) } // SemaphoreLock verifies correct functionality of the basic // semaphore lock scenarios. -func (s *ServicesTestSuite) SemaphoreLock(c *check.C) { +func (s *ServicesTestSuite) SemaphoreLock(t *testing.T) { ctx := context.Background() + cfg := services.SemaphoreLockConfig{ Service: s.PresenceS, Expiry: time.Hour, @@ -1287,51 +1296,51 @@ func (s *ServicesTestSuite) SemaphoreLock(c *check.C) { cancelCtx, cancel := context.WithCancel(ctx) defer cancel() lock, err := services.AcquireSemaphoreLock(cancelCtx, cfg) - c.Assert(err, check.IsNil) + require.NoError(t, err) // MaxLeases is 1, so second acquire op fails. _, err = services.AcquireSemaphoreLock(cancelCtx, cfg) - fixtures.ExpectLimitExceeded(c, err) + require.True(t, trace.IsLimitExceeded(err)) // Lock is successfully released. lock.Stop() - c.Assert(lock.Wait(), check.IsNil) + require.NoError(t, lock.Wait()) // Acquire new lock with short expiry // and high tick rate to verify renewals. cfg.Expiry = time.Second cfg.TickRate = time.Millisecond * 50 lock, err = services.AcquireSemaphoreLock(cancelCtx, cfg) - c.Assert(err, check.IsNil) + require.NoError(t, err) timeout := time.After(time.Second) for i := 0; i < 3; i++ { select { case <-lock.Done(): - c.Fatalf("Unexpected lock failure: %v", lock.Wait()) + t.Fatalf("Unexpected lock failure: %v", lock.Wait()) case <-timeout: - c.Fatalf("Timeout waiting for lock renewal %d", i) + t.Fatalf("Timeout waiting for lock renewal %d", i) case <-lock.Renewed(): } } // forcibly delete the semaphore - c.Assert(s.PresenceS.DeleteSemaphore(ctx, types.SemaphoreFilter{ + require.NoError(t, s.PresenceS.DeleteSemaphore(ctx, types.SemaphoreFilter{ SemaphoreKind: cfg.Params.SemaphoreKind, SemaphoreName: cfg.Params.SemaphoreName, - }), check.IsNil) + })) select { case <-lock.Done(): - fixtures.ExpectNotFound(c, lock.Wait()) + require.True(t, trace.IsNotFound(lock.Wait())) case <-time.After(time.Millisecond * 1500): - c.Errorf("timeout waiting for semaphore lock failure") + t.Errorf("timeout waiting for semaphore lock failure") } } // Events tests various events variations -func (s *ServicesTestSuite) Events(c *check.C) { +func (s *ServicesTestSuite) Events(t *testing.T) { ctx := context.Background() testCases := []eventTest{ { @@ -1342,17 +1351,17 @@ func (s *ServicesTestSuite) Events(c *check.C) { }, crud: func(context.Context) types.Resource { ca := NewTestCA(types.UserCA, "example.com") - c.Assert(s.CAS.UpsertCertAuthority(ca), check.IsNil) + require.NoError(t, s.CAS.UpsertCertAuthority(ca)) out, err := s.CAS.GetCertAuthority(ctx, *ca.ID(), true) - c.Assert(err, check.IsNil) + require.NoError(t, err) - c.Assert(s.CAS.DeleteCertAuthority(*ca.ID()), check.IsNil) + require.NoError(t, s.CAS.DeleteCertAuthority(*ca.ID())) return out }, }, } - s.runEventsTests(c, testCases) + s.runEventsTests(t, testCases) testCases = []eventTest{ { @@ -1363,17 +1372,17 @@ func (s *ServicesTestSuite) Events(c *check.C) { }, crud: func(context.Context) types.Resource { ca := NewTestCA(types.UserCA, "example.com") - c.Assert(s.CAS.UpsertCertAuthority(ca), check.IsNil) + require.NoError(t, s.CAS.UpsertCertAuthority(ca)) out, err := s.CAS.GetCertAuthority(ctx, *ca.ID(), false) - c.Assert(err, check.IsNil) + require.NoError(t, err) - c.Assert(s.CAS.DeleteCertAuthority(*ca.ID()), check.IsNil) + require.NoError(t, s.CAS.DeleteCertAuthority(*ca.ID())) return out }, }, } - s.runEventsTests(c, testCases) + s.runEventsTests(t, testCases) testCases = []eventTest{ { @@ -1383,16 +1392,16 @@ func (s *ServicesTestSuite) Events(c *check.C) { }, crud: func(context.Context) types.Resource { expires := time.Now().UTC().Add(time.Hour) - t, err := types.NewProvisionToken("token", + tok, err := types.NewProvisionToken("token", types.SystemRoles{types.RoleAuth, types.RoleNode}, expires) - c.Assert(err, check.IsNil) + require.NoError(t, err) - c.Assert(s.ProvisioningS.UpsertToken(ctx, t), check.IsNil) + require.NoError(t, s.ProvisioningS.UpsertToken(ctx, tok)) token, err := s.ProvisioningS.GetToken(ctx, "token") - c.Assert(err, check.IsNil) + require.NoError(t, err) - c.Assert(s.ProvisioningS.DeleteToken(ctx, "token"), check.IsNil) + require.NoError(t, s.ProvisioningS.DeleteToken(ctx, "token")) return token }, }, @@ -1411,13 +1420,13 @@ func (s *ServicesTestSuite) Events(c *check.C) { }, } err := s.PresenceS.UpsertNamespace(ns) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err := s.PresenceS.GetNamespace(ns.Metadata.Name) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.PresenceS.DeleteNamespace(ns.Metadata.Name) - c.Assert(err, check.IsNil) + require.NoError(t, err) return out }, @@ -1437,16 +1446,16 @@ func (s *ServicesTestSuite) Events(c *check.C) { }, }, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.ConfigS.SetStaticTokens(staticTokens) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err := s.ConfigS.GetStaticTokens() - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.ConfigS.DeleteStaticTokens() - c.Assert(err, check.IsNil) + require.NoError(t, err) return out }, @@ -1467,16 +1476,16 @@ func (s *ServicesTestSuite) Events(c *check.C) { }, Deny: types.RoleConditions{}, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.Access.UpsertRole(ctx, role) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err := s.Access.GetRole(ctx, role.GetName()) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.Access.DeleteRole(ctx, role.GetName()) - c.Assert(err, check.IsNil) + require.NoError(t, err) return out }, @@ -1489,12 +1498,12 @@ func (s *ServicesTestSuite) Events(c *check.C) { crud: func(context.Context) types.Resource { user := newUser("user1", []string{"admin"}) err := s.Users().UpsertUser(user) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err := s.Users().GetUser(user.GetName(), false) - c.Assert(err, check.IsNil) + require.NoError(t, err) - c.Assert(s.Users().DeleteUser(ctx, user.GetName()), check.IsNil) + require.NoError(t, s.Users().DeleteUser(ctx, user.GetName())) return out }, }, @@ -1507,13 +1516,13 @@ func (s *ServicesTestSuite) Events(c *check.C) { srv := NewServer(types.KindNode, "srv1", "127.0.0.1:2022", apidefaults.Namespace) _, err := s.PresenceS.UpsertNode(ctx, srv) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err := s.PresenceS.GetNodes(ctx, srv.Metadata.Namespace) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.PresenceS.DeleteAllNodes(ctx, srv.Metadata.Namespace) - c.Assert(err, check.IsNil) + require.NoError(t, err) return out[0] }, @@ -1527,13 +1536,13 @@ func (s *ServicesTestSuite) Events(c *check.C) { srv := NewServer(types.KindProxy, "srv1", "127.0.0.1:2022", apidefaults.Namespace) err := s.PresenceS.UpsertProxy(srv) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err := s.PresenceS.GetProxies() - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.PresenceS.DeleteAllProxies() - c.Assert(err, check.IsNil) + require.NoError(t, err) return out[0] }, @@ -1549,16 +1558,16 @@ func (s *ServicesTestSuite) Events(c *check.C) { ProxyName: "p1", LastHeartbeat: time.Now().UTC(), }) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.PresenceS.UpsertTunnelConnection(conn) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err := s.PresenceS.GetTunnelConnections("example.com") - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.PresenceS.DeleteAllTunnelConnections() - c.Assert(err, check.IsNil) + require.NoError(t, err) return out[0] }, @@ -1570,13 +1579,13 @@ func (s *ServicesTestSuite) Events(c *check.C) { }, crud: func(context.Context) types.Resource { tunnel := newReverseTunnel("example.com", []string{"example.com:2023"}) - c.Assert(s.PresenceS.UpsertReverseTunnel(tunnel), check.IsNil) + require.NoError(t, s.PresenceS.UpsertReverseTunnel(tunnel)) out, err := s.PresenceS.GetReverseTunnels(context.Background()) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.PresenceS.DeleteReverseTunnel(tunnel.Spec.ClusterName) - c.Assert(err, check.IsNil) + require.NoError(t, err) return out[0] }, @@ -1589,20 +1598,20 @@ func (s *ServicesTestSuite) Events(c *check.C) { crud: func(context.Context) types.Resource { rc, err := types.NewRemoteCluster("example.com") rc.SetConnectionStatus(teleport.RemoteClusterStatusOffline) - c.Assert(err, check.IsNil) - c.Assert(s.PresenceS.CreateRemoteCluster(rc), check.IsNil) + require.NoError(t, err) + require.NoError(t, s.PresenceS.CreateRemoteCluster(rc)) out, err := s.PresenceS.GetRemoteClusters() - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.PresenceS.DeleteRemoteCluster(rc.GetName()) - c.Assert(err, check.IsNil) + require.NoError(t, err) return out[0] }, }, } - s.runEventsTests(c, testCases) + s.runEventsTests(t, testCases) // Namespace with a name testCases = []eventTest{ @@ -1622,23 +1631,23 @@ func (s *ServicesTestSuite) Events(c *check.C) { }, } err := s.PresenceS.UpsertNamespace(ns) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err := s.PresenceS.GetNamespace(ns.Metadata.Name) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.PresenceS.DeleteNamespace(ns.Metadata.Name) - c.Assert(err, check.IsNil) + require.NoError(t, err) return out }, }, } - s.runEventsTests(c, testCases) + s.runEventsTests(t, testCases) } // EventsClusterConfig tests cluster config resource events -func (s *ServicesTestSuite) EventsClusterConfig(c *check.C) { +func (s *ServicesTestSuite) EventsClusterConfig(t *testing.T) { testCases := []eventTest{ { name: "Cluster name", @@ -1649,16 +1658,16 @@ func (s *ServicesTestSuite) EventsClusterConfig(c *check.C) { clusterName, err := services.NewClusterNameWithRandomID(types.ClusterNameSpecV2{ ClusterName: "example.com", }) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.ConfigS.UpsertClusterName(clusterName) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err := s.ConfigS.GetClusterName() - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.ConfigS.DeleteClusterName() - c.Assert(err, check.IsNil) + require.NoError(t, err) return out }, }, @@ -1674,16 +1683,16 @@ func (s *ServicesTestSuite) EventsClusterConfig(c *check.C) { AuditSessionsURI: "file:///home/log", AuditEventsURI: []string{"dynamodb://audit_table_name", "file:///home/test/log"}, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.ConfigS.SetClusterAuditConfig(ctx, auditConfig) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err := s.ConfigS.GetClusterAuditConfig(ctx) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.ConfigS.DeleteClusterAuditConfig(ctx) - c.Assert(err, check.IsNil) + require.NoError(t, err) return out }, }, @@ -1696,16 +1705,16 @@ func (s *ServicesTestSuite) EventsClusterConfig(c *check.C) { netConfig, err := types.NewClusterNetworkingConfigFromConfigFile(types.ClusterNetworkingConfigSpecV2{ ClientIdleTimeout: types.Duration(5 * time.Second), }) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.ConfigS.SetClusterNetworkingConfig(ctx, netConfig) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err := s.ConfigS.GetClusterNetworkingConfig(ctx) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.ConfigS.DeleteClusterNetworkingConfig(ctx) - c.Assert(err, check.IsNil) + require.NoError(t, err) return out }, }, @@ -1718,33 +1727,33 @@ func (s *ServicesTestSuite) EventsClusterConfig(c *check.C) { recConfig, err := types.NewSessionRecordingConfigFromConfigFile(types.SessionRecordingConfigSpecV2{ Mode: types.RecordAtProxySync, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.ConfigS.SetSessionRecordingConfig(ctx, recConfig) - c.Assert(err, check.IsNil) + require.NoError(t, err) out, err := s.ConfigS.GetSessionRecordingConfig(ctx) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.ConfigS.DeleteSessionRecordingConfig(ctx) - c.Assert(err, check.IsNil) + require.NoError(t, err) return out }, }, } - s.runEventsTests(c, testCases) + s.runEventsTests(t, testCases) } // NetworkRestrictions tests network restrictions. -func (s *ServicesTestSuite) NetworkRestrictions(c *check.C, opts ...Option) { +func (s *ServicesTestSuite) NetworkRestrictions(t *testing.T, opts ...Option) { ctx := context.Background() // blank slate, should be get/delete should fail _, err := s.RestrictionsS.GetNetworkRestrictions(ctx) - fixtures.ExpectNotFound(c, err) + require.True(t, trace.IsNotFound(err)) err = s.RestrictionsS.DeleteNetworkRestrictions(ctx) - fixtures.ExpectNotFound(c, err) + require.True(t, trace.IsNotFound(err)) allow := []types.AddressCondition{ {CIDR: "10.0.1.0/24"}, @@ -1761,37 +1770,37 @@ func (s *ServicesTestSuite) NetworkRestrictions(c *check.C, opts ...Option) { // set and make sure we get it back err = s.RestrictionsS.SetNetworkRestrictions(ctx, expected) - c.Assert(err, check.IsNil) + require.NoError(t, err) actual, err := s.RestrictionsS.GetNetworkRestrictions(ctx) - c.Assert(err, check.IsNil) + require.NoError(t, err) - fixtures.DeepCompare(c, expected.GetAllow(), actual.GetAllow()) - fixtures.DeepCompare(c, expected.GetDeny(), actual.GetDeny()) + require.Empty(t, cmp.Diff(expected.GetAllow(), actual.GetAllow())) + require.Empty(t, cmp.Diff(expected.GetDeny(), actual.GetDeny())) // now delete should work ok and get should fail again err = s.RestrictionsS.DeleteNetworkRestrictions(ctx) - c.Assert(err, check.IsNil) + require.NoError(t, err) err = s.RestrictionsS.DeleteNetworkRestrictions(ctx) - fixtures.ExpectNotFound(c, err) + require.True(t, trace.IsNotFound(err)) } -func (s *ServicesTestSuite) runEventsTests(c *check.C, testCases []eventTest) { +func (s *ServicesTestSuite) runEventsTests(t *testing.T, testCases []eventTest) { ctx := context.Background() w, err := s.EventsS.NewWatcher(ctx, types.Watch{ Kinds: eventsTestKinds(testCases), }) - c.Assert(err, check.IsNil) + require.NoError(t, err) defer w.Close() select { case event := <-w.Events(): - c.Assert(event.Type, check.Equals, types.OpInit) + require.Equal(t, event.Type, types.OpInit) case <-w.Done(): - c.Fatalf("Watcher exited with error %v", w.Error()) + t.Fatalf("Watcher exited with error %v", w.Error()) case <-time.After(2 * time.Second): - c.Fatalf("Timeout waiting for init event") + t.Fatalf("Timeout waiting for init event") } // filter out all events that could have been inserted @@ -1805,15 +1814,15 @@ skiploop: default: break skiploop case <-w.Done(): - c.Fatalf("Watcher exited with error %v", w.Error()) + t.Fatalf("Watcher exited with error %v", w.Error()) } } for _, tc := range testCases { - c.Logf("test case %q", tc.name) + t.Logf("test case %q", tc.name) resource := tc.crud(ctx) - ExpectResource(c, w, 3*time.Second, resource) + ExpectResource(t, w, 3*time.Second, resource) meta := resource.GetMetadata() header := &types.ResourceHeader{ @@ -1827,7 +1836,7 @@ skiploop: } // delete events don't have IDs yet header.SetResourceID(0) - ExpectDeleteResource(c, w, 3*time.Second, header) + ExpectDeleteResource(t, w, 3*time.Second, header) } } @@ -1846,15 +1855,15 @@ func eventsTestKinds(tests []eventTest) []types.WatchKind { } // ExpectResource expects a Put event of a certain resource -func ExpectResource(c *check.C, w types.Watcher, timeout time.Duration, resource types.Resource) { +func ExpectResource(t *testing.T, w types.Watcher, timeout time.Duration, resource types.Resource) { timeoutC := time.After(timeout) waitLoop: for { select { case <-timeoutC: - c.Fatalf("Timeout waiting for event") + t.Fatalf("Timeout waiting for event") case <-w.Done(): - c.Fatalf("Watcher exited with error %v", w.Error()) + t.Fatalf("Watcher exited with error %v", w.Error()) case event := <-w.Events(): if event.Type != types.OpPut { log.Debugf("Skipping event %+v", event) @@ -1868,28 +1877,28 @@ waitLoop: log.Debugf("Skipping event %v resource %v, expecting %v", event.Type, event.Resource.GetMetadata(), event.Resource.GetMetadata()) continue waitLoop } - fixtures.DeepCompare(c, resource, event.Resource) + require.Empty(t, cmp.Diff(resource, event.Resource)) break waitLoop } } } // ExpectDeleteResource expects a delete event of a certain kind -func ExpectDeleteResource(c *check.C, w types.Watcher, timeout time.Duration, resource types.Resource) { +func ExpectDeleteResource(t *testing.T, w types.Watcher, timeout time.Duration, resource types.Resource) { timeoutC := time.After(timeout) waitLoop: for { select { case <-timeoutC: - c.Fatalf("Timeout waiting for delete resource %v", resource) + t.Fatalf("Timeout waiting for delete resource %v", resource) case <-w.Done(): - c.Fatalf("Watcher exited with error %v", w.Error()) + t.Fatalf("Watcher exited with error %v", w.Error()) case event := <-w.Events(): if event.Type != types.OpDelete { log.Debugf("Skipping stale event %v %v", event.Type, event.Resource.GetName()) continue } - fixtures.DeepCompare(c, resource, event.Resource) + require.Empty(t, cmp.Diff(resource, event.Resource)) break waitLoop } } From 88e0d9a9c470e0957490feebe4fe6b4e267a78a9 Mon Sep 17 00:00:00 2001 From: Forrest Marshall Date: Tue, 12 Jul 2022 20:20:44 +0000 Subject: [PATCH 2/2] improve semaphore retries and tests --- lib/services/local/presence.go | 6 +++--- lib/services/suite/suite.go | 13 ++++++------- lib/utils/retry.go | 20 ++++++++++++++++++++ 3 files changed, 29 insertions(+), 10 deletions(-) diff --git a/lib/services/local/presence.go b/lib/services/local/presence.go index a97dd02b021a0..07ad793c4d31d 100644 --- a/lib/services/local/presence.go +++ b/lib/services/local/presence.go @@ -52,7 +52,7 @@ type backendItemToResourceFunc func(item backend.Item) (types.ResourceWithLabels func NewPresenceService(b backend.Backend) *PresenceService { return &PresenceService{ log: logrus.WithFields(logrus.Fields{trace.Component: "Presence"}), - jitter: utils.NewJitter(), + jitter: utils.NewFullJitter(), Backend: b, } } @@ -689,10 +689,10 @@ func (s *PresenceService) DeleteAllRemoteClusters() error { } // this combination of backoff parameters leads to worst-case total time spent -// in backoff between 1200ms and 2400ms depending on jitter. tests are in +// in backoff between 1ms and 2000ms depending on jitter. tests are in // place to verify that this is sufficient to resolve a 20-lease contention // event, which is worse than should ever occur in practice. -const baseBackoff = time.Millisecond * 300 +const baseBackoff = time.Millisecond * 400 const leaseRetryAttempts int64 = 6 // AcquireSemaphore attempts to acquire the specified semaphore. AcquireSemaphore will automatically handle diff --git a/lib/services/suite/suite.go b/lib/services/suite/suite.go index 50f5534a5cca1..b32676997039f 100644 --- a/lib/services/suite/suite.go +++ b/lib/services/suite/suite.go @@ -1212,7 +1212,7 @@ func (s *ServicesTestSuite) SemaphoreContention(t *testing.T) { Expiry: time.Hour, Params: types.AcquireSemaphoreRequest{ SemaphoreKind: types.SemaphoreKindConnection, - SemaphoreName: "alice", + SemaphoreName: fmt.Sprintf("sem-%d", i), // avoid overlap between iterations MaxLeases: locks, }, } @@ -1220,22 +1220,21 @@ func (s *ServicesTestSuite) SemaphoreContention(t *testing.T) { // context-based cancellation is needed to cleanup the // background keepalive activity. cancelCtx, cancel := context.WithCancel(ctx) - var wg sync.WaitGroup + acquireErrs := make(chan error, locks) for i := int64(0); i < locks; i++ { - wg.Add(1) go func() { - defer wg.Done() _, err := services.AcquireSemaphoreLock(cancelCtx, cfg) - require.NoError(t, err) + acquireErrs <- err }() } - wg.Wait() + for i := int64(0); i < locks; i++ { + require.NoError(t, <-acquireErrs) + } cancel() require.NoError(t, s.PresenceS.DeleteSemaphore(ctx, types.SemaphoreFilter{ SemaphoreKind: cfg.Params.SemaphoreKind, SemaphoreName: cfg.Params.SemaphoreName, })) - } } diff --git a/lib/utils/retry.go b/lib/utils/retry.go index a2b8d23885345..2a1bd6d0a45ab 100644 --- a/lib/utils/retry.go +++ b/lib/utils/retry.go @@ -85,6 +85,26 @@ func NewSeventhJitter() Jitter { } } +// NewFullJitter builds a new jitter on the range (0,n]. Most use-cases +// are better served by a jitter with a meaningful minimum value, but if +// the *only* purpose of the jitter is to spread out retries to the greatest +// extent possible (e.g. when retrying a CompareAndSwap operation), a full jitter +// may be appropriate. +func NewFullJitter() Jitter { + var mu sync.Mutex + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + return func(d time.Duration) time.Duration { + // values less than 1 cause rng to panic, and some logic + // relies on treating zero duration as non-blocking case. + if d < 1 { + return 0 + } + mu.Lock() + defer mu.Unlock() + return time.Duration(1) + time.Duration(rng.Int63n(int64(d))) + } +} + // Retry is an interface that provides retry logic type Retry interface { // Reset resets retry state