Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 57 additions & 52 deletions service/kas/access/rewrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,12 @@ type kaoResult struct {
Error error

// Optional: Present for EC wrapped responses
EphemeralPublicKey []byte
}

type policyResult struct {
kaoResults map[string]kaoResult
requiredObligations []string
EphemeralPublicKey []byte
RequiredObligations []string
}

// From policy ID to KAO ID to result
type policyKAOResults map[string]policyResult
type policyKAOResults map[string]map[string]kaoResult

type ObligationCtx struct {
FulfillableFQNs []string `json:"fulfillableFQNs,omitempty"`
Expand Down Expand Up @@ -378,6 +374,14 @@ func getEntityInfo(ctx context.Context, logger *logger.Logger) (*entityInfo, err
return info, nil
}

func failedKAORewrapWithObligations(res map[string]kaoResult, kao *kaspb.UnsignedRewrapRequest_WithKeyAccessObject, err error, requiredObligations []string) {
res[kao.GetKeyAccessObjectId()] = kaoResult{
ID: kao.GetKeyAccessObjectId(),
Error: err,
RequiredObligations: requiredObligations,
}
}

func failedKAORewrap(res map[string]kaoResult, kao *kaspb.UnsignedRewrapRequest_WithKeyAccessObject, err error) {
res[kao.GetKeyAccessObjectId()] = kaoResult{
ID: kao.GetKeyAccessObjectId(),
Expand All @@ -390,12 +394,14 @@ func addResultsToResponse(response *kaspb.RewrapResponse, result policyKAOResult
policyResults := &kaspb.PolicyRewrapResult{
PolicyId: policyID,
}
for kaoID, kaoRes := range policyMap.kaoResults {
for kaoID, kaoRes := range policyMap {
// Add metadata
kaoResult := &kaspb.KeyAccessRewrapResult{
KeyAccessObjectId: kaoID,
}
switch {
case kaoRes.Error != nil:
kaoResult.Metadata = createKAOMetadata(kaoRes.RequiredObligations)
kaoResult.Status = kFailedStatus
kaoResult.Result = &kaspb.KeyAccessRewrapResult_Error{Error: kaoRes.Error.Error()}
case kaoRes.Encapped != nil:
Expand All @@ -408,7 +414,6 @@ func addResultsToResponse(response *kaspb.RewrapResponse, result policyKAOResult
policyResults.Results = append(policyResults.Results, kaoResult)
}
response.Responses = append(response.Responses, policyResults)
populateRequiredObligationsOnResponse(response, policyMap.requiredObligations, policyID)
}
}

Expand Down Expand Up @@ -470,16 +475,16 @@ func (p *Provider) Rewrap(ctx context.Context, req *connect.Request[kaspb.Rewrap
p.Logger.WarnContext(ctx, "status 400 due to wrong result set size", slog.Any("results", results))
return nil, err400("invalid request")
}
policyResults := *getMapValue(results)
if len(policyResults.kaoResults) != 1 {
kaoResults := *getMapValue(results)
if len(kaoResults) != 1 {
p.Logger.WarnContext(ctx,
"status 400 due to wrong result set size",
slog.Any("kao_results", policyResults.kaoResults),
slog.Any("kao_results", kaoResults),
slog.Any("results", results),
)
return nil, err400("invalid request")
}
kao := *getMapValue(policyResults.kaoResults)
kao := *getMapValue(kaoResults)

if kao.Error != nil {
p.Logger.DebugContext(ctx, "forwarding legacy err", slog.Any("error", kao.Error))
Expand Down Expand Up @@ -707,10 +712,7 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew
for _, req := range requests {
policy, kaoResults, err := p.verifyRewrapRequests(ctx, req)
policyID := req.GetPolicy().GetId()
results[policyID] = policyResult{
kaoResults: kaoResults,
requiredObligations: []string{},
}
results[policyID] = kaoResults
if err != nil {
p.Logger.WarnContext(ctx,
"rewrap: verifyRewrapRequests failed",
Expand Down Expand Up @@ -765,20 +767,20 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew

for _, pdpAccess := range pdpAccessResults {
policy := pdpAccess.Policy
requiredObligationsForPolicy := pdpAccess.RequiredObligations
req, ok := policyReqs[policy]
if !ok {
//nolint:sloglint // reference to key is intentional
p.Logger.WarnContext(ctx, "policy not found in policyReqs", "policy.uuid", policy.UUID)
continue
}

policyRes, ok := results[req.GetPolicy().GetId()]
kaoResults, ok := results[req.GetPolicy().GetId()]
if !ok { // this should not happen
//nolint:sloglint // reference to key is intentional
p.Logger.WarnContext(ctx, "policy not found in policyReq response", "policy.uuid", policy.UUID)
continue
}
kaoResults := policyRes.kaoResults
access := pdpAccess.Access

// Audit the TDF3 Rewrap
Expand All @@ -802,7 +804,7 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew

if !access {
p.Logger.Audit.RewrapFailure(ctx, auditEventParams)
failedKAORewrap(kaoResults, kao, err403("forbidden"))
failedKAORewrapWithObligations(kaoResults, kao, err403("forbidden"), requiredObligationsForPolicy)
continue
}

Expand All @@ -823,10 +825,6 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew

p.Logger.Audit.RewrapSuccess(ctx, auditEventParams)
}
results[req.GetPolicy().GetId()] = policyResult{
kaoResults: kaoResults,
requiredObligations: pdpAccess.RequiredObligations,
}
}
return sessionKey, results
}
Expand All @@ -842,7 +840,7 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned

for _, req := range requests {
policy, kaoResults := p.verifyNanoRewrapRequests(ctx, req)
results[req.GetPolicy().GetId()] = policyResult{kaoResults: kaoResults, requiredObligations: []string{}}
results[req.GetPolicy().GetId()] = kaoResults
if policy != nil {
policies = append(policies, policy)
policyReqs[policy] = req
Expand Down Expand Up @@ -875,17 +873,17 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned

for _, pdpAccess := range pdpAccessResults {
policy := pdpAccess.Policy
requiredObligationsForPolicy := pdpAccess.RequiredObligations
req, ok := policyReqs[policy]
if !ok { // this should not happen
continue
}
policyRes, ok := results[req.GetPolicy().GetId()]
kaoResults, ok := results[req.GetPolicy().GetId()]
if !ok { // this should not happen
//nolint:sloglint // reference to key is intentional
p.Logger.WarnContext(ctx, "policy not found in policyReq response", "policy.uuid", policy.UUID)
continue
}
kaoResults := policyRes.kaoResults
access := pdpAccess.Access

// Audit the Nano Rewrap
Expand All @@ -906,7 +904,7 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned

if !access {
p.Logger.Audit.RewrapFailure(ctx, auditEventParams)
failedKAORewrap(kaoResults, kao, err403("forbidden"))
failedKAORewrapWithObligations(kaoResults, kao, err403("forbidden"), requiredObligationsForPolicy)
continue
}
cipherText, err := kaoInfo.DEK.Export(sessionKey)
Expand All @@ -923,10 +921,6 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned

p.Logger.Audit.RewrapSuccess(ctx, auditEventParams)
}
results[req.GetPolicy().GetId()] = policyResult{
kaoResults: kaoResults,
requiredObligations: pdpAccess.RequiredObligations,
}
}
return sessionKeyPEM, results
}
Expand Down Expand Up @@ -1045,37 +1039,48 @@ func extractNanoPolicy(symmetricKey ocrypto.ProtectedKey, header sdk.NanoTDFHead
func failAllKaos(reqs []*kaspb.UnsignedRewrapRequest_WithPolicyRequest, results policyKAOResults, err error) {
for _, req := range reqs {
for _, kao := range req.GetKeyAccessObjects() {
failedKAORewrap(results[req.GetPolicy().GetId()].kaoResults, kao, err)
failedKAORewrap(results[req.GetPolicy().GetId()], kao, err)
}
}
}

// Populate response metadata with required obligations
func populateRequiredObligationsOnResponse(response *kaspb.RewrapResponse, obligations []string, policyID string) {
metadata := response.GetMetadata()
if metadata == nil {
metadata = make(map[string]*structpb.Value)
}

var fields map[string]*structpb.Value
obligationValue, ok := metadata[requiredObligationsHeader]
if !ok || obligationValue.GetStructValue() == nil {
fields = make(map[string]*structpb.Value)
metadata[requiredObligationsHeader] = structpb.NewStructValue(&structpb.Struct{
Fields: fields,
})
} else {
fields = obligationValue.GetStructValue().GetFields()
}
// Populate response metadata with required obligations for each key access object response
// Result will look like:
/*
{
"responses": [
{
policy_id: "policy-uuid",
results: [
{
"metadata": {
"X-Required-Obligations": [<required obligations>]
},
"key_access_object_id": "kao-uuid",
},
{
"metadata": {
"X-Required-Obligations": [<required obligations>]
},
"key_access_object_id": "kao-uuid",
},
]
}
]
}
*/
func createKAOMetadata(obligations []string) map[string]*structpb.Value {
metadata := make(map[string]*structpb.Value)

values := make([]*structpb.Value, len(obligations))
for i, obligation := range obligations {
values[i] = structpb.NewStringValue(obligation)
}
fields[policyID] = structpb.NewListValue(&structpb.ListValue{
metadata[requiredObligationsHeader] = structpb.NewListValue(&structpb.ListValue{
Values: values,
})
response.Metadata = metadata

return metadata
}

// Retrieve additional request context needed for rewrap processing
Expand Down
Loading
Loading