diff --git a/lib/auth/init.go b/lib/auth/init.go index 6da166eb14fdd..7bdade01952e4 100644 --- a/lib/auth/init.go +++ b/lib/auth/init.go @@ -90,7 +90,7 @@ type VersionStorage interface { // operations. type RecordingEncryptionManager interface { services.RecordingEncryption - recordingencryption.DecryptionKeyFinder + recordingencryption.KeyUnwrapper SetCache(cache recordingencryption.Cache) } diff --git a/lib/auth/keystore/aws_kms.go b/lib/auth/keystore/aws_kms.go index 8259f6393d01c..2d2de9243e538 100644 --- a/lib/auth/keystore/aws_kms.go +++ b/lib/auth/keystore/aws_kms.go @@ -253,6 +253,8 @@ func awsAlgorithm(alg cryptosuites.Algorithm) (kmstypes.KeySpec, error) { switch alg { case cryptosuites.RSA2048: return kmstypes.KeySpecRsa2048, nil + case cryptosuites.RSA4096: + return kmstypes.KeySpecRsa4096, nil case cryptosuites.ECDSAP256: return kmstypes.KeySpecEccNistP256, nil } diff --git a/lib/auth/keystore/aws_kms_test.go b/lib/auth/keystore/aws_kms_test.go index 36ba9d376e6ea..6eb767656f008 100644 --- a/lib/auth/keystore/aws_kms_test.go +++ b/lib/auth/keystore/aws_kms_test.go @@ -395,6 +395,8 @@ func (f *fakeAWSKMSService) CreateKey(_ context.Context, input *kms.CreateKeyInp switch input.KeySpec { case kmstypes.KeySpecRsa2048: privKeyPEM = testRSA2048PrivateKeyPEM + case kmstypes.KeySpecRsa4096: + privKeyPEM = testRSA4096PrivateKeyPEM case kmstypes.KeySpecEccNistP256: signer, err := cryptosuites.GenerateKeyWithAlgorithm(cryptosuites.ECDSAP256) if err != nil { diff --git a/lib/auth/keystore/gcp_kms.go b/lib/auth/keystore/gcp_kms.go index 50f208e0a8c19..4550879d7289d 100644 --- a/lib/auth/keystore/gcp_kms.go +++ b/lib/auth/keystore/gcp_kms.go @@ -112,7 +112,7 @@ func (g *gcpKMSKeyStore) keyTypeDescription() string { } func (g *gcpKMSKeyStore) generateKey(ctx context.Context, algorithm cryptosuites.Algorithm, usage keyUsage) (gcpKMSKeyID, error) { - alg, err := gcpAlgorithm(algorithm) + alg, err := gcpAlgorithm(usage, algorithm) if err != nil { return gcpKMSKeyID{}, trace.Wrap(err) } @@ -176,12 +176,29 @@ func (g *gcpKMSKeyStore) generateDecrypter(ctx context.Context, algorithm crypto return keyID.marshal(), decrypter, gcpOAEPHash, nil } -func gcpAlgorithm(alg cryptosuites.Algorithm) (kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm, error) { +func gcpAlgorithm(usage keyUsage, alg cryptosuites.Algorithm) (kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm, error) { switch alg { case cryptosuites.RSA2048: - return kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_2048_SHA256, nil + switch usage { + case keyUsageSign: + return kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_2048_SHA256, nil + case keyUsageDecrypt: + return kmspb.CryptoKeyVersion_RSA_DECRYPT_OAEP_2048_SHA256, nil + } + case cryptosuites.RSA4096: + switch usage { + case keyUsageSign: + return kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA256, nil + case keyUsageDecrypt: + return kmspb.CryptoKeyVersion_RSA_DECRYPT_OAEP_4096_SHA256, nil + } case cryptosuites.ECDSAP256: - return kmspb.CryptoKeyVersion_EC_SIGN_P256_SHA256, nil + switch usage { + case keyUsageSign: + return kmspb.CryptoKeyVersion_EC_SIGN_P256_SHA256, nil + case keyUsageDecrypt: + return kmspb.CryptoKeyVersion_CRYPTO_KEY_VERSION_ALGORITHM_UNSPECIFIED, trace.BadParameter("unsupported algorithm for decryption: %v", alg) + } } return kmspb.CryptoKeyVersion_CRYPTO_KEY_VERSION_ALGORITHM_UNSPECIFIED, trace.BadParameter("unsupported algorithm: %v", alg) } diff --git a/lib/auth/keystore/gcp_kms_test.go b/lib/auth/keystore/gcp_kms_test.go index 800b50feb2556..cb64b3c038d2a 100644 --- a/lib/auth/keystore/gcp_kms_test.go +++ b/lib/auth/keystore/gcp_kms_test.go @@ -112,6 +112,8 @@ func (f *fakeGCPKMSServer) CreateCryptoKey(ctx context.Context, req *kmspb.Creat switch cryptoKey.VersionTemplate.Algorithm { case kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_2048_SHA256, kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA512: pem = testRSA2048PrivateKeyPEM + case kmspb.CryptoKeyVersion_RSA_DECRYPT_OAEP_4096_SHA256: + pem = testRSA4096PrivateKeyPEM case kmspb.CryptoKeyVersion_EC_SIGN_P256_SHA256: signer, err := cryptosuites.GenerateKeyWithAlgorithm(cryptosuites.ECDSAP256) if err != nil { diff --git a/lib/auth/keystore/pkcs11.go b/lib/auth/keystore/pkcs11.go index 5477f7d1b6e52..3489890cd8dd6 100644 --- a/lib/auth/keystore/pkcs11.go +++ b/lib/auth/keystore/pkcs11.go @@ -167,6 +167,9 @@ func (p *pkcs11KeyStore) generateSigner(ctx context.Context, alg cryptosuites.Al case cryptosuites.RSA2048: signer, err := p.generateRSA2048(rawPKCS11ID, label) return rawTeleportID, signer, trace.Wrap(err, "generating RSA2048 key") + case cryptosuites.RSA4096: + signer, err := p.generateRSA4096(rawPKCS11ID, label) + return rawTeleportID, signer, trace.Wrap(err, "generating RSA4096 key") case cryptosuites.ECDSAP256: signer, err := p.generateECDSAP256(rawPKCS11ID, label) return rawTeleportID, signer, trace.Wrap(err, "generating ECDSAP256 key") @@ -208,9 +211,9 @@ func (p *pkcs11KeyStore) generateDecrypter(ctx context.Context, alg cryptosuites label := []byte(p.hostUUID) switch alg { - case cryptosuites.RSA2048: - decrypter, err := p.generateRSA2048(rawPKCS11ID, label) - return rawTeleportID, newOAEPDecrypter(p.oaepHash, decrypter), p.oaepHash, trace.Wrap(err, "generating RSA2048 key") + case cryptosuites.RSA4096: + decrypter, err := p.generateRSA4096(rawPKCS11ID, label) + return rawTeleportID, newOAEPDecrypter(p.oaepHash, decrypter), p.oaepHash, trace.Wrap(err, "generating RSA4096 key") default: return nil, nil, p.oaepHash, trace.BadParameter("unsupported key algorithm for PKCS#11 HSM decryption: %v", alg) } @@ -221,6 +224,11 @@ func (p *pkcs11KeyStore) generateRSA2048(ckaID, label []byte) (crypto11.SignerDe return signer, trace.Wrap(err) } +func (p *pkcs11KeyStore) generateRSA4096(ckaID, label []byte) (crypto11.SignerDecrypter, error) { + signer, err := p.ctx.GenerateRSAKeyPairWithLabel(ckaID, label, 4096) + return signer, trace.Wrap(err) +} + func (p *pkcs11KeyStore) generateECDSAP256(ckaID, label []byte) (crypto.Signer, error) { signer, err := p.ctx.GenerateECDSAKeyPairWithLabel(ckaID, label, elliptic.P256()) return signer, trace.Wrap(err) diff --git a/lib/auth/keystore/software.go b/lib/auth/keystore/software.go index 3392790b86aae..b8162f6c7eb9f 100644 --- a/lib/auth/keystore/software.go +++ b/lib/auth/keystore/software.go @@ -115,17 +115,12 @@ func (s *softwareKeyStore) generateDecrypter(ctx context.Context, alg cryptosuit return nil, nil, softwareHash, trace.Wrap(err) } - decrypter := key - if alg == cryptosuites.RSA2048 { - decrypter = newOAEPDecrypter(softwareHash, decrypter) - } - privateKeyPEM, err := keys.MarshalDecrypter(key) if err != nil { return nil, nil, softwareHash, trace.Wrap(err) } - return privateKeyPEM, decrypter, softwareHash, trace.Wrap(err) + return privateKeyPEM, newOAEPDecrypter(softwareHash, key), softwareHash, trace.Wrap(err) } // getSigner returns a crypto.Signer for the given pem-encoded private key. diff --git a/lib/auth/recordingencryption/age.go b/lib/auth/recordingencryption/age.go index b4e6ac495c548..2c6d08257989a 100644 --- a/lib/auth/recordingencryption/age.go +++ b/lib/auth/recordingencryption/age.go @@ -18,44 +18,62 @@ package recordingencryption import ( "context" + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "io" "filippo.io/age" "github.com/gravitational/trace" - "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/keys" ) -// X25519Stanza is the default stanza type used by age. -const X25519Stanza = "X25519" - // RecordingStanza is the type used for the identifying stanza added by RecordingRecipient. -const RecordingStanza = "Recording-X25519" +const RecordingStanza = "teleport-recording-rsa4096" + +// oaepLabel must be present during encryption and decryption. +const oaepLabel = "teleport/v1/rsa" + +// UnwrapInput represents a request to decrypt a wrapped file key. +type UnwrapInput struct { + // Fingerprint of the public key used to find the related private key. + Fingerprint string + // WrappedKey is the encrypted file key in an encrypted recording stanza. + WrappedKey []byte + + // Rand reader to pass to use during decryption. + Rand io.Reader + // Opts that should be used during decryption. + Opts crypto.DecrypterOpts +} -// DecryptionKeyFinder returns an EncryptionKeyPair related to at least one of the given public keys to be used -// for file key unwrapping. -type DecryptionKeyFinder interface { - FindDecryptionKey(ctx context.Context, publicKeys ...[]byte) (*types.EncryptionKeyPair, error) +// KeyUnwrapper returns an unwrapped file key given a wrapped key and a fingerprint of the encryption key. +type KeyUnwrapper interface { + UnwrapKey(ctx context.Context, in UnwrapInput) ([]byte, error) } -// RecordingIdentity removes public keys from stanzas and passes the unwrap call to the default -// age.X25519Identity. +// RecordingIdentity unwraps file keys using the configured [KeyUnwrapper] and the recording stanzas +// included in the age header. type RecordingIdentity struct { ctx context.Context - keyFinder DecryptionKeyFinder + unwrapper KeyUnwrapper } -// NewRecordingIdentity returns a RecordingIdentity that will use the given DecryptionKeyFinder in order to facilitate +// NewRecordingIdentity returns a new RecordingIdentity using the given [KeyUnwrapper] // file key unwrapping. -func NewRecordingIdentity(ctx context.Context, keyFinder DecryptionKeyFinder) *RecordingIdentity { +func NewRecordingIdentity(ctx context.Context, unwrapper KeyUnwrapper) *RecordingIdentity { return &RecordingIdentity{ ctx: ctx, - keyFinder: keyFinder, + unwrapper: unwrapper, } } -// Unwrap uses the additional stanzas added by RecordingRecipient.Wrap in order to find a matching X25519 identity. +// Unwrap uses the additional stanzas added by [RecordingRecipient.Wrap] in order to find a matching RSA 4096 +// private key. func (i *RecordingIdentity) Unwrap(stanzas []*age.Stanza) ([]byte, error) { - var publicKeys [][]byte + var errs []error for _, stanza := range stanzas { if stanza.Type != RecordingStanza { continue @@ -65,55 +83,69 @@ func (i *RecordingIdentity) Unwrap(stanzas []*age.Stanza) ([]byte, error) { continue } - publicKeys = append(publicKeys, []byte(stanza.Args[0])) - } + fileKey, err := i.unwrapper.UnwrapKey(i.ctx, UnwrapInput{ + Rand: rand.Reader, + WrappedKey: stanza.Body, + Fingerprint: stanza.Args[0], + Opts: &rsa.OAEPOptions{ + Hash: crypto.SHA256, + Label: []byte(oaepLabel), + }, + }) + if err != nil { + if !trace.IsNotFound(err) { + errs = append(errs, err) + } + continue + } - pair, err := i.keyFinder.FindDecryptionKey(i.ctx, publicKeys...) - if err != nil { - return nil, trace.Wrap(err) + return fileKey, nil } - identity, err := age.ParseX25519Identity(string(pair.PrivateKey)) - if err != nil { - return nil, trace.Wrap(err) + if len(errs) == 0 { + return nil, trace.Errorf("could not find an accessible decrypter for unwrapping") } - - return identity.Unwrap(stanzas) + return nil, trace.NewAggregate(errs...) } -// RecordingRecipient adds the public key to the stanzas generated by the default age.X25519Recipient +// RecordingRecipient wraps file keys using an RSA 40960public key. type RecordingRecipient struct { - *age.X25519Recipient + *rsa.PublicKey } -// ParseRecordingRecipient parses an Bech32 encoded age X25519 public key into a RecordingRecipient. -func ParseRecordingRecipient(s string) (*RecordingRecipient, error) { - recipient, err := age.ParseX25519Recipient(s) +// ParseRecordingRecipient parses a PEM encoded RSA 4096 public key into a RecordingRecipient. +func ParseRecordingRecipient(in []byte) (*RecordingRecipient, error) { + pubKey, err := keys.ParsePublicKey(in) if err != nil { return nil, trace.Wrap(err) } - return &RecordingRecipient{X25519Recipient: recipient}, nil + rsaKey, ok := pubKey.(*rsa.PublicKey) + if !ok { + return nil, trace.BadParameter("recording encryption key must be a public RSA 4096") + } + + return &RecordingRecipient{PublicKey: rsaKey}, nil } -// Wrap a fileKey using the wrapped X2519Recipient. An additional stanza containing the bech32 encoded X25519 -// public key will be created to enable lookups during Unwrap. +// Wrap a fileKey using an RSA public key. The fingerprint of the key will be included in the stanza +// to aid in fetching the correct private key during [Unwrap]. func (r *RecordingRecipient) Wrap(fileKey []byte) ([]*age.Stanza, error) { - stanzas, err := r.X25519Recipient.Wrap(fileKey) + cipher, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, r.PublicKey, fileKey, []byte(oaepLabel)) if err != nil { return nil, trace.Wrap(err) } - // a new stanza has to be added because modifying the original stanza and returning it to "normal" during - // Unwrap fails due to MAC errors - for _, stanza := range stanzas { - if stanza.Type == X25519Stanza { - stanzas = append(stanzas, &age.Stanza{ - Type: RecordingStanza, - Args: []string{r.String()}, - }) - } + fp, err := Fingerprint(r.PublicKey) + if err != nil { + return nil, trace.Wrap(err) } - return stanzas, nil + return []*age.Stanza{ + { + Type: RecordingStanza, + Args: []string{fp}, + Body: cipher, + }, + }, nil } diff --git a/lib/auth/recordingencryption/age_test.go b/lib/auth/recordingencryption/age_test.go index a7329e22e0b1d..773e5c2324f96 100644 --- a/lib/auth/recordingencryption/age_test.go +++ b/lib/auth/recordingencryption/age_test.go @@ -18,8 +18,6 @@ package recordingencryption_test import ( "bytes" - "context" - "errors" "io" "testing" @@ -32,13 +30,13 @@ import ( func TestRecordingAgePlugin(t *testing.T) { ctx := t.Context() - keyFinder := newFakeKeyFinder() - recordingIdentity := recordingencryption.NewRecordingIdentity(ctx, keyFinder) + keyStore := newFakeKeyStore(types.PrivateKeyType_RAW) + recordingIdentity := recordingencryption.NewRecordingIdentity(ctx, keyStore) - ident, err := keyFinder.generateIdentity() + _, pubKey, err := keyStore.createKey() require.NoError(t, err) - recipient, err := recordingencryption.ParseRecordingRecipient(ident.Recipient().String()) + recipient, err := recordingencryption.ParseRecordingRecipient(pubKey) require.NoError(t, err) out := bytes.NewBuffer(nil) @@ -53,6 +51,7 @@ func TestRecordingAgePlugin(t *testing.T) { err = writer.Close() require.NoError(t, err) + // decrypted text should match original msg reader, err := age.Decrypt(out, recordingIdentity) require.NoError(t, err) plaintext, err := io.ReadAll(reader) @@ -60,11 +59,14 @@ func TestRecordingAgePlugin(t *testing.T) { require.Equal(t, msg, plaintext) - // running the same test with the raw recipient should fail because the - // the extra stanza added by RecordingRecipient won't be present and - // the private key won't be found + // running the same test with an unknown public key should fail + _, pubKey, err = keyStore.genKeys() + require.NoError(t, err) + + recipient, err = recordingencryption.ParseRecordingRecipient(pubKey) + require.NoError(t, err) out.Reset() - writer, err = age.Encrypt(out, ident.Recipient()) + writer, err = age.Encrypt(out, recipient) require.NoError(t, err) _, err = writer.Write(msg) require.NoError(t, err) @@ -73,39 +75,3 @@ func TestRecordingAgePlugin(t *testing.T) { _, err = age.Decrypt(out, recordingIdentity) require.Error(t, err) } - -type fakeKeyFinder struct { - keys map[string]string -} - -func newFakeKeyFinder() *fakeKeyFinder { - return &fakeKeyFinder{ - keys: make(map[string]string), - } -} - -func (f *fakeKeyFinder) FindDecryptionKey(ctx context.Context, publicKeys ...[]byte) (*types.EncryptionKeyPair, error) { - for _, pubKey := range publicKeys { - key, ok := f.keys[string(pubKey)] - if !ok { - continue - } - - return &types.EncryptionKeyPair{ - PrivateKey: []byte(key), - PublicKey: pubKey, - }, nil - } - - return nil, errors.New("no accessible decryption key found") -} - -func (f *fakeKeyFinder) generateIdentity() (*age.X25519Identity, error) { - ident, err := age.GenerateX25519Identity() - if err != nil { - return nil, err - } - - f.keys[ident.Recipient().String()] = ident.String() - return ident, nil -} diff --git a/lib/auth/recordingencryption/encryptedio.go b/lib/auth/recordingencryption/encryptedio.go index e441d93c19b9a..676e957246259 100644 --- a/lib/auth/recordingencryption/encryptedio.go +++ b/lib/auth/recordingencryption/encryptedio.go @@ -36,21 +36,21 @@ type SessionRecordingConfigGetter interface { // to provide encryption and decryption wrapping backed by cluster resources type EncryptedIO struct { srcGetter SessionRecordingConfigGetter - keyFinder DecryptionKeyFinder + unwrapper KeyUnwrapper } // NewEncryptedIO returns an EncryptedIO configured with the given SessionRecordingConfigGetter and // recordingencryption.DecryptionKeyFinder -func NewEncryptedIO(srcGetter SessionRecordingConfigGetter, decryptionKeyGetter DecryptionKeyFinder) (*EncryptedIO, error) { +func NewEncryptedIO(srcGetter SessionRecordingConfigGetter, unwrapper KeyUnwrapper) (*EncryptedIO, error) { switch { case srcGetter == nil: return nil, trace.BadParameter("SessionRecordingConfigGetter is required for EncryptedIO") - case decryptionKeyGetter == nil: + case unwrapper == nil: return nil, trace.BadParameter("DecryptionKeyFinder is required for EncryptedIO") } return &EncryptedIO{ srcGetter: srcGetter, - keyFinder: decryptionKeyGetter, + unwrapper: unwrapper, }, nil } @@ -71,7 +71,7 @@ func (e *EncryptedIO) WithEncryption(ctx context.Context, writer io.WriteCloser) // will dynamically search for an accessible decryption key using the provided recordingencryption.DecryptionKeyFinder // in order to perform decryption func (e *EncryptedIO) WithDecryption(ctx context.Context, reader io.Reader) (io.Reader, error) { - ident := NewRecordingIdentity(ctx, e.keyFinder) + ident := NewRecordingIdentity(ctx, e.unwrapper) r, err := age.Decrypt(reader, ident) if err != nil { return nil, trace.Wrap(err) @@ -105,7 +105,7 @@ func (s *EncryptionWrapper) WithEncryption(ctx context.Context, writer io.WriteC var recipients []age.Recipient for _, key := range s.config.GetEncryptionKeys() { - recipient, err := ParseRecordingRecipient(string(key.PublicKey)) + recipient, err := ParseRecordingRecipient(key.PublicKey) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/auth/recordingencryption/encryptedio_test.go b/lib/auth/recordingencryption/encryptedio_test.go index b274c2aeaac27..3ace3a06b1868 100644 --- a/lib/auth/recordingencryption/encryptedio_test.go +++ b/lib/auth/recordingencryption/encryptedio_test.go @@ -31,18 +31,18 @@ import ( func TestEncryptedIO(t *testing.T) { ctx := t.Context() - keyFinder := newFakeKeyFinder() - ident, err := keyFinder.generateIdentity() + keyStore := newFakeKeyStore(types.PrivateKeyType_RAW) + _, publicKey, err := keyStore.createKey() require.NoError(t, err) srcGetter, err := newFakeSRCGetter(true, []*types.AgeEncryptionKey{ { - PublicKey: []byte(ident.Recipient().String()), + PublicKey: publicKey, }, }) require.NoError(t, err) - encryptedIO, err := recordingencryption.NewEncryptedIO(srcGetter, keyFinder) + encryptedIO, err := recordingencryption.NewEncryptedIO(srcGetter, keyStore) require.NoError(t, err) out := bytes.NewBuffer(nil) @@ -74,7 +74,7 @@ func TestEncryptedIO(t *testing.T) { // wrapping encryption when encryption is disabled should return an ErrEncryptionDisabled srcGetter, err = newFakeSRCGetter(false, nil) require.NoError(t, err) - encryptedIO, err = recordingencryption.NewEncryptedIO(srcGetter, keyFinder) + encryptedIO, err = recordingencryption.NewEncryptedIO(srcGetter, keyStore) require.NoError(t, err) _, err = encryptedIO.WithEncryption(ctx, &writeCloser{Writer: out}) diff --git a/lib/auth/recordingencryption/manager.go b/lib/auth/recordingencryption/manager.go index 3527493d59566..e6c11b380cde6 100644 --- a/lib/auth/recordingencryption/manager.go +++ b/lib/auth/recordingencryption/manager.go @@ -19,21 +19,19 @@ package recordingencryption import ( "context" "crypto" - "crypto/rand" "crypto/sha256" - "encoding/hex" - "errors" + "crypto/x509" + "encoding/base64" "iter" "log/slog" - "slices" "time" - "filippo.io/age" "github.com/gravitational/trace" "github.com/gravitational/teleport" recordingencryptionv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/recordingencryption/v1" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/cryptosuites" @@ -102,12 +100,13 @@ type Manager struct { keyStore KeyStore } -// CreateSessionRecordingConfig creates a new session recording configuration. If encryption is enabled then the -// recording encryption resource will also be resolved. +// CreateSessionRecordingConfig creates a new session recording configuration. If encryption is enabled then an +// accessible encryption key pair will be confirmed. Either creating one if none exists, doing nothing if one is +// accessible, or returning an error if none are accessible. func (m *Manager) CreateSessionRecordingConfig(ctx context.Context, cfg types.SessionRecordingConfig) (sessionRecordingConfig types.SessionRecordingConfig, err error) { err = backend.RunWhileLocked(ctx, m.lockConfig, func(ctx context.Context) error { if cfg.GetEncrypted() { - encryption, err := m.resolveRecordingEncryption(ctx) + encryption, err := m.ensureRecordingEncryptionKey(ctx) if err != nil { return err } @@ -126,12 +125,13 @@ func (m *Manager) CreateSessionRecordingConfig(ctx context.Context, cfg types.Se return sessionRecordingConfig, trace.Wrap(err) } -// UpdateSessionRecordingConfig updates an existing session recording configuration. If encryption is enabled then -// the recording encryption resource will also be resolved. +// UpdateSessionRecordingConfig updates an existing session recording configuration. If encryption is enabled +// then an accessible encryption key pair will be confirmed. Either creating one if none exists, doing nothing +// if one is accessible, or returning an error if none are accessible. func (m *Manager) UpdateSessionRecordingConfig(ctx context.Context, cfg types.SessionRecordingConfig) (sessionRecordingConfig types.SessionRecordingConfig, err error) { err = backend.RunWhileLocked(ctx, m.lockConfig, func(ctx context.Context) error { if cfg.GetEncrypted() { - encryption, err := m.resolveRecordingEncryption(ctx) + encryption, err := m.ensureRecordingEncryptionKey(ctx) if err != nil { return err } @@ -151,11 +151,12 @@ func (m *Manager) UpdateSessionRecordingConfig(ctx context.Context, cfg types.Se } // UpsertSessionRecordingConfig creates a new session recording configuration or overwrites an existing one. If -// encryption is enabled then the recording encryption resource will also be resolved. +// encryption is enabled then an accessible encryption key pair will be confirmed. Either creating one if none +// exists, doing nothing if one is accessible, or returning an error if none are accessible. func (m *Manager) UpsertSessionRecordingConfig(ctx context.Context, cfg types.SessionRecordingConfig) (sessionRecordingConfig types.SessionRecordingConfig, err error) { err = backend.RunWhileLocked(ctx, m.lockConfig, func(ctx context.Context) error { if cfg.GetEncrypted() { - encryption, err := m.resolveRecordingEncryption(ctx) + encryption, err := m.ensureRecordingEncryptionKey(ctx) if err != nil { return err } @@ -179,15 +180,15 @@ func (m *Manager) SetCache(cache Cache) { m.cache = cache } -// ensureActiveRecordingEncryption returns the configured RecordingEncryption resource if it exists with active keys. If it does not, -// then the resource will be created or updated with a new active keypair. The bool return value indicates whether or not -// a new pair was provisioned. -func (m *Manager) ensureActiveRecordingEncryption(ctx context.Context) (*recordingencryptionv1.RecordingEncryption, bool, error) { +// ensureRecordingEncryptionKey returns the configured RecordingEncryption resource if it exists with an +// accessible key. If no keys exist, a new key pair will be provisioned. An error is returned if keys exist +// but none are accessible. +func (m *Manager) ensureRecordingEncryptionKey(ctx context.Context) (*recordingencryptionv1.RecordingEncryption, error) { persistFn := m.RecordingEncryption.UpdateRecordingEncryption encryption, err := m.RecordingEncryption.GetRecordingEncryption(ctx) if err != nil { if !trace.IsNotFound(err) { - return encryption, false, trace.Wrap(err) + return encryption, trace.Wrap(err) } encryption = &recordingencryptionv1.RecordingEncryption{ Spec: &recordingencryptionv1.RecordingEncryptionSpec{}, @@ -196,240 +197,81 @@ func (m *Manager) ensureActiveRecordingEncryption(ctx context.Context) (*recordi } activeKeys := encryption.GetSpec().ActiveKeys - - // no keys present, need to generate the initial active keypair if len(activeKeys) > 0 { - return encryption, false, nil - } - - keyEncryptionPair, err := m.keyStore.NewEncryptionKeyPair(ctx, cryptosuites.RecordingKeyWrapping) - if err != nil { - return encryption, false, trace.Wrap(err, "generating wrapping key") - } + for _, key := range activeKeys { + // fetch the decrypter to ensure we have access to it + if _, err := m.keyStore.GetDecrypter(ctx, key.RecordingEncryptionPair); err != nil { + fp, _ := fingerprintPEM(key.RecordingEncryptionPair.PublicKey) + m.logger.DebugContext(ctx, "key not accessible", "fingerprint", fp) + continue + } + return encryption, nil + } - ident, err := age.GenerateX25519Identity() - if err != nil { - return encryption, false, trace.Wrap(err, "generating age encryption key") + return nil, trace.AccessDenied("active key not accessible: %v", err) } - encryptedIdent, err := keyEncryptionPair.EncryptOAEP([]byte(ident.String())) + // no keys present, need to generate the initial active keypair + encryptionPair, err := m.keyStore.NewEncryptionKeyPair(ctx, cryptosuites.RecordingKeyWrapping) if err != nil { - return encryption, false, trace.Wrap(err, "wrapping encryption key") + return nil, trace.Wrap(err, "generating wrapping key") } wrappedKey := recordingencryptionv1.WrappedKey{ - KeyEncryptionPair: keyEncryptionPair, - RecordingEncryptionPair: &types.EncryptionKeyPair{ - PrivateKeyType: types.PrivateKeyType_RAW, - PrivateKey: encryptedIdent, - PublicKey: []byte(ident.Recipient().String()), - }, + RecordingEncryptionPair: encryptionPair, } encryption.Spec.ActiveKeys = []*recordingencryptionv1.WrappedKey{&wrappedKey} encryption, err = persistFn(ctx, encryption) if err != nil { - return encryption, false, trace.Wrap(err) - } - fp := sha256.Sum256(wrappedKey.RecordingEncryptionPair.PublicKey) - m.logger.InfoContext(ctx, "no active keys, generated initial recording encryption pair", "public_fingerprint", hex.EncodeToString(fp[:])) - return encryption, true, nil -} - -var errWaitingForKey = errors.New("waiting for key to be fulfilled") - -// getRecordingEncryptionKey returns the first active recording encryption key accessible to the configured key store. -func (m *Manager) getRecordingEncryptionKeyPair(ctx context.Context, keys []*recordingencryptionv1.WrappedKey) (*types.EncryptionKeyPair, error) { - var foundUnfulfilledKey bool - for _, key := range keys { - decrypter, err := m.keyStore.GetDecrypter(ctx, key.KeyEncryptionPair) - if err != nil { - continue - } - - // if we make it to this section the key is accessible to the current auth server - if key.RecordingEncryptionPair == nil { - foundUnfulfilledKey = true - continue - } - - decryptionKey, err := decrypter.Decrypt(rand.Reader, key.RecordingEncryptionPair.PrivateKey, nil) - if err != nil { - return nil, trace.Wrap(err, "decrypting known key") - } - - return &types.EncryptionKeyPair{ - PrivateKey: decryptionKey, - PublicKey: key.RecordingEncryptionPair.PublicKey, - }, nil - } - - if foundUnfulfilledKey { - return nil, trace.Wrap(errWaitingForKey) + return nil, trace.Wrap(err) } - - return nil, trace.NotFound("no accessible recording encryption pair found") + fp, _ := fingerprintPEM(encryptionPair.PublicKey) + m.logger.InfoContext(ctx, "no active keys, generated initial recording encryption pair", "public_fingerprint", fp) + return encryption, nil } -// resolveRecordingEncryption examines the current state of the RescordingEncryption resource and advances it to the -// next state on behalf of the current auth server. -// -// When no active recording encryption key pairs exist, the first pair will be generated and wrapped using a new key -// encryption pair generated by the Manager's keystore. -// -// When at least one active keypair exists but none are accessible to the Manager's keystore, a new key encryption pair -// will be generated and saved without a key encryption pair. This is an unfulfilled key that some other instance of -// Manager on another auth server will need to fulfill asynchronously. -// -// If at least one active key is accessible to the Manager's keystore, then unfulfilled keys (identified by missing -// recording encryption key pairs) will be fulfilled using their public key encryption keys. -// -// If there are no unfulfilled keys present, then nothing should be done. -func (m *Manager) resolveRecordingEncryption(ctx context.Context) (*recordingencryptionv1.RecordingEncryption, error) { - encryption, generatedKey, err := m.ensureActiveRecordingEncryption(ctx) +// UnwrapKey searches for the private key compatible with the provided public key fingerprint and uses it to unwrap +// a wrapped file key. +func (m *Manager) UnwrapKey(ctx context.Context, in UnwrapInput) ([]byte, error) { + encryption, err := m.cache.GetRecordingEncryption(ctx) if err != nil { return nil, trace.Wrap(err) } - if generatedKey { - m.logger.DebugContext(ctx, "created initial recording encryption key") - return encryption, nil - } - + // TODO (eriktate): search rotated keys as well once rotation is implemented activeKeys := encryption.GetSpec().ActiveKeys - recordingEncryptionPair, err := m.getRecordingEncryptionKeyPair(ctx, activeKeys) - if err != nil { - if errors.Is(err, errWaitingForKey) { - // do nothing - return encryption, nil - } - - if trace.IsNotFound(err) { - m.logger.InfoContext(ctx, "no accessible recording encryption keys, posting new key to be fulfilled") - keypair, err := m.keyStore.NewEncryptionKeyPair(ctx, cryptosuites.RecordingKeyWrapping) - if err != nil { - return nil, trace.Wrap(err, "generating keypair for new wrapped key") - } - encryption.GetSpec().ActiveKeys = append(activeKeys, &recordingencryptionv1.WrappedKey{ - KeyEncryptionPair: keypair, - }) - - encryption, err = m.RecordingEncryption.UpdateRecordingEncryption(ctx, encryption) - return encryption, trace.Wrap(err, "updating session recording config") - } - - return nil, trace.Wrap(err) - } - - var shouldUpdate bool for _, key := range activeKeys { - if key.RecordingEncryptionPair != nil { + if key.GetRecordingEncryptionPair() == nil { continue } - encryptedKey, err := key.KeyEncryptionPair.EncryptOAEP(recordingEncryptionPair.PrivateKey) - if err != nil { - return encryption, trace.Wrap(err, "reencrypting decryption key") - } - - key.RecordingEncryptionPair = &types.EncryptionKeyPair{ - PrivateKey: encryptedKey, - PublicKey: recordingEncryptionPair.PublicKey, - } - - shouldUpdate = true - } - - if shouldUpdate { - m.logger.DebugContext(ctx, "fulfilling empty keys") - encryption, err = m.RecordingEncryption.UpdateRecordingEncryption(ctx, encryption) + activeFP, err := fingerprintPEM(key.RecordingEncryptionPair.PublicKey) if err != nil { - return encryption, trace.Wrap(err, "updating session recording config") - } - } - - return encryption, nil -} - -func (m *Manager) searchActiveKeys(ctx context.Context, activeKeys []*recordingencryptionv1.WrappedKey, publicKey []byte) (*types.EncryptionKeyPair, error) { - for _, key := range activeKeys { - if key.GetRecordingEncryptionPair() == nil { + m.logger.ErrorContext(ctx, "failed to fingerprint active public key", "error", err) continue } - if !slices.Equal(key.RecordingEncryptionPair.PublicKey, publicKey) { + if activeFP != in.Fingerprint { continue } - decrypter, err := m.keyStore.GetDecrypter(ctx, key.KeyEncryptionPair) + decrypter, err := m.keyStore.GetDecrypter(ctx, key.RecordingEncryptionPair) if err != nil { continue } - privateKey, err := decrypter.Decrypt(rand.Reader, key.RecordingEncryptionPair.PrivateKey, nil) + fileKey, err := decrypter.Decrypt(in.Rand, in.WrappedKey, in.Opts) if err != nil { return nil, trace.Wrap(err) } - return &types.EncryptionKeyPair{ - PrivateKey: privateKey, - PublicKey: key.RecordingEncryptionPair.PublicKey, - PrivateKeyType: key.RecordingEncryptionPair.PrivateKeyType, - }, nil - } - - return nil, trace.NotFound("no accessible decryption key found") -} - -// FindDecryptionKey returns the first accessible decryption key that matches one of the given public keys. -func (m *Manager) FindDecryptionKey(ctx context.Context, publicKeys ...[]byte) (*types.EncryptionKeyPair, error) { - encryption, err := m.cache.GetRecordingEncryption(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - - // TODO (eriktate): search rotated keys as well once rotation is implemented - activeKeys := encryption.GetSpec().ActiveKeys - if len(publicKeys) == 0 { - return m.searchActiveKeys(ctx, activeKeys, nil) - } - - for _, publicKey := range publicKeys { - found, err := m.searchActiveKeys(ctx, activeKeys, publicKey) - if err != nil { - if trace.IsNotFound(err) { - continue - } - - if !slices.Equal(found.PublicKey, publicKey) { - continue - } - - decrypter, err := m.keyStore.GetDecrypter(ctx, found) - if err != nil { - if !trace.IsNotFound(err) { - m.logger.ErrorContext(ctx, "could not get decrypter from key store", "error", err) - } - continue - } - - privateKey, err := decrypter.Decrypt(rand.Reader, found.PrivateKey, nil) - if err != nil { - return nil, trace.Wrap(err) - } - - return &types.EncryptionKeyPair{ - PrivateKey: privateKey, - PublicKey: found.PublicKey, - PrivateKeyType: found.PrivateKeyType, - }, nil - } - - return found, nil + return fileKey, nil } - return nil, trace.NotFound("no accessible decryption key found") + return nil, trace.NotFound("no accessible decrypter found") } +// Watch for changes in the recording_encryption resource and respond by ensuring access to keys. func (m *Manager) Watch(ctx context.Context, events types.Events) (err error) { // shouldRetryAfterJitterFn waits at most 5 seconds and returns a bool specifying whether or not // execution should continue @@ -504,17 +346,23 @@ func (m *Manager) handleEvent(ctx context.Context, ev types.Event, shouldRetryFn sessionRecordingConfig, err := m.GetSessionRecordingConfig(ctx) if err != nil { m.logger.ErrorContext(ctx, "failed to retrieve session_recording_config, retrying", "error", err) - return err + return trace.Wrap(err) } if !sessionRecordingConfig.GetEncrypted() { return nil } - if _, err := m.resolveRecordingEncryption(ctx); err != nil { + encryption, err := m.ensureRecordingEncryptionKey(ctx) + if err != nil { m.logger.ErrorContext(ctx, "failed to resolve recording encryption keys, retrying", "retry", retry, "retries_left", retries-retry, "error", err) + return trace.Wrap(err) + } - return err + if sessionRecordingConfig.SetEncryptionKeys(getAgeEncryptionKeys(encryption.GetSpec().ActiveKeys)) { + if _, err := m.ClusterConfigurationInternal.UpdateSessionRecordingConfig(ctx, sessionRecordingConfig); err != nil { + return trace.Wrap(err) + } } return nil @@ -546,3 +394,24 @@ func getAgeEncryptionKeys(keys []*recordingencryptionv1.WrappedKey) iter.Seq[*ty } } } + +// Fingerprint a public key for use in logging and as a cache key. +func Fingerprint(pubKey crypto.PublicKey) (string, error) { + derPub, err := x509.MarshalPKIXPublicKey(pubKey) + if err != nil { + return "", trace.Wrap(err) + } + + fp := sha256.Sum256(derPub) + return base64.StdEncoding.EncodeToString(fp[:]), nil +} + +// fingerprints a public RSA key encoded as PEM-wrapped PKIX. +func fingerprintPEM(pubKeyPEM []byte) (string, error) { + pubKey, err := keys.ParsePublicKey(pubKeyPEM) + if err != nil { + return "", trace.Wrap(err) + } + + return Fingerprint(pubKey) +} diff --git a/lib/auth/recordingencryption/manager_test.go b/lib/auth/recordingencryption/manager_test.go index 4b67f254bb6d2..6c746c1ba7a4e 100644 --- a/lib/auth/recordingencryption/manager_test.go +++ b/lib/auth/recordingencryption/manager_test.go @@ -19,16 +19,15 @@ package recordingencryption_test import ( "context" "crypto" + "crypto/rand" "crypto/rsa" - "crypto/x509" - "encoding/pem" + "crypto/sha256" "errors" "io" "sync" "testing" "time" - "filippo.io/age" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" @@ -45,23 +44,81 @@ import ( "github.com/gravitational/teleport/lib/utils/log/logtest" ) +// It takes forever to generate RSA4096 keys so we generate and cache a few to be used by the fakeKeyStore +// instead of actually generating a new key every time a test needs one. This cuts down flaky test execution +// time to ~20-30s instead of timing out at >10m +var cachedDecrypters = initDecrypters() + +func initDecrypters() []crypto.Decrypter { + var decrypters []crypto.Decrypter + for range 10 { + decrypter, err := cryptosuites.GenerateDecrypterWithAlgorithm(cryptosuites.RSA4096) + if err != nil { + panic("failed to generate RSA 4096 key") + } + + decrypters = append(decrypters, decrypter) + } + + return decrypters +} + type oaepDecrypter struct { crypto.Decrypter hash crypto.Hash } func (d oaepDecrypter) Decrypt(rand io.Reader, msg []byte, opts crypto.DecrypterOpts) ([]byte, error) { - return d.Decrypter.Decrypt(rand, msg, &rsa.OAEPOptions{ - Hash: d.hash, - }) + return d.Decrypter.Decrypt(rand, msg, opts) } type fakeKeyStore struct { keyType types.PrivateKeyType // abusing this field as a way to simulate different auth servers + keys map[string]crypto.Decrypter + + cacheIdx int +} + +func newFakeKeyStore(keyType types.PrivateKeyType) *fakeKeyStore { + return &fakeKeyStore{ + keys: make(map[string]crypto.Decrypter), + keyType: keyType, + } +} + +func (f *fakeKeyStore) genKeys() (crypto.Decrypter, []byte, error) { + decrypter := cachedDecrypters[f.cacheIdx] + f.cacheIdx += 1 + if f.cacheIdx >= len(cachedDecrypters) { + f.cacheIdx = 0 + } + + publicKey, err := keys.MarshalPublicKey(decrypter.Public()) + if err != nil { + return nil, nil, err + } + + return decrypter, publicKey, nil +} + +func (f *fakeKeyStore) createKey() (crypto.Decrypter, []byte, error) { + decrypter, publicKey, err := f.genKeys() + if err != nil { + return nil, nil, err + } + + fp, err := recordingencryption.Fingerprint(decrypter.Public()) + if err != nil { + return nil, nil, err + } + + f.keys[fp] = decrypter + + return decrypter, publicKey, nil } func (f *fakeKeyStore) NewEncryptionKeyPair(ctx context.Context, purpose cryptosuites.KeyPurpose) (*types.EncryptionKeyPair, error) { - decrypter, err := cryptosuites.GenerateDecrypterWithAlgorithm(cryptosuites.RSA2048) + decrypter, pubPEM, err := f.createKey() if err != nil { return nil, err } @@ -71,19 +128,14 @@ func (f *fakeKeyStore) NewEncryptionKeyPair(ctx context.Context, purpose cryptos return nil, errors.New("expected RSA private key") } - privatePEM := pem.EncodeToMemory(&pem.Block{ - Type: keys.PKCS1PrivateKeyType, - Bytes: x509.MarshalPKCS1PrivateKey(private), - }) - - publicPEM := pem.EncodeToMemory(&pem.Block{ - Type: keys.PKCS1PublicKeyType, - Bytes: x509.MarshalPKCS1PublicKey(&private.PublicKey), - }) + privatePEM, err := keys.MarshalDecrypter(private) + if err != nil { + return nil, err + } return &types.EncryptionKeyPair{ PrivateKey: privatePEM, - PublicKey: publicPEM, + PublicKey: pubPEM, PrivateKeyType: f.keyType, Hash: uint32(crypto.SHA256), }, nil @@ -94,14 +146,30 @@ func (f *fakeKeyStore) GetDecrypter(ctx context.Context, keyPair *types.Encrypti return nil, errors.New("could not access decrypter") } - block, _ := pem.Decode(keyPair.PrivateKey) + private, err := keys.ParsePrivateKey(keyPair.PrivateKey) + if err != nil { + return nil, err + } - private, err := x509.ParsePKCS1PrivateKey(block.Bytes) + decrypter, ok := private.Signer.(crypto.Decrypter) + if !ok { + return nil, errors.New("private key should have been a decrypter") + } + return oaepDecrypter{Decrypter: decrypter, hash: crypto.Hash(keyPair.Hash)}, nil +} + +func (f *fakeKeyStore) UnwrapKey(ctx context.Context, in recordingencryption.UnwrapInput) ([]byte, error) { + decrypter, ok := f.keys[in.Fingerprint] + if !ok { + return nil, trace.NotFound("no accessible decryption key found") + } + + fileKey, err := decrypter.Decrypt(in.Rand, in.WrappedKey, in.Opts) if err != nil { return nil, err } - return oaepDecrypter{Decrypter: private, hash: crypto.Hash(keyPair.Hash)}, nil + return fileKey, nil } func newLocalBackend( @@ -136,14 +204,14 @@ func newManagerConfig(t *testing.T, bk backend.Backend, keyType types.PrivateKey Backend: recordingEncryptionService, Cache: recordingEncryptionService, ClusterConfig: clusterConfigService, - KeyStore: &fakeKeyStore{keyType: keyType}, + KeyStore: newFakeKeyStore(keyType), Logger: logtest.NewLogger(), LockConfig: backend.RunWhileLockedConfig{ LockConfiguration: backend.LockConfiguration{ Backend: bk, LockNameComponents: []string{"recording_encryption"}, - TTL: 5 * time.Second, - RetryInterval: 10 * time.Millisecond, + TTL: 10 * time.Second, + RetryInterval: 100 * time.Millisecond, }, }, } @@ -201,9 +269,6 @@ func TestCreateUpdateSessionRecordingConfig(t *testing.T) { require.NotNil(t, activeKeys[0].RecordingEncryptionPair) require.NotEmpty(t, activeKeys[0].RecordingEncryptionPair.PrivateKey) require.NotEmpty(t, activeKeys[0].RecordingEncryptionPair.PublicKey) - require.NotNil(t, activeKeys[0].KeyEncryptionPair) - require.NotEmpty(t, activeKeys[0].KeyEncryptionPair.PrivateKey) - require.NotEmpty(t, activeKeys[0].KeyEncryptionPair.PublicKey) // update should change nothing src, err = manager.UpdateSessionRecordingConfig(ctx, src) @@ -221,12 +286,13 @@ func TestResolveRecordingEncryption(t *testing.T) { // SETUP ctx, bk := newLocalBackend(t) - managerAType := types.PrivateKeyType_RAW - managerBType := types.PrivateKeyType_AWS_KMS + managerABType := types.PrivateKeyType_AWS_KMS + managerCType := types.PrivateKeyType_GCP_KMS - configA := newManagerConfig(t, bk, managerAType) + configA := newManagerConfig(t, bk, managerABType) configB := configA - configB.KeyStore = &fakeKeyStore{managerBType} + configC := configA + configC.KeyStore = newFakeKeyStore(managerCType) managerA, err := recordingencryption.NewManager(configA) require.NoError(t, err) @@ -234,89 +300,52 @@ func TestResolveRecordingEncryption(t *testing.T) { managerB, err := recordingencryption.NewManager(configB) require.NoError(t, err) + managerC, err := recordingencryption.NewManager(configC) + require.NoError(t, err) + service := configA.Backend // TEST // CASE: service A first evaluation initializes recording encryption resource encryption, src, err := resolve(ctx, service, managerA) require.NoError(t, err) - activeKeys := encryption.GetSpec().GetActiveKeys() + initialKeys := encryption.GetSpec().GetActiveKeys() - require.Len(t, activeKeys, 1) + require.Len(t, initialKeys, 1) require.Len(t, src.GetEncryptionKeys(), 1) - firstKey := activeKeys[0] - - // should generate a wrapped key with the initial recording encryption pair - require.NotNil(t, firstKey.KeyEncryptionPair) - require.NotNil(t, firstKey.RecordingEncryptionPair) + key := initialKeys[0] + require.Equal(t, key.RecordingEncryptionPair.PublicKey, src.GetEncryptionKeys()[0].PublicKey) + require.NotNil(t, key.RecordingEncryptionPair) - // CASE: service B should generate an unfulfilled key since there's an existing recording encryption resource + // CASE: service B should have access to the same key encryption, src, err = resolve(ctx, service, managerB) require.NoError(t, err) - activeKeys = encryption.GetSpec().ActiveKeys - require.Len(t, activeKeys, 2) - require.Len(t, src.GetEncryptionKeys(), 1) - for _, key := range activeKeys { - require.NotNil(t, key.KeyEncryptionPair) - if key.KeyEncryptionPair.PrivateKeyType == managerAType { - require.NotNil(t, key.RecordingEncryptionPair) - } else { - require.Nil(t, key.RecordingEncryptionPair) - } - } - - // service B re-evaluting with an unfulfilled key should do nothing - encryption, src, err = resolve(ctx, service, managerB) - require.NoError(t, err) - activeKeys = encryption.GetSpec().ActiveKeys - require.Len(t, activeKeys, 2) + activeKeys := encryption.GetSpec().ActiveKeys require.Len(t, src.GetEncryptionKeys(), 1) - for _, key := range activeKeys { - require.NotNil(t, key.KeyEncryptionPair) - if key.KeyEncryptionPair.PrivateKeyType == managerAType { - require.NotNil(t, key.RecordingEncryptionPair) - } else { - require.Nil(t, key.RecordingEncryptionPair) - } - } + require.Equal(t, key.RecordingEncryptionPair.PublicKey, src.GetEncryptionKeys()[0].PublicKey) + require.ElementsMatch(t, initialKeys, activeKeys) - // CASE: service A evaluation should fulfill service B's key - encryption, src, err = resolve(ctx, service, managerA) - require.NoError(t, err) - activeKeys = encryption.GetSpec().ActiveKeys - require.Len(t, activeKeys, 2) - require.Len(t, src.GetEncryptionKeys(), 1) - for _, key := range activeKeys { - require.NotNil(t, key.KeyEncryptionPair) - require.NotNil(t, key.RecordingEncryptionPair) - } + // service C should error without access to the current key + _, _, err = resolve(ctx, service, managerC) + require.Error(t, err) } func TestResolveRecordingEncryptionConcurrent(t *testing.T) { // SETUP ctx, bk := newLocalBackend(t) - managerAType := types.PrivateKeyType_RAW - managerBType := types.PrivateKeyType_AWS_KMS - serviceCType := types.PrivateKeyType_GCP_KMS - - configA := newManagerConfig(t, bk, managerAType) - configB := configA - configB.KeyStore = &fakeKeyStore{managerBType} - configC := configA - configC.KeyStore = &fakeKeyStore{serviceCType} - recordingEncryptionService := configA.Backend - managerA, err := recordingencryption.NewManager(configA) + config := newManagerConfig(t, bk, types.PrivateKeyType_RAW) + managerA, err := recordingencryption.NewManager(config) require.NoError(t, err) - managerB, err := recordingencryption.NewManager(configB) + managerB, err := recordingencryption.NewManager(config) require.NoError(t, err) - serviceC, err := recordingencryption.NewManager(configC) + serviceC, err := recordingencryption.NewManager(config) require.NoError(t, err) - service := configA.Backend + service := config.Backend resolveFn := func(manager *recordingencryption.Manager, wg *sync.WaitGroup) { wg.Add(1) go func() { @@ -326,78 +355,70 @@ func TestResolveRecordingEncryptionConcurrent(t *testing.T) { }() } + // it should be safe for multiple services to resolve encryption keys concurrently wg := sync.WaitGroup{} resolveFn(managerA, &wg) resolveFn(managerB, &wg) resolveFn(serviceC, &wg) wg.Wait() - encryption, err := recordingEncryptionService.GetRecordingEncryption(ctx) + encryption, err := service.GetRecordingEncryption(ctx) require.NoError(t, err) activeKeys := encryption.GetSpec().ActiveKeys - // each service should have an active wrapped key - require.Len(t, activeKeys, 3) - var fulfilledKeys int - for _, key := range activeKeys { - // all wrapped keys should have KeyEncryptionPairs - require.NotNil(t, key.KeyEncryptionPair) - require.NotEmpty(t, key.KeyEncryptionPair.PublicKey) - require.NotEmpty(t, key.KeyEncryptionPair.PrivateKey) - - if key.RecordingEncryptionPair != nil { - fulfilledKeys += 1 - } - } - - // only the first service to run should have a fulfilled wrapped key - require.Equal(t, 1, fulfilledKeys) + // each service should share a single active key + require.Len(t, activeKeys, 1) + require.NotNil(t, activeKeys[0].RecordingEncryptionPair) + require.NotEmpty(t, activeKeys[0].RecordingEncryptionPair.PrivateKey) + require.NotEmpty(t, activeKeys[0].RecordingEncryptionPair.PublicKey) + require.Equal(t, types.PrivateKeyType_RAW, activeKeys[0].RecordingEncryptionPair.PrivateKeyType) } -func TestFindDecryptionKeyFromActiveKeys(t *testing.T) { +func TestUnwrapKey(t *testing.T) { // SETUP ctx, bk := newLocalBackend(t) - keyTypeA := types.PrivateKeyType_RAW - keyTypeB := types.PrivateKeyType_AWS_KMS + keyType := types.PrivateKeyType_RAW - configA := newManagerConfig(t, bk, keyTypeA) - configB := configA - configB.KeyStore = &fakeKeyStore{keyTypeB} - managerA, err := recordingencryption.NewManager(configA) + config := newManagerConfig(t, bk, keyType) + manager, err := recordingencryption.NewManager(config) require.NoError(t, err) - managerB, err := recordingencryption.NewManager(configB) + service := config.Backend + _, _, err = resolve(ctx, service, manager) require.NoError(t, err) - service := configA.Backend - _, _, err = resolve(ctx, service, managerA) + src, err := manager.GetSessionRecordingConfig(ctx) require.NoError(t, err) - encryption, _, err := resolve(ctx, service, managerB) - require.NoError(t, err) + encryptionKeys := src.GetEncryptionKeys() + require.Len(t, encryptionKeys, 1) + pubKeyPEM := encryptionKeys[0].PublicKey - activeKeys := encryption.GetSpec().ActiveKeys - require.Len(t, activeKeys, 2) - pubKey := activeKeys[0].RecordingEncryptionPair.PublicKey + pubKey, err := keys.ParsePublicKey(pubKeyPEM) + require.NoError(t, err) - // fail to find private key for manager B because it is waiting for key fulfillment - _, err = managerB.FindDecryptionKey(ctx, pubKey) - require.Error(t, err) + rsaPubKey, ok := pubKey.(*rsa.PublicKey) + require.True(t, ok) - _, _, err = resolve(ctx, service, managerA) + fileKey := []byte("test_file_key") + label := []byte("test_label") + wrappedKey, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, rsaPubKey, fileKey, label) require.NoError(t, err) - // find private key for manager A because it provisioned the key - decryptionPair, err := managerA.FindDecryptionKey(ctx, pubKey) + fp, err := recordingencryption.Fingerprint(pubKey) require.NoError(t, err) - ident, err := age.ParseX25519Identity(string(decryptionPair.PrivateKey)) - require.NoError(t, err) - require.Equal(t, ident.Recipient().String(), string(pubKey)) - // find private key for manager B after fulfillment - decryptionPair, err = managerB.FindDecryptionKey(ctx, pubKey) - require.NoError(t, err) - ident, err = age.ParseX25519Identity(string(decryptionPair.PrivateKey)) + unwrapInput := recordingencryption.UnwrapInput{ + Fingerprint: fp, + WrappedKey: wrappedKey, + Rand: rand.Reader, + Opts: &rsa.OAEPOptions{ + Hash: crypto.SHA256, + Label: label, + }, + } + unwrappedKey, err := manager.UnwrapKey(ctx, unwrapInput) require.NoError(t, err) - require.Equal(t, ident.Recipient().String(), string(pubKey)) + + require.Equal(t, fileKey, unwrappedKey) } diff --git a/lib/cryptosuites/internal/rsa/rsa.go b/lib/cryptosuites/internal/rsa/rsa.go index 2203bd22bf05e..5c7b8cb915f05 100644 --- a/lib/cryptosuites/internal/rsa/rsa.go +++ b/lib/cryptosuites/internal/rsa/rsa.go @@ -42,12 +42,19 @@ func GenerateKey() (*rsa.PrivateKey, error) { return getOrGenerateRSAPrivateKey() } +// GenerateKey4096 generates a 4096-bit RSA private key meant for use in asymmetric encryption use cases such as +// encrypted session recordings. It is exposed as a separate function from [GenerateKey] so that the precomputed +// keys optimization used for sign/verify use cases does not have to be extended to support mixed key sizes. +func GenerateKey4096() (*rsa.PrivateKey, error) { + return generateRSAPrivateKey(4096) +} + func getOrGenerateRSAPrivateKey() (*rsa.PrivateKey, error) { select { case k := <-precomputedKeys: return k, nil default: - rsaKeyPair, err := generateRSAPrivateKey() + rsaKeyPair, err := generateRSAPrivateKey(constants.RSAKeySize) if err != nil { return nil, err } @@ -55,15 +62,15 @@ func getOrGenerateRSAPrivateKey() (*rsa.PrivateKey, error) { } } -func generateRSAPrivateKey() (*rsa.PrivateKey, error) { +func generateRSAPrivateKey(bits int) (*rsa.PrivateKey, error) { //nolint:forbidigo // This is the one function allowed to generate RSA keys. - return rsa.GenerateKey(rand.Reader, constants.RSAKeySize) + return rsa.GenerateKey(rand.Reader, bits) } func precomputeKeys() { const backoff = time.Second * 30 for { - rsaPrivateKey, err := generateRSAPrivateKey() + rsaPrivateKey, err := generateRSAPrivateKey(constants.RSAKeySize) if err != nil { log.ErrorContext(context.Background(), "Failed to precompute key pair, retrying (this might be a bug).", slog.Any("error", err), slog.Duration("backoff", backoff)) @@ -98,7 +105,7 @@ func generateTestKeys() <-chan *rsa.PrivateKey { // Generate each key in a separate goroutine to take advantage of // multiple cores if possible. go func() { - private, err := generateRSAPrivateKey() + private, err := generateRSAPrivateKey(constants.RSAKeySize) if err != nil { // Use only in tests. Safe to panic. panic(err) diff --git a/lib/cryptosuites/suites.go b/lib/cryptosuites/suites.go index b49f3f92a0a6c..ae8b6bf638560 100644 --- a/lib/cryptosuites/suites.go +++ b/lib/cryptosuites/suites.go @@ -144,6 +144,8 @@ const ( // RSA2048 represents RSA 2048-bit keys. RSA2048 + // RSA4096 represents RSA 4096-bit keys. + RSA4096 // ECDSAP256 represents ECDSA keys using NIST curve P-256. ECDSAP256 // Ed25519 represents Ed25519 keys. @@ -159,6 +161,8 @@ func (a Algorithm) String() string { return "algorithm unspecified" case RSA2048: return "RSA2048" + case RSA4096: + return "RSA4096" case ECDSAP256: return "ECDSAP256" case Ed25519: @@ -212,7 +216,7 @@ var ( AWSRACATLS: ECDSAP256, BoundKeypairJoining: Ed25519, BoundKeypairCAJWT: ECDSAP256, - RecordingKeyWrapping: RSA2048, + RecordingKeyWrapping: RSA4096, } // balancedV1 strikes a balance between security, compatibility, and @@ -247,7 +251,7 @@ var ( AWSRACATLS: ECDSAP256, BoundKeypairJoining: Ed25519, BoundKeypairCAJWT: Ed25519, - RecordingKeyWrapping: RSA2048, + RecordingKeyWrapping: RSA4096, } // fipsv1 is an algorithm suite tailored for FIPS compliance. It is based on @@ -283,7 +287,7 @@ var ( AWSRACATLS: ECDSAP256, BoundKeypairJoining: ECDSAP256, BoundKeypairCAJWT: ECDSAP256, - RecordingKeyWrapping: RSA2048, + RecordingKeyWrapping: RSA4096, } // hsmv1 in an algorithm suite tailored for clusters using an HSM or KMS @@ -321,7 +325,7 @@ var ( AWSRACATLS: ECDSAP256, BoundKeypairJoining: Ed25519, BoundKeypairCAJWT: ECDSAP256, - RecordingKeyWrapping: RSA2048, + RecordingKeyWrapping: RSA4096, } allSuites = map[types.SignatureAlgorithmSuite]suite{ @@ -470,6 +474,8 @@ func GenerateKeyWithAlgorithm(alg Algorithm) (crypto.Signer, error) { switch alg { case RSA2048: return generateRSA2048() + case RSA4096: + return generateRSA4096() case ECDSAP256: return generateECDSAP256() case Ed25519: @@ -482,10 +488,10 @@ func GenerateKeyWithAlgorithm(alg Algorithm) (crypto.Signer, error) { // GenerateDecrypterWithAlgorithm generates a new cryptographic keypair with the given algorithm meant for decryption. func GenerateDecrypterWithAlgorithm(alg Algorithm) (crypto.Decrypter, error) { switch alg { - case RSA2048: - return generateRSA2048() + case RSA4096: + return generateRSA4096() default: - return nil, trace.BadParameter("unsupported key algorithm %v", alg) + return nil, trace.BadParameter("unsupported decryption key algorithm %v", alg) } } @@ -505,6 +511,11 @@ func generateRSA2048() (*rsa.PrivateKey, error) { return key, trace.Wrap(err) } +func generateRSA4096() (*rsa.PrivateKey, error) { + key, err := internalrsa.GenerateKey4096() + return key, trace.Wrap(err) +} + func generateECDSAP256() (*ecdsa.PrivateKey, error) { key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) return key, trace.Wrap(err)