Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions lib/auth/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ type TestAuthServerConfig struct {
AuditLog events.IAuditLog
// TraceClient allows a test to configure the trace client
TraceClient otlptrace.Client
// AuthPreferenceSpec is custom initial AuthPreference spec for the test.
AuthPreferenceSpec *types.AuthPreferenceSpecV2
}

// CheckAndSetDefaults checks and sets defaults
Expand All @@ -92,6 +94,12 @@ func (cfg *TestAuthServerConfig) CheckAndSetDefaults() error {
if len(cfg.CipherSuites) == 0 {
cfg.CipherSuites = utils.DefaultCipherSuites()
}
if cfg.AuthPreferenceSpec == nil {
cfg.AuthPreferenceSpec = &types.AuthPreferenceSpecV2{
Type: constants.Local,
SecondFactor: constants.SecondFactorOff,
}
}
return nil
}

Expand Down Expand Up @@ -295,10 +303,7 @@ func NewTestAuthServer(cfg TestAuthServerConfig) (*TestAuthServer, error) {
return nil, trace.Wrap(err)
}

authPreference, err := types.NewAuthPreferenceFromConfigFile(types.AuthPreferenceSpecV2{
Type: constants.Local,
SecondFactor: constants.SecondFactorOff,
})
authPreference, err := types.NewAuthPreferenceFromConfigFile(*cfg.AuthPreferenceSpec)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
7 changes: 6 additions & 1 deletion lib/httplib/csrf/csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,19 @@ const (
defaultMaxAge = 0
)

// GenerateToken generates a random CSRF token.
func GenerateToken() (string, error) {
return utils.CryptoRandomHex(tokenLenBytes)
}

