From a7c691b39590ae9fa8c7fe9972652218f9b9c2e3 Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Tue, 23 Sep 2025 13:19:31 -0700 Subject: [PATCH 1/9] feat(core): propagate token clientID on configured claim via interceptor into shared context --- service/internal/auth/authn.go | 63 ++++++++++++++----- service/internal/auth/authn_test.go | 26 +++++++- service/internal/auth/config.go | 2 + service/pkg/auth/context_auth.go | 57 ++++++++++++++++- service/pkg/auth/context_auth_test.go | 89 +++++++++++++++++++++++++++ 5 files changed, 220 insertions(+), 17 deletions(-) diff --git a/service/internal/auth/authn.go b/service/internal/auth/authn.go index 5a26aa27e5..ef1a6c011d 100644 --- a/service/internal/auth/authn.go +++ b/service/internal/auth/authn.go @@ -22,7 +22,6 @@ import ( "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jws" "github.com/lestrrat-go/jwx/v2/jwt" - "google.golang.org/grpc/metadata" sdkAudit "github.com/opentdf/platform/sdk/audit" "github.com/opentdf/platform/service/logger" @@ -71,6 +70,9 @@ const ( ActionDelete = "delete" ActionUnsafe = "unsafe" ActionOther = "other" + + mdAccessTokenKey = "access_token" + mdClientIDKey = "client_id" ) // Authentication holds a jwks cache and information about the openid configuration @@ -242,12 +244,16 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler { return } - md, ok := metadata.FromIncomingContext(ctxWithJWK) - if !ok { - md = metadata.New(nil) + var clientID string + clientIDClaim := a.oidcConfiguration.Policy.ClientIDClaim + if clientIDClaim != "" { + if id, ok := accessTok.Get(clientIDClaim); ok { + if clientIDClaimValue, ok := id.(string); ok { + clientID = clientIDClaimValue + } + } } - md.Append("access_token", ctxAuth.GetRawAccessTokenFromContext(ctxWithJWK, nil)) - ctxWithJWK = metadata.NewIncomingContext(ctxWithJWK, md) + ctxWithMetadata := ctxAuth.ContextWithAuthnMetadata(ctxWithJWK, clientID) // Check if the token is allowed to access the resource var action string @@ -266,6 +272,8 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler { a.logger.WarnContext(r.Context(), "permission denied", slog.String("azp", accessTok.Subject()), + slog.String("configured_client_id_claim_name", clientIDClaim), + slog.String("client_id", clientID), slog.Any("error", err), ) http.Error(w, "permission denied", http.StatusForbidden) @@ -274,12 +282,18 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler { http.Error(w, "internal server error", http.StatusInternalServerError) return } else if !allow { - a.logger.WarnContext(r.Context(), "permission denied", slog.String("azp", accessTok.Subject())) + a.logger.WarnContext( + r.Context(), + "permission denied", + slog.String("azp", accessTok.Subject()), + slog.String("configured_client_id_claim_name", clientIDClaim), + slog.String("client_id", clientID), + ) http.Error(w, "permission denied", http.StatusForbidden) return } - r = r.WithContext(ctxWithJWK) + r = r.WithContext(ctxWithMetadata) handler.ServeHTTP(w, r) }) } @@ -319,7 +333,7 @@ func (a Authentication) ConnectUnaryServerInterceptor() connect.UnaryInterceptor resource := p[1] + "/" + p[2] action := getAction(p[2]) - token, newCtx, err := a.checkToken( + token, ctxWithJWK, err := a.checkToken( ctx, header, ri, @@ -329,11 +343,24 @@ func (a Authentication) ConnectUnaryServerInterceptor() connect.UnaryInterceptor return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("unauthenticated")) } + var clientID string + clientIDClaim := a.oidcConfiguration.Policy.ClientIDClaim + if clientIDClaim != "" { + if id, ok := token.Get(clientIDClaim); ok { + if idStr, ok := id.(string); ok { + clientID = idStr + } + } + } + ctxWithMetadata := ctxAuth.ContextWithAuthnMetadata(ctxWithJWK, clientID) + // Check if the token is allowed to access the resource if allowed, err := a.enforcer.Enforce(token, resource, action); err != nil { if err.Error() == "permission denied" { a.logger.Warn("permission denied", slog.String("azp", token.Subject()), + slog.String("configured_client_id_claim_name", clientIDClaim), + slog.String("client_id", clientID), slog.Any("error", err), ) return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied")) @@ -344,7 +371,7 @@ func (a Authentication) ConnectUnaryServerInterceptor() connect.UnaryInterceptor return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied")) } - return next(newCtx, req) + return next(ctxWithMetadata, req) }) } return connect.UnaryInterceptorFunc(interceptor) @@ -431,12 +458,12 @@ func (a *Authentication) checkToken(ctx context.Context, authHeader []string, dp ctx = ctxAuth.ContextWithAuthNInfo(ctx, nil, accessToken, tokenRaw) return accessToken, ctx, nil } - key, err := a.validateDPoP(accessToken, tokenRaw, dpopInfo, dpopHeader) + dpopKey, err := a.validateDPoP(accessToken, tokenRaw, dpopInfo, dpopHeader) if err != nil { a.logger.Warn("failed to validate dpop", slog.Any("err", err)) return nil, nil, err } - ctx = ctxAuth.ContextWithAuthNInfo(ctx, key, accessToken, tokenRaw) + ctx = ctxAuth.ContextWithAuthNInfo(ctx, dpopKey, accessToken, tokenRaw) return accessToken, ctx, nil } @@ -668,7 +695,7 @@ func (a Authentication) ipcReauthCheck(ctx context.Context, path string, header u = append(u, a.lookupGatewayPaths(ctx, path, header)...) // Validate the token and create a JWT token - _, nextCtx, err := a.checkToken(ctx, authHeader, receiverInfo{ + token, ctxWithJWK, err := a.checkToken(ctx, authHeader, receiverInfo{ u: u, m: []string{http.MethodPost}, }, header["Dpop"]) @@ -677,7 +704,15 @@ func (a Authentication) ipcReauthCheck(ctx context.Context, path string, header } // Return the next context with the token - return nextCtx, nil + var clientID string + if clientIDClaim := a.oidcConfiguration.Policy.ClientIDClaim; clientIDClaim != "" { + if id, ok := token.Get(clientIDClaim); ok { + if idStr, ok := id.(string); ok { + clientID = idStr + } + } + } + return ctxAuth.ContextWithAuthnMetadata(ctxWithJWK, clientID), nil } } return ctx, nil diff --git a/service/internal/auth/authn_test.go b/service/internal/auth/authn_test.go index 686bc8ad7d..b5b7f302ed 100644 --- a/service/internal/auth/authn_test.go +++ b/service/internal/auth/authn_test.go @@ -56,6 +56,7 @@ type FakeAccessTokenSource struct { } type FakeAccessServiceServer struct { + clientID string accessToken []string dpopKey jwk.Key kas.UnimplementedAccessServiceServer @@ -72,6 +73,7 @@ func (f *FakeAccessServiceServer) LegacyPublicKey(_ context.Context, _ *connect. func (f *FakeAccessServiceServer) Rewrap(ctx context.Context, req *connect.Request[kas.RewrapRequest]) (*connect.Response[kas.RewrapResponse], error) { f.accessToken = req.Header()["Authorization"] f.dpopKey = ctxAuth.GetJWKFromContext(ctx, logger.CreateTestLogger()) + f.clientID, _ = ctxAuth.GetClientIDFromContext(ctx) return &connect.Response[kas.RewrapResponse]{Msg: &kas.RewrapResponse{}}, nil } @@ -148,7 +150,9 @@ func (s *AuthSuite) SetupTest() { } })) - policyCfg := PolicyConfig{} + policyCfg := PolicyConfig{ + ClientIDClaim: "cid", + } err = defaults.Set(&policyCfg) s.Require().NoError(err) @@ -214,6 +218,8 @@ func TestNormalizeUrl(t *testing.T) { func (s *AuthSuite) Test_IPCUnaryServerInterceptor() { // Mock the checkToken method to return a valid token and context mockToken := jwt.New() + mockToken.Set("cid", "mockClientID") + type contextKey string mockCtx := context.WithValue(context.Background(), contextKey("mockKey"), "mockValue") s.auth._testCheckTokenFunc = func(_ context.Context, authHeader []string, _ receiverInfo, _ []string) (jwt.Token, context.Context, error) { @@ -234,6 +240,9 @@ func (s *AuthSuite) Test_IPCUnaryServerInterceptor() { s.Require().NoError(err) s.Require().NotNil(nextCtx) s.Equal("mockValue", nextCtx.Value(contextKey("mockKey"))) + clientID, err := ctxAuth.GetClientIDFromContext(nextCtx) + s.Require().NoError(err) + s.Equal("mockClientID", clientID) // Test with a route not requiring reauthorization nextCtx, err = s.auth.ipcReauthCheck(context.Background(), "/kas.AccessService/PublicKey", nil) @@ -482,7 +491,7 @@ func (s *AuthSuite) TestDPoPEndToEnd_GRPC() { s.Require().NoError(tok.Set(jwt.ExpirationKey, time.Now().Add(time.Hour))) s.Require().NoError(tok.Set("iss", s.server.URL)) s.Require().NoError(tok.Set("aud", "test")) - s.Require().NoError(tok.Set("cid", "client2")) + s.Require().NoError(tok.Set("cid", "client-123")) s.Require().NoError(tok.Set("realm_access", map[string][]string{"roles": {"opentdf-standard"}})) thumbprint, err := dpopKey.Thumbprint(crypto.SHA256) s.Require().NoError(err) @@ -517,6 +526,7 @@ func (s *AuthSuite) TestDPoPEndToEnd_GRPC() { _, err = client.Rewrap(context.Background(), &kas.RewrapRequest{}) s.Require().NoError(err) + s.Equal(fakeServer.clientID, "client-123") s.NotNil(fakeServer.dpopKey) dpopJWKFromRequest, ok := fakeServer.dpopKey.(jwk.RSAPublicKey) s.True(ok) @@ -552,12 +562,15 @@ func (s *AuthSuite) TestDPoPEndToEnd_HTTP() { jwkChan := make(chan jwk.Key, 1) timeout := make(chan string, 1) + clientIDChan := make(chan string, 1) go func() { time.Sleep(5 * time.Second) timeout <- "" }() server := httptest.NewServer(s.auth.MuxHandler(http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) { jwkChan <- ctxAuth.GetJWKFromContext(req.Context(), logger.CreateTestLogger()) + cid, _ := ctxAuth.GetClientIDFromContext(req.Context()) + clientIDChan <- cid }))) defer server.Close() @@ -585,6 +598,15 @@ func (s *AuthSuite) TestDPoPEndToEnd_HTTP() { case <-timeout: s.Require().FailNow("timed out waiting for call to complete") } + var clientID string + select { + case cid := <-clientIDChan: + clientID = cid + case <-timeout: + s.Require().FailNow("timed out waiting for call to complete") + } + + s.Equal(clientID, "client2") s.NotNil(dpopKeyFromRequest) dpopJWKFromRequest, ok := dpopKeyFromRequest.(jwk.RSAPublicKey) diff --git a/service/internal/auth/config.go b/service/internal/auth/config.go index 7fe8fd9a76..5e48877cf1 100644 --- a/service/internal/auth/config.go +++ b/service/internal/auth/config.go @@ -34,6 +34,8 @@ type PolicyConfig struct { UserNameClaim string `mapstructure:"username_claim" json:"username_claim" default:"preferred_username"` // Claim to use for group/role information GroupsClaim string `mapstructure:"groups_claim" json:"groups_claim" default:"realm_access.roles"` + // Claim to use to reference idP clientID + ClientIDClaim string `mapstructure:"client_id_claim" json:"client_id_claim" default:"azp"` // Deprecated: Use GroupClain instead RoleClaim string `mapstructure:"claim" json:"claim" default:"realm_access.roles"` // Deprecated: Use Casbin grouping statements g, , diff --git a/service/pkg/auth/context_auth.go b/service/pkg/auth/context_auth.go index 20f5b62d8d..f4d857a777 100644 --- a/service/pkg/auth/context_auth.go +++ b/service/pkg/auth/context_auth.go @@ -2,13 +2,25 @@ package auth import ( "context" + "errors" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/opentdf/platform/service/logger" + "google.golang.org/grpc/metadata" ) -var authnContextKey = authContextKey{} +var ( + authnContextKey = authContextKey{} + ErrNoMetadataFound = errors.New("no metadata found within context") + ErrMissingClientID = errors.New("context metadata missing authn idP clientID that should have been set by interceptor") + ErrConflictClientID = errors.New("context metadata has more than one authn idP clientID and should only ever have one") +) + +const ( + accessTokenKey = "access_token" + clientIDKey = "client_id" +) type authContextKey struct{} @@ -60,3 +72,46 @@ func GetRawAccessTokenFromContext(ctx context.Context, l *logger.Logger) string } return "" } + +// ContextWithAuthnMetadata adds the access token and client ID to context metadata +// +// Adding the authn into to gRPC metadata propagates it across services rather than strictly +// in-process within Go alone +func ContextWithAuthnMetadata(ctx context.Context, clientID string) context.Context { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + md = metadata.New(nil) + } else { + // Do not modify original metadata from parent context + md = md.Copy() + } + + if rawToken := GetRawAccessTokenFromContext(ctx, nil); rawToken != "" { + md.Set(accessTokenKey, rawToken) + } + + // Add client ID to metadata for downstream services + if clientID != "" { + md.Set(clientIDKey, clientID) + } + + return metadata.NewIncomingContext(ctx, md) +} + +// GetClientIDFromContext retrieves the client ID from the metadata in the context +func GetClientIDFromContext(ctx context.Context) (string, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return "", ErrNoMetadataFound + } + + clientIDs := md.Get(clientIDKey) + if len(clientIDs) == 0 { + return "", ErrMissingClientID + } + if len(clientIDs) > 1 { + return "", ErrConflictClientID + } + + return clientIDs[0], nil +} diff --git a/service/pkg/auth/context_auth_test.go b/service/pkg/auth/context_auth_test.go index c015fe8270..ef8f68dd4c 100644 --- a/service/pkg/auth/context_auth_test.go +++ b/service/pkg/auth/context_auth_test.go @@ -8,6 +8,8 @@ import ( "github.com/lestrrat-go/jwx/v2/jwt" "github.com/opentdf/platform/service/logger" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/metadata" ) func TestContextWithAuthNInfo(t *testing.T) { @@ -70,3 +72,90 @@ func TestGetContextDetailsInvalidType(t *testing.T) { retrievedJWK := GetJWKFromContext(ctx, logger.CreateTestLogger()) assert.Nil(t, retrievedJWK, "JWK should be nil when context value is invalid") } + +func TestContextWithAuthnMetadata(t *testing.T) { + mockClientID := "test-client-id" + + t.Run("should add access token and client id to metadata", func(t *testing.T) { + ctx := ContextWithAuthNInfo(context.Background(), nil, nil, "raw-token-string") + enrichedCtx := ContextWithAuthnMetadata(ctx, mockClientID) + + md, ok := metadata.FromIncomingContext(enrichedCtx) + require.True(t, ok) + + accessToken := md.Get("access_token") + require.Len(t, accessToken, 1) + assert.Equal(t, "raw-token-string", accessToken[0]) + + clientIDs := md.Get(clientIDKey) + require.Len(t, clientIDs, 1) + assert.Equal(t, mockClientID, clientIDs[0]) + }) + + t.Run("should not set client id if empty", func(t *testing.T) { + ctx := ContextWithAuthNInfo(context.Background(), nil, nil, "raw-token-string") + enrichedCtx := ContextWithAuthnMetadata(ctx, "") + + md, ok := metadata.FromIncomingContext(enrichedCtx) + require.True(t, ok) + + clientIDs := md.Get(clientIDKey) + assert.Empty(t, clientIDs) + }) + + t.Run("should preserve existing metadata", func(t *testing.T) { + originalMD := metadata.New(map[string]string{"original-key": "original-value"}) + ctx := metadata.NewIncomingContext(context.Background(), originalMD) + ctx = ContextWithAuthNInfo(ctx, nil, nil, "raw-token-string") + + enrichedCtx := ContextWithAuthnMetadata(ctx, mockClientID) + + md, ok := metadata.FromIncomingContext(enrichedCtx) + require.True(t, ok) + + originalValue := md.Get("original-key") + require.Len(t, originalValue, 1) + assert.Equal(t, "original-value", originalValue[0]) + + clientIDs := md.Get(clientIDKey) + require.Len(t, clientIDs, 1) + assert.Equal(t, mockClientID, clientIDs[0]) + }) +} + +func TestGetClientIDFromContext(t *testing.T) { + mockClientID := "test-client-id" + + t.Run("good - should retrieve client id from context", func(t *testing.T) { + md := metadata.New(map[string]string{clientIDKey: mockClientID}) + ctx := metadata.NewIncomingContext(t.Context(), md) + + clientID, err := GetClientIDFromContext(ctx) + require.NoError(t, err) + assert.Equal(t, mockClientID, clientID) + }) + + t.Run("bad - should return error if client_id key is not present", func(t *testing.T) { + md := metadata.New(map[string]string{"other-key": "other-value"}) + ctx := metadata.NewIncomingContext(t.Context(), md) + + _, err := GetClientIDFromContext(ctx) + require.Error(t, err) + require.ErrorIs(t, err, ErrMissingClientID) + }) + + t.Run("bad - should return error if no metadata in context", func(t *testing.T) { + _, err := GetClientIDFromContext(t.Context()) + require.Error(t, err) + require.ErrorIs(t, err, ErrNoMetadataFound) + }) + + t.Run("bad - should return error if more than one metadata client_id key in context", func(t *testing.T) { + md := metadata.Pairs(clientIDKey, "id-1", clientIDKey, "id-2") + ctx := metadata.NewIncomingContext(t.Context(), md) + + _, err := GetClientIDFromContext(ctx) + require.Error(t, err) + require.ErrorIs(t, err, ErrConflictClientID) + }) +} From d3eeba358000f019c9c650be117daf4c83327619 Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Tue, 23 Sep 2025 15:47:56 -0700 Subject: [PATCH 2/9] ConnectUnaryServerInterceptor test --- service/internal/auth/authn_test.go | 61 +++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/service/internal/auth/authn_test.go b/service/internal/auth/authn_test.go index b5b7f302ed..d537666bfc 100644 --- a/service/internal/auth/authn_test.go +++ b/service/internal/auth/authn_test.go @@ -39,6 +39,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/wrapperspb" ) @@ -263,6 +264,63 @@ func (s *AuthSuite) Test_IPCUnaryServerInterceptor() { s.Contains(err.Error(), "unauthenticated") } +func (s *AuthSuite) Test_ConnectUnaryServerInterceptor_ClientIDPropagated() { + tok := jwt.New() + s.Require().NoError(tok.Set(jwt.ExpirationKey, time.Now().Add(time.Hour))) + s.Require().NoError(tok.Set("iss", s.server.URL)) + s.Require().NoError(tok.Set("aud", "test")) + // default client ID claim in policy config is 'azp' + s.Require().NoError(tok.Set("azp", "test-client-id")) + s.Require().NoError(tok.Set("realm_access", map[string][]string{"roles": {"opentdf-standard"}})) + + policyCfg := new(PolicyConfig) + err := defaults.Set(policyCfg) + s.Require().NoError(err) + + authnConfig := AuthNConfig{ + Issuer: s.server.URL, + Audience: "test", + Policy: *policyCfg, + } + config := Config{ + AuthNConfig: authnConfig, + } + auth, err := NewAuthenticator(context.Background(), config, &logger.Logger{ + Logger: slog.New(slog.Default().Handler()), + }, func(_ string, _ any) error { return nil }) + s.Require().NoError(err) + + // Sign the token + signedTok, err := jwt.Sign(tok, jwt.WithKey(jwa.RS256, s.key)) + s.Require().NoError(err) + + // Create a minimal connect server setup to properly test the interceptor + // This is necessary because connect requests need proper procedure routing + interceptor := connect.WithInterceptors(auth.ConnectUnaryServerInterceptor()) + + fakeServer := &FakeAccessServiceServer{} + mux := http.NewServeMux() + path, handler := kasconnect.NewAccessServiceHandler(fakeServer, interceptor) + mux.Handle(path, handler) + + server := memhttp.New(mux) + defer server.Close() + + // Create a connect client that sends a Bearer token + conn, _ := grpc.NewClient("passthrough://bufconn", grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { + return server.Listener.DialContext(ctx, "tcp", "http://localhost:8080") + }), grpc.WithTransportCredentials(insecure.NewCredentials())) + + client := kas.NewAccessServiceClient(conn) + + // Make the request + _, err = client.Rewrap(metadata.AppendToOutgoingContext(s.T().Context(), "authorization", "Bearer "+string(signedTok)), &kas.RewrapRequest{}) + s.Require().NoError(err) + + // Assert that the client ID was properly extracted and set in the context + s.Equal("test-client-id", fakeServer.clientID) +} + func (s *AuthSuite) Test_CheckToken_When_JWT_Expired_Expect_Error() { tok := jwt.New() s.Require().NoError(tok.Set(jwt.ExpirationKey, time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC))) @@ -526,7 +584,10 @@ func (s *AuthSuite) TestDPoPEndToEnd_GRPC() { _, err = client.Rewrap(context.Background(), &kas.RewrapRequest{}) s.Require().NoError(err) + + // interceptor propagated clientID from the token at the configured claim s.Equal(fakeServer.clientID, "client-123") + s.NotNil(fakeServer.dpopKey) dpopJWKFromRequest, ok := fakeServer.dpopKey.(jwk.RSAPublicKey) s.True(ok) From 80fa33a0e6e43e20998240ad1a4dc21e9e11510f Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Wed, 24 Sep 2025 04:38:54 -0700 Subject: [PATCH 3/9] suggestions --- service/internal/auth/authn.go | 46 ++++++++------------- service/internal/auth/authn_test.go | 63 ++++++++++++++++++++++++++++- 2 files changed, 79 insertions(+), 30 deletions(-) diff --git a/service/internal/auth/authn.go b/service/internal/auth/authn.go index ef1a6c011d..88994d39d7 100644 --- a/service/internal/auth/authn.go +++ b/service/internal/auth/authn.go @@ -70,9 +70,6 @@ const ( ActionDelete = "delete" ActionUnsafe = "unsafe" ActionOther = "other" - - mdAccessTokenKey = "access_token" - mdClientIDKey = "client_id" ) // Authentication holds a jwks cache and information about the openid configuration @@ -244,15 +241,7 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler { return } - var clientID string - clientIDClaim := a.oidcConfiguration.Policy.ClientIDClaim - if clientIDClaim != "" { - if id, ok := accessTok.Get(clientIDClaim); ok { - if clientIDClaimValue, ok := id.(string); ok { - clientID = clientIDClaimValue - } - } - } + clientID, clientIDClaim := a.getClientIDFromToken(accessTok) ctxWithMetadata := ctxAuth.ContextWithAuthnMetadata(ctxWithJWK, clientID) // Check if the token is allowed to access the resource @@ -343,15 +332,7 @@ func (a Authentication) ConnectUnaryServerInterceptor() connect.UnaryInterceptor return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("unauthenticated")) } - var clientID string - clientIDClaim := a.oidcConfiguration.Policy.ClientIDClaim - if clientIDClaim != "" { - if id, ok := token.Get(clientIDClaim); ok { - if idStr, ok := id.(string); ok { - clientID = idStr - } - } - } + clientID, clientIDClaim := a.getClientIDFromToken(token) ctxWithMetadata := ctxAuth.ContextWithAuthnMetadata(ctxWithJWK, clientID) // Check if the token is allowed to access the resource @@ -704,16 +685,23 @@ func (a Authentication) ipcReauthCheck(ctx context.Context, path string, header } // Return the next context with the token - var clientID string - if clientIDClaim := a.oidcConfiguration.Policy.ClientIDClaim; clientIDClaim != "" { - if id, ok := token.Get(clientIDClaim); ok { - if idStr, ok := id.(string); ok { - clientID = idStr - } - } - } + clientID, _ := a.getClientIDFromToken(token) return ctxAuth.ContextWithAuthnMetadata(ctxWithJWK, clientID), nil } } return ctx, nil } + +// getClientIDFromToken returns the client ID from the token and the configured claim name +func (a *Authentication) getClientIDFromToken(tok jwt.Token) (string, string) { + var clientID string + clientIDClaim := a.oidcConfiguration.Policy.ClientIDClaim + if clientIDClaim != "" { + if val, ok := tok.Get(clientIDClaim); ok { + if strVal, ok := val.(string); ok { + clientID = strVal + } + } + } + return clientID, clientIDClaim +} diff --git a/service/internal/auth/authn_test.go b/service/internal/auth/authn_test.go index d537666bfc..f57806f2b1 100644 --- a/service/internal/auth/authn_test.go +++ b/service/internal/auth/authn_test.go @@ -584,7 +584,7 @@ func (s *AuthSuite) TestDPoPEndToEnd_GRPC() { _, err = client.Rewrap(context.Background(), &kas.RewrapRequest{}) s.Require().NoError(err) - + // interceptor propagated clientID from the token at the configured claim s.Equal(fakeServer.clientID, "client-123") @@ -876,3 +876,64 @@ func (s *AuthSuite) Test_LookupGatewayPaths() { }) } } + +func Test_GetClientIDFromToken(t *testing.T) { + tests := []struct { + name string + claims map[string]interface{} + clientIDClaim string + expectedClientID string + expectedClaimName string + }{ + { + name: "Happy Path", + claims: map[string]interface{}{ + "cid": "test-client-id", + }, + clientIDClaim: "cid", + expectedClientID: "test-client-id", + expectedClaimName: "cid", + }, + { + name: "Claim not found", + claims: map[string]interface{}{ + "other-claim": "some-value", + }, + clientIDClaim: "cid", + expectedClientID: "", + expectedClaimName: "cid", + }, + { + name: "Other claim name", + claims: map[string]interface{}{ + "client": "test-client-id", + }, + clientIDClaim: "client", + expectedClientID: "test-client-id", + expectedClaimName: "client", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + auth := &Authentication{ + oidcConfiguration: AuthNConfig{ + Policy: PolicyConfig{ + ClientIDClaim: tt.clientIDClaim, + }, + }, + } + + tok := jwt.New() + for k, v := range tt.claims { + err := tok.Set(k, v) + require.NoError(t, err) + } + + clientID, clientIDClaimName := auth.getClientIDFromToken(tok) + + assert.Equal(t, tt.expectedClientID, clientID) + assert.Equal(t, tt.expectedClaimName, clientIDClaimName) + }) + } +} From 6de91b035b46a08621c07c17c795ef95108c173e Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Wed, 24 Sep 2025 07:55:50 -0700 Subject: [PATCH 4/9] lint fixes --- service/internal/auth/authn.go | 2 +- service/internal/auth/authn_test.go | 7 ++++--- service/pkg/auth/context_auth_test.go | 6 +++--- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/service/internal/auth/authn.go b/service/internal/auth/authn.go index 88994d39d7..34e1be0851 100644 --- a/service/internal/auth/authn.go +++ b/service/internal/auth/authn.go @@ -697,7 +697,7 @@ func (a *Authentication) getClientIDFromToken(tok jwt.Token) (string, string) { var clientID string clientIDClaim := a.oidcConfiguration.Policy.ClientIDClaim if clientIDClaim != "" { - if val, ok := tok.Get(clientIDClaim); ok { + if val, exists := tok.Get(clientIDClaim); exists { if strVal, ok := val.(string); ok { clientID = strVal } diff --git a/service/internal/auth/authn_test.go b/service/internal/auth/authn_test.go index f57806f2b1..8c14546fdd 100644 --- a/service/internal/auth/authn_test.go +++ b/service/internal/auth/authn_test.go @@ -219,7 +219,8 @@ func TestNormalizeUrl(t *testing.T) { func (s *AuthSuite) Test_IPCUnaryServerInterceptor() { // Mock the checkToken method to return a valid token and context mockToken := jwt.New() - mockToken.Set("cid", "mockClientID") + err := mockToken.Set("cid", "mockClientID") + s.Require().NoError(err) type contextKey string mockCtx := context.WithValue(context.Background(), contextKey("mockKey"), "mockValue") @@ -586,7 +587,7 @@ func (s *AuthSuite) TestDPoPEndToEnd_GRPC() { s.Require().NoError(err) // interceptor propagated clientID from the token at the configured claim - s.Equal(fakeServer.clientID, "client-123") + s.Equal("client-123", fakeServer.clientID) s.NotNil(fakeServer.dpopKey) dpopJWKFromRequest, ok := fakeServer.dpopKey.(jwk.RSAPublicKey) @@ -667,7 +668,7 @@ func (s *AuthSuite) TestDPoPEndToEnd_HTTP() { s.Require().FailNow("timed out waiting for call to complete") } - s.Equal(clientID, "client2") + s.Equal("client2", clientID) s.NotNil(dpopKeyFromRequest) dpopJWKFromRequest, ok := dpopKeyFromRequest.(jwk.RSAPublicKey) diff --git a/service/pkg/auth/context_auth_test.go b/service/pkg/auth/context_auth_test.go index ef8f68dd4c..75b7795e9e 100644 --- a/service/pkg/auth/context_auth_test.go +++ b/service/pkg/auth/context_auth_test.go @@ -77,7 +77,7 @@ func TestContextWithAuthnMetadata(t *testing.T) { mockClientID := "test-client-id" t.Run("should add access token and client id to metadata", func(t *testing.T) { - ctx := ContextWithAuthNInfo(context.Background(), nil, nil, "raw-token-string") + ctx := ContextWithAuthNInfo(t.Context(), nil, nil, "raw-token-string") enrichedCtx := ContextWithAuthnMetadata(ctx, mockClientID) md, ok := metadata.FromIncomingContext(enrichedCtx) @@ -93,7 +93,7 @@ func TestContextWithAuthnMetadata(t *testing.T) { }) t.Run("should not set client id if empty", func(t *testing.T) { - ctx := ContextWithAuthNInfo(context.Background(), nil, nil, "raw-token-string") + ctx := ContextWithAuthNInfo(t.Context(), nil, nil, "raw-token-string") enrichedCtx := ContextWithAuthnMetadata(ctx, "") md, ok := metadata.FromIncomingContext(enrichedCtx) @@ -105,7 +105,7 @@ func TestContextWithAuthnMetadata(t *testing.T) { t.Run("should preserve existing metadata", func(t *testing.T) { originalMD := metadata.New(map[string]string{"original-key": "original-value"}) - ctx := metadata.NewIncomingContext(context.Background(), originalMD) + ctx := metadata.NewIncomingContext(t.Context(), originalMD) ctx = ContextWithAuthNInfo(ctx, nil, nil, "raw-token-string") enrichedCtx := ContextWithAuthnMetadata(ctx, mockClientID) From 8a2c136d5b67410ac251151e01adc7f465e836e9 Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Wed, 24 Sep 2025 13:00:31 -0700 Subject: [PATCH 5/9] docs and example configs --- docs/Configuring.md | 3 +++ opentdf-dev.yaml | 2 ++ opentdf-ers-mode.yaml | 2 ++ opentdf-example.yaml | 2 ++ opentdf-kas-mode.yaml | 2 ++ 5 files changed, 11 insertions(+) diff --git a/docs/Configuring.md b/docs/Configuring.md index 65ab172cc0..3167fa14c9 100644 --- a/docs/Configuring.md +++ b/docs/Configuring.md @@ -352,6 +352,9 @@ server: ## Dot notation is used to access the groups claim group_claim: "realm_access.roles" + + # Dot notation is used to access the claim the represents the idP client ID + client_id_claim: # azp ## Deprecated: Use standard casbin policy groupings (g, , ) ## Maps the external role to the OpenTDF role diff --git a/opentdf-dev.yaml b/opentdf-dev.yaml index a0a565af62..3361bd6ec4 100644 --- a/opentdf-dev.yaml +++ b/opentdf-dev.yaml @@ -72,6 +72,8 @@ server: username_claim: # preferred_username # That claim to access groups (i.e. realm_access.roles) groups_claim: # realm_access.roles + # Claim the represents the idP client ID + client_id_claim: # azp ## Extends the builtin policy extension: | g, opentdf-admin, role:admin diff --git a/opentdf-ers-mode.yaml b/opentdf-ers-mode.yaml index 1b0e5f3f7e..a396b963a8 100644 --- a/opentdf-ers-mode.yaml +++ b/opentdf-ers-mode.yaml @@ -28,6 +28,8 @@ server: default: #"role:standard" ## Dot notation is used to access nested claims (i.e. realm_access.roles) claim: # realm_access.roles + # Claim the represents the idP client ID + client_id_claim: # azp ## Maps the external role to the opentdf role ## Note: left side is used in the policy, right side is the external role map: diff --git a/opentdf-example.yaml b/opentdf-example.yaml index 3c012632ba..c110295121 100644 --- a/opentdf-example.yaml +++ b/opentdf-example.yaml @@ -53,6 +53,8 @@ server: username_claim: # preferred_username # That claim to access groups (i.e. realm_access.roles) groups_claim: # realm_access.roles + # Claim the represents the idP client ID + client_id_claim: # azp ## Extends the builtin policy extension: | g, opentdf-admin, role:admin diff --git a/opentdf-kas-mode.yaml b/opentdf-kas-mode.yaml index e7532d4e63..cbfaee1f06 100644 --- a/opentdf-kas-mode.yaml +++ b/opentdf-kas-mode.yaml @@ -45,6 +45,8 @@ server: default: #"role:standard" ## Dot notation is used to access nested claims (i.e. realm_access.roles) claim: # realm_access.roles + # Claim the represents the idP client ID + client_id_claim: # azp ## Maps the external role to the opentdf role ## Note: left side is used in the policy, right side is the external role map: From c7ab27c27c1a4327447d1a8ca2fda414c6f564ad Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Wed, 24 Sep 2025 13:00:45 -0700 Subject: [PATCH 6/9] implementation of dot notation for config claim specification --- service/internal/auth/authn.go | 103 +++++++++++++++++++--------- service/internal/auth/authn_test.go | 99 ++++++++++++++++++++------ 2 files changed, 149 insertions(+), 53 deletions(-) diff --git a/service/internal/auth/authn.go b/service/internal/auth/authn.go index 34e1be0851..b049037aa1 100644 --- a/service/internal/auth/authn.go +++ b/service/internal/auth/authn.go @@ -61,6 +61,11 @@ var ( jwa.PS384: true, jwa.PS512: true, } + + // Exported error variables for client ID processing + ErrClientIDClaimNotConfigured = errors.New("no client ID claim configured") + ErrClientIDClaimNotFound = errors.New("client ID claim not found") + ErrClientIDClaimNotString = errors.New("client ID claim is not a string") ) const ( @@ -163,7 +168,7 @@ func NewAuthenticator(ctx context.Context, cfg Config, logger *logger.Logger, we // Try an register oidc issuer to wellknown service but don't return an error if it fails if err := wellknownRegistration("platform_issuer", cfg.Issuer); err != nil { - logger.Warn("failed to register platform issuer", slog.String("error", err.Error())) + logger.Warn("failed to register platform issuer", slog.Any("error", err)) } var oidcConfigMap map[string]any @@ -179,7 +184,7 @@ func NewAuthenticator(ctx context.Context, cfg Config, logger *logger.Logger, we } if err := wellknownRegistration("idp", oidcConfigMap); err != nil { - logger.Warn("failed to register platform idp information", slog.String("error", err.Error())) + logger.Warn("failed to register platform idp information", slog.Any("error", err)) } return a, nil @@ -211,6 +216,7 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler { } dp := r.Header.Values("Dpop") + log := a.logger // Verify the token header := r.Header["Authorization"] @@ -227,12 +233,13 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler { origin = "http://" + strings.TrimSuffix(origin, ":80") } } - accessTok, ctxWithJWK, err := a.checkToken(r.Context(), header, receiverInfo{ + ctx := r.Context() + accessTok, ctxWithJWK, err := a.checkToken(ctx, header, receiverInfo{ u: []string{normalizeURL(origin, r.URL)}, m: []string{r.Method}, }, dp) if err != nil { - slog.WarnContext(r.Context(), + log.WarnContext(ctx, "unauthenticated", slog.Any("error", err), slog.Any("dpop", dp), @@ -241,8 +248,19 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler { return } - clientID, clientIDClaim := a.getClientIDFromToken(accessTok) - ctxWithMetadata := ctxAuth.ContextWithAuthnMetadata(ctxWithJWK, clientID) + clientID, err := a.getClientIDFromToken(ctx, accessTok) + if err != nil { + log.WarnContext( + ctx, + "could not determine client ID from token", + slog.Any("err", err), + ) + } else { + log = log. + With("client_id", clientID). + With("configured_client_id_claim_name", a.oidcConfiguration.Policy.ClientIDClaim) + ctx = ctxAuth.ContextWithAuthnMetadata(ctxWithJWK, clientID) + } // Check if the token is allowed to access the resource var action string @@ -258,11 +276,9 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler { } if allow, err := a.enforcer.Enforce(accessTok, r.URL.Path, action); err != nil { if err.Error() == "permission denied" { - a.logger.WarnContext(r.Context(), + log.WarnContext(ctx, "permission denied", slog.String("azp", accessTok.Subject()), - slog.String("configured_client_id_claim_name", clientIDClaim), - slog.String("client_id", clientID), slog.Any("error", err), ) http.Error(w, "permission denied", http.StatusForbidden) @@ -271,18 +287,16 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler { http.Error(w, "internal server error", http.StatusInternalServerError) return } else if !allow { - a.logger.WarnContext( - r.Context(), + log.WarnContext( + ctx, "permission denied", slog.String("azp", accessTok.Subject()), - slog.String("configured_client_id_claim_name", clientIDClaim), - slog.String("client_id", clientID), ) http.Error(w, "permission denied", http.StatusForbidden) return } - r = r.WithContext(ctxWithMetadata) + r = r.WithContext(ctx) handler.ServeHTTP(w, r) }) } @@ -299,6 +313,8 @@ func (a Authentication) ConnectUnaryServerInterceptor() connect.UnaryInterceptor return next(ctx, req) } + log := a.logger + ri := receiverInfo{ u: []string{req.Spec().Procedure}, m: []string{http.MethodPost}, @@ -332,27 +348,38 @@ func (a Authentication) ConnectUnaryServerInterceptor() connect.UnaryInterceptor return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("unauthenticated")) } - clientID, clientIDClaim := a.getClientIDFromToken(token) - ctxWithMetadata := ctxAuth.ContextWithAuthnMetadata(ctxWithJWK, clientID) + clientID, err := a.getClientIDFromToken(ctx, token) + if err != nil { + log.WarnContext( + ctx, + "could not determine client ID from token", + slog.Any("err", err), + ) + } else { + log = log. + With("client_id", clientID). + With("configured_client_id_claim_name", a.oidcConfiguration.Policy.ClientIDClaim) + ctxWithJWK = ctxAuth.ContextWithAuthnMetadata(ctxWithJWK, clientID) + } // Check if the token is allowed to access the resource if allowed, err := a.enforcer.Enforce(token, resource, action); err != nil { if err.Error() == "permission denied" { - a.logger.Warn("permission denied", + log.WarnContext( + ctxWithJWK, + "permission denied", slog.String("azp", token.Subject()), - slog.String("configured_client_id_claim_name", clientIDClaim), - slog.String("client_id", clientID), slog.Any("error", err), ) return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied")) } return nil, err } else if !allowed { - a.logger.Warn("permission denied", slog.String("azp", token.Subject())) + log.WarnContext(ctxWithJWK, "permission denied", slog.String("azp", token.Subject())) return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied")) } - return next(ctxWithMetadata, req) + return next(ctxWithJWK, req) }) } return connect.UnaryInterceptorFunc(interceptor) @@ -407,7 +434,7 @@ func (a *Authentication) checkToken(ctx context.Context, authHeader []string, dp case strings.HasPrefix(authHeader[0], "Bearer "): tokenRaw = strings.TrimPrefix(authHeader[0], "Bearer ") default: - a.logger.Warn("failed to validate authentication header: not of type bearer or dpop", slog.String("header", authHeader[0])) + a.logger.WarnContext(ctx, "failed to validate authentication header: not of type bearer or dpop", slog.String("header", authHeader[0])) return nil, nil, errors.New("not of type bearer or dpop") } @@ -685,23 +712,33 @@ func (a Authentication) ipcReauthCheck(ctx context.Context, path string, header } // Return the next context with the token - clientID, _ := a.getClientIDFromToken(token) + clientID, err := a.getClientIDFromToken(ctxWithJWK, token) + if err != nil { + return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("unauthenticated")) + } return ctxAuth.ContextWithAuthnMetadata(ctxWithJWK, clientID), nil } } return ctx, nil } -// getClientIDFromToken returns the client ID from the token and the configured claim name -func (a *Authentication) getClientIDFromToken(tok jwt.Token) (string, string) { - var clientID string +// getClientIDFromToken returns the client ID from the token if found (dot notation) +func (a *Authentication) getClientIDFromToken(ctx context.Context, tok jwt.Token) (string, error) { clientIDClaim := a.oidcConfiguration.Policy.ClientIDClaim - if clientIDClaim != "" { - if val, exists := tok.Get(clientIDClaim); exists { - if strVal, ok := val.(string); ok { - clientID = strVal - } - } + if clientIDClaim == "" { + return "", ErrClientIDClaimNotConfigured + } + claimsMap, err := tok.AsMap(ctx) + if err != nil { + return "", fmt.Errorf("failed to parse token as a map and find claim at [%s]: %w", clientIDClaim, err) + } + found := dotNotation(claimsMap, clientIDClaim) + if found == nil { + return "", fmt.Errorf("%w at [%s]", ErrClientIDClaimNotFound, clientIDClaim) + } + clientID, isString := found.(string) + if !isString { + return "", fmt.Errorf("%w at [%s]", ErrClientIDClaimNotString, clientIDClaim) } - return clientID, clientIDClaim + return clientID, nil } diff --git a/service/internal/auth/authn_test.go b/service/internal/auth/authn_test.go index 8c14546fdd..5d515eff00 100644 --- a/service/internal/auth/authn_test.go +++ b/service/internal/auth/authn_test.go @@ -880,38 +880,91 @@ func (s *AuthSuite) Test_LookupGatewayPaths() { func Test_GetClientIDFromToken(t *testing.T) { tests := []struct { - name string - claims map[string]interface{} - clientIDClaim string - expectedClientID string - expectedClaimName string + name string + claims map[string]interface{} + clientIDClaim string + expectedClientID string + expectedErr error + expectError bool }{ { - name: "Happy Path", + name: "Happy Path - simple claim", claims: map[string]interface{}{ "cid": "test-client-id", }, - clientIDClaim: "cid", - expectedClientID: "test-client-id", - expectedClaimName: "cid", + clientIDClaim: "cid", + expectedClientID: "test-client-id", + expectError: false, }, { - name: "Claim not found", + name: "Happy Path - different claim name", + claims: map[string]interface{}{ + "client": "test-client-id", + }, + clientIDClaim: "client", + expectedClientID: "test-client-id", + expectError: false, + }, + { + name: "Happy Path - dot notation", + claims: map[string]interface{}{ + "client": map[string]interface{}{ + "info": map[string]interface{}{ + "id": "test-client-id", + }, + }, + }, + clientIDClaim: "client.info.id", + expectedClientID: "test-client-id", + expectError: false, + }, + { + name: "Error - no client ID claim configured", + claims: map[string]interface{}{"cid": "test"}, + clientIDClaim: "", // empty claim name + expectedClientID: "", + expectedErr: ErrClientIDClaimNotConfigured, + expectError: true, + }, + { + name: "Error - claim not found", claims: map[string]interface{}{ "other-claim": "some-value", }, - clientIDClaim: "cid", - expectedClientID: "", - expectedClaimName: "cid", + clientIDClaim: "cid", + expectedClientID: "", + expectedErr: ErrClientIDClaimNotFound, + expectError: true, }, { - name: "Other claim name", + name: "Error - claim is not a string (int)", claims: map[string]interface{}{ - "client": "test-client-id", + "cid": 12345, + }, + clientIDClaim: "cid", + expectedClientID: "", + expectedErr: ErrClientIDClaimNotString, + expectError: true, + }, + { + name: "Error - claim is not a string (bool)", + claims: map[string]interface{}{ + "cid": true, + }, + clientIDClaim: "cid", + expectedClientID: "", + expectedErr: ErrClientIDClaimNotString, + expectError: true, + }, + { + name: "Error - claim is not a string (object)", + claims: map[string]interface{}{ + "cid": map[string]interface{}{"nested": "value"}, }, - clientIDClaim: "client", - expectedClientID: "test-client-id", - expectedClaimName: "client", + clientIDClaim: "cid", + expectedClientID: "", + expectedErr: ErrClientIDClaimNotString, + expectError: true, }, } @@ -931,10 +984,16 @@ func Test_GetClientIDFromToken(t *testing.T) { require.NoError(t, err) } - clientID, clientIDClaimName := auth.getClientIDFromToken(tok) + clientID, err := auth.getClientIDFromToken(context.Background(), tok) assert.Equal(t, tt.expectedClientID, clientID) - assert.Equal(t, tt.expectedClaimName, clientIDClaimName) + + if tt.expectError { + require.Error(t, err) + assert.ErrorIs(t, err, tt.expectedErr) + } else { + require.NoError(t, err) + } }) } } From 61e536bd432ec750a806b1c46d0cdcbbc9a41b8d Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Wed, 24 Sep 2025 14:03:54 -0700 Subject: [PATCH 7/9] fixes --- service/internal/auth/authn.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/service/internal/auth/authn.go b/service/internal/auth/authn.go index b049037aa1..cc2e4151eb 100644 --- a/service/internal/auth/authn.go +++ b/service/internal/auth/authn.go @@ -233,8 +233,7 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler { origin = "http://" + strings.TrimSuffix(origin, ":80") } } - ctx := r.Context() - accessTok, ctxWithJWK, err := a.checkToken(ctx, header, receiverInfo{ + accessTok, ctx, err := a.checkToken(r.Context(), header, receiverInfo{ u: []string{normalizeURL(origin, r.URL)}, m: []string{r.Method}, }, dp) @@ -259,7 +258,7 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler { log = log. With("client_id", clientID). With("configured_client_id_claim_name", a.oidcConfiguration.Policy.ClientIDClaim) - ctx = ctxAuth.ContextWithAuthnMetadata(ctxWithJWK, clientID) + ctx = ctxAuth.ContextWithAuthnMetadata(ctx, clientID) } // Check if the token is allowed to access the resource @@ -276,7 +275,8 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler { } if allow, err := a.enforcer.Enforce(accessTok, r.URL.Path, action); err != nil { if err.Error() == "permission denied" { - log.WarnContext(ctx, + log.WarnContext( + ctx, "permission denied", slog.String("azp", accessTok.Subject()), slog.Any("error", err), @@ -348,10 +348,10 @@ func (a Authentication) ConnectUnaryServerInterceptor() connect.UnaryInterceptor return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("unauthenticated")) } - clientID, err := a.getClientIDFromToken(ctx, token) + clientID, err := a.getClientIDFromToken(ctxWithJWK, token) if err != nil { log.WarnContext( - ctx, + ctxWithJWK, "could not determine client ID from token", slog.Any("err", err), ) From 2fc13d46104d824c494499543a3da3c8179ddd97 Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Wed, 24 Sep 2025 14:19:31 -0700 Subject: [PATCH 8/9] lint fix --- service/internal/auth/authn_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/service/internal/auth/authn_test.go b/service/internal/auth/authn_test.go index 5d515eff00..05dd178744 100644 --- a/service/internal/auth/authn_test.go +++ b/service/internal/auth/authn_test.go @@ -984,7 +984,7 @@ func Test_GetClientIDFromToken(t *testing.T) { require.NoError(t, err) } - clientID, err := auth.getClientIDFromToken(context.Background(), tok) + clientID, err := auth.getClientIDFromToken(t.Context(), tok) assert.Equal(t, tt.expectedClientID, clientID) From 99fea3433b0a5a55bf05ca783f9b4561259d4cec Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Wed, 24 Sep 2025 14:28:53 -0700 Subject: [PATCH 9/9] fixes --- service/internal/auth/authn_test.go | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/service/internal/auth/authn_test.go b/service/internal/auth/authn_test.go index 05dd178744..0883703025 100644 --- a/service/internal/auth/authn_test.go +++ b/service/internal/auth/authn_test.go @@ -180,9 +180,7 @@ func (s *AuthSuite) SetupTest() { "/static-doublestar4/x/**", }, }, - &logger.Logger{ - Logger: slog.New(slog.Default().Handler()), - }, + logger.CreateTestLogger(), func(_ string, _ any) error { return nil }, ) @@ -286,9 +284,7 @@ func (s *AuthSuite) Test_ConnectUnaryServerInterceptor_ClientIDPropagated() { config := Config{ AuthNConfig: authnConfig, } - auth, err := NewAuthenticator(context.Background(), config, &logger.Logger{ - Logger: slog.New(slog.Default().Handler()), - }, func(_ string, _ any) error { return nil }) + auth, err := NewAuthenticator(s.T().Context(), config, logger.CreateTestLogger(), func(_ string, _ any) error { return nil }) s.Require().NoError(err) // Sign the token