Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion sdk/auth/oauth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ func processResponse(resp *http.Response) (*Token, error) {
}

func getDPoPAssertion(dpopJWK jwk.Key, method string, endpoint string, nonce string) (string, error) {
slog.Debug("building DPoP Proof")
publicKey, err := jwk.PublicKeyOf(dpopJWK)
const expirationTime = 5 * time.Minute

Expand Down
6 changes: 3 additions & 3 deletions sdk/bulk.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,15 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error {
return fmt.Errorf("retrieving platformEndpoint failed: %w", err)
}
// if no kasAllowlist is set, we get the allowlist from the registry
allowlist, err := allowListFromKASRegistry(ctx, s.KeyAccessServerRegistry, platformEndpoint)
allowlist, err := allowListFromKASRegistry(ctx, s.logger, s.KeyAccessServerRegistry, platformEndpoint)
if err != nil {
return fmt.Errorf("failed to get allowlist from registry: %w", err)
}
bulkReq.kasAllowlist = allowlist
bulkReq.NanoTDFDecryptOptions = append(bulkReq.NanoTDFDecryptOptions, withNanoKasAllowlist(bulkReq.kasAllowlist))
bulkReq.TDF3DecryptOptions = append(bulkReq.TDF3DecryptOptions, withKasAllowlist(bulkReq.kasAllowlist))
} else {
slog.Error("no KAS allowlist provided and no KeyAccessServerRegistry available")
s.Logger().Error("no KAS allowlist provided and no KeyAccessServerRegistry available")
return errors.New("no KAS allowlist provided and no KeyAccessServerRegistry available")
}
}
Expand Down Expand Up @@ -172,7 +172,7 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error {
var err error
for kasurl, rewrapRequests := range kasRewrapRequests {
if bulkReq.ignoreAllowList {
slog.Warn("kasAllowlist is ignored, kas url is allowed", slog.String("kas_url", kasurl))
s.Logger().Warn("kasAllowlist is ignored, kas url is allowed", slog.String("kas_url", kasurl))
} else if !bulkReq.kasAllowlist.IsAllowed(kasurl) {
// if kas url is not allowed, the result for each kao in each rewrap request is set to error
for _, req := range rewrapRequests {
Expand Down
2 changes: 2 additions & 0 deletions sdk/fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/rand"
"encoding/base64"
"io"
"log/slog"
"net/http"
"testing"

Expand All @@ -26,6 +27,7 @@ func writeBytes(writerFunc func(io.Writer) error) []byte {
func newSDK() *SDK {
key, _ := ocrypto.NewRSAKeyPair(tdf3KeySize)
cfg := &config{
logger: slog.Default(),
kasSessionKey: &key,
}
sdk := &SDK{
Expand Down
54 changes: 29 additions & 25 deletions sdk/granter.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ func (a AttributeValueFQN) Name() string {

// Structure capable of generating a split plan from a given set of data tags.
type granter struct {
logger *slog.Logger

// The data attributes (tags) that this granter is responsible for.
tags []AttributeValueFQN

Expand Down Expand Up @@ -235,7 +237,7 @@ func (r *granter) addGrant(fqn AttributeValueFQN, kas string, attr *policy.Attri
func (r *granter) addMappedKey(fqn AttributeValueFQN, sk *policy.SimpleKasKey) error {
key := sk.GetPublicKey()
if key == nil || key.GetKid() == "" || key.GetPem() == "" {
slog.Debug("invalid cached key in policy service",
r.logger.Debug("invalid cached key in policy service",
slog.String("kas", sk.GetKasUri()),
slog.Any("value", fqn),
)
Expand All @@ -252,15 +254,15 @@ func (r *granter) addMappedKey(fqn AttributeValueFQN, sk *policy.SimpleKasKey) e

rl, err := NewResourceLocator(sk.GetKasUri())
if err != nil {
slog.Debug("invalid KAS URL in policy service",
r.logger.Debug("invalid KAS URL in policy service",
slog.String("kas", sk.GetKasUri()),
slog.Any("value", fqn),
slog.Any("error", err),
)
return fmt.Errorf("invalid KAS URL in policy service associated with [%s]: %w", fqn, err)
}
rl.identifier = key.GetKid()
slog.Debug("added mapped key",
r.logger.Debug("added mapped key",
slog.Any("fqn", fqn),
slog.String("kas", sk.GetKasUri()),
slog.String("kid", key.GetKid()),
Expand Down Expand Up @@ -330,7 +332,7 @@ func (r *granter) addAllGrants(fqn AttributeValueFQN, ag grantableObject, attr *
// Check for mapped keys
for _, k := range ag.GetKasKeys() {
if k == nil || k.GetKasUri() == "" {
slog.Debug("invalid KAS key in policy service",
r.logger.Debug("invalid KAS key in policy service",
slog.Any("simple_kas_key", k),
slog.Any("value", fqn),
)
Expand All @@ -341,7 +343,7 @@ func (r *granter) addAllGrants(fqn AttributeValueFQN, ag grantableObject, attr *
result = r.typ
err := r.addMappedKey(fqn, k)
if err != nil {
slog.Debug("failed to add mapped key",
r.logger.Debug("failed to add mapped key",
slog.Any("fqn", fqn),
slog.String("kas", kasURI),
slog.Any("error", err),
Expand All @@ -367,7 +369,7 @@ func (r *granter) addAllGrants(fqn AttributeValueFQN, ag grantableObject, attr *
for _, k := range g.GetKasKeys() {
err := r.addMappedKey(fqn, k)
if err != nil {
slog.Warn("failed to add mapped key",
r.logger.Warn("failed to add mapped key",
slog.Any("fqn", fqn),
slog.String("kas", kasURI),
slog.Any("error", err),
Expand All @@ -378,15 +380,15 @@ func (r *granter) addAllGrants(fqn AttributeValueFQN, ag grantableObject, attr *
}
ks := g.GetPublicKey().GetCached().GetKeys()
if len(ks) == 0 {
slog.Debug("no cached key in policy service",
r.logger.Debug("no cached key in policy service",
slog.String("kas", kasURI),
slog.Any("value", fqn),
)
continue
}
for _, k := range ks {
if k.GetKid() == "" || k.GetPem() == "" {
slog.Debug("invalid cached key in policy service",
r.logger.Debug("invalid cached key in policy service",
slog.String("kas", kasURI),
slog.Any("value", fqn),
slog.Any("key", k),
Expand All @@ -404,7 +406,7 @@ func (r *granter) addAllGrants(fqn AttributeValueFQN, ag grantableObject, attr *
}
err := r.addMappedKey(fqn, sk)
if err != nil {
slog.Warn("failed to add mapped key",
r.logger.Warn("failed to add mapped key",
slog.Any("fqn", fqn),
slog.String("kas", kasURI),
slog.Any("error", err),
Expand All @@ -426,7 +428,7 @@ func (r granter) byAttribute(fqn AttributeValueFQN) *keyAccessGrant {
}

// Gets a list of directory of KAS grants for a list of attribute FQNs
func newGranterFromService(ctx context.Context, keyCache *kasKeyCache, as sdkconnect.AttributesServiceClient, fqns ...AttributeValueFQN) (granter, error) {
func newGranterFromService(ctx context.Context, logger *slog.Logger, keyCache *kasKeyCache, as sdkconnect.AttributesServiceClient, fqns ...AttributeValueFQN) (granter, error) {
fqnsStr := make([]string, len(fqns))
for i, v := range fqns {
fqnsStr[i] = v.String()
Expand All @@ -443,6 +445,7 @@ func newGranterFromService(ctx context.Context, keyCache *kasKeyCache, as sdkcon
}

grants := granter{
logger: logger,
tags: fqns,
grantTable: make(map[string]*keyAccessGrant),
keyCache: &rlKeyCache{c: make(map[ResourceLocator]*policy.SimpleKasKey)},
Expand All @@ -455,23 +458,23 @@ func newGranterFromService(ctx context.Context, keyCache *kasKeyCache, as sdkcon
def := pair.GetAttribute()

if def != nil {
storeKeysToCache(def.GetGrants(), def.GetKasKeys(), keyCache, grants.keyCache)
storeKeysToCache(logger, def.GetGrants(), def.GetKasKeys(), keyCache, grants.keyCache)
}
v := pair.GetValue()
gType := noKeysFound
if v != nil {
gType = grants.addAllGrants(fqn, v, def)
storeKeysToCache(v.GetGrants(), v.GetKasKeys(), keyCache, grants.keyCache)
storeKeysToCache(logger, v.GetGrants(), v.GetKasKeys(), keyCache, grants.keyCache)
}

// If no more specific grant was found, then add the value grants
if gType == noKeysFound && def != nil {
gType = grants.addAllGrants(fqn, def, def)
storeKeysToCache(def.GetGrants(), def.GetKasKeys(), keyCache, grants.keyCache)
storeKeysToCache(logger, def.GetGrants(), def.GetKasKeys(), keyCache, grants.keyCache)
}
if gType == noKeysFound && def.GetNamespace() != nil {
grants.addAllGrants(fqn, def.GetNamespace(), def)
storeKeysToCache(def.GetNamespace().GetGrants(), def.GetNamespace().GetKasKeys(), keyCache, grants.keyCache)
storeKeysToCache(logger, def.GetNamespace().GetGrants(), def.GetNamespace().GetKasKeys(), keyCache, grants.keyCache)
}
}

Expand Down Expand Up @@ -515,11 +518,11 @@ func algProto2OcryptoKeyType(e policy.Algorithm) ocrypto.KeyType {
}
}

func storeKeysToCache(kases []*policy.KeyAccessServer, keys []*policy.SimpleKasKey, c *kasKeyCache, kc *rlKeyCache) {
func storeKeysToCache(logger *slog.Logger, kases []*policy.KeyAccessServer, keys []*policy.SimpleKasKey, c *kasKeyCache, kc *rlKeyCache) {
for _, kas := range kases {
keys := kas.GetPublicKey().GetCached().GetKeys()
if len(keys) == 0 {
slog.Debug("no cached key in policy service", slog.String("kas", kas.GetUri()))
logger.Debug("no cached key in policy service", slog.String("kas", kas.GetUri()))
continue
}
for _, ki := range keys {
Expand All @@ -535,7 +538,7 @@ func storeKeysToCache(kases []*policy.KeyAccessServer, keys []*policy.SimpleKasK
if kc != nil && ki.GetKid() != "" && ki.GetPem() != "" {
rl, err := NewResourceLocator(kas.GetUri())
if err != nil {
slog.Debug("failed to create ResourceLocator",
logger.Debug("failed to create ResourceLocator",
slog.String("kas", kas.GetUri()),
slog.Any("error", err),
)
Expand Down Expand Up @@ -570,7 +573,7 @@ func storeKeysToCache(kases []*policy.KeyAccessServer, keys []*policy.SimpleKasK
if kc != nil && key.GetPublicKey().GetKid() != "" && key.GetPublicKey().GetPem() != "" {
rl, err := NewResourceLocator(key.GetKasUri())
if err != nil {
slog.Debug("failed to create ResourceLocator",
logger.Debug("failed to create ResourceLocator",
slog.String("kas", key.GetKasUri()),
slog.Any("error", err),
)
Expand All @@ -585,8 +588,9 @@ func storeKeysToCache(kases []*policy.KeyAccessServer, keys []*policy.SimpleKasK
// Given a policy (list of data attributes or tags),
// get a set of grants from attribute values to KASes.
// Unlike `newGranterFromService`, this works offline.
func newGranterFromAttributes(keyCache *kasKeyCache, attrs ...*policy.Value) (granter, error) {
func newGranterFromAttributes(logger *slog.Logger, keyCache *kasKeyCache, attrs ...*policy.Value) (granter, error) {
grants := granter{
logger: logger,
grantTable: make(map[string]*keyAccessGrant),
mapTable: make(map[string][]*ResourceLocator),
tags: make([]AttributeValueFQN, len(attrs)),
Expand All @@ -608,16 +612,16 @@ func newGranterFromAttributes(keyCache *kasKeyCache, attrs ...*policy.Value) (gr
}

if grants.addAllGrants(fqn, v, def) != noKeysFound {
storeKeysToCache(v.GetGrants(), v.GetKasKeys(), keyCache, grants.keyCache)
storeKeysToCache(logger, v.GetGrants(), v.GetKasKeys(), keyCache, grants.keyCache)
continue
}
// If no more specific grant was found, then add the attr grants
if grants.addAllGrants(fqn, def, def) != noKeysFound {
storeKeysToCache(def.GetGrants(), def.GetKasKeys(), keyCache, grants.keyCache)
storeKeysToCache(logger, def.GetGrants(), def.GetKasKeys(), keyCache, grants.keyCache)
continue
}
grants.addAllGrants(fqn, namespace, def)
storeKeysToCache(namespace.GetGrants(), namespace.GetKasKeys(), keyCache, grants.keyCache)
storeKeysToCache(logger, namespace.GetGrants(), namespace.GetKasKeys(), keyCache, grants.keyCache)
}

return grants, nil
Expand Down Expand Up @@ -846,7 +850,7 @@ func (r *granter) insertKeysForAttribute(e attributeBooleanExpression) (booleanK
var err error
rl, err = NewResourceLocator(kas)
if err != nil {
slog.Warn("invalid KAS URL in policy service",
r.logger.Warn("invalid KAS URL in policy service",
slog.String("kas", kas),
slog.Any("value", term),
slog.Any("error", err),
Expand All @@ -859,7 +863,7 @@ func (r *granter) insertKeysForAttribute(e attributeBooleanExpression) (booleanK
}
op := ruleToOperator(clause.def.GetRule())
if op == unspecified {
slog.Warn("unknown attribute rule type", slog.Any("rule", clause))
r.logger.Warn("unknown attribute rule type", slog.Any("rule", clause))
}
kc := keyClause{
operator: op,
Expand Down Expand Up @@ -888,7 +892,7 @@ func (r *granter) assignKeysTo(e attributeBooleanExpression) (booleanKeyExpressi
}
op := ruleToOperator(clause.def.GetRule())
if op == unspecified {
slog.Warn("unknown attribute rule type", slog.Any("rule", clause))
r.logger.Warn("unknown attribute rule type", slog.Any("rule", clause))
}
kc := keyClause{
operator: op,
Expand Down
9 changes: 5 additions & 4 deletions sdk/granter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sdk
import (
"context"
"errors"
"log/slog"
"maps"
"regexp"
"slices"
Expand Down Expand Up @@ -568,7 +569,7 @@ func TestConfigurationServicePutGet(t *testing.T) {
} {
t.Run(tc.n, func(t *testing.T) {
v := valuesToPolicy(tc.policy...)
grants, err := newGranterFromAttributes(newKasKeyCache(), v...)
grants, err := newGranterFromAttributes(slog.Default(), newKasKeyCache(), v...)
require.NoError(t, err)
assert.Len(t, grants.grantTable, tc.size)
assert.Subset(t, policyToStringKeys(tc.policy), slices.Collect(maps.Keys(grants.grantTable)))
Expand Down Expand Up @@ -726,7 +727,7 @@ func TestReasonerConstructAttributeBoolean(t *testing.T) {
},
} {
t.Run(tc.n, func(t *testing.T) {
reasoner, err := newGranterFromAttributes(newKasKeyCache(), valuesToPolicy(tc.policy...)...)
reasoner, err := newGranterFromAttributes(slog.Default(), newKasKeyCache(), valuesToPolicy(tc.policy...)...)
require.NoError(t, err)

reasoner.keyInfoFetcher = &fakeKeyInfoFetcher{}
Expand Down Expand Up @@ -874,7 +875,7 @@ func TestReasonerSpecificity(t *testing.T) {
},
} {
t.Run(tc.n, func(t *testing.T) {
reasoner, err := newGranterFromService(t.Context(), newKasKeyCache(), &mockAttributesClient{}, tc.policy...)
reasoner, err := newGranterFromService(t.Context(), slog.Default(), newKasKeyCache(), &mockAttributesClient{}, tc.policy...)
require.NoError(t, err)
i := 0
plan, err := reasoner.plan(tc.defaults, func() string {
Expand Down Expand Up @@ -1025,7 +1026,7 @@ func TestReasonerSpecificityWithNamespaces(t *testing.T) {
},
} {
t.Run((tc.n + "\n" + tc.desc), func(t *testing.T) {
reasoner, err := newGranterFromService(t.Context(), newKasKeyCache(), &mockAttributesClient{}, tc.policy...)
reasoner, err := newGranterFromService(t.Context(), slog.Default(), newKasKeyCache(), &mockAttributesClient{}, tc.policy...)
require.NoError(t, err)
i := 0
plan, err := reasoner.plan(tc.defaults, func() string {
Expand Down
6 changes: 4 additions & 2 deletions sdk/idp_access_token_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ func getNewDPoPKey(dpopKeyPair *ocrypto.RsaKeyPair) (string, jwk.Key, *ocrypto.A
// IDPAccessTokenSource credentials that allow us to connect to an IDP and obtain an access token that is bound
// to a DPoP key
type IDPAccessTokenSource struct {
logger *slog.Logger
credentials oauth.ClientCredentials
idpTokenEndpoint url.URL
token *oauth.Token
Expand All @@ -60,7 +61,7 @@ type IDPAccessTokenSource struct {
}

func NewIDPAccessTokenSource(
credentials oauth.ClientCredentials, idpTokenEndpoint string, scopes []string, key *ocrypto.RsaKeyPair,
log *slog.Logger, credentials oauth.ClientCredentials, idpTokenEndpoint string, scopes []string, key *ocrypto.RsaKeyPair,
) (*IDPAccessTokenSource, error) {
endpoint, err := url.Parse(idpTokenEndpoint)
if err != nil {
Expand All @@ -73,6 +74,7 @@ func NewIDPAccessTokenSource(
}

tokenSource := IDPAccessTokenSource{
logger: log,
credentials: credentials,
idpTokenEndpoint: *endpoint,
token: nil,
Expand All @@ -92,7 +94,7 @@ func (t *IDPAccessTokenSource) AccessToken(ctx context.Context, client *http.Cli
defer t.tokenMutex.Unlock()

if t.token == nil || t.token.Expired() {
slog.DebugContext(ctx, "getting new access token")
t.logger.DebugContext(ctx, "getting new access token")
tok, err := oauth.GetAccessToken(client, t.idpTokenEndpoint.String(), t.scopes, t.credentials, t.dpopKey)
if err != nil {
return "", fmt.Errorf("error getting access token: %w", err)
Expand Down
5 changes: 4 additions & 1 deletion sdk/idp_cert_exchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sdk

import (
"context"
"log/slog"
"net/http"
"sync"

Expand All @@ -13,6 +14,7 @@ import (

type CertExchangeTokenSource struct {
auth.AccessTokenSource
logger *slog.Logger
IdpEndpoint string
credentials oauth.ClientCredentials
tokenMutex *sync.Mutex
Expand All @@ -21,13 +23,14 @@ type CertExchangeTokenSource struct {
key jwk.Key
}

func NewCertExchangeTokenSource(info oauth.CertExchangeInfo, credentials oauth.ClientCredentials, idpTokenEndpoint string, dpop *ocrypto.RsaKeyPair) (auth.AccessTokenSource, error) {
func NewCertExchangeTokenSource(logger *slog.Logger, info oauth.CertExchangeInfo, credentials oauth.ClientCredentials, idpTokenEndpoint string, dpop *ocrypto.RsaKeyPair) (auth.AccessTokenSource, error) {
_, dpopKey, _, err := getNewDPoPKey(dpop)
if err != nil {
return nil, err
}

exchangeSource := CertExchangeTokenSource{
logger: logger,
info: info,
IdpEndpoint: idpTokenEndpoint,
credentials: credentials,
Expand Down
Loading
Loading