// AddCSRFProtection adds CSRF token into the user session via secure cookie,
// it implements "double submit cookie" approach to check against CSRF attacks
// https://www.owasp.org/index.php/Cross-Site_Request_Forgery_%28CSRF%29_Prevention_Cheat_Sheet#Double_Submit_Cookie
func AddCSRFProtection(w http.ResponseWriter, r *http.Request) (string, error) {
token, err := ExtractTokenFromCookie(r)
// if there was an error retrieving the token, the token doesn't exist
if err != nil || len(token) == 0 {
token, err = utils.CryptoRandomHex(tokenLenBytes)
token, err = GenerateToken()
if err != nil {
return "", trace.Wrap(err)
}
Expand Down
45 changes: 45 additions & 0 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1892,6 +1892,11 @@ func (h *Handler) changeUserAuthentication(w http.ResponseWriter, r *http.Reques
return nil, trace.Wrap(err)
}

err = h.trySettingConnectorNameToPasswordless(r.Context(), ctx, req)
if err != nil {
h.log.WithError(err).Error("Failed to set passwordless as connector name.")
}

if err := SetSessionCookie(w, sess.GetUser(), sess.GetName()); err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -1913,6 +1918,45 @@ func (h *Handler) changeUserAuthentication(w http.ResponseWriter, r *http.Reques
}, nil
}

// trySettingConnectorNameToPasswordless sets cluster_auth_preference connectorName to `passwordless` when the first cloud user chooses passwordless as the authentication method.
// This simplifies UX for cloud users, as they will not need to select a passwordless connector when logging in.
func (h *Handler) trySettingConnectorNameToPasswordless(ctx context.Context, sessCtx *SessionContext, req changeUserAuthenticationRequest) error {
// We use the presence of a WebAuthn response, along with the absence of a
// password, as a proxy to determine that a passwordless registration took
// place, as it is not possible to infer that just from the WebAuthn response.
isPasswordlessRegistration := req.WebauthnCreationResponse != nil && len(req.Password) == 0
if !isPasswordlessRegistration {
return nil
}

if !h.ClusterFeatures.GetCloud() {
return nil
}

authPreference, err := sessCtx.clt.GetAuthPreference(ctx)
if err != nil {
return nil
}

if connector := authPreference.GetConnectorName(); connector != "" && connector != constants.LocalConnector {
return nil
}

users, err := h.cfg.ProxyClient.GetUsers(false)
if err != nil {
return trace.Wrap(err)
}

if len(users) != 1 {
return nil
}

authPreference.SetConnectorName(constants.PasswordlessConnector)

err = sessCtx.clt.SetAuthPreference(ctx, authPreference)
return trace.Wrap(err)
}

// createResetPasswordToken allows a UI user to reset a user's password.
// This handler is also required for after creating new users.
func (h *Handler) createResetPasswordToken(w http.ResponseWriter, r *http.Request, _ httprouter.Params, ctx *SessionContext) (interface{}, error) {
Expand Down Expand Up @@ -2082,6 +2126,7 @@ func (h *Handler) mfaLoginFinishSession(w http.ResponseWriter, r *http.Request,
if err != nil {
return nil, trace.AccessDenied("need auth")
}

return newSessionResponse(ctx)
}

Expand Down
164 changes: 164 additions & 0 deletions lib/web/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,15 @@ func TestMain(m *testing.M) {
}

func newWebSuite(t *testing.T) *WebSuite {
return newWebSuiteWithConfig(t, webSuiteConfig{})
}

type webSuiteConfig struct {
// AuthPreferenceSpec is custom initial AuthPreference spec for the test.
authPreferenceSpec *types.AuthPreferenceSpecV2
}

func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite {
mockU2F, err := mocku2f.Create()
require.NoError(t, err)
require.NotNil(t, mockU2F)
Expand Down Expand Up @@ -182,6 +191,7 @@ func newWebSuite(t *testing.T) *WebSuite {
Dir: t.TempDir(),
Clock: s.clock,
ClusterNetworkingConfig: networkingConfig,
AuthPreferenceSpec: cfg.authPreferenceSpec,
},
})
require.NoError(t, err)
Expand Down Expand Up @@ -352,7 +362,9 @@ func newWebSuite(t *testing.T) *WebSuite {
var sessionLingeringThreshold time.Duration
fs, err := NewDebugFileSystem("../../webassets/teleport")
require.NoError(t, err)

handler, err := NewHandler(Config{
ClusterFeatures: *modules.GetModules().Features().ToProto(), // safe to dereference because ToProto creates a struct and return a pointer to it
Proxy: revTunServer,
AuthServers: utils.FromAddr(s.server.TLS.Addr()),
DomainName: s.server.ClusterName(),
Expand Down Expand Up @@ -4159,6 +4171,158 @@ func TestChangeUserAuthentication_WithPrivacyPolicyEnabledError(t *testing.T) {
require.True(t, apiRes.PrivateKeyPolicyEnabled)
}

func TestChangeUserAuthentication_settingDefaultClusterAuthPreference(t *testing.T) {
tt := []struct {
name string
cloud bool
numberOfUsers int
password []byte
authPreferenceType string
initialConnectorName string
resultConnectorName string
}{{
name: "first cloud sign-in changes connector to `passwordless`",
cloud: true,
numberOfUsers: 1,
authPreferenceType: constants.Local,
initialConnectorName: "",
resultConnectorName: constants.PasswordlessConnector,
}, {
name: "first non-cloud sign-in doesn't change the connector",
cloud: false,
numberOfUsers: 1,
authPreferenceType: constants.Local,
initialConnectorName: "",
resultConnectorName: "",
}, {
name: "second cloud sign-in doesn't change the connector",
cloud: true,
numberOfUsers: 2,
authPreferenceType: constants.Local,
initialConnectorName: "",
resultConnectorName: "",
}, {
name: "first cloud sign-in does not change custom connector",
cloud: true,
numberOfUsers: 1,
authPreferenceType: constants.OIDC,
initialConnectorName: "custom",
resultConnectorName: "custom",
}, {
name: "first cloud sign-in with password does not change connector",
cloud: true,
numberOfUsers: 1,
password: []byte("abc123"),
authPreferenceType: constants.Local,
initialConnectorName: "",
resultConnectorName: "",
}}

for _, tc := range tt {
modules.SetTestModules(t, &modules.TestModules{
TestFeatures: modules.Features{
Cloud: tc.cloud,
},
})

const RPID = "localhost"

s := newWebSuiteWithConfig(t, webSuiteConfig{
authPreferenceSpec: &types.AuthPreferenceSpecV2{
Type: tc.authPreferenceType,
ConnectorName: tc.initialConnectorName,
SecondFactor: constants.SecondFactorOn,
Webauthn: &types.Webauthn{
RPID: RPID,
},
},
})

// user and role
users := make([]types.User, tc.numberOfUsers)

for i := 0; i < tc.numberOfUsers; i++ {
user, err := types.NewUser(fmt.Sprintf("test_user_%v", i))
require.NoError(t, err)

user.SetCreatedBy(types.CreatedBy{
User: types.UserRef{Name: "other_user"},
})

role := services.RoleForUser(user)

err = s.server.Auth().UpsertRole(s.ctx, role)
require.NoError(t, err)

user.AddRole(role.GetName())

err = s.server.Auth().CreateUser(s.ctx, user)
require.NoError(t, err)

users[i] = user
}

initialUser := users[0]

clt := s.client()

// create register challenge
token, err := s.server.Auth().CreateResetPasswordToken(s.ctx, auth.CreateUserTokenRequest{
Name: initialUser.GetName(),
})
require.NoError(t, err)

res, err := s.server.Auth().CreateRegisterChallenge(s.ctx, &authproto.CreateRegisterChallengeRequest{
TokenID: token.GetName(),
DeviceType: authproto.DeviceType_DEVICE_TYPE_WEBAUTHN,
DeviceUsage: authproto.DeviceUsage_DEVICE_USAGE_PASSWORDLESS,
})
require.NoError(t, err)

cc := wanlib.CredentialCreationFromProto(res.GetWebauthn())

// use passwordless as auth method
device, err := mocku2f.Create()
require.NoError(t, err)

device.SetPasswordless()

ccr, err := device.SignCredentialCreation("https://"+RPID, cc)
require.NoError(t, err)

// send sign-in response to server
body, err := json.Marshal(changeUserAuthenticationRequest{
WebauthnCreationResponse: ccr,
TokenID: token.GetName(),
DeviceName: "passwordless-device",
Password: tc.password,
})
require.NoError(t, err)

req, err := http.NewRequest("PUT", clt.Endpoint("webapi", "users", "password", "token"), bytes.NewBuffer(body))
require.NoError(t, err)

csrfToken, err := csrf.GenerateToken()
require.NoError(t, err)
addCSRFCookieToReq(req, csrfToken)
req.Header.Set(csrf.HeaderName, csrfToken)
req.Header.Set("Content-Type", "application/json")

re, err := clt.Client.RoundTrip(func() (*http.Response, error) {
return clt.Client.HTTPClient().Do(req)
})

require.NoError(t, err)
require.Equal(t, re.Code(), http.StatusOK)

// check if auth preference connectorName is set
authPreference, err := s.server.Auth().GetAuthPreference(s.ctx)
require.NoError(t, err)

require.Equal(t, authPreference.GetConnectorName(), tc.resultConnectorName, "Found unexpected auth connector name")
}
}

func TestParseSSORequestParams(t *testing.T) {
t.Parallel()

Expand Down