Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 9 additions & 23 deletions sdk/bulk.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type BulkDecryptRequest struct {
type BulkDecryptPrepared struct {
PolicyTDF map[string]*BulkTDF
tdfDecryptors map[string]decryptor
allRewrapResp map[string]policyResult
allRewrapResp map[string][]kaoResult
}

// BulkErrors List of Errors that Failed during Bulk Decryption
Expand Down Expand Up @@ -181,9 +181,9 @@ func (s SDK) prepareDecryptors(ctx context.Context, bulkReq *BulkDecryptRequest)
}

// performRewraps executes all rewrap requests with KAS servers
func (s SDK) performRewraps(ctx context.Context, bulkReq *BulkDecryptRequest, kasRewrapRequests map[string][]*kas.UnsignedRewrapRequest_WithPolicyRequest, fulfillableObligations []string) (map[string]policyResult, error) {
func (s SDK) performRewraps(ctx context.Context, bulkReq *BulkDecryptRequest, kasRewrapRequests map[string][]*kas.UnsignedRewrapRequest_WithPolicyRequest, fulfillableObligations []string) (map[string][]kaoResult, error) {
kasClient := newKASClient(s.conn.Client, s.conn.Options, s.tokenSource, s.kasSessionKey, fulfillableObligations)
allRewrapResp := make(map[string]policyResult)
allRewrapResp := make(map[string][]kaoResult)
var err error

for kasurl, rewrapRequests := range kasRewrapRequests {
Expand All @@ -194,21 +194,16 @@ func (s SDK) performRewraps(ctx context.Context, bulkReq *BulkDecryptRequest, ka
for _, req := range rewrapRequests {
id := req.GetPolicy().GetId()
for _, kao := range req.GetKeyAccessObjects() {
policyRewrapResp, ok := allRewrapResp[id]
if !ok {
policyRewrapResp = policyResult{policyID: id, obligations: []string{}, kaoRes: []kaoResult{}}
}
policyRewrapResp.kaoRes = append(policyRewrapResp.kaoRes, kaoResult{
allRewrapResp[id] = append(allRewrapResp[id], kaoResult{
Error: fmt.Errorf("KasAllowlist: kas url %s is not allowed", kasurl),
KeyAccessObjectID: kao.GetKeyAccessObjectId(),
})
allRewrapResp[id] = policyRewrapResp
}
}
continue
}

var rewrapResp map[string]policyResult
var rewrapResp map[string][]kaoResult
switch bulkReq.TDFType {
case Nano:
rewrapResp, err = kasClient.nanoUnwrap(ctx, rewrapRequests...)
Expand All @@ -217,16 +212,7 @@ func (s SDK) performRewraps(ctx context.Context, bulkReq *BulkDecryptRequest, ka
}

for id, res := range rewrapResp {
// ! It's possible that we already created a policyResult for the policy above for a specific KAS URL.
// ! Meaning for another kas url of the same policy we will end up with an empty list of obligations.
// ! This should be fine since we will error out anyways.
if existingResp, ok := allRewrapResp[id]; !ok {
allRewrapResp[id] = res
} else {
// ! Should not need to append obligations since they should be the same for all TDFs under a policy
existingResp.kaoRes = append(existingResp.kaoRes, res.kaoRes...)
allRewrapResp[id] = existingResp
}
allRewrapResp[id] = append(allRewrapResp[id], res...)
}
}

Expand Down Expand Up @@ -274,7 +260,7 @@ func (s SDK) PrepareBulkDecrypt(ctx context.Context, opts ...BulkDecryptOption)
tdf.Error = errors.New("rewrap did not create a response for this TDF")
continue
}
tdf.TriggeredObligations = Obligations{FQNs: policyRes.obligations}
tdf.TriggeredObligations = Obligations{FQNs: dedupRequiredObligations(policyRes)}
}

return &BulkDecryptPrepared{
Expand All @@ -289,14 +275,14 @@ func (bp *BulkDecryptPrepared) BulkDecrypt(ctx context.Context) error {
var errList []error
var err error
for id, tdf := range bp.PolicyTDF {
policyRes, ok := bp.allRewrapResp[id]
kaoRes, ok := bp.allRewrapResp[id]
if !ok {
tdf.Error = errors.New("rewrap did not create a response for this TDF")
errList = append(errList, tdf.Error)
continue
}
decryptor := bp.tdfDecryptors[id]
if _, err = decryptor.Decrypt(ctx, policyRes.kaoRes); err != nil {
if _, err = decryptor.Decrypt(ctx, kaoRes); err != nil {
tdf.Error = err
errList = append(errList, tdf.Error)
continue
Expand Down
117 changes: 52 additions & 65 deletions sdk/kas_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,10 @@ type KASClient struct {
}

type kaoResult struct {
SymmetricKey []byte
Error error
KeyAccessObjectID string
}

type policyResult struct {
policyID string
obligations []string
kaoRes []kaoResult
SymmetricKey []byte
Error error
KeyAccessObjectID string
RequiredObligations []string
}

type decryptor interface {
Expand Down Expand Up @@ -175,7 +170,7 @@ func upgradeRewrapErrorV1(err error, requests []*kas.UnsignedRewrapRequest_WithP
}, nil
}

func (k *KASClient) nanoUnwrap(ctx context.Context, requests ...*kas.UnsignedRewrapRequest_WithPolicyRequest) (map[string]policyResult, error) {
func (k *KASClient) nanoUnwrap(ctx context.Context, requests ...*kas.UnsignedRewrapRequest_WithPolicyRequest) (map[string][]kaoResult, error) {
keypair, err := ocrypto.NewECKeyPair(ocrypto.ECCModeSecp256r1)
if err != nil {
return nil, fmt.Errorf("ocrypto.NewECKeyPair failed :%w", err)
Expand All @@ -198,22 +193,19 @@ func (k *KASClient) nanoUnwrap(ctx context.Context, requests ...*kas.UnsignedRew
// If the session key is empty, all responses are errors
spk := response.GetSessionPublicKey()
if spk == "" {
policyResults := make(map[string]policyResult)
policyResults := make(map[string][]kaoResult)
err = errors.New("nanoUnwrap: session public key is empty")
for _, results := range response.GetResponses() {
policyRes, ok := policyResults[results.GetPolicyId()]
if !ok {
policyRes = policyResult{policyID: results.GetPolicyId(), obligations: []string{}, kaoRes: []kaoResult{}}
}
var kaoKeys []kaoResult
for _, kao := range results.GetResults() {
requiredObligationsForKAO := k.retrieveObligationsFromMetadata(kao.GetMetadata())
if kao.GetStatus() == statusPermit {
policyRes.kaoRes = append(policyRes.kaoRes, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err})
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err, RequiredObligations: requiredObligationsForKAO})
} else {
policyRes.kaoRes = append(policyRes.kaoRes, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError())})
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError()), RequiredObligations: requiredObligationsForKAO})
}
}
policyRes.obligations = k.retrieveObligationsFromMetadata(response.GetMetadata(), results.GetPolicyId())
policyResults[results.GetPolicyId()] = policyRes
policyResults[results.GetPolicyId()] = kaoKeys
}

return policyResults, nil
Expand All @@ -234,33 +226,30 @@ func (k *KASClient) nanoUnwrap(ctx context.Context, requests ...*kas.UnsignedRew
return nil, fmt.Errorf("nanoUnwrap: ocrypto.NewAESGcm failed:%w", err)
}

policyResults := make(map[string]policyResult)
policyResults := make(map[string][]kaoResult)
for _, results := range response.GetResponses() {
policyRes, ok := policyResults[results.GetPolicyId()]
if !ok {
policyRes = policyResult{policyID: results.GetPolicyId(), obligations: []string{}, kaoRes: []kaoResult{}}
}
var kaoKeys []kaoResult
for _, kao := range results.GetResults() {
requiredObligationsForKAO := k.retrieveObligationsFromMetadata(kao.GetMetadata())
if kao.GetStatus() == statusPermit {
wrappedKey := kao.GetKasWrappedKey()
key, err := aesGcm.Decrypt(wrappedKey)
if err != nil {
policyRes.kaoRes = append(policyRes.kaoRes, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err})
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err, RequiredObligations: requiredObligationsForKAO})
} else {
policyRes.kaoRes = append(policyRes.kaoRes, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), SymmetricKey: key})
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), SymmetricKey: key, RequiredObligations: requiredObligationsForKAO})
}
} else {
policyRes.kaoRes = append(policyRes.kaoRes, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError())})
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError()), RequiredObligations: requiredObligationsForKAO})
}
}
policyRes.obligations = k.retrieveObligationsFromMetadata(response.GetMetadata(), results.GetPolicyId())
policyResults[results.GetPolicyId()] = policyRes
policyResults[results.GetPolicyId()] = kaoKeys
}

