diff --git a/api/utils/keypaths/keypaths.go b/api/utils/keypaths/keypaths.go index ff619e49ebd8a..cd42d96bc7aa3 100644 --- a/api/utils/keypaths/keypaths.go +++ b/api/utils/keypaths/keypaths.go @@ -29,7 +29,7 @@ const ( // sessionKeyDir is a sub-directory where session keys are stored sessionKeyDir = "keys" // sshDirSuffix is the suffix of a sub-directory where SSH certificates are stored. - sshDirSuffix = "-ssh" + SSHDirSuffix = "-ssh" // fileNameKnownHosts is a file where known hosts are stored. fileNameKnownHosts = "known_hosts" // FileExtTLSCertLegacy is the legacy suffix/extension of a file where a TLS cert is stored. @@ -257,7 +257,7 @@ func TLSCAsPathCluster(baseDir, proxy, cluster string) string { // // /keys//-ssh func SSHDir(baseDir, proxy, username string) string { - return filepath.Join(ProxyKeyDir(baseDir, proxy), username+sshDirSuffix) + return filepath.Join(ProxyKeyDir(baseDir, proxy), username+SSHDirSuffix) } // PPKFilePath returns the path to the user's PuTTY PPK-formatted keypair diff --git a/lib/client/keystore.go b/lib/client/keystore.go index 06f930286705b..5c3fbe1ac9c8e 100644 --- a/lib/client/keystore.go +++ b/lib/client/keystore.go @@ -25,9 +25,11 @@ import ( "fmt" iofs "io/fs" "log/slog" + "maps" "os" "path/filepath" "runtime" + "slices" "strings" "time" @@ -89,6 +91,10 @@ type KeyStore interface { // GetSSHCertificates gets all certificates signed for the given user and proxy, // including certificates for trusted clusters. GetSSHCertificates(proxyHost, username string) ([]*ssh.Certificate, error) + + // GetIdentities returns the usernames associated to signed user certificates + // for the given proxy in the keystore. + GetIdentities(proxyHost string) ([]string, error) } // FSKeyStore is an on-disk implementation of the KeyStore interface. @@ -113,6 +119,11 @@ func NewFSKeyStore(dirPath string) *FSKeyStore { } } +// proxyKeyDir returns the path to the given proxy's keys directory. +func (fs *FSKeyStore) proxyKeyDir(proxy string) string { + return keypaths.ProxyKeyDir(fs.KeyDir, proxy) +} + // userSSHKeyPath returns the SSH private key path for the given KeyRingIndex. func (fs *FSKeyStore) userSSHKeyPath(idx KeyRingIndex) string { return keypaths.UserSSHKeyPath(fs.KeyDir, idx.ProxyHost, idx.Username) @@ -603,6 +614,27 @@ func (fs *FSKeyStore) GetSSHCertificates(proxyHost, username string) ([]*ssh.Cer return sshCerts, nil } +// GetIdentities returns the usernames associated to signed user certificates +// for the given proxy in the keystore. +func (fs *FSKeyStore) GetIdentities(proxyHost string) ([]string, error) { + proxyDir := fs.proxyKeyDir(proxyHost) + files, err := os.ReadDir(proxyDir) + if err != nil { + return nil, trace.ConvertSystemError(err) + } + + var identities []string + for _, file := range files { + // we seek the files corresponding to user SSH certificates, which are + // stored in [user]-ssh subdirectory. These are generated on successful logins. + if file.IsDir() && strings.HasSuffix(file.Name(), keypaths.SSHDirSuffix) { + username := strings.TrimSuffix(file.Name(), keypaths.SSHDirSuffix) + identities = append(identities, username) + } + } + return identities, nil +} + func getCredentialsByName(credentialDir string, opts ...keys.ParsePrivateKeyOpt) (map[string]TLSCredential, error) { files, err := os.ReadDir(credentialDir) if err != nil { @@ -949,3 +981,9 @@ func (ms *MemKeyStore) GetSSHCertificates(proxyHost, username string) ([]*ssh.Ce return sshCerts, nil } + +// GetIdentities returns the usernames associated to signed user certificates +// for the given proxy in the keystore. +func (ms *MemKeyStore) GetIdentities(proxyHost string) ([]string, error) { + return slices.Collect(maps.Keys(ms.keyRings[proxyHost])), nil +} diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index 9b6d5bc5cd1b6..3503de59669b2 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -2549,10 +2549,10 @@ func onLogout(cf *CLIConf) error { active, available, err := cf.FullProfileStatus() if err != nil && !trace.IsCompareFailed(err) { if trace.IsNotFound(err) { - fmt.Printf("All users logged out.\n") + fmt.Fprintf(cf.Stdout(), "All users logged out.\n") return nil } else if trace.IsAccessDenied(err) { - fmt.Printf("%v: Logged in user does not have the correct permissions\n", err) + fmt.Fprintf(cf.Stdout(), "%v: Logged in user does not have the correct permissions\n", err) return nil } return trace.Wrap(err) @@ -2570,7 +2570,33 @@ func onLogout(cf *CLIConf) error { switch { // Proxy and username for key to remove. - case proxyHost != "" && cf.Username != "": + case proxyHost != "": + // In the event --user flag is not supplied, and there is only one identity, + // we can simply log out the single identity. + if cf.Username == "" { + clientStore := cf.getClientStore() + usernames, err := clientStore.GetIdentities(proxyHost) + if err != nil { + return trace.Wrap(err) + } + + if len(usernames) == 0 { + fmt.Fprintf(cf.Stdout(), "All users logged out.\n") + return nil + } + + logger.DebugContext(cf.Context, "No --user flag provided, but identities found for proxy", + "proxy_host", proxyHost, + "users", usernames) + + if len(usernames) > 1 { + fmt.Fprintf(cf.Stdout(), "Specify --user to log out a specific user from %q or remove the --proxy flag to log out all users from all proxies.\n", proxyHost) + return nil + } + + cf.Username = usernames[0] + } + tc, err := makeClient(cf) if err != nil { return trace.Wrap(err) @@ -2599,7 +2625,7 @@ func onLogout(cf *CLIConf) error { err = tc.Logout() if err != nil { if trace.IsNotFound(err) { - fmt.Printf("User %v already logged out from %v.\n", cf.Username, proxyHost) + fmt.Fprintf(cf.Stdout(), "User %v already logged out from %v.\n", cf.Username, proxyHost) return trace.Wrap(&common.ExitCodeError{Code: 1}) } return trace.Wrap(err) @@ -2612,7 +2638,7 @@ func onLogout(cf *CLIConf) error { return trace.Wrap(err) } - fmt.Printf("Logged out %v from %v.\n", cf.Username, proxyHost) + fmt.Fprintf(cf.Stdout(), "Logged out %v from %v.\n", cf.Username, proxyHost) // Remove all keys. case proxyHost == "" && cf.Username == "": tc, err := makeClient(cf) @@ -2667,7 +2693,7 @@ func onLogout(cf *CLIConf) error { return trace.Wrap(tc.SAMLSingleLogout(ctx, sloURL)) }) if err != nil { - fmt.Printf("We were unable to log you out of your SAML identity provider: %v", err) + fmt.Fprintf(cf.Stdout(), "We were unable to log you out of your SAML identity provider: %v\n", err) } // Remove all keys from disk and the running agent. @@ -2676,11 +2702,9 @@ func onLogout(cf *CLIConf) error { return trace.Wrap(err) } - fmt.Printf("Logged out all users from all proxies.\n") - case proxyHost != "" && cf.Username == "": - fmt.Printf("Specify --user to log out a specific user from %q or remove the --proxy flag to log out all users from all proxies.\n", proxyHost) + fmt.Fprintf(cf.Stdout(), "Logged out all users from all proxies.\n") case proxyHost == "" && cf.Username != "": - fmt.Printf("Specify --proxy to log out user %q from a specific proxy or remove the --user flag to log out all users from all proxies.\n", cf.Username) + fmt.Fprintf(cf.Stdout(), "Specify --proxy to log out user %q from a specific proxy or remove the --user flag to log out all users from all proxies.\n", cf.Username) } return nil } diff --git a/tool/tsh/common/tsh_test.go b/tool/tsh/common/tsh_test.go index 9408c4a5c4fb7..caf01c525656b 100644 --- a/tool/tsh/common/tsh_test.go +++ b/tool/tsh/common/tsh_test.go @@ -7910,3 +7910,69 @@ func TestDebugVersionOutput(t *testing.T) { require.Contains(t, string(output), v) } } + +func TestLogoutOneIdentity(t *testing.T) { + tmpHomePath := t.TempDir() + connector := mockConnector(t) + + alice, err := types.NewUser("alice@example.com") + require.NoError(t, err) + alice.SetRoles([]string{"access"}) + + rootServer, err := testserver.NewTeleportProcess( + t.TempDir(), + testserver.WithBootstrap(connector, alice)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, rootServer.Close()) + require.NoError(t, rootServer.Wait()) + }) + + authServer := rootServer.GetAuthServer() + require.NotNil(t, authServer) + proxyAddr, err := rootServer.ProxyWebAddr() + require.NoError(t, err) + + tests := []struct { + name string + command []string + envMap map[string]string + }{ + { + name: "--proxy flag set", + command: []string{"logout", "--proxy", proxyAddr.String()}, + }, + { + name: "TELEPORT_PROXY set", + command: []string{"logout"}, + envMap: map[string]string{ + proxyEnvVar: proxyAddr.String(), + }, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + for k, v := range tc.envMap { + t.Setenv(k, v) + } + + err = Run(context.Background(), []string{ + "login", + "--insecure", + "--proxy", proxyAddr.String()}, + setHomePath(tmpHomePath), + setMockSSOLogin(authServer, alice, connector.GetName())) + require.NoError(t, err) + + buf := bytes.NewBuffer([]byte{}) + err := Run(context.Background(), tc.command, + setHomePath(tmpHomePath), + func(cf *CLIConf) error { + cf.OverrideStdout = buf + return nil + }) + require.NoError(t, err) + require.Contains(t, buf.String(), fmt.Sprintf("Logged out %v from %v.\n", alice.GetName(), proxyAddr.Host())) + }) + } +}