diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index e96915963689d..38402a01d9955 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -2439,20 +2439,30 @@ func (g *GRPCServer) GenerateUserSingleUseCerts(stream proto.AuthService_Generat } mfaRequired := proto.MFARequired_MFA_REQUIRED_UNSPECIFIED - if required, err := isMFARequiredForSingleUseCertRequest(ctx, actx, initReq); err == nil { - // If MFA is not required to gain access to the resource then let the client - // know and abort the ceremony. - if !required { - return trace.Wrap(stream.Send(&proto.UserSingleUseCertsResponse{ - Response: &proto.UserSingleUseCertsResponse_MFAChallenge{ - MFAChallenge: &proto.MFAAuthenticateChallenge{ - MFARequired: proto.MFARequired_MFA_REQUIRED_NO, + clusterName, err := actx.GetClusterName() + if err != nil { + return trace.Wrap(err) + } + + // Only check if MFA is required for resources within the current cluster. Determining if + // MFA is required for a resource in a leaf cluster will result in a not found error and + // prevent users from accessing resources in leaf clusters. + if initReq.RouteToCluster == "" || clusterName.GetClusterName() == initReq.RouteToCluster { + if required, err := isMFARequiredForSingleUseCertRequest(ctx, actx, initReq); err == nil { + // If MFA is not required to gain access to the resource then let the client + // know and abort the ceremony. + if !required { + return trace.Wrap(stream.Send(&proto.UserSingleUseCertsResponse{ + Response: &proto.UserSingleUseCertsResponse_MFAChallenge{ + MFAChallenge: &proto.MFAAuthenticateChallenge{ + MFARequired: proto.MFARequired_MFA_REQUIRED_NO, + }, }, - }, - })) - } + })) + } - mfaRequired = proto.MFARequired_MFA_REQUIRED_YES + mfaRequired = proto.MFARequired_MFA_REQUIRED_YES + } } // 2. send MFAChallenge diff --git a/lib/auth/grpcserver_test.go b/lib/auth/grpcserver_test.go index 7cec547dfb7b2..0387da95cd2a4 100644 --- a/lib/auth/grpcserver_test.go +++ b/lib/auth/grpcserver_test.go @@ -1227,6 +1227,20 @@ func TestGenerateUserSingleUseCert(t *testing.T) { _, err = srv.Auth().UpsertDatabaseServer(ctx, db) require.NoError(t, err) + desktop, err := types.NewWindowsDesktopV3("desktop", nil, types.WindowsDesktopSpecV3{ + Addr: "localhost", + HostID: "test", + }) + require.NoError(t, err) + + require.NoError(t, srv.Auth().CreateWindowsDesktop(ctx, desktop)) + + leaf, err := types.NewRemoteCluster("leaf") + require.NoError(t, err) + + // create remote cluster + require.NoError(t, srv.Auth().CreateRemoteCluster(leaf)) + // Create a fake user. user, role, err := CreateUserAndRole(srv.Auth(), "mfa-user", []string{"role"}, nil) require.NoError(t, err) @@ -1236,6 +1250,8 @@ func TestGenerateUserSingleUseCert(t *testing.T) { role.SetDatabaseUsers(types.Allow, []string{types.Wildcard}) role.SetDatabaseLabels(types.Allow, types.Labels{types.Wildcard: {types.Wildcard}}) role.SetDatabaseNames(types.Allow, []string{types.Wildcard}) + role.SetWindowsLogins(types.Allow, []string{"role"}) + role.SetWindowsDesktopLabels(types.Allow, types.Labels{types.Wildcard: {types.Wildcard}}) role.SetOptions(roleOpt) err = srv.Auth().UpsertRole(ctx, role) require.NoError(t, err) @@ -1492,6 +1508,44 @@ func TestGenerateUserSingleUseCert(t *testing.T) { }, }, }, + { + desc: "desktops", + opts: generateUserSingleUseCertTestOpts{ + initReq: &proto.UserCertsRequest{ + PublicKey: pub, + Username: user.GetName(), + // This expiry is longer than allowed, should be + // automatically adjusted. + Expires: clock.Now().Add(2 * teleport.UserSingleUseCertTTL), + Usage: proto.UserCertsRequest_WindowsDesktop, + RouteToWindowsDesktop: proto.RouteToWindowsDesktop{ + WindowsDesktop: "desktop", + Login: "role", + }, + }, + checkInitErr: require.NoError, + mfaRequiredHandler: func(t *testing.T, required proto.MFARequired) { + require.Equal(t, proto.MFARequired_MFA_REQUIRED_YES, required) + }, + authHandler: registered.webAuthHandler, + checkAuthErr: require.NoError, + validateCert: func(t *testing.T, c *proto.SingleUseUserCert) { + crt := c.GetTLS() + require.NotEmpty(t, crt) + + cert, err := tlsca.ParseCertificatePEM(crt) + require.NoError(t, err) + require.Equal(t, cert.NotAfter, clock.Now().Add(teleport.UserSingleUseCertTTL)) + + identity, err := tlsca.FromSubject(cert.Subject, cert.NotAfter) + require.NoError(t, err) + require.Equal(t, webDevID, identity.MFAVerified) + require.Equal(t, userCertExpires, identity.PreviousIdentityExpires) + require.True(t, net.ParseIP(identity.LoginIP).IsLoopback()) + require.Equal(t, []string{teleport.UsageWindowsDesktopOnly}, identity.Usage) + }, + }, + }, { desc: "fail - wrong usage", opts: generateUserSingleUseCertTestOpts{ @@ -1669,6 +1723,129 @@ func TestGenerateUserSingleUseCert(t *testing.T) { checkAuthErr: require.Error, }, }, + { + desc: "k8s in leaf cluster", + opts: generateUserSingleUseCertTestOpts{ + initReq: &proto.UserCertsRequest{ + PublicKey: pub, + Username: user.GetName(), + // This expiry is longer than allowed, should be + // automatically adjusted. + Expires: clock.Now().Add(2 * teleport.UserSingleUseCertTTL), + Usage: proto.UserCertsRequest_Kubernetes, + KubernetesCluster: "kube-b", + RouteToCluster: "leaf", + }, + checkInitErr: require.NoError, + mfaRequiredHandler: func(t *testing.T, required proto.MFARequired) { + require.Equal(t, proto.MFARequired_MFA_REQUIRED_UNSPECIFIED, required) + }, + authHandler: registered.webAuthHandler, + checkAuthErr: require.NoError, + validateCert: func(t *testing.T, c *proto.SingleUseUserCert) { + crt := c.GetTLS() + require.NotEmpty(t, crt) + + cert, err := tlsca.ParseCertificatePEM(crt) + require.NoError(t, err) + require.Equal(t, cert.NotAfter, clock.Now().Add(teleport.UserSingleUseCertTTL)) + + identity, err := tlsca.FromSubject(cert.Subject, cert.NotAfter) + require.NoError(t, err) + require.Equal(t, webDevID, identity.MFAVerified) + require.Equal(t, userCertExpires, identity.PreviousIdentityExpires) + require.True(t, net.ParseIP(identity.LoginIP).IsLoopback()) + require.Equal(t, []string{teleport.UsageKubeOnly}, identity.Usage) + require.Equal(t, "kube-b", identity.KubernetesCluster) + }, + }, + }, + { + desc: "db in leaf cluster", + opts: generateUserSingleUseCertTestOpts{ + initReq: &proto.UserCertsRequest{ + PublicKey: pub, + Username: user.GetName(), + // This expiry is longer than allowed, should be + // automatically adjusted. + Expires: clock.Now().Add(2 * teleport.UserSingleUseCertTTL), + Usage: proto.UserCertsRequest_Database, + RouteToDatabase: proto.RouteToDatabase{ + ServiceName: "db-b", + Database: "db-b", + }, + RouteToCluster: "leaf", + }, + checkInitErr: require.NoError, + mfaRequiredHandler: func(t *testing.T, required proto.MFARequired) { + require.Equal(t, proto.MFARequired_MFA_REQUIRED_UNSPECIFIED, required) + }, + authHandler: registered.webAuthHandler, + checkAuthErr: require.NoError, + validateCert: func(t *testing.T, c *proto.SingleUseUserCert) { + crt := c.GetTLS() + require.NotEmpty(t, crt) + + cert, err := tlsca.ParseCertificatePEM(crt) + require.NoError(t, err) + require.Equal(t, clock.Now().Add(teleport.UserSingleUseCertTTL), cert.NotAfter) + + identity, err := tlsca.FromSubject(cert.Subject, cert.NotAfter) + require.NoError(t, err) + require.Equal(t, webDevID, identity.MFAVerified) + require.Equal(t, userCertExpires, identity.PreviousIdentityExpires) + require.True(t, net.ParseIP(identity.LoginIP).IsLoopback()) + require.Equal(t, []string{teleport.UsageDatabaseOnly}, identity.Usage) + require.Equal(t, identity.RouteToDatabase.ServiceName, "db-b") + }, + }, + }, + { + desc: "ssh in leaf node", + opts: generateUserSingleUseCertTestOpts{ + initReq: &proto.UserCertsRequest{ + PublicKey: pub, + Username: user.GetName(), + // This expiry is longer than allowed, should be + // automatically adjusted. + Expires: clock.Now().Add(2 * teleport.UserSingleUseCertTTL), + Usage: proto.UserCertsRequest_SSH, + NodeName: "node-b", + SSHLogin: "role", + RouteToCluster: "leaf", + }, + checkInitErr: require.NoError, + mfaRequiredHandler: func(t *testing.T, required proto.MFARequired) { + require.Equal(t, proto.MFARequired_MFA_REQUIRED_UNSPECIFIED, required) + }, + authHandler: registered.webAuthHandler, + checkAuthErr: require.NoError, + validateCert: func(t *testing.T, c *proto.SingleUseUserCert) { + sshCertBytes := c.GetSSH() + require.NotEmpty(t, sshCertBytes) + + cert, err := sshutils.ParseCertificate(sshCertBytes) + require.NoError(t, err) + + require.Equal(t, webDevID, cert.Extensions[teleport.CertExtensionMFAVerified]) + require.Equal(t, userCertExpires.Format(time.RFC3339), cert.Extensions[teleport.CertExtensionPreviousIdentityExpires]) + require.True(t, net.ParseIP(cert.Extensions[teleport.CertExtensionLoginIP]).IsLoopback()) + require.Equal(t, uint64(clock.Now().Add(teleport.UserSingleUseCertTTL).Unix()), cert.ValidBefore) + }, + }, + }, + { + desc: "fail - app access not supported", + opts: generateUserSingleUseCertTestOpts{ + initReq: &proto.UserCertsRequest{ + PublicKey: pub, + Username: user.GetName(), + Expires: clock.Now().Add(teleport.UserSingleUseCertTTL), + Usage: proto.UserCertsRequest_App, + }, + checkInitErr: require.Error, + }, + }, } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { diff --git a/lib/client/api.go b/lib/client/api.go index b8df6b13c5025..dab5328f0384b 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -1605,14 +1605,22 @@ func (tc *TeleportClient) ConnectToNode(ctx context.Context, clt *ClusterClient, mfaCancel() directCancel() - // Only return the error from connecting with mfa if the error - // originates from the mfa ceremony. If mfa is not required then - // the error from the direct connection to the node must be returned. - if mfaErr != nil && !errors.Is(mfaErr, io.EOF) && !errors.Is(mfaErr, MFARequiredUnknownErr{}) && !errors.Is(mfaErr, services.ErrSessionMFANotRequired) { + switch { + // No MFA errors, return any errors from the direct connection + case mfaErr == nil: + return nil, trace.Wrap(directErr) + // Any direct connection errors other than access denied, which should be returned + // if MFA is required, take precedent over MFA errors due to users not having any + // enrolled devices. + case !trace.IsAccessDenied(directErr) && errors.Is(mfaErr, auth.ErrNoMFADevices): + return nil, trace.Wrap(directErr) + case !errors.Is(mfaErr, io.EOF) && // Ignore any errors from MFA due to locks being enforced, the direct error will be friendlier + !errors.Is(mfaErr, MFARequiredUnknownErr{}) && // Ignore any failures that occurred before determining if MFA was required + !errors.Is(mfaErr, services.ErrSessionMFANotRequired): // Ignore any errors caused by attempting the MFA ceremony when MFA will not grant access return nil, trace.Wrap(mfaErr) + default: + return nil, trace.Wrap(directErr) } - - return nil, trace.Wrap(directErr) } // MFARequiredUnknownErr indicates that connections to an instance failed @@ -2817,7 +2825,12 @@ func (tc *TeleportClient) ConnectToCluster(ctx context.Context) (*ClusterClient, cluster = connected } - aclt, err := auth.NewClient(pclt.ClientConfig(ctx, cluster)) + cltConfig := pclt.ClientConfig(ctx, cluster) + cltConfig.DialOpts = append(cltConfig.DialOpts, + grpc.WithStreamInterceptor(utils.GRPCClientStreamErrorInterceptor), + grpc.WithUnaryInterceptor(utils.GRPCClientUnaryErrorInterceptor), + ) + aclt, err := auth.NewClient(cltConfig) if err != nil { return nil, trace.NewAggregate(err, pclt.Close()) } diff --git a/lib/client/client.go b/lib/client/client.go index 3c00d0d6c06d6..18806994aa37e 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -31,7 +31,6 @@ import ( "time" "github.com/gravitational/trace" - "github.com/gravitational/trace/trail" "github.com/moby/term" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/propagation" @@ -576,7 +575,7 @@ func (proxy *ProxyClient) IssueUserCertsWithMFA(ctx context.Context, params Reis // challenge and will terminate the stream with an auth.ErrNoMFADevices error. // In this case for all protocols other than SSH fall back to reissuing // certs without MFA. - if errors.Is(trail.FromGRPC(err), auth.ErrNoMFADevices) { + if errors.Is(err, auth.ErrNoMFADevices) { if params.usage() != proto.UserCertsRequest_SSH { return proxy.reissueUserCerts(ctx, CertCacheKeep, params) } diff --git a/lib/client/cluster_client.go b/lib/client/cluster_client.go index 68040bd324ec4..134b6a485af66 100644 --- a/lib/client/cluster_client.go +++ b/lib/client/cluster_client.go @@ -21,12 +21,14 @@ import ( "go.opentelemetry.io/otel/attribute" oteltrace "go.opentelemetry.io/otel/trace" "golang.org/x/crypto/ssh" + "google.golang.org/grpc" "github.com/gravitational/teleport/api/client/proto" proxyclient "github.com/gravitational/teleport/api/client/proxy" "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils" ) // ClusterClient facilitates communicating with both the @@ -85,7 +87,13 @@ func (c *ClusterClient) SessionSSHConfig(ctx context.Context, user string, targe mfaClt := c if target.Cluster != rootClusterName { - aclt, err := auth.NewClient(c.ProxyClient.ClientConfig(ctx, rootClusterName)) + cltConfig := c.ProxyClient.ClientConfig(ctx, rootClusterName) + cltConfig.DialOpts = append(cltConfig.DialOpts, + grpc.WithStreamInterceptor(utils.GRPCClientStreamErrorInterceptor), + grpc.WithUnaryInterceptor(utils.GRPCClientUnaryErrorInterceptor), + ) + + aclt, err := auth.NewClient(cltConfig) if err != nil { return nil, trace.Wrap(MFARequiredUnknown(err)) } @@ -102,7 +110,7 @@ func (c *ClusterClient) SessionSSHConfig(ctx context.Context, user string, targe } log.Debug("Attempting to issue a single-use user certificate with an MFA check.") - key, err = performMFACeremony(ctx, mfaClt, + key, err = c.performMFACeremony(ctx, mfaClt, ReissueParams{ NodeName: nodeName(target.Addr), RouteToCluster: target.Cluster, @@ -236,7 +244,7 @@ func (c *ClusterClient) prepareUserCertsRequest(params ReissueParams, key *Key) // performMFACeremony runs the mfa ceremony to completion. If successful the returned // [Key] will be authorized to connect to the target. -func performMFACeremony(ctx context.Context, clt *ClusterClient, params ReissueParams, key *Key) (*Key, error) { +func (c *ClusterClient) performMFACeremony(ctx context.Context, clt *ClusterClient, params ReissueParams, key *Key) (*Key, error) { stream, err := clt.AuthClient.GenerateUserSingleUseCerts(ctx) if err != nil { if trace.IsNotImplemented(err) { @@ -285,7 +293,9 @@ func performMFACeremony(ctx context.Context, clt *ClusterClient, params ReissueP case proto.MFARequired_MFA_REQUIRED_NO: return nil, trace.Wrap(services.ErrSessionMFANotRequired) case proto.MFARequired_MFA_REQUIRED_UNSPECIFIED: - check, err := clt.AuthClient.IsMFARequired(ctx, params.isMFARequiredRequest(clt.tc.HostLogin)) + // check if MFA is required with the auth client for this cluster and + // not the root client + check, err := c.AuthClient.IsMFARequired(ctx, params.isMFARequiredRequest(clt.tc.HostLogin)) if err != nil { return nil, trace.Wrap(MFARequiredUnknown(err)) } diff --git a/lib/web/terminal.go b/lib/web/terminal.go index 34e1daefa27f1..23bd821ac32f7 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -46,6 +46,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/lib/agentless" + "github.com/gravitational/teleport/lib/auth" wanlib "github.com/gravitational/teleport/lib/auth/webauthn" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/defaults" @@ -732,14 +733,22 @@ func (t *TerminalHandler) connectToHost(ctx context.Context, ws *websocket.Conn, mfaCancel() directCancel() - // Only return the error from connecting with mfa if the error - // originates from the mfa ceremony. If mfa is not required then - // the error from the direct connection to the node must be returned. - if mfaErr != nil && !errors.Is(mfaErr, io.EOF) && !errors.Is(mfaErr, client.MFARequiredUnknownErr{}) && !errors.Is(mfaErr, services.ErrSessionMFANotRequired) { + switch { + // No MFA errors, return any errors from the direct connection + case mfaErr == nil: + return nil, trace.Wrap(directErr) + // Any direct connection errors other than access denied, which should be returned + // if MFA is required, take precedent over MFA errors due to users not having any + // enrolled devices. + case !trace.IsAccessDenied(directErr) && errors.Is(mfaErr, auth.ErrNoMFADevices): + return nil, trace.Wrap(directErr) + case !errors.Is(mfaErr, io.EOF) && // Ignore any errors from MFA due to locks being enforced, the direct error will be friendlier + !errors.Is(mfaErr, client.MFARequiredUnknownErr{}) && // Ignore any failures that occurred before determining if MFA was required + !errors.Is(mfaErr, services.ErrSessionMFANotRequired): // Ignore any errors caused by attempting the MFA ceremony when MFA will not grant access return nil, trace.Wrap(mfaErr) + default: + return nil, trace.Wrap(directErr) } - - return nil, trace.Wrap(directErr) } // streamTerminal opens a SSH connection to the remote host and streams diff --git a/tool/tsh/tsh_test.go b/tool/tsh/tsh_test.go index 5419f8a31c218..6696c57dc247d 100644 --- a/tool/tsh/tsh_test.go +++ b/tool/tsh/tsh_test.go @@ -947,7 +947,15 @@ func TestSSHOnMultipleNodes(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - const origin = "https://localhost" + isInsecure := lib.IsInsecureDevMode() + lib.SetInsecureDevMode(true) + t.Cleanup(func() { + lib.SetInsecureDevMode(isInsecure) + }) + + origin := func(cluster string) string { + return fmt.Sprintf("https://%s", cluster) + } connector := mockConnector(t) user, err := user.Current() @@ -991,27 +999,69 @@ func TestSSHOnMultipleNodes(t *testing.T) { rootAuth, rootProxy := makeTestServers(t, withBootstrap(connector, alice, noAccessRole, sshLoginRole, perSessionMFARole)) - authAddr, err := rootAuth.AuthAddr() + rootAuthAddr, err := rootAuth.AuthAddr() require.NoError(t, err) - proxyAddr, err := rootProxy.ProxyWebAddr() + rootProxyAddr, err := rootProxy.ProxyWebAddr() + require.NoError(t, err) + rootTunnelAddr, err := rootProxy.ProxyTunnelAddr() + require.NoError(t, err) + + trustedCluster, err := types.NewTrustedCluster("localhost", types.TrustedClusterSpecV2{ + Enabled: true, + Roles: []string{}, + Token: staticToken, + ProxyAddress: rootProxyAddr.String(), + ReverseTunnelAddress: rootTunnelAddr.String(), + RoleMap: []types.RoleMapping{ + { + Remote: "access", + Local: []string{"access"}, + }, + { + Remote: perSessionMFARole.GetName(), + Local: []string{perSessionMFARole.GetName()}, + }, + { + Remote: sshLoginRole.GetName(), + Local: []string{sshLoginRole.GetName()}, + }, + }, + }) + require.NoError(t, err) + + leafAuth, leafProxy := makeTestServers(t, withClusterName(t, "leafcluster"), withBootstrap(connector, alice, sshLoginRole, perSessionMFARole)) + tryCreateTrustedCluster(t, leafAuth.GetAuthServer(), trustedCluster) + + leafAuthAddr, err := leafAuth.AuthAddr() require.NoError(t, err) + require.Eventually(t, func() bool { + conns, err := rootAuth.GetAuthServer().GetTunnelConnections("leafcluster") + return err == nil && len(conns) == 1 + }, 10*time.Second, 100*time.Millisecond, "leaf cluster never heart beated") + + leafProxyAddr := leafProxy.Config.Proxy.WebAddr.String() + stage1Hostname := "test-stage-1" - node := makeTestSSHNode(t, authAddr, withHostname(stage1Hostname), withSSHLabel("env", "stage")) + node := makeTestSSHNode(t, rootAuthAddr, withHostname(stage1Hostname), withSSHLabel("env", "stage")) sshHostID := node.Config.HostUUID stage2Hostname := "test-stage-2" - node2 := makeTestSSHNode(t, authAddr, withHostname(stage2Hostname), withSSHLabel("env", "stage")) + node2 := makeTestSSHNode(t, rootAuthAddr, withHostname(stage2Hostname), withSSHLabel("env", "stage")) sshHostID2 := node2.Config.HostUUID prodHostname := "test-prod-1" - nodeProd := makeTestSSHNode(t, authAddr, withHostname(prodHostname), withSSHLabel("env", "prod")) + nodeProd := makeTestSSHNode(t, rootAuthAddr, withHostname(prodHostname), withSSHLabel("env", "prod")) sshHostID3 := nodeProd.Config.HostUUID - hasNodes := func(hostIDs ...string) func() bool { + leafHostname := "leaf-node" + leafNode := makeTestSSHNode(t, leafAuthAddr, withHostname(leafHostname), withSSHLabel("animal", "llama")) + sshLeafHostID := leafNode.Config.HostUUID + + hasNodes := func(asrv *auth.Server, hostIDs ...string) func() bool { return func() bool { - nodes, err := rootAuth.GetAuthServer().GetNodes(ctx, apidefaults.Namespace) + nodes, err := asrv.GetNodes(ctx, apidefaults.Namespace) if err != nil { return false } @@ -1027,76 +1077,92 @@ func TestSSHOnMultipleNodes(t *testing.T) { } // wait for auth to see nodes - require.Eventually(t, hasNodes(sshHostID, sshHostID2, sshHostID3), - 5*time.Second, 100*time.Millisecond, "nodes never joined cluster") + require.Eventually(t, hasNodes(rootAuth.GetAuthServer(), sshHostID, sshHostID2, sshHostID3), + 5*time.Second, 100*time.Millisecond, "nodes never joined root cluster") + + require.Eventually(t, hasNodes(leafAuth.GetAuthServer(), sshLeafHostID), + 5*time.Second, 100*time.Millisecond, "nodes never joined leaf cluster") defaultPreference, err := rootAuth.GetAuthServer().GetAuthPreference(ctx) require.NoError(t, err) - // set the default auth preference - webauthnPreference := &types.AuthPreferenceV2{ - Spec: types.AuthPreferenceSpecV2{ - Type: constants.Local, - SecondFactor: constants.SecondFactorOptional, - Webauthn: &types.Webauthn{ - RPID: "localhost", + webauthnPreference := func(cluster string) *types.AuthPreferenceV2 { + return &types.AuthPreferenceV2{ + Spec: types.AuthPreferenceSpecV2{ + Type: constants.Local, + SecondFactor: constants.SecondFactorOptional, + Webauthn: &types.Webauthn{ + RPID: cluster, + }, }, - }, + } } - err = rootAuth.GetAuthServer().SetAuthPreference(ctx, webauthnPreference) - require.NoError(t, err) - token, err := rootAuth.GetAuthServer().CreateResetPasswordToken(ctx, auth.CreateUserTokenRequest{ - Name: "alice", - }) - require.NoError(t, err) - tokenID := token.GetName() - res, err := rootAuth.GetAuthServer().CreateRegisterChallenge(ctx, &proto.CreateRegisterChallengeRequest{ - TokenID: tokenID, - DeviceType: proto.DeviceType_DEVICE_TYPE_WEBAUTHN, - DeviceUsage: proto.DeviceUsage_DEVICE_USAGE_PASSWORDLESS, - }) - require.NoError(t, err) - cc := wanlib.CredentialCreationFromProto(res.GetWebauthn()) + setupUser := func(cluster string, asrv *auth.Server) { + // set the default auth preference + err = asrv.SetAuthPreference(ctx, webauthnPreference(cluster)) + require.NoError(t, err) - ccr, err := device.SignCredentialCreation(origin, cc) - require.NoError(t, err) - _, err = rootAuth.GetAuthServer().ChangeUserAuthentication(ctx, &proto.ChangeUserAuthenticationRequest{ - TokenID: tokenID, - NewPassword: []byte(password), - NewMFARegisterResponse: &proto.MFARegisterResponse{ - Response: &proto.MFARegisterResponse_Webauthn{ - Webauthn: wanlib.CredentialCreationResponseToProto(ccr), - }, - }, - }) - require.NoError(t, err) + token, err := asrv.CreateResetPasswordToken(ctx, auth.CreateUserTokenRequest{ + Name: "alice", + }) + require.NoError(t, err) + tokenID := token.GetName() + res, err := asrv.CreateRegisterChallenge(ctx, &proto.CreateRegisterChallengeRequest{ + TokenID: tokenID, + DeviceType: proto.DeviceType_DEVICE_TYPE_WEBAUTHN, + DeviceUsage: proto.DeviceUsage_DEVICE_USAGE_PASSWORDLESS, + }) + require.NoError(t, err) + cc := wanlib.CredentialCreationFromProto(res.GetWebauthn()) - successfulChallenge := func(ctx context.Context, realOrigin string, assertion *wanlib.CredentialAssertion, prompt wancli.LoginPrompt, _ *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) { - car, err := device.SignAssertion(origin, assertion) // use the fake origin to prevent a mismatch - if err != nil { - return nil, "", err - } - return &proto.MFAAuthenticateResponse{ - Response: &proto.MFAAuthenticateResponse_Webauthn{ - Webauthn: wanlib.CredentialAssertionResponseToProto(car), + ccr, err := device.SignCredentialCreation(origin(cluster), cc) + require.NoError(t, err) + _, err = asrv.ChangeUserAuthentication(ctx, &proto.ChangeUserAuthenticationRequest{ + TokenID: tokenID, + NewPassword: []byte(password), + NewMFARegisterResponse: &proto.MFARegisterResponse{ + Response: &proto.MFARegisterResponse_Webauthn{ + Webauthn: wanlib.CredentialCreationResponseToProto(ccr), + }, }, - }, "", nil + }) + require.NoError(t, err) } - failedChallenge := func(ctx context.Context, realOrigin string, assertion *wanlib.CredentialAssertion, prompt wancli.LoginPrompt, _ *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) { - car, err := device.SignAssertion(origin, assertion) // use the fake origin to prevent a mismatch - if err != nil { - return nil, "", err + setupUser("localhost", rootAuth.GetAuthServer()) + setupUser("leafcluster", leafAuth.GetAuthServer()) + + successfulChallenge := func(cluster string) func(ctx context.Context, realOrigin string, assertion *wanlib.CredentialAssertion, prompt wancli.LoginPrompt, _ *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) { + return func(ctx context.Context, realOrigin string, assertion *wanlib.CredentialAssertion, prompt wancli.LoginPrompt, _ *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) { + car, err := device.SignAssertion(origin(cluster), assertion) // use the fake origin to prevent a mismatch + if err != nil { + return nil, "", err + } + return &proto.MFAAuthenticateResponse{ + Response: &proto.MFAAuthenticateResponse_Webauthn{ + Webauthn: wanlib.CredentialAssertionResponseToProto(car), + }, + }, "", nil } - carProto := wanlib.CredentialAssertionResponseToProto(car) - carProto.Type = "NOT A VALID TYPE" // set to an invalid type so the ceremony fails + } - return &proto.MFAAuthenticateResponse{ - Response: &proto.MFAAuthenticateResponse_Webauthn{ - Webauthn: carProto, - }, - }, "", nil + failedChallenge := func(cluster string) func(ctx context.Context, realOrigin string, assertion *wanlib.CredentialAssertion, prompt wancli.LoginPrompt, _ *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) { + return func(ctx context.Context, realOrigin string, assertion *wanlib.CredentialAssertion, prompt wancli.LoginPrompt, _ *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) { + + car, err := device.SignAssertion(origin(cluster), assertion) // use the fake origin to prevent a mismatch + if err != nil { + return nil, "", err + } + carProto := wanlib.CredentialAssertionResponseToProto(car) + carProto.Type = "NOT A VALID TYPE" // set to an invalid type so the ceremony fails + + return &proto.MFAAuthenticateResponse{ + Response: &proto.MFAAuthenticateResponse_Webauthn{ + Webauthn: carProto, + }, + }, "", nil + } } type mfaPrompt = func(ctx context.Context, origin string, assertion *wanlib.CredentialAssertion, prompt wancli.LoginPrompt, _ *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) @@ -1129,10 +1195,15 @@ func TestSSHOnMultipleNodes(t *testing.T) { stdoutAssertion require.ValueAssertionFunc mfaPromptCount int headless bool + proxyAddr string + auth *auth.Server + cluster string }{ { name: "default auth preference runs commands on multiple nodes without mfa", authPreference: defaultPreference, + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), target: "env=stage", stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { require.Equal(t, "test\ntest\n", i, i2...) @@ -1140,16 +1211,20 @@ func TestSSHOnMultipleNodes(t *testing.T) { errAssertion: require.NoError, }, { - name: "webauthn auth preference runs commands on multiple matches without mfa", - target: "env=stage", + name: "webauthn auth preference runs commands on multiple matches without mfa", + target: "env=stage", + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { require.Equal(t, "test\ntest\n", i, i2...) }, errAssertion: require.NoError, }, { - name: "webauthn auth preference runs commands on a single match without mfa", - target: "env=prod", + name: "webauthn auth preference runs commands on a single match without mfa", + target: "env=prod", + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { require.Equal(t, "test\n", i, i2...) }, @@ -1158,6 +1233,8 @@ func TestSSHOnMultipleNodes(t *testing.T) { { name: "no matching hosts", target: "env=dev", + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), errAssertion: require.Error, stdoutAssertion: require.Empty, }, @@ -1173,8 +1250,10 @@ func TestSSHOnMultipleNodes(t *testing.T) { RequireMFAType: types.RequireMFAType_SESSION, }, }, - setup: setupChallengeSolver(successfulChallenge), - target: "env=stage", + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), + setup: setupChallengeSolver(successfulChallenge("localhost")), + target: "env=stage", stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { require.Equal(t, "test\ntest\n", i, i2...) }, @@ -1193,8 +1272,10 @@ func TestSSHOnMultipleNodes(t *testing.T) { RequireMFAType: types.RequireMFAType_SESSION, }, }, - setup: setupChallengeSolver(successfulChallenge), - target: "env=prod", + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), + setup: setupChallengeSolver(successfulChallenge("localhost")), + target: "env=prod", stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { require.Equal(t, "test\n", i, i2...) }, @@ -1213,7 +1294,9 @@ func TestSSHOnMultipleNodes(t *testing.T) { RequireMFAType: types.RequireMFAType_SESSION, }, }, - setup: setupChallengeSolver(successfulChallenge), + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), + setup: setupChallengeSolver(successfulChallenge("localhost")), target: "env=dev", errAssertion: require.Error, stdoutAssertion: require.Empty, @@ -1229,9 +1312,11 @@ func TestSSHOnMultipleNodes(t *testing.T) { }, }, }, - roles: []string{"access", sshLoginRole.GetName(), perSessionMFARole.GetName()}, - setup: setupChallengeSolver(successfulChallenge), - target: "env=stage", + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), + roles: []string{"access", sshLoginRole.GetName(), perSessionMFARole.GetName()}, + setup: setupChallengeSolver(successfulChallenge("localhost")), + target: "env=stage", stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { require.Equal(t, "test\ntest\n", i, i2...) }, @@ -1239,9 +1324,11 @@ func TestSSHOnMultipleNodes(t *testing.T) { errAssertion: require.NoError, }, { - name: "role permits access without mfa", - target: sshHostID, - roles: []string{sshLoginRole.GetName()}, + name: "role permits access without mfa", + target: sshHostID, + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), + roles: []string{sshLoginRole.GetName()}, stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { require.Equal(t, "test\n", i, i2...) }, @@ -1250,15 +1337,19 @@ func TestSSHOnMultipleNodes(t *testing.T) { { name: "role prevents access", target: sshHostID, + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), roles: []string{noAccessRole.GetName()}, stdoutAssertion: require.Empty, errAssertion: require.Error, }, { - name: "command runs on a hostname with mfa set via role", - target: sshHostID, - roles: []string{perSessionMFARole.GetName()}, - setup: setupChallengeSolver(successfulChallenge), + name: "command runs on a hostname with mfa set via role", + target: sshHostID, + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), + roles: []string{perSessionMFARole.GetName()}, + setup: setupChallengeSolver(successfulChallenge("localhost")), stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { require.Equal(t, "test\n", i, i2...) }, @@ -1276,9 +1367,11 @@ func TestSSHOnMultipleNodes(t *testing.T) { }, }, }, + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), target: sshHostID, roles: []string{perSessionMFARole.GetName()}, - setup: setupChallengeSolver(failedChallenge), + setup: setupChallengeSolver(failedChallenge("localhost")), stdoutAssertion: require.Empty, mfaPromptCount: 1, errAssertion: require.Error, @@ -1294,25 +1387,80 @@ func TestSSHOnMultipleNodes(t *testing.T) { }, }, }, - target: sshHostID, - roles: []string{perSessionMFARole.GetName()}, - setup: setupChallengeSolver(failedChallenge), + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), + target: sshHostID, + roles: []string{perSessionMFARole.GetName()}, + setup: setupChallengeSolver(failedChallenge("localhost")), stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { require.Equal(t, "test\n", i, i2...) }, errAssertion: require.NoError, headless: true, }, + { + name: "command runs on a leaf node with mfa set via role", + target: sshLeafHostID, + proxyAddr: leafProxyAddr, + auth: leafAuth.GetAuthServer(), + roles: []string{perSessionMFARole.GetName()}, + setup: setupChallengeSolver(successfulChallenge("leafcluster")), + stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { + require.Equal(t, "test\n", i, i2...) + }, + mfaPromptCount: 1, + errAssertion: require.NoError, + }, { + name: "command runs on a leaf node via root without mfa", + target: sshLeafHostID, + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), + cluster: "leafcluster", + roles: []string{sshLoginRole.GetName()}, + setup: setupChallengeSolver(successfulChallenge("localhost")), + stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { + require.Equal(t, "test\n", i, i2...) + }, + errAssertion: require.NoError, + }, + { + name: "command runs on a leaf node without mfa", + target: sshLeafHostID, + proxyAddr: leafProxyAddr, + auth: leafAuth.GetAuthServer(), + roles: []string{sshLoginRole.GetName()}, + setup: setupChallengeSolver(successfulChallenge("leafcluster")), + stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { + require.Equal(t, "test\n", i, i2...) + }, + errAssertion: require.NoError, + }, { + name: "command runs on a leaf node via root with mfa set via role", + target: sshLeafHostID, + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), + cluster: "leafcluster", + roles: []string{perSessionMFARole.GetName()}, + setup: setupChallengeSolver(successfulChallenge("localhost")), + stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { + require.Equal(t, "test\n", i, i2...) + }, + mfaPromptCount: 1, + errAssertion: require.NoError, + }, } for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { tmpHomePath := t.TempDir() + clusterName, err := tt.auth.GetClusterName() + require.NoError(t, err) + if tt.authPreference != nil { - require.NoError(t, rootAuth.GetAuthServer().SetAuthPreference(ctx, tt.authPreference)) + require.NoError(t, tt.auth.SetAuthPreference(ctx, tt.authPreference)) t.Cleanup(func() { - require.NoError(t, rootAuth.GetAuthServer().SetAuthPreference(ctx, webauthnPreference)) + require.NoError(t, tt.auth.SetAuthPreference(ctx, webauthnPreference(clusterName.GetClusterName()))) }) } @@ -1324,21 +1472,23 @@ func TestSSHOnMultipleNodes(t *testing.T) { roles := alice.GetRoles() t.Cleanup(func() { alice.SetRoles(roles) - require.NoError(t, rootAuth.GetAuthServer().UpsertUser(alice)) + require.NoError(t, tt.auth.UpsertUser(alice)) }) alice.SetRoles(tt.roles) - require.NoError(t, rootAuth.GetAuthServer().UpsertUser(alice)) + require.NoError(t, tt.auth.UpsertUser(alice)) } err = Run(ctx, []string{ "login", + "-d", "--insecure", "--auth", connector.GetName(), - "--proxy", proxyAddr.String(), + "--proxy", tt.proxyAddr, "--user", "alice", + tt.cluster, }, setHomePath(tmpHomePath), func(cf *CLIConf) error { - cf.mockSSOLogin = mockSSOLogin(t, rootAuth.GetAuthServer(), alice) + cf.mockSSOLogin = mockSSOLogin(t, tt.auth, alice) return nil }, ) @@ -1349,9 +1499,9 @@ func TestSSHOnMultipleNodes(t *testing.T) { // so we can assert how many times sign was called. device.SetCounter(0) - args := []string{"ssh", "--insecure"} + args := []string{"ssh", "-d", "--insecure"} if tt.headless { - args = append(args, "--headless", "--proxy", proxyAddr.String(), "--user", alice.GetName()) + args = append(args, "--headless", "--proxy", tt.proxyAddr, "--user", alice.GetName()) } args = append(args, tt.target, "echo", "test") @@ -1361,7 +1511,7 @@ func TestSSHOnMultipleNodes(t *testing.T) { func(conf *CLIConf) error { conf.overrideStdin = &bytes.Buffer{} conf.overrideStdout = stdout - conf.mockHeadlessLogin = mockHeadlessLogin(t, rootAuth.GetAuthServer(), alice) + conf.mockHeadlessLogin = mockHeadlessLogin(t, tt.auth, alice) return nil }, )