Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3093,7 +3093,7 @@ func (a *Server) CreateRegisterChallenge(ctx context.Context, req *proto.CreateR
}
username = token.GetUser()

case req.ExistingMFAResponse != nil: // Authenticated user without token, tsh.
default: // Authenticated user without token, tsh.
var err error
username, err = authz.GetClientUsername(ctx)
if err != nil {
Expand All @@ -3116,9 +3116,6 @@ func (a *Server) CreateRegisterChallenge(ctx context.Context, req *proto.CreateR
if err != nil {
return nil, trace.Wrap(err)
}

default:
return nil, trace.BadParameter("either a token or an MFA response are required")
}

regChal, err := a.createRegisterChallenge(ctx, &newRegisterChallengeRequest{
Expand Down
2 changes: 1 addition & 1 deletion lib/auth/auth_login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ func TestCreateRegisterChallenge(t *testing.T) {
DeviceType: proto.DeviceType_DEVICE_TYPE_WEBAUTHN,
DeviceUsage: proto.DeviceUsage_DEVICE_USAGE_MFA,
})
assert.ErrorContains(t, err, "token or an MFA response")
assert.ErrorContains(t, err, "second factor authentication required")

// Acquire and solve an authn challenge.
authnChal, err := authClient.CreateAuthenticateChallenge(ctx, &proto.CreateAuthenticateChallengeRequest{
Expand Down
36 changes: 33 additions & 3 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -6335,17 +6335,35 @@ func (a *ServerWithRoles) CreateAuthenticateChallenge(ctx context.Context, req *

// CreatePrivilegeToken is implemented by AuthService.CreatePrivilegeToken.
func (a *ServerWithRoles) CreatePrivilegeToken(ctx context.Context, req *proto.CreatePrivilegeTokenRequest) (*types.UserTokenV3, error) {
// Device trust: authorize device before issuing a privileged token without an MFA response.
//
// This is an exceptional case for that that results in a "privilege_exception" token, which can
// used to register a user's first MFA device thorugh the WebUI. Since a register challenge can
// be created on behalf of the user using this token (e.g. by the Proxy Service), we must enforce
// the device trust requirement seen in [CreatePrivilegeToken] here instead.
if mfaResp := req.GetExistingMFAResponse(); mfaResp.GetTOTP() == nil && mfaResp.GetWebauthn() == nil {
if err := a.enforceGlobalModeTrustedDevice(ctx); err != nil {
return nil, trace.Wrap(err, "device trust is required for users to create a privileged token without an MFA check")
}
}

return a.authServer.CreatePrivilegeToken(ctx, req)
}

// CreateRegisterChallenge is implemented by AuthService.CreateRegisterChallenge.
func (a *ServerWithRoles) CreateRegisterChallenge(ctx context.Context, req *proto.CreateRegisterChallengeRequest) (*proto.MFARegisterChallenge, error) {
switch {
case req.TokenID != "":
case req.ExistingMFAResponse != nil:
if req.TokenID == "" {
if !authz.IsLocalOrRemoteUser(a.context) {
return nil, trace.BadParameter("only end users are allowed issue registration challenges without a privilege token")
}

// Device trust: authorize device before issuing a register challenge without an MFA response or privilege token.
// This is an exceptional case for users registering their first MFA challenge through `tsh`.
if mfaResp := req.GetExistingMFAResponse(); mfaResp.GetTOTP() == nil && mfaResp.GetWebauthn() == nil {
if err := a.enforceGlobalModeTrustedDevice(ctx); err != nil {
return nil, trace.Wrap(err, "device trust is required for users to register their first MFA device")
}
}
}

// The following serve as means of authentication for this RPC:
Expand All @@ -6354,6 +6372,18 @@ func (a *ServerWithRoles) CreateRegisterChallenge(ctx context.Context, req *prot
return a.authServer.CreateRegisterChallenge(ctx, req)
}

// enforceGlobalModeTrustedDevice is used to enforce global device trust requirements
// for key endpoints.
func (a *ServerWithRoles) enforceGlobalModeTrustedDevice(ctx context.Context) error {
authPref, err := a.GetAuthPreference(ctx)
if err != nil {
return trace.Wrap(err)
}

err = dtauthz.VerifyTLSUser(authPref.GetDeviceTrust(), a.context.Identity.GetIdentity())
return trace.Wrap(err)
}

// GetAccountRecoveryCodes is implemented by AuthService.GetAccountRecoveryCodes.
func (a *ServerWithRoles) GetAccountRecoveryCodes(ctx context.Context, req *proto.GetAccountRecoveryCodesRequest) (*proto.RecoveryCodes, error) {
// User in context can retrieve their own recovery codes.
Expand Down
114 changes: 114 additions & 0 deletions lib/auth/grpcserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ import (
wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes"
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/defaults"
dtauthz "github.com/gravitational/teleport/lib/devicetrust/authz"
"github.com/gravitational/teleport/lib/modules"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/tlsca"
Expand Down Expand Up @@ -977,6 +978,119 @@ func TestGenerateUserCerts_deviceAuthz(t *testing.T) {
}
}

// Test that device trust is required for a user registering their first MFA device.
func TestRegisterFirstDevice_deviceAuthz(t *testing.T) {
modules.SetTestModules(t, &modules.TestModules{
TestBuildType: modules.BuildEnterprise, // required for Device Trust.
})

testServer := newTestTLSServer(t)

ctx := context.Background()
authServer := testServer.Auth()

// Create a user for testing.
user, _, err := CreateUserAndRole(testServer.Auth(), "llama", []string{"llama"}, nil)
require.NoError(t, err, "CreateUserAndRole failed")
username := user.GetName()

// Create clients with and without device extensions.
clientWithoutDevice, err := testServer.NewClient(TestUser(username))
require.NoError(t, err, "NewClient failed")

clientWithDevice, err := testServer.NewClient(
TestUserWithDeviceExtensions(username, tlsca.DeviceExtensions{
DeviceID: "deviceid1",
AssetTag: "assettag1",
CredentialID: "credentialid1",
}))
require.NoError(t, err, "NewClient failed")

// updateAuthPref is a helper used throughout the test.
updateAuthPref := func(t *testing.T, modify func(ap types.AuthPreference)) {
authPref, err := authServer.GetAuthPreference(ctx)
require.NoError(t, err, "GetAuthPreference failed")

modify(authPref)

require.NoError(t,
authServer.SetAuthPreference(ctx, authPref),
"SetAuthPreference failed")
}

// Enable webauthn
updateAuthPref(t, func(authPref types.AuthPreference) {
authPref.SetSecondFactor(constants.SecondFactorOptional)
authPref.SetWebauthn(&types.Webauthn{
RPID: "localhost",
})
})

assertSuccess := func(t *testing.T, err error) {
assert.NoError(t, err)
}
assertAccessDenied := func(t *testing.T, err error) {
assert.True(t, trace.IsAccessDenied(err), "expected access denied error but got %v", err)
assert.ErrorContains(t, err, dtauthz.ErrTrustedDeviceRequired.Error())
}

tests := []struct {
name string
clusterDeviceMode string
client *Client
skipLoginCerts bool // aka non-MFA issuance.
skipSingleUseCerts bool // aka MFA/streaming issuance.
assertErr func(t *testing.T, err error)
}{
{
name: "mode=optional without extensions",
clusterDeviceMode: constants.DeviceTrustModeOptional,
client: clientWithoutDevice,
assertErr: assertSuccess,
},
{
name: "mode=optional with extensions",
clusterDeviceMode: constants.DeviceTrustModeOptional,
client: clientWithDevice,
assertErr: assertSuccess,
},
{
name: "nok: mode=required without extensions",
clusterDeviceMode: constants.DeviceTrustModeRequired,
client: clientWithoutDevice,
assertErr: assertAccessDenied,
},
{
name: "mode=required with extensions",
clusterDeviceMode: constants.DeviceTrustModeRequired,
client: clientWithDevice,
assertErr: assertSuccess,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
updateAuthPref(t, func(ap types.AuthPreference) {
ap.SetDeviceTrust(&types.DeviceTrust{
Mode: test.clusterDeviceMode,
})
})

t.Run("CreatePrivilegeTokenRequest", func(t *testing.T) {
_, err := test.client.CreatePrivilegeToken(ctx, &proto.CreatePrivilegeTokenRequest{})
test.assertErr(t, err)
})

t.Run("CreateRegisterChallenge", func(t *testing.T) {
_, err := test.client.CreateRegisterChallenge(ctx, &proto.CreateRegisterChallengeRequest{
DeviceType: proto.DeviceType_DEVICE_TYPE_WEBAUTHN,
DeviceUsage: proto.DeviceUsage_DEVICE_USAGE_MFA,
})
test.assertErr(t, err)
})
})
}
}

func mustCreateDatabase(t *testing.T, name, protocol, uri string) *types.DatabaseV3 {
database, err := types.NewDatabaseV3(
types.Metadata{
Expand Down