Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions service/internal/security/in_process_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,20 @@ func (a *InProcessProvider) FindKeyByID(_ context.Context, id trust.KeyIdentifie

// ListKeys lists all available keys
func (a *InProcessProvider) ListKeys(ctx context.Context) ([]trust.KeyDetails, error) {
return a.ListKeysWith(ctx, trust.ListKeyOptions{LegacyOnly: false})
}

func (a *InProcessProvider) ListKeysWith(ctx context.Context, opts trust.ListKeyOptions) ([]trust.KeyDetails, error) {
// This is a limited implementation as CryptoProvider doesn't expose a list of all keys
var keys []trust.KeyDetails

// Try to find keys for known algorithms
for _, alg := range []string{AlgorithmRSA2048, AlgorithmECP256R1} {
if kids, err := a.cryptoProvider.ListKIDsByAlgorithm(alg); err == nil && len(kids) > 0 {
for _, kid := range kids {
if opts.LegacyOnly && !a.legacyKeys[kid] {
continue // Skip non-legacy keys if LegacyOnly is true
}
keys = append(keys, &KeyDetailsAdapter{
id: trust.KeyIdentifier(kid),
algorithm: ocrypto.KeyType(alg),
Expand Down
11 changes: 11 additions & 0 deletions service/kas/access/publicKey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,17 @@ func (m *MockSecurityProvider) ListKeys(_ context.Context) ([]trust.KeyDetails,
return keys, nil
}

func (m *MockSecurityProvider) ListKeysWith(_ context.Context, opts trust.ListKeyOptions) ([]trust.KeyDetails, error) {
var keys []trust.KeyDetails
for _, key := range m.keys {
if opts.LegacyOnly && !key.IsLegacy() {
continue
}
keys = append(keys, key)
}
return keys, nil
}

func (m *MockSecurityProvider) Decrypt(_ context.Context, _ trust.KeyDetails, _, _ []byte) (trust.ProtectedKey, error) {
return nil, errors.New("not implemented for tests")
}
Expand Down
2 changes: 1 addition & 1 deletion service/kas/access/rewrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ func (p *Provider) listLegacyKeys(ctx context.Context) []trust.KeyIdentifier {
return kidsToCheck
}

k, err := p.KeyDelegator.ListKeys(ctx)
k, err := p.KeyDelegator.ListKeysWith(ctx, trust.ListKeyOptions{LegacyOnly: true})
if err != nil {
p.Logger.WarnContext(ctx, "checkpoint KeyIndex.ListKeys failed", slog.Any("error", err))
} else {
Expand Down
18 changes: 17 additions & 1 deletion service/kas/access/rewrap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,23 @@ func (f *fakeKeyIndex) FindKeyByAlgorithm(context.Context, string, bool) (trust.
func (f *fakeKeyIndex) FindKeyByID(context.Context, trust.KeyIdentifier) (trust.KeyDetails, error) {
return nil, errors.New("not implemented")
}
func (f *fakeKeyIndex) ListKeys(context.Context) ([]trust.KeyDetails, error) { return f.keys, f.err }

func (f *fakeKeyIndex) ListKeys(context.Context) ([]trust.KeyDetails, error) {
return f.keys, f.err
}

func (f *fakeKeyIndex) ListKeysWith(_ context.Context, opts trust.ListKeyOptions) ([]trust.KeyDetails, error) {
if opts.LegacyOnly {
var legacyKeys []trust.KeyDetails
for _, key := range f.keys {
if key.IsLegacy() {
legacyKeys = append(legacyKeys, key)
}
}
return legacyKeys, f.err
}
return f.keys, f.err
}

func TestListLegacyKeys_KeyringPopulated(t *testing.T) {
testLogger := logger.CreateTestLogger()
Expand Down
20 changes: 18 additions & 2 deletions service/kas/key_indexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,23 @@ func convertAlgToEnum(alg string) (policy.Algorithm, error) {
}
}

func (p *KeyIndexer) FindKeyByAlgorithm(ctx context.Context, algorithm string, _ bool) (trust.KeyDetails, error) {
func (p *KeyIndexer) FindKeyByAlgorithm(ctx context.Context, algorithm string, includeLegacy bool) (trust.KeyDetails, error) {
alg, err := convertAlgToEnum(algorithm)
if err != nil {
return nil, err
}

var legacy *bool
if !includeLegacy {
legacy = &includeLegacy
}

req := &kasregistry.ListKeysRequest{
KeyAlgorithm: alg,
KasFilter: &kasregistry.ListKeysRequest_KasUri{
KasUri: p.kasURI,
},
Legacy: legacy,
}
resp, err := p.sdk.KeyAccessServerRegistry.ListKeys(ctx, req)
if err != nil {
Expand Down Expand Up @@ -140,10 +146,20 @@ func (p *KeyIndexer) FindKeyByID(ctx context.Context, id trust.KeyIdentifier) (t
}

func (p *KeyIndexer) ListKeys(ctx context.Context) ([]trust.KeyDetails, error) {
return p.ListKeysWith(ctx, trust.ListKeyOptions{LegacyOnly: false})
}

func (p *KeyIndexer) ListKeysWith(ctx context.Context, opts trust.ListKeyOptions) ([]trust.KeyDetails, error) {
var legacyOnly *bool
if opts.LegacyOnly {
legacyOnly = &opts.LegacyOnly
}

req := &kasregistry.ListKeysRequest{
KasFilter: &kasregistry.ListKeysRequest_KasUri{
KasUri: p.kasURI,
},
Legacy: legacyOnly,
}
resp, err := p.sdk.KeyAccessServerRegistry.ListKeys(ctx, req)
if err != nil {
Expand Down Expand Up @@ -171,7 +187,7 @@ func (p *KeyAdapter) Algorithm() ocrypto.KeyType {
}

func (p *KeyAdapter) IsLegacy() bool {
return false
return p.key.GetKey().GetLegacy()
}

// This will point to the correct "manager"
Expand Down
229 changes: 229 additions & 0 deletions service/kas/key_indexer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,89 @@ package kas

import (
"context"
"errors"
"testing"

"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/opentdf/platform/lib/ocrypto"
"github.com/opentdf/platform/protocol/go/policy"
"github.com/opentdf/platform/protocol/go/policy/kasregistry"
"github.com/opentdf/platform/sdk"
"github.com/opentdf/platform/service/trust"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
)

type MockKeyAccessServerRegistryClient struct {
mock.Mock
}

func (m *MockKeyAccessServerRegistryClient) CreateKeyAccessServer(context.Context, *kasregistry.CreateKeyAccessServerRequest) (*kasregistry.CreateKeyAccessServerResponse, error) {
return nil, errors.New("not implemented")
}

func (m *MockKeyAccessServerRegistryClient) GetKeyAccessServer(context.Context, *kasregistry.GetKeyAccessServerRequest) (*kasregistry.GetKeyAccessServerResponse, error) {
return nil, errors.New("not implemented")
}

func (m *MockKeyAccessServerRegistryClient) ListKeyAccessServers(context.Context, *kasregistry.ListKeyAccessServersRequest) (*kasregistry.ListKeyAccessServersResponse, error) {
return nil, errors.New("not implemented")
}

func (m *MockKeyAccessServerRegistryClient) UpdateKeyAccessServer(context.Context, *kasregistry.UpdateKeyAccessServerRequest) (*kasregistry.UpdateKeyAccessServerResponse, error) {
return nil, errors.New("not implemented")
}

func (m *MockKeyAccessServerRegistryClient) DeleteKeyAccessServer(context.Context, *kasregistry.DeleteKeyAccessServerRequest) (*kasregistry.DeleteKeyAccessServerResponse, error) {
return nil, errors.New("not implemented")
}

func (m *MockKeyAccessServerRegistryClient) ListKeyAccessServerGrants(context.Context, *kasregistry.ListKeyAccessServerGrantsRequest) (*kasregistry.ListKeyAccessServerGrantsResponse, error) {
return nil, errors.New("not implemented")
}

func (m *MockKeyAccessServerRegistryClient) CreateKey(context.Context, *kasregistry.CreateKeyRequest) (*kasregistry.CreateKeyResponse, error) {
return nil, errors.New("not implemented")
}

func (m *MockKeyAccessServerRegistryClient) GetKey(context.Context, *kasregistry.GetKeyRequest) (*kasregistry.GetKeyResponse, error) {
return nil, errors.New("not implemented")
}

func (m *MockKeyAccessServerRegistryClient) ListKeys(ctx context.Context, req *kasregistry.ListKeysRequest) (*kasregistry.ListKeysResponse, error) {
args := m.Called(ctx, req)
if args.Get(0) == nil {
return nil, args.Error(1)
}

var resp *kasregistry.ListKeysResponse
var ok bool
if resp, ok = args.Get(0).(*kasregistry.ListKeysResponse); !ok {
return nil, args.Error(1)
}
return resp, args.Error(1)
}

func (m *MockKeyAccessServerRegistryClient) UpdateKey(context.Context, *kasregistry.UpdateKeyRequest) (*kasregistry.UpdateKeyResponse, error) {
return nil, errors.New("not implemented")
}

func (m *MockKeyAccessServerRegistryClient) RotateKey(context.Context, *kasregistry.RotateKeyRequest) (*kasregistry.RotateKeyResponse, error) {
return nil, errors.New("not implemented")
}

func (m *MockKeyAccessServerRegistryClient) SetBaseKey(context.Context, *kasregistry.SetBaseKeyRequest) (*kasregistry.SetBaseKeyResponse, error) {
return nil, errors.New("not implemented")
}

func (m *MockKeyAccessServerRegistryClient) GetBaseKey(context.Context, *kasregistry.GetBaseKeyRequest) (*kasregistry.GetBaseKeyResponse, error) {
return nil, errors.New("not implemented")
}

func (m *MockKeyAccessServerRegistryClient) ListKeyMappings(context.Context, *kasregistry.ListKeyMappingsRequest) (*kasregistry.ListKeyMappingsResponse, error) {
return nil, errors.New("not implemented")
}

type KeyIndexTestSuite struct {
suite.Suite
rsaKey trust.KeyDetails
Expand Down Expand Up @@ -73,6 +147,161 @@ func (s *KeyIndexTestSuite) TestKeyExportPublicKey_PKCSFormat() {
s.Equal(pubCtx.GetPem(), string(base64Pem))
}

func (s *KeyIndexTestSuite) TestKeyDetails_Legacy() {
legacyKey := &KeyAdapter{
key: &policy.KasKey{
KasId: "test-kas-id",
Key: &policy.AsymmetricKey{
Id: "test-id-legacy",
KeyId: "test-key-id-legacy",
KeyAlgorithm: policy.Algorithm_ALGORITHM_RSA_2048,
KeyStatus: policy.KeyStatus_KEY_STATUS_ACTIVE,
KeyMode: policy.KeyMode_KEY_MODE_CONFIG_ROOT_KEY,
Legacy: true, // Mark as legacy
PublicKeyCtx: &policy.PublicKeyCtx{
Pem: "LS0tLS1CRUdJTiBQVUJMSUMgS0VZLS0tLS0KTUlJQklqQU5CZ2txaGtpRzl3MEJBUUVGQUFPQ0FROEFNSUlCQ2dLQ0FRRUF3SEw0TkVrOFpDa0JzNjZXQVpWagpIS3NseDRseWdmaXN3aW42RUx5OU9OczZLVDRYa1crRGxsdExtck14bHZkbzVRaDg1UmFZS01mWUdDTWtPM0dGCkFsK0JOeWFOM1kwa0N1QjNPU2ErTzdyMURhNVZteVVuaEJNbFBrYnVPY1Y0cjlLMUhOSGd3eDl2UFp3RjRpQW8KQStEY1VBcWFEeHlvYjV6enNGZ0hUNjJHLzdLdEtiZ2hYT1dCanRUYUl1ZHpsK2FaSjFPemY0U1RkOXhST2QrMQordVo2VG1ocmFEUm9zdDUrTTZUN0toL2lGWk40TTFUY2hwWXU1TDhKR2tVaG9YaEdZcHUrMGczSzlqYlh6RVh5CnpJU3VXN2d6SGRWYUxvcnBkQlNkRHpOWkNvTFVoL0U1T3d5TFZFQkNKaDZJVUtvdWJ5WHVucnIxQnJmK2tLbEsKeHdJREFRQUIKLS0tLS1FTkQgUFVCTElDIEtFWS0tLS0tCg==",
},
},
},
}
s.True(legacyKey.IsLegacy())
}

func (s *KeyIndexTestSuite) TestListKeysWith() {
mockClient := new(MockKeyAccessServerRegistryClient)
keyIndexer := &KeyIndexer{
sdk: &sdk.SDK{
KeyAccessServerRegistry: mockClient,
},
}

// Mock the ListKeys function to return a specific key based on the legacy flag
mockClient.On("ListKeys", mock.Anything, mock.MatchedBy(func(req *kasregistry.ListKeysRequest) bool {
return req.GetLegacy()
})).Return(&kasregistry.ListKeysResponse{
KasKeys: []*policy.KasKey{
{
Key: &policy.AsymmetricKey{
KeyId: "legacy-key-id",
},
},
},
}, nil)

mockClient.On("ListKeys", mock.Anything, mock.MatchedBy(func(req *kasregistry.ListKeysRequest) bool {
//nolint:staticcheck // Legacy optional flag should not be set
return req.Legacy == nil
})).Return(&kasregistry.ListKeysResponse{
KasKeys: []*policy.KasKey{
{
Key: &policy.AsymmetricKey{
KeyId: "non-legacy-key-id",
},
},
{
Key: &policy.AsymmetricKey{
KeyId: "legacy-key-id",
},
},
},
}, nil)

// Test with legacy flag set to true
keys, err := keyIndexer.ListKeysWith(context.Background(), trust.ListKeyOptions{LegacyOnly: true})
s.Require().NoError(err)
s.Len(keys, 1)
s.Equal("legacy-key-id", string(keys[0].ID()))

// Test with legacy flag set to false
keys, err = keyIndexer.ListKeysWith(context.Background(), trust.ListKeyOptions{LegacyOnly: false})
s.Require().NoError(err)
s.Len(keys, 2)
s.Equal("non-legacy-key-id", string(keys[0].ID()))
s.Equal("legacy-key-id", string(keys[1].ID()))
}

func (s *KeyIndexTestSuite) TestListKeys() {
mockClient := new(MockKeyAccessServerRegistryClient)
keyIndexer := &KeyIndexer{
sdk: &sdk.SDK{
KeyAccessServerRegistry: mockClient,
},
}

mockClient.On("ListKeys", mock.Anything, mock.MatchedBy(func(req *kasregistry.ListKeysRequest) bool {
return !req.GetLegacy()
})).Return(&kasregistry.ListKeysResponse{
KasKeys: []*policy.KasKey{
{
Key: &policy.AsymmetricKey{
KeyId: "test-key-id",
},
},
},
}, nil)

keys, err := keyIndexer.ListKeys(context.Background())
s.Require().NoError(err)
s.Len(keys, 1)
s.Equal("test-key-id", string(keys[0].ID()))
}

func (s *KeyIndexTestSuite) TestFindKeyByAlgorithm() {
mockClient := new(MockKeyAccessServerRegistryClient)
keyIndexer := &KeyIndexer{
sdk: &sdk.SDK{
KeyAccessServerRegistry: mockClient,
},
}

mockClient.On("ListKeys", mock.Anything, mock.MatchedBy(func(req *kasregistry.ListKeysRequest) bool {
//nolint:staticcheck // Legacy optional flag should not be set
return req.GetKeyAlgorithm() == policy.Algorithm_ALGORITHM_RSA_2048 && (req.Legacy != nil && req.GetLegacy() == false)
})).Return(&kasregistry.ListKeysResponse{
KasKeys: []*policy.KasKey{
{
Key: &policy.AsymmetricKey{
KeyId: "test-key-id",
KeyAlgorithm: policy.Algorithm_ALGORITHM_RSA_2048,
KeyStatus: policy.KeyStatus_KEY_STATUS_ACTIVE,
},
},
},
}, nil)

mockClient.On("ListKeys", mock.Anything, mock.MatchedBy(func(req *kasregistry.ListKeysRequest) bool {
//nolint:staticcheck // Legacy optional flag should not be set
return req.GetKeyAlgorithm() == policy.Algorithm_ALGORITHM_RSA_2048 && req.Legacy == nil
})).Return(&kasregistry.ListKeysResponse{
KasKeys: []*policy.KasKey{
{
Key: &policy.AsymmetricKey{
KeyId: "test-legacy-key-id",
KeyAlgorithm: policy.Algorithm_ALGORITHM_RSA_2048,
KeyStatus: policy.KeyStatus_KEY_STATUS_ACTIVE,
},
},
{
Key: &policy.AsymmetricKey{
KeyId: "test-key-id",
KeyAlgorithm: policy.Algorithm_ALGORITHM_RSA_2048,
KeyStatus: policy.KeyStatus_KEY_STATUS_ACTIVE,
},
},
},
}, nil)

key, err := keyIndexer.FindKeyByAlgorithm(context.Background(), string(ocrypto.RSA2048Key), false)
s.Require().NoError(err)
s.NotNil(key)
s.Equal("test-key-id", string(key.ID()))

key, err = keyIndexer.FindKeyByAlgorithm(context.Background(), string(ocrypto.RSA2048Key), true)
s.Require().NoError(err)
s.NotNil(key)
s.Equal("test-legacy-key-id", string(key.ID()))
}

func TestNewPlatformKeyIndexTestSuite(t *testing.T) {
suite.Run(t, new(KeyIndexTestSuite))
}
4 changes: 4 additions & 0 deletions service/trust/delegating_key_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ func (d *DelegatingKeyService) ListKeys(ctx context.Context) ([]KeyDetails, erro
return d.index.ListKeys(ctx)
}

func (d *DelegatingKeyService) ListKeysWith(ctx context.Context, opts ListKeyOptions) ([]KeyDetails, error) {
return d.index.ListKeysWith(ctx, opts)
}

// Implementing KeyManager methods
func (d *DelegatingKeyService) Name() string {
return "DelegatingKeyService"
Expand Down
Loading
Loading