@@ -44,15 +44,10 @@ type KASClient struct {
4444}
4545
4646type kaoResult struct {
47- SymmetricKey []byte
48- Error error
49- KeyAccessObjectID string
50- }
51-
52- type policyResult struct {
53- policyID string
54- obligations []string
55- kaoRes []kaoResult
47+ SymmetricKey []byte
48+ Error error
49+ KeyAccessObjectID string
50+ RequiredObligations []string
5651}
5752
5853type decryptor interface {
@@ -175,7 +170,7 @@ func upgradeRewrapErrorV1(err error, requests []*kas.UnsignedRewrapRequest_WithP
175170 }, nil
176171}
177172
178- func (k * KASClient ) nanoUnwrap (ctx context.Context , requests ... * kas.UnsignedRewrapRequest_WithPolicyRequest ) (map [string ]policyResult , error ) {
173+ func (k * KASClient ) nanoUnwrap (ctx context.Context , requests ... * kas.UnsignedRewrapRequest_WithPolicyRequest ) (map [string ][] kaoResult , error ) {
179174 keypair , err := ocrypto .NewECKeyPair (ocrypto .ECCModeSecp256r1 )
180175 if err != nil {
181176 return nil , fmt .Errorf ("ocrypto.NewECKeyPair failed :%w" , err )
@@ -198,22 +193,19 @@ func (k *KASClient) nanoUnwrap(ctx context.Context, requests ...*kas.UnsignedRew
198193 // If the session key is empty, all responses are errors
199194 spk := response .GetSessionPublicKey ()
200195 if spk == "" {
201- policyResults := make (map [string ]policyResult )
196+ policyResults := make (map [string ][] kaoResult )
202197 err = errors .New ("nanoUnwrap: session public key is empty" )
203198 for _ , results := range response .GetResponses () {
204- policyRes , ok := policyResults [results .GetPolicyId ()]
205- if ! ok {
206- policyRes = policyResult {policyID : results .GetPolicyId (), obligations : []string {}, kaoRes : []kaoResult {}}
207- }
199+ var kaoKeys []kaoResult
208200 for _ , kao := range results .GetResults () {
201+ requiredObligationsForKAO := k .retrieveObligationsFromMetadata (kao .GetMetadata ())
209202 if kao .GetStatus () == statusPermit {
210- policyRes . kaoRes = append (policyRes . kaoRes , kaoResult {KeyAccessObjectID : kao .GetKeyAccessObjectId (), Error : err })
203+ kaoKeys = append (kaoKeys , kaoResult {KeyAccessObjectID : kao .GetKeyAccessObjectId (), Error : err , RequiredObligations : requiredObligationsForKAO })
211204 } else {
212- policyRes . kaoRes = append (policyRes . kaoRes , kaoResult {KeyAccessObjectID : kao .GetKeyAccessObjectId (), Error : errors .New (kao .GetError ())})
205+ kaoKeys = append (kaoKeys , kaoResult {KeyAccessObjectID : kao .GetKeyAccessObjectId (), Error : errors .New (kao .GetError ()), RequiredObligations : requiredObligationsForKAO })
213206 }
214207 }
215- policyRes .obligations = k .retrieveObligationsFromMetadata (response .GetMetadata (), results .GetPolicyId ())
216- policyResults [results .GetPolicyId ()] = policyRes
208+ policyResults [results .GetPolicyId ()] = kaoKeys
217209 }
218210
219211 return policyResults , nil
@@ -234,33 +226,30 @@ func (k *KASClient) nanoUnwrap(ctx context.Context, requests ...*kas.UnsignedRew
234226 return nil , fmt .Errorf ("nanoUnwrap: ocrypto.NewAESGcm failed:%w" , err )
235227 }
236228
237- policyResults := make (map [string ]policyResult )
229+ policyResults := make (map [string ][] kaoResult )
238230 for _ , results := range response .GetResponses () {
239- policyRes , ok := policyResults [results .GetPolicyId ()]
240- if ! ok {
241- policyRes = policyResult {policyID : results .GetPolicyId (), obligations : []string {}, kaoRes : []kaoResult {}}
242- }
231+ var kaoKeys []kaoResult
243232 for _ , kao := range results .GetResults () {
233+ requiredObligationsForKAO := k .retrieveObligationsFromMetadata (kao .GetMetadata ())
244234 if kao .GetStatus () == statusPermit {
245235 wrappedKey := kao .GetKasWrappedKey ()
246236 key , err := aesGcm .Decrypt (wrappedKey )
247237 if err != nil {
248- policyRes . kaoRes = append (policyRes . kaoRes , kaoResult {KeyAccessObjectID : kao .GetKeyAccessObjectId (), Error : err })
238+ kaoKeys = append (kaoKeys , kaoResult {KeyAccessObjectID : kao .GetKeyAccessObjectId (), Error : err , RequiredObligations : requiredObligationsForKAO })
249239 } else {
250- policyRes . kaoRes = append (policyRes . kaoRes , kaoResult {KeyAccessObjectID : kao .GetKeyAccessObjectId (), SymmetricKey : key })
240+ kaoKeys = append (kaoKeys , kaoResult {KeyAccessObjectID : kao .GetKeyAccessObjectId (), SymmetricKey : key , RequiredObligations : requiredObligationsForKAO })
251241 }
252242 } else {
253- policyRes . kaoRes = append (policyRes . kaoRes , kaoResult {KeyAccessObjectID : kao .GetKeyAccessObjectId (), Error : errors .New (kao .GetError ())})
243+ kaoKeys = append (kaoKeys , kaoResult {KeyAccessObjectID : kao .GetKeyAccessObjectId (), Error : errors .New (kao .GetError ()), RequiredObligations : requiredObligationsForKAO })
254244 }
255245 }
256- policyRes .obligations = k .retrieveObligationsFromMetadata (response .GetMetadata (), results .GetPolicyId ())
257- policyResults [results .GetPolicyId ()] = policyRes
246+ policyResults [results .GetPolicyId ()] = kaoKeys
258247 }
259248
260249 return policyResults , nil
261250}
262251
263- func (k * KASClient ) unwrap (ctx context.Context , requests ... * kas.UnsignedRewrapRequest_WithPolicyRequest ) (map [string ]policyResult , error ) {
252+ func (k * KASClient ) unwrap (ctx context.Context , requests ... * kas.UnsignedRewrapRequest_WithPolicyRequest ) (map [string ][] kaoResult , error ) {
264253 if k .sessionKey == nil {
265254 return nil , errors .New ("session key is nil" )
266255 }
@@ -279,7 +268,7 @@ func (k *KASClient) unwrap(ctx context.Context, requests ...*kas.UnsignedRewrapR
279268 return k .handleRSAKeyResponse (response )
280269}
281270
282- func (k * KASClient ) handleECKeyResponse (response * kas.RewrapResponse ) (map [string ]policyResult , error ) {
271+ func (k * KASClient ) handleECKeyResponse (response * kas.RewrapResponse ) (map [string ][] kaoResult , error ) {
283272 kasEphemeralPublicKey := response .GetSessionPublicKey ()
284273 clientPrivateKey , err := k .sessionKey .PrivateKeyInPemFormat ()
285274 if err != nil {
@@ -306,60 +295,58 @@ func (k *KASClient) handleECKeyResponse(response *kas.RewrapResponse) (map[strin
306295 return k .processECResponse (response , aesGcm )
307296}
308297
309- func (k * KASClient ) processECResponse (response * kas.RewrapResponse , aesGcm ocrypto.AesGcm ) (map [string ]policyResult , error ) {
310- policyResults := make (map [string ]policyResult )
298+ func (k * KASClient ) processECResponse (response * kas.RewrapResponse , aesGcm ocrypto.AesGcm ) (map [string ][] kaoResult , error ) {
299+ policyResults := make (map [string ][] kaoResult )
311300 for _ , results := range response .GetResponses () {
312301 var kaoKeys []kaoResult
313302 for _ , kao := range results .GetResults () {
303+ requiredObligationsForKAO := k .retrieveObligationsFromMetadata (kao .GetMetadata ())
314304 if kao .GetStatus () == statusPermit {
315305 key , err := aesGcm .Decrypt (kao .GetKasWrappedKey ())
316306 if err != nil {
317- kaoKeys = append (kaoKeys , kaoResult {KeyAccessObjectID : kao .GetKeyAccessObjectId (), Error : err })
307+ kaoKeys = append (kaoKeys , kaoResult {KeyAccessObjectID : kao .GetKeyAccessObjectId (), Error : err , RequiredObligations : requiredObligationsForKAO })
318308 } else {
319- kaoKeys = append (kaoKeys , kaoResult {KeyAccessObjectID : kao .GetKeyAccessObjectId (), SymmetricKey : key })
309+ kaoKeys = append (kaoKeys , kaoResult {KeyAccessObjectID : kao .GetKeyAccessObjectId (), SymmetricKey : key , RequiredObligations : requiredObligationsForKAO })
320310 }
321311 } else {
322- kaoKeys = append (kaoKeys , kaoResult {KeyAccessObjectID : kao .GetKeyAccessObjectId (), Error : errors .New (kao .GetError ())})
312+ kaoKeys = append (kaoKeys , kaoResult {KeyAccessObjectID : kao .GetKeyAccessObjectId (), Error : errors .New (kao .GetError ()), RequiredObligations : requiredObligationsForKAO })
323313 }
324314 }
325- requiredObligations := k .retrieveObligationsFromMetadata (response .GetMetadata (), results .GetPolicyId ())
326- policyResults [results .GetPolicyId ()] = policyResult {policyID : results .GetPolicyId (), kaoRes : kaoKeys , obligations : requiredObligations }
315+ policyResults [results .GetPolicyId ()] = kaoKeys
327316 }
328317 return policyResults , nil
329318}
330319
331- func ( k * KASClient ) retrieveObligationsFromMetadata ( metadata map [ string ] * structpb. Value , policyID string ) [] string {
332- var triggeredFQNs [] string
320+ /*
321+ Metadata will be in the following form, per kao:
333322
334- if metadata == nil {
335- return triggeredFQNs
323+ {
324+ "metadata": {
325+ "X-Required-Obligations": [<Required obligation FQNs>]
326+ }
336327 }
328+ */
329+ func (k * KASClient ) retrieveObligationsFromMetadata (metadata map [string ]* structpb.Value ) []string {
330+ var requiredObligations []string
337331
338- triggerOblsValue , ok := metadata [triggeredObligationsHeader ]
339- if ! ok {
340- return triggeredFQNs
341- }
342-
343- fields := triggerOblsValue .GetStructValue ().GetFields ()
344- if fields == nil {
345- return triggeredFQNs
332+ if metadata == nil {
333+ return requiredObligations
346334 }
347335
348- policyOblsValue , ok := fields [ policyID ]
336+ triggerOblsValue , ok := metadata [ triggeredObligationsHeader ]
349337 if ! ok {
350- return triggeredFQNs
338+ return requiredObligations
351339 }
352340
353- values := policyOblsValue .GetListValue ().GetValues ()
354-
355- for _ , v := range values {
356- triggeredFQNs = append (triggeredFQNs , v .GetStringValue ())
341+ triggerOblsList := triggerOblsValue .GetListValue ().GetValues ()
342+ for _ , v := range triggerOblsList {
343+ requiredObligations = append (requiredObligations , v .GetStringValue ())
357344 }
358345
359- return triggeredFQNs
346+ return requiredObligations
360347}
361348
362- func (k * KASClient ) handleRSAKeyResponse (response * kas.RewrapResponse ) (map [string ]policyResult , error ) {
349+ func (k * KASClient ) handleRSAKeyResponse (response * kas.RewrapResponse ) (map [string ][] kaoResult , error ) {
363350 clientPrivateKey , err := k .sessionKey .PrivateKeyInPemFormat ()
364351 if err != nil {
365352 return nil , fmt .Errorf ("ocrypto.PrivateKeyInPemFormat failed: %w" , err )
@@ -373,24 +360,24 @@ func (k *KASClient) handleRSAKeyResponse(response *kas.RewrapResponse) (map[stri
373360 return k .processRSAResponse (response , asymDecryption )
374361}
375362
376- func (k * KASClient ) processRSAResponse (response * kas.RewrapResponse , asymDecryption ocrypto.AsymDecryption ) (map [string ]policyResult , error ) {
377- policyResults := make (map [string ]policyResult )
363+ func (k * KASClient ) processRSAResponse (response * kas.RewrapResponse , asymDecryption ocrypto.AsymDecryption ) (map [string ][] kaoResult , error ) {
364+ policyResults := make (map [string ][] kaoResult )
378365 for _ , results := range response .GetResponses () {
379366 var kaoKeys []kaoResult
380367 for _ , kao := range results .GetResults () {
368+ requiredObligationsForKAO := k .retrieveObligationsFromMetadata (kao .GetMetadata ())
381369 if kao .GetStatus () == statusPermit {
382370 key , err := asymDecryption .Decrypt (kao .GetKasWrappedKey ())
383371 if err != nil {
384- kaoKeys = append (kaoKeys , kaoResult {KeyAccessObjectID : kao .GetKeyAccessObjectId (), Error : err })
372+ kaoKeys = append (kaoKeys , kaoResult {KeyAccessObjectID : kao .GetKeyAccessObjectId (), Error : err , RequiredObligations : requiredObligationsForKAO })
385373 } else {
386- kaoKeys = append (kaoKeys , kaoResult {KeyAccessObjectID : kao .GetKeyAccessObjectId (), SymmetricKey : key })
374+ kaoKeys = append (kaoKeys , kaoResult {KeyAccessObjectID : kao .GetKeyAccessObjectId (), SymmetricKey : key , RequiredObligations : requiredObligationsForKAO })
387375 }
388376 } else {
389- kaoKeys = append (kaoKeys , kaoResult {KeyAccessObjectID : kao .GetKeyAccessObjectId (), Error : errors .New (kao .GetError ())})
377+ kaoKeys = append (kaoKeys , kaoResult {KeyAccessObjectID : kao .GetKeyAccessObjectId (), Error : errors .New (kao .GetError ()), RequiredObligations : requiredObligationsForKAO })
390378 }
391379 }
392- requiredObligations := k .retrieveObligationsFromMetadata (response .GetMetadata (), results .GetPolicyId ())
393- policyResults [results .GetPolicyId ()] = policyResult {policyID : results .GetPolicyId (), kaoRes : kaoKeys , obligations : requiredObligations }
380+ policyResults [results .GetPolicyId ()] = kaoKeys
394381 }
395382 return policyResults , nil
396383}
0 commit comments