diff --git a/lib/auth/auth.go b/lib/auth/auth.go
index bfb869fe39aa7..37577871f01d3 100644
--- a/lib/auth/auth.go
+++ b/lib/auth/auth.go
@@ -763,7 +763,12 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (as *Server, err error) {
as.k8sJWKSValidator = kubetoken.ValidateTokenWithJWKS
}
if as.k8sOIDCValidator == nil {
- as.k8sOIDCValidator = kubetoken.ValidateTokenWithOIDC
+ validator, err := kubetoken.NewKubernetesOIDCTokenValidator()
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ as.k8sOIDCValidator = validator
}
if as.gcpIDTokenValidator == nil {
@@ -1254,7 +1259,7 @@ type Server struct {
k8sJWKSValidator k8sJWKSValidator
// k8sOIDCValidator allows tokens from Kubernetes to be validated by the
// auth server using a known OIDC endpoint. It can be overridden in tests.
- k8sOIDCValidator k8sOIDCValidator
+ k8sOIDCValidator *kubetoken.KubernetesOIDCTokenValidator
// gcpIDTokenValidator allows ID tokens from GCP to be validated by the auth
// server. It can be overridden for the purpose of tests.
diff --git a/lib/auth/join_kubernetes.go b/lib/auth/join_kubernetes.go
index 4b4ccce091e90..badaa37fc510b 100644
--- a/lib/auth/join_kubernetes.go
+++ b/lib/auth/join_kubernetes.go
@@ -35,13 +35,6 @@ type k8sTokenReviewValidator interface {
type k8sJWKSValidator func(now time.Time, jwksData []byte, clusterName string, token string) (*kubetoken.ValidationResult, error)
-type k8sOIDCValidator func(
- ctx context.Context,
- issuerURL string,
- clusterName string,
- token string,
-) (*kubetoken.ValidationResult, error)
-
func (a *Server) checkKubernetesJoinRequest(
ctx context.Context,
req *types.RegisterUsingTokenRequest,
@@ -77,7 +70,7 @@ func (a *Server) checkKubernetesJoinRequest(
return nil, trace.WrapWithMessage(err, "reviewing kubernetes token with static_jwks")
}
case types.KubernetesJoinTypeOIDC:
- result, err = a.k8sOIDCValidator(
+ result, err = a.k8sOIDCValidator.ValidateToken(
ctx,
token.Spec.Kubernetes.OIDC.Issuer,
clusterName,
diff --git a/lib/kube/token/validator.go b/lib/kube/token/validator.go
index d6702d2377a9e..ecb17473b94d6 100644
--- a/lib/kube/token/validator.go
+++ b/lib/kube/token/validator.go
@@ -29,6 +29,7 @@ import (
"github.com/go-jose/go-jose/v3"
josejwt "github.com/go-jose/go-jose/v3/jwt"
"github.com/gravitational/trace"
+ "github.com/jonboulle/clockwork"
v1 "k8s.io/api/authentication/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/version"
@@ -328,20 +329,38 @@ func ValidateTokenWithJWKS(
}, nil
}
+// NewKubernetesOIDCTokenValidator constructs a KubernetesOIDCTokenValidator.
+func NewKubernetesOIDCTokenValidator() (*KubernetesOIDCTokenValidator, error) {
+ validator, err := oidc.NewCachingTokenValidator[*tokenclaims.OIDCServiceAccountClaims](clockwork.NewRealClock())
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ return &KubernetesOIDCTokenValidator{
+ validator: validator,
+ }, nil
+}
+
+// KubernetesOIDCTokenValidator is a validator that can validate Kubernetes
+// projected service account tokens against an external OIDC compatible IdP.
+type KubernetesOIDCTokenValidator struct {
+ validator *oidc.CachingTokenValidator[*tokenclaims.OIDCServiceAccountClaims]
+}
+
// ValidateTokenWithJWKS validates a Kubernetes Service Account JWT using an
// OIDC endpoint.
-func ValidateTokenWithOIDC(
+func (v *KubernetesOIDCTokenValidator) ValidateToken(
ctx context.Context,
issuerURL string,
clusterName string,
token string,
) (*ValidationResult, error) {
- claims, err := oidc.ValidateToken[*tokenclaims.OIDCServiceAccountClaims](
- ctx,
- issuerURL,
- clusterName,
- token,
- )
+ validator, err := v.validator.GetValidator(ctx, issuerURL, clusterName)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ claims, err := validator.ValidateToken(ctx, token)
if err != nil {
return nil, trace.Wrap(err, "validating OIDC token")
}
diff --git a/lib/kube/token/validator_test.go b/lib/kube/token/validator_test.go
index 18ef6a4e83917..9e4ade25b290d 100644
--- a/lib/kube/token/validator_test.go
+++ b/lib/kube/token/validator_test.go
@@ -848,7 +848,10 @@ func TestValidateTokenWithOIDC(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
- result, err := ValidateTokenWithOIDC(ctx, idp.IssuerURL(), tt.audience, tt.token)
+ validator, err := NewKubernetesOIDCTokenValidator()
+ require.NoError(t, err)
+
+ result, err := validator.ValidateToken(ctx, idp.IssuerURL(), tt.audience, tt.token)
tt.assertError(t, err)
require.Empty(t, cmp.Diff(
diff --git a/lib/oidc/caching_token_validator.go b/lib/oidc/caching_token_validator.go
new file mode 100644
index 0000000000000..3425e8460292c
--- /dev/null
+++ b/lib/oidc/caching_token_validator.go
@@ -0,0 +1,259 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package oidc
+
+import (
+ "context"
+ "log/slog"
+ "net/http"
+ "sync"
+ "time"
+
+ "github.com/gravitational/trace"
+ "github.com/jonboulle/clockwork"
+ "github.com/zitadel/oidc/v3/pkg/client"
+ "github.com/zitadel/oidc/v3/pkg/client/rp"
+ "github.com/zitadel/oidc/v3/pkg/oidc"
+ "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
+
+ "github.com/gravitational/teleport"
+ "github.com/gravitational/teleport/lib/defaults"
+ "github.com/gravitational/teleport/lib/utils"
+ logutils "github.com/gravitational/teleport/lib/utils/log"
+)
+
+var log = logutils.NewPackageLogger(teleport.ComponentKey, teleport.Component("oidc"))
+
+const (
+ // discoveryTTL is the maximum duration a discovery configuration will be
+ // cached locally before being discarded
+ discoveryTTL = time.Hour
+
+ // keySetTTL is the maximum duration a particular keyset will be allowed to
+ // exist before being purged, regardless of whether or not it is being used
+ // actively. The underlying library may update its internal cache of keys
+ // within this window.
+ keySetTTL = time.Hour * 24
+
+ // validatorTTL is a maximum time a particular validator instance should
+ // remain in memory before being pruned if left unused.
+ validatorTTL = time.Hour * 24 * 2
+)
+
+// validatorKey is a composite key for the validator instance map
+type validatorKey struct {
+ issuer string
+ audience string
+}
+
+// NewCachingTokenValidator creates a caching validator for the given issuer and
+// audience, using a real clock.
+func NewCachingTokenValidator[C oidc.Claims](clock clockwork.Clock) (*CachingTokenValidator[C], error) {
+ if clock == nil {
+ clock = clockwork.NewRealClock()
+ }
+
+ cache, err := utils.NewFnCache(utils.FnCacheConfig{
+ Clock: clock,
+ TTL: validatorTTL,
+ ReloadOnErr: true,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return &CachingTokenValidator[C]{
+ clock: clock,
+ cache: cache,
+ }, nil
+}
+
+// CachingTokenValidator is a wrapper on top of `CachingValidatorInstance` that
+// automatically manages and prunes validator instances for a given
+// (issuer, audience) pair. This helps to ensure validators and key sets don't
+// remain in memory indefinitely if e.g. a Teleport auth token is modified to
+// use a different issuer or removed outright.
+type CachingTokenValidator[C oidc.Claims] struct {
+ clock clockwork.Clock
+
+ cache *utils.FnCache
+}
+
+// GetValidator retreives a validator for the given issuer and audience. This
+// will create a new validator instance if necessary, and will occasionally
+// prune old instances that have not been used to validate any tokens in some
+// time.
+func (v *CachingTokenValidator[C]) GetValidator(ctx context.Context, issuer, audience string) (*CachingValidatorInstance[C], error) {
+ key := validatorKey{issuer: issuer, audience: audience}
+ instance, err := utils.FnCacheGet(ctx, v.cache, key, func(ctx context.Context) (*CachingValidatorInstance[C], error) {
+ transport, err := defaults.Transport()
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ return &CachingValidatorInstance[C]{
+ client: &http.Client{Transport: otelhttp.NewTransport(transport)},
+ clock: v.clock,
+ issuer: issuer,
+ audience: audience,
+ verifierFn: zoidcTokenVerifier[C],
+ logger: log.With("issuer", issuer, "audience", audience),
+ }, nil
+ })
+
+ return instance, err
+}
+
+// CachingValidatorInstance provides an issuer-specific cache. It separately
+// caches the discovery config and `oidc.KeySet` to ensure each is reasonably
+// fresh, and purges sufficiently old key sets to ensure old keys are not
+// retained indefinitely.
+type CachingValidatorInstance[C oidc.Claims] struct {
+ issuer string
+ audience string
+ clock clockwork.Clock
+ client *http.Client
+ logger *slog.Logger
+
+ mu sync.Mutex
+ discoveryConfig *oidc.DiscoveryConfiguration
+ discoveryConfigExpires time.Time
+ lastJWKSURI string
+ keySet oidc.KeySet
+ keySetExpires time.Time
+
+ // verifierFn is the function that actually verifies the token using the
+ // oidc library. `zitadel/oidc` doesn't provide any way to override the
+ // clock, so we use this for tests.
+ verifierFn func(
+ ctx context.Context,
+ issuer,
+ clientID string,
+ keySet oidc.KeySet,
+ token string,
+ opts ...rp.VerifierOption,
+ ) (C, error)
+}
+
+func (v *CachingValidatorInstance[C]) getKeySet(
+ ctx context.Context,
+) (oidc.KeySet, error) {
+ // Note: We could consider an RWLock or singleflight if perf proves to be
+ // poor here. As written, I don't expect serialized warm-cache requests to
+ // accumulate enough to be worth the added complexity.
+ v.mu.Lock()
+ defer v.mu.Unlock()
+
+ now := v.clock.Now()
+
+ if !v.discoveryConfigExpires.IsZero() && now.After(v.discoveryConfigExpires) {
+ // Invalidate the cached value.
+ v.discoveryConfig = nil
+ v.discoveryConfigExpires = time.Time{}
+
+ v.logger.DebugContext(ctx, "Invalidating expired discovery config")
+ }
+
+ if v.discoveryConfig == nil {
+ v.logger.DebugContext(ctx, "Fetching new discovery config")
+
+ // Note: This is the only blocking call inside the mutex.
+ // In the future, it might be a good idea to fetch the new discovery
+ // config async and keep it available if the refresh fails.
+ dc, err := client.Discover(ctx, v.issuer, v.client)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ v.discoveryConfig = dc
+ v.discoveryConfigExpires = now.Add(discoveryTTL)
+
+ if v.lastJWKSURI != "" && v.lastJWKSURI != dc.JwksURI {
+ // If the JWKS URI has changed, expire the keyset now.
+ v.keySet = nil
+ v.keySetExpires = time.Time{}
+ }
+ v.lastJWKSURI = dc.JwksURI
+ }
+
+ // If this upstream issue is fixed, we can remove this in favor of keeping
+ // the KeySet: https://github.com/zitadel/oidc/issues/747
+ if !v.keySetExpires.IsZero() && now.After(v.keySetExpires) {
+ // Invalidate the cached value.
+ v.keySet = nil
+ v.keySetExpires = time.Time{}
+
+ v.logger.DebugContext(ctx, "Invalidating expired KeySet")
+ }
+
+ if v.keySet == nil {
+ v.logger.DebugContext(ctx, "Creating new remote KeySet")
+ v.keySet = rp.NewRemoteKeySet(v.client, v.discoveryConfig.JwksURI)
+ v.keySetExpires = now.Add(keySetTTL)
+ }
+
+ return v.keySet, nil
+}
+
+func zoidcTokenVerifier[C oidc.Claims](
+ ctx context.Context,
+ issuer,
+ clientID string,
+ keySet oidc.KeySet,
+ token string,
+ opts ...rp.VerifierOption,
+) (C, error) {
+ verifier := rp.NewIDTokenVerifier(issuer, clientID, keySet, opts...)
+
+ // Note: VerifyIDToken() may mutate the KeySet (if the keyset is empty or if
+ // it encounters an unknown `kid`). The keyset manages a mutex of its own,
+ // so we don't need to protect this operation. It's acceptable for this
+ // keyset to be swapped in another thread and still used here; it will just
+ // be GC'd afterward.
+ claims, err := rp.VerifyIDToken[C](ctx, token, verifier)
+ if err != nil {
+ return *new(C), trace.Wrap(err, "verifying token")
+ }
+
+ return claims, nil
+}
+
+// ValidateToken verifies a compact encoded token against the configured
+// issuer and keys, potentially using cached OpenID configuration and JWKS
+// values.
+func (v *CachingValidatorInstance[C]) ValidateToken(
+ ctx context.Context,
+ token string,
+ opts ...rp.VerifierOption,
+) (C, error) {
+ timeoutCtx, cancel := context.WithTimeout(ctx, providerTimeout)
+ defer cancel()
+
+ ks, err := v.getKeySet(timeoutCtx)
+ if err != nil {
+ return *new(C), trace.Wrap(err)
+ }
+
+ claims, err := v.verifierFn(ctx, v.issuer, v.audience, ks, token, opts...)
+ if err != nil {
+ return *new(C), trace.Wrap(err)
+ }
+
+ return claims, nil
+}
diff --git a/lib/oidc/caching_token_validator_test.go b/lib/oidc/caching_token_validator_test.go
new file mode 100644
index 0000000000000..8e51ef93081d5
--- /dev/null
+++ b/lib/oidc/caching_token_validator_test.go
@@ -0,0 +1,396 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package oidc
+
+import (
+ "context"
+ "crypto"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/go-jose/go-jose/v4"
+ "github.com/go-jose/go-jose/v4/jwt"
+ "github.com/jonboulle/clockwork"
+ "github.com/stretchr/testify/require"
+ "github.com/zitadel/oidc/v3/pkg/client/rp"
+ "github.com/zitadel/oidc/v3/pkg/oidc"
+
+ "github.com/gravitational/teleport/lib/cryptosuites"
+)
+
+// fakeIDP provides a minimal fake OIDC provider for use in tests
+type fakeIDP struct {
+ t *testing.T
+ clock *clockwork.FakeClock
+ signer jose.Signer
+ publicKey crypto.PublicKey
+ server *httptest.Server
+ audience string
+
+ useAlternateJWKSEndpoint atomic.Bool
+ configRequests atomic.Uint32
+ jwksRequests atomic.Uint32
+}
+
+func newFakeIDP(t *testing.T, clock *clockwork.FakeClock, audience string) *fakeIDP {
+ privateKey, err := cryptosuites.GenerateKeyWithAlgorithm(cryptosuites.RSA2048)
+ require.NoError(t, err)
+
+ signer, err := jose.NewSigner(
+ jose.SigningKey{Algorithm: jose.RS256, Key: privateKey},
+ (&jose.SignerOptions{}).WithType("JWT"),
+ )
+ require.NoError(t, err)
+
+ f := &fakeIDP{
+ clock: clock,
+ signer: signer,
+ publicKey: privateKey.Public(),
+ t: t,
+ audience: audience,
+ }
+
+ providerMux := http.NewServeMux()
+ providerMux.HandleFunc(
+ "/.well-known/openid-configuration",
+ f.handleOpenIDConfig,
+ )
+ providerMux.HandleFunc(
+ "/.well-known/jwks",
+ f.handleJWKSEndpoint,
+ )
+ providerMux.HandleFunc(
+ "/.well-known/jwks-alt",
+ f.handleJWKSEndpoint,
+ )
+
+ srv := httptest.NewServer(providerMux)
+ t.Cleanup(srv.Close)
+ f.server = srv
+ return f
+}
+
+func (f *fakeIDP) issuer() string {
+ return f.server.URL
+}
+
+func (f *fakeIDP) handleOpenIDConfig(w http.ResponseWriter, r *http.Request) {
+ jwksURI := f.issuer() + "/.well-known/jwks"
+ if f.useAlternateJWKSEndpoint.Load() {
+ jwksURI += "-alt"
+ }
+
+ response := map[string]any{
+ "claims_supported": []string{
+ "sub",
+ "iss",
+ },
+ "id_token_signing_alg_values_supported": []string{"RS256"},
+ "issuer": f.issuer(),
+ "jwks_uri": jwksURI,
+ "response_types_supported": []string{"id_token"},
+ "scopes_supported": []string{"openid"},
+ "subject_types_supported": []string{"public"},
+ }
+ responseBytes, err := json.Marshal(response)
+ require.NoError(f.t, err)
+ _, err = w.Write(responseBytes)
+ require.NoError(f.t, err)
+
+ f.configRequests.Add(1)
+}
+
+func (f *fakeIDP) handleJWKSEndpoint(w http.ResponseWriter, r *http.Request) {
+ jwks := jose.JSONWebKeySet{
+ Keys: []jose.JSONWebKey{
+ {
+ Key: f.publicKey,
+ },
+ },
+ }
+ responseBytes, err := json.Marshal(jwks)
+ require.NoError(f.t, err)
+ _, err = w.Write(responseBytes)
+ require.NoError(f.t, err)
+
+ f.jwksRequests.Add(1)
+}
+
+func (f *fakeIDP) issueToken(
+ t *testing.T,
+ audience,
+ sub string,
+ ttl time.Duration,
+) string {
+ claims := oidc.TokenClaims{
+ Issuer: f.issuer(),
+ Subject: sub,
+ Audience: oidc.Audience{audience},
+ IssuedAt: oidc.FromTime(f.clock.Now()),
+ NotBefore: oidc.FromTime(f.clock.Now()),
+ Expiration: oidc.FromTime(f.clock.Now().Add(ttl)),
+ }
+
+ token, err := jwt.Signed(f.signer).
+ Claims(claims).
+ Serialize()
+ require.NoError(t, err)
+
+ return token
+}
+
+// TestCachingTokenValidator runs various tests against the caching token
+// validator
+func TestCachingTokenValidator(t *testing.T) {
+ t.Parallel()
+
+ const defaultAudience = "example.teleport.sh"
+
+ // A minimal validator that skips most checks, especially any that depend on
+ // the system clock. We do this mainly to still invoke
+ // `keySet.VerifySignature()`.
+ minimalValidator := func() func(
+ context.Context,
+ string, string, oidc.KeySet, string, ...rp.VerifierOption) (*oidc.TokenClaims, error) {
+ return func(
+ ctx context.Context,
+ issuer,
+ clientID string,
+ keySet oidc.KeySet,
+ token string,
+ opts ...rp.VerifierOption,
+ ) (*oidc.TokenClaims, error) {
+ var claims oidc.TokenClaims
+ _, err := oidc.ParseToken(token, &claims)
+ if err != nil {
+ return nil, err
+ }
+
+ jws, err := jose.ParseSigned(token, []jose.SignatureAlgorithm{jose.RS256})
+ if err != nil {
+ return nil, err
+ }
+
+ _, err = keySet.VerifySignature(ctx, jws)
+ if err != nil {
+ return nil, err
+ }
+
+ return &claims, nil
+ }
+ }
+
+ tests := []struct {
+ name string
+ audience string
+ execute func(t *testing.T, idp *fakeIDP, v *CachingTokenValidator[*oidc.TokenClaims])
+ }{
+ {
+ name: "empty",
+ audience: defaultAudience,
+ execute: func(t *testing.T, idp *fakeIDP, v *CachingTokenValidator[*oidc.TokenClaims]) {
+ // Do nothing.
+ require.Zero(t, idp.configRequests.Load())
+ require.Zero(t, idp.jwksRequests.Load())
+ },
+ },
+ {
+ name: "single validator",
+ audience: defaultAudience,
+ execute: func(t *testing.T, idp *fakeIDP, v *CachingTokenValidator[*oidc.TokenClaims]) {
+ val, err := v.GetValidator(t.Context(), idp.issuer(), defaultAudience)
+ require.NoError(t, err)
+
+ token := idp.issueToken(t, defaultAudience, "a", time.Hour)
+ claims, err := val.ValidateToken(t.Context(), token)
+ require.NoError(t, err)
+
+ require.Equal(t, "a", claims.Subject)
+ require.EqualValues(t, 1, idp.configRequests.Load())
+ require.EqualValues(t, 1, idp.jwksRequests.Load())
+
+ token = idp.issueToken(t, defaultAudience, "b", time.Hour)
+ claims, err = val.ValidateToken(t.Context(), token)
+ require.NoError(t, err)
+
+ require.Equal(t, "b", claims.Subject)
+ require.EqualValues(t, 1, idp.configRequests.Load())
+ require.EqualValues(t, 1, idp.jwksRequests.Load())
+ },
+ },
+ {
+ name: "multiple validators",
+ audience: defaultAudience,
+ execute: func(t *testing.T, idp *fakeIDP, v *CachingTokenValidator[*oidc.TokenClaims]) {
+ v1, err := v.GetValidator(t.Context(), idp.issuer(), "a.teleport.sh")
+ require.NoError(t, err)
+ v2, err := v.GetValidator(t.Context(), idp.issuer(), "b.teleport.sh")
+ require.NoError(t, err)
+
+ token := idp.issueToken(t, "a.teleport.sh", "a", time.Hour)
+ claims, err := v1.ValidateToken(t.Context(), token)
+ require.NoError(t, err)
+
+ require.Equal(t, "a", claims.Subject)
+ require.EqualValues(t, 1, idp.configRequests.Load(), "config")
+ require.EqualValues(t, 1, idp.jwksRequests.Load(), "jwks")
+
+ token = idp.issueToken(t, "b.teleport.sh", "b", time.Hour)
+ claims, err = v2.ValidateToken(t.Context(), token)
+ require.NoError(t, err)
+
+ require.Equal(t, "b", claims.Subject)
+ require.EqualValues(t, 2, idp.configRequests.Load())
+ require.EqualValues(t, 2, idp.jwksRequests.Load())
+
+ // Validating against a bad token should fail, and should not
+ // result in spurious requests.
+ token = idp.issueToken(t, "c.teleport.sh", "c", time.Hour)
+ _, err = v2.ValidateToken(t.Context(), token)
+ require.Error(t, err)
+
+ require.EqualValues(t, 2, idp.configRequests.Load())
+ require.EqualValues(t, 2, idp.jwksRequests.Load())
+ },
+ },
+ {
+ name: "expired config",
+ audience: defaultAudience,
+ execute: func(t *testing.T, idp *fakeIDP, v *CachingTokenValidator[*oidc.TokenClaims]) {
+ val, err := v.GetValidator(t.Context(), idp.issuer(), defaultAudience)
+ require.NoError(t, err)
+ val.verifierFn = minimalValidator()
+
+ token := idp.issueToken(t, defaultAudience, "a", time.Hour)
+ _, err = val.ValidateToken(t.Context(), token)
+ require.NoError(t, err)
+
+ require.EqualValues(t, 1, idp.configRequests.Load())
+ require.EqualValues(t, 1, idp.jwksRequests.Load())
+
+ idp.clock.Advance(discoveryTTL + time.Minute)
+ token = idp.issueToken(t, defaultAudience, "b", time.Hour)
+ _, err = val.ValidateToken(t.Context(), token)
+ require.NoError(t, err)
+
+ require.EqualValues(t, 2, idp.configRequests.Load())
+ require.EqualValues(t, 1, idp.jwksRequests.Load())
+ },
+ },
+ {
+ name: "stale config",
+ audience: defaultAudience,
+ execute: func(t *testing.T, idp *fakeIDP, v *CachingTokenValidator[*oidc.TokenClaims]) {
+ val, err := v.GetValidator(t.Context(), idp.issuer(), defaultAudience)
+ require.NoError(t, err)
+ val.verifierFn = minimalValidator()
+
+ idp.clock.Advance(validatorTTL + time.Minute)
+
+ token := idp.issueToken(t, defaultAudience, "a", time.Hour)
+ _, err = val.ValidateToken(t.Context(), token)
+ require.NoError(t, err)
+
+ // This validation attempt should fetch both the config and JWKS
+ // endpoint.
+ require.EqualValues(t, 1, idp.configRequests.Load())
+ require.EqualValues(t, 1, idp.jwksRequests.Load())
+
+ idp.clock.Advance(discoveryTTL + time.Minute)
+ token = idp.issueToken(t, defaultAudience, "b", time.Hour)
+ _, err = val.ValidateToken(t.Context(), token)
+ require.NoError(t, err)
+
+ // Config should be reloaded, but the keyset will remain cached
+ require.EqualValues(t, 2, idp.configRequests.Load())
+ require.EqualValues(t, 1, idp.jwksRequests.Load())
+ },
+ },
+ {
+ name: "changed jwks uri",
+ audience: defaultAudience,
+ execute: func(t *testing.T, idp *fakeIDP, v *CachingTokenValidator[*oidc.TokenClaims]) {
+ val, err := v.GetValidator(t.Context(), idp.issuer(), defaultAudience)
+ require.NoError(t, err)
+ val.verifierFn = minimalValidator()
+
+ idp.clock.Advance(validatorTTL + time.Minute)
+
+ token := idp.issueToken(t, defaultAudience, "a", time.Hour)
+ _, err = val.ValidateToken(t.Context(), token)
+ require.NoError(t, err)
+
+ // This validation attempt should fetch both the config and JWKS
+ // endpoint.
+ require.EqualValues(t, 1, idp.configRequests.Load())
+ require.EqualValues(t, 1, idp.jwksRequests.Load())
+
+ // Switch to the new endpoint, advance the clock enough to
+ // trigger a config refresh, and validate again.
+ idp.useAlternateJWKSEndpoint.Store(true)
+ idp.clock.Advance(discoveryTTL + time.Minute)
+ token = idp.issueToken(t, defaultAudience, "b", time.Hour)
+ _, err = val.ValidateToken(t.Context(), token)
+ require.NoError(t, err)
+
+ // Config should be reloaded, and the keyset should be reloaded.
+ require.EqualValues(t, 2, idp.configRequests.Load())
+ require.EqualValues(t, 2, idp.jwksRequests.Load())
+ },
+ },
+ {
+ name: "validator pruning",
+ audience: defaultAudience,
+ execute: func(t *testing.T, idp *fakeIDP, v *CachingTokenValidator[*oidc.TokenClaims]) {
+ valOld, err := v.GetValidator(t.Context(), idp.issuer(), "a")
+ require.NoError(t, err)
+
+ // After just 1 hour, it should return the same pointer
+ idp.clock.Advance(time.Hour + time.Minute)
+
+ valTemp, err := v.GetValidator(t.Context(), idp.issuer(), "a")
+ require.NoError(t, err)
+ require.Same(t, valOld, valTemp)
+
+ // After 48 hours, make the request again. It's now past its
+ // TTL and should be recreated.
+ idp.clock.Advance(validatorTTL + time.Minute)
+ valNew, err := v.GetValidator(t.Context(), idp.issuer(), "a")
+ require.NoError(t, err)
+ require.NotSame(t, valNew, valOld)
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ clock := clockwork.NewFakeClock()
+ idp := newFakeIDP(t, clock, tt.audience)
+
+ validator, err := NewCachingTokenValidator[*oidc.TokenClaims](clock)
+ require.NoError(t, err)
+
+ tt.execute(t, idp, validator)
+ })
+ }
+}