Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ require (
github.com/testcontainers/testcontainers-go v0.28.0
github.com/xeipuuv/gojsonschema v1.2.0
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225
golang.org/x/oauth2 v0.16.0
google.golang.org/grpc v1.62.1
google.golang.org/protobuf v1.33.0
)
Expand Down Expand Up @@ -86,6 +87,7 @@ require (
golang.org/x/sys v0.21.0 // indirect
golang.org/x/text v0.16.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
google.golang.org/appengine v1.6.8 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240311173647-c811ad7063a7 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240311173647-c811ad7063a7 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
Expand Down
7 changes: 7 additions & 0 deletions sdk/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
Expand Down Expand Up @@ -210,6 +211,8 @@ golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
golang.org/x/oauth2 v0.16.0 h1:aDkGMBSYxElaoP81NpoUoz2oo2R2wHdZpGToUxfyQrQ=
golang.org/x/oauth2 v0.16.0/go.mod h1:hqZ+0LWXsiVoZpeld6jVt06P3adbS2Uu911W1SsJv2o=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
Expand Down Expand Up @@ -245,6 +248,7 @@ golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
Expand All @@ -264,13 +268,16 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM=
google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds=
google.golang.org/genproto/googleapis/api v0.0.0-20240311173647-c811ad7063a7 h1:oqta3O3AnlWbmIE3bFnWbu4bRxZjfbWCp0cKSuZh01E=
google.golang.org/genproto/googleapis/api v0.0.0-20240311173647-c811ad7063a7/go.mod h1:VQW3tUculP/D4B+xVCo+VgSq8As6wA9ZjHl//pmk+6s=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240311173647-c811ad7063a7 h1:8EeVk1VKMD+GD/neyEHGmz7pFblqPjHoi+PGQIlLx2s=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240311173647-c811ad7063a7/go.mod h1:WtryC6hu0hhx87FDGxWCDptyssuo68sk10vYjF+T9fY=
google.golang.org/grpc v1.62.1 h1:B4n+nfKzOICUXMgyrNd19h/I9oH0L1pizfk1d4zSgTk=
google.golang.org/grpc v1.62.1/go.mod h1:IWTG0VlJLCh1SkC58F7np9ka9mx/WNkjl4PGJaiq+QE=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
Expand Down
60 changes: 60 additions & 0 deletions sdk/idp_oauth_access_token_source.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package sdk

import (
"context"
"fmt"
"net/http"

"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/opentdf/platform/lib/ocrypto"
"github.com/opentdf/platform/sdk/auth"
"golang.org/x/oauth2"
)

// OAuthAccessTokenSource allow connecting to an IDP and obtain a DPoP bound access token
type OAuthAccessTokenSource struct {
source oauth2.TokenSource
scopes []string
dpopKey jwk.Key
asymDecryption ocrypto.AsymDecryption
dpopPEM string
}

func NewOAuthAccessTokenSource(
source oauth2.TokenSource, scopes []string, key *ocrypto.RsaKeyPair,
) (*OAuthAccessTokenSource, error) {
dpopPublicKeyPEM, dpopKey, asymDecryption, err := getNewDPoPKey(key)
if err != nil {
return nil, err
}

tokenSource := OAuthAccessTokenSource{
source: source,
scopes: scopes,
asymDecryption: *asymDecryption,
dpopKey: dpopKey,
dpopPEM: dpopPublicKeyPEM,
}

return &tokenSource, nil
}

// AccessToken use a pointer receiver so that the token state is shared
func (t *OAuthAccessTokenSource) AccessToken(_ context.Context, _ *http.Client) (auth.AccessToken, error) { // must satisfy auth.AccessTokenSource interface
tok, err := t.source.Token()
if err != nil {
return "", fmt.Errorf("error getting access token: %w", err)
}

// Non-nil with AccessToken and not Expired
if !tok.Valid() {
return "", ErrAccessTokenInvalid
// TODO: refresh tokens if expired?
}

return auth.AccessToken(tok.AccessToken), nil
}

func (t *OAuthAccessTokenSource) MakeToken(tokenMaker func(jwk.Key) ([]byte, error)) ([]byte, error) {
return tokenMaker(t.dpopKey)
}
97 changes: 97 additions & 0 deletions sdk/idp_oauth_access_token_source_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package sdk

import (
"context"
"testing"
"time"

"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/opentdf/platform/lib/ocrypto"
"github.com/opentdf/platform/sdk/auth"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)

func TestNewOAuthAccessTokenSource_Success(t *testing.T) {
mockToken := "mockToken"
// Expected
mockSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: mockToken})
mockScopes := []string{"scope1", "scope2"}
mockKey, _ := ocrypto.NewRSAKeyPair(dpopKeySize)
dpopPublicKeyPEM, dpopKey, asymDecryption, _ := getNewDPoPKey(&mockKey)

// Testable
tokenSource, err := NewOAuthAccessTokenSource(mockSource, mockScopes, &mockKey)

// Sanity Checks
require.NoError(t, err)
assert.NotNil(t, tokenSource)
assert.Equal(t, mockSource, tokenSource.source)
assert.Equal(t, mockScopes, tokenSource.scopes)
// DPoP values
assert.Equal(t, asymDecryption, &tokenSource.asymDecryption)
assert.Equal(t, dpopPublicKeyPEM, tokenSource.dpopPEM)
assert.Equal(t, dpopKey, tokenSource.dpopKey)
// Interface checks
tok, err := tokenSource.AccessToken(context.Background(), nil)
require.NoError(t, err)
assert.Equal(t, tok, auth.AccessToken(mockToken))
made, err := tokenSource.MakeToken(func(jwk.Key) ([]byte, error) { return []byte(mockToken), nil })
require.NoError(t, err)
assert.Equal(t, made, []byte(mockToken))
}

