Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion api/profile/profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func (p *Profile) TLSConfig() (*tls.Config, error) {

// Expiry returns the credential expiry.
func (p *Profile) Expiry() (time.Time, bool) {
certPEMBlock, err := os.ReadFile(p.TLSCertPath())
certPEMBlock, err := p.TLSCert()
if err != nil {
return time.Time{}, false
}
Expand All @@ -188,6 +188,12 @@ func (p *Profile) Expiry() (time.Time, bool) {
return cert.NotAfter, true
}

// TLSCert returns the profile's TLS certificate.
func (p *Profile) TLSCert() ([]byte, error) {
certPEMBlock, err := os.ReadFile(p.TLSCertPath())
return certPEMBlock, trace.Wrap(err)
}

// RequireKubeLocalProxy returns true if this profile indicates a local proxy
// is required for kube access.
func (p *Profile) RequireKubeLocalProxy() bool {
Expand Down
5 changes: 5 additions & 0 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,9 @@ type Config struct {

// SSOHost is the host of the SSO provider used to log in.
SSOHost string

// Dir is the directory of the profile.
Dir string
}

// CachePolicy defines cache policy for local clients
Expand Down Expand Up @@ -953,6 +956,7 @@ func (c *Config) LoadProfile(proxyAddr string) error {
c.SAMLSingleLogoutEnabled = profile.SAMLSingleLogoutEnabled
c.SSHDialTimeout = profile.SSHDialTimeout
c.SSOHost = profile.SSOHost
c.Dir = profile.Dir

c.AuthenticatorAttachment, err = parseMFAMode(profile.MFAMode)
if err != nil {
Expand Down Expand Up @@ -1021,6 +1025,7 @@ func (c *Config) Profile() *profile.Profile {
SAMLSingleLogoutEnabled: c.SAMLSingleLogoutEnabled,
SSHDialTimeout: c.SSHDialTimeout,
SSOHost: c.SSOHost,
Dir: c.Dir,
}
}

Expand Down
74 changes: 64 additions & 10 deletions lib/client/clientcache/clientcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package clientcache

import (
"bytes"
"context"
"log/slog"
"slices"
Expand All @@ -36,11 +37,30 @@ type Cache struct {
cfg Config
mu sync.RWMutex
// clients keeps a mapping from key (profile name and leaf cluster name) to cluster client.
clients map[key]*client.ClusterClient
clients map[key]*clientWithCert
// group prevents duplicate requests to create clients for a given cluster.
group singleflight.Group
}

type clientWithCert struct {
// client is cluster client.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can drop this comment hehe.

client *client.ClusterClient
// coreTLSCert is the cert used in TeleportClient.ConnectToCluster to create the client.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// coreTLSCert is the cert used in TeleportClient.ConnectToCluster to create the client.
// coreTLSCert is the cert used in [client.TeleportClient.ConnectToCluster] to create the client.

or maybe even

Suggested change
// coreTLSCert is the cert used in TeleportClient.ConnectToCluster to create the client.
// coreTLSCert is the contents of the cert at the time of creating the client.

coreTLSCert []byte
// readCoreTLSCert reads a fresh cert from disk.
readCoreTLSCert func() ([]byte, error)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's very Go-like, I assume the more idiomatic way to do it is to create a profile interface with TLSCert method.

}

func (c *clientWithCert) isCoreTLSCertUnchanged() (bool, error) {
tlsCert, err := c.readCoreTLSCert()
if err != nil {
return false, trace.Wrap(err)
}

equal := bytes.Equal(c.coreTLSCert, tlsCert)
return equal, nil
}

// NewClientFunc is a function that will return a new [*client.TeleportClient] for a given profile and leaf
// cluster. [leafClusterName] may be empty, in which case implementations should return a client for the root cluster.
type NewClientFunc func(ctx context.Context, profileName, leafClusterName string) (*client.TeleportClient, error)
Expand Down Expand Up @@ -89,7 +109,7 @@ func New(c Config) (*Cache, error) {

return &Cache{
cfg: c,
clients: make(map[key]*client.ClusterClient),
clients: make(map[key]*clientWithCert),
}, nil
}

Expand All @@ -99,8 +119,19 @@ func (c *Cache) Get(ctx context.Context, profileName, leafClusterName string) (*
k := key{profile: profileName, leafCluster: leafClusterName}
groupClt, err, _ := c.group.Do(k.String(), func() (any, error) {
if fromCache := c.getFromCache(k); fromCache != nil {
c.cfg.Logger.DebugContext(ctx, "Retrieved client from cache", "cluster", k)
return fromCache, nil
unchanged, err := fromCache.isCoreTLSCertUnchanged()
if err != nil {
c.cfg.Logger.WarnContext(ctx, "Failed to validate TLS certificate, removing from cache", "cluster", k, "error", err)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
c.cfg.Logger.WarnContext(ctx, "Failed to validate TLS certificate, removing from cache", "cluster", k, "error", err)
c.cfg.Logger.WarnContext(ctx, "Failed to check if TLS certificate has changed, removing client from cache", "cluster", k, "error", err)

} else if unchanged {
c.cfg.Logger.DebugContext(ctx, "Retrieved client from cache", "cluster", k)
return fromCache.client, nil
} else {
c.cfg.Logger.DebugContext(ctx, "TLS certificate for cached client has changed, removing from cache", "cluster", k)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
c.cfg.Logger.DebugContext(ctx, "TLS certificate for cached client has changed, removing from cache", "cluster", k)
c.cfg.Logger.DebugContext(ctx, "TLS certificate for cached client has changed, removing client from cache", "cluster", k)

}

if err := c.clearForKey(k); err != nil {
return nil, trace.Wrap(err)
}
}

tc, err := c.cfg.NewClientFunc(ctx, profileName, leafClusterName)
Expand All @@ -120,8 +151,17 @@ func (c *Cache) Get(ctx context.Context, profileName, leafClusterName string) (*
return nil, trace.Wrap(err)
}

keyRing, err := tc.LocalAgent().GetCoreKeyRing()
if err != nil {
return nil, trace.Wrap(err)
}

// Save the client in the cache, so we don't have to build a new connection next time.
c.addToCache(k, newClient)
c.addToCache(k, &clientWithCert{
client: newClient,
coreTLSCert: keyRing.TLSCert,
readCoreTLSCert: tc.Profile().TLSCert,
})

c.cfg.Logger.InfoContext(ctx, "Added client to cache", "cluster", k)

Expand Down Expand Up @@ -151,7 +191,7 @@ func (c *Cache) ClearForRoot(profileName string) error {

for k, clt := range c.clients {
if k.profile == profileName {
if err := clt.Close(); err != nil {
if err := clt.client.Close(); err != nil {
errors = append(errors, err)
}
deleted = append(deleted, k.String())
Expand All @@ -175,7 +215,7 @@ func (c *Cache) Clear() error {

var errors []error
for _, clt := range c.clients {
if err := clt.Close(); err != nil {
if err := clt.client.Close(); err != nil {
errors = append(errors, err)
}
}
Expand All @@ -184,14 +224,28 @@ func (c *Cache) Clear() error {
return trace.NewAggregate(errors...)
}

func (c *Cache) addToCache(k key, clusterClient *client.ClusterClient) {
func (c *Cache) clearForKey(k key) error {
c.mu.Lock()
defer c.mu.Unlock()

clt, ok := c.clients[k]
delete(c.clients, k)
if ok {
err := clt.client.Close()
return trace.Wrap(err)
}

return nil
}

func (c *Cache) addToCache(k key, cc *clientWithCert) {
c.mu.Lock()
defer c.mu.Unlock()

c.clients[k] = clusterClient
c.clients[k] = cc
}

func (c *Cache) getFromCache(k key) *client.ClusterClient {
func (c *Cache) getFromCache(k key) *clientWithCert {
c.mu.RLock()
defer c.mu.RUnlock()

Expand Down
8 changes: 0 additions & 8 deletions lib/teleterm/apiserver/handler/handler_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ func (s *Handler) Login(ctx context.Context, req *api.LoginRequest) (*api.EmptyR
return nil, trace.BadParameter("cluster URI must be a root URI")
}

if err = s.DaemonService.ClearCachedClientsForRoot(cluster.URI); err != nil {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the removal of these constitute a significant change in behavior?

In the previous version, after logging in or assuming a role, all clients created so far would be closed. With the new behavior, no client will be closed after one of those operations until another RPC attempts to get a client from the cache.

Does this have any significance? It feels like it mostly could affect leaf clients as those might continue to not be closed well after a relogin / role assumption.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In #59760, I tried re-enabling the client cache in tests. Currently there's a test which just straight up fails (#59760 (comment)). I think it's because in the real world, we depend on the cache being cleared during the login RPCs. In tests however we circumvent those RPCs completely (because we don't have access to user credentials so we depend on test helpers to generate new certs on disk).

The test passes when I manually clear the cache on relogin (see the diff). I think once this PR is merged and when the cache no longer depends on being manually cleared, then we should be able to re-enable it in integration tests without such workarounds like the one in the diff.

Diff
diff --git a/integration/proxy/teleterm_test.go b/integration/proxy/teleterm_test.go
index ee19d22c1d2..02b5112a9cb 100644
--- a/integration/proxy/teleterm_test.go
+++ b/integration/proxy/teleterm_test.go
@@ -264,6 +264,7 @@ func testGatewayCertRenewal(ctx context.Context, t *testing.T, params gatewayCer
 	t.Cleanup(func() {
 		daemonService.Stop()
 	})
+	tshdEventsService.daemonService = daemonService
 
 	// Connect the daemon to the tshd events service, like it would
 	// during normal initialization of the app.
@@ -320,6 +321,7 @@ type mockTSHDEventsService struct {
 	sendNotificationCallCount atomic.Uint32
 	promptMFACallCount        atomic.Uint32
 	generateAndSetupUserCreds generateAndSetupUserCredsFunc
+	daemonService             *daemon.Service
 }
 
 func newMockTSHDEventsServiceServer(t *testing.T, tc *libclient.TeleportClient, generateAndSetupUserCreds generateAndSetupUserCredsFunc) (service *mockTSHDEventsService) {
@@ -360,11 +362,20 @@ func newMockTSHDEventsServiceServer(t *testing.T, tc *libclient.TeleportClient,
 
 // Relogin simulates the act of the user logging in again in the Electron app by replacing the user
 // cert on disk with a valid one.
-func (c *mockTSHDEventsService) Relogin(context.Context, *api.ReloginRequest) (*api.ReloginResponse, error) {
+func (c *mockTSHDEventsService) Relogin(ctx context.Context, req *api.ReloginRequest) (*api.ReloginResponse, error) {
 	c.reloginCallCount.Add(1)
 
 	// Generate valid certs with the default TTL.
 	c.generateAndSetupUserCreds(c.t, c.tc, 0 /* ttl */)
+	if c.daemonService != nil {
+		clusterURI, err := uri.Parse(req.RootClusterUri)
+		if err != nil {
+			return nil, err
+		}
+		if err := c.daemonService.ClearCachedClientsForRoot(clusterURI); err != nil {
+			return nil, err
+		}
+	}
 
 	return &api.ReloginResponse{}, nil
 }

return nil, trace.Wrap(err)
}

if req.Params == nil {
return nil, trace.BadParameter("missing login parameters")
}
Expand Down Expand Up @@ -91,10 +87,6 @@ func (s *Handler) LoginPasswordless(stream api.TerminalService_LoginPasswordless
// daemon.Service.ResolveClusterURI.
clusterClient.MFAPromptConstructor = nil

if err := s.DaemonService.ClearCachedClientsForRoot(cluster.URI); err != nil {
return trace.Wrap(err)
}

// Start the prompt flow.
if err := cluster.PasswordlessLogin(stream.Context(), stream); err != nil {
return trace.Wrap(err)
Expand Down
3 changes: 1 addition & 2 deletions lib/teleterm/daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -885,8 +885,7 @@ func (s *Service) AssumeRole(ctx context.Context, req *api.AssumeRoleRequest) er
kubeGw.ClearCerts()
}

// We have to reconnect using the updated cert.
return trace.Wrap(s.ClearCachedClientsForRoot(cluster.URI))
return nil
}

// ListKubernetesResourcesRequest defines a request to retrieve kube resources paginated.
Expand Down
9 changes: 0 additions & 9 deletions tool/tsh/common/vnet_client_application.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import (
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/client/clientcache"
libhwk "github.com/gravitational/teleport/lib/hardwarekey"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/vnet"
)

Expand Down Expand Up @@ -173,11 +172,6 @@ func (p *vnetClientApplication) getRootClusterCACertPoolPEM(ctx context.Context,
}

func (p *vnetClientApplication) retryWithRelogin(ctx context.Context, tc *client.TeleportClient, fn func() error, opts ...client.RetryWithReloginOption) error {
profileName, err := utils.Host(tc.WebProxyAddr)
if err != nil {
return trace.Wrap(err)
}

// Make sure the release the login mutex if we end up acquiring it.
didLock := false
defer func() {
Expand All @@ -200,9 +194,6 @@ func (p *vnetClientApplication) retryWithRelogin(ctx context.Context, tc *client
fmt.Printf("Login for cluster %s expired, attempting to log in again.\n", tc.SiteName)
return nil
}),
client.WithAfterLoginHook(func() error {
return trace.Wrap(p.clientCache.ClearForRoot(profileName), "clearing client cache after relogin")
}),
client.WithMakeCurrentProfile(false),
)
return client.RetryWithRelogin(ctx, tc, fn, opts...)
Expand Down
Loading