diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index 6d7543bfb7..d9da9b4661 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -54,6 +54,7 @@ jobs: version: v1.55 only-new-issues: true working-directory: ${{ matrix.directory }} + args: --out-format=colored-line-number - run: go test ./... -short working-directory: ${{ matrix.directory }} diff --git a/.golangci.yaml b/.golangci.yaml index 7f34ce6cb9..a35a60f65a 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -195,7 +195,7 @@ linters: - goconst # finds repeated strings that could be replaced by a constant - gocritic # provides diagnostics that check for bugs, performance and style issues - gocyclo # computes and checks the cyclomatic complexity of functions - - godot # checks if comments end in a period + # - godot # checks if comments end in a period - goimports # in addition to fixing imports, goimports also formats your code in the same style as gofmt - gomnd # detects magic numbers # - gomoddirectives # manages the use of 'replace', 'retract', and 'excludes' directives in go.mod diff --git a/sdk/manifest.go b/sdk/manifest.go index f7dae009f9..869cba1f56 100644 --- a/sdk/manifest.go +++ b/sdk/manifest.go @@ -64,7 +64,7 @@ type attributeObject struct { KasURL string `json:"kasURL"` } -type policyObject struct { +type PolicyObject struct { UUID string `json:"uuid"` Body struct { DataAttributes []attributeObject `json:"dataAttributes"` diff --git a/sdk/split_key.go b/sdk/split_key.go deleted file mode 100644 index a0d49ca443..0000000000 --- a/sdk/split_key.go +++ /dev/null @@ -1,468 +0,0 @@ -package sdk - -import ( - "bytes" - "context" - "encoding/hex" - "encoding/json" - "errors" - "fmt" - "io" - "log/slog" - "net/http" - "net/url" - "strings" - "time" - - "github.com/golang-jwt/jwt/v4" - "github.com/google/uuid" - "github.com/opentdf/opentdf-v2-poc/sdk/internal/crypto" -) - -const ( - kKeySize = 32 - kWrapped = "wrapped" - kKasProtocol = "kas" - kSplitKeyType = "split" - kGCMCipherAlgorithm = "AES-256-GCM" - kGMACPayloadLength = 16 - kClientPublicKey = "clientPublicKey" - kSignedRequestToken = "signedRequestToken" - kKasURL = "url" - kRewrapV2 = "/v2/rewrap" - kAuthorizationKey = "Authorization" - kContentTypeKey = "Content-Type" - kAcceptKey = "Accept" - kContentTypeJSONValue = "application/json" - kEntityWrappedKey = "entityWrappedKey" - kPolicy = "policy" - kHmacIntegrityAlgorithm = "HS256" - kGmacIntegrityAlgorithm = "GMAC" -) - -type rewrapJWTClaims struct { - jwt.RegisteredClaims - Body string `json:"requestBody"` -} - -type splitKey struct { - attributes []string - tdfKeyAccessObjects []tdfKeyAccess - kasInfoList []KASInfo - key [kKeySize]byte - aesGcm crypto.AesGcm -} - -type tdfKeyAccess struct { - kasPublicKey string - kasURL string - wrappedKey [kKeySize]byte - metaData string -} - -type RequestBody struct { - KeyAccess `json:"keyAccess"` - ClientPublicKey string `json:"clientPublicKey"` - Policy string `json:"policy"` -} - -var ( - errInvalidKasInfo = errors.New("split-key: kas information is missing") - errKasPubKeyMissing = errors.New("split-key: kas public key is missing") -) - -// newSplitKeyFromKasInfo create a instance of split key object. -func newSplitKeyFromKasInfo(kasInfoList []KASInfo, attributes []string, metaData string) (splitKey, error) { - if len(kasInfoList) == 0 { - return splitKey{}, errInvalidKasInfo - } - - tdfKeyAccessObjs := make([]tdfKeyAccess, 0) - for _, kasInfo := range kasInfoList { - if len(kasInfo.publicKey) == 0 { - return splitKey{}, errKasPubKeyMissing - } - - keyAccess := tdfKeyAccess{} - keyAccess.kasPublicKey = kasInfo.publicKey - keyAccess.kasURL = kasInfo.url - keyAccess.metaData = metaData - - key, err := crypto.RandomBytes(kKeySize) - if err != nil { - return splitKey{}, fmt.Errorf("crypto.RandomBytes failed:%w", err) - } - - keyAccess.wrappedKey = [kKeySize]byte(key) - tdfKeyAccessObjs = append(tdfKeyAccessObjs, keyAccess) - } - - sKey := splitKey{} - - // create the split key by XOR all the keys in key access object. - for _, keyAccessObj := range tdfKeyAccessObjs { - for keyByteIndex, keyByte := range keyAccessObj.wrappedKey { - sKey.key[keyByteIndex] ^= keyByte - } - } - - gcm, err := crypto.NewAESGcm(sKey.key[:]) - if err != nil { - return splitKey{}, fmt.Errorf(" crypto.NewAESGcm failed:%w", err) - } - - sKey.attributes = attributes - sKey.tdfKeyAccessObjects = tdfKeyAccessObjs - sKey.kasInfoList = kasInfoList - sKey.aesGcm = gcm - - return sKey, nil -} - -// newSplitKeyFromManifest create a instance of split key from(parsing) the manifest. -func newSplitKeyFromManifest(authConfig AuthConfig, manifest Manifest) (splitKey, error) { - sKey := splitKey{} - - for _, keyAccessObj := range manifest.EncryptionInformation.KeyAccessObjs { - requestBody := RequestBody{keyAccessObj, "", manifest.EncryptionInformation.Policy} - key, err := sKey.rewrap(authConfig, &requestBody) - if err != nil { - return splitKey{}, fmt.Errorf(" splitKey.rewrap failed:%w", err) - } - - for keyByteIndex, keyByte := range key { - sKey.key[keyByteIndex] ^= keyByte - } - keyAccess := tdfKeyAccess{} - keyAccess.kasURL = keyAccessObj.KasURL - keyAccess.wrappedKey = [32]byte(key) - - if len(keyAccessObj.EncryptedMetadata) != 0 { - gcm, err := crypto.NewAESGcm(key) - if err != nil { - return splitKey{}, fmt.Errorf("crypto.NewAESGcm failed:%w", err) - } - - decodedMetaData, err := crypto.Base64Decode([]byte(keyAccessObj.EncryptedMetadata)) - if err != nil { - return splitKey{}, fmt.Errorf("crypto.Base64Decode failed:%w", err) - } - metadata := EncryptedMetadata{} - err = json.Unmarshal(decodedMetaData, &metadata) - if err != nil { - return splitKey{}, fmt.Errorf("json.Unmarshal failed:%w", err) - - } - encodedCipherText := metadata.Cipher - cipherText, _ := crypto.Base64Decode([]byte(encodedCipherText)) - metaData, err := gcm.Decrypt(cipherText) - if err != nil { - return splitKey{}, fmt.Errorf("crypto.AesGcm.encrypt failed:%w", err) - } - - keyAccess.metaData = string(metaData) - } - - sKey.tdfKeyAccessObjects = append(sKey.tdfKeyAccessObjects, keyAccess) - } - - gcm, err := crypto.NewAESGcm(sKey.key[:]) - if err != nil { - return splitKey{}, fmt.Errorf(" crypto.NewAESGcm failed:%w", err) - } - sKey.aesGcm = gcm - - return sKey, nil -} - -// getManifest Return the manifest. -func (splitKey splitKey) getManifest() (*Manifest, error) { - manifest := Manifest{} - manifest.EncryptionInformation.KeyAccessType = kSplitKeyType - - policyObj, err := splitKey.createPolicyObject() - if err != nil { - return nil, fmt.Errorf("fail to create policy object:%w", err) - } - - policyObjectAsStr, err := json.Marshal(policyObj) - if err != nil { - return nil, fmt.Errorf("json.Marshal failed:%w", err) - } - - base64PolicyObject := crypto.Base64Encode(policyObjectAsStr) - - for _, keyAccessObj := range splitKey.tdfKeyAccessObjects { - keyAccess := KeyAccess{} - keyAccess.KeyType = kWrapped - keyAccess.KasURL = keyAccessObj.kasURL - keyAccess.Protocol = kKasProtocol - - // wrap the key with kas public key - asymEncrypt, err := crypto.NewAsymEncryption(keyAccessObj.kasPublicKey) - if err != nil { - return nil, fmt.Errorf("crypto.NewAsymEncryption failed:%w", err) - } - - encryptData, err := asymEncrypt.Encrypt(keyAccessObj.wrappedKey[:]) - if err != nil { - return nil, fmt.Errorf("crypto.AsymEncryption.encrypt failed:%w", err) - } - keyAccess.WrappedKey = string(crypto.Base64Encode(encryptData)) - - // add policyBinding - policyBinding := hex.EncodeToString(crypto.CalculateSHA256Hmac(keyAccessObj.wrappedKey[:], base64PolicyObject)) - keyAccess.PolicyBinding = string(crypto.Base64Encode([]byte(policyBinding))) - - // add meta data - if len(keyAccessObj.metaData) > 0 { - gcm, err := crypto.NewAESGcm(keyAccessObj.wrappedKey[:]) - if err != nil { - return nil, fmt.Errorf("crypto.NewAESGcm failed:%w", err) - } - - encryptedMetaData, err := gcm.Encrypt([]byte(keyAccessObj.metaData)) - if err != nil { - return nil, fmt.Errorf("crypto.AesGcm.encrypt failed:%w", err) - } - - iv := encryptedMetaData[:crypto.GcmStandardNonceSize] - metadata := EncryptedMetadata{Cipher: string(crypto.Base64Encode(encryptedMetaData)), Iv: string(crypto.Base64Encode(iv))} - - metadataJson, err := json.Marshal(metadata) - if err != nil { - return nil, fmt.Errorf(" json.Marshal failed:%w", err) - - } - - keyAccess.EncryptedMetadata = string(crypto.Base64Encode(metadataJson)) - } - - manifest.EncryptionInformation.KeyAccessObjs = append(manifest.EncryptionInformation.KeyAccessObjs, keyAccess) - } - - manifest.EncryptionInformation.Policy = string(base64PolicyObject) - manifest.EncryptionInformation.Method.Algorithm = kGCMCipherAlgorithm - return &manifest, nil -} - -// encrypt the data using the split key. -func (splitKey splitKey) encrypt(data []byte) ([]byte, error) { - buf, err := splitKey.aesGcm.Encrypt(data) - if err != nil { - return nil, fmt.Errorf("AesGcm.encrypt failed:%w", err) - } - - return buf, nil -} - -// decrypt the data using the split key. -func (splitKey splitKey) decrypt(data []byte) ([]byte, error) { - buf, err := splitKey.aesGcm.Decrypt(data) - if err != nil { - return nil, fmt.Errorf("AesGcm.Decrypt failed:%w", err) - } - - return buf, nil -} - -func (splitKey splitKey) validateRootSignature(manifest *Manifest) (bool, error) { - rootSigAlg := manifest.EncryptionInformation.IntegrityInformation.RootSignature.Algorithm - rootSigValue := manifest.EncryptionInformation.IntegrityInformation.RootSignature.Signature - - aggregateHash := &bytes.Buffer{} - for _, segment := range manifest.EncryptionInformation.IntegrityInformation.Segments { - decodedHash, err := crypto.Base64Decode([]byte(segment.Hash)) - if err != nil { - return false, fmt.Errorf("crypto.Base64Decode failed:%w", err) - } - - aggregateHash.Write(decodedHash) - } - - sigAlg := HS256 - if strings.EqualFold(gmacIntegrityAlgorithm, rootSigAlg) { - sigAlg = GMAC - } - - sig, err := splitKey.getSignature(aggregateHash.Bytes(), sigAlg) - if err != nil { - return false, fmt.Errorf("splitkey.getSignature failed:%w", err) - } - - if rootSigValue == string(crypto.Base64Encode([]byte(sig))) { - return true, nil - } - - return false, nil -} - -// getSignature calculate signature of data of the given algorithm. -func (splitKey splitKey) getSignature(data []byte, alg IntegrityAlgorithm) (string, error) { - if alg == HS256 { - hmac := crypto.CalculateSHA256Hmac(splitKey.key[:], data) - return hex.EncodeToString(hmac), nil - } - if kGMACPayloadLength > len(data) { - return "", fmt.Errorf("fail to create gmac signature") - } - - return hex.EncodeToString(data[len(data)-kGMACPayloadLength:]), nil -} - -func (splitKey splitKey) createPolicyObject() (policyObject, error) { - uuidObj, err := uuid.NewUUID() - if err != nil { - return policyObject{}, fmt.Errorf("uuid.NewUUID failed: %w", err) - } - - policyObj := policyObject{} - policyObj.UUID = uuidObj.String() - - for _, attribute := range splitKey.attributes { - attributeObj := attributeObject{} - attributeObj.Attribute = attribute - policyObj.Body.DataAttributes = append(policyObj.Body.DataAttributes, attributeObj) - policyObj.Body.Dissem = make([]string, 0) - } - - return policyObj, nil -} - -func (splitKey splitKey) handleKasRequest(kasPath string, body *RequestBody, authConfig AuthConfig) (*http.Response, error) { - kasURL := body.KasURL - - requestBodyData, err := json.Marshal(body) - if err != nil { - return nil, fmt.Errorf("json.Marshal failed: %w", err) - } - - claims := rewrapJWTClaims{ - jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(60 * time.Second)), - IssuedAt: jwt.NewNumericDate(time.Now()), - }, - string(requestBodyData), - } - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) - - signingRSAPrivateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(authConfig.signingPrivateKey)) - if err != nil { - return nil, fmt.Errorf("jwt.ParseRSAPrivateKeyFromPEM failed: %w", err) - } - - signedToken, err := token.SignedString(signingRSAPrivateKey) - if err != nil { - return nil, fmt.Errorf("jwt.SignedString failed: %w", err) - } - - signedTokenRequestBody, err := json.Marshal(map[string]string{ - kSignedRequestToken: signedToken, - }) - if err != nil { - return nil, fmt.Errorf("json.Marshal failed: %w", err) - } - - kasRequestURL, err := url.JoinPath(fmt.Sprintf("%v", kasURL), kasPath) - if err != nil { - return nil, fmt.Errorf("url.JoinPath failed: %w", err) - } - request, err := http.NewRequestWithContext(context.Background(), http.MethodPost, kasRequestURL, - bytes.NewBuffer(signedTokenRequestBody)) - if err != nil { - return nil, fmt.Errorf("http.NewRequestWithContext failed: %w", err) - } - - // add required headers - request.Header = http.Header{ - kContentTypeKey: {kContentTypeJSONValue}, - kAuthorizationKey: {authConfig.authToken}, - kAcceptKey: {kContentTypeJSONValue}, - } - - client := &http.Client{} - - response, err := client.Do(request) - if err != nil { - slog.Error("failed http request") - return nil, err - } - - return response, nil -} -func (splitKey splitKey) rewrap(authConfig AuthConfig, requestBody *RequestBody) ([]byte, error) { - - clientKeyPair, err := crypto.NewRSAKeyPair(tdf3KeySize) - if err != nil { - return nil, fmt.Errorf("crypto.NewRSAKeyPair failed: %w", err) - } - - clientPubKey, err := clientKeyPair.PublicKeyInPemFormat() - if err != nil { - return nil, fmt.Errorf("crypto.PublicKeyInPemFormat failed: %w", err) - } - requestBody.ClientPublicKey = clientPubKey - - clientPrivateKey, err := clientKeyPair.PrivateKeyInPemFormat() - if err != nil { - return nil, fmt.Errorf("crypto.PublicKeyInPemFormat failed: %w", err) - } - - response, err := splitKey.handleKasRequest(kRewrapV2, requestBody, authConfig) - if err != nil { - slog.Error("failed http request") - return nil, err - } - if response.StatusCode != kHTTPOk { - return nil, fmt.Errorf("http request failed status code:%d", response.StatusCode) - } - - defer func(Body io.ReadCloser) { - err := Body.Close() - if err != nil { - slog.Error("Fail to close HTTP response") - } - }(response.Body) - - rewrapResponseBody, err := io.ReadAll(response.Body) - if err != nil { - return nil, fmt.Errorf("io.ReadAll failed: %w", err) - } - - key, err := getWrappedKey(rewrapResponseBody, clientPrivateKey) - if err != nil { - return nil, fmt.Errorf("failed to unwrap the wrapped key:%w", err) - } - - return key, nil -} - -func getWrappedKey(rewrapResponseBody []byte, clientPrivateKey string) ([]byte, error) { - var data map[string]interface{} - err := json.Unmarshal(rewrapResponseBody, &data) - if err != nil { - return nil, fmt.Errorf("json.Unmarshal failed: %w", err) - } - - entityWrappedKey, ok := data[kEntityWrappedKey] - if !ok { - return nil, fmt.Errorf("entityWrappedKey is missing in key access object") - } - - asymDecrypt, err := crypto.NewAsymDecryption(clientPrivateKey) - if err != nil { - return nil, fmt.Errorf("crypto.NewAsymDecryption failed: %w", err) - } - - entityWrappedKeyDecoded, err := crypto.Base64Decode([]byte(fmt.Sprintf("%v", entityWrappedKey))) - if err != nil { - return nil, fmt.Errorf("crypto.Base64Decode failed: %w", err) - } - - key, err := asymDecrypt.Decrypt(entityWrappedKeyDecoded) - if err != nil { - return nil, fmt.Errorf("crypto.Decrypt failed: %w", err) - } - - return key, nil -} diff --git a/sdk/split_key_test.go b/sdk/split_key_test.go deleted file mode 100644 index 0d7b7153e3..0000000000 --- a/sdk/split_key_test.go +++ /dev/null @@ -1,261 +0,0 @@ -package sdk - -import ( - "encoding/hex" - "encoding/json" - "fmt" - "io" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/golang-jwt/jwt/v4" - "github.com/opentdf/opentdf-v2-poc/sdk/internal/crypto" -) - -func TestNewSplitKeyFromKasInfo(t *testing.T) { - attributes := []string{ - "https://example.com/attr/Classification/value/S", - "https://example.com/attr/Classification/value/X", - } - sampleMetaData := `{"displayName" : "openTDF go sdk"}` - - for _, test := range testHarnesses { - kasInfoList := test.kasInfoList - for index := range kasInfoList { - kasInfoList[index].publicKey = mockKasPublicKey - } - - sKey, err := newSplitKeyFromKasInfo(test.kasInfoList, attributes, sampleMetaData) - if err != nil { - t.Fatalf("tdf.newSplitKeyFromKasInfo failed: %v", err) - } - - manifest, err := sKey.getManifest() - if err != nil { - t.Fatalf("tdf.splitKey.getManifest failed: %v", err) - } - - if len(manifest.KeyAccessObjs) == 0 { - t.Fatalf("fail: key access object missing from the manifest") - } - - if len(manifest.KeyAccessObjs[0].EncryptedMetadata) == 0 { - t.Fatalf("fail: meta data missing from the manifest") - } - } -} - -//nolint:gocognit -func TestNewSplitKeyFromManifest(t *testing.T) { - kasPrivateKey := `-----BEGIN PRIVATE KEY----- - MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDOpiotrvV2i5h6 - clHMzDGgh3h/kMa0LoGx2OkDPd8jogycUh7pgE5GNiN2lpSmFkjxwYMXnyrwr9Ex - yczBWJ7sRGDCDaQg5fjVUIloZ8FJVbn+sEcfQ9iX6vmI9/S++oGK79QM3V8M8cp4 - 1r/T1YVmuzUHE1say/TLHGhjtGkxHDF8qFy6Z2rYFTCVJQHNqGmwNVGd0qG7gim8 - 6Hawu/CMYj4jG9oITlj8rJtQOaJ6ZqemQVoNmb3j1LkyeUKzRIt+86aoBiz+T3Tf - OEvXF6xgBj3XoiOhPYK+abFPYcrArvb6oubT8NjjQoj3j0sXWUnIIMg+e4f+XNVU - 54ZzDaLZAgMBAAECggEBALb0yK0PlMUyzHnEUwXV1y5AIoAWhsYp0qvJ1msHUVKz - +yQ/VJz4+tQQxI8OvGbbnhNkd5LnWdYkYzsIZl7b/kBCPcQw3Zo+4XLCzhUAn1E1 - M+n42c8le1LtN6Z7mVWoZh7DPONy7t+ABvm7b7S1+1i78DPmgCeWYZGeAhIcPXG6 - 5AxWIV3jigxksE6kYY9Y7DmtsZgMRrdV7SU8VtgPtT7tua8z5/U3Av0WINyKBSoM - 0yDHsAg57KnM8znx2JWLtHd0Mk5bBuu2DLbtyKNrVUAUuMPzrLGBh9S9QRd934KU - uFAi1TEfgEachnGgSHJpzVzr2ur1tifABnQ7GNXObe0CgYEA6KowK0subdDY+uGW - ciP2XDAMerbJJeL0/UIGPb/LUmskniio2493UBGgY2FsRyvbzJ+/UAOjIPyIxhj7 - 78ZyVG8BmIzKan1RRVh//O+5yvks/eTOYjWeQ1Lcgqs3q4YAO13CEBZgKWKTUomg - mskFJq04tndeSIyhDaW+BuWaXA8CgYEA42ABz3pql+DH7oL5C4KYBymK6wFBBOqk - dVk+ftyJQ6PzuZKpfsu4aPIjKm71lkTgK6O9o08s3SckAdu6vLukq2TZFF+a+9OI - lu5ww7GvfdMTgLAaFchD4bPlOInh1KVjBc1MwGXpl0ROde5pi8+WUrv9QJuoQfB/ - 4rhYdbJLSpcCgYA41mqSCPm8pgp7r2RbWeGzP6Gs0L5u3PTQcbKonxQCfF4jrPcj - O/b/vm6aGJClClfVsyi/WUQeqNKY4j2Zo7cGXV/cbnh8b0TNVgNePQn8Rcbx91Vb - tJGHDNUFruIYqtGfrxXbbDvtoEExJqHvbjAt9J8oJB0KSCCH/vdfI/QDjQKBgQCD - xLPH5Y24js/O7aAeh4RLQkv7fTKNAt5kE2AgbPYveOhZ9yC7Fpy8VPcENGGmwCuZ - nr7b0ZqSX4iCezBxB92aZktXf0B2CFT0AyLehi7JoHWA8o1rai/MsVB5v45ciawl - RKDiLy18OF2wAoawO5FGSSOvOYX9EL9MSMEbFESF6QKBgCVlZ9pPC+55rGT6AcEL - tUpDs+/wZvcmfsFd8xC5mMUN0DatAVzVAUI95+tQaWU3Uj+bqHq0lC6Wy2VceG0D - D+7EicjdGFN/2WVPXiYX1fblkxasZY+wChYBrPLjA9g0qOzzmXbRBph5QxDuQjJ6 - qcddVKB624a93ZBssn7OivnR - -----END PRIVATE KEY-----` - - sampleManifest := `{ - "encryptionInformation": { - "type": "split", - "policy": "eyJ1dWlkIjoiMmQyY2ZjMzQtYjg5MC0xMWVlLWEyMDgtYjJjMDM2M2FlNjI5IiwiQm9keSI6eyJkYXRhQXR0cmlidXRlcyI6W10sImRpc3NlbSI6W119fQ==", - "keyAccess": [ - { - "type": "kWrapped", - "url": "http://localhost:65432/api/kas", - "protocol": "kas", - "wrappedKey": "DfWZxVju4DIkSAu/QRHI04pLnBciASSDRokJ5gdDjx8fnh5jNsoyGQ63ekJgGEQp0r5CZqCIUHny7RU52LyMQuTz+lNLJKsZ3n9jDim5TbfzR2ETYAaAySzEPtUsVUWxwXHeHY8YNvb3nu8DuGCO2VadascqU9lZt6KOZ6Vr5JBOH3TukvTb0twHeJoBfyT+4HKSh27sdSOSNWOSuQkcbKGbcrAuTaV50jABphlW01gCfUv1N0BF3nWF30xOzpVl3BFwS/dA8bVVIckTLP6M456cWL6YrqHefwVA1Igrks/uVolL9sN1xS+nNlVVFCgipVz3I3wwgSTjhg5QD8YUcg==", - "policyBinding": "MDczYTJiYjE0MmZiODIxNTA3MjI2ZDBiYmNhMTM0ZmQyNDQ0YzJkODAwNmRjMjMxYjY2OWVhNTZlNzYyNTY1Nw==", - "encryptedMetadata": "" - }, - { - "type": "kWrapped", - "url": "http://localhost:65432/api/kas", - "protocol": "kas", - "wrappedKey": "rz13UFBazveewf7gHzEZZeg6Y5hjcVaz05W4VTlqVBxcNvJGajcXFIaeVCUgMf1++LOyqlqy6lIT+QpSG4pksXBCr7DeBrzvrXd4PUPlzFVDdZFbV22AZviSNQWe9IJyiZLt8L6RaHZcUfK2Gy2rUvXVr8o70xSjOvNAzp4nGJZPTSfbgSTo0aFPqgSvk+SmWNZl6eA98woCYO/SnSkHDWzuz7eSKcooiWoZD/XV71SpY+vHZaNwToEH4lhOxBTzNvPCX8cxi/2a6bygw4ma/bpepwwERS3SLg0cqDdQhQ95j34Y2aVzx3tSUntr33X0DHLimp1RKOTFdiPiAAnfuQ==", - "policyBinding": "MWQ3NmEwNjk2NWU5ZDZiNDQzM2U2ZTQ3MTU0NTEyYTQ0NjYwZGFiZDkyYjYzMTI3ZDUzMjE5NDJmMDg4YTNhOQ==", - "encryptedMetadata": "" - } - ], - "method": { - "algorithm": "AES-256-GCM", - "iv": "", - "isStreamable": true - }, - "integrityInformation": { - "rootSignature": { - "alg": "HS256", - "sig": "MWI0NWNmMzJkMDliOWI5YjJmNDk1YTk0NzhjMmJjMzMyODFhM2U5YjgxOTE0ZWY0NDI2ZGFkODkyMDEzY2VlMg==" - }, - "segmentHashAlg": "GMAC", - "segmentSizeDefault": 2097152, - "encryptedSegmentSizeDefault": 2097180, - "segments": [ - { - "hash": "NTZkZTg4NmE2MDhkNTU5OTU0N2RiNmRiNjNmMWExY2U=", - "segmentSize": 1024, - "encryptedSegmentSize": 1052 - } - ] - } - }, - "payload": { - "type": "reference", - "url": "0.payload", - "protocol": "zip", - "mimeType": "application/octet-stream", - "isEncrypted": true - } -}` - signingKeyPair, err := crypto.NewRSAKeyPair(tdf3KeySize) - if err != nil { - t.Fatalf("crypto.NewRSAKeyPair: %v", err) - } - - signingPubKey, err := signingKeyPair.PublicKeyInPemFormat() - if err != nil { - t.Fatalf("crypto.PublicKeyInPemFormat failed: %v", err) - } - - signingPrivateKey, err := signingKeyPair.PrivateKeyInPemFormat() - if err != nil { - t.Fatalf("crypto.PrivateKeyInPemFormat failed: %v", err) - } - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != kRewrapV2 { - t.Fatalf("expected to request '%s', got: %s", kRewrapV2, r.URL.Path) - } - if r.Header.Get(kAcceptKey) != kContentTypeJSONValue { - t.Fatalf("expected Accept: application/json header, got: %s", r.Header.Get("Accept")) - } - - requestBody, err := io.ReadAll(r.Body) - if err != nil { - t.Fatalf("io.ReadAll failed: %v", err) - } - - var data map[string]string - err = json.Unmarshal(requestBody, &data) - if err != nil { - t.Fatalf("json.Unmarsha failed: %v", err) - } - - tokenString, ok := data[kSignedRequestToken] - if !ok { - t.Fatalf("signed token missing in rewrap response") - } - - token, err := jwt.ParseWithClaims(tokenString, &rewrapJWTClaims{}, func(token *jwt.Token) (interface{}, error) { - signingRSAPublicKey, err := jwt.ParseRSAPublicKeyFromPEM([]byte(signingPubKey)) - if err != nil { - return nil, fmt.Errorf("jwt.ParseRSAPrivateKeyFromPEM failed: %w", err) - } - - return signingRSAPublicKey, nil - }) - - var rewrapRequest = "" - if err != nil { - t.Fatalf("jwt.ParseWithClaims failed:%v", err) - } else if claims, fine := token.Claims.(*rewrapJWTClaims); fine { - rewrapRequest = claims.Body - } else { - t.Fatalf("unknown claims type, cannot proceed") - } - reqBody := RequestBody{} - err = json.Unmarshal([]byte(rewrapRequest), &reqBody) - if err != nil { - t.Fatalf("json.Unmarshal failed: %v", err) - } - - wrappedKey, err := crypto.Base64Decode([]byte(reqBody.WrappedKey)) - if err != nil { - t.Fatalf("crypto.Base64Decode failed: %v", err) - } - - kasPrivateKey = strings.ReplaceAll(kasPrivateKey, "\n\t", "\n") - asymDecrypt, err := crypto.NewAsymDecryption(kasPrivateKey) - if err != nil { - t.Fatalf("crypto.NewAsymDecryption failed: %v", err) - } - - symmetricKey, err := asymDecrypt.Decrypt(wrappedKey) - if err != nil { - t.Fatalf("crypto.Decrypt failed: %v", err) - } - - asymEncrypt, err := crypto.NewAsymEncryption(reqBody.ClientPublicKey) - if err != nil { - t.Fatalf("crypto.NewAsymEncryption failed: %v", err) - } - - entityWrappedKey, err := asymEncrypt.Encrypt(symmetricKey) - if err != nil { - t.Fatalf("crypto.encrypt failed: %v", err) - } - - response, err := json.Marshal(map[string]string{ - kEntityWrappedKey: string(crypto.Base64Encode(entityWrappedKey)), - }) - if err != nil { - t.Fatalf("json.Marshal failed: %v", err) - } - - w.WriteHeader(http.StatusOK) - _, err = w.Write(response) - if err != nil { - t.Fatalf("http.ResponseWriter.Write failed: %v", err) - } - })) - defer server.Close() - - manifestObj := &Manifest{} - err = json.Unmarshal([]byte(sampleManifest), manifestObj) - if err != nil { - t.Fatalf("json.Unmarshal failed:%v", err) - } - - // mock the kas url - for index := range manifestObj.EncryptionInformation.KeyAccessObjs { - manifestObj.EncryptionInformation.KeyAccessObjs[index].KasURL = server.URL - } - - authConfig := AuthConfig{signingPrivateKey: signingPrivateKey, signingPublicKey: signingPubKey} - sKey, err := newSplitKeyFromManifest(authConfig, *manifestObj) - if err != nil { - t.Errorf("newSplitKeyFromManifest failed: %v", err) - } - - if len(sKey.tdfKeyAccessObjects) != 2 { - t.Errorf("split key key access objects count don't match: expected %v, got %v", len(sKey.tdfKeyAccessObjects), 2) - } - - expectedSplitKey := "6788741d1a659ac43693ffba933d8eaded57fad1705558fba98a89605fb56ab8" - if hex.EncodeToString(sKey.key[:]) != expectedSplitKey { - t.Errorf("split key is valid explected:%v, got %v", expectedSplitKey, hex.EncodeToString(sKey.key[:])) - } -} diff --git a/sdk/tdf.go b/sdk/tdf.go index 7fdb127a50..f14617640c 100644 --- a/sdk/tdf.go +++ b/sdk/tdf.go @@ -2,13 +2,21 @@ package sdk import ( "bytes" + "context" + "encoding/hex" "encoding/json" "errors" "fmt" "io" + "log/slog" "math" + "net/http" + "net/url" "strings" + "time" + "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" "github.com/opentdf/opentdf-v2-poc/sdk/internal/archive" "github.com/opentdf/opentdf-v2-poc/sdk/internal/crypto" ) @@ -21,6 +29,8 @@ var ( errWriteFailed = errors.New("tdf: io.writer fail to write all bytes") errSegSigValidation = errors.New("tdf: failed integrity check on segment hash") errTDFPayloadReadFail = errors.New("tdf: fail to read payload from tdf") + errInvalidKasInfo = errors.New("tdf: kas information is missing") + errKasPubKeyMissing = errors.New("tdf: kas public key is missing") errTDFPayloadInvalidOffset = errors.New("sdk.Reader.ReadAt: negative offset") ) @@ -33,42 +43,75 @@ const ( hmacIntegrityAlgorithm = "HS256" gmacIntegrityAlgorithm = "GMAC" tdfZipReference = "reference" + kKeySize = 32 + kWrapped = "wrapped" + kKasProtocol = "kas" + kSplitKeyType = "split" + kGCMCipherAlgorithm = "AES-256-GCM" + kGMACPayloadLength = 16 + // kClientPublicKey = "clientPublicKey" + kSignedRequestToken = "signedRequestToken" + // kKasURL = "url" + kRewrapV2 = "/v2/rewrap" + kAuthorizationKey = "Authorization" + kContentTypeKey = "Content-Type" + kAcceptKey = "Accept" + kContentTypeJSONValue = "application/json" + kEntityWrappedKey = "entityWrappedKey" + // kPolicy = "policy" + // kHmacIntegrityAlgorithm = "HS256" + // kGmacIntegrityAlgorithm = "GMAC" ) type Reader struct { - tdfReader archive.TDFReader - sKey splitKey - cursor int64 - payloadSize int64 - manifest Manifest + Manifest + unencryptedMetadata string + tdfReader archive.TDFReader + authConfig AuthConfig + cursor int64 + aesGcm crypto.AesGcm + payloadSize int64 + payloadKey []byte } -// CreateTDF tdf -func CreateTDF(tdfConfig TDFConfig, reader io.ReadSeeker, writer io.Writer) (int64, error) { +type TDFObject struct { + Manifest + TdfSize int64 + aesGcm crypto.AesGcm + payloadKey [kKeySize]byte +} - inputSize, err := reader.Seek(0, io.SeekEnd) - if err != nil { - return 0, fmt.Errorf("readSeeker.Seek failed: %w", err) - } +type rewrapJWTClaims struct { + jwt.RegisteredClaims + Body string `json:"requestBody"` +} - _, err = reader.Seek(0, io.SeekStart) +type RequestBody struct { + KeyAccess `json:"keyAccess"` + ClientPublicKey string `json:"clientPublicKey"` + Policy string `json:"policy"` +} + +// CreateTDF tdf +func CreateTDF(tdfConfig TDFConfig, reader io.ReadSeeker, writer io.Writer) (*TDFObject, error) { //nolint:funlen, gocognit, lll + inputSize, err := reader.Seek(0, io.SeekEnd) if err != nil { - return 0, fmt.Errorf("readSeeker.Seek failed: %w", err) + return nil, fmt.Errorf("readSeeker.Seek failed: %w", err) } if inputSize > maxFileSizeSupported { - return 0, errFileTooLarge + return nil, errFileTooLarge } - // create a split key - splitKey, err := newSplitKeyFromKasInfo(tdfConfig.kasInfoList, tdfConfig.attributes, tdfConfig.metaData) + _, err = reader.Seek(0, io.SeekStart) if err != nil { - return 0, fmt.Errorf("fail to create a new split key: %w", err) + return nil, fmt.Errorf("readSeeker.Seek failed: %w", err) } - manifest, err := splitKey.getManifest() + tdfObject := &TDFObject{} + err = tdfObject.prepareManifest(tdfConfig) if err != nil { - return 0, fmt.Errorf("fail to create manifest: %w", err) + return nil, fmt.Errorf("fail to create a new split key: %w", err) } segmentSize := tdfConfig.defaultSegmentSize @@ -88,7 +131,7 @@ func CreateTDF(tdfConfig TDFConfig, reader io.ReadSeeker, writer io.Writer) (int err = tdfWriter.SetPayloadSize(payloadSize) if err != nil { - return 0, fmt.Errorf("archive.SetPayloadSize failed: %w", err) + return nil, fmt.Errorf("archive.SetPayloadSize failed: %w", err) } var readPos int64 @@ -102,141 +145,235 @@ func CreateTDF(tdfConfig TDFConfig, reader io.ReadSeeker, writer io.Writer) (int n, err := reader.Read(readBuf.Bytes()[:readSize]) if err != nil { - return 0, fmt.Errorf("io.ReadSeeker.Read failed: %w", err) + return nil, fmt.Errorf("io.ReadSeeker.Read failed: %w", err) } if int64(n) != readSize { - return 0, fmt.Errorf("io.ReadSeeker.Read size missmatch") + return nil, fmt.Errorf("io.ReadSeeker.Read size missmatch") } - cipherData, err := splitKey.encrypt(readBuf.Bytes()[:readSize]) + cipherData, err := tdfObject.aesGcm.Encrypt(readBuf.Bytes()[:readSize]) if err != nil { - return 0, fmt.Errorf("io.ReadSeeker.Read failed: %w", err) + return nil, fmt.Errorf("io.ReadSeeker.Read failed: %w", err) } err = tdfWriter.AppendPayload(cipherData) if err != nil { - return 0, fmt.Errorf("io.writer.Write failed: %w", err) + return nil, fmt.Errorf("io.writer.Write failed: %w", err) } - payloadSig, err := splitKey.getSignature(cipherData, tdfConfig.segmentIntegrityAlgorithm) + segmentSig, err := calculateSignature(cipherData, tdfObject.payloadKey[:], tdfConfig.segmentIntegrityAlgorithm) if err != nil { - return 0, fmt.Errorf("splitKey.GetSignaturefailed: %w", err) + return nil, fmt.Errorf("splitKey.GetSignaturefailed: %w", err) } - aggregateHash += payloadSig + aggregateHash += segmentSig + segmentInfo := Segment{ + Hash: string(crypto.Base64Encode([]byte(segmentSig))), + Size: readSize, + EncryptedSize: int64(len(cipherData)), + } - segmentInfo := Segment{} - segmentInfo.Hash = string(crypto.Base64Encode([]byte(payloadSig))) - segmentInfo.Size = readSize - segmentInfo.EncryptedSize = int64(len(cipherData)) - manifest.EncryptionInformation.IntegrityInformation.Segments = - append(manifest.EncryptionInformation.IntegrityInformation.Segments, segmentInfo) + tdfObject.Manifest.EncryptionInformation.IntegrityInformation.Segments = + append(tdfObject.Manifest.EncryptionInformation.IntegrityInformation.Segments, segmentInfo) totalSegments-- readPos += readSize } - aggregateHashSig, err := splitKey.getSignature([]byte(aggregateHash), tdfConfig.integrityAlgorithm) + rootSignature, err := calculateSignature([]byte(aggregateHash), tdfObject.payloadKey[:], tdfConfig.integrityAlgorithm) if err != nil { - return 0, fmt.Errorf("splitKey.GetSignaturefailed: %w", err) + return nil, fmt.Errorf("splitKey.GetSignaturefailed: %w", err) } - sig := string(crypto.Base64Encode([]byte(aggregateHashSig))) - manifest.EncryptionInformation.IntegrityInformation.RootSignature.Signature = sig + sig := string(crypto.Base64Encode([]byte(rootSignature))) + tdfObject.Manifest.EncryptionInformation.IntegrityInformation.RootSignature.Signature = sig integrityAlgStr := gmacIntegrityAlgorithm if tdfConfig.integrityAlgorithm == HS256 { integrityAlgStr = hmacIntegrityAlgorithm } - manifest.EncryptionInformation.IntegrityInformation.RootSignature.Algorithm = integrityAlgStr + tdfObject.Manifest.EncryptionInformation.IntegrityInformation.RootSignature.Algorithm = integrityAlgStr - manifest.EncryptionInformation.IntegrityInformation.DefaultSegmentSize = segmentSize - manifest.EncryptionInformation.IntegrityInformation.DefaultEncryptedSegSize = encryptedSegmentSize + tdfObject.Manifest.EncryptionInformation.IntegrityInformation.DefaultSegmentSize = segmentSize + tdfObject.Manifest.EncryptionInformation.IntegrityInformation.DefaultEncryptedSegSize = encryptedSegmentSize segIntegrityAlgStr := gmacIntegrityAlgorithm if tdfConfig.segmentIntegrityAlgorithm == HS256 { segIntegrityAlgStr = hmacIntegrityAlgorithm } - manifest.EncryptionInformation.IntegrityInformation.SegmentHashAlgorithm = segIntegrityAlgStr - manifest.EncryptionInformation.Method.IsStreamable = true + tdfObject.Manifest.EncryptionInformation.IntegrityInformation.SegmentHashAlgorithm = segIntegrityAlgStr + tdfObject.Manifest.EncryptionInformation.Method.IsStreamable = true // add payload info - manifest.Payload.MimeType = defaultMimeType - manifest.Payload.Protocol = tdfAsZip - manifest.Payload.Type = tdfZipReference - manifest.Payload.URL = archive.TDFPayloadFileName - manifest.Payload.IsEncrypted = true + tdfObject.Manifest.Payload.MimeType = defaultMimeType + tdfObject.Manifest.Payload.Protocol = tdfAsZip + tdfObject.Manifest.Payload.Type = tdfZipReference + tdfObject.Manifest.Payload.URL = archive.TDFPayloadFileName + tdfObject.Manifest.Payload.IsEncrypted = true - manifestAsStr, err := json.Marshal(manifest) + manifestAsStr, err := json.Marshal(tdfObject.Manifest) if err != nil { - return 0, fmt.Errorf("json.Marshal failed:%w", err) + return nil, fmt.Errorf("json.Marshal failed:%w", err) } err = tdfWriter.AppendManifest(string(manifestAsStr)) if err != nil { - return 0, fmt.Errorf("TDFWriter.AppendManifest failed:%w", err) + return nil, fmt.Errorf("TDFWriter.AppendManifest failed:%w", err) } - totalBytes, err := tdfWriter.Finish() + tdfObject.TdfSize, err = tdfWriter.Finish() if err != nil { - return 0, fmt.Errorf("TDFWriter.Finish failed:%w", err) + return nil, fmt.Errorf("TDFWriter.Finish failed:%w", err) } - return totalBytes, nil + return tdfObject, nil } -func NewReader(authConfig AuthConfig, reader io.ReadSeeker) (*Reader, error) { - // create tdf reader - tdfReader, err := archive.NewTDFReader(reader) - if err != nil { - return nil, fmt.Errorf("archive.NewTDFReader failed: %w", err) +// prepare the manifest for TDF +func (tdfObject *TDFObject) prepareManifest(tdfConfig TDFConfig) error { //nolint:funlen,gocognit + manifest := Manifest{} + if len(tdfConfig.kasInfoList) == 0 { + return errInvalidKasInfo } - manifest, err := tdfReader.Manifest() + manifest.EncryptionInformation.KeyAccessType = kSplitKeyType + + policyObj, err := tdfObject.createPolicyObject(tdfConfig.attributes) if err != nil { - return nil, fmt.Errorf("tdfReader.Manifest failed: %w", err) + return fmt.Errorf("fail to create policy object:%w", err) } - manifestObj := &Manifest{} - err = json.Unmarshal([]byte(manifest), manifestObj) + policyObjectAsStr, err := json.Marshal(policyObj) if err != nil { - return nil, fmt.Errorf("json.Unmarshal failed:%w", err) + return fmt.Errorf("json.Marshal failed:%w", err) } - // create a split key - sKey, err := newSplitKeyFromManifest(authConfig, *manifestObj) + base64PolicyObject := crypto.Base64Encode(policyObjectAsStr) + symKeys := [][]byte{} + for _, kasInfo := range tdfConfig.kasInfoList { + if len(kasInfo.publicKey) == 0 { + return errKasPubKeyMissing + } + + symKey, err := crypto.RandomBytes(kKeySize) + if err != nil { + return fmt.Errorf("crypto.RandomBytes failed:%w", err) + } + + keyAccess := KeyAccess{} + keyAccess.KeyType = kWrapped + keyAccess.KasURL = kasInfo.url + keyAccess.Protocol = kKasProtocol + + // add policyBinding + policyBinding := hex.EncodeToString(crypto.CalculateSHA256Hmac(symKey, base64PolicyObject)) + keyAccess.PolicyBinding = string(crypto.Base64Encode([]byte(policyBinding))) + + // wrap the key with kas public key + asymEncrypt, err := crypto.NewAsymEncryption(kasInfo.publicKey) + if err != nil { + return fmt.Errorf("crypto.NewAsymEncryption failed:%w", err) + } + + wrappedKey, err := asymEncrypt.Encrypt(symKey) + if err != nil { + return fmt.Errorf("crypto.AsymEncryption.encrypt failed:%w", err) + } + keyAccess.WrappedKey = string(crypto.Base64Encode(wrappedKey)) + + // add meta data + if len(tdfConfig.metaData) > 0 { + gcm, err := crypto.NewAESGcm(symKey) + if err != nil { + return fmt.Errorf("crypto.NewAESGcm failed:%w", err) + } + + encryptedMetaData, err := gcm.Encrypt([]byte(tdfConfig.metaData)) + if err != nil { + return fmt.Errorf("crypto.AesGcm.encrypt failed:%w", err) + } + + iv := encryptedMetaData[:crypto.GcmStandardNonceSize] + metadata := EncryptedMetadata{Cipher: string(crypto.Base64Encode(encryptedMetaData)), + Iv: string(crypto.Base64Encode(iv))} + + metadataJSON, err := json.Marshal(metadata) + if err != nil { + return fmt.Errorf(" json.Marshal failed:%w", err) + } + + keyAccess.EncryptedMetadata = string(crypto.Base64Encode(metadataJSON)) + } + + symKeys = append(symKeys, symKey) + manifest.EncryptionInformation.KeyAccessObjs = append(manifest.EncryptionInformation.KeyAccessObjs, keyAccess) + } + + manifest.EncryptionInformation.Policy = string(base64PolicyObject) + manifest.EncryptionInformation.Method.Algorithm = kGCMCipherAlgorithm + + // create the payload key by XOR all the keys in key access object. + for _, symKey := range symKeys { + for keyByteIndex, keyByte := range symKey { + tdfObject.payloadKey[keyByteIndex] ^= keyByte + } + } + + gcm, err := crypto.NewAESGcm(tdfObject.payloadKey[:]) if err != nil { - return nil, fmt.Errorf("fail to create a new split key: %w", err) + return fmt.Errorf(" crypto.NewAESGcm failed:%w", err) } - res, err := sKey.validateRootSignature(manifestObj) + tdfObject.Manifest = manifest + tdfObject.aesGcm = gcm + return nil +} + +// create policy object +func (tdfObject *TDFObject) createPolicyObject(attributes []string) (PolicyObject, error) { + uuidObj, err := uuid.NewUUID() if err != nil { - return nil, fmt.Errorf("splitKey.validateRootSignature failed: %w", err) + return PolicyObject{}, fmt.Errorf("uuid.NewUUID failed: %w", err) } - if !res { - return nil, errRootSigValidation + policyObj := PolicyObject{} + policyObj.UUID = uuidObj.String() + + for _, attribute := range attributes { + attributeObj := attributeObject{} + attributeObj.Attribute = attribute + policyObj.Body.DataAttributes = append(policyObj.Body.DataAttributes, attributeObj) + policyObj.Body.Dissem = make([]string, 0) } - segSize := manifestObj.EncryptionInformation.IntegrityInformation.DefaultSegmentSize - encryptedSegSize := manifestObj.EncryptionInformation.IntegrityInformation.DefaultEncryptedSegSize + return policyObj, nil +} - if segSize != encryptedSegSize-(gcmIvSize+aesBlockSize) { - return nil, errSegSizeMismatch +// LoadTDF loads the tdf and prepare for reading the payload from TDF +func LoadTDF(authConfig AuthConfig, reader io.ReadSeeker) (*Reader, error) { + // create tdf reader + tdfReader, err := archive.NewTDFReader(reader) + if err != nil { + return nil, fmt.Errorf("archive.NewTDFReader failed: %w", err) } - var payloadSize int64 - for _, seg := range manifestObj.EncryptionInformation.IntegrityInformation.Segments { - payloadSize += seg.Size + manifest, err := tdfReader.Manifest() + if err != nil { + return nil, fmt.Errorf("tdfReader.Manifest failed: %w", err) + } + + manifestObj := &Manifest{} + err = json.Unmarshal([]byte(manifest), manifestObj) + if err != nil { + return nil, fmt.Errorf("json.Unmarshal failed:%w", err) } return &Reader{ - tdfReader: tdfReader, - manifest: *manifestObj, - payloadSize: payloadSize, - sKey: sKey, + tdfReader: tdfReader, + Manifest: *manifestObj, + authConfig: authConfig, }, nil } @@ -244,6 +381,13 @@ func NewReader(authConfig AuthConfig, reader io.ReadSeeker) (*Reader, error) { // read (0 <= n <= len(p)) and any error encountered. It returns an // io.EOF error when the stream ends. func (reader *Reader) Read(p []byte) (int, error) { + if reader.payloadKey == nil { + err := reader.getPayloadKey() + if err != nil { + return 0, fmt.Errorf("reader.getPayloadKey failed: %w", err) + } + } + n, err := reader.ReadAt(p, reader.cursor) reader.cursor += int64(n) return n, err @@ -251,10 +395,17 @@ func (reader *Reader) Read(p []byte) (int, error) { // WriteTo writes data to writer until there's no more data to write or // when an error occurs. -func (reader *Reader) WriteTo(writer io.Writer) (n int64, err error) { +func (reader *Reader) WriteTo(writer io.Writer) (int64, error) { + if reader.payloadKey == nil { + err := reader.getPayloadKey() + if err != nil { + return 0, fmt.Errorf("reader.getPayloadKey failed: %w", err) + } + } + var totalBytes int64 var payloadReadOffset int64 - for _, seg := range reader.manifest.EncryptionInformation.IntegrityInformation.Segments { + for _, seg := range reader.Manifest.EncryptionInformation.IntegrityInformation.Segments { readBuf, err := reader.tdfReader.ReadPayload(payloadReadOffset, seg.EncryptedSize) if err != nil { return totalBytes, fmt.Errorf("TDFReader.ReadPayload failed: %w", err) @@ -264,13 +415,13 @@ func (reader *Reader) WriteTo(writer io.Writer) (n int64, err error) { return totalBytes, errTDFReaderFailed } - segHashAlg := reader.manifest.EncryptionInformation.IntegrityInformation.SegmentHashAlgorithm + segHashAlg := reader.Manifest.EncryptionInformation.IntegrityInformation.SegmentHashAlgorithm sigAlg := HS256 if strings.EqualFold(gmacIntegrityAlgorithm, segHashAlg) { sigAlg = GMAC } - payloadSig, err := reader.sKey.getSignature(readBuf, sigAlg) + payloadSig, err := calculateSignature(readBuf, reader.payloadKey, sigAlg) if err != nil { return totalBytes, fmt.Errorf("splitKey.GetSignaturefailed: %w", err) } @@ -279,7 +430,7 @@ func (reader *Reader) WriteTo(writer io.Writer) (n int64, err error) { return totalBytes, errSegSigValidation } - writeBuf, err := reader.sKey.decrypt(readBuf) + writeBuf, err := reader.aesGcm.Decrypt(readBuf) if err != nil { return totalBytes, fmt.Errorf("splitKey.decrypt failed: %w", err) } @@ -305,13 +456,19 @@ func (reader *Reader) WriteTo(writer io.Writer) (n int64, err error) { // of bytes read (0 <= n <= len(p)) and any error encountered. It returns an // io.EOF error when the stream ends. // NOTE: For larger tdf sizes use sdk.GetTDFPayload for better performance -func (reader *Reader) ReadAt(buf []byte, offset int64) (int, error) { +func (reader *Reader) ReadAt(buf []byte, offset int64) (int, error) { //nolint:funlen, gocognit + if reader.payloadKey == nil { + err := reader.getPayloadKey() + if err != nil { + return 0, fmt.Errorf("reader.getPayloadKey failed: %w", err) + } + } if offset < 0 { return 0, errTDFPayloadInvalidOffset } - defaultSegmentSize := reader.manifest.EncryptionInformation.IntegrityInformation.DefaultSegmentSize + defaultSegmentSize := reader.Manifest.EncryptionInformation.IntegrityInformation.DefaultSegmentSize var start = math.Floor(float64(offset) / float64(defaultSegmentSize)) var end = math.Ceil(float64(offset+int64(len(buf))) / float64(defaultSegmentSize)) @@ -329,7 +486,7 @@ func (reader *Reader) ReadAt(buf []byte, offset int64) (int, error) { var decryptedBuf bytes.Buffer var payloadReadOffset int64 - for index, seg := range reader.manifest.EncryptionInformation.IntegrityInformation.Segments { + for index, seg := range reader.Manifest.EncryptionInformation.IntegrityInformation.Segments { if firstSegment > int64(index) { payloadReadOffset += seg.EncryptedSize continue @@ -344,13 +501,13 @@ func (reader *Reader) ReadAt(buf []byte, offset int64) (int, error) { return 0, errTDFReaderFailed } - segHashAlg := reader.manifest.EncryptionInformation.IntegrityInformation.SegmentHashAlgorithm + segHashAlg := reader.Manifest.EncryptionInformation.IntegrityInformation.SegmentHashAlgorithm sigAlg := HS256 if strings.EqualFold(gmacIntegrityAlgorithm, segHashAlg) { sigAlg = GMAC } - payloadSig, err := reader.sKey.getSignature(readBuf, sigAlg) + payloadSig, err := calculateSignature(readBuf, reader.payloadKey, sigAlg) if err != nil { return 0, fmt.Errorf("splitKey.GetSignaturefailed: %w", err) } @@ -359,7 +516,7 @@ func (reader *Reader) ReadAt(buf []byte, offset int64) (int, error) { return 0, errSegSigValidation } - writeBuf, err := reader.sKey.decrypt(readBuf) + writeBuf, err := reader.aesGcm.Decrypt(readBuf) if err != nil { return 0, fmt.Errorf("splitKey.decrypt failed: %w", err) } @@ -381,7 +538,7 @@ func (reader *Reader) ReadAt(buf []byte, offset int64) (int, error) { } } - var err error = nil + var err error bufLen := int64(len(buf)) if (offset + int64(len(buf))) > reader.payloadSize { bufLen = reader.payloadSize - offset @@ -393,30 +550,47 @@ func (reader *Reader) ReadAt(buf []byte, offset int64) (int, error) { return int(bufLen), err } -// Manifest return the manifest as json string. -func (reader *Reader) Manifest() (string, error) { - manifestAsStr, err := json.Marshal(reader.manifest) - if err != nil { - return "", fmt.Errorf("json.Marshal failed:%w", err) +// GetManifest return the manifest in TDF. +func (reader *Reader) GetManifest() Manifest { + return reader.Manifest +} + +// GetUnencryptedMetadata return decrypted metadata in manifest. +func (reader *Reader) GetUnencryptedMetadata() (string, error) { + if reader.payloadKey == nil { + err := reader.getPayloadKey() + if err != nil { + return "", fmt.Errorf("reader.getPayloadKey failed: %w", err) + } } - return string(manifestAsStr), nil + return reader.unencryptedMetadata, nil } -// UnencryptedMetadata return the meta present in tdf. -func (reader *Reader) UnencryptedMetadata() string { - // There will be at least one key access in tdf - return reader.sKey.tdfKeyAccessObjects[0].metaData +// GetPolicy return policy object in manifest. +func (reader *Reader) GetPolicy() (PolicyObject, error) { + policyObj := PolicyObject{} + policy, err := crypto.Base64Decode([]byte(reader.Manifest.Policy)) + if err != nil { + return policyObj, fmt.Errorf("crypto.Base64Decode failed:%w", err) + } + + err = json.Unmarshal(policy, &policyObj) + if err != nil { + return policyObj, fmt.Errorf("json.Unmarshal failed: %w", err) + } + + return policyObj, nil } // DataAttributes return the data attributes present in tdf. func (reader *Reader) DataAttributes() ([]string, error) { - policy, err := crypto.Base64Decode([]byte(reader.manifest.Policy)) + policy, err := crypto.Base64Decode([]byte(reader.Manifest.Policy)) if err != nil { return nil, fmt.Errorf("crypto.Base64Decode failed:%w", err) } - policyObj := policyObject{} + policyObj := PolicyObject{} err = json.Unmarshal(policy, &policyObj) if err != nil { return nil, fmt.Errorf("json.Unmarshal failed: %w", err) @@ -430,3 +604,264 @@ func (reader *Reader) DataAttributes() ([]string, error) { return attributes, nil } + +// Get the payload key th +func (reader *Reader) getPayloadKey() error { //nolint:gocognit + var unencryptedMetadata string + var payloadKey [kKeySize]byte + for _, keyAccessObj := range reader.Manifest.EncryptionInformation.KeyAccessObjs { + requestBody := RequestBody{keyAccessObj, "", reader.Manifest.EncryptionInformation.Policy} + wrappedKey, err := rewrap(reader.authConfig, &requestBody) + if err != nil { + return fmt.Errorf(" splitKey.rewrap failed:%w", err) + } + + for keyByteIndex, keyByte := range wrappedKey { + payloadKey[keyByteIndex] ^= keyByte + } + + if len(keyAccessObj.EncryptedMetadata) != 0 { + gcm, err := crypto.NewAESGcm(wrappedKey) + if err != nil { + return fmt.Errorf("crypto.NewAESGcm failed:%w", err) + } + + decodedMetaData, err := crypto.Base64Decode([]byte(keyAccessObj.EncryptedMetadata)) + if err != nil { + return fmt.Errorf("crypto.Base64Decode failed:%w", err) + } + + metadata := EncryptedMetadata{} + err = json.Unmarshal(decodedMetaData, &metadata) + if err != nil { + return fmt.Errorf("json.Unmarshal failed:%w", err) + } + + encodedCipherText := metadata.Cipher + cipherText, _ := crypto.Base64Decode([]byte(encodedCipherText)) + metaData, err := gcm.Decrypt(cipherText) + if err != nil { + return fmt.Errorf("crypto.AesGcm.encrypt failed:%w", err) + } + + unencryptedMetadata = string(metaData) + } + } + + res, err := validateRootSignature(reader.Manifest, payloadKey[:]) + if err != nil { + return fmt.Errorf("splitKey.validateRootSignature failed: %w", err) + } + + if !res { + return errRootSigValidation + } + + segSize := reader.Manifest.EncryptionInformation.IntegrityInformation.DefaultSegmentSize + encryptedSegSize := reader.Manifest.EncryptionInformation.IntegrityInformation.DefaultEncryptedSegSize + + if segSize != encryptedSegSize-(gcmIvSize+aesBlockSize) { + return errSegSizeMismatch + } + + var payloadSize int64 + for _, seg := range reader.Manifest.EncryptionInformation.IntegrityInformation.Segments { + payloadSize += seg.Size + } + + gcm, err := crypto.NewAESGcm(payloadKey[:]) + if err != nil { + return fmt.Errorf(" crypto.NewAESGcm failed:%w", err) + } + + reader.payloadSize = payloadSize + reader.unencryptedMetadata = unencryptedMetadata + reader.payloadKey = payloadKey[:] + reader.aesGcm = gcm + + return nil +} + +// calculateSignature calculate signature of data of the given algorithm. +func calculateSignature(data []byte, secret []byte, alg IntegrityAlgorithm) (string, error) { + if alg == HS256 { + hmac := crypto.CalculateSHA256Hmac(secret, data) + return hex.EncodeToString(hmac), nil + } + if kGMACPayloadLength > len(data) { + return "", fmt.Errorf("fail to create gmac signature") + } + + return hex.EncodeToString(data[len(data)-kGMACPayloadLength:]), nil +} + +// validate the root signature +func validateRootSignature(manifest Manifest, secret []byte) (bool, error) { + rootSigAlg := manifest.EncryptionInformation.IntegrityInformation.RootSignature.Algorithm + rootSigValue := manifest.EncryptionInformation.IntegrityInformation.RootSignature.Signature + + aggregateHash := &bytes.Buffer{} + for _, segment := range manifest.EncryptionInformation.IntegrityInformation.Segments { + decodedHash, err := crypto.Base64Decode([]byte(segment.Hash)) + if err != nil { + return false, fmt.Errorf("crypto.Base64Decode failed:%w", err) + } + + aggregateHash.Write(decodedHash) + } + + sigAlg := HS256 + if strings.EqualFold(gmacIntegrityAlgorithm, rootSigAlg) { + sigAlg = GMAC + } + + sig, err := calculateSignature(aggregateHash.Bytes(), secret, sigAlg) + if err != nil { + return false, fmt.Errorf("splitkey.getSignature failed:%w", err) + } + + if rootSigValue == string(crypto.Base64Encode([]byte(sig))) { + return true, nil + } + + return false, nil +} + +func handleKasRequest(kasPath string, body *RequestBody, authConfig AuthConfig) (*http.Response, error) { + kasURL := body.KasURL + + requestBodyData, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("json.Marshal failed: %w", err) + } + + claims := rewrapJWTClaims{ + jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), // Set expiration to be one minute from now + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + string(requestBodyData), + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + + signingRSAPrivateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(authConfig.signingPrivateKey)) + if err != nil { + return nil, fmt.Errorf("jwt.ParseRSAPrivateKeyFromPEM failed: %w", err) + } + + signedToken, err := token.SignedString(signingRSAPrivateKey) + if err != nil { + return nil, fmt.Errorf("jwt.SignedString failed: %w", err) + } + + signedTokenRequestBody, err := json.Marshal(map[string]string{ + kSignedRequestToken: signedToken, + }) + if err != nil { + return nil, fmt.Errorf("json.Marshal failed: %w", err) + } + + kasRequestURL, err := url.JoinPath(fmt.Sprintf("%v", kasURL), kasPath) + if err != nil { + return nil, fmt.Errorf("url.JoinPath failed: %w", err) + } + request, err := http.NewRequestWithContext(context.Background(), http.MethodPost, kasRequestURL, + bytes.NewBuffer(signedTokenRequestBody)) + if err != nil { + return nil, fmt.Errorf("http.NewRequestWithContext failed: %w", err) + } + + // add required headers + request.Header = http.Header{ + kContentTypeKey: {kContentTypeJSONValue}, + kAuthorizationKey: {authConfig.authToken}, + kAcceptKey: {kContentTypeJSONValue}, + } + + client := &http.Client{} + + response, err := client.Do(request) + if err != nil { + slog.Error("failed http request") + return nil, fmt.Errorf("http request failed: %w", err) + } + + return response, nil +} + +func rewrap(authConfig AuthConfig, requestBody *RequestBody) ([]byte, error) { + clientKeyPair, err := crypto.NewRSAKeyPair(tdf3KeySize) + if err != nil { + return nil, fmt.Errorf("crypto.NewRSAKeyPair failed: %w", err) + } + + clientPubKey, err := clientKeyPair.PublicKeyInPemFormat() + if err != nil { + return nil, fmt.Errorf("crypto.PublicKeyInPemFormat failed: %w", err) + } + requestBody.ClientPublicKey = clientPubKey + + clientPrivateKey, err := clientKeyPair.PrivateKeyInPemFormat() + if err != nil { + return nil, fmt.Errorf("crypto.PublicKeyInPemFormat failed: %w", err) + } + + response, err := handleKasRequest(kRewrapV2, requestBody, authConfig) + + defer func() { + err := response.Body.Close() + if err != nil { + slog.Error("Fail to close HTTP response") + } + }() + + if err != nil { + slog.Error("failed http request") + return nil, fmt.Errorf("http request error: %w", err) + } + if response.StatusCode != kHTTPOk { + return nil, fmt.Errorf("http request failed status code:%d", response.StatusCode) + } + + rewrapResponseBody, err := io.ReadAll(response.Body) + if err != nil { + return nil, fmt.Errorf("io.ReadAll failed: %w", err) + } + + key, err := getWrappedKey(rewrapResponseBody, clientPrivateKey) + if err != nil { + return nil, fmt.Errorf("failed to unwrap the wrapped key:%w", err) + } + + return key, nil +} + +func getWrappedKey(rewrapResponseBody []byte, clientPrivateKey string) ([]byte, error) { + var data map[string]interface{} + err := json.Unmarshal(rewrapResponseBody, &data) + if err != nil { + return nil, fmt.Errorf("json.Unmarshal failed: %w", err) + } + + entityWrappedKey, ok := data[kEntityWrappedKey] + if !ok { + return nil, fmt.Errorf("entityWrappedKey is missing in key access object") + } + + asymDecrypt, err := crypto.NewAsymDecryption(clientPrivateKey) + if err != nil { + return nil, fmt.Errorf("crypto.NewAsymDecryption failed: %w", err) + } + + entityWrappedKeyDecoded, err := crypto.Base64Decode([]byte(fmt.Sprintf("%v", entityWrappedKey))) + if err != nil { + return nil, fmt.Errorf("crypto.Base64Decode failed: %w", err) + } + + key, err := asymDecrypt.Decrypt(entityWrappedKeyDecoded) + if err != nil { + return nil, fmt.Errorf("crypto.Decrypt failed: %w", err) + } + + return key, nil +} diff --git a/sdk/tdf_test.go b/sdk/tdf_test.go index 8c01868ef2..807d329a4d 100644 --- a/sdk/tdf_test.go +++ b/sdk/tdf_test.go @@ -326,13 +326,13 @@ func TestSimpleTDF(t *testing.T) { } }(fileWriter) - tdfSize, err := CreateTDF(*tdfConfig, bufReader, fileWriter) + tdfObj, err := CreateTDF(*tdfConfig, bufReader, fileWriter) if err != nil { t.Fatalf("tdf.CreateTDF failed: %v", err) } - if tdfSize != expectedTdfSize { - t.Errorf("tdf size test failed expected %v, got %v", tdfSize, expectedTdfSize) + if tdfObj.TdfSize != expectedTdfSize { + t.Errorf("tdf size test failed expected %v, got %v", tdfObj.TdfSize, expectedTdfSize) } } @@ -360,12 +360,16 @@ func TestSimpleTDF(t *testing.T) { authConfig.signingPublicKey = signingPubKey authConfig.signingPrivateKey = signingPrivateKey - r, err := NewReader(*authConfig, readSeeker) + r, err := LoadTDF(*authConfig, readSeeker) + if err != nil { + t.Fatalf("Fail to load the tdf:%v", err) + } + + unencryptedMetaData, err := r.GetUnencryptedMetadata() if err != nil { t.Fatalf("Fail to get meta data from tdf:%v", err) } - unencryptedMetaData := r.UnencryptedMetadata() if metaDataStr != unencryptedMetaData { t.Errorf("meta data test failed expected %v, got %v", metaDataStr, unencryptedMetaData) } @@ -405,7 +409,7 @@ func TestSimpleTDF(t *testing.T) { authConfig.signingPublicKey = signingPubKey authConfig.signingPrivateKey = signingPrivateKey - r, err := NewReader(*authConfig, readSeeker) + r, err := LoadTDF(*authConfig, readSeeker) if err != nil { t.Fatalf("Fail to create reader:%v", err) } @@ -468,7 +472,7 @@ func TestTDFReader(t *testing.T) { // test reader tdfReadSeeker := bytes.NewReader(tdfBuf.Bytes()) - r, err := NewReader(*authConfig, tdfReadSeeker) + r, err := LoadTDF(*authConfig, tdfReadSeeker) if err != nil { t.Fatalf("failed to read tdf: %v", err) } @@ -624,7 +628,7 @@ func BenchmarkReader(b *testing.B) { authConfig.signingPrivateKey = signingPrivateKey readSeeker = bytes.NewReader(tdfBuf.Bytes()) - r, err := NewReader(*authConfig, readSeeker) + r, err := LoadTDF(*authConfig, readSeeker) if err != nil { b.Fatalf("failed to read tdf: %v", err) } @@ -671,13 +675,13 @@ func testEncrypt(t *testing.T, tdfConfig TDFConfig, plainTextFilename, tdfFileNa t.Fatalf("Fail to close the tdf file: %v", err) } }(fileWriter) // CreateTDF TDFConfig - tdfSize, err := CreateTDF(tdfConfig, readSeeker, fileWriter) + tdfObj, err := CreateTDF(tdfConfig, readSeeker, fileWriter) if err != nil { t.Fatalf("tdf.CreateTDF failed: %v", err) } - if tdfSize != test.tdfFileSize { - t.Errorf("tdf size test failed expected %v, got %v", test.tdfFileSize, tdfSize) + if tdfObj.TdfSize != test.tdfFileSize { + t.Errorf("tdf size test failed expected %v, got %v", test.tdfFileSize, tdfObj.TdfSize) } } @@ -694,7 +698,7 @@ func testDecryptWithReader(t *testing.T, authConfig AuthConfig, tdfFile, decrypt } }(readSeeker) - r, err := NewReader(authConfig, readSeeker) + r, err := LoadTDF(authConfig, readSeeker) if err != nil { t.Fatalf("failed to read tdf: %v", err) }