Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 42 additions & 12 deletions lib/auth/auth_login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -735,41 +735,40 @@ func TestServer_Authenticate_headless(t *testing.T) {
}

for _, tc := range []struct {
name string
update updateHeadlessAuthnFn
checkErr require.ErrorAssertionFunc
name string
update updateHeadlessAuthnFn
expectErr bool
}{
{
name: "OK approved",
update: func(ha *types.HeadlessAuthentication, mfa *types.MFADevice) {
ha.State = types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED
ha.MfaDevice = mfa
},
checkErr: require.NoError,
}, {
name: "NOK approved without MFA",
update: func(ha *types.HeadlessAuthentication, mfa *types.MFADevice) {
ha.State = types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED
},
checkErr: require.Error,
expectErr: true,
}, {
name: "NOK user mismatch",
update: func(ha *types.HeadlessAuthentication, mfa *types.MFADevice) {
ha.State = types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED
ha.MfaDevice = mfa
ha.User = "other-user"
},
checkErr: require.Error,
expectErr: true,
}, {
name: "NOK denied",
update: func(ha *types.HeadlessAuthentication, mfa *types.MFADevice) {
ha.State = types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_DENIED
},
checkErr: require.Error,
expectErr: true,
}, {
name: "NOK timeout",
update: func(ha *types.HeadlessAuthentication, mfa *types.MFADevice) {},
checkErr: require.Error,
name: "NOK timeout",
update: func(ha *types.HeadlessAuthentication, mfa *types.MFADevice) {},
expectErr: true,
},
} {
t.Run(tc.name, func(t *testing.T) {
Expand All @@ -785,6 +784,20 @@ func TestServer_Authenticate_headless(t *testing.T) {
mfa := configureForMFA(t, srv)
username := mfa.User

// Fail a login attempt so have a non-empty list of attempts.
_, err = proxyClient.AuthenticateSSHUser(ctx, AuthenticateSSHRequest{
AuthenticateUserRequest: AuthenticateUserRequest{
Username: username,
Webauthn: &wanlib.CredentialAssertionResponse{}, // bad response
PublicKey: []byte(sshPubKey),
},
TTL: 24 * time.Hour,
})
require.True(t, trace.IsAccessDenied(err), "got err = %v, want AccessDenied")
attempts, err := srv.Auth().GetUserLoginAttempts(username)
require.NoError(t, err)
require.NotEmpty(t, attempts, "Want at least one failed login attempt")

t.Cleanup(func() {
srv.Auth().DeleteHeadlessAuthentication(ctx, headlessID)
})
Expand All @@ -795,17 +808,34 @@ func TestServer_Authenticate_headless(t *testing.T) {
errC := updateHeadlessAuthnInGoroutine(ctx, srv, mfa.WebDev.MFA, tc.update)
_, err = proxyClient.AuthenticateSSHUser(ctx, AuthenticateSSHRequest{
AuthenticateUserRequest: AuthenticateUserRequest{
// HeadlessAuthenticationID should take precedence over WebAuthn and OTP fields.
HeadlessAuthenticationID: headlessID,
Webauthn: &wanlib.CredentialAssertionResponse{},
OTP: &OTPCreds{},
Username: username,
PublicKey: []byte(sshPubKey),
HeadlessAuthenticationID: headlessID,
ClientMetadata: &ForwardedClientMetadata{
RemoteAddr: "0.0.0.0",
},
},
TTL: defaults.CallbackTimeout,
})
tc.checkErr(t, err)
require.NoError(t, <-errC)
if tc.expectErr {
require.Error(t, err)
// Verify login attempts unchanged. This is a proxy for various other user
// checks (locked, etc).
updatedAttempts, err := srv.Auth().GetUserLoginAttempts(username)
require.NoError(t, err)
require.Equal(t, attempts, updatedAttempts, "Login attempts unexpectedly changed")
} else {
require.NoError(t, err)
// Verify zeroed login attempts. This is a proxy for various other user
// checks (locked, etc).
updatedAttempts, err := srv.Auth().GetUserLoginAttempts(username)
require.NoError(t, err)
require.Empty(t, updatedAttempts, "Login attempts not reset")
}
})
}
}
Expand Down
24 changes: 12 additions & 12 deletions lib/auth/methods.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,18 @@ func (s *Server) authenticateUser(ctx context.Context, req AuthenticateUserReque
var authErr error // error message kept obscure on purpose, use logging for details
switch {
// cases in order of preference
case req.HeadlessAuthenticationID != "":
Comment thread
codingllama marked this conversation as resolved.
// handle authentication before the user lock to prevent locking out users
// due to timed-out/canceled headless authentication attempts.
mfaDevice, err := s.authenticateHeadless(ctx, req)
if err != nil {
log.Debugf("Headless Authentication for user %q failed while waiting for approval: %v", user, err)
return nil, "", trace.Wrap(invalidHeadlessAuthenticationError)
}
authenticateFn = func() (*types.MFADevice, error) {
return mfaDevice, nil
}
authErr = invalidHeadlessAuthenticationError
case req.Webauthn != nil:
authenticateFn = func() (*types.MFADevice, error) {
mfaResponse := &proto.MFAAuthenticateResponse{
Expand All @@ -224,18 +236,6 @@ func (s *Server) authenticateUser(ctx context.Context, req AuthenticateUserReque
return res.mfaDev, nil
}
authErr = invalidUserPass2FError
case req.HeadlessAuthenticationID != "":
// handle authentication before the user lock to prevent locking out users
// due to timed-out/canceled headless authentication attempts.
mfaDevice, err := s.authenticateHeadless(ctx, req)
if err != nil {
log.Debugf("Headless Authentication for user %q failed while waiting for approval: %v", user, err)
return nil, "", trace.Wrap(invalidHeadlessAuthenticationError)
}
authenticateFn = func() (*types.MFADevice, error) {
return mfaDevice, nil
}
authErr = invalidHeadlessAuthenticationError
}
if authenticateFn != nil {
var dev *types.MFADevice
Expand Down
63 changes: 34 additions & 29 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -3124,50 +3124,55 @@ func (h *Handler) createSSHCert(w http.ResponseWriter, r *http.Request, p httpro
return nil, trace.Wrap(err)
}

var authenticationClient interface {
AuthenticateSSHUser(ctx context.Context, req auth.AuthenticateSSHRequest) (*auth.SSHLoginResponse, error)
} = authClient
authSSHUserReq := auth.AuthenticateSSHRequest{
AuthenticateUserRequest: auth.AuthenticateUserRequest{
Username: req.User,
PublicKey: req.PubKey,
ClientMetadata: clientMetaFromReq(r),
},
CompatibilityMode: req.Compatibility,
TTL: req.TTL,
RouteToCluster: req.RouteToCluster,
KubernetesCluster: req.KubernetesCluster,
AttestationStatement: req.AttestationStatement,
}

if req.HeadlessAuthenticationID != "" {
// We need to use the default callback timeout rather than the standard client timeout.
// However, authClient is shared across all Proxy->Auth requests, so we need to create
// a new client to avoid applying the callback timeout to other concurrent requests. To
// this end, we create a clone of the HTTP Client with the desired timeout instead.
httpClient, err := authClient.CloneHTTPClient(
auth.ClientParamTimeout(defaults.CallbackTimeout),
auth.ClientParamResponseHeaderTimeout(defaults.CallbackTimeout),
)
if err != nil {
return nil, trace.Wrap(err)
}

authReq := auth.AuthenticateUserRequest{
Username: req.User,
PublicKey: req.PubKey,
ClientMetadata: clientMetaFromReq(r),
authSSHUserReq.AuthenticateUserRequest.HeadlessAuthenticationID = req.HeadlessAuthenticationID
loginResp, err := httpClient.AuthenticateSSHUser(r.Context(), authSSHUserReq)
if err != nil {
return nil, trace.Wrap(err)
}
return loginResp, nil
}

switch cap.GetSecondFactor() {
case constants.SecondFactorOff:
authReq.Pass = &auth.PassCreds{
authSSHUserReq.AuthenticateUserRequest.Pass = &auth.PassCreds{
Password: []byte(req.Password),
}
case constants.SecondFactorOTP, constants.SecondFactorOn, constants.SecondFactorOptional:
authReq.OTP = &auth.OTPCreds{
authSSHUserReq.AuthenticateUserRequest.OTP = &auth.OTPCreds{
Password: []byte(req.Password),
Token: req.OTPToken,
}
case constants.SecondFactorWebauthn:
// WebAuthn only supports this endpoint for headless login.
authReq.HeadlessAuthenticationID = req.HeadlessAuthenticationID

// create a new http client with a standard callback timeout.
authenticationClient, err = authClient.CloneHTTPClient(
auth.ClientParamTimeout(defaults.CallbackTimeout),
auth.ClientParamResponseHeaderTimeout(defaults.CallbackTimeout),
)
if err != nil {
return nil, trace.Wrap(err)
}
default:
return nil, trace.AccessDenied("unsupported second factor type: %q", cap.GetSecondFactor())
}

loginResp, err := authenticationClient.AuthenticateSSHUser(r.Context(), auth.AuthenticateSSHRequest{
AuthenticateUserRequest: authReq,
CompatibilityMode: req.Compatibility,
TTL: req.TTL,
RouteToCluster: req.RouteToCluster,
KubernetesCluster: req.KubernetesCluster,
AttestationStatement: req.AttestationStatement,
})
loginResp, err := authClient.AuthenticateSSHUser(r.Context(), authSSHUserReq)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down