Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions api/utils/keys/privatekey.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,10 @@ func LoadPrivateKey(keyFile string) (*PrivateKey, error) {

priv, err := ParsePrivateKey(keyPEM)
if err != nil {
// Treat malformed keys the same as missing keys.
if trace.IsBadParameter(err) {
return nil, trace.NotFound("%s", err.Error())
}
return nil, trace.Wrap(err)
}
return priv, nil
Expand Down Expand Up @@ -307,14 +311,14 @@ func ParsePrivateKey(keyPEM []byte, opts ...ParsePrivateKeyOpt) (*PrivateKey, er

hwSigner, err := hardwarekey.DecodeSigner(block.Bytes, hwks, appliedOpts.ContextualKeyInfo)
if err != nil {
return nil, trace.Wrap(err, "failed to parse hardware key signer")
return nil, trace.BadParameter("failed to parse hardware key signer: %s", err.Error())
}

return newPrivateKeyWithKeyPEM(hwSigner, keyPEM)
case OpenSSHPrivateKeyType:
priv, err := ssh.ParseRawPrivateKey(keyPEM)
if err != nil {
return nil, trace.Wrap(err)
return nil, trace.BadParameter("%s", err.Error())
}
cryptoSigner, ok := priv.(crypto.Signer)
if !ok {
Expand Down Expand Up @@ -355,7 +359,7 @@ func ParsePrivateKey(keyPEM []byte, opts ...ParsePrivateKeyOpt) (*PrivateKey, er
// If all three parse functions returned an error, preferedErr is
// guaranteed to be set to the error from the parse function that
// usually matches the PEM block type.
return nil, trace.Wrap(preferredErr, "parsing private key PEM")
return nil, trace.BadParameter("parsing private key PEM: %s", preferredErr.Error())
default:
return nil, trace.BadParameter("unexpected private key PEM type %q", block.Type)
}
Expand Down Expand Up @@ -425,6 +429,10 @@ func LoadKeyPair(privFile, sshPubFile string, opts ...ParsePrivateKeyOpt) (*Priv

priv, err := ParseKeyPair(privPEM, marshaledSSHPub, opts...)
if err != nil {
// Treat malformed keys the same as missing keys.
if trace.IsBadParameter(err) {
return nil, trace.NotFound("%s", err.Error())
}
return nil, trace.Wrap(err)
}
return priv, nil
Expand Down Expand Up @@ -460,6 +468,10 @@ func LoadX509KeyPair(certFile, keyFile string) (tls.Certificate, error) {

tlsCert, err := X509KeyPair(certPEMBlock, keyPEMBlock)
if err != nil {
// Treat malformed keys the same as missing keys.
if trace.IsBadParameter(err) {
return tls.Certificate{}, trace.NotFound("%s", err.Error())
}
return tls.Certificate{}, trace.Wrap(err)
}

Expand Down
65 changes: 48 additions & 17 deletions api/utils/keys/privatekey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/gravitational/trace"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -172,26 +173,56 @@ func TestParseMismatchedPEMHeader(t *testing.T) {
// that the preferredErr logic in Parse(Private|Public)Key returns an error for
// each PEM type.
func TestParseCorruptedKey(t *testing.T) {
for _, tc := range []string{
"RSA PRIVATE KEY",
"PRIVATE KEY",
"EC PRIVATE KEY",
} {
t.Run(tc, func(t *testing.T) {
b := pem.EncodeToMemory(&pem.Block{Type: tc, Bytes: []byte("foo")})
_, err := keys.ParsePrivateKey(b)
require.Error(t, err)
t.Parallel()
privateKeyTests := []struct {
name string
pemData []byte
}{
{
name: "PRIVATE KEY",
pemData: pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: []byte("foo")}),
},
{
name: "RSA PRIVATE KEY",
pemData: pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: []byte("foo")}),
},
{
name: "EC PRIVATE KEY",
pemData: pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: []byte("foo")}),
},
{
name: "not a private key pem file",
pemData: []byte("foo"),
},
}
for _, tc := range privateKeyTests {
t.Run(tc.name, func(t *testing.T) {
_, err := keys.ParsePrivateKey(tc.pemData)
require.True(t, trace.IsBadParameter(err), "wanted BadParameter, got: %v", err)
})
}

for _, tc := range []string{
"RSA PUBLIC KEY",
"PUBLIC KEY",
} {
t.Run(tc, func(t *testing.T) {
b := pem.EncodeToMemory(&pem.Block{Type: tc, Bytes: []byte("foo")})
_, err := keys.ParsePublicKey(b)
require.Error(t, err)
publicKeyTests := []struct {
name string
pemData []byte
}{
{
name: "RSA PUBLIC KEY",
pemData: pem.EncodeToMemory(&pem.Block{Type: "RSA PUBLIC KEY", Bytes: []byte("foo")}),
},
{
name: "PUBLIC KEY",
pemData: pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: []byte("foo")}),
},
{
name: "not a public key pem file",
pemData: []byte("foo"),
},
}
for _, tc := range publicKeyTests {
t.Run(tc.name, func(t *testing.T) {
_, err := keys.ParsePublicKey(tc.pemData)
require.True(t, trace.IsBadParameter(err), "wanted BadParameter, got: %v", err)
})
}
}
Expand Down
2 changes: 1 addition & 1 deletion api/utils/keys/publickey.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,5 @@ func ParsePublicKey(keyPEM []byte) (crypto.PublicKey, error) {
// If both parse functions returned an error, preferedErr is guaranteed to
// be set to the error from the parse function that usually matches the PEM
// block type.
return nil, trace.Wrap(preferredErr, "parsing public key PEM")
return nil, trace.BadParameter("parsing public key PEM: %s", preferredErr)
}
8 changes: 4 additions & 4 deletions lib/client/keystore.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,10 +426,9 @@ func (fs *FSKeyStore) DeleteKeyRing(idx KeyRingIndex) error {
fs.publicKeyPath(idx),
fs.tlsCertPath(idx),
}
var deleteErrs []error
for _, fn := range files {
if err := utils.RemoveSecure(fn); err != nil {
return trace.ConvertSystemError(err)
}
deleteErrs = append(deleteErrs, trace.ConvertSystemError(utils.RemoveSecure(fn)))
}
// we also need to delete the extra PuTTY-formatted .ppk file when running on Windows,
// but it may not exist when upgrading from v9 -> v10 and logging into an existing cluster.
Expand All @@ -446,7 +445,8 @@ func (fs *FSKeyStore) DeleteKeyRing(idx KeyRingIndex) error {

// Clear ClusterName to delete the user certs stored for all clusters.
idx.ClusterName = ""
return fs.DeleteUserCerts(idx, WithAllCerts...)
deleteErrs = append(deleteErrs, fs.DeleteUserCerts(idx, WithAllCerts...))
return trace.NewAggregate(deleteErrs...)
}

// DeleteUserCerts deletes only the specified parts of the user's keyring,
Expand Down
21 changes: 21 additions & 0 deletions lib/client/keystore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,27 @@ func TestProtectedDirsNotDeleted(t *testing.T) {
require.NoDirExists(t, filepath.Join(keyStore.KeyDir, "keys"))
}

// TestDeleteKeyRingContinueOnError verifies that an issue deleting one file
// does not prevent deleting the others.
func TestDeleteKeyRingContinueOnError(t *testing.T) {
t.Parallel()
auth := newTestAuthority(t)
keyStore := newTestFSKeyStore(t)
idx := KeyRingIndex{"host.a", "bob", "root"}
require.NoError(t, keyStore.AddKeyRing(auth.makeSignedKeyRing(t, idx, false)))

require.NoError(t, os.Remove(keyStore.userSSHKeyPath(idx)))
require.Error(t, keyStore.DeleteKeyRing(idx))
for _, file := range []string{
keyStore.userSSHKeyPath(idx),
keyStore.userTLSKeyPath(idx),
keyStore.publicKeyPath(idx),
keyStore.tlsCertPath(idx),
} {
require.NoFileExists(t, file)
}
}

func assertEqualKeyRings(t *testing.T, expected, actual *KeyRing) {
t.Helper()
// Ignore differences in unexported private key fields, for example keyPEM
Expand Down
12 changes: 3 additions & 9 deletions lib/utils/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,15 +337,9 @@ func removeSecure(filePath string, fi os.FileInfo) error {
defer f.Close()

if runtime.GOOS == "windows" {
// Windows can't unlink the file before overwriting.
if f != nil {
for i := 0; i < 3; i++ {
if err := overwriteFile(f, fi); err != nil {
break
}
}
}
// The file should be closed before removing it on Windows.
// On windows, os.Remove() will fail if there are any open handles to the
// file, including in other processes. Skip overwrite to avoid leaving
// files in a broken state.
closeErr := trace.ConvertSystemError(f.Close())
removeErr := trace.ConvertSystemError(os.Remove(filePath))
return trace.NewAggregate(closeErr, removeErr)
Expand Down
Loading