@@ -12,9 +12,10 @@ import (
1212
1313// BulkTDF: Reader is TDF Content. Writer writes encrypted data. Error is the error that occurs if decrypting fails.
1414type BulkTDF struct {
15- Reader io.ReadSeeker
16- Writer io.Writer
17- Error error
15+ Reader io.ReadSeeker
16+ Writer io.Writer
17+ Error error
18+ TriggeredObligations Obligations
1819}
1920
2021type BulkDecryptRequest struct {
@@ -26,6 +27,15 @@ type BulkDecryptRequest struct {
2627 ignoreAllowList bool
2728}
2829
30+ // BulkDecryptPrepared holds the prepared state for bulk decryption
31+ // The PolicyTDF is a map of created policy IDs to their corresponding BulkTDF
32+ // The policy IDs are generated during the prepareDecryptors function
33+ type BulkDecryptPrepared struct {
34+ PolicyTDF map [string ]* BulkTDF
35+ tdfDecryptors map [string ]decryptor
36+ allRewrapResp map [string ][]kaoResult
37+ }
38+
2939// BulkErrors List of Errors that Failed during Bulk Decryption
3040type BulkErrors []error
3141
@@ -116,17 +126,9 @@ func (s SDK) createDecryptor(tdf *BulkTDF, req *BulkDecryptRequest) (decryptor,
116126 return nil , fmt .Errorf ("unknown tdf type: %s" , req .TDFType )
117127}
118128
119- // BulkDecrypt Decrypts a list of BulkTDF and if a partial failure of TDFs unable to be decrypted, BulkErrors would be returned.
120- func (s SDK ) BulkDecrypt (ctx context.Context , opts ... BulkDecryptOption ) error {
121- bulkReq , createError := createBulkRewrapRequest (opts ... )
122- if createError != nil {
123- return fmt .Errorf ("failed to create bulk rewrap request: %w" , createError )
124- }
125- kasRewrapRequests := make (map [string ][]* kas.UnsignedRewrapRequest_WithPolicyRequest )
126- tdfDecryptors := make (map [string ]decryptor )
127- policyTDF := make (map [string ]* BulkTDF )
128-
129- if ! bulkReq .ignoreAllowList && len (bulkReq .kasAllowlist ) == 0 { //nolint:nestif // if kasAllowlist is not set, we get it from the registry
129+ // setupKasAllowlist configures the KAS allowlist for the bulk request
130+ func (s SDK ) setupKasAllowlist (ctx context.Context , bulkReq * BulkDecryptRequest ) error {
131+ if ! bulkReq .ignoreAllowList && len (bulkReq .kasAllowlist ) == 0 { //nolint:nestif // not complex
130132 if s .KeyAccessServerRegistry != nil {
131133 platformEndpoint , err := s .PlatformConfiguration .platformEndpoint ()
132134 if err != nil {
@@ -145,10 +147,18 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error {
145147 return errors .New ("no KAS allowlist provided and no KeyAccessServerRegistry available" )
146148 }
147149 }
150+ return nil
151+ }
152+
153+ // prepareDecryptors creates decryptors and rewrap requests for all TDFs
154+ func (s SDK ) prepareDecryptors (ctx context.Context , bulkReq * BulkDecryptRequest ) (map [string ][]* kas.UnsignedRewrapRequest_WithPolicyRequest , map [string ]decryptor , map [string ]* BulkTDF ) {
155+ kasRewrapRequests := make (map [string ][]* kas.UnsignedRewrapRequest_WithPolicyRequest )
156+ tdfDecryptors := make (map [string ]decryptor )
157+ policyTDF := make (map [string ]* BulkTDF )
148158
149159 for i , tdf := range bulkReq .TDFs {
150160 policyID := fmt .Sprintf ("policy-%d" , i )
151- decryptor , err := s .createDecryptor (tdf , bulkReq ) //nolint:contextcheck // dont want to change signature of LoadTDF
161+ decryptor , err := s .createDecryptor (tdf , bulkReq ) //nolint:contextcheck // context is not used in createDecryptor
152162 if err != nil {
153163 tdf .Error = err
154164 continue
@@ -167,9 +177,15 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error {
167177 }
168178 }
169179
170- kasClient := newKASClient (s .conn .Client , s .conn .Options , s .tokenSource , s .kasSessionKey )
180+ return kasRewrapRequests , tdfDecryptors , policyTDF
181+ }
182+
183+ // performRewraps executes all rewrap requests with KAS servers
184+ func (s SDK ) performRewraps (ctx context.Context , bulkReq * BulkDecryptRequest , kasRewrapRequests map [string ][]* kas.UnsignedRewrapRequest_WithPolicyRequest , fulfillableObligations []string ) (map [string ][]kaoResult , error ) {
185+ kasClient := newKASClient (s .conn .Client , s .conn .Options , s .tokenSource , s .kasSessionKey , fulfillableObligations )
171186 allRewrapResp := make (map [string ][]kaoResult )
172187 var err error
188+
173189 for kasurl , rewrapRequests := range kasRewrapRequests {
174190 if bulkReq .ignoreAllowList {
175191 s .Logger ().Warn ("kasAllowlist is ignored, kas url is allowed" , slog .String ("kas_url" , kasurl ))
@@ -186,6 +202,7 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error {
186202 }
187203 continue
188204 }
205+
189206 var rewrapResp map [string ][]kaoResult
190207 switch bulkReq .TDFType {
191208 case Nano :
@@ -198,19 +215,73 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error {
198215 allRewrapResp [id ] = append (allRewrapResp [id ], res ... )
199216 }
200217 }
218+
201219 if err != nil {
202- return fmt .Errorf ("bulk rewrap failed: %w" , err )
220+ return nil , fmt .Errorf ("bulk rewrap failed: %w" , err )
221+ }
222+
223+ return allRewrapResp , nil
224+ }
225+
226+ // PrepareBulkDecrypt does everything except decrypt from the Bulk Decrypt
227+ // ! Currently you cannot specify fulfillable obligations on an individual TDF basis
228+ func (s SDK ) PrepareBulkDecrypt (ctx context.Context , opts ... BulkDecryptOption ) (* BulkDecryptPrepared , error ) {
229+ bulkReq , createError := createBulkRewrapRequest (opts ... )
230+ if createError != nil {
231+ return nil , fmt .Errorf ("failed to create bulk rewrap request: %w" , createError )
232+ }
233+
234+ // Setup KAS allowlist
235+ if err := s .setupKasAllowlist (ctx , bulkReq ); err != nil {
236+ return nil , err
237+ }
238+
239+ // Prepare decryptors and rewrap requests
240+ kasRewrapRequests , tdfDecryptors , policyTDF := s .prepareDecryptors (ctx , bulkReq )
241+
242+ // Use the default fulfillable obligations unless a decryptor is available to provide its own
243+ fulfillableObligations := s .fulfillableObligationFQNs
244+ if len (tdfDecryptors ) > 0 {
245+ for _ , d := range tdfDecryptors {
246+ fulfillableObligations = getFulfillableObligations (d , s .logger )
247+ break
248+ }
249+ }
250+
251+ // Perform rewraps
252+ allRewrapResp , err := s .performRewraps (ctx , bulkReq , kasRewrapRequests , fulfillableObligations )
253+ if err != nil {
254+ return nil , err
203255 }
204256
205- var errList []error
206257 for id , tdf := range policyTDF {
207- kaoRes , ok := allRewrapResp [id ]
258+ policyRes , ok := allRewrapResp [id ]
259+ if ! ok {
260+ tdf .Error = errors .New ("rewrap did not create a response for this TDF" )
261+ continue
262+ }
263+ tdf .TriggeredObligations = Obligations {FQNs : dedupRequiredObligations (policyRes )}
264+ }
265+
266+ return & BulkDecryptPrepared {
267+ PolicyTDF : policyTDF ,
268+ tdfDecryptors : tdfDecryptors ,
269+ allRewrapResp : allRewrapResp ,
270+ }, nil
271+ }
272+
273+ // Allow the bulk decryption to occur
274+ func (bp * BulkDecryptPrepared ) BulkDecrypt (ctx context.Context ) error {
275+ var errList []error
276+ var err error
277+ for id , tdf := range bp .PolicyTDF {
278+ kaoRes , ok := bp .allRewrapResp [id ]
208279 if ! ok {
209280 tdf .Error = errors .New ("rewrap did not create a response for this TDF" )
210281 errList = append (errList , tdf .Error )
211282 continue
212283 }
213- decryptor := tdfDecryptors [id ]
284+ decryptor := bp . tdfDecryptors [id ]
214285 if _ , err = decryptor .Decrypt (ctx , kaoRes ); err != nil {
215286 tdf .Error = err
216287 errList = append (errList , tdf .Error )
@@ -225,9 +296,36 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error {
225296 return nil
226297}
227298
299+ // BulkDecrypt Decrypts a list of BulkTDF and if a partial failure of TDFs unable to be decrypted, BulkErrors would be returned.
300+ func (s SDK ) BulkDecrypt (ctx context.Context , opts ... BulkDecryptOption ) error {
301+ prepared , err := s .PrepareBulkDecrypt (ctx , opts ... )
302+ if err != nil {
303+ return err
304+ }
305+
306+ return prepared .BulkDecrypt (ctx )
307+ }
308+
228309func (b * BulkDecryptRequest ) appendTDFs (tdfs ... * BulkTDF ) {
229310 b .TDFs = append (
230311 b .TDFs ,
231312 tdfs ... ,
232313 )
233314}
315+
316+ func getFulfillableObligations (decryptor decryptor , logger * slog.Logger ) []string {
317+ if decryptor == nil {
318+ logger .Warn ("decryptor is nil, cannot populate obligations" )
319+ return make ([]string , 0 )
320+ }
321+
322+ switch d := decryptor .(type ) {
323+ case * tdf3DecryptHandler :
324+ return d .reader .config .fulfillableObligationFQNs
325+ case * NanoTDFDecryptHandler :
326+ return d .config .fulfillableObligationFQNs
327+ default :
328+ logger .Warn ("unknown decryptor type, cannot populate obligations" , slog .String ("type" , fmt .Sprintf ("%T" , d )))
329+ return make ([]string , 0 )
330+ }
331+ }
0 commit comments