Skip to content
Merged
15 changes: 15 additions & 0 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -5909,6 +5909,21 @@ func (a *Server) isMFARequired(ctx context.Context, checker services.AccessCheck
return nil, trace.BadParameter("empty Login field")
}

// state.MFARequired is "per-role", so if the user is joining
// a session, MFA is required no matter what node they are
// connecting to. We don't preform an RBAC check like we do
// below when users are starting a session to selectively
// require MFA because we don't know what session the user
// is joining, nor do we know what role allowed the session
// creator to start the session that is attempting to be joined.
// We need this info to be able to selectively skip MFA in
// this case.
if t.Node.Login == teleport.SSHSessionJoinPrincipal {
return &proto.IsMFARequiredResponse{
MFARequired: proto.MFARequired_MFA_REQUIRED_YES,
}, nil
}

// Find the target node and check whether MFA is required.
matches, err := client.GetResourcesWithFilters(ctx, a, proto.ListResourcesRequest{
ResourceType: types.KindNode,
Expand Down
141 changes: 121 additions & 20 deletions lib/auth/auth_login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/constants"
mfav1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/mfa/v1"
Expand Down Expand Up @@ -395,24 +396,78 @@ func TestCreateAuthenticateChallenge_mfaVerification(t *testing.T) {
prodRole, err = adminClient.UpsertRole(ctx, prodRole)
require.NoError(t, err, "UpsertRole(%q)", prodRole.GetName())

// Create a user with MFA devices...
userCreds, err := createUserWithSecondFactors(testServer)
require.NoError(t, err, "createUserWithSecondFactors")
username := userCreds.username
// Create a role that requires MFA when joining sessions
joinMFARole, err := types.NewRole("mfa_session_join", types.RoleSpecV6{
Options: types.RoleOptions{
RequireMFAType: types.RequireMFAType_SESSION,
},
Allow: types.RoleConditions{
Logins: []string{"{{internal.logins}}"},
NodeLabels: types.Labels{
"env": []string{"*"},
},
JoinSessions: []*types.SessionJoinPolicy{
{
Name: "session_join",
Roles: []string{"access"},
Kinds: []string{string(types.SSHSessionKind)},
Modes: []string{string(types.SessionPeerMode)},
},
},
},
})
require.NoError(t, err, "NewRole(joinMFA)")
joinMFARole, err = adminClient.UpsertRole(ctx, joinMFARole)
require.NoError(t, err, "UpsertRole(%q)", joinMFARole.GetName())

// ...and assign the user a sane unix login, plus the prod role.
user, err := adminClient.GetUser(ctx, username, false /* withSecrets */)
require.NoError(t, err, "GetUser(%q)", username)
const login = "llama"
user.SetLogins(append(user.GetLogins(), login))
user.AddRole(prodRole.GetName())
_, err = adminClient.UpdateUser(ctx, user.(*types.UserV2))
require.NoError(t, err, "UpdateUser(%q)", username)
// Create a role that doesn't require MFA when joining sessions
joinNoMFARole, err := types.NewRole("no_mfa_session_join", types.RoleSpecV6{
Allow: types.RoleConditions{
Logins: []string{"{{internal.logins}}"},
NodeLabels: types.Labels{
"env": []string{"*"},
},
JoinSessions: []*types.SessionJoinPolicy{
{
Name: "session_join",
Roles: []string{"access"},
Kinds: []string{string(types.SSHSessionKind)},
Modes: []string{string(types.SessionPeerMode)},
},
},
},
})
require.NoError(t, err, "NewRole(joinNoMFA)")
joinNoMFARole, err = adminClient.UpsertRole(ctx, joinNoMFARole)
require.NoError(t, err, "UpsertRole(%q)", joinNoMFARole.GetName())

const normalLogin = "llama"
createUser := func(role types.Role) *Client {
// Create a user with MFA devices...
userCreds, err := createUserWithSecondFactors(testServer)
require.NoError(t, err, "createUserWithSecondFactors")
username := userCreds.username

// ...and assign the user a sane unix login, plus the specified role.
user, err := adminClient.GetUser(ctx, username, false /* withSecrets */)
require.NoError(t, err, "GetUser(%q)", username)

user.SetLogins(append(user.GetLogins(), normalLogin))
user.AddRole(role.GetName())
_, err = adminClient.UpdateUser(ctx, user.(*types.UserV2))
require.NoError(t, err, "UpdateUser(%q)", username)

userClient, err := testServer.NewClient(TestUser(username))
require.NoError(t, err, "NewClient(%q)", username)

return userClient
}

userClient, err := testServer.NewClient(TestUser(username))
require.NoError(t, err, "NewClient(%q)", username)
prodAccessClient := createUser(prodRole)
joinMFAClient := createUser(joinMFARole)
joinNoMFAClient := createUser(joinNoMFARole)

createReqForNode := func(node string) *proto.CreateAuthenticateChallengeRequest {
createReqForNode := func(node, login string) *proto.CreateAuthenticateChallengeRequest {
return &proto.CreateAuthenticateChallengeRequest{
Request: &proto.CreateAuthenticateChallengeRequest_ContextUser{
ContextUser: &proto.ContextUser{},
Expand All @@ -430,25 +485,71 @@ func TestCreateAuthenticateChallenge_mfaVerification(t *testing.T) {

tests := []struct {
name string
userClient *Client
req *proto.CreateAuthenticateChallengeRequest
wantMFARequired proto.MFARequired
wantChallenges bool
}{
{
name: "MFA not required, no challenges issued",
req: createReqForNode(devNode),
name: "MFA not required to start session, no challenges issued",
userClient: prodAccessClient,
req: createReqForNode(devNode, normalLogin),
wantMFARequired: proto.MFARequired_MFA_REQUIRED_NO,
},
{
name: "MFA required",
req: createReqForNode(prodNode),
name: "MFA required to start session",
userClient: prodAccessClient,
req: createReqForNode(prodNode, normalLogin),
wantMFARequired: proto.MFARequired_MFA_REQUIRED_YES,
wantChallenges: true,
},
{
name: "MFA required to join session on prod node (prod role)",
userClient: prodAccessClient,
req: createReqForNode(prodNode, teleport.SSHSessionJoinPrincipal),
wantMFARequired: proto.MFARequired_MFA_REQUIRED_YES,
wantChallenges: true,
},
{
name: "MFA required to join session on dev node (prod role)",
userClient: prodAccessClient,
req: createReqForNode(devNode, teleport.SSHSessionJoinPrincipal),
wantMFARequired: proto.MFARequired_MFA_REQUIRED_YES,
wantChallenges: true,
},
{
name: "MFA required to join session on prod node (join MFA role)",
userClient: joinMFAClient,
req: createReqForNode(prodNode, teleport.SSHSessionJoinPrincipal),
wantMFARequired: proto.MFARequired_MFA_REQUIRED_YES,
wantChallenges: true,
},
{
name: "MFA required to join session dev node (join MFA role)",
userClient: joinMFAClient,
req: createReqForNode(prodNode, teleport.SSHSessionJoinPrincipal),
wantMFARequired: proto.MFARequired_MFA_REQUIRED_YES,
wantChallenges: true,
},
{
name: "MFA not required to join session, no challenges issued on dev node (join no MFA role)",
userClient: joinNoMFAClient,
req: createReqForNode(devNode, teleport.SSHSessionJoinPrincipal),
wantMFARequired: proto.MFARequired_MFA_REQUIRED_NO,
},
{
name: "MFA not required to join session, no challenges issued on prod node (join no MFA role)",
userClient: joinNoMFAClient,
req: createReqForNode(prodNode, teleport.SSHSessionJoinPrincipal),
wantMFARequired: proto.MFARequired_MFA_REQUIRED_NO,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
resp, err := userClient.CreateAuthenticateChallenge(ctx, test.req)
t.Parallel()

resp, err := test.userClient.CreateAuthenticateChallenge(ctx, test.req)
require.NoError(t, err, "CreateAuthenticateChallenge")

assert.Equal(t, test.wantMFARequired, resp.MFARequired, "resp.MFARequired mismatch")
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/authhandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ func (a *ahLoginChecker) canLoginWithRBAC(cert *ssh.Certificate, ca types.CertAu
auth.RoleSupportsModeratedSessions(accessChecker.Roles()) {

// allow joining if cluster wide MFA is not required
if state.MFARequired != services.MFARequiredAlways {
if state.MFARequired == services.MFARequiredNever {
return nil
}

Expand Down
Loading