Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions lib/auth/auth_login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,6 @@ func TestServer_Authenticate_headless(t *testing.T) {

ctx := context.Background()
headlessID := services.NewHeadlessAuthenticationID([]byte(sshPubKey))
const timeout = time.Millisecond * 200

type updateHeadlessAuthnFn func(*types.HeadlessAuthentication, *types.MFADevice)
updateHeadlessAuthnInGoroutine := func(ctx context.Context, srv *TestTLSServer, mfa *types.MFADevice, update updateHeadlessAuthnFn) chan error {
Expand Down Expand Up @@ -784,7 +783,7 @@ func TestServer_Authenticate_headless(t *testing.T) {
mfa := configureForMFA(t, srv)
username := mfa.User

// Fail a login attempt so have a non-empty list of attempts.
// Fail a login attempt so we have a non-empty list of attempts.
_, err = proxyClient.AuthenticateSSHUser(ctx, AuthenticateSSHRequest{
AuthenticateUserRequest: AuthenticateUserRequest{
Username: username,
Expand All @@ -793,16 +792,12 @@ func TestServer_Authenticate_headless(t *testing.T) {
},
TTL: 24 * time.Hour,
})
require.True(t, trace.IsAccessDenied(err), "got err = %v, want AccessDenied")
require.True(t, trace.IsAccessDenied(err), "got err = %v, want AccessDenied", err)
attempts, err := srv.Auth().GetUserLoginAttempts(username)
require.NoError(t, err)
require.NotEmpty(t, attempts, "Want at least one failed login attempt")

t.Cleanup(func() {
srv.Auth().DeleteHeadlessAuthentication(ctx, headlessID)
})

ctx, cancel := context.WithTimeout(ctx, timeout)
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()

errC := updateHeadlessAuthnInGoroutine(ctx, srv, mfa.WebDev.MFA, tc.update)
Expand All @@ -821,6 +816,7 @@ func TestServer_Authenticate_headless(t *testing.T) {
TTL: defaults.CallbackTimeout,
})
require.NoError(t, <-errC)

if tc.expectErr {
require.Error(t, err)
// Verify login attempts unchanged. This is a proxy for various other user
Expand Down
9 changes: 4 additions & 5 deletions lib/services/local/headlessauthn.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package local

import (
"context"
"time"

"github.com/gravitational/trace"

Expand All @@ -31,7 +30,8 @@ import (

// CreateHeadlessAuthenticationStub creates a headless authentication stub in the backend.
func (s *IdentityService) CreateHeadlessAuthenticationStub(ctx context.Context, name string) (*types.HeadlessAuthentication, error) {
headlessAuthn, err := types.NewHeadlessAuthenticationStub(name, s.Clock().Now().Add(defaults.CallbackTimeout))
expires := s.Clock().Now().Add(defaults.CallbackTimeout)
headlessAuthn, err := types.NewHeadlessAuthenticationStub(name, expires)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -115,6 +115,7 @@ func (s *IdentityService) DeleteHeadlessAuthentication(ctx context.Context, name
return trace.Wrap(err)
}

// MarshalHeadlessAuthenticationToItem marshals a headless authentication to a backend.Item.
func MarshalHeadlessAuthenticationToItem(headlessAuthn *types.HeadlessAuthentication) (*backend.Item, error) {
if err := headlessAuthn.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
Expand All @@ -132,15 +133,13 @@ func MarshalHeadlessAuthenticationToItem(headlessAuthn *types.HeadlessAuthentica
}, nil
}

// unmarshalHeadlessAuthenticationFromItem unmarshals a headless authentication from a backend.Item.
func unmarshalHeadlessAuthenticationFromItem(item *backend.Item) (*types.HeadlessAuthentication, error) {
var headlessAuthn types.HeadlessAuthentication
if err := utils.FastUnmarshal(item.Value, &headlessAuthn); err != nil {
return nil, trace.Wrap(err, "error unmarshalling headless authentication from storage")
}

// Copy item.Expires without pointer to avoid race conditions with memory backend.
headlessAuthn.Metadata.Expires = new(time.Time)
*headlessAuthn.Metadata.Expires = item.Expires
if err := headlessAuthn.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
Expand Down