diff --git a/api/utils/keys/piv/service.go b/api/utils/keys/piv/service.go index c4b4e9d5cc1d0..b0c5daa0544bd 100644 --- a/api/utils/keys/piv/service.go +++ b/api/utils/keys/piv/service.go @@ -43,6 +43,9 @@ var yubiKeyServiceMu sync.Mutex // YubiKeyService is a YubiKey PIV implementation of [hardwarekey.Service]. type YubiKeyService struct { prompt hardwarekey.Prompt + // TODO(Joerger): Remove prompt mutex once there is no longer a shared global service + // that needs its protection. + promptMu sync.Mutex // signMu prevents prompting for PIN/touch repeatedly for concurrent signatures. // TODO(Joerger): Rather than preventing concurrent signatures, we can make the @@ -67,7 +70,7 @@ func NewYubiKeyService(customPrompt hardwarekey.Prompt) *YubiKeyService { if yubiKeyService != nil { // If a prompt is provided, prioritize it over the existing prompt value. if customPrompt != nil { - yubiKeyService.prompt = customPrompt + yubiKeyService.setPrompt(customPrompt) } return yubiKeyService } @@ -116,7 +119,7 @@ func (s *YubiKeyService) NewPrivateKey(ctx context.Context, config hardwarekey.P // If PIN is required, check that PIN and PUK are not the defaults. if config.Policy.PINRequired { - if err := y.checkOrSetPIN(ctx, s.prompt, config.ContextualKeyInfo, config.PINCacheTTL); err != nil { + if err := y.checkOrSetPIN(ctx, s.getPrompt(), config.ContextualKeyInfo, config.PINCacheTTL); err != nil { return nil, trace.Wrap(err) } } @@ -188,7 +191,7 @@ func (s *YubiKeyService) Sign(ctx context.Context, ref *hardwarekey.PrivateKeyRe s.signMu.Lock() defer s.signMu.Unlock() - return y.sign(ctx, ref, keyInfo, s.prompt, rand, digest, opts) + return y.sign(ctx, ref, keyInfo, s.getPrompt(), rand, digest, opts) } // TODO(Joerger): Re-attesting the key every time we decode a hardware key signer is very resource @@ -260,10 +263,22 @@ func (s *YubiKeyService) getYubiKey(serialNumber uint32) (*YubiKey, error) { func (s *YubiKeyService) promptOverwriteSlot(ctx context.Context, msg string, keyInfo hardwarekey.ContextualKeyInfo) error { promptQuestion := fmt.Sprintf("%v\nWould you like to overwrite this slot's private key and certificate?", msg) - if confirmed, confirmErr := s.prompt.ConfirmSlotOverwrite(ctx, promptQuestion, keyInfo); confirmErr != nil { + if confirmed, confirmErr := s.getPrompt().ConfirmSlotOverwrite(ctx, promptQuestion, keyInfo); confirmErr != nil { return trace.Wrap(confirmErr) } else if !confirmed { return trace.Wrap(trace.CompareFailed(msg), "user declined to overwrite slot") } return nil } + +func (s *YubiKeyService) setPrompt(prompt hardwarekey.Prompt) { + s.promptMu.Lock() + defer s.promptMu.Unlock() + s.prompt = prompt +} + +func (s *YubiKeyService) getPrompt() hardwarekey.Prompt { + s.promptMu.Lock() + defer s.promptMu.Unlock() + return s.prompt +} diff --git a/integration/proxy/teleterm_test.go b/integration/proxy/teleterm_test.go index 4deec70fe6986..0c7731360ada5 100644 --- a/integration/proxy/teleterm_test.go +++ b/integration/proxy/teleterm_test.go @@ -242,6 +242,7 @@ func testGatewayCertRenewal(ctx context.Context, t *testing.T, params gatewayCer fakeClock := clockwork.NewFakeClockAt(time.Now()) storage, err := clusters.NewStorage(clusters.Config{ Dir: tc.KeysDir, + ClientStore: tc.ClientStore, InsecureSkipVerify: tc.InsecureSkipVerify, // Inject a fake clock into clusters.Storage so we can control when the middleware thinks the // db cert has expired. @@ -250,12 +251,14 @@ func testGatewayCertRenewal(ctx context.Context, t *testing.T, params gatewayCer }) require.NoError(t, err) + tshdEventsClient := daemon.NewTshdEventsClient(func() (grpc.DialOption, error) { + return grpc.WithTransportCredentials(insecure.NewCredentials()), nil + }) + daemonService, err := daemon.New(daemon.Config{ - Clock: fakeClock, - Storage: storage, - CreateTshdEventsClientCredsFunc: func() (grpc.DialOption, error) { - return grpc.WithTransportCredentials(insecure.NewCredentials()), nil - }, + Clock: fakeClock, + Storage: storage, + TshdEventsClient: tshdEventsClient, CreateClientCacheFunc: func(newClient clientcache.NewClientFunc) (daemon.ClientCache, error) { return clientcache.NewNoCache(newClient), nil }, @@ -877,14 +880,18 @@ func testTeletermAppGatewayTargetPortValidation(t *testing.T, pack *appaccess.Pa storage, err := clusters.NewStorage(clusters.Config{ Dir: tc.KeysDir, + ClientStore: tc.ClientStore, InsecureSkipVerify: tc.InsecureSkipVerify, }) require.NoError(t, err) + + tshdEventsClient := daemon.NewTshdEventsClient(func() (grpc.DialOption, error) { + return grpc.WithTransportCredentials(insecure.NewCredentials()), nil + }) + daemonService, err := daemon.New(daemon.Config{ - Storage: storage, - CreateTshdEventsClientCredsFunc: func() (grpc.DialOption, error) { - return grpc.WithTransportCredentials(insecure.NewCredentials()), nil - }, + Storage: storage, + TshdEventsClient: tshdEventsClient, CreateClientCacheFunc: func(newClient clientcache.NewClientFunc) (daemon.ClientCache, error) { return clientcache.NewNoCache(newClient), nil }, diff --git a/integration/teleterm_test.go b/integration/teleterm_test.go index 25c25507829fc..656472ff86f29 100644 --- a/integration/teleterm_test.go +++ b/integration/teleterm_test.go @@ -254,8 +254,10 @@ func TestTeleterm(t *testing.T) { func testAddingRootCluster(t *testing.T, pack *dbhelpers.DatabasePack, creds *helpers.UserCreds) { t.Helper() + homeDir := t.TempDir() storage, err := clusters.NewStorage(clusters.Config{ - Dir: t.TempDir(), + Dir: homeDir, + ClientStore: client.NewFSClientStore(homeDir), InsecureSkipVerify: true, }) require.NoError(t, err) @@ -288,6 +290,7 @@ func testListRootClustersReturnsLoggedInUser(t *testing.T, pack *dbhelpers.Datab storage, err := clusters.NewStorage(clusters.Config{ Dir: tc.KeysDir, + ClientStore: tc.ClientStore, InsecureSkipVerify: tc.InsecureSkipVerify, }) require.NoError(t, err) @@ -370,6 +373,7 @@ func testGetClusterReturnsPropertiesFromAuthServer(t *testing.T, pack *dbhelpers storage, err := clusters.NewStorage(clusters.Config{ Dir: tc.KeysDir, + ClientStore: tc.ClientStore, InsecureSkipVerify: tc.InsecureSkipVerify, }) require.NoError(t, err) @@ -422,6 +426,7 @@ func testHeadlessWatcher(t *testing.T, pack *dbhelpers.DatabasePack, creds *help storage, err := clusters.NewStorage(clusters.Config{ Dir: tc.KeysDir, + ClientStore: tc.ClientStore, InsecureSkipVerify: tc.InsecureSkipVerify, }) require.NoError(t, err) @@ -429,13 +434,15 @@ func testHeadlessWatcher(t *testing.T, pack *dbhelpers.DatabasePack, creds *help cluster, _, err := storage.Add(ctx, tc.WebProxyAddr) require.NoError(t, err) + tshdEventsClient := daemon.NewTshdEventsClient(func() (grpc.DialOption, error) { + return grpc.WithTransportCredentials(insecure.NewCredentials()), nil + }) + daemonService, err := daemon.New(daemon.Config{ - Storage: storage, - CreateTshdEventsClientCredsFunc: func() (grpc.DialOption, error) { - return grpc.WithTransportCredentials(insecure.NewCredentials()), nil - }, - KubeconfigsDir: t.TempDir(), - AgentsDir: t.TempDir(), + Storage: storage, + TshdEventsClient: tshdEventsClient, + KubeconfigsDir: t.TempDir(), + AgentsDir: t.TempDir(), }) require.NoError(t, err) t.Cleanup(func() { @@ -489,6 +496,7 @@ func testClientCache(t *testing.T, pack *dbhelpers.DatabasePack, creds *helpers. storage, err := clusters.NewStorage(clusters.Config{ Dir: tc.KeysDir, + ClientStore: tc.ClientStore, Clock: storageFakeClock, InsecureSkipVerify: tc.InsecureSkipVerify, }) @@ -497,13 +505,15 @@ func testClientCache(t *testing.T, pack *dbhelpers.DatabasePack, creds *helpers. cluster, _, err := storage.Add(ctx, tc.WebProxyAddr) require.NoError(t, err) + tshdEventsClient := daemon.NewTshdEventsClient(func() (grpc.DialOption, error) { + return grpc.WithTransportCredentials(insecure.NewCredentials()), nil + }) + daemonService, err := daemon.New(daemon.Config{ - Storage: storage, - CreateTshdEventsClientCredsFunc: func() (grpc.DialOption, error) { - return grpc.WithTransportCredentials(insecure.NewCredentials()), nil - }, - KubeconfigsDir: t.TempDir(), - AgentsDir: t.TempDir(), + Storage: storage, + TshdEventsClient: tshdEventsClient, + KubeconfigsDir: t.TempDir(), + AgentsDir: t.TempDir(), }) require.NoError(t, err) t.Cleanup(func() { @@ -747,8 +757,10 @@ func testCreateConnectMyComputerRole(t *testing.T, pack *dbhelpers.DatabasePack) require.NoError(t, authServer.UpsertPassword(userName, []byte(userPassword))) // Prepare daemon.Service. + homeDir := t.TempDir() storage, err := clusters.NewStorage(clusters.Config{ - Dir: t.TempDir(), + Dir: homeDir, + ClientStore: client.NewFSClientStore(homeDir), InsecureSkipVerify: true, }) require.NoError(t, err) @@ -863,20 +875,23 @@ func testCreateConnectMyComputerToken(t *testing.T, pack *dbhelpers.DatabasePack // Prepare daemon.Service. storage, err := clusters.NewStorage(clusters.Config{ Dir: tc.KeysDir, + ClientStore: tc.ClientStore, InsecureSkipVerify: tc.InsecureSkipVerify, Clock: fakeClock, WebauthnLogin: webauthnLogin, }) require.NoError(t, err) + tshdEventsClient := daemon.NewTshdEventsClient(func() (grpc.DialOption, error) { + return grpc.WithTransportCredentials(insecure.NewCredentials()), nil + }) + daemonService, err := daemon.New(daemon.Config{ - Clock: fakeClock, - Storage: storage, - KubeconfigsDir: t.TempDir(), - AgentsDir: t.TempDir(), - CreateTshdEventsClientCredsFunc: func() (grpc.DialOption, error) { - return grpc.WithTransportCredentials(insecure.NewCredentials()), nil - }, + Clock: fakeClock, + Storage: storage, + KubeconfigsDir: t.TempDir(), + AgentsDir: t.TempDir(), + TshdEventsClient: tshdEventsClient, }) require.NoError(t, err) t.Cleanup(func() { @@ -925,6 +940,7 @@ func testWaitForConnectMyComputerNodeJoin(t *testing.T, pack *dbhelpers.Database storage, err := clusters.NewStorage(clusters.Config{ Dir: tc.KeysDir, + ClientStore: tc.ClientStore, InsecureSkipVerify: tc.InsecureSkipVerify, }) require.NoError(t, err) @@ -1009,6 +1025,7 @@ func testDeleteConnectMyComputerNode(t *testing.T, pack *dbhelpers.DatabasePack) storage, err := clusters.NewStorage(clusters.Config{ Dir: tc.KeysDir, + ClientStore: tc.ClientStore, InsecureSkipVerify: tc.InsecureSkipVerify, }) require.NoError(t, err) @@ -1236,6 +1253,7 @@ func testListDatabaseUsers(t *testing.T, pack *dbhelpers.DatabasePack) { storage, err := clusters.NewStorage(clusters.Config{ Dir: tc.KeysDir, + ClientStore: tc.ClientStore, InsecureSkipVerify: tc.InsecureSkipVerify, }) require.NoError(t, err) diff --git a/lib/client/api.go b/lib/client/api.go index 8d0b1c1268beb..3e06947ef2c8f 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -499,11 +499,6 @@ type Config struct { // SSOMFACeremonyConstructor is a custom SSO MFA ceremony constructor. SSOMFACeremonyConstructor func(rd *sso.Redirector) mfa.SSOMFACeremony - // CustomHardwareKeyPrompt is a custom hardware key prompt to use when asking - // for a hardware key PIN, touch, etc. - // If empty, a default CLI prompt is used. - CustomHardwareKeyPrompt hardwarekey.Prompt - // DisableSSHResumption disables transparent SSH connection resumption. DisableSSHResumption bool @@ -1293,7 +1288,7 @@ func NewClient(c *Config) (tc *TeleportClient, err error) { } else { // TODO (Joerger): init hardware key service (and client store) earlier where it can // be properly shared. - hardwareKeyService := libhwk.NewService(context.TODO(), tc.CustomHardwareKeyPrompt) + hardwareKeyService := libhwk.NewService(context.TODO(), nil /*prompt*/) tc.ClientStore = NewFSClientStore(c.KeysDir, WithHardwareKeyService(hardwareKeyService)) if c.AddKeysToAgent == AddKeysToAgentOnly { // Store client keys in memory, but still save trusted certs and profile to disk. diff --git a/lib/teleterm/clusters/config.go b/lib/teleterm/clusters/config.go index d1410b0eb9890..df2518e137566 100644 --- a/lib/teleterm/clusters/config.go +++ b/lib/teleterm/clusters/config.go @@ -25,7 +25,6 @@ import ( "github.com/jonboulle/clockwork" "github.com/gravitational/teleport" - "github.com/gravitational/teleport/api/utils/keys/hardwarekey" "github.com/gravitational/teleport/lib/client" ) @@ -44,9 +43,8 @@ type Config struct { WebauthnLogin client.WebauthnLoginFunc // AddKeysToAgent is passed to [client.Config]. AddKeysToAgent string - // CustomHardwareKeyPrompt is a custom hardware key prompt to use when asking - // for a hardware key PIN, touch, etc. - CustomHardwareKeyPrompt hardwarekey.Prompt + // ClientStore stores client data. + ClientStore *client.Store } // CheckAndSetDefaults checks the configuration for its validity and sets default values if needed @@ -55,6 +53,10 @@ func (c *Config) CheckAndSetDefaults() error { return trace.BadParameter("missing working directory") } + if c.ClientStore == nil { + return trace.BadParameter("missing client store") + } + if c.Clock == nil { c.Clock = clockwork.NewRealClock() } diff --git a/lib/teleterm/clusters/storage.go b/lib/teleterm/clusters/storage.go index 7ababc194d611..54d6ede1b8f3a 100644 --- a/lib/teleterm/clusters/storage.go +++ b/lib/teleterm/clusters/storage.go @@ -42,9 +42,7 @@ func NewStorage(cfg Config) (*Storage, error) { // ListProfileNames returns just the names of profiles in s.Dir. func (s *Storage) ListProfileNames() ([]string, error) { - profileStore := client.NewFSProfileStore(s.Dir) - pfNames, err := profileStore.ListProfiles() - return pfNames, trace.Wrap(err) + return s.ClientStore.ListProfiles() } // ListRootClusters reads root clusters from profiles. @@ -161,7 +159,7 @@ func (s *Storage) addCluster(ctx context.Context, dir, webProxyAddress string) ( profileName := parseName(webProxyAddress) clusterURI := uri.NewClusterURI(profileName) - cfg := s.makeDefaultClientConfig(clusterURI) + cfg := s.makeClientConfig() cfg.WebProxyAddr = webProxyAddress clusterClient, err := client.NewClient(cfg) @@ -211,10 +209,8 @@ func (s *Storage) fromProfile(profileName, leafClusterName string) (*Cluster, *c clusterNameForKey := profileName clusterURI := uri.NewClusterURI(profileName) - profileStore := client.NewFSProfileStore(s.Dir) - - cfg := s.makeDefaultClientConfig(clusterURI) - if err := cfg.LoadProfile(profileStore, profileName); err != nil { + cfg := s.makeClientConfig() + if err := cfg.LoadProfile(s.ClientStore, profileName); err != nil { return nil, nil, trace.Wrap(err) } @@ -275,15 +271,14 @@ func (s *Storage) loadProfileStatusAndClusterKey(clusterClient *client.TeleportC return status, nil } -func (s *Storage) makeDefaultClientConfig(rootClusterURI uri.ResourceURI) *client.Config { +func (s *Storage) makeClientConfig() *client.Config { cfg := client.MakeDefaultConfig() - cfg.HomePath = s.Dir cfg.KeysDir = s.Dir cfg.InsecureSkipVerify = s.InsecureSkipVerify cfg.AddKeysToAgent = s.AddKeysToAgent cfg.WebauthnLogin = s.WebauthnLogin - cfg.CustomHardwareKeyPrompt = s.CustomHardwareKeyPrompt + cfg.ClientStore = s.ClientStore cfg.DTAuthnRunCeremony = dtauthn.NewCeremony().Run cfg.DTAutoEnroll = dtenroll.AutoEnroll return cfg diff --git a/lib/teleterm/daemon/config.go b/lib/teleterm/daemon/config.go index 3646b78f05e2b..a8b87de4af022 100644 --- a/lib/teleterm/daemon/config.go +++ b/lib/teleterm/daemon/config.go @@ -63,10 +63,13 @@ type Config struct { AgentsDir string GatewayCreator GatewayCreator - // CreateTshdEventsClientCredsFunc lazily creates creds for the tshd events server ran by the - // Electron app. This is to ensure that the server public key is written to the disk under the - // expected location by the time we get around to creating the client. - CreateTshdEventsClientCredsFunc CreateTshdEventsClientCredsFunc + + // TshdEventsClient holds a client to send events to the Electron App. + // + // The startup of the app is orchestrated so that the client is loaded before any other method on + // daemon.Service. This way all the other code in daemon.Service can assume that the tshd events + // client is available right from the beginning, without the need for nil checks. + TshdEventsClient *TshdEventsClient ConnectMyComputerRoleSetup *connectmycomputer.RoleSetup ConnectMyComputerTokenProvisioner *connectmycomputer.TokenProvisioner diff --git a/lib/teleterm/daemon/daemon.go b/lib/teleterm/daemon/daemon.go index 7d27e6d82da8a..7e69dec4deb08 100644 --- a/lib/teleterm/daemon/daemon.go +++ b/lib/teleterm/daemon/daemon.go @@ -26,7 +26,6 @@ import ( "time" "github.com/gravitational/trace" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -89,6 +88,7 @@ func New(cfg Config) (*Service, error) { gateways: make(map[string]gateway.Gateway), usageReporter: connectUsageReporter, headlessWatcherClosers: make(map[string]context.CancelFunc), + headlessAuthSemaphore: newWaitSemaphore(maxConcurrentImportantModals, imporantModalWaitDuraiton), } // TODO(gzdunek): The client cache should be created outside of daemon.New. @@ -119,7 +119,7 @@ func (s *Service) relogin(ctx context.Context, req *api.ReloginRequest) error { timeoutCtx, cancelTshdEventsCtx := context.WithTimeout(ctx, reloginUserTimeout) defer cancelTshdEventsCtx() - if _, err := s.tshdEventsClient.Relogin(timeoutCtx, req); err != nil { + if _, err := s.cfg.TshdEventsClient.client.Relogin(timeoutCtx, req); err != nil { if status.Code(err) == codes.DeadlineExceeded { return trace.Wrap(err, "the user did not refresh the session within %s", reloginUserTimeout.String()) } @@ -308,8 +308,8 @@ func (s *Service) ClusterLogout(ctx context.Context, uri string) error { // CreateGateway creates a gateway to given targetURI func (s *Service) CreateGateway(ctx context.Context, params CreateGatewayParams) (gateway.Gateway, error) { - s.mu.Lock() - defer s.mu.Unlock() + s.gatewaysMu.Lock() + defer s.gatewaysMu.Unlock() gateway, err := s.createGateway(ctx, params) if err != nil { @@ -432,8 +432,8 @@ func (s *Service) reissueGatewayCerts(ctx context.Context, g gateway.Gateway) (t // RemoveGateway removes cluster gateway func (s *Service) RemoveGateway(gatewayURI string) error { - s.mu.Lock() - defer s.mu.Unlock() + s.gatewaysMu.Lock() + defer s.gatewaysMu.Unlock() gateway, err := s.findGateway(gatewayURI) if err != nil { @@ -471,8 +471,8 @@ func (s *Service) findGateway(gatewayURI string) (gateway.Gateway, error) { // ListGateways lists gateways func (s *Service) ListGateways() []gateway.Gateway { - s.mu.RLock() - defer s.mu.RUnlock() + s.gatewaysMu.RLock() + defer s.gatewaysMu.RUnlock() gws := make([]gateway.Gateway, 0, len(s.gateways)) for _, gateway := range s.gateways { @@ -516,8 +516,8 @@ func (s *Service) GetGatewayCLICommand(ctx context.Context, gateway gateway.Gate // SetGatewayTargetSubresourceName updates the TargetSubresourceName field of a gateway stored in // s.gateways. func (s *Service) SetGatewayTargetSubresourceName(ctx context.Context, gatewayURI, targetSubresourceName string) (gateway.Gateway, error) { - s.mu.Lock() - defer s.mu.Unlock() + s.gatewaysMu.Lock() + defer s.gatewaysMu.Unlock() gateway, err := s.findGateway(gatewayURI) if err != nil { @@ -568,8 +568,8 @@ func (s *Service) SetGatewayTargetSubresourceName(ctx context.Context, gatewayUR // // SetGatewayLocalPort is a noop if port is equal to the existing port. func (s *Service) SetGatewayLocalPort(gatewayURI, localPort string) (gateway.Gateway, error) { - s.mu.Lock() - defer s.mu.Unlock() + s.gatewaysMu.Lock() + defer s.gatewaysMu.Unlock() oldGateway, err := s.findGateway(gatewayURI) if err != nil { @@ -830,8 +830,8 @@ func (s *Service) AssumeRole(ctx context.Context, req *api.AssumeRoleRequest) er // // We don't know which gateways are affected by the access request, // so we need to clear certs for all of them. - s.mu.RLock() - defer s.mu.RUnlock() + s.gatewaysMu.RLock() + defer s.gatewaysMu.RUnlock() for _, gw := range s.gateways { targetURI := gw.TargetURI() if !(targetURI.IsKube() && targetURI.GetRootClusterURI() == cluster.URI) { @@ -894,8 +894,8 @@ func (s *Service) ReportUsageEvent(req *api.ReportUsageEventRequest) error { // Stop terminates all cluster open connections func (s *Service) Stop() { - s.mu.RLock() - defer s.mu.RUnlock() + s.gatewaysMu.RLock() + defer s.gatewaysMu.RUnlock() s.cfg.Logger.InfoContext(s.closeContext, "Stopping") @@ -924,43 +924,17 @@ func (s *Service) Stop() { // UpdateAndDialTshdEventsServerAddress allows the Electron app to provide the tshd events server // address. -// -// The startup of the app is orchestrated so that this method is called before any other method on -// daemon.Service. This way all the other code in daemon.Service can assume that the tshd events -// client is available right from the beginning, without the need for nil checks. func (s *Service) UpdateAndDialTshdEventsServerAddress(serverAddress string) error { - s.mu.Lock() - defer s.mu.Unlock() - - withCreds, err := s.cfg.CreateTshdEventsClientCredsFunc() - if err != nil { - return trace.Wrap(err) - } - - conn, err := grpc.Dial(serverAddress, withCreds) - if err != nil { - return trace.Wrap(err) - } - - client := api.NewTshdEventsServiceClient(conn) - - s.tshdEventsClient = client - s.headlessAuthSemaphore = newWaitSemaphore(maxConcurrentImportantModals, imporantModalWaitDuraiton) - - return nil + return s.cfg.TshdEventsClient.Connect(serverAddress) } // TshdEventsClient returns the client if it was initialized earlied by calling // UpdateAndDialTshdEventsServerAddress, otherwise it returns an error. // // The startup of Connect is orchestrated in a way that makes it safe to call this method from any -// RPC. Code inside daemon.Service should just use s.tshdEventsClient directly. -func (s *Service) TshdEventsClient() (api.TshdEventsServiceClient, error) { - if s.tshdEventsClient == nil { - return nil, trace.NotFound("tshd events client has not been initialized yet") - } - - return s.tshdEventsClient, nil +// RPC. Code inside daemon.Service should just use s.cfg.tshdEventsClient directly. +func (s *Service) TshdEventsClient(ctx context.Context) (api.TshdEventsServiceClient, error) { + return s.cfg.TshdEventsClient.GetClient(ctx) } // NotifyApp sends a notification (usually an error) to the Electron App. @@ -968,7 +942,7 @@ func (s *Service) NotifyApp(ctx context.Context, notification *api.SendNotificat tshdEventsCtx, cancelTshdEventsCtx := context.WithTimeout(ctx, tshdEventsTimeout) defer cancelTshdEventsCtx() - _, err := s.tshdEventsClient.SendNotification(tshdEventsCtx, notification) + _, err := s.cfg.TshdEventsClient.client.SendNotification(tshdEventsCtx, notification) return trace.Wrap(err) } @@ -1241,18 +1215,17 @@ func (s *Service) ClearCachedClientsForRoot(clusterURI uri.ResourceURI) error { // Service is the daemon service type Service struct { cfg *Config - // mu guards gateways and the creation of tshdEventsClient. - mu sync.RWMutex // closeContext is canceled when Service is getting stopped. It is used as a context for the calls // to the tshd events gRPC client. closeContext context.Context cancel context.CancelFunc + // gateways holds the long-running gateways for resources on different clusters. So far it's been // used mostly for database gateways but it has potential to be used for app access as well. gateways map[string]gateway.Gateway - // tshdEventsClient is a client to send events to the Electron App. - tshdEventsClient api.TshdEventsServiceClient + // gatewaysMu guards gateways. + gatewaysMu sync.RWMutex // The Electron App can display multiple important modals by showing the latest one and hiding the others. // However, we should be careful with it, and generally try to limit the number of prompts on the tshd side, diff --git a/lib/teleterm/daemon/daemon_headless.go b/lib/teleterm/daemon/daemon_headless.go index 310b853d229c3..2e28de1c3aa6c 100644 --- a/lib/teleterm/daemon/daemon_headless.go +++ b/lib/teleterm/daemon/daemon_headless.go @@ -282,7 +282,7 @@ func (s *Service) sendPendingHeadlessAuthentication(ctx context.Context, ha *typ } defer s.headlessAuthSemaphore.Release() - _, err := s.tshdEventsClient.SendPendingHeadlessAuthentication(ctx, req) + _, err := s.cfg.TshdEventsClient.client.SendPendingHeadlessAuthentication(ctx, req) return trace.Wrap(err) } diff --git a/lib/teleterm/daemon/daemon_test.go b/lib/teleterm/daemon/daemon_test.go index 725b224574aa9..f6e7f2c1576ec 100644 --- a/lib/teleterm/daemon/daemon_test.go +++ b/lib/teleterm/daemon/daemon_test.go @@ -346,6 +346,7 @@ func TestUpdateTshdEventsServerAddress(t *testing.T) { storage, err := clusters.NewStorage(clusters.Config{ Dir: homeDir, + ClientStore: client.NewFSClientStore(homeDir), InsecureSkipVerify: true, }) require.NoError(t, err) @@ -356,11 +357,12 @@ func TestUpdateTshdEventsServerAddress(t *testing.T) { return grpc.WithTransportCredentials(insecure.NewCredentials()), nil } + tshdEventsClient := NewTshdEventsClient(createTshdEventsClientCredsFunc) daemon, err := New(Config{ - Storage: storage, - CreateTshdEventsClientCredsFunc: createTshdEventsClientCredsFunc, - KubeconfigsDir: t.TempDir(), - AgentsDir: t.TempDir(), + Storage: storage, + TshdEventsClient: tshdEventsClient, + KubeconfigsDir: t.TempDir(), + AgentsDir: t.TempDir(), }) require.NoError(t, err) @@ -370,7 +372,7 @@ func TestUpdateTshdEventsServerAddress(t *testing.T) { err = daemon.UpdateAndDialTshdEventsServerAddress(ls.Addr().String()) require.NoError(t, err) - require.NotNil(t, daemon.tshdEventsClient) + require.NotNil(t, daemon.cfg.TshdEventsClient) require.Equal(t, 1, createTshdEventsClientCredsFuncCallCount, "Expected createTshdEventsClientCredsFunc to be called exactly once") } @@ -380,6 +382,7 @@ func TestUpdateTshdEventsServerAddress_CredsErr(t *testing.T) { storage, err := clusters.NewStorage(clusters.Config{ Dir: homeDir, + ClientStore: client.NewFSClientStore(homeDir), InsecureSkipVerify: true, }) require.NoError(t, err) @@ -388,11 +391,12 @@ func TestUpdateTshdEventsServerAddress_CredsErr(t *testing.T) { return nil, trace.Errorf("Error while creating creds") } + tshdEventsClient := NewTshdEventsClient(createTshdEventsClientCredsFunc) daemon, err := New(Config{ - Storage: storage, - CreateTshdEventsClientCredsFunc: createTshdEventsClientCredsFunc, - KubeconfigsDir: t.TempDir(), - AgentsDir: t.TempDir(), + Storage: storage, + TshdEventsClient: tshdEventsClient, + KubeconfigsDir: t.TempDir(), + AgentsDir: t.TempDir(), }) require.NoError(t, err) @@ -479,19 +483,23 @@ func TestRetryWithRelogin(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() + homeDir := t.TempDir() storage, err := clusters.NewStorage(clusters.Config{ - Dir: t.TempDir(), + Dir: homeDir, + ClientStore: client.NewFSClientStore(homeDir), InsecureSkipVerify: true, }) require.NoError(t, err) + tshdEventsClient := NewTshdEventsClient(func() (grpc.DialOption, error) { + return grpc.WithTransportCredentials(insecure.NewCredentials()), nil + }) + daemon, err := New(Config{ - Storage: storage, - CreateTshdEventsClientCredsFunc: func() (grpc.DialOption, error) { - return grpc.WithTransportCredentials(insecure.NewCredentials()), nil - }, - KubeconfigsDir: t.TempDir(), - AgentsDir: t.TempDir(), + Storage: storage, + TshdEventsClient: tshdEventsClient, + KubeconfigsDir: t.TempDir(), + AgentsDir: t.TempDir(), CreateClientCacheFunc: func(newClientFunc clientcache.NewClientFunc) (ClientCache, error) { return fakeClientCache{}, nil }, @@ -532,19 +540,23 @@ func TestConcurrentHeadlessAuthPrompts(t *testing.T) { t.Parallel() ctx := context.Background() + homeDir := t.TempDir() storage, err := clusters.NewStorage(clusters.Config{ - Dir: t.TempDir(), + Dir: homeDir, + ClientStore: client.NewFSClientStore(homeDir), InsecureSkipVerify: true, }) require.NoError(t, err) + tshdEventsClient := NewTshdEventsClient(func() (grpc.DialOption, error) { + return grpc.WithTransportCredentials(insecure.NewCredentials()), nil + }) + daemon, err := New(Config{ - Storage: storage, - CreateTshdEventsClientCredsFunc: func() (grpc.DialOption, error) { - return grpc.WithTransportCredentials(insecure.NewCredentials()), nil - }, - KubeconfigsDir: t.TempDir(), - AgentsDir: t.TempDir(), + Storage: storage, + TshdEventsClient: tshdEventsClient, + KubeconfigsDir: t.TempDir(), + AgentsDir: t.TempDir(), CreateClientCacheFunc: func(newClientFunc clientcache.NewClientFunc) (ClientCache, error) { return fakeClientCache{}, nil }, @@ -688,13 +700,15 @@ func (c *mockTSHDEventsService) SendPendingHeadlessAuthentication(context.Contex func TestGetGatewayCLICommand(t *testing.T) { t.Parallel() + tshdEventsClient := NewTshdEventsClient(func() (grpc.DialOption, error) { + return grpc.WithTransportCredentials(insecure.NewCredentials()), nil + }) + daemon, err := New(Config{ - Storage: fakeStorage{}, - CreateTshdEventsClientCredsFunc: func() (grpc.DialOption, error) { - return grpc.WithTransportCredentials(insecure.NewCredentials()), nil - }, - KubeconfigsDir: t.TempDir(), - AgentsDir: t.TempDir(), + Storage: fakeStorage{}, + TshdEventsClient: tshdEventsClient, + KubeconfigsDir: t.TempDir(), + AgentsDir: t.TempDir(), CreateClientCacheFunc: func(newClientFunc clientcache.NewClientFunc) (ClientCache, error) { return fakeClientCache{}, nil }, diff --git a/lib/teleterm/daemon/events_client.go b/lib/teleterm/daemon/events_client.go new file mode 100644 index 0000000000000..ecbb994b6ece9 --- /dev/null +++ b/lib/teleterm/daemon/events_client.go @@ -0,0 +1,94 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package daemon + +import ( + "context" + "sync" + "time" + + "github.com/gravitational/trace" + "google.golang.org/grpc" + + api "github.com/gravitational/teleport/gen/proto/go/teleport/lib/teleterm/v1" +) + +// TshdEventsClient holds a lazily loaded [api.TshdEventsServiceClient]. +type TshdEventsClient struct { + client api.TshdEventsServiceClient + // connectedChan is closed once the client is connected + connectedChan chan struct{} + // connectMu is used during connection to prevent a race between callers. + connectMu sync.Mutex + + // credsFn lazily creates creds for the tshd events server ran by the Electron app. + // This is to ensure that the server public key is written to the disk under the + // expected location by the time we get around to creating the client. + credsFn CreateTshdEventsClientCredsFunc +} + +func NewTshdEventsClient(credsFn CreateTshdEventsClientCredsFunc) *TshdEventsClient { + return &TshdEventsClient{ + credsFn: credsFn, + connectedChan: make(chan struct{}), + } +} + +// Connect connects to the given server address. +func (c *TshdEventsClient) Connect(serverAddress string) error { + c.connectMu.Lock() + defer c.connectMu.Unlock() + + select { + case <-c.connectedChan: + // already connected, no-op. + return nil + default: + } + + withCreds, err := c.credsFn() + if err != nil { + return trace.Wrap(err) + } + + conn, err := grpc.NewClient(serverAddress, withCreds) + if err != nil { + return trace.Wrap(err) + } + + // Successfully connected set the client and signal to any waiters. + c.client = api.NewTshdEventsServiceClient(conn) + close(c.connectedChan) + return nil +} + +// GetClient retrieves the lazily loaded client. If the client is not yet loaded, +// this method will wait until it is loaded, the given context is closed, or it +// times out. +func (c *TshdEventsClient) GetClient(ctx context.Context) (api.TshdEventsServiceClient, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + select { + case <-c.connectedChan: + return c.client, nil + case <-ctx.Done(): + return nil, trace.Wrap(ctx.Err(), "tshd events client has not been initialized yet") + } +} diff --git a/lib/teleterm/daemon/events_client_test.go b/lib/teleterm/daemon/events_client_test.go new file mode 100644 index 0000000000000..550a4167b2545 --- /dev/null +++ b/lib/teleterm/daemon/events_client_test.go @@ -0,0 +1,88 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package daemon + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + api "github.com/gravitational/teleport/gen/proto/go/teleport/lib/teleterm/v1" +) + +func TestTshdEventsClient(t *testing.T) { + t.Parallel() + + ctx := context.Background() + _, addr := newMockTSHDEventsServiceServer(t) + + c := NewTshdEventsClient(func() (grpc.DialOption, error) { + return grpc.WithTransportCredentials(insecure.NewCredentials()), nil + }) + + // GetClient should timeout if client is not connected. + timeoutCtx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) + defer cancel() + + _, err := c.GetClient(timeoutCtx) + require.ErrorIs(t, err, context.DeadlineExceeded) + + // Make 2 calls to GetClient to wait for a client connection. + timeoutCtx, cancel = context.WithTimeout(ctx, 5*time.Second) + defer cancel() + type getClientRet struct { + clt api.TshdEventsServiceClient + err error + } + retC := make(chan getClientRet) + for range 2 { + go func() { + clt, err := c.GetClient(timeoutCtx) + retC <- getClientRet{ + clt: clt, + err: err, + } + }() + } + + // Connect client, GetClient calls should complete. + err = c.Connect(addr) + require.NoError(t, err) + + for range 2 { + select { + case <-timeoutCtx.Done(): + t.Error("timeout waiting for client connection") + case ret := <-retC: + require.NoError(t, ret.err) + require.NotNil(t, ret.clt) + } + } + + // GetClient should complete immediately once connected. + timeoutCtx, cancel = context.WithTimeout(ctx, 500*time.Millisecond) + defer cancel() + + _, err = c.GetClient(timeoutCtx) + require.NoError(t, err) +} diff --git a/lib/teleterm/daemon/hardwarekeyprompt.go b/lib/teleterm/daemon/hardwarekeyprompt.go index ded69b1b32d66..035c59d09f47f 100644 --- a/lib/teleterm/daemon/hardwarekeyprompt.go +++ b/lib/teleterm/daemon/hardwarekeyprompt.go @@ -45,12 +45,12 @@ import ( // Because the code in yubikey.go assumes you use a single key, we don't have any mutex here. // (unlike other modals triggered by tshd). // We don't expect receiving prompts from different hardware keys. -func (s *Service) NewHardwareKeyPrompt() hardwarekey.Prompt { - return &hardwareKeyPrompter{s: s} +func (c *TshdEventsClient) NewHardwareKeyPrompt() hardwarekey.Prompt { + return &hardwareKeyPrompter{c: c} } type hardwareKeyPrompter struct { - s *Service + c *TshdEventsClient } // Touch prompts the user to touch the hardware key. @@ -60,7 +60,12 @@ func (h *hardwareKeyPrompter) Touch(ctx context.Context, keyInfo hardwarekey.Con keyInfo.Command = "" } - _, err := h.s.tshdEventsClient.PromptHardwareKeyTouch(ctx, &api.PromptHardwareKeyTouchRequest{ + clt, err := h.c.GetClient(ctx) + if err != nil { + return trace.Wrap(err) + } + + _, err = clt.PromptHardwareKeyTouch(ctx, &api.PromptHardwareKeyTouchRequest{ ProxyHostname: keyInfo.ProxyHost, Command: keyInfo.Command, }) @@ -77,7 +82,12 @@ func (h *hardwareKeyPrompter) AskPIN(ctx context.Context, requirement hardwareke keyInfo.Command = "" } - res, err := h.s.tshdEventsClient.PromptHardwareKeyPIN(ctx, &api.PromptHardwareKeyPINRequest{ + clt, err := h.c.GetClient(ctx) + if err != nil { + return "", trace.Wrap(err) + } + + res, err := clt.PromptHardwareKeyPIN(ctx, &api.PromptHardwareKeyPINRequest{ ProxyHostname: keyInfo.ProxyHost, PinOptional: requirement == hardwarekey.PINOptional, Command: keyInfo.Command, @@ -92,7 +102,12 @@ func (h *hardwareKeyPrompter) AskPIN(ctx context.Context, requirement hardwareke // The Electron app prompt must handle default values for PIN and PUK, // preventing the user from submitting empty/default values. func (h *hardwareKeyPrompter) ChangePIN(ctx context.Context, keyInfo hardwarekey.ContextualKeyInfo) (*hardwarekey.PINAndPUK, error) { - res, err := h.s.tshdEventsClient.PromptHardwareKeyPINChange(ctx, &api.PromptHardwareKeyPINChangeRequest{ + clt, err := h.c.GetClient(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + res, err := clt.PromptHardwareKeyPINChange(ctx, &api.PromptHardwareKeyPINChangeRequest{ ProxyHostname: keyInfo.ProxyHost, }) if err != nil { @@ -107,7 +122,12 @@ func (h *hardwareKeyPrompter) ChangePIN(ctx context.Context, keyInfo hardwarekey // ConfirmSlotOverwrite asks the user if the slot's private key and certificate can be overridden. func (h *hardwareKeyPrompter) ConfirmSlotOverwrite(ctx context.Context, message string, keyInfo hardwarekey.ContextualKeyInfo) (bool, error) { - res, err := h.s.tshdEventsClient.ConfirmHardwareKeySlotOverwrite(ctx, &api.ConfirmHardwareKeySlotOverwriteRequest{ + clt, err := h.c.GetClient(ctx) + if err != nil { + return false, trace.Wrap(err) + } + + res, err := clt.ConfirmHardwareKeySlotOverwrite(ctx, &api.ConfirmHardwareKeySlotOverwriteRequest{ ProxyHostname: keyInfo.ProxyHost, Message: message, }) diff --git a/lib/teleterm/daemon/mfaprompt.go b/lib/teleterm/daemon/mfaprompt.go index 7975a23956597..012234ce45022 100644 --- a/lib/teleterm/daemon/mfaprompt.go +++ b/lib/teleterm/daemon/mfaprompt.go @@ -67,7 +67,7 @@ func (s *Service) promptAppMFA(ctx context.Context, in *api.PromptMFARequest) (* s.mfaMu.Lock() defer s.mfaMu.Unlock() - return s.tshdEventsClient.PromptMFA(ctx, in) + return s.cfg.TshdEventsClient.client.PromptMFA(ctx, in) } // Run prompts the user to complete an MFA authentication challenge. diff --git a/lib/teleterm/teleterm.go b/lib/teleterm/teleterm.go index 75cffb256e238..4c039396f4437 100644 --- a/lib/teleterm/teleterm.go +++ b/lib/teleterm/teleterm.go @@ -33,6 +33,7 @@ import ( "google.golang.org/grpc/credentials/insecure" "github.com/gravitational/teleport/api/utils/keys/piv" + "github.com/gravitational/teleport/lib/client" libhwk "github.com/gravitational/teleport/lib/hardwarekey" "github.com/gravitational/teleport/lib/teleterm/apiserver" "github.com/gravitational/teleport/lib/teleterm/clusteridcache" @@ -53,11 +54,18 @@ func Serve(ctx context.Context, cfg Config) error { clock := clockwork.NewRealClock() + // Prepare tshdEventsClient with lazy loading. + tshdEventsClient := daemon.NewTshdEventsClient(grpcCredentials.tshdEvents) + + // Always use the direct YubiKey PIV service since Connect provides the best UX. + hwks := piv.NewYubiKeyService(tshdEventsClient.NewHardwareKeyPrompt()) + storage, err := clusters.NewStorage(clusters.Config{ Dir: cfg.HomeDir, Clock: clock, InsecureSkipVerify: cfg.InsecureSkipVerify, AddKeysToAgent: cfg.AddKeysToAgent, + ClientStore: client.NewFSClientStore(cfg.HomeDir, client.WithHardwareKeyService(hwks)), }) if err != nil { return trace.Wrap(err) @@ -66,21 +74,17 @@ func Serve(ctx context.Context, cfg Config) error { clusterIDCache := &clusteridcache.Cache{} daemonService, err := daemon.New(daemon.Config{ - Storage: storage, - CreateTshdEventsClientCredsFunc: grpcCredentials.tshdEvents, - PrehogAddr: cfg.PrehogAddr, - KubeconfigsDir: cfg.KubeconfigsDir, - AgentsDir: cfg.AgentsDir, - ClusterIDCache: clusterIDCache, + Storage: storage, + PrehogAddr: cfg.PrehogAddr, + KubeconfigsDir: cfg.KubeconfigsDir, + AgentsDir: cfg.AgentsDir, + ClusterIDCache: clusterIDCache, + TshdEventsClient: tshdEventsClient, }) if err != nil { return trace.Wrap(err) } - // TODO(gzdunek): Move tshdEventsClient out of daemonService so that we can - // construct the prompt before creating Storage. - storage.CustomHardwareKeyPrompt = daemonService.NewHardwareKeyPrompt() - apiServer, err := apiserver.New(apiserver.Config{ HostAddr: cfg.Addr, InsecureSkipVerify: cfg.InsecureSkipVerify, @@ -103,8 +107,7 @@ func Serve(ctx context.Context, cfg Config) error { var hardwareKeyAgentServer *libhwk.Server if cfg.HardwareKeyAgent { - hardwareKeyService := piv.NewYubiKeyService(daemonService.NewHardwareKeyPrompt()) - hardwareKeyAgentServer, err = libhwk.NewAgentServer(ctx, hardwareKeyService, libhwk.DefaultAgentDir()) + hardwareKeyAgentServer, err = libhwk.NewAgentServer(ctx, hwks, libhwk.DefaultAgentDir()) if err != nil { slog.WarnContext(ctx, "failed to create the hardware key agent server", "err", err) } else { diff --git a/lib/teleterm/vnet/service.go b/lib/teleterm/vnet/service.go index 75c5424fefa0c..a785e4a6ed039 100644 --- a/lib/teleterm/vnet/service.go +++ b/lib/teleterm/vnet/service.go @@ -332,7 +332,7 @@ func (s *Service) Close() error { } func (s *Service) isUsageReportingEnabled(ctx context.Context) (bool, error) { - tshdEventsClient, err := s.cfg.DaemonService.TshdEventsClient() + tshdEventsClient, err := s.cfg.DaemonService.TshdEventsClient(ctx) if err != nil { return false, trace.Wrap(err) } @@ -346,7 +346,7 @@ func (s *Service) isUsageReportingEnabled(ctx context.Context) (bool, error) { } func (s *Service) reportUnexpectedShutdown(ctx context.Context, shutdownErr error) error { - tshdEventsClient, err := s.cfg.DaemonService.TshdEventsClient() + tshdEventsClient, err := s.cfg.DaemonService.TshdEventsClient(ctx) if err != nil { return trace.Wrap(err, "obtaining tshd events client") }