Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 150 additions & 6 deletions sdk/auth_config.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,31 @@
package sdk

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"strings"
"time"

"github.com/golang-jwt/jwt/v4"
"github.com/opentdf/platform/sdk/internal/crypto"
)

type AuthConfig struct {
signingPublicKey string
signingPrivateKey string
authToken string
dpopPublicKeyPEM string
dpopPrivateKeyPEM string
accessToken string
}

type RequestBody struct {
KeyAccess `json:"keyAccess"`
ClientPublicKey string `json:"clientPublicKey"`
Policy string `json:"policy"`
}

// NewAuthConfig Create a new instance of authConfig
Expand All @@ -35,7 +45,7 @@ func NewAuthConfig() (*AuthConfig, error) {
return nil, fmt.Errorf("crypto.PrivateKeyInPemFormat failed: %w", err)
}

return &AuthConfig{signingPublicKey: publicKey, signingPrivateKey: privateKey}, nil
return &AuthConfig{dpopPublicKeyPEM: publicKey, dpopPrivateKeyPEM: privateKey}, nil
}

func NewOIDCAuthConfig(ctx context.Context, host, realm, clientId, clientSecret, subjectToken string) (*AuthConfig, error) {
Expand All @@ -44,7 +54,7 @@ func NewOIDCAuthConfig(ctx context.Context, host, realm, clientId, clientSecret,
return nil, err
}

authConfig.authToken, err = authConfig.fetchOIDCAccessToken(ctx, host, realm, clientId, clientSecret, subjectToken)
authConfig.accessToken, err = authConfig.fetchOIDCAccessToken(ctx, host, realm, clientId, clientSecret, subjectToken)
if err != nil {
return nil, fmt.Errorf("Failed to fetch acces token:%w", err)
}
Expand All @@ -62,11 +72,14 @@ func (a *AuthConfig) fetchOIDCAccessToken(ctx context.Context, host, realm, clie
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

certB64 := crypto.Base64Encode([]byte(a.signingPublicKey))
certB64 := crypto.Base64Encode([]byte(a.dpopPublicKeyPEM))
req.Header.Set("X-VirtruPubKey", string(certB64))

client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("error making request to IdP for token exchange: %w", err)
}
type keycloakResponsePayload struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
Expand All @@ -79,3 +92,134 @@ func (a *AuthConfig) fetchOIDCAccessToken(ctx context.Context, host, realm, clie
}
return "Bearer " + keyResp.AccessToken, nil
}

func (a *AuthConfig) makeKASRequest(kasPath string, body *RequestBody) (*http.Response, error) {
kasURL := body.KasURL

requestBodyData, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("json.Marshal failed: %w", err)
}

claims := rewrapJWTClaims{
jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)),
IssuedAt: jwt.NewNumericDate(time.Now()),
},
string(requestBodyData),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)

signingRSAPrivateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(a.dpopPrivateKeyPEM))
if err != nil {
return nil, fmt.Errorf("jwt.ParseRSAPrivateKeyFromPEM failed: %w", err)
}

signedToken, err := token.SignedString(signingRSAPrivateKey)
if err != nil {
return nil, fmt.Errorf("jwt.SignedString failed: %w", err)
}

signedTokenRequestBody, err := json.Marshal(map[string]string{
kSignedRequestToken: signedToken,
})
if err != nil {
return nil, fmt.Errorf("json.Marshal failed: %w", err)
}

kasRequestURL, err := url.JoinPath(fmt.Sprintf("%v", kasURL), kasPath)
if err != nil {
return nil, fmt.Errorf("url.JoinPath failed: %w", err)
}
request, err := http.NewRequestWithContext(context.Background(), http.MethodPost, kasRequestURL,
bytes.NewBuffer(signedTokenRequestBody))
if err != nil {
return nil, fmt.Errorf("http.NewRequestWithContext failed: %w", err)
}

// add required headers
request.Header = http.Header{
kContentTypeKey: {kContentTypeJSONValue},
kAuthorizationKey: {fmt.Sprintf("Bearer %s", a.accessToken)},
kAcceptKey: {kContentTypeJSONValue},
}

client := &http.Client{}

response, err := client.Do(request)
if err != nil {
slog.Error("failed http request")
return nil, fmt.Errorf("http request failed: %w", err)
}

return response, nil
}

func (a *AuthConfig) unwrap(keyAccess KeyAccess, policy string) ([]byte, error) {
requestBody := RequestBody{
KeyAccess: keyAccess,
Policy: policy,
ClientPublicKey: a.dpopPublicKeyPEM,
}

response, err := a.makeKASRequest(kRewrapV2, &requestBody)
defer func() {
if response == nil {
return
}
err := response.Body.Close()
if err != nil {
slog.Error("Fail to close HTTP response")
}
}()

if err != nil {
slog.Error("failed http request")
return nil, err
}
if response.StatusCode != kHTTPOk {
return nil, fmt.Errorf("http request failed status code:%d", response.StatusCode)
}

rewrapResponseBody, err := io.ReadAll(response.Body)
if err != nil {
return nil, fmt.Errorf("io.ReadAll failed: %w", err)
}

key, err := getWrappedKey(rewrapResponseBody, a.dpopPrivateKeyPEM)
if err != nil {
return nil, fmt.Errorf("failed to unwrap the wrapped key:%w", err)
}

return key, nil
}

func getWrappedKey(rewrapResponseBody []byte, clientPrivateKey string) ([]byte, error) {
var data map[string]interface{}
err := json.Unmarshal(rewrapResponseBody, &data)
if err != nil {
return nil, fmt.Errorf("json.Unmarshal failed: %w", err)
}

entityWrappedKey, ok := data[kEntityWrappedKey]
if !ok {
return nil, fmt.Errorf("entityWrappedKey is missing in key access object")
}

asymDecrypt, err := crypto.NewAsymDecryption(clientPrivateKey)
if err != nil {
return nil, fmt.Errorf("crypto.NewAsymDecryption failed: %w", err)
}

entityWrappedKeyDecoded, err := crypto.Base64Decode([]byte(fmt.Sprintf("%v", entityWrappedKey)))
if err != nil {
return nil, fmt.Errorf("crypto.Base64Decode failed: %w", err)
}

key, err := asymDecrypt.Decrypt(entityWrappedKeyDecoded)
if err != nil {
return nil, fmt.Errorf("crypto.Decrypt failed: %w", err)
}

return key, nil
}
4 changes: 2 additions & 2 deletions sdk/auth_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestNewOIDCAuthConfig(t *testing.T) {
t.Fatalf("authconfig failed: %v", err)
}

if authConfig.authToken != expectedAccessToken {
t.Fatalf("Auth token expected %s recived %s", expectedAccessToken, authConfig.authToken)
if authConfig.accessToken != expectedAccessToken {
t.Fatalf("Auth token expected %s recived %s", expectedAccessToken, authConfig.accessToken)
}
}
Loading