Skip to content

Commit f444a46

Browse files
authored
fix(sdk): Retrieve required obligations from correct metadata (#2809)
### Proposed Changes 1.) Fix `kas_client` to retrieve obligations from each kao metadata 2.) Dedup returned obligations 3.) Fix tests ### Checklist - [ ] I have added or updated unit tests - [ ] I have added or updated integration tests (if appropriate) - [ ] I have added or updated documentation ### Testing Instructions
1 parent 05a41dd commit f444a46

File tree

6 files changed

+578
-287
lines changed

6 files changed

+578
-287
lines changed

sdk/bulk.go

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ type BulkDecryptRequest struct {
3333
type BulkDecryptPrepared struct {
3434
PolicyTDF map[string]*BulkTDF
3535
tdfDecryptors map[string]decryptor
36-
allRewrapResp map[string]policyResult
36+
allRewrapResp map[string][]kaoResult
3737
}
3838

3939
// BulkErrors List of Errors that Failed during Bulk Decryption
@@ -181,9 +181,9 @@ func (s SDK) prepareDecryptors(ctx context.Context, bulkReq *BulkDecryptRequest)
181181
}
182182

183183
// performRewraps executes all rewrap requests with KAS servers
184-
func (s SDK) performRewraps(ctx context.Context, bulkReq *BulkDecryptRequest, kasRewrapRequests map[string][]*kas.UnsignedRewrapRequest_WithPolicyRequest, fulfillableObligations []string) (map[string]policyResult, error) {
184+
func (s SDK) performRewraps(ctx context.Context, bulkReq *BulkDecryptRequest, kasRewrapRequests map[string][]*kas.UnsignedRewrapRequest_WithPolicyRequest, fulfillableObligations []string) (map[string][]kaoResult, error) {
185185
kasClient := newKASClient(s.conn.Client, s.conn.Options, s.tokenSource, s.kasSessionKey, fulfillableObligations)
186-
allRewrapResp := make(map[string]policyResult)
186+
allRewrapResp := make(map[string][]kaoResult)
187187
var err error
188188

189189
for kasurl, rewrapRequests := range kasRewrapRequests {
@@ -194,21 +194,16 @@ func (s SDK) performRewraps(ctx context.Context, bulkReq *BulkDecryptRequest, ka
194194
for _, req := range rewrapRequests {
195195
id := req.GetPolicy().GetId()
196196
for _, kao := range req.GetKeyAccessObjects() {
197-
policyRewrapResp, ok := allRewrapResp[id]
198-
if !ok {
199-
policyRewrapResp = policyResult{policyID: id, obligations: []string{}, kaoRes: []kaoResult{}}
200-
}
201-
policyRewrapResp.kaoRes = append(policyRewrapResp.kaoRes, kaoResult{
197+
allRewrapResp[id] = append(allRewrapResp[id], kaoResult{
202198
Error: fmt.Errorf("KasAllowlist: kas url %s is not allowed", kasurl),
203199
KeyAccessObjectID: kao.GetKeyAccessObjectId(),
204200
})
205-
allRewrapResp[id] = policyRewrapResp
206201
}
207202
}
208203
continue
209204
}
210205

211-
var rewrapResp map[string]policyResult
206+
var rewrapResp map[string][]kaoResult
212207
switch bulkReq.TDFType {
213208
case Nano:
214209
rewrapResp, err = kasClient.nanoUnwrap(ctx, rewrapRequests...)
@@ -217,16 +212,7 @@ func (s SDK) performRewraps(ctx context.Context, bulkReq *BulkDecryptRequest, ka
217212
}
218213

219214
for id, res := range rewrapResp {
220-
// ! It's possible that we already created a policyResult for the policy above for a specific KAS URL.
221-
// ! Meaning for another kas url of the same policy we will end up with an empty list of obligations.
222-
// ! This should be fine since we will error out anyways.
223-
if existingResp, ok := allRewrapResp[id]; !ok {
224-
allRewrapResp[id] = res
225-
} else {
226-
// ! Should not need to append obligations since they should be the same for all TDFs under a policy
227-
existingResp.kaoRes = append(existingResp.kaoRes, res.kaoRes...)
228-
allRewrapResp[id] = existingResp
229-
}
215+
allRewrapResp[id] = append(allRewrapResp[id], res...)
230216
}
231217
}
232218

@@ -274,7 +260,7 @@ func (s SDK) PrepareBulkDecrypt(ctx context.Context, opts ...BulkDecryptOption)
274260
tdf.Error = errors.New("rewrap did not create a response for this TDF")
275261
continue
276262
}
277-
tdf.TriggeredObligations = Obligations{FQNs: policyRes.obligations}
263+
tdf.TriggeredObligations = Obligations{FQNs: dedupRequiredObligations(policyRes)}
278264
}
279265

280266
return &BulkDecryptPrepared{
@@ -289,14 +275,14 @@ func (bp *BulkDecryptPrepared) BulkDecrypt(ctx context.Context) error {
289275
var errList []error
290276
var err error
291277
for id, tdf := range bp.PolicyTDF {
292-
policyRes, ok := bp.allRewrapResp[id]
278+
kaoRes, ok := bp.allRewrapResp[id]
293279
if !ok {
294280
tdf.Error = errors.New("rewrap did not create a response for this TDF")
295281
errList = append(errList, tdf.Error)
296282
continue
297283
}
298284
decryptor := bp.tdfDecryptors[id]
299-
if _, err = decryptor.Decrypt(ctx, policyRes.kaoRes); err != nil {
285+
if _, err = decryptor.Decrypt(ctx, kaoRes); err != nil {
300286
tdf.Error = err
301287
errList = append(errList, tdf.Error)
302288
continue

sdk/kas_client.go

Lines changed: 52 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,10 @@ type KASClient struct {
4444
}
4545

4646
type kaoResult struct {
47-
SymmetricKey []byte
48-
Error error
49-
KeyAccessObjectID string
50-
}
51-
52-
type policyResult struct {
53-
policyID string
54-
obligations []string
55-
kaoRes []kaoResult
47+
SymmetricKey []byte
48+
Error error
49+
KeyAccessObjectID string
50+
RequiredObligations []string
5651
}
5752

5853
type decryptor interface {
@@ -175,7 +170,7 @@ func upgradeRewrapErrorV1(err error, requests []*kas.UnsignedRewrapRequest_WithP
175170
}, nil
176171
}
177172

178-
func (k *KASClient) nanoUnwrap(ctx context.Context, requests ...*kas.UnsignedRewrapRequest_WithPolicyRequest) (map[string]policyResult, error) {
173+
func (k *KASClient) nanoUnwrap(ctx context.Context, requests ...*kas.UnsignedRewrapRequest_WithPolicyRequest) (map[string][]kaoResult, error) {
179174
keypair, err := ocrypto.NewECKeyPair(ocrypto.ECCModeSecp256r1)
180175
if err != nil {
181176
return nil, fmt.Errorf("ocrypto.NewECKeyPair failed :%w", err)
@@ -198,22 +193,19 @@ func (k *KASClient) nanoUnwrap(ctx context.Context, requests ...*kas.UnsignedRew
198193
// If the session key is empty, all responses are errors
199194
spk := response.GetSessionPublicKey()
200195
if spk == "" {
201-
policyResults := make(map[string]policyResult)
196+
policyResults := make(map[string][]kaoResult)
202197
err = errors.New("nanoUnwrap: session public key is empty")
203198
for _, results := range response.GetResponses() {
204-
policyRes, ok := policyResults[results.GetPolicyId()]
205-
if !ok {
206-
policyRes = policyResult{policyID: results.GetPolicyId(), obligations: []string{}, kaoRes: []kaoResult{}}
207-
}
199+
var kaoKeys []kaoResult
208200
for _, kao := range results.GetResults() {
201+
requiredObligationsForKAO := k.retrieveObligationsFromMetadata(kao.GetMetadata())
209202
if kao.GetStatus() == statusPermit {
210-
policyRes.kaoRes = append(policyRes.kaoRes, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err})
203+
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err, RequiredObligations: requiredObligationsForKAO})
211204
} else {
212-
policyRes.kaoRes = append(policyRes.kaoRes, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError())})
205+
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError()), RequiredObligations: requiredObligationsForKAO})
213206
}
214207
}
215-
policyRes.obligations = k.retrieveObligationsFromMetadata(response.GetMetadata(), results.GetPolicyId())
216-
policyResults[results.GetPolicyId()] = policyRes
208+
policyResults[results.GetPolicyId()] = kaoKeys
217209
}
218210

219211
return policyResults, nil
@@ -234,33 +226,30 @@ func (k *KASClient) nanoUnwrap(ctx context.Context, requests ...*kas.UnsignedRew
234226
return nil, fmt.Errorf("nanoUnwrap: ocrypto.NewAESGcm failed:%w", err)
235227
}
236228

237-
policyResults := make(map[string]policyResult)
229+
policyResults := make(map[string][]kaoResult)
238230
for _, results := range response.GetResponses() {
239-
policyRes, ok := policyResults[results.GetPolicyId()]
240-
if !ok {
241-
policyRes = policyResult{policyID: results.GetPolicyId(), obligations: []string{}, kaoRes: []kaoResult{}}
242-
}
231+
var kaoKeys []kaoResult
243232
for _, kao := range results.GetResults() {
233+
requiredObligationsForKAO := k.retrieveObligationsFromMetadata(kao.GetMetadata())
244234
if kao.GetStatus() == statusPermit {
245235
wrappedKey := kao.GetKasWrappedKey()
246236
key, err := aesGcm.Decrypt(wrappedKey)
247237
if err != nil {
248-
policyRes.kaoRes = append(policyRes.kaoRes, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err})
238+
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err, RequiredObligations: requiredObligationsForKAO})
249239
} else {
250-
policyRes.kaoRes = append(policyRes.kaoRes, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), SymmetricKey: key})
240+
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), SymmetricKey: key, RequiredObligations: requiredObligationsForKAO})
251241
}
252242
} else {
253-
policyRes.kaoRes = append(policyRes.kaoRes, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError())})
243+
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError()), RequiredObligations: requiredObligationsForKAO})
254244
}
255245
}
256-
policyRes.obligations = k.retrieveObligationsFromMetadata(response.GetMetadata(), results.GetPolicyId())
257-
policyResults[results.GetPolicyId()] = policyRes
246+
policyResults[results.GetPolicyId()] = kaoKeys
258247
}
259248

