diff --git a/service/internal/security/in_process_provider.go b/service/internal/security/in_process_provider.go index 694665caee..bf8b1b9dbb 100644 --- a/service/internal/security/in_process_provider.go +++ b/service/internal/security/in_process_provider.go @@ -268,6 +268,10 @@ func (a *InProcessProvider) FindKeyByID(_ context.Context, id trust.KeyIdentifie // ListKeys lists all available keys func (a *InProcessProvider) ListKeys(ctx context.Context) ([]trust.KeyDetails, error) { + return a.ListKeysWith(ctx, trust.ListKeyOptions{LegacyOnly: false}) +} + +func (a *InProcessProvider) ListKeysWith(ctx context.Context, opts trust.ListKeyOptions) ([]trust.KeyDetails, error) { // This is a limited implementation as CryptoProvider doesn't expose a list of all keys var keys []trust.KeyDetails @@ -275,6 +279,9 @@ func (a *InProcessProvider) ListKeys(ctx context.Context) ([]trust.KeyDetails, e for _, alg := range []string{AlgorithmRSA2048, AlgorithmECP256R1} { if kids, err := a.cryptoProvider.ListKIDsByAlgorithm(alg); err == nil && len(kids) > 0 { for _, kid := range kids { + if opts.LegacyOnly && !a.legacyKeys[kid] { + continue // Skip non-legacy keys if LegacyOnly is true + } keys = append(keys, &KeyDetailsAdapter{ id: trust.KeyIdentifier(kid), algorithm: ocrypto.KeyType(alg), diff --git a/service/kas/access/publicKey_test.go b/service/kas/access/publicKey_test.go index 728c3ae53f..5adf3b529d 100644 --- a/service/kas/access/publicKey_test.go +++ b/service/kas/access/publicKey_test.go @@ -127,6 +127,17 @@ func (m *MockSecurityProvider) ListKeys(_ context.Context) ([]trust.KeyDetails, return keys, nil } +func (m *MockSecurityProvider) ListKeysWith(_ context.Context, opts trust.ListKeyOptions) ([]trust.KeyDetails, error) { + var keys []trust.KeyDetails + for _, key := range m.keys { + if opts.LegacyOnly && !key.IsLegacy() { + continue + } + keys = append(keys, key) + } + return keys, nil +} + func (m *MockSecurityProvider) Decrypt(_ context.Context, _ trust.KeyDetails, _, _ []byte) (trust.ProtectedKey, error) { return nil, errors.New("not implemented for tests") } diff --git a/service/kas/access/rewrap.go b/service/kas/access/rewrap.go index d4e528a349..763dc76b90 100644 --- a/service/kas/access/rewrap.go +++ b/service/kas/access/rewrap.go @@ -652,7 +652,7 @@ func (p *Provider) listLegacyKeys(ctx context.Context) []trust.KeyIdentifier { return kidsToCheck } - k, err := p.KeyDelegator.ListKeys(ctx) + k, err := p.KeyDelegator.ListKeysWith(ctx, trust.ListKeyOptions{LegacyOnly: true}) if err != nil { p.Logger.WarnContext(ctx, "checkpoint KeyIndex.ListKeys failed", slog.Any("error", err)) } else { diff --git a/service/kas/access/rewrap_test.go b/service/kas/access/rewrap_test.go index 85983c2b16..15f141ec16 100644 --- a/service/kas/access/rewrap_test.go +++ b/service/kas/access/rewrap_test.go @@ -68,7 +68,23 @@ func (f *fakeKeyIndex) FindKeyByAlgorithm(context.Context, string, bool) (trust. func (f *fakeKeyIndex) FindKeyByID(context.Context, trust.KeyIdentifier) (trust.KeyDetails, error) { return nil, errors.New("not implemented") } -func (f *fakeKeyIndex) ListKeys(context.Context) ([]trust.KeyDetails, error) { return f.keys, f.err } + +func (f *fakeKeyIndex) ListKeys(context.Context) ([]trust.KeyDetails, error) { + return f.keys, f.err +} + +func (f *fakeKeyIndex) ListKeysWith(_ context.Context, opts trust.ListKeyOptions) ([]trust.KeyDetails, error) { + if opts.LegacyOnly { + var legacyKeys []trust.KeyDetails + for _, key := range f.keys { + if key.IsLegacy() { + legacyKeys = append(legacyKeys, key) + } + } + return legacyKeys, f.err + } + return f.keys, f.err +} func TestListLegacyKeys_KeyringPopulated(t *testing.T) { testLogger := logger.CreateTestLogger() diff --git a/service/kas/key_indexer.go b/service/kas/key_indexer.go index 3dd8cecc8c..7435ebf174 100644 --- a/service/kas/key_indexer.go +++ b/service/kas/key_indexer.go @@ -81,17 +81,23 @@ func convertAlgToEnum(alg string) (policy.Algorithm, error) { } } -func (p *KeyIndexer) FindKeyByAlgorithm(ctx context.Context, algorithm string, _ bool) (trust.KeyDetails, error) { +func (p *KeyIndexer) FindKeyByAlgorithm(ctx context.Context, algorithm string, includeLegacy bool) (trust.KeyDetails, error) { alg, err := convertAlgToEnum(algorithm) if err != nil { return nil, err } + var legacy *bool + if !includeLegacy { + legacy = &includeLegacy + } + req := &kasregistry.ListKeysRequest{ KeyAlgorithm: alg, KasFilter: &kasregistry.ListKeysRequest_KasUri{ KasUri: p.kasURI, }, + Legacy: legacy, } resp, err := p.sdk.KeyAccessServerRegistry.ListKeys(ctx, req) if err != nil { @@ -140,10 +146,20 @@ func (p *KeyIndexer) FindKeyByID(ctx context.Context, id trust.KeyIdentifier) (t } func (p *KeyIndexer) ListKeys(ctx context.Context) ([]trust.KeyDetails, error) { + return p.ListKeysWith(ctx, trust.ListKeyOptions{LegacyOnly: false}) +} + +func (p *KeyIndexer) ListKeysWith(ctx context.Context, opts trust.ListKeyOptions) ([]trust.KeyDetails, error) { + var legacyOnly *bool + if opts.LegacyOnly { + legacyOnly = &opts.LegacyOnly + } + req := &kasregistry.ListKeysRequest{ KasFilter: &kasregistry.ListKeysRequest_KasUri{ KasUri: p.kasURI, }, + Legacy: legacyOnly, } resp, err := p.sdk.KeyAccessServerRegistry.ListKeys(ctx, req) if err != nil { @@ -171,7 +187,7 @@ func (p *KeyAdapter) Algorithm() ocrypto.KeyType { } func (p *KeyAdapter) IsLegacy() bool { - return false + return p.key.GetKey().GetLegacy() } // This will point to the correct "manager" diff --git a/service/kas/key_indexer_test.go b/service/kas/key_indexer_test.go index a0367cc1c3..9605bea533 100644 --- a/service/kas/key_indexer_test.go +++ b/service/kas/key_indexer_test.go @@ -2,15 +2,89 @@ package kas import ( "context" + "errors" "testing" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/opentdf/platform/lib/ocrypto" "github.com/opentdf/platform/protocol/go/policy" + "github.com/opentdf/platform/protocol/go/policy/kasregistry" + "github.com/opentdf/platform/sdk" "github.com/opentdf/platform/service/trust" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" ) +type MockKeyAccessServerRegistryClient struct { + mock.Mock +} + +func (m *MockKeyAccessServerRegistryClient) CreateKeyAccessServer(context.Context, *kasregistry.CreateKeyAccessServerRequest) (*kasregistry.CreateKeyAccessServerResponse, error) { + return nil, errors.New("not implemented") +} + +func (m *MockKeyAccessServerRegistryClient) GetKeyAccessServer(context.Context, *kasregistry.GetKeyAccessServerRequest) (*kasregistry.GetKeyAccessServerResponse, error) { + return nil, errors.New("not implemented") +} + +func (m *MockKeyAccessServerRegistryClient) ListKeyAccessServers(context.Context, *kasregistry.ListKeyAccessServersRequest) (*kasregistry.ListKeyAccessServersResponse, error) { + return nil, errors.New("not implemented") +} + +func (m *MockKeyAccessServerRegistryClient) UpdateKeyAccessServer(context.Context, *kasregistry.UpdateKeyAccessServerRequest) (*kasregistry.UpdateKeyAccessServerResponse, error) { + return nil, errors.New("not implemented") +} + +func (m *MockKeyAccessServerRegistryClient) DeleteKeyAccessServer(context.Context, *kasregistry.DeleteKeyAccessServerRequest) (*kasregistry.DeleteKeyAccessServerResponse, error) { + return nil, errors.New("not implemented") +} + +func (m *MockKeyAccessServerRegistryClient) ListKeyAccessServerGrants(context.Context, *kasregistry.ListKeyAccessServerGrantsRequest) (*kasregistry.ListKeyAccessServerGrantsResponse, error) { + return nil, errors.New("not implemented") +} + +func (m *MockKeyAccessServerRegistryClient) CreateKey(context.Context, *kasregistry.CreateKeyRequest) (*kasregistry.CreateKeyResponse, error) { + return nil, errors.New("not implemented") +} + +func (m *MockKeyAccessServerRegistryClient) GetKey(context.Context, *kasregistry.GetKeyRequest) (*kasregistry.GetKeyResponse, error) { + return nil, errors.New("not implemented") +} + +func (m *MockKeyAccessServerRegistryClient) ListKeys(ctx context.Context, req *kasregistry.ListKeysRequest) (*kasregistry.ListKeysResponse, error) { + args := m.Called(ctx, req) + if args.Get(0) == nil { + return nil, args.Error(1) + } + + var resp *kasregistry.ListKeysResponse + var ok bool + if resp, ok = args.Get(0).(*kasregistry.ListKeysResponse); !ok { + return nil, args.Error(1) + } + return resp, args.Error(1) +} + +func (m *MockKeyAccessServerRegistryClient) UpdateKey(context.Context, *kasregistry.UpdateKeyRequest) (*kasregistry.UpdateKeyResponse, error) { + return nil, errors.New("not implemented") +} + +func (m *MockKeyAccessServerRegistryClient) RotateKey(context.Context, *kasregistry.RotateKeyRequest) (*kasregistry.RotateKeyResponse, error) { + return nil, errors.New("not implemented") +} + +func (m *MockKeyAccessServerRegistryClient) SetBaseKey(context.Context, *kasregistry.SetBaseKeyRequest) (*kasregistry.SetBaseKeyResponse, error) { + return nil, errors.New("not implemented") +} + +func (m *MockKeyAccessServerRegistryClient) GetBaseKey(context.Context, *kasregistry.GetBaseKeyRequest) (*kasregistry.GetBaseKeyResponse, error) { + return nil, errors.New("not implemented") +} + +func (m *MockKeyAccessServerRegistryClient) ListKeyMappings(context.Context, *kasregistry.ListKeyMappingsRequest) (*kasregistry.ListKeyMappingsResponse, error) { + return nil, errors.New("not implemented") +} + type KeyIndexTestSuite struct { suite.Suite rsaKey trust.KeyDetails @@ -73,6 +147,161 @@ func (s *KeyIndexTestSuite) TestKeyExportPublicKey_PKCSFormat() { s.Equal(pubCtx.GetPem(), string(base64Pem)) } +func (s *KeyIndexTestSuite) TestKeyDetails_Legacy() { + legacyKey := &KeyAdapter{ + key: &policy.KasKey{ + KasId: "test-kas-id", + Key: &policy.AsymmetricKey{ + Id: "test-id-legacy", + KeyId: "test-key-id-legacy", + KeyAlgorithm: policy.Algorithm_ALGORITHM_RSA_2048, + KeyStatus: policy.KeyStatus_KEY_STATUS_ACTIVE, + KeyMode: policy.KeyMode_KEY_MODE_CONFIG_ROOT_KEY, + Legacy: true, // Mark as legacy + PublicKeyCtx: &policy.PublicKeyCtx{ + Pem: "LS0tLS1CRUdJTiBQVUJMSUMgS0VZLS0tLS0KTUlJQklqQU5CZ2txaGtpRzl3MEJBUUVGQUFPQ0FROEFNSUlCQ2dLQ0FRRUF3SEw0TkVrOFpDa0JzNjZXQVpWagpIS3NseDRseWdmaXN3aW42RUx5OU9OczZLVDRYa1crRGxsdExtck14bHZkbzVRaDg1UmFZS01mWUdDTWtPM0dGCkFsK0JOeWFOM1kwa0N1QjNPU2ErTzdyMURhNVZteVVuaEJNbFBrYnVPY1Y0cjlLMUhOSGd3eDl2UFp3RjRpQW8KQStEY1VBcWFEeHlvYjV6enNGZ0hUNjJHLzdLdEtiZ2hYT1dCanRUYUl1ZHpsK2FaSjFPemY0U1RkOXhST2QrMQordVo2VG1ocmFEUm9zdDUrTTZUN0toL2lGWk40TTFUY2hwWXU1TDhKR2tVaG9YaEdZcHUrMGczSzlqYlh6RVh5CnpJU3VXN2d6SGRWYUxvcnBkQlNkRHpOWkNvTFVoL0U1T3d5TFZFQkNKaDZJVUtvdWJ5WHVucnIxQnJmK2tLbEsKeHdJREFRQUIKLS0tLS1FTkQgUFVCTElDIEtFWS0tLS0tCg==", + }, + }, + }, + } + s.True(legacyKey.IsLegacy()) +} + +func (s *KeyIndexTestSuite) TestListKeysWith() { + mockClient := new(MockKeyAccessServerRegistryClient) + keyIndexer := &KeyIndexer{ + sdk: &sdk.SDK{ + KeyAccessServerRegistry: mockClient, + }, + } + + // Mock the ListKeys function to return a specific key based on the legacy flag + mockClient.On("ListKeys", mock.Anything, mock.MatchedBy(func(req *kasregistry.ListKeysRequest) bool { + return req.GetLegacy() + })).Return(&kasregistry.ListKeysResponse{ + KasKeys: []*policy.KasKey{ + { + Key: &policy.AsymmetricKey{ + KeyId: "legacy-key-id", + }, + }, + }, + }, nil) + + mockClient.On("ListKeys", mock.Anything, mock.MatchedBy(func(req *kasregistry.ListKeysRequest) bool { + //nolint:staticcheck // Legacy optional flag should not be set + return req.Legacy == nil + })).Return(&kasregistry.ListKeysResponse{ + KasKeys: []*policy.KasKey{ + { + Key: &policy.AsymmetricKey{ + KeyId: "non-legacy-key-id", + }, + }, + { + Key: &policy.AsymmetricKey{ + KeyId: "legacy-key-id", + }, + }, + }, + }, nil) + + // Test with legacy flag set to true + keys, err := keyIndexer.ListKeysWith(context.Background(), trust.ListKeyOptions{LegacyOnly: true}) + s.Require().NoError(err) + s.Len(keys, 1) + s.Equal("legacy-key-id", string(keys[0].ID())) + + // Test with legacy flag set to false + keys, err = keyIndexer.ListKeysWith(context.Background(), trust.ListKeyOptions{LegacyOnly: false}) + s.Require().NoError(err) + s.Len(keys, 2) + s.Equal("non-legacy-key-id", string(keys[0].ID())) + s.Equal("legacy-key-id", string(keys[1].ID())) +} + +func (s *KeyIndexTestSuite) TestListKeys() { + mockClient := new(MockKeyAccessServerRegistryClient) + keyIndexer := &KeyIndexer{ + sdk: &sdk.SDK{ + KeyAccessServerRegistry: mockClient, + }, + } + + mockClient.On("ListKeys", mock.Anything, mock.MatchedBy(func(req *kasregistry.ListKeysRequest) bool { + return !req.GetLegacy() + })).Return(&kasregistry.ListKeysResponse{ + KasKeys: []*policy.KasKey{ + { + Key: &policy.AsymmetricKey{ + KeyId: "test-key-id", + }, + }, + }, + }, nil) + + keys, err := keyIndexer.ListKeys(context.Background()) + s.Require().NoError(err) + s.Len(keys, 1) + s.Equal("test-key-id", string(keys[0].ID())) +} + +func (s *KeyIndexTestSuite) TestFindKeyByAlgorithm() { + mockClient := new(MockKeyAccessServerRegistryClient) + keyIndexer := &KeyIndexer{ + sdk: &sdk.SDK{ + KeyAccessServerRegistry: mockClient, + }, + } + + mockClient.On("ListKeys", mock.Anything, mock.MatchedBy(func(req *kasregistry.ListKeysRequest) bool { + //nolint:staticcheck // Legacy optional flag should not be set + return req.GetKeyAlgorithm() == policy.Algorithm_ALGORITHM_RSA_2048 && (req.Legacy != nil && req.GetLegacy() == false) + })).Return(&kasregistry.ListKeysResponse{ + KasKeys: []*policy.KasKey{ + { + Key: &policy.AsymmetricKey{ + KeyId: "test-key-id", + KeyAlgorithm: policy.Algorithm_ALGORITHM_RSA_2048, + KeyStatus: policy.KeyStatus_KEY_STATUS_ACTIVE, + }, + }, + }, + }, nil) + + mockClient.On("ListKeys", mock.Anything, mock.MatchedBy(func(req *kasregistry.ListKeysRequest) bool { + //nolint:staticcheck // Legacy optional flag should not be set + return req.GetKeyAlgorithm() == policy.Algorithm_ALGORITHM_RSA_2048 && req.Legacy == nil + })).Return(&kasregistry.ListKeysResponse{ + KasKeys: []*policy.KasKey{ + { + Key: &policy.AsymmetricKey{ + KeyId: "test-legacy-key-id", + KeyAlgorithm: policy.Algorithm_ALGORITHM_RSA_2048, + KeyStatus: policy.KeyStatus_KEY_STATUS_ACTIVE, + }, + }, + { + Key: &policy.AsymmetricKey{ + KeyId: "test-key-id", + KeyAlgorithm: policy.Algorithm_ALGORITHM_RSA_2048, + KeyStatus: policy.KeyStatus_KEY_STATUS_ACTIVE, + }, + }, + }, + }, nil) + + key, err := keyIndexer.FindKeyByAlgorithm(context.Background(), string(ocrypto.RSA2048Key), false) + s.Require().NoError(err) + s.NotNil(key) + s.Equal("test-key-id", string(key.ID())) + + key, err = keyIndexer.FindKeyByAlgorithm(context.Background(), string(ocrypto.RSA2048Key), true) + s.Require().NoError(err) + s.NotNil(key) + s.Equal("test-legacy-key-id", string(key.ID())) +} + func TestNewPlatformKeyIndexTestSuite(t *testing.T) { suite.Run(t, new(KeyIndexTestSuite)) } diff --git a/service/trust/delegating_key_service.go b/service/trust/delegating_key_service.go index 7e37c17648..2b5b02244a 100644 --- a/service/trust/delegating_key_service.go +++ b/service/trust/delegating_key_service.go @@ -78,6 +78,10 @@ func (d *DelegatingKeyService) ListKeys(ctx context.Context) ([]KeyDetails, erro return d.index.ListKeys(ctx) } +func (d *DelegatingKeyService) ListKeysWith(ctx context.Context, opts ListKeyOptions) ([]KeyDetails, error) { + return d.index.ListKeysWith(ctx, opts) +} + // Implementing KeyManager methods func (d *DelegatingKeyService) Name() string { return "DelegatingKeyService" diff --git a/service/trust/delegating_key_service_test.go b/service/trust/delegating_key_service_test.go index d07380a7b1..7871a75210 100644 --- a/service/trust/delegating_key_service_test.go +++ b/service/trust/delegating_key_service_test.go @@ -78,6 +78,14 @@ func (m *MockKeyIndex) ListKeys(ctx context.Context) ([]KeyDetails, error) { return nil, args.Error(1) } +func (m *MockKeyIndex) ListKeysWith(ctx context.Context, opts ListKeyOptions) ([]KeyDetails, error) { + args := m.Called(ctx, opts) + if a0, ok := args.Get(0).([]KeyDetails); ok { + return a0, args.Error(1) + } + return nil, args.Error(1) +} + // MockKeyDetails is a mock implementation of the KeyDetails interface type MockKeyDetails struct { mock.Mock @@ -226,6 +234,21 @@ func (suite *DelegatingKeyServiceTestSuite) TestListKeys() { suite.Len(keys, 1) } +func (suite *DelegatingKeyServiceTestSuite) TestListKeysWith_Legacy() { + legacyKey := &MockKeyDetails{} + legacyKey.On("IsLegacy").Return(true) + + nonLegacyKey := &MockKeyDetails{} + nonLegacyKey.On("IsLegacy").Return(false) + + suite.mockIndex.On("ListKeysWith", mock.Anything, ListKeyOptions{LegacyOnly: true}).Return([]KeyDetails{legacyKey}, nil) + + keys, err := suite.service.ListKeysWith(context.Background(), ListKeyOptions{LegacyOnly: true}) + suite.Require().NoError(err) + suite.Len(keys, 1) + suite.True(keys[0].IsLegacy()) +} + func (suite *DelegatingKeyServiceTestSuite) TestDecrypt() { mockKeyDetails := &MockKeyDetails{} mockKeyDetails.On("System").Return("mockManager") diff --git a/service/trust/key_index.go b/service/trust/key_index.go index 6ada58349d..41c81f68c9 100644 --- a/service/trust/key_index.go +++ b/service/trust/key_index.go @@ -10,6 +10,12 @@ import ( // KeyType represents the format in which a key can be exported type KeyType int +// Key Options to pass into ListKeysWith +// when filtering keys +type ListKeyOptions struct { + LegacyOnly bool +} + const ( // KeyTypeJWK represents a key in JWK format KeyTypeJWK KeyType = iota @@ -66,4 +72,7 @@ type KeyIndex interface { // ListKeys returns all available keys ListKeys(ctx context.Context) ([]KeyDetails, error) + + // List keys with options + ListKeysWith(ctx context.Context, opts ListKeyOptions) ([]KeyDetails, error) }