return policyResults, nil
}

func (k *KASClient) unwrap(ctx context.Context, requests ...*kas.UnsignedRewrapRequest_WithPolicyRequest) (map[string]policyResult, error) {
func (k *KASClient) unwrap(ctx context.Context, requests ...*kas.UnsignedRewrapRequest_WithPolicyRequest) (map[string][]kaoResult, error) {
if k.sessionKey == nil {
return nil, errors.New("session key is nil")
}
Expand All @@ -279,7 +268,7 @@ func (k *KASClient) unwrap(ctx context.Context, requests ...*kas.UnsignedRewrapR
return k.handleRSAKeyResponse(response)
}

func (k *KASClient) handleECKeyResponse(response *kas.RewrapResponse) (map[string]policyResult, error) {
func (k *KASClient) handleECKeyResponse(response *kas.RewrapResponse) (map[string][]kaoResult, error) {
kasEphemeralPublicKey := response.GetSessionPublicKey()
clientPrivateKey, err := k.sessionKey.PrivateKeyInPemFormat()
if err != nil {
Expand All @@ -306,60 +295,58 @@ func (k *KASClient) handleECKeyResponse(response *kas.RewrapResponse) (map[strin
return k.processECResponse(response, aesGcm)
}

func (k *KASClient) processECResponse(response *kas.RewrapResponse, aesGcm ocrypto.AesGcm) (map[string]policyResult, error) {
policyResults := make(map[string]policyResult)
func (k *KASClient) processECResponse(response *kas.RewrapResponse, aesGcm ocrypto.AesGcm) (map[string][]kaoResult, error) {
policyResults := make(map[string][]kaoResult)
for _, results := range response.GetResponses() {
var kaoKeys []kaoResult
for _, kao := range results.GetResults() {
requiredObligationsForKAO := k.retrieveObligationsFromMetadata(kao.GetMetadata())
if kao.GetStatus() == statusPermit {
key, err := aesGcm.Decrypt(kao.GetKasWrappedKey())
if err != nil {
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err})
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err, RequiredObligations: requiredObligationsForKAO})
} else {
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), SymmetricKey: key})
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), SymmetricKey: key, RequiredObligations: requiredObligationsForKAO})
}
} else {
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError())})
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError()), RequiredObligations: requiredObligationsForKAO})
}
}
requiredObligations := k.retrieveObligationsFromMetadata(response.GetMetadata(), results.GetPolicyId())
policyResults[results.GetPolicyId()] = policyResult{policyID: results.GetPolicyId(), kaoRes: kaoKeys, obligations: requiredObligations}
policyResults[results.GetPolicyId()] = kaoKeys
}
return policyResults, nil
}