260249
return policyResults, nil
261250
}
262251

263-
func (k *KASClient) unwrap(ctx context.Context, requests ...*kas.UnsignedRewrapRequest_WithPolicyRequest) (map[string]policyResult, error) {
252+
func (k *KASClient) unwrap(ctx context.Context, requests ...*kas.UnsignedRewrapRequest_WithPolicyRequest) (map[string][]kaoResult, error) {
264253
if k.sessionKey == nil {
265254
return nil, errors.New("session key is nil")
266255
}
@@ -279,7 +268,7 @@ func (k *KASClient) unwrap(ctx context.Context, requests ...*kas.UnsignedRewrapR
279268
return k.handleRSAKeyResponse(response)
280269
}
281270

282-
func (k *KASClient) handleECKeyResponse(response *kas.RewrapResponse) (map[string]policyResult, error) {
271+
func (k *KASClient) handleECKeyResponse(response *kas.RewrapResponse) (map[string][]kaoResult, error) {
283272
kasEphemeralPublicKey := response.GetSessionPublicKey()
284273
clientPrivateKey, err := k.sessionKey.PrivateKeyInPemFormat()
285274
if err != nil {
@@ -306,60 +295,58 @@ func (k *KASClient) handleECKeyResponse(response *kas.RewrapResponse) (map[strin
306295
return k.processECResponse(response, aesGcm)
307296
}
308297

309-
func (k *KASClient) processECResponse(response *kas.RewrapResponse, aesGcm ocrypto.AesGcm) (map[string]policyResult, error) {
310-
policyResults := make(map[string]policyResult)
298+
func (k *KASClient) processECResponse(response *kas.RewrapResponse, aesGcm ocrypto.AesGcm) (map[string][]kaoResult, error) {
299+
policyResults := make(map[string][]kaoResult)
311300
for _, results := range response.GetResponses() {
312301
var kaoKeys []kaoResult
313302
for _, kao := range results.GetResults() {
303+
requiredObligationsForKAO := k.retrieveObligationsFromMetadata(kao.GetMetadata())
314304
if kao.GetStatus() == statusPermit {
315305
key, err := aesGcm.Decrypt(kao.GetKasWrappedKey())
316306
if err != nil {
317-
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err})
307+
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err, RequiredObligations: requiredObligationsForKAO})
318308
} else {
319-
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), SymmetricKey: key})
309+
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), SymmetricKey: key, RequiredObligations: requiredObligationsForKAO})
320310
}
321311
} else {
322-
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError())})
312+
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError()), RequiredObligations: requiredObligationsForKAO})
323313
}
324314
}
325-
requiredObligations := k.retrieveObligationsFromMetadata(response.GetMetadata(), results.GetPolicyId())
326-
policyResults[results.GetPolicyId()] = policyResult{policyID: results.GetPolicyId(), kaoRes: kaoKeys, obligations: requiredObligations}
315+
policyResults[results.GetPolicyId()] = kaoKeys
327316
}
328317
return policyResults, nil
329318
}
330319

