diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 4e703ee4b6068..c56b3809c48f1 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -3093,7 +3093,7 @@ func (a *Server) CreateRegisterChallenge(ctx context.Context, req *proto.CreateR } username = token.GetUser() - case req.ExistingMFAResponse != nil: // Authenticated user without token, tsh. + default: // Authenticated user without token, tsh. var err error username, err = authz.GetClientUsername(ctx) if err != nil { @@ -3116,9 +3116,6 @@ func (a *Server) CreateRegisterChallenge(ctx context.Context, req *proto.CreateR if err != nil { return nil, trace.Wrap(err) } - - default: - return nil, trace.BadParameter("either a token or an MFA response are required") } regChal, err := a.createRegisterChallenge(ctx, &newRegisterChallengeRequest{ diff --git a/lib/auth/auth_login_test.go b/lib/auth/auth_login_test.go index e3b5f86575681..9c1b33c8dd4ce 100644 --- a/lib/auth/auth_login_test.go +++ b/lib/auth/auth_login_test.go @@ -532,7 +532,7 @@ func TestCreateRegisterChallenge(t *testing.T) { DeviceType: proto.DeviceType_DEVICE_TYPE_WEBAUTHN, DeviceUsage: proto.DeviceUsage_DEVICE_USAGE_MFA, }) - assert.ErrorContains(t, err, "token or an MFA response") + assert.ErrorContains(t, err, "second factor authentication required") // Acquire and solve an authn challenge. authnChal, err := authClient.CreateAuthenticateChallenge(ctx, &proto.CreateAuthenticateChallengeRequest{ diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index a873d1dbd9013..f11a924d661a4 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -6335,17 +6335,35 @@ func (a *ServerWithRoles) CreateAuthenticateChallenge(ctx context.Context, req * // CreatePrivilegeToken is implemented by AuthService.CreatePrivilegeToken. func (a *ServerWithRoles) CreatePrivilegeToken(ctx context.Context, req *proto.CreatePrivilegeTokenRequest) (*types.UserTokenV3, error) { + // Device trust: authorize device before issuing a privileged token without an MFA response. + // + // This is an exceptional case for that that results in a "privilege_exception" token, which can + // used to register a user's first MFA device thorugh the WebUI. Since a register challenge can + // be created on behalf of the user using this token (e.g. by the Proxy Service), we must enforce + // the device trust requirement seen in [CreatePrivilegeToken] here instead. + if mfaResp := req.GetExistingMFAResponse(); mfaResp.GetTOTP() == nil && mfaResp.GetWebauthn() == nil { + if err := a.enforceGlobalModeTrustedDevice(ctx); err != nil { + return nil, trace.Wrap(err, "device trust is required for users to create a privileged token without an MFA check") + } + } + return a.authServer.CreatePrivilegeToken(ctx, req) } // CreateRegisterChallenge is implemented by AuthService.CreateRegisterChallenge. func (a *ServerWithRoles) CreateRegisterChallenge(ctx context.Context, req *proto.CreateRegisterChallengeRequest) (*proto.MFARegisterChallenge, error) { - switch { - case req.TokenID != "": - case req.ExistingMFAResponse != nil: + if req.TokenID == "" { if !authz.IsLocalOrRemoteUser(a.context) { return nil, trace.BadParameter("only end users are allowed issue registration challenges without a privilege token") } + + // Device trust: authorize device before issuing a register challenge without an MFA response or privilege token. + // This is an exceptional case for users registering their first MFA challenge through `tsh`. + if mfaResp := req.GetExistingMFAResponse(); mfaResp.GetTOTP() == nil && mfaResp.GetWebauthn() == nil { + if err := a.enforceGlobalModeTrustedDevice(ctx); err != nil { + return nil, trace.Wrap(err, "device trust is required for users to register their first MFA device") + } + } } // The following serve as means of authentication for this RPC: @@ -6354,6 +6372,18 @@ func (a *ServerWithRoles) CreateRegisterChallenge(ctx context.Context, req *prot return a.authServer.CreateRegisterChallenge(ctx, req) } +// enforceGlobalModeTrustedDevice is used to enforce global device trust requirements +// for key endpoints. +func (a *ServerWithRoles) enforceGlobalModeTrustedDevice(ctx context.Context) error { + authPref, err := a.GetAuthPreference(ctx) + if err != nil { + return trace.Wrap(err) + } + + err = dtauthz.VerifyTLSUser(authPref.GetDeviceTrust(), a.context.Identity.GetIdentity()) + return trace.Wrap(err) +} + // GetAccountRecoveryCodes is implemented by AuthService.GetAccountRecoveryCodes. func (a *ServerWithRoles) GetAccountRecoveryCodes(ctx context.Context, req *proto.GetAccountRecoveryCodesRequest) (*proto.RecoveryCodes, error) { // User in context can retrieve their own recovery codes. diff --git a/lib/auth/grpcserver_test.go b/lib/auth/grpcserver_test.go index 65b910e858455..7bc5d271be503 100644 --- a/lib/auth/grpcserver_test.go +++ b/lib/auth/grpcserver_test.go @@ -68,6 +68,7 @@ import ( wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes" "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/defaults" + dtauthz "github.com/gravitational/teleport/lib/devicetrust/authz" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/tlsca" @@ -977,6 +978,119 @@ func TestGenerateUserCerts_deviceAuthz(t *testing.T) { } } +// Test that device trust is required for a user registering their first MFA device. +func TestRegisterFirstDevice_deviceAuthz(t *testing.T) { + modules.SetTestModules(t, &modules.TestModules{ + TestBuildType: modules.BuildEnterprise, // required for Device Trust. + }) + + testServer := newTestTLSServer(t) + + ctx := context.Background() + authServer := testServer.Auth() + + // Create a user for testing. + user, _, err := CreateUserAndRole(testServer.Auth(), "llama", []string{"llama"}, nil) + require.NoError(t, err, "CreateUserAndRole failed") + username := user.GetName() + + // Create clients with and without device extensions. + clientWithoutDevice, err := testServer.NewClient(TestUser(username)) + require.NoError(t, err, "NewClient failed") + + clientWithDevice, err := testServer.NewClient( + TestUserWithDeviceExtensions(username, tlsca.DeviceExtensions{ + DeviceID: "deviceid1", + AssetTag: "assettag1", + CredentialID: "credentialid1", + })) + require.NoError(t, err, "NewClient failed") + + // updateAuthPref is a helper used throughout the test. + updateAuthPref := func(t *testing.T, modify func(ap types.AuthPreference)) { + authPref, err := authServer.GetAuthPreference(ctx) + require.NoError(t, err, "GetAuthPreference failed") + + modify(authPref) + + require.NoError(t, + authServer.SetAuthPreference(ctx, authPref), + "SetAuthPreference failed") + } + + // Enable webauthn + updateAuthPref(t, func(authPref types.AuthPreference) { + authPref.SetSecondFactor(constants.SecondFactorOptional) + authPref.SetWebauthn(&types.Webauthn{ + RPID: "localhost", + }) + }) + + assertSuccess := func(t *testing.T, err error) { + assert.NoError(t, err) + } + assertAccessDenied := func(t *testing.T, err error) { + assert.True(t, trace.IsAccessDenied(err), "expected access denied error but got %v", err) + assert.ErrorContains(t, err, dtauthz.ErrTrustedDeviceRequired.Error()) + } + + tests := []struct { + name string + clusterDeviceMode string + client *Client + skipLoginCerts bool // aka non-MFA issuance. + skipSingleUseCerts bool // aka MFA/streaming issuance. + assertErr func(t *testing.T, err error) + }{ + { + name: "mode=optional without extensions", + clusterDeviceMode: constants.DeviceTrustModeOptional, + client: clientWithoutDevice, + assertErr: assertSuccess, + }, + { + name: "mode=optional with extensions", + clusterDeviceMode: constants.DeviceTrustModeOptional, + client: clientWithDevice, + assertErr: assertSuccess, + }, + { + name: "nok: mode=required without extensions", + clusterDeviceMode: constants.DeviceTrustModeRequired, + client: clientWithoutDevice, + assertErr: assertAccessDenied, + }, + { + name: "mode=required with extensions", + clusterDeviceMode: constants.DeviceTrustModeRequired, + client: clientWithDevice, + assertErr: assertSuccess, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + updateAuthPref(t, func(ap types.AuthPreference) { + ap.SetDeviceTrust(&types.DeviceTrust{ + Mode: test.clusterDeviceMode, + }) + }) + + t.Run("CreatePrivilegeTokenRequest", func(t *testing.T) { + _, err := test.client.CreatePrivilegeToken(ctx, &proto.CreatePrivilegeTokenRequest{}) + test.assertErr(t, err) + }) + + t.Run("CreateRegisterChallenge", func(t *testing.T) { + _, err := test.client.CreateRegisterChallenge(ctx, &proto.CreateRegisterChallengeRequest{ + DeviceType: proto.DeviceType_DEVICE_TYPE_WEBAUTHN, + DeviceUsage: proto.DeviceUsage_DEVICE_USAGE_MFA, + }) + test.assertErr(t, err) + }) + }) + } +} + func mustCreateDatabase(t *testing.T, name, protocol, uri string) *types.DatabaseV3 { database, err := types.NewDatabaseV3( types.Metadata{