diff --git a/sdk/auth_config.go b/sdk/auth_config.go index 968e9a5a3e..f080a3846d 100644 --- a/sdk/auth_config.go +++ b/sdk/auth_config.go @@ -1,21 +1,31 @@ package sdk import ( + "bytes" "context" "encoding/json" "fmt" "io" + "log/slog" "net/http" "net/url" "strings" + "time" + "github.com/golang-jwt/jwt/v4" "github.com/opentdf/platform/sdk/internal/crypto" ) type AuthConfig struct { - signingPublicKey string - signingPrivateKey string - authToken string + dpopPublicKeyPEM string + dpopPrivateKeyPEM string + accessToken string +} + +type RequestBody struct { + KeyAccess `json:"keyAccess"` + ClientPublicKey string `json:"clientPublicKey"` + Policy string `json:"policy"` } // NewAuthConfig Create a new instance of authConfig @@ -35,7 +45,7 @@ func NewAuthConfig() (*AuthConfig, error) { return nil, fmt.Errorf("crypto.PrivateKeyInPemFormat failed: %w", err) } - return &AuthConfig{signingPublicKey: publicKey, signingPrivateKey: privateKey}, nil + return &AuthConfig{dpopPublicKeyPEM: publicKey, dpopPrivateKeyPEM: privateKey}, nil } func NewOIDCAuthConfig(ctx context.Context, host, realm, clientId, clientSecret, subjectToken string) (*AuthConfig, error) { @@ -44,7 +54,7 @@ func NewOIDCAuthConfig(ctx context.Context, host, realm, clientId, clientSecret, return nil, err } - authConfig.authToken, err = authConfig.fetchOIDCAccessToken(ctx, host, realm, clientId, clientSecret, subjectToken) + authConfig.accessToken, err = authConfig.fetchOIDCAccessToken(ctx, host, realm, clientId, clientSecret, subjectToken) if err != nil { return nil, fmt.Errorf("Failed to fetch acces token:%w", err) } @@ -62,11 +72,14 @@ func (a *AuthConfig) fetchOIDCAccessToken(ctx context.Context, host, realm, clie } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - certB64 := crypto.Base64Encode([]byte(a.signingPublicKey)) + certB64 := crypto.Base64Encode([]byte(a.dpopPublicKeyPEM)) req.Header.Set("X-VirtruPubKey", string(certB64)) client := &http.Client{} resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("error making request to IdP for token exchange: %w", err) + } type keycloakResponsePayload struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` @@ -79,3 +92,134 @@ func (a *AuthConfig) fetchOIDCAccessToken(ctx context.Context, host, realm, clie } return "Bearer " + keyResp.AccessToken, nil } + +func (a *AuthConfig) makeKASRequest(kasPath string, body *RequestBody) (*http.Response, error) { + kasURL := body.KasURL + + requestBodyData, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("json.Marshal failed: %w", err) + } + + claims := rewrapJWTClaims{ + jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + string(requestBodyData), + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + + signingRSAPrivateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(a.dpopPrivateKeyPEM)) + if err != nil { + return nil, fmt.Errorf("jwt.ParseRSAPrivateKeyFromPEM failed: %w", err) + } + + signedToken, err := token.SignedString(signingRSAPrivateKey) + if err != nil { + return nil, fmt.Errorf("jwt.SignedString failed: %w", err) + } + + signedTokenRequestBody, err := json.Marshal(map[string]string{ + kSignedRequestToken: signedToken, + }) + if err != nil { + return nil, fmt.Errorf("json.Marshal failed: %w", err) + } + + kasRequestURL, err := url.JoinPath(fmt.Sprintf("%v", kasURL), kasPath) + if err != nil { + return nil, fmt.Errorf("url.JoinPath failed: %w", err) + } + request, err := http.NewRequestWithContext(context.Background(), http.MethodPost, kasRequestURL, + bytes.NewBuffer(signedTokenRequestBody)) + if err != nil { + return nil, fmt.Errorf("http.NewRequestWithContext failed: %w", err) + } + + // add required headers + request.Header = http.Header{ + kContentTypeKey: {kContentTypeJSONValue}, + kAuthorizationKey: {fmt.Sprintf("Bearer %s", a.accessToken)}, + kAcceptKey: {kContentTypeJSONValue}, + } + + client := &http.Client{} + + response, err := client.Do(request) + if err != nil { + slog.Error("failed http request") + return nil, fmt.Errorf("http request failed: %w", err) + } + + return response, nil +} + +func (a *AuthConfig) unwrap(keyAccess KeyAccess, policy string) ([]byte, error) { + requestBody := RequestBody{ + KeyAccess: keyAccess, + Policy: policy, + ClientPublicKey: a.dpopPublicKeyPEM, + } + + response, err := a.makeKASRequest(kRewrapV2, &requestBody) + defer func() { + if response == nil { + return + } + err := response.Body.Close() + if err != nil { + slog.Error("Fail to close HTTP response") + } + }() + + if err != nil { + slog.Error("failed http request") + return nil, err + } + if response.StatusCode != kHTTPOk { + return nil, fmt.Errorf("http request failed status code:%d", response.StatusCode) + } + + rewrapResponseBody, err := io.ReadAll(response.Body) + if err != nil { + return nil, fmt.Errorf("io.ReadAll failed: %w", err) + } + + key, err := getWrappedKey(rewrapResponseBody, a.dpopPrivateKeyPEM) + if err != nil { + return nil, fmt.Errorf("failed to unwrap the wrapped key:%w", err) + } + + return key, nil +} + +func getWrappedKey(rewrapResponseBody []byte, clientPrivateKey string) ([]byte, error) { + var data map[string]interface{} + err := json.Unmarshal(rewrapResponseBody, &data) + if err != nil { + return nil, fmt.Errorf("json.Unmarshal failed: %w", err) + } + + entityWrappedKey, ok := data[kEntityWrappedKey] + if !ok { + return nil, fmt.Errorf("entityWrappedKey is missing in key access object") + } + + asymDecrypt, err := crypto.NewAsymDecryption(clientPrivateKey) + if err != nil { + return nil, fmt.Errorf("crypto.NewAsymDecryption failed: %w", err) + } + + entityWrappedKeyDecoded, err := crypto.Base64Decode([]byte(fmt.Sprintf("%v", entityWrappedKey))) + if err != nil { + return nil, fmt.Errorf("crypto.Base64Decode failed: %w", err) + } + + key, err := asymDecrypt.Decrypt(entityWrappedKeyDecoded) + if err != nil { + return nil, fmt.Errorf("crypto.Decrypt failed: %w", err) + } + + return key, nil +} diff --git a/sdk/auth_config_test.go b/sdk/auth_config_test.go index 9e5758d345..4fb2e42142 100644 --- a/sdk/auth_config_test.go +++ b/sdk/auth_config_test.go @@ -41,7 +41,7 @@ func TestNewOIDCAuthConfig(t *testing.T) { t.Fatalf("authconfig failed: %v", err) } - if authConfig.authToken != expectedAccessToken { - t.Fatalf("Auth token expected %s recived %s", expectedAccessToken, authConfig.authToken) + if authConfig.accessToken != expectedAccessToken { + t.Fatalf("Auth token expected %s recived %s", expectedAccessToken, authConfig.accessToken) } } diff --git a/sdk/tdf.go b/sdk/tdf.go index 58c3b18e12..57017a3cd9 100644 --- a/sdk/tdf.go +++ b/sdk/tdf.go @@ -2,18 +2,13 @@ package sdk import ( "bytes" - "context" "encoding/hex" "encoding/json" "errors" "fmt" "io" - "log/slog" "math" - "net/http" - "net/url" "strings" - "time" "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" @@ -67,7 +62,7 @@ type Reader struct { manifest Manifest unencryptedMetadata string tdfReader archive.TDFReader - authConfig AuthConfig + unwrapper Unwrapper cursor int64 aesGcm crypto.AesGcm payloadSize int64 @@ -86,10 +81,8 @@ type rewrapJWTClaims struct { Body string `json:"requestBody"` } -type RequestBody struct { - KeyAccess `json:"keyAccess"` - ClientPublicKey string `json:"clientPublicKey"` - Policy string `json:"policy"` +type Unwrapper interface { + unwrap(keyAccess KeyAccess, policy string) ([]byte, error) } // CreateTDF reads plain text from the given reader and saves it to the writer, subject to the given options @@ -367,7 +360,7 @@ func (t *TDFObject) createPolicyObject(attributes []string) (PolicyObject, error } // LoadTDF loads the tdf and prepare for reading the payload from TDF -func LoadTDF(authConfig AuthConfig, reader io.ReadSeeker) (*Reader, error) { +func LoadTDF(unwrapper Unwrapper, reader io.ReadSeeker) (*Reader, error) { // create tdf reader tdfReader, err := archive.NewTDFReader(reader) if err != nil { @@ -386,9 +379,9 @@ func LoadTDF(authConfig AuthConfig, reader io.ReadSeeker) (*Reader, error) { } return &Reader{ - tdfReader: tdfReader, - manifest: *manifestObj, - authConfig: authConfig, + tdfReader: tdfReader, + manifest: *manifestObj, + unwrapper: unwrapper, }, nil } @@ -621,8 +614,7 @@ func (r *Reader) doPayloadKeyUnwrap() error { //nolint:gocognit var unencryptedMetadata string var payloadKey [kKeySize]byte for _, keyAccessObj := range r.manifest.EncryptionInformation.KeyAccessObjs { - requestBody := RequestBody{keyAccessObj, "", r.manifest.EncryptionInformation.Policy} - wrappedKey, err := rewrap(r.authConfig, &requestBody) + wrappedKey, err := r.unwrapper.unwrap(keyAccessObj, r.manifest.EncryptionInformation.Policy) if err != nil { return fmt.Errorf(" splitKey.rewrap failed:%w", err) } @@ -737,144 +729,3 @@ func validateRootSignature(manifest Manifest, secret []byte) (bool, error) { return false, nil } - -func handleKasRequest(kasPath string, body *RequestBody, authConfig AuthConfig) (*http.Response, error) { - kasURL := body.KasURL - - requestBodyData, err := json.Marshal(body) - if err != nil { - return nil, fmt.Errorf("json.Marshal failed: %w", err) - } - - claims := rewrapJWTClaims{ - jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), - IssuedAt: jwt.NewNumericDate(time.Now()), - }, - string(requestBodyData), - } - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) - - signingRSAPrivateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(authConfig.signingPrivateKey)) - if err != nil { - return nil, fmt.Errorf("jwt.ParseRSAPrivateKeyFromPEM failed: %w", err) - } - - signedToken, err := token.SignedString(signingRSAPrivateKey) - if err != nil { - return nil, fmt.Errorf("jwt.SignedString failed: %w", err) - } - - signedTokenRequestBody, err := json.Marshal(map[string]string{ - kSignedRequestToken: signedToken, - }) - if err != nil { - return nil, fmt.Errorf("json.Marshal failed: %w", err) - } - - kasRequestURL, err := url.JoinPath(fmt.Sprintf("%v", kasURL), kasPath) - if err != nil { - return nil, fmt.Errorf("url.JoinPath failed: %w", err) - } - request, err := http.NewRequestWithContext(context.Background(), http.MethodPost, kasRequestURL, - bytes.NewBuffer(signedTokenRequestBody)) - if err != nil { - return nil, fmt.Errorf("http.NewRequestWithContext failed: %w", err) - } - - // add required headers - request.Header = http.Header{ - kContentTypeKey: {kContentTypeJSONValue}, - kAuthorizationKey: {authConfig.authToken}, - kAcceptKey: {kContentTypeJSONValue}, - } - - client := &http.Client{} - - response, err := client.Do(request) - if err != nil { - slog.Error("failed http request") - return nil, fmt.Errorf("http request failed: %w", err) - } - - return response, nil -} - -func rewrap(authConfig AuthConfig, requestBody *RequestBody) ([]byte, error) { - clientKeyPair, err := crypto.NewRSAKeyPair(tdf3KeySize) - if err != nil { - return nil, fmt.Errorf("crypto.NewRSAKeyPair failed: %w", err) - } - - clientPubKey, err := clientKeyPair.PublicKeyInPemFormat() - if err != nil { - return nil, fmt.Errorf("crypto.PublicKeyInPemFormat failed: %w", err) - } - requestBody.ClientPublicKey = clientPubKey - - clientPrivateKey, err := clientKeyPair.PrivateKeyInPemFormat() - if err != nil { - return nil, fmt.Errorf("crypto.PublicKeyInPemFormat failed: %w", err) - } - - response, err := handleKasRequest(kRewrapV2, requestBody, authConfig) - defer func() { - if response == nil { - return - } - err := response.Body.Close() - if err != nil { - slog.Error("Fail to close HTTP response") - } - }() - - if err != nil { - slog.Error("failed http request") - return nil, err - } - if response.StatusCode != kHTTPOk { - return nil, fmt.Errorf("http request failed status code:%d", response.StatusCode) - } - - rewrapResponseBody, err := io.ReadAll(response.Body) - if err != nil { - return nil, fmt.Errorf("io.ReadAll failed: %w", err) - } - - key, err := getWrappedKey(rewrapResponseBody, clientPrivateKey) - if err != nil { - return nil, fmt.Errorf("failed to unwrap the wrapped key:%w", err) - } - - return key, nil -} - -func getWrappedKey(rewrapResponseBody []byte, clientPrivateKey string) ([]byte, error) { - var data map[string]interface{} - err := json.Unmarshal(rewrapResponseBody, &data) - if err != nil { - return nil, fmt.Errorf("json.Unmarshal failed: %w", err) - } - - entityWrappedKey, ok := data[kEntityWrappedKey] - if !ok { - return nil, fmt.Errorf("entityWrappedKey is missing in key access object") - } - - asymDecrypt, err := crypto.NewAsymDecryption(clientPrivateKey) - if err != nil { - return nil, fmt.Errorf("crypto.NewAsymDecryption failed: %w", err) - } - - entityWrappedKeyDecoded, err := crypto.Base64Decode([]byte(fmt.Sprintf("%v", entityWrappedKey))) - if err != nil { - return nil, fmt.Errorf("crypto.Base64Decode failed: %w", err) - } - - key, err := asymDecrypt.Decrypt(entityWrappedKeyDecoded) - if err != nil { - return nil, fmt.Errorf("crypto.Decrypt failed: %w", err) - } - - return key, nil -} diff --git a/sdk/tdf_test.go b/sdk/tdf_test.go index 0a4d315462..20b8f0fc12 100644 --- a/sdk/tdf_test.go +++ b/sdk/tdf_test.go @@ -2,6 +2,7 @@ package sdk import ( "bytes" + "crypto/rand" "crypto/sha256" "encoding/json" "errors" @@ -277,7 +278,7 @@ func init() { } func TestSimpleTDF(t *testing.T) { //nolint:gocognit - server, signingPubKey, signingPrivateKey := runKas() + server, authConfig := runKas() defer server.Close() metaDataStr := `{"displayName" : "openTDF go sdk"}` @@ -339,17 +340,7 @@ func TestSimpleTDF(t *testing.T) { //nolint:gocognit } }(readSeeker) - // create auth config - authConfig, err := NewAuthConfig() - if err != nil { - t.Fatalf("Fail to close archive file:%v", err) - } - - // override the signing keys to get the mock working. - authConfig.signingPublicKey = signingPubKey - authConfig.signingPrivateKey = signingPrivateKey - - r, err := LoadTDF(*authConfig, readSeeker) + r, err := LoadTDF(&authConfig, readSeeker) if err != nil { t.Fatalf("Fail to load the tdf:%v", err) } @@ -388,17 +379,8 @@ func TestSimpleTDF(t *testing.T) { //nolint:gocognit }(readSeeker) buf := make([]byte, 8) - // create auth config - authConfig, err := NewAuthConfig() - if err != nil { - t.Fatalf("Fail to close archive file:%v", err) - } - - // override the signing keys to get the mock working. - authConfig.signingPublicKey = signingPubKey - authConfig.signingPrivateKey = signingPrivateKey - r, err := LoadTDF(*authConfig, readSeeker) + r, err := LoadTDF(&authConfig, readSeeker) if err != nil { t.Fatalf("Fail to create reader:%v", err) } @@ -419,7 +401,7 @@ func TestSimpleTDF(t *testing.T) { //nolint:gocognit } func TestTDFReader(t *testing.T) { //nolint:gocognit - server, signingPubKey, signingPrivateKey := runKas() + server, authConfig := runKas() defer server.Close() for _, test := range partialTDFTestHarnesses { // create .txt file @@ -429,12 +411,6 @@ func TestTDFReader(t *testing.T) { //nolint:gocognit kasInfoList[index].publicKey = "" } - // create auth config - authConfig, err := NewAuthConfig() - if err != nil { - t.Fatalf("Fail to close archive file:%v", err) - } - for _, readAtTest := range test.readAtTests { tdfBuf := bytes.Buffer{} readSeeker := bytes.NewReader([]byte(test.payload)) @@ -449,13 +425,9 @@ func TestTDFReader(t *testing.T) { //nolint:gocognit t.Fatalf("tdf.CreateTDF failed: %v", err) } - // override the signing keys to get the mock working. - authConfig.signingPublicKey = signingPubKey - authConfig.signingPrivateKey = signingPrivateKey - // test reader tdfReadSeeker := bytes.NewReader(tdfBuf.Bytes()) - r, err := LoadTDF(*authConfig, tdfReadSeeker) + r, err := LoadTDF(&authConfig, tdfReadSeeker) if err != nil { t.Fatalf("failed to read tdf: %v", err) } @@ -509,7 +481,7 @@ func TestTDFReader(t *testing.T) { //nolint:gocognit } func TestTDF(t *testing.T) { - server, signingPubKey, signingPrivateKey := runKas() + server, authConfig := runKas() defer server.Close() for index, test := range testHarnesses { @@ -527,18 +499,8 @@ func TestTDF(t *testing.T) { // test encrypt testEncrypt(t, kasInfoList, plaintTextFileName, tdfFileName, test) - // create auth config - authConfig, err := NewAuthConfig() - if err != nil { - t.Fatalf("Fail to close archive file:%v", err) - } - - // override the signing keys to get the mock working. - authConfig.signingPublicKey = signingPubKey - authConfig.signingPrivateKey = signingPrivateKey - // test decrypt with reader - testDecryptWithReader(t, *authConfig, tdfFileName, decryptedTdfFileName, test) + testDecryptWithReader(t, authConfig, tdfFileName, decryptedTdfFileName, test) // Remove the test files _ = os.Remove(plaintTextFileName) @@ -557,7 +519,7 @@ func BenchmarkReader(b *testing.B) { }, } - server, signingPubKey, signingPrivateKey := runKas() + server, authConfig := runKas() defer server.Close() kasInfoList := test.kasInfoList @@ -580,18 +542,8 @@ func BenchmarkReader(b *testing.B) { b.Fatalf("tdf.CreateTDF failed: %v", err) } - // create auth config - authConfig, err := NewAuthConfig() - if err != nil { - b.Fatalf("Fail to close archive file:%v", err) - } - - // override the signing keys to get the mock working. - authConfig.signingPublicKey = signingPubKey - authConfig.signingPrivateKey = signingPrivateKey - readSeeker = bytes.NewReader(tdfBuf.Bytes()) - r, err := LoadTDF(*authConfig, readSeeker) + r, err := LoadTDF(&authConfig, readSeeker) if err != nil { b.Fatalf("failed to read tdf: %v", err) } @@ -661,7 +613,7 @@ func testDecryptWithReader(t *testing.T, authConfig AuthConfig, tdfFile, decrypt } }(readSeeker) - r, err := LoadTDF(authConfig, readSeeker) + r, err := LoadTDF(&authConfig, readSeeker) if err != nil { t.Fatalf("failed to read tdf: %v", err) } @@ -734,7 +686,7 @@ func createFileName(buf []byte, filename string, size int64) { } } -func runKas() (*httptest.Server, string, string) { //nolint:gocognit +func runKas() (*httptest.Server, AuthConfig) { //nolint:gocognit signingKeyPair, err := crypto.NewRSAKeyPair(tdf3KeySize) if err != nil { panic(fmt.Sprintf("crypto.NewRSAKeyPair: %v", err)) @@ -750,6 +702,12 @@ func runKas() (*httptest.Server, string, string) { //nolint:gocognit panic(fmt.Sprintf("crypto.PrivateKeyInPemFormat failed: %v", err)) } + accessTokenBytes := make([]byte, 10) + if _, err := rand.Read(accessTokenBytes); err != nil { + panic("failed to create random access token") + } + accessToken := crypto.Base64Encode(accessTokenBytes) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get(kAcceptKey) != kContentTypeJSONValue { panic(fmt.Sprintf("expected Accept: application/json header, got: %s", r.Header.Get("Accept"))) @@ -773,6 +731,10 @@ func runKas() (*httptest.Server, string, string) { //nolint:gocognit if err != nil { panic(fmt.Sprintf("io.ReadAll failed: %v", err)) } + if r.Header.Get("authorization") != fmt.Sprintf("Bearer %s", accessToken) { + panic(fmt.Sprintf("got a bad auth header: [%s]", r.Header.Get("authorization"))) + } + var data map[string]string err = json.Unmarshal(requestBody, &data) if err != nil { @@ -841,7 +803,7 @@ func runKas() (*httptest.Server, string, string) { //nolint:gocognit } })) - return server, signingPubKey, signingPrivateKey + return server, AuthConfig{dpopPublicKeyPEM: signingPubKey, dpopPrivateKeyPEM: signingPrivateKey, accessToken: string(accessToken)} } func checkIdentical(t *testing.T, file, checksum string) bool {