diff --git a/lib/auth/auth_test.go b/lib/auth/auth_test.go index 65dd76312c46a..2282ba33de695 100644 --- a/lib/auth/auth_test.go +++ b/lib/auth/auth_test.go @@ -4015,7 +4015,7 @@ func TestCAGeneration(t *testing.T) { keyStoreManager, err := keystore.NewManager(t.Context(), &servicecfg.KeystoreConfig{}, &keystore.Options{ ClusterName: &types.ClusterNameV2{Metadata: types.Metadata{Name: clusterName}}, AuthPreferenceGetter: &fakeAuthPreferenceGetter{}, - RSAKeyPairSource: func() (priv []byte, pub []byte, err error) { + RSAKeyPairSource: func(alg cryptosuites.Algorithm) (priv []byte, pub []byte, err error) { return privKey, pubKey, nil }, }) diff --git a/lib/auth/keystore/keystore_test.go b/lib/auth/keystore/keystore_test.go index d9063bd1f9d41..5ee89f716051d 100644 --- a/lib/auth/keystore/keystore_test.go +++ b/lib/auth/keystore/keystore_test.go @@ -600,9 +600,21 @@ func newTestPack(ctx context.Context, t *testing.T) *testPack { }, kmsClient: testGCPKMSClient, clockworkOverride: clock, + RSAKeyPairSource: func(alg cryptosuites.Algorithm) ([]byte, []byte, error) { + switch alg { + case cryptosuites.RSA2048: + return testRSA2048PrivateKeyPEM, nil, nil + case cryptosuites.RSA4096: + return testRSA4096PrivateKeyPEM, nil, nil + } + + return nil, nil, trace.Errorf("unexpected algorithm: %v", alg) + }, } - softwareBackend := newSoftwareKeyStore(&softwareConfig{}) + softwareBackend := newSoftwareKeyStore(&softwareConfig{ + rsaKeyPairSource: baseOpts.RSAKeyPairSource, + }) backends = append(backends, &backendDesc{ name: "software", config: servicecfg.KeystoreConfig{}, diff --git a/lib/auth/keystore/software.go b/lib/auth/keystore/software.go index 18058bc24b2c2..01b28bc68035c 100644 --- a/lib/auth/keystore/software.go +++ b/lib/auth/keystore/software.go @@ -38,7 +38,7 @@ type softwareKeyStore struct { } // RSAKeyPairSource is a function type which returns new RSA keypairs. -type RSAKeyPairSource func() (priv []byte, pub []byte, err error) +type RSAKeyPairSource func(algo cryptosuites.Algorithm) (priv []byte, pub []byte, err error) type softwareConfig struct { rsaKeyPairSource RSAKeyPairSource @@ -65,7 +65,7 @@ func (s *softwareKeyStore) keyTypeDescription() string { // an equivalent crypto.Signer. func (s *softwareKeyStore) generateSigner(ctx context.Context, alg cryptosuites.Algorithm) ([]byte, crypto.Signer, error) { if alg == cryptosuites.RSA2048 && s.rsaKeyPairSource != nil { - privateKeyPEM, _, err := s.rsaKeyPairSource() + privateKeyPEM, _, err := s.rsaKeyPairSource(alg) if err != nil { return nil, nil, err } @@ -110,6 +110,20 @@ func (d oaepDecrypter) Decrypt(rand io.Reader, ciphertext []byte, opts crypto.De // identifier for softwareKeyStore is a pem-encoded private key, and can be passed to getDecrypter later to get // an equivalent crypto.Decrypter. func (s *softwareKeyStore) generateDecrypter(ctx context.Context, alg cryptosuites.Algorithm) ([]byte, crypto.Decrypter, crypto.Hash, error) { + if alg == cryptosuites.RSA4096 && s.rsaKeyPairSource != nil { + privateKeyPEM, _, err := s.rsaKeyPairSource(alg) + if err != nil { + return nil, nil, softwareHash, trace.Wrap(err) + } + privateKey, err := keys.ParsePrivateKey(privateKeyPEM) + decrypter, ok := privateKey.Signer.(crypto.Decrypter) + if !ok { + return nil, nil, softwareHash, trace.Errorf("could not type assert crypto.Decrypter") + } + + return privateKeyPEM, newOAEPDecrypter(softwareHash, decrypter), softwareHash, trace.Wrap(err) + } + key, err := cryptosuites.GenerateDecrypterWithAlgorithm(alg) if err != nil { return nil, nil, softwareHash, trace.Wrap(err)