diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index 9d300abebefff..b0ebfb35561a7 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -2587,20 +2587,30 @@ func (g *GRPCServer) GenerateUserSingleUseCerts(stream proto.AuthService_Generat } mfaRequired := proto.MFARequired_MFA_REQUIRED_UNSPECIFIED - if required, err := isMFARequiredForSingleUseCertRequest(ctx, actx, initReq); err == nil { - // If MFA is not required to gain access to the resource then let the client - // know and abort the ceremony. - if !required { - return trace.Wrap(stream.Send(&proto.UserSingleUseCertsResponse{ - Response: &proto.UserSingleUseCertsResponse_MFAChallenge{ - MFAChallenge: &proto.MFAAuthenticateChallenge{ - MFARequired: proto.MFARequired_MFA_REQUIRED_NO, + clusterName, err := actx.GetClusterName() + if err != nil { + return trace.Wrap(err) + } + + // Only check if MFA is required for resources within the current cluster. Determining if + // MFA is required for a resource in a leaf cluster will result in a not found error and + // prevent users from accessing resources in leaf clusters. + if initReq.RouteToCluster == "" || clusterName.GetClusterName() == initReq.RouteToCluster { + if required, err := isMFARequiredForSingleUseCertRequest(ctx, actx, initReq); err == nil { + // If MFA is not required to gain access to the resource then let the client + // know and abort the ceremony. + if !required { + return trace.Wrap(stream.Send(&proto.UserSingleUseCertsResponse{ + Response: &proto.UserSingleUseCertsResponse_MFAChallenge{ + MFAChallenge: &proto.MFAAuthenticateChallenge{ + MFARequired: proto.MFARequired_MFA_REQUIRED_NO, + }, }, - }, - })) - } + })) + } - mfaRequired = proto.MFARequired_MFA_REQUIRED_YES + mfaRequired = proto.MFARequired_MFA_REQUIRED_YES + } } // 2. send MFAChallenge diff --git a/lib/auth/grpcserver_test.go b/lib/auth/grpcserver_test.go index 7a9279c260f45..39864cd5994f6 100644 --- a/lib/auth/grpcserver_test.go +++ b/lib/auth/grpcserver_test.go @@ -1223,6 +1223,20 @@ func TestGenerateUserSingleUseCert(t *testing.T) { _, err = srv.Auth().UpsertDatabaseServer(ctx, db) require.NoError(t, err) + desktop, err := types.NewWindowsDesktopV3("desktop", nil, types.WindowsDesktopSpecV3{ + Addr: "localhost", + HostID: "test", + }) + require.NoError(t, err) + + require.NoError(t, srv.Auth().CreateWindowsDesktop(ctx, desktop)) + + leaf, err := types.NewRemoteCluster("leaf") + require.NoError(t, err) + + // create remote cluster + require.NoError(t, srv.Auth().CreateRemoteCluster(leaf)) + // Create a fake user. user, role, err := CreateUserAndRole(srv.Auth(), "mfa-user", []string{"role"}) require.NoError(t, err) @@ -1232,6 +1246,8 @@ func TestGenerateUserSingleUseCert(t *testing.T) { role.SetDatabaseUsers(types.Allow, []string{types.Wildcard}) role.SetDatabaseLabels(types.Allow, types.Labels{types.Wildcard: {types.Wildcard}}) role.SetDatabaseNames(types.Allow, []string{types.Wildcard}) + role.SetWindowsLogins(types.Allow, []string{"role"}) + role.SetWindowsDesktopLabels(types.Allow, types.Labels{types.Wildcard: {types.Wildcard}}) role.SetOptions(roleOpt) err = srv.Auth().UpsertRole(ctx, role) require.NoError(t, err) @@ -1626,6 +1642,129 @@ func TestGenerateUserSingleUseCert(t *testing.T) { checkAuthErr: require.Error, }, }, + { + desc: "k8s in leaf cluster", + opts: generateUserSingleUseCertTestOpts{ + initReq: &proto.UserCertsRequest{ + PublicKey: pub, + Username: user.GetName(), + // This expiry is longer than allowed, should be + // automatically adjusted. + Expires: clock.Now().Add(2 * teleport.UserSingleUseCertTTL), + Usage: proto.UserCertsRequest_Kubernetes, + KubernetesCluster: "kube-b", + RouteToCluster: "leaf", + }, + checkInitErr: require.NoError, + mfaRequiredHandler: func(t *testing.T, required proto.MFARequired) { + require.Equal(t, proto.MFARequired_MFA_REQUIRED_UNSPECIFIED, required) + }, + authHandler: registered.webAuthHandler, + checkAuthErr: require.NoError, + validateCert: func(t *testing.T, c *proto.SingleUseUserCert) { + crt := c.GetTLS() + require.NotEmpty(t, crt) + + cert, err := tlsca.ParseCertificatePEM(crt) + require.NoError(t, err) + require.Equal(t, cert.NotAfter, clock.Now().Add(teleport.UserSingleUseCertTTL)) + + identity, err := tlsca.FromSubject(cert.Subject, cert.NotAfter) + require.NoError(t, err) + require.Equal(t, webDevID, identity.MFAVerified) + require.Equal(t, userCertExpires, identity.PreviousIdentityExpires) + require.True(t, net.ParseIP(identity.LoginIP).IsLoopback()) + require.Equal(t, []string{teleport.UsageKubeOnly}, identity.Usage) + require.Equal(t, "kube-b", identity.KubernetesCluster) + }, + }, + }, + { + desc: "db in leaf cluster", + opts: generateUserSingleUseCertTestOpts{ + initReq: &proto.UserCertsRequest{ + PublicKey: pub, + Username: user.GetName(), + // This expiry is longer than allowed, should be + // automatically adjusted. + Expires: clock.Now().Add(2 * teleport.UserSingleUseCertTTL), + Usage: proto.UserCertsRequest_Database, + RouteToDatabase: proto.RouteToDatabase{ + ServiceName: "db-b", + Database: "db-b", + }, + RouteToCluster: "leaf", + }, + checkInitErr: require.NoError, + mfaRequiredHandler: func(t *testing.T, required proto.MFARequired) { + require.Equal(t, proto.MFARequired_MFA_REQUIRED_UNSPECIFIED, required) + }, + authHandler: registered.webAuthHandler, + checkAuthErr: require.NoError, + validateCert: func(t *testing.T, c *proto.SingleUseUserCert) { + crt := c.GetTLS() + require.NotEmpty(t, crt) + + cert, err := tlsca.ParseCertificatePEM(crt) + require.NoError(t, err) + require.Equal(t, clock.Now().Add(teleport.UserSingleUseCertTTL), cert.NotAfter) + + identity, err := tlsca.FromSubject(cert.Subject, cert.NotAfter) + require.NoError(t, err) + require.Equal(t, webDevID, identity.MFAVerified) + require.Equal(t, userCertExpires, identity.PreviousIdentityExpires) + require.True(t, net.ParseIP(identity.LoginIP).IsLoopback()) + require.Equal(t, []string{teleport.UsageDatabaseOnly}, identity.Usage) + require.Equal(t, identity.RouteToDatabase.ServiceName, "db-b") + }, + }, + }, + { + desc: "ssh in leaf node", + opts: generateUserSingleUseCertTestOpts{ + initReq: &proto.UserCertsRequest{ + PublicKey: pub, + Username: user.GetName(), + // This expiry is longer than allowed, should be + // automatically adjusted. + Expires: clock.Now().Add(2 * teleport.UserSingleUseCertTTL), + Usage: proto.UserCertsRequest_SSH, + NodeName: "node-b", + SSHLogin: "role", + RouteToCluster: "leaf", + }, + checkInitErr: require.NoError, + mfaRequiredHandler: func(t *testing.T, required proto.MFARequired) { + require.Equal(t, proto.MFARequired_MFA_REQUIRED_UNSPECIFIED, required) + }, + authHandler: registered.webAuthHandler, + checkAuthErr: require.NoError, + validateCert: func(t *testing.T, c *proto.SingleUseUserCert) { + sshCertBytes := c.GetSSH() + require.NotEmpty(t, sshCertBytes) + + cert, err := sshutils.ParseCertificate(sshCertBytes) + require.NoError(t, err) + + require.Equal(t, webDevID, cert.Extensions[teleport.CertExtensionMFAVerified]) + require.Equal(t, userCertExpires.Format(time.RFC3339), cert.Extensions[teleport.CertExtensionPreviousIdentityExpires]) + require.True(t, net.ParseIP(cert.Extensions[teleport.CertExtensionLoginIP]).IsLoopback()) + require.Equal(t, uint64(clock.Now().Add(teleport.UserSingleUseCertTTL).Unix()), cert.ValidBefore) + }, + }, + }, + { + desc: "fail - app access not supported", + opts: generateUserSingleUseCertTestOpts{ + initReq: &proto.UserCertsRequest{ + PublicKey: pub, + Username: user.GetName(), + Expires: clock.Now().Add(teleport.UserSingleUseCertTTL), + Usage: proto.UserCertsRequest_App, + }, + checkInitErr: require.Error, + }, + }, } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { diff --git a/lib/client/api.go b/lib/client/api.go index e91a09a332920..d9c498f30cf98 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -1546,14 +1546,22 @@ func (tc *TeleportClient) ConnectToNode(ctx context.Context, proxyClient *ProxyC mfaCancel() directCancel() - // Only return the error from connecting with mfa if the error - // originates from the mfa ceremony. If mfa is not required then - // the error from the direct connection to the node must be returned. - if mfaErr != nil && !errors.Is(mfaErr, io.EOF) && !errors.Is(mfaErr, MFARequiredUnknownErr{}) && !errors.Is(mfaErr, services.ErrSessionMFANotRequired) { + switch { + // No MFA errors, return any errors from the direct connection + case mfaErr == nil: + return nil, trace.Wrap(directErr) + // Any direct connection errors other than access denied, which should be returned + // if MFA is required, take precedent over MFA errors due to users not having any + // enrolled devices. + case !trace.IsAccessDenied(directErr) && errors.Is(mfaErr, auth.ErrNoMFADevices): + return nil, trace.Wrap(directErr) + case !errors.Is(mfaErr, io.EOF) && // Ignore any errors from MFA due to locks being enforced, the direct error will be friendlier + !errors.Is(mfaErr, MFARequiredUnknownErr{}) && // Ignore any failures that occurred before determining if MFA was required + !errors.Is(mfaErr, services.ErrSessionMFANotRequired): // Ignore any errors caused by attempting the MFA ceremony when MFA will not grant access return nil, trace.Wrap(mfaErr) + default: + return nil, trace.Wrap(directErr) } - - return nil, trace.Wrap(directErr) } // MFARequiredUnknownErr indicates that connections to an instance failed diff --git a/lib/client/client.go b/lib/client/client.go index fafc25f9fe913..917f6d7d06069 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -31,13 +31,13 @@ import ( "time" "github.com/gravitational/trace" - "github.com/gravitational/trace/trail" "github.com/moby/term" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/propagation" oteltrace "go.opentelemetry.io/otel/trace" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" + "google.golang.org/grpc" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/breaker" @@ -443,19 +443,17 @@ func (proxy *ProxyClient) IssueUserCertsWithMFA(ctx context.Context, params Reis } } - var clt auth.ClientI + clt, mfaClt := params.AuthClient, params.AuthClient // Connect to the target cluster (root or leaf) to check whether MFA is // required or if we know from param that it's required, connect because // it will be needed to do MFA check. - if params.AuthClient != nil { - clt = params.AuthClient - } else { + if clt == nil { authClt, err := proxy.ConnectToCluster(ctx, params.RouteToCluster) if err != nil { return nil, trace.Wrap(MFARequiredUnknown(err)) } - clt = authClt - defer clt.Close() + clt, mfaClt = authClt, authClt + defer authClt.Close() } // Always connect to root for getting new credentials, but attempt to reuse @@ -465,12 +463,11 @@ func (proxy *ProxyClient) IssueUserCertsWithMFA(ctx context.Context, params Reis return nil, trace.Wrap(err) } if params.RouteToCluster != rootClusterName { - clt.Close() rootClusterProxy := proxy if jumpHost := proxy.teleportClient.JumpHosts; jumpHost != nil { // In case of MFA connect to root teleport proxy instead of JumpHost to request // MFA certificates. - proxy.teleportClient.WithoutJumpHosts(func(tcNoJump *TeleportClient) error { + err := proxy.teleportClient.WithoutJumpHosts(func(tcNoJump *TeleportClient) error { rootClusterProxy, err = tcNoJump.ConnectToProxy(ctx) return trace.Wrap(err) }) @@ -479,17 +476,17 @@ func (proxy *ProxyClient) IssueUserCertsWithMFA(ctx context.Context, params Reis } defer rootClusterProxy.Close() } - clt, err = rootClusterProxy.ConnectToCluster(ctx, rootClusterName) + mfaClt, err = rootClusterProxy.ConnectToCluster(ctx, rootClusterName) if err != nil { return nil, trace.Wrap(err) } - defer clt.Close() + defer mfaClt.Close() } - params.AuthClient = clt + params.AuthClient = mfaClt log.Debug("Attempting to issue a single-use user certificate with an MFA check.") - stream, err := clt.GenerateUserSingleUseCerts(ctx) + stream, err := mfaClt.GenerateUserSingleUseCerts(ctx) if err != nil { if trace.IsNotImplemented(err) { // Probably talking to an older server, use the old non-MFA endpoint. @@ -527,7 +524,7 @@ func (proxy *ProxyClient) IssueUserCertsWithMFA(ctx context.Context, params Reis // challenge and will terminate the stream with an auth.ErrNoMFADevices error. // In this case for all protocols other than SSH fall back to reissuing // certs without MFA. - if errors.Is(trail.FromGRPC(err), auth.ErrNoMFADevices) { + if errors.Is(err, auth.ErrNoMFADevices) { if params.usage() != proto.UserCertsRequest_SSH { return proxy.reissueUserCerts(ctx, CertCacheKeep, params) } @@ -1163,6 +1160,10 @@ func (proxy *ProxyClient) ConnectToAuthServiceThroughALPNSNIProxy(ctx context.Co }, ALPNSNIAuthDialClusterName: clusterName, CircuitBreakerConfig: breaker.NoopBreakerConfig(), + DialOpts: []grpc.DialOption{ + grpc.WithStreamInterceptor(utils.GRPCClientStreamErrorInterceptor), + grpc.WithUnaryInterceptor(utils.GRPCClientUnaryErrorInterceptor), + }, }) if err != nil { return nil, trace.Wrap(err) @@ -1241,6 +1242,10 @@ func (proxy *ProxyClient) ConnectToCluster(ctx context.Context, clusterName stri client.LoadTLS(tlsConfig), }, CircuitBreakerConfig: breaker.NoopBreakerConfig(), + DialOpts: []grpc.DialOption{ + grpc.WithStreamInterceptor(utils.GRPCClientStreamErrorInterceptor), + grpc.WithUnaryInterceptor(utils.GRPCClientUnaryErrorInterceptor), + }, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/web/terminal.go b/lib/web/terminal.go index 45237db76125b..a09b997bf1800 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -45,6 +45,7 @@ import ( tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/keys" + "github.com/gravitational/teleport/lib/auth" wanlib "github.com/gravitational/teleport/lib/auth/webauthn" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/defaults" @@ -709,14 +710,22 @@ func (t *TerminalHandler) connectToHost(ctx context.Context, ws *websocket.Conn, mfaCancel() directCancel() - // Only return the error from connecting with mfa if the error - // originates from the mfa ceremony. If mfa is not required then - // the error from the direct connection to the node must be returned. - if mfaErr != nil && !errors.Is(mfaErr, io.EOF) && !errors.Is(mfaErr, client.MFARequiredUnknownErr{}) && !errors.Is(mfaErr, services.ErrSessionMFANotRequired) { + switch { + // No MFA errors, return any errors from the direct connection + case mfaErr == nil: + return nil, trace.Wrap(directErr) + // Any direct connection errors other than access denied, which should be returned + // if MFA is required, take precedent over MFA errors due to users not having any + // enrolled devices. + case !trace.IsAccessDenied(directErr) && errors.Is(mfaErr, auth.ErrNoMFADevices): + return nil, trace.Wrap(directErr) + case !errors.Is(mfaErr, io.EOF) && // Ignore any errors from MFA due to locks being enforced, the direct error will be friendlier + !errors.Is(mfaErr, client.MFARequiredUnknownErr{}) && // Ignore any failures that occurred before determining if MFA was required + !errors.Is(mfaErr, services.ErrSessionMFANotRequired): // Ignore any errors caused by attempting the MFA ceremony when MFA will not grant access return nil, trace.Wrap(mfaErr) + default: + return nil, trace.Wrap(directErr) } - - return nil, trace.Wrap(directErr) } // streamTerminal opens a SSH connection to the remote host and streams diff --git a/tool/tsh/tsh_test.go b/tool/tsh/tsh_test.go index ed2989677d114..e888e353939ca 100644 --- a/tool/tsh/tsh_test.go +++ b/tool/tsh/tsh_test.go @@ -946,7 +946,15 @@ func TestSSHOnMultipleNodes(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - const origin = "https://localhost" + isInsecure := lib.IsInsecureDevMode() + lib.SetInsecureDevMode(true) + t.Cleanup(func() { + lib.SetInsecureDevMode(isInsecure) + }) + + origin := func(cluster string) string { + return fmt.Sprintf("https://%s", cluster) + } connector := mockConnector(t) user, err := user.Current() @@ -990,27 +998,69 @@ func TestSSHOnMultipleNodes(t *testing.T) { rootAuth, rootProxy := makeTestServers(t, withBootstrap(connector, alice, noAccessRole, sshLoginRole, perSessionMFARole)) - authAddr, err := rootAuth.AuthAddr() + rootAuthAddr, err := rootAuth.AuthAddr() require.NoError(t, err) - proxyAddr, err := rootProxy.ProxyWebAddr() + rootProxyAddr, err := rootProxy.ProxyWebAddr() + require.NoError(t, err) + rootTunnelAddr, err := rootProxy.ProxyTunnelAddr() require.NoError(t, err) + trustedCluster, err := types.NewTrustedCluster("localhost", types.TrustedClusterSpecV2{ + Enabled: true, + Roles: []string{}, + Token: staticToken, + ProxyAddress: rootProxyAddr.String(), + ReverseTunnelAddress: rootTunnelAddr.String(), + RoleMap: []types.RoleMapping{ + { + Remote: "access", + Local: []string{"access"}, + }, + { + Remote: perSessionMFARole.GetName(), + Local: []string{perSessionMFARole.GetName()}, + }, + { + Remote: sshLoginRole.GetName(), + Local: []string{sshLoginRole.GetName()}, + }, + }, + }) + require.NoError(t, err) + + leafAuth, leafProxy := makeTestServers(t, withClusterName(t, "leafcluster"), withBootstrap(connector, alice, sshLoginRole, perSessionMFARole)) + tryCreateTrustedCluster(t, leafAuth.GetAuthServer(), trustedCluster) + + leafAuthAddr, err := leafAuth.AuthAddr() + require.NoError(t, err) + + require.Eventually(t, func() bool { + conns, err := rootAuth.GetAuthServer().GetTunnelConnections("leafcluster") + return err == nil && len(conns) == 1 + }, 10*time.Second, 100*time.Millisecond, "leaf cluster never heart beated") + + leafProxyAddr := leafProxy.Config.Proxy.WebAddr.String() + stage1Hostname := "test-stage-1" - node := makeTestSSHNode(t, authAddr, withHostname(stage1Hostname), withSSHLabel("env", "stage")) + node := makeTestSSHNode(t, rootAuthAddr, withHostname(stage1Hostname), withSSHLabel("env", "stage")) sshHostID := node.Config.HostUUID stage2Hostname := "test-stage-2" - node2 := makeTestSSHNode(t, authAddr, withHostname(stage2Hostname), withSSHLabel("env", "stage")) + node2 := makeTestSSHNode(t, rootAuthAddr, withHostname(stage2Hostname), withSSHLabel("env", "stage")) sshHostID2 := node2.Config.HostUUID prodHostname := "test-prod-1" - nodeProd := makeTestSSHNode(t, authAddr, withHostname(prodHostname), withSSHLabel("env", "prod")) + nodeProd := makeTestSSHNode(t, rootAuthAddr, withHostname(prodHostname), withSSHLabel("env", "prod")) sshHostID3 := nodeProd.Config.HostUUID - hasNodes := func(hostIDs ...string) func() bool { + leafHostname := "leaf-node" + leafNode := makeTestSSHNode(t, leafAuthAddr, withHostname(leafHostname), withSSHLabel("animal", "llama")) + sshLeafHostID := leafNode.Config.HostUUID + + hasNodes := func(asrv *auth.Server, hostIDs ...string) func() bool { return func() bool { - nodes, err := rootAuth.GetAuthServer().GetNodes(ctx, apidefaults.Namespace) + nodes, err := asrv.GetNodes(ctx, apidefaults.Namespace) if err != nil { return false } @@ -1026,76 +1076,92 @@ func TestSSHOnMultipleNodes(t *testing.T) { } // wait for auth to see nodes - require.Eventually(t, hasNodes(sshHostID, sshHostID2, sshHostID3), - 5*time.Second, 100*time.Millisecond, "nodes never joined cluster") + require.Eventually(t, hasNodes(rootAuth.GetAuthServer(), sshHostID, sshHostID2, sshHostID3), + 5*time.Second, 100*time.Millisecond, "nodes never joined root cluster") + + require.Eventually(t, hasNodes(leafAuth.GetAuthServer(), sshLeafHostID), + 5*time.Second, 100*time.Millisecond, "nodes never joined leaf cluster") defaultPreference, err := rootAuth.GetAuthServer().GetAuthPreference(ctx) require.NoError(t, err) - // set the default auth preference - webauthnPreference := &types.AuthPreferenceV2{ - Spec: types.AuthPreferenceSpecV2{ - Type: constants.Local, - SecondFactor: constants.SecondFactorOptional, - Webauthn: &types.Webauthn{ - RPID: "localhost", + webauthnPreference := func(cluster string) *types.AuthPreferenceV2 { + return &types.AuthPreferenceV2{ + Spec: types.AuthPreferenceSpecV2{ + Type: constants.Local, + SecondFactor: constants.SecondFactorOptional, + Webauthn: &types.Webauthn{ + RPID: cluster, + }, }, - }, + } } - err = rootAuth.GetAuthServer().SetAuthPreference(ctx, webauthnPreference) - require.NoError(t, err) - token, err := rootAuth.GetAuthServer().CreateResetPasswordToken(ctx, auth.CreateUserTokenRequest{ - Name: "alice", - }) - require.NoError(t, err) - tokenID := token.GetName() - res, err := rootAuth.GetAuthServer().CreateRegisterChallenge(ctx, &proto.CreateRegisterChallengeRequest{ - TokenID: tokenID, - DeviceType: proto.DeviceType_DEVICE_TYPE_WEBAUTHN, - DeviceUsage: proto.DeviceUsage_DEVICE_USAGE_PASSWORDLESS, - }) - require.NoError(t, err) - cc := wanlib.CredentialCreationFromProto(res.GetWebauthn()) + setupUser := func(cluster string, asrv *auth.Server) { + // set the default auth preference + err = asrv.SetAuthPreference(ctx, webauthnPreference(cluster)) + require.NoError(t, err) - ccr, err := device.SignCredentialCreation(origin, cc) - require.NoError(t, err) - _, err = rootAuth.GetAuthServer().ChangeUserAuthentication(ctx, &proto.ChangeUserAuthenticationRequest{ - TokenID: tokenID, - NewPassword: []byte(password), - NewMFARegisterResponse: &proto.MFARegisterResponse{ - Response: &proto.MFARegisterResponse_Webauthn{ - Webauthn: wanlib.CredentialCreationResponseToProto(ccr), - }, - }, - }) - require.NoError(t, err) + token, err := asrv.CreateResetPasswordToken(ctx, auth.CreateUserTokenRequest{ + Name: "alice", + }) + require.NoError(t, err) + tokenID := token.GetName() + res, err := asrv.CreateRegisterChallenge(ctx, &proto.CreateRegisterChallengeRequest{ + TokenID: tokenID, + DeviceType: proto.DeviceType_DEVICE_TYPE_WEBAUTHN, + DeviceUsage: proto.DeviceUsage_DEVICE_USAGE_PASSWORDLESS, + }) + require.NoError(t, err) + cc := wanlib.CredentialCreationFromProto(res.GetWebauthn()) - successfulChallenge := func(ctx context.Context, realOrigin string, assertion *wanlib.CredentialAssertion, prompt wancli.LoginPrompt, _ *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) { - car, err := device.SignAssertion(origin, assertion) // use the fake origin to prevent a mismatch - if err != nil { - return nil, "", err - } - return &proto.MFAAuthenticateResponse{ - Response: &proto.MFAAuthenticateResponse_Webauthn{ - Webauthn: wanlib.CredentialAssertionResponseToProto(car), + ccr, err := device.SignCredentialCreation(origin(cluster), cc) + require.NoError(t, err) + _, err = asrv.ChangeUserAuthentication(ctx, &proto.ChangeUserAuthenticationRequest{ + TokenID: tokenID, + NewPassword: []byte(password), + NewMFARegisterResponse: &proto.MFARegisterResponse{ + Response: &proto.MFARegisterResponse_Webauthn{ + Webauthn: wanlib.CredentialCreationResponseToProto(ccr), + }, }, - }, "", nil + }) + require.NoError(t, err) } - failedChallenge := func(ctx context.Context, realOrigin string, assertion *wanlib.CredentialAssertion, prompt wancli.LoginPrompt, _ *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) { - car, err := device.SignAssertion(origin, assertion) // use the fake origin to prevent a mismatch - if err != nil { - return nil, "", err + setupUser("localhost", rootAuth.GetAuthServer()) + setupUser("leafcluster", leafAuth.GetAuthServer()) + + successfulChallenge := func(cluster string) func(ctx context.Context, realOrigin string, assertion *wanlib.CredentialAssertion, prompt wancli.LoginPrompt, _ *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) { + return func(ctx context.Context, realOrigin string, assertion *wanlib.CredentialAssertion, prompt wancli.LoginPrompt, _ *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) { + car, err := device.SignAssertion(origin(cluster), assertion) // use the fake origin to prevent a mismatch + if err != nil { + return nil, "", err + } + return &proto.MFAAuthenticateResponse{ + Response: &proto.MFAAuthenticateResponse_Webauthn{ + Webauthn: wanlib.CredentialAssertionResponseToProto(car), + }, + }, "", nil } - carProto := wanlib.CredentialAssertionResponseToProto(car) - carProto.Type = "NOT A VALID TYPE" // set to an invalid type so the ceremony fails + } - return &proto.MFAAuthenticateResponse{ - Response: &proto.MFAAuthenticateResponse_Webauthn{ - Webauthn: carProto, - }, - }, "", nil + failedChallenge := func(cluster string) func(ctx context.Context, realOrigin string, assertion *wanlib.CredentialAssertion, prompt wancli.LoginPrompt, _ *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) { + return func(ctx context.Context, realOrigin string, assertion *wanlib.CredentialAssertion, prompt wancli.LoginPrompt, _ *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) { + + car, err := device.SignAssertion(origin(cluster), assertion) // use the fake origin to prevent a mismatch + if err != nil { + return nil, "", err + } + carProto := wanlib.CredentialAssertionResponseToProto(car) + carProto.Type = "NOT A VALID TYPE" // set to an invalid type so the ceremony fails + + return &proto.MFAAuthenticateResponse{ + Response: &proto.MFAAuthenticateResponse_Webauthn{ + Webauthn: carProto, + }, + }, "", nil + } } type mfaPrompt = func(ctx context.Context, origin string, assertion *wanlib.CredentialAssertion, prompt wancli.LoginPrompt, _ *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) @@ -1128,10 +1194,15 @@ func TestSSHOnMultipleNodes(t *testing.T) { stdoutAssertion require.ValueAssertionFunc mfaPromptCount int headless bool + proxyAddr string + auth *auth.Server + cluster string }{ { name: "default auth preference runs commands on multiple nodes without mfa", authPreference: defaultPreference, + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), target: "env=stage", stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { require.Equal(t, "test\ntest\n", i, i2...) @@ -1139,16 +1210,20 @@ func TestSSHOnMultipleNodes(t *testing.T) { errAssertion: require.NoError, }, { - name: "webauthn auth preference runs commands on multiple matches without mfa", - target: "env=stage", + name: "webauthn auth preference runs commands on multiple matches without mfa", + target: "env=stage", + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { require.Equal(t, "test\ntest\n", i, i2...) }, errAssertion: require.NoError, }, { - name: "webauthn auth preference runs commands on a single match without mfa", - target: "env=prod", + name: "webauthn auth preference runs commands on a single match without mfa", + target: "env=prod", + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { require.Equal(t, "test\n", i, i2...) }, @@ -1157,6 +1232,8 @@ func TestSSHOnMultipleNodes(t *testing.T) { { name: "no matching hosts", target: "env=dev", + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), errAssertion: require.Error, stdoutAssertion: require.Empty, }, @@ -1172,8 +1249,10 @@ func TestSSHOnMultipleNodes(t *testing.T) { RequireMFAType: types.RequireMFAType_SESSION, }, }, - setup: setupChallengeSolver(successfulChallenge), - target: "env=stage", + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), + setup: setupChallengeSolver(successfulChallenge("localhost")), + target: "env=stage", stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { require.Equal(t, "test\ntest\n", i, i2...) }, @@ -1192,8 +1271,10 @@ func TestSSHOnMultipleNodes(t *testing.T) { RequireMFAType: types.RequireMFAType_SESSION, }, }, - setup: setupChallengeSolver(successfulChallenge), - target: "env=prod", + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), + setup: setupChallengeSolver(successfulChallenge("localhost")), + target: "env=prod", stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { require.Equal(t, "test\n", i, i2...) }, @@ -1212,7 +1293,9 @@ func TestSSHOnMultipleNodes(t *testing.T) { RequireMFAType: types.RequireMFAType_SESSION, }, }, - setup: setupChallengeSolver(successfulChallenge), + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), + setup: setupChallengeSolver(successfulChallenge("localhost")), target: "env=dev", errAssertion: require.Error, stdoutAssertion: require.Empty, @@ -1228,9 +1311,11 @@ func TestSSHOnMultipleNodes(t *testing.T) { }, }, }, - roles: []string{"access", sshLoginRole.GetName(), perSessionMFARole.GetName()}, - setup: setupChallengeSolver(successfulChallenge), - target: "env=stage", + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), + roles: []string{"access", sshLoginRole.GetName(), perSessionMFARole.GetName()}, + setup: setupChallengeSolver(successfulChallenge("localhost")), + target: "env=stage", stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { require.Equal(t, "test\ntest\n", i, i2...) }, @@ -1238,9 +1323,11 @@ func TestSSHOnMultipleNodes(t *testing.T) { errAssertion: require.NoError, }, { - name: "role permits access without mfa", - target: sshHostID, - roles: []string{sshLoginRole.GetName()}, + name: "role permits access without mfa", + target: sshHostID, + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), + roles: []string{sshLoginRole.GetName()}, stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { require.Equal(t, "test\n", i, i2...) }, @@ -1249,15 +1336,19 @@ func TestSSHOnMultipleNodes(t *testing.T) { { name: "role prevents access", target: sshHostID, + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), roles: []string{noAccessRole.GetName()}, stdoutAssertion: require.Empty, errAssertion: require.Error, }, { - name: "command runs on a hostname with mfa set via role", - target: sshHostID, - roles: []string{perSessionMFARole.GetName()}, - setup: setupChallengeSolver(successfulChallenge), + name: "command runs on a hostname with mfa set via role", + target: sshHostID, + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), + roles: []string{perSessionMFARole.GetName()}, + setup: setupChallengeSolver(successfulChallenge("localhost")), stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { require.Equal(t, "test\n", i, i2...) }, @@ -1275,9 +1366,11 @@ func TestSSHOnMultipleNodes(t *testing.T) { }, }, }, + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), target: sshHostID, roles: []string{perSessionMFARole.GetName()}, - setup: setupChallengeSolver(failedChallenge), + setup: setupChallengeSolver(failedChallenge("localhost")), stdoutAssertion: require.Empty, mfaPromptCount: 1, errAssertion: require.Error, @@ -1293,25 +1386,81 @@ func TestSSHOnMultipleNodes(t *testing.T) { }, }, }, - target: sshHostID, - roles: []string{perSessionMFARole.GetName()}, - setup: setupChallengeSolver(failedChallenge), + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), + target: sshHostID, + roles: []string{perSessionMFARole.GetName()}, + setup: setupChallengeSolver(failedChallenge("localhost")), stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { require.Equal(t, "test\n", i, i2...) }, errAssertion: require.NoError, headless: true, }, + { + name: "command runs on a leaf node with mfa set via role", + target: sshLeafHostID, + proxyAddr: leafProxyAddr, + auth: leafAuth.GetAuthServer(), + roles: []string{perSessionMFARole.GetName()}, + setup: setupChallengeSolver(successfulChallenge("leafcluster")), + stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { + require.Equal(t, "test\n", i, i2...) + }, + mfaPromptCount: 1, + errAssertion: require.NoError, + }, + { + name: "command runs on a leaf node via root without mfa", + target: sshLeafHostID, + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), + cluster: "leafcluster", + roles: []string{sshLoginRole.GetName()}, + setup: setupChallengeSolver(successfulChallenge("localhost")), + stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { + require.Equal(t, "test\n", i, i2...) + }, + errAssertion: require.NoError, + }, + { + name: "command runs on a leaf node without mfa", + target: sshLeafHostID, + proxyAddr: leafProxyAddr, + auth: leafAuth.GetAuthServer(), + roles: []string{sshLoginRole.GetName()}, + setup: setupChallengeSolver(successfulChallenge("leafcluster")), + stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { + require.Equal(t, "test\n", i, i2...) + }, + errAssertion: require.NoError, + }, { + name: "command runs on a leaf node via root with mfa set via role", + target: sshLeafHostID, + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), + cluster: "leafcluster", + roles: []string{perSessionMFARole.GetName()}, + setup: setupChallengeSolver(successfulChallenge("localhost")), + stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { + require.Equal(t, "test\n", i, i2...) + }, + mfaPromptCount: 1, + errAssertion: require.NoError, + }, } for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { tmpHomePath := t.TempDir() + clusterName, err := tt.auth.GetClusterName() + require.NoError(t, err) + if tt.authPreference != nil { - require.NoError(t, rootAuth.GetAuthServer().SetAuthPreference(ctx, tt.authPreference)) + require.NoError(t, tt.auth.SetAuthPreference(ctx, tt.authPreference)) t.Cleanup(func() { - require.NoError(t, rootAuth.GetAuthServer().SetAuthPreference(ctx, webauthnPreference)) + require.NoError(t, tt.auth.SetAuthPreference(ctx, webauthnPreference(clusterName.GetClusterName()))) }) } @@ -1323,21 +1472,23 @@ func TestSSHOnMultipleNodes(t *testing.T) { roles := alice.GetRoles() t.Cleanup(func() { alice.SetRoles(roles) - require.NoError(t, rootAuth.GetAuthServer().UpsertUser(alice)) + require.NoError(t, tt.auth.UpsertUser(alice)) }) alice.SetRoles(tt.roles) - require.NoError(t, rootAuth.GetAuthServer().UpsertUser(alice)) + require.NoError(t, tt.auth.UpsertUser(alice)) } err = Run(ctx, []string{ "login", + "-d", "--insecure", "--auth", connector.GetName(), - "--proxy", proxyAddr.String(), + "--proxy", tt.proxyAddr, "--user", "alice", + tt.cluster, }, setHomePath(tmpHomePath), func(cf *CLIConf) error { - cf.mockSSOLogin = mockSSOLogin(t, rootAuth.GetAuthServer(), alice) + cf.mockSSOLogin = mockSSOLogin(t, tt.auth, alice) return nil }, ) @@ -1348,9 +1499,9 @@ func TestSSHOnMultipleNodes(t *testing.T) { // so we can assert how many times sign was called. device.SetCounter(0) - args := []string{"ssh", "--insecure"} + args := []string{"ssh", "-d", "--insecure"} if tt.headless { - args = append(args, "--headless", "--proxy", proxyAddr.String(), "--user", alice.GetName()) + args = append(args, "--headless", "--proxy", tt.proxyAddr, "--user", alice.GetName()) } args = append(args, tt.target, "echo", "test") @@ -1360,7 +1511,7 @@ func TestSSHOnMultipleNodes(t *testing.T) { func(conf *CLIConf) error { conf.overrideStdin = &bytes.Buffer{} conf.overrideStdout = stdout - conf.mockHeadlessLogin = mockHeadlessLogin(t, rootAuth.GetAuthServer(), alice) + conf.mockHeadlessLogin = mockHeadlessLogin(t, tt.auth, alice) return nil }, )