func TestNewOAuthAccessTokenSource_ExpiredToken(t *testing.T) {
// Expected
pastTime := time.Now().Add(-time.Hour)
mockSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "mockToken", Expiry: pastTime})
mockScopes := []string{"scope1"}
mockKey, _ := ocrypto.NewRSAKeyPair(dpopKeySize)

// Testable
tokenSource, err := NewOAuthAccessTokenSource(mockSource, mockScopes, &mockKey)

// Sanity Checks
require.NoError(t, err)
assert.NotNil(t, tokenSource)
assert.Equal(t, mockSource, tokenSource.source)
// Interface checks
tok, err := tokenSource.AccessToken(context.Background(), nil)
require.Error(t, err)
require.ErrorIs(t, err, ErrAccessTokenInvalid)
assert.Empty(t, tok)
}

func TestNewOAuthAccessTokenSource_InvalidTokenSource(t *testing.T) {
// Expected
mockSource := oauth2.StaticTokenSource(&oauth2.Token{})
mockScopes := []string{"scope1"}
mockKey, _ := ocrypto.NewRSAKeyPair(dpopKeySize)

// Testable
tokenSource, err := NewOAuthAccessTokenSource(mockSource, mockScopes, &mockKey)

// Sanity Checks
require.NoError(t, err)
assert.NotNil(t, tokenSource)
assert.Equal(t, mockSource, tokenSource.source)
// Interface checks
tok, err := tokenSource.AccessToken(context.Background(), nil)
require.Error(t, err)
assert.Empty(t, tok)
}

func TestNewOAuthAccessTokenSource_InvalidKey(t *testing.T) {
// Expected
mockSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "mockToken"})
mockScopes := []string{"scope1"}
badKey := ocrypto.RsaKeyPair{}

// Testable
tokenSource, err := NewOAuthAccessTokenSource(mockSource, mockScopes, &badKey)

// Sanity Checks
require.Error(t, err)
assert.Nil(t, tokenSource)
}
1 change: 0 additions & 1 deletion sdk/internal/oauth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ func GetAccessToken(client *http.Client, tokenEndpoint string, scopes []string,

func processResponse(resp *http.Response) (*Token, error) {
respBytes, err := io.ReadAll(resp.Body)

if err != nil {
return nil, fmt.Errorf("error reading bytes from response: %w", err)
}
Expand Down
9 changes: 9 additions & 0 deletions sdk/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/opentdf/platform/lib/ocrypto"
"github.com/opentdf/platform/sdk/auth"
"github.com/opentdf/platform/sdk/internal/oauth"
"golang.org/x/oauth2"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
Expand All @@ -31,6 +32,7 @@ type config struct {
ipc bool
tdfFeatures tdfFeatures
customAccessTokenSource auth.AccessTokenSource
oauthAccessTokenSource oauth2.TokenSource
coreConn *grpc.ClientConn
}

Expand Down Expand Up @@ -97,6 +99,13 @@ func withCustomAccessTokenSource(a auth.AccessTokenSource) Option {
}
}

// WithOAuthAccessTokenSource directs the SDK to use a standard OAuth2 token source for authentication
func WithOAuthAccessTokenSource(t oauth2.TokenSource) Option {
return func(c *config) {
c.oauthAccessTokenSource = t
}
}

// Deprecated: Use WithCustomCoreConnection instead
func WithCustomPolicyConnection(conn *grpc.ClientConn) Option {
return func(c *config) {
Expand Down
6 changes: 4 additions & 2 deletions sdk/sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ const (
ErrPlatformAuthzEndpointNotFound = Error("authorization_endpoint not found in well-known idp configuration")
ErrPlatformTokenEndpointNotFound = Error("token_endpoint not found in well-known idp configuration")
ErrPlatformPublicClientIDNotFound = Error("public_client_id not found in well-known idp configuration")
ErrAccessTokenInvalid = Error("access token is invalid")
)

type Error string
Expand Down Expand Up @@ -215,9 +216,8 @@ func buildIDPTokenSource(c *config) (auth.AccessTokenSource, error) {
return c.customAccessTokenSource, nil
}

// If we don't have client-credentials, just return a KAS client that can only get public keys.
// There are uses for uncredentialed clients (i.e. consuming the well-known configuration).
if c.clientCredentials == nil {
if c.clientCredentials == nil && c.oauthAccessTokenSource == nil {
return nil, nil //nolint:nilnil // not having credentials is not an error
}

Expand All @@ -237,6 +237,8 @@ func buildIDPTokenSource(c *config) (auth.AccessTokenSource, error) {
var err error

switch {
case c.oauthAccessTokenSource != nil:
ts, err = NewOAuthAccessTokenSource(c.oauthAccessTokenSource, c.scopes, c.dpopKey)
case c.certExchange != nil:
ts, err = NewCertExchangeTokenSource(*c.certExchange, *c.clientCredentials, c.tokenEndpoint, c.dpopKey)
case c.tokenExchange != nil:
Expand Down