diff --git a/lib/auth/auth_login_test.go b/lib/auth/auth_login_test.go index 57b214cb54257..2250e5cfe9e05 100644 --- a/lib/auth/auth_login_test.go +++ b/lib/auth/auth_login_test.go @@ -735,9 +735,9 @@ func TestServer_Authenticate_headless(t *testing.T) { } for _, tc := range []struct { - name string - update updateHeadlessAuthnFn - checkErr require.ErrorAssertionFunc + name string + update updateHeadlessAuthnFn + expectErr bool }{ { name: "OK approved", @@ -745,13 +745,12 @@ func TestServer_Authenticate_headless(t *testing.T) { ha.State = types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED ha.MfaDevice = mfa }, - checkErr: require.NoError, }, { name: "NOK approved without MFA", update: func(ha *types.HeadlessAuthentication, mfa *types.MFADevice) { ha.State = types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED }, - checkErr: require.Error, + expectErr: true, }, { name: "NOK user mismatch", update: func(ha *types.HeadlessAuthentication, mfa *types.MFADevice) { @@ -759,17 +758,17 @@ func TestServer_Authenticate_headless(t *testing.T) { ha.MfaDevice = mfa ha.User = "other-user" }, - checkErr: require.Error, + expectErr: true, }, { name: "NOK denied", update: func(ha *types.HeadlessAuthentication, mfa *types.MFADevice) { ha.State = types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_DENIED }, - checkErr: require.Error, + expectErr: true, }, { - name: "NOK timeout", - update: func(ha *types.HeadlessAuthentication, mfa *types.MFADevice) {}, - checkErr: require.Error, + name: "NOK timeout", + update: func(ha *types.HeadlessAuthentication, mfa *types.MFADevice) {}, + expectErr: true, }, } { t.Run(tc.name, func(t *testing.T) { @@ -785,6 +784,20 @@ func TestServer_Authenticate_headless(t *testing.T) { mfa := configureForMFA(t, srv) username := mfa.User + // Fail a login attempt so have a non-empty list of attempts. + _, err = proxyClient.AuthenticateSSHUser(ctx, AuthenticateSSHRequest{ + AuthenticateUserRequest: AuthenticateUserRequest{ + Username: username, + Webauthn: &wanlib.CredentialAssertionResponse{}, // bad response + PublicKey: []byte(sshPubKey), + }, + TTL: 24 * time.Hour, + }) + require.True(t, trace.IsAccessDenied(err), "got err = %v, want AccessDenied") + attempts, err := srv.Auth().GetUserLoginAttempts(username) + require.NoError(t, err) + require.NotEmpty(t, attempts, "Want at least one failed login attempt") + t.Cleanup(func() { srv.Auth().DeleteHeadlessAuthentication(ctx, headlessID) }) @@ -795,17 +808,34 @@ func TestServer_Authenticate_headless(t *testing.T) { errC := updateHeadlessAuthnInGoroutine(ctx, srv, mfa.WebDev.MFA, tc.update) _, err = proxyClient.AuthenticateSSHUser(ctx, AuthenticateSSHRequest{ AuthenticateUserRequest: AuthenticateUserRequest{ + // HeadlessAuthenticationID should take precedence over WebAuthn and OTP fields. + HeadlessAuthenticationID: headlessID, + Webauthn: &wanlib.CredentialAssertionResponse{}, + OTP: &OTPCreds{}, Username: username, PublicKey: []byte(sshPubKey), - HeadlessAuthenticationID: headlessID, ClientMetadata: &ForwardedClientMetadata{ RemoteAddr: "0.0.0.0", }, }, TTL: defaults.CallbackTimeout, }) - tc.checkErr(t, err) require.NoError(t, <-errC) + if tc.expectErr { + require.Error(t, err) + // Verify login attempts unchanged. This is a proxy for various other user + // checks (locked, etc). + updatedAttempts, err := srv.Auth().GetUserLoginAttempts(username) + require.NoError(t, err) + require.Equal(t, attempts, updatedAttempts, "Login attempts unexpectedly changed") + } else { + require.NoError(t, err) + // Verify zeroed login attempts. This is a proxy for various other user + // checks (locked, etc). + updatedAttempts, err := srv.Auth().GetUserLoginAttempts(username) + require.NoError(t, err) + require.Empty(t, updatedAttempts, "Login attempts not reset") + } }) } } diff --git a/lib/auth/methods.go b/lib/auth/methods.go index a578ec6594890..9fab2dfc651f9 100644 --- a/lib/auth/methods.go +++ b/lib/auth/methods.go @@ -202,6 +202,18 @@ func (s *Server) authenticateUser(ctx context.Context, req AuthenticateUserReque var authErr error // error message kept obscure on purpose, use logging for details switch { // cases in order of preference + case req.HeadlessAuthenticationID != "": + // handle authentication before the user lock to prevent locking out users + // due to timed-out/canceled headless authentication attempts. + mfaDevice, err := s.authenticateHeadless(ctx, req) + if err != nil { + log.Debugf("Headless Authentication for user %q failed while waiting for approval: %v", user, err) + return nil, "", trace.Wrap(invalidHeadlessAuthenticationError) + } + authenticateFn = func() (*types.MFADevice, error) { + return mfaDevice, nil + } + authErr = invalidHeadlessAuthenticationError case req.Webauthn != nil: authenticateFn = func() (*types.MFADevice, error) { mfaResponse := &proto.MFAAuthenticateResponse{ @@ -224,18 +236,6 @@ func (s *Server) authenticateUser(ctx context.Context, req AuthenticateUserReque return res.mfaDev, nil } authErr = invalidUserPass2FError - case req.HeadlessAuthenticationID != "": - // handle authentication before the user lock to prevent locking out users - // due to timed-out/canceled headless authentication attempts. - mfaDevice, err := s.authenticateHeadless(ctx, req) - if err != nil { - log.Debugf("Headless Authentication for user %q failed while waiting for approval: %v", user, err) - return nil, "", trace.Wrap(invalidHeadlessAuthenticationError) - } - authenticateFn = func() (*types.MFADevice, error) { - return mfaDevice, nil - } - authErr = invalidHeadlessAuthenticationError } if authenticateFn != nil { var dev *types.MFADevice diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 5fe77ab670e72..45addeb74a0bc 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -3124,50 +3124,55 @@ func (h *Handler) createSSHCert(w http.ResponseWriter, r *http.Request, p httpro return nil, trace.Wrap(err) } - var authenticationClient interface { - AuthenticateSSHUser(ctx context.Context, req auth.AuthenticateSSHRequest) (*auth.SSHLoginResponse, error) - } = authClient + authSSHUserReq := auth.AuthenticateSSHRequest{ + AuthenticateUserRequest: auth.AuthenticateUserRequest{ + Username: req.User, + PublicKey: req.PubKey, + ClientMetadata: clientMetaFromReq(r), + }, + CompatibilityMode: req.Compatibility, + TTL: req.TTL, + RouteToCluster: req.RouteToCluster, + KubernetesCluster: req.KubernetesCluster, + AttestationStatement: req.AttestationStatement, + } + + if req.HeadlessAuthenticationID != "" { + // We need to use the default callback timeout rather than the standard client timeout. + // However, authClient is shared across all Proxy->Auth requests, so we need to create + // a new client to avoid applying the callback timeout to other concurrent requests. To + // this end, we create a clone of the HTTP Client with the desired timeout instead. + httpClient, err := authClient.CloneHTTPClient( + auth.ClientParamTimeout(defaults.CallbackTimeout), + auth.ClientParamResponseHeaderTimeout(defaults.CallbackTimeout), + ) + if err != nil { + return nil, trace.Wrap(err) + } - authReq := auth.AuthenticateUserRequest{ - Username: req.User, - PublicKey: req.PubKey, - ClientMetadata: clientMetaFromReq(r), + authSSHUserReq.AuthenticateUserRequest.HeadlessAuthenticationID = req.HeadlessAuthenticationID + loginResp, err := httpClient.AuthenticateSSHUser(r.Context(), authSSHUserReq) + if err != nil { + return nil, trace.Wrap(err) + } + return loginResp, nil } switch cap.GetSecondFactor() { case constants.SecondFactorOff: - authReq.Pass = &auth.PassCreds{ + authSSHUserReq.AuthenticateUserRequest.Pass = &auth.PassCreds{ Password: []byte(req.Password), } case constants.SecondFactorOTP, constants.SecondFactorOn, constants.SecondFactorOptional: - authReq.OTP = &auth.OTPCreds{ + authSSHUserReq.AuthenticateUserRequest.OTP = &auth.OTPCreds{ Password: []byte(req.Password), Token: req.OTPToken, } - case constants.SecondFactorWebauthn: - // WebAuthn only supports this endpoint for headless login. - authReq.HeadlessAuthenticationID = req.HeadlessAuthenticationID - - // create a new http client with a standard callback timeout. - authenticationClient, err = authClient.CloneHTTPClient( - auth.ClientParamTimeout(defaults.CallbackTimeout), - auth.ClientParamResponseHeaderTimeout(defaults.CallbackTimeout), - ) - if err != nil { - return nil, trace.Wrap(err) - } default: return nil, trace.AccessDenied("unsupported second factor type: %q", cap.GetSecondFactor()) } - loginResp, err := authenticationClient.AuthenticateSSHUser(r.Context(), auth.AuthenticateSSHRequest{ - AuthenticateUserRequest: authReq, - CompatibilityMode: req.Compatibility, - TTL: req.TTL, - RouteToCluster: req.RouteToCluster, - KubernetesCluster: req.KubernetesCluster, - AttestationStatement: req.AttestationStatement, - }) + loginResp, err := authClient.AuthenticateSSHUser(r.Context(), authSSHUserReq) if err != nil { return nil, trace.Wrap(err) }