diff --git a/lib/auth/auth.go b/lib/auth/auth.go index bfb869fe39aa7..37577871f01d3 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -763,7 +763,12 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (as *Server, err error) { as.k8sJWKSValidator = kubetoken.ValidateTokenWithJWKS } if as.k8sOIDCValidator == nil { - as.k8sOIDCValidator = kubetoken.ValidateTokenWithOIDC + validator, err := kubetoken.NewKubernetesOIDCTokenValidator() + if err != nil { + return nil, trace.Wrap(err) + } + + as.k8sOIDCValidator = validator } if as.gcpIDTokenValidator == nil { @@ -1254,7 +1259,7 @@ type Server struct { k8sJWKSValidator k8sJWKSValidator // k8sOIDCValidator allows tokens from Kubernetes to be validated by the // auth server using a known OIDC endpoint. It can be overridden in tests. - k8sOIDCValidator k8sOIDCValidator + k8sOIDCValidator *kubetoken.KubernetesOIDCTokenValidator // gcpIDTokenValidator allows ID tokens from GCP to be validated by the auth // server. It can be overridden for the purpose of tests. diff --git a/lib/auth/join_kubernetes.go b/lib/auth/join_kubernetes.go index 4b4ccce091e90..badaa37fc510b 100644 --- a/lib/auth/join_kubernetes.go +++ b/lib/auth/join_kubernetes.go @@ -35,13 +35,6 @@ type k8sTokenReviewValidator interface { type k8sJWKSValidator func(now time.Time, jwksData []byte, clusterName string, token string) (*kubetoken.ValidationResult, error) -type k8sOIDCValidator func( - ctx context.Context, - issuerURL string, - clusterName string, - token string, -) (*kubetoken.ValidationResult, error) - func (a *Server) checkKubernetesJoinRequest( ctx context.Context, req *types.RegisterUsingTokenRequest, @@ -77,7 +70,7 @@ func (a *Server) checkKubernetesJoinRequest( return nil, trace.WrapWithMessage(err, "reviewing kubernetes token with static_jwks") } case types.KubernetesJoinTypeOIDC: - result, err = a.k8sOIDCValidator( + result, err = a.k8sOIDCValidator.ValidateToken( ctx, token.Spec.Kubernetes.OIDC.Issuer, clusterName, diff --git a/lib/kube/token/validator.go b/lib/kube/token/validator.go index d6702d2377a9e..ecb17473b94d6 100644 --- a/lib/kube/token/validator.go +++ b/lib/kube/token/validator.go @@ -29,6 +29,7 @@ import ( "github.com/go-jose/go-jose/v3" josejwt "github.com/go-jose/go-jose/v3/jwt" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" v1 "k8s.io/api/authentication/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/version" @@ -328,20 +329,38 @@ func ValidateTokenWithJWKS( }, nil } +// NewKubernetesOIDCTokenValidator constructs a KubernetesOIDCTokenValidator. +func NewKubernetesOIDCTokenValidator() (*KubernetesOIDCTokenValidator, error) { + validator, err := oidc.NewCachingTokenValidator[*tokenclaims.OIDCServiceAccountClaims](clockwork.NewRealClock()) + if err != nil { + return nil, trace.Wrap(err) + } + + return &KubernetesOIDCTokenValidator{ + validator: validator, + }, nil +} + +// KubernetesOIDCTokenValidator is a validator that can validate Kubernetes +// projected service account tokens against an external OIDC compatible IdP. +type KubernetesOIDCTokenValidator struct { + validator *oidc.CachingTokenValidator[*tokenclaims.OIDCServiceAccountClaims] +} + // ValidateTokenWithJWKS validates a Kubernetes Service Account JWT using an // OIDC endpoint. -func ValidateTokenWithOIDC( +func (v *KubernetesOIDCTokenValidator) ValidateToken( ctx context.Context, issuerURL string, clusterName string, token string, ) (*ValidationResult, error) { - claims, err := oidc.ValidateToken[*tokenclaims.OIDCServiceAccountClaims]( - ctx, - issuerURL, - clusterName, - token, - ) + validator, err := v.validator.GetValidator(ctx, issuerURL, clusterName) + if err != nil { + return nil, trace.Wrap(err) + } + + claims, err := validator.ValidateToken(ctx, token) if err != nil { return nil, trace.Wrap(err, "validating OIDC token") } diff --git a/lib/kube/token/validator_test.go b/lib/kube/token/validator_test.go index 18ef6a4e83917..9e4ade25b290d 100644 --- a/lib/kube/token/validator_test.go +++ b/lib/kube/token/validator_test.go @@ -848,7 +848,10 @@ func TestValidateTokenWithOIDC(t *testing.T) { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() - result, err := ValidateTokenWithOIDC(ctx, idp.IssuerURL(), tt.audience, tt.token) + validator, err := NewKubernetesOIDCTokenValidator() + require.NoError(t, err) + + result, err := validator.ValidateToken(ctx, idp.IssuerURL(), tt.audience, tt.token) tt.assertError(t, err) require.Empty(t, cmp.Diff( diff --git a/lib/oidc/caching_token_validator.go b/lib/oidc/caching_token_validator.go new file mode 100644 index 0000000000000..3425e8460292c --- /dev/null +++ b/lib/oidc/caching_token_validator.go @@ -0,0 +1,259 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package oidc + +import ( + "context" + "log/slog" + "net/http" + "sync" + "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/zitadel/oidc/v3/pkg/client" + "github.com/zitadel/oidc/v3/pkg/client/rp" + "github.com/zitadel/oidc/v3/pkg/oidc" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/utils" + logutils "github.com/gravitational/teleport/lib/utils/log" +) + +var log = logutils.NewPackageLogger(teleport.ComponentKey, teleport.Component("oidc")) + +const ( + // discoveryTTL is the maximum duration a discovery configuration will be + // cached locally before being discarded + discoveryTTL = time.Hour + + // keySetTTL is the maximum duration a particular keyset will be allowed to + // exist before being purged, regardless of whether or not it is being used + // actively. The underlying library may update its internal cache of keys + // within this window. + keySetTTL = time.Hour * 24 + + // validatorTTL is a maximum time a particular validator instance should + // remain in memory before being pruned if left unused. + validatorTTL = time.Hour * 24 * 2 +) + +// validatorKey is a composite key for the validator instance map +type validatorKey struct { + issuer string + audience string +} + +// NewCachingTokenValidator creates a caching validator for the given issuer and +// audience, using a real clock. +func NewCachingTokenValidator[C oidc.Claims](clock clockwork.Clock) (*CachingTokenValidator[C], error) { + if clock == nil { + clock = clockwork.NewRealClock() + } + + cache, err := utils.NewFnCache(utils.FnCacheConfig{ + Clock: clock, + TTL: validatorTTL, + ReloadOnErr: true, + }) + if err != nil { + return nil, err + } + + return &CachingTokenValidator[C]{ + clock: clock, + cache: cache, + }, nil +} + +// CachingTokenValidator is a wrapper on top of `CachingValidatorInstance` that +// automatically manages and prunes validator instances for a given +// (issuer, audience) pair. This helps to ensure validators and key sets don't +// remain in memory indefinitely if e.g. a Teleport auth token is modified to +// use a different issuer or removed outright. +type CachingTokenValidator[C oidc.Claims] struct { + clock clockwork.Clock + + cache *utils.FnCache +} + +// GetValidator retreives a validator for the given issuer and audience. This +// will create a new validator instance if necessary, and will occasionally +// prune old instances that have not been used to validate any tokens in some +// time. +func (v *CachingTokenValidator[C]) GetValidator(ctx context.Context, issuer, audience string) (*CachingValidatorInstance[C], error) { + key := validatorKey{issuer: issuer, audience: audience} + instance, err := utils.FnCacheGet(ctx, v.cache, key, func(ctx context.Context) (*CachingValidatorInstance[C], error) { + transport, err := defaults.Transport() + if err != nil { + return nil, trace.Wrap(err) + } + + return &CachingValidatorInstance[C]{ + client: &http.Client{Transport: otelhttp.NewTransport(transport)}, + clock: v.clock, + issuer: issuer, + audience: audience, + verifierFn: zoidcTokenVerifier[C], + logger: log.With("issuer", issuer, "audience", audience), + }, nil + }) + + return instance, err +} + +// CachingValidatorInstance provides an issuer-specific cache. It separately +// caches the discovery config and `oidc.KeySet` to ensure each is reasonably +// fresh, and purges sufficiently old key sets to ensure old keys are not +// retained indefinitely. +type CachingValidatorInstance[C oidc.Claims] struct { + issuer string + audience string + clock clockwork.Clock + client *http.Client + logger *slog.Logger + + mu sync.Mutex + discoveryConfig *oidc.DiscoveryConfiguration + discoveryConfigExpires time.Time + lastJWKSURI string + keySet oidc.KeySet + keySetExpires time.Time + + // verifierFn is the function that actually verifies the token using the + // oidc library. `zitadel/oidc` doesn't provide any way to override the + // clock, so we use this for tests. + verifierFn func( + ctx context.Context, + issuer, + clientID string, + keySet oidc.KeySet, + token string, + opts ...rp.VerifierOption, + ) (C, error) +} + +func (v *CachingValidatorInstance[C]) getKeySet( + ctx context.Context, +) (oidc.KeySet, error) { + // Note: We could consider an RWLock or singleflight if perf proves to be + // poor here. As written, I don't expect serialized warm-cache requests to + // accumulate enough to be worth the added complexity. + v.mu.Lock() + defer v.mu.Unlock() + + now := v.clock.Now() + + if !v.discoveryConfigExpires.IsZero() && now.After(v.discoveryConfigExpires) { + // Invalidate the cached value. + v.discoveryConfig = nil + v.discoveryConfigExpires = time.Time{} + + v.logger.DebugContext(ctx, "Invalidating expired discovery config") + } + + if v.discoveryConfig == nil { + v.logger.DebugContext(ctx, "Fetching new discovery config") + + // Note: This is the only blocking call inside the mutex. + // In the future, it might be a good idea to fetch the new discovery + // config async and keep it available if the refresh fails. + dc, err := client.Discover(ctx, v.issuer, v.client) + if err != nil { + return nil, trace.Wrap(err) + } + + v.discoveryConfig = dc + v.discoveryConfigExpires = now.Add(discoveryTTL) + + if v.lastJWKSURI != "" && v.lastJWKSURI != dc.JwksURI { + // If the JWKS URI has changed, expire the keyset now. + v.keySet = nil + v.keySetExpires = time.Time{} + } + v.lastJWKSURI = dc.JwksURI + } + + // If this upstream issue is fixed, we can remove this in favor of keeping + // the KeySet: https://github.com/zitadel/oidc/issues/747 + if !v.keySetExpires.IsZero() && now.After(v.keySetExpires) { + // Invalidate the cached value. + v.keySet = nil + v.keySetExpires = time.Time{} + + v.logger.DebugContext(ctx, "Invalidating expired KeySet") + } + + if v.keySet == nil { + v.logger.DebugContext(ctx, "Creating new remote KeySet") + v.keySet = rp.NewRemoteKeySet(v.client, v.discoveryConfig.JwksURI) + v.keySetExpires = now.Add(keySetTTL) + } + + return v.keySet, nil +} + +func zoidcTokenVerifier[C oidc.Claims]( + ctx context.Context, + issuer, + clientID string, + keySet oidc.KeySet, + token string, + opts ...rp.VerifierOption, +) (C, error) { + verifier := rp.NewIDTokenVerifier(issuer, clientID, keySet, opts...) + + // Note: VerifyIDToken() may mutate the KeySet (if the keyset is empty or if + // it encounters an unknown `kid`). The keyset manages a mutex of its own, + // so we don't need to protect this operation. It's acceptable for this + // keyset to be swapped in another thread and still used here; it will just + // be GC'd afterward. + claims, err := rp.VerifyIDToken[C](ctx, token, verifier) + if err != nil { + return *new(C), trace.Wrap(err, "verifying token") + } + + return claims, nil +} + +// ValidateToken verifies a compact encoded token against the configured +// issuer and keys, potentially using cached OpenID configuration and JWKS +// values. +func (v *CachingValidatorInstance[C]) ValidateToken( + ctx context.Context, + token string, + opts ...rp.VerifierOption, +) (C, error) { + timeoutCtx, cancel := context.WithTimeout(ctx, providerTimeout) + defer cancel() + + ks, err := v.getKeySet(timeoutCtx) + if err != nil { + return *new(C), trace.Wrap(err) + } + + claims, err := v.verifierFn(ctx, v.issuer, v.audience, ks, token, opts...) + if err != nil { + return *new(C), trace.Wrap(err) + } + + return claims, nil +} diff --git a/lib/oidc/caching_token_validator_test.go b/lib/oidc/caching_token_validator_test.go new file mode 100644 index 0000000000000..8e51ef93081d5 --- /dev/null +++ b/lib/oidc/caching_token_validator_test.go @@ -0,0 +1,396 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package oidc + +import ( + "context" + "crypto" + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + "github.com/zitadel/oidc/v3/pkg/client/rp" + "github.com/zitadel/oidc/v3/pkg/oidc" + + "github.com/gravitational/teleport/lib/cryptosuites" +) + +// fakeIDP provides a minimal fake OIDC provider for use in tests +type fakeIDP struct { + t *testing.T + clock *clockwork.FakeClock + signer jose.Signer + publicKey crypto.PublicKey + server *httptest.Server + audience string + + useAlternateJWKSEndpoint atomic.Bool + configRequests atomic.Uint32 + jwksRequests atomic.Uint32 +} + +func newFakeIDP(t *testing.T, clock *clockwork.FakeClock, audience string) *fakeIDP { + privateKey, err := cryptosuites.GenerateKeyWithAlgorithm(cryptosuites.RSA2048) + require.NoError(t, err) + + signer, err := jose.NewSigner( + jose.SigningKey{Algorithm: jose.RS256, Key: privateKey}, + (&jose.SignerOptions{}).WithType("JWT"), + ) + require.NoError(t, err) + + f := &fakeIDP{ + clock: clock, + signer: signer, + publicKey: privateKey.Public(), + t: t, + audience: audience, + } + + providerMux := http.NewServeMux() + providerMux.HandleFunc( + "/.well-known/openid-configuration", + f.handleOpenIDConfig, + ) + providerMux.HandleFunc( + "/.well-known/jwks", + f.handleJWKSEndpoint, + ) + providerMux.HandleFunc( + "/.well-known/jwks-alt", + f.handleJWKSEndpoint, + ) + + srv := httptest.NewServer(providerMux) + t.Cleanup(srv.Close) + f.server = srv + return f +} + +func (f *fakeIDP) issuer() string { + return f.server.URL +} + +func (f *fakeIDP) handleOpenIDConfig(w http.ResponseWriter, r *http.Request) { + jwksURI := f.issuer() + "/.well-known/jwks" + if f.useAlternateJWKSEndpoint.Load() { + jwksURI += "-alt" + } + + response := map[string]any{ + "claims_supported": []string{ + "sub", + "iss", + }, + "id_token_signing_alg_values_supported": []string{"RS256"}, + "issuer": f.issuer(), + "jwks_uri": jwksURI, + "response_types_supported": []string{"id_token"}, + "scopes_supported": []string{"openid"}, + "subject_types_supported": []string{"public"}, + } + responseBytes, err := json.Marshal(response) + require.NoError(f.t, err) + _, err = w.Write(responseBytes) + require.NoError(f.t, err) + + f.configRequests.Add(1) +} + +func (f *fakeIDP) handleJWKSEndpoint(w http.ResponseWriter, r *http.Request) { + jwks := jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + { + Key: f.publicKey, + }, + }, + } + responseBytes, err := json.Marshal(jwks) + require.NoError(f.t, err) + _, err = w.Write(responseBytes) + require.NoError(f.t, err) + + f.jwksRequests.Add(1) +} + +func (f *fakeIDP) issueToken( + t *testing.T, + audience, + sub string, + ttl time.Duration, +) string { + claims := oidc.TokenClaims{ + Issuer: f.issuer(), + Subject: sub, + Audience: oidc.Audience{audience}, + IssuedAt: oidc.FromTime(f.clock.Now()), + NotBefore: oidc.FromTime(f.clock.Now()), + Expiration: oidc.FromTime(f.clock.Now().Add(ttl)), + } + + token, err := jwt.Signed(f.signer). + Claims(claims). + Serialize() + require.NoError(t, err) + + return token +} + +// TestCachingTokenValidator runs various tests against the caching token +// validator +func TestCachingTokenValidator(t *testing.T) { + t.Parallel() + + const defaultAudience = "example.teleport.sh" + + // A minimal validator that skips most checks, especially any that depend on + // the system clock. We do this mainly to still invoke + // `keySet.VerifySignature()`. + minimalValidator := func() func( + context.Context, + string, string, oidc.KeySet, string, ...rp.VerifierOption) (*oidc.TokenClaims, error) { + return func( + ctx context.Context, + issuer, + clientID string, + keySet oidc.KeySet, + token string, + opts ...rp.VerifierOption, + ) (*oidc.TokenClaims, error) { + var claims oidc.TokenClaims + _, err := oidc.ParseToken(token, &claims) + if err != nil { + return nil, err + } + + jws, err := jose.ParseSigned(token, []jose.SignatureAlgorithm{jose.RS256}) + if err != nil { + return nil, err + } + + _, err = keySet.VerifySignature(ctx, jws) + if err != nil { + return nil, err + } + + return &claims, nil + } + } + + tests := []struct { + name string + audience string + execute func(t *testing.T, idp *fakeIDP, v *CachingTokenValidator[*oidc.TokenClaims]) + }{ + { + name: "empty", + audience: defaultAudience, + execute: func(t *testing.T, idp *fakeIDP, v *CachingTokenValidator[*oidc.TokenClaims]) { + // Do nothing. + require.Zero(t, idp.configRequests.Load()) + require.Zero(t, idp.jwksRequests.Load()) + }, + }, + { + name: "single validator", + audience: defaultAudience, + execute: func(t *testing.T, idp *fakeIDP, v *CachingTokenValidator[*oidc.TokenClaims]) { + val, err := v.GetValidator(t.Context(), idp.issuer(), defaultAudience) + require.NoError(t, err) + + token := idp.issueToken(t, defaultAudience, "a", time.Hour) + claims, err := val.ValidateToken(t.Context(), token) + require.NoError(t, err) + + require.Equal(t, "a", claims.Subject) + require.EqualValues(t, 1, idp.configRequests.Load()) + require.EqualValues(t, 1, idp.jwksRequests.Load()) + + token = idp.issueToken(t, defaultAudience, "b", time.Hour) + claims, err = val.ValidateToken(t.Context(), token) + require.NoError(t, err) + + require.Equal(t, "b", claims.Subject) + require.EqualValues(t, 1, idp.configRequests.Load()) + require.EqualValues(t, 1, idp.jwksRequests.Load()) + }, + }, + { + name: "multiple validators", + audience: defaultAudience, + execute: func(t *testing.T, idp *fakeIDP, v *CachingTokenValidator[*oidc.TokenClaims]) { + v1, err := v.GetValidator(t.Context(), idp.issuer(), "a.teleport.sh") + require.NoError(t, err) + v2, err := v.GetValidator(t.Context(), idp.issuer(), "b.teleport.sh") + require.NoError(t, err) + + token := idp.issueToken(t, "a.teleport.sh", "a", time.Hour) + claims, err := v1.ValidateToken(t.Context(), token) + require.NoError(t, err) + + require.Equal(t, "a", claims.Subject) + require.EqualValues(t, 1, idp.configRequests.Load(), "config") + require.EqualValues(t, 1, idp.jwksRequests.Load(), "jwks") + + token = idp.issueToken(t, "b.teleport.sh", "b", time.Hour) + claims, err = v2.ValidateToken(t.Context(), token) + require.NoError(t, err) + + require.Equal(t, "b", claims.Subject) + require.EqualValues(t, 2, idp.configRequests.Load()) + require.EqualValues(t, 2, idp.jwksRequests.Load()) + + // Validating against a bad token should fail, and should not + // result in spurious requests. + token = idp.issueToken(t, "c.teleport.sh", "c", time.Hour) + _, err = v2.ValidateToken(t.Context(), token) + require.Error(t, err) + + require.EqualValues(t, 2, idp.configRequests.Load()) + require.EqualValues(t, 2, idp.jwksRequests.Load()) + }, + }, + { + name: "expired config", + audience: defaultAudience, + execute: func(t *testing.T, idp *fakeIDP, v *CachingTokenValidator[*oidc.TokenClaims]) { + val, err := v.GetValidator(t.Context(), idp.issuer(), defaultAudience) + require.NoError(t, err) + val.verifierFn = minimalValidator() + + token := idp.issueToken(t, defaultAudience, "a", time.Hour) + _, err = val.ValidateToken(t.Context(), token) + require.NoError(t, err) + + require.EqualValues(t, 1, idp.configRequests.Load()) + require.EqualValues(t, 1, idp.jwksRequests.Load()) + + idp.clock.Advance(discoveryTTL + time.Minute) + token = idp.issueToken(t, defaultAudience, "b", time.Hour) + _, err = val.ValidateToken(t.Context(), token) + require.NoError(t, err) + + require.EqualValues(t, 2, idp.configRequests.Load()) + require.EqualValues(t, 1, idp.jwksRequests.Load()) + }, + }, + { + name: "stale config", + audience: defaultAudience, + execute: func(t *testing.T, idp *fakeIDP, v *CachingTokenValidator[*oidc.TokenClaims]) { + val, err := v.GetValidator(t.Context(), idp.issuer(), defaultAudience) + require.NoError(t, err) + val.verifierFn = minimalValidator() + + idp.clock.Advance(validatorTTL + time.Minute) + + token := idp.issueToken(t, defaultAudience, "a", time.Hour) + _, err = val.ValidateToken(t.Context(), token) + require.NoError(t, err) + + // This validation attempt should fetch both the config and JWKS + // endpoint. + require.EqualValues(t, 1, idp.configRequests.Load()) + require.EqualValues(t, 1, idp.jwksRequests.Load()) + + idp.clock.Advance(discoveryTTL + time.Minute) + token = idp.issueToken(t, defaultAudience, "b", time.Hour) + _, err = val.ValidateToken(t.Context(), token) + require.NoError(t, err) + + // Config should be reloaded, but the keyset will remain cached + require.EqualValues(t, 2, idp.configRequests.Load()) + require.EqualValues(t, 1, idp.jwksRequests.Load()) + }, + }, + { + name: "changed jwks uri", + audience: defaultAudience, + execute: func(t *testing.T, idp *fakeIDP, v *CachingTokenValidator[*oidc.TokenClaims]) { + val, err := v.GetValidator(t.Context(), idp.issuer(), defaultAudience) + require.NoError(t, err) + val.verifierFn = minimalValidator() + + idp.clock.Advance(validatorTTL + time.Minute) + + token := idp.issueToken(t, defaultAudience, "a", time.Hour) + _, err = val.ValidateToken(t.Context(), token) + require.NoError(t, err) + + // This validation attempt should fetch both the config and JWKS + // endpoint. + require.EqualValues(t, 1, idp.configRequests.Load()) + require.EqualValues(t, 1, idp.jwksRequests.Load()) + + // Switch to the new endpoint, advance the clock enough to + // trigger a config refresh, and validate again. + idp.useAlternateJWKSEndpoint.Store(true) + idp.clock.Advance(discoveryTTL + time.Minute) + token = idp.issueToken(t, defaultAudience, "b", time.Hour) + _, err = val.ValidateToken(t.Context(), token) + require.NoError(t, err) + + // Config should be reloaded, and the keyset should be reloaded. + require.EqualValues(t, 2, idp.configRequests.Load()) + require.EqualValues(t, 2, idp.jwksRequests.Load()) + }, + }, + { + name: "validator pruning", + audience: defaultAudience, + execute: func(t *testing.T, idp *fakeIDP, v *CachingTokenValidator[*oidc.TokenClaims]) { + valOld, err := v.GetValidator(t.Context(), idp.issuer(), "a") + require.NoError(t, err) + + // After just 1 hour, it should return the same pointer + idp.clock.Advance(time.Hour + time.Minute) + + valTemp, err := v.GetValidator(t.Context(), idp.issuer(), "a") + require.NoError(t, err) + require.Same(t, valOld, valTemp) + + // After 48 hours, make the request again. It's now past its + // TTL and should be recreated. + idp.clock.Advance(validatorTTL + time.Minute) + valNew, err := v.GetValidator(t.Context(), idp.issuer(), "a") + require.NoError(t, err) + require.NotSame(t, valNew, valOld) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + clock := clockwork.NewFakeClock() + idp := newFakeIDP(t, clock, tt.audience) + + validator, err := NewCachingTokenValidator[*oidc.TokenClaims](clock) + require.NoError(t, err) + + tt.execute(t, idp, validator) + }) + } +}