@@ -108,6 +108,8 @@ const (
108108 ErrInternal = Error ("internal error" )
109109
110110 ErrNanoTDFPolicyModeUnsupported = Error ("unsupported policy mode" )
111+
112+ errNoValidKeyAccessObjects = Error ("no valid KAOs" )
111113)
112114
113115func err400 (s string ) error {
@@ -597,10 +599,18 @@ func (p *Provider) Rewrap(ctx context.Context, req *connect.Request[kaspb.Rewrap
597599 return nil , err400 (err .Error ())
598600 }
599601 if len (tdf3Reqs ) > 0 {
600- resp .SessionPublicKey , results = p .tdf3Rewrap (ctx , tdf3Reqs , body .GetClientPublicKey (), entityInfo , additionalRewrapContext )
602+ resp .SessionPublicKey , results , err = p .tdf3Rewrap (ctx , tdf3Reqs , body .GetClientPublicKey (), entityInfo , additionalRewrapContext )
603+ if err != nil {
604+ p .Logger .WarnContext (ctx , "status 400, tdf3 rewrap failure" , slog .Any ("error" , err ))
605+ return nil , err
606+ }
601607 addResultsToResponse (resp , results )
602608 } else {
603- resp .SessionPublicKey , results = p .nanoTDFRewrap (ctx , nanoReqs , body .GetClientPublicKey (), entityInfo , additionalRewrapContext )
609+ resp .SessionPublicKey , results , err = p .nanoTDFRewrap (ctx , nanoReqs , body .GetClientPublicKey (), entityInfo , additionalRewrapContext )
610+ if err != nil {
611+ p .Logger .WarnContext (ctx , "status 400, nanoTDF rewrap failure" , slog .Any ("error" , err ))
612+ return nil , err
613+ }
604614 addResultsToResponse (resp , results )
605615 }
606616
@@ -631,15 +641,31 @@ func (p *Provider) Rewrap(ctx context.Context, req *connect.Request[kaspb.Rewrap
631641}
632642
633643func (p * Provider ) verifyRewrapRequests (ctx context.Context , req * kaspb.UnsignedRewrapRequest_WithPolicyRequest ) (* Policy , map [string ]kaoResult , error ) {
634- ctx , span := p .Start (ctx , "tdf3Rewrap" )
635- defer span .End ()
644+ // Safe tracer handling - only start span if tracer is available
645+ var span trace.Span
646+ if p .Tracer != nil {
647+ ctx , span = p .Start (ctx , "tdf3Rewrap" )
648+ defer span .End ()
649+ }
636650
637651 results := make (map [string ]kaoResult )
638652 anyValidKAOs := false
653+ policy := & Policy {}
654+
655+ // Check if req is nil
656+ if req == nil {
657+ p .Logger .WarnContext (ctx , "request is nil" )
658+ return nil , results , errors .New ("request is nil" )
659+ }
660+
661+ // Check if policy is nil
662+ if req .GetPolicy () == nil {
663+ p .Logger .WarnContext (ctx , "policy is nil" )
664+ return nil , results , errors .New ("policy is nil" )
665+ }
639666
640667 p .Logger .DebugContext (ctx , "extracting policy" , slog .Any ("policy" , req .GetPolicy ()))
641668 sDecPolicy , policyErr := base64 .StdEncoding .DecodeString (req .GetPolicy ().GetBody ())
642- policy := & Policy {}
643669 if policyErr == nil {
644670 policyErr = json .Unmarshal (sDecPolicy , policy )
645671 }
@@ -650,6 +676,21 @@ func (p *Provider) verifyRewrapRequests(ctx context.Context, req *kaspb.Unsigned
650676 continue
651677 }
652678
679+ // Check if KeyAccessObject is nil
680+ if kao .GetKeyAccessObject () == nil {
681+ p .Logger .WarnContext (ctx , "key access object is nil" , slog .String ("kao_id" , kao .GetKeyAccessObjectId ()))
682+ failedKAORewrap (results , kao , err400 ("bad request" ))
683+ continue
684+ }
685+
686+ // Check if wrapped key is empty
687+ wrappedKey := kao .GetKeyAccessObject ().GetWrappedKey ()
688+ if len (wrappedKey ) == 0 {
689+ p .Logger .WarnContext (ctx , "wrapped key is empty" , slog .String ("kao_id" , kao .GetKeyAccessObjectId ()))
690+ failedKAORewrap (results , kao , err400 ("bad request" ))
691+ continue
692+ }
693+
653694 var dek ocrypto.ProtectedKey
654695 var err error
655696 switch kao .GetKeyAccessObject ().GetKeyType () {
@@ -754,13 +795,28 @@ func (p *Provider) verifyRewrapRequests(ctx context.Context, req *kaspb.Unsigned
754795 }
755796 dek , err = p .KeyDelegator .Decrypt (ctx , kid , kao .GetKeyAccessObject ().GetWrappedKey (), nil )
756797 }
798+ default :
799+ // handle unsupported key types
800+ keyType := kao .GetKeyAccessObject ().GetKeyType ()
801+ p .Logger .WarnContext (ctx , "unsupported key type" ,
802+ slog .String ("key_type" , keyType ),
803+ slog .String ("kao_id" , kao .GetKeyAccessObjectId ()))
804+ failedKAORewrap (results , kao , err400 ("bad request" ))
805+ continue
757806 }
758807 if err != nil {
759808 p .Logger .WarnContext (ctx , "failure to decrypt dek" , slog .Any ("error" , err ))
760809 failedKAORewrap (results , kao , err400 ("bad request" ))
761810 continue
762811 }
763812
813+ // Check if policy binding is nil
814+ if kao .GetKeyAccessObject ().GetPolicyBinding () == nil {
815+ p .Logger .WarnContext (ctx , "policy binding is nil" , slog .String ("kao_id" , kao .GetKeyAccessObjectId ()))
816+ failedKAORewrap (results , kao , err400 ("missing policy binding" ))
817+ continue
818+ }
819+
764820 // Store policy binding in context for verification
765821 policyBindingB64Encoded := kao .GetKeyAccessObject ().GetPolicyBinding ().GetHash ()
766822 policyBinding := make ([]byte , base64 .StdEncoding .DecodedLen (len (policyBindingB64Encoded )))
@@ -801,7 +857,7 @@ func (p *Provider) verifyRewrapRequests(ctx context.Context, req *kaspb.Unsigned
801857
802858 if ! anyValidKAOs {
803859 p .Logger .WarnContext (ctx , "no valid KAOs found" )
804- return policy , results , errors . New ( "no valid KAOs" )
860+ return policy , results , errNoValidKeyAccessObjects
805861 }
806862
807863 return policy , results , nil
@@ -833,7 +889,7 @@ func (p *Provider) listLegacyKeys(ctx context.Context) []trust.KeyIdentifier {
833889 return kidsToCheck
834890}
835891
836- func (p * Provider ) tdf3Rewrap (ctx context.Context , requests []* kaspb.UnsignedRewrapRequest_WithPolicyRequest , clientPublicKey string , entityInfo * entityInfo , additionalRewrapContext * AdditionalRewrapContext ) (string , policyKAOResults ) {
892+ func (p * Provider ) tdf3Rewrap (ctx context.Context , requests []* kaspb.UnsignedRewrapRequest_WithPolicyRequest , clientPublicKey string , entityInfo * entityInfo , additionalRewrapContext * AdditionalRewrapContext ) (string , policyKAOResults , error ) {
837893 if p .Tracer != nil {
838894 var span trace.Span
839895 ctx , span = p .Start (ctx , "rewrap-tdf3" )
@@ -844,7 +900,14 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew
844900 var policies []* Policy
845901 policyReqs := make (map [* Policy ]* kaspb.UnsignedRewrapRequest_WithPolicyRequest )
846902 for _ , req := range requests {
903+ if req == nil || req .GetPolicy () == nil || req .GetPolicy ().GetId () == "" {
904+ p .Logger .WarnContext (ctx , "rewrap: nil request or policy" )
905+ continue
906+ }
847907 policy , kaoResults , err := p .verifyRewrapRequests (ctx , req )
908+ if err != nil && ! errors .Is (err , errNoValidKeyAccessObjects ) {
909+ return "" , nil , err400 ("invalid request" )
910+ }
848911 policyID := req .GetPolicy ().GetId ()
849912 results [policyID ] = kaoResults
850913 if err != nil {
@@ -872,14 +935,14 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew
872935 slog .Any ("error" , accessErr ),
873936 )
874937 failAllKaos (requests , results , err500 ("could not perform access" ))
875- return "" , results
938+ return "" , results , nil
876939 }
877940
878941 asymEncrypt , err := ocrypto .FromPublicPEMWithSalt (clientPublicKey , security .TDFSalt (), nil )
879942 if err != nil {
880943 p .Logger .WarnContext (ctx , "ocrypto.NewAsymEncryption" , slog .Any ("error" , err ))
881944 failAllKaos (requests , results , err400 ("invalid request" ))
882- return "" , results
945+ return "" , results , nil
883946 }
884947 encap := security.OCEncapsulator {PublicKeyEncryptor : asymEncrypt }
885948
@@ -890,12 +953,12 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew
890953 p .Logger .ErrorContext (ctx , "unable to serialize ephemeral key" , slog .Any ("error" , err ))
891954 // This may be a 500, but could also be caused by a bad clientPublicKey
892955 failAllKaos (requests , results , err400 ("invalid request" ))
893- return "" , results
956+ return "" , results , nil
894957 }
895958 if ! p .ECTDFEnabled && ! p .Preview .ECTDFEnabled {
896959 p .Logger .ErrorContext (ctx , "ec rewrap not enabled" )
897960 failAllKaos (requests , results , err400 ("invalid request" ))
898- return "" , results
961+ return "" , results , nil
899962 }
900963 }
901964
@@ -962,10 +1025,10 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew
9621025 p .Logger .Audit .RewrapSuccess (ctx , auditEventParams )
9631026 }
9641027 }
965- return sessionKey , results
1028+ return sessionKey , results , nil
9661029}
9671030
968- func (p * Provider ) nanoTDFRewrap (ctx context.Context , requests []* kaspb.UnsignedRewrapRequest_WithPolicyRequest , clientPublicKey string , entityInfo * entityInfo , additionalRewrapContext * AdditionalRewrapContext ) (string , policyKAOResults ) {
1031+ func (p * Provider ) nanoTDFRewrap (ctx context.Context , requests []* kaspb.UnsignedRewrapRequest_WithPolicyRequest , clientPublicKey string , entityInfo * entityInfo , additionalRewrapContext * AdditionalRewrapContext ) (string , policyKAOResults , error ) {
9691032 ctx , span := p .Start (ctx , "nanoTDFRewrap" )
9701033 defer span .End ()
9711034
@@ -975,7 +1038,10 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned
9751038 policyReqs := make (map [* Policy ]* kaspb.UnsignedRewrapRequest_WithPolicyRequest )
9761039
9771040 for _ , req := range requests {
978- policy , kaoResults := p .verifyNanoRewrapRequests (ctx , req )
1041+ policy , kaoResults , err := p .verifyNanoRewrapRequests (ctx , req )
1042+ if err != nil {
1043+ return "" , nil , err400 ("invalid request" )
1044+ }
9791045 results [req .GetPolicy ().GetId ()] = kaoResults
9801046 if policy != nil {
9811047 policies = append (policies , policy )
@@ -991,20 +1057,20 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned
9911057 pdpAccessResults , accessErr := p .canAccess (ctx , tok , policies , additionalRewrapContext .Obligations .FulfillableFQNs )
9921058 if accessErr != nil {
9931059 failAllKaos (requests , results , err500 ("could not perform access" ))
994- return "" , results
1060+ return "" , results , nil
9951061 }
9961062
9971063 sessionKey , err := p .KeyDelegator .GenerateECSessionKey (ctx , clientPublicKey )
9981064 if err != nil {
9991065 p .Logger .WarnContext (ctx , "failure in GenerateNanoTDFSessionKey" , slog .Any ("error" , err ))
10001066 failAllKaos (requests , results , err400 ("keypair mismatch" ))
1001- return "" , results
1067+ return "" , results , nil
10021068 }
10031069 sessionKeyPEM , err := sessionKey .PublicKeyAsPEM ()
10041070 if err != nil {
10051071 p .Logger .WarnContext (ctx , "failure in PublicKeyToPem" , slog .Any ("error" , err ))
10061072 failAllKaos (requests , results , err500 ("" ))
1007- return "" , results
1073+ return "" , results , nil
10081074 }
10091075
10101076 for _ , pdpAccess := range pdpAccessResults {
@@ -1061,12 +1127,18 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned
10611127 p .Logger .Audit .RewrapSuccess (ctx , auditEventParams )
10621128 }
10631129 }
1064- return sessionKeyPEM , results
1130+ return sessionKeyPEM , results , nil
10651131}
10661132
1067- func (p * Provider ) verifyNanoRewrapRequests (ctx context.Context , req * kaspb.UnsignedRewrapRequest_WithPolicyRequest ) (* Policy , map [string ]kaoResult ) {
1133+ func (p * Provider ) verifyNanoRewrapRequests (ctx context.Context , req * kaspb.UnsignedRewrapRequest_WithPolicyRequest ) (* Policy , map [string ]kaoResult , error ) {
10681134 results := make (map [string ]kaoResult )
10691135
1136+ // Check if req is nil
1137+ if req == nil {
1138+ p .Logger .WarnContext (ctx , "request is nil" )
1139+ return nil , nil , errors .New ("request is nil" )
1140+ }
1141+
10701142 for _ , kao := range req .GetKeyAccessObjects () {
10711143 // there should never be multiple KAOs in policy
10721144 if len (req .GetKeyAccessObjects ()) != 1 {
@@ -1078,7 +1150,7 @@ func (p *Provider) verifyNanoRewrapRequests(ctx context.Context, req *kaspb.Unsi
10781150 header , _ , err := sdk .NewNanoTDFHeaderFromReader (headerReader )
10791151 if err != nil {
10801152 failedKAORewrap (results , kao , fmt .Errorf ("failed to parse NanoTDF header: %w" , err ))
1081- return nil , results
1153+ return nil , results , nil
10821154 }
10831155 // Lookup KID from nano header
10841156 kid , err := header .GetKasURL ().GetIdentifier ()
@@ -1101,48 +1173,48 @@ func (p *Provider) verifyNanoRewrapRequests(ctx context.Context, req *kaspb.Unsi
11011173 ecCurve , err := header .ECCurve ()
11021174 if err != nil {
11031175 failedKAORewrap (results , kao , fmt .Errorf ("ECCurve failed: %w" , err ))
1104- return nil , results
1176+ return nil , results , nil
11051177 }
11061178
11071179 symmetricKey , err := p .KeyDelegator .DeriveKey (ctx , trust .KeyIdentifier (kid ), header .EphemeralKey , ecCurve )
11081180 if err != nil {
11091181 failedKAORewrap (results , kao , fmt .Errorf ("failed to generate symmetric key: %w" , err ))
1110- return nil , results
1182+ return nil , results , nil
11111183 }
11121184
11131185 // extract the policy
11141186 policy , err := extractNanoPolicy (symmetricKey , header )
11151187 if err != nil {
11161188 failedKAORewrap (results , kao , fmt .Errorf ("Error extracting policy: %w" , err ))
1117- return nil , results
1189+ return nil , results , nil
11181190 }
11191191
11201192 // check the policy binding
11211193 binding , err := header .PolicyBinding ()
11221194 if err != nil {
11231195 failedKAORewrap (results , kao , fmt .Errorf ("failed to retrieve policy binding: %w" , err ))
1124- return nil , results
1196+ return nil , results , nil
11251197 }
11261198
11271199 verify , err := binding .Verify ()
11281200 if err != nil {
11291201 failedKAORewrap (results , kao , fmt .Errorf ("error verifying policy binding: %w" , err ))
1130- return nil , results
1202+ return nil , results , nil
11311203 }
11321204
11331205 if ! verify {
11341206 failedKAORewrap (results , kao , errors .New ("policy binding verification failed" ))
1135- return nil , results
1207+ return nil , results , nil
11361208 }
11371209 results [kao .GetKeyAccessObjectId ()] = kaoResult {
11381210 ID : kao .GetKeyAccessObjectId (),
11391211 DEK : symmetricKey ,
11401212 KeyID : kid ,
11411213 PolicyBinding : binding .String (),
11421214 }
1143- return policy , results
1215+ return policy , results , nil
11441216 }
1145- return nil , results
1217+ return nil , results , nil
11461218}
11471219
11481220func extractNanoPolicy (symmetricKey ocrypto.ProtectedKey , header sdk.NanoTDFHeader ) (* Policy , error ) {
0 commit comments