diff --git a/integration/proxy/teleterm_test.go b/integration/proxy/teleterm_test.go index 593dd705d6a6d..74d65b7aabd33 100644 --- a/integration/proxy/teleterm_test.go +++ b/integration/proxy/teleterm_test.go @@ -198,7 +198,9 @@ func testGatewayCertRenewal(ctx context.Context, t *testing.T, params gatewayCer CreateTshdEventsClientCredsFunc: func() (grpc.DialOption, error) { return grpc.WithTransportCredentials(insecure.NewCredentials()), nil }, - ClientCache: clientcache.NewNoCache(storage), + CreateClientCacheFunc: func(resolveCluster daemon.ResolveClusterFunc) daemon.ClientCache { + return clientcache.NewNoCache(clientcache.ResolveClusterFunc(resolveCluster)) + }, KubeconfigsDir: t.TempDir(), AgentsDir: t.TempDir(), }) diff --git a/integration/teleterm_test.go b/integration/teleterm_test.go index 74c3c2146598b..2294639929555 100644 --- a/integration/teleterm_test.go +++ b/integration/teleterm_test.go @@ -38,6 +38,8 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils" @@ -45,6 +47,9 @@ import ( dbhelpers "github.com/gravitational/teleport/integration/db" "github.com/gravitational/teleport/integration/helpers" "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/auth/mocku2f" + wancli "github.com/gravitational/teleport/lib/auth/webauthncli" + wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/service" "github.com/gravitational/teleport/lib/service/servicecfg" @@ -105,7 +110,7 @@ func TestTeleterm(t *testing.T) { t.Run("CreateConnectMyComputerToken", func(t *testing.T) { t.Parallel() - testCreateConnectMyComputerToken(t, pack) + testCreateConnectMyComputerToken(t, pack, nil /* setupUserMFA */) }) t.Run("WaitForConnectMyComputerNodeJoin", func(t *testing.T) { @@ -123,6 +128,115 @@ func TestTeleterm(t *testing.T) { testClientCache(t, pack, creds) }) + + t.Run("with MFA", func(t *testing.T) { + authServer := pack.Root.Cluster.Process.GetAuthServer() + rpID, _, err := net.SplitHostPort(pack.Root.Cluster.Web) + require.NoError(t, err) + + // Enforce MFA + _, err = authServer.UpsertAuthPreference(context.Background(), &types.AuthPreferenceV2{ + Spec: types.AuthPreferenceSpecV2{ + Type: constants.Local, + SecondFactor: constants.SecondFactorWebauthn, + Webauthn: &types.Webauthn{ + RPID: rpID, + }, + }, + }) + require.NoError(t, err) + + // Remove MFA enforcement on cleanup. + t.Cleanup(func() { + _, err := authServer.UpsertAuthPreference(context.Background(), &types.AuthPreferenceV2{ + Spec: types.AuthPreferenceSpecV2{ + Type: constants.Local, + SecondFactor: constants.SecondFactorOff, + }, + }) + require.NoError(t, err) + }) + + setupUserMFA := func(t *testing.T, userName string, tshdEventsService *mockTSHDEventsService) client.WebauthnLoginFunc { + // Configure user account with an MFA device. + origin := fmt.Sprintf("https://%s", rpID) + device, err := mocku2f.Create() + require.NoError(t, err) + device.SetPasswordless() + + token, err := authServer.CreateResetPasswordToken(context.Background(), auth.CreateUserTokenRequest{ + Name: userName, + }) + require.NoError(t, err) + + tokenID := token.GetName() + res, err := authServer.CreateRegisterChallenge(context.Background(), &proto.CreateRegisterChallengeRequest{ + TokenID: tokenID, + DeviceType: proto.DeviceType_DEVICE_TYPE_WEBAUTHN, + DeviceUsage: proto.DeviceUsage_DEVICE_USAGE_PASSWORDLESS, + }) + require.NoError(t, err) + cc := wantypes.CredentialCreationFromProto(res.GetWebauthn()) + + ccr, err := device.SignCredentialCreation(origin, cc) + require.NoError(t, err) + _, err = authServer.ChangeUserAuthentication(context.Background(), &proto.ChangeUserAuthenticationRequest{ + TokenID: tokenID, + NewMFARegisterResponse: &proto.MFARegisterResponse{ + Response: &proto.MFARegisterResponse_Webauthn{ + Webauthn: wantypes.CredentialCreationResponseToProto(ccr), + }, + }, + }) + require.NoError(t, err) + + // Prepare a function which simulates key tap. + var webauthLoginCallCount atomic.Uint32 + webauthnLogin := func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, opts *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) { + t.Helper() + updatedWebauthnLoginCallCount := webauthLoginCallCount.Add(1) + + // When daemon.mfaPrompt prompts for MFA, it spawns two goroutines. One calls PromptMFA on + // tshdEventService and expects OTP in response (if available). Another calls this function. + // Whichever returns a non-error response first wins. + // + // Since in this test we use Webauthn, this function can return ASAP without giving a chance + // to the other to call PromptMFA. This would cause race conditions, as we might want to + // verify later in the test that PromptMFA has indeed been called. + // + // To ensure that, this function waits until PromptMFA has been called before proceeding. + // This also simulates a flow where the user was notified about the need to tap the key + // through the UI and then taps the key. + assert.EventuallyWithT(t, func(t *assert.CollectT) { + // Each call to webauthnLogin should have an equivalent call to PromptMFA and there should + // be no multiple concurrent calls. + assert.Equal(t, updatedWebauthnLoginCallCount, tshdEventsService.promptMFACallCount.Load(), + "Expected each call to webauthnLogin to have an equivalent call to PromptMFA") + }, 5*time.Second, 50*time.Millisecond) + + car, err := device.SignAssertion(origin, assertion) + if err != nil { + return nil, "", err + } + + carProto := wantypes.CredentialAssertionResponseToProto(car) + + return &proto.MFAAuthenticateResponse{ + Response: &proto.MFAAuthenticateResponse_Webauthn{ + Webauthn: carProto, + }, + }, "", nil + } + + return webauthnLogin + } + + t.Run("CreateConnectMyComputerToken", func(t *testing.T) { + t.Parallel() + + testCreateConnectMyComputerToken(t, pack, setupUserMFA) + }) + }) } func testAddingRootCluster(t *testing.T, pack *dbhelpers.DatabasePack, creds *helpers.UserCreds) { @@ -312,7 +426,6 @@ func testHeadlessWatcher(t *testing.T, pack *dbhelpers.DatabasePack, creds *help ha.State = types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_PENDING // Start the tshd event service and connect the daemon to it. - tshdEventsService, addr := newMockTSHDEventsServiceServer(t) err = daemonService.UpdateAndDialTshdEventsServerAddress(addr) require.NoError(t, err) @@ -698,7 +811,7 @@ func testCreateConnectMyComputerRole(t *testing.T, pack *dbhelpers.DatabasePack) } } -func testCreateConnectMyComputerToken(t *testing.T, pack *dbhelpers.DatabasePack) { +func testCreateConnectMyComputerToken(t *testing.T, pack *dbhelpers.DatabasePack, setupUserMFA setupUserMFAFunc) { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) @@ -721,6 +834,12 @@ func testCreateConnectMyComputerToken(t *testing.T, pack *dbhelpers.DatabasePack _, err = auth.CreateUser(ctx, authServer, userName, userRoles...) require.NoError(t, err) + tshdEventsService, addr := newMockTSHDEventsServiceServer(t) + var webauthnLogin client.WebauthnLoginFunc + if setupUserMFA != nil { + webauthnLogin = setupUserMFA(t, userName, tshdEventsService) + } + // Log in as the new user. creds, err := helpers.GenerateUserCreds(helpers.UserCredsRequest{ Process: pack.Root.Cluster.Process, @@ -736,6 +855,7 @@ func testCreateConnectMyComputerToken(t *testing.T, pack *dbhelpers.DatabasePack Dir: tc.KeysDir, InsecureSkipVerify: tc.InsecureSkipVerify, Clock: fakeClock, + WebauthnLogin: webauthnLogin, }) require.NoError(t, err) @@ -744,6 +864,9 @@ func testCreateConnectMyComputerToken(t *testing.T, pack *dbhelpers.DatabasePack Storage: storage, KubeconfigsDir: t.TempDir(), AgentsDir: t.TempDir(), + CreateTshdEventsClientCredsFunc: func() (grpc.DialOption, error) { + return grpc.WithTransportCredentials(insecure.NewCredentials()), nil + }, }) require.NoError(t, err) t.Cleanup(func() { @@ -756,6 +879,9 @@ func testCreateConnectMyComputerToken(t *testing.T, pack *dbhelpers.DatabasePack ) require.NoError(t, err) + err = daemonService.UpdateAndDialTshdEventsServerAddress(addr) + require.NoError(t, err) + // Call CreateConnectMyComputerNodeToken. rootClusterName, _, err := net.SplitHostPort(pack.Root.Cluster.Web) require.NoError(t, err) @@ -774,6 +900,11 @@ func testCreateConnectMyComputerToken(t *testing.T, pack *dbhelpers.DatabasePack require.Equal(t, types.SystemRoles{types.RoleNode}, tokenFromAuthServer.GetRoles()) // ...and is valid for no longer than 5 minutes. require.LessOrEqual(t, tokenFromAuthServer.Expiry(), requestCreatedAt.Add(5*time.Minute)) + + if setupUserMFA != nil { + require.Equal(t, uint32(1), tshdEventsService.promptMFACallCount.Load(), + "Unexpected number of calls to TSHDEventsClient.PromptMFA") + } } func testWaitForConnectMyComputerNodeJoin(t *testing.T, pack *dbhelpers.DatabasePack, creds *helpers.UserCreds) { @@ -943,12 +1074,16 @@ func mustLogin(t *testing.T, userName string, pack *dbhelpers.DatabasePack, cred return tc } +type setupUserMFAFunc func(t *testing.T, userName string, tshdEventsService *mockTSHDEventsService) client.WebauthnLoginFunc + type mockTSHDEventsService struct { - *api.UnimplementedTshdEventsServiceServer + api.UnimplementedTshdEventsServiceServer sendPendingHeadlessAuthenticationCount atomic.Uint32 + promptMFACallCount atomic.Uint32 } func newMockTSHDEventsServiceServer(t *testing.T) (service *mockTSHDEventsService, addr string) { + t.Helper() tshdEventsService := &mockTSHDEventsService{} ls, err := net.Listen("tcp", "localhost:0") @@ -981,3 +1116,11 @@ func (c *mockTSHDEventsService) SendPendingHeadlessAuthentication(context.Contex c.sendPendingHeadlessAuthenticationCount.Add(1) return &api.SendPendingHeadlessAuthenticationResponse{}, nil } + +func (c *mockTSHDEventsService) PromptMFA(context.Context, *api.PromptMFARequest) (*api.PromptMFAResponse, error) { + c.promptMFACallCount.Add(1) + + // PromptMFAResponse returns the TOTP code, so PromptMFA itself + // needs to be implemented only once we need to test TOTP MFA. + return nil, trace.NotImplemented("mockTSHDEventsService does not implement PromptMFA") +} diff --git a/lib/teleterm/daemon/config.go b/lib/teleterm/daemon/config.go index f6fd1163c94ad..437c5675398a5 100644 --- a/lib/teleterm/daemon/config.go +++ b/lib/teleterm/daemon/config.go @@ -72,9 +72,12 @@ type Config struct { ConnectMyComputerNodeDelete *connectmycomputer.NodeDelete ConnectMyComputerNodeName *connectmycomputer.NodeName - ClientCache ClientCache + CreateClientCacheFunc func(resolver ResolveClusterFunc) ClientCache } +// ResolveClusterFunc returns a cluster by URI. +type ResolveClusterFunc func(uri uri.ResourceURI) (*clusters.Cluster, *client.TeleportClient, error) + // ClientCache stores clients keyed by cluster URI. type ClientCache interface { // Get returns a client from the cache if there is one, @@ -157,11 +160,13 @@ func (c *Config) CheckAndSetDefaults() error { c.ConnectMyComputerNodeName = nodeName } - if c.ClientCache == nil { - c.ClientCache = clientcache.New(clientcache.Config{ - Log: c.Log, - Resolver: c.Storage, - }) + if c.CreateClientCacheFunc == nil { + c.CreateClientCacheFunc = func(resolver ResolveClusterFunc) ClientCache { + return clientcache.New(clientcache.Config{ + Log: c.Log, + ResolveClusterFunc: clientcache.ResolveClusterFunc(resolver), + }) + } } return nil diff --git a/lib/teleterm/daemon/daemon.go b/lib/teleterm/daemon/daemon.go index b1cc5d5446a06..f3ce7cf92c5a8 100644 --- a/lib/teleterm/daemon/daemon.go +++ b/lib/teleterm/daemon/daemon.go @@ -77,14 +77,24 @@ func New(cfg Config) (*Service, error) { go connectUsageReporter.Run(closeContext) - return &Service{ + service := &Service{ cfg: &cfg, closeContext: closeContext, cancel: cancel, gateways: make(map[string]gateway.Gateway), usageReporter: connectUsageReporter, headlessWatcherClosers: make(map[string]context.CancelFunc), - }, nil + } + + // TODO(gzdunek): The client cache should be created outside of daemon.New. + // Unfortunately, we have to do it here, because we need to pass + // Daemon.ResolveClusterURI as a cluster resolver. + // Why can't we pass Storage.GetByResourceURI? + // That's because Daemon.ResolveClusterURI sets a custom MFAPromptConstructor that + // shows an MFA prompt in Connect. + // At the level of Storage.ResolveClusterFunc we don't have access to it. + service.clientCache = cfg.CreateClientCacheFunc(service.ResolveClusterURI) + return service, nil } // relogin makes the Electron app display a login modal to trigger re-login. @@ -802,7 +812,7 @@ func (s *Service) Stop() { s.StopHeadlessWatchers() - if err := s.cfg.ClientCache.Clear(); err != nil { + if err := s.clientCache.Clear(); err != nil { s.cfg.Log.WithError(err).Error("Failed to close remote clients") } @@ -1084,14 +1094,14 @@ func (s *Service) findGatewayByTargetURI(targetURI uri.ResourceURI) (gateway.Gat // GetCachedClient returns a client from the cache if it exists, // otherwise it dials the remote server. func (s *Service) GetCachedClient(ctx context.Context, clusterURI uri.ResourceURI) (*client.ProxyClient, error) { - clt, err := s.cfg.ClientCache.Get(ctx, clusterURI) + clt, err := s.clientCache.Get(ctx, clusterURI) return clt, trace.Wrap(err) } // ClearCachedClientsForRoot closes and removes clients from the cache // for the root cluster and its leaf clusters. func (s *Service) ClearCachedClientsForRoot(clusterURI uri.ResourceURI) error { - return trace.Wrap(s.cfg.ClientCache.ClearForRoot(clusterURI)) + return trace.Wrap(s.clientCache.ClearForRoot(clusterURI)) } // Service is the daemon service @@ -1126,6 +1136,7 @@ type Service struct { // headlessWatcherClosers holds a map of root cluster URIs to headless watchers. headlessWatcherClosers map[string]context.CancelFunc headlessWatcherClosersMu sync.Mutex + clientCache ClientCache } type CreateGatewayParams struct { diff --git a/lib/teleterm/daemon/daemon_test.go b/lib/teleterm/daemon/daemon_test.go index ebda3c15a3706..836d39f928f61 100644 --- a/lib/teleterm/daemon/daemon_test.go +++ b/lib/teleterm/daemon/daemon_test.go @@ -272,7 +272,9 @@ func TestGatewayCRUD(t *testing.T) { GatewayCreator: mockGatewayCreator, KubeconfigsDir: t.TempDir(), AgentsDir: t.TempDir(), - ClientCache: fakeClientCache{}, + CreateClientCacheFunc: func(resolver ResolveClusterFunc) ClientCache { + return fakeClientCache{} + }, }) require.NoError(t, err) @@ -451,7 +453,9 @@ func TestRetryWithRelogin(t *testing.T) { }, KubeconfigsDir: t.TempDir(), AgentsDir: t.TempDir(), - ClientCache: fakeClientCache{}, + CreateClientCacheFunc: func(resolver ResolveClusterFunc) ClientCache { + return fakeClientCache{} + }, }) require.NoError(t, err) @@ -502,7 +506,9 @@ func TestImportantModalSemaphore(t *testing.T) { }, KubeconfigsDir: t.TempDir(), AgentsDir: t.TempDir(), - ClientCache: fakeClientCache{}, + CreateClientCacheFunc: func(resolver ResolveClusterFunc) ClientCache { + return fakeClientCache{} + }, }) require.NoError(t, err) @@ -651,7 +657,9 @@ func TestGetGatewayCLICommand(t *testing.T) { }, KubeconfigsDir: t.TempDir(), AgentsDir: t.TempDir(), - ClientCache: fakeClientCache{}, + CreateClientCacheFunc: func(resolver ResolveClusterFunc) ClientCache { + return fakeClientCache{} + }, }) require.NoError(t, err) diff --git a/lib/teleterm/services/clientcache/clientcache.go b/lib/teleterm/services/clientcache/clientcache.go index 8ab91c9a857e1..b99899a4b9da4 100644 --- a/lib/teleterm/services/clientcache/clientcache.go +++ b/lib/teleterm/services/clientcache/clientcache.go @@ -45,10 +45,12 @@ type Cache struct { group singleflight.Group } +type ResolveClusterFunc func(uri uri.ResourceURI) (*clusters.Cluster, *client.TeleportClient, error) + // Config describes the client cache configuration. type Config struct { - Resolver clusters.Resolver - Log logrus.FieldLogger + ResolveClusterFunc ResolveClusterFunc + Log logrus.FieldLogger } func (c *Config) checkAndSetDefaults() { @@ -76,7 +78,7 @@ func (c *Cache) Get(ctx context.Context, clusterURI uri.ResourceURI) (*client.Pr return fromCache, nil } - _, clusterClient, err := c.cfg.Resolver.ResolveCluster(clusterURI) + _, clusterClient, err := c.cfg.ResolveClusterFunc(clusterURI) if err != nil { return nil, trace.Wrap(err) } @@ -198,9 +200,9 @@ func (c *Cache) getFromCache(clusterURI uri.ResourceURI) *client.ProxyClient { // // ClearForRoot and Clear still work as expected. type NoCache struct { - mu sync.Mutex - resolver clusters.Resolver - clients []noCacheClient + mu sync.Mutex + resolveClusterFunc ResolveClusterFunc + clients []noCacheClient } type noCacheClient struct { @@ -208,14 +210,14 @@ type noCacheClient struct { client *client.ProxyClient } -func NewNoCache(resolver clusters.Resolver) *NoCache { +func NewNoCache(resolveClusterFunc ResolveClusterFunc) *NoCache { return &NoCache{ - resolver: resolver, + resolveClusterFunc: resolveClusterFunc, } } func (c *NoCache) Get(ctx context.Context, clusterURI uri.ResourceURI) (*client.ProxyClient, error) { - _, clusterClient, err := c.resolver.ResolveCluster(clusterURI) + _, clusterClient, err := c.resolveClusterFunc(clusterURI) if err != nil { return nil, trace.Wrap(err) }