Skip to content

Commit 79971f5

Browse files
committed
feat(policy): Add platform key indexer.
remove comments. linting. linting.
1 parent 8822c45 commit 79971f5

File tree

2 files changed

+334
-0
lines changed

2 files changed

+334
-0
lines changed
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
package trust
2+
3+
import (
4+
"context"
5+
"crypto/rsa"
6+
"crypto/x509"
7+
"encoding/base64"
8+
"encoding/json"
9+
"encoding/pem"
10+
"errors"
11+
"fmt"
12+
13+
"github.com/lestrrat-go/jwx/v2/jwk"
14+
"github.com/opentdf/platform/protocol/go/policy"
15+
"github.com/opentdf/platform/protocol/go/policy/kasregistry"
16+
"github.com/opentdf/platform/sdk"
17+
"github.com/opentdf/platform/service/logger"
18+
)
19+
20+
var ErrNoActiveKeyForAlgorithm = errors.New("no active key found for specified algorithm")
21+
22+
// Used for reaching out to platform to get keys
23+
type PlatformKeyIndexer struct {
24+
// KeyIndex is the key index used to manage keys
25+
KeyIndex
26+
// SDK is the SDK instance used to interact with the platform
27+
sdk *sdk.SDK
28+
// KasURI
29+
kasURI string
30+
// Logger is the logger instance used for logging
31+
log *logger.Logger
32+
}
33+
34+
// platformKeyAdapter is an adapter for KeyDetails, where keys come from the platform
35+
type KasKeyAdapter struct {
36+
key *policy.KasKey
37+
log *logger.Logger
38+
}
39+
40+
func NewPlatformKeyIndexer(sdk *sdk.SDK, kasURI string, l *logger.Logger) *PlatformKeyIndexer {
41+
return &PlatformKeyIndexer{
42+
sdk: sdk,
43+
kasURI: kasURI,
44+
log: l,
45+
}
46+
}
47+
48+
func convertAlgToEnum(alg string) (policy.Algorithm, error) {
49+
switch alg {
50+
case "rsa:2048":
51+
return policy.Algorithm_ALGORITHM_RSA_2048, nil
52+
case "rsa:4096":
53+
return policy.Algorithm_ALGORITHM_RSA_4096, nil
54+
case "ec:secp256r1":
55+
return policy.Algorithm_ALGORITHM_EC_P256, nil
56+
case "ec:secp384r1":
57+
return policy.Algorithm_ALGORITHM_EC_P384, nil
58+
case "ec:secp521r1":
59+
return policy.Algorithm_ALGORITHM_EC_P521, nil
60+
default:
61+
return policy.Algorithm_ALGORITHM_UNSPECIFIED, errors.New("unsupported algorithm")
62+
}
63+
}
64+
65+
func (p *PlatformKeyIndexer) FindKeyByAlgorithm(ctx context.Context, algorithm string, _ bool) (KeyDetails, error) {
66+
alg, err := convertAlgToEnum(algorithm)
67+
if err != nil {
68+
return nil, err
69+
}
70+
71+
req := &kasregistry.ListKeysRequest{
72+
KeyAlgorithm: alg,
73+
KasFilter: &kasregistry.ListKeysRequest_KasUri{
74+
KasUri: p.kasURI,
75+
},
76+
}
77+
resp, err := p.sdk.KeyAccessServerRegistry.ListKeys(ctx, req)
78+
if err != nil {
79+
return nil, err
80+
}
81+
82+
// Find active key.
83+
var activeKey *policy.KasKey
84+
for _, key := range resp.GetKasKeys() {
85+
if key.GetKey().GetKeyStatus() == policy.KeyStatus_KEY_STATUS_ACTIVE {
86+
activeKey = key
87+
break
88+
}
89+
}
90+
if activeKey == nil {
91+
return nil, ErrNoActiveKeyForAlgorithm
92+
}
93+
94+
return &KasKeyAdapter{
95+
key: activeKey,
96+
log: p.log,
97+
}, nil
98+
}
99+
100+
func (p *PlatformKeyIndexer) FindKeyByID(ctx context.Context, id KeyIdentifier) (KeyDetails, error) {
101+
req := &kasregistry.GetKeyRequest{
102+
Identifier: &kasregistry.GetKeyRequest_Key{
103+
Key: &kasregistry.KasKeyIdentifier{
104+
Identifier: &kasregistry.KasKeyIdentifier_Uri{
105+
Uri: p.kasURI,
106+
},
107+
Kid: string(id),
108+
},
109+
},
110+
}
111+
112+
resp, err := p.sdk.KeyAccessServerRegistry.GetKey(ctx, req)
113+
if err != nil {
114+
return nil, err
115+
}
116+
117+
return &KasKeyAdapter{
118+
key: resp.GetKasKey(),
119+
log: p.log,
120+
}, nil
121+
}
122+
123+
func (p *PlatformKeyIndexer) ListKeys(ctx context.Context) ([]KeyDetails, error) {
124+
req := &kasregistry.ListKeysRequest{
125+
KasFilter: &kasregistry.ListKeysRequest_KasUri{
126+
KasUri: p.kasURI,
127+
},
128+
}
129+
resp, err := p.sdk.KeyAccessServerRegistry.ListKeys(ctx, req)
130+
if err != nil {
131+
return nil, err
132+
}
133+
134+
keys := make([]KeyDetails, len(resp.GetKasKeys()))
135+
for i, key := range resp.GetKasKeys() {
136+
keys[i] = &KasKeyAdapter{
137+
key: key,
138+
log: p.log,
139+
}
140+
}
141+
142+
return keys, nil
143+
}
144+
145+
func (p *KasKeyAdapter) ID() KeyIdentifier {
146+
return KeyIdentifier(p.key.GetKey().GetKeyId())
147+
}
148+
149+
// Might need to convert this to a standard format
150+
func (p *KasKeyAdapter) Algorithm() string {
151+
return p.key.GetKey().GetKeyAlgorithm().String()
152+
}
153+
154+
func (p *KasKeyAdapter) IsLegacy() bool {
155+
return false
156+
}
157+
158+
// This will point to the correct "manager"
159+
func (p *KasKeyAdapter) System() string {
160+
var mode string
161+
if p.key.GetKey().GetProviderConfig() != nil {
162+
mode = p.key.GetKey().GetProviderConfig().GetName()
163+
}
164+
return mode
165+
}
166+
167+
func pemToPublicKey(publicPEM string) (*rsa.PublicKey, error) {
168+
// Decode the PEM data
169+
block, _ := pem.Decode([]byte(publicPEM))
170+
if block == nil || block.Type != "PUBLIC KEY" {
171+
return nil, fmt.Errorf("failed to decode PEM block or incorrect PEM type")
172+
}
173+
174+
// Parse the public key
175+
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
176+
if err != nil {
177+
return nil, fmt.Errorf("failed to parse public key: %w", err)
178+
}
179+
180+
// Assert type and return
181+
rsaPub, ok := pub.(*rsa.PublicKey)
182+
if !ok {
183+
return nil, fmt.Errorf("not an RSA public key")
184+
}
185+
186+
return rsaPub, nil
187+
}
188+
189+
// Repurpose of the StandardCrypto function
190+
func rsaPublicKeyAsJSON(_ context.Context, publicPEM string) (string, error) {
191+
pubKey, err := pemToPublicKey(publicPEM)
192+
if err != nil {
193+
return "", err
194+
}
195+
196+
rsaPublicKeyJwk, err := jwk.FromRaw(pubKey)
197+
if err != nil {
198+
return "", fmt.Errorf("jwk.FromRaw: %w", err)
199+
}
200+
201+
// Convert the public key to JSON format
202+
pubKeyJSON, err := json.Marshal(rsaPublicKeyJwk)
203+
if err != nil {
204+
return "", err
205+
}
206+
207+
return string(pubKeyJSON), nil
208+
}
209+
210+
// Repurpose of the StandardCrypto function
211+
func convertPEMToJWK(_ string) (string, error) {
212+
return "", errors.New("convertPEMToJWK function is not implemented")
213+
}
214+
215+
func (p *KasKeyAdapter) ExportPublicKey(ctx context.Context, format KeyType) (string, error) {
216+
publicKeyCtx := p.key.GetKey().GetPublicKeyCtx()
217+
var pubKeyCtxMap map[string]any
218+
if err := json.Unmarshal(publicKeyCtx, &pubKeyCtxMap); err != nil {
219+
return "", err
220+
}
221+
222+
pubKey, ok := pubKeyCtxMap["pubKey"].(string)
223+
if !ok {
224+
return "", errors.New("public key is not a string")
225+
}
226+
// Decode the base64-encoded public key
227+
decodedPubKey, err := base64.StdEncoding.DecodeString(pubKey)
228+
if err != nil {
229+
return "", err
230+
}
231+
232+
switch format {
233+
case KeyTypeJWK:
234+
// For JWK format (currently only supported for RSA)
235+
if p.key.GetKey().GetKeyAlgorithm() == policy.Algorithm_ALGORITHM_RSA_2048 ||
236+
p.key.GetKey().GetKeyAlgorithm() == policy.Algorithm_ALGORITHM_RSA_4096 {
237+
return rsaPublicKeyAsJSON(ctx, string(decodedPubKey))
238+
}
239+
// For EC keys, we return the public key in PEM format
240+
jwkKey, err := convertPEMToJWK(string(decodedPubKey))
241+
if err != nil {
242+
return "", err
243+
}
244+
245+
return jwkKey, nil
246+
case KeyTypePKCS8:
247+
return string(decodedPubKey), nil
248+
default:
249+
return "", errors.New("unsupported key type")
250+
}
251+
}
252+
253+
func (p *KasKeyAdapter) ExportCertificate(_ context.Context) (string, error) {
254+
return "", errors.New("not implemented")
255+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package trust
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"testing"
7+
8+
"github.com/lestrrat-go/jwx/v2/jwk"
9+
"github.com/opentdf/platform/lib/ocrypto"
10+
"github.com/opentdf/platform/protocol/go/policy"
11+
"github.com/stretchr/testify/suite"
12+
)
13+
14+
type PlatformKeyIndexTestSuite struct {
15+
suite.Suite
16+
rsaKey KeyDetails
17+
}
18+
19+
func (s *PlatformKeyIndexTestSuite) SetupTest() {
20+
s.rsaKey = &KasKeyAdapter{
21+
key: &policy.KasKey{
22+
KasId: "test-kas-id",
23+
Key: &policy.AsymmetricKey{
24+
Id: "test-key-id",
25+
KeyId: "test-key-id",
26+
KeyAlgorithm: policy.Algorithm_ALGORITHM_RSA_2048,
27+
KeyStatus: policy.KeyStatus_KEY_STATUS_ACTIVE,
28+
KeyMode: policy.KeyMode_KEY_MODE_LOCAL,
29+
PublicKeyCtx: []byte(`{"pubKey": "LS0tLS1CRUdJTiBQVUJMSUMgS0VZLS0tLS0KTUlJQklqQU5CZ2txaGtpRzl3MEJBUUVGQUFPQ0FROEFNSUlCQ2dLQ0FRRUF3SEw0TkVrOFpDa0JzNjZXQVpWagpIS3NseDRseWdmaXN3aW42RUx5OU9OczZLVDRYa1crRGxsdExtck14bHZkbzVRaDg1UmFZS01mWUdDTWtPM0dGCkFsK0JOeWFOM1kwa0N1QjNPU2ErTzdyMURhNVZteVVuaEJNbFBrYnVPY1Y0cjlLMUhOSGd3eDl2UFp3RjRpQW8KQStEY1VBcWFEeHlvYjV6enNGZ0hUNjJHLzdLdEtiZ2hYT1dCanRUYUl1ZHpsK2FaSjFPemY0U1RkOXhST2QrMQordVo2VG1ocmFEUm9zdDUrTTZUN0toL2lGWk40TTFUY2hwWXU1TDhKR2tVaG9YaEdZcHUrMGczSzlqYlh6RVh5CnpJU3VXN2d6SGRWYUxvcnBkQlNkRHpOWkNvTFVoL0U1T3d5TFZFQkNKaDZJVUtvdWJ5WHVucnIxQnJmK2tLbEsKeHdJREFRQUIKLS0tLS1FTkQgUFVCTElDIEtFWS0tLS0tCg=="}`),
30+
ProviderConfig: &policy.KeyProviderConfig{
31+
Id: "test-provider-id",
32+
Name: "openbao",
33+
},
34+
},
35+
},
36+
}
37+
}
38+
func (s *PlatformKeyIndexTestSuite) TearDownTest() {}
39+
40+
func (s *PlatformKeyIndexTestSuite) TestKeyDetails() {
41+
s.Equal("test-key-id", s.rsaKey.ID())
42+
s.Equal("ALGORITHM_RSA_2048", s.rsaKey.Algorithm())
43+
s.False(s.rsaKey.IsLegacy())
44+
s.Equal("openbao", s.rsaKey.System())
45+
}
46+
47+
func (s *PlatformKeyIndexTestSuite) TestKeyExportPublicKey_JWKFormat() {
48+
// Export JWK format
49+
jwkString, err := s.rsaKey.ExportPublicKey(context.Background(), KeyTypeJWK)
50+
s.Require().NoError(err)
51+
s.Require().NotEmpty(jwkString)
52+
53+
rsaKey, err := jwk.ParseKey([]byte(jwkString))
54+
s.Require().NoError(err)
55+
s.Require().NotNil(rsaKey)
56+
}
57+
58+
func (s *PlatformKeyIndexTestSuite) TestKeyExportPublicKey_PKCSFormat() {
59+
// Export JWK format
60+
pem, err := s.rsaKey.ExportPublicKey(context.Background(), KeyTypePKCS8)
61+
s.Require().NoError(err)
62+
s.Require().NotEmpty(pem)
63+
64+
keyAdapter, ok := s.rsaKey.(*KasKeyAdapter)
65+
s.Require().True(ok)
66+
pubCtx := keyAdapter.key.GetKey().GetPublicKeyCtx()
67+
s.Require().NotEmpty(pubCtx)
68+
base64Pem := ocrypto.Base64Encode([]byte(pem))
69+
70+
var pubCtxMap map[string]interface{}
71+
err = json.Unmarshal(pubCtx, &pubCtxMap)
72+
s.Require().NoError(err)
73+
74+
s.Equal(pubCtxMap["pubKey"], string(base64Pem))
75+
}
76+
77+
func TestNewPlatformKeyIndexTestSuite(t *testing.T) {
78+
suite.Run(t, new(PlatformKeyIndexTestSuite))
79+
}

0 commit comments

Comments
 (0)