func (k *KASClient) retrieveObligationsFromMetadata(metadata map[string]*structpb.Value, policyID string) []string {
var triggeredFQNs []string
/*
Metadata will be in the following form, per kao:

if metadata == nil {
return triggeredFQNs
{
"metadata": {
"X-Required-Obligations": [<Required obligation FQNs>]
}
}
*/
func (k *KASClient) retrieveObligationsFromMetadata(metadata map[string]*structpb.Value) []string {
var requiredObligations []string

triggerOblsValue, ok := metadata[triggeredObligationsHeader]
if !ok {
return triggeredFQNs
}

fields := triggerOblsValue.GetStructValue().GetFields()
if fields == nil {
return triggeredFQNs
if metadata == nil {
return requiredObligations
}

policyOblsValue, ok := fields[policyID]
triggerOblsValue, ok := metadata[triggeredObligationsHeader]
if !ok {
return triggeredFQNs
return requiredObligations
}

values := policyOblsValue.GetListValue().GetValues()

for _, v := range values {
triggeredFQNs = append(triggeredFQNs, v.GetStringValue())
triggerOblsList := triggerOblsValue.GetListValue().GetValues()
for _, v := range triggerOblsList {
requiredObligations = append(requiredObligations, v.GetStringValue())
}

return triggeredFQNs
return requiredObligations
}

func (k *KASClient) handleRSAKeyResponse(response *kas.RewrapResponse) (map[string]policyResult, error) {
func (k *KASClient) handleRSAKeyResponse(response *kas.RewrapResponse) (map[string][]kaoResult, error) {
clientPrivateKey, err := k.sessionKey.PrivateKeyInPemFormat()
if err != nil {
return nil, fmt.Errorf("ocrypto.PrivateKeyInPemFormat failed: %w", err)
Expand All @@ -373,24 +360,24 @@ func (k *KASClient) handleRSAKeyResponse(response *kas.RewrapResponse) (map[stri
return k.processRSAResponse(response, asymDecryption)
}

func (k *KASClient) processRSAResponse(response *kas.RewrapResponse, asymDecryption ocrypto.AsymDecryption) (map[string]policyResult, error) {
policyResults := make(map[string]policyResult)
func (k *KASClient) processRSAResponse(response *kas.RewrapResponse, asymDecryption ocrypto.AsymDecryption) (map[string][]kaoResult, error) {
policyResults := make(map[string][]kaoResult)
for _, results := range response.GetResponses() {
var kaoKeys []kaoResult
for _, kao := range results.GetResults() {
requiredObligationsForKAO := k.retrieveObligationsFromMetadata(kao.GetMetadata())
if kao.GetStatus() == statusPermit {
key, err := asymDecryption.Decrypt(kao.GetKasWrappedKey())
if err != nil {
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err})
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err, RequiredObligations: requiredObligationsForKAO})
} else {
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), SymmetricKey: key})
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), SymmetricKey: key, RequiredObligations: requiredObligationsForKAO})
}
} else {
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError())})
kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError()), RequiredObligations: requiredObligationsForKAO})
}
}
requiredObligations := k.retrieveObligationsFromMetadata(response.GetMetadata(), results.GetPolicyId())
policyResults[results.GetPolicyId()] = policyResult{policyID: results.GetPolicyId(), kaoRes: kaoKeys, obligations: requiredObligations}
policyResults[results.GetPolicyId()] = kaoKeys
}
return policyResults, nil
}
Expand Down
Loading
Loading