diff --git a/lib/client/keystore.go b/lib/client/keystore.go index 2cba0d5810825..6de10d6929516 100644 --- a/lib/client/keystore.go +++ b/lib/client/keystore.go @@ -59,6 +59,10 @@ const ( // tshAzureDirName is the name of the directory containing the // az cli app-specific profiles. tshAzureDirName = "azure" + + // tshBin is the name of the directory containing the + // updated binaries of client tools. + tshBin = "bin" ) // KeyStore is a storage interface for client session keys and certificates. @@ -480,13 +484,11 @@ func (fs *FSKeyStore) DeleteKeys() error { if err != nil { return trace.ConvertSystemError(err) } + ignoreDirs := map[string]struct{}{tshConfigFileName: {}, tshAzureDirName: {}, tshBin: {}} for _, file := range files { - // Don't delete 'config' and 'azure' directories. + // Don't delete 'config', 'azure' and 'bin' directories. // TODO: this is hackish and really shouldn't be needed, but fs.KeyDir is `~/.tsh` while it probably should be `~/.tsh/keys` instead. - if file.IsDir() && file.Name() == tshConfigFileName { - continue - } - if file.IsDir() && file.Name() == tshAzureDirName { + if _, ok := ignoreDirs[file.Name()]; ok && file.IsDir() { continue } if file.IsDir() { diff --git a/lib/client/keystore_test.go b/lib/client/keystore_test.go index 620ec8f687981..b78820ace94ce 100644 --- a/lib/client/keystore_test.go +++ b/lib/client/keystore_test.go @@ -281,17 +281,31 @@ func TestAddKey_withoutSSHCert(t *testing.T) { require.Len(t, keyCopy.DBTLSCredentials, 1) } -func TestConfigDirNotDeleted(t *testing.T) { +func TestProtectedDirsNotDeleted(t *testing.T) { t.Parallel() auth := newTestAuthority(t) keyStore := newTestFSKeyStore(t) idx := KeyRingIndex{"host.a", "bob", "root"} keyStore.AddKeyRing(auth.makeSignedKeyRing(t, idx, false)) + configPath := filepath.Join(keyStore.KeyDir, "config") require.NoError(t, os.Mkdir(configPath, 0700)) + + azurePath := filepath.Join(keyStore.KeyDir, "azure") + require.NoError(t, os.Mkdir(azurePath, 0700)) + + binPath := filepath.Join(keyStore.KeyDir, "bin") + require.NoError(t, os.Mkdir(binPath, 0700)) + + testPath := filepath.Join(keyStore.KeyDir, "test") + require.NoError(t, os.Mkdir(testPath, 0700)) + require.NoError(t, keyStore.DeleteKeys()) require.DirExists(t, configPath) + require.DirExists(t, azurePath) + require.DirExists(t, binPath) + require.NoDirExists(t, testPath) require.NoDirExists(t, filepath.Join(keyStore.KeyDir, "keys")) }