Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions api/utils/keypaths/keypaths.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ const (
// sessionKeyDir is a sub-directory where session keys are stored
sessionKeyDir = "keys"
// sshDirSuffix is the suffix of a sub-directory where SSH certificates are stored.
sshDirSuffix = "-ssh"
SSHDirSuffix = "-ssh"
// fileNameKnownHosts is a file where known hosts are stored.
fileNameKnownHosts = "known_hosts"
// FileExtTLSCertLegacy is the legacy suffix/extension of a file where a TLS cert is stored.
Expand Down Expand Up @@ -257,7 +257,7 @@ func TLSCAsPathCluster(baseDir, proxy, cluster string) string {
//
// <baseDir>/keys/<proxy>/<username>-ssh
func SSHDir(baseDir, proxy, username string) string {
return filepath.Join(ProxyKeyDir(baseDir, proxy), username+sshDirSuffix)
return filepath.Join(ProxyKeyDir(baseDir, proxy), username+SSHDirSuffix)
}

// PPKFilePath returns the path to the user's PuTTY PPK-formatted keypair
Expand Down
38 changes: 38 additions & 0 deletions lib/client/keystore.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ import (
"fmt"
iofs "io/fs"
"log/slog"
"maps"
"os"
"path/filepath"
"runtime"
"slices"
"strings"
"time"

Expand Down Expand Up @@ -89,6 +91,10 @@ type KeyStore interface {
// GetSSHCertificates gets all certificates signed for the given user and proxy,
// including certificates for trusted clusters.
GetSSHCertificates(proxyHost, username string) ([]*ssh.Certificate, error)

// GetIdentities returns the usernames associated to signed user certificates
// for the given proxy in the keystore.
GetIdentities(proxyHost string) ([]string, error)
}

// FSKeyStore is an on-disk implementation of the KeyStore interface.
Expand All @@ -113,6 +119,11 @@ func NewFSKeyStore(dirPath string) *FSKeyStore {
}
}

// proxyKeyDir returns the path to the given proxy's keys directory.
func (fs *FSKeyStore) proxyKeyDir(proxy string) string {
return keypaths.ProxyKeyDir(fs.KeyDir, proxy)
}

// userSSHKeyPath returns the SSH private key path for the given KeyRingIndex.
func (fs *FSKeyStore) userSSHKeyPath(idx KeyRingIndex) string {
return keypaths.UserSSHKeyPath(fs.KeyDir, idx.ProxyHost, idx.Username)
Expand Down Expand Up @@ -603,6 +614,27 @@ func (fs *FSKeyStore) GetSSHCertificates(proxyHost, username string) ([]*ssh.Cer
return sshCerts, nil
}

// GetIdentities returns the usernames associated to signed user certificates
// for the given proxy in the keystore.
func (fs *FSKeyStore) GetIdentities(proxyHost string) ([]string, error) {
proxyDir := fs.proxyKeyDir(proxyHost)
files, err := os.ReadDir(proxyDir)
if err != nil {
return nil, trace.ConvertSystemError(err)
}

var identities []string
for _, file := range files {
// we seek the files corresponding to user SSH certificates, which are
// stored in [user]-ssh subdirectory. These are generated on successful logins.
if file.IsDir() && strings.HasSuffix(file.Name(), keypaths.SSHDirSuffix) {
username := strings.TrimSuffix(file.Name(), keypaths.SSHDirSuffix)
identities = append(identities, username)
}
}
return identities, nil
}

func getCredentialsByName(credentialDir string, opts ...keys.ParsePrivateKeyOpt) (map[string]TLSCredential, error) {
files, err := os.ReadDir(credentialDir)
if err != nil {
Expand Down Expand Up @@ -949,3 +981,9 @@ func (ms *MemKeyStore) GetSSHCertificates(proxyHost, username string) ([]*ssh.Ce

return sshCerts, nil
}

// GetIdentities returns the usernames associated to signed user certificates
// for the given proxy in the keystore.
func (ms *MemKeyStore) GetIdentities(proxyHost string) ([]string, error) {
return slices.Collect(maps.Keys(ms.keyRings[proxyHost])), nil
}
44 changes: 34 additions & 10 deletions tool/tsh/common/tsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -2549,10 +2549,10 @@ func onLogout(cf *CLIConf) error {
active, available, err := cf.FullProfileStatus()
if err != nil && !trace.IsCompareFailed(err) {
if trace.IsNotFound(err) {
fmt.Printf("All users logged out.\n")
fmt.Fprintf(cf.Stdout(), "All users logged out.\n")
return nil
} else if trace.IsAccessDenied(err) {
fmt.Printf("%v: Logged in user does not have the correct permissions\n", err)
fmt.Fprintf(cf.Stdout(), "%v: Logged in user does not have the correct permissions\n", err)
return nil
}
return trace.Wrap(err)
Expand All @@ -2570,7 +2570,33 @@ func onLogout(cf *CLIConf) error {

switch {
// Proxy and username for key to remove.
case proxyHost != "" && cf.Username != "":
case proxyHost != "":
// In the event --user flag is not supplied, and there is only one identity,
// we can simply log out the single identity.
if cf.Username == "" {
clientStore := cf.getClientStore()
usernames, err := clientStore.GetIdentities(proxyHost)
if err != nil {
return trace.Wrap(err)
}

if len(usernames) == 0 {
fmt.Fprintf(cf.Stdout(), "All users logged out.\n")
return nil
}

logger.DebugContext(cf.Context, "No --user flag provided, but identities found for proxy",
"proxy_host", proxyHost,
"users", usernames)

if len(usernames) > 1 {
fmt.Fprintf(cf.Stdout(), "Specify --user to log out a specific user from %q or remove the --proxy flag to log out all users from all proxies.\n", proxyHost)
return nil
}

cf.Username = usernames[0]
}

tc, err := makeClient(cf)
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -2599,7 +2625,7 @@ func onLogout(cf *CLIConf) error {
err = tc.Logout()
if err != nil {
if trace.IsNotFound(err) {
fmt.Printf("User %v already logged out from %v.\n", cf.Username, proxyHost)
fmt.Fprintf(cf.Stdout(), "User %v already logged out from %v.\n", cf.Username, proxyHost)
return trace.Wrap(&common.ExitCodeError{Code: 1})
}
return trace.Wrap(err)
Expand All @@ -2612,7 +2638,7 @@ func onLogout(cf *CLIConf) error {
return trace.Wrap(err)
}

fmt.Printf("Logged out %v from %v.\n", cf.Username, proxyHost)
fmt.Fprintf(cf.Stdout(), "Logged out %v from %v.\n", cf.Username, proxyHost)
// Remove all keys.
case proxyHost == "" && cf.Username == "":
tc, err := makeClient(cf)
Expand Down Expand Up @@ -2667,7 +2693,7 @@ func onLogout(cf *CLIConf) error {
return trace.Wrap(tc.SAMLSingleLogout(ctx, sloURL))
})
if err != nil {
fmt.Printf("We were unable to log you out of your SAML identity provider: %v", err)
fmt.Fprintf(cf.Stdout(), "We were unable to log you out of your SAML identity provider: %v\n", err)
}

// Remove all keys from disk and the running agent.
Expand All @@ -2676,11 +2702,9 @@ func onLogout(cf *CLIConf) error {
return trace.Wrap(err)
}

fmt.Printf("Logged out all users from all proxies.\n")
case proxyHost != "" && cf.Username == "":
fmt.Printf("Specify --user to log out a specific user from %q or remove the --proxy flag to log out all users from all proxies.\n", proxyHost)
fmt.Fprintf(cf.Stdout(), "Logged out all users from all proxies.\n")
case proxyHost == "" && cf.Username != "":
fmt.Printf("Specify --proxy to log out user %q from a specific proxy or remove the --user flag to log out all users from all proxies.\n", cf.Username)
fmt.Fprintf(cf.Stdout(), "Specify --proxy to log out user %q from a specific proxy or remove the --user flag to log out all users from all proxies.\n", cf.Username)
}
return nil
}
Expand Down
66 changes: 66 additions & 0 deletions tool/tsh/common/tsh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7910,3 +7910,69 @@ func TestDebugVersionOutput(t *testing.T) {
require.Contains(t, string(output), v)
}
}

func TestLogoutOneIdentity(t *testing.T) {
tmpHomePath := t.TempDir()
connector := mockConnector(t)

alice, err := types.NewUser("alice@example.com")
require.NoError(t, err)
alice.SetRoles([]string{"access"})

rootServer, err := testserver.NewTeleportProcess(
t.TempDir(),
testserver.WithBootstrap(connector, alice))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, rootServer.Close())
require.NoError(t, rootServer.Wait())
})

authServer := rootServer.GetAuthServer()
require.NotNil(t, authServer)
proxyAddr, err := rootServer.ProxyWebAddr()
require.NoError(t, err)

tests := []struct {
name string
command []string
envMap map[string]string
}{
{
name: "--proxy flag set",
command: []string{"logout", "--proxy", proxyAddr.String()},
},
{
name: "TELEPORT_PROXY set",
command: []string{"logout"},
envMap: map[string]string{
proxyEnvVar: proxyAddr.String(),
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
for k, v := range tc.envMap {
t.Setenv(k, v)
}

err = Run(context.Background(), []string{
"login",
"--insecure",
"--proxy", proxyAddr.String()},
setHomePath(tmpHomePath),
setMockSSOLogin(authServer, alice, connector.GetName()))
require.NoError(t, err)

buf := bytes.NewBuffer([]byte{})
err := Run(context.Background(), tc.command,
setHomePath(tmpHomePath),
func(cf *CLIConf) error {
cf.OverrideStdout = buf
return nil
})
require.NoError(t, err)
require.Contains(t, buf.String(), fmt.Sprintf("Logged out %v from %v.\n", alice.GetName(), proxyAddr.Host()))
})
}
}
Loading