Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion sdk/auth/oauth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ func processResponse(resp *http.Response) (*Token, error) {
}

func getDPoPAssertion(dpopJWK jwk.Key, method string, endpoint string, nonce string) (string, error) {
slog.Debug("building DPoP Proof")
publicKey, err := jwk.PublicKeyOf(dpopJWK)
const expirationTime = 5 * time.Minute

Expand Down
6 changes: 3 additions & 3 deletions sdk/bulk.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,15 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error {
return fmt.Errorf("retrieving platformEndpoint failed: %w", err)
}
// if no kasAllowlist is set, we get the allowlist from the registry
allowlist, err := allowListFromKASRegistry(ctx, s.KeyAccessServerRegistry, platformEndpoint)
allowlist, err := allowListFromKASRegistry(ctx, s.logger, s.KeyAccessServerRegistry, platformEndpoint)
if err != nil {
return fmt.Errorf("failed to get allowlist from registry: %w", err)
}
bulkReq.kasAllowlist = allowlist
bulkReq.NanoTDFDecryptOptions = append(bulkReq.NanoTDFDecryptOptions, withNanoKasAllowlist(bulkReq.kasAllowlist))
bulkReq.TDF3DecryptOptions = append(bulkReq.TDF3DecryptOptions, withKasAllowlist(bulkReq.kasAllowlist))
} else {
slog.Error("no KAS allowlist provided and no KeyAccessServerRegistry available")
s.Logger().Error("no KAS allowlist provided and no KeyAccessServerRegistry available")
return errors.New("no KAS allowlist provided and no KeyAccessServerRegistry available")
}
}
Expand Down Expand Up @@ -172,7 +172,7 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error {
var err error
for kasurl, rewrapRequests := range kasRewrapRequests {
if bulkReq.ignoreAllowList {
slog.Warn("kasAllowlist is ignored, kas url is allowed", slog.String("kas_url", kasurl))
s.Logger().Warn("kasAllowlist is ignored, kas url is allowed", slog.String("kas_url", kasurl))
} else if !bulkReq.kasAllowlist.IsAllowed(kasurl) {
// if kas url is not allowed, the result for each kao in each rewrap request is set to error
for _, req := range rewrapRequests {
Expand Down
2 changes: 2 additions & 0 deletions sdk/fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/rand"
"encoding/base64"
"io"
"log/slog"
"net/http"
"testing"

Expand All @@ -26,6 +27,7 @@ func writeBytes(writerFunc func(io.Writer) error) []byte {
func newSDK() *SDK {
key, _ := ocrypto.NewRSAKeyPair(tdf3KeySize)
cfg := &config{
logger: slog.Default(),
kasSessionKey: &key,
}
sdk := &SDK{
Expand Down
54 changes: 29 additions & 25 deletions sdk/granter.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ func (a AttributeValueFQN) Name() string {

// Structure capable of generating a split plan from a given set of data tags.
type granter struct {
logger *slog.Logger

// The data attributes (tags) that this granter is responsible for.
tags []AttributeValueFQN

Expand Down Expand Up @@ -235,7 +237,7 @@ func (r *granter) addGrant(fqn AttributeValueFQN, kas string, attr *policy.Attri
func (r *granter) addMappedKey(fqn AttributeValueFQN, sk *policy.SimpleKasKey) error {
key := sk.GetPublicKey()
if key == nil || key.GetKid() == "" || key.GetPem() == "" {
slog.Debug("invalid cached key in policy service",
r.logger.Debug("invalid cached key in policy service",
slog.String("kas", sk.GetKasUri()),
slog.Any("value", fqn),
)
Expand All @@ -252,15 +254,15 @@ func (r *granter) addMappedKey(fqn AttributeValueFQN, sk *policy.SimpleKasKey) e

rl, err := NewResourceLocator(sk.GetKasUri())
if err != nil {
slog.Debug("invalid KAS URL in policy service",
r.logger.Debug("invalid KAS URL in policy service",
slog.String("kas", sk.GetKasUri()),
slog.Any("value", fqn),
slog.Any("error", err),
)
return fmt.Errorf("invalid KAS URL in policy service associated with [%s]: %w", fqn, err)
}
rl.identifier = key.GetKid()
slog.Debug("added mapped key",
r.logger.Debug("added mapped key",
slog.Any("fqn", fqn),
slog.String("kas", sk.GetKasUri()),
slog.String("kid", key.GetKid()),
Expand Down Expand Up @@ -330,7 +332,7 @@ func (r *granter) addAllGrants(fqn AttributeValueFQN, ag grantableObject, attr *
// Check for mapped keys
for _, k := range ag.GetKasKeys() {
if k == nil || k.GetKasUri() == "" {
slog.Debug("invalid KAS key in policy service",
r.logger.Debug("invalid KAS key in policy service",
slog.Any("simple_kas_key", k),
slog.Any("value", fqn),
)
Expand All @@ -341,7 +343,7 @@ func (r *granter) addAllGrants(fqn AttributeValueFQN, ag grantableObject, attr *
result = r.typ
err := r.addMappedKey(fqn, k)
if err != nil {
slog.Debug("failed to add mapped key",
r.logger.Debug("failed to add mapped key",
slog.Any("fqn", fqn),
slog.String("kas", kasURI),
slog.Any("error", err),
Expand All @@ -367,7 +369,7 @@ func (r *granter) addAllGrants(fqn AttributeValueFQN, ag grantableObject, attr *
for _, k := range g.GetKasKeys() {
err := r.addMappedKey(fqn, k)
if err != nil {
slog.Warn("failed to add mapped key",
r.logger.Warn("failed to add mapped key",
slog.Any("fqn", fqn),
slog.String("kas", kasURI),
slog.Any("error", err),
Expand All @@ -378,15 +380,15 @@ func (r *granter) addAllGrants(fqn AttributeValueFQN, ag grantableObject, attr *
}
ks := g.GetPublicKey().GetCached().GetKeys()
if len(ks) == 0 {
slog.Debug("no cached key in policy service",
r.logger.Debug("no cached key in policy service",
slog.String("kas", kasURI),
slog.Any("value", fqn),
)
continue
}
for _, k := range ks {
if k.GetKid() == "" || k.GetPem() == "" {
slog.Debug("invalid cached key in policy service",
r.logger.Debug("invalid cached key in policy service",
slog.String("kas", kasURI),
slog.Any("value", fqn),
slog.Any("key", k),
Expand All @@ -404,7 +406,7 @@ func (r *granter) addAllGrants(fqn AttributeValueFQN, ag grantableObject, attr *
}
err := r.addMappedKey(fqn, sk)
if err != nil {
slog.Warn("failed to add mapped key",
r.logger.Warn("failed to add mapped key",
slog.Any("fqn", fqn),
slog.String("kas", kasURI),
slog.Any("error", err),
Expand All @@ -426,7 +428,7 @@ func (r granter) byAttribute(fqn AttributeValueFQN) *keyAccessGrant {
}

// Gets a list of directory of KAS grants for a list of attribute FQNs
func newGranterFromService(ctx context.Context, keyCache *kasKeyCache, as sdkconnect.AttributesServiceClient, fqns ...AttributeValueFQN) (granter, error) {
func newGranterFromService(ctx context.Context, logger *slog.Logger, keyCache *kasKeyCache, as sdkconnect.AttributesServiceClient, fqns ...AttributeValueFQN) (granter, error) {
fqnsStr := make([]string, len(fqns))
for i, v := range fqns {
fqnsStr[i] = v.String()
Expand All @@ -443,6 +445,7 @@ func newGranterFromService(ctx context.Context, keyCache *kasKeyCache, as sdkcon
}

grants := granter{
logger: logger,
tags: fqns,
grantTable: make(map[string]*keyAccessGrant),
keyCache: &rlKeyCache{c: make(map[ResourceLocator]*policy.SimpleKasKey)},
Expand All @@ -455,23 +458,23 @@ func newGranterFromService(ctx context.Context, keyCache *kasKeyCache, as sdkcon
def := pair.GetAttribute()

if def != nil {
storeKeysToCache(def.GetGrants(), def.GetKasKeys(), keyCache, grants.keyCache)
storeKeysToCache(logger, def.GetGrants(), def.GetKasKeys(), keyCache, grants.keyCache)
}
v := pair.GetValue()
gType := noKeysFound
if v != nil {
gType = grants.addAllGrants(fqn, v, def)
storeKeysToCache(v.GetGrants(), v.GetKasKeys(), keyCache, grants.keyCache)
storeKeysToCache(logger, v.GetGrants(), v.GetKasKeys(), keyCache, grants.keyCache)
}

// If no more specific grant was found, then add the value grants
if gType == noKeysFound && def != nil {
gType = grants.addAllGrants(fqn, def, def)
storeKeysToCache(def.GetGrants(), def.GetKasKeys(), keyCache, grants.keyCache)
storeKeysToCache(logger, def.GetGrants(), def.GetKasKeys(), keyCache, grants.keyCache)
}
if gType == noKeysFound && def.GetNamespace() != nil {
grants.addAllGrants(fqn, def.GetNamespace(), def)
storeKeysToCache(def.GetNamespace().GetGrants(), def.GetNamespace().GetKasKeys(), keyCache, grants.keyCache)
storeKeysToCache(logger, def.GetNamespace().GetGrants(), def.GetNamespace().GetKasKeys(), keyCache, grants.keyCache)
}
}

Expand Down Expand Up @@ -515,11 +518,11 @@ func algProto2OcryptoKeyType(e policy.Algorithm) ocrypto.KeyType {
}
}

func storeKeysToCache(kases []*policy.KeyAccessServer, keys []*policy.SimpleKasKey, c *kasKeyCache, kc *rlKeyCache) {
func storeKeysToCache(logger *slog.Logger, kases []*policy.KeyAccessServer, keys []*policy.SimpleKasKey, c *kasKeyCache, kc *rlKeyCache) {
for _, kas := range kases {
keys := kas.GetPublicKey().GetCached().GetKeys()
if len(keys) == 0 {
slog.Debug("no cached key in policy service", slog.String("kas", kas.GetUri()))
logger.Debug("no cached key in policy service", slog.String("kas", kas.GetUri()))
continue
}
for _, ki := range keys {
Expand All @@ -535,7 +538,7 @@ func storeKeysToCache(kases []*policy.KeyAccessServer, keys []*policy.SimpleKasK
if kc != nil && ki.GetKid() != "" && ki.GetPem() != "" {
rl, err := NewResourceLocator(kas.GetUri())
if err != nil {
slog.Debug("failed to create ResourceLocator",
logger.Debug("failed to create ResourceLocator",
slog.String("kas", kas.GetUri()),
slog.Any("error", err),
)
Expand Down Expand Up @@ -570,7 +573,7 @@ func storeKeysToCache(kases []*policy.KeyAccessServer, keys []*policy.SimpleKasK
if kc != nil && key.GetPublicKey().GetKid() != "" && key.GetPublicKey().GetPem() != "" {
rl, err := NewResourceLocator(key.GetKasUri())
if err != nil {
slog.Debug("failed to create ResourceLocator",
logger.Debug("failed to create ResourceLocator",
slog.String("kas", key.GetKasUri()),
slog.Any("error", err),
)
Expand All @@ -585,8 +588,9 @@ func storeKeysToCache(kases []*policy.KeyAccessServer, keys []*policy.SimpleKasK
// Given a policy (list of data attributes or tags),
// get a set of grants from attribute values to KASes.
// Unlike `newGranterFromService`, this works offline.
func newGranterFromAttributes(keyCache *kasKeyCache, attrs ...*policy.Value) (granter, error) {
func newGranterFromAttributes(logger *slog.Logger, keyCache *kasKeyCache, attrs ...*policy.Value) (granter, error) {
grants := granter{
logger: logger,
grantTable: make(map[string]*keyAccessGrant),
mapTable: make(map[string][]*ResourceLocator),
tags: make([]AttributeValueFQN, len(attrs)),
Expand All @@ -608,16 +612,16 @@ func newGranterFromAttributes(keyCache *kasKeyCache, attrs ...*policy.Value) (gr
}

if grants.addAllGrants(fqn, v, def) != noKeysFound {
storeKeysToCache(v.GetGrants(), v.GetKasKeys(), keyCache, grants.keyCache)
storeKeysToCache(logger, v.GetGrants(), v.GetKasKeys(), keyCache, grants.keyCache)
continue
}
// If no more specific grant was found, then add the attr grants
if grants.addAllGrants(fqn, def, def) != noKeysFound {
storeKeysToCache(def.GetGrants(), def.GetKasKeys(), keyCache, grants.keyCache)
storeKeysToCache(logger, def.GetGrants(), def.GetKasKeys(), keyCache, grants.keyCache)
continue
}
grants.addAllGrants(fqn, namespace, def)
storeKeysToCache(namespace.GetGrants(), namespace.GetKasKeys(), keyCache, grants.keyCache)
storeKeysToCache(logger, namespace.GetGrants(), namespace.GetKasKeys(), keyCache, grants.keyCache)
}

return grants, nil
Expand Down Expand Up @@ -846,7 +850,7 @@ func (r *granter) insertKeysForAttribute(e attributeBooleanExpression) (booleanK
var err error
rl, err = NewResourceLocator(kas)
if err != nil {
slog.Warn("invalid KAS URL in policy service",
r.logger.Warn("invalid KAS URL in policy service",
slog.String("kas", kas),
slog.Any("value", term),
slog.Any("error", err),
Expand All @@ -859,7 +863,7 @@ func (r *granter) insertKeysForAttribute(e attributeBooleanExpression) (booleanK
}
op := ruleToOperator(clause.def.GetRule())
if op == unspecified {
slog.Warn("unknown attribute rule type", slog.Any("rule", clause))
r.logger.Warn("unknown attribute rule type", slog.Any("rule", clause))
}
kc := keyClause{
operator: op,
Expand Down Expand Up @@ -888,7 +892,7 @@ func (r *granter) assignKeysTo(e attributeBooleanExpression) (booleanKeyExpressi
}
op := ruleToOperator(clause.def.GetRule())
if op == unspecified {
slog.Warn("unknown attribute rule type", slog.Any("rule", clause))
r.logger.Warn("unknown attribute rule type", slog.Any("rule", clause))
}
kc := keyClause{
operator: op,
Expand Down
9 changes: 5 additions & 4 deletions sdk/granter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sdk
import (
"context"
"errors"
"log/slog"
"maps"
"regexp"
"slices"
Expand Down Expand Up @@ -568,7 +569,7 @@ func TestConfigurationServicePutGet(t *testing.T) {
} {
t.Run(tc.n, func(t *testing.T) {
v := valuesToPolicy(tc.policy...)
grants, err := newGranterFromAttributes(newKasKeyCache(), v...)
grants, err := newGranterFromAttributes(slog.Default(), newKasKeyCache(), v...)
require.NoError(t, err)
assert.Len(t, grants.grantTable, tc.size)
assert.Subset(t, policyToStringKeys(tc.policy), slices.Collect(maps.Keys(grants.grantTable)))
Expand Down Expand Up @@ -726,7 +727,7 @@ func TestReasonerConstructAttributeBoolean(t *testing.T) {
},
} {
t.Run(tc.n, func(t *testing.T) {
reasoner, err := newGranterFromAttributes(newKasKeyCache(), valuesToPolicy(tc.policy...)...)
reasoner, err := newGranterFromAttributes(slog.Default(), newKasKeyCache(), valuesToPolicy(tc.policy...)...)
require.NoError(t, err)

reasoner.keyInfoFetcher = &fakeKeyInfoFetcher{}
Expand Down Expand Up @@ -874,7 +875,7 @@ func TestReasonerSpecificity(t *testing.T) {
},
} {
t.Run(tc.n, func(t *testing.T) {
reasoner, err := newGranterFromService(t.Context(), newKasKeyCache(), &mockAttributesClient{}, tc.policy...)
reasoner, err := newGranterFromService(t.Context(), slog.Default(), newKasKeyCache(), &mockAttributesClient{}, tc.policy...)
require.NoError(t, err)
i := 0
plan, err := reasoner.plan(tc.defaults, func() string {
Expand Down Expand Up @@ -1025,7 +1026,7 @@ func TestReasonerSpecificityWithNamespaces(t *testing.T) {
},
} {
t.Run((tc.n + "\n" + tc.desc), func(t *testing.T) {
reasoner, err := newGranterFromService(t.Context(), newKasKeyCache(), &mockAttributesClient{}, tc.policy...)
reasoner, err := newGranterFromService(t.Context(), slog.Default(), newKasKeyCache(), &mockAttributesClient{}, tc.policy...)
require.NoError(t, err)
i := 0
plan, err := reasoner.plan(tc.defaults, func() string {
Expand Down
10 changes: 6 additions & 4 deletions sdk/idp_access_token_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package sdk
import (
"context"
"fmt"
"log/slog"
"net/http"
"net/url"
"sync"
Expand Down Expand Up @@ -60,7 +59,10 @@ type IDPAccessTokenSource struct {
}

func NewIDPAccessTokenSource(
credentials oauth.ClientCredentials, idpTokenEndpoint string, scopes []string, key *ocrypto.RsaKeyPair,
credentials oauth.ClientCredentials,
idpTokenEndpoint string,
scopes []string,
key *ocrypto.RsaKeyPair,
) (*IDPAccessTokenSource, error) {
endpoint, err := url.Parse(idpTokenEndpoint)
if err != nil {
Expand All @@ -73,6 +75,7 @@ func NewIDPAccessTokenSource(
}

tokenSource := IDPAccessTokenSource{
// logger: log,
credentials: credentials,
idpTokenEndpoint: *endpoint,
token: nil,
Expand All @@ -87,12 +90,11 @@ func NewIDPAccessTokenSource(
}

// AccessToken use a pointer receiver so that the token state is shared
func (t *IDPAccessTokenSource) AccessToken(ctx context.Context, client *http.Client) (auth.AccessToken, error) {
func (t *IDPAccessTokenSource) AccessToken(_ context.Context, client *http.Client) (auth.AccessToken, error) {
t.tokenMutex.Lock()
defer t.tokenMutex.Unlock()

if t.token == nil || t.token.Expired() {
slog.DebugContext(ctx, "getting new access token")
tok, err := oauth.GetAccessToken(client, t.idpTokenEndpoint.String(), t.scopes, t.credentials, t.dpopKey)
if err != nil {
return "", fmt.Errorf("error getting access token: %w", err)
Expand Down
5 changes: 4 additions & 1 deletion sdk/idp_cert_exchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sdk

import (
"context"
"log/slog"
"net/http"
"sync"

Expand All @@ -13,6 +14,7 @@ import (

type CertExchangeTokenSource struct {
auth.AccessTokenSource
logger *slog.Logger
IdpEndpoint string
credentials oauth.ClientCredentials
tokenMutex *sync.Mutex
Expand All @@ -21,13 +23,14 @@ type CertExchangeTokenSource struct {
key jwk.Key
}

func NewCertExchangeTokenSource(info oauth.CertExchangeInfo, credentials oauth.ClientCredentials, idpTokenEndpoint string, dpop *ocrypto.RsaKeyPair) (auth.AccessTokenSource, error) {
func NewCertExchangeTokenSource(logger *slog.Logger, info oauth.CertExchangeInfo, credentials oauth.ClientCredentials, idpTokenEndpoint string, dpop *ocrypto.RsaKeyPair) (auth.AccessTokenSource, error) {
_, dpopKey, _, err := getNewDPoPKey(dpop)
if err != nil {
return nil, err
}

exchangeSource := CertExchangeTokenSource{
logger: logger,
info: info,
IdpEndpoint: idpTokenEndpoint,
credentials: credentials,
Expand Down
Loading
Loading