Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 40 additions & 46 deletions service/internal/security/standard_crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,11 @@ type StandardECCrypto struct {
type keylist map[string]any

type StandardCrypto struct {
// Lists of keys first sorted by algorithm
keys map[string]keylist
// Lists of keysByAlg first sorted by algorithm
keysByAlg map[string]keylist

// Lists all keys by identifier.
keysByID keylist
}

// NewStandardCrypto Create a new instance of standard crypto
Expand All @@ -83,20 +86,26 @@ func NewStandardCrypto(cfg StandardConfig) (*StandardCrypto, error) {
}

func loadKeys(ks []KeyPairInfo) (*StandardCrypto, error) {
keys := make(map[string]keylist)
keysByAlg := make(map[string]keylist)
keysByID := make(keylist)
for _, k := range ks {
slog.Info("crypto cfg loading", "id", k.KID, "alg", k.Algorithm)
if _, ok := keys[k.Algorithm]; !ok {
keys[k.Algorithm] = make(map[string]any)
if _, ok := keysByID[k.KID]; ok {
return nil, fmt.Errorf("duplicate key identifier [%s]", k.KID)
}
if _, ok := keysByAlg[k.Algorithm]; !ok {
keysByAlg[k.Algorithm] = make(map[string]any)
}
loadedKey, err := loadKey(k)
if err != nil {
return nil, err
}
keys[k.Algorithm][k.KID] = loadedKey
keysByAlg[k.Algorithm][k.KID] = loadedKey
keysByID[k.KID] = loadedKey
}
return &StandardCrypto{
keys: keys,
keysByAlg: keysByAlg,
keysByID: keysByID,
}, nil
}

Expand Down Expand Up @@ -139,13 +148,14 @@ func loadKey(k KeyPairInfo) (any, error) {
}

func loadDeprecatedKeys(rsaKeys map[string]StandardKeyInfo, ecKeys map[string]StandardKeyInfo) (*StandardCrypto, error) {
keys := make(map[string]keylist)
keysByAlg := make(map[string]keylist)
keysByID := make(keylist)

if len(ecKeys) > 0 {
keys[AlgorithmECP256R1] = make(map[string]any)
keysByAlg[AlgorithmECP256R1] = make(map[string]any)
}
if len(rsaKeys) > 0 {
keys[AlgorithmRSA2048] = make(map[string]any)
keysByAlg[AlgorithmRSA2048] = make(map[string]any)
}

for id, kasInfo := range rsaKeys {
Expand All @@ -169,7 +179,7 @@ func loadDeprecatedKeys(rsaKeys map[string]StandardKeyInfo, ecKeys map[string]St
return nil, fmt.Errorf("ocrypto.NewAsymEncryption failed: %w", err)
}

keys[AlgorithmRSA2048][id] = StandardRSACrypto{
k := StandardRSACrypto{
KeyPairInfo: KeyPairInfo{
Algorithm: AlgorithmRSA2048,
KID: id,
Expand All @@ -179,6 +189,8 @@ func loadDeprecatedKeys(rsaKeys map[string]StandardKeyInfo, ecKeys map[string]St
asymDecryption: asymDecryption,
asymEncryption: asymEncryption,
}
keysByAlg[AlgorithmRSA2048][id] = k
keysByID[id] = k
}
for id, kasInfo := range ecKeys {
slog.Info("cfg.ECKeys", "id", id, "kasInfo", kasInfo)
Expand All @@ -192,7 +204,7 @@ func loadDeprecatedKeys(rsaKeys map[string]StandardKeyInfo, ecKeys map[string]St
if err != nil {
return nil, fmt.Errorf("failed to EC certificate file: %w", err)
}
keys[AlgorithmECP256R1][id] = StandardECCrypto{
k := StandardECCrypto{
KeyPairInfo: KeyPairInfo{
Algorithm: AlgorithmRSA2048,
KID: id,
Expand All @@ -202,15 +214,18 @@ func loadDeprecatedKeys(rsaKeys map[string]StandardKeyInfo, ecKeys map[string]St
ecPrivateKeyPem: string(privatePemData),
ecCertificatePEM: string(ecCertificatePEM),
}
keysByAlg[AlgorithmECP256R1][id] = k
keysByID[id] = k
}

return &StandardCrypto{
keys: keys,
keysByAlg: keysByAlg,
keysByID: keysByID,
}, nil
}

func (s StandardCrypto) FindKID(alg string) string {
if ks, ok := s.keys[alg]; ok && len(ks) > 0 {
if ks, ok := s.keysByAlg[alg]; ok && len(ks) > 0 {
for kid := range ks {
return kid
}
Expand All @@ -219,17 +234,13 @@ func (s StandardCrypto) FindKID(alg string) string {
}

func (s StandardCrypto) RSAPublicKey(kid string) (string, error) {
rsaKeys, ok := s.keys[AlgorithmRSA2048]
if !ok || len(rsaKeys) == 0 {
return "", ErrCertNotFound
}
k, ok := rsaKeys[kid]
k, ok := s.keysByID[kid]
if !ok {
return "", ErrCertNotFound
return "", fmt.Errorf("no rsa key with id [%s]: %w", kid, ErrCertNotFound)
}
rsa, ok := k.(StandardRSACrypto)
if !ok {
return "", ErrCertNotFound
return "", fmt.Errorf("key with id [%s] is not an RSA key: %w", kid, ErrCertNotFound)
}

pem, err := rsa.asymEncryption.PublicKeyInPemFormat()
Expand All @@ -241,27 +252,19 @@ func (s StandardCrypto) RSAPublicKey(kid string) (string, error) {
}

func (s StandardCrypto) ECCertificate(kid string) (string, error) {
ecKeys, ok := s.keys[AlgorithmECP256R1]
if !ok || len(ecKeys) == 0 {
return "", ErrCertNotFound
}
k, ok := ecKeys[kid]
k, ok := s.keysByID[kid]
if !ok {
return "", ErrCertNotFound
return "", fmt.Errorf("no ec key with id [%s]: %w", kid, ErrCertNotFound)
}
ec, ok := k.(StandardECCrypto)
if !ok {
return "", ErrCertNotFound
return "", fmt.Errorf("key with id [%s] is not an EC key: %w", kid, ErrCertNotFound)
}
return ec.ecCertificatePEM, nil
}

func (s StandardCrypto) ECPublicKey(kid string) (string, error) {
ecKeys, ok := s.keys[AlgorithmECP256R1]
if !ok || len(ecKeys) == 0 {
return "", ErrCertNotFound
}
k, ok := ecKeys[kid]
k, ok := s.keysByID[kid]
if !ok {
return "", ErrCertNotFound
}
Expand Down Expand Up @@ -293,11 +296,7 @@ func (s StandardCrypto) ECPublicKey(kid string) (string, error) {
}

func (s StandardCrypto) RSADecrypt(_ crypto.Hash, kid string, _ string, ciphertext []byte) ([]byte, error) {
rsaKeys, ok := s.keys[AlgorithmRSA2048]
if !ok || len(rsaKeys) == 0 {
return nil, ErrCertNotFound
}
k, ok := rsaKeys[kid]
k, ok := s.keysByID[kid]
if !ok {
return nil, ErrCertNotFound
}
Expand All @@ -315,11 +314,10 @@ func (s StandardCrypto) RSADecrypt(_ crypto.Hash, kid string, _ string, cipherte
}

func (s StandardCrypto) RSAPublicKeyAsJSON(kid string) (string, error) {
rsaKeys, ok := s.keys[AlgorithmRSA2048]
if !ok || len(rsaKeys) == 0 {
k, ok := s.keysByID[kid]
if !ok {
return "", ErrCertNotFound
}
k, ok := rsaKeys[kid]
if !ok {
return "", ErrCertNotFound
}
Expand Down Expand Up @@ -357,11 +355,7 @@ func (s StandardCrypto) GenerateNanoTDFSymmetricKey(kasKID string, ephemeralPubl
}
ephemeralECDSAPublicKeyPEM := pem.EncodeToMemory(pemBlock)

ecKeys, ok := s.keys[AlgorithmECP256R1]
if !ok || len(ecKeys) == 0 {
return nil, ErrNoKeys
}
k, ok := ecKeys[kasKID]
k, ok := s.keysByID[kasKID]
if !ok {
return nil, ErrKeyPairInfoNotFound
}
Expand Down
46 changes: 46 additions & 0 deletions service/kas/access/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,49 @@ func (p *Provider) IsReady(ctx context.Context) error {
p.Logger.TraceContext(ctx, "checking readiness of kas service")
return nil
}

func (kasCfg *KASConfig) UpgradeMapToKeyring(c security.CryptoProvider) {
switch {
case kasCfg.ECCertID != "" && len(kasCfg.Keyring) > 0:
panic("invalid kas cfg: please specify keyring or eccertid, not both")
case len(kasCfg.Keyring) == 0:
deprecatedOrDefault := func(kid, alg string) {
if kid == "" {
kid = c.FindKID(alg)
}
if kid == "" {
// no known key for this algorithm type
return
}
kasCfg.Keyring = append(kasCfg.Keyring, CurrentKeyFor{
Algorithm: alg,
KID: kid,
})
kasCfg.Keyring = append(kasCfg.Keyring, CurrentKeyFor{
Algorithm: alg,
KID: kid,
Legacy: true,
})
}
deprecatedOrDefault(kasCfg.ECCertID, security.AlgorithmECP256R1)
deprecatedOrDefault(kasCfg.RSACertID, security.AlgorithmRSA2048)
default:
kasCfg.Keyring = append(kasCfg.Keyring, inferLegacyKeys(kasCfg.Keyring)...)
}
}

// If there exists *any* legacy keys, returns empty list.
// Otherwise, create a copy with legacy=true for all values
func inferLegacyKeys(keys []CurrentKeyFor) []CurrentKeyFor {
for _, k := range keys {
if k.Legacy {
return nil
}
}
l := make([]CurrentKeyFor, len(keys))
for i, k := range keys {
l[i] = k
l[i].Legacy = true
}
return l
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
package kas
package access

import (
"testing"

"github.com/opentdf/platform/service/internal/security"
"github.com/opentdf/platform/service/kas/access"
"github.com/stretchr/testify/assert"
)

Expand All @@ -13,14 +12,14 @@ func TestInferLegacyKeys_empty(t *testing.T) {
}

func TestInferLegacyKeys_singles(t *testing.T) {
one := []access.CurrentKeyFor{
one := []CurrentKeyFor{
{
Algorithm: security.AlgorithmRSA2048,
KID: "rsa",
},
}

oneLegacy := []access.CurrentKeyFor{
oneLegacy := []CurrentKeyFor{
{
Algorithm: security.AlgorithmRSA2048,
KID: "rsa",
Expand All @@ -34,7 +33,7 @@ func TestInferLegacyKeys_singles(t *testing.T) {
}

func TestInferLegacyKeys_Mixed(t *testing.T) {
in := []access.CurrentKeyFor{
in := []CurrentKeyFor{
{
Algorithm: security.AlgorithmRSA2048,
KID: "a",
Expand Down
4 changes: 2 additions & 2 deletions service/kas/access/publicKey.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ func (p Provider) PublicKey(ctx context.Context, req *connect.Request[kaspb.Publ
r := func(value, kid string, err error) (*connect.Response[kaspb.PublicKeyResponse], error) {
if errors.Is(err, security.ErrCertNotFound) {
p.Logger.ErrorContext(ctx, "no key found for", "err", err, "kid", kid, "algorithm", algorithm, "fmt", fmt)
return nil, connect.NewError(connect.CodeNotFound, err)
return nil, connect.NewError(connect.CodeNotFound, security.ErrCertNotFound)
} else if err != nil {
p.Logger.ErrorContext(ctx, "configuration error for key lookup", "err", err, "kid", kid, "algorithm", algorithm, "fmt", fmt)
return nil, connect.NewError(connect.CodeInternal, err)
return nil, connect.NewError(connect.CodeInternal, ErrInternal)
}
if req.Msg.GetV() == "1" {
p.Logger.WarnContext(ctx, "hiding kid in public key response for legacy client", "kid", kid, "v", req.Msg.GetV())
Expand Down
Loading
Loading