Skip to content

Commit 3cccfd2

Browse files
authored
feat(sdk): Add obligations support. (#2759)
### Proposed Changes 1.) Expose `Obligations` function to TDF/NanoTDF 2.) Expose new `NanoTDFReader` with `LoadNanoTDF`,`Init` functions 3.) Refactor `bulk` to allow consumers to retrieve Obligations before decrypting content 4.) Expose `PrepareBulkDecrypt` to retrieve Obligations before decrypting the content 5.) Expose `ErrRewrapForbidden` for consumption by downstream consumers of the SDK. Manual E2E tests: - Ran benchmark and rt tests for bulk - Nano/ZTDF no access to resource regardless of obligations (returns empty slice of FQNs) - Nano/ZTDF has access to resource, but does not fulfill obligations (returns required obligations) - Nano/ZTDF has access to resource and fulfills obligations -> Decrypt successfully. - Multi-KAS decrypt successful when able to fulfill obligations, failure with returned obligations when not **Caveats** >[!NOTE] >Bulk decrypt is designed in a way that will only allow >a global set of fulfillable obligations to be applied, not for >an individual TDF. ### Checklist - [ ] I have added or updated unit tests - [ ] I have added or updated integration tests (if appropriate) - [ ] I have added or updated documentation ### Testing Instructions
1 parent 206abe3 commit 3cccfd2

File tree

12 files changed

+2214
-192
lines changed

12 files changed

+2214
-192
lines changed

sdk/bulk.go

Lines changed: 118 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@ import (
1212

1313
// BulkTDF: Reader is TDF Content. Writer writes encrypted data. Error is the error that occurs if decrypting fails.
1414
type BulkTDF struct {
15-
Reader io.ReadSeeker
16-
Writer io.Writer
17-
Error error
15+
Reader io.ReadSeeker
16+
Writer io.Writer
17+
Error error
18+
TriggeredObligations Obligations
1819
}
1920

2021
type BulkDecryptRequest struct {
@@ -26,6 +27,15 @@ type BulkDecryptRequest struct {
2627
ignoreAllowList bool
2728
}
2829

30+
// BulkDecryptPrepared holds the prepared state for bulk decryption
31+
// The PolicyTDF is a map of created policy IDs to their corresponding BulkTDF
32+
// The policy IDs are generated during the prepareDecryptors function
33+
type BulkDecryptPrepared struct {
34+
PolicyTDF map[string]*BulkTDF
35+
tdfDecryptors map[string]decryptor
36+
allRewrapResp map[string][]kaoResult
37+
}
38+
2939
// BulkErrors List of Errors that Failed during Bulk Decryption
3040
type BulkErrors []error
3141

@@ -116,17 +126,9 @@ func (s SDK) createDecryptor(tdf *BulkTDF, req *BulkDecryptRequest) (decryptor,
116126
return nil, fmt.Errorf("unknown tdf type: %s", req.TDFType)
117127
}
118128

119-
// BulkDecrypt Decrypts a list of BulkTDF and if a partial failure of TDFs unable to be decrypted, BulkErrors would be returned.
120-
func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error {
121-
bulkReq, createError := createBulkRewrapRequest(opts...)
122-
if createError != nil {
123-
return fmt.Errorf("failed to create bulk rewrap request: %w", createError)
124-
}
125-
kasRewrapRequests := make(map[string][]*kas.UnsignedRewrapRequest_WithPolicyRequest)
126-
tdfDecryptors := make(map[string]decryptor)
127-
policyTDF := make(map[string]*BulkTDF)
128-
129-
if !bulkReq.ignoreAllowList && len(bulkReq.kasAllowlist) == 0 { //nolint:nestif // if kasAllowlist is not set, we get it from the registry
129+
// setupKasAllowlist configures the KAS allowlist for the bulk request
130+
func (s SDK) setupKasAllowlist(ctx context.Context, bulkReq *BulkDecryptRequest) error {
131+
if !bulkReq.ignoreAllowList && len(bulkReq.kasAllowlist) == 0 { //nolint:nestif // not complex
130132
if s.KeyAccessServerRegistry != nil {
131133
platformEndpoint, err := s.PlatformConfiguration.platformEndpoint()
132134
if err != nil {
@@ -145,10 +147,18 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error {
145147
return errors.New("no KAS allowlist provided and no KeyAccessServerRegistry available")
146148
}
147149
}
150+
return nil
151+
}
152+
153+
// prepareDecryptors creates decryptors and rewrap requests for all TDFs
154+
func (s SDK) prepareDecryptors(ctx context.Context, bulkReq *BulkDecryptRequest) (map[string][]*kas.UnsignedRewrapRequest_WithPolicyRequest, map[string]decryptor, map[string]*BulkTDF) {
155+
kasRewrapRequests := make(map[string][]*kas.UnsignedRewrapRequest_WithPolicyRequest)
156+
tdfDecryptors := make(map[string]decryptor)
157+
policyTDF := make(map[string]*BulkTDF)
148158

149159
for i, tdf := range bulkReq.TDFs {
150160
policyID := fmt.Sprintf("policy-%d", i)
151-
decryptor, err := s.createDecryptor(tdf, bulkReq) //nolint:contextcheck // dont want to change signature of LoadTDF
161+
decryptor, err := s.createDecryptor(tdf, bulkReq) //nolint:contextcheck // context is not used in createDecryptor
152162
if err != nil {
153163
tdf.Error = err
154164
continue
@@ -167,9 +177,15 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error {
167177
}
168178
}
169179

170-
kasClient := newKASClient(s.conn.Client, s.conn.Options, s.tokenSource, s.kasSessionKey)
180+
return kasRewrapRequests, tdfDecryptors, policyTDF
181+
}
182+
183+
// performRewraps executes all rewrap requests with KAS servers
184+
func (s SDK) performRewraps(ctx context.Context, bulkReq *BulkDecryptRequest, kasRewrapRequests map[string][]*kas.UnsignedRewrapRequest_WithPolicyRequest, fulfillableObligations []string) (map[string][]kaoResult, error) {
185+
kasClient := newKASClient(s.conn.Client, s.conn.Options, s.tokenSource, s.kasSessionKey, fulfillableObligations)
171186
allRewrapResp := make(map[string][]kaoResult)
172187
var err error
188+
173189
for kasurl, rewrapRequests := range kasRewrapRequests {
174190
if bulkReq.ignoreAllowList {
175191
s.Logger().Warn("kasAllowlist is ignored, kas url is allowed", slog.String("kas_url", kasurl))
@@ -186,6 +202,7 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error {
186202
}
187203
continue
188204
}
205+
189206
var rewrapResp map[string][]kaoResult
190207
switch bulkReq.TDFType {
191208
case Nano:
@@ -198,19 +215,73 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error {
198215
allRewrapResp[id] = append(allRewrapResp[id], res...)
199216
}
200217
}
218+
201219
if err != nil {
202-
return fmt.Errorf("bulk rewrap failed: %w", err)
220+
return nil, fmt.Errorf("bulk rewrap failed: %w", err)
221+
}
222+
223+
return allRewrapResp, nil
224+
}
225+
226+
// PrepareBulkDecrypt does everything except decrypt from the Bulk Decrypt
227+
// ! Currently you cannot specify fulfillable obligations on an individual TDF basis
228+
func (s SDK) PrepareBulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) (*BulkDecryptPrepared, error) {
229+
bulkReq, createError := createBulkRewrapRequest(opts...)
230+
if createError != nil {
231+
return nil, fmt.Errorf("failed to create bulk rewrap request: %w", createError)
232+
}
233+
234+
// Setup KAS allowlist
235+
if err := s.setupKasAllowlist(ctx, bulkReq); err != nil {
236+
return nil, err
237+
}
238+
239+
// Prepare decryptors and rewrap requests
240+
kasRewrapRequests, tdfDecryptors, policyTDF := s.prepareDecryptors(ctx, bulkReq)
241+
242+
// Use the default fulfillable obligations unless a decryptor is available to provide its own
243+
fulfillableObligations := s.fulfillableObligationFQNs
244+
if len(tdfDecryptors) > 0 {
245+
for _, d := range tdfDecryptors {
246+
fulfillableObligations = getFulfillableObligations(d, s.logger)
247+
break
248+
}
249+
}
250+
251+
// Perform rewraps
252+
allRewrapResp, err := s.performRewraps(ctx, bulkReq, kasRewrapRequests, fulfillableObligations)
253+
if err != nil {
254+
return nil, err
203255
}
204256

205-
var errList []error
206257
for id, tdf := range policyTDF {
207-
kaoRes, ok := allRewrapResp[id]
258+
policyRes, ok := allRewrapResp[id]
259+
if !ok {
260+
tdf.Error = errors.New("rewrap did not create a response for this TDF")
261+
continue
262+
}
263+
tdf.TriggeredObligations = Obligations{FQNs: dedupRequiredObligations(policyRes)}
264+
}
265+
266+
return &BulkDecryptPrepared{
267+
PolicyTDF: policyTDF,
268+
tdfDecryptors: tdfDecryptors,
269+
allRewrapResp: allRewrapResp,
270+
}, nil
271+
}
272+
273+
// Allow the bulk decryption to occur
274+
func (bp *BulkDecryptPrepared) BulkDecrypt(ctx context.Context) error {
275+
var errList []error
276+
var err error
277+
for id, tdf := range bp.PolicyTDF {
278+
kaoRes, ok := bp.allRewrapResp[id]
208279
if !ok {
209280
tdf.Error = errors.New("rewrap did not create a response for this TDF")
210281
errList = append(errList, tdf.Error)
211282
continue
212283
}
213-
decryptor := tdfDecryptors[id]
284+
decryptor := bp.tdfDecryptors[id]
214285
if _, err = decryptor.Decrypt(ctx, kaoRes); err != nil {
215286
tdf.Error = err
216287
errList = append(errList, tdf.Error)
@@ -225,9 +296,36 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error {
225296
return nil
226297
}
227298

299+
// BulkDecrypt Decrypts a list of BulkTDF and if a partial failure of TDFs unable to be decrypted, BulkErrors would be returned.
300+
func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error {
301+
prepared, err := s.PrepareBulkDecrypt(ctx, opts...)
302+
if err != nil {
303+
return err
304+
}
305+
306+
return prepared.BulkDecrypt(ctx)
307+
}
308+
228309
func (b *BulkDecryptRequest) appendTDFs(tdfs ...*BulkTDF) {
229310
b.TDFs = append(
230311
b.TDFs,
231312
tdfs...,
232313
)
233314
}
315+
316+
func getFulfillableObligations(decryptor decryptor, logger *slog.Logger) []string {
317+
if decryptor == nil {
318+
logger.Warn("decryptor is nil, cannot populate obligations")
319+
return make([]string, 0)
320+
}
321+
322+
switch d := decryptor.(type) {
323+
case *tdf3DecryptHandler:
324+
return d.reader.config.fulfillableObligationFQNs
325+
case *NanoTDFDecryptHandler:
326+
return d.config.fulfillableObligationFQNs
327+
default:
328+
logger.Warn("unknown decryptor type, cannot populate obligations", slog.String("type", fmt.Sprintf("%T", d)))
329+
return make([]string, 0)
330+
}
331+
}

sdk/granter_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ const (
3030
specifiedKas = "https://attr.kas.com/"
3131
evenMoreSpecificKas = "https://value.kas.com/"
3232
lessSpecificKas = "https://namespace.kas.com/"
33+
obligationKas = "https://obligation.kas.com/"
3334
fakePem = mockRSAPublicKey1
3435
)
3536

@@ -75,6 +76,21 @@ var (
7576
mpc, _ = NewAttributeValueFQN("https://virtru.com/attr/mapped/value/c")
7677
mpd, _ = NewAttributeValueFQN("https://virtru.com/attr/mapped/value/d")
7778
mpu, _ = NewAttributeValueFQN("https://virtru.com/attr/mapped/value/unspecified")
79+
80+
// Attributes for testing obligations
81+
82+
OBLIGATIONATTR, _ = NewAttributeNameFQN("https://virtru.com/attr/obligation_test")
83+
oa1, _ = NewAttributeValueFQN("https://virtru.com/attr/obligation_test/value/value1")
84+
oa2, _ = NewAttributeValueFQN("https://virtru.com/attr/obligation_test/value/value2")
85+
oa3, _ = NewAttributeValueFQN("https://virtru.com/attr/obligation_test/value/value3")
86+
obligationWatermark = "https://virtru.com/obl/obligation_test/value/watermark"
87+
obligationGeofence = "https://virtru.com/obl/obligation_test/value/geofence"
88+
obligationRedact = "https://virtru.com/obl/obligation_test/value/redact"
89+
obligationMap = map[string]string{
90+
oa1.key: obligationWatermark,
91+
oa2.key: obligationGeofence,
92+
oa3.key: obligationRedact,
93+
}
7894
)
7995

8096
func spongeCase(s string) string {
@@ -211,6 +227,14 @@ func mockAttributeFor(fqn AttributeNameFQN) *policy.Attribute {
211227
Rule: policy.AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF,
212228
Fqn: fqn.String(),
213229
}
230+
case OBLIGATIONATTR.key:
231+
return &policy.Attribute{
232+
Id: "OBL",
233+
Namespace: &nsOne,
234+
Name: "obligation",
235+
Rule: policy.AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF,
236+
Fqn: fqn.String(),
237+
}
214238
}
215239
return nil
216240
}
@@ -452,6 +476,18 @@ func mockValueFor(fqn AttributeValueFQN) *policy.Value {
452476
p.Grants = make([]*policy.KeyAccessServer, 1)
453477
p.Grants[0] = mockGrant(evenMoreSpecificKas, "r1")
454478
}
479+
case OBLIGATIONATTR.key:
480+
switch strings.ToLower(fqn.Value()) {
481+
case "value1":
482+
p.KasKeys = make([]*policy.SimpleKasKey, 1)
483+
p.KasKeys[0] = mockSimpleKasKey(obligationKas, "r3")
484+
case "value2":
485+
p.KasKeys = make([]*policy.SimpleKasKey, 1)
486+
p.KasKeys[0] = mockSimpleKasKey(obligationKas, "r3")
487+
case "value3":
488+
p.KasKeys = make([]*policy.SimpleKasKey, 1)
489+
p.KasKeys[0] = mockSimpleKasKey("https://d.kas/", "e1")
490+
}
455491
}
456492
return &p
457493
}

0 commit comments

Comments
 (0)