From 354a8e5120bd5728bb2ac6d617c59ee3645e46e7 Mon Sep 17 00:00:00 2001 From: Chris Reed Date: Wed, 15 Oct 2025 19:23:36 -0500 Subject: [PATCH 1/2] feat(kas): Add required obligations to kao metadata.: --- service/kas/access/rewrap.go | 109 ++++--- service/kas/access/rewrap_test.go | 510 +++++++++++++++++++----------- 2 files changed, 381 insertions(+), 238 deletions(-) diff --git a/service/kas/access/rewrap.go b/service/kas/access/rewrap.go index 2fd9bf5d6..ecd8479b0 100644 --- a/service/kas/access/rewrap.go +++ b/service/kas/access/rewrap.go @@ -81,16 +81,12 @@ type kaoResult struct { Error error // Optional: Present for EC wrapped responses - EphemeralPublicKey []byte -} - -type policyResult struct { - kaoResults map[string]kaoResult - requiredObligations []string + EphemeralPublicKey []byte + RequiredObligations []string } // From policy ID to KAO ID to result -type policyKAOResults map[string]policyResult +type policyKAOResults map[string]map[string]kaoResult type ObligationCtx struct { FulfillableFQNs []string `json:"fulfillableFQNs,omitempty"` @@ -378,6 +374,14 @@ func getEntityInfo(ctx context.Context, logger *logger.Logger) (*entityInfo, err return info, nil } +func failedKAORewrapWithObligations(res map[string]kaoResult, kao *kaspb.UnsignedRewrapRequest_WithKeyAccessObject, err error, requiredObligations []string) { + res[kao.GetKeyAccessObjectId()] = kaoResult{ + ID: kao.GetKeyAccessObjectId(), + Error: err, + RequiredObligations: requiredObligations, + } +} + func failedKAORewrap(res map[string]kaoResult, kao *kaspb.UnsignedRewrapRequest_WithKeyAccessObject, err error) { res[kao.GetKeyAccessObjectId()] = kaoResult{ ID: kao.GetKeyAccessObjectId(), @@ -390,12 +394,14 @@ func addResultsToResponse(response *kaspb.RewrapResponse, result policyKAOResult policyResults := &kaspb.PolicyRewrapResult{ PolicyId: policyID, } - for kaoID, kaoRes := range policyMap.kaoResults { + for kaoID, kaoRes := range policyMap { + // Add metadata kaoResult := &kaspb.KeyAccessRewrapResult{ KeyAccessObjectId: kaoID, } switch { case kaoRes.Error != nil: + kaoResult.Metadata = createKAOMetadata(kaoRes.RequiredObligations) kaoResult.Status = kFailedStatus kaoResult.Result = &kaspb.KeyAccessRewrapResult_Error{Error: kaoRes.Error.Error()} case kaoRes.Encapped != nil: @@ -408,7 +414,6 @@ func addResultsToResponse(response *kaspb.RewrapResponse, result policyKAOResult policyResults.Results = append(policyResults.Results, kaoResult) } response.Responses = append(response.Responses, policyResults) - populateRequiredObligationsOnResponse(response, policyMap.requiredObligations, policyID) } } @@ -470,16 +475,16 @@ func (p *Provider) Rewrap(ctx context.Context, req *connect.Request[kaspb.Rewrap p.Logger.WarnContext(ctx, "status 400 due to wrong result set size", slog.Any("results", results)) return nil, err400("invalid request") } - policyResults := *getMapValue(results) - if len(policyResults.kaoResults) != 1 { + kaoResults := *getMapValue(results) + if len(kaoResults) != 1 { p.Logger.WarnContext(ctx, "status 400 due to wrong result set size", - slog.Any("kao_results", policyResults.kaoResults), + slog.Any("kao_results", kaoResults), slog.Any("results", results), ) return nil, err400("invalid request") } - kao := *getMapValue(policyResults.kaoResults) + kao := *getMapValue(kaoResults) if kao.Error != nil { p.Logger.DebugContext(ctx, "forwarding legacy err", slog.Any("error", kao.Error)) @@ -707,10 +712,7 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew for _, req := range requests { policy, kaoResults, err := p.verifyRewrapRequests(ctx, req) policyID := req.GetPolicy().GetId() - results[policyID] = policyResult{ - kaoResults: kaoResults, - requiredObligations: []string{}, - } + results[policyID] = kaoResults if err != nil { p.Logger.WarnContext(ctx, "rewrap: verifyRewrapRequests failed", @@ -765,6 +767,7 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew for _, pdpAccess := range pdpAccessResults { policy := pdpAccess.Policy + requiredObligationsForPolicy := pdpAccess.RequiredObligations req, ok := policyReqs[policy] if !ok { //nolint:sloglint // reference to key is intentional @@ -772,13 +775,12 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew continue } - policyRes, ok := results[req.GetPolicy().GetId()] + kaoResults, ok := results[req.GetPolicy().GetId()] if !ok { // this should not happen //nolint:sloglint // reference to key is intentional p.Logger.WarnContext(ctx, "policy not found in policyReq response", "policy.uuid", policy.UUID) continue } - kaoResults := policyRes.kaoResults access := pdpAccess.Access // Audit the TDF3 Rewrap @@ -802,7 +804,7 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew if !access { p.Logger.Audit.RewrapFailure(ctx, auditEventParams) - failedKAORewrap(kaoResults, kao, err403("forbidden")) + failedKAORewrapWithObligations(kaoResults, kao, err403("forbidden"), requiredObligationsForPolicy) continue } @@ -823,10 +825,6 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew p.Logger.Audit.RewrapSuccess(ctx, auditEventParams) } - results[req.GetPolicy().GetId()] = policyResult{ - kaoResults: kaoResults, - requiredObligations: pdpAccess.RequiredObligations, - } } return sessionKey, results } @@ -842,7 +840,7 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned for _, req := range requests { policy, kaoResults := p.verifyNanoRewrapRequests(ctx, req) - results[req.GetPolicy().GetId()] = policyResult{kaoResults: kaoResults, requiredObligations: []string{}} + results[req.GetPolicy().GetId()] = kaoResults if policy != nil { policies = append(policies, policy) policyReqs[policy] = req @@ -875,17 +873,17 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned for _, pdpAccess := range pdpAccessResults { policy := pdpAccess.Policy + requiredObligationsForPolicy := pdpAccess.RequiredObligations req, ok := policyReqs[policy] if !ok { // this should not happen continue } - policyRes, ok := results[req.GetPolicy().GetId()] + kaoResults, ok := results[req.GetPolicy().GetId()] if !ok { // this should not happen //nolint:sloglint // reference to key is intentional p.Logger.WarnContext(ctx, "policy not found in policyReq response", "policy.uuid", policy.UUID) continue } - kaoResults := policyRes.kaoResults access := pdpAccess.Access // Audit the Nano Rewrap @@ -906,7 +904,7 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned if !access { p.Logger.Audit.RewrapFailure(ctx, auditEventParams) - failedKAORewrap(kaoResults, kao, err403("forbidden")) + failedKAORewrapWithObligations(kaoResults, kao, err403("forbidden"), requiredObligationsForPolicy) continue } cipherText, err := kaoInfo.DEK.Export(sessionKey) @@ -923,10 +921,6 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned p.Logger.Audit.RewrapSuccess(ctx, auditEventParams) } - results[req.GetPolicy().GetId()] = policyResult{ - kaoResults: kaoResults, - requiredObligations: pdpAccess.RequiredObligations, - } } return sessionKeyPEM, results } @@ -1045,37 +1039,48 @@ func extractNanoPolicy(symmetricKey ocrypto.ProtectedKey, header sdk.NanoTDFHead func failAllKaos(reqs []*kaspb.UnsignedRewrapRequest_WithPolicyRequest, results policyKAOResults, err error) { for _, req := range reqs { for _, kao := range req.GetKeyAccessObjects() { - failedKAORewrap(results[req.GetPolicy().GetId()].kaoResults, kao, err) + failedKAORewrap(results[req.GetPolicy().GetId()], kao, err) } } } -// Populate response metadata with required obligations -func populateRequiredObligationsOnResponse(response *kaspb.RewrapResponse, obligations []string, policyID string) { - metadata := response.GetMetadata() - if metadata == nil { - metadata = make(map[string]*structpb.Value) - } - - var fields map[string]*structpb.Value - obligationValue, ok := metadata[requiredObligationsHeader] - if !ok || obligationValue.GetStructValue() == nil { - fields = make(map[string]*structpb.Value) - metadata[requiredObligationsHeader] = structpb.NewStructValue(&structpb.Struct{ - Fields: fields, - }) - } else { - fields = obligationValue.GetStructValue().GetFields() - } +// Populate response metadata with required obligations for each key access object response +// Result will look like: +/* + { + "responses": [ + { + policy_id: "policy-uuid", + results: [ + { + "metadata": { + "X-Required-Obligations": [] + }, + "key_access_object_id": "kao-uuid", + }, + { + "metadata": { + "X-Required-Obligations": [] + }, + "key_access_object_id": "kao-uuid", + }, + ] + } + ] + } +*/ +func createKAOMetadata(obligations []string) map[string]*structpb.Value { + metadata := make(map[string]*structpb.Value) values := make([]*structpb.Value, len(obligations)) for i, obligation := range obligations { values[i] = structpb.NewStringValue(obligation) } - fields[policyID] = structpb.NewListValue(&structpb.ListValue{ + metadata[requiredObligationsHeader] = structpb.NewListValue(&structpb.ListValue{ Values: values, }) - response.Metadata = metadata + + return metadata } // Retrieve additional request context needed for rewrap processing diff --git a/service/kas/access/rewrap_test.go b/service/kas/access/rewrap_test.go index d3873a067..af72df7fc 100644 --- a/service/kas/access/rewrap_test.go +++ b/service/kas/access/rewrap_test.go @@ -754,235 +754,373 @@ func TestGetAdditionalRewrapContext(t *testing.T) { } } -func TestPopulateRequiredObligationsOnResponse(t *testing.T) { - type policyObligation struct { +func TestCreateKAOMetadata(t *testing.T) { + tests := []struct { + name string obligations []string - policyID string + expected map[string]*structpb.Value + }{ + { + name: "single obligation", + obligations: []string{"https://demo.com/obl/test/value/watermark"}, + expected: map[string]*structpb.Value{ + requiredObligationsHeader: structpb.NewListValue(&structpb.ListValue{ + Values: []*structpb.Value{ + structpb.NewStringValue("https://demo.com/obl/test/value/watermark"), + }, + }), + }, + }, + { + name: "multiple obligations", + obligations: []string{ + "https://demo.com/obl/test/value/watermark", + "https://demo.com/obl/test/value/geofence", + "https://example.com/obl/test/value/mfa", + }, + expected: map[string]*structpb.Value{ + requiredObligationsHeader: structpb.NewListValue(&structpb.ListValue{ + Values: []*structpb.Value{ + structpb.NewStringValue("https://demo.com/obl/test/value/watermark"), + structpb.NewStringValue("https://demo.com/obl/test/value/geofence"), + structpb.NewStringValue("https://example.com/obl/test/value/mfa"), + }, + }), + }, + }, + { + name: "empty obligations list", + obligations: []string{}, + expected: map[string]*structpb.Value{ + requiredObligationsHeader: structpb.NewListValue(&structpb.ListValue{ + Values: []*structpb.Value{}, + }), + }, + }, + { + name: "nil obligations list", + obligations: nil, + expected: map[string]*structpb.Value{ + requiredObligationsHeader: structpb.NewListValue(&structpb.ListValue{ + Values: []*structpb.Value{}, + }), + }, + }, } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := createKAOMetadata(tt.obligations) + + // Verify the result has the correct structure + require.NotNil(t, result) + require.Contains(t, result, requiredObligationsHeader) + + // Get the list value from both result and expected + resultList := result[requiredObligationsHeader].GetListValue() + expectedList := tt.expected[requiredObligationsHeader].GetListValue() + + require.NotNil(t, resultList) + require.NotNil(t, expectedList) + + // Verify the number of values matches + require.Len(t, resultList.GetValues(), len(expectedList.GetValues())) + + // Verify each obligation value + for i, expectedValue := range expectedList.GetValues() { + actualValue := resultList.GetValues()[i] + require.Equal(t, expectedValue.GetStringValue(), actualValue.GetStringValue()) + } + }) + } +} + +func TestAddResultsToResponse(t *testing.T) { tests := []struct { name string - response *kaspb.RewrapResponse - policies []policyObligation - validate func(t *testing.T, response *kaspb.RewrapResponse) + input policyKAOResults + expected *kaspb.RewrapResponse }{ { - name: "single policy with single obligation", - response: &kaspb.RewrapResponse{ - Metadata: make(map[string]*structpb.Value), - }, - policies: []policyObligation{ - { - obligations: []string{"https://demo.com/obl/test/value/watermark"}, - policyID: "policy1", + name: "single policy with successful KAO", + input: policyKAOResults{ + "policy-1": { + "kao-1": kaoResult{ + ID: "kao-1", + Encapped: []byte("encrypted-key-data"), + }, }, }, - validate: func(t *testing.T, response *kaspb.RewrapResponse) { - metadata := response.GetMetadata() - require.Contains(t, metadata, requiredObligationsHeader) //nolint:staticcheck // testing deprecated field - - structValue := metadata[requiredObligationsHeader].GetStructValue() //nolint:staticcheck // testing deprecated field - require.NotNil(t, structValue) - require.Contains(t, structValue.GetFields(), "policy1") - - fields := structValue.GetFields() - listValue := fields["policy1"].GetListValue() - require.NotNil(t, listValue) - require.Len(t, listValue.GetValues(), 1) - assert.Equal(t, "https://demo.com/obl/test/value/watermark", listValue.GetValues()[0].GetStringValue()) + expected: &kaspb.RewrapResponse{ + Responses: []*kaspb.PolicyRewrapResult{ + { + PolicyId: "policy-1", + Results: []*kaspb.KeyAccessRewrapResult{ + { + KeyAccessObjectId: "kao-1", + Status: kPermitStatus, + Result: &kaspb.KeyAccessRewrapResult_KasWrappedKey{KasWrappedKey: []byte("encrypted-key-data")}, + }, + }, + }, + }, }, }, { - name: "single policy with multiple obligations", - response: &kaspb.RewrapResponse{ - Metadata: make(map[string]*structpb.Value), - }, - policies: []policyObligation{ - { - obligations: []string{ - "https://demo.com/obl/test/value/watermark", - "https://demo.com/obl/test/value/geofence", + name: "single policy with failed KAO and obligations", + input: policyKAOResults{ + "policy-1": { + "kao-1": kaoResult{ + ID: "kao-1", + Error: errors.New("access denied"), + RequiredObligations: []string{"https://demo.com/obl/test/value/watermark"}, }, - policyID: "policy1", }, }, - validate: func(t *testing.T, response *kaspb.RewrapResponse) { - metadata := response.GetMetadata() - require.Contains(t, metadata, requiredObligationsHeader) - - structValue := metadata[requiredObligationsHeader].GetStructValue() - require.NotNil(t, structValue) - require.Contains(t, structValue.GetFields(), "policy1") - - fields := structValue.GetFields() - listValue := fields["policy1"].GetListValue() - require.NotNil(t, listValue) - require.Len(t, listValue.GetValues(), 2) - assert.Equal(t, "https://demo.com/obl/test/value/watermark", listValue.GetValues()[0].GetStringValue()) - assert.Equal(t, "https://demo.com/obl/test/value/geofence", listValue.GetValues()[1].GetStringValue()) + expected: &kaspb.RewrapResponse{ + Responses: []*kaspb.PolicyRewrapResult{ + { + PolicyId: "policy-1", + Results: []*kaspb.KeyAccessRewrapResult{ + { + KeyAccessObjectId: "kao-1", + Status: kFailedStatus, + Result: &kaspb.KeyAccessRewrapResult_Error{Error: "access denied"}, + Metadata: map[string]*structpb.Value{ + requiredObligationsHeader: structpb.NewListValue(&structpb.ListValue{ + Values: []*structpb.Value{ + structpb.NewStringValue("https://demo.com/obl/test/value/watermark"), + }, + }), + }, + }, + }, + }, + }, }, }, { - name: "multiple policies with different obligations", - response: &kaspb.RewrapResponse{ - Metadata: make(map[string]*structpb.Value), - }, - policies: []policyObligation{ - { - obligations: []string{"https://demo.com/obl/test/value/watermark"}, - policyID: "policy1", - }, - { - obligations: []string{"https://demo.com/obl/test/value/geofence"}, - policyID: "policy2", - }, - { - obligations: []string{"https://example.com/obl/test/value/mfa"}, - policyID: "policy3", + name: "single policy with failed KAO and no obligations", + input: policyKAOResults{ + "policy-1": { + "kao-1": kaoResult{ + ID: "kao-1", + Error: errors.New("invalid key"), + }, }, }, - validate: func(t *testing.T, response *kaspb.RewrapResponse) { - metadata := response.GetMetadata() - require.Contains(t, metadata, requiredObligationsHeader) - - structValue := metadata[requiredObligationsHeader].GetStructValue() - require.NotNil(t, structValue) - fields := structValue.GetFields() - - // Verify policy1 - require.Contains(t, fields, "policy1") - listValue1 := fields["policy1"].GetListValue() - require.NotNil(t, listValue1) - require.Len(t, listValue1.GetValues(), 1) - require.Equal(t, "https://demo.com/obl/test/value/watermark", listValue1.GetValues()[0].GetStringValue()) - - // Verify policy2 - require.Contains(t, fields, "policy2") - listValue2 := fields["policy2"].GetListValue() - require.NotNil(t, listValue2) - require.Len(t, listValue2.GetValues(), 1) - require.Equal(t, "https://demo.com/obl/test/value/geofence", listValue2.GetValues()[0].GetStringValue()) - - // Verify policy3 - require.Contains(t, fields, "policy3") - listValue3 := fields["policy3"].GetListValue() - require.NotNil(t, listValue3) - require.Len(t, listValue3.GetValues(), 1) - require.Equal(t, "https://example.com/obl/test/value/mfa", listValue3.GetValues()[0].GetStringValue()) + expected: &kaspb.RewrapResponse{ + Responses: []*kaspb.PolicyRewrapResult{ + { + PolicyId: "policy-1", + Results: []*kaspb.KeyAccessRewrapResult{ + { + KeyAccessObjectId: "kao-1", + Status: kFailedStatus, + Result: &kaspb.KeyAccessRewrapResult_Error{Error: "invalid key"}, + Metadata: map[string]*structpb.Value{ + requiredObligationsHeader: structpb.NewListValue(&structpb.ListValue{ + Values: []*structpb.Value{}, + }), + }, + }, + }, + }, + }, }, }, { - name: "empty obligations list", - response: &kaspb.RewrapResponse{ - Metadata: make(map[string]*structpb.Value), - }, - policies: []policyObligation{ - { - obligations: []string{}, - policyID: "policy1", + name: "single policy with unprocessed KAO", + input: policyKAOResults{ + "policy-1": { + "kao-1": kaoResult{ + ID: "kao-1", + // No Error and no Encapped data + }, }, }, - validate: func(t *testing.T, response *kaspb.RewrapResponse) { - metadata := response.GetMetadata() - require.Contains(t, metadata, requiredObligationsHeader) //nolint:staticcheck // testing deprecated field - - structValue := metadata[requiredObligationsHeader].GetStructValue() //nolint:staticcheck // testing deprecated field - require.NotNil(t, structValue) - require.Contains(t, structValue.GetFields(), "policy1") - - fields := structValue.GetFields() - listValue := fields["policy1"].GetListValue() - require.NotNil(t, listValue) - require.Empty(t, listValue.GetValues()) + expected: &kaspb.RewrapResponse{ + Responses: []*kaspb.PolicyRewrapResult{ + { + PolicyId: "policy-1", + Results: []*kaspb.KeyAccessRewrapResult{ + { + KeyAccessObjectId: "kao-1", + Status: kFailedStatus, + Result: &kaspb.KeyAccessRewrapResult_Error{Error: "kao not processed by kas"}, + }, + }, + }, + }, }, }, { - name: "nil response metadata", - response: &kaspb.RewrapResponse{ - Metadata: nil, - }, - policies: []policyObligation{ - { - obligations: []string{"https://demo.com/obl/test/value/watermark"}, - policyID: "policy1", + name: "multiple policies with mixed results", + input: policyKAOResults{ + "policy-1": { + "kao-1": kaoResult{ + ID: "kao-1", + Encapped: []byte("encrypted-key-1"), + }, + "kao-2": kaoResult{ + ID: "kao-2", + Error: errors.New("forbidden"), + RequiredObligations: []string{"https://demo.com/obl/test/value/watermark", "https://demo.com/obl/test/value/geofence"}, + }, + }, + "policy-2": { + "kao-3": kaoResult{ + ID: "kao-3", + Encapped: []byte("encrypted-key-3"), + }, }, }, - validate: func(t *testing.T, response *kaspb.RewrapResponse) { - metadata := response.GetMetadata() - require.NotNil(t, metadata) //nolint:staticcheck // testing deprecated field - require.Contains(t, metadata, requiredObligationsHeader) //nolint:staticcheck // testing deprecated field - - structValue := metadata[requiredObligationsHeader].GetStructValue() //nolint:staticcheck // testing deprecated field - require.NotNil(t, structValue) - require.Contains(t, structValue.GetFields(), "policy1") - - fields := structValue.GetFields() - listValue := fields["policy1"].GetListValue() - require.NotNil(t, listValue) - require.Len(t, listValue.GetValues(), 1) - require.Equal(t, "https://demo.com/obl/test/value/watermark", listValue.GetValues()[0].GetStringValue()) + expected: &kaspb.RewrapResponse{ + Responses: []*kaspb.PolicyRewrapResult{ + { + PolicyId: "policy-1", + Results: []*kaspb.KeyAccessRewrapResult{ + { + KeyAccessObjectId: "kao-1", + Status: kPermitStatus, + Result: &kaspb.KeyAccessRewrapResult_KasWrappedKey{KasWrappedKey: []byte("encrypted-key-1")}, + }, + { + KeyAccessObjectId: "kao-2", + Status: kFailedStatus, + Result: &kaspb.KeyAccessRewrapResult_Error{Error: "forbidden"}, + Metadata: map[string]*structpb.Value{ + requiredObligationsHeader: structpb.NewListValue(&structpb.ListValue{ + Values: []*structpb.Value{ + structpb.NewStringValue("https://demo.com/obl/test/value/watermark"), + structpb.NewStringValue("https://demo.com/obl/test/value/geofence"), + }, + }), + }, + }, + }, + }, + { + PolicyId: "policy-2", + Results: []*kaspb.KeyAccessRewrapResult{ + { + KeyAccessObjectId: "kao-3", + Status: kPermitStatus, + Result: &kaspb.KeyAccessRewrapResult_KasWrappedKey{KasWrappedKey: []byte("encrypted-key-3")}, + }, + }, + }, + }, }, }, { - name: "preserve existing metadata when adding obligations", - response: &kaspb.RewrapResponse{ - Metadata: map[string]*structpb.Value{ - "existing-header": structpb.NewStringValue("existing-value"), - "session-info": structpb.NewStructValue(&structpb.Struct{ - Fields: map[string]*structpb.Value{ - "sessionId": structpb.NewStringValue("session-123"), - "timestamp": structpb.NewNumberValue(1672531200), - }, - }), - }, + name: "empty input", + input: policyKAOResults{}, + expected: &kaspb.RewrapResponse{ + Responses: []*kaspb.PolicyRewrapResult{}, + }, + }, + { + name: "policy with empty KAO map", + input: policyKAOResults{ + "policy-1": {}, }, - policies: []policyObligation{ - { - obligations: []string{ - "https://demo.com/obl/test/value/watermark", + expected: &kaspb.RewrapResponse{ + Responses: []*kaspb.PolicyRewrapResult{ + { + PolicyId: "policy-1", + Results: []*kaspb.KeyAccessRewrapResult{}, }, - policyID: "policy1", }, }, - validate: func(t *testing.T, response *kaspb.RewrapResponse) { - metadata := response.GetMetadata() - require.NotNil(t, metadata) //nolint:staticcheck // testing deprecated field - - // Verify existing metadata is preserved - require.Contains(t, metadata, "existing-header") - require.Equal(t, "existing-value", metadata["existing-header"].GetStringValue()) - - require.Contains(t, metadata, "session-info") - sessionInfo := metadata["session-info"].GetStructValue() - require.NotNil(t, sessionInfo) - sessionFields := sessionInfo.GetFields() - require.Contains(t, sessionFields, "sessionId") - require.Equal(t, "session-123", sessionFields["sessionId"].GetStringValue()) - require.Contains(t, sessionFields, "timestamp") - require.InDelta(t, float64(1672531200), sessionFields["timestamp"].GetNumberValue(), 0.001) - - // Verify new obligations are added - require.Contains(t, metadata, requiredObligationsHeader) - structValue := metadata[requiredObligationsHeader].GetStructValue() - require.NotNil(t, structValue) - require.Contains(t, structValue.GetFields(), "policy1") - - obligationFields := structValue.GetFields() - listValue := obligationFields["policy1"].GetListValue() - require.NotNil(t, listValue) - require.Len(t, listValue.GetValues(), 1) - require.Equal(t, "https://demo.com/obl/test/value/watermark", listValue.GetValues()[0].GetStringValue()) - }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Call populateRequiredObligationsOnResponse for each policy - for _, policy := range tt.policies { - populateRequiredObligationsOnResponse(tt.response, policy.obligations, policy.policyID) + response := &kaspb.RewrapResponse{} + addResultsToResponse(response, tt.input) + + // Verify the number of policy responses + require.Len(t, response.GetResponses(), len(tt.expected.GetResponses())) + + // Sort both responses by PolicyId for consistent comparison + actualPolicies := make(map[string]*kaspb.PolicyRewrapResult) + for _, policyResult := range response.GetResponses() { + actualPolicies[policyResult.GetPolicyId()] = policyResult + } + + expectedPolicies := make(map[string]*kaspb.PolicyRewrapResult) + for _, policyResult := range tt.expected.GetResponses() { + expectedPolicies[policyResult.GetPolicyId()] = policyResult + } + + // Verify each policy response + for policyID, expectedPolicy := range expectedPolicies { + actualPolicy, exists := actualPolicies[policyID] + require.True(t, exists, "Expected policy %s not found in response", policyID) + require.Equal(t, expectedPolicy.GetPolicyId(), actualPolicy.GetPolicyId()) + + // Verify the number of KAO results + require.Len(t, actualPolicy.GetResults(), len(expectedPolicy.GetResults())) + + // Sort KAO results by KeyAccessObjectId for consistent comparison + actualKAOs := make(map[string]*kaspb.KeyAccessRewrapResult) + for _, kaoResult := range actualPolicy.GetResults() { + actualKAOs[kaoResult.GetKeyAccessObjectId()] = kaoResult + } + + expectedKAOs := make(map[string]*kaspb.KeyAccessRewrapResult) + for _, kaoResult := range expectedPolicy.GetResults() { + expectedKAOs[kaoResult.GetKeyAccessObjectId()] = kaoResult + } + + // Verify each KAO result + for kaoID, expectedKAO := range expectedKAOs { + actualKAO, actualKAOExists := actualKAOs[kaoID] + require.True(t, actualKAOExists, "Expected KAO %s not found in policy %s", kaoID, policyID) + + require.Equal(t, expectedKAO.GetKeyAccessObjectId(), actualKAO.GetKeyAccessObjectId()) + require.Equal(t, expectedKAO.GetStatus(), actualKAO.GetStatus()) + + // Verify result content + switch expectedResult := expectedKAO.GetResult().(type) { + case *kaspb.KeyAccessRewrapResult_KasWrappedKey: + actualResult, ok := actualKAO.GetResult().(*kaspb.KeyAccessRewrapResult_KasWrappedKey) + require.True(t, ok, "Expected KasWrappedKey result for KAO %s", kaoID) + require.Equal(t, expectedResult.KasWrappedKey, actualResult.KasWrappedKey) + case *kaspb.KeyAccessRewrapResult_Error: + actualResult, ok := actualKAO.GetResult().(*kaspb.KeyAccessRewrapResult_Error) + require.True(t, ok, "Expected Error result for KAO %s", kaoID) + require.Equal(t, expectedResult.Error, actualResult.Error) + } + + // Verify metadata if expected + if expectedKAO.GetMetadata() != nil { + require.NotNil(t, actualKAO.GetMetadata(), "Expected metadata for KAO %s", kaoID) + + // Verify required obligations header + if expectedObligations, oblExists := expectedKAO.GetMetadata()[requiredObligationsHeader]; oblExists { + actualObligations, actualExists := actualKAO.GetMetadata()[requiredObligationsHeader] + require.True(t, actualExists, "Expected obligations header in metadata for KAO %s", kaoID) + + expectedList := expectedObligations.GetListValue() + actualList := actualObligations.GetListValue() + require.NotNil(t, expectedList) + require.NotNil(t, actualList) + require.Len(t, actualList.GetValues(), len(expectedList.GetValues())) + + for i, expectedValue := range expectedList.GetValues() { + actualValue := actualList.GetValues()[i] + require.Equal(t, expectedValue.GetStringValue(), actualValue.GetStringValue()) + } + } + } else if actualKAO.GetMetadata() != nil { + // If no metadata is expected, actualKAO.Metadata should be nil or empty + require.Empty(t, actualKAO.GetMetadata(), "Unexpected metadata for KAO %s", kaoID) + } + } } - tt.validate(t, tt.response) }) } } From 64bf9003ac1e140284e011938b8750dca628292b Mon Sep 17 00:00:00 2001 From: Chris Reed Date: Thu, 16 Oct 2025 13:18:00 -0500 Subject: [PATCH 2/2] add for success too. --- service/kas/access/rewrap.go | 14 ++++++++------ service/kas/access/rewrap_test.go | 27 +++++++++++++++++++++++++-- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/service/kas/access/rewrap.go b/service/kas/access/rewrap.go index ecd8479b0..1193d8009 100644 --- a/service/kas/access/rewrap.go +++ b/service/kas/access/rewrap.go @@ -401,7 +401,6 @@ func addResultsToResponse(response *kaspb.RewrapResponse, result policyKAOResult } switch { case kaoRes.Error != nil: - kaoResult.Metadata = createKAOMetadata(kaoRes.RequiredObligations) kaoResult.Status = kFailedStatus kaoResult.Result = &kaspb.KeyAccessRewrapResult_Error{Error: kaoRes.Error.Error()} case kaoRes.Encapped != nil: @@ -411,6 +410,7 @@ func addResultsToResponse(response *kaspb.RewrapResponse, result policyKAOResult kaoResult.Status = kFailedStatus kaoResult.Result = &kaspb.KeyAccessRewrapResult_Error{Error: "kao not processed by kas"} } + kaoResult.Metadata = createKAOMetadata(kaoRes.RequiredObligations) policyResults.Results = append(policyResults.Results, kaoResult) } response.Responses = append(response.Responses, policyResults) @@ -818,9 +818,10 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew continue } kaoResults[kaoID] = kaoResult{ - ID: kaoID, - Encapped: encryptedKey, - EphemeralPublicKey: asymEncrypt.EphemeralKey(), + ID: kaoID, + Encapped: encryptedKey, + EphemeralPublicKey: asymEncrypt.EphemeralKey(), + RequiredObligations: requiredObligationsForPolicy, } p.Logger.Audit.RewrapSuccess(ctx, auditEventParams) @@ -915,8 +916,9 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned } kaoResults[kao.GetKeyAccessObjectId()] = kaoResult{ - ID: kao.GetKeyAccessObjectId(), - Encapped: cipherText, + ID: kao.GetKeyAccessObjectId(), + Encapped: cipherText, + RequiredObligations: requiredObligationsForPolicy, } p.Logger.Audit.RewrapSuccess(ctx, auditEventParams) diff --git a/service/kas/access/rewrap_test.go b/service/kas/access/rewrap_test.go index af72df7fc..813e6c096 100644 --- a/service/kas/access/rewrap_test.go +++ b/service/kas/access/rewrap_test.go @@ -846,8 +846,9 @@ func TestAddResultsToResponse(t *testing.T) { input: policyKAOResults{ "policy-1": { "kao-1": kaoResult{ - ID: "kao-1", - Encapped: []byte("encrypted-key-data"), + ID: "kao-1", + Encapped: []byte("encrypted-key-data"), + RequiredObligations: []string{"https://demo.com/obl/test/value/watermark"}, }, }, }, @@ -860,6 +861,13 @@ func TestAddResultsToResponse(t *testing.T) { KeyAccessObjectId: "kao-1", Status: kPermitStatus, Result: &kaspb.KeyAccessRewrapResult_KasWrappedKey{KasWrappedKey: []byte("encrypted-key-data")}, + Metadata: map[string]*structpb.Value{ + requiredObligationsHeader: structpb.NewListValue(&structpb.ListValue{ + Values: []*structpb.Value{ + structpb.NewStringValue("https://demo.com/obl/test/value/watermark"), + }, + }), + }, }, }, }, @@ -948,6 +956,11 @@ func TestAddResultsToResponse(t *testing.T) { KeyAccessObjectId: "kao-1", Status: kFailedStatus, Result: &kaspb.KeyAccessRewrapResult_Error{Error: "kao not processed by kas"}, + Metadata: map[string]*structpb.Value{ + requiredObligationsHeader: structpb.NewListValue(&structpb.ListValue{ + Values: []*structpb.Value{}, + }), + }, }, }, }, @@ -984,6 +997,11 @@ func TestAddResultsToResponse(t *testing.T) { KeyAccessObjectId: "kao-1", Status: kPermitStatus, Result: &kaspb.KeyAccessRewrapResult_KasWrappedKey{KasWrappedKey: []byte("encrypted-key-1")}, + Metadata: map[string]*structpb.Value{ + requiredObligationsHeader: structpb.NewListValue(&structpb.ListValue{ + Values: []*structpb.Value{}, + }), + }, }, { KeyAccessObjectId: "kao-2", @@ -1007,6 +1025,11 @@ func TestAddResultsToResponse(t *testing.T) { KeyAccessObjectId: "kao-3", Status: kPermitStatus, Result: &kaspb.KeyAccessRewrapResult_KasWrappedKey{KasWrappedKey: []byte("encrypted-key-3")}, + Metadata: map[string]*structpb.Value{ + requiredObligationsHeader: structpb.NewListValue(&structpb.ListValue{ + Values: []*structpb.Value{}, + }), + }, }, }, },