diff --git a/api/utils/keys/privatekey.go b/api/utils/keys/privatekey.go index 356e9153cd8e8..e3ecc7881756d 100644 --- a/api/utils/keys/privatekey.go +++ b/api/utils/keys/privatekey.go @@ -20,6 +20,8 @@ package keys import ( "bytes" "crypto" + "crypto/ecdsa" + "crypto/ed25519" "crypto/rsa" "crypto/tls" "crypto/x509" @@ -212,6 +214,31 @@ func ParsePrivateKey(keyPEM []byte) (*PrivateKey, error) { } } +// MarshalPrivateKey will return a PEM encoded crypto.Signer. +// Only supports rsa, ecdsa, and ed25519 keys. +func MarshalPrivateKey(key crypto.Signer) ([]byte, error) { + switch privateKey := key.(type) { + case *rsa.PrivateKey: + privPEM := pem.EncodeToMemory(&pem.Block{ + Type: PKCS1PrivateKeyType, + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + }) + return privPEM, nil + case *ecdsa.PrivateKey, *ed25519.PrivateKey: + der, err := x509.MarshalPKCS8PrivateKey(privateKey) + if err != nil { + return nil, trace.Wrap(err) + } + privPEM := pem.EncodeToMemory(&pem.Block{ + Type: PKCS8PrivateKeyType, + Bytes: der, + }) + return privPEM, nil + default: + return nil, trace.BadParameter("unsupported private key type %T", key) + } +} + // LoadKeyPair returns the PrivateKey for the given private and public key files. func LoadKeyPair(privFile, sshPubFile string) (*PrivateKey, error) { privPEM, err := os.ReadFile(privFile) diff --git a/lib/client/interfaces.go b/lib/client/interfaces.go index 9a12d2b7ceabb..f612fa737073e 100644 --- a/lib/client/interfaces.go +++ b/lib/client/interfaces.go @@ -223,6 +223,9 @@ func (k *Key) authorizedHostKeys(hostnames ...string) ([]ssh.PublicKey, error) { // TeleportClientTLSConfig returns client TLS configuration used // to authenticate against API servers. func (k *Key) TeleportClientTLSConfig(cipherSuites []uint16, clusters []string) (*tls.Config, error) { + if len(k.TLSCert) == 0 { + return nil, trace.NotFound("TLS certificate not found") + } return k.clientTLSConfig(cipherSuites, k.TLSCert, clusters) } @@ -399,6 +402,9 @@ func canAddToSystemAgent(agentKey agent.AddedKey) bool { // TeleportTLSCertificate returns the parsed x509 certificate for // authentication against Teleport APIs. func (k *Key) TeleportTLSCertificate() (*x509.Certificate, error) { + if len(k.TLSCert) == 0 { + return nil, trace.NotFound("TLS certificate not found") + } return tlsca.ParseCertificatePEM(k.TLSCert) } @@ -491,7 +497,7 @@ func (k *Key) SSHSigner() (ssh.Signer, error) { // SSHCert returns parsed SSH certificate func (k *Key) SSHCert() (*ssh.Certificate, error) { if k.Cert == nil { - return nil, trace.NotFound("SSH cert not available") + return nil, trace.NotFound("SSH cert not found") } return sshutils.ParseCertificate(k.Cert) } diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index 5de21c69d052d..f7bf8dbb3c87f 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -2030,7 +2030,7 @@ func onLogin(cf *CLIConf) error { func onLogout(cf *CLIConf) error { // Extract all clusters the user is currently logged into. active, available, err := cf.FullProfileStatus() - if err != nil { + if err != nil && !trace.IsCompareFailed(err) { if trace.IsNotFound(err) { fmt.Printf("All users logged out.\n") return nil @@ -2061,7 +2061,7 @@ func onLogout(cf *CLIConf) error { // Load profile for the requested proxy/user. profile, err := tc.ProfileStatus() - if err != nil && !trace.IsNotFound(err) { + if err != nil && !trace.IsNotFound(err) && !trace.IsCompareFailed(err) { return trace.Wrap(err) } @@ -3389,7 +3389,7 @@ func makeClientForProxy(cf *CLIConf, proxy string) (*client.TeleportClient, erro profile, profileError := c.GetProfile(c.ClientStore, proxy) if profileError == nil { if err := tc.LoadKeyForCluster(ctx, profile.SiteName); err != nil { - if !trace.IsNotFound(err) && !trace.IsConnectionProblem(err) { + if !trace.IsNotFound(err) && !trace.IsConnectionProblem(err) && !trace.IsCompareFailed(err) { return nil, trace.Wrap(err) } log.WithError(err).Infof("Could not load key for %s into the local agent.", cf.SiteName) diff --git a/tool/tsh/common/tsh_test.go b/tool/tsh/common/tsh_test.go index d30a4b8e54cf1..09a6e49c7dd24 100644 --- a/tool/tsh/common/tsh_test.go +++ b/tool/tsh/common/tsh_test.go @@ -21,6 +21,9 @@ import ( "bytes" "context" "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" "encoding/json" "errors" "fmt" @@ -44,6 +47,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" otlp "go.opentelemetry.io/proto/otlp/trace/v1" + "golang.org/x/crypto/ssh" "golang.org/x/exp/slices" yamlv2 "gopkg.in/yaml.v2" @@ -5058,3 +5062,82 @@ func TestBenchmarkMySQL(t *testing.T) { }) } } + +func TestLogout(t *testing.T) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + privPEM, err := keys.MarshalPrivateKey(key) + require.NoError(t, err) + privateKey, err := keys.NewPrivateKey(key, privPEM) + require.NoError(t, err) + clientKey := &client.Key{ + KeyIndex: client.KeyIndex{ + ProxyHost: "proxy", + Username: "user", + ClusterName: "cluster", + }, + PrivateKey: privateKey, + } + profile := &profile.Profile{ + WebProxyAddr: clientKey.ProxyHost, + Username: clientKey.Username, + SiteName: clientKey.ClusterName, + } + + for _, tt := range []struct { + name string + modifyKeyDir func(t *testing.T, homePath string) + }{ + { + name: "normal home dir", + modifyKeyDir: func(t *testing.T, homePath string) {}, + }, { + name: "public key missing", + modifyKeyDir: func(t *testing.T, homePath string) { + pubKeyPath := keypaths.PublicKeyPath(homePath, clientKey.ProxyHost, clientKey.Username) + require.NoError(t, os.Remove(pubKeyPath)) + }, + }, { + name: "private key missing", + modifyKeyDir: func(t *testing.T, homePath string) { + privKeyPath := keypaths.UserKeyPath(homePath, clientKey.ProxyHost, clientKey.Username) + require.NoError(t, os.Remove(privKeyPath)) + }, + }, { + name: "public key mismatch", + modifyKeyDir: func(t *testing.T, homePath string) { + newKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + sshPub, err := ssh.NewPublicKey(newKey.Public()) + require.NoError(t, err) + + pubKeyPath := keypaths.PublicKeyPath(homePath, clientKey.ProxyHost, clientKey.Username) + err = os.WriteFile(pubKeyPath, ssh.MarshalAuthorizedKey(sshPub), 0600) + require.NoError(t, err) + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + tmpHomePath := t.TempDir() + + store := client.NewFSClientStore(tmpHomePath) + err = store.AddKey(clientKey) + require.NoError(t, err) + store.SaveProfile(profile, true) + + tt.modifyKeyDir(t, tmpHomePath) + + _, err := os.Lstat(tmpHomePath) + require.NoError(t, err) + + err = Run(context.Background(), []string{"logout"}, setHomePath(tmpHomePath)) + require.NoError(t, err) + + // direcory should be empty. + f, err := os.Open(tmpHomePath) + require.NoError(t, err) + _, err = f.Readdir(1) + require.ErrorIs(t, err, io.EOF) + }) + } +}