Skip to content
24 changes: 24 additions & 0 deletions examples/cmd/encrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ var (
dataAttributes []string
collection int
alg string
policyMode string
)

func init() {
Expand All @@ -40,6 +41,7 @@ func init() {
encryptCmd.Flags().StringVarP(&outputName, "output", "o", "sensitive.txt.tdf", "name or path of output file; - for stdout")
encryptCmd.Flags().StringVarP(&alg, "key-encapsulation-algorithm", "A", "rsa:2048", "Key wrap algorithm algorithm:parameters")
encryptCmd.Flags().IntVarP(&collection, "collection", "c", 0, "number of nano's to create for collection. If collection >0 (default) then output will be <iteration>_<output>")
encryptCmd.Flags().StringVar(&policyMode, "policy-mode", "plaintext", "Store policy as encrypted instead of plaintext (nanoTDF only) [plaintext|encrypted|encrypted-policy-key-access]")

ExamplesCmd.AddCommand(&encryptCmd)
}
Expand Down Expand Up @@ -146,6 +148,28 @@ func encrypt(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}

// Handle policy mode if nanoTDF
switch policyMode {
case "plaintext":
err := nanoTDFConfig.SetPolicyMode(sdk.NanoTDFPolicyModePlainText)
if err != nil {
return err
}
case "encrypted":
err := nanoTDFConfig.SetPolicyMode(sdk.NanoTDFPolicyModeEncrypted)
if err != nil {
return err
}
case "encrypted-policy-key-access":
err := nanoTDFConfig.SetPolicyMode(sdk.NanoTDFPolicyModeEncryptedPolicyKeyAccess)
if err != nil {
return err
}
default:
return fmt.Errorf("unsupported policy mode: %s", policyMode)
}

for i, writer := range writer {
input := plainText
if collection > 0 {
Expand Down
4 changes: 2 additions & 2 deletions sdk/granter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ func TestAttributeFromMalformedURL(t *testing.T) {
t.Run(tc.n, func(t *testing.T) {
a, err := NewAttributeNameFQN(tc.u)
require.ErrorIs(t, err, ErrInvalid)
assert.Equal(t, "", a.String())
assert.Empty(t, a.String())
})
}
}
Expand Down Expand Up @@ -342,7 +342,7 @@ func TestAttributeValueFromMalformedURL(t *testing.T) {
t.Run(tc.n, func(t *testing.T) {
a, err := NewAttributeValueFQN(tc.u)
require.ErrorIs(t, err, ErrInvalid)
assert.Equal(t, "", a.String())
assert.Empty(t, a.String())
})
}
}
Expand Down
103 changes: 38 additions & 65 deletions sdk/nanotdf.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ type NanoTDFHeader struct {
bindCfg bindingConfig
sigCfg signatureConfig
EphemeralKey []byte
EncryptedPolicyBody []byte
PolicyMode policyType
PolicyBody []byte
gmacPolicyBinding []byte
ecdsaPolicyBindingR []byte
ecdsaPolicyBindingS []byte
Expand Down Expand Up @@ -90,7 +91,7 @@ func (header *NanoTDFHeader) VerifyPolicyBinding() (bool, error) {
return false, err
}

digest := ocrypto.CalculateSHA256(header.EncryptedPolicyBody)
digest := ocrypto.CalculateSHA256(header.PolicyBody)
if header.IsEcdsaBindingEnabled() {
ephemeralECDSAPublicKey, err := ocrypto.UncompressECPubKey(curve, header.EphemeralKey)
if err != nil {
Expand Down Expand Up @@ -499,14 +500,15 @@ func writeNanoTDFHeader(writer io.Writer, config NanoTDFConfig) ([]byte, uint32,
}
totalBytes += uint32(l)

// Write policy - (Policy Mode, Policy length, Policy cipherText, Policy binding)
config.policy.body.mode = policyTypeEmbeddedPolicyEncrypted
// Write policy mode
config.policy.body.mode = config.policyMode
l, err = writer.Write([]byte{byte(config.policy.body.mode)})
if err != nil {
return nil, 0, 0, err
}
totalBytes += uint32(l)

// Create policy object
policyObj, err := createPolicyObject(config.attributes)
if err != nil {
return nil, 0, 0, fmt.Errorf("fail to create policy object:%w", err)
Expand All @@ -517,59 +519,25 @@ func writeNanoTDFHeader(writer io.Writer, config NanoTDFConfig) ([]byte, uint32,
return nil, 0, 0, fmt.Errorf("json.Marshal failed:%w", err)
}

ecdhKey, err := ocrypto.ConvertToECDHPrivateKey(config.keyPair.PrivateKey)
if err != nil {
return nil, 0, 0, fmt.Errorf("ocrypto.ConvertToECDHPrivateKey failed:%w", err)
}

symKey, err := ocrypto.ComputeECDHKeyFromECDHKeys(config.kasPublicKey, ecdhKey)
if err != nil {
return nil, 0, 0, fmt.Errorf("ocrypto.ComputeECDHKeyFromEC failed:%w", err)
}

salt := versionSalt()

symmetricKey, err := ocrypto.CalculateHKDF(salt, symKey)
if err != nil {
return nil, 0, 0, fmt.Errorf("ocrypto.CalculateHKDF failed:%w", err)
}

aesGcm, err := ocrypto.NewAESGcm(symmetricKey)
if err != nil {
return nil, 0, 0, fmt.Errorf("ocrypto.NewAESGcm failed:%w", err)
}

tagSize, err := SizeOfAuthTagForCipher(config.sigCfg.cipher)
if err != nil {
return nil, 0, 0, fmt.Errorf("SizeOfAuthTagForCipher failed:%w", err)
}
var embeddedP embeddedPolicy

const (
kIvLength = 12
)
iv := make([]byte, kIvLength)
cipherText, err := aesGcm.EncryptWithIVAndTagSize(iv, policyObjectAsStr, tagSize)
embeddedP, err = createNanoTDFEmbeddedPolicy(policyObjectAsStr, config)
if err != nil {
return nil, 0, 0, fmt.Errorf("AesGcm.EncryptWithIVAndTagSize failed:%w", err)
return nil, 0, 0, fmt.Errorf("failed to create embedded policy: %w", err)
}

embeddedP := embeddedPolicy{
lengthBody: uint16(len(cipherText) - len(iv)),
body: cipherText[len(iv):],
}
err = embeddedP.writeEmbeddedPolicy(writer)
if err != nil {
return nil, 0, 0, fmt.Errorf("writeEmbeddedPolicy failed:%w", err)
}

// size of uint16
const (
kSizeOfUint16 = 2
)
const kSizeOfUint16 = 2
totalBytes += kSizeOfUint16 + uint32(len(embeddedP.body))

digest := ocrypto.CalculateSHA256(embeddedP.body)
if config.bindCfg.useEcdsaBinding { //nolint:nestif // todo: subfunction

if config.bindCfg.useEcdsaBinding { //nolint:nestif // TODO: refactor
rBytes, sBytes, err := ocrypto.ComputeECDSASig(digest, config.keyPair.PrivateKey)
if err != nil {
return nil, 0, 0, fmt.Errorf("ComputeECDSASig failed:%w", err)
Expand Down Expand Up @@ -617,19 +585,24 @@ func writeNanoTDFHeader(writer io.Writer, config NanoTDFConfig) ([]byte, uint32,
}
totalBytes += uint32(l)

symKey, err := ocrypto.RandomBytes(kKeySize)
if err != nil {
return nil, 0, 0, fmt.Errorf("ocrypto.RandomBytes failed:%w", err)
}

if config.collectionCfg.useCollection {
config.collectionCfg.symKey = symmetricKey
config.collectionCfg.symKey = symKey
}

return symmetricKey, totalBytes, 0, nil
return symKey, totalBytes, 0, nil
}

func NewNanoTDFHeaderFromReader(reader io.Reader) (NanoTDFHeader, uint32, error) {
header := NanoTDFHeader{}
var size uint32

// Read and validate magic number
magicNumber := make([]byte, len(kNanoTDFMagicStringAndVersion))

l, err := reader.Read(magicNumber)
if err != nil {
return header, 0, fmt.Errorf(" io.Reader.Read failed :%w", err)
Expand All @@ -643,7 +616,7 @@ func NewNanoTDFHeaderFromReader(reader io.Reader) (NanoTDFHeader, uint32, error)
return header, 0, errors.New("not a valid nano tdf")
}

// read resource locator
// Read resource locator
resource, err := NewResourceLocatorFromReader(reader)
if err != nil {
return header, 0, fmt.Errorf("call to NewResourceLocatorFromReader failed :%w", err)
Expand All @@ -653,7 +626,7 @@ func NewNanoTDFHeaderFromReader(reader io.Reader) (NanoTDFHeader, uint32, error)

slog.Debug("NewNanoTDFHeaderFromReader", slog.Uint64("resource locator", uint64(resource.getLength())))

// read ECC and Binding Mode
// Read ECC and Binding Mode
oneBytes := make([]byte, 1)
l, err = reader.Read(oneBytes)
if err != nil {
Expand All @@ -662,12 +635,12 @@ func NewNanoTDFHeaderFromReader(reader io.Reader) (NanoTDFHeader, uint32, error)
size += uint32(l)
header.bindCfg = deserializeBindingCfg(oneBytes[0])

// check ephemeral ECC Params Enum
// Check ephemeral ECC Params Enum
if header.bindCfg.eccMode != ocrypto.ECCModeSecp256r1 {
return header, 0, errors.New("current implementation of nano tdf only support secp256r1 curve")
}

// read Payload and Sig Mode
// Read Payload and Sig Mode
l, err = reader.Read(oneBytes)
if err != nil {
return header, 0, fmt.Errorf(" io.Reader.Read failed :%w", err)
Expand All @@ -682,14 +655,13 @@ func NewNanoTDFHeaderFromReader(reader io.Reader) (NanoTDFHeader, uint32, error)
}
size += uint32(l)

if oneBytes[0] != uint8(policyTypeEmbeddedPolicyEncrypted) {
return header, 0, errors.New(" current implementation only support embedded policy type")
policyMode := policyType(oneBytes[0])
if !validNanoTDFPolicyMode(policyMode) {
return header, 0, fmt.Errorf("unsupported policy mode: %v", policyMode)
}

// read policy length
const (
kSizeOfUint16 = 2
)
// Read policy length
const kSizeOfUint16 = 2
twoBytes := make([]byte, kSizeOfUint16)
l, err = reader.Read(twoBytes)
if err != nil {
Expand All @@ -699,17 +671,18 @@ func NewNanoTDFHeaderFromReader(reader io.Reader) (NanoTDFHeader, uint32, error)
policyLength := binary.BigEndian.Uint16(twoBytes)
slog.Debug("NewNanoTDFHeaderFromReader", slog.Uint64("policyLength", uint64(policyLength)))

// read policy body
header.EncryptedPolicyBody = make([]byte, policyLength)
l, err = reader.Read(header.EncryptedPolicyBody)
// Read policy body
header.PolicyMode = policyMode
header.PolicyBody = make([]byte, policyLength)
l, err = reader.Read(header.PolicyBody)
if err != nil {
return header, 0, fmt.Errorf(" io.Reader.Read failed :%w", err)
}
size += uint32(l)

// read policy binding
if header.bindCfg.useEcdsaBinding { //nolint:nestif // todo: subfunction
// read rBytes len and its contents
// Read policy binding
if header.bindCfg.useEcdsaBinding { //nolint:nestif // TODO: refactor
// Read rBytes len and its contents
l, err = reader.Read(oneBytes)
if err != nil {
return header, 0, fmt.Errorf(" io.Reader.Read failed :%w", err)
Expand All @@ -723,7 +696,7 @@ func NewNanoTDFHeaderFromReader(reader io.Reader) (NanoTDFHeader, uint32, error)
}
size += uint32(l)

// read sBytes len and its contents
// Read sBytes len and its contents
l, err = reader.Read(oneBytes)
if err != nil {
return header, 0, fmt.Errorf(" io.Reader.Read failed :%w", err)
Expand All @@ -750,7 +723,7 @@ func NewNanoTDFHeaderFromReader(reader io.Reader) (NanoTDFHeader, uint32, error)
return header, 0, fmt.Errorf("getECCKeyLength :%w", err)
}

// read ephemeral Key
// Read ephemeral Key
ephemeralKey := make([]byte, ephemeralKeySize)
l, err = reader.Read(ephemeralKey)
if err != nil {
Expand Down
11 changes: 11 additions & 0 deletions sdk/nanotdf_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type NanoTDFConfig struct {
policy policyInfo
bindCfg bindingConfig
collectionCfg *collectionConfig
policyMode policyType // Added field for policy mode
}

type NanoTDFOption func(*NanoTDFConfig) error
Expand Down Expand Up @@ -55,6 +56,7 @@ func (s SDK) NewNanoTDFConfig() (*NanoTDFConfig, error) {
useCollection: false,
header: []byte{},
},
policyMode: NanoTDFPolicyModePlainText,
}

return c, nil
Expand Down Expand Up @@ -89,6 +91,15 @@ func (config *NanoTDFConfig) EnableCollection() {
config.collectionCfg.useCollection = true
}

// SetPolicyMode sets whether the policy should be encrypted or plaintext
func (config *NanoTDFConfig) SetPolicyMode(mode policyType) error {
if !validNanoTDFPolicyMode(mode) {
return fmt.Errorf("invalid policy mode")
}
config.policyMode = mode
return nil
}

// WithNanoDataAttributes appends the given data attributes to the bound policy
func WithNanoDataAttributes(attributes ...string) NanoTDFOption {
return func(c *NanoTDFConfig) error {
Expand Down
32 changes: 32 additions & 0 deletions sdk/nanotdf_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,35 @@ func TestWithNanoKasAllowlist_with(t *testing.T) {
assert.False(t, config.kasAllowlist.IsAllowed("https://example.com:443"), "Expected KAS URL to not be allowed")
})
}

func TestSetPolicyMode(t *testing.T) {
t.Run("Set to plaintext", func(t *testing.T) {
var s SDK
conf, err := s.NewNanoTDFConfig()
require.NoError(t, err)

err = conf.SetPolicyMode(NanoTDFPolicyModePlainText)
require.NoError(t, err)
assert.Equal(t, NanoTDFPolicyModePlainText, conf.policyMode)
})

t.Run("Set to encrypted", func(t *testing.T) {
var s SDK
conf, err := s.NewNanoTDFConfig()
require.NoError(t, err)

err = conf.SetPolicyMode(NanoTDFPolicyModeEncrypted)
require.NoError(t, err)
assert.Equal(t, NanoTDFPolicyModeEncrypted, conf.policyMode)
})

t.Run("Set to invalid mode", func(t *testing.T) {
var s SDK
conf, err := s.NewNanoTDFConfig()
require.NoError(t, err)

err = conf.SetPolicyMode(policyType(99)) // Assuming 99 is an invalid policyType
require.Error(t, err)
assert.NotEqual(t, policyType(99), conf.policyMode)
})
}
Loading
Loading