Skip to content
Merged
22 changes: 0 additions & 22 deletions service/internal/security/crypto_provider.go
Original file line number Diff line number Diff line change
@@ -1,30 +1,8 @@
package security

import (
"crypto"
"crypto/elliptic"
)

const (
// Key agreement along P-256
AlgorithmECP256R1 = "ec:secp256r1"
// Used for encryption with RSA of the KAO
AlgorithmRSA2048 = "rsa:2048"
)

type CryptoProvider interface {
// Gets some KID associated with a given algorithm.
// Returns empty string if none are found.
FindKID(alg string) string
RSAPublicKey(keyID string) (string, error)
RSAPublicKeyAsJSON(keyID string) (string, error)
RSADecrypt(hash crypto.Hash, keyID string, keyLabel string, ciphertext []byte) ([]byte, error)
ECDecrypt(keyID string, ephemeralPublicKey, ciphertext []byte) ([]byte, error)

ECPublicKey(keyID string) (string, error)
ECCertificate(keyID string) (string, error)
GenerateNanoTDFSymmetricKey(kasKID string, ephemeralPublicKeyBytes []byte, curve elliptic.Curve) ([]byte, error)
GenerateEphemeralKasKeys() (any, []byte, error)
GenerateNanoTDFSessionKey(privateKeyHandle any, ephemeralPublicKey []byte) ([]byte, error)
Close()
}
71 changes: 50 additions & 21 deletions service/internal/security/in_process_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,18 @@ func convertPEMToJWK(_ string) (string, error) {

// InProcessProvider adapts a CryptoProvider to the SecurityProvider interface
type InProcessProvider struct {
cryptoProvider CryptoProvider
cryptoProvider *StandardCrypto
logger *slog.Logger
defaultKeys map[string]bool
legacyKeys map[string]bool
}

// KeyDetailsAdapter adapts CryptoProvider to KeyDetails
type KeyDetailsAdapter struct {
id trust.KeyIdentifier
algorithm string
legacy bool
cryptoProvider CryptoProvider
cryptoProvider *StandardCrypto
}

// Mode returns the mode of the key details
Expand Down Expand Up @@ -174,10 +176,22 @@ func (k *KeyDetailsAdapter) ExportCertificate(_ context.Context) (string, error)
}