331-
func (k *KASClient) retrieveObligationsFromMetadata(metadata map[string]*structpb.Value, policyID string) []string {
332-
var triggeredFQNs []string
320+
/*
321+
Metadata will be in the following form, per kao:
333322
334-
if metadata == nil {
335-
return triggeredFQNs
323+
{
324+
"metadata": {
325+
"X-Required-Obligations": [<Required obligation FQNs>]
326+
}
336327
}
328+
*/
329+
func (k *KASClient) retrieveObligationsFromMetadata(metadata map[string]*structpb.Value) []string {
330+
var requiredObligations []string
337331

338-
triggerOblsValue, ok := metadata[triggeredObligationsHeader]
339-
if !ok {
340-
return triggeredFQNs
341-
}
342-
343-
fields := triggerOblsValue.GetStructValue().GetFields()
344-
if fields == nil {
345-
return triggeredFQNs
332+
if metadata == nil {
333+
return requiredObligations
346334
}
347335

348-
policyOblsValue, ok := fields[policyID]
336+
triggerOblsValue, ok := metadata[triggeredObligationsHeader]
349337
if !ok {
350-
return triggeredFQNs
338+
return requiredObligations
351339
}
352340

353-
values := policyOblsValue.GetListValue().GetValues()
354-
355-
for _, v := range values {
356-
triggeredFQNs = append(triggeredFQNs, v.GetStringValue())
341+
triggerOblsList := triggerOblsValue.GetListValue().GetValues()
342+
for _, v := range triggerOblsList {
343+
requiredObligations = append(requiredObligations, v.GetStringValue())
357344
}
358345

359-
return triggeredFQNs
346+
return requiredObligations
360347
}
361348

362-
func (k *KASClient) handleRSAKeyResponse(response *kas.RewrapResponse) (map[string]policyResult, error) {
349+
func (k *KASClient) handleRSAKeyResponse(response *kas.RewrapResponse) (map[string][]kaoResult, error) {
363350
clientPrivateKey, err := k.sessionKey.PrivateKeyInPemFormat()
364351
if err != nil {
365352
return nil, fmt.Errorf("ocrypto.PrivateKeyInPemFormat failed: %w", err)
@@ -373,24 +360,24 @@ func (k *KASClient) handleRSAKeyResponse(response *kas.RewrapResponse) (map[stri
373360
return k.processRSAResponse(response, asymDecryption)
374361
}
375362

376-
func (k *KASClient) processRSAResponse(response *kas.RewrapResponse, asymDecryption ocrypto.AsymDecryption) (map[string]policyResult, error) {
377-
policyResults := make(map[string]policyResult)
363+
func (k *KASClient) processRSAResponse(response *kas.RewrapResponse, asymDecryption ocrypto.AsymDecryption) (map[string][]kaoResult, error) {
364+
policyResults := make(map[string][]kaoResult)
378365
for _, results := range response.GetResponses() {
379366
var kaoKeys []kaoResult
380367
for _, kao := range results.GetResults() {
368+
requiredObligationsForKAO := k.retrieveObligationsFromMetadata(kao.GetMetadata())
381369
if kao.GetStatus() == statusPermit {
382370
key, err := asymDecryption.Decrypt(kao.GetKasWrappedKey())
383371
if err != nil {
384-
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err})
372+
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err, RequiredObligations: requiredObligationsForKAO})
385373
} else {
386-
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), SymmetricKey: key})
374+
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), SymmetricKey: key, RequiredObligations: requiredObligationsForKAO})
387375
}
388376
} else {
389-
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError())})
377+
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError()), RequiredObligations: requiredObligationsForKAO})
390378
}
391379
}
392-
requiredObligations := k.retrieveObligationsFromMetadata(response.GetMetadata(), results.GetPolicyId())
393-
policyResults[results.GetPolicyId()] = policyResult{policyID: results.GetPolicyId(), kaoRes: kaoKeys, obligations: requiredObligations}
380+
policyResults[results.GetPolicyId()] = kaoKeys
394381
}
395382
return policyResults, nil
396383
}

0 commit comments

Comments
 (0)