Skip to content

Commit 182b463

Browse files
fix(kas): Fix kas panics on bad requests (#2916)
### Proposed Changes * Added comprehensive nil checks for requests, policies, and key access objects * Updated tdf3Rewrap and nanoTDFRewrap functions to return errors instead of silently failing * Added validation for empty wrapped keys and unsupported key types * Improved error handling in the Rewrap main handler ### Checklist - [ ] I have added or updated unit tests - [ ] I have added or updated integration tests (if appropriate) - [ ] I have added or updated documentation ### Testing Instructions
1 parent b07a4fe commit 182b463

File tree

2 files changed

+367
-28
lines changed

2 files changed

+367
-28
lines changed

service/kas/access/rewrap.go

Lines changed: 100 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ const (
108108
ErrInternal = Error("internal error")
109109

110110
ErrNanoTDFPolicyModeUnsupported = Error("unsupported policy mode")
111+
112+
errNoValidKeyAccessObjects = Error("no valid KAOs")
111113
)
112114

113115
func err400(s string) error {
@@ -597,10 +599,18 @@ func (p *Provider) Rewrap(ctx context.Context, req *connect.Request[kaspb.Rewrap
597599
return nil, err400(err.Error())
598600
}
599601
if len(tdf3Reqs) > 0 {
600-
resp.SessionPublicKey, results = p.tdf3Rewrap(ctx, tdf3Reqs, body.GetClientPublicKey(), entityInfo, additionalRewrapContext)
602+
resp.SessionPublicKey, results, err = p.tdf3Rewrap(ctx, tdf3Reqs, body.GetClientPublicKey(), entityInfo, additionalRewrapContext)
603+
if err != nil {
604+
p.Logger.WarnContext(ctx, "status 400, tdf3 rewrap failure", slog.Any("error", err))
605+
return nil, err
606+
}
601607
addResultsToResponse(resp, results)
602608
} else {
603-
resp.SessionPublicKey, results = p.nanoTDFRewrap(ctx, nanoReqs, body.GetClientPublicKey(), entityInfo, additionalRewrapContext)
609+
resp.SessionPublicKey, results, err = p.nanoTDFRewrap(ctx, nanoReqs, body.GetClientPublicKey(), entityInfo, additionalRewrapContext)
610+
if err != nil {
611+
p.Logger.WarnContext(ctx, "status 400, nanoTDF rewrap failure", slog.Any("error", err))
612+
return nil, err
613+
}
604614
addResultsToResponse(resp, results)
605615
}
606616

@@ -631,15 +641,31 @@ func (p *Provider) Rewrap(ctx context.Context, req *connect.Request[kaspb.Rewrap
631641
}
632642

633643
func (p *Provider) verifyRewrapRequests(ctx context.Context, req *kaspb.UnsignedRewrapRequest_WithPolicyRequest) (*Policy, map[string]kaoResult, error) {
634-
ctx, span := p.Start(ctx, "tdf3Rewrap")
635-
defer span.End()
644+
// Safe tracer handling - only start span if tracer is available
645+
var span trace.Span
646+
if p.Tracer != nil {
647+
ctx, span = p.Start(ctx, "tdf3Rewrap")
648+
defer span.End()
649+
}
636650

637651
results := make(map[string]kaoResult)
638652
anyValidKAOs := false
653+
policy := &Policy{}
654+
655+
// Check if req is nil
656+
if req == nil {
657+
p.Logger.WarnContext(ctx, "request is nil")
658+
return nil, results, errors.New("request is nil")
659+
}
660+
661+
// Check if policy is nil
662+
if req.GetPolicy() == nil {
663+
p.Logger.WarnContext(ctx, "policy is nil")
664+
return nil, results, errors.New("policy is nil")
665+
}
639666

640667
p.Logger.DebugContext(ctx, "extracting policy", slog.Any("policy", req.GetPolicy()))
641668
sDecPolicy, policyErr := base64.StdEncoding.DecodeString(req.GetPolicy().GetBody())
642-
policy := &Policy{}
643669
if policyErr == nil {
644670
policyErr = json.Unmarshal(sDecPolicy, policy)
645671
}
@@ -650,6 +676,21 @@ func (p *Provider) verifyRewrapRequests(ctx context.Context, req *kaspb.Unsigned
650676
continue
651677
}
652678

679+
// Check if KeyAccessObject is nil
680+
if kao.GetKeyAccessObject() == nil {
681+
p.Logger.WarnContext(ctx, "key access object is nil", slog.String("kao_id", kao.GetKeyAccessObjectId()))
682+
failedKAORewrap(results, kao, err400("bad request"))
683+
continue
684+
}
685+
686+
// Check if wrapped key is empty
687+
wrappedKey := kao.GetKeyAccessObject().GetWrappedKey()
688+
if len(wrappedKey) == 0 {
689+
p.Logger.WarnContext(ctx, "wrapped key is empty", slog.String("kao_id", kao.GetKeyAccessObjectId()))
690+
failedKAORewrap(results, kao, err400("bad request"))
691+
continue
692+
}
693+
653694
var dek ocrypto.ProtectedKey
654695
var err error
655696
switch kao.GetKeyAccessObject().GetKeyType() {
@@ -754,13 +795,28 @@ func (p *Provider) verifyRewrapRequests(ctx context.Context, req *kaspb.Unsigned
754795
}
755796
dek, err = p.KeyDelegator.Decrypt(ctx, kid, kao.GetKeyAccessObject().GetWrappedKey(), nil)
756797
}
798+
default:
799+
// handle unsupported key types
800+
keyType := kao.GetKeyAccessObject().GetKeyType()
801+
p.Logger.WarnContext(ctx, "unsupported key type",
802+
slog.String("key_type", keyType),
803+
slog.String("kao_id", kao.GetKeyAccessObjectId()))
804+
failedKAORewrap(results, kao, err400("bad request"))
805+
continue
757806
}
758807
if err != nil {
759808
p.Logger.WarnContext(ctx, "failure to decrypt dek", slog.Any("error", err))
760809
failedKAORewrap(results, kao, err400("bad request"))
761810
continue
762811
}
763812

813+
// Check if policy binding is nil
814+
if kao.GetKeyAccessObject().GetPolicyBinding() == nil {
815+
p.Logger.WarnContext(ctx, "policy binding is nil", slog.String("kao_id", kao.GetKeyAccessObjectId()))
816+
failedKAORewrap(results, kao, err400("missing policy binding"))
817+
continue
818+
}
819+
764820
// Store policy binding in context for verification
765821
policyBindingB64Encoded := kao.GetKeyAccessObject().GetPolicyBinding().GetHash()
766822
policyBinding := make([]byte, base64.StdEncoding.DecodedLen(len(policyBindingB64Encoded)))
@@ -801,7 +857,7 @@ func (p *Provider) verifyRewrapRequests(ctx context.Context, req *kaspb.Unsigned
801857

802858
if !anyValidKAOs {
803859
p.Logger.WarnContext(ctx, "no valid KAOs found")
804-
return policy, results, errors.New("no valid KAOs")
860+
return policy, results, errNoValidKeyAccessObjects
805861
}
806862

807863
return policy, results, nil
@@ -833,7 +889,7 @@ func (p *Provider) listLegacyKeys(ctx context.Context) []trust.KeyIdentifier {
833889
return kidsToCheck
834890
}
835891

836-
func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRewrapRequest_WithPolicyRequest, clientPublicKey string, entityInfo *entityInfo, additionalRewrapContext *AdditionalRewrapContext) (string, policyKAOResults) {
892+
func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRewrapRequest_WithPolicyRequest, clientPublicKey string, entityInfo *entityInfo, additionalRewrapContext *AdditionalRewrapContext) (string, policyKAOResults, error) {
837893
if p.Tracer != nil {
838894
var span trace.Span
839895
ctx, span = p.Start(ctx, "rewrap-tdf3")
@@ -844,7 +900,14 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew
844900
var policies []*Policy
845901
policyReqs := make(map[*Policy]*kaspb.UnsignedRewrapRequest_WithPolicyRequest)
846902
for _, req := range requests {
903+
if req == nil || req.GetPolicy() == nil || req.GetPolicy().GetId() == "" {
904+
p.Logger.WarnContext(ctx, "rewrap: nil request or policy")
905+
continue
906+
}
847907
policy, kaoResults, err := p.verifyRewrapRequests(ctx, req)
908+
if err != nil && !errors.Is(err, errNoValidKeyAccessObjects) {
909+
return "", nil, err400("invalid request")
910+
}
848911
policyID := req.GetPolicy().GetId()
849912
results[policyID] = kaoResults
850913
if err != nil {
@@ -872,14 +935,14 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew
872935
slog.Any("error", accessErr),
873936
)
874937
failAllKaos(requests, results, err500("could not perform access"))
875-
return "", results
938+
return "", results, nil
876939
}
877940

878941
asymEncrypt, err := ocrypto.FromPublicPEMWithSalt(clientPublicKey, security.TDFSalt(), nil)
879942
if err != nil {
880943
p.Logger.WarnContext(ctx, "ocrypto.NewAsymEncryption", slog.Any("error", err))
881944
failAllKaos(requests, results, err400("invalid request"))
882-
return "", results
945+
return "", results, nil
883946
}
884947
encap := security.OCEncapsulator{PublicKeyEncryptor: asymEncrypt}
885948

@@ -890,12 +953,12 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew
890953
p.Logger.ErrorContext(ctx, "unable to serialize ephemeral key", slog.Any("error", err))
891954
// This may be a 500, but could also be caused by a bad clientPublicKey
892955
failAllKaos(requests, results, err400("invalid request"))
893-
return "", results
956+
return "", results, nil
894957
}
895958
if !p.ECTDFEnabled && !p.Preview.ECTDFEnabled {
896959
p.Logger.ErrorContext(ctx, "ec rewrap not enabled")
897960
failAllKaos(requests, results, err400("invalid request"))
898-
return "", results
961+
return "", results, nil
899962
}
900963
}
901964

@@ -962,10 +1025,10 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew
9621025
p.Logger.Audit.RewrapSuccess(ctx, auditEventParams)
9631026
}
9641027
}
965-
return sessionKey, results
1028+
return sessionKey, results, nil
9661029
}
9671030

968-
func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.UnsignedRewrapRequest_WithPolicyRequest, clientPublicKey string, entityInfo *entityInfo, additionalRewrapContext *AdditionalRewrapContext) (string, policyKAOResults) {
1031+
func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.UnsignedRewrapRequest_WithPolicyRequest, clientPublicKey string, entityInfo *entityInfo, additionalRewrapContext *AdditionalRewrapContext) (string, policyKAOResults, error) {
9691032
ctx, span := p.Start(ctx, "nanoTDFRewrap")
9701033
defer span.End()
9711034

@@ -975,7 +1038,10 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned
9751038
policyReqs := make(map[*Policy]*kaspb.UnsignedRewrapRequest_WithPolicyRequest)
9761039

9771040
for _, req := range requests {
978-
policy, kaoResults := p.verifyNanoRewrapRequests(ctx, req)
1041+
policy, kaoResults, err := p.verifyNanoRewrapRequests(ctx, req)
1042+
if err != nil {
1043+
return "", nil, err400("invalid request")
1044+
}
9791045
results[req.GetPolicy().GetId()] = kaoResults
9801046
if policy != nil {
9811047
policies = append(policies, policy)
@@ -991,20 +1057,20 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned
9911057
pdpAccessResults, accessErr := p.canAccess(ctx, tok, policies, additionalRewrapContext.Obligations.FulfillableFQNs)
9921058
if accessErr != nil {
9931059
failAllKaos(requests, results, err500("could not perform access"))
994-
return "", results
1060+
return "", results, nil
9951061
}
9961062

9971063
sessionKey, err := p.KeyDelegator.GenerateECSessionKey(ctx, clientPublicKey)
9981064
if err != nil {
9991065
p.Logger.WarnContext(ctx, "failure in GenerateNanoTDFSessionKey", slog.Any("error", err))
10001066
failAllKaos(requests, results, err400("keypair mismatch"))
1001-
return "", results
1067+
return "", results, nil
10021068
}
10031069
sessionKeyPEM, err := sessionKey.PublicKeyAsPEM()
10041070
if err != nil {
10051071
p.Logger.WarnContext(ctx, "failure in PublicKeyToPem", slog.Any("error", err))
10061072
failAllKaos(requests, results, err500(""))
1007-
return "", results
1073+
return "", results, nil
10081074
}
10091075

10101076
for _, pdpAccess := range pdpAccessResults {
@@ -1061,12 +1127,18 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned
10611127
p.Logger.Audit.RewrapSuccess(ctx, auditEventParams)
10621128
}
10631129
}
1064-
return sessionKeyPEM, results
1130+
return sessionKeyPEM, results, nil
10651131
}
10661132

1067-
func (p *Provider) verifyNanoRewrapRequests(ctx context.Context, req *kaspb.UnsignedRewrapRequest_WithPolicyRequest) (*Policy, map[string]kaoResult) {
1133+
func (p *Provider) verifyNanoRewrapRequests(ctx context.Context, req *kaspb.UnsignedRewrapRequest_WithPolicyRequest) (*Policy, map[string]kaoResult, error) {
10681134
results := make(map[string]kaoResult)
10691135

1136+
// Check if req is nil
1137+
if req == nil {
1138+
p.Logger.WarnContext(ctx, "request is nil")
1139+
return nil, nil, errors.New("request is nil")
1140+
}
1141+
10701142
for _, kao := range req.GetKeyAccessObjects() {
10711143
// there should never be multiple KAOs in policy
10721144
if len(req.GetKeyAccessObjects()) != 1 {
@@ -1078,7 +1150,7 @@ func (p *Provider) verifyNanoRewrapRequests(ctx context.Context, req *kaspb.Unsi
10781150
header, _, err := sdk.NewNanoTDFHeaderFromReader(headerReader)
10791151
if err != nil {
10801152
failedKAORewrap(results, kao, fmt.Errorf("failed to parse NanoTDF header: %w", err))
1081-
return nil, results
1153+
return nil, results, nil
10821154
}
10831155
// Lookup KID from nano header
10841156
kid, err := header.GetKasURL().GetIdentifier()
@@ -1101,48 +1173,48 @@ func (p *Provider) verifyNanoRewrapRequests(ctx context.Context, req *kaspb.Unsi
11011173
ecCurve, err := header.ECCurve()
11021174
if err != nil {
11031175
failedKAORewrap(results, kao, fmt.Errorf("ECCurve failed: %w", err))
1104-
return nil, results
1176+
return nil, results, nil
11051177
}
11061178

11071179
symmetricKey, err := p.KeyDelegator.DeriveKey(ctx, trust.KeyIdentifier(kid), header.EphemeralKey, ecCurve)
11081180
if err != nil {
11091181
failedKAORewrap(results, kao, fmt.Errorf("failed to generate symmetric key: %w", err))
1110-
return nil, results
1182+
return nil, results, nil
11111183
}
11121184

11131185
// extract the policy
11141186
policy, err := extractNanoPolicy(symmetricKey, header)
11151187
if err != nil {
11161188
failedKAORewrap(results, kao, fmt.Errorf("Error extracting policy: %w", err))
1117-
return nil, results
1189+
return nil, results, nil
11181190
}
11191191

11201192
// check the policy binding
11211193
binding, err := header.PolicyBinding()
11221194
if err != nil {
11231195
failedKAORewrap(results, kao, fmt.Errorf("failed to retrieve policy binding: %w", err))
1124-
return nil, results
1196+
return nil, results, nil
11251197
}
11261198

11271199
verify, err := binding.Verify()
11281200
if err != nil {
11291201
failedKAORewrap(results, kao, fmt.Errorf("error verifying policy binding: %w", err))
1130-
return nil, results
1202+
return nil, results, nil
11311203
}
11321204

11331205
if !verify {
11341206
failedKAORewrap(results, kao, errors.New("policy binding verification failed"))
1135-
return nil, results
1207+
return nil, results, nil
11361208
}
11371209
results[kao.GetKeyAccessObjectId()] = kaoResult{
11381210
ID: kao.GetKeyAccessObjectId(),
11391211
DEK: symmetricKey,
11401212
KeyID: kid,
11411213
PolicyBinding: binding.String(),
11421214
}
1143-
return policy, results
1215+
return policy, results, nil
11441216
}
1145-
return nil, results
1217+
return nil, results, nil
11461218
}
11471219

11481220
func extractNanoPolicy(symmetricKey ocrypto.ProtectedKey, header sdk.NanoTDFHeader) (*Policy, error) {

0 commit comments

Comments
 (0)