diff --git a/lib/auth/helpers.go b/lib/auth/helpers.go index 9cb653c6cb9fd..f199a30a3d684 100644 --- a/lib/auth/helpers.go +++ b/lib/auth/helpers.go @@ -75,6 +75,8 @@ type TestAuthServerConfig struct { AuditLog events.IAuditLog // TraceClient allows a test to configure the trace client TraceClient otlptrace.Client + // AuthPreferenceSpec is custom initial AuthPreference spec for the test. + AuthPreferenceSpec *types.AuthPreferenceSpecV2 } // CheckAndSetDefaults checks and sets defaults @@ -91,6 +93,12 @@ func (cfg *TestAuthServerConfig) CheckAndSetDefaults() error { if len(cfg.CipherSuites) == 0 { cfg.CipherSuites = utils.DefaultCipherSuites() } + if cfg.AuthPreferenceSpec == nil { + cfg.AuthPreferenceSpec = &types.AuthPreferenceSpecV2{ + Type: constants.Local, + SecondFactor: constants.SecondFactorOff, + } + } return nil } @@ -289,10 +297,7 @@ func NewTestAuthServer(cfg TestAuthServerConfig) (*TestAuthServer, error) { return nil, trace.Wrap(err) } - authPreference, err := types.NewAuthPreferenceFromConfigFile(types.AuthPreferenceSpecV2{ - Type: constants.Local, - SecondFactor: constants.SecondFactorOff, - }) + authPreference, err := types.NewAuthPreferenceFromConfigFile(*cfg.AuthPreferenceSpec) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/httplib/csrf/csrf.go b/lib/httplib/csrf/csrf.go index 1b0f21ed336c3..c5d3d3945a28f 100644 --- a/lib/httplib/csrf/csrf.go +++ b/lib/httplib/csrf/csrf.go @@ -40,6 +40,11 @@ const ( defaultMaxAge = 0 ) +// GenerateToken generates a random CSRF token. +func GenerateToken() (string, error) { + return utils.CryptoRandomHex(tokenLenBytes) +} + // AddCSRFProtection adds CSRF token into the user session via secure cookie, // it implements "double submit cookie" approach to check against CSRF attacks // https://www.owasp.org/index.php/Cross-Site_Request_Forgery_%28CSRF%29_Prevention_Cheat_Sheet#Double_Submit_Cookie @@ -47,7 +52,7 @@ func AddCSRFProtection(w http.ResponseWriter, r *http.Request) (string, error) { token, err := ExtractTokenFromCookie(r) // if there was an error retrieving the token, the token doesn't exist if err != nil || len(token) == 0 { - token, err = utils.CryptoRandomHex(tokenLenBytes) + token, err = GenerateToken() if err != nil { return "", trace.Wrap(err) } diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 9f4a53707dd0d..d72fed4ee8d2c 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -1755,6 +1755,11 @@ func (h *Handler) changeUserAuthentication(w http.ResponseWriter, r *http.Reques return nil, trace.Wrap(err) } + err = h.trySettingConnectorNameToPasswordless(r.Context(), ctx, req) + if err != nil { + h.log.WithError(err).Error("Failed to set passwordless as connector name.") + } + if err := SetSessionCookie(w, sess.GetUser(), sess.GetName()); err != nil { return nil, trace.Wrap(err) } @@ -1776,6 +1781,45 @@ func (h *Handler) changeUserAuthentication(w http.ResponseWriter, r *http.Reques }, nil } +// trySettingConnectorNameToPasswordless sets cluster_auth_preference connectorName to `passwordless` when the first cloud user chooses passwordless as the authentication method. +// This simplifies UX for cloud users, as they will not need to select a passwordless connector when logging in. +func (h *Handler) trySettingConnectorNameToPasswordless(ctx context.Context, sessCtx *SessionContext, req changeUserAuthenticationRequest) error { + // We use the presence of a WebAuthn response, along with the absence of a + // password, as a proxy to determine that a passwordless registration took + // place, as it is not possible to infer that just from the WebAuthn response. + isPasswordlessRegistration := req.WebauthnCreationResponse != nil && len(req.Password) == 0 + if !isPasswordlessRegistration { + return nil + } + + if !h.ClusterFeatures.GetCloud() { + return nil + } + + authPreference, err := sessCtx.clt.GetAuthPreference(ctx) + if err != nil { + return nil + } + + if connector := authPreference.GetConnectorName(); connector != "" && connector != constants.LocalConnector { + return nil + } + + users, err := h.cfg.ProxyClient.GetUsers(false) + if err != nil { + return trace.Wrap(err) + } + + if len(users) != 1 { + return nil + } + + authPreference.SetConnectorName(constants.PasswordlessConnector) + + err = sessCtx.clt.SetAuthPreference(ctx, authPreference) + return trace.Wrap(err) +} + // createResetPasswordToken allows a UI user to reset a user's password. // This handler is also required for after creating new users. func (h *Handler) createResetPasswordToken(w http.ResponseWriter, r *http.Request, _ httprouter.Params, ctx *SessionContext) (interface{}, error) { @@ -1945,6 +1989,7 @@ func (h *Handler) mfaLoginFinishSession(w http.ResponseWriter, r *http.Request, if err != nil { return nil, trace.AccessDenied("need auth") } + return newSessionResponse(ctx) } diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 7df4150ae6e97..009fca11f44c4 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -148,6 +148,15 @@ func TestMain(m *testing.M) { } func newWebSuite(t *testing.T) *WebSuite { + return newWebSuiteWithConfig(t, webSuiteConfig{}) +} + +type webSuiteConfig struct { + // AuthPreferenceSpec is custom initial AuthPreference spec for the test. + authPreferenceSpec *types.AuthPreferenceSpecV2 +} + +func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { mockU2F, err := mocku2f.Create() require.NoError(t, err) require.NotNil(t, mockU2F) @@ -175,6 +184,7 @@ func newWebSuite(t *testing.T) *WebSuite { Dir: t.TempDir(), Clock: s.clock, ClusterNetworkingConfig: networkingConfig, + AuthPreferenceSpec: cfg.authPreferenceSpec, }, }) require.NoError(t, err) @@ -345,7 +355,9 @@ func newWebSuite(t *testing.T) *WebSuite { var sessionLingeringThreshold time.Duration fs, err := NewDebugFileSystem("../../webassets/teleport") require.NoError(t, err) + handler, err := NewHandler(Config{ + ClusterFeatures: *modules.GetModules().Features().ToProto(), // safe to dereference because ToProto creates a struct and return a pointer to it Proxy: revTunServer, AuthServers: utils.FromAddr(s.server.TLS.Addr()), DomainName: s.server.ClusterName(), @@ -4204,6 +4216,158 @@ func TestChangeUserAuthentication_WithPrivacyPolicyEnabledError(t *testing.T) { require.True(t, apiRes.PrivateKeyPolicyEnabled) } +func TestChangeUserAuthentication_settingDefaultClusterAuthPreference(t *testing.T) { + tt := []struct { + name string + cloud bool + numberOfUsers int + password []byte + authPreferenceType string + initialConnectorName string + resultConnectorName string + }{{ + name: "first cloud sign-in changes connector to `passwordless`", + cloud: true, + numberOfUsers: 1, + authPreferenceType: constants.Local, + initialConnectorName: "", + resultConnectorName: constants.PasswordlessConnector, + }, { + name: "first non-cloud sign-in doesn't change the connector", + cloud: false, + numberOfUsers: 1, + authPreferenceType: constants.Local, + initialConnectorName: "", + resultConnectorName: "", + }, { + name: "second cloud sign-in doesn't change the connector", + cloud: true, + numberOfUsers: 2, + authPreferenceType: constants.Local, + initialConnectorName: "", + resultConnectorName: "", + }, { + name: "first cloud sign-in does not change custom connector", + cloud: true, + numberOfUsers: 1, + authPreferenceType: constants.OIDC, + initialConnectorName: "custom", + resultConnectorName: "custom", + }, { + name: "first cloud sign-in with password does not change connector", + cloud: true, + numberOfUsers: 1, + password: []byte("abc123"), + authPreferenceType: constants.Local, + initialConnectorName: "", + resultConnectorName: "", + }} + + for _, tc := range tt { + modules.SetTestModules(t, &modules.TestModules{ + TestFeatures: modules.Features{ + Cloud: tc.cloud, + }, + }) + + const RPID = "localhost" + + s := newWebSuiteWithConfig(t, webSuiteConfig{ + authPreferenceSpec: &types.AuthPreferenceSpecV2{ + Type: tc.authPreferenceType, + ConnectorName: tc.initialConnectorName, + SecondFactor: constants.SecondFactorOn, + Webauthn: &types.Webauthn{ + RPID: RPID, + }, + }, + }) + + // user and role + users := make([]types.User, tc.numberOfUsers) + + for i := 0; i < tc.numberOfUsers; i++ { + user, err := types.NewUser(fmt.Sprintf("test_user_%v", i)) + require.NoError(t, err) + + user.SetCreatedBy(types.CreatedBy{ + User: types.UserRef{Name: "other_user"}, + }) + + role := services.RoleForUser(user) + + err = s.server.Auth().UpsertRole(s.ctx, role) + require.NoError(t, err) + + user.AddRole(role.GetName()) + + err = s.server.Auth().CreateUser(s.ctx, user) + require.NoError(t, err) + + users[i] = user + } + + initialUser := users[0] + + clt := s.client() + + // create register challenge + token, err := s.server.Auth().CreateResetPasswordToken(s.ctx, auth.CreateUserTokenRequest{ + Name: initialUser.GetName(), + }) + require.NoError(t, err) + + res, err := s.server.Auth().CreateRegisterChallenge(s.ctx, &authproto.CreateRegisterChallengeRequest{ + TokenID: token.GetName(), + DeviceType: authproto.DeviceType_DEVICE_TYPE_WEBAUTHN, + DeviceUsage: authproto.DeviceUsage_DEVICE_USAGE_PASSWORDLESS, + }) + require.NoError(t, err) + + cc := wanlib.CredentialCreationFromProto(res.GetWebauthn()) + + // use passwordless as auth method + device, err := mocku2f.Create() + require.NoError(t, err) + + device.SetPasswordless() + + ccr, err := device.SignCredentialCreation("https://"+RPID, cc) + require.NoError(t, err) + + // send sign-in response to server + body, err := json.Marshal(changeUserAuthenticationRequest{ + WebauthnCreationResponse: ccr, + TokenID: token.GetName(), + DeviceName: "passwordless-device", + Password: tc.password, + }) + require.NoError(t, err) + + req, err := http.NewRequest("PUT", clt.Endpoint("webapi", "users", "password", "token"), bytes.NewBuffer(body)) + require.NoError(t, err) + + csrfToken, err := csrf.GenerateToken() + require.NoError(t, err) + addCSRFCookieToReq(req, csrfToken) + req.Header.Set(csrf.HeaderName, csrfToken) + req.Header.Set("Content-Type", "application/json") + + re, err := clt.Client.RoundTrip(func() (*http.Response, error) { + return clt.Client.HTTPClient().Do(req) + }) + + require.NoError(t, err) + require.Equal(t, re.Code(), http.StatusOK) + + // check if auth preference connectorName is set + authPreference, err := s.server.Auth().GetAuthPreference(s.ctx) + require.NoError(t, err) + + require.Equal(t, authPreference.GetConnectorName(), tc.resultConnectorName, "Found unexpected auth connector name") + } +} + func TestParseSSORequestParams(t *testing.T) { t.Parallel()