diff --git a/sdk/bulk.go b/sdk/bulk.go index c418aec49f..e0a7405ed1 100644 --- a/sdk/bulk.go +++ b/sdk/bulk.go @@ -33,7 +33,7 @@ type BulkDecryptRequest struct { type BulkDecryptPrepared struct { PolicyTDF map[string]*BulkTDF tdfDecryptors map[string]decryptor - allRewrapResp map[string]policyResult + allRewrapResp map[string][]kaoResult } // BulkErrors List of Errors that Failed during Bulk Decryption @@ -181,9 +181,9 @@ func (s SDK) prepareDecryptors(ctx context.Context, bulkReq *BulkDecryptRequest) } // performRewraps executes all rewrap requests with KAS servers -func (s SDK) performRewraps(ctx context.Context, bulkReq *BulkDecryptRequest, kasRewrapRequests map[string][]*kas.UnsignedRewrapRequest_WithPolicyRequest, fulfillableObligations []string) (map[string]policyResult, error) { +func (s SDK) performRewraps(ctx context.Context, bulkReq *BulkDecryptRequest, kasRewrapRequests map[string][]*kas.UnsignedRewrapRequest_WithPolicyRequest, fulfillableObligations []string) (map[string][]kaoResult, error) { kasClient := newKASClient(s.conn.Client, s.conn.Options, s.tokenSource, s.kasSessionKey, fulfillableObligations) - allRewrapResp := make(map[string]policyResult) + allRewrapResp := make(map[string][]kaoResult) var err error for kasurl, rewrapRequests := range kasRewrapRequests { @@ -194,21 +194,16 @@ func (s SDK) performRewraps(ctx context.Context, bulkReq *BulkDecryptRequest, ka for _, req := range rewrapRequests { id := req.GetPolicy().GetId() for _, kao := range req.GetKeyAccessObjects() { - policyRewrapResp, ok := allRewrapResp[id] - if !ok { - policyRewrapResp = policyResult{policyID: id, obligations: []string{}, kaoRes: []kaoResult{}} - } - policyRewrapResp.kaoRes = append(policyRewrapResp.kaoRes, kaoResult{ + allRewrapResp[id] = append(allRewrapResp[id], kaoResult{ Error: fmt.Errorf("KasAllowlist: kas url %s is not allowed", kasurl), KeyAccessObjectID: kao.GetKeyAccessObjectId(), }) - allRewrapResp[id] = policyRewrapResp } } continue } - var rewrapResp map[string]policyResult + var rewrapResp map[string][]kaoResult switch bulkReq.TDFType { case Nano: rewrapResp, err = kasClient.nanoUnwrap(ctx, rewrapRequests...) @@ -217,16 +212,7 @@ func (s SDK) performRewraps(ctx context.Context, bulkReq *BulkDecryptRequest, ka } for id, res := range rewrapResp { - // ! It's possible that we already created a policyResult for the policy above for a specific KAS URL. - // ! Meaning for another kas url of the same policy we will end up with an empty list of obligations. - // ! This should be fine since we will error out anyways. - if existingResp, ok := allRewrapResp[id]; !ok { - allRewrapResp[id] = res - } else { - // ! Should not need to append obligations since they should be the same for all TDFs under a policy - existingResp.kaoRes = append(existingResp.kaoRes, res.kaoRes...) - allRewrapResp[id] = existingResp - } + allRewrapResp[id] = append(allRewrapResp[id], res...) } } @@ -274,7 +260,7 @@ func (s SDK) PrepareBulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) tdf.Error = errors.New("rewrap did not create a response for this TDF") continue } - tdf.TriggeredObligations = Obligations{FQNs: policyRes.obligations} + tdf.TriggeredObligations = Obligations{FQNs: dedupRequiredObligations(policyRes)} } return &BulkDecryptPrepared{ @@ -289,14 +275,14 @@ func (bp *BulkDecryptPrepared) BulkDecrypt(ctx context.Context) error { var errList []error var err error for id, tdf := range bp.PolicyTDF { - policyRes, ok := bp.allRewrapResp[id] + kaoRes, ok := bp.allRewrapResp[id] if !ok { tdf.Error = errors.New("rewrap did not create a response for this TDF") errList = append(errList, tdf.Error) continue } decryptor := bp.tdfDecryptors[id] - if _, err = decryptor.Decrypt(ctx, policyRes.kaoRes); err != nil { + if _, err = decryptor.Decrypt(ctx, kaoRes); err != nil { tdf.Error = err errList = append(errList, tdf.Error) continue diff --git a/sdk/kas_client.go b/sdk/kas_client.go index ebaf9b7a67..d54fa06fd0 100644 --- a/sdk/kas_client.go +++ b/sdk/kas_client.go @@ -44,15 +44,10 @@ type KASClient struct { } type kaoResult struct { - SymmetricKey []byte - Error error - KeyAccessObjectID string -} - -type policyResult struct { - policyID string - obligations []string - kaoRes []kaoResult + SymmetricKey []byte + Error error + KeyAccessObjectID string + RequiredObligations []string } type decryptor interface { @@ -175,7 +170,7 @@ func upgradeRewrapErrorV1(err error, requests []*kas.UnsignedRewrapRequest_WithP }, nil } -func (k *KASClient) nanoUnwrap(ctx context.Context, requests ...*kas.UnsignedRewrapRequest_WithPolicyRequest) (map[string]policyResult, error) { +func (k *KASClient) nanoUnwrap(ctx context.Context, requests ...*kas.UnsignedRewrapRequest_WithPolicyRequest) (map[string][]kaoResult, error) { keypair, err := ocrypto.NewECKeyPair(ocrypto.ECCModeSecp256r1) if err != nil { return nil, fmt.Errorf("ocrypto.NewECKeyPair failed :%w", err) @@ -198,22 +193,19 @@ func (k *KASClient) nanoUnwrap(ctx context.Context, requests ...*kas.UnsignedRew // If the session key is empty, all responses are errors spk := response.GetSessionPublicKey() if spk == "" { - policyResults := make(map[string]policyResult) + policyResults := make(map[string][]kaoResult) err = errors.New("nanoUnwrap: session public key is empty") for _, results := range response.GetResponses() { - policyRes, ok := policyResults[results.GetPolicyId()] - if !ok { - policyRes = policyResult{policyID: results.GetPolicyId(), obligations: []string{}, kaoRes: []kaoResult{}} - } + var kaoKeys []kaoResult for _, kao := range results.GetResults() { + requiredObligationsForKAO := k.retrieveObligationsFromMetadata(kao.GetMetadata()) if kao.GetStatus() == statusPermit { - policyRes.kaoRes = append(policyRes.kaoRes, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err}) + kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err, RequiredObligations: requiredObligationsForKAO}) } else { - policyRes.kaoRes = append(policyRes.kaoRes, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError())}) + kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError()), RequiredObligations: requiredObligationsForKAO}) } } - policyRes.obligations = k.retrieveObligationsFromMetadata(response.GetMetadata(), results.GetPolicyId()) - policyResults[results.GetPolicyId()] = policyRes + policyResults[results.GetPolicyId()] = kaoKeys } return policyResults, nil @@ -234,33 +226,30 @@ func (k *KASClient) nanoUnwrap(ctx context.Context, requests ...*kas.UnsignedRew return nil, fmt.Errorf("nanoUnwrap: ocrypto.NewAESGcm failed:%w", err) } - policyResults := make(map[string]policyResult) + policyResults := make(map[string][]kaoResult) for _, results := range response.GetResponses() { - policyRes, ok := policyResults[results.GetPolicyId()] - if !ok { - policyRes = policyResult{policyID: results.GetPolicyId(), obligations: []string{}, kaoRes: []kaoResult{}} - } + var kaoKeys []kaoResult for _, kao := range results.GetResults() { + requiredObligationsForKAO := k.retrieveObligationsFromMetadata(kao.GetMetadata()) if kao.GetStatus() == statusPermit { wrappedKey := kao.GetKasWrappedKey() key, err := aesGcm.Decrypt(wrappedKey) if err != nil { - policyRes.kaoRes = append(policyRes.kaoRes, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err}) + kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err, RequiredObligations: requiredObligationsForKAO}) } else { - policyRes.kaoRes = append(policyRes.kaoRes, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), SymmetricKey: key}) + kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), SymmetricKey: key, RequiredObligations: requiredObligationsForKAO}) } } else { - policyRes.kaoRes = append(policyRes.kaoRes, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError())}) + kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError()), RequiredObligations: requiredObligationsForKAO}) } } - policyRes.obligations = k.retrieveObligationsFromMetadata(response.GetMetadata(), results.GetPolicyId()) - policyResults[results.GetPolicyId()] = policyRes + policyResults[results.GetPolicyId()] = kaoKeys } return policyResults, nil } -func (k *KASClient) unwrap(ctx context.Context, requests ...*kas.UnsignedRewrapRequest_WithPolicyRequest) (map[string]policyResult, error) { +func (k *KASClient) unwrap(ctx context.Context, requests ...*kas.UnsignedRewrapRequest_WithPolicyRequest) (map[string][]kaoResult, error) { if k.sessionKey == nil { return nil, errors.New("session key is nil") } @@ -279,7 +268,7 @@ func (k *KASClient) unwrap(ctx context.Context, requests ...*kas.UnsignedRewrapR return k.handleRSAKeyResponse(response) } -func (k *KASClient) handleECKeyResponse(response *kas.RewrapResponse) (map[string]policyResult, error) { +func (k *KASClient) handleECKeyResponse(response *kas.RewrapResponse) (map[string][]kaoResult, error) { kasEphemeralPublicKey := response.GetSessionPublicKey() clientPrivateKey, err := k.sessionKey.PrivateKeyInPemFormat() if err != nil { @@ -306,60 +295,58 @@ func (k *KASClient) handleECKeyResponse(response *kas.RewrapResponse) (map[strin return k.processECResponse(response, aesGcm) } -func (k *KASClient) processECResponse(response *kas.RewrapResponse, aesGcm ocrypto.AesGcm) (map[string]policyResult, error) { - policyResults := make(map[string]policyResult) +func (k *KASClient) processECResponse(response *kas.RewrapResponse, aesGcm ocrypto.AesGcm) (map[string][]kaoResult, error) { + policyResults := make(map[string][]kaoResult) for _, results := range response.GetResponses() { var kaoKeys []kaoResult for _, kao := range results.GetResults() { + requiredObligationsForKAO := k.retrieveObligationsFromMetadata(kao.GetMetadata()) if kao.GetStatus() == statusPermit { key, err := aesGcm.Decrypt(kao.GetKasWrappedKey()) if err != nil { - kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err}) + kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err, RequiredObligations: requiredObligationsForKAO}) } else { - kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), SymmetricKey: key}) + kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), SymmetricKey: key, RequiredObligations: requiredObligationsForKAO}) } } else { - kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError())}) + kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError()), RequiredObligations: requiredObligationsForKAO}) } } - requiredObligations := k.retrieveObligationsFromMetadata(response.GetMetadata(), results.GetPolicyId()) - policyResults[results.GetPolicyId()] = policyResult{policyID: results.GetPolicyId(), kaoRes: kaoKeys, obligations: requiredObligations} + policyResults[results.GetPolicyId()] = kaoKeys } return policyResults, nil } -func (k *KASClient) retrieveObligationsFromMetadata(metadata map[string]*structpb.Value, policyID string) []string { - var triggeredFQNs []string +/* +Metadata will be in the following form, per kao: - if metadata == nil { - return triggeredFQNs + { + "metadata": { + "X-Required-Obligations": [] + } } +*/ +func (k *KASClient) retrieveObligationsFromMetadata(metadata map[string]*structpb.Value) []string { + var requiredObligations []string - triggerOblsValue, ok := metadata[triggeredObligationsHeader] - if !ok { - return triggeredFQNs - } - - fields := triggerOblsValue.GetStructValue().GetFields() - if fields == nil { - return triggeredFQNs + if metadata == nil { + return requiredObligations } - policyOblsValue, ok := fields[policyID] + triggerOblsValue, ok := metadata[triggeredObligationsHeader] if !ok { - return triggeredFQNs + return requiredObligations } - values := policyOblsValue.GetListValue().GetValues() - - for _, v := range values { - triggeredFQNs = append(triggeredFQNs, v.GetStringValue()) + triggerOblsList := triggerOblsValue.GetListValue().GetValues() + for _, v := range triggerOblsList { + requiredObligations = append(requiredObligations, v.GetStringValue()) } - return triggeredFQNs + return requiredObligations } -func (k *KASClient) handleRSAKeyResponse(response *kas.RewrapResponse) (map[string]policyResult, error) { +func (k *KASClient) handleRSAKeyResponse(response *kas.RewrapResponse) (map[string][]kaoResult, error) { clientPrivateKey, err := k.sessionKey.PrivateKeyInPemFormat() if err != nil { return nil, fmt.Errorf("ocrypto.PrivateKeyInPemFormat failed: %w", err) @@ -373,24 +360,24 @@ func (k *KASClient) handleRSAKeyResponse(response *kas.RewrapResponse) (map[stri return k.processRSAResponse(response, asymDecryption) } -func (k *KASClient) processRSAResponse(response *kas.RewrapResponse, asymDecryption ocrypto.AsymDecryption) (map[string]policyResult, error) { - policyResults := make(map[string]policyResult) +func (k *KASClient) processRSAResponse(response *kas.RewrapResponse, asymDecryption ocrypto.AsymDecryption) (map[string][]kaoResult, error) { + policyResults := make(map[string][]kaoResult) for _, results := range response.GetResponses() { var kaoKeys []kaoResult for _, kao := range results.GetResults() { + requiredObligationsForKAO := k.retrieveObligationsFromMetadata(kao.GetMetadata()) if kao.GetStatus() == statusPermit { key, err := asymDecryption.Decrypt(kao.GetKasWrappedKey()) if err != nil { - kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err}) + kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err, RequiredObligations: requiredObligationsForKAO}) } else { - kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), SymmetricKey: key}) + kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), SymmetricKey: key, RequiredObligations: requiredObligationsForKAO}) } } else { - kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError())}) + kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError()), RequiredObligations: requiredObligationsForKAO}) } } - requiredObligations := k.retrieveObligationsFromMetadata(response.GetMetadata(), results.GetPolicyId()) - policyResults[results.GetPolicyId()] = policyResult{policyID: results.GetPolicyId(), kaoRes: kaoKeys, obligations: requiredObligations} + policyResults[results.GetPolicyId()] = kaoKeys } return policyResults, nil } diff --git a/sdk/kas_client_test.go b/sdk/kas_client_test.go index be311dc645..e1f0ecd8c8 100644 --- a/sdk/kas_client_test.go +++ b/sdk/kas_client_test.go @@ -410,23 +410,24 @@ func Test_newConnectRewrapRequest(t *testing.T) { func Test_retrieveObligationsFromMetadata(t *testing.T) { c := newKASClient(nil, nil, nil, nil, nil) - metadata := createMetadataWithObligations(map[string][]string{ - "policy1": {"https://example.com/attr/attr1/value/val1", "https://example.com/attr/attr2/value/val2"}, + metadata := createMetadataWithObligations([]string{ + "https://example.com/attr/attr1/value/val1", + "https://example.com/attr/attr2/value/val2", }) - fqns := c.retrieveObligationsFromMetadata(metadata, "policy1") - assert.Len(t, fqns, 2) - assert.Equal(t, "https://example.com/attr/attr1/value/val1", fqns[0]) - assert.Equal(t, "https://example.com/attr/attr2/value/val2", fqns[1]) + fqns := c.retrieveObligationsFromMetadata(metadata) + require.Len(t, fqns, 2) + require.Equal(t, "https://example.com/attr/attr1/value/val1", fqns[0]) + require.Equal(t, "https://example.com/attr/attr2/value/val2", fqns[1]) } func Test_retrieveObligationsFromMetadata_NoObligations(t *testing.T) { c := newKASClient(nil, nil, nil, nil, nil) - fqns := c.retrieveObligationsFromMetadata(createMetadataWithObligations(nil), "policy1") - assert.Empty(t, fqns) + fqns := c.retrieveObligationsFromMetadata(createMetadataWithObligations(nil)) + require.Empty(t, fqns) } -func Test_retrieveObligationsFromMetadata_NotStructValue(t *testing.T) { +func Test_retrieveObligationsFromMetadata_NotListValue(t *testing.T) { c := newKASClient(nil, nil, nil, nil, nil) metadata := make(map[string]*structpb.Value) metadata[triggeredObligationsHeader] = &structpb.Value{ @@ -434,38 +435,20 @@ func Test_retrieveObligationsFromMetadata_NotStructValue(t *testing.T) { BoolValue: true, }, } - fqns := c.retrieveObligationsFromMetadata(metadata, "policy1") - assert.Empty(t, fqns) + fqns := c.retrieveObligationsFromMetadata(metadata) + require.Empty(t, fqns) } -func Test_retrieveObligationsFromMetadata_PolicyNotPresent(t *testing.T) { +func Test_retrieveObligationsFromMetadata_EmptyList(t *testing.T) { c := newKASClient(nil, nil, nil, nil, nil) metadata := make(map[string]*structpb.Value) metadata[triggeredObligationsHeader] = &structpb.Value{ - Kind: &structpb.Value_StructValue{ - StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{}}, + Kind: &structpb.Value_ListValue{ + ListValue: &structpb.ListValue{Values: []*structpb.Value{}}, }, } - fqns := c.retrieveObligationsFromMetadata(metadata, "policy1") - assert.Empty(t, fqns) -} - -func Test_retrieveObligationsFromMetadata_ListValuesNotPresent(t *testing.T) { - c := newKASClient(nil, nil, nil, nil, nil) - metadata := make(map[string]*structpb.Value) - metadata[triggeredObligationsHeader] = &structpb.Value{ - Kind: &structpb.Value_StructValue{ - StructValue: &structpb.Struct{ - Fields: map[string]*structpb.Value{ - "not_a_list": { - Kind: &structpb.Value_StringValue{StringValue: "not a list"}, - }, - }, - }, - }, - } - fqns := c.retrieveObligationsFromMetadata(metadata, "policy1") - assert.Empty(t, fqns) + fqns := c.retrieveObligationsFromMetadata(metadata) + require.Empty(t, fqns) } func Test_processRSAResponse(t *testing.T) { @@ -493,6 +476,31 @@ func Test_processRSAResponse(t *testing.T) { Responses: []*kaspb.PolicyRewrapResult{ { PolicyId: "policy1", + Results: []*kaspb.KeyAccessRewrapResult{ + { + KeyAccessObjectId: "kao1", + Status: "fail", + Result: &kaspb.KeyAccessRewrapResult_Error{ + Error: "Access denied", + }, + Metadata: createMetadataWithObligations([]string{ + "https://example.com/attr/attr1/value/val1", + }), + }, + { + KeyAccessObjectId: "kao2", + Status: "fail", + Result: &kaspb.KeyAccessRewrapResult_Error{ + Error: "Access denied", + }, + Metadata: createMetadataWithObligations([]string{ + "https://example.com/attr/attr1/value/val2", + }), + }, + }, + }, + { + PolicyId: "policy2", Results: []*kaspb.KeyAccessRewrapResult{ { KeyAccessObjectId: "kao1", @@ -500,26 +508,35 @@ func Test_processRSAResponse(t *testing.T) { Result: &kaspb.KeyAccessRewrapResult_KasWrappedKey{ KasWrappedKey: wrappedKey, }, + Metadata: createMetadataWithObligations([]string{ + "https://example.com/attr/attr1/value/val3", + }), }, }, }, }, - Metadata: createMetadataWithObligations(map[string][]string{ - "policy1": {"https://example.com/attr/attr1/value/val1"}, - }), } policyResults, err := c.processRSAResponse(response, mockDecryptor) require.NoError(t, err) - require.Len(t, policyResults, 1) + require.Len(t, policyResults, 2) result, ok := policyResults["policy1"] require.True(t, ok) - require.Equal(t, "policy1", result.policyID) - require.Len(t, result.kaoRes, 1) - require.Equal(t, symmetricKey, result.kaoRes[0].SymmetricKey) - require.Len(t, result.obligations, 1) - require.Equal(t, "https://example.com/attr/attr1/value/val1", result.obligations[0]) + require.Len(t, result, 2) + require.Nil(t, result[0].SymmetricKey) + require.Nil(t, result[1].SymmetricKey) + require.Len(t, result[0].RequiredObligations, 1) + require.Len(t, result[1].RequiredObligations, 1) + require.Equal(t, "https://example.com/attr/attr1/value/val1", result[0].RequiredObligations[0]) + require.Equal(t, "https://example.com/attr/attr1/value/val2", result[1].RequiredObligations[0]) + + result2, ok := policyResults["policy2"] + require.True(t, ok) + require.Len(t, result2, 1) + require.Equal(t, symmetricKey, result2[0].SymmetricKey) + require.Len(t, result2[0].RequiredObligations, 1) + require.Equal(t, "https://example.com/attr/attr1/value/val3", result2[0].RequiredObligations[0]) } func Test_processECResponse(t *testing.T) { @@ -550,11 +567,6 @@ func Test_processECResponse(t *testing.T) { encryptor, err := ocrypto.NewAESGcm(sessionKey) require.NoError(t, err) - // 4. Encrypt a symmetric key - symmetricKey1 := []byte("supersecretkey1") - wrappedKey1, err := encryptor.Encrypt(symmetricKey1) - require.NoError(t, err) - symmetricKey2 := []byte("supersecretkey2") wrappedKey2, err := encryptor.Encrypt(symmetricKey2) require.NoError(t, err) @@ -567,10 +579,23 @@ func Test_processECResponse(t *testing.T) { Results: []*kaspb.KeyAccessRewrapResult{ { KeyAccessObjectId: "kao1", - Status: "permit", - Result: &kaspb.KeyAccessRewrapResult_KasWrappedKey{ - KasWrappedKey: wrappedKey1, + Status: "fail", + Result: &kaspb.KeyAccessRewrapResult_Error{ + Error: "Access denied", }, + Metadata: createMetadataWithObligations([]string{ + "https://example.com/attr/attr1/value/val1", + }), + }, + { + KeyAccessObjectId: "kao2", + Status: "fail", + Result: &kaspb.KeyAccessRewrapResult_Error{ + Error: "Access denied", + }, + Metadata: createMetadataWithObligations([]string{ + "https://example.com/attr/attr1/value/val2", + }), }, }, }, @@ -578,19 +603,18 @@ func Test_processECResponse(t *testing.T) { PolicyId: "policy2", Results: []*kaspb.KeyAccessRewrapResult{ { - KeyAccessObjectId: "kao2", + KeyAccessObjectId: "kao1", Status: "permit", Result: &kaspb.KeyAccessRewrapResult_KasWrappedKey{ KasWrappedKey: wrappedKey2, }, + Metadata: createMetadataWithObligations([]string{ + "https://example.com/attr/attr2/value/val2", + }), }, }, }, }, - Metadata: createMetadataWithObligations(map[string][]string{ - "policy1": {"https://example.com/attr/attr1/value/val1"}, - "policy2": {"https://example.com/attr/attr2/value/val2"}, - }), } // 6. Create AES-GCM cipher for decryption (using the same session key) @@ -605,20 +629,21 @@ func Test_processECResponse(t *testing.T) { // 8. Assertions for policy1 result1, ok := policyResults["policy1"] require.True(t, ok) - require.Equal(t, "policy1", result1.policyID) - require.Len(t, result1.kaoRes, 1) - require.Equal(t, symmetricKey1, result1.kaoRes[0].SymmetricKey) - require.Len(t, result1.obligations, 1) - require.Equal(t, "https://example.com/attr/attr1/value/val1", result1.obligations[0]) + require.Len(t, result1, 2) + require.Nil(t, result1[0].SymmetricKey) + require.Nil(t, result1[1].SymmetricKey) + require.Len(t, result1[0].RequiredObligations, 1) + require.Equal(t, "https://example.com/attr/attr1/value/val1", result1[0].RequiredObligations[0]) + require.Len(t, result1[1].RequiredObligations, 1) + require.Equal(t, "https://example.com/attr/attr1/value/val2", result1[1].RequiredObligations[0]) // 9. Assertions for policy2 result2, ok := policyResults["policy2"] require.True(t, ok) - require.Equal(t, "policy2", result2.policyID) - require.Len(t, result2.kaoRes, 1) - require.Equal(t, symmetricKey2, result2.kaoRes[0].SymmetricKey) - require.Len(t, result2.obligations, 1) - require.Equal(t, "https://example.com/attr/attr2/value/val2", result2.obligations[0]) + require.Len(t, result2, 1) + require.Equal(t, symmetricKey2, result2[0].SymmetricKey) + require.Len(t, result2[0].RequiredObligations, 1) + require.Equal(t, "https://example.com/attr/attr2/value/val2", result2[0].RequiredObligations[0]) } type mockService interface { @@ -626,8 +651,9 @@ type mockService interface { } type MockKas struct { - t *testing.T - obligations map[string][]string // policyID -> obligations + t *testing.T + obligations map[string][]string // policyID -> obligations + policyDecisions map[string]string // policyID -> "permit" or "fail" } func (f *MockKas) Process(req *http.Request) (*http.Response, error) { @@ -678,27 +704,48 @@ func (f *MockKas) Process(req *http.Request) (*http.Response, error) { rewrapResponse := &kaspb.RewrapResponse{ SessionPublicKey: kasPublicKeyPEM, } - triggeredObligations := make(map[string][]string) for _, req := range unsignedReq.GetRequests() { policyID := req.GetPolicy().GetId() - rewrapResponse.Responses = append(rewrapResponse.Responses, &kaspb.PolicyRewrapResult{ - PolicyId: policyID, - Results: []*kaspb.KeyAccessRewrapResult{ - { - KeyAccessObjectId: req.GetKeyAccessObjects()[0].GetKeyAccessObjectId(), - Status: "permit", - Result: &kaspb.KeyAccessRewrapResult_KasWrappedKey{ - KasWrappedKey: wrappedKey, - }, - }, - }, - }) + // Determine if this policy should be permitted or failed + decision := "permit" // default to permit + if f.policyDecisions != nil { + if d, exists := f.policyDecisions[policyID]; exists { + decision = d + } + } + + var kaoResult *kaspb.KeyAccessRewrapResult + var metadata map[string]*structpb.Value if fqns, exists := f.obligations[policyID]; exists { - triggeredObligations[policyID] = fqns + metadata = createMetadataWithObligations(fqns) + } + if decision == "permit" { + // For permitted policies: no metadata/obligations + kaoResult = &kaspb.KeyAccessRewrapResult{ + KeyAccessObjectId: req.GetKeyAccessObjects()[0].GetKeyAccessObjectId(), + Status: "permit", + Result: &kaspb.KeyAccessRewrapResult_KasWrappedKey{ + KasWrappedKey: wrappedKey, + }, + Metadata: metadata, + } + } else { + kaoResult = &kaspb.KeyAccessRewrapResult{ + KeyAccessObjectId: req.GetKeyAccessObjects()[0].GetKeyAccessObjectId(), + Status: "fail", + Result: &kaspb.KeyAccessRewrapResult_Error{ + Error: "denied by policy", + }, + Metadata: metadata, + } } + + rewrapResponse.Responses = append(rewrapResponse.Responses, &kaspb.PolicyRewrapResult{ + PolicyId: policyID, + Results: []*kaspb.KeyAccessRewrapResult{kaoResult}, + }) } - rewrapResponse.Metadata = createMetadataWithObligations(triggeredObligations) responseBody, err := protojson.Marshal(rewrapResponse) require.NoError(f.t, err) @@ -737,6 +784,10 @@ func Test_nanoUnwrap(t *testing.T) { "policy1": {"https://example.com/attr/attr1/value/val1"}, "policy2": {"https://example.com/attr/attr2/value/val2"}, }, + policyDecisions: map[string]string{ + "policy1": "permit", // policy1 should be permitted + "policy2": "fail", // policy2 should be failed + }, }}, } @@ -774,48 +825,24 @@ func Test_nanoUnwrap(t *testing.T) { require.Len(t, policyResults, 2) // 5. Assertions + // Policy1 should be permitted - has symmetric key, no error, no obligations result1, ok := policyResults["policy1"] require.True(t, ok) - require.Equal(t, "policy1", result1.policyID) - require.Len(t, result1.kaoRes, 1) - require.Equal(t, []byte("supersecretkey1"), result1.kaoRes[0].SymmetricKey) - require.NoError(t, result1.kaoRes[0].Error) - require.Len(t, result1.obligations, 1) - require.Equal(t, "https://example.com/attr/attr1/value/val1", result1.obligations[0]) + require.Len(t, result1, 1) + require.Equal(t, []byte("supersecretkey1"), result1[0].SymmetricKey) + require.NoError(t, result1[0].Error) + require.Len(t, result1[0].RequiredObligations, 1) + require.Equal(t, "https://example.com/attr/attr1/value/val1", result1[0].RequiredObligations[0]) + // Policy2 should be failed - has error, no symmetric key, has obligations result2, ok := policyResults["policy2"] require.True(t, ok) - require.Equal(t, "policy2", result2.policyID) - require.Len(t, result2.kaoRes, 1) - require.Equal(t, []byte("supersecretkey1"), result2.kaoRes[0].SymmetricKey) - require.NoError(t, result2.kaoRes[0].Error) - require.Len(t, result2.obligations, 1) - require.Equal(t, "https://example.com/attr/attr2/value/val2", result2.obligations[0]) -} - -func createMetadataWithObligations(obligations map[string][]string) map[string]*structpb.Value { - metadata := make(map[string]*structpb.Value) - if len(obligations) == 0 { - return metadata - } - - fields := make(map[string]*structpb.Value) - for policyID, fqns := range obligations { - listValue := &structpb.ListValue{} - for _, fqn := range fqns { - listValue.Values = append(listValue.Values, structpb.NewStringValue(fqn)) - } - fields[policyID] = structpb.NewListValue(listValue) - } - - metadata[triggeredObligationsHeader] = &structpb.Value{ - Kind: &structpb.Value_StructValue{ - StructValue: &structpb.Struct{ - Fields: fields, - }, - }, - } - return metadata + require.Len(t, result2, 1) + require.Nil(t, result2[0].SymmetricKey, "Failed policies should not have symmetric key") + require.Error(t, result2[0].Error) + require.Contains(t, result2[0].Error.Error(), "denied by policy") + require.Len(t, result2[0].RequiredObligations, 1) + require.Equal(t, "https://example.com/attr/attr2/value/val2", result2[0].RequiredObligations[0]) } func Test_nanoUnwrap_EmptySPK_WithObligations(t *testing.T) { @@ -832,13 +859,13 @@ func Test_nanoUnwrap_EmptySPK_WithObligations(t *testing.T) { Result: &kaspb.KeyAccessRewrapResult_Error{ Error: "denied by policy", }, + Metadata: createMetadataWithObligations([]string{ + "https://example.com/attr/attr1/value/val1", + }), }, }, }, }, - Metadata: createMetadataWithObligations(map[string][]string{ - "policy1": {"https://example.com/attr/attr1/value/val1"}, - }), } responseBody, err := protojson.Marshal(rewrapResponse) @@ -883,15 +910,29 @@ func Test_nanoUnwrap_EmptySPK_WithObligations(t *testing.T) { // 6. Assertions result, ok := policyResults["policy1"] require.True(t, ok) - require.Equal(t, "policy1", result.policyID) + require.Len(t, result, 1) // Assert that the KAO result contains an error - require.Len(t, result.kaoRes, 1) - require.Error(t, result.kaoRes[0].Error) - require.Contains(t, result.kaoRes[0].Error.Error(), "denied by policy") - require.Nil(t, result.kaoRes[0].SymmetricKey) + require.Error(t, result[0].Error) + require.Contains(t, result[0].Error.Error(), "denied by policy") + require.Nil(t, result[0].SymmetricKey) // Assert that obligations are still present despite the KAO error - require.Len(t, result.obligations, 1) - require.Equal(t, "https://example.com/attr/attr1/value/val1", result.obligations[0]) + require.Len(t, result[0].RequiredObligations, 1) + require.Equal(t, "https://example.com/attr/attr1/value/val1", result[0].RequiredObligations[0]) +} + +func createMetadataWithObligations(obligations []string) map[string]*structpb.Value { + metadata := make(map[string]*structpb.Value) + if len(obligations) == 0 { + return metadata + } + + listValue := &structpb.ListValue{} + for _, fqn := range obligations { + listValue.Values = append(listValue.Values, structpb.NewStringValue(fqn)) + } + + metadata[triggeredObligationsHeader] = structpb.NewListValue(listValue) + return metadata } diff --git a/sdk/nanotdf.go b/sdk/nanotdf.go index 511ea7346f..c76e7054de 100644 --- a/sdk/nanotdf.go +++ b/sdk/nanotdf.go @@ -1066,23 +1066,23 @@ func (n *NanoTDFReader) getNanoRewrapKey(ctx context.Context) error { return fmt.Errorf("rewrap failed: %w", err) } result, ok := policyResult["policy"] - if !ok || len(result.kaoRes) != 1 { + if !ok || len(result) != 1 { return errors.New("policy was not found in rewrap response") } // Populate obligations after policy result is found. - n.requiredObligations = &Obligations{FQNs: result.obligations} + n.requiredObligations = &Obligations{FQNs: result[0].RequiredObligations} - if result.kaoRes[0].Error != nil { - errToReturn := fmt.Errorf("rewrapError: %w", result.kaoRes[0].Error) - return getKasErrorToReturn(result.kaoRes[0].Error, errToReturn) + if result[0].Error != nil { + errToReturn := fmt.Errorf("rewrapError: %w", result[0].Error) + return getKasErrorToReturn(result[0].Error, errToReturn) } if n.collectionStore != nil { - n.collectionStore.store(n.headerBuf, result.kaoRes[0].SymmetricKey) + n.collectionStore.store(n.headerBuf, result[0].SymmetricKey) } - n.payloadKey = result.kaoRes[0].SymmetricKey + n.payloadKey = result[0].SymmetricKey return nil } diff --git a/sdk/tdf.go b/sdk/tdf.go index e9397d7b78..db5957386d 100644 --- a/sdk/tdf.go +++ b/sdk/tdf.go @@ -1422,11 +1422,11 @@ func (r *Reader) doPayloadKeyUnwrap(ctx context.Context) error { //nolint:gocogn err = errors.New("could not find policy in rewrap response") reqFail(err, req) } - // ! Should constantly be the same obligation for the same policy - r.requiredObligations = &Obligations{FQNs: result.obligations} - kaoResults = append(kaoResults, result.kaoRes...) + kaoResults = append(kaoResults, result...) } } + // Deduplicate obligations for all kao results + r.requiredObligations = &Obligations{FQNs: dedupRequiredObligations(kaoResults)} return r.buildKey(ctx, kaoResults) } @@ -1612,3 +1612,22 @@ func getKasAllowList(ctx context.Context, kasAllowList AllowList, s SDK, ignoreA return allowList, nil } + +func dedupRequiredObligations(kaoResults []kaoResult) []string { + seen := make(map[string]struct{}) + dedupedOblgs := make([]string, 0) + for _, kao := range kaoResults { + for _, oblg := range kao.RequiredObligations { + normalizedOblg := strings.TrimSpace(strings.ToLower(oblg)) + if len(normalizedOblg) == 0 { + continue + } + if _, ok := seen[normalizedOblg]; !ok { + seen[normalizedOblg] = struct{}{} + dedupedOblgs = append(dedupedOblgs, normalizedOblg) + } + } + } + + return dedupedOblgs +} diff --git a/sdk/tdf_test.go b/sdk/tdf_test.go index 158e9c0ab2..a36113d503 100644 --- a/sdk/tdf_test.go +++ b/sdk/tdf_test.go @@ -76,6 +76,7 @@ type tdfTest struct { splitPlan []keySplitStep policy []AttributeValueFQN expectedPlanSize int + opts []TDFReaderOption } type baseKeyTest struct { @@ -1934,36 +1935,50 @@ func (s *TDFSuite) Test_KeySplits() { func (s *TDFSuite) Test_Obligations_Decrypt() { for _, test := range []struct { - n string - fileSize int64 - tdfFileSize float64 - checksum string - obligationFQNs []string - opts []TDFOption + n string + fileSize int64 + tdfFileSize float64 + checksum string + requiredObligationFQNs []string + opts []TDFOption + fulfillableObligations []string + attrValueFQNs []AttributeValueFQN + expectError bool }{ { - n: "with-obligations-same-kas", - fileSize: 5, - tdfFileSize: 2534, - checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2", - obligationFQNs: []string{obligationWatermark, obligationRedact}, - opts: []TDFOption{ - WithDataAttributes(oa1.key, oa3.key), - }, - }, - { - n: "with-obligations-different-kas", - fileSize: 5, - tdfFileSize: 2690, - checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2", - obligationFQNs: []string{obligationWatermark, obligationRedact, obligationGeofence}, - opts: []TDFOption{ - WithDataAttributes(oa1.key, oa2.key, oa3.key), - }, + n: "two-attributes-same-kas-with-fulfillable-obligations", + fileSize: 5, + tdfFileSize: 1909, + checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2", + requiredObligationFQNs: []string{obligationWatermark, obligationGeofence}, + opts: []TDFOption{WithDataAttributes(oa1.key, oa2.key)}, // Both go to obligationKas + fulfillableObligations: []string{obligationWatermark, obligationGeofence}, + attrValueFQNs: []AttributeValueFQN{oa1, oa2}, + expectError: false, + }, + { + n: "two-attributes-same-kas-no-fulfillable-obligations", + fileSize: 5, + tdfFileSize: 1909, + checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2", + requiredObligationFQNs: []string{obligationWatermark, obligationGeofence}, + opts: []TDFOption{WithDataAttributes(oa1.key, oa2.key)}, // Both go to obligationKas + fulfillableObligations: []string{}, // No fulfillable obligations + expectError: true, + }, + { + n: "fulfill-one-of-two-attributes-same-kas", + fileSize: 5, + tdfFileSize: 1909, + checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2", + requiredObligationFQNs: []string{obligationWatermark, obligationGeofence}, + opts: []TDFOption{WithDataAttributes(oa1.key, oa2.key)}, + fulfillableObligations: []string{obligationWatermark}, + expectError: true, }, } { s.Run(test.n, func() { - // create .txt file + // Create a new SDK instance with limited fulfillable obligations plainTextFileName := test.n + ".txt" tdfFileName := plainTextFileName + ".tdf" decryptedTdfFileName := tdfFileName + ".txt" @@ -1975,7 +1990,7 @@ func (s *TDFSuite) Test_Obligations_Decrypt() { _ = os.Remove(decryptedTdfFileName) }() - // test encrypt + // test encrypt using the default SDK (which has all fulfillable obligations) s.testEncrypt(s.sdk, test.opts, plainTextFileName, tdfFileName, tdfTest{ n: test.n, fileSize: test.fileSize, @@ -1983,7 +1998,6 @@ func (s *TDFSuite) Test_Obligations_Decrypt() { checksum: test.checksum, }) - // test decrypt with reader readSeeker, err := os.Open(tdfFileName) s.Require().NoError(err) defer func(readSeeker *os.File) { @@ -1993,26 +2007,32 @@ func (s *TDFSuite) Test_Obligations_Decrypt() { r, err := s.sdk.LoadTDF(readSeeker) s.Require().NoError(err) - for _, ob := range []string{obligationGeofence, obligationRedact, obligationWatermark} { - s.Require().Contains(r.config.fulfillableObligationFQNs, ob, "Should contain obligation "+ob) + r.config.fulfillableObligationFQNs = test.fulfillableObligations + + if !test.expectError { + // Validate successful decryption + s.testDecryptWithReader(s.sdk, tdfFileName, decryptedTdfFileName, tdfTest{ + n: test.n, + fileSize: test.fileSize, + checksum: test.checksum, + policy: test.attrValueFQNs, + opts: []TDFReaderOption{WithTDFFulfillableObligationFQNs(test.fulfillableObligations)}, + }) + + _, err = r.WriteTo(io.Discard) + s.Require().NoError(err) + } else { + // The decryption should fail due to unmet obligations + _, err = r.WriteTo(io.Discard) + s.Require().Error(err, "Decryption should fail when obligations are not met") } - // Validate decryption - s.testDecryptWithReader(s.sdk, tdfFileName, decryptedTdfFileName, tdfTest{ - n: test.n, - fileSize: test.fileSize, - checksum: test.checksum, - policy: []AttributeValueFQN{oa1, oa3}, - }) - - _, err = r.WriteTo(io.Discard) - s.Require().NoError(err) obligations, err := r.Obligations(s.T().Context()) s.Require().NoError(err) s.Require().NotNil(obligations, "Obligations should not be nil") - s.Require().Len(obligations.FQNs, len(test.obligationFQNs), "Should have correct number of obligations") + s.Require().Len(obligations.FQNs, len(test.requiredObligationFQNs), "Should have correct number of obligations") actualObligations := obligations - for _, ob := range test.obligationFQNs { + for _, ob := range test.requiredObligationFQNs { s.Require().Contains(actualObligations.FQNs, ob, "Actual obligations should contain "+ob) } }) @@ -2178,6 +2198,192 @@ func (s *TDFSuite) Test_Obligations() { } } +func TestDedupRequiredObligations(t *testing.T) { + testCases := []struct { + name string + kaoResults []kaoResult + expectedResult []string + }{ + { + name: "empty input", + kaoResults: []kaoResult{}, + expectedResult: []string{}, + }, + { + name: "single kao with no obligations", + kaoResults: []kaoResult{ + { + KeyAccessObjectID: "kao-1", + RequiredObligations: []string{}, + }, + }, + expectedResult: []string{}, + }, + { + name: "single kao with single obligation", + kaoResults: []kaoResult{ + { + KeyAccessObjectID: "kao-1", + RequiredObligations: []string{"https://demo.com/obl/test/value/watermark"}, + }, + }, + expectedResult: []string{"https://demo.com/obl/test/value/watermark"}, + }, + { + name: "single kao with multiple obligations", + kaoResults: []kaoResult{ + { + KeyAccessObjectID: "kao-1", + RequiredObligations: []string{ + "https://demo.com/obl/test/value/watermark", + "https://demo.com/obl/test/value/geofence", + }, + }, + }, + expectedResult: []string{ + "https://demo.com/obl/test/value/watermark", + "https://demo.com/obl/test/value/geofence", + }, + }, + { + name: "multiple kaos with same obligations - should dedupe", + kaoResults: []kaoResult{ + { + KeyAccessObjectID: "kao-1", + RequiredObligations: []string{"https://demo.com/obl/test/value/watermark"}, + }, + { + KeyAccessObjectID: "kao-2", + RequiredObligations: []string{"https://demo.com/obl/test/value/watermark"}, + }, + }, + expectedResult: []string{"https://demo.com/obl/test/value/watermark"}, + }, + { + name: "multiple kaos with different obligations", + kaoResults: []kaoResult{ + { + KeyAccessObjectID: "kao-1", + RequiredObligations: []string{"https://demo.com/obl/test/value/watermark"}, + }, + { + KeyAccessObjectID: "kao-2", + RequiredObligations: []string{"https://demo.com/obl/test/value/geofence"}, + }, + }, + expectedResult: []string{ + "https://demo.com/obl/test/value/watermark", + "https://demo.com/obl/test/value/geofence", + }, + }, + { + name: "case insensitive deduplication", + kaoResults: []kaoResult{ + { + KeyAccessObjectID: "kao-1", + RequiredObligations: []string{"https://demo.com/obl/test/value/WATERMARK"}, + }, + { + KeyAccessObjectID: "kao-2", + RequiredObligations: []string{"https://demo.com/obl/test/value/watermark"}, + }, + }, + expectedResult: []string{"https://demo.com/obl/test/value/watermark"}, + }, + { + name: "whitespace trimming and deduplication", + kaoResults: []kaoResult{ + { + KeyAccessObjectID: "kao-1", + RequiredObligations: []string{" https://demo.com/obl/test/value/watermark "}, + }, + { + KeyAccessObjectID: "kao-2", + RequiredObligations: []string{"https://demo.com/obl/test/value/watermark"}, + }, + }, + expectedResult: []string{"https://demo.com/obl/test/value/watermark"}, + }, + { + name: "complex case - mixed duplicates with case and whitespace variations", + kaoResults: []kaoResult{ + { + KeyAccessObjectID: "kao-1", + RequiredObligations: []string{ + "https://demo.com/obl/test/value/WATERMARK", + "https://demo.com/obl/test/value/geofence", + }, + }, + { + KeyAccessObjectID: "kao-2", + RequiredObligations: []string{ + " https://demo.com/obl/test/value/watermark ", + "https://demo.com/obl/test/value/ENCRYPTION", + }, + }, + { + KeyAccessObjectID: "kao-3", + RequiredObligations: []string{ + "https://demo.com/obl/test/value/geofence", + "https://demo.com/obl/test/value/encryption", + }, + }, + }, + expectedResult: []string{ + "https://demo.com/obl/test/value/watermark", + "https://demo.com/obl/test/value/geofence", + "https://demo.com/obl/test/value/encryption", + }, + }, + { + name: "empty string obligations should be normalized", + kaoResults: []kaoResult{ + { + KeyAccessObjectID: "kao-1", + RequiredObligations: []string{ + "", + " ", + "https://demo.com/obl/test/value/watermark", + }, + }, + }, + expectedResult: []string{ + "https://demo.com/obl/test/value/watermark", + }, + }, + { + name: "preserve order of first occurrence", + kaoResults: []kaoResult{ + { + KeyAccessObjectID: "kao-1", + RequiredObligations: []string{ + "https://demo.com/obl/test/value/geofence", + "https://demo.com/obl/test/value/watermark", + }, + }, + { + KeyAccessObjectID: "kao-2", + RequiredObligations: []string{ + "https://demo.com/obl/test/value/watermark", + "https://demo.com/obl/test/value/geofence", + }, + }, + }, + expectedResult: []string{ + "https://demo.com/obl/test/value/geofence", + "https://demo.com/obl/test/value/watermark", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := dedupRequiredObligations(tc.kaoResults) + assert.Equal(t, tc.expectedResult, result, "Deduplication result should match expected") + }) + } +} + func (s *TDFSuite) Test_Autoconfigure() { for index, test := range []tdfTest{ { @@ -2320,7 +2526,7 @@ func (s *TDFSuite) testDecryptWithReader(sdk *SDK, tdfFile, decryptedTdfFileName s.Require().NoError(err) }(readSeeker) - r, err := sdk.LoadTDF(readSeeker) + r, err := sdk.LoadTDF(readSeeker, test.opts...) s.Require().NoError(err) ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(300*time.Minute)) @@ -2444,8 +2650,8 @@ func (s *TDFSuite) startBackend() { s: s, privateKey: ki.private, KASInfo: KASInfo{ URL: ki.url, PublicKey: ki.public, KID: ki.kid, Algorithm: "rsa:2048", }, - legakeys: map[string]keyInfo{}, - obligations: obligationMap, + legakeys: map[string]keyInfo{}, + attrToRequiredObligations: obligationMap, } path, handler := attributesconnect.NewAttributesServiceHandler(fa) mux.Handle(path, handler) @@ -2475,7 +2681,6 @@ func (s *TDFSuite) startBackend() { withCustomAccessTokenSource(&ats), WithTokenEndpoint("http://localhost:65432/auth/token"), WithInsecurePlaintextConn(), - WithFulfillableObligationFQNs([]string{obligationGeofence, obligationRedact, obligationWatermark}), ) s.Require().NoError(err) s.sdk = sdk @@ -2554,10 +2759,10 @@ func (f *FakeKASRegistry) ListKeyAccessServers(_ context.Context, _ *connect.Req type FakeKas struct { kasconnect.UnimplementedAccessServiceHandler KASInfo - privateKey string - s *TDFSuite - legakeys map[string]keyInfo - obligations map[string]string + privateKey string + s *TDFSuite + legakeys map[string]keyInfo + attrToRequiredObligations map[string]string } func (f *FakeKas) Rewrap(_ context.Context, in *connect.Request[kaspb.RewrapRequest]) (*connect.Response[kaspb.RewrapResponse], error) { @@ -2577,7 +2782,24 @@ func (f *FakeKas) Rewrap(_ context.Context, in *connect.Request[kaspb.RewrapRequ if !ok { return nil, errors.New("requestBody not a string") } - result := f.getRewrapResponse(requestBodyStr) + + // Extract fulfillable obligations from header + var fulfillableObligations []string + if val := in.Header().Get("X-Rewrap-Additional-Context"); val != "" { + decoded, err := base64.StdEncoding.DecodeString(val) + if err == nil { + var rewrapContext struct { + Obligations struct { + FulfillableFQNs []string `json:"fulfillableFQNs"` + } `json:"obligations"` + } + if json.Unmarshal(decoded, &rewrapContext) == nil { + fulfillableObligations = rewrapContext.Obligations.FulfillableFQNs + } + } + } + + result := f.getRewrapResponse(requestBodyStr, fulfillableObligations) return connect.NewResponse(result), nil } @@ -2586,14 +2808,35 @@ func (f *FakeKas) PublicKey(_ context.Context, _ *connect.Request[kaspb.PublicKe return connect.NewResponse(&kaspb.PublicKeyResponse{PublicKey: f.KASInfo.PublicKey, Kid: f.KID}), nil } -func (f *FakeKas) getRewrapResponse(rewrapRequest string) *kaspb.RewrapResponse { +func (f *FakeKas) getRewrapResponse(rewrapRequest string, fulfillableObligations []string) *kaspb.RewrapResponse { bodyData := kaspb.UnsignedRewrapRequest{} err := protojson.Unmarshal([]byte(rewrapRequest), &bodyData) f.s.Require().NoError(err, "json.Unmarshal failed") resp := &kaspb.RewrapResponse{} - policyObligationMap := make(map[string][]string) for _, req := range bodyData.GetRequests() { + requiredObligations := f.s.checkPolicyObligations(f.attrToRequiredObligations, req) + if f.KASInfo.URL == f.s.kasTestURLLookup[obligationKas] { + // Only return failures for obligation kas URL + if !f.s.checkObligationsFulfillment(requiredObligations, fulfillableObligations) { + // Return a deny response if obligations are not fulfilled + results := &kaspb.PolicyRewrapResult{PolicyId: req.GetPolicy().GetId()} + for _, kaoReq := range req.GetKeyAccessObjects() { + kaoResult := &kaspb.KeyAccessRewrapResult{ + Result: &kaspb.KeyAccessRewrapResult_Error{ + Error: "forbidden", + }, + Status: "deny", + KeyAccessObjectId: kaoReq.GetKeyAccessObjectId(), + Metadata: createMetadataWithObligations(requiredObligations), + } + results.Results = append(results.Results, kaoResult) + } + resp.Responses = append(resp.Responses, results) + continue + } + } + results := &kaspb.PolicyRewrapResult{PolicyId: req.GetPolicy().GetId()} resp.Responses = append(resp.Responses, results) for _, kaoReq := range req.GetKeyAccessObjects() { @@ -2684,33 +2927,48 @@ func (f *FakeKas) getRewrapResponse(rewrapRequest string) *kaspb.RewrapResponse Result: &kaspb.KeyAccessRewrapResult_KasWrappedKey{KasWrappedKey: entityWrappedKey}, Status: "permit", KeyAccessObjectId: kaoReq.GetKeyAccessObjectId(), + Metadata: createMetadataWithObligations(requiredObligations), } results.Results = append(results.Results, kaoResult) } - - policyObligationMap[req.GetPolicy().GetId()] = f.s.checkPolicyObligations(f.obligations, req) } - resp.Metadata = createMetadataWithObligations(policyObligationMap) return resp } func (s *TDFSuite) checkPolicyObligations(obligationsMap map[string]string, req *kaspb.UnsignedRewrapRequest_WithPolicyRequest) []string { - var obligations []string + var requiredObligations []string sDecPolicy, policyErr := base64.StdEncoding.DecodeString(req.GetPolicy().GetBody()) policy := &Policy{} if policyErr == nil { policyErr = json.Unmarshal(sDecPolicy, policy) if policyErr != nil { - return obligations + return requiredObligations } } for _, attr := range policy.Body.DataAttributes { if val, found := obligationsMap[attr.URI]; found { - obligations = append(obligations, val) + requiredObligations = append(requiredObligations, val) } } - return obligations + return requiredObligations +} + +func (s *TDFSuite) checkObligationsFulfillment(requiredObligations, fulfillableObligations []string) bool { + // Create a set of fulfillable obligations for fast lookup + fulfillableSet := make(map[string]bool) + for _, obligation := range fulfillableObligations { + fulfillableSet[obligation] = true + } + + // Check if all required obligations are in the fulfillable set + for _, required := range requiredObligations { + if !fulfillableSet[required] { + return false + } + } + + return true } func (s *TDFSuite) checkIdentical(file, checksum string) bool {