diff --git a/examples/cmd/encrypt.go b/examples/cmd/encrypt.go index 9ef5364ae8..5bc11d8ee7 100644 --- a/examples/cmd/encrypt.go +++ b/examples/cmd/encrypt.go @@ -23,6 +23,7 @@ var ( dataAttributes []string collection int alg string + policyMode string ) func init() { @@ -40,6 +41,7 @@ func init() { encryptCmd.Flags().StringVarP(&outputName, "output", "o", "sensitive.txt.tdf", "name or path of output file; - for stdout") encryptCmd.Flags().StringVarP(&alg, "key-encapsulation-algorithm", "A", "rsa:2048", "Key wrap algorithm algorithm:parameters") encryptCmd.Flags().IntVarP(&collection, "collection", "c", 0, "number of nano's to create for collection. If collection >0 (default) then output will be _") + encryptCmd.Flags().StringVar(&policyMode, "policy-mode", "", "Store policy as encrypted instead of plaintext (nanoTDF only) [plaintext|encrypted]") ExamplesCmd.AddCommand(&encryptCmd) } @@ -146,6 +148,21 @@ func encrypt(cmd *cobra.Command, args []string) error { if err != nil { return err } + + // Handle policy mode if nanoTDF + switch policyMode { + case "": // default to encrypted + case "encrypted": + err = nanoTDFConfig.SetPolicyMode(sdk.NanoTDFPolicyModeEncrypted) + case "plaintext": + err = nanoTDFConfig.SetPolicyMode(sdk.NanoTDFPolicyModePlainText) + default: + err = fmt.Errorf("unsupported policy mode: %s", policyMode) + } + if err != nil { + return err + } + for i, writer := range writer { input := plainText if collection > 0 { diff --git a/sdk/granter_test.go b/sdk/granter_test.go index cc748d6889..f7612554e9 100644 --- a/sdk/granter_test.go +++ b/sdk/granter_test.go @@ -299,7 +299,7 @@ func TestAttributeFromMalformedURL(t *testing.T) { t.Run(tc.n, func(t *testing.T) { a, err := NewAttributeNameFQN(tc.u) require.ErrorIs(t, err, ErrInvalid) - assert.Equal(t, "", a.String()) + assert.Empty(t, a.String()) }) } } @@ -342,7 +342,7 @@ func TestAttributeValueFromMalformedURL(t *testing.T) { t.Run(tc.n, func(t *testing.T) { a, err := NewAttributeValueFQN(tc.u) require.ErrorIs(t, err, ErrInvalid) - assert.Equal(t, "", a.String()) + assert.Empty(t, a.String()) }) } } diff --git a/sdk/nanotdf.go b/sdk/nanotdf.go index b5ef3f153e..6cc73f1a1a 100644 --- a/sdk/nanotdf.go +++ b/sdk/nanotdf.go @@ -61,7 +61,8 @@ type NanoTDFHeader struct { bindCfg bindingConfig sigCfg signatureConfig EphemeralKey []byte - EncryptedPolicyBody []byte + PolicyMode PolicyType + PolicyBody []byte gmacPolicyBinding []byte ecdsaPolicyBindingR []byte ecdsaPolicyBindingS []byte @@ -90,7 +91,7 @@ func (header *NanoTDFHeader) VerifyPolicyBinding() (bool, error) { return false, err } - digest := ocrypto.CalculateSHA256(header.EncryptedPolicyBody) + digest := ocrypto.CalculateSHA256(header.PolicyBody) if header.IsEcdsaBindingEnabled() { ephemeralECDSAPublicKey, err := ocrypto.UncompressECPubKey(curve, header.EphemeralKey) if err != nil { @@ -499,14 +500,15 @@ func writeNanoTDFHeader(writer io.Writer, config NanoTDFConfig) ([]byte, uint32, } totalBytes += uint32(l) - // Write policy - (Policy Mode, Policy length, Policy cipherText, Policy binding) - config.policy.body.mode = policyTypeEmbeddedPolicyEncrypted + // Write policy mode + config.policy.body.mode = config.policyMode l, err = writer.Write([]byte{byte(config.policy.body.mode)}) if err != nil { return nil, 0, 0, err } totalBytes += uint32(l) + // Create policy object policyObj, err := createPolicyObject(config.attributes) if err != nil { return nil, 0, 0, fmt.Errorf("fail to create policy object:%w", err) @@ -517,59 +519,34 @@ func writeNanoTDFHeader(writer io.Writer, config NanoTDFConfig) ([]byte, uint32, return nil, 0, 0, fmt.Errorf("json.Marshal failed:%w", err) } - ecdhKey, err := ocrypto.ConvertToECDHPrivateKey(config.keyPair.PrivateKey) + // Create the symmetric key + symmetricKey, err := createNanoTDFSymmetricKey(config) if err != nil { - return nil, 0, 0, fmt.Errorf("ocrypto.ConvertToECDHPrivateKey failed:%w", err) - } - - symKey, err := ocrypto.ComputeECDHKeyFromECDHKeys(config.kasPublicKey, ecdhKey) - if err != nil { - return nil, 0, 0, fmt.Errorf("ocrypto.ComputeECDHKeyFromEC failed:%w", err) - } - - salt := versionSalt() - - symmetricKey, err := ocrypto.CalculateHKDF(salt, symKey) - if err != nil { - return nil, 0, 0, fmt.Errorf("ocrypto.CalculateHKDF failed:%w", err) - } - - aesGcm, err := ocrypto.NewAESGcm(symmetricKey) - if err != nil { - return nil, 0, 0, fmt.Errorf("ocrypto.NewAESGcm failed:%w", err) + return nil, 0, 0, err } - tagSize, err := SizeOfAuthTagForCipher(config.sigCfg.cipher) - if err != nil { - return nil, 0, 0, fmt.Errorf("SizeOfAuthTagForCipher failed:%w", err) + // Set the symmetric key in the collection config + if config.collectionCfg.useCollection { + config.collectionCfg.symKey = symmetricKey } - const ( - kIvLength = 12 - ) - iv := make([]byte, kIvLength) - cipherText, err := aesGcm.EncryptWithIVAndTagSize(iv, policyObjectAsStr, tagSize) + embeddedP, err := createNanoTDFEmbeddedPolicy(symmetricKey, policyObjectAsStr, config) if err != nil { - return nil, 0, 0, fmt.Errorf("AesGcm.EncryptWithIVAndTagSize failed:%w", err) + return nil, 0, 0, fmt.Errorf("failed to create embedded policy:%w", err) } - embeddedP := embeddedPolicy{ - lengthBody: uint16(len(cipherText) - len(iv)), - body: cipherText[len(iv):], - } err = embeddedP.writeEmbeddedPolicy(writer) if err != nil { return nil, 0, 0, fmt.Errorf("writeEmbeddedPolicy failed:%w", err) } // size of uint16 - const ( - kSizeOfUint16 = 2 - ) + const kSizeOfUint16 = 2 totalBytes += kSizeOfUint16 + uint32(len(embeddedP.body)) digest := ocrypto.CalculateSHA256(embeddedP.body) - if config.bindCfg.useEcdsaBinding { //nolint:nestif // todo: subfunction + + if config.bindCfg.useEcdsaBinding { //nolint:nestif // TODO: refactor rBytes, sBytes, err := ocrypto.ComputeECDSASig(digest, config.keyPair.PrivateKey) if err != nil { return nil, 0, 0, fmt.Errorf("ComputeECDSASig failed:%w", err) @@ -617,10 +594,6 @@ func writeNanoTDFHeader(writer io.Writer, config NanoTDFConfig) ([]byte, uint32, } totalBytes += uint32(l) - if config.collectionCfg.useCollection { - config.collectionCfg.symKey = symmetricKey - } - return symmetricKey, totalBytes, 0, nil } @@ -628,8 +601,8 @@ func NewNanoTDFHeaderFromReader(reader io.Reader) (NanoTDFHeader, uint32, error) header := NanoTDFHeader{} var size uint32 + // Read and validate magic number magicNumber := make([]byte, len(kNanoTDFMagicStringAndVersion)) - l, err := reader.Read(magicNumber) if err != nil { return header, 0, fmt.Errorf(" io.Reader.Read failed :%w", err) @@ -643,7 +616,7 @@ func NewNanoTDFHeaderFromReader(reader io.Reader) (NanoTDFHeader, uint32, error) return header, 0, errors.New("not a valid nano tdf") } - // read resource locator + // Read resource locator resource, err := NewResourceLocatorFromReader(reader) if err != nil { return header, 0, fmt.Errorf("call to NewResourceLocatorFromReader failed :%w", err) @@ -653,7 +626,7 @@ func NewNanoTDFHeaderFromReader(reader io.Reader) (NanoTDFHeader, uint32, error) slog.Debug("NewNanoTDFHeaderFromReader", slog.Uint64("resource locator", uint64(resource.getLength()))) - // read ECC and Binding Mode + // Read ECC and Binding Mode oneBytes := make([]byte, 1) l, err = reader.Read(oneBytes) if err != nil { @@ -662,12 +635,12 @@ func NewNanoTDFHeaderFromReader(reader io.Reader) (NanoTDFHeader, uint32, error) size += uint32(l) header.bindCfg = deserializeBindingCfg(oneBytes[0]) - // check ephemeral ECC Params Enum + // Check ephemeral ECC Params Enum if header.bindCfg.eccMode != ocrypto.ECCModeSecp256r1 { return header, 0, errors.New("current implementation of nano tdf only support secp256r1 curve") } - // read Payload and Sig Mode + // Read Payload and Sig Mode l, err = reader.Read(oneBytes) if err != nil { return header, 0, fmt.Errorf(" io.Reader.Read failed :%w", err) @@ -682,14 +655,13 @@ func NewNanoTDFHeaderFromReader(reader io.Reader) (NanoTDFHeader, uint32, error) } size += uint32(l) - if oneBytes[0] != uint8(policyTypeEmbeddedPolicyEncrypted) { - return header, 0, errors.New(" current implementation only support embedded policy type") + policyMode := PolicyType(oneBytes[0]) + if err := validNanoTDFPolicyMode(policyMode); err != nil { + return header, 0, errors.Join(fmt.Errorf("unsupported policy mode: %v", policyMode), err) } - // read policy length - const ( - kSizeOfUint16 = 2 - ) + // Read policy length + const kSizeOfUint16 = 2 twoBytes := make([]byte, kSizeOfUint16) l, err = reader.Read(twoBytes) if err != nil { @@ -699,17 +671,18 @@ func NewNanoTDFHeaderFromReader(reader io.Reader) (NanoTDFHeader, uint32, error) policyLength := binary.BigEndian.Uint16(twoBytes) slog.Debug("NewNanoTDFHeaderFromReader", slog.Uint64("policyLength", uint64(policyLength))) - // read policy body - header.EncryptedPolicyBody = make([]byte, policyLength) - l, err = reader.Read(header.EncryptedPolicyBody) + // Read policy body + header.PolicyMode = policyMode + header.PolicyBody = make([]byte, policyLength) + l, err = reader.Read(header.PolicyBody) if err != nil { return header, 0, fmt.Errorf(" io.Reader.Read failed :%w", err) } size += uint32(l) - // read policy binding - if header.bindCfg.useEcdsaBinding { //nolint:nestif // todo: subfunction - // read rBytes len and its contents + // Read policy binding + if header.bindCfg.useEcdsaBinding { //nolint:nestif // TODO: refactor + // Read rBytes len and its contents l, err = reader.Read(oneBytes) if err != nil { return header, 0, fmt.Errorf(" io.Reader.Read failed :%w", err) @@ -723,7 +696,7 @@ func NewNanoTDFHeaderFromReader(reader io.Reader) (NanoTDFHeader, uint32, error) } size += uint32(l) - // read sBytes len and its contents + // Read sBytes len and its contents l, err = reader.Read(oneBytes) if err != nil { return header, 0, fmt.Errorf(" io.Reader.Read failed :%w", err) @@ -750,7 +723,7 @@ func NewNanoTDFHeaderFromReader(reader io.Reader) (NanoTDFHeader, uint32, error) return header, 0, fmt.Errorf("getECCKeyLength :%w", err) } - // read ephemeral Key + // Read ephemeral Key ephemeralKey := make([]byte, ephemeralKeySize) l, err = reader.Read(ephemeralKey) if err != nil { @@ -995,8 +968,8 @@ func (n *NanoTDFDecryptHandler) Decrypt(_ context.Context, result []kaoResult) ( payloadLength := binary.BigEndian.Uint32(payloadLengthBuf) slog.Debug("ReadNanoTDF", slog.Uint64("payloadLength", uint64(payloadLength))) - cipherDate := make([]byte, payloadLength) - _, err = n.reader.Read(cipherDate) + cipherData := make([]byte, payloadLength) + _, err = n.reader.Read(cipherData) if err != nil { return 0, fmt.Errorf("readSeeker.Seek failed: %w", err) } @@ -1009,7 +982,7 @@ func (n *NanoTDFDecryptHandler) Decrypt(_ context.Context, result []kaoResult) ( ivPadded := make([]byte, 0, ocrypto.GcmStandardNonceSize) noncePadding := make([]byte, kIvPadding) ivPadded = append(ivPadded, noncePadding...) - iv := cipherDate[:kNanoTDFIvSize] + iv := cipherData[:kNanoTDFIvSize] ivPadded = append(ivPadded, iv...) tagSize, err := SizeOfAuthTagForCipher(n.header.sigCfg.cipher) @@ -1017,7 +990,7 @@ func (n *NanoTDFDecryptHandler) Decrypt(_ context.Context, result []kaoResult) ( return 0, fmt.Errorf("SizeOfAuthTagForCipher failed:%w", err) } - decryptedData, err := aesGcm.DecryptWithIVAndTagSize(ivPadded, cipherDate[kNanoTDFIvSize:], tagSize) + decryptedData, err := aesGcm.DecryptWithIVAndTagSize(ivPadded, cipherData[kNanoTDFIvSize:], tagSize) if err != nil { return 0, err } @@ -1108,3 +1081,28 @@ func versionSalt() []byte { digest.Write([]byte(kNanoTDFMagicStringAndVersion)) return digest.Sum(nil) } + +// createNanoTDFSymmetricKey creates the symmetric key for nanoTDF header +func createNanoTDFSymmetricKey(config NanoTDFConfig) ([]byte, error) { + if config.kasPublicKey == nil { + return nil, fmt.Errorf("KAS public key is required for encrypted policy mode") + } + + ecdhKey, err := ocrypto.ConvertToECDHPrivateKey(config.keyPair.PrivateKey) + if err != nil { + return nil, fmt.Errorf("ocrypto.ConvertToECDHPrivateKey failed:%w", err) + } + + symKey, err := ocrypto.ComputeECDHKeyFromECDHKeys(config.kasPublicKey, ecdhKey) + if err != nil { + return nil, fmt.Errorf("ocrypto.ComputeECDHKeyFromEC failed:%w", err) + } + + salt := versionSalt() + symmetricKey, err := ocrypto.CalculateHKDF(salt, symKey) + if err != nil { + return nil, fmt.Errorf("ocrypto.CalculateHKDF failed:%w", err) + } + + return symmetricKey, nil +} diff --git a/sdk/nanotdf_config.go b/sdk/nanotdf_config.go index 32f98c89ea..5efde5663d 100644 --- a/sdk/nanotdf_config.go +++ b/sdk/nanotdf_config.go @@ -26,6 +26,7 @@ type NanoTDFConfig struct { policy policyInfo bindCfg bindingConfig collectionCfg *collectionConfig + policyMode PolicyType // Added field for policy mode } type NanoTDFOption func(*NanoTDFConfig) error @@ -55,6 +56,7 @@ func (s SDK) NewNanoTDFConfig() (*NanoTDFConfig, error) { useCollection: false, header: []byte{}, }, + policyMode: NanoTDFPolicyModeDefault, } return c, nil @@ -89,6 +91,15 @@ func (config *NanoTDFConfig) EnableCollection() { config.collectionCfg.useCollection = true } +// SetPolicyMode sets whether the policy should be encrypted or plaintext +func (config *NanoTDFConfig) SetPolicyMode(mode PolicyType) error { + if err := validNanoTDFPolicyMode(mode); err != nil { + return err + } + config.policyMode = mode + return nil +} + // WithNanoDataAttributes appends the given data attributes to the bound policy func WithNanoDataAttributes(attributes ...string) NanoTDFOption { return func(c *NanoTDFConfig) error { diff --git a/sdk/nanotdf_config_test.go b/sdk/nanotdf_config_test.go index 8180cefe34..43625f8bef 100644 --- a/sdk/nanotdf_config_test.go +++ b/sdk/nanotdf_config_test.go @@ -128,3 +128,35 @@ func TestWithNanoKasAllowlist_with(t *testing.T) { assert.False(t, config.kasAllowlist.IsAllowed("https://example.com:443"), "Expected KAS URL to not be allowed") }) } + +func TestSetPolicyMode(t *testing.T) { + t.Run("Set to plaintext", func(t *testing.T) { + var s SDK + conf, err := s.NewNanoTDFConfig() + require.NoError(t, err) + + err = conf.SetPolicyMode(NanoTDFPolicyModePlainText) + require.NoError(t, err) + assert.Equal(t, NanoTDFPolicyModePlainText, conf.policyMode) + }) + + t.Run("Set to encrypted", func(t *testing.T) { + var s SDK + conf, err := s.NewNanoTDFConfig() + require.NoError(t, err) + + err = conf.SetPolicyMode(NanoTDFPolicyModeEncrypted) + require.NoError(t, err) + assert.Equal(t, NanoTDFPolicyModeEncrypted, conf.policyMode) + }) + + t.Run("Set to invalid mode", func(t *testing.T) { + var s SDK + conf, err := s.NewNanoTDFConfig() + require.NoError(t, err) + + err = conf.SetPolicyMode(PolicyType(99)) // Assuming 99 is an invalid policyType + require.Error(t, err) + assert.NotEqual(t, PolicyType(99), conf.policyMode) + }) +} diff --git a/sdk/nanotdf_policy.go b/sdk/nanotdf_policy.go index 4c4467c806..a23b268b03 100644 --- a/sdk/nanotdf_policy.go +++ b/sdk/nanotdf_policy.go @@ -3,7 +3,10 @@ package sdk import ( "encoding/binary" "errors" + "fmt" "io" + + "github.com/opentdf/platform/lib/ocrypto" ) // ============================================================================================================ @@ -11,17 +14,24 @@ import ( // // ============================================================================================================ -type policyType uint8 +type PolicyType uint8 const ( - policyTypeRemotePolicy policyType = 0 - policyTypeEmbeddedPolicyPlainText policyType = 1 - policyTypeEmbeddedPolicyEncrypted policyType = 2 - policyTypeEmbeddedPolicyEncryptedPolicyKeyAccess policyType = 3 + NanoTDFPolicyModeRemote PolicyType = iota + NanoTDFPolicyModePlainText + NanoTDFPolicyModeEncrypted + NanoTDFPolicyModeEncryptedPolicyKeyAccess + + NanoTDFPolicyModeDefault = NanoTDFPolicyModeEncrypted +) + +var ( + ErrNanoTDFUnsupportedPolicyMode = errors.New("unsupported policy mode") + ErrNanoTDFInvalidPolicyMode = errors.New("invalid policy mode") ) type PolicyBody struct { - mode policyType + mode PolicyType rp remotePolicy ep embeddedPolicy } @@ -44,20 +54,20 @@ type PolicyBody struct { // readPolicyBody - helper function to decode input data into a PolicyBody object func (pb *PolicyBody) readPolicyBody(reader io.Reader) error { - var mode policyType + var mode PolicyType if err := binary.Read(reader, binary.BigEndian, &mode); err != nil { return err } switch mode { - case policyTypeRemotePolicy: + case NanoTDFPolicyModeRemote: var rl ResourceLocator if err := rl.readResourceLocator(reader); err != nil { return errors.Join(ErrNanoTDFHeaderRead, err) } pb.rp = remotePolicy{url: rl} - case policyTypeEmbeddedPolicyPlainText: - case policyTypeEmbeddedPolicyEncrypted: - case policyTypeEmbeddedPolicyEncryptedPolicyKeyAccess: + case NanoTDFPolicyModeEncrypted: + case NanoTDFPolicyModeEncryptedPolicyKeyAccess: + case NanoTDFPolicyModePlainText: var ep embeddedPolicy if err := ep.readEmbeddedPolicy(reader); err != nil { return errors.Join(ErrNanoTDFHeaderRead, err) @@ -74,7 +84,7 @@ func (pb *PolicyBody) writePolicyBody(writer io.Writer) error { var err error switch pb.mode { - case policyTypeRemotePolicy: // remote policy - resource locator + case NanoTDFPolicyModeRemote: // remote policy - resource locator if err = binary.Write(writer, binary.BigEndian, pb.mode); err != nil { return err } @@ -82,9 +92,9 @@ func (pb *PolicyBody) writePolicyBody(writer io.Writer) error { return err } return nil - case policyTypeEmbeddedPolicyPlainText: - case policyTypeEmbeddedPolicyEncrypted: - case policyTypeEmbeddedPolicyEncryptedPolicyKeyAccess: + case NanoTDFPolicyModeEncrypted: + case NanoTDFPolicyModeEncryptedPolicyKeyAccess: + case NanoTDFPolicyModePlainText: // embedded policy - inline if err = binary.Write(writer, binary.BigEndian, pb.mode); err != nil { return err @@ -97,3 +107,46 @@ func (pb *PolicyBody) writePolicyBody(writer io.Writer) error { } return nil } + +func validNanoTDFPolicyMode(mode PolicyType) error { + switch mode { + case NanoTDFPolicyModePlainText, NanoTDFPolicyModeEncrypted: + return nil + case NanoTDFPolicyModeRemote, NanoTDFPolicyModeEncryptedPolicyKeyAccess: + return ErrNanoTDFUnsupportedPolicyMode + default: + return ErrNanoTDFInvalidPolicyMode + } +} + +// createEmbeddedPolicy creates an embedded policy object, encrypting it if required by the policy mode +func createNanoTDFEmbeddedPolicy(symmetricKey []byte, policyObjectAsStr []byte, config NanoTDFConfig) (embeddedPolicy, error) { + if config.policyMode == NanoTDFPolicyModeEncrypted { + aesGcm, err := ocrypto.NewAESGcm(symmetricKey) + if err != nil { + return embeddedPolicy{}, fmt.Errorf("ocrypto.NewAESGcm failed:%w", err) + } + + tagSize, err := SizeOfAuthTagForCipher(config.sigCfg.cipher) + if err != nil { + return embeddedPolicy{}, fmt.Errorf("SizeOfAuthTagForCipher failed:%w", err) + } + + const kIvLength = 12 + iv := make([]byte, kIvLength) + cipherText, err := aesGcm.EncryptWithIVAndTagSize(iv, policyObjectAsStr, tagSize) + if err != nil { + return embeddedPolicy{}, fmt.Errorf("AesGcm.EncryptWithIVAndTagSize failed:%w", err) + } + + return embeddedPolicy{ + lengthBody: uint16(len(cipherText) - len(iv)), + body: cipherText[len(iv):], + }, nil + } + + return embeddedPolicy{ + lengthBody: uint16(len(policyObjectAsStr)), + body: policyObjectAsStr, + }, nil +} diff --git a/sdk/nanotdf_policy_test.go b/sdk/nanotdf_policy_test.go index dd204261c3..aca63f588d 100644 --- a/sdk/nanotdf_policy_test.go +++ b/sdk/nanotdf_policy_test.go @@ -2,8 +2,13 @@ package sdk import ( "bytes" + "crypto/ecdh" + "crypto/rand" "io" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const ( @@ -15,7 +20,7 @@ const ( // TestNanoTDFPolicyWrite - Create a new policy, write it to a buffer func TestNanoTDFPolicy(t *testing.T) { pb := &PolicyBody{ - mode: policyTypeRemotePolicy, + mode: NanoTDFPolicyModeRemote, rp: remotePolicy{ url: ResourceLocator{ protocol: 1, @@ -44,3 +49,43 @@ func TestNanoTDFPolicy(t *testing.T) { t.Fatal(fullURL) } } + +func TestCreateEmbeddedPolicy(t *testing.T) { + // Test data + policyData := []byte(`{"attributes":["https://example.com/attr/Classification/value/S"]}`) + + t.Run("plaintext policy", func(t *testing.T) { + config, err := new(SDK).NewNanoTDFConfig() + require.NoError(t, err) + err = config.SetPolicyMode(NanoTDFPolicyModePlainText) + require.NoError(t, err) + + policy, err := createNanoTDFEmbeddedPolicy(make([]byte, 32), policyData, *config) + require.NoError(t, err) + assert.Equal(t, uint16(len(policyData)), policy.lengthBody) + assert.Equal(t, policyData, policy.body) + }) + + t.Run("encrypted policy", func(t *testing.T) { + config, err := new(SDK).NewNanoTDFConfig() + require.NoError(t, err) + + // Defaults to encrypted policy + + // Setup KAS public key + key, err := ecdh.P256().GenerateKey(rand.Reader) + require.NoError(t, err) + config.kasPublicKey = key.PublicKey() + + policy, err := createNanoTDFEmbeddedPolicy(make([]byte, 32), policyData, *config) + require.NoError(t, err) + + // Verify the encrypted policy is different from input and has expected length + assert.NotEqual(t, policyData, policy.body) + assert.NotEmpty(t, policy.body, "Encrypted policy body should not be empty") + assert.Equal(t, uint16(len(policy.body)), policy.lengthBody) + + assert.NotEqual(t, policyData, policy.body, "Policy body should be encrypted and different from original data") + assert.NotEmpty(t, policy.body, "Policy body should not be empty after encryption") + }) +} diff --git a/sdk/sdk_test.go b/sdk/sdk_test.go index e3b67176c7..12b1d431f7 100644 --- a/sdk/sdk_test.go +++ b/sdk/sdk_test.go @@ -75,19 +75,19 @@ func TestNew_ShouldCreateSDK(t *testing.T) { func Test_PlatformConfiguration_BadCases(t *testing.T) { assertions := func(t *testing.T, s *sdk.SDK) { iss, err := s.PlatformConfiguration.Issuer() - assert.Equal(t, "", iss) + assert.Empty(t, iss) require.ErrorIs(t, err, sdk.ErrPlatformIssuerNotFound) authzEndpoint, err := s.PlatformConfiguration.AuthzEndpoint() - assert.Equal(t, "", authzEndpoint) + assert.Empty(t, authzEndpoint) require.ErrorIs(t, err, sdk.ErrPlatformAuthzEndpointNotFound) tokenEndpoint, err := s.PlatformConfiguration.TokenEndpoint() - assert.Equal(t, "", tokenEndpoint) + assert.Empty(t, tokenEndpoint) require.ErrorIs(t, err, sdk.ErrPlatformTokenEndpointNotFound) publicClientID, err := s.PlatformConfiguration.PublicClientID() - assert.Equal(t, "", publicClientID) + assert.Empty(t, publicClientID) require.ErrorIs(t, err, sdk.ErrPlatformPublicClientIDNotFound) } diff --git a/sdk/tdf_test.go b/sdk/tdf_test.go index 8b3eb0699f..9d3ae5cedc 100644 --- a/sdk/tdf_test.go +++ b/sdk/tdf_test.go @@ -411,7 +411,7 @@ func (s *TDFSuite) Test_SimpleTDF() { unencryptedMetaData, err := r.UnencryptedMetadata() s.Require().NoError(err) - s.EqualValues(metaData, unencryptedMetaData) + s.Equal(metaData, unencryptedMetaData) dataAttributes, err := r.DataAttributes() s.Require().NoError(err) @@ -438,7 +438,7 @@ func (s *TDFSuite) Test_SimpleTDF() { // check version is present if usehex is false if config.useHex { - s.Equal("", r.Manifest().TDFVersion) + s.Empty(r.Manifest().TDFVersion) } else { s.Equal(TDFSpecVersion, r.Manifest().TDFVersion) } diff --git a/service/kas/access/rewrap.go b/service/kas/access/rewrap.go index 83f846d2ae..40a58c9129 100644 --- a/service/kas/access/rewrap.go +++ b/service/kas/access/rewrap.go @@ -81,6 +81,8 @@ const ( kNanoTDFGMACLength = 8 ErrUser = Error("request error") ErrInternal = Error("internal error") + + ErrNanoTDFPolicyModeUnsupported = Error("unsupported policy mode") ) func err400(s string) error { @@ -879,24 +881,38 @@ func extractNanoPolicy(symmetricKey trust.ProtectedKey, header sdk.NanoTDFHeader const ( kIvLen = 12 ) - // The IV is always an empty 12 bytes for the policy. - iv := make([]byte, kIvLen) - tagSize, err := sdk.SizeOfAuthTagForCipher(header.GetCipher()) - if err != nil { - return nil, fmt.Errorf("SizeOfAuthTagForCipher failed:%w", err) - } - - policyData, err := symmetricKey.DecryptAESGCM(iv, header.EncryptedPolicyBody, tagSize) - if err != nil { - return nil, fmt.Errorf("Error decrypting policy body:%w", err) - } var policy Policy - err = json.Unmarshal(policyData, &policy) - if err != nil { - return nil, fmt.Errorf("Error unmarshalling policy:%w", err) + switch header.PolicyMode { + case sdk.NanoTDFPolicyModePlainText: + err := json.Unmarshal(header.PolicyBody, &policy) + if err != nil { + return nil, fmt.Errorf("error unmarshalling plaintext policy: %w", err) + } + return &policy, nil + + case sdk.NanoTDFPolicyModeEncrypted: + iv := make([]byte, kIvLen) + tagSize, err := sdk.SizeOfAuthTagForCipher(header.GetCipher()) + if err != nil { + return nil, fmt.Errorf("SizeOfAuthTagForCipher failed: %w", err) + } + + policyData, err := symmetricKey.DecryptAESGCM(iv, header.PolicyBody, tagSize) + if err != nil { + return nil, fmt.Errorf("error decrypting policy body: %w", err) + } + + err = json.Unmarshal(policyData, &policy) + if err != nil { + return nil, fmt.Errorf("error unmarshalling encrypted policy: %w", err) + } + return &policy, nil + case sdk.NanoTDFPolicyModeRemote, sdk.NanoTDFPolicyModeEncryptedPolicyKeyAccess: + default: + // noop } - return &policy, nil + return nil, errors.Join(fmt.Errorf("unsupported policy mode: %d", header.PolicyMode), ErrNanoTDFPolicyModeUnsupported) } func failAllKaos(reqs []*kaspb.UnsignedRewrapRequest_WithPolicyRequest, results policyKAOResults, err error) { diff --git a/test/tdf-roundtrips.bats b/test/tdf-roundtrips.bats index 0e323fdcbc..4181d552cc 100755 --- a/test/tdf-roundtrips.bats +++ b/test/tdf-roundtrips.bats @@ -84,7 +84,7 @@ printf '%s\n' "$output" | grep "Hello multikao split" } -@test "examples: roundtrip nanoTDF" { +@test "examples: roundtrip nanoTDF (encrypted policy)" { echo "[INFO] creating nanotdf file" go run ./examples encrypt -o sensitive.txt.ntdf --nano --no-kid-in-nano "Hello NanoTDF" go run ./examples encrypt -o sensitive-kid.txt.ntdf --nano "Hello NanoTDF KID" @@ -96,6 +96,18 @@ go run ./examples decrypt sensitive-kid.txt.ntdf | grep "Hello NanoTDF KID" } +@test "examples: roundtrip nanoTDF (plaintext policy)" { + echo "[INFO] creating nanotdf file" + go run ./examples encrypt -o sensitive-plaintext_policy.txt.ntdf --policy-mode plaintext --nano --no-kid-in-nano "Hello NanoTDF" + go run ./examples encrypt -o sensitive-kid-plaintext_policy.txt.ntdf --policy-mode plaintext --nano "Hello NanoTDF KID" + + echo "[INFO] decrypting nanotdf..." + go run ./examples decrypt sensitive-plaintext_policy.txt.ntdf + go run ./examples decrypt sensitive-plaintext_policy.txt.ntdf | grep "Hello NanoTDF" + go run ./examples decrypt sensitive-kid-plaintext_policy.txt.ntdf + go run ./examples decrypt sensitive-kid-plaintext_policy.txt.ntdf | grep "Hello NanoTDF KID" +} + @test "examples: legacy key support Z-TDF" { echo "[INFO] validating default key is r1" [ "$(grpcurl "localhost:8080" "kas.AccessService/PublicKey" | jq -e -r .kid)" = r1 ]