From a60c253fcc0e6570d2d94d6f47532b4ef9760c92 Mon Sep 17 00:00:00 2001 From: Trey Date: Fri, 3 Apr 2026 08:08:45 -0700 Subject: [PATCH 1/2] In-process jwks retrieval in vMCP --- cmd/vmcp/app/commands.go | 11 +- pkg/vmcp/auth/factory/authz_not_wired_test.go | 6 +- pkg/vmcp/auth/factory/incoming.go | 13 +- .../auth/factory/incoming_keyprovider_test.go | 144 ++++++++++++++++++ pkg/vmcp/auth/factory/incoming_test.go | 2 +- .../auth/factory/incoming_upstream_test.go | 6 +- 6 files changed, 168 insertions(+), 14 deletions(-) create mode 100644 pkg/vmcp/auth/factory/incoming_keyprovider_test.go diff --git a/cmd/vmcp/app/commands.go b/cmd/vmcp/app/commands.go index 9ddd4b712f..c520589fab 100644 --- a/cmd/vmcp/app/commands.go +++ b/cmd/vmcp/app/commands.go @@ -24,6 +24,7 @@ import ( "github.com/stacklok/toolhive/pkg/auth/upstreamtoken" authserverconfig "github.com/stacklok/toolhive/pkg/authserver" authserverrunner "github.com/stacklok/toolhive/pkg/authserver/runner" + "github.com/stacklok/toolhive/pkg/authserver/server/keys" "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/groups" "github.com/stacklok/toolhive/pkg/telemetry" @@ -589,18 +590,20 @@ func runServe(cmd *cobra.Command, _ []string) error { } } - // Create an upstream token reader from the embedded auth server so that - // the OIDC middleware can enrich Identity with upstream provider tokens. - // This is required for the upstream_inject outgoing auth strategy. + // Extract dependencies from the embedded auth server so the OIDC middleware + // can (a) resolve JWKS keys in-process instead of self-referential HTTP + // calls, and (b) enrich Identity with upstream provider tokens. var upstreamReader upstreamtoken.TokenReader + var keyProvider keys.PublicKeyProvider if embeddedAuthServer != nil { stor := embeddedAuthServer.IDPTokenStorage() refresher := embeddedAuthServer.UpstreamTokenRefresher() upstreamReader = upstreamtoken.NewInProcessService(stor, refresher) + keyProvider = embeddedAuthServer.KeyProvider() } authMiddleware, authzMiddleware, authInfoHandler, err := - factory.NewIncomingAuthMiddleware(ctx, cfg.IncomingAuth, passThroughTools, upstreamReader) + factory.NewIncomingAuthMiddleware(ctx, cfg.IncomingAuth, passThroughTools, upstreamReader, keyProvider) if err != nil { return fmt.Errorf("failed to create authentication middleware: %w", err) } diff --git a/pkg/vmcp/auth/factory/authz_not_wired_test.go b/pkg/vmcp/auth/factory/authz_not_wired_test.go index 0c07c33601..06cb966a2b 100644 --- a/pkg/vmcp/auth/factory/authz_not_wired_test.go +++ b/pkg/vmcp/auth/factory/authz_not_wired_test.go @@ -47,7 +47,7 @@ func TestNewIncomingAuthMiddleware_AuthzEnforced(t *testing.T) { }, } - authMw, authzMw, _, err := NewIncomingAuthMiddleware(t.Context(), cfg, nil, nil) + authMw, authzMw, _, err := NewIncomingAuthMiddleware(t.Context(), cfg, nil, nil, nil) require.NoError(t, err, "middleware creation should succeed") require.NotNil(t, authMw, "auth middleware should not be nil") require.NotNil(t, authzMw, "authz middleware should not be nil") @@ -105,7 +105,7 @@ func TestNewIncomingAuthMiddleware_AuthzEnforced(t *testing.T) { }, } - authMw, authzMw, _, err := NewIncomingAuthMiddleware(t.Context(), cfg, nil, nil) + authMw, authzMw, _, err := NewIncomingAuthMiddleware(t.Context(), cfg, nil, nil, nil) require.NoError(t, err, "middleware creation should succeed") require.NotNil(t, authMw, "auth middleware should not be nil") require.NotNil(t, authzMw, "authz middleware should not be nil") @@ -163,7 +163,7 @@ func TestNewIncomingAuthMiddleware_AuthzApproveAndBlock(t *testing.T) { }, } - authMw, authzMw, _, err := NewIncomingAuthMiddleware(t.Context(), cfg, nil, nil) + authMw, authzMw, _, err := NewIncomingAuthMiddleware(t.Context(), cfg, nil, nil, nil) require.NoError(t, err, "middleware creation should succeed") require.NotNil(t, authMw, "auth middleware should not be nil") require.NotNil(t, authzMw, "authz middleware should not be nil") diff --git a/pkg/vmcp/auth/factory/incoming.go b/pkg/vmcp/auth/factory/incoming.go index 148ca2c18a..c27d842573 100644 --- a/pkg/vmcp/auth/factory/incoming.go +++ b/pkg/vmcp/auth/factory/incoming.go @@ -11,6 +11,7 @@ import ( "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/auth/upstreamtoken" + "github.com/stacklok/toolhive/pkg/authserver/server/keys" "github.com/stacklok/toolhive/pkg/authz" "github.com/stacklok/toolhive/pkg/authz/authorizers" "github.com/stacklok/toolhive/pkg/authz/authorizers/cedar" @@ -51,6 +52,7 @@ func NewIncomingAuthMiddleware( cfg *config.IncomingAuthConfig, passThroughTools map[string]struct{}, upstreamReader upstreamtoken.TokenReader, + keyProvider keys.PublicKeyProvider, ) ( authMw func(http.Handler) http.Handler, authzMw func(http.Handler) http.Handler, @@ -65,7 +67,7 @@ func NewIncomingAuthMiddleware( switch cfg.Type { case "oidc": - authMiddleware, authInfoHandler, err = newOIDCAuthMiddleware(ctx, cfg.OIDC, upstreamReader) + authMiddleware, authInfoHandler, err = newOIDCAuthMiddleware(ctx, cfg.OIDC, upstreamReader, keyProvider) case "local": authMiddleware, authInfoHandler, err = newLocalAuthMiddleware(ctx) case "anonymous": @@ -151,6 +153,7 @@ func newOIDCAuthMiddleware( ctx context.Context, oidcCfg *config.OIDCConfig, reader upstreamtoken.TokenReader, + keyProvider keys.PublicKeyProvider, ) (func(http.Handler) http.Handler, http.Handler, error) { if oidcCfg == nil { return nil, nil, fmt.Errorf("OIDC configuration required when Type='oidc'") @@ -175,9 +178,13 @@ func newOIDCAuthMiddleware( Scopes: oidcCfg.Scopes, } - // Wire the upstream token reader so the JWT validator can enrich Identity - // with upstream provider tokens (needed for upstream_inject auth strategy). + // Wire optional dependencies from the embedded auth server so the JWT + // validator can (a) resolve JWKS keys in-process instead of self-referential + // HTTP calls, and (b) enrich Identity with upstream provider tokens. var opts []auth.TokenValidatorOption + if keyProvider != nil { + opts = append(opts, auth.WithKeyProvider(keyProvider)) + } if reader != nil { opts = append(opts, auth.WithUpstreamTokenReader(reader)) } diff --git a/pkg/vmcp/auth/factory/incoming_keyprovider_test.go b/pkg/vmcp/auth/factory/incoming_keyprovider_test.go new file mode 100644 index 0000000000..65d33c0bc5 --- /dev/null +++ b/pkg/vmcp/auth/factory/incoming_keyprovider_test.go @@ -0,0 +1,144 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package factory + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + pkgauth "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/authserver/server/keys" + keysmocks "github.com/stacklok/toolhive/pkg/authserver/server/keys/mocks" + "github.com/stacklok/toolhive/pkg/vmcp/config" +) + +// TestNewOIDCAuthMiddleware_KeyProviderWiring verifies that an in-process +// PublicKeyProvider is used for JWKS key resolution, avoiding self-referential +// HTTP calls when the embedded auth server runs in the same process. +func TestNewOIDCAuthMiddleware_KeyProviderWiring(t *testing.T) { + t.Parallel() + + // Generate an ECDSA P-256 key pair (matching the embedded auth server's + // default GeneratingProvider algorithm). + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + const ecdsaKeyID = "test-ecdsa-key-1" + + // Stand up a minimal OIDC discovery server so issuer validation passes. + // The JWKS endpoint returns an empty key set — all key resolution should + // happen through the local provider, not HTTP. + server, _ := newTestOIDCServer(t) + t.Cleanup(server.Close) + + issuer := server.URL + + oidcCfg := &config.OIDCConfig{ + Issuer: issuer, + ClientID: "test-client", + Audience: "test-audience", + InsecureAllowHTTP: true, + JwksAllowPrivateIP: true, + } + + t.Run("keys resolved from local provider instead of HTTP", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockProvider := keysmocks.NewMockPublicKeyProvider(ctrl) + mockProvider.EXPECT(). + PublicKeys(gomock.Any()). + Return([]*keys.PublicKeyData{{ + KeyID: ecdsaKeyID, + Algorithm: "ES256", + PublicKey: &privateKey.PublicKey, + CreatedAt: time.Now(), + }}, nil). + AnyTimes() + + authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, nil, mockProvider) + require.NoError(t, err, "middleware creation should succeed with key provider") + require.NotNil(t, authMw) + + var capturedIdentity *pkgauth.Identity + handler := authMw(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + capturedIdentity, _ = pkgauth.IdentityFromContext(r.Context()) + })) + + // Sign a JWT with the ECDSA private key — only the local provider + // holds the matching public key. + tok := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{ + "iss": issuer, + "aud": "test-audience", + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + }) + tok.Header["kid"] = ecdsaKeyID + tokenString, err := tok.SignedString(privateKey) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("Authorization", "Bearer "+tokenString) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code, "request should succeed via local key provider") + require.NotNil(t, capturedIdentity, "identity should be present in context") + assert.Equal(t, "test-user", capturedIdentity.Subject) + }) + + t.Run("falls back to HTTP JWKS when key provider is nil", func(t *testing.T) { + t.Parallel() + + // Use the RSA key from the test OIDC server (served via HTTP JWKS). + httpServer, rsaPrivateKey := newTestOIDCServer(t) + t.Cleanup(httpServer.Close) + + httpIssuer := httpServer.URL + httpOIDCCfg := &config.OIDCConfig{ + Issuer: httpIssuer, + ClientID: "test-client", + Audience: "test-audience", + InsecureAllowHTTP: true, + JwksAllowPrivateIP: true, + } + + authMw, _, err := newOIDCAuthMiddleware(t.Context(), httpOIDCCfg, nil, nil) + require.NoError(t, err) + require.NotNil(t, authMw) + + var capturedIdentity *pkgauth.Identity + handler := authMw(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + capturedIdentity, _ = pkgauth.IdentityFromContext(r.Context()) + })) + + token := signJWT(t, rsaPrivateKey, jwt.MapClaims{ + "iss": httpIssuer, + "aud": "test-audience", + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code, "request should succeed via HTTP JWKS fallback") + require.NotNil(t, capturedIdentity, "identity should be present in context") + assert.Equal(t, "test-user", capturedIdentity.Subject) + }) +} diff --git a/pkg/vmcp/auth/factory/incoming_test.go b/pkg/vmcp/auth/factory/incoming_test.go index f93c914036..e27d36024b 100644 --- a/pkg/vmcp/auth/factory/incoming_test.go +++ b/pkg/vmcp/auth/factory/incoming_test.go @@ -135,7 +135,7 @@ func TestNewIncomingAuthMiddleware(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - authMw, authzMw, authInfo, err := NewIncomingAuthMiddleware(t.Context(), tt.cfg, nil, nil) + authMw, authzMw, authInfo, err := NewIncomingAuthMiddleware(t.Context(), tt.cfg, nil, nil, nil) if tt.wantErr { require.Error(t, err) diff --git a/pkg/vmcp/auth/factory/incoming_upstream_test.go b/pkg/vmcp/auth/factory/incoming_upstream_test.go index 34e4d4f9d2..c3271a6f17 100644 --- a/pkg/vmcp/auth/factory/incoming_upstream_test.go +++ b/pkg/vmcp/auth/factory/incoming_upstream_test.go @@ -113,7 +113,7 @@ func TestNewOIDCAuthMiddleware_UpstreamTokenReaderWiring(t *testing.T) { GetAllValidTokens(gomock.Any(), "session-abc"). Return(map[string]string{"google": "gcp-access-token"}, nil) - authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, reader) + authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, reader, nil) require.NoError(t, err, "middleware creation should succeed with non-nil reader") require.NotNil(t, authMw) @@ -145,7 +145,7 @@ func TestNewOIDCAuthMiddleware_UpstreamTokenReaderWiring(t *testing.T) { t.Run("upstream tokens nil when reader is nil", func(t *testing.T) { t.Parallel() - authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, nil) + authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, nil, nil) require.NoError(t, err) require.NotNil(t, authMw) @@ -181,7 +181,7 @@ func TestNewOIDCAuthMiddleware_UpstreamTokenReaderWiring(t *testing.T) { reader := upstreamtokenmocks.NewMockTokenReader(ctrl) // No EXPECT -- reader should not be called when tsid is absent. - authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, reader) + authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, reader, nil) require.NoError(t, err) require.NotNil(t, authMw) From 7161e2d4c1c11fca4be0de3572af6f9b3caad8d3 Mon Sep 17 00:00:00 2001 From: Trey Date: Fri, 3 Apr 2026 08:52:42 -0700 Subject: [PATCH 2/2] Improve key provider wiring tests for clarity and coverage Address code review feedback on incoming_keyprovider_test.go: - MEDIUM: Split TestNewOIDCAuthMiddleware_KeyProviderWiring into two independent top-level functions so each test is fully self-contained; outer-level setup no longer leaks into subtests that don't use it - LOW: Add TestNewOIDCAuthMiddleware_KeyProvider_KidMissFallback to cover the end-to-end kid-miss path at the factory level, confirming the validator falls back to HTTP JWKS when the local provider holds no matching key Co-Authored-By: Claude Sonnet 4.6 --- .../auth/factory/incoming_keyprovider_test.go | 249 +++++++++++------- 1 file changed, 159 insertions(+), 90 deletions(-) diff --git a/pkg/vmcp/auth/factory/incoming_keyprovider_test.go b/pkg/vmcp/auth/factory/incoming_keyprovider_test.go index 65d33c0bc5..21030f83ab 100644 --- a/pkg/vmcp/auth/factory/incoming_keyprovider_test.go +++ b/pkg/vmcp/auth/factory/incoming_keyprovider_test.go @@ -23,10 +23,10 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp/config" ) -// TestNewOIDCAuthMiddleware_KeyProviderWiring verifies that an in-process -// PublicKeyProvider is used for JWKS key resolution, avoiding self-referential -// HTTP calls when the embedded auth server runs in the same process. -func TestNewOIDCAuthMiddleware_KeyProviderWiring(t *testing.T) { +// TestNewOIDCAuthMiddleware_KeyProvider_LocalResolution verifies that when a +// PublicKeyProvider is wired in, key resolution happens in-process via the +// local provider rather than through an HTTP JWKS fetch. +func TestNewOIDCAuthMiddleware_KeyProvider_LocalResolution(t *testing.T) { t.Parallel() // Generate an ECDSA P-256 key pair (matching the embedded auth server's @@ -52,93 +52,162 @@ func TestNewOIDCAuthMiddleware_KeyProviderWiring(t *testing.T) { JwksAllowPrivateIP: true, } - t.Run("keys resolved from local provider instead of HTTP", func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - mockProvider := keysmocks.NewMockPublicKeyProvider(ctrl) - mockProvider.EXPECT(). - PublicKeys(gomock.Any()). - Return([]*keys.PublicKeyData{{ - KeyID: ecdsaKeyID, - Algorithm: "ES256", - PublicKey: &privateKey.PublicKey, - CreatedAt: time.Now(), - }}, nil). - AnyTimes() - - authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, nil, mockProvider) - require.NoError(t, err, "middleware creation should succeed with key provider") - require.NotNil(t, authMw) - - var capturedIdentity *pkgauth.Identity - handler := authMw(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { - capturedIdentity, _ = pkgauth.IdentityFromContext(r.Context()) - })) - - // Sign a JWT with the ECDSA private key — only the local provider - // holds the matching public key. - tok := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{ - "iss": issuer, - "aud": "test-audience", - "sub": "test-user", - "exp": time.Now().Add(time.Hour).Unix(), - }) - tok.Header["kid"] = ecdsaKeyID - tokenString, err := tok.SignedString(privateKey) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodGet, "/test", nil) - req.Header.Set("Authorization", "Bearer "+tokenString) - rr := httptest.NewRecorder() - - handler.ServeHTTP(rr, req) - - require.Equal(t, http.StatusOK, rr.Code, "request should succeed via local key provider") - require.NotNil(t, capturedIdentity, "identity should be present in context") - assert.Equal(t, "test-user", capturedIdentity.Subject) + ctrl := gomock.NewController(t) + mockProvider := keysmocks.NewMockPublicKeyProvider(ctrl) + mockProvider.EXPECT(). + PublicKeys(gomock.Any()). + Return([]*keys.PublicKeyData{{ + KeyID: ecdsaKeyID, + Algorithm: "ES256", + PublicKey: &privateKey.PublicKey, + CreatedAt: time.Now(), + }}, nil). + AnyTimes() + + authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, nil, mockProvider) + require.NoError(t, err, "middleware creation should succeed with key provider") + require.NotNil(t, authMw) + + var capturedIdentity *pkgauth.Identity + handler := authMw(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + capturedIdentity, _ = pkgauth.IdentityFromContext(r.Context()) + })) + + // Sign a JWT with the ECDSA private key — only the local provider + // holds the matching public key. + tok := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{ + "iss": issuer, + "aud": "test-audience", + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), }) + tok.Header["kid"] = ecdsaKeyID + tokenString, err := tok.SignedString(privateKey) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("Authorization", "Bearer "+tokenString) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code, "request should succeed via local key provider") + require.NotNil(t, capturedIdentity, "identity should be present in context") + assert.Equal(t, "test-user", capturedIdentity.Subject) +} + +// TestNewOIDCAuthMiddleware_KeyProvider_HTTPFallback verifies that when the +// key provider is nil, key resolution falls back to an HTTP JWKS fetch. +func TestNewOIDCAuthMiddleware_KeyProvider_HTTPFallback(t *testing.T) { + t.Parallel() + + // Use the RSA key from the test OIDC server (served via HTTP JWKS). + server, rsaPrivateKey := newTestOIDCServer(t) + t.Cleanup(server.Close) - t.Run("falls back to HTTP JWKS when key provider is nil", func(t *testing.T) { - t.Parallel() - - // Use the RSA key from the test OIDC server (served via HTTP JWKS). - httpServer, rsaPrivateKey := newTestOIDCServer(t) - t.Cleanup(httpServer.Close) - - httpIssuer := httpServer.URL - httpOIDCCfg := &config.OIDCConfig{ - Issuer: httpIssuer, - ClientID: "test-client", - Audience: "test-audience", - InsecureAllowHTTP: true, - JwksAllowPrivateIP: true, - } - - authMw, _, err := newOIDCAuthMiddleware(t.Context(), httpOIDCCfg, nil, nil) - require.NoError(t, err) - require.NotNil(t, authMw) - - var capturedIdentity *pkgauth.Identity - handler := authMw(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { - capturedIdentity, _ = pkgauth.IdentityFromContext(r.Context()) - })) - - token := signJWT(t, rsaPrivateKey, jwt.MapClaims{ - "iss": httpIssuer, - "aud": "test-audience", - "sub": "test-user", - "exp": time.Now().Add(time.Hour).Unix(), - }) - - req := httptest.NewRequest(http.MethodGet, "/test", nil) - req.Header.Set("Authorization", "Bearer "+token) - rr := httptest.NewRecorder() - - handler.ServeHTTP(rr, req) - - require.Equal(t, http.StatusOK, rr.Code, "request should succeed via HTTP JWKS fallback") - require.NotNil(t, capturedIdentity, "identity should be present in context") - assert.Equal(t, "test-user", capturedIdentity.Subject) + issuer := server.URL + oidcCfg := &config.OIDCConfig{ + Issuer: issuer, + ClientID: "test-client", + Audience: "test-audience", + InsecureAllowHTTP: true, + JwksAllowPrivateIP: true, + } + + authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, nil, nil) + require.NoError(t, err) + require.NotNil(t, authMw) + + var capturedIdentity *pkgauth.Identity + handler := authMw(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + capturedIdentity, _ = pkgauth.IdentityFromContext(r.Context()) + })) + + token := signJWT(t, rsaPrivateKey, jwt.MapClaims{ + "iss": issuer, + "aud": "test-audience", + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code, "request should succeed via HTTP JWKS fallback") + require.NotNil(t, capturedIdentity, "identity should be present in context") + assert.Equal(t, "test-user", capturedIdentity.Subject) +} + +// TestNewOIDCAuthMiddleware_KeyProvider_KidMissFallback verifies that when the +// local PublicKeyProvider does not hold a key matching the JWT's kid, the +// validator falls back to HTTP JWKS and the request still succeeds. This +// confirms the end-to-end wiring for the kid-miss path at the factory level. +func TestNewOIDCAuthMiddleware_KeyProvider_KidMissFallback(t *testing.T) { + t.Parallel() + + // Stand up a real OIDC server that serves the RSA key via HTTP JWKS. + server, rsaPrivateKey := newTestOIDCServer(t) + t.Cleanup(server.Close) + + issuer := server.URL + oidcCfg := &config.OIDCConfig{ + Issuer: issuer, + ClientID: "test-client", + Audience: "test-audience", + InsecureAllowHTTP: true, + JwksAllowPrivateIP: true, + } + + // Wire a mock provider that returns a key with a *different* kid than the + // one in the JWT. The validator should call the local provider first, get a + // kid-miss (nil key returned), and then fall back to HTTP JWKS. + ctrl := gomock.NewController(t) + mockProvider := keysmocks.NewMockPublicKeyProvider(ctrl) + + // Generate a throwaway ECDSA key so the mock returns a non-nil key list + // with a different kid. + throwawayKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + mockProvider.EXPECT(). + PublicKeys(gomock.Any()). + Return([]*keys.PublicKeyData{{ + KeyID: "unrelated-key-id", // does NOT match testKeyID used by signJWT + Algorithm: "ES256", + PublicKey: &throwawayKey.PublicKey, + CreatedAt: time.Now(), + }}, nil). + AnyTimes() + + authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, nil, mockProvider) + require.NoError(t, err) + require.NotNil(t, authMw) + + var capturedIdentity *pkgauth.Identity + handler := authMw(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + capturedIdentity, _ = pkgauth.IdentityFromContext(r.Context()) + })) + + // Sign the JWT with the RSA key from the test server (kid = testKeyID). + // The mock provider holds a key with a different kid, so the validator must + // fall back to HTTP JWKS to find the matching key. + token := signJWT(t, rsaPrivateKey, jwt.MapClaims{ + "iss": issuer, + "aud": "test-audience", + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code, "request should succeed via HTTP JWKS fallback on kid-miss") + require.NotNil(t, capturedIdentity, "identity should be present in context") + assert.Equal(t, "test-user", capturedIdentity.Subject) }