diff --git a/sdk/auth/oauth/oauth.go b/sdk/auth/oauth/oauth.go index a1d8ab5fea..80545f2fe4 100644 --- a/sdk/auth/oauth/oauth.go +++ b/sdk/auth/oauth/oauth.go @@ -196,7 +196,6 @@ func processResponse(resp *http.Response) (*Token, error) { } func getDPoPAssertion(dpopJWK jwk.Key, method string, endpoint string, nonce string) (string, error) { - slog.Debug("building DPoP Proof") publicKey, err := jwk.PublicKeyOf(dpopJWK) const expirationTime = 5 * time.Minute diff --git a/sdk/bulk.go b/sdk/bulk.go index 7130cb6acf..4e1f3f8e3f 100644 --- a/sdk/bulk.go +++ b/sdk/bulk.go @@ -133,7 +133,7 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error { return fmt.Errorf("retrieving platformEndpoint failed: %w", err) } // if no kasAllowlist is set, we get the allowlist from the registry - allowlist, err := allowListFromKASRegistry(ctx, s.KeyAccessServerRegistry, platformEndpoint) + allowlist, err := allowListFromKASRegistry(ctx, s.logger, s.KeyAccessServerRegistry, platformEndpoint) if err != nil { return fmt.Errorf("failed to get allowlist from registry: %w", err) } @@ -141,7 +141,7 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error { bulkReq.NanoTDFDecryptOptions = append(bulkReq.NanoTDFDecryptOptions, withNanoKasAllowlist(bulkReq.kasAllowlist)) bulkReq.TDF3DecryptOptions = append(bulkReq.TDF3DecryptOptions, withKasAllowlist(bulkReq.kasAllowlist)) } else { - slog.Error("no KAS allowlist provided and no KeyAccessServerRegistry available") + s.Logger().Error("no KAS allowlist provided and no KeyAccessServerRegistry available") return errors.New("no KAS allowlist provided and no KeyAccessServerRegistry available") } } @@ -172,7 +172,7 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error { var err error for kasurl, rewrapRequests := range kasRewrapRequests { if bulkReq.ignoreAllowList { - slog.Warn("kasAllowlist is ignored, kas url is allowed", slog.String("kas_url", kasurl)) + s.Logger().Warn("kasAllowlist is ignored, kas url is allowed", slog.String("kas_url", kasurl)) } else if !bulkReq.kasAllowlist.IsAllowed(kasurl) { // if kas url is not allowed, the result for each kao in each rewrap request is set to error for _, req := range rewrapRequests { diff --git a/sdk/fuzz_test.go b/sdk/fuzz_test.go index 3faadb8a0d..e62aabf036 100644 --- a/sdk/fuzz_test.go +++ b/sdk/fuzz_test.go @@ -6,6 +6,7 @@ import ( "crypto/rand" "encoding/base64" "io" + "log/slog" "net/http" "testing" @@ -26,6 +27,7 @@ func writeBytes(writerFunc func(io.Writer) error) []byte { func newSDK() *SDK { key, _ := ocrypto.NewRSAKeyPair(tdf3KeySize) cfg := &config{ + logger: slog.Default(), kasSessionKey: &key, } sdk := &SDK{ diff --git a/sdk/granter.go b/sdk/granter.go index c6642d4818..92b3b3afd9 100644 --- a/sdk/granter.go +++ b/sdk/granter.go @@ -200,6 +200,8 @@ func (a AttributeValueFQN) Name() string { // Structure capable of generating a split plan from a given set of data tags. type granter struct { + logger *slog.Logger + // The data attributes (tags) that this granter is responsible for. tags []AttributeValueFQN @@ -235,7 +237,7 @@ func (r *granter) addGrant(fqn AttributeValueFQN, kas string, attr *policy.Attri func (r *granter) addMappedKey(fqn AttributeValueFQN, sk *policy.SimpleKasKey) error { key := sk.GetPublicKey() if key == nil || key.GetKid() == "" || key.GetPem() == "" { - slog.Debug("invalid cached key in policy service", + r.logger.Debug("invalid cached key in policy service", slog.String("kas", sk.GetKasUri()), slog.Any("value", fqn), ) @@ -252,7 +254,7 @@ func (r *granter) addMappedKey(fqn AttributeValueFQN, sk *policy.SimpleKasKey) e rl, err := NewResourceLocator(sk.GetKasUri()) if err != nil { - slog.Debug("invalid KAS URL in policy service", + r.logger.Debug("invalid KAS URL in policy service", slog.String("kas", sk.GetKasUri()), slog.Any("value", fqn), slog.Any("error", err), @@ -260,7 +262,7 @@ func (r *granter) addMappedKey(fqn AttributeValueFQN, sk *policy.SimpleKasKey) e return fmt.Errorf("invalid KAS URL in policy service associated with [%s]: %w", fqn, err) } rl.identifier = key.GetKid() - slog.Debug("added mapped key", + r.logger.Debug("added mapped key", slog.Any("fqn", fqn), slog.String("kas", sk.GetKasUri()), slog.String("kid", key.GetKid()), @@ -330,7 +332,7 @@ func (r *granter) addAllGrants(fqn AttributeValueFQN, ag grantableObject, attr * // Check for mapped keys for _, k := range ag.GetKasKeys() { if k == nil || k.GetKasUri() == "" { - slog.Debug("invalid KAS key in policy service", + r.logger.Debug("invalid KAS key in policy service", slog.Any("simple_kas_key", k), slog.Any("value", fqn), ) @@ -341,7 +343,7 @@ func (r *granter) addAllGrants(fqn AttributeValueFQN, ag grantableObject, attr * result = r.typ err := r.addMappedKey(fqn, k) if err != nil { - slog.Debug("failed to add mapped key", + r.logger.Debug("failed to add mapped key", slog.Any("fqn", fqn), slog.String("kas", kasURI), slog.Any("error", err), @@ -367,7 +369,7 @@ func (r *granter) addAllGrants(fqn AttributeValueFQN, ag grantableObject, attr * for _, k := range g.GetKasKeys() { err := r.addMappedKey(fqn, k) if err != nil { - slog.Warn("failed to add mapped key", + r.logger.Warn("failed to add mapped key", slog.Any("fqn", fqn), slog.String("kas", kasURI), slog.Any("error", err), @@ -378,7 +380,7 @@ func (r *granter) addAllGrants(fqn AttributeValueFQN, ag grantableObject, attr * } ks := g.GetPublicKey().GetCached().GetKeys() if len(ks) == 0 { - slog.Debug("no cached key in policy service", + r.logger.Debug("no cached key in policy service", slog.String("kas", kasURI), slog.Any("value", fqn), ) @@ -386,7 +388,7 @@ func (r *granter) addAllGrants(fqn AttributeValueFQN, ag grantableObject, attr * } for _, k := range ks { if k.GetKid() == "" || k.GetPem() == "" { - slog.Debug("invalid cached key in policy service", + r.logger.Debug("invalid cached key in policy service", slog.String("kas", kasURI), slog.Any("value", fqn), slog.Any("key", k), @@ -404,7 +406,7 @@ func (r *granter) addAllGrants(fqn AttributeValueFQN, ag grantableObject, attr * } err := r.addMappedKey(fqn, sk) if err != nil { - slog.Warn("failed to add mapped key", + r.logger.Warn("failed to add mapped key", slog.Any("fqn", fqn), slog.String("kas", kasURI), slog.Any("error", err), @@ -426,7 +428,7 @@ func (r granter) byAttribute(fqn AttributeValueFQN) *keyAccessGrant { } // Gets a list of directory of KAS grants for a list of attribute FQNs -func newGranterFromService(ctx context.Context, keyCache *kasKeyCache, as sdkconnect.AttributesServiceClient, fqns ...AttributeValueFQN) (granter, error) { +func newGranterFromService(ctx context.Context, logger *slog.Logger, keyCache *kasKeyCache, as sdkconnect.AttributesServiceClient, fqns ...AttributeValueFQN) (granter, error) { fqnsStr := make([]string, len(fqns)) for i, v := range fqns { fqnsStr[i] = v.String() @@ -443,6 +445,7 @@ func newGranterFromService(ctx context.Context, keyCache *kasKeyCache, as sdkcon } grants := granter{ + logger: logger, tags: fqns, grantTable: make(map[string]*keyAccessGrant), keyCache: &rlKeyCache{c: make(map[ResourceLocator]*policy.SimpleKasKey)}, @@ -455,23 +458,23 @@ func newGranterFromService(ctx context.Context, keyCache *kasKeyCache, as sdkcon def := pair.GetAttribute() if def != nil { - storeKeysToCache(def.GetGrants(), def.GetKasKeys(), keyCache, grants.keyCache) + storeKeysToCache(logger, def.GetGrants(), def.GetKasKeys(), keyCache, grants.keyCache) } v := pair.GetValue() gType := noKeysFound if v != nil { gType = grants.addAllGrants(fqn, v, def) - storeKeysToCache(v.GetGrants(), v.GetKasKeys(), keyCache, grants.keyCache) + storeKeysToCache(logger, v.GetGrants(), v.GetKasKeys(), keyCache, grants.keyCache) } // If no more specific grant was found, then add the value grants if gType == noKeysFound && def != nil { gType = grants.addAllGrants(fqn, def, def) - storeKeysToCache(def.GetGrants(), def.GetKasKeys(), keyCache, grants.keyCache) + storeKeysToCache(logger, def.GetGrants(), def.GetKasKeys(), keyCache, grants.keyCache) } if gType == noKeysFound && def.GetNamespace() != nil { grants.addAllGrants(fqn, def.GetNamespace(), def) - storeKeysToCache(def.GetNamespace().GetGrants(), def.GetNamespace().GetKasKeys(), keyCache, grants.keyCache) + storeKeysToCache(logger, def.GetNamespace().GetGrants(), def.GetNamespace().GetKasKeys(), keyCache, grants.keyCache) } } @@ -515,11 +518,11 @@ func algProto2OcryptoKeyType(e policy.Algorithm) ocrypto.KeyType { } } -func storeKeysToCache(kases []*policy.KeyAccessServer, keys []*policy.SimpleKasKey, c *kasKeyCache, kc *rlKeyCache) { +func storeKeysToCache(logger *slog.Logger, kases []*policy.KeyAccessServer, keys []*policy.SimpleKasKey, c *kasKeyCache, kc *rlKeyCache) { for _, kas := range kases { keys := kas.GetPublicKey().GetCached().GetKeys() if len(keys) == 0 { - slog.Debug("no cached key in policy service", slog.String("kas", kas.GetUri())) + logger.Debug("no cached key in policy service", slog.String("kas", kas.GetUri())) continue } for _, ki := range keys { @@ -535,7 +538,7 @@ func storeKeysToCache(kases []*policy.KeyAccessServer, keys []*policy.SimpleKasK if kc != nil && ki.GetKid() != "" && ki.GetPem() != "" { rl, err := NewResourceLocator(kas.GetUri()) if err != nil { - slog.Debug("failed to create ResourceLocator", + logger.Debug("failed to create ResourceLocator", slog.String("kas", kas.GetUri()), slog.Any("error", err), ) @@ -570,7 +573,7 @@ func storeKeysToCache(kases []*policy.KeyAccessServer, keys []*policy.SimpleKasK if kc != nil && key.GetPublicKey().GetKid() != "" && key.GetPublicKey().GetPem() != "" { rl, err := NewResourceLocator(key.GetKasUri()) if err != nil { - slog.Debug("failed to create ResourceLocator", + logger.Debug("failed to create ResourceLocator", slog.String("kas", key.GetKasUri()), slog.Any("error", err), ) @@ -585,8 +588,9 @@ func storeKeysToCache(kases []*policy.KeyAccessServer, keys []*policy.SimpleKasK // Given a policy (list of data attributes or tags), // get a set of grants from attribute values to KASes. // Unlike `newGranterFromService`, this works offline. -func newGranterFromAttributes(keyCache *kasKeyCache, attrs ...*policy.Value) (granter, error) { +func newGranterFromAttributes(logger *slog.Logger, keyCache *kasKeyCache, attrs ...*policy.Value) (granter, error) { grants := granter{ + logger: logger, grantTable: make(map[string]*keyAccessGrant), mapTable: make(map[string][]*ResourceLocator), tags: make([]AttributeValueFQN, len(attrs)), @@ -608,16 +612,16 @@ func newGranterFromAttributes(keyCache *kasKeyCache, attrs ...*policy.Value) (gr } if grants.addAllGrants(fqn, v, def) != noKeysFound { - storeKeysToCache(v.GetGrants(), v.GetKasKeys(), keyCache, grants.keyCache) + storeKeysToCache(logger, v.GetGrants(), v.GetKasKeys(), keyCache, grants.keyCache) continue } // If no more specific grant was found, then add the attr grants if grants.addAllGrants(fqn, def, def) != noKeysFound { - storeKeysToCache(def.GetGrants(), def.GetKasKeys(), keyCache, grants.keyCache) + storeKeysToCache(logger, def.GetGrants(), def.GetKasKeys(), keyCache, grants.keyCache) continue } grants.addAllGrants(fqn, namespace, def) - storeKeysToCache(namespace.GetGrants(), namespace.GetKasKeys(), keyCache, grants.keyCache) + storeKeysToCache(logger, namespace.GetGrants(), namespace.GetKasKeys(), keyCache, grants.keyCache) } return grants, nil @@ -846,7 +850,7 @@ func (r *granter) insertKeysForAttribute(e attributeBooleanExpression) (booleanK var err error rl, err = NewResourceLocator(kas) if err != nil { - slog.Warn("invalid KAS URL in policy service", + r.logger.Warn("invalid KAS URL in policy service", slog.String("kas", kas), slog.Any("value", term), slog.Any("error", err), @@ -859,7 +863,7 @@ func (r *granter) insertKeysForAttribute(e attributeBooleanExpression) (booleanK } op := ruleToOperator(clause.def.GetRule()) if op == unspecified { - slog.Warn("unknown attribute rule type", slog.Any("rule", clause)) + r.logger.Warn("unknown attribute rule type", slog.Any("rule", clause)) } kc := keyClause{ operator: op, @@ -888,7 +892,7 @@ func (r *granter) assignKeysTo(e attributeBooleanExpression) (booleanKeyExpressi } op := ruleToOperator(clause.def.GetRule()) if op == unspecified { - slog.Warn("unknown attribute rule type", slog.Any("rule", clause)) + r.logger.Warn("unknown attribute rule type", slog.Any("rule", clause)) } kc := keyClause{ operator: op, diff --git a/sdk/granter_test.go b/sdk/granter_test.go index dc44442c8d..19e9519e61 100644 --- a/sdk/granter_test.go +++ b/sdk/granter_test.go @@ -3,6 +3,7 @@ package sdk import ( "context" "errors" + "log/slog" "maps" "regexp" "slices" @@ -568,7 +569,7 @@ func TestConfigurationServicePutGet(t *testing.T) { } { t.Run(tc.n, func(t *testing.T) { v := valuesToPolicy(tc.policy...) - grants, err := newGranterFromAttributes(newKasKeyCache(), v...) + grants, err := newGranterFromAttributes(slog.Default(), newKasKeyCache(), v...) require.NoError(t, err) assert.Len(t, grants.grantTable, tc.size) assert.Subset(t, policyToStringKeys(tc.policy), slices.Collect(maps.Keys(grants.grantTable))) @@ -726,7 +727,7 @@ func TestReasonerConstructAttributeBoolean(t *testing.T) { }, } { t.Run(tc.n, func(t *testing.T) { - reasoner, err := newGranterFromAttributes(newKasKeyCache(), valuesToPolicy(tc.policy...)...) + reasoner, err := newGranterFromAttributes(slog.Default(), newKasKeyCache(), valuesToPolicy(tc.policy...)...) require.NoError(t, err) reasoner.keyInfoFetcher = &fakeKeyInfoFetcher{} @@ -874,7 +875,7 @@ func TestReasonerSpecificity(t *testing.T) { }, } { t.Run(tc.n, func(t *testing.T) { - reasoner, err := newGranterFromService(t.Context(), newKasKeyCache(), &mockAttributesClient{}, tc.policy...) + reasoner, err := newGranterFromService(t.Context(), slog.Default(), newKasKeyCache(), &mockAttributesClient{}, tc.policy...) require.NoError(t, err) i := 0 plan, err := reasoner.plan(tc.defaults, func() string { @@ -1025,7 +1026,7 @@ func TestReasonerSpecificityWithNamespaces(t *testing.T) { }, } { t.Run((tc.n + "\n" + tc.desc), func(t *testing.T) { - reasoner, err := newGranterFromService(t.Context(), newKasKeyCache(), &mockAttributesClient{}, tc.policy...) + reasoner, err := newGranterFromService(t.Context(), slog.Default(), newKasKeyCache(), &mockAttributesClient{}, tc.policy...) require.NoError(t, err) i := 0 plan, err := reasoner.plan(tc.defaults, func() string { diff --git a/sdk/idp_access_token_source.go b/sdk/idp_access_token_source.go index cf4eb8a1fb..9b84b4022f 100644 --- a/sdk/idp_access_token_source.go +++ b/sdk/idp_access_token_source.go @@ -3,7 +3,6 @@ package sdk import ( "context" "fmt" - "log/slog" "net/http" "net/url" "sync" @@ -60,7 +59,10 @@ type IDPAccessTokenSource struct { } func NewIDPAccessTokenSource( - credentials oauth.ClientCredentials, idpTokenEndpoint string, scopes []string, key *ocrypto.RsaKeyPair, + credentials oauth.ClientCredentials, + idpTokenEndpoint string, + scopes []string, + key *ocrypto.RsaKeyPair, ) (*IDPAccessTokenSource, error) { endpoint, err := url.Parse(idpTokenEndpoint) if err != nil { @@ -73,6 +75,7 @@ func NewIDPAccessTokenSource( } tokenSource := IDPAccessTokenSource{ + // logger: log, credentials: credentials, idpTokenEndpoint: *endpoint, token: nil, @@ -87,12 +90,11 @@ func NewIDPAccessTokenSource( } // AccessToken use a pointer receiver so that the token state is shared -func (t *IDPAccessTokenSource) AccessToken(ctx context.Context, client *http.Client) (auth.AccessToken, error) { +func (t *IDPAccessTokenSource) AccessToken(_ context.Context, client *http.Client) (auth.AccessToken, error) { t.tokenMutex.Lock() defer t.tokenMutex.Unlock() if t.token == nil || t.token.Expired() { - slog.DebugContext(ctx, "getting new access token") tok, err := oauth.GetAccessToken(client, t.idpTokenEndpoint.String(), t.scopes, t.credentials, t.dpopKey) if err != nil { return "", fmt.Errorf("error getting access token: %w", err) diff --git a/sdk/idp_cert_exchange.go b/sdk/idp_cert_exchange.go index 77de36245e..5e0d2f4ee9 100644 --- a/sdk/idp_cert_exchange.go +++ b/sdk/idp_cert_exchange.go @@ -2,6 +2,7 @@ package sdk import ( "context" + "log/slog" "net/http" "sync" @@ -13,6 +14,7 @@ import ( type CertExchangeTokenSource struct { auth.AccessTokenSource + logger *slog.Logger IdpEndpoint string credentials oauth.ClientCredentials tokenMutex *sync.Mutex @@ -21,13 +23,14 @@ type CertExchangeTokenSource struct { key jwk.Key } -func NewCertExchangeTokenSource(info oauth.CertExchangeInfo, credentials oauth.ClientCredentials, idpTokenEndpoint string, dpop *ocrypto.RsaKeyPair) (auth.AccessTokenSource, error) { +func NewCertExchangeTokenSource(logger *slog.Logger, info oauth.CertExchangeInfo, credentials oauth.ClientCredentials, idpTokenEndpoint string, dpop *ocrypto.RsaKeyPair) (auth.AccessTokenSource, error) { _, dpopKey, _, err := getNewDPoPKey(dpop) if err != nil { return nil, err } exchangeSource := CertExchangeTokenSource{ + logger: logger, info: info, IdpEndpoint: idpTokenEndpoint, credentials: credentials, diff --git a/sdk/idp_token_exchange_token_source.go b/sdk/idp_token_exchange_token_source.go index 5641447c49..e9748d0b0d 100644 --- a/sdk/idp_token_exchange_token_source.go +++ b/sdk/idp_token_exchange_token_source.go @@ -2,6 +2,7 @@ package sdk import ( "context" + "log/slog" "net/http" "github.com/lestrrat-go/jwx/v2/jwk" @@ -13,15 +14,17 @@ import ( type IDPTokenExchangeTokenSource struct { IDPAccessTokenSource oauth.TokenExchangeInfo + logger *slog.Logger } -func NewIDPTokenExchangeTokenSource(exchangeInfo oauth.TokenExchangeInfo, credentials oauth.ClientCredentials, idpTokenEndpoint string, scopes []string, key *ocrypto.RsaKeyPair) (*IDPTokenExchangeTokenSource, error) { +func NewIDPTokenExchangeTokenSource(logger *slog.Logger, exchangeInfo oauth.TokenExchangeInfo, credentials oauth.ClientCredentials, idpTokenEndpoint string, scopes []string, key *ocrypto.RsaKeyPair) (*IDPTokenExchangeTokenSource, error) { idpSource, err := NewIDPAccessTokenSource(credentials, idpTokenEndpoint, scopes, key) if err != nil { return nil, err } exchangeSource := IDPTokenExchangeTokenSource{ + logger: logger, IDPAccessTokenSource: *idpSource, TokenExchangeInfo: exchangeInfo, } diff --git a/sdk/nanotdf.go b/sdk/nanotdf.go index 32d46d0c50..51af506687 100644 --- a/sdk/nanotdf.go +++ b/sdk/nanotdf.go @@ -96,7 +96,7 @@ func NewNanoTDFHeaderFromReader(reader io.Reader) (NanoTDFHeader, uint32, error) size += uint32(resource.getLength()) header.kasURL = *resource - slog.Debug("checkpoint NewNanoTDFHeaderFromReader", slog.Uint64("resource_locator", uint64(resource.getLength()))) + getLogger().Debug("checkpoint NewNanoTDFHeaderFromReader", slog.Uint64("resource_locator", uint64(resource.getLength()))) // Read ECC and Binding Mode oneBytes := make([]byte, 1) @@ -141,7 +141,7 @@ func NewNanoTDFHeaderFromReader(reader io.Reader) (NanoTDFHeader, uint32, error) } size += uint32(l) policyLength := binary.BigEndian.Uint16(twoBytes) - slog.Debug("checkpoint NewNanoTDFHeaderFromReader", slog.Uint64("policy_length", uint64(policyLength))) + getLogger().Debug("checkpoint NewNanoTDFHeaderFromReader", slog.Uint64("policy_length", uint64(policyLength))) // Read policy body header.PolicyMode = policyMode @@ -204,7 +204,7 @@ func NewNanoTDFHeaderFromReader(reader io.Reader) (NanoTDFHeader, uint32, error) size += uint32(l) header.EphemeralKey = ephemeralKey - slog.Debug("checkpoint NewNanoTDFHeaderFromReader", slog.Uint64("header_size", uint64(size))) + getLogger().Debug("checkpoint NewNanoTDFHeaderFromReader", slog.Uint64("header_size", uint64(size))) return header, size, nil } @@ -275,12 +275,12 @@ func (ep embeddedPolicy) writeEmbeddedPolicy(writer io.Writer) error { if _, err := writer.Write(buf); err != nil { return err } - slog.Debug("writeEmbeddedPolicy", slog.Uint64("policy_length", uint64(ep.lengthBody))) + getLogger().Debug("writeEmbeddedPolicy", slog.Uint64("policy_length", uint64(ep.lengthBody))) if _, err := writer.Write(ep.body); err != nil { return err } - slog.Debug("writeEmbeddedPolicy", slog.Uint64("policy_body", uint64(len(ep.body)))) + getLogger().Debug("writeEmbeddedPolicy", slog.Uint64("policy_body", uint64(len(ep.body)))) return nil } @@ -617,7 +617,7 @@ func writeNanoTDFHeader(writer io.Writer, config NanoTDFConfig) ([]byte, uint32, } totalBytes += uint32(l) - slog.Debug("writeNanoTDFHeader", slog.Uint64("magic_number", uint64(len(kNanoTDFMagicStringAndVersion)))) + getLogger().Debug("writeNanoTDFHeader", slog.Uint64("magic_number", uint64(len(kNanoTDFMagicStringAndVersion)))) // Write the kas url err = config.kasURL.writeResourceLocator(writer) @@ -625,7 +625,7 @@ func writeNanoTDFHeader(writer io.Writer, config NanoTDFConfig) ([]byte, uint32, return nil, 0, 0, err } totalBytes += uint32(config.kasURL.getLength()) - slog.Debug("writeNanoTDFHeader", slog.Uint64("resource_locator_number", uint64(config.kasURL.getLength()))) + getLogger().Debug("writeNanoTDFHeader", slog.Uint64("resource_locator_number", uint64(config.kasURL.getLength()))) // Write ECC And Binding Mode l, err = writer.Write([]byte{serializeBindingCfg(config.bindCfg)}) @@ -804,7 +804,7 @@ func (s SDK) CreateNanoTDF(writer io.Writer, reader io.Reader, config NanoTDFCon return 0, fmt.Errorf("writeNanoTDFHeader failed:%w", err) } - slog.Debug("checkpoint CreateNanoTDF", slog.Uint64("header", uint64(totalSize))) + s.Logger().Debug("checkpoint CreateNanoTDF", slog.Uint64("header", uint64(totalSize))) aesGcm, err := ocrypto.NewAESGcm(key) if err != nil { @@ -846,7 +846,7 @@ func (s SDK) CreateNanoTDF(writer io.Writer, reader io.Reader, config NanoTDFCon } totalSize += uint32(l) - slog.Debug("checkpoint CreateNanoTDF", slog.Uint64("payload_length", uint64(len(cipherDataWithoutPadding)))) + s.Logger().Debug("checkpoint CreateNanoTDF", slog.Uint64("payload_length", uint64(len(cipherDataWithoutPadding)))) // write cipher data l, err = writer.Write(cipherDataWithoutPadding) @@ -948,7 +948,7 @@ func (n *NanoTDFDecryptHandler) Decrypt(ctx context.Context, result []kaoResult) } payloadLength := binary.BigEndian.Uint32(payloadLengthBuf) - slog.DebugContext(ctx, "decrypt", slog.Uint64("payload_length", uint64(payloadLength))) + getLogger().DebugContext(ctx, "decrypt", slog.Uint64("payload_length", uint64(payloadLength))) cipherData := make([]byte, payloadLength) _, err = n.reader.Read(cipherData) @@ -1008,7 +1008,7 @@ func (s SDK) ReadNanoTDFContext(ctx context.Context, writer io.Writer, reader io return 0, fmt.Errorf("retrieving platformEndpoint failed: %w", err) } // retrieve the registered kases if not provided - allowList, err := allowListFromKASRegistry(ctx, s.KeyAccessServerRegistry, platformEndpoint) + allowList, err := allowListFromKASRegistry(ctx, s.logger, s.KeyAccessServerRegistry, platformEndpoint) if err != nil { return 0, fmt.Errorf("allowListFromKASRegistry failed: %w", err) } @@ -1104,7 +1104,7 @@ func getKasInfoForNanoTDF(s *SDK, config *NanoTDFConfig) (*KASInfo, error) { } } - slog.Debug("getNanoKasInfoFromBaseKey failed, falling back to default kas", slog.String("error", err.Error())) + s.logger.Debug("getNanoKasInfoFromBaseKey failed, falling back to default kas", slog.String("error", err.Error())) kasURL, err := config.kasURL.GetURL() if err != nil { diff --git a/sdk/nanotdf_test.go b/sdk/nanotdf_test.go index f13a349779..17f396343e 100644 --- a/sdk/nanotdf_test.go +++ b/sdk/nanotdf_test.go @@ -8,6 +8,7 @@ import ( "encoding/gob" "errors" "io" + "log/slog" "os" "testing" @@ -431,6 +432,9 @@ func (s *NanoSuite) Test_CreateNanoTDF_BaseKey() { // Create SDK sdk := &SDK{ + config: config{ + logger: slog.Default(), + }, wellknownConfiguration: mockClient, } diff --git a/sdk/options.go b/sdk/options.go index 7ad993388d..e62b8459a1 100644 --- a/sdk/options.go +++ b/sdk/options.go @@ -3,6 +3,7 @@ package sdk import ( "crypto/rsa" "crypto/tls" + "log/slog" "net/http" "connectrpc.com/connect" @@ -43,6 +44,7 @@ type config struct { entityResolutionConn *ConnectRPCConnection collectionStore *collectionStore shouldValidatePlatformConnectivity bool + logger *slog.Logger } // Options specific to TDF protocol features @@ -230,3 +232,10 @@ func WithNoKIDInNano() Option { c.nanoFeatures.noKID = true } } + +// WithLogger returns an Option that sets a custom slog.Logger for all SDK logging. +func WithLogger(logger *slog.Logger) Option { + return func(c *config) { + c.logger = logger + } +} diff --git a/sdk/sdk.go b/sdk/sdk.go index 45af3b275d..fa6727130a 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -9,9 +9,11 @@ import ( "errors" "fmt" "io" + "log/slog" "net/http" "net/url" "strings" + "sync" "connectrpc.com/connect" "github.com/opentdf/platform/lib/ocrypto" @@ -43,12 +45,36 @@ const ( ErrWellKnowConfigEmpty = Error("well-known configuration is empty") ) +var ( + // Package-level logger for internal SDK functions + packageLogger *slog.Logger + loggerMutex sync.RWMutex +) + type Error string func (c Error) Error() string { return string(c) } +// getLogger returns the package-level logger, defaulting to slog.Default() if not set to +// provide access to the logger in exported functions where signatures are unable to be altered +func getLogger() *slog.Logger { + loggerMutex.RLock() + defer loggerMutex.RUnlock() + if packageLogger != nil { + return packageLogger + } + return slog.Default() +} + +// setPackageLogger sets the package-level logger for internal SDK functions +func setPackageLogger(logger *slog.Logger) { + loggerMutex.Lock() + defer loggerMutex.Unlock() + packageLogger = logger +} + type SDK struct { config *kasKeyCache @@ -91,6 +117,14 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { opt(cfg) } + // Set default logger if none provided + if cfg.logger == nil { + cfg.logger = slog.Default() + } + + // Set package-level logger for internal functions + setPackageLogger(cfg.logger) + // If IPC is enabled, we need to have a core connection if cfg.ipc && cfg.coreConn == nil { return nil, errors.New("core connection is required for IPC mode") @@ -241,9 +275,10 @@ func buildIDPTokenSource(c *config) (auth.AccessTokenSource, error) { case c.oauthAccessTokenSource != nil: ts, err = NewOAuthAccessTokenSource(c.oauthAccessTokenSource, c.scopes, c.dpopKey) case c.certExchange != nil: - ts, err = NewCertExchangeTokenSource(*c.certExchange, *c.clientCredentials, c.tokenEndpoint, c.dpopKey) + ts, err = NewCertExchangeTokenSource(c.logger, *c.certExchange, *c.clientCredentials, c.tokenEndpoint, c.dpopKey) case c.tokenExchange != nil: ts, err = NewIDPTokenExchangeTokenSource( + c.logger, *c.tokenExchange, *c.clientCredentials, c.tokenEndpoint, @@ -269,6 +304,11 @@ func (s SDK) Close() error { return nil } +// Logger returns the configured slog.Logger for this SDK instance +func (s SDK) Logger() *slog.Logger { + return s.logger +} + // Conn returns the underlying http connection func (s SDK) Conn() *ConnectRPCConnection { return s.conn diff --git a/sdk/tdf.go b/sdk/tdf.go index b74884d741..eb9de45ad6 100644 --- a/sdk/tdf.go +++ b/sdk/tdf.go @@ -391,7 +391,7 @@ func (tdfConfig *TDFConfig) initKAOTemplate(ctx context.Context, s SDK) error { if err == nil { err = populateKasInfoFromBaseKey(baseKey, tdfConfig) } else { - slog.Debug("cannot getting base key, falling back to default kas", slog.Any("error", err)) + s.Logger().Debug("cannot getting base key, falling back to default kas", slog.Any("error", err)) dk := s.defaultKases(tdfConfig) tdfConfig.kaoTemplate = nil tdfConfig.splitPlan, err = g.plan(dk, uuidSplitIDGenerator) @@ -450,9 +450,9 @@ func (s SDK) newGranter(ctx context.Context, tdfConfig *TDFConfig) (granter, err var g granter var err error if len(tdfConfig.attributeValues) > 0 { - g, err = newGranterFromAttributes(s.kasKeyCache, tdfConfig.attributeValues...) + g, err = newGranterFromAttributes(s.logger, s.kasKeyCache, tdfConfig.attributeValues...) } else if len(tdfConfig.attributes) > 0 { - g, err = newGranterFromService(ctx, s.kasKeyCache, s.Attributes, tdfConfig.attributes...) + g, err = newGranterFromService(ctx, s.logger, s.kasKeyCache, s.Attributes, tdfConfig.attributes...) } if err != nil { return g, err @@ -729,7 +729,7 @@ func createPolicyObject(attributes []AttributeValueFQN) (PolicyObject, error) { return policyObj, nil } -func allowListFromKASRegistry(ctx context.Context, kasRegistryClient sdkconnect.KeyAccessServerRegistryServiceClient, platformURL string) (AllowList, error) { +func allowListFromKASRegistry(ctx context.Context, logger *slog.Logger, kasRegistryClient sdkconnect.KeyAccessServerRegistryServiceClient, platformURL string) (AllowList, error) { kases, err := kasRegistryClient.ListKeyAccessServers(ctx, &kasregistry.ListKeyAccessServersRequest{}) if err != nil { return nil, fmt.Errorf("kasregistry.ListKeyAccessServers failed: %w", err) @@ -742,7 +742,7 @@ func allowListFromKASRegistry(ctx context.Context, kasRegistryClient sdkconnect. } } // grpc target does not have a scheme - slog.Debug("adding platform URL to KAS allowlist", slog.String("platform_url", platformURL)) + logger.Debug("adding platform URL to KAS allowlist", slog.String("platform_url", platformURL)) err = kasAllowlist.Add(platformURL) if err != nil { return nil, fmt.Errorf("kasAllowlist.Add failed: %w", err) @@ -774,7 +774,7 @@ func (s SDK) LoadTDF(reader io.ReadSeeker, opts ...TDFReaderOption) (*Reader, er if err != nil { return nil, fmt.Errorf("retrieving platformEndpoint failed: %w", err) } - allowList, err := allowListFromKASRegistry(context.Background(), s.KeyAccessServerRegistry, platformEndpoint) + allowList, err := allowListFromKASRegistry(context.Background(), s.logger, s.KeyAccessServerRegistry, platformEndpoint) if err != nil { return nil, fmt.Errorf("allowListFromKASRegistry failed: %w", err) } @@ -1379,7 +1379,7 @@ func (r *Reader) doPayloadKeyUnwrap(ctx context.Context) error { //nolint:gocogn // if ignoreing allowlist then warn // if kas url is not allowed then return error if r.config.ignoreAllowList { - slog.WarnContext(ctx, "kasAllowlist is ignored, kas url is allowed", slog.String("kas_url", kasurl)) + getLogger().WarnContext(ctx, "kasAllowlist is ignored, kas url is allowed", slog.String("kas_url", kasurl)) } else if !r.config.kasAllowlist.IsAllowed(kasurl) { reqFail(fmt.Errorf("KasAllowlist: kas url %s is not allowed", kasurl), req) continue @@ -1479,12 +1479,12 @@ func populateKasInfoFromBaseKey(key *policy.SimpleKasKey, tdfConfig *TDFConfig) // ? Maybe we shouldn't overwrite the key type if tdfConfig.preferredKeyWrapAlg != ocrypto.KeyType(algoString) { - slog.Warn("base key is enabled, setting key type", slog.String("key_type", algoString)) + getLogger().Warn("base key is enabled, setting key type", slog.String("key_type", algoString)) } tdfConfig.preferredKeyWrapAlg = ocrypto.KeyType(algoString) tdfConfig.splitPlan = nil if len(tdfConfig.kasInfoList) > 0 { - slog.Warn("base key is enabled, overwriting kasInfoList with base key info") + getLogger().Warn("base key is enabled, overwriting kasInfoList with base key info") } tdfConfig.kasInfoList = []KASInfo{ {