diff --git a/lib/auth/auth_login_test.go b/lib/auth/auth_login_test.go index 2250e5cfe9e05..befdca96a6d28 100644 --- a/lib/auth/auth_login_test.go +++ b/lib/auth/auth_login_test.go @@ -707,7 +707,6 @@ func TestServer_Authenticate_headless(t *testing.T) { ctx := context.Background() headlessID := services.NewHeadlessAuthenticationID([]byte(sshPubKey)) - const timeout = time.Millisecond * 200 type updateHeadlessAuthnFn func(*types.HeadlessAuthentication, *types.MFADevice) updateHeadlessAuthnInGoroutine := func(ctx context.Context, srv *TestTLSServer, mfa *types.MFADevice, update updateHeadlessAuthnFn) chan error { @@ -784,7 +783,7 @@ func TestServer_Authenticate_headless(t *testing.T) { mfa := configureForMFA(t, srv) username := mfa.User - // Fail a login attempt so have a non-empty list of attempts. + // Fail a login attempt so we have a non-empty list of attempts. _, err = proxyClient.AuthenticateSSHUser(ctx, AuthenticateSSHRequest{ AuthenticateUserRequest: AuthenticateUserRequest{ Username: username, @@ -793,16 +792,12 @@ func TestServer_Authenticate_headless(t *testing.T) { }, TTL: 24 * time.Hour, }) - require.True(t, trace.IsAccessDenied(err), "got err = %v, want AccessDenied") + require.True(t, trace.IsAccessDenied(err), "got err = %v, want AccessDenied", err) attempts, err := srv.Auth().GetUserLoginAttempts(username) require.NoError(t, err) require.NotEmpty(t, attempts, "Want at least one failed login attempt") - t.Cleanup(func() { - srv.Auth().DeleteHeadlessAuthentication(ctx, headlessID) - }) - - ctx, cancel := context.WithTimeout(ctx, timeout) + ctx, cancel := context.WithTimeout(ctx, 2*time.Second) defer cancel() errC := updateHeadlessAuthnInGoroutine(ctx, srv, mfa.WebDev.MFA, tc.update) @@ -821,6 +816,7 @@ func TestServer_Authenticate_headless(t *testing.T) { TTL: defaults.CallbackTimeout, }) require.NoError(t, <-errC) + if tc.expectErr { require.Error(t, err) // Verify login attempts unchanged. This is a proxy for various other user diff --git a/lib/services/local/headlessauthn.go b/lib/services/local/headlessauthn.go index 9537761020aa7..ca69c24cce98d 100644 --- a/lib/services/local/headlessauthn.go +++ b/lib/services/local/headlessauthn.go @@ -18,7 +18,6 @@ package local import ( "context" - "time" "github.com/gravitational/trace" @@ -31,7 +30,8 @@ import ( // CreateHeadlessAuthenticationStub creates a headless authentication stub in the backend. func (s *IdentityService) CreateHeadlessAuthenticationStub(ctx context.Context, name string) (*types.HeadlessAuthentication, error) { - headlessAuthn, err := types.NewHeadlessAuthenticationStub(name, s.Clock().Now().Add(defaults.CallbackTimeout)) + expires := s.Clock().Now().Add(defaults.CallbackTimeout) + headlessAuthn, err := types.NewHeadlessAuthenticationStub(name, expires) if err != nil { return nil, trace.Wrap(err) } @@ -115,6 +115,7 @@ func (s *IdentityService) DeleteHeadlessAuthentication(ctx context.Context, name return trace.Wrap(err) } +// MarshalHeadlessAuthenticationToItem marshals a headless authentication to a backend.Item. func MarshalHeadlessAuthenticationToItem(headlessAuthn *types.HeadlessAuthentication) (*backend.Item, error) { if err := headlessAuthn.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) @@ -132,15 +133,13 @@ func MarshalHeadlessAuthenticationToItem(headlessAuthn *types.HeadlessAuthentica }, nil } +// unmarshalHeadlessAuthenticationFromItem unmarshals a headless authentication from a backend.Item. func unmarshalHeadlessAuthenticationFromItem(item *backend.Item) (*types.HeadlessAuthentication, error) { var headlessAuthn types.HeadlessAuthentication if err := utils.FastUnmarshal(item.Value, &headlessAuthn); err != nil { return nil, trace.Wrap(err, "error unmarshalling headless authentication from storage") } - // Copy item.Expires without pointer to avoid race conditions with memory backend. - headlessAuthn.Metadata.Expires = new(time.Time) - *headlessAuthn.Metadata.Expires = item.Expires if err := headlessAuthn.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) }