diff --git a/go.mod b/go.mod index 4526934e6c..f9fe46a320 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,8 @@ require ( github.com/bufbuild/protovalidate-go v0.4.3 github.com/creasty/defaults v1.7.0 github.com/go-chi/cors v1.2.1 + github.com/golang-jwt/jwt/v4 v4.5.0 + github.com/google/uuid v1.4.0 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.1 github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.1 github.com/jackc/pgx/v5 v5.5.0 @@ -51,7 +53,6 @@ require ( github.com/gobwas/glob v0.2.3 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/google/cel-go v0.18.2 // indirect - github.com/google/uuid v1.4.0 // indirect github.com/gorilla/mux v1.8.1 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/klauspost/compress v1.17.2 // indirect diff --git a/internal/archive/tdf3_reader.go b/internal/archive/tdf3_reader.go index 1ba87df0d6..bf19a6f3e4 100644 --- a/internal/archive/tdf3_reader.go +++ b/internal/archive/tdf3_reader.go @@ -9,8 +9,8 @@ type TDFReader struct { } const ( - tdfManifestFileName = "0.manifest.json" - tdfPayloadFileName = "0.payload" + TDFManifestFileName = "0.manifest.json" + TDFPayloadFileName = "0.payload" ) // NewTDFReader Create tdf reader instance. @@ -28,7 +28,7 @@ func NewTDFReader(readSeeker io.ReadSeeker) (TDFReader, error) { // Manifest Return the manifest of the tdf. func (tdfReader TDFReader) Manifest() (string, error) { - fileContent, err := tdfReader.archiveReader.ReadAllFileData(tdfManifestFileName) + fileContent, err := tdfReader.archiveReader.ReadAllFileData(TDFManifestFileName) if err != nil { return "", err } @@ -37,7 +37,7 @@ func (tdfReader TDFReader) Manifest() (string, error) { // ReadPayload Return the payload of given length from index. func (tdfReader TDFReader) ReadPayload(index, length int64) ([]byte, error) { - buf, err := tdfReader.archiveReader.ReadFileData(tdfPayloadFileName, index, length) + buf, err := tdfReader.archiveReader.ReadFileData(TDFPayloadFileName, index, length) if err != nil { return nil, err } @@ -46,7 +46,7 @@ func (tdfReader TDFReader) ReadPayload(index, length int64) ([]byte, error) { // PayloadSize Return the size of the payload. func (tdfReader TDFReader) PayloadSize() (int64, error) { - size, err := tdfReader.archiveReader.ReadFileSize(tdfPayloadFileName) + size, err := tdfReader.archiveReader.ReadFileSize(TDFPayloadFileName) if err != nil { return -1, err } diff --git a/internal/archive/tdf3_writer.go b/internal/archive/tdf3_writer.go index e89f09125e..65754c5fef 100644 --- a/internal/archive/tdf3_writer.go +++ b/internal/archive/tdf3_writer.go @@ -4,6 +4,7 @@ import "io" type TDFWriter struct { archiveWriter *Writer + totalBytes int64 } // NewTDFWriter Create tdf writer instance. @@ -20,12 +21,12 @@ func (tdfWriter *TDFWriter) SetPayloadSize(payloadSize int64) error { tdfWriter.archiveWriter.EnableZip64() } - return tdfWriter.archiveWriter.AddHeader(tdfPayloadFileName, payloadSize) + return tdfWriter.archiveWriter.AddHeader(TDFPayloadFileName, payloadSize) } -// AppendManifest Add the manifest to tdf3 archive. +// AppendManifest Add the manifest to tdf archive. func (tdfWriter *TDFWriter) AppendManifest(manifest string) error { - err := tdfWriter.archiveWriter.AddHeader(tdfManifestFileName, int64(len(manifest))) + err := tdfWriter.archiveWriter.AddHeader(TDFManifestFileName, int64(len(manifest))) if err != nil { return err } @@ -33,12 +34,12 @@ func (tdfWriter *TDFWriter) AppendManifest(manifest string) error { return tdfWriter.archiveWriter.AddData([]byte(manifest)) } -// AppendPayload Add payload to tdf3 archive. +// AppendPayload Add payload to sdk archive. func (tdfWriter *TDFWriter) AppendPayload(data []byte) error { return tdfWriter.archiveWriter.AddData(data) } -// Close Completed adding all the files in zip archive. -func (tdfWriter *TDFWriter) Close() error { - return tdfWriter.archiveWriter.Close() +// Finish Finished adding all the files in zip archive. +func (tdfWriter *TDFWriter) Finish() (int64, error) { + return tdfWriter.archiveWriter.Finish() } diff --git a/internal/archive/tdf3_writer_reader_test.go b/internal/archive/tdf3_writer_reader_test.go index f153b92ed2..53878356ce 100644 --- a/internal/archive/tdf3_writer_reader_test.go +++ b/internal/archive/tdf3_writer_reader_test.go @@ -10,12 +10,14 @@ import ( type TDF3Entry struct { manifest string payloadSize int64 + tdfSize int64 } var TDF3Tests = []TDF3Entry{ //nolint:gochecknoglobals // This global is used as test harness for other tests { manifest: "some manifest", payloadSize: oneKB, + tdfSize: 1291, }, { manifest: `{ @@ -63,6 +65,7 @@ var TDF3Tests = []TDF3Entry{ //nolint:gochecknoglobals // This global is used as } }`, payloadSize: 10 * oneMB, + tdfSize: 10487693, }, { manifest: `{ @@ -110,6 +113,7 @@ var TDF3Tests = []TDF3Entry{ //nolint:gochecknoglobals // This global is used as } }`, payloadSize: 3 * oneGB, + tdfSize: 3145729933, }, { manifest: `{ @@ -157,6 +161,7 @@ var TDF3Tests = []TDF3Entry{ //nolint:gochecknoglobals // This global is used as } }`, payloadSize: 10 * oneGB, + tdfSize: 10485762121, }, } @@ -190,12 +195,6 @@ func writeTDFs(t *testing.T) { }(writer) tdf3Writer := NewTDFWriter(writer) - defer func(tdf3Writer *TDFWriter) { - err := tdf3Writer.Close() - if err != nil { - t.Fatalf("Fail to close tdf3 writer: %v", err) - } - }(tdf3Writer) // write payload totalBytes := tdf3Entry.payloadSize @@ -225,6 +224,15 @@ func writeTDFs(t *testing.T) { if err != nil { t.Fatalf("Fail to add payload to tdf3 writer: %v", err) } + + tdfSize, err := tdf3Writer.Finish() + if err != nil { + t.Fatalf("Fail to close tdf3 writer: %v", err) + } + + if tdfSize != tdf3Entry.tdfSize { + t.Errorf("tdf size test failed expected %v, got %v", tdfSize, tdf3Entry.tdfSize) + } } } diff --git a/internal/archive/writer.go b/internal/archive/writer.go index a90d47e49e..85c2cf5035 100644 --- a/internal/archive/writer.go +++ b/internal/archive/writer.go @@ -62,6 +62,7 @@ type Writer struct { fileInfoEntries []FileInfo writeState WriteState isZip64 bool + totalBytes int64 } // NewWriter Create tdf3 writer instance. @@ -137,13 +138,13 @@ func (writer *Writer) AddData(data []byte) error { return fmt.Errorf("binary.Write failed: %w", err) } - _, err = writer.writer.Write(buf.Bytes()) + err = writer.writeData(buf.Bytes()) if err != nil { return fmt.Errorf("io.Writer.Write failed: %w", err) } // write the file name - _, err = writer.writer.Write([]byte(writer.FileInfo.filename)) + err = writer.writeData([]byte(writer.FileInfo.filename)) if err != nil { return fmt.Errorf("io.Writer.Write failed: %w", err) } @@ -161,7 +162,7 @@ func (writer *Writer) AddData(data []byte) error { return fmt.Errorf("binary.Write failed: %w", err) } - _, err = writer.writer.Write(buf.Bytes()) + err = writer.writeData(buf.Bytes()) if err != nil { return fmt.Errorf("io.Writer.Write failed: %w", err) } @@ -176,7 +177,7 @@ func (writer *Writer) AddData(data []byte) error { } // now write the contents - _, err := writer.writer.Write(data) + err := writer.writeData(data) if err != nil { return fmt.Errorf("io.Writer.Write failed: %w", err) } @@ -214,7 +215,7 @@ func (writer *Writer) AddData(data []byte) error { return fmt.Errorf("binary.Write failed: %w", err) } - _, err = writer.writer.Write(buf.Bytes()) + err = writer.writeData(buf.Bytes()) if err != nil { return fmt.Errorf("io.Writer.Write failed: %w", err) } @@ -238,7 +239,7 @@ func (writer *Writer) AddData(data []byte) error { return fmt.Errorf("binary.Write failed: %w", err) } - _, err = writer.writer.Write(buf.Bytes()) + err = writer.writeData(buf.Bytes()) if err != nil { return fmt.Errorf("io.Writer.Write failed: %w", err) } @@ -256,19 +257,19 @@ func (writer *Writer) AddData(data []byte) error { return nil } -// Close Completed adding all the files in zip archive. -func (writer *Writer) Close() error { +// Finish Finished adding all the files in zip archive. +func (writer *Writer) Finish() (int64, error) { err := writer.writeCentralDirectory() if err != nil { - return err + return writer.totalBytes, err } err = writer.writeEndOfCentralDirectory() if err != nil { - return fmt.Errorf("io.Writer.Write failed: %w", err) + return writer.totalBytes, fmt.Errorf("io.Writer.Write failed: %w", err) } - return nil + return writer.totalBytes, nil } // WriteCentralDirectory write central directory struct into archive. @@ -313,13 +314,13 @@ func (writer *Writer) writeCentralDirectory() error { return fmt.Errorf("binary.Write failed: %w", err) } - _, err = writer.writer.Write(buf.Bytes()) + err = writer.writeData(buf.Bytes()) if err != nil { return fmt.Errorf("io.Writer.Write failed: %w", err) } // write the filename - _, err = writer.writer.Write([]byte(writer.fileInfoEntries[i].filename)) + err = writer.writeData([]byte(writer.fileInfoEntries[i].filename)) if err != nil { return fmt.Errorf("io.Writer.Write failed: %w", err) } @@ -339,7 +340,7 @@ func (writer *Writer) writeCentralDirectory() error { return fmt.Errorf("binary.Write failed: %w", err) } - _, err = writer.writer.Write(buf.Bytes()) + err = writer.writeData(buf.Bytes()) if err != nil { return fmt.Errorf("io.Writer.Write failed: %w", err) } @@ -391,7 +392,7 @@ func (writer *Writer) writeEndOfCentralDirectory() error { return fmt.Errorf("binary.Write failed: %w", err) } - _, err = writer.writer.Write(buf.Bytes()) + err = writer.writeData(buf.Bytes()) if err != nil { return fmt.Errorf("io.Writer.Write failed: %w", err) } @@ -420,7 +421,7 @@ func (writer *Writer) WriteZip64EndOfCentralDirectory() error { return fmt.Errorf("binary.Write failed: %w", err) } - _, err = writer.writer.Write(buf.Bytes()) + err = writer.writeData(buf.Bytes()) if err != nil { return fmt.Errorf("io.Writer.Write failed: %w", err) } @@ -444,10 +445,11 @@ func (writer *Writer) WriteZip64EndOfCentralDirectoryLocator() error { return fmt.Errorf("binary.Write failed: %w", err) } - _, err = writer.writer.Write(buf.Bytes()) + err = writer.writeData(buf.Bytes()) if err != nil { return fmt.Errorf("io.Writer.Write failed: %w", err) } + return nil } @@ -458,3 +460,13 @@ func (writer *Writer) getTimeDateUnMSDosFormat() (uint16, uint16) { dateInDos := (t.Year()-80)<<9 | int((t.Month()+1)<<5) | t.Day() return uint16(timeInDos), uint16(dateInDos) } + +func (writer *Writer) writeData(data []byte) error { + n, err := writer.writer.Write(data) + if err != nil { + return err + } + + writer.totalBytes += int64(n) + return nil +} diff --git a/internal/archive/writer_test.go b/internal/archive/writer_test.go index 11f4ef4b7e..28855990f5 100644 --- a/internal/archive/writer_test.go +++ b/internal/archive/writer_test.go @@ -24,7 +24,8 @@ type ZipEntryInfo struct { } var ArchiveTests = []struct { //nolint:gochecknoglobals // This global is used as test harness for other tests - files []ZipEntryInfo + files []ZipEntryInfo + archiveSize int64 }{ { []ZipEntryInfo{ @@ -41,6 +42,7 @@ var ArchiveTests = []struct { //nolint:gochecknoglobals // This global is used a 10, }, }, + 358, }, { []ZipEntryInfo{ @@ -69,6 +71,7 @@ var ArchiveTests = []struct { //nolint:gochecknoglobals // This global is used a oneKB, }, }, + 6778, }, { []ZipEntryInfo{ @@ -97,6 +100,7 @@ var ArchiveTests = []struct { //nolint:gochecknoglobals // This global is used a oneMB + oneKB, }, }, + 526397048, }, { @@ -114,6 +118,7 @@ var ArchiveTests = []struct { //nolint:gochecknoglobals // This global is used a tenGB, }, }, + 12582912572, }, } @@ -191,10 +196,14 @@ func customZip(t *testing.T) { } } - err = archiveWriter.Close() + archiveSize, err := archiveWriter.Finish() if err != nil { t.Fatalf("Fail to close to archive: %v", err) } + + if archiveSize != test.archiveSize { + t.Errorf("archive size test failed expected %v, got %v", archiveSize, test.archiveSize) + } } } diff --git a/internal/crypto/aes_gcm.go b/internal/crypto/aes_gcm.go index d5d619297f..2a1b34aff7 100644 --- a/internal/crypto/aes_gcm.go +++ b/internal/crypto/aes_gcm.go @@ -31,14 +31,14 @@ func NewAESGcm(key []byte) (AesGcm, error) { } // Encrypt encrypts data with symmetric key. -// NOTE: This method use nonce of 16 bytes and auth tag as aes block size(16 bytes). +// NOTE: This method use nonce of 12 bytes and auth tag as aes block size(16 bytes). func (aesGcm AesGcm) Encrypt(data []byte) ([]byte, error) { - nonce, err := RandomBytes(DefaultNonceSize) + nonce, err := RandomBytes(GcmStandardNonceSize) if err != nil { return nil, err } - gcm, err := cipher.NewGCMWithNonceSize(aesGcm.block, DefaultNonceSize) + gcm, err := cipher.NewGCMWithNonceSize(aesGcm.block, GcmStandardNonceSize) if err != nil { return nil, fmt.Errorf("cipher.NewGCMWithNonceSize failed: %w", err) } @@ -77,12 +77,12 @@ func (aesGcm AesGcm) EncryptWithIVAndTagSize(iv, data []byte, authTagSize int) ( } // Decrypt decrypts data with symmetric key. -// NOTE: This method use nonce of 16 bytes and auth tag as aes block size(16 bytes) +// NOTE: This method use nonce of 12 bytes and auth tag as aes block size(16 bytes) // also expects IV as preamble of data. func (aesGcm AesGcm) Decrypt(data []byte) ([]byte, error) { // extract nonce and cipherText - nonce, cipherText := data[:DefaultNonceSize], data[DefaultNonceSize:] + nonce, cipherText := data[:GcmStandardNonceSize], data[GcmStandardNonceSize:] - gcm, err := cipher.NewGCMWithNonceSize(aesGcm.block, DefaultNonceSize) + gcm, err := cipher.NewGCMWithNonceSize(aesGcm.block, GcmStandardNonceSize) if err != nil { return nil, fmt.Errorf("cipher.NewGCMWithNonceSize failed: %w", err) } diff --git a/internal/crypto/aes_gcm_test.go b/internal/crypto/aes_gcm_test.go index 3a3e0f3a36..23364c2dca 100644 --- a/internal/crypto/aes_gcm_test.go +++ b/internal/crypto/aes_gcm_test.go @@ -15,24 +15,23 @@ func TestCreateAesGcm_DecryptWithDefaults(t *testing.T) { }{ { "66af5c10753139c6161d0f0eee125bbc9545d6704d64890e396c5c8d4f4820d4", - "29a8b044b5b6ce00e18bc6fc78ff50c6", - "b3cf733137d865892e5af63dcbca08086ba1ac82aae2", + "d2c32fa42f97341e97a33b58", + "a89d8e00e3bacacc2ed13bbc602a191d60584af3a933", "virtru", }, { "120fba31c537d99ade0a0a8c8e6df535f7de86fb6e1d5948317b4596982a5e1b", - "ec9074bc6c6b81d6520f5a7425f8977a", - "dabe8b28dd100eea2f58d71e3644b43d", + "591a6f1e947dd887d72610c8", + "83aaba876616c02bfaf5120c785ac92c", "", }, { "9895f395913a3cfd974ea53c0735030c7df4602d699c986afdc5fdd10071c0a5", - "5142d90e8499f597802ca68cddb25ec1", - `01c1e44df776bfca60ed217e06421c7b945adaf3289849406ca5b7046c886050fe7 -2cc0ebc429f683f9cfe3a47613e2ca8a812ef9b75d361c32d042124d3dc5d84c75722521df65ed7829327b -5adda0ae020a778b909328a48311cc705d4c0a8b83f49430aa80febba73e27e99b3006d6e768a092d5b9dc -894e7a634235198b1a986a3624912dec108ef03055b319f59f25fc579eb08f01820ea19edc7f9896129c572 -c36440ed80fd61fc71df37`, + "71c291bc41aacde6e0b57e7d", + `7b19b61dc053c3ffeaba57195356025a05600b071a4618912917681480f1eb62afb9ecc +ff7a90d6cba96275bd52bd8d6afa4fcbae6a400ce7033e7abd58e301ab9b4a9c3e7f4c0f55256d250faf8ce0c22bdd +9b79654842a6186df98831289eeee66fac014390a4363034d64e44fc9a2c0e0231d69c78f0a8049d8b458579041858 +d4f6da9f39542d2287d20d19dd99db339c038e3b6e1720c97ff73adda5ca4fac7da70c7d53f97a5aa346e93af`, `In cryptography, Galois/Counter Mode (GCM)[1] is a mode of operation for symmetric-key cryptographic block ciphers which is widely adopted for its performance`, diff --git a/internal/crypto/asym_decryption.go b/internal/crypto/asym_decryption.go index 5c0ffa1250..e817b157b7 100644 --- a/internal/crypto/asym_decryption.go +++ b/internal/crypto/asym_decryption.go @@ -43,7 +43,7 @@ func (asymDecryption AsymDecryption) Decrypt(data []byte) ([]byte, error) { bytes, err := asymDecryption.privateKey.Decrypt(nil, data, - &rsa.OAEPOptions{Hash: crypto.SHA256}) + &rsa.OAEPOptions{Hash: crypto.SHA1}) if err != nil { return nil, fmt.Errorf("x509.ParsePKCS8PrivateKey failed: %w", err) } diff --git a/internal/crypto/asym_encrypt_decrypt_test.go b/internal/crypto/asym_encrypt_decrypt_test.go index 0533b39449..026f77982c 100644 --- a/internal/crypto/asym_encrypt_decrypt_test.go +++ b/internal/crypto/asym_encrypt_decrypt_test.go @@ -171,53 +171,48 @@ ufgiB73q6Fnh5QHf1HNAeMUCAwEAAQ== }, { // Test certificate `-----BEGIN PRIVATE KEY----- -MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDNxPq9JCuazKAB -0s+xTVqN2kiyDQ1OpJG08yQ7efnd5yinEyxesr008UvCo5+G5AKMbmajAY/weiry -o2G4UVt37TyH0ODkAFS3dVDkLfh2MnwAXSppOadav1XRiYqGYjPPMRUJe/ueNewx -XhknUm1Bn5EXgUnMS2V0XmZdAN0Fn1DxTKZvuPa50be0G+uRDWw+ZVFg1MDPt7TO -/TScQx6ytIgfRbUMqC8dM9aVHOyfUvbdM39ZNjbyZIERSuMZJHOZdx28n7gfiScQ -3ZqgDFQPRm0J6GDCkeXTjElD63X4P3qgrWVzAPZPuvdBiaQlfs+YQtjB345Tgg+B -YJN/Kre5AgMBAAECggEABBDjX39qeSmX89FFl1xO8MSicRo+7BHmayvuyFoVrOPX -cs23L7vab0RhWdw/17uDrWC1GH18aaUQWjEOSkUZSJpgetKOzxKOmf1wdsHNyhAf -USaGIwQnoWxsdrKAET25hluS7dgMVcj8/NC+MH+5dvV/OXatjaLjw1PmM+pDc8v0 -DAGEO198w/gADYli8dQdrwFHWacklKURRhONTSXNvFmGZwB2OcrAQ+Flu5C6zAUt -zaE+orLYUcvZvW5UjTABZ8yEQASsNh0mGddCMgKwidSG1grFrgHy3QmpJfiWR04T -RdUJAr6si3+7/j3gXpd+aHvPcBYOrICWY8zYShr/gQKBgQDqGIsc3jHNBk8R0QCh -f9FIDpXnZCLppsi3vxSkvR1fQgaNJfhg8OsZWyBaNO0gc0GgvUxsr2Yq6EzA6qZB -86hCw2ykl2zWtMV+pihuxiyPsEX6GeI0ojgc5881H0yosw3HWOGQIxtq/j6xs4Kv -CKmAsjqKMYTlERZq/mB/ACtEUQKBgQDhBeoD9NIOueHtV6cnbiJXXOO0hvnARUWX -B3FIw9isoUIYIHhg6Fhz7aIXkIcaG9zJnAnQDiQPE1mru+FSXbb7sqvF8DwgUpNh -c1Wu3KoWG/ll+cm8epFb//8RcsHwrU2oOn2U6SniqQEEhY11g6fXTfiB+2aR0nri -tJKLtYBq6QKBgQCqoudMTJ5qf03Fg9582imX56HXQAO+4ubuISeQCZXOaNdTrbjG -GPaVzfngJzIt9DWDUFjT5GqJnjjMan50FoKw37hipUodmzlWXxGb5XJ37pqjepiL -my9hyoscgssjMRk7FQueQCjtLZRPfbUllx/PAptvPjdLrc/0f6WhDWN8cQKBgBH/ -oRoz6OFYqOONEUesHX8TZPs5mJxybgCIjd0eHSShuWGopzhJHVoYddCgtM230M8n -dfl4SBYUnCWKX5lw+YPkZKzubEDBmhw/V2knKUufbTFk62fx/dJ4BXTl0vFnS0Db -fNP+WmVQ004DSK9Pmos0J15uN2QOi9m6S8Z0/BKRAoGAPXVuUCakXJucjorc2Rnn -TC4h73nhw127gNHT//V0Xok9PQkfCLyppPlhpSMQJ7IaMwIL28c4I/611CHDU8jb -KjNDoTuL3yRqOQNz1zLub5m/+vti4+Sh0//kT5G1SjL7BWSxtGt7EvQiM6rxgihP -Px0iQ2RGtzRV2ejyvAouYos= +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDOpiotrvV2i5h6 +clHMzDGgh3h/kMa0LoGx2OkDPd8jogycUh7pgE5GNiN2lpSmFkjxwYMXnyrwr9Ex +yczBWJ7sRGDCDaQg5fjVUIloZ8FJVbn+sEcfQ9iX6vmI9/S++oGK79QM3V8M8cp4 +1r/T1YVmuzUHE1say/TLHGhjtGkxHDF8qFy6Z2rYFTCVJQHNqGmwNVGd0qG7gim8 +6Hawu/CMYj4jG9oITlj8rJtQOaJ6ZqemQVoNmb3j1LkyeUKzRIt+86aoBiz+T3Tf +OEvXF6xgBj3XoiOhPYK+abFPYcrArvb6oubT8NjjQoj3j0sXWUnIIMg+e4f+XNVU +54ZzDaLZAgMBAAECggEBALb0yK0PlMUyzHnEUwXV1y5AIoAWhsYp0qvJ1msHUVKz ++yQ/VJz4+tQQxI8OvGbbnhNkd5LnWdYkYzsIZl7b/kBCPcQw3Zo+4XLCzhUAn1E1 +M+n42c8le1LtN6Z7mVWoZh7DPONy7t+ABvm7b7S1+1i78DPmgCeWYZGeAhIcPXG6 +5AxWIV3jigxksE6kYY9Y7DmtsZgMRrdV7SU8VtgPtT7tua8z5/U3Av0WINyKBSoM +0yDHsAg57KnM8znx2JWLtHd0Mk5bBuu2DLbtyKNrVUAUuMPzrLGBh9S9QRd934KU +uFAi1TEfgEachnGgSHJpzVzr2ur1tifABnQ7GNXObe0CgYEA6KowK0subdDY+uGW +ciP2XDAMerbJJeL0/UIGPb/LUmskniio2493UBGgY2FsRyvbzJ+/UAOjIPyIxhj7 +78ZyVG8BmIzKan1RRVh//O+5yvks/eTOYjWeQ1Lcgqs3q4YAO13CEBZgKWKTUomg +mskFJq04tndeSIyhDaW+BuWaXA8CgYEA42ABz3pql+DH7oL5C4KYBymK6wFBBOqk +dVk+ftyJQ6PzuZKpfsu4aPIjKm71lkTgK6O9o08s3SckAdu6vLukq2TZFF+a+9OI +lu5ww7GvfdMTgLAaFchD4bPlOInh1KVjBc1MwGXpl0ROde5pi8+WUrv9QJuoQfB/ +4rhYdbJLSpcCgYA41mqSCPm8pgp7r2RbWeGzP6Gs0L5u3PTQcbKonxQCfF4jrPcj +O/b/vm6aGJClClfVsyi/WUQeqNKY4j2Zo7cGXV/cbnh8b0TNVgNePQn8Rcbx91Vb +tJGHDNUFruIYqtGfrxXbbDvtoEExJqHvbjAt9J8oJB0KSCCH/vdfI/QDjQKBgQCD +xLPH5Y24js/O7aAeh4RLQkv7fTKNAt5kE2AgbPYveOhZ9yC7Fpy8VPcENGGmwCuZ +nr7b0ZqSX4iCezBxB92aZktXf0B2CFT0AyLehi7JoHWA8o1rai/MsVB5v45ciawl +RKDiLy18OF2wAoawO5FGSSOvOYX9EL9MSMEbFESF6QKBgCVlZ9pPC+55rGT6AcEL +tUpDs+/wZvcmfsFd8xC5mMUN0DatAVzVAUI95+tQaWU3Uj+bqHq0lC6Wy2VceG0D +D+7EicjdGFN/2WVPXiYX1fblkxasZY+wChYBrPLjA9g0qOzzmXbRBph5QxDuQjJ6 +qcddVKB624a93ZBssn7OivnR -----END PRIVATE KEY-----`, `-----BEGIN CERTIFICATE----- -MIIDazCCAlOgAwIBAgIUUySK3aFhC08WO9hOMq19tFurxeUwDQYJKoZIhvcNAQEL -BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM -GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMzEyMTIxODQzNDZaFw0yNDEy -MTExODQzNDZaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw -HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB -AQUAA4IBDwAwggEKAoIBAQDNxPq9JCuazKAB0s+xTVqN2kiyDQ1OpJG08yQ7efnd -5yinEyxesr008UvCo5+G5AKMbmajAY/weiryo2G4UVt37TyH0ODkAFS3dVDkLfh2 -MnwAXSppOadav1XRiYqGYjPPMRUJe/ueNewxXhknUm1Bn5EXgUnMS2V0XmZdAN0F -n1DxTKZvuPa50be0G+uRDWw+ZVFg1MDPt7TO/TScQx6ytIgfRbUMqC8dM9aVHOyf -UvbdM39ZNjbyZIERSuMZJHOZdx28n7gfiScQ3ZqgDFQPRm0J6GDCkeXTjElD63X4 -P3qgrWVzAPZPuvdBiaQlfs+YQtjB345Tgg+BYJN/Kre5AgMBAAGjUzBRMB0GA1Ud -DgQWBBTV1245NzOOCFSsUwWAsKdC/9fEnDAfBgNVHSMEGDAWgBTV1245NzOOCFSs -UwWAsKdC/9fEnDAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQAU -mJ6wTM0nlDfewYSOl6vkVHwL1LIXVEPnBhAAa3Hax9wtvUK+Tbg+7D3l7/Iq5Y/+ -UsHa/BQ9lktj/v2DAsaTVcO33IIyPrz1NCzB44hv+3G5HzUwKgDTY5tN8SCmy+Ti -rx3c0AOqCfpKmuRhzYLbaHjbj7jwCzdelJr318x2dVOGXB5J1yWQ3TfKKicWqA+9 -pDtPNFJ3M79YVI+84gUN6YAeM0edJWOpZRoxui3DvzINVAd3W+p3v+HYc6d+SRTw -Ac3Yy7hVSJCQjkgYm4dWNDSWFFz5TYkASjQiXhEyvofS37lf6p6qm77CAr/xfdKx -rfa3TMt0K2Gs502yHMD4 +MIICmDCCAYACCQC3BCaSANRhYzANBgkqhkiG9w0BAQsFADAOMQwwCgYDVQQDDANr +YXMwHhcNMjEwOTE1MTQxMTQ4WhcNMjIwOTE1MTQxMTQ4WjAOMQwwCgYDVQQDDANr +YXMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDOpiotrvV2i5h6clHM +zDGgh3h/kMa0LoGx2OkDPd8jogycUh7pgE5GNiN2lpSmFkjxwYMXnyrwr9ExyczB +WJ7sRGDCDaQg5fjVUIloZ8FJVbn+sEcfQ9iX6vmI9/S++oGK79QM3V8M8cp41r/T +1YVmuzUHE1say/TLHGhjtGkxHDF8qFy6Z2rYFTCVJQHNqGmwNVGd0qG7gim86Haw +u/CMYj4jG9oITlj8rJtQOaJ6ZqemQVoNmb3j1LkyeUKzRIt+86aoBiz+T3TfOEvX +F6xgBj3XoiOhPYK+abFPYcrArvb6oubT8NjjQoj3j0sXWUnIIMg+e4f+XNVU54Zz +DaLZAgMBAAEwDQYJKoZIhvcNAQELBQADggEBABewfZOJ4/KNRE8IQ5TsW/AVn7C1 +l5ty6tUUBSVi8/df7WYts0bHEdQh9yl9agEU5i4rj43y8vMVZNzSeHcurtV/+C0j +fbkHQHeiQ1xn7cq3Sbh4UVRyuu4C5PklEH4AN6gxmgXC3kT15uWw8I4nm/plzYLs +I099IoRfC5djHUYYLMU/VkOIHuPC3sb7J65pSN26eR8bTMVNagk187V/xNwUuvkf ++NUxDO615/5BwQKnAu5xiIVagYnDZqKCOtYS5qhxF33Nlnwlm7hH8iVZ1RI+n52l +wVyElqp317Ksz+GtTIc+DE6oryxK3tZd4hrj9fXT4KiJvQ4pcRjpePgH7B8= -----END CERTIFICATE-----`, }, } diff --git a/internal/crypto/asym_encryption.go b/internal/crypto/asym_encryption.go index d656915c32..bf75cbcb50 100644 --- a/internal/crypto/asym_encryption.go +++ b/internal/crypto/asym_encryption.go @@ -3,7 +3,7 @@ package crypto import ( "crypto/rand" "crypto/rsa" - "crypto/sha256" + "crypto/sha1" "crypto/x509" "encoding/pem" "errors" @@ -57,10 +57,10 @@ func (asymEncryption AsymEncryption) Encrypt(data []byte) ([]byte, error) { return nil, errors.New("failed to encrypt, public key is empty") } - bytes, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, asymEncryption.publicKey, data, nil) + bytes, err := rsa.EncryptOAEP(sha1.New(), rand.Reader, asymEncryption.publicKey, data, nil) if err != nil { return nil, fmt.Errorf("rsa.EncryptOAEP failed: %w", err) } return bytes, nil -} +} \ No newline at end of file diff --git a/sdk/assertion.go b/sdk/assertion.go new file mode 100644 index 0000000000..3ae68693b5 --- /dev/null +++ b/sdk/assertion.go @@ -0,0 +1,4 @@ +package sdk + +type Assertion struct { +} diff --git a/sdk/auth_config.go b/sdk/auth_config.go new file mode 100644 index 0000000000..60df050f35 --- /dev/null +++ b/sdk/auth_config.go @@ -0,0 +1,32 @@ +package sdk + +import ( + "fmt" + "github.com/opentdf/opentdf-v2-poc/internal/crypto" +) + +type AuthConfig struct { + signingPublicKey string + signingPrivateKey string + authToken string +} + +// NewAuthConfig Create a new instance of authConfig +func NewAuthConfig() (*AuthConfig, error) { + rsaKeyPair, err := crypto.NewRSAKeyPair(tdf3KeySize) + if err != nil { + return nil, fmt.Errorf("crypto.NewRSAKeyPair failed: %w", err) + } + + publicKey, err := rsaKeyPair.PublicKeyInPemFormat() + if err != nil { + return nil, fmt.Errorf("crypto.PublicKeyInPemFormat failed: %w", err) + } + + privateKey, err := rsaKeyPair.PublicKeyInPemFormat() + if err != nil { + return nil, fmt.Errorf("crypto.PublicKeyInPemFormat failed: %w", err) + } + + return &AuthConfig{signingPublicKey: publicKey, signingPrivateKey: privateKey}, nil +} diff --git a/sdk/manifest.go b/sdk/manifest.go new file mode 100644 index 0000000000..870c9490bf --- /dev/null +++ b/sdk/manifest.go @@ -0,0 +1,73 @@ +package sdk + +type Segment struct { + Hash string `json:"hash"` + Size int64 `json:"segmentSize"` + EncryptedSize int64 `json:"encryptedSegmentSize"` +} + +type RootSignature struct { + Algorithm string `json:"alg"` + Signature string `json:"sig"` +} + +type IntegrityInformation struct { + RootSignature `json:"rootSignature"` + SegmentHashAlgorithm string `json:"segmentHashAlg"` + DefaultSegmentSize int64 `json:"segmentSizeDefault"` + DefaultEncryptedSegSize int64 `json:"encryptedSegmentSizeDefault"` + Segments []Segment `json:"segments"` +} + +type KeyAccess struct { + KeyType string `json:"type"` + KasURL string `json:"url"` + Protocol string `json:"protocol"` + WrappedKey string `json:"wrappedKey"` + PolicyBinding string `json:"policyBinding"` + EncryptedMetadata string `json:"encryptedMetadata"` +} + +type Method struct { + Algorithm string `json:"algorithm"` + IV string `json:"iv"` + IsStreamable bool `json:"isStreamable"` +} + +type Payload struct { + Type string `json:"type"` + URL string `json:"url"` + Protocol string `json:"protocol"` + MimeType string `json:"mimeType"` + IsEncrypted bool `json:"isEncrypted"` + // IntegrityInformation IntegrityInformation `json:"integrityInformation"` +} + +type EncryptionInformation struct { + KeyAccessType string `json:"type"` + Policy string `json:"policy"` + KeyAccessObjs []KeyAccess `json:"keyAccess"` + Method Method `json:"method"` + IntegrityInformation `json:"integrityInformation"` +} + +type Manifest struct { + EncryptionInformation `json:"encryptionInformation"` + Payload `json:"payload"` +} + +type attributeObject struct { + Attribute string `json:"attribute"` + DisplayName string `json:"displayName"` + IsDefault bool `json:"isDefault"` + PubKey string `json:"pubKey"` + KasURL string `json:"kasURL"` +} + +type policyObject struct { + UUID string `json:"uuid"` + Body struct { + DataAttributes []attributeObject `json:"dataAttributes"` + Dissem []string `json:"dissem"` + } +} diff --git a/sdk/split_key.go b/sdk/split_key.go new file mode 100644 index 0000000000..7827cb296b --- /dev/null +++ b/sdk/split_key.go @@ -0,0 +1,456 @@ +package sdk + +import ( + "bytes" + "context" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" + "github.com/opentdf/opentdf-v2-poc/internal/crypto" + "io" + "log/slog" + "net/http" + "net/url" + "strings" + "time" +) + +const ( + kKeySize = 32 + kWrapped = "wrapped" + kKasProtocol = "kas" + kSplitKeyType = "split" + kGCMCipherAlgorithm = "AES-256-GCM" + kGMACPayloadLength = 16 + kClientPublicKey = "clientPublicKey" + kSignedRequestToken = "signedRequestToken" + kKasURL = "url" + kRewrapV2 = "/v2/upsert" + kAuthorizationKey = "Authorization" + kContentTypeKey = "Content-Type" + kAcceptKey = "Accept" + kContentTypeJSONValue = "application/json" + kEntityWrappedKey = "entityWrappedKey" + kPolicy = "policy" + kHmacIntegrityAlgorithm = "HS256" + kGmacIntegrityAlgorithm = "GMAC" +) + +type rewrapJWTClaims struct { + jwt.RegisteredClaims + Body string `json:"requestBody"` +} + +type splitKey struct { + attributes []string + tdfKeyAccessObjects []tdfKeyAccess + kasInfoList []KASInfo + key [kKeySize]byte + aesGcm crypto.AesGcm +} + +type tdfKeyAccess struct { + kasPublicKey string + kasURL string + wrappedKey [kKeySize]byte + metaData string +} + +var ( + errInvalidKasInfo = errors.New("split-key: kas information is missing") + errKasPubKeyMissing = errors.New("split-key: kas public key is missing") +) + +// newSplitKeyFromKasInfo create a instance of split key object. +func newSplitKeyFromKasInfo(kasInfoList []KASInfo, attributes []string, metaData string) (splitKey, error) { + if len(kasInfoList) == 0 { + return splitKey{}, errInvalidKasInfo + } + + tdfKeyAccessObjs := make([]tdfKeyAccess, 0) + for _, kasInfo := range kasInfoList { + if len(kasInfo.publicKey) == 0 { + return splitKey{}, errKasPubKeyMissing + } + + keyAccess := tdfKeyAccess{} + keyAccess.kasPublicKey = kasInfo.publicKey + keyAccess.kasURL = kasInfo.url + keyAccess.metaData = metaData + + key, err := crypto.RandomBytes(kKeySize) + if err != nil { + return splitKey{}, fmt.Errorf("crypto.RandomBytes failed:%w", err) + } + + keyAccess.wrappedKey = [kKeySize]byte(key) + tdfKeyAccessObjs = append(tdfKeyAccessObjs, keyAccess) + } + + sKey := splitKey{} + + // create the split key by XOR all the keys in key access object. + for _, keyAccessObj := range tdfKeyAccessObjs { + for keyByteIndex, keyByte := range keyAccessObj.wrappedKey { + sKey.key[keyByteIndex] ^= keyByte + } + } + + gcm, err := crypto.NewAESGcm(sKey.key[:]) + if err != nil { + return splitKey{}, fmt.Errorf(" crypto.NewAESGcm failed:%w", err) + } + + sKey.attributes = attributes + sKey.tdfKeyAccessObjects = tdfKeyAccessObjs + sKey.kasInfoList = kasInfoList + sKey.aesGcm = gcm + + return sKey, nil +} + +// newSplitKeyFromManifest create a instance of split key from(parsing) the manifest. +func newSplitKeyFromManifest(authConfig AuthConfig, manifest Manifest) (splitKey, error) { + sKey := splitKey{} + + for _, keyAccessObj := range manifest.EncryptionInformation.KeyAccessObjs { + keyAccessAsMap, err := structToMap(keyAccessObj) + if err != nil { + return splitKey{}, fmt.Errorf("fail to convert key access object to map:%w", err) + } + + keyAccessAsMap[kPolicy] = manifest.EncryptionInformation.Policy + key, err := sKey.rewrap(authConfig, keyAccessAsMap) + if err != nil { + return splitKey{}, fmt.Errorf(" splitKey.rewrap failed:%w", err) + } + + for keyByteIndex, keyByte := range key { + sKey.key[keyByteIndex] ^= keyByte + } + + keyAccess := tdfKeyAccess{} + keyAccess.kasURL = keyAccessObj.KasURL + keyAccess.wrappedKey = [32]byte(key) + + if len(keyAccessObj.EncryptedMetadata) != 0 { + gcm, err := crypto.NewAESGcm(key) + if err != nil { + return splitKey{}, fmt.Errorf("crypto.NewAESGcm failed:%w", err) + } + + decodedMetaData, err := crypto.Base64Decode([]byte(keyAccessObj.EncryptedMetadata)) + if err != nil { + return splitKey{}, fmt.Errorf("crypto.Base64Decode failed:%w", err) + } + + metaData, err := gcm.Decrypt(decodedMetaData) + if err != nil { + return splitKey{}, fmt.Errorf("crypto.AesGcm.encrypt failed:%w", err) + } + + keyAccess.metaData = string(metaData) + } + + sKey.tdfKeyAccessObjects = append(sKey.tdfKeyAccessObjects, keyAccess) + } + + gcm, err := crypto.NewAESGcm(sKey.key[:]) + if err != nil { + return splitKey{}, fmt.Errorf(" crypto.NewAESGcm failed:%w", err) + } + sKey.aesGcm = gcm + + return sKey, nil +} + +// getManifest Return the manifest. +func (splitKey splitKey) getManifest() (*Manifest, error) { + manifest := Manifest{} + manifest.EncryptionInformation.KeyAccessType = kSplitKeyType + + policyObj, err := splitKey.createPolicyObject() + if err != nil { + return nil, fmt.Errorf("fail to create policy object:%w", err) + } + + policyObjectAsStr, err := json.Marshal(policyObj) + if err != nil { + return nil, fmt.Errorf("json.Marshal failed:%w", err) + } + + base64PolicyObject := crypto.Base64Encode(policyObjectAsStr) + + for _, keyAccessObj := range splitKey.tdfKeyAccessObjects { + keyAccess := KeyAccess{} + keyAccess.KeyType = kWrapped + keyAccess.KasURL = keyAccessObj.kasURL + keyAccess.Protocol = kKasProtocol + + // wrap the key with kas public key + asymEncrypt, err := crypto.NewAsymEncryption(keyAccessObj.kasPublicKey) + if err != nil { + return nil, fmt.Errorf("crypto.NewAsymEncryption failed:%w", err) + } + + encryptData, err := asymEncrypt.Encrypt(keyAccessObj.wrappedKey[:]) + if err != nil { + return nil, fmt.Errorf("crypto.AsymEncryption.encrypt failed:%w", err) + } + keyAccess.WrappedKey = string(crypto.Base64Encode(encryptData)) + + // add policyBinding + policyBinding := hex.EncodeToString(crypto.CalculateSHA256Hmac(keyAccessObj.wrappedKey[:], base64PolicyObject)) + keyAccess.PolicyBinding = string(crypto.Base64Encode([]byte(policyBinding))) + + // add meta data + if len(keyAccessObj.metaData) > 0 { + gcm, err := crypto.NewAESGcm(keyAccessObj.wrappedKey[:]) + if err != nil { + return nil, fmt.Errorf("crypto.NewAESGcm failed:%w", err) + } + + encryptedMetaData, err := gcm.Encrypt([]byte(keyAccessObj.metaData)) + if err != nil { + return nil, fmt.Errorf("crypto.AesGcm.encrypt failed:%w", err) + } + + keyAccess.EncryptedMetadata = string(crypto.Base64Encode(encryptedMetaData)) + } + + manifest.EncryptionInformation.KeyAccessObjs = append(manifest.EncryptionInformation.KeyAccessObjs, keyAccess) + } + + manifest.EncryptionInformation.Policy = string(base64PolicyObject) + manifest.EncryptionInformation.Method.Algorithm = kGCMCipherAlgorithm + + return &manifest, nil +} + +// encrypt the data using the split key. +func (splitKey splitKey) encrypt(data []byte) ([]byte, error) { + buf, err := splitKey.aesGcm.Encrypt(data) + if err != nil { + return nil, fmt.Errorf("AesGcm.encrypt failed:%w", err) + } + + return buf, nil +} + +// decrypt the data using the split key. +func (splitKey splitKey) decrypt(data []byte) ([]byte, error) { + buf, err := splitKey.aesGcm.Decrypt(data) + if err != nil { + return nil, fmt.Errorf("AesGcm.Decrypt failed:%w", err) + } + + return buf, nil +} + +func (splitKey splitKey) validateRootSignature(manifest *Manifest) (bool, error) { + rootSigAlg := manifest.EncryptionInformation.IntegrityInformation.RootSignature.Algorithm + rootSigValue := manifest.EncryptionInformation.IntegrityInformation.RootSignature.Signature + + aggregateHash := &bytes.Buffer{} + for _, segment := range manifest.EncryptionInformation.IntegrityInformation.Segments { + decodedHash, err := crypto.Base64Decode([]byte(segment.Hash)) + if err != nil { + return false, fmt.Errorf("crypto.Base64Decode failed:%w", err) + } + + aggregateHash.Write(decodedHash) + } + + sigAlg := HS256 + if strings.EqualFold(gmacIntegrityAlgorithm, rootSigAlg) { + sigAlg = GMAC + } + + sig, err := splitKey.getSignature(aggregateHash.Bytes(), sigAlg) + if err != nil { + return false, fmt.Errorf("splitkey.getSignature failed:%w", err) + } + + if rootSigValue == string(crypto.Base64Encode([]byte(sig))) { + return true, nil + } + + return false, nil +} + +// getSignature calculate signature of data of the given algorithm. +func (splitKey splitKey) getSignature(data []byte, alg IntegrityAlgorithm) (string, error) { + if alg == HS256 { + hmac := crypto.CalculateSHA256Hmac(splitKey.key[:], data) + return hex.EncodeToString(hmac), nil + } + if kGMACPayloadLength > len(data) { + return "", fmt.Errorf("fail to create gmac signature") + } + + return hex.EncodeToString(data[len(data)-kGMACPayloadLength:]), nil +} + +func (splitKey splitKey) createPolicyObject() (policyObject, error) { + uuidObj, err := uuid.NewUUID() + if err != nil { + return policyObject{}, fmt.Errorf("uuid.NewUUID failed: %w", err) + } + + policyObj := policyObject{} + policyObj.UUID = uuidObj.String() + + for _, attribute := range splitKey.attributes { + attributeObj := attributeObject{} + attributeObj.Attribute = attribute + policyObj.Body.DataAttributes = append(policyObj.Body.DataAttributes, attributeObj) + } + + return policyObj, nil +} + +func (splitKey splitKey) rewrap(authConfig AuthConfig, requestBody map[string]interface{}) ([]byte, error) { + kasURL, ok := requestBody[kKasURL] + if !ok { + return nil, fmt.Errorf("kas url is missing in key access object") + } + + clientKeyPair, err := crypto.NewRSAKeyPair(tdf3KeySize) + if err != nil { + return nil, fmt.Errorf("crypto.NewRSAKeyPair failed: %w", err) + } + + clientPubKey, err := clientKeyPair.PublicKeyInPemFormat() + if err != nil { + return nil, fmt.Errorf("crypto.PublicKeyInPemFormat failed: %w", err) + } + + clientPrivateKey, err := clientKeyPair.PrivateKeyInPemFormat() + if err != nil { + return nil, fmt.Errorf("crypto.PrivateKeyInPemFormat failed: %w", err) + } + + requestBody[kClientPublicKey] = clientPubKey + requestBodyData, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("json.Marshal failed: %w", err) + } + + claims := rewrapJWTClaims{ + jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(60 * time.Second)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + string(requestBodyData), + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + + signingRSAPrivateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(authConfig.signingPrivateKey)) + if err != nil { + return nil, fmt.Errorf("jwt.ParseRSAPrivateKeyFromPEM failed: %w", err) + } + + signedToken, err := token.SignedString(signingRSAPrivateKey) + if err != nil { + return nil, fmt.Errorf("jwt.SignedString failed: %w", err) + } + + signedTokenRequestBody, err := json.Marshal(map[string]string{ + kSignedRequestToken: signedToken, + }) + if err != nil { + return nil, fmt.Errorf("json.Marshal failed: %w", err) + } + + kasRewrapURL, err := url.JoinPath(fmt.Sprintf("%v", kasURL), kRewrapV2) + if err != nil { + return nil, fmt.Errorf("url.JoinPath failed: %w", err) + } + + request, err := http.NewRequestWithContext(context.Background(), http.MethodPost, kasRewrapURL, + bytes.NewBuffer(signedTokenRequestBody)) + if err != nil { + return nil, fmt.Errorf("http.NewRequestWithContext failed: %w", err) + } + + // add required headers + request.Header = http.Header{ + kContentTypeKey: {kContentTypeJSONValue}, + kAuthorizationKey: {authConfig.authToken}, + kAcceptKey: {kContentTypeJSONValue}, + } + + client := &http.Client{} + + response, err := client.Do(request) + if response.StatusCode != kHTTPOk { + return nil, fmt.Errorf("%s failed status code:%d", kasRewrapURL, response.StatusCode) + } + + defer func(Body io.ReadCloser) { + err := Body.Close() + if err != nil { + slog.Error("Fail to close HTTP response") + } + }(response.Body) + + rewrapResponseBody, err := io.ReadAll(response.Body) + if err != nil { + return nil, fmt.Errorf("io.ReadAll failed: %w", err) + } + + key, err := getWrappedKey(rewrapResponseBody, clientPrivateKey) + if err != nil { + return nil, fmt.Errorf("failed to unwrap the wrapped key:%w", err) + } + + return key, nil +} + +func getWrappedKey(rewrapResponseBody []byte, clientPrivateKey string) ([]byte, error) { + var data map[string]interface{} + err := json.Unmarshal(rewrapResponseBody, &data) + if err != nil { + return nil, fmt.Errorf("json.Unmarshal failed: %w", err) + } + + entityWrappedKey, ok := data[kEntityWrappedKey] + if !ok { + return nil, fmt.Errorf("entityWrappedKey is missing in key access object") + } + + asymDecrypt, err := crypto.NewAsymDecryption(clientPrivateKey) + if err != nil { + return nil, fmt.Errorf("crypto.NewAsymDecryption failed: %w", err) + } + + entityWrappedKeyDecoded, err := crypto.Base64Decode([]byte(fmt.Sprintf("%v", entityWrappedKey))) + if err != nil { + return nil, fmt.Errorf("crypto.Base64Decode failed: %w", err) + } + + key, err := asymDecrypt.Decrypt(entityWrappedKeyDecoded) + if err != nil { + return nil, fmt.Errorf("crypto.Decrypt failed: %w", err) + } + + return key, nil +} + +func structToMap(structObj interface{}) (map[string]interface{}, error) { + structData, err := json.Marshal(structObj) + if err != nil { + return nil, fmt.Errorf("json.Marshal failed: %w", err) + } + + mapData := make(map[string]interface{}) + err = json.Unmarshal(structData, &mapData) + if err != nil { + return nil, fmt.Errorf("json.Unmarshal failed: %w", err) + } + + return mapData, nil +} diff --git a/sdk/split_key_test.go b/sdk/split_key_test.go new file mode 100644 index 0000000000..f192339640 --- /dev/null +++ b/sdk/split_key_test.go @@ -0,0 +1,260 @@ +package sdk + +import ( + "encoding/hex" + "encoding/json" + "fmt" + "github.com/golang-jwt/jwt/v4" + "github.com/opentdf/opentdf-v2-poc/internal/crypto" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestNewSplitKeyFromKasInfo(t *testing.T) { + attributes := []string{ + "https://example.com/attr/Classification/value/S", + "https://example.com/attr/Classification/value/X", + } + sampleMetaData := `{"displayName" : "openTDF go sdk"}` + + for _, test := range testHarnesses { + kasInfoList := test.kasInfoList + for index := range kasInfoList { + kasInfoList[index].publicKey = mockKasPublicKey + } + + sKey, err := newSplitKeyFromKasInfo(test.kasInfoList, attributes, sampleMetaData) + if err != nil { + t.Fatalf("tdf.newSplitKeyFromKasInfo failed: %v", err) + } + + manifest, err := sKey.getManifest() + if err != nil { + t.Fatalf("tdf.splitKey.getManifest failed: %v", err) + } + + if len(manifest.KeyAccessObjs) == 0 { + t.Fatalf("fail: key access object missing from the manifest") + } + + if len(manifest.KeyAccessObjs[0].EncryptedMetadata) == 0 { + t.Fatalf("fail: meta data missing from the manifest") + } + } +} + +//nolint:gocognit +func TestNewSplitKeyFromManifest(t *testing.T) { + kasPrivateKey := `-----BEGIN PRIVATE KEY----- + MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDOpiotrvV2i5h6 + clHMzDGgh3h/kMa0LoGx2OkDPd8jogycUh7pgE5GNiN2lpSmFkjxwYMXnyrwr9Ex + yczBWJ7sRGDCDaQg5fjVUIloZ8FJVbn+sEcfQ9iX6vmI9/S++oGK79QM3V8M8cp4 + 1r/T1YVmuzUHE1say/TLHGhjtGkxHDF8qFy6Z2rYFTCVJQHNqGmwNVGd0qG7gim8 + 6Hawu/CMYj4jG9oITlj8rJtQOaJ6ZqemQVoNmb3j1LkyeUKzRIt+86aoBiz+T3Tf + OEvXF6xgBj3XoiOhPYK+abFPYcrArvb6oubT8NjjQoj3j0sXWUnIIMg+e4f+XNVU + 54ZzDaLZAgMBAAECggEBALb0yK0PlMUyzHnEUwXV1y5AIoAWhsYp0qvJ1msHUVKz + +yQ/VJz4+tQQxI8OvGbbnhNkd5LnWdYkYzsIZl7b/kBCPcQw3Zo+4XLCzhUAn1E1 + M+n42c8le1LtN6Z7mVWoZh7DPONy7t+ABvm7b7S1+1i78DPmgCeWYZGeAhIcPXG6 + 5AxWIV3jigxksE6kYY9Y7DmtsZgMRrdV7SU8VtgPtT7tua8z5/U3Av0WINyKBSoM + 0yDHsAg57KnM8znx2JWLtHd0Mk5bBuu2DLbtyKNrVUAUuMPzrLGBh9S9QRd934KU + uFAi1TEfgEachnGgSHJpzVzr2ur1tifABnQ7GNXObe0CgYEA6KowK0subdDY+uGW + ciP2XDAMerbJJeL0/UIGPb/LUmskniio2493UBGgY2FsRyvbzJ+/UAOjIPyIxhj7 + 78ZyVG8BmIzKan1RRVh//O+5yvks/eTOYjWeQ1Lcgqs3q4YAO13CEBZgKWKTUomg + mskFJq04tndeSIyhDaW+BuWaXA8CgYEA42ABz3pql+DH7oL5C4KYBymK6wFBBOqk + dVk+ftyJQ6PzuZKpfsu4aPIjKm71lkTgK6O9o08s3SckAdu6vLukq2TZFF+a+9OI + lu5ww7GvfdMTgLAaFchD4bPlOInh1KVjBc1MwGXpl0ROde5pi8+WUrv9QJuoQfB/ + 4rhYdbJLSpcCgYA41mqSCPm8pgp7r2RbWeGzP6Gs0L5u3PTQcbKonxQCfF4jrPcj + O/b/vm6aGJClClfVsyi/WUQeqNKY4j2Zo7cGXV/cbnh8b0TNVgNePQn8Rcbx91Vb + tJGHDNUFruIYqtGfrxXbbDvtoEExJqHvbjAt9J8oJB0KSCCH/vdfI/QDjQKBgQCD + xLPH5Y24js/O7aAeh4RLQkv7fTKNAt5kE2AgbPYveOhZ9yC7Fpy8VPcENGGmwCuZ + nr7b0ZqSX4iCezBxB92aZktXf0B2CFT0AyLehi7JoHWA8o1rai/MsVB5v45ciawl + RKDiLy18OF2wAoawO5FGSSOvOYX9EL9MSMEbFESF6QKBgCVlZ9pPC+55rGT6AcEL + tUpDs+/wZvcmfsFd8xC5mMUN0DatAVzVAUI95+tQaWU3Uj+bqHq0lC6Wy2VceG0D + D+7EicjdGFN/2WVPXiYX1fblkxasZY+wChYBrPLjA9g0qOzzmXbRBph5QxDuQjJ6 + qcddVKB624a93ZBssn7OivnR + -----END PRIVATE KEY-----` + + sampleManifest := `{ + "encryptionInformation": { + "type": "split", + "policy": "eyJ1dWlkIjoiMmQyY2ZjMzQtYjg5MC0xMWVlLWEyMDgtYjJjMDM2M2FlNjI5IiwiQm9keSI6eyJkYXRhQXR0cmlidXRlcyI6W10sImRpc3NlbSI6W119fQ==", + "keyAccess": [ + { + "type": "kWrapped", + "url": "http://localhost:65432/api/kas", + "protocol": "kas", + "wrappedKey": "DfWZxVju4DIkSAu/QRHI04pLnBciASSDRokJ5gdDjx8fnh5jNsoyGQ63ekJgGEQp0r5CZqCIUHny7RU52LyMQuTz+lNLJKsZ3n9jDim5TbfzR2ETYAaAySzEPtUsVUWxwXHeHY8YNvb3nu8DuGCO2VadascqU9lZt6KOZ6Vr5JBOH3TukvTb0twHeJoBfyT+4HKSh27sdSOSNWOSuQkcbKGbcrAuTaV50jABphlW01gCfUv1N0BF3nWF30xOzpVl3BFwS/dA8bVVIckTLP6M456cWL6YrqHefwVA1Igrks/uVolL9sN1xS+nNlVVFCgipVz3I3wwgSTjhg5QD8YUcg==", + "policyBinding": "MDczYTJiYjE0MmZiODIxNTA3MjI2ZDBiYmNhMTM0ZmQyNDQ0YzJkODAwNmRjMjMxYjY2OWVhNTZlNzYyNTY1Nw==", + "encryptedMetadata": "" + }, + { + "type": "kWrapped", + "url": "http://localhost:65432/api/kas", + "protocol": "kas", + "wrappedKey": "rz13UFBazveewf7gHzEZZeg6Y5hjcVaz05W4VTlqVBxcNvJGajcXFIaeVCUgMf1++LOyqlqy6lIT+QpSG4pksXBCr7DeBrzvrXd4PUPlzFVDdZFbV22AZviSNQWe9IJyiZLt8L6RaHZcUfK2Gy2rUvXVr8o70xSjOvNAzp4nGJZPTSfbgSTo0aFPqgSvk+SmWNZl6eA98woCYO/SnSkHDWzuz7eSKcooiWoZD/XV71SpY+vHZaNwToEH4lhOxBTzNvPCX8cxi/2a6bygw4ma/bpepwwERS3SLg0cqDdQhQ95j34Y2aVzx3tSUntr33X0DHLimp1RKOTFdiPiAAnfuQ==", + "policyBinding": "MWQ3NmEwNjk2NWU5ZDZiNDQzM2U2ZTQ3MTU0NTEyYTQ0NjYwZGFiZDkyYjYzMTI3ZDUzMjE5NDJmMDg4YTNhOQ==", + "encryptedMetadata": "" + } + ], + "method": { + "algorithm": "AES-256-GCM", + "iv": "", + "isStreamable": true + }, + "integrityInformation": { + "rootSignature": { + "alg": "HS256", + "sig": "MWI0NWNmMzJkMDliOWI5YjJmNDk1YTk0NzhjMmJjMzMyODFhM2U5YjgxOTE0ZWY0NDI2ZGFkODkyMDEzY2VlMg==" + }, + "segmentHashAlg": "GMAC", + "segmentSizeDefault": 2097152, + "encryptedSegmentSizeDefault": 2097180, + "segments": [ + { + "hash": "NTZkZTg4NmE2MDhkNTU5OTU0N2RiNmRiNjNmMWExY2U=", + "segmentSize": 1024, + "encryptedSegmentSize": 1052 + } + ] + } + }, + "payload": { + "type": "reference", + "url": "0.payload", + "protocol": "zip", + "mimeType": "application/octet-stream", + "isEncrypted": true + } +}` + signingKeyPair, err := crypto.NewRSAKeyPair(tdf3KeySize) + if err != nil { + t.Fatalf("crypto.NewRSAKeyPair: %v", err) + } + + signingPubKey, err := signingKeyPair.PublicKeyInPemFormat() + if err != nil { + t.Fatalf("crypto.PublicKeyInPemFormat failed: %v", err) + } + + signingPrivateKey, err := signingKeyPair.PrivateKeyInPemFormat() + if err != nil { + t.Fatalf("crypto.PrivateKeyInPemFormat failed: %v", err) + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != kRewrapV2 { + t.Fatalf("expected to request '%s', got: %s", kRewrapV2, r.URL.Path) + } + if r.Header.Get(kAcceptKey) != kContentTypeJSONValue { + t.Fatalf("expected Accept: application/json header, got: %s", r.Header.Get("Accept")) + } + + requestBody, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("io.ReadAll failed: %v", err) + } + + var data map[string]string + err = json.Unmarshal(requestBody, &data) + if err != nil { + t.Fatalf("json.Unmarsha failed: %v", err) + } + + tokenString, ok := data[kSignedRequestToken] + if !ok { + t.Fatalf("signed token missing in rewrap response") + } + + token, err := jwt.ParseWithClaims(tokenString, &rewrapJWTClaims{}, func(token *jwt.Token) (interface{}, error) { + signingRSAPublicKey, err := jwt.ParseRSAPublicKeyFromPEM([]byte(signingPubKey)) + if err != nil { + return nil, fmt.Errorf("jwt.ParseRSAPrivateKeyFromPEM failed: %w", err) + } + + return signingRSAPublicKey, nil + }) + + var rewrapRequest = "" + if err != nil { + t.Fatalf("jwt.ParseWithClaims failed:%v", err) + } else if claims, fine := token.Claims.(*rewrapJWTClaims); fine { + rewrapRequest = claims.Body + } else { + t.Fatalf("unknown claims type, cannot proceed") + } + + err = json.Unmarshal([]byte(rewrapRequest), &data) + if err != nil { + t.Fatalf("json.Unmarshal failed: %v", err) + } + + wrappedKey, err := crypto.Base64Decode([]byte(data["wrappedKey"])) + if err != nil { + t.Fatalf("crypto.Base64Decode failed: %v", err) + } + + kasPrivateKey = strings.ReplaceAll(kasPrivateKey, "\n\t", "\n") + asymDecrypt, err := crypto.NewAsymDecryption(kasPrivateKey) + if err != nil { + t.Fatalf("crypto.NewAsymDecryption failed: %v", err) + } + + symmetricKey, err := asymDecrypt.Decrypt(wrappedKey) + if err != nil { + t.Fatalf("crypto.Decrypt failed: %v", err) + } + + asymEncrypt, err := crypto.NewAsymEncryption(data[kClientPublicKey]) + if err != nil { + t.Fatalf("crypto.NewAsymEncryption failed: %v", err) + } + + entityWrappedKey, err := asymEncrypt.Encrypt(symmetricKey) + if err != nil { + t.Fatalf("crypto.encrypt failed: %v", err) + } + + response, err := json.Marshal(map[string]string{ + kEntityWrappedKey: string(crypto.Base64Encode(entityWrappedKey)), + }) + if err != nil { + t.Fatalf("json.Marshal failed: %v", err) + } + + w.WriteHeader(http.StatusOK) + _, err = w.Write(response) + if err != nil { + t.Fatalf("http.ResponseWriter.Write failed: %v", err) + } + })) + defer server.Close() + + manifestObj := &Manifest{} + err = json.Unmarshal([]byte(sampleManifest), manifestObj) + if err != nil { + t.Fatalf("json.Unmarshal failed:%v", err) + } + + // mock the kas url + for index := range manifestObj.EncryptionInformation.KeyAccessObjs { + manifestObj.EncryptionInformation.KeyAccessObjs[index].KasURL = server.URL + } + + authConfig := AuthConfig{signingPrivateKey: signingPrivateKey, signingPublicKey: signingPubKey} + sKey, err := newSplitKeyFromManifest(authConfig, *manifestObj) + if err != nil { + t.Errorf("newSplitKeyFromManifest failed: %v", err) + } + + if len(sKey.tdfKeyAccessObjects) != 2 { + t.Errorf("split key key access objects count don't match: expected %v, got %v", len(sKey.tdfKeyAccessObjects), 2) + } + + expectedSplitKey := "6788741d1a659ac43693ffba933d8eaded57fad1705558fba98a89605fb56ab8" + if hex.EncodeToString(sKey.key[:]) != expectedSplitKey { + t.Errorf("split key is valid explected:%v, got %v", expectedSplitKey, hex.EncodeToString(sKey.key[:])) + } +} diff --git a/sdk/tdf.go b/sdk/tdf.go new file mode 100644 index 0000000000..82dd3b5821 --- /dev/null +++ b/sdk/tdf.go @@ -0,0 +1,339 @@ +package sdk + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "github.com/opentdf/opentdf-v2-poc/internal/archive" + "github.com/opentdf/opentdf-v2-poc/internal/crypto" + "io" + "strings" +) + +var ( + errFileTooLarge = errors.New("tdf: can't create tdf larger than 64gb") + errRootSigValidation = errors.New("tdf: failed integrity check on root signature") + errSegSizeMismatch = errors.New("tdf: mismatch encrypted segment size in manifest") + errTDFReaderFailed = errors.New("tdf: fail to read bytes from TDFReader") + errWriteFailed = errors.New("tdf: io.writer fail to write all bytes") + errSegSigValidation = errors.New("tdf: Failed integrity check on segment hash") +) + +const ( + maxFileSizeSupported = 68719476736 // 64gb + defaultMimeType = "application/octet-stream" + tdfAsZip = "zip" + gcmIvSize = 12 + aesBlockSize = 16 + hmacIntegrityAlgorithm = "HS256" + gmacIntegrityAlgorithm = "GMAC" + tdfZipReference = "reference" +) + +// Create tdf +func Create(tdfConfig TDFConfig, reader io.ReadSeeker, writer io.Writer) (int64, error) { + toalBytes := int64(0) + inputSize, err := reader.Seek(0, io.SeekEnd) + if err != nil { + return toalBytes, fmt.Errorf("readSeeker.Seek failed: %w", err) + } + + _, err = reader.Seek(0, io.SeekStart) + if err != nil { + return toalBytes, fmt.Errorf("readSeeker.Seek failed: %w", err) + } + + if inputSize > maxFileSizeSupported { + return toalBytes, errFileTooLarge + } + + // create a split key + splitKey, err := newSplitKeyFromKasInfo(tdfConfig.kasInfoList, tdfConfig.attributes, tdfConfig.metaData) + if err != nil { + return toalBytes, fmt.Errorf("fail to create a new split key: %w", err) + } + + manifest, err := splitKey.getManifest() + if err != nil { + return toalBytes, fmt.Errorf("fail to create manifest: %w", err) + } + + segmentSize := tdfConfig.defaultSegmentSize + totalSegments := inputSize / segmentSize + if inputSize%segmentSize != 0 { + totalSegments++ + } + + // empty payload we still want to create a payload + if totalSegments == 0 { + totalSegments = 1 + } + + encryptedSegmentSize := segmentSize + gcmIvSize + aesBlockSize + payloadSize := inputSize + (totalSegments * (gcmIvSize + aesBlockSize)) + tdfWriter := archive.NewTDFWriter(writer) + + err = tdfWriter.SetPayloadSize(payloadSize) + if err != nil { + return toalBytes, fmt.Errorf("archive.SetPayloadSize failed: %w", err) + } + + var readPos int64 + var aggregateHash string + readBuf := bytes.NewBuffer(make([]byte, 0, tdfConfig.defaultSegmentSize)) + for totalSegments != 0 { // adjust read size + readSize := segmentSize + if (inputSize - readPos) < segmentSize { + readSize = inputSize - readPos + } + + n, err := reader.Read(readBuf.Bytes()[:readSize]) + if err != nil { + return toalBytes, fmt.Errorf("io.ReadSeeker.Read failed: %w", err) + } + + if int64(n) != readSize { + return toalBytes, fmt.Errorf("io.ReadSeeker.Read size missmatch") + } + + cipherData, err := splitKey.encrypt(readBuf.Bytes()[:readSize]) + if err != nil { + return toalBytes, fmt.Errorf("io.ReadSeeker.Read failed: %w", err) + } + + err = tdfWriter.AppendPayload(cipherData) + if err != nil { + return toalBytes, fmt.Errorf("io.writer.Write failed: %w", err) + } + + payloadSig, err := splitKey.getSignature(cipherData, tdfConfig.segmentIntegrityAlgorithm) + if err != nil { + return toalBytes, fmt.Errorf("splitKey.GetSignaturefailed: %w", err) + } + + aggregateHash += payloadSig + + segmentInfo := Segment{} + segmentInfo.Hash = string(crypto.Base64Encode([]byte(payloadSig))) + segmentInfo.Size = readSize + segmentInfo.EncryptedSize = int64(len(cipherData)) + manifest.EncryptionInformation.IntegrityInformation.Segments = + append(manifest.EncryptionInformation.IntegrityInformation.Segments, segmentInfo) + + totalSegments-- + readPos += readSize + } + + aggregateHashSig, err := splitKey.getSignature([]byte(aggregateHash), tdfConfig.integrityAlgorithm) + if err != nil { + return toalBytes, fmt.Errorf("splitKey.GetSignaturefailed: %w", err) + } + + sig := string(crypto.Base64Encode([]byte(aggregateHashSig))) + manifest.EncryptionInformation.IntegrityInformation.RootSignature.Signature = sig + + integrityAlgStr := gmacIntegrityAlgorithm + if tdfConfig.integrityAlgorithm == HS256 { + integrityAlgStr = hmacIntegrityAlgorithm + } + manifest.EncryptionInformation.IntegrityInformation.RootSignature.Algorithm = integrityAlgStr + + manifest.EncryptionInformation.IntegrityInformation.DefaultSegmentSize = segmentSize + manifest.EncryptionInformation.IntegrityInformation.DefaultEncryptedSegSize = encryptedSegmentSize + + segIntegrityAlgStr := gmacIntegrityAlgorithm + if tdfConfig.segmentIntegrityAlgorithm == HS256 { + segIntegrityAlgStr = hmacIntegrityAlgorithm + } + + manifest.EncryptionInformation.IntegrityInformation.SegmentHashAlgorithm = segIntegrityAlgStr + manifest.EncryptionInformation.Method.IsStreamable = true + + // add payload info + manifest.Payload.MimeType = defaultMimeType + manifest.Payload.Protocol = tdfAsZip + manifest.Payload.Type = tdfZipReference + manifest.Payload.URL = archive.TDFPayloadFileName + manifest.Payload.IsEncrypted = true + + manifestAsStr, err := json.Marshal(manifest) + if err != nil { + return toalBytes, fmt.Errorf("json.Marshal failed:%w", err) + } + + err = tdfWriter.AppendManifest(string(manifestAsStr)) + if err != nil { + return toalBytes, fmt.Errorf("TDFWriter.AppendManifest failed:%w", err) + } + + totalBytes, err := tdfWriter.Finish() + if err != nil { + return toalBytes, fmt.Errorf("TDFWriter.Finish failed:%w", err) + } + + return totalBytes, nil +} + +// GetPayload decrypt the tdf and write the data to writer. +func GetPayload(authConfig AuthConfig, reader io.ReadSeeker, writer io.Writer) (int64, error) { + + totalBytes := int64(0) + + // create tdf reader + tdfReader, err := archive.NewTDFReader(reader) + if err != nil { + return totalBytes, fmt.Errorf("archive.NewTDFReader failed: %w", err) + } + + manifest, err := tdfReader.Manifest() + if err != nil { + return totalBytes, fmt.Errorf("tdfReader.Manifest failed: %w", err) + } + + manifestObj := &Manifest{} + err = json.Unmarshal([]byte(manifest), manifestObj) + if err != nil { + return totalBytes, fmt.Errorf("json.Unmarshal failed:%w", err) + } + + // create a split key + sKey, err := newSplitKeyFromManifest(authConfig, *manifestObj) + if err != nil { + return totalBytes, fmt.Errorf("fail to create a new split key: %w", err) + } + + res, err := sKey.validateRootSignature(manifestObj) + if err != nil { + return totalBytes, fmt.Errorf("splitKey.validateRootSignature failed: %w", err) + } + + if !res { + return totalBytes, errRootSigValidation + } + + segSize := manifestObj.EncryptionInformation.IntegrityInformation.DefaultSegmentSize + encryptedSegSize := manifestObj.EncryptionInformation.IntegrityInformation.DefaultEncryptedSegSize + + if segSize != encryptedSegSize-(gcmIvSize+aesBlockSize) { + return totalBytes, errSegSizeMismatch + } + + var payloadReadOffset int64 + for _, seg := range manifestObj.EncryptionInformation.IntegrityInformation.Segments { + readBuf, err := tdfReader.ReadPayload(payloadReadOffset, seg.EncryptedSize) + if err != nil { + return totalBytes, fmt.Errorf("TDFReader.ReadPayload failed: %w", err) + } + + if int64(len(readBuf)) != seg.EncryptedSize { + return totalBytes, errTDFReaderFailed + } + + segHashAlg := manifestObj.EncryptionInformation.IntegrityInformation.SegmentHashAlgorithm + sigAlg := HS256 + if strings.EqualFold(gmacIntegrityAlgorithm, segHashAlg) { + sigAlg = GMAC + } + + payloadSig, err := sKey.getSignature(readBuf, sigAlg) + if err != nil { + return totalBytes, fmt.Errorf("splitKey.GetSignaturefailed: %w", err) + } + + if seg.Hash != string(crypto.Base64Encode([]byte(payloadSig))) { + return totalBytes, errSegSigValidation + } + + writeBuf, err := sKey.decrypt(readBuf) + if err != nil { + return totalBytes, fmt.Errorf("splitKey.decrypt failed: %w", err) + } + + n, err := writer.Write(writeBuf) + if err != nil { + return totalBytes, fmt.Errorf("io.writer.write failed: %w", err) + } + + if n != len(writeBuf) { + return totalBytes, errWriteFailed + } + + payloadReadOffset += seg.EncryptedSize + totalBytes += int64(n) + } + + return totalBytes, nil +} + +// GetMetadata return the meta present in tdf. +func GetMetadata(authConfig AuthConfig, reader io.ReadSeeker) (string, error) { + // create tdf reader + tdfReader, err := archive.NewTDFReader(reader) + if err != nil { + return "", fmt.Errorf("archive.NewTDFReader failed: %w", err) + } + + manifest, err := tdfReader.Manifest() + if err != nil { + return "", fmt.Errorf("tdfReader.Manifest failed: %w", err) + } + + manifestObj := &Manifest{} + err = json.Unmarshal([]byte(manifest), manifestObj) + if err != nil { + return "", fmt.Errorf("json.Unmarshal failed:%w", err) + } + + // create a split key + sKey, err := newSplitKeyFromManifest(authConfig, *manifestObj) + if err != nil { + return "", fmt.Errorf("fail to create a new split key: %w", err) + } + + // There will be at least one key access in tdf + return sKey.tdfKeyAccessObjects[0].metaData, nil +} + +// GetAttributes return the attributes present in tdf. +func GetAttributes(reader io.ReadSeeker) ([]string, error) { + // create tdf reader + tdfReader, err := archive.NewTDFReader(reader) + if err != nil { + return nil, fmt.Errorf("archive.NewTDFReader failed: %w", err) + } + + manifest, err := tdfReader.Manifest() + if err != nil { + return nil, fmt.Errorf("tdfReader.Manifest failed: %w", err) + } + + manifestObj := &Manifest{} + err = json.Unmarshal([]byte(manifest), manifestObj) + if err != nil { + return nil, fmt.Errorf("json.Unmarshal failed:%w", err) + } + + policy, err := crypto.Base64Decode([]byte(manifestObj.Policy)) + if err != nil { + return nil, fmt.Errorf("crypto.Base64Decode failed:%w", err) + } + + return attributesFromPolicy(policy) +} + +func attributesFromPolicy(policy []byte) ([]string, error) { + policyObj := policyObject{} + err := json.Unmarshal(policy, &policyObj) + if err != nil { + return nil, fmt.Errorf("json.Unmarshal failed: %w", err) + } + + attributes := make([]string, 0) + attributeObjs := policyObj.Body.DataAttributes + for _, attributeObj := range attributeObjs { + attributes = append(attributes, attributeObj.Attribute) + } + + return attributes, nil +} diff --git a/sdk/tdf_config.go b/sdk/tdf_config.go new file mode 100644 index 0000000000..a3e291ae63 --- /dev/null +++ b/sdk/tdf_config.go @@ -0,0 +1,163 @@ +package sdk + +import ( + "context" + "encoding/json" + "fmt" + "github.com/opentdf/opentdf-v2-poc/internal/crypto" + "io" + "log/slog" + "net/http" + "net/url" +) + +type TDFFormat = int + +const ( + JSONFormat = iota + XMLFormat +) + +type IntegrityAlgorithm = int + +const ( + HS256 = iota + GMAC +) + +const kHTTPOk = 200 + +type KASInfo struct { + url string + publicKey string // Public key can be empty. +} + +type TDFConfig struct { + defaultSegmentSize int64 + enableEncryption bool + tdfFormat TDFFormat + tdfPublicKey string // TODO: Remove it + tdfPrivateKey string + metaData string + integrityAlgorithm IntegrityAlgorithm + segmentIntegrityAlgorithm IntegrityAlgorithm + assertions []Assertion + attributes []string + kasInfoList []KASInfo +} + +const ( + tdf3KeySize = 2048 + defaultSegmentSize = 2 * 1024 * 1024 // 2mb + kasPublicKeyPath = "/kas_public_key" +) + +// NewTDFConfig Create a new instance of tdf config. +func NewTDFConfig() (*TDFConfig, error) { + rsaKeyPair, err := crypto.NewRSAKeyPair(tdf3KeySize) + if err != nil { + return nil, fmt.Errorf("crypto.NewRSAKeyPair failed: %w", err) + } + + publicKey, err := rsaKeyPair.PublicKeyInPemFormat() + if err != nil { + return nil, fmt.Errorf("crypto.PublicKeyInPemFormat failed: %w", err) + } + + privateKey, err := rsaKeyPair.PublicKeyInPemFormat() + if err != nil { + return nil, fmt.Errorf("crypto.PublicKeyInPemFormat failed: %w", err) + } + + return &TDFConfig{ + tdfPrivateKey: privateKey, + tdfPublicKey: publicKey, + defaultSegmentSize: defaultSegmentSize, + enableEncryption: true, + tdfFormat: JSONFormat, + integrityAlgorithm: HS256, + segmentIntegrityAlgorithm: GMAC, + }, nil +} + +// AddKasInformation Add all the kas urls and their corresponding public keys +// that is required to create and read the tdf. +func (tdfConfig *TDFConfig) AddKasInformation(kasInfoList []KASInfo) error { + for _, kasInfo := range kasInfoList { + newEntry := KASInfo{} + newEntry.url = kasInfo.url + newEntry.publicKey = kasInfo.publicKey + + if newEntry.publicKey != "" { + tdfConfig.kasInfoList = append(tdfConfig.kasInfoList, newEntry) + continue + } + + // get kas public + kasPubKeyURL, err := url.JoinPath(kasInfo.url, kasPublicKeyPath) + if err != nil { + return fmt.Errorf("url.Parse failed: %w", err) + } + + request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, kasPubKeyURL, nil) + if err != nil { + return fmt.Errorf("http.NewRequestWithContext failed: %w", err) + } + + // add required headers + request.Header = http.Header{ + kAcceptKey: {kContentTypeJSONValue}, + } + + client := &http.Client{} + + response, err := client.Do(request) + if response.StatusCode != kHTTPOk { + return fmt.Errorf("client.Do failed: %w", err) + } + + defer func(Body io.ReadCloser) { + err := Body.Close() + if err != nil { + slog.Error("Fail to close HTTP response") + } + }(response.Body) + + var jsonResponse interface{} + err = json.NewDecoder(response.Body).Decode(&jsonResponse) + if err != nil { + return fmt.Errorf("json.NewDecoder.Decode failed: %w", err) + } + + newEntry.publicKey = fmt.Sprintf("%s", jsonResponse) + + tdfConfig.kasInfoList = append(tdfConfig.kasInfoList, newEntry) + } + + return nil +} + +// AddAttributes Add all the attributes used to create and read the tdf. +func (tdfConfig *TDFConfig) AddAttributes(attributes []string) { + tdfConfig.attributes = append(tdfConfig.attributes, attributes...) +} + +// SetMetaData Set the meta data. +func (tdfConfig *TDFConfig) SetMetaData(metaData string) { + tdfConfig.metaData = metaData +} + +// SetDefaultSegmentSize Set the default segment size. +func (tdfConfig *TDFConfig) SetDefaultSegmentSize(size int64) { + tdfConfig.defaultSegmentSize = size +} + +// SetXMLFormat TDFs created with this config will be in XML format. +func (tdfConfig *TDFConfig) SetXMLFormat() { + tdfConfig.tdfFormat = XMLFormat +} + +// DisableEncryption TDFs create with this config will not be encrypted. +func (tdfConfig *TDFConfig) DisableEncryption() { + tdfConfig.enableEncryption = false +} diff --git a/sdk/tdf_test.go b/sdk/tdf_test.go new file mode 100644 index 0000000000..30e6ab6a2a --- /dev/null +++ b/sdk/tdf_test.go @@ -0,0 +1,530 @@ +package sdk + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/golang-jwt/jwt/v4" + "github.com/opentdf/opentdf-v2-poc/internal/crypto" + "io" + "net/http" + "net/http/httptest" + "os" + "reflect" + "strconv" + "strings" + "testing" +) + +const ( + oneKB = 1024 + // tenKB = 10 * oneKB + oneMB = 1024 * 1024 + hundredMB = 100 * oneMB + // oneGB = 10 * hundredMB + // tenGB = 10 * oneGB +) + +const ( + stepSize int64 = 2 * oneMB +) + +type tdfTest struct { + fileSize int64 + tdfFileSize int64 + kasInfoList []KASInfo +} + +//nolint:gochecknoglobals +var mockKasPublicKey = `-----BEGIN CERTIFICATE----- +MIICmDCCAYACCQC3BCaSANRhYzANBgkqhkiG9w0BAQsFADAOMQwwCgYDVQQDDANr +YXMwHhcNMjEwOTE1MTQxMTQ4WhcNMjIwOTE1MTQxMTQ4WjAOMQwwCgYDVQQDDANr +YXMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDOpiotrvV2i5h6clHM +zDGgh3h/kMa0LoGx2OkDPd8jogycUh7pgE5GNiN2lpSmFkjxwYMXnyrwr9ExyczB +WJ7sRGDCDaQg5fjVUIloZ8FJVbn+sEcfQ9iX6vmI9/S++oGK79QM3V8M8cp41r/T +1YVmuzUHE1say/TLHGhjtGkxHDF8qFy6Z2rYFTCVJQHNqGmwNVGd0qG7gim86Haw +u/CMYj4jG9oITlj8rJtQOaJ6ZqemQVoNmb3j1LkyeUKzRIt+86aoBiz+T3TfOEvX +F6xgBj3XoiOhPYK+abFPYcrArvb6oubT8NjjQoj3j0sXWUnIIMg+e4f+XNVU54Zz +DaLZAgMBAAEwDQYJKoZIhvcNAQELBQADggEBABewfZOJ4/KNRE8IQ5TsW/AVn7C1 +l5ty6tUUBSVi8/df7WYts0bHEdQh9yl9agEU5i4rj43y8vMVZNzSeHcurtV/+C0j +fbkHQHeiQ1xn7cq3Sbh4UVRyuu4C5PklEH4AN6gxmgXC3kT15uWw8I4nm/plzYLs +I099IoRfC5djHUYYLMU/VkOIHuPC3sb7J65pSN26eR8bTMVNagk187V/xNwUuvkf ++NUxDO615/5BwQKnAu5xiIVagYnDZqKCOtYS5qhxF33Nlnwlm7hH8iVZ1RI+n52l +wVyElqp317Ksz+GtTIc+DE6oryxK3tZd4hrj9fXT4KiJvQ4pcRjpePgH7B8= +-----END CERTIFICATE-----` + +//nolint:gochecknoglobals +var mockKasPrivateKey = `-----BEGIN PRIVATE KEY----- + MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDOpiotrvV2i5h6 + clHMzDGgh3h/kMa0LoGx2OkDPd8jogycUh7pgE5GNiN2lpSmFkjxwYMXnyrwr9Ex + yczBWJ7sRGDCDaQg5fjVUIloZ8FJVbn+sEcfQ9iX6vmI9/S++oGK79QM3V8M8cp4 + 1r/T1YVmuzUHE1say/TLHGhjtGkxHDF8qFy6Z2rYFTCVJQHNqGmwNVGd0qG7gim8 + 6Hawu/CMYj4jG9oITlj8rJtQOaJ6ZqemQVoNmb3j1LkyeUKzRIt+86aoBiz+T3Tf + OEvXF6xgBj3XoiOhPYK+abFPYcrArvb6oubT8NjjQoj3j0sXWUnIIMg+e4f+XNVU + 54ZzDaLZAgMBAAECggEBALb0yK0PlMUyzHnEUwXV1y5AIoAWhsYp0qvJ1msHUVKz + +yQ/VJz4+tQQxI8OvGbbnhNkd5LnWdYkYzsIZl7b/kBCPcQw3Zo+4XLCzhUAn1E1 + M+n42c8le1LtN6Z7mVWoZh7DPONy7t+ABvm7b7S1+1i78DPmgCeWYZGeAhIcPXG6 + 5AxWIV3jigxksE6kYY9Y7DmtsZgMRrdV7SU8VtgPtT7tua8z5/U3Av0WINyKBSoM + 0yDHsAg57KnM8znx2JWLtHd0Mk5bBuu2DLbtyKNrVUAUuMPzrLGBh9S9QRd934KU + uFAi1TEfgEachnGgSHJpzVzr2ur1tifABnQ7GNXObe0CgYEA6KowK0subdDY+uGW + ciP2XDAMerbJJeL0/UIGPb/LUmskniio2493UBGgY2FsRyvbzJ+/UAOjIPyIxhj7 + 78ZyVG8BmIzKan1RRVh//O+5yvks/eTOYjWeQ1Lcgqs3q4YAO13CEBZgKWKTUomg + mskFJq04tndeSIyhDaW+BuWaXA8CgYEA42ABz3pql+DH7oL5C4KYBymK6wFBBOqk + dVk+ftyJQ6PzuZKpfsu4aPIjKm71lkTgK6O9o08s3SckAdu6vLukq2TZFF+a+9OI + lu5ww7GvfdMTgLAaFchD4bPlOInh1KVjBc1MwGXpl0ROde5pi8+WUrv9QJuoQfB/ + 4rhYdbJLSpcCgYA41mqSCPm8pgp7r2RbWeGzP6Gs0L5u3PTQcbKonxQCfF4jrPcj + O/b/vm6aGJClClfVsyi/WUQeqNKY4j2Zo7cGXV/cbnh8b0TNVgNePQn8Rcbx91Vb + tJGHDNUFruIYqtGfrxXbbDvtoEExJqHvbjAt9J8oJB0KSCCH/vdfI/QDjQKBgQCD + xLPH5Y24js/O7aAeh4RLQkv7fTKNAt5kE2AgbPYveOhZ9yC7Fpy8VPcENGGmwCuZ + nr7b0ZqSX4iCezBxB92aZktXf0B2CFT0AyLehi7JoHWA8o1rai/MsVB5v45ciawl + RKDiLy18OF2wAoawO5FGSSOvOYX9EL9MSMEbFESF6QKBgCVlZ9pPC+55rGT6AcEL + tUpDs+/wZvcmfsFd8xC5mMUN0DatAVzVAUI95+tQaWU3Uj+bqHq0lC6Wy2VceG0D + D+7EicjdGFN/2WVPXiYX1fblkxasZY+wChYBrPLjA9g0qOzzmXbRBph5QxDuQjJ6 + qcddVKB624a93ZBssn7OivnR + -----END PRIVATE KEY-----` + +var testHarnesses = []tdfTest{ //nolint:gochecknoglobals + { + fileSize: 5, + tdfFileSize: 1580, + kasInfoList: []KASInfo{ + { + url: "http://localhost:65432/api/kas", + publicKey: "", + }, + }, + }, + { + fileSize: oneKB, + tdfFileSize: 2604, + kasInfoList: []KASInfo{ + { + url: "http://localhost:65432/api/kas", + publicKey: "", + }, + }, + }, + { + fileSize: hundredMB, + tdfFileSize: 104866456, + kasInfoList: []KASInfo{ + { + url: "http://localhost:65432/api/kas", + publicKey: mockKasPublicKey, + }, + { + url: "http://localhost:65432/api/kas", + publicKey: mockKasPublicKey, + }, + }, + }, +} + +var buffer []byte //nolint:gochecknoglobals + +func init() { + // create a buffer and write with 0xff + buffer = make([]byte, stepSize) + for index := 0; index < len(buffer); index++ { + buffer[index] = 'a' + } +} + +func TestSimpleTDF(t *testing.T) { + server, signingPubKey, signingPrivateKey := runKas(t) + defer server.Close() + + metaDataStr := `{"displayName" : "openTDF go sdk"}` + + attributes := []string{ + "https://example.com/attr/Classification/value/S", + "https://example.com/attr/Classification/value/X", + } + + expectedTdfSize := int64(1989) + tdfFilename := "secure-text.tdf" + plainText := "Virtru" + { + // Create TDFConfig + tdfConfig, err := NewTDFConfig() + if err != nil { + t.Fatalf("Fail to create tdf config: %v", err) + } + + kasURLs := []KASInfo{ + { + url: server.URL, + publicKey: "", + }, + } + + err = tdfConfig.AddKasInformation(kasURLs) + if err != nil { + t.Fatalf("tdfConfig.AddKasUrls failed: %v", err) + } + + tdfConfig.SetMetaData(metaDataStr) + tdfConfig.AddAttributes(attributes) + + inBuf := bytes.NewBufferString(plainText) + bufReader := bytes.NewReader(inBuf.Bytes()) + + fileWriter, err := os.Create(tdfFilename) + if err != nil { + t.Fatalf("os.Create failed: %v", err) + } + defer func(fileWriter *os.File) { + err := fileWriter.Close() + if err != nil { + t.Fatalf("Fail to close the file: %v", err) + } + }(fileWriter) + + tdfSize, err := Create(*tdfConfig, bufReader, fileWriter) + if err != nil { + t.Fatalf("tdf.Create failed: %v", err) + } + + if tdfSize != expectedTdfSize { + t.Errorf("tdf size test failed expected %v, got %v", tdfSize, expectedTdfSize) + } + } + + // test meta data + { + readSeeker, err := os.Open(tdfFilename) + if err != nil { + t.Fatalf("Fail to open archive file:%s %v", tdfFilename, err) + } + + defer func(readSeeker *os.File) { + err := readSeeker.Close() + if err != nil { + t.Fatalf("Fail to close archive file:%v", err) + } + }(readSeeker) + + // create auth config + authConfig, err := NewAuthConfig() + if err != nil { + t.Fatalf("Fail to close archive file:%v", err) + } + + // override the signing keys to get the mock working. + authConfig.signingPublicKey = signingPubKey + authConfig.signingPrivateKey = signingPrivateKey + + metaData, err := GetMetadata(*authConfig, readSeeker) + if err != nil { + t.Fatalf("Fail to get meta data from tdf:%v", err) + } + + if metaDataStr != metaData { + t.Errorf("meta data test failed expected %v, got %v", metaDataStr, metaData) + } + + dataAttributes, err := GetAttributes(readSeeker) + if err != nil { + t.Fatalf("Fail to get policy from tdf:%v", err) + } + + if reflect.DeepEqual(attributes, dataAttributes) != true { + t.Errorf("attributes test failed expected %v, got %v", attributes, dataAttributes) + } + } + + // test decrypt + { + readSeeker, err := os.Open(tdfFilename) + if err != nil { + t.Fatalf("Fail to open archive file:%s %v", tdfFilename, err) + } + + defer func(readSeeker *os.File) { + err := readSeeker.Close() + if err != nil { + t.Fatalf("Fail to close archive file:%v", err) + } + }(readSeeker) + + // writer + var buf bytes.Buffer + // create auth config + authConfig, err := NewAuthConfig() + if err != nil { + t.Fatalf("Fail to close archive file:%v", err) + } + + // override the signing keys to get the mock working. + authConfig.signingPublicKey = signingPubKey + authConfig.signingPrivateKey = signingPrivateKey + + payloadSize, err := GetPayload(*authConfig, readSeeker, &buf) + if err != nil { + t.Fatalf("Fail to decrypt tdf:%v", err) + } + + if string(buf.Bytes()[:payloadSize]) != plainText { + t.Errorf("decrypt test failed expected %v, got %v", plainText, buf.String()) + } + } + + _ = os.Remove(tdfFilename) +} + +func TestTDF(t *testing.T) { + server, signingPubKey, signingPrivateKey := runKas(t) + defer server.Close() + + for index, test := range testHarnesses { // create .txt file + plaintTextFileName := strconv.Itoa(index) + ".txt" + tdfFileName := plaintTextFileName + ".tdf" + decryptedTdfFileName := tdfFileName + ".txt" + + kasInfoList := test.kasInfoList + for index := range kasInfoList { + kasInfoList[index].url = server.URL + kasInfoList[index].publicKey = "" + } + + tdfConfig, err := NewTDFConfig() + if err != nil { + t.Fatalf("Fail to create tdf config: %v", err) + } + + err = tdfConfig.AddKasInformation(kasInfoList) + if err != nil { + t.Fatalf("tdfConfig.AddKasUrls failed: %v", err) + } + + // test encrypt + testEncrypt(t, *tdfConfig, plaintTextFileName, tdfFileName, test) + + // create auth config + authConfig, err := NewAuthConfig() + if err != nil { + t.Fatalf("Fail to close archive file:%v", err) + } + + // override the signing keys to get the mock working. + authConfig.signingPublicKey = signingPubKey + authConfig.signingPrivateKey = signingPrivateKey + + // test decrypt + testDecrypt(t, *authConfig, tdfFileName, decryptedTdfFileName, test.fileSize) + + // Remove the test files + _ = os.Remove(plaintTextFileName) + _ = os.Remove(tdfFileName) + _ = os.Remove(decryptedTdfFileName) + } +} + +// create tdf +func testEncrypt(t *testing.T, tdfConfig TDFConfig, plainTextFilename, tdfFileName string, test tdfTest) { + // create a plain text file + createFileName(t, buffer, plainTextFilename, test.fileSize) + + // open file + readSeeker, err := os.Open(plainTextFilename) + if err != nil { + t.Fatalf("Fail to open plain text file:%s %v", plainTextFilename, err) + } + + defer func(readSeeker *os.File) { + err := readSeeker.Close() + if err != nil { + t.Fatalf("Fail to close plain text file:%v", err) + } + }(readSeeker) + + fileWriter, err := os.Create(tdfFileName) + + if err != nil { + t.Fatalf("os.Create failed: %v", err) + } + defer func(fileWriter *os.File) { + err := fileWriter.Close() + if err != nil { + t.Fatalf("Fail to close the tdf file: %v", err) + } + }(fileWriter) // Create TDFConfig + tdfSize, err := Create(tdfConfig, readSeeker, fileWriter) + if err != nil { + t.Fatalf("tdf.Create failed: %v", err) + } + + if tdfSize != test.tdfFileSize { + t.Errorf("tdf size test failed expected %v, got %v", test.tdfFileSize, tdfSize) + } +} + +func testDecrypt(t *testing.T, authConfig AuthConfig, tdfFile, decryptedTdfFileName string, payloadSize int64) { + readSeeker, err := os.Open(tdfFile) + if err != nil { + t.Fatalf("Fail to open archive file:%s %v", tdfFile, err) + } + + defer func(readSeeker *os.File) { + err := readSeeker.Close() + if err != nil { + t.Fatalf("Fail to close archive file:%v", err) + } + }(readSeeker) + + fileWriter, err := os.Create(decryptedTdfFileName) + if err != nil { + t.Fatalf("os.Create failed: %v", err) + } + + defer func(fileWriter *os.File) { + err := fileWriter.Close() + if err != nil { + t.Fatalf("Fail to close the file: %v", err) + } + }(fileWriter) // Create TDFConfig + + decryptedData, err := GetPayload(authConfig, readSeeker, fileWriter) + if err != nil { + t.Fatalf("tdf.Create failed: %v", err) + } + + if payloadSize != decryptedData { + t.Errorf("payload size test failed expected %v, got %v", payloadSize, decryptedData) + } +} + +func createFileName(t *testing.T, buf []byte, filename string, size int64) { + f, err := os.Create(filename) + if err != nil { + t.Fatalf("os.Create failed: %v", err) + } + + totalBytes := size + var bytesToWrite int64 + for totalBytes > 0 { + if totalBytes >= stepSize { + totalBytes -= stepSize + bytesToWrite = stepSize + } else { + bytesToWrite = totalBytes + totalBytes = 0 + } + _, err := f.Write(buf[:bytesToWrite]) + if err != nil { + t.Fatalf("io.Write failed: %v", err) + } + } + err = f.Close() + if err != nil { + t.Fatalf("os.Close failed: %v", err) + } +} + +func runKas(t *testing.T) (*httptest.Server, string, string) { + signingKeyPair, err := crypto.NewRSAKeyPair(tdf3KeySize) + if err != nil { + t.Fatalf("crypto.NewRSAKeyPair: %v", err) + } + + signingPubKey, err := signingKeyPair.PublicKeyInPemFormat() + if err != nil { + t.Fatalf("crypto.PublicKeyInPemFormat failed: %v", err) + } + + signingPrivateKey, err := signingKeyPair.PrivateKeyInPemFormat() + if err != nil { + t.Fatalf("crypto.PrivateKeyInPemFormat failed: %v", err) + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get(kAcceptKey) != kContentTypeJSONValue { + t.Fatalf("expected Accept: application/json header, got: %s", r.Header.Get("Accept")) + } + + r.Header.Set(kContentTypeKey, kContentTypeJSONValue) + + switch { + case r.URL.Path == kasPublicKeyPath: + kasPublicKeyResponse, err := json.Marshal(mockKasPublicKey) + if err != nil { + t.Fatalf("json.Marshal failed: %v", err) + } + w.WriteHeader(http.StatusOK) + _, err = w.Write(kasPublicKeyResponse) + if err != nil { + t.Fatalf("http.ResponseWriter.Write failed: %v", err) + } + case r.URL.Path == kRewrapV2: + requestBody, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("io.ReadAll failed: %v", err) + } + var data map[string]string + err = json.Unmarshal(requestBody, &data) + if err != nil { + t.Fatalf("json.Unmarsha failed: %v", err) + } + tokenString, ok := data[kSignedRequestToken] + if !ok { + t.Fatalf("signed token missing in rewrap response") + } + token, err := jwt.ParseWithClaims(tokenString, &rewrapJWTClaims{}, func(token *jwt.Token) (interface{}, error) { + signingRSAPublicKey, err := jwt.ParseRSAPublicKeyFromPEM([]byte(signingPubKey)) + if err != nil { + return nil, fmt.Errorf("jwt.ParseRSAPrivateKeyFromPEM failed: %w", err) + } + + return signingRSAPublicKey, nil + }) + var rewrapRequest = "" + if err != nil { + t.Fatalf("jwt.ParseWithClaims failed:%v", err) + } else if claims, fine := token.Claims.(*rewrapJWTClaims); fine { + rewrapRequest = claims.Body + } else { + t.Fatalf("unknown claims type, cannot proceed") + } + err = json.Unmarshal([]byte(rewrapRequest), &data) + if err != nil { + t.Fatalf("json.Unmarshal failed: %v", err) + } + wrappedKey, err := crypto.Base64Decode([]byte(data["wrappedKey"])) + if err != nil { + t.Fatalf("crypto.Base64Decode failed: %v", err) + } + kasPrivateKey := strings.ReplaceAll(mockKasPrivateKey, "\n\t", "\n") + asymDecrypt, err := crypto.NewAsymDecryption(kasPrivateKey) + if err != nil { + t.Fatalf("crypto.NewAsymDecryption failed: %v", err) + } + symmetricKey, err := asymDecrypt.Decrypt(wrappedKey) + if err != nil { + t.Fatalf("crypto.Decrypt failed: %v", err) + } + asymEncrypt, err := crypto.NewAsymEncryption(data[kClientPublicKey]) + if err != nil { + t.Fatalf("crypto.NewAsymEncryption failed: %v", err) + } + entityWrappedKey, err := asymEncrypt.Encrypt(symmetricKey) + if err != nil { + t.Fatalf("crypto.encrypt failed: %v", err) + } + response, err := json.Marshal(map[string]string{ + kEntityWrappedKey: string(crypto.Base64Encode(entityWrappedKey)), + }) + if err != nil { + t.Fatalf("json.Marshal failed: %v", err) + } + w.WriteHeader(http.StatusOK) + _, err = w.Write(response) + if err != nil { + t.Fatalf("http.ResponseWriter.Write failed: %v", err) + } + default: + t.Fatalf("expected to request: %s", r.URL.Path) + } + })) + + return server, signingPubKey, signingPrivateKey +}