Skip to content
4 changes: 4 additions & 0 deletions service/internal/security/crypto_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ type CryptoProvider interface {
// Gets some KID associated with a given algorithm.
// Returns empty string if none are found.
FindKID(alg string) string

// Gets all KIDs associated with a given algorithm.
ListKeysByAlg(alg string) ([]string, error)

RSAPublicKey(keyID string) (string, error)
RSAPublicKeyAsJSON(keyID string) (string, error)
RSADecrypt(hash crypto.Hash, keyID string, keyLabel string, ciphertext []byte) ([]byte, error)
Expand Down
61 changes: 44 additions & 17 deletions service/internal/security/in_process_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ func convertPEMToJWK(_ string) (string, error) {
type InProcessProvider struct {
cryptoProvider CryptoProvider
logger *slog.Logger
defaultKeys map[string]bool
legacyKeys map[string]bool
}

// KeyDetailsAdapter adapts CryptoProvider to KeyDetails
Expand Down Expand Up @@ -174,10 +176,22 @@ func (k *KeyDetailsAdapter) ExportCertificate(_ context.Context) (string, error)
}

// NewSecurityProviderAdapter creates a new adapter that implements SecurityProvider using a CryptoProvider
func NewSecurityProviderAdapter(cryptoProvider CryptoProvider) trust.KeyService {
func NewSecurityProviderAdapter(cryptoProvider CryptoProvider, defaultKeys, legacyKeys []string) trust.KeyService {
legacyKeysMap := make(map[string]bool, len(legacyKeys))
for _, key := range legacyKeys {
legacyKeysMap[key] = true
}

defaultKeysMap := make(map[string]bool, len(defaultKeys))
for _, key := range defaultKeys {
defaultKeysMap[key] = true
}

return &InProcessProvider{
cryptoProvider: cryptoProvider,
logger: slog.Default(),
defaultKeys: defaultKeysMap,
legacyKeys: legacyKeysMap,
}
}

Expand All @@ -193,17 +207,23 @@ func (a *InProcessProvider) WithLogger(logger *slog.Logger) *InProcessProvider {
}

// FindKeyByAlgorithm finds a key by algorithm using the underlying CryptoProvider
func (a *InProcessProvider) FindKeyByAlgorithm(_ context.Context, algorithm string, _ bool) (trust.KeyDetails, error) {
func (a *InProcessProvider) FindKeyByAlgorithm(_ context.Context, algorithm string, legacy bool) (trust.KeyDetails, error) {
// Get the key ID for this algorithm
kid := a.cryptoProvider.FindKID(algorithm)
if kid == "" {
kids, err := a.cryptoProvider.ListKeysByAlg(algorithm)
if err != nil || len(kids) == 0 {
return nil, ErrCertNotFound
}
return &KeyDetailsAdapter{
id: trust.KeyIdentifier(kid),
algorithm: algorithm,
cryptoProvider: a.cryptoProvider,
}, nil
for _, kid := range kids {
if legacy && a.legacyKeys[kid] || !legacy && a.defaultKeys[kid] {
return &KeyDetailsAdapter{
id: trust.KeyIdentifier(kid),
algorithm: algorithm,
cryptoProvider: a.cryptoProvider,
legacy: legacy,
}, nil
}
}
return nil, ErrCertNotFound
}

// FindKeyByID finds a key by ID
Expand All @@ -216,7 +236,7 @@ func (a *InProcessProvider) FindKeyByID(_ context.Context, id trust.KeyIdentifie
return &KeyDetailsAdapter{
id: id,
algorithm: alg,
legacy: false,
legacy: a.legacyKeys[string(id)],
cryptoProvider: a.cryptoProvider,
}, nil
}
Expand All @@ -225,7 +245,7 @@ func (a *InProcessProvider) FindKeyByID(_ context.Context, id trust.KeyIdentifie
return &KeyDetailsAdapter{
id: id,
algorithm: alg,
legacy: false,
legacy: a.legacyKeys[string(id)],
cryptoProvider: a.cryptoProvider,
}, nil
}
Expand All @@ -241,12 +261,19 @@ func (a *InProcessProvider) ListKeys(_ context.Context) ([]trust.KeyDetails, err

// Try to find keys for known algorithms
for _, alg := range []string{AlgorithmRSA2048, AlgorithmECP256R1} {
if kid := a.cryptoProvider.FindKID(alg); kid != "" {
keys = append(keys, &KeyDetailsAdapter{
id: trust.KeyIdentifier(kid),
algorithm: alg,
cryptoProvider: a.cryptoProvider,
})
if kids, err := a.cryptoProvider.ListKeysByAlg(alg); err == nil && len(kids) > 0 {
for _, kid := range kids {
keys = append(keys, &KeyDetailsAdapter{
id: trust.KeyIdentifier(kid),
algorithm: alg,
cryptoProvider: a.cryptoProvider,
legacy: a.legacyKeys[kid],
})
}
} else if err != nil {
if a.logger != nil {
a.logger.Warn("failed to list keys by algorithm", "algorithm", alg, "error", err)
}
}
}

Expand Down
39 changes: 28 additions & 11 deletions service/internal/security/standard_crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,18 @@ func NewStandardCrypto(cfg StandardConfig) (*StandardCrypto, error) {
}
}

func (s StandardCrypto) ListKeysByAlg(alg string) ([]string, error) {
k, ok := s.keysByAlg[alg]
if !ok {
return nil, fmt.Errorf("no key found with algorithm [%s]: %w", alg, ErrCertNotFound)
}
keys := make([]string, 0, len(k))
for kid := range k {
keys = append(keys, kid)
}
return keys, nil
}

func loadKeys(ks []KeyPairInfo) (*StandardCrypto, error) {
keysByAlg := make(map[string]keylist)
keysByID := make(keylist)
Expand Down Expand Up @@ -351,7 +363,21 @@ func (s StandardCrypto) RSAPublicKeyAsJSON(kid string) (string, error) {
}

func (s StandardCrypto) GenerateNanoTDFSymmetricKey(kasKID string, ephemeralPublicKeyBytes []byte, curve elliptic.Curve) ([]byte, error) {
ephemeralECDSAPublicKey, err := ocrypto.UncompressECPubKey(curve, ephemeralPublicKeyBytes)
k, ok := s.keysByID[kasKID]
if !ok {
return nil, ErrKeyPairInfoNotFound
}
ec, ok := k.(StandardECCrypto)
if !ok {
return nil, ErrKeyPairInfoMalformed
}
privateKeyPEM := []byte(ec.ecPrivateKeyPem)

return DeriveNanoTDFSymmetricKey(curve, ephemeralPublicKeyBytes, privateKeyPEM)
}

func DeriveNanoTDFSymmetricKey(curve elliptic.Curve, clientEphemera []byte, privateKeyPEM []byte) ([]byte, error) {
ephemeralECDSAPublicKey, err := ocrypto.UncompressECPubKey(curve, clientEphemera)
if err != nil {
return nil, err
}
Expand All @@ -366,16 +392,7 @@ func (s StandardCrypto) GenerateNanoTDFSymmetricKey(kasKID string, ephemeralPubl
}
ephemeralECDSAPublicKeyPEM := pem.EncodeToMemory(pemBlock)

k, ok := s.keysByID[kasKID]
if !ok {
return nil, ErrKeyPairInfoNotFound
}
ec, ok := k.(StandardECCrypto)
if !ok {
return nil, ErrKeyPairInfoMalformed
}

symmetricKey, err := ocrypto.ComputeECDHKey([]byte(ec.ecPrivateKeyPem), ephemeralECDSAPublicKeyPEM)
symmetricKey, err := ocrypto.ComputeECDHKey(privateKeyPEM, ephemeralECDSAPublicKeyPEM)
if err != nil {
return nil, fmt.Errorf("ocrypto.ComputeECDHKey failed: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion service/internal/security/standard_only.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package security
import "log/slog"

type Config struct {
Type string `mapstructure:"type" json:"type" default:"standard"`
Type string `mapstructure:"type" json:"type"`
// StandardConfig is the configuration for the standard key provider
StandardConfig StandardConfig `mapstructure:"standard" json:"standard"`
}
Expand Down
64 changes: 41 additions & 23 deletions service/kas/access/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,32 +35,13 @@ type Provider struct {

// GetSecurityProvider returns the SecurityProvider
func (p *Provider) GetSecurityProvider() trust.KeyManager {
// If SecurityProvider is explicitly set, use it
if p.KeyManager != nil {
return p.KeyManager
}

// Otherwise, create an adapter from CryptoProvider if available
if p.CryptoProvider != nil {
return security.NewSecurityProviderAdapter(p.CryptoProvider)
}

// This shouldn't happen in normal operation
p.Logger.Error("no security provider available")
return nil
p.initSecurityProviderAdapter()
return p.KeyManager
}

func (p *Provider) GetKeyIndex() trust.KeyIndex {
if p.KeyIndex != nil {
return p.KeyIndex
}

if p.CryptoProvider != nil {
return security.NewSecurityProviderAdapter(p.CryptoProvider)
}

p.Logger.Error("no key index available")
return nil
p.initSecurityProviderAdapter()
return p.KeyIndex
}

type KASConfig struct {
Expand Down Expand Up @@ -122,6 +103,43 @@ func (kasCfg *KASConfig) UpgradeMapToKeyring(c security.CryptoProvider) {
}
}

func (p *Provider) initSecurityProviderAdapter() {
// If the CryptoProvider is set, create a SecurityProviderAdapter
if p.CryptoProvider == nil || p.KeyManager != nil && p.KeyIndex != nil {
return
}
var defaults []string
var legacies []string
for _, key := range p.KASConfig.Keyring {
if key.Legacy {
legacies = append(legacies, key.KID)
} else {
defaults = append(defaults, key.KID)
}
}
if len(defaults) == 0 && len(legacies) == 0 {
for _, alg := range []string{security.AlgorithmECP256R1, security.AlgorithmRSA2048} {
kid := p.CryptoProvider.FindKID(alg)
if kid != "" {
defaults = append(defaults, kid)
} else {
p.Logger.Warn("no default key found for algorithm", "algorithm", alg)
}
}
}

inProcessService := security.NewSecurityProviderAdapter(p.CryptoProvider, defaults, legacies)

if p.KeyIndex == nil {
p.Logger.Warn("fallback to in-process key index")
p.KeyIndex = inProcessService
}
if p.KeyManager == nil {
p.Logger.Error("fallback to in-process manager")
p.KeyManager = inProcessService
}
}

// If there exists *any* legacy keys, returns empty list.
// Otherwise, create a copy with legacy=true for all values
func inferLegacyKeys(keys []CurrentKeyFor) []CurrentKeyFor {
Expand Down
9 changes: 9 additions & 0 deletions service/kas/access/publicKey.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ const (
)

func (p *Provider) lookupKid(ctx context.Context, algorithm string) (string, error) {
keyIdx := p.GetKeyIndex()
if keyIdx != nil {
k, err := keyIdx.FindKeyByAlgorithm(ctx, algorithm, false)
if err == nil {
return string(k.ID()), nil
}
p.Logger.WarnContext(ctx, "KeyIndex.FindKeyByAlgorithm failed", "err", err)
}

if len(p.Keyring) == 0 {
p.Logger.WarnContext(ctx, "no default keys found", "algorithm", algorithm)
return "", connect.NewError(connect.CodeNotFound, errors.Join(ErrConfig, errors.New("no default keys configured")))
Expand Down
2 changes: 1 addition & 1 deletion service/kas/access/publicKey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ func TestStandardCertificateHandlerEmpty(t *testing.T) {

kas := Provider{
URI: *kasURI,
KeyManager: security.NewSecurityProviderAdapter(c),
KeyManager: security.NewSecurityProviderAdapter(c, nil, nil),
Logger: logger.CreateTestLogger(),
Tracer: noop.NewTracerProvider().Tracer(""),
}
Expand Down
33 changes: 27 additions & 6 deletions service/kas/access/rewrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -529,12 +529,7 @@ func (p *Provider) verifyRewrapRequests(ctx context.Context, req *kaspb.Unsigned
kid := trust.KeyIdentifier(kao.GetKeyAccessObject().GetKid())
kidsToCheck = []trust.KeyIdentifier{kid}
} else {
p.Logger.InfoContext(ctx, "kid free kao")
for _, k := range p.Keyring {
if k.Algorithm == security.AlgorithmRSA2048 && k.Legacy {
kidsToCheck = append(kidsToCheck, trust.KeyIdentifier(k.KID))
}
}
kidsToCheck = p.listLegacyKeys(ctx)
if len(kidsToCheck) == 0 {
p.Logger.WarnContext(ctx, "failure to find legacy kids for rsa")
failedKAORewrap(results, kao, err400("bad request"))
Expand Down Expand Up @@ -603,6 +598,32 @@ func (p *Provider) verifyRewrapRequests(ctx context.Context, req *kaspb.Unsigned
return policy, results, nil
}

func (p *Provider) listLegacyKeys(ctx context.Context) []trust.KeyIdentifier {
var kidsToCheck []trust.KeyIdentifier
p.Logger.InfoContext(ctx, "kid free kao")
if len(p.Keyring) > 0 {
// Using deprecated 'keyring' feature for lookup
for _, k := range p.Keyring {
if k.Algorithm == security.AlgorithmRSA2048 && k.Legacy {
kidsToCheck = append(kidsToCheck, trust.KeyIdentifier(k.KID))
}
}
return kidsToCheck
}

k, err := p.GetKeyIndex().ListKeys(ctx)
if err != nil {
p.Logger.WarnContext(ctx, "KeyIndex.ListKeys failed", "err", err)
} else {
for _, key := range k {
if key.Algorithm() == security.AlgorithmRSA2048 && key.IsLegacy() {
kidsToCheck = append(kidsToCheck, key.ID())
}
}
}
return kidsToCheck
}

func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRewrapRequest_WithPolicyRequest, clientPublicKey string, entity *entityInfo) (string, policyKAOResults) {
if p.Tracer != nil {
var span trace.Span
Expand Down
Loading
Loading