Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions sdk/auth/oauth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ func (t Token) Expired() bool {
return time.Now().After(expirationTime.Add(-tokenExpirationBuffer))
}

func getAccessTokenRequest(tokenEndpoint, dpopNonce string, scopes []string, clientCredentials ClientCredentials, privateJWK *jwk.Key) (*http.Request, error) {
func getAccessTokenRequest(logger *slog.Logger, tokenEndpoint, dpopNonce string, scopes []string, clientCredentials ClientCredentials, privateJWK *jwk.Key) (*http.Request, error) {
req, err := http.NewRequest(http.MethodPost, tokenEndpoint, nil) //nolint: noctx // TODO(#455): AccessToken methods should take contexts
if err != nil {
return nil, err
}
dpop, err := getDPoPAssertion(*privateJWK, http.MethodPost, tokenEndpoint, dpopNonce)
dpop, err := getDPoPAssertion(logger, *privateJWK, http.MethodPost, tokenEndpoint, dpopNonce)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -141,8 +141,8 @@ func getSignedToken(clientID, tokenEndpoint string, key jwk.Key) ([]byte, error)
// GetAccessToken this misses the flow where the Authorization server can tell us the next nonce to use.
// missing this flow costs us a bit in efficiency (a round trip per access token) but this is
// still correct because
func GetAccessToken(client *http.Client, tokenEndpoint string, scopes []string, clientCredentials ClientCredentials, dpopPrivateKey jwk.Key) (*Token, error) {
req, err := getAccessTokenRequest(tokenEndpoint, "", scopes, clientCredentials, &dpopPrivateKey)
func GetAccessToken(logger *slog.Logger, client *http.Client, tokenEndpoint string, scopes []string, clientCredentials ClientCredentials, dpopPrivateKey jwk.Key) (*Token, error) {
req, err := getAccessTokenRequest(logger, tokenEndpoint, "", scopes, clientCredentials, &dpopPrivateKey)
if err != nil {
return nil, err
}
Expand All @@ -158,7 +158,7 @@ func GetAccessToken(client *http.Client, tokenEndpoint string, scopes []string,
defer resp.Body.Close()

if nonceHeader := resp.Header.Get("dpop-nonce"); nonceHeader != "" && resp.StatusCode == http.StatusBadRequest {
nonceReq, err := getAccessTokenRequest(tokenEndpoint, nonceHeader, scopes, clientCredentials, &dpopPrivateKey)
nonceReq, err := getAccessTokenRequest(logger, tokenEndpoint, nonceHeader, scopes, clientCredentials, &dpopPrivateKey)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -195,8 +195,8 @@ func processResponse(resp *http.Response) (*Token, error) {
return token, nil
}

func getDPoPAssertion(dpopJWK jwk.Key, method string, endpoint string, nonce string) (string, error) {
slog.Debug("building DPoP Proof")
func getDPoPAssertion(logger *slog.Logger, dpopJWK jwk.Key, method string, endpoint string, nonce string) (string, error) {
logger.Debug("building DPoP Proof")
publicKey, err := jwk.PublicKeyOf(dpopJWK)
const expirationTime = 5 * time.Minute

Expand Down Expand Up @@ -247,8 +247,8 @@ func getDPoPAssertion(dpopJWK jwk.Key, method string, endpoint string, nonce str
return string(proof), nil
}

func DoTokenExchange(ctx context.Context, client *http.Client, tokenEndpoint string, scopes []string, clientCredentials ClientCredentials, tokenExchange TokenExchangeInfo, key jwk.Key) (*Token, error) {
req, err := getTokenExchangeRequest(ctx, tokenEndpoint, "", scopes, clientCredentials, tokenExchange, &key)
func DoTokenExchange(ctx context.Context, log *slog.Logger, client *http.Client, tokenEndpoint string, scopes []string, clientCredentials ClientCredentials, tokenExchange TokenExchangeInfo, key jwk.Key) (*Token, error) {
req, err := getTokenExchangeRequest(ctx, log, tokenEndpoint, "", scopes, clientCredentials, tokenExchange, &key)
if err != nil {
return nil, err
}
Expand All @@ -262,7 +262,7 @@ func DoTokenExchange(ctx context.Context, client *http.Client, tokenEndpoint str
defer resp.Body.Close()

if nonceHeader := resp.Header.Get("dpop-nonce"); nonceHeader != "" && resp.StatusCode == http.StatusBadRequest {
nonceReq, err := getTokenExchangeRequest(ctx, tokenEndpoint, nonceHeader, scopes, clientCredentials, tokenExchange, &key)
nonceReq, err := getTokenExchangeRequest(ctx, log, tokenEndpoint, nonceHeader, scopes, clientCredentials, tokenExchange, &key)
if err != nil {
return nil, err
}
Expand All @@ -279,7 +279,7 @@ func DoTokenExchange(ctx context.Context, client *http.Client, tokenEndpoint str
return processResponse(resp)
}

func getTokenExchangeRequest(ctx context.Context, tokenEndpoint, dpopNonce string, scopes []string, clientCredentials ClientCredentials, tokenExchange TokenExchangeInfo, privateJWK *jwk.Key) (*http.Request, error) {
func getTokenExchangeRequest(ctx context.Context, logger *slog.Logger, tokenEndpoint, dpopNonce string, scopes []string, clientCredentials ClientCredentials, tokenExchange TokenExchangeInfo, privateJWK *jwk.Key) (*http.Request, error) {
data := url.Values{
"grant_type": {"urn:ietf:params:oauth:grant-type:token-exchange"},
"subject_token": {tokenExchange.SubjectToken},
Expand All @@ -299,7 +299,7 @@ func getTokenExchangeRequest(ctx context.Context, tokenEndpoint, dpopNonce strin
if err != nil {
return nil, fmt.Errorf("error getting HTTP request: %w", err)
}
dpop, err := getDPoPAssertion(*privateJWK, http.MethodPost, tokenEndpoint, dpopNonce)
dpop, err := getDPoPAssertion(logger, *privateJWK, http.MethodPost, tokenEndpoint, dpopNonce)
if err != nil {
return nil, err
}
Expand All @@ -314,8 +314,8 @@ func getTokenExchangeRequest(ctx context.Context, tokenEndpoint, dpopNonce strin
return req, nil
}

func DoCertExchange(ctx context.Context, tokenEndpoint string, exchangeInfo CertExchangeInfo, clientCredentials ClientCredentials, key jwk.Key) (*Token, error) {
req, err := getCertExchangeRequest(ctx, tokenEndpoint, clientCredentials, exchangeInfo, key)
func DoCertExchange(ctx context.Context, log *slog.Logger, tokenEndpoint string, exchangeInfo CertExchangeInfo, clientCredentials ClientCredentials, key jwk.Key) (*Token, error) {
req, err := getCertExchangeRequest(ctx, log, tokenEndpoint, clientCredentials, exchangeInfo, key)
if err != nil {
return nil, err
}
Expand All @@ -333,7 +333,7 @@ func DoCertExchange(ctx context.Context, tokenEndpoint string, exchangeInfo Cert
return processResponse(resp)
}

func getCertExchangeRequest(ctx context.Context, tokenEndpoint string, clientCredentials ClientCredentials, exchangeInfo CertExchangeInfo, key jwk.Key) (*http.Request, error) {
func getCertExchangeRequest(ctx context.Context, logger *slog.Logger, tokenEndpoint string, clientCredentials ClientCredentials, exchangeInfo CertExchangeInfo, key jwk.Key) (*http.Request, error) {
data := url.Values{"grant_type": {"password"}, "client_id": {clientCredentials.ClientID}, "username": {""}, "password": {""}}

for _, a := range exchangeInfo.Audience {
Expand All @@ -346,7 +346,7 @@ func getCertExchangeRequest(ctx context.Context, tokenEndpoint string, clientCre
return nil, err
}

dpop, err := getDPoPAssertion(key, http.MethodPost, tokenEndpoint, "")
dpop, err := getDPoPAssertion(logger, key, http.MethodPost, tokenEndpoint, "")
if err != nil {
return nil, err
}
Expand Down
12 changes: 8 additions & 4 deletions sdk/auth/oauth/oauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
_ "embed"
"encoding/base64"
"encoding/json"
"log/slog"
"net/http"
"net/http/httptest"
"os"
Expand Down Expand Up @@ -90,6 +91,7 @@ func (s *OAuthSuite) TestCertExchangeFromKeycloak() {

tok, err := DoCertExchange(
context.Background(),
slog.Default(),
s.keycloakHTTPSEndpoint,
exhcangeInfo,
clientCredentials,
Expand Down Expand Up @@ -128,6 +130,7 @@ func (s *OAuthSuite) TestGettingAccessTokenFromKeycloak() {
}

tok, err := GetAccessToken(
slog.Default(),
http.DefaultClient,
s.keycloakEndpoint,
[]string{"testscope"},
Expand Down Expand Up @@ -183,6 +186,7 @@ func (s *OAuthSuite) TestDoingTokenExchangeWithKeycloak() {
}

subjectToken, err := GetAccessToken(
slog.Default(),
http.DefaultClient,
s.keycloakEndpoint,
[]string{"testscope"},
Expand All @@ -200,7 +204,7 @@ func (s *OAuthSuite) TestDoingTokenExchangeWithKeycloak() {
Audience: []string{"opentdf-sdk"},
}

exchangedTok, err := DoTokenExchange(ctx, http.DefaultClient, s.keycloakEndpoint, []string{}, exchangeCredentials, tokenExchange, s.dpopJWK)
exchangedTok, err := DoTokenExchange(ctx, slog.Default(), http.DefaultClient, s.keycloakEndpoint, []string{}, exchangeCredentials, tokenExchange, s.dpopJWK)
s.Require().NoError(err)

tokenDetails, err := jwt.ParseString(exchangedTok.AccessToken, jwt.WithVerify(false))
Expand Down Expand Up @@ -274,7 +278,7 @@ func (s *OAuthSuite) TestClientSecretNoNonce() {
ClientID: "theclient",
ClientAuth: "thesecret",
}
_, err := GetAccessToken(http.DefaultClient, server.URL+"/token", []string{"scope1", "scope2"}, clientCredentials, s.dpopJWK)
_, err := GetAccessToken(slog.Default(), http.DefaultClient, server.URL+"/token", []string{"scope1", "scope2"}, clientCredentials, s.dpopJWK)
s.Require().NoError(err, "didn't get a token back from the IdP")
}

Expand Down Expand Up @@ -337,7 +341,7 @@ func (s *OAuthSuite) TestClientSecretWithNonce() {
ClientID: "theclient",
ClientAuth: "thesecret",
}
_, err := GetAccessToken(http.DefaultClient, server.URL+"/token", []string{"scope1", "scope2"}, clientCredentials, s.dpopJWK)
_, err := GetAccessToken(slog.Default(), http.DefaultClient, server.URL+"/token", []string{"scope1", "scope2"}, clientCredentials, s.dpopJWK)
if err != nil {
s.T().Errorf("didn't get a token back from the IdP: %v", err)
}
Expand Down Expand Up @@ -457,7 +461,7 @@ func (s *OAuthSuite) TestSignedJWTWithNonce() {

url = server.URL + "/token"

_, err = GetAccessToken(http.DefaultClient, url, []string{"scope1", "scope2"}, clientCredentials, dpopJWK)
_, err = GetAccessToken(slog.Default(), http.DefaultClient, url, []string{"scope1", "scope2"}, clientCredentials, dpopJWK)
if err != nil {
s.T().Errorf("didn't get a token back from the IdP: %v", err)
}
Expand Down
6 changes: 3 additions & 3 deletions sdk/bulk.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,15 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error {
return fmt.Errorf("retrieving platformEndpoint failed: %w", err)
}
// if no kasAllowlist is set, we get the allowlist from the registry
allowlist, err := allowListFromKASRegistry(ctx, s.KeyAccessServerRegistry, platformEndpoint)
allowlist, err := allowListFromKASRegistry(ctx, s.logger, s.KeyAccessServerRegistry, platformEndpoint)
if err != nil {
return fmt.Errorf("failed to get allowlist from registry: %w", err)
}
bulkReq.kasAllowlist = allowlist
bulkReq.NanoTDFDecryptOptions = append(bulkReq.NanoTDFDecryptOptions, withNanoKasAllowlist(bulkReq.kasAllowlist))
bulkReq.TDF3DecryptOptions = append(bulkReq.TDF3DecryptOptions, withKasAllowlist(bulkReq.kasAllowlist))
} else {
slog.Error("no KAS allowlist provided and no KeyAccessServerRegistry available")
s.Logger().Error("no KAS allowlist provided and no KeyAccessServerRegistry available")
return errors.New("no KAS allowlist provided and no KeyAccessServerRegistry available")
}
}
Expand Down Expand Up @@ -172,7 +172,7 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error {
var err error
for kasurl, rewrapRequests := range kasRewrapRequests {
if bulkReq.ignoreAllowList {
slog.Warn("kasAllowlist is ignored, kas url is allowed", slog.String("kas_url", kasurl))
s.Logger().Warn("kasAllowlist is ignored, kas url is allowed", slog.String("kas_url", kasurl))
} else if !bulkReq.kasAllowlist.IsAllowed(kasurl) {
// if kas url is not allowed, the result for each kao in each rewrap request is set to error
for _, req := range rewrapRequests {
Expand Down
2 changes: 2 additions & 0 deletions sdk/fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/rand"
"encoding/base64"
"io"
"log/slog"
"net/http"
"testing"

Expand All @@ -26,6 +27,7 @@ func writeBytes(writerFunc func(io.Writer) error) []byte {
func newSDK() *SDK {
key, _ := ocrypto.NewRSAKeyPair(tdf3KeySize)
cfg := &config{
logger: slog.Default(),
kasSessionKey: &key,
}
sdk := &SDK{
Expand Down
Loading
Loading