Skip to content

Commit c362a31

Browse files
committed
feat(policy): Add platform key indexer.
remove comments. linting.
1 parent 8822c45 commit c362a31

File tree

2 files changed

+335
-0
lines changed

2 files changed

+335
-0
lines changed
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
package trust
2+
3+
import (
4+
"context"
5+
"crypto/rsa"
6+
"crypto/x509"
7+
"encoding/base64"
8+
"encoding/json"
9+
"encoding/pem"
10+
"errors"
11+
"fmt"
12+
13+
"github.com/lestrrat-go/jwx/v2/jwk"
14+
"github.com/opentdf/platform/protocol/go/policy"
15+
"github.com/opentdf/platform/protocol/go/policy/kasregistry"
16+
"github.com/opentdf/platform/sdk"
17+
"github.com/opentdf/platform/service/logger"
18+
)
19+
20+
var (
21+
ErrNoActiveKeyForAlgorithm = errors.New("no active key found for specified algorithm")
22+
)
23+
24+
// Used for reaching out to platform to get keys
25+
type PlatformKeyIndexer struct {
26+
// KeyIndex is the key index used to manage keys
27+
KeyIndex
28+
// SDK is the SDK instance used to interact with the platform
29+
sdk *sdk.SDK
30+
// KasURI
31+
kasURI string
32+
// Logger is the logger instance used for logging
33+
log *logger.Logger
34+
}
35+
36+
// platformKeyAdapter is an adapter for KeyDetails, where keys come from the platform
37+
type KasKeyAdapter struct {
38+
key *policy.KasKey
39+
log *logger.Logger
40+
}
41+
42+
func NewPlatformKeyIndexer(sdk *sdk.SDK, kasURI string, l *logger.Logger) *PlatformKeyIndexer {
43+
return &PlatformKeyIndexer{
44+
sdk: sdk,
45+
kasURI: kasURI,
46+
log: l,
47+
}
48+
}
49+
50+
func convertAlgToEnum(alg string) (policy.Algorithm, error) {
51+
switch alg {
52+
case "rsa:2048":
53+
return policy.Algorithm_ALGORITHM_RSA_2048, nil
54+
case "rsa:4096":
55+
return policy.Algorithm_ALGORITHM_RSA_4096, nil
56+
case "ec:secp256r1":
57+
return policy.Algorithm_ALGORITHM_EC_P256, nil
58+
case "ec:secp384r1":
59+
return policy.Algorithm_ALGORITHM_EC_P384, nil
60+
case "ec:secp521r1":
61+
return policy.Algorithm_ALGORITHM_EC_P521, nil
62+
default:
63+
return policy.Algorithm_ALGORITHM_UNSPECIFIED, errors.New("unsupported algorithm")
64+
}
65+
}
66+
67+
func (p *PlatformKeyIndexer) FindKeyByAlgorithm(ctx context.Context, algorithm string, _ bool) (KeyDetails, error) {
68+
alg, err := convertAlgToEnum(algorithm)
69+
if err != nil {
70+
return nil, err
71+
}
72+
73+
req := &kasregistry.ListKeysRequest{
74+
KeyAlgorithm: alg,
75+
KasFilter: &kasregistry.ListKeysRequest_KasUri{
76+
KasUri: p.kasURI,
77+
},
78+
}
79+
resp, err := p.sdk.KeyAccessServerRegistry.ListKeys(ctx, req)
80+
if err != nil {
81+
return nil, err
82+
}
83+
84+
// Find active key.
85+
var activeKey *policy.KasKey
86+
for _, key := range resp.GetKasKeys() {
87+
if key.GetKey().GetKeyStatus() == policy.KeyStatus_KEY_STATUS_ACTIVE {
88+
activeKey = key
89+
break
90+
}
91+
}
92+
if activeKey == nil {
93+
return nil, ErrNoActiveKeyForAlgorithm
94+
}
95+
96+
return &KasKeyAdapter{
97+
key: activeKey,
98+
log: p.log,
99+
}, nil
100+
}
101+
102+
func (p *PlatformKeyIndexer) FindKeyByID(ctx context.Context, id KeyIdentifier) (KeyDetails, error) {
103+
req := &kasregistry.GetKeyRequest{
104+
Identifier: &kasregistry.GetKeyRequest_Key{
105+
Key: &kasregistry.KasKeyIdentifier{
106+
Identifier: &kasregistry.KasKeyIdentifier_Uri{
107+
Uri: p.kasURI,
108+
},
109+
Kid: string(id),
110+
},
111+
},
112+
}
113+
114+
resp, err := p.sdk.KeyAccessServerRegistry.GetKey(ctx, req)
115+
if err != nil {
116+
return nil, err
117+
}
118+
119+
return &KasKeyAdapter{
120+
key: resp.GetKasKey(),
121+
log: p.log,
122+
}, nil
123+
}
124+
125+
func (p *PlatformKeyIndexer) ListKeys(ctx context.Context) ([]KeyDetails, error) {
126+
req := &kasregistry.ListKeysRequest{
127+
KasFilter: &kasregistry.ListKeysRequest_KasUri{
128+
KasUri: p.kasURI,
129+
},
130+
}
131+
resp, err := p.sdk.KeyAccessServerRegistry.ListKeys(ctx, req)
132+
if err != nil {
133+
return nil, err
134+
}
135+
136+
keys := make([]KeyDetails, len(resp.GetKasKeys()))
137+
for i, key := range resp.GetKasKeys() {
138+
keys[i] = &KasKeyAdapter{
139+
key: key,
140+
log: p.log,
141+
}
142+
}
143+
144+
return keys, nil
145+
}
146+
147+
func (p *KasKeyAdapter) ID() KeyIdentifier {
148+
return KeyIdentifier(p.key.GetKey().GetKeyId())
149+
}
150+
151+
// Might need to convert this to a standard format
152+
func (p *KasKeyAdapter) Algorithm() string {
153+
return p.key.GetKey().GetKeyAlgorithm().String()
154+
}
155+
func (p *KasKeyAdapter) IsLegacy() bool {
156+
return false
157+
}
158+
159+
// This will point to the correct "manager"
160+
func (p *KasKeyAdapter) System() string {
161+
var mode string
162+
if p.key.GetKey().GetProviderConfig() != nil {
163+
mode = p.key.GetKey().GetProviderConfig().GetName()
164+
}
165+
return mode
166+
}
167+
168+
func pemToPublicKey(publicPEM string) (*rsa.PublicKey, error) {
169+
// Decode the PEM data
170+
block, _ := pem.Decode([]byte(publicPEM))
171+
if block == nil || block.Type != "PUBLIC KEY" {
172+
return nil, fmt.Errorf("failed to decode PEM block or incorrect PEM type")
173+
}
174+
175+
// Parse the public key
176+
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
177+
if err != nil {
178+
return nil, fmt.Errorf("failed to parse public key: %w", err)
179+
}
180+
181+
// Assert type and return
182+
rsaPub, ok := pub.(*rsa.PublicKey)
183+
if !ok {
184+
return nil, fmt.Errorf("not an RSA public key")
185+
}
186+
187+
return rsaPub, nil
188+
}
189+
190+
// Repurpose of the StandardCrypto function
191+
func rsaPublicKeyAsJSON(_ context.Context, publicPEM string) (string, error) {
192+
pubKey, err := pemToPublicKey(publicPEM)
193+
if err != nil {
194+
return "", err
195+
}
196+
197+
rsaPublicKeyJwk, err := jwk.FromRaw(pubKey)
198+
if err != nil {
199+
return "", fmt.Errorf("jwk.FromRaw: %w", err)
200+
}
201+
202+
// Convert the public key to JSON format
203+
pubKeyJSON, err := json.Marshal(rsaPublicKeyJwk)
204+
if err != nil {
205+
return "", err
206+
}
207+
208+
return string(pubKeyJSON), nil
209+
}
210+
211+
// Repurpose of the StandardCrypto function
212+
func convertPEMToJWK(_ string) (string, error) {
213+
return "", errors.New("convertPEMToJWK function is not implemented")
214+
}
215+
216+
func (p *KasKeyAdapter) ExportPublicKey(ctx context.Context, format KeyType) (string, error) {
217+
publicKeyCtx := p.key.GetKey().GetPublicKeyCtx()
218+
var pubKeyCtxMap map[string]any
219+
if err := json.Unmarshal(publicKeyCtx, &pubKeyCtxMap); err != nil {
220+
return "", err
221+
}
222+
223+
pubKey, ok := pubKeyCtxMap["pubKey"].(string)
224+
if !ok {
225+
return "", errors.New("public key is not a string")
226+
}
227+
// Decode the base64-encoded public key
228+
decodedPubKey, err := base64.StdEncoding.DecodeString(pubKey)
229+
if err != nil {
230+
return "", err
231+
}
232+
233+
switch format {
234+
case KeyTypeJWK:
235+
// For JWK format (currently only supported for RSA)
236+
if p.key.GetKey().GetKeyAlgorithm() == policy.Algorithm_ALGORITHM_RSA_2048 ||
237+
p.key.GetKey().GetKeyAlgorithm() == policy.Algorithm_ALGORITHM_RSA_4096 {
238+
return rsaPublicKeyAsJSON(ctx, string(decodedPubKey))
239+
}
240+
// For EC keys, we return the public key in PEM format
241+
jwkKey, err := convertPEMToJWK(string(decodedPubKey))
242+
if err != nil {
243+
return "", err
244+
}
245+
246+
return jwkKey, nil
247+
case KeyTypePKCS8:
248+
return string(decodedPubKey), nil
249+
default:
250+
return "", errors.New("unsupported key type")
251+
}
252+
}
253+
254+
func (p *KasKeyAdapter) ExportCertificate(_ context.Context) (string, error) {
255+
return "", errors.New("not implemented")
256+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package trust
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"testing"
7+
8+
"github.com/lestrrat-go/jwx/v2/jwk"
9+
"github.com/opentdf/platform/lib/ocrypto"
10+
"github.com/opentdf/platform/protocol/go/policy"
11+
"github.com/stretchr/testify/suite"
12+
)
13+
14+
type PlatformKeyIndexTestSuite struct {
15+
suite.Suite
16+
rsaKey KeyDetails
17+
}
18+
19+
func (s *PlatformKeyIndexTestSuite) SetupTest() {
20+
s.rsaKey = &KasKeyAdapter{
21+
key: &policy.KasKey{
22+
KasId: "test-kas-id",
23+
Key: &policy.AsymmetricKey{
24+
Id: "test-key-id",
25+
KeyId: "test-key-id",
26+
KeyAlgorithm: policy.Algorithm_ALGORITHM_RSA_2048,
27+
KeyStatus: policy.KeyStatus_KEY_STATUS_ACTIVE,
28+
KeyMode: policy.KeyMode_KEY_MODE_LOCAL,
29+
PublicKeyCtx: []byte(`{"pubKey": "LS0tLS1CRUdJTiBQVUJMSUMgS0VZLS0tLS0KTUlJQklqQU5CZ2txaGtpRzl3MEJBUUVGQUFPQ0FROEFNSUlCQ2dLQ0FRRUF3SEw0TkVrOFpDa0JzNjZXQVpWagpIS3NseDRseWdmaXN3aW42RUx5OU9OczZLVDRYa1crRGxsdExtck14bHZkbzVRaDg1UmFZS01mWUdDTWtPM0dGCkFsK0JOeWFOM1kwa0N1QjNPU2ErTzdyMURhNVZteVVuaEJNbFBrYnVPY1Y0cjlLMUhOSGd3eDl2UFp3RjRpQW8KQStEY1VBcWFEeHlvYjV6enNGZ0hUNjJHLzdLdEtiZ2hYT1dCanRUYUl1ZHpsK2FaSjFPemY0U1RkOXhST2QrMQordVo2VG1ocmFEUm9zdDUrTTZUN0toL2lGWk40TTFUY2hwWXU1TDhKR2tVaG9YaEdZcHUrMGczSzlqYlh6RVh5CnpJU3VXN2d6SGRWYUxvcnBkQlNkRHpOWkNvTFVoL0U1T3d5TFZFQkNKaDZJVUtvdWJ5WHVucnIxQnJmK2tLbEsKeHdJREFRQUIKLS0tLS1FTkQgUFVCTElDIEtFWS0tLS0tCg=="}`),
30+
ProviderConfig: &policy.KeyProviderConfig{
31+
Id: "test-provider-id",
32+
Name: "openbao",
33+
},
34+
},
35+
},
36+
}
37+
}
38+
func (s *PlatformKeyIndexTestSuite) TearDownTest() {}
39+
40+
func (s *PlatformKeyIndexTestSuite) TestKeyDetails() {
41+
s.Equal("test-key-id", s.rsaKey.ID())
42+
s.Equal("ALGORITHM_RSA_2048", s.rsaKey.Algorithm())
43+
s.False(s.rsaKey.IsLegacy())
44+
s.Equal("openbao", s.rsaKey.System())
45+
}
46+
47+
func (s *PlatformKeyIndexTestSuite) TestKeyExportPublicKey_JWKFormat() {
48+
// Export JWK format
49+
jwkString, err := s.rsaKey.ExportPublicKey(context.Background(), KeyTypeJWK)
50+
s.Require().NoError(err)
51+
s.Require().NotEmpty(jwkString)
52+
53+
rsaKey, err := jwk.ParseKey([]byte(jwkString))
54+
s.Require().NoError(err)
55+
s.Require().NotNil(rsaKey)
56+
}
57+
58+
func (s *PlatformKeyIndexTestSuite) TestKeyExportPublicKey_PKCSFormat() {
59+
// Export JWK format
60+
pem, err := s.rsaKey.ExportPublicKey(context.Background(), KeyTypePKCS8)
61+
s.Require().NoError(err)
62+
s.Require().NotEmpty(pem)
63+
64+
keyAdapter, ok := s.rsaKey.(*KasKeyAdapter)
65+
s.Require().True(ok)
66+
pubCtx := keyAdapter.key.GetKey().GetPublicKeyCtx()
67+
s.Require().NotEmpty(pubCtx)
68+
base64Pem := ocrypto.Base64Encode([]byte(pem))
69+
70+
var pubCtxMap map[string]interface{}
71+
err = json.Unmarshal(pubCtx, &pubCtxMap)
72+
s.Require().NoError(err)
73+
74+
s.Equal(pubCtxMap["pubKey"], string(base64Pem))
75+
}
76+
77+
func TestNewPlatformKeyIndexTestSuite(t *testing.T) {
78+
suite.Run(t, new(PlatformKeyIndexTestSuite))
79+
}

0 commit comments

Comments
 (0)