// NewSecurityProviderAdapter creates a new adapter that implements SecurityProvider using a CryptoProvider
func NewSecurityProviderAdapter(cryptoProvider CryptoProvider) trust.KeyService {
func NewSecurityProviderAdapter(cryptoProvider *StandardCrypto, defaultKeys, legacyKeys []string) trust.KeyService {
legacyKeysMap := make(map[string]bool, len(legacyKeys))
for _, key := range legacyKeys {
legacyKeysMap[key] = true
}

defaultKeysMap := make(map[string]bool, len(defaultKeys))
for _, key := range defaultKeys {
defaultKeysMap[key] = true
}

return &InProcessProvider{
cryptoProvider: cryptoProvider,
logger: slog.Default(),
defaultKeys: defaultKeysMap,
legacyKeys: legacyKeysMap,
}
}

Expand All @@ -192,18 +206,26 @@ func (a *InProcessProvider) WithLogger(logger *slog.Logger) *InProcessProvider {
return a
}

// FindKeyByAlgorithm finds a key by algorithm using the underlying CryptoProvider
func (a *InProcessProvider) FindKeyByAlgorithm(_ context.Context, algorithm string, _ bool) (trust.KeyDetails, error) {
// FindKeyByAlgorithm finds a key by algorithm using the underlying CryptoProvider.
// This will only return default keys if legacy is false.
// If legacy is true, it will return the first legacy key found that matches the algorithm.
func (a *InProcessProvider) FindKeyByAlgorithm(_ context.Context, algorithm string, legacy bool) (trust.KeyDetails, error) {
// Get the key ID for this algorithm
kid := a.cryptoProvider.FindKID(algorithm)
if kid == "" {
kids, err := a.cryptoProvider.ListKIDsByAlgorithm(algorithm)
if err != nil || len(kids) == 0 {
return nil, ErrCertNotFound
}
return &KeyDetailsAdapter{
id: trust.KeyIdentifier(kid),
algorithm: algorithm,
cryptoProvider: a.cryptoProvider,
}, nil
for _, kid := range kids {
if legacy && a.legacyKeys[kid] || !legacy && a.defaultKeys[kid] {
return &KeyDetailsAdapter{
id: trust.KeyIdentifier(kid),
algorithm: algorithm,
cryptoProvider: a.cryptoProvider,
legacy: legacy,
}, nil
}
}
return nil, ErrCertNotFound
}

// FindKeyByID finds a key by ID
Expand All @@ -216,7 +238,7 @@ func (a *InProcessProvider) FindKeyByID(_ context.Context, id trust.KeyIdentifie
return &KeyDetailsAdapter{
id: id,
algorithm: alg,
legacy: false,
legacy: a.legacyKeys[string(id)],
cryptoProvider: a.cryptoProvider,
}, nil
}
Expand All @@ -225,7 +247,7 @@ func (a *InProcessProvider) FindKeyByID(_ context.Context, id trust.KeyIdentifie
return &KeyDetailsAdapter{
id: id,
algorithm: alg,
legacy: false,
legacy: a.legacyKeys[string(id)],
cryptoProvider: a.cryptoProvider,
}, nil
}
Expand All @@ -241,12 +263,19 @@ func (a *InProcessProvider) ListKeys(_ context.Context) ([]trust.KeyDetails, err

// Try to find keys for known algorithms
for _, alg := range []string{AlgorithmRSA2048, AlgorithmECP256R1} {
if kid := a.cryptoProvider.FindKID(alg); kid != "" {
keys = append(keys, &KeyDetailsAdapter{
id: trust.KeyIdentifier(kid),
algorithm: alg,
cryptoProvider: a.cryptoProvider,
})
if kids, err := a.cryptoProvider.ListKIDsByAlgorithm(alg); err == nil && len(kids) > 0 {
for _, kid := range kids {
keys = append(keys, &KeyDetailsAdapter{
id: trust.KeyIdentifier(kid),
algorithm: alg,
cryptoProvider: a.cryptoProvider,
legacy: a.legacyKeys[kid],
})
}
} else if err != nil {
if a.logger != nil {
a.logger.Warn("failed to list keys by algorithm", "algorithm", alg, "error", err)
}
}
}

Expand Down Expand Up @@ -275,7 +304,7 @@ func (a *InProcessProvider) Decrypt(ctx context.Context, keyID trust.KeyIdentifi
if len(ephemeralPublicKey) == 0 {
return nil, errors.New("ephemeral public key is required for EC decryption")
}
rawKey, err = a.cryptoProvider.ECDecrypt(kid, ephemeralPublicKey, ciphertext)
rawKey, err = a.cryptoProvider.ECDecrypt(ctx, kid, ephemeralPublicKey, ciphertext)

default:
return nil, errors.New("unsupported key algorithm")
Expand Down
87 changes: 32 additions & 55 deletions service/internal/security/standard_crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,20 @@ func NewStandardCrypto(cfg StandardConfig) (*StandardCrypto, error) {
}
}

// ListKIDsByAlgorithm returns a list of key identifiers for the specified algorithm
// Errors if no keys are found of the requested algorithm.
func (s StandardCrypto) ListKIDsByAlgorithm(alg string) ([]string, error) {
k, ok := s.keysByAlg[alg]
if !ok {
return nil, fmt.Errorf("no key found with algorithm [%s]: %w", alg, ErrCertNotFound)
}
keys := make([]string, 0, len(k))
for kid := range k {
keys = append(keys, kid)
}
return keys, nil
}

func loadKeys(ks []KeyPairInfo) (*StandardCrypto, error) {
keysByAlg := make(map[string]keylist)
keysByID := make(keylist)
Expand Down Expand Up @@ -351,21 +365,6 @@ func (s StandardCrypto) RSAPublicKeyAsJSON(kid string) (string, error) {
}

func (s StandardCrypto) GenerateNanoTDFSymmetricKey(kasKID string, ephemeralPublicKeyBytes []byte, curve elliptic.Curve) ([]byte, error) {
ephemeralECDSAPublicKey, err := ocrypto.UncompressECPubKey(curve, ephemeralPublicKeyBytes)
if err != nil {
return nil, err
}

derBytes, err := x509.MarshalPKIXPublicKey(ephemeralECDSAPublicKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal ECDSA public key: %w", err)
}
pemBlock := &pem.Block{
Type: "PUBLIC KEY",
Bytes: derBytes,
}
ephemeralECDSAPublicKeyPEM := pem.EncodeToMemory(pemBlock)

k, ok := s.keysByID[kasKID]
if !ok {
return nil, ErrKeyPairInfoNotFound
Expand All @@ -374,60 +373,38 @@ func (s StandardCrypto) GenerateNanoTDFSymmetricKey(kasKID string, ephemeralPubl
if !ok {
return nil, ErrKeyPairInfoMalformed
}
privateKeyPEM := []byte(ec.ecPrivateKeyPem)

symmetricKey, err := ocrypto.ComputeECDHKey([]byte(ec.ecPrivateKeyPem), ephemeralECDSAPublicKeyPEM)
if err != nil {
return nil, fmt.Errorf("ocrypto.ComputeECDHKey failed: %w", err)
}

key, err := ocrypto.CalculateHKDF(versionSalt(), symmetricKey)
if err != nil {
return nil, fmt.Errorf("ocrypto.CalculateHKDF failed:%w", err)
}

return key, nil
return DeriveNanoTDFSymmetricKey(curve, ephemeralPublicKeyBytes, privateKeyPEM)
}

func (s StandardCrypto) GenerateEphemeralKasKeys() (any, []byte, error) {
ephemeralKeyPair, err := ocrypto.NewECKeyPair(ocrypto.ECCModeSecp256r1)
func DeriveNanoTDFSymmetricKey(curve elliptic.Curve, clientEphemera []byte, privateKeyPEM []byte) ([]byte, error) {
ephemeralECDSAPublicKey, err := ocrypto.UncompressECPubKey(curve, clientEphemera)
if err != nil {
return nil, nil, fmt.Errorf("ocrypto.NewECKeyPair failed: %w", err)
return nil, err
}

pubKeyInPem, err := ephemeralKeyPair.PublicKeyInPemFormat()
derBytes, err := x509.MarshalPKIXPublicKey(ephemeralECDSAPublicKey)
if err != nil {
return nil, nil, fmt.Errorf("failed to get public key in PEM format: %w", err)
return nil, fmt.Errorf("failed to marshal ECDSA public key: %w", err)
}
pubKeyBytes := []byte(pubKeyInPem)

privKey, err := ocrypto.ConvertToECDHPrivateKey(ephemeralKeyPair.PrivateKey)
if err != nil {
return nil, nil, fmt.Errorf("failed to convert provate key to ECDH: %w", err)
pemBlock := &pem.Block{
Type: "PUBLIC KEY",
Bytes: derBytes,
}
return privKey, pubKeyBytes, nil
}
ephemeralECDSAPublicKeyPEM := pem.EncodeToMemory(pemBlock)

func (s StandardCrypto) GenerateNanoTDFSessionKey(privateKey any, ephemeralPublicKeyPEM []byte) ([]byte, error) {
ecdhKey, err := ocrypto.ConvertToECDHPrivateKey(privateKey)
symmetricKey, err := ocrypto.ComputeECDHKey(privateKeyPEM, ephemeralECDSAPublicKeyPEM)
if err != nil {
return nil, fmt.Errorf("GenerateNanoTDFSessionKey failed to ConvertToECDHPrivateKey: %w", err)
}
ephemeralECDHPublicKey, err := ocrypto.ECPubKeyFromPem(ephemeralPublicKeyPEM)
if err != nil {
return nil, fmt.Errorf("GenerateNanoTDFSessionKey failed to ocrypto.ECPubKeyFromPem: %w", err)
}
// shared secret
sessionKey, err := ecdhKey.ECDH(ephemeralECDHPublicKey)
if err != nil {
return nil, fmt.Errorf("GenerateNanoTDFSessionKey failed to ecdhKey.ECDH: %w", err)
return nil, fmt.Errorf("ocrypto.ComputeECDHKey failed: %w", err)
}

salt := versionSalt()
derivedKey, err := ocrypto.CalculateHKDF(salt, sessionKey)
key, err := ocrypto.CalculateHKDF(versionSalt(), symmetricKey)
if err != nil {
return nil, fmt.Errorf("ocrypto.CalculateHKDF failed:%w", err)
}
return derivedKey, nil

return key, nil
}

func (s StandardCrypto) Close() {
Expand All @@ -447,8 +424,8 @@ func versionSalt() []byte {
}

// ECDecrypt uses hybrid ECIES to decrypt the data.
func (s *StandardCrypto) ECDecrypt(keyID string, ephemeralPublicKey, ciphertext []byte) ([]byte, error) {
unwrappedKey, err := s.Decrypt(context.Background(), trust.KeyIdentifier(keyID), ciphertext, ephemeralPublicKey)
func (s *StandardCrypto) ECDecrypt(ctx context.Context, keyID string, ephemeralPublicKey, ciphertext []byte) ([]byte, error) {
unwrappedKey, err := s.Decrypt(ctx, trust.KeyIdentifier(keyID), ciphertext, ephemeralPublicKey)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions service/internal/security/standard_only.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package security
import "log/slog"

type Config struct {
Type string `mapstructure:"type" json:"type" default:"standard"`
Type string `mapstructure:"type" json:"type"`
// StandardConfig is the configuration for the standard key provider
StandardConfig StandardConfig `mapstructure:"standard" json:"standard"`
}
Expand All @@ -12,7 +12,7 @@ func (c Config) IsEmpty() bool {
return c.Type == "" && c.StandardConfig.IsEmpty()
}

func NewCryptoProvider(cfg Config) (CryptoProvider, error) {
func NewCryptoProvider(cfg Config) (*StandardCrypto, error) {
switch cfg.Type {
case "hsm":
slog.Error("opentdf hsm mode has been removed")
Expand Down
2 changes: 1 addition & 1 deletion service/internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ type OpenTDFServer struct {
TrustKeyManager trust.KeyManager

// To Deprecate: Use the TrustKeyIndex and TrustKeyManager instead
CryptoProvider security.CryptoProvider
CryptoProvider *security.StandardCrypto

logger *logger.Logger
}
Expand Down
Loading
Loading