diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 1ea206cbee792..c46fc008eaff7 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -5909,6 +5909,21 @@ func (a *Server) isMFARequired(ctx context.Context, checker services.AccessCheck return nil, trace.BadParameter("empty Login field") } + // state.MFARequired is "per-role", so if the user is joining + // a session, MFA is required no matter what node they are + // connecting to. We don't preform an RBAC check like we do + // below when users are starting a session to selectively + // require MFA because we don't know what session the user + // is joining, nor do we know what role allowed the session + // creator to start the session that is attempting to be joined. + // We need this info to be able to selectively skip MFA in + // this case. + if t.Node.Login == teleport.SSHSessionJoinPrincipal { + return &proto.IsMFARequiredResponse{ + MFARequired: proto.MFARequired_MFA_REQUIRED_YES, + }, nil + } + // Find the target node and check whether MFA is required. matches, err := client.GetResourcesWithFilters(ctx, a, proto.ListResourcesRequest{ ResourceType: types.KindNode, diff --git a/lib/auth/auth_login_test.go b/lib/auth/auth_login_test.go index 2cf52eedf8dab..a9f10ef583ad1 100644 --- a/lib/auth/auth_login_test.go +++ b/lib/auth/auth_login_test.go @@ -30,6 +30,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/constants" mfav1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/mfa/v1" @@ -395,24 +396,78 @@ func TestCreateAuthenticateChallenge_mfaVerification(t *testing.T) { prodRole, err = adminClient.UpsertRole(ctx, prodRole) require.NoError(t, err, "UpsertRole(%q)", prodRole.GetName()) - // Create a user with MFA devices... - userCreds, err := createUserWithSecondFactors(testServer) - require.NoError(t, err, "createUserWithSecondFactors") - username := userCreds.username + // Create a role that requires MFA when joining sessions + joinMFARole, err := types.NewRole("mfa_session_join", types.RoleSpecV6{ + Options: types.RoleOptions{ + RequireMFAType: types.RequireMFAType_SESSION, + }, + Allow: types.RoleConditions{ + Logins: []string{"{{internal.logins}}"}, + NodeLabels: types.Labels{ + "env": []string{"*"}, + }, + JoinSessions: []*types.SessionJoinPolicy{ + { + Name: "session_join", + Roles: []string{"access"}, + Kinds: []string{string(types.SSHSessionKind)}, + Modes: []string{string(types.SessionPeerMode)}, + }, + }, + }, + }) + require.NoError(t, err, "NewRole(joinMFA)") + joinMFARole, err = adminClient.UpsertRole(ctx, joinMFARole) + require.NoError(t, err, "UpsertRole(%q)", joinMFARole.GetName()) - // ...and assign the user a sane unix login, plus the prod role. - user, err := adminClient.GetUser(ctx, username, false /* withSecrets */) - require.NoError(t, err, "GetUser(%q)", username) - const login = "llama" - user.SetLogins(append(user.GetLogins(), login)) - user.AddRole(prodRole.GetName()) - _, err = adminClient.UpdateUser(ctx, user.(*types.UserV2)) - require.NoError(t, err, "UpdateUser(%q)", username) + // Create a role that doesn't require MFA when joining sessions + joinNoMFARole, err := types.NewRole("no_mfa_session_join", types.RoleSpecV6{ + Allow: types.RoleConditions{ + Logins: []string{"{{internal.logins}}"}, + NodeLabels: types.Labels{ + "env": []string{"*"}, + }, + JoinSessions: []*types.SessionJoinPolicy{ + { + Name: "session_join", + Roles: []string{"access"}, + Kinds: []string{string(types.SSHSessionKind)}, + Modes: []string{string(types.SessionPeerMode)}, + }, + }, + }, + }) + require.NoError(t, err, "NewRole(joinNoMFA)") + joinNoMFARole, err = adminClient.UpsertRole(ctx, joinNoMFARole) + require.NoError(t, err, "UpsertRole(%q)", joinNoMFARole.GetName()) + + const normalLogin = "llama" + createUser := func(role types.Role) *Client { + // Create a user with MFA devices... + userCreds, err := createUserWithSecondFactors(testServer) + require.NoError(t, err, "createUserWithSecondFactors") + username := userCreds.username + + // ...and assign the user a sane unix login, plus the specified role. + user, err := adminClient.GetUser(ctx, username, false /* withSecrets */) + require.NoError(t, err, "GetUser(%q)", username) + + user.SetLogins(append(user.GetLogins(), normalLogin)) + user.AddRole(role.GetName()) + _, err = adminClient.UpdateUser(ctx, user.(*types.UserV2)) + require.NoError(t, err, "UpdateUser(%q)", username) + + userClient, err := testServer.NewClient(TestUser(username)) + require.NoError(t, err, "NewClient(%q)", username) + + return userClient + } - userClient, err := testServer.NewClient(TestUser(username)) - require.NoError(t, err, "NewClient(%q)", username) + prodAccessClient := createUser(prodRole) + joinMFAClient := createUser(joinMFARole) + joinNoMFAClient := createUser(joinNoMFARole) - createReqForNode := func(node string) *proto.CreateAuthenticateChallengeRequest { + createReqForNode := func(node, login string) *proto.CreateAuthenticateChallengeRequest { return &proto.CreateAuthenticateChallengeRequest{ Request: &proto.CreateAuthenticateChallengeRequest_ContextUser{ ContextUser: &proto.ContextUser{}, @@ -430,25 +485,71 @@ func TestCreateAuthenticateChallenge_mfaVerification(t *testing.T) { tests := []struct { name string + userClient *Client req *proto.CreateAuthenticateChallengeRequest wantMFARequired proto.MFARequired wantChallenges bool }{ { - name: "MFA not required, no challenges issued", - req: createReqForNode(devNode), + name: "MFA not required to start session, no challenges issued", + userClient: prodAccessClient, + req: createReqForNode(devNode, normalLogin), wantMFARequired: proto.MFARequired_MFA_REQUIRED_NO, }, { - name: "MFA required", - req: createReqForNode(prodNode), + name: "MFA required to start session", + userClient: prodAccessClient, + req: createReqForNode(prodNode, normalLogin), wantMFARequired: proto.MFARequired_MFA_REQUIRED_YES, wantChallenges: true, }, + { + name: "MFA required to join session on prod node (prod role)", + userClient: prodAccessClient, + req: createReqForNode(prodNode, teleport.SSHSessionJoinPrincipal), + wantMFARequired: proto.MFARequired_MFA_REQUIRED_YES, + wantChallenges: true, + }, + { + name: "MFA required to join session on dev node (prod role)", + userClient: prodAccessClient, + req: createReqForNode(devNode, teleport.SSHSessionJoinPrincipal), + wantMFARequired: proto.MFARequired_MFA_REQUIRED_YES, + wantChallenges: true, + }, + { + name: "MFA required to join session on prod node (join MFA role)", + userClient: joinMFAClient, + req: createReqForNode(prodNode, teleport.SSHSessionJoinPrincipal), + wantMFARequired: proto.MFARequired_MFA_REQUIRED_YES, + wantChallenges: true, + }, + { + name: "MFA required to join session dev node (join MFA role)", + userClient: joinMFAClient, + req: createReqForNode(prodNode, teleport.SSHSessionJoinPrincipal), + wantMFARequired: proto.MFARequired_MFA_REQUIRED_YES, + wantChallenges: true, + }, + { + name: "MFA not required to join session, no challenges issued on dev node (join no MFA role)", + userClient: joinNoMFAClient, + req: createReqForNode(devNode, teleport.SSHSessionJoinPrincipal), + wantMFARequired: proto.MFARequired_MFA_REQUIRED_NO, + }, + { + name: "MFA not required to join session, no challenges issued on prod node (join no MFA role)", + userClient: joinNoMFAClient, + req: createReqForNode(prodNode, teleport.SSHSessionJoinPrincipal), + wantMFARequired: proto.MFARequired_MFA_REQUIRED_NO, + }, } for _, test := range tests { + test := test t.Run(test.name, func(t *testing.T) { - resp, err := userClient.CreateAuthenticateChallenge(ctx, test.req) + t.Parallel() + + resp, err := test.userClient.CreateAuthenticateChallenge(ctx, test.req) require.NoError(t, err, "CreateAuthenticateChallenge") assert.Equal(t, test.wantMFARequired, resp.MFARequired, "resp.MFARequired mismatch") diff --git a/lib/srv/authhandlers.go b/lib/srv/authhandlers.go index 3e7dbac19c801..8404b02fc0ce9 100644 --- a/lib/srv/authhandlers.go +++ b/lib/srv/authhandlers.go @@ -646,7 +646,7 @@ func (a *ahLoginChecker) canLoginWithRBAC(cert *ssh.Certificate, ca types.CertAu auth.RoleSupportsModeratedSessions(accessChecker.Roles()) { // allow joining if cluster wide MFA is not required - if state.MFARequired != services.MFARequiredAlways { + if state.MFARequired == services.MFARequiredNever { return nil } diff --git a/lib/srv/authhandlers_test.go b/lib/srv/authhandlers_test.go index 45e9f5de5f5ea..750ebaf522770 100644 --- a/lib/srv/authhandlers_test.go +++ b/lib/srv/authhandlers_test.go @@ -28,7 +28,9 @@ import ( "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/types/wrappers" "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/auth/native" "github.com/gravitational/teleport/lib/auth/testauthority" @@ -36,13 +38,18 @@ import ( "github.com/gravitational/teleport/lib/services" ) -type mockCAGetter struct { +type mockCAandAuthPrefGetter struct { AccessPoint - cas map[types.CertAuthType]types.CertAuthority + authPref types.AuthPreference + cas map[types.CertAuthType]types.CertAuthority } -func (m mockCAGetter) GetCertAuthorities(ctx context.Context, caType types.CertAuthType, loadKeys bool) ([]types.CertAuthority, error) { +func (m mockCAandAuthPrefGetter) GetAuthPreference(s_12345678 context.Context) (types.AuthPreference, error) { + return m.authPref, nil +} + +func (m mockCAandAuthPrefGetter) GetCertAuthorities(_ context.Context, caType types.CertAuthType, _ bool) ([]types.CertAuthority, error) { ca, ok := m.cas[caType] if !ok { return nil, trace.NotFound("CA not found") @@ -160,8 +167,9 @@ func TestRBAC(t *testing.T) { err = server.auth.SetClusterName(clusterName) require.NoError(t, err) - accessPoint := mockCAGetter{ + accessPoint := mockCAandAuthPrefGetter{ AccessPoint: server.auth, + authPref: types.DefaultAuthPreference(), cas: map[types.CertAuthType]types.CertAuthority{ types.UserCA: userCA, }, @@ -225,3 +233,178 @@ func TestRBAC(t *testing.T) { }) } } + +// TestRBACJoinMFA tests that MFA is enforced correctly when joining +// sessions depending on the cluster auth preference and roles presented. +func TestRBACJoinMFA(t *testing.T) { + t.Parallel() + + const clusterName = "localhost" + const username = "testuser" + + // create User CA + userTA := testauthority.New() + userCAPriv, err := userTA.GeneratePrivateKey() + require.NoError(t, err) + userCA, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ + Type: types.UserCA, + ClusterName: clusterName, + ActiveKeys: types.CAKeySet{ + SSH: []*types.SSHKeyPair{ + { + PublicKey: userCAPriv.MarshalSSHPublicKey(), + PrivateKey: userCAPriv.PrivateKeyPEM(), + PrivateKeyType: types.PrivateKeyType_RAW, + }, + }, + }, + }) + require.NoError(t, err) + + // create mock SSH server and add a cluster name + server := newMockServer(t) + cn, err := types.NewClusterName(types.ClusterNameSpecV2{ + ClusterName: clusterName, + ClusterID: "cluster_id", + }) + require.NoError(t, err) + err = server.auth.SetClusterName(cn) + require.NoError(t, err) + ctx := context.Background() + + accessPoint := &mockCAandAuthPrefGetter{ + AccessPoint: server.auth, + cas: map[types.CertAuthType]types.CertAuthority{ + types.UserCA: userCA, + }, + } + + // create auth handler and dummy node + config := &AuthHandlerConfig{ + Server: server, + Emitter: &eventstest.MockRecorderEmitter{}, + AccessPoint: accessPoint, + } + ah, err := NewAuthHandlers(config) + require.NoError(t, err) + + node, err := types.NewServer("testie_node", types.KindNode, types.ServerSpecV2{ + Addr: "1.2.3.4:22", + Hostname: "testie", + Version: types.V2, + }) + require.NoError(t, err) + + mfaAuthPref, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ + SecondFactor: constants.SecondFactorOTP, + RequireMFAType: types.RequireMFAType_HARDWARE_KEY_TOUCH, + }) + require.NoError(t, err) + + noMFAAuthPref, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ + SecondFactor: constants.SecondFactorOTP, + RequireMFAType: types.RequireMFAType_OFF, + }) + require.NoError(t, err) + + // create roles + joinMFARole, err := types.NewRole("joinMFA", types.RoleSpecV6{ + Options: types.RoleOptions{ + RequireMFAType: types.RequireMFAType_SESSION, + }, + Allow: types.RoleConditions{ + NodeLabels: types.Labels{ + types.Wildcard: []string{types.Wildcard}, + }, + }, + }) + require.NoError(t, err) + _, err = server.auth.CreateRole(ctx, joinMFARole) + require.NoError(t, err) + + joinRole, err := types.NewRole("join", types.RoleSpecV6{ + Allow: types.RoleConditions{ + NodeLabels: types.Labels{ + types.Wildcard: []string{types.Wildcard}, + }, + }, + }) + require.NoError(t, err) + _, err = server.auth.CreateRole(ctx, joinRole) + require.NoError(t, err) + + tests := []struct { + name string + authPref types.AuthPreference + role string + testError func(t *testing.T, err error) + }{ + { + name: "MFA cluster auth, MFA role", + authPref: mfaAuthPref, + role: joinMFARole.GetName(), + testError: func(t *testing.T, err error) { + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) + }, + }, + { + name: "MFA cluster auth, no MFA role", + authPref: mfaAuthPref, + role: joinRole.GetName(), + testError: func(t *testing.T, err error) { + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) + }, + }, + { + name: "no MFA cluster auth, MFA role", + authPref: noMFAAuthPref, + role: joinMFARole.GetName(), + testError: func(t *testing.T, err error) { + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) + }, + }, + { + name: "no MFA cluster auth, no MFA role", + authPref: noMFAAuthPref, + role: joinRole.GetName(), + testError: func(t *testing.T, err error) { + require.NoError(t, err) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + accessPoint.authPref = tt.authPref + + // create SSH certificate + caSigner, err := ssh.NewSignerFromKey(userCAPriv) + require.NoError(t, err) + keygen := testauthority.New() + privateKey, err := native.GeneratePrivateKey() + require.NoError(t, err) + + c, err := keygen.GenerateUserCert(services.UserCertParams{ + CASigner: caSigner, + PublicUserKey: ssh.MarshalAuthorizedKey(privateKey.SSHPublicKey()), + Username: username, + AllowedLogins: []string{username}, + Traits: wrappers.Traits{ + teleport.TraitInternalPrefix: []string{""}, + }, + Roles: []string{tt.role}, + CertificateFormat: constants.CertificateFormatStandard, + }) + require.NoError(t, err) + + cert, err := sshutils.ParseCertificate(c) + require.NoError(t, err) + + err = ah.canLoginWithRBAC(cert, userCA, clusterName, node, username, teleport.SSHSessionJoinPrincipal) + tt.testError(t, err) + }) + } +}