diff --git a/api/utils/keys/privatekey.go b/api/utils/keys/privatekey.go index cc7339ecc7bf1..c1fd3470c9967 100644 --- a/api/utils/keys/privatekey.go +++ b/api/utils/keys/privatekey.go @@ -244,6 +244,10 @@ func LoadPrivateKey(keyFile string) (*PrivateKey, error) { priv, err := ParsePrivateKey(keyPEM) if err != nil { + // Treat malformed keys the same as missing keys. + if trace.IsBadParameter(err) { + return nil, trace.NotFound("%s", err.Error()) + } return nil, trace.Wrap(err) } return priv, nil @@ -307,14 +311,14 @@ func ParsePrivateKey(keyPEM []byte, opts ...ParsePrivateKeyOpt) (*PrivateKey, er hwSigner, err := hardwarekey.DecodeSigner(block.Bytes, hwks, appliedOpts.ContextualKeyInfo) if err != nil { - return nil, trace.Wrap(err, "failed to parse hardware key signer") + return nil, trace.BadParameter("failed to parse hardware key signer: %s", err.Error()) } return newPrivateKeyWithKeyPEM(hwSigner, keyPEM) case OpenSSHPrivateKeyType: priv, err := ssh.ParseRawPrivateKey(keyPEM) if err != nil { - return nil, trace.Wrap(err) + return nil, trace.BadParameter("%s", err.Error()) } cryptoSigner, ok := priv.(crypto.Signer) if !ok { @@ -355,7 +359,7 @@ func ParsePrivateKey(keyPEM []byte, opts ...ParsePrivateKeyOpt) (*PrivateKey, er // If all three parse functions returned an error, preferedErr is // guaranteed to be set to the error from the parse function that // usually matches the PEM block type. - return nil, trace.Wrap(preferredErr, "parsing private key PEM") + return nil, trace.BadParameter("parsing private key PEM: %s", preferredErr.Error()) default: return nil, trace.BadParameter("unexpected private key PEM type %q", block.Type) } @@ -425,6 +429,10 @@ func LoadKeyPair(privFile, sshPubFile string, opts ...ParsePrivateKeyOpt) (*Priv priv, err := ParseKeyPair(privPEM, marshaledSSHPub, opts...) if err != nil { + // Treat malformed keys the same as missing keys. + if trace.IsBadParameter(err) { + return nil, trace.NotFound("%s", err.Error()) + } return nil, trace.Wrap(err) } return priv, nil @@ -460,6 +468,10 @@ func LoadX509KeyPair(certFile, keyFile string) (tls.Certificate, error) { tlsCert, err := X509KeyPair(certPEMBlock, keyPEMBlock) if err != nil { + // Treat malformed keys the same as missing keys. + if trace.IsBadParameter(err) { + return tls.Certificate{}, trace.NotFound("%s", err.Error()) + } return tls.Certificate{}, trace.Wrap(err) } diff --git a/api/utils/keys/privatekey_test.go b/api/utils/keys/privatekey_test.go index 4be10d55e6d51..c9bea6ac52a96 100644 --- a/api/utils/keys/privatekey_test.go +++ b/api/utils/keys/privatekey_test.go @@ -33,6 +33,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/gravitational/trace" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -172,26 +173,56 @@ func TestParseMismatchedPEMHeader(t *testing.T) { // that the preferredErr logic in Parse(Private|Public)Key returns an error for // each PEM type. func TestParseCorruptedKey(t *testing.T) { - for _, tc := range []string{ - "RSA PRIVATE KEY", - "PRIVATE KEY", - "EC PRIVATE KEY", - } { - t.Run(tc, func(t *testing.T) { - b := pem.EncodeToMemory(&pem.Block{Type: tc, Bytes: []byte("foo")}) - _, err := keys.ParsePrivateKey(b) - require.Error(t, err) + t.Parallel() + privateKeyTests := []struct { + name string + pemData []byte + }{ + { + name: "PRIVATE KEY", + pemData: pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: []byte("foo")}), + }, + { + name: "RSA PRIVATE KEY", + pemData: pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: []byte("foo")}), + }, + { + name: "EC PRIVATE KEY", + pemData: pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: []byte("foo")}), + }, + { + name: "not a private key pem file", + pemData: []byte("foo"), + }, + } + for _, tc := range privateKeyTests { + t.Run(tc.name, func(t *testing.T) { + _, err := keys.ParsePrivateKey(tc.pemData) + require.True(t, trace.IsBadParameter(err), "wanted BadParameter, got: %v", err) }) } - for _, tc := range []string{ - "RSA PUBLIC KEY", - "PUBLIC KEY", - } { - t.Run(tc, func(t *testing.T) { - b := pem.EncodeToMemory(&pem.Block{Type: tc, Bytes: []byte("foo")}) - _, err := keys.ParsePublicKey(b) - require.Error(t, err) + publicKeyTests := []struct { + name string + pemData []byte + }{ + { + name: "RSA PUBLIC KEY", + pemData: pem.EncodeToMemory(&pem.Block{Type: "RSA PUBLIC KEY", Bytes: []byte("foo")}), + }, + { + name: "PUBLIC KEY", + pemData: pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: []byte("foo")}), + }, + { + name: "not a public key pem file", + pemData: []byte("foo"), + }, + } + for _, tc := range publicKeyTests { + t.Run(tc.name, func(t *testing.T) { + _, err := keys.ParsePublicKey(tc.pemData) + require.True(t, trace.IsBadParameter(err), "wanted BadParameter, got: %v", err) }) } } diff --git a/api/utils/keys/publickey.go b/api/utils/keys/publickey.go index 0979caf266c60..77a4489786bcb 100644 --- a/api/utils/keys/publickey.go +++ b/api/utils/keys/publickey.go @@ -89,5 +89,5 @@ func ParsePublicKey(keyPEM []byte) (crypto.PublicKey, error) { // If both parse functions returned an error, preferedErr is guaranteed to // be set to the error from the parse function that usually matches the PEM // block type. - return nil, trace.Wrap(preferredErr, "parsing public key PEM") + return nil, trace.BadParameter("parsing public key PEM: %s", preferredErr) } diff --git a/lib/client/keystore.go b/lib/client/keystore.go index 06f930286705b..638eff57526fa 100644 --- a/lib/client/keystore.go +++ b/lib/client/keystore.go @@ -426,10 +426,9 @@ func (fs *FSKeyStore) DeleteKeyRing(idx KeyRingIndex) error { fs.publicKeyPath(idx), fs.tlsCertPath(idx), } + var deleteErrs []error for _, fn := range files { - if err := utils.RemoveSecure(fn); err != nil { - return trace.ConvertSystemError(err) - } + deleteErrs = append(deleteErrs, trace.ConvertSystemError(utils.RemoveSecure(fn))) } // we also need to delete the extra PuTTY-formatted .ppk file when running on Windows, // but it may not exist when upgrading from v9 -> v10 and logging into an existing cluster. @@ -446,7 +445,8 @@ func (fs *FSKeyStore) DeleteKeyRing(idx KeyRingIndex) error { // Clear ClusterName to delete the user certs stored for all clusters. idx.ClusterName = "" - return fs.DeleteUserCerts(idx, WithAllCerts...) + deleteErrs = append(deleteErrs, fs.DeleteUserCerts(idx, WithAllCerts...)) + return trace.NewAggregate(deleteErrs...) } // DeleteUserCerts deletes only the specified parts of the user's keyring, diff --git a/lib/client/keystore_test.go b/lib/client/keystore_test.go index b891fa9859041..11d1242ff115f 100644 --- a/lib/client/keystore_test.go +++ b/lib/client/keystore_test.go @@ -338,6 +338,27 @@ func TestProtectedDirsNotDeleted(t *testing.T) { require.NoDirExists(t, filepath.Join(keyStore.KeyDir, "keys")) } +// TestDeleteKeyRingContinueOnError verifies that an issue deleting one file +// does not prevent deleting the others. +func TestDeleteKeyRingContinueOnError(t *testing.T) { + t.Parallel() + auth := newTestAuthority(t) + keyStore := newTestFSKeyStore(t) + idx := KeyRingIndex{"host.a", "bob", "root"} + require.NoError(t, keyStore.AddKeyRing(auth.makeSignedKeyRing(t, idx, false))) + + require.NoError(t, os.Remove(keyStore.userSSHKeyPath(idx))) + require.Error(t, keyStore.DeleteKeyRing(idx)) + for _, file := range []string{ + keyStore.userSSHKeyPath(idx), + keyStore.userTLSKeyPath(idx), + keyStore.publicKeyPath(idx), + keyStore.tlsCertPath(idx), + } { + require.NoFileExists(t, file) + } +} + func assertEqualKeyRings(t *testing.T, expected, actual *KeyRing) { t.Helper() // Ignore differences in unexported private key fields, for example keyPEM diff --git a/lib/utils/fs.go b/lib/utils/fs.go index 6831a4a482b71..a510c6e6a30b3 100644 --- a/lib/utils/fs.go +++ b/lib/utils/fs.go @@ -337,15 +337,9 @@ func removeSecure(filePath string, fi os.FileInfo) error { defer f.Close() if runtime.GOOS == "windows" { - // Windows can't unlink the file before overwriting. - if f != nil { - for i := 0; i < 3; i++ { - if err := overwriteFile(f, fi); err != nil { - break - } - } - } - // The file should be closed before removing it on Windows. + // On windows, os.Remove() will fail if there are any open handles to the + // file, including in other processes. Skip overwrite to avoid leaving + // files in a broken state. closeErr := trace.ConvertSystemError(f.Close()) removeErr := trace.ConvertSystemError(os.Remove(filePath)) return trace.NewAggregate(closeErr, removeErr)